| """ | |
| Train E-E-A-T signal regressors from content embeddings and optional link features. | |
| """ | |
| import json | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import mean_absolute_error, r2_score | |
| from sentence_transformers import SentenceTransformer | |
| import xgboost as xgb | |
| import joblib | |
| from config import ( | |
| DATA_PATH, | |
| MODEL_DIR, | |
| RANDOM_STATE, | |
| CONTENT_COLUMN, | |
| TARGET_COLUMNS, | |
| OPTIONAL_FEATURES, | |
| ) | |
| def main(): | |
| if not Path(DATA_PATH).exists(): | |
| print(f"Data not found at {DATA_PATH}. Create data/eeat_labels.csv with: {CONTENT_COLUMN}, {TARGET_COLUMNS}") | |
| return | |
| df = pd.read_csv(DATA_PATH) | |
| if CONTENT_COLUMN not in df.columns: | |
| raise ValueError(f"Missing column: {CONTENT_COLUMN}") | |
| targets = [c for c in TARGET_COLUMNS if c in df.columns] | |
| if not targets: | |
| raise ValueError(f"Need at least one of {TARGET_COLUMNS}") | |
| encoder = SentenceTransformer("all-MiniLM-L6-v2") | |
| content_emb = encoder.encode(df[CONTENT_COLUMN].fillna("").astype(str).tolist()) | |
| extra = [c for c in OPTIONAL_FEATURES if c in df.columns] | |
| if extra: | |
| X = np.hstack([content_emb, df[extra].fillna(0).values]) | |
| else: | |
| X = content_emb | |
| X = np.asarray(X) | |
| metrics = {} | |
| for t in targets: | |
| y = df[t].values | |
| X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=RANDOM_STATE) | |
| model = xgb.XGBRegressor(random_state=RANDOM_STATE, n_estimators=100) | |
| model.fit(X_train, y_train) | |
| pred = model.predict(X_val) | |
| metrics[t] = {"mae": float(mean_absolute_error(y_val, pred)), "r2": float(r2_score(y_val, pred))} | |
| joblib.dump(model, MODEL_DIR / f"eeat_{t}.joblib") | |
| joblib.dump(encoder, MODEL_DIR / "encoder.joblib") | |
| joblib.dump(extra, MODEL_DIR / "extra_features.joblib") | |
| with open(MODEL_DIR / "metrics.json", "w") as f: | |
| json.dump(metrics, f, indent=2) | |
| print("Metrics:", metrics) | |
| if __name__ == "__main__": | |
| main() | |