Spaces:
Running
Running
File size: 4,114 Bytes
4ec75cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | # tests/test_endpoints.py
# Basic endpoint tests for the environment.
# Run: python -m pytest tests/ -v
import requests
import pytest
BASE_URL = 'http://localhost:7860'
def test_health_check():
"""GET / should return 200 with status ok."""
r = requests.get(f'{BASE_URL}/')
assert r.status_code == 200
data = r.json()
assert data['status'] == 'ok'
assert data['tasks'] == 9
def test_reset_valid_task():
"""POST /reset with valid task_id should return episode_id and observation."""
r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'sec_easy'})
assert r.status_code == 200
data = r.json()
assert 'episode_id' in data
assert 'observation' in data
assert data['observation']['task_type'] == 'security'
def test_reset_all_tasks():
"""POST /reset should work for all 9 task IDs."""
tasks = [
'sec_easy', 'sec_medium', 'sec_hard',
'dep_easy', 'dep_medium', 'dep_hard',
'cli_easy', 'cli_medium', 'cli_hard',
]
for task_id in tasks:
r = requests.post(f'{BASE_URL}/reset', json={'task_id': task_id})
assert r.status_code == 200
data = r.json()
assert 'episode_id' in data, f'No episode_id for {task_id}'
assert 'observation' in data, f'No observation for {task_id}'
def test_reset_invalid_task():
"""POST /reset with invalid task_id should still return 200."""
r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'nonexistent'})
assert r.status_code == 200
def test_step_valid_action():
"""POST /step with valid action should return reward and observation."""
# Reset first
r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'sec_easy'})
ep_id = r.json()['episode_id']
# Step
action = {
'episode_id': ep_id,
'action_type': 'identify_vulnerability',
'vuln_type': 'sql_injection',
'cvss_score': 9.1,
'severity': 'critical',
'affected_line': 1,
}
r = requests.post(f'{BASE_URL}/step', json=action)
assert r.status_code == 200
data = r.json()
assert 'reward' in data
assert 'done' in data
assert 'observation' in data
assert 0.0 <= data['reward'] <= 1.0
def test_step_invalid_episode():
"""POST /step with invalid episode_id should return 200 with done=True."""
r = requests.post(f'{BASE_URL}/step', json={
'episode_id': 'nonexistent',
'action_type': 'identify_vulnerability',
})
assert r.status_code == 200
data = r.json()
assert data['done'] is True
def test_state_endpoint():
"""GET /state should return episode info."""
r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'sec_easy'})
ep_id = r.json()['episode_id']
r = requests.get(f'{BASE_URL}/state', params={'episode_id': ep_id})
assert r.status_code == 200
data = r.json()
assert data['episode_id'] == ep_id
assert data['done'] is False
def test_reward_range():
"""Rewards should always be in [0.0, 1.0]."""
tasks = ['sec_easy', 'dep_easy', 'cli_easy']
for task_id in tasks:
r = requests.post(f'{BASE_URL}/reset', json={'task_id': task_id})
ep_id = r.json()['episode_id']
# Send an invalid action
r = requests.post(f'{BASE_URL}/step', json={
'episode_id': ep_id,
'action_type': 'invalid_action_type',
})
data = r.json()
assert 0.0 <= data['reward'] <= 1.0, f'Reward out of range for {task_id}'
def test_step_enriched_observation():
"""Step observations should include task context fields."""
r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'sec_easy'})
ep_id = r.json()['episode_id']
action = {
'episode_id': ep_id,
'action_type': 'identify_vulnerability',
'vuln_type': 'sql_injection',
'cvss_score': 9.1,
'severity': 'critical',
'affected_line': 1,
}
r = requests.post(f'{BASE_URL}/step', json=action)
obs = r.json()['observation']
assert 'task_type' in obs
assert 'max_steps' in obs
assert 'steps_remaining' in obs
|