ykjung Claude Sonnet 4.6 commited on
Commit ·
8acad99
1
Parent(s): d14a456
fix(rl): stage 저장 누락·WebSocket 진행률 미전송 수정
Browse files- rl_service: train_rl에 stage 파라미터 추가, model_json에 stage 저장
- rl_service: predict_rl에서 model_json.stage 자동 로드 (차원 불일치 방지)
- rl_service: SB3 ProgressCallback 추가 — 10,000 스텝마다 진행률 콜백
- rl_service: 스레드 콜백 → asyncio 이벤트루프 브리지 (run_coroutine_threadsafe)
- routers/rl: train_progress_callback 연결, 10k 스텝마다 WS 진행률 전송
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- routers/rl.py +11 -1
- services/rl_service.py +42 -5
routers/rl.py
CHANGED
|
@@ -160,7 +160,17 @@ async def websocket_rl_train(websocket: WebSocket):
|
|
| 160 |
|
| 161 |
# ── 2단계: PPO 학습 (blocking → executor) ───────────
|
| 162 |
# 브라우저가 닫혀도 서버에서 계속 실행됩니다.
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
_set_job(status="complete", train_progress=100, result=result)
|
| 166 |
await _send({"type": "training", "progress": 100})
|
|
|
|
| 160 |
|
| 161 |
# ── 2단계: PPO 학습 (blocking → executor) ───────────
|
| 162 |
# 브라우저가 닫혀도 서버에서 계속 실행됩니다.
|
| 163 |
+
# 10,000 스텝마다 진행률을 클라이언트에 전송합니다.
|
| 164 |
+
async def on_train_progress(pct: int):
|
| 165 |
+
_set_job(train_progress=pct)
|
| 166 |
+
await _send({"type": "training", "progress": pct,
|
| 167 |
+
"message": f"PPO 학습 중... {pct}%"})
|
| 168 |
+
|
| 169 |
+
result = await rl_service.train_rl(
|
| 170 |
+
episodes, model_name, total_timesteps,
|
| 171 |
+
stage=stage,
|
| 172 |
+
train_progress_callback=on_train_progress,
|
| 173 |
+
)
|
| 174 |
|
| 175 |
_set_job(status="complete", train_progress=100, result=result)
|
| 176 |
await _send({"type": "training", "progress": 100})
|
services/rl_service.py
CHANGED
|
@@ -113,13 +113,37 @@ async def collect_rl_episodes(
|
|
| 113 |
# PPO 학습 (동기 → executor에서 호출)
|
| 114 |
# ─────────────────────────────────────────────────────────
|
| 115 |
|
| 116 |
-
def _train_ppo_sync(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
"""stable-baselines3 PPO 학습 (blocking). asyncio executor에서 실행됩니다."""
|
|
|
|
| 118 |
from services.rl_environment import StockTradingEnv
|
| 119 |
|
| 120 |
PPO = _get_ppo()
|
| 121 |
env = StockTradingEnv(episodes)
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
model = PPO(
|
| 124 |
"MlpPolicy",
|
| 125 |
env,
|
|
@@ -132,8 +156,9 @@ def _train_ppo_sync(episodes: list[dict], total_timesteps: int) -> object:
|
|
| 132 |
verbose=0,
|
| 133 |
)
|
| 134 |
|
|
|
|
| 135 |
logger.info(f"[RL:Train] PPO 학습 시작: timesteps={total_timesteps}, episodes={len(episodes)}")
|
| 136 |
-
model.learn(total_timesteps=total_timesteps)
|
| 137 |
logger.info("[RL:Train] PPO 학습 완료")
|
| 138 |
return model
|
| 139 |
|
|
@@ -184,13 +209,21 @@ async def train_rl(
|
|
| 184 |
episodes: list[dict],
|
| 185 |
model_name: str,
|
| 186 |
total_timesteps: int = 300_000,
|
|
|
|
|
|
|
| 187 |
) -> dict:
|
| 188 |
"""PPO 학습 후 Supabase ml_models 테이블에 저장합니다."""
|
| 189 |
from services import supabase_service
|
| 190 |
|
| 191 |
loop = asyncio.get_event_loop()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
model = await loop.run_in_executor(
|
| 193 |
-
None, _train_ppo_sync, episodes, total_timesteps
|
| 194 |
)
|
| 195 |
|
| 196 |
model_b64 = await loop.run_in_executor(None, _serialize_model, model)
|
|
@@ -207,10 +240,11 @@ async def train_rl(
|
|
| 207 |
"auc": 0.0,
|
| 208 |
"feature_count": n_features,
|
| 209 |
"sample_count": total_steps,
|
| 210 |
-
"stage":
|
| 211 |
"model_json": {
|
| 212 |
"type": "rl",
|
| 213 |
"algorithm": "PPO",
|
|
|
|
| 214 |
"n_episodes": len(episodes),
|
| 215 |
"total_timesteps": total_timesteps,
|
| 216 |
"model_b64": model_b64,
|
|
@@ -234,7 +268,7 @@ async def predict_rl(
|
|
| 234 |
model_id: str,
|
| 235 |
ticker: str,
|
| 236 |
days: int = 500,
|
| 237 |
-
stage: int =
|
| 238 |
) -> dict:
|
| 239 |
"""RL 모델로 종목의 날짜별 BUY/HOLD/SELL 시퀀스를 반환합니다."""
|
| 240 |
from services import supabase_service, data_collector
|
|
@@ -246,6 +280,9 @@ async def predict_rl(
|
|
| 246 |
if not isinstance(model_json, dict) or model_json.get("type") != "rl":
|
| 247 |
raise ValueError(f"모델 {model_id}는 RL 모델이 아닙니다 (type={model_json.get('type')})")
|
| 248 |
|
|
|
|
|
|
|
|
|
|
| 249 |
loop = asyncio.get_event_loop()
|
| 250 |
model = await loop.run_in_executor(
|
| 251 |
None, _deserialize_model, model_json["model_b64"]
|
|
|
|
| 113 |
# PPO 학습 (동기 → executor에서 호출)
|
| 114 |
# ─────────────────────────────────────────────────────────
|
| 115 |
|
| 116 |
+
def _train_ppo_sync(
|
| 117 |
+
episodes: list[dict],
|
| 118 |
+
total_timesteps: int,
|
| 119 |
+
on_progress=None, # callable(pct: int) — 스레드에서 호출 (동기)
|
| 120 |
+
) -> object:
|
| 121 |
"""stable-baselines3 PPO 학습 (blocking). asyncio executor에서 실행됩니다."""
|
| 122 |
+
from stable_baselines3.common.callbacks import BaseCallback
|
| 123 |
from services.rl_environment import StockTradingEnv
|
| 124 |
|
| 125 |
PPO = _get_ppo()
|
| 126 |
env = StockTradingEnv(episodes)
|
| 127 |
|
| 128 |
+
class _ProgressCb(BaseCallback):
|
| 129 |
+
"""10,000 스텝마다 진행률을 on_progress 콜백으로 전달합니다."""
|
| 130 |
+
def __init__(self, total, cb):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self._total = total
|
| 133 |
+
self._cb = cb
|
| 134 |
+
self._last_pct = -1
|
| 135 |
+
|
| 136 |
+
def _on_step(self) -> bool:
|
| 137 |
+
if self._cb and self.num_timesteps % 10_000 == 0:
|
| 138 |
+
pct = min(99, int(self.num_timesteps / self._total * 100))
|
| 139 |
+
if pct != self._last_pct:
|
| 140 |
+
self._last_pct = pct
|
| 141 |
+
try:
|
| 142 |
+
self._cb(pct)
|
| 143 |
+
except Exception:
|
| 144 |
+
pass
|
| 145 |
+
return True
|
| 146 |
+
|
| 147 |
model = PPO(
|
| 148 |
"MlpPolicy",
|
| 149 |
env,
|
|
|
|
| 156 |
verbose=0,
|
| 157 |
)
|
| 158 |
|
| 159 |
+
cb = _ProgressCb(total_timesteps, on_progress) if on_progress else None
|
| 160 |
logger.info(f"[RL:Train] PPO 학습 시작: timesteps={total_timesteps}, episodes={len(episodes)}")
|
| 161 |
+
model.learn(total_timesteps=total_timesteps, callback=cb)
|
| 162 |
logger.info("[RL:Train] PPO 학습 완료")
|
| 163 |
return model
|
| 164 |
|
|
|
|
| 209 |
episodes: list[dict],
|
| 210 |
model_name: str,
|
| 211 |
total_timesteps: int = 300_000,
|
| 212 |
+
stage: int = 6,
|
| 213 |
+
train_progress_callback=None, # async callable(pct: int)
|
| 214 |
) -> dict:
|
| 215 |
"""PPO 학습 후 Supabase ml_models 테이블에 저장합니다."""
|
| 216 |
from services import supabase_service
|
| 217 |
|
| 218 |
loop = asyncio.get_event_loop()
|
| 219 |
+
|
| 220 |
+
# 동기 콜백 → asyncio 이벤트 루프로 브리지
|
| 221 |
+
def _sync_progress(pct: int):
|
| 222 |
+
if train_progress_callback:
|
| 223 |
+
asyncio.run_coroutine_threadsafe(train_progress_callback(pct), loop)
|
| 224 |
+
|
| 225 |
model = await loop.run_in_executor(
|
| 226 |
+
None, _train_ppo_sync, episodes, total_timesteps, _sync_progress
|
| 227 |
)
|
| 228 |
|
| 229 |
model_b64 = await loop.run_in_executor(None, _serialize_model, model)
|
|
|
|
| 240 |
"auc": 0.0,
|
| 241 |
"feature_count": n_features,
|
| 242 |
"sample_count": total_steps,
|
| 243 |
+
"stage": stage, # 학습에 사용된 피처 stage (예측 시 동일 stage 필요)
|
| 244 |
"model_json": {
|
| 245 |
"type": "rl",
|
| 246 |
"algorithm": "PPO",
|
| 247 |
+
"stage": stage,
|
| 248 |
"n_episodes": len(episodes),
|
| 249 |
"total_timesteps": total_timesteps,
|
| 250 |
"model_b64": model_b64,
|
|
|
|
| 268 |
model_id: str,
|
| 269 |
ticker: str,
|
| 270 |
days: int = 500,
|
| 271 |
+
stage: int | None = None, # None이면 model_json에서 자동 로드
|
| 272 |
) -> dict:
|
| 273 |
"""RL 모델로 종목의 날짜별 BUY/HOLD/SELL 시퀀스를 반환합니다."""
|
| 274 |
from services import supabase_service, data_collector
|
|
|
|
| 280 |
if not isinstance(model_json, dict) or model_json.get("type") != "rl":
|
| 281 |
raise ValueError(f"모델 {model_id}는 RL 모델이 아닙니다 (type={model_json.get('type')})")
|
| 282 |
|
| 283 |
+
# stage: model_json 우선 (학습 시 저장된 값), 없으면 파라미터, 최후 기본값 6
|
| 284 |
+
stage = model_json.get("stage") or stage or model_record.get("stage") or 6
|
| 285 |
+
|
| 286 |
loop = asyncio.get_event_loop()
|
| 287 |
model = await loop.run_in_executor(
|
| 288 |
None, _deserialize_model, model_json["model_b64"]
|