ykjung Claude Sonnet 4.6 commited on
Commit
8acad99
·
1 Parent(s): d14a456

fix(rl): stage 저장 누락·WebSocket 진행률 미전송 수정

Browse files

- rl_service: train_rl에 stage 파라미터 추가, model_json에 stage 저장
- rl_service: predict_rl에서 model_json.stage 자동 로드 (차원 불일치 방지)
- rl_service: SB3 ProgressCallback 추가 — 10,000 스텝마다 진행률 콜백
- rl_service: 스레드 콜백 → asyncio 이벤트루프 브리지 (run_coroutine_threadsafe)
- routers/rl: train_progress_callback 연결, 10k 스텝마다 WS 진행률 전송

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. routers/rl.py +11 -1
  2. services/rl_service.py +42 -5
routers/rl.py CHANGED
@@ -160,7 +160,17 @@ async def websocket_rl_train(websocket: WebSocket):
160
 
161
  # ── 2단계: PPO 학습 (blocking → executor) ───────────
162
  # 브라우저가 닫혀도 서버에서 계속 실행됩니다.
163
- result = await rl_service.train_rl(episodes, model_name, total_timesteps)
 
 
 
 
 
 
 
 
 
 
164
 
165
  _set_job(status="complete", train_progress=100, result=result)
166
  await _send({"type": "training", "progress": 100})
 
160
 
161
  # ── 2단계: PPO 학습 (blocking → executor) ───────────
162
  # 브라우저가 닫혀도 서버에서 계속 실행됩니다.
163
+ # 10,000 스텝마다 진행률을 클라이언트에 전송합니다.
164
+ async def on_train_progress(pct: int):
165
+ _set_job(train_progress=pct)
166
+ await _send({"type": "training", "progress": pct,
167
+ "message": f"PPO 학습 중... {pct}%"})
168
+
169
+ result = await rl_service.train_rl(
170
+ episodes, model_name, total_timesteps,
171
+ stage=stage,
172
+ train_progress_callback=on_train_progress,
173
+ )
174
 
175
  _set_job(status="complete", train_progress=100, result=result)
176
  await _send({"type": "training", "progress": 100})
services/rl_service.py CHANGED
@@ -113,13 +113,37 @@ async def collect_rl_episodes(
113
  # PPO 학습 (동기 → executor에서 호출)
114
  # ─────────────────────────────────────────────────────────
115
 
116
- def _train_ppo_sync(episodes: list[dict], total_timesteps: int) -> object:
 
 
 
 
117
  """stable-baselines3 PPO 학습 (blocking). asyncio executor에서 실행됩니다."""
 
118
  from services.rl_environment import StockTradingEnv
119
 
120
  PPO = _get_ppo()
121
  env = StockTradingEnv(episodes)
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  model = PPO(
124
  "MlpPolicy",
125
  env,
@@ -132,8 +156,9 @@ def _train_ppo_sync(episodes: list[dict], total_timesteps: int) -> object:
132
  verbose=0,
133
  )
134
 
 
135
  logger.info(f"[RL:Train] PPO 학습 시작: timesteps={total_timesteps}, episodes={len(episodes)}")
136
- model.learn(total_timesteps=total_timesteps)
137
  logger.info("[RL:Train] PPO 학습 완료")
138
  return model
139
 
@@ -184,13 +209,21 @@ async def train_rl(
184
  episodes: list[dict],
185
  model_name: str,
186
  total_timesteps: int = 300_000,
 
 
187
  ) -> dict:
188
  """PPO 학습 후 Supabase ml_models 테이블에 저장합니다."""
189
  from services import supabase_service
190
 
191
  loop = asyncio.get_event_loop()
 
 
 
 
 
 
192
  model = await loop.run_in_executor(
193
- None, _train_ppo_sync, episodes, total_timesteps
194
  )
195
 
196
  model_b64 = await loop.run_in_executor(None, _serialize_model, model)
@@ -207,10 +240,11 @@ async def train_rl(
207
  "auc": 0.0,
208
  "feature_count": n_features,
209
  "sample_count": total_steps,
210
- "stage": 0,
211
  "model_json": {
212
  "type": "rl",
213
  "algorithm": "PPO",
 
214
  "n_episodes": len(episodes),
215
  "total_timesteps": total_timesteps,
216
  "model_b64": model_b64,
@@ -234,7 +268,7 @@ async def predict_rl(
234
  model_id: str,
235
  ticker: str,
236
  days: int = 500,
237
- stage: int = 6,
238
  ) -> dict:
239
  """RL 모델로 종목의 날짜별 BUY/HOLD/SELL 시퀀스를 반환합니다."""
240
  from services import supabase_service, data_collector
@@ -246,6 +280,9 @@ async def predict_rl(
246
  if not isinstance(model_json, dict) or model_json.get("type") != "rl":
247
  raise ValueError(f"모델 {model_id}는 RL 모델이 아닙니다 (type={model_json.get('type')})")
248
 
 
 
 
249
  loop = asyncio.get_event_loop()
250
  model = await loop.run_in_executor(
251
  None, _deserialize_model, model_json["model_b64"]
 
113
  # PPO 학습 (동기 → executor에서 호출)
114
  # ─────────────────────────────────────────────────────────
115
 
116
+ def _train_ppo_sync(
117
+ episodes: list[dict],
118
+ total_timesteps: int,
119
+ on_progress=None, # callable(pct: int) — 스레드에서 호출 (동기)
120
+ ) -> object:
121
  """stable-baselines3 PPO 학습 (blocking). asyncio executor에서 실행됩니다."""
122
+ from stable_baselines3.common.callbacks import BaseCallback
123
  from services.rl_environment import StockTradingEnv
124
 
125
  PPO = _get_ppo()
126
  env = StockTradingEnv(episodes)
127
 
128
+ class _ProgressCb(BaseCallback):
129
+ """10,000 스텝마다 진행률을 on_progress 콜백으로 전달합니다."""
130
+ def __init__(self, total, cb):
131
+ super().__init__()
132
+ self._total = total
133
+ self._cb = cb
134
+ self._last_pct = -1
135
+
136
+ def _on_step(self) -> bool:
137
+ if self._cb and self.num_timesteps % 10_000 == 0:
138
+ pct = min(99, int(self.num_timesteps / self._total * 100))
139
+ if pct != self._last_pct:
140
+ self._last_pct = pct
141
+ try:
142
+ self._cb(pct)
143
+ except Exception:
144
+ pass
145
+ return True
146
+
147
  model = PPO(
148
  "MlpPolicy",
149
  env,
 
156
  verbose=0,
157
  )
158
 
159
+ cb = _ProgressCb(total_timesteps, on_progress) if on_progress else None
160
  logger.info(f"[RL:Train] PPO 학습 시작: timesteps={total_timesteps}, episodes={len(episodes)}")
161
+ model.learn(total_timesteps=total_timesteps, callback=cb)
162
  logger.info("[RL:Train] PPO 학습 완료")
163
  return model
164
 
 
209
  episodes: list[dict],
210
  model_name: str,
211
  total_timesteps: int = 300_000,
212
+ stage: int = 6,
213
+ train_progress_callback=None, # async callable(pct: int)
214
  ) -> dict:
215
  """PPO 학습 후 Supabase ml_models 테이블에 저장합니다."""
216
  from services import supabase_service
217
 
218
  loop = asyncio.get_event_loop()
219
+
220
+ # 동기 콜백 → asyncio 이벤트 루프로 브리지
221
+ def _sync_progress(pct: int):
222
+ if train_progress_callback:
223
+ asyncio.run_coroutine_threadsafe(train_progress_callback(pct), loop)
224
+
225
  model = await loop.run_in_executor(
226
+ None, _train_ppo_sync, episodes, total_timesteps, _sync_progress
227
  )
228
 
229
  model_b64 = await loop.run_in_executor(None, _serialize_model, model)
 
240
  "auc": 0.0,
241
  "feature_count": n_features,
242
  "sample_count": total_steps,
243
+ "stage": stage, # 학습에 사용된 피처 stage (예측 시 동일 stage 필요)
244
  "model_json": {
245
  "type": "rl",
246
  "algorithm": "PPO",
247
+ "stage": stage,
248
  "n_episodes": len(episodes),
249
  "total_timesteps": total_timesteps,
250
  "model_b64": model_b64,
 
268
  model_id: str,
269
  ticker: str,
270
  days: int = 500,
271
+ stage: int | None = None, # None이면 model_json에서 자동 로드
272
  ) -> dict:
273
  """RL 모델로 종목의 날짜별 BUY/HOLD/SELL 시퀀스를 반환합니다."""
274
  from services import supabase_service, data_collector
 
280
  if not isinstance(model_json, dict) or model_json.get("type") != "rl":
281
  raise ValueError(f"모델 {model_id}는 RL 모델이 아닙니다 (type={model_json.get('type')})")
282
 
283
+ # stage: model_json 우선 (학습 시 저장된 값), 없으면 파라미터, 최후 기본값 6
284
+ stage = model_json.get("stage") or stage or model_record.get("stage") or 6
285
+
286
  loop = asyncio.get_event_loop()
287
  model = await loop.run_in_executor(
288
  None, _deserialize_model, model_json["model_b64"]