File size: 3,275 Bytes
02c6351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python3
"""
Script to download models from Hugging Face Hub if not present locally
"""

import logging
import os
from pathlib import Path

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def download_checkpoint_if_needed():
    """Download model checkpoint if not present locally"""
    # Check if we have any local checkpoints
    possible_checkpoints = [
        Path("agent_epoch_00206.pt"),
        Path("agent_epoch_00003.pt"),
        Path("checkpoints/agent_epoch_00206.pt"),
        Path("checkpoints/agent_epoch_00003.pt"),
    ]
    
    for ckpt_path in possible_checkpoints:
        if ckpt_path.exists():
            logger.info(f"Found local checkpoint: {ckpt_path}")
            return True
    
    logger.info("No local checkpoint found, attempting to download from Hugging Face Hub...")
    
    try:
        from huggingface_hub import hf_hub_download
        
        # This would download from a hypothetical HF model repository
        # You would need to upload your models to HF Hub first
        # Example:
        # checkpoint_path = hf_hub_download(
        #     repo_id="your-username/diamond-csgo-model",
        #     filename="agent_epoch_00206.pt",
        #     cache_dir="./checkpoints"
        # )
        
        logger.warning("Model download not implemented yet.")
        logger.warning("Please ensure you have model checkpoints available locally.")
        return False
        
    except ImportError:
        logger.error("huggingface_hub not installed. Cannot download models.")
        return False
    except Exception as e:
        logger.error(f"Failed to download models: {e}")
        return False

def setup_demo_data():
    """Set up minimal demo data if models are not available"""
    spawn_dir = Path("csgo/spawn/0")
    spawn_dir.mkdir(parents=True, exist_ok=True)
    
    # Create minimal dummy files for demo
    import numpy as np
    import json
    
    files_to_create = {
        "act.npy": np.zeros((100, 51)),  # 100 timesteps, 51 actions
        "low_res.npy": np.zeros((100, 3, 150, 600)),  # 100 frames
        "full_res.npy": np.zeros((100, 3, 300, 1200)),  # 100 high-res frames
        "next_act.npy": np.zeros((100, 51)),
    }
    
    for filename, data in files_to_create.items():
        file_path = spawn_dir / filename
        if not file_path.exists():
            np.save(file_path, data)
            logger.info(f"Created dummy file: {file_path}")
    
    # Create info.json
    info_path = spawn_dir / "info.json"
    if not info_path.exists():
        info_data = {
            "episode_length": 100,
            "total_reward": 0.0,
            "demo": True
        }
        with open(info_path, 'w') as f:
            json.dump(info_data, f)
        logger.info(f"Created info file: {info_path}")

if __name__ == "__main__":
    logger.info("Setting up Diamond CSGO demo...")
    
    # Try to download models
    has_models = download_checkpoint_if_needed()
    
    # Set up demo data
    setup_demo_data()
    
    if not has_models:
        logger.warning("Running in demo mode without trained models.")
        logger.warning("The AI agent will not function properly without model checkpoints.")
    
    logger.info("Setup complete!")