midah commited on
Commit
3e85304
·
1 Parent(s): 4b829ab

Add network pre-computation, styling improvements, and theme toggle

Browse files

- Add precompute_network.py and upload_network_to_hf.py scripts
- Update network endpoints to support pre-computed graphs and HF Hub download
- Add theme toggle with system preference detection
- Improve force-directed graph styling and controls
- Add rounded point rendering for embeddings view
- Update API retry logic for 429 errors
- Add comprehensive UI controls for FDG (color by, size by, family filter, search)
- Update documentation

Files changed (39) hide show
  1. APP_ANALYSIS.md +1 -0
  2. DEPLOYMENT_COMPLETE.md +1 -0
  3. DEPLOYMENT_STATUS.md +1 -0
  4. DEPLOY_TO_HF_SPACES.md +1 -0
  5. FDG_DESIGN_ANALYSIS.md +113 -0
  6. FORCE_DIRECTED_STATUS.md +1 -0
  7. HF_SPACES_DEPLOYMENT.md +1 -0
  8. HF_SPACES_READY.md +1 -0
  9. HOW_TO_RUN.md +1 -0
  10. PRODUCTION_DEPLOYMENT.md +1 -0
  11. README_SPACE.md +1 -0
  12. SCALING_EMBEDDINGS_STRATEGY.md +1 -0
  13. SCALING_QUICKSTART.md +1 -0
  14. SCALING_SUMMARY.md +1 -0
  15. TOGGLE_GUIDE.md +240 -0
  16. app.py +1 -0
  17. auto_deploy.sh +1 -0
  18. backend/api/main.py +88 -17
  19. backend/requirements.txt +1 -0
  20. backend/scripts/precompute_network.py +211 -0
  21. backend/scripts/upload_network_to_hf.py +130 -0
  22. backend/utils/chunked_loader.py +1 -0
  23. backend/utils/precomputed_loader.py +85 -0
  24. check_and_deploy.sh +1 -0
  25. frontend/src/App.css +82 -12
  26. frontend/src/App.tsx +220 -35
  27. frontend/src/components/controls/EdgeTypeFilter.css +50 -11
  28. frontend/src/components/controls/EdgeTypeFilter.tsx +1 -0
  29. frontend/src/components/controls/ForceParameterControls.css +41 -13
  30. frontend/src/components/controls/ForceParameterControls.tsx +1 -0
  31. frontend/src/components/controls/ThemeToggle.tsx +20 -2
  32. frontend/src/components/visualizations/ForceDirectedGraph.css +44 -2
  33. frontend/src/components/visualizations/ForceDirectedGraph3D.tsx +162 -24
  34. frontend/src/components/visualizations/ForceDirectedGraph3DInstanced.tsx +166 -39
  35. frontend/src/stores/filterStore.ts +7 -3
  36. frontend/src/utils/api/graphApi.ts +142 -36
  37. precomputed_data/network_metadata.json +10 -0
  38. requirements.txt +1 -0
  39. upload_to_hf_dataset.py +1 -0
APP_ANALYSIS.md CHANGED
@@ -269,3 +269,4 @@
269
  **Last Updated**: Based on current codebase analysis
270
  **Status**: Ready for deployment pending data generation completion
271
 
 
 
269
  **Last Updated**: Based on current codebase analysis
270
  **Status**: Ready for deployment pending data generation completion
271
 
272
+
DEPLOYMENT_COMPLETE.md CHANGED
@@ -178,3 +178,4 @@ ps aux | grep precompute_data.py
178
 
179
  The chunked embedding system is fully deployed and ready. The server will automatically use chunked mode once production data completes. You can start using it now with test data!
180
 
 
 
178
 
179
  The chunked embedding system is fully deployed and ready. The server will automatically use chunked mode once production data completes. You can start using it now with test data!
180
 
181
+
DEPLOYMENT_STATUS.md CHANGED
@@ -134,3 +134,4 @@ nohup python scripts/precompute_data.py --sample-size 0 --chunked --chunk-size 5
134
 
135
  The chunked embedding system is fully implemented and tested. The full precompute is running and will complete in a few hours. Once complete, the server will automatically use chunked mode for fast startup and efficient memory usage.
136
 
 
 
134
 
135
  The chunked embedding system is fully implemented and tested. The full precompute is running and will complete in a few hours. Once complete, the server will automatically use chunked mode for fast startup and efficient memory usage.
136
 
137
+
DEPLOY_TO_HF_SPACES.md CHANGED
@@ -159,3 +159,4 @@ When you update data:
159
 
160
  **Note**: The Space automatically downloads chunked data from the Hugging Face Dataset. No need to include data files in the Space repository!
161
 
 
 
159
 
160
  **Note**: The Space automatically downloads chunked data from the Hugging Face Dataset. No need to include data files in the Space repository!
161
 
162
+
FDG_DESIGN_ANALYSIS.md ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Force-Directed Graph Design Analysis
2
+
3
+ ## ✅ IMPLEMENTATION COMPLETE
4
+
5
+ All priority improvements have been implemented. The force-directed graph is now fully harmonized with the embeddings view.
6
+
7
+ ---
8
+
9
+ ## Current State (After Improvements)
10
+
11
+ ### ✅ What Works Well
12
+
13
+ 1. **3D Visualization**
14
+ - Fully 3D using Three.js/React Three Fiber
15
+ - Same Canvas setup as embeddings view
16
+ - OrbitControls for navigation (pan, zoom, rotate)
17
+ - Consistent camera controls
18
+
19
+ 2. **Design Consistency**
20
+ - Same background color (`#1a1a1a`)
21
+ - Same hover/selection states (red for selected, yellow for hovered, cyan for highlighted)
22
+ - Same container styling
23
+ - Consistent with app theme
24
+
25
+ 3. **Performance**
26
+ - Instanced rendering for large graphs (>10k nodes)
27
+ - Efficient filtering by edge types, family, and search
28
+ - Optimized for up to 500k nodes
29
+
30
+ 4. **✅ Full Filtering (NEW)**
31
+ - Edge type filter (show/hide relationship types)
32
+ - Family/organization filter dropdown
33
+ - Search by model ID with highlighting
34
+ - Nodes filtered to only show connected nodes when edge types are filtered
35
+
36
+ 5. **✅ Color/Style Options (NEW)**
37
+ - Color By: ML Library, Task Type, Downloads, Likes, Edge Type
38
+ - Size By: Downloads, Likes, Uniform Size
39
+ - Color Schemes: Viridis, Plasma, Inferno, Cool-Warm
40
+ - Uses same color utilities as embeddings view
41
+
42
+ ## Comparison: Embeddings vs FDG (Updated)
43
+
44
+ | Feature | Embeddings View | FDG View | Status |
45
+ |---------|----------------|----------|--------|
46
+ | 3D Visualization | ✅ | ✅ | ✅ Consistent |
47
+ | Color By Options | ✅ (5 options) | ✅ (5 options) | ✅ Harmonized |
48
+ | Size By Options | ✅ (3 options) | ✅ (3 options) | ✅ Harmonized |
49
+ | Color Schemes | ✅ (4 options) | ✅ (4 options) | ✅ Harmonized |
50
+ | Search Integration | ✅ | ✅ | ✅ Added |
51
+ | Filter by Family | ✅ | ✅ | ✅ Added |
52
+ | Highlight Model | ✅ | ✅ | ✅ Added |
53
+ | Edge Type Filter | N/A | ✅ | ✅ Unique |
54
+ | Force Parameters | N/A | ✅ | ✅ Unique |
55
+
56
+ ## Implementation Status
57
+
58
+ ### ✅ Phase 1: Design Harmony (COMPLETE)
59
+ - [x] Add Color By selector to FDG controls
60
+ - [x] Add Size By selector to FDG controls
61
+ - [x] Integrate color utilities from embeddings view
62
+ - [x] Add color scheme selector for gradients
63
+
64
+ ### ✅ Phase 2: Search & Filtering (COMPLETE)
65
+ - [x] Add search bar to FDG view
66
+ - [x] Implement model search with highlighting
67
+ - [x] Add family/organization filter dropdown
68
+ - [x] Filters apply to both nodes AND edges
69
+
70
+ ### 📋 Phase 3: Future Enhancements (Optional)
71
+ - [ ] Implement N-hop neighbor filtering
72
+ - [ ] Add "focus on node" mode (zoom to node)
73
+ - [ ] Add path highlighting between nodes
74
+ - [ ] Add subgraph isolation controls
75
+
76
+ ## Files Modified
77
+
78
+ - `ForceDirectedGraph3D.tsx` - Added colorBy, sizeBy, colorScheme, familyFilter, searchQuery, highlightedNodeId props
79
+ - `ForceDirectedGraph3DInstanced.tsx` - Same props + filtering logic
80
+ - `App.tsx` - Added controls UI, state management, props passing
81
+ - `App.css` - Added graph search input styles
82
+
83
+ ## User Guide
84
+
85
+ ### Controls in Force-directed Graph Mode:
86
+
87
+ **Color By** - Change node colors
88
+ - ML Library: Different colors per framework
89
+ - Task Type: Different colors per pipeline tag
90
+ - Downloads/Likes: Gradient color scale
91
+ - Edge Type: Color by relationship type
92
+
93
+ **Size By** - Change node sizes
94
+ - By Downloads: Larger = more downloads
95
+ - By Likes: Larger = more likes
96
+ - Uniform: Same size for all
97
+
98
+ **Family Filter** - Filter by organization
99
+ - Select from top 100 organizations
100
+ - Filters both nodes AND edges
101
+
102
+ **Search** - Find and highlight models
103
+ - Type to filter matching nodes
104
+ - First match is highlighted in cyan
105
+ - Clear button to reset
106
+
107
+ **Edge Types** - Show/hide relationship types
108
+ - Fine-tuned, Quantized, Adapter, Merged, Parent
109
+
110
+ **Settings** - Force simulation parameters
111
+ - Link Distance, Charge Strength, Collision Radius
112
+ - Node Size Multiplier, Edge Opacity
113
+
FORCE_DIRECTED_STATUS.md CHANGED
@@ -167,3 +167,4 @@
167
  - Force parameters (hardcoded): `ForceDirectedGraph.tsx` lines 148-179
168
  - Edge type controls (reference): `GraphPage.tsx` lines 562-598
169
 
 
 
167
  - Force parameters (hardcoded): `ForceDirectedGraph.tsx` lines 148-179
168
  - Edge type controls (reference): `GraphPage.tsx` lines 562-598
169
 
170
+
HF_SPACES_DEPLOYMENT.md CHANGED
@@ -228,3 +228,4 @@ When you need to update the precomputed data:
228
 
229
  **Note**: The Space will automatically download chunked data from the Hugging Face Dataset on startup. No manual data upload to the Space repository is needed!
230
 
 
 
228
 
229
  **Note**: The Space will automatically download chunked data from the Hugging Face Dataset on startup. No manual data upload to the Space repository is needed!
230
 
231
+
HF_SPACES_READY.md CHANGED
@@ -150,3 +150,4 @@ After deployment, check:
150
 
151
  **Everything is ready!** Once the precompute completes and data is uploaded, you can deploy to Hugging Face Spaces and it will work without any local access needed.
152
 
 
 
150
 
151
  **Everything is ready!** Once the precompute completes and data is uploaded, you can deploy to Hugging Face Spaces and it will work without any local access needed.
152
 
153
+
HOW_TO_RUN.md CHANGED
@@ -115,3 +115,4 @@ Press `Ctrl+C` in the terminal where the server is running, or:
115
  pkill -f "uvicorn api.main:app"
116
  ```
117
 
 
 
115
  pkill -f "uvicorn api.main:app"
116
  ```
117
 
118
+
PRODUCTION_DEPLOYMENT.md CHANGED
@@ -219,3 +219,4 @@ After successful deployment:
219
  - `SCALING_QUICKSTART.md` - Quick start guide
220
  - `SCALING_SUMMARY.md` - Implementation summary
221
 
 
 
219
  - `SCALING_QUICKSTART.md` - Quick start guide
220
  - `SCALING_SUMMARY.md` - Implementation summary
221
 
222
+
README_SPACE.md CHANGED
@@ -76,3 +76,4 @@ This Space automatically:
76
  - **Paper**: [arXiv:2508.06811](https://arxiv.org/abs/2508.06811)
77
  - **Dataset**: [modelbiome/ai_ecosystem](https://huggingface.co/datasets/modelbiome/ai_ecosystem)
78
 
 
 
76
  - **Paper**: [arXiv:2508.06811](https://arxiv.org/abs/2508.06811)
77
  - **Dataset**: [modelbiome/ai_ecosystem](https://huggingface.co/datasets/modelbiome/ai_ecosystem)
78
 
79
+
SCALING_EMBEDDINGS_STRATEGY.md CHANGED
@@ -287,3 +287,4 @@ async def get_models(
287
  4. **Compression**: Use better compression (zstd) for parquet files
288
  5. **Quantization**: Use int8 embeddings (50% memory reduction)
289
 
 
 
287
  4. **Compression**: Use better compression (zstd) for parquet files
288
  5. **Quantization**: Use int8 embeddings (50% memory reduction)
289
 
290
+
SCALING_QUICKSTART.md CHANGED
@@ -149,3 +149,4 @@ Follow the complete strategy in `SCALING_EMBEDDINGS_STRATEGY.md`:
149
  3. Start with Option A (minimal changes) for quick wins
150
  4. Gradually implement Option B for full optimization
151
 
 
 
149
  3. Start with Option A (minimal changes) for quick wins
150
  4. Gradually implement Option B for full optimization
151
 
152
+
SCALING_SUMMARY.md CHANGED
@@ -200,3 +200,4 @@ embeddings, found_ids = chunked_loader.load_embeddings_for_models(filtered_model
200
 
201
  See `SCALING_EMBEDDINGS_STRATEGY.md` for details.
202
 
 
 
200
 
201
  See `SCALING_EMBEDDINGS_STRATEGY.md` for details.
202
 
203
+
TOGGLE_GUIDE.md ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Toggle Guide - What Each Toggle Does
2
+
3
+ ## Navigation Tabs (Left Sidebar)
4
+
5
+ ### 1. **Visualization** Tab (Default)
6
+ **What it does:**
7
+ - Shows the main 3D interactive visualization
8
+ - Displays models as points in a 3D embedding space
9
+ - Models closer together are more similar (based on embeddings)
10
+ - You can zoom, pan, rotate, and click on models
11
+
12
+ **What you see:**
13
+ - 3D scatter plot with thousands of models
14
+ - Color-coded by your selected attribute (library, pipeline, downloads, etc.)
15
+ - Size varies based on downloads/likes (if enabled)
16
+ - Mini-map in bottom-right corner
17
+
18
+ ---
19
+
20
+ ### 2. **Families** Tab
21
+ **What it does:**
22
+ - Switches to a different view focused on model families/organizations
23
+ - Shows adoption curves and family statistics
24
+ - Groups models by organization (e.g., "meta-llama", "google", "microsoft")
25
+
26
+ **What you see:**
27
+ - List of top model families by count
28
+ - Adoption curves showing how families grew over time
29
+ - Comparison mode to compare top 5 families
30
+ - Family depth distribution
31
+
32
+ **Key Features:**
33
+ - **Compare Top 5 Toggle**: Switches between single family view and comparison of top 5 families
34
+ - Click on a family to see its adoption curve
35
+ - Shows how models in each family were created over time
36
+
37
+ ---
38
+
39
+ ### 3. **Analytics** Tab
40
+ **What it does:**
41
+ - Shows statistics and rankings
42
+ - Displays top models by different metrics
43
+ - Shows trends and growth rates
44
+
45
+ **What you see:**
46
+ - Top models by downloads
47
+ - Top models by likes
48
+ - Trending models
49
+ - Newest models
50
+ - Largest families
51
+ - Fastest growing families (with growth rate %)
52
+
53
+ **Key Features:**
54
+ - Time range selector (24h, 7d, 30d) - filters data by time period
55
+ - Growth rate calculation shows % of models created in last 30 days
56
+
57
+ ---
58
+
59
+ ## Visualization Mode Toggle (Top Control Bar)
60
+
61
+ ### **Embeddings** Mode (Default)
62
+ **What it does:**
63
+ - Shows models in semantic embedding space
64
+ - Uses UMAP coordinates (pre-computed 3D positions)
65
+ - Models positioned based on similarity (tags, descriptions, metadata)
66
+
67
+ **Controls available:**
68
+ - **Color By**: Change what determines point color
69
+ - Family Depth
70
+ - ML Library (transformers, pytorch, etc.)
71
+ - Task Type (text-generation, image-classification, etc.)
72
+ - Downloads (gradient)
73
+ - Likes (gradient)
74
+
75
+ - **Size By**: Change what determines point size
76
+ - Downloads
77
+ - Likes
78
+ - Uniform Size
79
+
80
+ - **Show All Models**: Toggle between sampled (150k) and all models
81
+
82
+ **What changes:**
83
+ - Point colors change based on selected attribute
84
+ - Point sizes change based on selected metric
85
+ - Visual grouping changes (e.g., libraries cluster together)
86
+
87
+ ---
88
+
89
+ ### **Force-directed Graph** Mode
90
+ **What it does:**
91
+ - Shows model relationships as a network graph
92
+ - Displays parent-child relationships (fine-tuning, quantization, etc.)
93
+ - Uses force-directed layout (models connected by edges)
94
+
95
+ **Controls available:**
96
+ - **Edge Type Filter**: Toggle which relationship types to show
97
+ - Fine-tuned (blue) - models fine-tuned from a base model
98
+ - Quantized (green) - quantized versions of models
99
+ - Adapter (orange) - models with adapters added
100
+ - Merged (purple) - merged models
101
+ - Parent (gray) - generic parent relationships
102
+
103
+ - **Force Parameters** (Settings icon):
104
+ - Link Distance: How far apart connected models are (50-200)
105
+ - Charge Strength: Repulsion between models (-500 to -100)
106
+ - Collision Radius: How much models avoid overlapping (0.5x to 2x)
107
+ - Node Size: Size multiplier for nodes (0.5x to 2x)
108
+ - Edge Opacity: Transparency of edges (0.1 to 1.0)
109
+
110
+ **What changes:**
111
+ - Graph layout changes based on force parameters
112
+ - Different relationship types appear/disappear based on edge type filter
113
+ - Node sizes and edge visibility adjust
114
+
115
+ ---
116
+
117
+ ## Color By Options (Embeddings Mode)
118
+
119
+ ### **Family Depth**
120
+ - Colors models by their depth in the family tree
121
+ - Base models (depth 0) vs. fine-tuned models (depth 1, 2, 3...)
122
+ - Shows how models are related hierarchically
123
+
124
+ ### **ML Library**
125
+ - Colors by library: transformers, pytorch, tensorflow, diffusers, etc.
126
+ - Each library gets a distinct color
127
+ - Shows which libraries are most popular
128
+
129
+ ### **Task Type** (Pipeline Tag)
130
+ - Colors by task: text-generation, image-classification, etc.
131
+ - Groups models by what they're designed to do
132
+ - Shows task distribution in the ecosystem
133
+
134
+ ### **Downloads** / **Likes**
135
+ - Uses a color gradient (viridis, plasma, inferno, cool-warm)
136
+ - Darker/lighter colors represent higher/lower values
137
+ - Shows popularity distribution
138
+
139
+ ---
140
+
141
+ ## Size By Options (Embeddings Mode)
142
+
143
+ ### **By Downloads**
144
+ - Larger points = more downloads
145
+ - Logarithmic scaling (so differences are visible)
146
+ - Popular models stand out visually
147
+
148
+ ### **By Likes**
149
+ - Larger points = more likes
150
+ - Shows community favorites
151
+ - Similar to downloads but reflects user engagement
152
+
153
+ ### **Uniform Size**
154
+ - All points same size
155
+ - Useful when you want to focus on color patterns
156
+ - Better for seeing density patterns
157
+
158
+ ---
159
+
160
+ ## Edge Type Filter (Force-directed Graph Mode)
161
+
162
+ Each toggle button controls which relationship types are visible:
163
+
164
+ - **Fine-tuned** (Blue): Shows fine-tuning relationships
165
+ - When OFF: Hides all fine-tuning edges
166
+ - When ON: Shows fine-tuning connections
167
+
168
+ - **Quantized** (Green): Shows quantization relationships
169
+ - When OFF: Hides quantized model connections
170
+ - When ON: Shows quantization relationships
171
+
172
+ - **Adapter** (Orange): Shows adapter-based models
173
+ - When OFF: Hides adapter relationships
174
+ - When ON: Shows adapter connections
175
+
176
+ - **Merged** (Purple): Shows merged models
177
+ - When OFF: Hides merge relationships
178
+ - When ON: Shows merged model connections
179
+
180
+ - **Parent** (Gray): Shows generic parent relationships
181
+ - When OFF: Hides parent connections
182
+ - When ON: Shows parent-child relationships
183
+
184
+ **Effect:**
185
+ - Graph becomes simpler/hidden when types are disabled
186
+ - Helps focus on specific relationship types
187
+ - Reduces visual clutter
188
+
189
+ ---
190
+
191
+ ## Force Parameters (Force-directed Graph Mode)
192
+
193
+ ### **Link Distance** (50-200)
194
+ - Controls how far apart connected nodes are
195
+ - Higher = more spread out graph
196
+ - Lower = more compact, clustered graph
197
+
198
+ ### **Charge Strength** (-500 to -100)
199
+ - Controls repulsion between nodes
200
+ - More negative = stronger repulsion (nodes push apart)
201
+ - Less negative = weaker repulsion (nodes can cluster)
202
+
203
+ ### **Collision Radius** (0.5x to 2x)
204
+ - Controls how much nodes avoid overlapping
205
+ - Higher = more spacing between nodes
206
+ - Lower = nodes can get closer together
207
+
208
+ ### **Node Size** (0.5x to 2x)
209
+ - Multiplies the size of all nodes
210
+ - Useful for adjusting visibility
211
+ - Doesn't change graph layout, just appearance
212
+
213
+ ### **Edge Opacity** (0.1 to 1.0)
214
+ - Controls transparency of edges
215
+ - Lower = more transparent (less visual clutter)
216
+ - Higher = more visible edges
217
+ - Useful for dense graphs
218
+
219
+ ---
220
+
221
+ ## Summary
222
+
223
+ **Navigation Tabs:**
224
+ - **Visualization**: Main 3D embedding space view
225
+ - **Families**: Family/organization analysis with adoption curves
226
+ - **Analytics**: Statistics, rankings, and trends
227
+
228
+ **Visualization Modes:**
229
+ - **Embeddings**: Semantic similarity space (default)
230
+ - **Force-directed Graph**: Relationship network view
231
+
232
+ **Key Toggles:**
233
+ - Color/Size controls change visual encoding
234
+ - Edge type filters show/hide relationship types
235
+ - Force parameters adjust graph layout
236
+ - Comparison mode (Families page) switches between single/comparison view
237
+
238
+ All toggles update the visualization in real-time without page reload!
239
+
240
+
app.py CHANGED
@@ -23,3 +23,4 @@ if __name__ == "__main__":
23
  port = int(os.environ.get("PORT", 7860))
24
  uvicorn.run(app, host="0.0.0.0", port=port)
25
 
 
 
23
  port = int(os.environ.get("PORT", 7860))
24
  uvicorn.run(app, host="0.0.0.0", port=port)
25
 
26
+
auto_deploy.sh CHANGED
@@ -100,3 +100,4 @@ main() {
100
 
101
  main "$@"
102
 
 
 
100
 
101
  main "$@"
102
 
103
+
backend/api/main.py CHANGED
@@ -1638,7 +1638,10 @@ async def get_family_network(
1638
  async def get_full_derivative_network(
1639
  edge_types: Optional[str] = Query(None, description="Comma-separated list of edge types to include (finetune,quantized,adapter,merge,parent). If None, includes all types."),
1640
  include_edge_attributes: bool = Query(False, description="Whether to include edge attributes (change in likes, downloads, etc.). Default False for performance."),
1641
- include_positions: bool = Query(True, description="Whether to include pre-computed 3D positions for each node. Default True for faster rendering.")
 
 
 
1642
  ):
1643
  """
1644
  Build full derivative relationship network for ALL models in the database.
@@ -1656,8 +1659,8 @@ async def get_full_derivative_network(
1656
 
1657
  try:
1658
  import time
 
1659
  start_time = time.time()
1660
- logger.info(f"Building full derivative network for {len(deps.df):,} models...")
1661
 
1662
  # Check if dataframe has required columns
1663
  required_columns = ['model_id']
@@ -1668,26 +1671,94 @@ async def get_full_derivative_network(
1668
  detail=f"Missing required columns: {missing_columns}"
1669
  )
1670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1671
  filter_types = None
1672
  if edge_types:
1673
  filter_types = [t.strip() for t in edge_types.split(',') if t.strip()]
1674
 
1675
- try:
1676
- network_builder = ModelNetworkBuilder(deps.df)
1677
- logger.info("Calling build_full_derivative_network...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1678
 
1679
- # Disable edge attributes for very large graphs to improve performance
1680
- # They can be slow to compute for 100k+ edges
1681
- graph = network_builder.build_full_derivative_network(
1682
- include_edge_attributes=include_edge_attributes,
1683
- filter_edge_types=filter_types
1684
- )
1685
- except Exception as build_error:
1686
- logger.error(f"Error in build_full_derivative_network: {build_error}", exc_info=True)
1687
- raise HTTPException(
1688
- status_code=500,
1689
- detail=f"Failed to build network graph: {str(build_error)}"
1690
- )
1691
 
1692
  build_time = time.time() - start_time
1693
  logger.info(f"Graph built in {build_time:.2f}s: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges")
 
1638
  async def get_full_derivative_network(
1639
  edge_types: Optional[str] = Query(None, description="Comma-separated list of edge types to include (finetune,quantized,adapter,merge,parent). If None, includes all types."),
1640
  include_edge_attributes: bool = Query(False, description="Whether to include edge attributes (change in likes, downloads, etc.). Default False for performance."),
1641
+ include_positions: bool = Query(True, description="Whether to include pre-computed 3D positions for each node. Default True for faster rendering."),
1642
+ min_downloads: int = Query(0, description="Minimum downloads to include a model. Use this to reduce network size."),
1643
+ max_nodes: Optional[int] = Query(None, ge=100, le=1000000, description="Maximum number of nodes to include. Models are sorted by downloads. Use this to reduce network size."),
1644
+ use_precomputed: bool = Query(True, description="Try to load pre-computed network graph from disk if available.")
1645
  ):
1646
  """
1647
  Build full derivative relationship network for ALL models in the database.
 
1659
 
1660
  try:
1661
  import time
1662
+ import networkx as nx
1663
  start_time = time.time()
 
1664
 
1665
  # Check if dataframe has required columns
1666
  required_columns = ['model_id']
 
1671
  detail=f"Missing required columns: {missing_columns}"
1672
  )
1673
 
1674
+ # Try to load pre-computed network graph
1675
+ graph = None
1676
+ if use_precomputed:
1677
+ try:
1678
+ backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
1679
+ root_dir = os.path.dirname(backend_dir)
1680
+ precomputed_dir = os.path.join(root_dir, "precomputed_data")
1681
+ graph_file = os.path.join(precomputed_dir, "full_derivative_network.pkl")
1682
+
1683
+ # Try to download from HF Hub if not found locally (for Spaces deployment)
1684
+ if not os.path.exists(graph_file):
1685
+ logger.info("Pre-computed network not found locally. Attempting to download from HF Hub...")
1686
+ from utils.precomputed_loader import download_network_from_hf_hub
1687
+ download_network_from_hf_hub(precomputed_dir, version="v1")
1688
+
1689
+ if os.path.exists(graph_file):
1690
+ logger.info(f"Loading pre-computed network graph from {graph_file}...")
1691
+ with open(graph_file, 'rb') as f:
1692
+ graph = pickle.load(f)
1693
+ logger.info(f"Loaded pre-computed graph: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges")
1694
+ else:
1695
+ logger.info("Pre-computed network graph not available. Will build from scratch.")
1696
+ except Exception as e:
1697
+ logger.warning(f"Could not load pre-computed network graph: {e}. Will build from scratch.")
1698
+
1699
+ # Filter dataframe if needed
1700
+ filtered_df = deps.df.copy()
1701
+ if min_downloads > 0:
1702
+ filtered_df = filtered_df[filtered_df.get('downloads', 0) >= min_downloads]
1703
+ logger.info(f"Filtered to {len(filtered_df):,} models with >= {min_downloads} downloads")
1704
+
1705
+ if max_nodes and len(filtered_df) > max_nodes:
1706
+ # Sort by downloads and take top N
1707
+ filtered_df = filtered_df.nlargest(max_nodes, 'downloads', keep='first')
1708
+ logger.info(f"Limited to top {max_nodes:,} models by downloads")
1709
+
1710
+ logger.info(f"Building full derivative network for {len(filtered_df):,} models...")
1711
+
1712
  filter_types = None
1713
  if edge_types:
1714
  filter_types = [t.strip() for t in edge_types.split(',') if t.strip()]
1715
 
1716
+ # Build graph if not loaded from disk
1717
+ if graph is None:
1718
+ try:
1719
+ network_builder = ModelNetworkBuilder(filtered_df)
1720
+ logger.info("Calling build_full_derivative_network...")
1721
+
1722
+ # Disable edge attributes for very large graphs to improve performance
1723
+ # They can be slow to compute for 100k+ edges
1724
+ graph = network_builder.build_full_derivative_network(
1725
+ include_edge_attributes=include_edge_attributes,
1726
+ filter_edge_types=filter_types
1727
+ )
1728
+ except Exception as build_error:
1729
+ logger.error(f"Error in build_full_derivative_network: {build_error}", exc_info=True)
1730
+ raise HTTPException(
1731
+ status_code=500,
1732
+ detail=f"Failed to build network graph: {str(build_error)}"
1733
+ )
1734
+ else:
1735
+ # Filter pre-computed graph if needed
1736
+ if filter_types:
1737
+ # Remove edges that don't match filter
1738
+ edges_to_remove = []
1739
+ for source, target, attrs in graph.edges(data=True):
1740
+ edge_types_list = attrs.get('edge_types', [])
1741
+ if not isinstance(edge_types_list, list):
1742
+ edge_types_list = [edge_types_list] if edge_types_list else []
1743
+ if not any(et in filter_types for et in edge_types_list):
1744
+ edges_to_remove.append((source, target))
1745
+ graph.remove_edges_from(edges_to_remove)
1746
+ # Remove isolated nodes
1747
+ isolated = list(nx.isolates(graph))
1748
+ graph.remove_nodes_from(isolated)
1749
+ logger.info(f"Filtered graph: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges")
1750
 
1751
+ # Filter nodes by downloads if needed
1752
+ if min_downloads > 0 or max_nodes:
1753
+ nodes_to_remove = []
1754
+ for node_id in graph.nodes():
1755
+ if node_id in filtered_df.index:
1756
+ continue
1757
+ nodes_to_remove.append(node_id)
1758
+ graph.remove_nodes_from(nodes_to_remove)
1759
+ isolated = list(nx.isolates(graph))
1760
+ graph.remove_nodes_from(isolated)
1761
+ logger.info(f"Filtered graph by model selection: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges")
 
1762
 
1763
  build_time = time.time() - start_time
1764
  logger.info(f"Graph built in {build_time:.2f}s: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges")
backend/requirements.txt CHANGED
@@ -14,4 +14,5 @@ pydantic>=2.0.0
14
  ormsgpack>=1.2.0
15
  rapidfuzz>=3.0.0
16
  pyarrow>=12.0.0
 
17
 
 
14
  ormsgpack>=1.2.0
15
  rapidfuzz>=3.0.0
16
  pyarrow>=12.0.0
17
+ networkx>=3.0.0
18
 
backend/scripts/precompute_network.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pre-compute the full derivative network graph and save it to disk.
3
+ This allows the API to load the network instantly instead of building it on-demand.
4
+
5
+ Usage:
6
+ python scripts/precompute_network.py [--output-dir precomputed_data] [--version v1]
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import pickle
12
+ import argparse
13
+ import logging
14
+ import time
15
+ from pathlib import Path
16
+ from typing import Optional
17
+
18
+ # Add backend to path
19
+ sys.path.insert(0, str(Path(__file__).parent.parent))
20
+
21
+ import pandas as pd
22
+ from utils.network_analysis import ModelNetworkBuilder
23
+ from utils.precomputed_loader import PrecomputedDataLoader
24
+ from utils.data_loader import ModelDataLoader
25
+
26
+ logging.basicConfig(
27
+ level=logging.INFO,
28
+ format='%(asctime)s - %(levelname)s - %(message)s'
29
+ )
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def precompute_network(
34
+ output_dir: str = "precomputed_data",
35
+ version: str = "v1",
36
+ include_edge_attributes: bool = False,
37
+ min_downloads: int = 0,
38
+ max_nodes: Optional[int] = None,
39
+ load_from_hf: bool = False,
40
+ sample_size: Optional[int] = None
41
+ ):
42
+ """
43
+ Pre-compute the full derivative network graph for the force-directed visualization.
44
+
45
+ Args:
46
+ output_dir: Directory to save the network file
47
+ version: Version tag for the data
48
+ include_edge_attributes: Whether to calculate edge attributes
49
+ min_downloads: Minimum downloads to include a model
50
+ max_nodes: Maximum number of nodes (top N by downloads)
51
+ load_from_hf: If True, load directly from HF dataset (includes parent relationships)
52
+ sample_size: If load_from_hf=True, sample this many models (None = all models)
53
+ """
54
+ start_time = time.time()
55
+
56
+ # Create output directory
57
+ output_path = Path(output_dir)
58
+ output_path.mkdir(parents=True, exist_ok=True)
59
+
60
+ logger.info("=" * 60)
61
+ logger.info("PRE-COMPUTING FULL DERIVATIVE NETWORK")
62
+ logger.info("=" * 60)
63
+
64
+ # Step 1: Load model data
65
+ logger.info("Step 1/3: Loading model data...")
66
+
67
+ if load_from_hf:
68
+ # Load directly from HF dataset (includes parent relationships)
69
+ logger.info(f"Loading directly from Hugging Face dataset (sample_size={sample_size if sample_size else 'ALL'})...")
70
+ data_loader = ModelDataLoader()
71
+ df = data_loader.load_data(sample_size=sample_size, prioritize_base_models=False)
72
+ df = data_loader.preprocess_for_embedding(df)
73
+
74
+ # Ensure model_id is set as index
75
+ if 'model_id' in df.columns:
76
+ df.set_index('model_id', drop=False, inplace=True)
77
+
78
+ # Ensure numeric columns
79
+ for col in ['downloads', 'likes']:
80
+ if col in df.columns:
81
+ df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0).astype(int)
82
+
83
+ logger.info(f"Loaded {len(df):,} models from HF dataset")
84
+ logger.info(f"Columns: {list(df.columns)}")
85
+
86
+ # Check if parent_model column exists (needed for network edges)
87
+ if 'parent_model' not in df.columns:
88
+ logger.warning("'parent_model' column not found - network will have 0 edges!")
89
+ else:
90
+ parent_count = df['parent_model'].notna().sum()
91
+ logger.info(f"Models with parent relationships: {parent_count:,}")
92
+ else:
93
+ # Load from pre-computed files
94
+ loader = PrecomputedDataLoader(data_dir=output_dir, version=version)
95
+
96
+ if not loader.check_available():
97
+ logger.error(f"Pre-computed data not found in {output_dir}")
98
+ logger.info("Please run precompute_data.py first, download from HF Hub, or use --load-from-hf flag")
99
+ return False
100
+
101
+ try:
102
+ df, embeddings, metadata = loader.load_all()
103
+ logger.info(f"Loaded {len(df):,} models from pre-computed data")
104
+
105
+ # Check if parent_model column exists
106
+ if 'parent_model' not in df.columns:
107
+ logger.warning("'parent_model' column not found in pre-computed data - network will have 0 edges!")
108
+ except Exception as e:
109
+ logger.error(f"Failed to load data: {e}")
110
+ return False
111
+
112
+ # Step 2: Filter data if needed
113
+ if min_downloads > 0:
114
+ df = df[df.get('downloads', 0) >= min_downloads]
115
+ logger.info(f"Filtered to {len(df):,} models with >= {min_downloads} downloads")
116
+
117
+ if max_nodes and len(df) > max_nodes:
118
+ df = df.nlargest(max_nodes, 'downloads', keep='first')
119
+ logger.info(f"Limited to top {max_nodes:,} models by downloads")
120
+
121
+ # Step 3: Build network graph
122
+ logger.info("Step 2/3: Building network graph (this may take 10-30 minutes)...")
123
+ try:
124
+ network_builder = ModelNetworkBuilder(df)
125
+ graph = network_builder.build_full_derivative_network(
126
+ include_edge_attributes=include_edge_attributes,
127
+ filter_edge_types=None # Include all edge types
128
+ )
129
+
130
+ logger.info(f"Graph built: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges")
131
+ except Exception as e:
132
+ logger.error(f"Failed to build network graph: {e}", exc_info=True)
133
+ return False
134
+
135
+ # Step 4: Save network graph
136
+ logger.info("Step 3/3: Saving network graph to disk...")
137
+ network_file = output_path / "full_derivative_network.pkl"
138
+
139
+ try:
140
+ with open(network_file, 'wb') as f:
141
+ pickle.dump(graph, f, protocol=pickle.HIGHEST_PROTOCOL)
142
+
143
+ file_size_mb = network_file.stat().st_size / (1024 * 1024)
144
+ logger.info(f"Saved network graph to {network_file}")
145
+ logger.info(f"File size: {file_size_mb:.2f} MB")
146
+ except Exception as e:
147
+ logger.error(f"Failed to save network graph: {e}", exc_info=True)
148
+ return False
149
+
150
+ # Save metadata
151
+ metadata_file = output_path / "network_metadata.json"
152
+ import json
153
+ from datetime import datetime
154
+
155
+ network_metadata = {
156
+ "created_at": datetime.now().isoformat(),
157
+ "version": version,
158
+ "nodes": graph.number_of_nodes(),
159
+ "edges": graph.number_of_edges(),
160
+ "include_edge_attributes": include_edge_attributes,
161
+ "min_downloads": min_downloads,
162
+ "max_nodes": max_nodes,
163
+ "file_size_mb": round(file_size_mb, 2)
164
+ }
165
+
166
+ with open(metadata_file, 'w') as f:
167
+ json.dump(network_metadata, f, indent=2)
168
+
169
+ total_time = time.time() - start_time
170
+ logger.info("=" * 60)
171
+ logger.info(f"PRE-COMPUTATION COMPLETE in {total_time:.2f} seconds")
172
+ logger.info(f"Network graph: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges")
173
+ logger.info(f"Saved to: {network_file}")
174
+ logger.info("=" * 60)
175
+
176
+ return True
177
+
178
+
179
+ if __name__ == "__main__":
180
+ import time
181
+
182
+ parser = argparse.ArgumentParser(description="Pre-compute full derivative network graph")
183
+ parser.add_argument("--output-dir", type=str, default="precomputed_data",
184
+ help="Output directory for pre-computed files")
185
+ parser.add_argument("--version", type=str, default="v1",
186
+ help="Version tag for the data")
187
+ parser.add_argument("--include-edge-attributes", action="store_true",
188
+ help="Include edge attributes (slower but more detailed)")
189
+ parser.add_argument("--min-downloads", type=int, default=0,
190
+ help="Minimum downloads to include a model")
191
+ parser.add_argument("--max-nodes", type=int, default=None,
192
+ help="Maximum number of nodes (top N by downloads)")
193
+ parser.add_argument("--load-from-hf", action="store_true",
194
+ help="Load directly from HF dataset instead of pre-computed files (includes parent relationships)")
195
+ parser.add_argument("--sample-size", type=int, default=None,
196
+ help="If --load-from-hf, sample this many models (default: all models, use 0 for all)")
197
+
198
+ args = parser.parse_args()
199
+
200
+ success = precompute_network(
201
+ output_dir=args.output_dir,
202
+ version=args.version,
203
+ include_edge_attributes=args.include_edge_attributes,
204
+ min_downloads=args.min_downloads,
205
+ max_nodes=args.max_nodes,
206
+ load_from_hf=args.load_from_hf,
207
+ sample_size=None if args.sample_size == 0 else args.sample_size
208
+ )
209
+
210
+ sys.exit(0 if success else 1)
211
+
backend/scripts/upload_network_to_hf.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Upload pre-computed network graph to Hugging Face Hub dataset.
3
+
4
+ Usage:
5
+ python scripts/upload_network_to_hf.py [--network-file precomputed_data/full_derivative_network.pkl] [--version v1]
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import argparse
11
+ import logging
12
+ from pathlib import Path
13
+
14
+ # Add backend to path
15
+ sys.path.insert(0, str(Path(__file__).parent.parent))
16
+
17
+ from utils.precomputed_loader import HF_PRECOMPUTED_DATASET
18
+
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format='%(asctime)s - %(levelname)s - %(message)s'
22
+ )
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def upload_network_to_hf(
27
+ network_file: str,
28
+ version: str = "v1",
29
+ dataset_id: str = None
30
+ ):
31
+ """
32
+ Upload pre-computed network graph to Hugging Face Hub.
33
+
34
+ Args:
35
+ network_file: Path to the network pickle file
36
+ version: Version tag for the data
37
+ dataset_id: HF dataset ID (defaults to HF_PRECOMPUTED_DATASET)
38
+ """
39
+ try:
40
+ from huggingface_hub import HfApi, upload_file
41
+
42
+ if dataset_id is None:
43
+ dataset_id = HF_PRECOMPUTED_DATASET
44
+
45
+ network_path = Path(network_file)
46
+ if not network_path.exists():
47
+ logger.error(f"Network file not found: {network_file}")
48
+ return False
49
+
50
+ logger.info(f"Uploading network graph to {dataset_id}...")
51
+ logger.info(f"File: {network_file}")
52
+ logger.info(f"Version: {version}")
53
+
54
+ api = HfApi()
55
+
56
+ # Check if repository exists, create if it doesn't
57
+ try:
58
+ api.dataset_info(dataset_id)
59
+ logger.info(f"Repository {dataset_id} exists")
60
+ except Exception:
61
+ logger.info(f"Repository {dataset_id} not found. Creating it...")
62
+ try:
63
+ api.create_repo(
64
+ repo_id=dataset_id,
65
+ repo_type="dataset",
66
+ exist_ok=True
67
+ )
68
+ logger.info(f"Created repository {dataset_id}")
69
+ except Exception as create_error:
70
+ logger.error(f"Could not create repository: {create_error}")
71
+ logger.info("You may need to create it manually at https://huggingface.co/new-dataset")
72
+ return False
73
+
74
+ # Upload network file
75
+ filename = f"full_derivative_network_{version}.pkl"
76
+ upload_file(
77
+ path_or_fileobj=str(network_path),
78
+ path_in_repo=filename,
79
+ repo_id=dataset_id,
80
+ repo_type="dataset",
81
+ commit_message=f"Upload pre-computed network graph (version {version})"
82
+ )
83
+
84
+ logger.info(f"Successfully uploaded {filename} to {dataset_id}")
85
+
86
+ # Try to upload metadata if it exists
87
+ metadata_file = network_path.parent / "network_metadata.json"
88
+ if metadata_file.exists():
89
+ try:
90
+ upload_file(
91
+ path_or_fileobj=str(metadata_file),
92
+ path_in_repo=f"network_metadata_{version}.json",
93
+ repo_id=dataset_id,
94
+ repo_type="dataset",
95
+ commit_message=f"Upload network metadata (version {version})"
96
+ )
97
+ logger.info("Successfully uploaded network metadata")
98
+ except Exception as e:
99
+ logger.warning(f"Could not upload metadata: {e}")
100
+
101
+ return True
102
+
103
+ except ImportError:
104
+ logger.error("huggingface_hub not installed. Install it with: pip install huggingface_hub")
105
+ return False
106
+ except Exception as e:
107
+ logger.error(f"Error uploading network to HF Hub: {e}", exc_info=True)
108
+ return False
109
+
110
+
111
+ if __name__ == "__main__":
112
+ parser = argparse.ArgumentParser(description="Upload pre-computed network graph to HF Hub")
113
+ parser.add_argument("--network-file", type=str,
114
+ default="precomputed_data/full_derivative_network.pkl",
115
+ help="Path to network pickle file")
116
+ parser.add_argument("--version", type=str, default="v1",
117
+ help="Version tag for the data")
118
+ parser.add_argument("--dataset-id", type=str, default=None,
119
+ help="HF dataset ID (defaults to HF_PRECOMPUTED_DATASET)")
120
+
121
+ args = parser.parse_args()
122
+
123
+ success = upload_network_to_hf(
124
+ network_file=args.network_file,
125
+ version=args.version,
126
+ dataset_id=args.dataset_id
127
+ )
128
+
129
+ sys.exit(0 if success else 1)
130
+
backend/utils/chunked_loader.py CHANGED
@@ -216,3 +216,4 @@ def create_chunk_index(
216
 
217
  return chunk_index
218
 
 
 
216
 
217
  return chunk_index
218
 
219
+
backend/utils/precomputed_loader.py CHANGED
@@ -181,6 +181,79 @@ class PrecomputedDataLoader:
181
  return df, embeddings, metadata
182
 
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  def download_from_hf_hub(data_dir: str, version: str = "v1") -> bool:
185
  """
186
  Download precomputed data from HuggingFace Hub.
@@ -283,6 +356,18 @@ def download_from_hf_hub(data_dir: str, version: str = "v1") -> bool:
283
  except Exception:
284
  logger.info("Single embeddings file not available either")
285
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  return True
287
 
288
  except ImportError:
 
181
  return df, embeddings, metadata
182
 
183
 
184
+ def download_network_from_hf_hub(data_dir: str, version: str = "v1") -> bool:
185
+ """
186
+ Download pre-computed network graph from Hugging Face Hub.
187
+
188
+ Args:
189
+ data_dir: Directory to save the network file
190
+ version: Version tag for the data
191
+
192
+ Returns:
193
+ True if download successful, False otherwise
194
+ """
195
+ try:
196
+ from huggingface_hub import hf_hub_download
197
+ import os
198
+ from pathlib import Path
199
+
200
+ data_path = Path(data_dir)
201
+ data_path.mkdir(parents=True, exist_ok=True)
202
+
203
+ network_file = data_path / "full_derivative_network.pkl"
204
+ metadata_file = data_path / "network_metadata.json"
205
+
206
+ # Skip if already exists
207
+ if network_file.exists():
208
+ logger.info(f"Network file already exists: {network_file}")
209
+ return True
210
+
211
+ logger.info(f"Downloading pre-computed network graph from HF Hub...")
212
+ logger.info(f"Dataset: {HF_PRECOMPUTED_DATASET}, Version: {version}")
213
+
214
+ try:
215
+ # Try to download network file
216
+ downloaded_path = hf_hub_download(
217
+ repo_id=HF_PRECOMPUTED_DATASET,
218
+ filename=f"full_derivative_network_{version}.pkl",
219
+ repo_type="dataset",
220
+ local_dir=str(data_path),
221
+ local_dir_use_symlinks=False
222
+ )
223
+
224
+ # Rename to standard name if needed
225
+ if downloaded_path != str(network_file):
226
+ import shutil
227
+ shutil.move(downloaded_path, str(network_file))
228
+
229
+ logger.info(f"Successfully downloaded network graph: {network_file}")
230
+
231
+ # Try to download metadata
232
+ try:
233
+ hf_hub_download(
234
+ repo_id=HF_PRECOMPUTED_DATASET,
235
+ filename=f"network_metadata_{version}.json",
236
+ repo_type="dataset",
237
+ local_dir=str(data_path),
238
+ local_dir_use_symlinks=False
239
+ )
240
+ except Exception as e:
241
+ logger.warning(f"Could not download network metadata: {e}")
242
+
243
+ return True
244
+
245
+ except Exception as e:
246
+ logger.warning(f"Network file not found in HF Hub (this is optional): {e}")
247
+ return False
248
+
249
+ except ImportError:
250
+ logger.warning("huggingface_hub not available. Cannot download network from HF Hub.")
251
+ return False
252
+ except Exception as e:
253
+ logger.error(f"Error downloading network from HF Hub: {e}")
254
+ return False
255
+
256
+
257
  def download_from_hf_hub(data_dir: str, version: str = "v1") -> bool:
258
  """
259
  Download precomputed data from HuggingFace Hub.
 
356
  except Exception:
357
  logger.info("Single embeddings file not available either")
358
 
359
+ # Try to download pre-computed network graph (optional)
360
+ try:
361
+ network_path = hf_hub_download(
362
+ repo_id=dataset_id,
363
+ filename=f"full_derivative_network_{version}.pkl",
364
+ repo_type="dataset",
365
+ local_dir=data_dir
366
+ )
367
+ logger.info(f"Downloaded pre-computed network graph to {network_path}")
368
+ except Exception as e:
369
+ logger.info(f"Pre-computed network graph not available (optional): {e}")
370
+
371
  return True
372
 
373
  except ImportError:
check_and_deploy.sh CHANGED
@@ -41,3 +41,4 @@ else
41
  fi
42
  fi
43
 
 
 
41
  fi
42
  fi
43
 
44
+
frontend/src/App.css CHANGED
@@ -223,6 +223,9 @@
223
  padding: 1.5rem;
224
  border-top: 1px solid var(--border-light);
225
  margin-top: auto;
 
 
 
226
  }
227
 
228
  .nav-links {
@@ -244,6 +247,14 @@
244
  text-decoration: underline;
245
  }
246
 
 
 
 
 
 
 
 
 
247
  /* ============================================
248
  APP MAIN CONTENT
249
  ============================================ */
@@ -928,6 +939,57 @@
928
  flex-shrink: 0;
929
  }
930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
931
  .semantic-search-toggle {
932
  padding: 0.5rem 0.75rem;
933
  background: var(--bg-secondary);
@@ -2032,24 +2094,32 @@
2032
  }
2033
 
2034
  .theme-toggle {
 
 
 
 
2035
  background: var(--bg-secondary);
2036
  border: 1px solid var(--border-medium);
2037
- border-radius: 0;
2038
- padding: 0.5rem;
2039
  cursor: pointer;
2040
- font-size: 1.2rem;
2041
- display: flex;
2042
- align-items: center;
2043
- justify-content: center;
2044
- min-width: 40px;
2045
- height: 40px;
2046
  transition: all var(--transition-base);
 
 
 
2047
  }
2048
 
2049
- .theme-toggle:hover {
2050
- background: var(--bg-tertiary);
2051
- border-color: var(--accent-blue);
2052
- }
 
 
 
 
 
2053
 
2054
  /* ============================================
2055
  ZOOM CONTROLS
 
223
  padding: 1.5rem;
224
  border-top: 1px solid var(--border-light);
225
  margin-top: auto;
226
+ display: flex;
227
+ flex-direction: column;
228
+ gap: 1rem;
229
  }
230
 
231
  .nav-links {
 
247
  text-decoration: underline;
248
  }
249
 
250
+ .nav-theme-toggle {
251
+ display: flex;
252
+ align-items: center;
253
+ justify-content: flex-start;
254
+ padding-top: 0.5rem;
255
+ border-top: 1px solid var(--border-light);
256
+ }
257
+
258
  /* ============================================
259
  APP MAIN CONTENT
260
  ============================================ */
 
939
  flex-shrink: 0;
940
  }
941
 
942
+ /* Graph Search Input - for force graph mode */
943
+ .graph-search {
944
+ display: flex;
945
+ align-items: center;
946
+ gap: 0.4rem;
947
+ position: relative;
948
+ }
949
+
950
+ .graph-search-input {
951
+ padding: 0.4rem 0.6rem;
952
+ border: 1px solid var(--border-medium);
953
+ background: var(--bg-primary);
954
+ color: var(--text-primary);
955
+ font-size: 0.85rem;
956
+ border-radius: 4px;
957
+ width: 180px;
958
+ transition: all var(--transition-base);
959
+ outline: none;
960
+ }
961
+
962
+ .graph-search-input:hover {
963
+ border-color: var(--border-dark);
964
+ }
965
+
966
+ .graph-search-input:focus {
967
+ border-color: var(--accent-blue);
968
+ box-shadow: 0 0 0 2px rgba(74, 144, 226, 0.1);
969
+ width: 220px;
970
+ }
971
+
972
+ .graph-search-input::placeholder {
973
+ color: var(--text-tertiary);
974
+ }
975
+
976
+ .graph-search-clear {
977
+ position: absolute;
978
+ right: 0.4rem;
979
+ background: none;
980
+ border: none;
981
+ color: var(--text-tertiary);
982
+ cursor: pointer;
983
+ padding: 0.2rem;
984
+ font-size: 1rem;
985
+ line-height: 1;
986
+ transition: color var(--transition-fast);
987
+ }
988
+
989
+ .graph-search-clear:hover {
990
+ color: var(--text-primary);
991
+ }
992
+
993
  .semantic-search-toggle {
994
  padding: 0.5rem 0.75rem;
995
  background: var(--bg-secondary);
 
2094
  }
2095
 
2096
  .theme-toggle {
2097
+ display: flex;
2098
+ align-items: center;
2099
+ gap: 0.5rem;
2100
+ padding: 0.5rem 0.75rem;
2101
  background: var(--bg-secondary);
2102
  border: 1px solid var(--border-medium);
2103
+ border-radius: 4px;
 
2104
  cursor: pointer;
2105
+ font-size: 0.875rem;
2106
+ font-weight: 500;
2107
+ color: var(--text-primary);
 
 
 
2108
  transition: all var(--transition-base);
2109
+ font-family: var(--font-primary);
2110
+ width: 100%;
2111
+ justify-content: flex-start;
2112
  }
2113
 
2114
+ .theme-toggle:hover {
2115
+ background: var(--bg-tertiary);
2116
+ border-color: var(--accent-blue);
2117
+ color: var(--accent-blue);
2118
+ }
2119
+
2120
+ .theme-toggle-label {
2121
+ font-size: 0.875rem;
2122
+ }
2123
 
2124
  /* ============================================
2125
  ZOOM CONTROLS
frontend/src/App.tsx CHANGED
@@ -1,5 +1,6 @@
1
  import React, { useState, useEffect, useCallback, useRef, useMemo } from 'react';
2
- import { ChevronLeft, ChevronRight, Palette, Maximize2, Eye, GitBranch } from 'lucide-react';
 
3
  import IntroModal from './components/ui/IntroModal';
4
  import ScatterPlot3D from './components/visualizations/ScatterPlot3D';
5
  import NetworkGraph from './components/visualizations/NetworkGraph';
@@ -111,6 +112,19 @@ function App() {
111
  const [nodeSizeMultiplier, setNodeSizeMultiplier] = useState(1.0);
112
  const [edgeOpacity, setEdgeOpacity] = useState(0.6);
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  // Threshold for using instanced rendering
115
  const INSTANCED_THRESHOLD = 10000;
116
 
@@ -481,14 +495,42 @@ function App() {
481
  setGraphLoading(true);
482
  setGraphError(null);
483
  try {
 
 
484
  const data = await fetchFullDerivativeNetwork({
485
  edgeTypes: undefined,
486
  includeEdgeAttributes: false,
 
 
 
487
  });
488
  setGraphNodes(data.nodes || []);
489
  setGraphLinks(data.links || []);
490
  setGraphStats(data.statistics ? { nodes: data.statistics.nodes, edges: data.statistics.edges } : null);
491
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
  if (data.links && data.links.length > 0) {
493
  const availableTypes = getAvailableEdgeTypes(data.links);
494
  if (availableTypes.size > 0) {
@@ -498,7 +540,21 @@ function App() {
498
  }
499
  graphLoadedRef.current = true;
500
  } catch (err) {
501
- setGraphError(err instanceof Error ? err.message : 'Failed to load graph');
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
  } finally {
503
  setGraphLoading(false);
504
  }
@@ -664,6 +720,9 @@ function App() {
664
  <a href="https://github.com/bendlaufer/ai-ecosystem" target="_blank" rel="noopener noreferrer" title="View source code on GitHub">GitHub</a>
665
  <a href="https://huggingface.co/modelbiome" target="_blank" rel="noopener noreferrer" title="Access the dataset on Hugging Face">Dataset</a>
666
  </div>
 
 
 
667
  </div>
668
  )}
669
  </aside>
@@ -795,6 +854,76 @@ function App() {
795
  {/* Force graph controls - only show for force-graph mode */}
796
  {vizMode === 'force-graph' && !showAnalytics && !showFamilies && !showGraph && (
797
  <>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
798
  {/* Edge type filter */}
799
  {availableEdgeTypes.length > 0 && (
800
  <>
@@ -854,39 +983,83 @@ function App() {
854
 
855
  {/* Right: Integrated Search */}
856
  <div className="control-bar-right">
857
- <IntegratedSearch
858
- value={localSearchQuery}
859
- onChange={(value) => {
860
- setLocalSearchQuery(value);
861
- debouncedSetSearchQuery(value);
862
- }}
863
- onSelect={(result) => {
864
- const modelPoint: ModelPoint = {
865
- model_id: result.model_id,
866
- x: result.x || 0,
867
- y: result.y || 0,
868
- z: result.z || 0,
869
- downloads: result.downloads || 0,
870
- likes: result.likes || 0,
871
- trending_score: null,
872
- tags: null,
873
- licenses: null,
874
- cluster_id: null,
875
- created_at: null,
876
- library_name: result.library_name || null,
877
- pipeline_tag: result.pipeline_tag || null,
878
- parent_model: null,
879
- family_depth: result.family_depth || null,
880
- };
881
- setSelectedModel(modelPoint);
882
- setIsModalOpen(true);
883
- setLocalSearchQuery('');
884
- setSearchQuery('');
885
- }}
886
- onZoomTo={(x, y, z) => {
887
- // Zoom to point - reserved for future implementation
888
- }}
889
- />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
890
  </div>
891
  </div>
892
  </div>
@@ -965,6 +1138,12 @@ function App() {
965
  collisionRadius={collisionRadius}
966
  nodeSizeMultiplier={nodeSizeMultiplier}
967
  edgeOpacity={edgeOpacity}
 
 
 
 
 
 
968
  />
969
  ) : (
970
  <ForceDirectedGraph3D
@@ -981,6 +1160,12 @@ function App() {
981
  collisionRadius={collisionRadius}
982
  nodeSizeMultiplier={nodeSizeMultiplier}
983
  edgeOpacity={edgeOpacity}
 
 
 
 
 
 
984
  />
985
  )}
986
  </>
 
1
  import React, { useState, useEffect, useCallback, useRef, useMemo } from 'react';
2
+ import { ChevronLeft, ChevronRight, Palette, Maximize2, Eye, GitBranch, Filter, Search, Moon, Sun } from 'lucide-react';
3
+ import ThemeToggle from './components/controls/ThemeToggle';
4
  import IntroModal from './components/ui/IntroModal';
5
  import ScatterPlot3D from './components/visualizations/ScatterPlot3D';
6
  import NetworkGraph from './components/visualizations/NetworkGraph';
 
112
  const [nodeSizeMultiplier, setNodeSizeMultiplier] = useState(1.0);
113
  const [edgeOpacity, setEdgeOpacity] = useState(0.6);
114
 
115
+ // Force graph visual options (harmonized with embeddings view)
116
+ const [graphColorBy, setGraphColorBy] = useState<'library' | 'pipeline' | 'downloads' | 'likes' | 'edge_type'>('library');
117
+ const [graphSizeBy, setGraphSizeBy] = useState<'downloads' | 'likes' | 'uniform'>('downloads');
118
+ const [graphColorScheme, setGraphColorScheme] = useState<'viridis' | 'plasma' | 'inferno' | 'coolwarm'>('viridis');
119
+
120
+ // Force graph search and filtering
121
+ const [graphSearchQuery, setGraphSearchQuery] = useState('');
122
+ const [graphFamilyFilter, setGraphFamilyFilter] = useState('');
123
+ const [highlightedNodeId, setHighlightedNodeId] = useState<string | null>(null);
124
+
125
+ // Available families for the filter dropdown
126
+ const [availableFamilies, setAvailableFamilies] = useState<string[]>([]);
127
+
128
  // Threshold for using instanced rendering
129
  const INSTANCED_THRESHOLD = 10000;
130
 
 
495
  setGraphLoading(true);
496
  setGraphError(null);
497
  try {
498
+ // Use filtering to reduce network size for better performance
499
+ // Start with top 50k models by downloads to avoid 500 errors
500
  const data = await fetchFullDerivativeNetwork({
501
  edgeTypes: undefined,
502
  includeEdgeAttributes: false,
503
+ minDownloads: 0, // Can be increased to reduce size further
504
+ maxNodes: 50000, // Limit to top 50k models by downloads
505
+ usePrecomputed: true, // Try to use pre-computed graph if available
506
  });
507
  setGraphNodes(data.nodes || []);
508
  setGraphLinks(data.links || []);
509
  setGraphStats(data.statistics ? { nodes: data.statistics.nodes, edges: data.statistics.edges } : null);
510
 
511
+ // Extract unique families (organizations) from node IDs for filtering
512
+ if (data.nodes && data.nodes.length > 0) {
513
+ const familySet = new Set<string>();
514
+ data.nodes.forEach((node: GraphNode) => {
515
+ if (node.id && node.id.includes('/')) {
516
+ const family = node.id.split('/')[0];
517
+ if (family) familySet.add(family);
518
+ }
519
+ });
520
+ // Sort by frequency (most common families first)
521
+ const familyCount = new Map<string, number>();
522
+ data.nodes.forEach((node: GraphNode) => {
523
+ if (node.id && node.id.includes('/')) {
524
+ const family = node.id.split('/')[0];
525
+ familyCount.set(family, (familyCount.get(family) || 0) + 1);
526
+ }
527
+ });
528
+ const sortedFamilies = Array.from(familySet)
529
+ .sort((a, b) => (familyCount.get(b) || 0) - (familyCount.get(a) || 0))
530
+ .slice(0, 100); // Keep top 100 families
531
+ setAvailableFamilies(sortedFamilies);
532
+ }
533
+
534
  if (data.links && data.links.length > 0) {
535
  const availableTypes = getAvailableEdgeTypes(data.links);
536
  if (availableTypes.size > 0) {
 
540
  }
541
  graphLoadedRef.current = true;
542
  } catch (err) {
543
+ const errorMessage = err instanceof Error ? err.message : 'Failed to load graph';
544
+ // Check if it's a rate limit error
545
+ if (errorMessage.includes('Rate limit') || errorMessage.includes('429')) {
546
+ setGraphError(`${errorMessage} The graph will automatically retry.`);
547
+ // Reset the loaded flag so it can retry
548
+ graphLoadedRef.current = false;
549
+ // Retry after a delay
550
+ setTimeout(() => {
551
+ if (vizMode === 'force-graph' && !graphLoadedRef.current) {
552
+ loadForceGraph();
553
+ }
554
+ }, 5000);
555
+ } else {
556
+ setGraphError(errorMessage);
557
+ }
558
  } finally {
559
  setGraphLoading(false);
560
  }
 
720
  <a href="https://github.com/bendlaufer/ai-ecosystem" target="_blank" rel="noopener noreferrer" title="View source code on GitHub">GitHub</a>
721
  <a href="https://huggingface.co/modelbiome" target="_blank" rel="noopener noreferrer" title="Access the dataset on Hugging Face">Dataset</a>
722
  </div>
723
+ <div className="nav-theme-toggle">
724
+ <ThemeToggle />
725
+ </div>
726
  </div>
727
  )}
728
  </aside>
 
854
  {/* Force graph controls - only show for force-graph mode */}
855
  {vizMode === 'force-graph' && !showAnalytics && !showFamilies && !showGraph && (
856
  <>
857
+ {/* Color by - harmonized with embeddings view */}
858
+ <div className="control-group">
859
+ <Palette size={14} className="control-icon" />
860
+ <select
861
+ value={graphColorBy}
862
+ onChange={(e) => setGraphColorBy(e.target.value as typeof graphColorBy)}
863
+ className="control-select"
864
+ title="Color nodes by attribute"
865
+ >
866
+ <option value="library">ML Library</option>
867
+ <option value="pipeline">Task Type</option>
868
+ <option value="downloads">Downloads</option>
869
+ <option value="likes">Likes</option>
870
+ <option value="edge_type">Edge Type</option>
871
+ </select>
872
+ {(graphColorBy === 'downloads' || graphColorBy === 'likes') && (
873
+ <select
874
+ value={graphColorScheme}
875
+ onChange={(e) => setGraphColorScheme(e.target.value as typeof graphColorScheme)}
876
+ className="control-select control-select-small"
877
+ title="Color gradient style"
878
+ >
879
+ <option value="viridis">Viridis</option>
880
+ <option value="plasma">Plasma</option>
881
+ <option value="inferno">Inferno</option>
882
+ <option value="coolwarm">Cool-Warm</option>
883
+ </select>
884
+ )}
885
+ </div>
886
+
887
+ <span className="control-divider" />
888
+
889
+ {/* Size by */}
890
+ <div className="control-group">
891
+ <Maximize2 size={14} className="control-icon" />
892
+ <select
893
+ value={graphSizeBy}
894
+ onChange={(e) => setGraphSizeBy(e.target.value as typeof graphSizeBy)}
895
+ className="control-select"
896
+ title="Size nodes by attribute"
897
+ >
898
+ <option value="downloads">By Downloads</option>
899
+ <option value="likes">By Likes</option>
900
+ <option value="uniform">Uniform Size</option>
901
+ </select>
902
+ </div>
903
+
904
+ <span className="control-divider" />
905
+
906
+ {/* Family filter */}
907
+ <div className="control-group">
908
+ <Filter size={14} className="control-icon" />
909
+ <select
910
+ value={graphFamilyFilter}
911
+ onChange={(e) => {
912
+ setGraphFamilyFilter(e.target.value);
913
+ setHighlightedNodeId(null);
914
+ }}
915
+ className="control-select"
916
+ title="Filter by organization/family"
917
+ >
918
+ <option value="">All Families</option>
919
+ {availableFamilies.map(family => (
920
+ <option key={family} value={family}>{family}</option>
921
+ ))}
922
+ </select>
923
+ </div>
924
+
925
+ <span className="control-divider" />
926
+
927
  {/* Edge type filter */}
928
  {availableEdgeTypes.length > 0 && (
929
  <>
 
983
 
984
  {/* Right: Integrated Search */}
985
  <div className="control-bar-right">
986
+ {vizMode === 'force-graph' ? (
987
+ // Simple search input for force graph mode
988
+ <div className="control-group graph-search">
989
+ <Search size={14} className="control-icon" />
990
+ <input
991
+ type="text"
992
+ value={graphSearchQuery}
993
+ onChange={(e) => {
994
+ setGraphSearchQuery(e.target.value);
995
+ // Auto-highlight the first matching node
996
+ if (e.target.value.trim()) {
997
+ const query = e.target.value.toLowerCase();
998
+ const match = graphNodes.find(n =>
999
+ n.id.toLowerCase().includes(query) ||
1000
+ n.title?.toLowerCase().includes(query)
1001
+ );
1002
+ if (match) {
1003
+ setHighlightedNodeId(match.id);
1004
+ } else {
1005
+ setHighlightedNodeId(null);
1006
+ }
1007
+ } else {
1008
+ setHighlightedNodeId(null);
1009
+ }
1010
+ }}
1011
+ placeholder="Search models..."
1012
+ className="graph-search-input"
1013
+ title="Type to filter and highlight models in the graph"
1014
+ />
1015
+ {graphSearchQuery && (
1016
+ <button
1017
+ className="graph-search-clear"
1018
+ onClick={() => {
1019
+ setGraphSearchQuery('');
1020
+ setHighlightedNodeId(null);
1021
+ }}
1022
+ title="Clear search"
1023
+ >
1024
+ ×
1025
+ </button>
1026
+ )}
1027
+ </div>
1028
+ ) : (
1029
+ <IntegratedSearch
1030
+ value={localSearchQuery}
1031
+ onChange={(value) => {
1032
+ setLocalSearchQuery(value);
1033
+ debouncedSetSearchQuery(value);
1034
+ }}
1035
+ onSelect={(result) => {
1036
+ const modelPoint: ModelPoint = {
1037
+ model_id: result.model_id,
1038
+ x: result.x || 0,
1039
+ y: result.y || 0,
1040
+ z: result.z || 0,
1041
+ downloads: result.downloads || 0,
1042
+ likes: result.likes || 0,
1043
+ trending_score: null,
1044
+ tags: null,
1045
+ licenses: null,
1046
+ cluster_id: null,
1047
+ created_at: null,
1048
+ library_name: result.library_name || null,
1049
+ pipeline_tag: result.pipeline_tag || null,
1050
+ parent_model: null,
1051
+ family_depth: result.family_depth || null,
1052
+ };
1053
+ setSelectedModel(modelPoint);
1054
+ setIsModalOpen(true);
1055
+ setLocalSearchQuery('');
1056
+ setSearchQuery('');
1057
+ }}
1058
+ onZoomTo={(x, y, z) => {
1059
+ // Zoom to point - reserved for future implementation
1060
+ }}
1061
+ />
1062
+ )}
1063
  </div>
1064
  </div>
1065
  </div>
 
1138
  collisionRadius={collisionRadius}
1139
  nodeSizeMultiplier={nodeSizeMultiplier}
1140
  edgeOpacity={edgeOpacity}
1141
+ colorBy={graphColorBy}
1142
+ sizeBy={graphSizeBy}
1143
+ colorScheme={graphColorScheme}
1144
+ highlightedNodeId={highlightedNodeId}
1145
+ familyFilter={graphFamilyFilter}
1146
+ searchQuery={graphSearchQuery}
1147
  />
1148
  ) : (
1149
  <ForceDirectedGraph3D
 
1160
  collisionRadius={collisionRadius}
1161
  nodeSizeMultiplier={nodeSizeMultiplier}
1162
  edgeOpacity={edgeOpacity}
1163
+ colorBy={graphColorBy}
1164
+ sizeBy={graphSizeBy}
1165
+ colorScheme={graphColorScheme}
1166
+ highlightedNodeId={highlightedNodeId}
1167
+ familyFilter={graphFamilyFilter}
1168
+ searchQuery={graphSearchQuery}
1169
  />
1170
  )}
1171
  </>
frontend/src/components/controls/EdgeTypeFilter.css CHANGED
@@ -1,19 +1,29 @@
1
  .edge-type-filter {
2
  padding: 12px;
3
- background: rgba(255, 255, 255, 0.05);
 
4
  border-radius: 8px;
5
  margin-bottom: 12px;
6
  }
7
 
 
 
 
 
 
8
  .edge-type-filter h4 {
9
  margin: 0 0 8px 0;
10
  font-size: 12px;
11
  font-weight: 600;
12
- color: var(--text-secondary, #9ca3af);
13
  text-transform: uppercase;
14
  letter-spacing: 0.5px;
15
  }
16
 
 
 
 
 
17
  .edge-type-item {
18
  display: flex;
19
  align-items: center;
@@ -21,11 +31,15 @@
21
  margin-bottom: 4px;
22
  border-radius: 4px;
23
  cursor: pointer;
24
- transition: all 0.2s;
25
  user-select: none;
26
  }
27
 
28
  .edge-type-item:hover {
 
 
 
 
29
  background: rgba(255, 255, 255, 0.05);
30
  }
31
 
@@ -43,6 +57,10 @@
43
 
44
  .edge-type-label {
45
  font-size: 13px;
 
 
 
 
46
  color: var(--text-primary, #ffffff);
47
  }
48
 
@@ -62,24 +80,45 @@
62
  }
63
 
64
  .edge-type-toggle {
65
- padding: 4px 10px;
66
- border: 1px solid;
 
 
 
67
  border-radius: 4px;
68
- font-size: 11px;
69
  font-weight: 500;
70
  cursor: pointer;
71
- transition: all 0.2s;
72
- color: var(--text-primary, #ffffff);
73
- background: transparent;
74
  white-space: nowrap;
 
75
  }
76
 
77
- .edge-type-toggle:hover {
78
- opacity: 0.8;
 
 
79
  }
80
 
81
  .edge-type-toggle.active {
82
  opacity: 1;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  }
84
 
85
  .edge-type-toggle-label {
 
1
  .edge-type-filter {
2
  padding: 12px;
3
+ background: var(--bg-secondary, #fafafa);
4
+ border: 1px solid var(--border-light, #e8e8e8);
5
  border-radius: 8px;
6
  margin-bottom: 12px;
7
  }
8
 
9
+ [data-theme="dark"] .edge-type-filter {
10
+ background: rgba(255, 255, 255, 0.05);
11
+ border-color: rgba(255, 255, 255, 0.1);
12
+ }
13
+
14
  .edge-type-filter h4 {
15
  margin: 0 0 8px 0;
16
  font-size: 12px;
17
  font-weight: 600;
18
+ color: var(--text-secondary, #666666);
19
  text-transform: uppercase;
20
  letter-spacing: 0.5px;
21
  }
22
 
23
+ [data-theme="dark"] .edge-type-filter h4 {
24
+ color: var(--text-secondary, #9ca3af);
25
+ }
26
+
27
  .edge-type-item {
28
  display: flex;
29
  align-items: center;
 
31
  margin-bottom: 4px;
32
  border-radius: 4px;
33
  cursor: pointer;
34
+ transition: all var(--transition-base, 200ms);
35
  user-select: none;
36
  }
37
 
38
  .edge-type-item:hover {
39
+ background: var(--bg-tertiary, #f5f5f5);
40
+ }
41
+
42
+ [data-theme="dark"] .edge-type-item:hover {
43
  background: rgba(255, 255, 255, 0.05);
44
  }
45
 
 
57
 
58
  .edge-type-label {
59
  font-size: 13px;
60
+ color: var(--text-primary, #1a1a1a);
61
+ }
62
+
63
+ [data-theme="dark"] .edge-type-label {
64
  color: var(--text-primary, #ffffff);
65
  }
66
 
 
80
  }
81
 
82
  .edge-type-toggle {
83
+ display: flex;
84
+ align-items: center;
85
+ gap: 0.35rem;
86
+ padding: 0.35rem 0.65rem;
87
+ border: 1px solid var(--border-medium, #d0d0d0);
88
  border-radius: 4px;
89
+ font-size: 0.8rem;
90
  font-weight: 500;
91
  cursor: pointer;
92
+ transition: all var(--transition-base, 200ms);
93
+ color: var(--text-secondary, #666666);
94
+ background: var(--bg-primary, #ffffff);
95
  white-space: nowrap;
96
+ font-family: var(--font-primary, -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif);
97
  }
98
 
99
+ .edge-type-toggle:hover:not(.active) {
100
+ background: var(--bg-secondary, #fafafa);
101
+ color: var(--text-primary, #1a1a1a);
102
+ border-color: var(--border-dark, #b0b0b0);
103
  }
104
 
105
  .edge-type-toggle.active {
106
  opacity: 1;
107
+ color: white;
108
+ border-color: transparent;
109
+ font-weight: 600;
110
+ }
111
+
112
+ [data-theme="dark"] .edge-type-toggle {
113
+ color: var(--text-secondary, #cccccc);
114
+ background: var(--bg-primary, #1a1a1a);
115
+ border-color: var(--border-medium, #3a3a3a);
116
+ }
117
+
118
+ [data-theme="dark"] .edge-type-toggle:hover:not(.active) {
119
+ background: var(--bg-secondary, #2d2d2d);
120
+ color: var(--text-primary, #ffffff);
121
+ border-color: var(--border-dark, #555555);
122
  }
123
 
124
  .edge-type-toggle-label {
frontend/src/components/controls/EdgeTypeFilter.tsx CHANGED
@@ -72,3 +72,4 @@ export default function EdgeTypeFilter({
72
  );
73
  }
74
 
 
 
72
  );
73
  }
74
 
75
+
frontend/src/components/controls/ForceParameterControls.css CHANGED
@@ -5,18 +5,32 @@
5
  .force-parameter-toggle {
6
  display: flex;
7
  align-items: center;
8
- gap: 6px;
9
- padding: 6px 12px;
10
- background: rgba(255, 255, 255, 0.05);
11
- border: 1px solid rgba(255, 255, 255, 0.1);
12
- border-radius: 6px;
13
- color: var(--text-primary, #ffffff);
14
- font-size: 12px;
 
15
  cursor: pointer;
16
- transition: all 0.2s;
 
17
  }
18
 
19
  .force-parameter-toggle:hover {
 
 
 
 
 
 
 
 
 
 
 
 
20
  background: rgba(255, 255, 255, 0.1);
21
  border-color: rgba(255, 255, 255, 0.2);
22
  }
@@ -27,14 +41,20 @@
27
  left: 0;
28
  margin-top: 8px;
29
  padding: 16px;
30
- background: var(--bg-secondary, #1f2937);
31
- border: 1px solid rgba(255, 255, 255, 0.1);
32
  border-radius: 8px;
33
- box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
34
  min-width: 240px;
35
  z-index: 1000;
36
  }
37
 
 
 
 
 
 
 
38
  .force-parameter-group {
39
  margin-bottom: 16px;
40
  }
@@ -47,19 +67,27 @@
47
  display: block;
48
  font-size: 12px;
49
  font-weight: 500;
50
- color: var(--text-primary, #ffffff);
51
  margin-bottom: 8px;
52
  }
53
 
 
 
 
 
54
  .force-parameter-group input[type="range"] {
55
  width: 100%;
56
  height: 4px;
57
  border-radius: 2px;
58
- background: rgba(255, 255, 255, 0.1);
59
  outline: none;
60
  -webkit-appearance: none;
61
  }
62
 
 
 
 
 
63
  .force-parameter-group input[type="range"]::-webkit-slider-thumb {
64
  -webkit-appearance: none;
65
  appearance: none;
 
5
  .force-parameter-toggle {
6
  display: flex;
7
  align-items: center;
8
+ gap: 0.35rem;
9
+ padding: 0.35rem 0.65rem;
10
+ border: 1px solid var(--border-medium, #d0d0d0);
11
+ border-radius: 4px;
12
+ background: var(--bg-primary, #ffffff);
13
+ color: var(--text-secondary, #666666);
14
+ font-size: 0.8rem;
15
+ font-weight: 500;
16
  cursor: pointer;
17
+ transition: all var(--transition-base, 200ms);
18
+ font-family: var(--font-primary, -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif);
19
  }
20
 
21
  .force-parameter-toggle:hover {
22
+ background: var(--bg-secondary, #fafafa);
23
+ color: var(--text-primary, #1a1a1a);
24
+ border-color: var(--border-dark, #b0b0b0);
25
+ }
26
+
27
+ [data-theme="dark"] .force-parameter-toggle {
28
+ background: rgba(255, 255, 255, 0.05);
29
+ border-color: rgba(255, 255, 255, 0.1);
30
+ color: var(--text-primary, #ffffff);
31
+ }
32
+
33
+ [data-theme="dark"] .force-parameter-toggle:hover {
34
  background: rgba(255, 255, 255, 0.1);
35
  border-color: rgba(255, 255, 255, 0.2);
36
  }
 
41
  left: 0;
42
  margin-top: 8px;
43
  padding: 16px;
44
+ background: var(--bg-elevated, #ffffff);
45
+ border: 1px solid var(--border-medium, #d0d0d0);
46
  border-radius: 8px;
47
+ box-shadow: var(--shadow-lg, 0 2px 8px rgba(0, 0, 0, 0.12));
48
  min-width: 240px;
49
  z-index: 1000;
50
  }
51
 
52
+ [data-theme="dark"] .force-parameter-panel {
53
+ background: var(--bg-secondary, #2d2d2d);
54
+ border-color: rgba(255, 255, 255, 0.1);
55
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
56
+ }
57
+
58
  .force-parameter-group {
59
  margin-bottom: 16px;
60
  }
 
67
  display: block;
68
  font-size: 12px;
69
  font-weight: 500;
70
+ color: var(--text-primary, #1a1a1a);
71
  margin-bottom: 8px;
72
  }
73
 
74
+ [data-theme="dark"] .force-parameter-group label {
75
+ color: var(--text-primary, #ffffff);
76
+ }
77
+
78
  .force-parameter-group input[type="range"] {
79
  width: 100%;
80
  height: 4px;
81
  border-radius: 2px;
82
+ background: var(--border-light, #e8e8e8);
83
  outline: none;
84
  -webkit-appearance: none;
85
  }
86
 
87
+ [data-theme="dark"] .force-parameter-group input[type="range"] {
88
+ background: rgba(255, 255, 255, 0.1);
89
+ }
90
+
91
  .force-parameter-group input[type="range"]::-webkit-slider-thumb {
92
  -webkit-appearance: none;
93
  appearance: none;
frontend/src/components/controls/ForceParameterControls.tsx CHANGED
@@ -117,3 +117,4 @@ export default function ForceParameterControls({
117
  );
118
  }
119
 
 
 
117
  );
118
  }
119
 
120
+
frontend/src/components/controls/ThemeToggle.tsx CHANGED
@@ -1,13 +1,30 @@
1
  /**
2
  * Toggle button for switching between light and dark themes.
3
  */
4
- import React from 'react';
 
5
  import { useFilterStore } from '../../stores/filterStore';
6
 
7
  export default function ThemeToggle() {
8
  const theme = useFilterStore((state) => state.theme);
9
  const toggleTheme = useFilterStore((state) => state.toggleTheme);
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  return (
12
  <button
13
  onClick={toggleTheme}
@@ -15,7 +32,8 @@ export default function ThemeToggle() {
15
  title={`Switch to ${theme === 'light' ? 'dark' : 'light'} mode`}
16
  aria-label={`Current theme: ${theme}. Click to switch to ${theme === 'light' ? 'dark' : 'light'} mode`}
17
  >
18
- {theme === 'light' ? 'Dark' : 'Light'}
 
19
  </button>
20
  );
21
  }
 
1
  /**
2
  * Toggle button for switching between light and dark themes.
3
  */
4
+ import React, { useEffect } from 'react';
5
+ import { Moon, Sun } from 'lucide-react';
6
  import { useFilterStore } from '../../stores/filterStore';
7
 
8
  export default function ThemeToggle() {
9
  const theme = useFilterStore((state) => state.theme);
10
  const toggleTheme = useFilterStore((state) => state.toggleTheme);
11
 
12
+ // Listen for system theme changes
13
+ useEffect(() => {
14
+ const mediaQuery = window.matchMedia('(prefers-color-scheme: dark)');
15
+ const handleChange = (e: MediaQueryListEvent) => {
16
+ // Only auto-switch if user hasn't manually set a preference
17
+ const saved = localStorage.getItem('theme');
18
+ if (!saved) {
19
+ const newTheme = e.matches ? 'dark' : 'light';
20
+ useFilterStore.getState().setTheme(newTheme);
21
+ }
22
+ };
23
+
24
+ mediaQuery.addEventListener('change', handleChange);
25
+ return () => mediaQuery.removeEventListener('change', handleChange);
26
+ }, []);
27
+
28
  return (
29
  <button
30
  onClick={toggleTheme}
 
32
  title={`Switch to ${theme === 'light' ? 'dark' : 'light'} mode`}
33
  aria-label={`Current theme: ${theme}. Click to switch to ${theme === 'light' ? 'dark' : 'light'} mode`}
34
  >
35
+ {theme === 'light' ? <Moon size={16} /> : <Sun size={16} />}
36
+ <span className="theme-toggle-label">{theme === 'light' ? 'Dark' : 'Light'}</span>
37
  </button>
38
  );
39
  }
frontend/src/components/visualizations/ForceDirectedGraph.css CHANGED
@@ -2,11 +2,15 @@
2
  position: relative;
3
  width: 100%;
4
  height: 100%;
5
- background: var(--bg-secondary, #1a1a1a);
6
  border-radius: 8px;
7
  overflow: hidden;
8
  }
9
 
 
 
 
 
10
  .force-directed-graph {
11
  width: 100%;
12
  height: 100%;
@@ -140,6 +144,44 @@
140
  justify-content: center;
141
  width: 100%;
142
  height: 100%;
143
- color: var(--text-secondary, #999);
144
  font-size: 14px;
145
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  position: relative;
3
  width: 100%;
4
  height: 100%;
5
+ background: var(--bg-secondary, #fafafa);
6
  border-radius: 8px;
7
  overflow: hidden;
8
  }
9
 
10
+ [data-theme="dark"] .force-directed-graph-container {
11
+ background: var(--bg-secondary, #1a1a1a);
12
+ }
13
+
14
  .force-directed-graph {
15
  width: 100%;
16
  height: 100%;
 
144
  justify-content: center;
145
  width: 100%;
146
  height: 100%;
147
+ color: var(--text-secondary, #666666);
148
  font-size: 14px;
149
  }
150
+
151
+ [data-theme="dark"] .graph-empty {
152
+ color: var(--text-secondary, #999999);
153
+ }
154
+
155
+ /* Performance info overlay */
156
+ .graph-performance-info {
157
+ position: absolute;
158
+ top: 10px;
159
+ right: 10px;
160
+ padding: 8px 12px;
161
+ background: rgba(255, 255, 255, 0.9);
162
+ color: var(--text-primary, #1a1a1a);
163
+ border-radius: 4px;
164
+ font-size: 12px;
165
+ font-family: monospace;
166
+ box-shadow: var(--shadow-md, 0 2px 4px rgba(0, 0, 0, 0.08));
167
+ z-index: 10;
168
+ }
169
+
170
+ [data-theme="dark"] .graph-performance-info {
171
+ background: rgba(0, 0, 0, 0.7);
172
+ color: var(--text-primary, #ffffff);
173
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.3);
174
+ }
175
+
176
+ .graph-performance-info div {
177
+ margin-bottom: 4px;
178
+ }
179
+
180
+ .graph-performance-info div:last-child {
181
+ margin-bottom: 0;
182
+ }
183
+
184
+ .graph-performance-warning {
185
+ color: #f59e0b;
186
+ font-weight: 500;
187
+ }
frontend/src/components/visualizations/ForceDirectedGraph3D.tsx CHANGED
@@ -4,12 +4,17 @@
4
  * with color-coded edges and interactive nodes in 3D space.
5
  */
6
  import React, { useMemo, useRef, useEffect, useState, useCallback } from 'react';
7
- import { Canvas, useFrame } from '@react-three/fiber';
8
  import { OrbitControls } from '@react-three/drei';
9
  import * as THREE from 'three';
10
  import { GraphNode, GraphLink, EdgeType } from './ForceDirectedGraph';
 
11
  import './ForceDirectedGraph.css';
12
 
 
 
 
 
13
  export interface ForceDirectedGraph3DProps {
14
  width: number;
15
  height: number;
@@ -25,6 +30,13 @@ export interface ForceDirectedGraph3DProps {
25
  collisionRadius?: number;
26
  nodeSizeMultiplier?: number;
27
  edgeOpacity?: number;
 
 
 
 
 
 
 
28
  }
29
 
30
  // Color scheme for different edge types
@@ -265,26 +277,72 @@ function Graph3DScene({
265
  collisionRadius = 1.0,
266
  nodeSizeMultiplier = 1.0,
267
  edgeOpacity = 0.6,
 
 
 
 
 
 
268
  }: ForceDirectedGraph3DProps) {
269
  const simulationRef = useRef<ForceSimulation3D | null>(null);
270
  const edgeRefsRef = useRef<Map<string, THREE.BufferGeometry>>(new Map());
271
  const [hoveredNodeId, setHoveredNodeId] = useState<string | null>(null);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  // Filter links based on enabled edge types
274
  const filteredLinks = useMemo((): GraphLink[] => {
275
- if (!enabledEdgeTypes || enabledEdgeTypes.size === 0) {
276
- return links;
 
 
 
 
 
 
277
  }
278
- return links.filter(link => {
279
- const linkTypes = link.edge_types || [link.edge_type];
280
- return linkTypes.some(type => enabledEdgeTypes.has(type));
281
- });
282
- }, [links, enabledEdgeTypes]);
 
 
 
 
 
 
 
 
283
 
284
  // Filter nodes to only include those connected by filtered links
285
  const filteredNodes = useMemo(() => {
286
  if (!enabledEdgeTypes || enabledEdgeTypes.size === 0) {
287
- return nodes;
288
  }
289
  const connectedNodeIds = new Set<string>();
290
  filteredLinks.forEach(link => {
@@ -297,8 +355,62 @@ function Graph3DScene({
297
  connectedNodeIds.add(sourceId);
298
  connectedNodeIds.add(targetId);
299
  });
300
- return nodes.filter(node => connectedNodeIds.has(node.id));
301
- }, [nodes, filteredLinks, enabledEdgeTypes]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
  // Initialize simulation
304
  useEffect(() => {
@@ -415,17 +527,31 @@ function Graph3DScene({
415
  {/* Nodes */}
416
  <group>
417
  {filteredNodes.map((node) => {
418
- const downloads = node.downloads || 0;
419
- const baseRadius = 0.3 + Math.sqrt(downloads) / 8000;
420
- const radius = baseRadius * nodeSizeMultiplier;
421
  const isSelected = selectedNodeId === node.id;
422
  const isHovered = hoveredNodeId === node.id;
423
-
424
- // Color by library
425
- const colors = ['#3b82f6', '#10b981', '#f59e0b', '#8b5cf6', '#ef4444', '#06b6d4'];
426
- const libraries = Array.from(new Set(filteredNodes.map(n => n.library).filter(Boolean)));
427
- const libIndex = libraries.indexOf(node.library);
428
- const color = colors[libIndex % colors.length] || '#6b7280';
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
  return (
431
  <mesh
@@ -442,11 +568,11 @@ function Graph3DScene({
442
  if (onNodeHover) onNodeHover(null);
443
  }}
444
  >
445
- <sphereGeometry args={[radius, 12, 12]} />
446
  <meshStandardMaterial
447
- color={isSelected ? '#ef4444' : isHovered ? '#fbbf24' : color}
448
- emissive={isSelected ? '#ef4444' : isHovered ? '#fbbf24' : color}
449
- emissiveIntensity={isSelected ? 0.5 : isHovered ? 0.3 : 0.1}
450
  />
451
  </mesh>
452
  );
@@ -488,6 +614,12 @@ export default function ForceDirectedGraph3D({
488
  collisionRadius = 1.0,
489
  nodeSizeMultiplier = 1.0,
490
  edgeOpacity = 0.6,
 
 
 
 
 
 
491
  }: ForceDirectedGraph3DProps) {
492
  // Calculate bounds for camera
493
  const bounds = useMemo(() => {
@@ -582,6 +714,12 @@ export default function ForceDirectedGraph3D({
582
  collisionRadius={collisionRadius}
583
  nodeSizeMultiplier={nodeSizeMultiplier}
584
  edgeOpacity={edgeOpacity}
 
 
 
 
 
 
585
  />
586
  </Canvas>
587
  </div>
 
4
  * with color-coded edges and interactive nodes in 3D space.
5
  */
6
  import React, { useMemo, useRef, useEffect, useState, useCallback } from 'react';
7
+ import { Canvas, useFrame, useThree } from '@react-three/fiber';
8
  import { OrbitControls } from '@react-three/drei';
9
  import * as THREE from 'three';
10
  import { GraphNode, GraphLink, EdgeType } from './ForceDirectedGraph';
11
+ import { getCategoricalColorMap, getContinuousColorScale, getDepthColorScale, LIBRARY_COLORS, PIPELINE_COLORS } from '../../utils/rendering/colors';
12
  import './ForceDirectedGraph.css';
13
 
14
+ export type ColorByOption = 'library' | 'pipeline' | 'downloads' | 'likes' | 'edge_type';
15
+ export type SizeByOption = 'downloads' | 'likes' | 'uniform';
16
+ export type ColorScheme = 'viridis' | 'plasma' | 'inferno' | 'coolwarm';
17
+
18
  export interface ForceDirectedGraph3DProps {
19
  width: number;
20
  height: number;
 
30
  collisionRadius?: number;
31
  nodeSizeMultiplier?: number;
32
  edgeOpacity?: number;
33
+ colorBy?: ColorByOption;
34
+ sizeBy?: SizeByOption;
35
+ colorScheme?: ColorScheme;
36
+ highlightedNodeId?: string | null;
37
+ familyFilter?: string;
38
+ searchQuery?: string;
39
+ onZoomToNode?: (nodeId: string) => void;
40
  }
41
 
42
  // Color scheme for different edge types
 
277
  collisionRadius = 1.0,
278
  nodeSizeMultiplier = 1.0,
279
  edgeOpacity = 0.6,
280
+ colorBy = 'library',
281
+ sizeBy = 'downloads',
282
+ colorScheme = 'viridis',
283
+ highlightedNodeId,
284
+ familyFilter,
285
+ searchQuery,
286
  }: ForceDirectedGraph3DProps) {
287
  const simulationRef = useRef<ForceSimulation3D | null>(null);
288
  const edgeRefsRef = useRef<Map<string, THREE.BufferGeometry>>(new Map());
289
  const [hoveredNodeId, setHoveredNodeId] = useState<string | null>(null);
290
+ const { camera, controls } = useThree();
291
+
292
+ // Filter nodes by family and search query first
293
+ const preFilteredNodes = useMemo(() => {
294
+ let result = nodes;
295
+
296
+ // Filter by family (organization prefix)
297
+ if (familyFilter && familyFilter.trim()) {
298
+ const filter = familyFilter.toLowerCase();
299
+ result = result.filter(node => {
300
+ const nodeId = node.id.toLowerCase();
301
+ return nodeId.startsWith(filter + '/') || nodeId.includes('/' + filter + '/');
302
+ });
303
+ }
304
+
305
+ // Filter by search query
306
+ if (searchQuery && searchQuery.trim()) {
307
+ const query = searchQuery.toLowerCase();
308
+ result = result.filter(node =>
309
+ node.id.toLowerCase().includes(query) ||
310
+ node.title?.toLowerCase().includes(query)
311
+ );
312
+ }
313
+
314
+ return result;
315
+ }, [nodes, familyFilter, searchQuery]);
316
 
317
  // Filter links based on enabled edge types
318
  const filteredLinks = useMemo((): GraphLink[] => {
319
+ let result = links;
320
+
321
+ // Filter by edge types
322
+ if (enabledEdgeTypes && enabledEdgeTypes.size > 0) {
323
+ result = result.filter(link => {
324
+ const linkTypes = link.edge_types || [link.edge_type];
325
+ return linkTypes.some(type => enabledEdgeTypes.has(type));
326
+ });
327
  }
328
+
329
+ // Filter to only include links where both source and target are in preFilteredNodes
330
+ if (familyFilter || searchQuery) {
331
+ const nodeIds = new Set(preFilteredNodes.map(n => n.id));
332
+ result = result.filter(link => {
333
+ const sourceId = typeof link.source === 'string' ? link.source : (link.source as GraphNode).id;
334
+ const targetId = typeof link.target === 'string' ? link.target : (link.target as GraphNode).id;
335
+ return nodeIds.has(sourceId) && nodeIds.has(targetId);
336
+ });
337
+ }
338
+
339
+ return result;
340
+ }, [links, enabledEdgeTypes, preFilteredNodes, familyFilter, searchQuery]);
341
 
342
  // Filter nodes to only include those connected by filtered links
343
  const filteredNodes = useMemo(() => {
344
  if (!enabledEdgeTypes || enabledEdgeTypes.size === 0) {
345
+ return preFilteredNodes;
346
  }
347
  const connectedNodeIds = new Set<string>();
348
  filteredLinks.forEach(link => {
 
355
  connectedNodeIds.add(sourceId);
356
  connectedNodeIds.add(targetId);
357
  });
358
+ return preFilteredNodes.filter(node => connectedNodeIds.has(node.id));
359
+ }, [preFilteredNodes, filteredLinks, enabledEdgeTypes]);
360
+
361
+ // Create color scale based on colorBy option
362
+ const getNodeColor = useMemo(() => {
363
+ if (colorBy === 'downloads' || colorBy === 'likes') {
364
+ const values = filteredNodes.map(n => colorBy === 'downloads' ? (n.downloads || 0) : (n.likes || 0));
365
+ const min = Math.min(...values, 0);
366
+ const max = Math.max(...values, 1);
367
+ const scale = getContinuousColorScale(min, max, colorScheme, true);
368
+ return (node: GraphNode) => scale(colorBy === 'downloads' ? (node.downloads || 0) : (node.likes || 0));
369
+ }
370
+
371
+ if (colorBy === 'library') {
372
+ return (node: GraphNode) => LIBRARY_COLORS[node.library?.toLowerCase() || 'unknown'] || '#6b7280';
373
+ }
374
+
375
+ if (colorBy === 'pipeline') {
376
+ return (node: GraphNode) => PIPELINE_COLORS[node.pipeline?.toLowerCase() || 'unknown'] || '#6b7280';
377
+ }
378
+
379
+ // Default: edge_type based on most common edge type connected to this node
380
+ return (node: GraphNode) => {
381
+ const nodeLinks = filteredLinks.filter(link => {
382
+ const sourceId = typeof link.source === 'string' ? link.source : link.source.id;
383
+ const targetId = typeof link.target === 'string' ? link.target : link.target.id;
384
+ return sourceId === node.id || targetId === node.id;
385
+ });
386
+ if (nodeLinks.length === 0) return '#6b7280';
387
+ const edgeType = nodeLinks[0].edge_type || 'parent';
388
+ const edgeColors: Record<string, string> = {
389
+ finetune: '#3b82f6',
390
+ quantized: '#10b981',
391
+ adapter: '#f59e0b',
392
+ merge: '#8b5cf6',
393
+ parent: '#6b7280',
394
+ };
395
+ return edgeColors[edgeType] || '#6b7280';
396
+ };
397
+ }, [colorBy, colorScheme, filteredNodes, filteredLinks]);
398
+
399
+ // Calculate node size based on sizeBy option
400
+ const getNodeSize = useCallback((node: GraphNode, baseMultiplier: number = 1.0) => {
401
+ let baseSize = 0.3;
402
+
403
+ if (sizeBy === 'downloads') {
404
+ baseSize = 0.3 + Math.sqrt(node.downloads || 0) / 8000;
405
+ } else if (sizeBy === 'likes') {
406
+ baseSize = 0.3 + Math.sqrt(node.likes || 0) / 500;
407
+ } else {
408
+ // uniform
409
+ baseSize = 0.5;
410
+ }
411
+
412
+ return baseSize * baseMultiplier;
413
+ }, [sizeBy]);
414
 
415
  // Initialize simulation
416
  useEffect(() => {
 
527
  {/* Nodes */}
528
  <group>
529
  {filteredNodes.map((node) => {
530
+ const radius = getNodeSize(node, nodeSizeMultiplier);
 
 
531
  const isSelected = selectedNodeId === node.id;
532
  const isHovered = hoveredNodeId === node.id;
533
+ const isHighlighted = highlightedNodeId === node.id;
534
+
535
+ // Get color from colorBy option
536
+ const baseColor = getNodeColor(node);
537
+
538
+ // Determine final color based on state
539
+ let finalColor = baseColor;
540
+ let emissiveIntensity = 0.1;
541
+
542
+ if (isSelected) {
543
+ finalColor = '#ef4444'; // Red for selected
544
+ emissiveIntensity = 0.5;
545
+ } else if (isHighlighted) {
546
+ finalColor = '#22d3ee'; // Cyan for highlighted (search result)
547
+ emissiveIntensity = 0.6;
548
+ } else if (isHovered) {
549
+ finalColor = '#fbbf24'; // Yellow for hovered
550
+ emissiveIntensity = 0.3;
551
+ }
552
+
553
+ // Scale up highlighted/selected nodes
554
+ const finalRadius = (isHighlighted || isSelected) ? radius * 1.5 : radius;
555
 
556
  return (
557
  <mesh
 
568
  if (onNodeHover) onNodeHover(null);
569
  }}
570
  >
571
+ <sphereGeometry args={[finalRadius, 12, 12]} />
572
  <meshStandardMaterial
573
+ color={finalColor}
574
+ emissive={finalColor}
575
+ emissiveIntensity={emissiveIntensity}
576
  />
577
  </mesh>
578
  );
 
614
  collisionRadius = 1.0,
615
  nodeSizeMultiplier = 1.0,
616
  edgeOpacity = 0.6,
617
+ colorBy = 'library',
618
+ sizeBy = 'downloads',
619
+ colorScheme = 'viridis',
620
+ highlightedNodeId,
621
+ familyFilter,
622
+ searchQuery,
623
  }: ForceDirectedGraph3DProps) {
624
  // Calculate bounds for camera
625
  const bounds = useMemo(() => {
 
714
  collisionRadius={collisionRadius}
715
  nodeSizeMultiplier={nodeSizeMultiplier}
716
  edgeOpacity={edgeOpacity}
717
+ colorBy={colorBy}
718
+ sizeBy={sizeBy}
719
+ colorScheme={colorScheme}
720
+ highlightedNodeId={highlightedNodeId}
721
+ familyFilter={familyFilter}
722
+ searchQuery={searchQuery}
723
  />
724
  </Canvas>
725
  </div>
frontend/src/components/visualizations/ForceDirectedGraph3DInstanced.tsx CHANGED
@@ -11,8 +11,13 @@ import { Canvas, useFrame, useThree } from '@react-three/fiber';
11
  import { OrbitControls } from '@react-three/drei';
12
  import * as THREE from 'three';
13
  import { GraphNode, GraphLink, EdgeType } from './ForceDirectedGraph';
 
14
  import './ForceDirectedGraph.css';
15
 
 
 
 
 
16
  export interface ForceDirectedGraph3DInstancedProps {
17
  width: number;
18
  height: number;
@@ -30,20 +35,14 @@ export interface ForceDirectedGraph3DInstancedProps {
30
  collisionRadius?: number;
31
  nodeSizeMultiplier?: number;
32
  edgeOpacity?: number;
 
 
 
 
 
 
33
  }
34
 
35
- // Color scheme for different libraries
36
- const LIBRARY_COLORS: Record<string, string> = {
37
- transformers: '#3b82f6', // Blue
38
- pytorch: '#ef4444', // Red
39
- tensorflow: '#f97316', // Orange
40
- diffusers: '#8b5cf6', // Purple
41
- 'sentence-transformers': '#10b981', // Green
42
- timm: '#06b6d4', // Cyan
43
- peft: '#ec4899', // Pink
44
- default: '#6b7280', // Gray
45
- };
46
-
47
  // Color scheme for different edge types
48
  const EDGE_COLORS: Record<EdgeType, THREE.Color> = {
49
  finetune: new THREE.Color('#3b82f6'), // Blue
@@ -54,18 +53,40 @@ const EDGE_COLORS: Record<EdgeType, THREE.Color> = {
54
  };
55
 
56
  /**
57
- * Get color for a node based on its library
58
  */
59
- function getNodeColor(library: string | undefined): THREE.Color {
60
- const colorHex = LIBRARY_COLORS[library?.toLowerCase() || ''] || LIBRARY_COLORS.default;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  return new THREE.Color(colorHex);
62
  }
63
 
64
  /**
65
- * Calculate node size based on downloads (log scale)
66
  */
67
- function getNodeSize(downloads: number): number {
68
- return 0.3 + Math.log10(Math.max(downloads, 1)) * 0.15;
 
 
 
 
 
 
 
69
  }
70
 
71
  /**
@@ -78,6 +99,12 @@ function InstancedNodes({
78
  onNodeHover,
79
  maxVisible = 500000,
80
  nodeSizeMultiplier = 1.0,
 
 
 
 
 
 
81
  }: {
82
  nodes: GraphNode[];
83
  selectedNodeId?: string | null;
@@ -85,19 +112,50 @@ function InstancedNodes({
85
  onNodeHover?: (node: GraphNode | null) => void;
86
  maxVisible?: number;
87
  nodeSizeMultiplier?: number;
 
 
 
 
 
 
88
  }) {
89
  const meshRef = useRef<THREE.InstancedMesh>(null);
90
  const { camera, raycaster, pointer } = useThree();
91
  const [hoveredIndex, setHoveredIndex] = useState<number | null>(null);
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  // Limit nodes for performance
94
  const visibleNodes = useMemo(() => {
95
- if (nodes.length <= maxVisible) return nodes;
96
  // Sort by downloads and take top N
97
- return [...nodes]
98
  .sort((a, b) => (b.downloads || 0) - (a.downloads || 0))
99
  .slice(0, maxVisible);
100
- }, [nodes, maxVisible]);
101
 
102
  // Node ID to index map for lookup
103
  const nodeIndexMap = useMemo(() => {
@@ -106,6 +164,17 @@ function InstancedNodes({
106
  return map;
107
  }, [visibleNodes]);
108
 
 
 
 
 
 
 
 
 
 
 
 
109
  // Pre-compute matrices and colors
110
  const { matrices, colors, sizes } = useMemo(() => {
111
  const matrices: THREE.Matrix4[] = [];
@@ -118,18 +187,18 @@ function InstancedNodes({
118
  const x = node.x || 0;
119
  const y = node.y || 0;
120
  const z = node.z || 0;
121
- const size = getNodeSize(node.downloads || 0) * nodeSizeMultiplier;
122
 
123
  tempMatrix.makeScale(size, size, size);
124
  tempMatrix.setPosition(x, y, z);
125
  matrices.push(tempMatrix.clone());
126
 
127
- colors.push(getNodeColor(node.library));
128
  sizes.push(size);
129
  });
130
 
131
  return { matrices, colors, sizes };
132
- }, [visibleNodes]);
133
 
134
  // Update instance attributes when data changes
135
  useEffect(() => {
@@ -141,12 +210,15 @@ function InstancedNodes({
141
  matrices.forEach((matrix, i) => {
142
  mesh.setMatrixAt(i, matrix);
143
 
144
- // Highlight selected/hovered nodes
145
  const isSelected = visibleNodes[i]?.id === selectedNodeId;
 
146
  const isHovered = i === hoveredIndex;
147
 
148
  if (isSelected) {
149
  tempColor.set('#ef4444'); // Red for selected
 
 
150
  } else if (isHovered) {
151
  tempColor.set('#fbbf24'); // Yellow for hovered
152
  } else {
@@ -158,7 +230,7 @@ function InstancedNodes({
158
 
159
  mesh.instanceMatrix.needsUpdate = true;
160
  if (mesh.instanceColor) mesh.instanceColor.needsUpdate = true;
161
- }, [matrices, colors, selectedNodeId, hoveredIndex, visibleNodes]);
162
 
163
  // Raycasting for hover/click
164
  useFrame(() => {
@@ -219,15 +291,44 @@ function Edges({
219
  enabledEdgeTypes,
220
  maxVisible = 100000,
221
  edgeOpacity = 0.6,
 
 
222
  }: {
223
  nodes: GraphNode[];
224
  links: GraphLink[];
225
  enabledEdgeTypes?: Set<EdgeType>;
226
  maxVisible?: number;
227
  edgeOpacity?: number;
 
 
228
  }) {
229
  const lineRef = useRef<THREE.LineSegments>(null);
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  // Create node lookup map
232
  const nodeMap = useMemo(() => {
233
  const map = new Map<string, GraphNode>();
@@ -239,6 +340,7 @@ function Edges({
239
  const visibleLinks = useMemo(() => {
240
  let filtered = links;
241
 
 
242
  if (enabledEdgeTypes && enabledEdgeTypes.size > 0) {
243
  filtered = links.filter(link => {
244
  const linkTypes = link.edge_types || [link.edge_type];
@@ -246,12 +348,21 @@ function Edges({
246
  });
247
  }
248
 
 
 
 
 
 
 
 
 
 
249
  if (filtered.length > maxVisible) {
250
  return filtered.slice(0, maxVisible);
251
  }
252
 
253
  return filtered;
254
- }, [links, enabledEdgeTypes, maxVisible]);
255
 
256
  // Build geometry
257
  const geometry = useMemo(() => {
@@ -316,6 +427,12 @@ function Scene({
316
  maxVisibleEdges = 100000,
317
  nodeSizeMultiplier = 1.0,
318
  edgeOpacity = 0.6,
 
 
 
 
 
 
319
  }: ForceDirectedGraph3DInstancedProps) {
320
  return (
321
  <>
@@ -325,6 +442,8 @@ function Scene({
325
  enabledEdgeTypes={enabledEdgeTypes}
326
  maxVisible={maxVisibleEdges}
327
  edgeOpacity={edgeOpacity}
 
 
328
  />
329
  <InstancedNodes
330
  nodes={nodes}
@@ -333,6 +452,12 @@ function Scene({
333
  onNodeHover={onNodeHover}
334
  maxVisible={maxVisibleNodes}
335
  nodeSizeMultiplier={nodeSizeMultiplier}
 
 
 
 
 
 
336
  />
337
  </>
338
  );
@@ -358,6 +483,12 @@ export default function ForceDirectedGraph3DInstanced({
358
  collisionRadius = 1.0,
359
  nodeSizeMultiplier = 1.0,
360
  edgeOpacity = 0.6,
 
 
 
 
 
 
361
  }: ForceDirectedGraph3DInstancedProps) {
362
  // Calculate bounds for camera positioning
363
  const bounds = useMemo(() => {
@@ -458,25 +589,21 @@ export default function ForceDirectedGraph3DInstanced({
458
  height={height}
459
  nodeSizeMultiplier={nodeSizeMultiplier}
460
  edgeOpacity={edgeOpacity}
 
 
 
 
 
 
461
  />
462
  </Canvas>
463
 
464
  {/* Performance info overlay */}
465
- <div className="graph-performance-info" style={{
466
- position: 'absolute',
467
- top: '10px',
468
- right: '10px',
469
- padding: '8px 12px',
470
- background: 'rgba(0,0,0,0.7)',
471
- color: '#fff',
472
- borderRadius: '4px',
473
- fontSize: '12px',
474
- fontFamily: 'monospace',
475
- }}>
476
  <div>Nodes: {nodes.length.toLocaleString()}</div>
477
  <div>Edges: {links.length.toLocaleString()}</div>
478
  {nodes.length > maxVisibleNodes && (
479
- <div style={{ color: '#f59e0b' }}>
480
  Showing top {maxVisibleNodes.toLocaleString()} by popularity
481
  </div>
482
  )}
 
11
  import { OrbitControls } from '@react-three/drei';
12
  import * as THREE from 'three';
13
  import { GraphNode, GraphLink, EdgeType } from './ForceDirectedGraph';
14
+ import { getContinuousColorScale, LIBRARY_COLORS, PIPELINE_COLORS } from '../../utils/rendering/colors';
15
  import './ForceDirectedGraph.css';
16
 
17
+ export type ColorByOption = 'library' | 'pipeline' | 'downloads' | 'likes' | 'edge_type';
18
+ export type SizeByOption = 'downloads' | 'likes' | 'uniform';
19
+ export type ColorScheme = 'viridis' | 'plasma' | 'inferno' | 'coolwarm';
20
+
21
  export interface ForceDirectedGraph3DInstancedProps {
22
  width: number;
23
  height: number;
 
35
  collisionRadius?: number;
36
  nodeSizeMultiplier?: number;
37
  edgeOpacity?: number;
38
+ colorBy?: ColorByOption;
39
+ sizeBy?: SizeByOption;
40
+ colorScheme?: ColorScheme;
41
+ highlightedNodeId?: string | null;
42
+ familyFilter?: string;
43
+ searchQuery?: string;
44
  }
45
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  // Color scheme for different edge types
47
  const EDGE_COLORS: Record<EdgeType, THREE.Color> = {
48
  finetune: new THREE.Color('#3b82f6'), // Blue
 
53
  };
54
 
55
  /**
56
+ * Get color for a node based on colorBy option
57
  */
58
+ function getNodeColorByOption(
59
+ node: GraphNode,
60
+ colorBy: ColorByOption,
61
+ colorScale?: (value: number) => string
62
+ ): THREE.Color {
63
+ if (colorBy === 'downloads' && colorScale) {
64
+ return new THREE.Color(colorScale(node.downloads || 0));
65
+ }
66
+ if (colorBy === 'likes' && colorScale) {
67
+ return new THREE.Color(colorScale(node.likes || 0));
68
+ }
69
+ if (colorBy === 'pipeline') {
70
+ const colorHex = PIPELINE_COLORS[node.pipeline?.toLowerCase() || 'unknown'] || '#6b7280';
71
+ return new THREE.Color(colorHex);
72
+ }
73
+ // Default: library
74
+ const colorHex = LIBRARY_COLORS[node.library?.toLowerCase() || 'unknown'] || '#6b7280';
75
  return new THREE.Color(colorHex);
76
  }
77
 
78
  /**
79
+ * Calculate node size based on sizeBy option
80
  */
81
+ function getNodeSizeByOption(node: GraphNode, sizeBy: SizeByOption): number {
82
+ if (sizeBy === 'downloads') {
83
+ return 0.3 + Math.log10(Math.max(node.downloads || 1, 1)) * 0.15;
84
+ }
85
+ if (sizeBy === 'likes') {
86
+ return 0.3 + Math.log10(Math.max(node.likes || 1, 1)) * 0.2;
87
+ }
88
+ // uniform
89
+ return 0.5;
90
  }
91
 
92
  /**
 
99
  onNodeHover,
100
  maxVisible = 500000,
101
  nodeSizeMultiplier = 1.0,
102
+ colorBy = 'library',
103
+ sizeBy = 'downloads',
104
+ colorScheme = 'viridis',
105
+ highlightedNodeId,
106
+ familyFilter,
107
+ searchQuery,
108
  }: {
109
  nodes: GraphNode[];
110
  selectedNodeId?: string | null;
 
112
  onNodeHover?: (node: GraphNode | null) => void;
113
  maxVisible?: number;
114
  nodeSizeMultiplier?: number;
115
+ colorBy?: ColorByOption;
116
+ sizeBy?: SizeByOption;
117
+ colorScheme?: ColorScheme;
118
+ highlightedNodeId?: string | null;
119
+ familyFilter?: string;
120
+ searchQuery?: string;
121
  }) {
122
  const meshRef = useRef<THREE.InstancedMesh>(null);
123
  const { camera, raycaster, pointer } = useThree();
124
  const [hoveredIndex, setHoveredIndex] = useState<number | null>(null);
125
 
126
+ // Filter nodes by family and search query first
127
+ const preFilteredNodes = useMemo(() => {
128
+ let result = nodes;
129
+
130
+ // Filter by family (organization prefix)
131
+ if (familyFilter && familyFilter.trim()) {
132
+ const filter = familyFilter.toLowerCase();
133
+ result = result.filter(node => {
134
+ const nodeId = node.id.toLowerCase();
135
+ return nodeId.startsWith(filter + '/') || nodeId.includes('/' + filter + '/');
136
+ });
137
+ }
138
+
139
+ // Filter by search query
140
+ if (searchQuery && searchQuery.trim()) {
141
+ const query = searchQuery.toLowerCase();
142
+ result = result.filter(node =>
143
+ node.id.toLowerCase().includes(query) ||
144
+ node.title?.toLowerCase().includes(query)
145
+ );
146
+ }
147
+
148
+ return result;
149
+ }, [nodes, familyFilter, searchQuery]);
150
+
151
  // Limit nodes for performance
152
  const visibleNodes = useMemo(() => {
153
+ if (preFilteredNodes.length <= maxVisible) return preFilteredNodes;
154
  // Sort by downloads and take top N
155
+ return [...preFilteredNodes]
156
  .sort((a, b) => (b.downloads || 0) - (a.downloads || 0))
157
  .slice(0, maxVisible);
158
+ }, [preFilteredNodes, maxVisible]);
159
 
160
  // Node ID to index map for lookup
161
  const nodeIndexMap = useMemo(() => {
 
164
  return map;
165
  }, [visibleNodes]);
166
 
167
+ // Create color scale for continuous colorBy options
168
+ const colorScale = useMemo(() => {
169
+ if (colorBy === 'downloads' || colorBy === 'likes') {
170
+ const values = visibleNodes.map(n => colorBy === 'downloads' ? (n.downloads || 0) : (n.likes || 0));
171
+ const min = Math.min(...values, 0);
172
+ const max = Math.max(...values, 1);
173
+ return getContinuousColorScale(min, max, colorScheme, true);
174
+ }
175
+ return undefined;
176
+ }, [colorBy, colorScheme, visibleNodes]);
177
+
178
  // Pre-compute matrices and colors
179
  const { matrices, colors, sizes } = useMemo(() => {
180
  const matrices: THREE.Matrix4[] = [];
 
187
  const x = node.x || 0;
188
  const y = node.y || 0;
189
  const z = node.z || 0;
190
+ const size = getNodeSizeByOption(node, sizeBy) * nodeSizeMultiplier;
191
 
192
  tempMatrix.makeScale(size, size, size);
193
  tempMatrix.setPosition(x, y, z);
194
  matrices.push(tempMatrix.clone());
195
 
196
+ colors.push(getNodeColorByOption(node, colorBy, colorScale));
197
  sizes.push(size);
198
  });
199
 
200
  return { matrices, colors, sizes };
201
+ }, [visibleNodes, colorBy, sizeBy, colorScale, nodeSizeMultiplier]);
202
 
203
  // Update instance attributes when data changes
204
  useEffect(() => {
 
210
  matrices.forEach((matrix, i) => {
211
  mesh.setMatrixAt(i, matrix);
212
 
213
+ // Highlight selected/hovered/highlighted nodes
214
  const isSelected = visibleNodes[i]?.id === selectedNodeId;
215
+ const isHighlighted = visibleNodes[i]?.id === highlightedNodeId;
216
  const isHovered = i === hoveredIndex;
217
 
218
  if (isSelected) {
219
  tempColor.set('#ef4444'); // Red for selected
220
+ } else if (isHighlighted) {
221
+ tempColor.set('#22d3ee'); // Cyan for highlighted (search result)
222
  } else if (isHovered) {
223
  tempColor.set('#fbbf24'); // Yellow for hovered
224
  } else {
 
230
 
231
  mesh.instanceMatrix.needsUpdate = true;
232
  if (mesh.instanceColor) mesh.instanceColor.needsUpdate = true;
233
+ }, [matrices, colors, selectedNodeId, highlightedNodeId, hoveredIndex, visibleNodes]);
234
 
235
  // Raycasting for hover/click
236
  useFrame(() => {
 
291
  enabledEdgeTypes,
292
  maxVisible = 100000,
293
  edgeOpacity = 0.6,
294
+ familyFilter,
295
+ searchQuery,
296
  }: {
297
  nodes: GraphNode[];
298
  links: GraphLink[];
299
  enabledEdgeTypes?: Set<EdgeType>;
300
  maxVisible?: number;
301
  edgeOpacity?: number;
302
+ familyFilter?: string;
303
+ searchQuery?: string;
304
  }) {
305
  const lineRef = useRef<THREE.LineSegments>(null);
306
 
307
+ // Filter nodes by family and search query first
308
+ const filteredNodeIds = useMemo(() => {
309
+ let result = nodes;
310
+
311
+ // Filter by family (organization prefix)
312
+ if (familyFilter && familyFilter.trim()) {
313
+ const filter = familyFilter.toLowerCase();
314
+ result = result.filter(node => {
315
+ const nodeId = node.id.toLowerCase();
316
+ return nodeId.startsWith(filter + '/') || nodeId.includes('/' + filter + '/');
317
+ });
318
+ }
319
+
320
+ // Filter by search query
321
+ if (searchQuery && searchQuery.trim()) {
322
+ const query = searchQuery.toLowerCase();
323
+ result = result.filter(node =>
324
+ node.id.toLowerCase().includes(query) ||
325
+ node.title?.toLowerCase().includes(query)
326
+ );
327
+ }
328
+
329
+ return new Set(result.map(n => n.id));
330
+ }, [nodes, familyFilter, searchQuery]);
331
+
332
  // Create node lookup map
333
  const nodeMap = useMemo(() => {
334
  const map = new Map<string, GraphNode>();
 
340
  const visibleLinks = useMemo(() => {
341
  let filtered = links;
342
 
343
+ // Filter by edge types
344
  if (enabledEdgeTypes && enabledEdgeTypes.size > 0) {
345
  filtered = links.filter(link => {
346
  const linkTypes = link.edge_types || [link.edge_type];
 
348
  });
349
  }
350
 
351
+ // Filter to only include links where both source and target are in filtered nodes
352
+ if (familyFilter || searchQuery) {
353
+ filtered = filtered.filter(link => {
354
+ const sourceId = typeof link.source === 'string' ? link.source : link.source?.id;
355
+ const targetId = typeof link.target === 'string' ? link.target : link.target?.id;
356
+ return filteredNodeIds.has(sourceId || '') && filteredNodeIds.has(targetId || '');
357
+ });
358
+ }
359
+
360
  if (filtered.length > maxVisible) {
361
  return filtered.slice(0, maxVisible);
362
  }
363
 
364
  return filtered;
365
+ }, [links, enabledEdgeTypes, maxVisible, familyFilter, searchQuery, filteredNodeIds]);
366
 
367
  // Build geometry
368
  const geometry = useMemo(() => {
 
427
  maxVisibleEdges = 100000,
428
  nodeSizeMultiplier = 1.0,
429
  edgeOpacity = 0.6,
430
+ colorBy = 'library',
431
+ sizeBy = 'downloads',
432
+ colorScheme = 'viridis',
433
+ highlightedNodeId,
434
+ familyFilter,
435
+ searchQuery,
436
  }: ForceDirectedGraph3DInstancedProps) {
437
  return (
438
  <>
 
442
  enabledEdgeTypes={enabledEdgeTypes}
443
  maxVisible={maxVisibleEdges}
444
  edgeOpacity={edgeOpacity}
445
+ familyFilter={familyFilter}
446
+ searchQuery={searchQuery}
447
  />
448
  <InstancedNodes
449
  nodes={nodes}
 
452
  onNodeHover={onNodeHover}
453
  maxVisible={maxVisibleNodes}
454
  nodeSizeMultiplier={nodeSizeMultiplier}
455
+ colorBy={colorBy}
456
+ sizeBy={sizeBy}
457
+ colorScheme={colorScheme}
458
+ highlightedNodeId={highlightedNodeId}
459
+ familyFilter={familyFilter}
460
+ searchQuery={searchQuery}
461
  />
462
  </>
463
  );
 
483
  collisionRadius = 1.0,
484
  nodeSizeMultiplier = 1.0,
485
  edgeOpacity = 0.6,
486
+ colorBy = 'library',
487
+ sizeBy = 'downloads',
488
+ colorScheme = 'viridis',
489
+ highlightedNodeId,
490
+ familyFilter,
491
+ searchQuery,
492
  }: ForceDirectedGraph3DInstancedProps) {
493
  // Calculate bounds for camera positioning
494
  const bounds = useMemo(() => {
 
589
  height={height}
590
  nodeSizeMultiplier={nodeSizeMultiplier}
591
  edgeOpacity={edgeOpacity}
592
+ colorBy={colorBy}
593
+ sizeBy={sizeBy}
594
+ colorScheme={colorScheme}
595
+ highlightedNodeId={highlightedNodeId}
596
+ familyFilter={familyFilter}
597
+ searchQuery={searchQuery}
598
  />
599
  </Canvas>
600
 
601
  {/* Performance info overlay */}
602
+ <div className="graph-performance-info">
 
 
 
 
 
 
 
 
 
 
603
  <div>Nodes: {nodes.length.toLocaleString()}</div>
604
  <div>Edges: {links.length.toLocaleString()}</div>
605
  {nodes.length > maxVisibleNodes && (
606
+ <div className="graph-performance-warning">
607
  Showing top {maxVisibleNodes.toLocaleString()} by popularity
608
  </div>
609
  )}
frontend/src/stores/filterStore.ts CHANGED
@@ -57,17 +57,20 @@ export interface FilterState {
57
  getActiveFilterCount: () => number;
58
  }
59
 
60
- // Load theme from localStorage or default to light
61
  const getInitialTheme = (): Theme => {
62
  if (typeof window !== 'undefined') {
63
  const saved = localStorage.getItem('theme');
64
  if (saved === 'dark' || saved === 'light') return saved;
65
- // Check system preference
66
  if (window.matchMedia && window.matchMedia('(prefers-color-scheme: dark)').matches) {
67
  return 'dark';
68
  }
 
 
 
69
  }
70
- return 'light';
71
  };
72
 
73
  export const useFilterStore = create<FilterState>((set, get) => ({
@@ -111,6 +114,7 @@ export const useFilterStore = create<FilterState>((set, get) => ({
111
  set({ theme });
112
  if (typeof window !== 'undefined') {
113
  localStorage.setItem('theme', theme);
 
114
  document.documentElement.setAttribute('data-theme', theme);
115
  }
116
  },
 
57
  getActiveFilterCount: () => number;
58
  }
59
 
60
+ // Load theme from localStorage or system preference, default to dark
61
  const getInitialTheme = (): Theme => {
62
  if (typeof window !== 'undefined') {
63
  const saved = localStorage.getItem('theme');
64
  if (saved === 'dark' || saved === 'light') return saved;
65
+ // Check system preference (prefers-color-scheme)
66
  if (window.matchMedia && window.matchMedia('(prefers-color-scheme: dark)').matches) {
67
  return 'dark';
68
  }
69
+ // If system prefers light, check if user has explicitly set dark before
70
+ // Otherwise default to dark mode
71
+ return 'dark';
72
  }
73
+ return 'dark';
74
  };
75
 
76
  export const useFilterStore = create<FilterState>((set, get) => ({
 
114
  set({ theme });
115
  if (typeof window !== 'undefined') {
116
  localStorage.setItem('theme', theme);
117
+ localStorage.setItem('theme-preference-set', 'true');
118
  document.documentElement.setAttribute('data-theme', theme);
119
  }
120
  },
frontend/src/utils/api/graphApi.ts CHANGED
@@ -22,6 +22,7 @@ export interface NetworkGraphResponse {
22
 
23
  /**
24
  * Fetch family network graph for a specific model
 
25
  */
26
  export async function fetchFamilyNetwork(
27
  modelId: string,
@@ -46,33 +47,84 @@ export async function fetchFamilyNetwork(
46
 
47
  const url = `${API_BASE}/api/network/family/${encodeURIComponent(modelId)}${params.toString() ? '?' + params.toString() : ''}`;
48
 
49
- const response = await fetch(url);
50
- if (!response.ok) {
51
- throw new Error(`Failed to fetch network graph: ${response.statusText}`);
52
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- const data = await response.json();
 
 
 
 
 
 
 
 
 
55
 
56
- // Transform the response to match our types
57
- return {
58
- nodes: data.nodes || [],
59
- links: data.links || [],
60
- statistics: data.statistics,
61
- root_model: data.root_model || modelId,
62
- };
63
  }
64
 
65
  /**
66
  * Fetch full derivative network graph for ALL models in the database
 
 
 
67
  */
68
  export async function fetchFullDerivativeNetwork(
69
  options: {
70
  edgeTypes?: EdgeType[];
71
  includeEdgeAttributes?: boolean;
 
 
 
72
  } = {}
73
  ): Promise<NetworkGraphResponse> {
74
  // Default to false for performance with large graphs
75
- const { edgeTypes, includeEdgeAttributes = false } = options;
 
 
 
 
 
 
76
 
77
  const params = new URLSearchParams();
78
  if (edgeTypes && edgeTypes.length > 0) {
@@ -81,38 +133,92 @@ export async function fetchFullDerivativeNetwork(
81
  if (includeEdgeAttributes !== undefined) {
82
  params.append('include_edge_attributes', includeEdgeAttributes.toString());
83
  }
 
 
 
 
 
 
 
 
 
84
 
85
  const url = `${API_BASE}/api/network/full-derivatives${params.toString() ? '?' + params.toString() : ''}`;
86
 
87
- let response: Response;
88
- try {
89
- response = await fetch(url);
90
- } catch (error: any) {
91
- throw new Error(`Network error: ${error.message || 'Failed to connect to server'}`);
92
- }
93
 
94
- if (!response.ok) {
95
- let errorMessage = `Failed to fetch full derivative network: ${response.statusText}`;
96
  try {
97
- const errorData = await response.json();
98
- if (errorData.detail) {
99
- errorMessage = errorData.detail;
 
100
  }
101
- } catch {
102
- // If response is not JSON, use status text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  }
104
- throw new Error(errorMessage);
105
- }
106
 
107
- const data = await response.json();
 
 
 
 
 
 
 
 
 
108
 
109
- // Transform the response to match our types
110
- return {
111
- nodes: data.nodes || [],
112
- links: data.links || [],
113
- statistics: data.statistics,
114
- root_model: '', // No root model for full network
115
- };
116
  }
117
 
118
  /**
 
22
 
23
  /**
24
  * Fetch family network graph for a specific model
25
+ * Includes retry logic for rate limiting (429 errors)
26
  */
27
  export async function fetchFamilyNetwork(
28
  modelId: string,
 
47
 
48
  const url = `${API_BASE}/api/network/family/${encodeURIComponent(modelId)}${params.toString() ? '?' + params.toString() : ''}`;
49
 
50
+ // Retry logic for rate limiting
51
+ const maxRetries = 3;
52
+ const baseDelay = 2000;
53
+
54
+ for (let attempt = 0; attempt < maxRetries; attempt++) {
55
+ const response = await fetch(url);
56
+
57
+ // Handle 429 (Too Many Requests) with exponential backoff
58
+ if (response.status === 429) {
59
+ if (attempt === maxRetries - 1) {
60
+ let errorMessage = 'Rate limit exceeded. Please wait a moment and try again.';
61
+ try {
62
+ const errorData = await response.json();
63
+ if (errorData.detail) {
64
+ errorMessage = errorData.detail;
65
+ }
66
+ const retryAfter = response.headers.get('Retry-After');
67
+ if (retryAfter) {
68
+ errorMessage += ` Please wait ${retryAfter} seconds.`;
69
+ }
70
+ } catch {
71
+ // If response is not JSON, use default message
72
+ }
73
+ throw new Error(errorMessage);
74
+ }
75
+
76
+ const retryAfter = response.headers.get('Retry-After');
77
+ const delay = retryAfter
78
+ ? parseInt(retryAfter) * 1000
79
+ : baseDelay * Math.pow(2, attempt) + Math.random() * 1000;
80
+
81
+ console.warn(`Rate limit hit (429). Retrying in ${Math.round(delay / 1000)}s... (attempt ${attempt + 1}/${maxRetries})`);
82
+ await new Promise(resolve => setTimeout(resolve, delay));
83
+ continue;
84
+ }
85
+
86
+ if (!response.ok) {
87
+ throw new Error(`Failed to fetch network graph: ${response.statusText}`);
88
+ }
89
 
90
+ const data = await response.json();
91
+
92
+ // Transform the response to match our types
93
+ return {
94
+ nodes: data.nodes || [],
95
+ links: data.links || [],
96
+ statistics: data.statistics,
97
+ root_model: data.root_model || modelId,
98
+ };
99
+ }
100
 
101
+ // Should never reach here
102
+ throw new Error('Failed to fetch network graph after retries');
 
 
 
 
 
103
  }
104
 
105
  /**
106
  * Fetch full derivative network graph for ALL models in the database
107
+ * Includes retry logic for rate limiting (429 errors)
108
+ *
109
+ * Use minDownloads and maxNodes to reduce network size for better performance.
110
  */
111
  export async function fetchFullDerivativeNetwork(
112
  options: {
113
  edgeTypes?: EdgeType[];
114
  includeEdgeAttributes?: boolean;
115
+ minDownloads?: number;
116
+ maxNodes?: number;
117
+ usePrecomputed?: boolean;
118
  } = {}
119
  ): Promise<NetworkGraphResponse> {
120
  // Default to false for performance with large graphs
121
+ const {
122
+ edgeTypes,
123
+ includeEdgeAttributes = false,
124
+ minDownloads = 0,
125
+ maxNodes,
126
+ usePrecomputed = true
127
+ } = options;
128
 
129
  const params = new URLSearchParams();
130
  if (edgeTypes && edgeTypes.length > 0) {
 
133
  if (includeEdgeAttributes !== undefined) {
134
  params.append('include_edge_attributes', includeEdgeAttributes.toString());
135
  }
136
+ if (minDownloads > 0) {
137
+ params.append('min_downloads', minDownloads.toString());
138
+ }
139
+ if (maxNodes !== undefined) {
140
+ params.append('max_nodes', maxNodes.toString());
141
+ }
142
+ if (usePrecomputed !== undefined) {
143
+ params.append('use_precomputed', usePrecomputed.toString());
144
+ }
145
 
146
  const url = `${API_BASE}/api/network/full-derivatives${params.toString() ? '?' + params.toString() : ''}`;
147
 
148
+ // Retry logic for rate limiting
149
+ const maxRetries = 3;
150
+ const baseDelay = 2000; // Start with 2 seconds
 
 
 
151
 
152
+ for (let attempt = 0; attempt < maxRetries; attempt++) {
153
+ let response: Response;
154
  try {
155
+ response = await fetch(url);
156
+ } catch (error: any) {
157
+ if (attempt === maxRetries - 1) {
158
+ throw new Error(`Network error: ${error.message || 'Failed to connect to server'}`);
159
  }
160
+ // Wait before retrying
161
+ await new Promise(resolve => setTimeout(resolve, baseDelay * (attempt + 1)));
162
+ continue;
163
+ }
164
+
165
+ // Handle 429 (Too Many Requests) with exponential backoff
166
+ if (response.status === 429) {
167
+ if (attempt === maxRetries - 1) {
168
+ let errorMessage = 'Rate limit exceeded. Please wait a moment and try again.';
169
+ try {
170
+ const errorData = await response.json();
171
+ if (errorData.detail) {
172
+ errorMessage = errorData.detail;
173
+ }
174
+ // Check for Retry-After header
175
+ const retryAfter = response.headers.get('Retry-After');
176
+ if (retryAfter) {
177
+ errorMessage += ` Please wait ${retryAfter} seconds.`;
178
+ }
179
+ } catch {
180
+ // If response is not JSON, use default message
181
+ }
182
+ throw new Error(errorMessage);
183
+ }
184
+
185
+ // Calculate delay: exponential backoff with jitter
186
+ const retryAfter = response.headers.get('Retry-After');
187
+ const delay = retryAfter
188
+ ? parseInt(retryAfter) * 1000
189
+ : baseDelay * Math.pow(2, attempt) + Math.random() * 1000;
190
+
191
+ console.warn(`Rate limit hit (429). Retrying in ${Math.round(delay / 1000)}s... (attempt ${attempt + 1}/${maxRetries})`);
192
+ await new Promise(resolve => setTimeout(resolve, delay));
193
+ continue;
194
+ }
195
+
196
+ if (!response.ok) {
197
+ let errorMessage = `Failed to fetch full derivative network: ${response.statusText}`;
198
+ try {
199
+ const errorData = await response.json();
200
+ if (errorData.detail) {
201
+ errorMessage = errorData.detail;
202
+ }
203
+ } catch {
204
+ // If response is not JSON, use status text
205
+ }
206
+ throw new Error(errorMessage);
207
  }
 
 
208
 
209
+ const data = await response.json();
210
+
211
+ // Transform the response to match our types
212
+ return {
213
+ nodes: data.nodes || [],
214
+ links: data.links || [],
215
+ statistics: data.statistics,
216
+ root_model: '', // No root model for full network
217
+ };
218
+ }
219
 
220
+ // Should never reach here, but TypeScript needs it
221
+ throw new Error('Failed to fetch full derivative network after retries');
 
 
 
 
 
222
  }
223
 
224
  /**
precomputed_data/network_metadata.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "created_at": "2026-01-11T03:06:01.412270",
3
+ "version": "v1",
4
+ "nodes": 1860411,
5
+ "edges": 573653,
6
+ "include_edge_attributes": false,
7
+ "min_downloads": 0,
8
+ "max_nodes": null,
9
+ "file_size_mb": 281.78
10
+ }
requirements.txt CHANGED
@@ -7,3 +7,4 @@
7
  # Additional Space-specific requirements
8
  gradio>=4.0.0
9
 
 
 
7
  # Additional Space-specific requirements
8
  gradio>=4.0.0
9
 
10
+
upload_to_hf_dataset.py CHANGED
@@ -130,3 +130,4 @@ if __name__ == "__main__":
130
  token=args.token
131
  )
132
 
 
 
130
  token=args.token
131
  )
132
 
133
+