File size: 4,114 Bytes
4ec75cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# tests/test_endpoints.py
# Basic endpoint tests for the environment.
# Run: python -m pytest tests/ -v

import requests
import pytest

BASE_URL = 'http://localhost:7860'


def test_health_check():
    """GET / should return 200 with status ok."""
    r = requests.get(f'{BASE_URL}/')
    assert r.status_code == 200
    data = r.json()
    assert data['status'] == 'ok'
    assert data['tasks'] == 9


def test_reset_valid_task():
    """POST /reset with valid task_id should return episode_id and observation."""
    r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'sec_easy'})
    assert r.status_code == 200
    data = r.json()
    assert 'episode_id' in data
    assert 'observation' in data
    assert data['observation']['task_type'] == 'security'


def test_reset_all_tasks():
    """POST /reset should work for all 9 task IDs."""
    tasks = [
        'sec_easy', 'sec_medium', 'sec_hard',
        'dep_easy', 'dep_medium', 'dep_hard',
        'cli_easy', 'cli_medium', 'cli_hard',
    ]
    for task_id in tasks:
        r = requests.post(f'{BASE_URL}/reset', json={'task_id': task_id})
        assert r.status_code == 200
        data = r.json()
        assert 'episode_id' in data, f'No episode_id for {task_id}'
        assert 'observation' in data, f'No observation for {task_id}'


def test_reset_invalid_task():
    """POST /reset with invalid task_id should still return 200."""
    r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'nonexistent'})
    assert r.status_code == 200


def test_step_valid_action():
    """POST /step with valid action should return reward and observation."""
    # Reset first
    r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'sec_easy'})
    ep_id = r.json()['episode_id']

    # Step
    action = {
        'episode_id': ep_id,
        'action_type': 'identify_vulnerability',
        'vuln_type': 'sql_injection',
        'cvss_score': 9.1,
        'severity': 'critical',
        'affected_line': 1,
    }
    r = requests.post(f'{BASE_URL}/step', json=action)
    assert r.status_code == 200
    data = r.json()
    assert 'reward' in data
    assert 'done' in data
    assert 'observation' in data
    assert 0.0 <= data['reward'] <= 1.0


def test_step_invalid_episode():
    """POST /step with invalid episode_id should return 200 with done=True."""
    r = requests.post(f'{BASE_URL}/step', json={
        'episode_id': 'nonexistent',
        'action_type': 'identify_vulnerability',
    })
    assert r.status_code == 200
    data = r.json()
    assert data['done'] is True


def test_state_endpoint():
    """GET /state should return episode info."""
    r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'sec_easy'})
    ep_id = r.json()['episode_id']

    r = requests.get(f'{BASE_URL}/state', params={'episode_id': ep_id})
    assert r.status_code == 200
    data = r.json()
    assert data['episode_id'] == ep_id
    assert data['done'] is False


def test_reward_range():
    """Rewards should always be in [0.0, 1.0]."""
    tasks = ['sec_easy', 'dep_easy', 'cli_easy']
    for task_id in tasks:
        r = requests.post(f'{BASE_URL}/reset', json={'task_id': task_id})
        ep_id = r.json()['episode_id']

        # Send an invalid action
        r = requests.post(f'{BASE_URL}/step', json={
            'episode_id': ep_id,
            'action_type': 'invalid_action_type',
        })
        data = r.json()
        assert 0.0 <= data['reward'] <= 1.0, f'Reward out of range for {task_id}'


def test_step_enriched_observation():
    """Step observations should include task context fields."""
    r = requests.post(f'{BASE_URL}/reset', json={'task_id': 'sec_easy'})
    ep_id = r.json()['episode_id']

    action = {
        'episode_id': ep_id,
        'action_type': 'identify_vulnerability',
        'vuln_type': 'sql_injection',
        'cvss_score': 9.1,
        'severity': 'critical',
        'affected_line': 1,
    }
    r = requests.post(f'{BASE_URL}/step', json=action)
    obs = r.json()['observation']
    assert 'task_type' in obs
    assert 'max_steps' in obs
    assert 'steps_remaining' in obs