midah commited on
Commit
2ba2072
·
1 Parent(s): 2f01c6a

Add 3D force-directed graph with instanced rendering and pre-computed positions

Browse files

- Add ForceDirectedGraph3DInstanced.tsx for rendering millions of nodes efficiently
- Add precompute_force_layout.py script to generate positions offline
- Update backend to serve pre-computed positions in full-derivatives endpoint
- Update GraphPage to use instanced rendering for large graphs (>10k nodes)
- Include force_layout_3d.pkl with 1.86M pre-computed node positions

backend/api/main.py CHANGED
@@ -1618,7 +1618,8 @@ async def get_family_network(
1618
  @cached_response(ttl=3600, key_prefix="full_derivatives_network")
1619
  async def get_full_derivative_network(
1620
  edge_types: Optional[str] = Query(None, description="Comma-separated list of edge types to include (finetune,quantized,adapter,merge,parent). If None, includes all types."),
1621
- include_edge_attributes: bool = Query(False, description="Whether to include edge attributes (change in likes, downloads, etc.). Default False for performance.")
 
1622
  ):
1623
  """
1624
  Build full derivative relationship network for ALL models in the database.
@@ -1626,6 +1627,7 @@ async def get_full_derivative_network(
1626
  This computes over every single model in the database.
1627
 
1628
  Note: Edge attributes are disabled by default for performance with large datasets.
 
1629
  """
1630
  if df is None:
1631
  raise DataNotLoadedError()
@@ -1652,10 +1654,26 @@ async def get_full_derivative_network(
1652
  build_time = time.time() - start_time
1653
  logger.info(f"Graph built in {build_time:.2f}s: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges")
1654
 
1655
- # Build nodes list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1656
  nodes = []
1657
  for node_id, attrs in graph.nodes(data=True):
1658
- nodes.append({
1659
  "id": node_id,
1660
  "title": attrs.get('title', node_id),
1661
  "freq": attrs.get('freq', 0),
@@ -1663,7 +1681,16 @@ async def get_full_derivative_network(
1663
  "downloads": attrs.get('downloads', 0),
1664
  "library": attrs.get('library', ''),
1665
  "pipeline": attrs.get('pipeline', '')
1666
- })
 
 
 
 
 
 
 
 
 
1667
 
1668
  logger.info(f"Processed {len(nodes):,} nodes")
1669
 
 
1618
  @cached_response(ttl=3600, key_prefix="full_derivatives_network")
1619
  async def get_full_derivative_network(
1620
  edge_types: Optional[str] = Query(None, description="Comma-separated list of edge types to include (finetune,quantized,adapter,merge,parent). If None, includes all types."),
1621
+ include_edge_attributes: bool = Query(False, description="Whether to include edge attributes (change in likes, downloads, etc.). Default False for performance."),
1622
+ include_positions: bool = Query(True, description="Whether to include pre-computed 3D positions for each node. Default True for faster rendering.")
1623
  ):
1624
  """
1625
  Build full derivative relationship network for ALL models in the database.
 
1627
  This computes over every single model in the database.
1628
 
1629
  Note: Edge attributes are disabled by default for performance with large datasets.
1630
+ If pre-computed positions exist, they will be included in the response.
1631
  """
1632
  if df is None:
1633
  raise DataNotLoadedError()
 
1654
  build_time = time.time() - start_time
1655
  logger.info(f"Graph built in {build_time:.2f}s: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges")
1656
 
1657
+ # Load pre-computed positions if available
1658
+ precomputed_positions = {}
1659
+ if include_positions:
1660
+ try:
1661
+ backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
1662
+ root_dir = os.path.dirname(backend_dir)
1663
+ layout_file = os.path.join(root_dir, "precomputed_data", "force_layout_3d.pkl")
1664
+
1665
+ if os.path.exists(layout_file):
1666
+ with open(layout_file, 'rb') as f:
1667
+ layout_data = pickle.load(f)
1668
+ precomputed_positions = layout_data.get('positions', {})
1669
+ logger.info(f"Loaded {len(precomputed_positions):,} pre-computed positions")
1670
+ except Exception as e:
1671
+ logger.warning(f"Could not load pre-computed positions: {e}")
1672
+
1673
+ # Build nodes list with optional pre-computed positions
1674
  nodes = []
1675
  for node_id, attrs in graph.nodes(data=True):
1676
+ node_data = {
1677
  "id": node_id,
1678
  "title": attrs.get('title', node_id),
1679
  "freq": attrs.get('freq', 0),
 
1681
  "downloads": attrs.get('downloads', 0),
1682
  "library": attrs.get('library', ''),
1683
  "pipeline": attrs.get('pipeline', '')
1684
+ }
1685
+
1686
+ # Add pre-computed position if available
1687
+ if node_id in precomputed_positions:
1688
+ pos = precomputed_positions[node_id]
1689
+ node_data['x'] = pos[0]
1690
+ node_data['y'] = pos[1]
1691
+ node_data['z'] = pos[2]
1692
+
1693
+ nodes.append(node_data)
1694
 
1695
  logger.info(f"Processed {len(nodes):,} nodes")
1696
 
backend/scripts/precompute_force_layout.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pre-compute force-directed layout positions for the full model network.
3
+ Uses graph-tool or networkx with Barnes-Hut optimization for large-scale layouts.
4
+
5
+ This script generates x, y, z coordinates for all nodes so the frontend
6
+ doesn't need to compute force simulation in real-time.
7
+
8
+ Usage:
9
+ python precompute_force_layout.py [--output force_layout.pkl] [--3d]
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import time
15
+ import pickle
16
+ import argparse
17
+ import logging
18
+ from pathlib import Path
19
+ from typing import Dict, Tuple, Optional
20
+ import numpy as np
21
+
22
+ # Add backend to path
23
+ sys.path.insert(0, str(Path(__file__).parent.parent))
24
+
25
+ logging.basicConfig(
26
+ level=logging.INFO,
27
+ format='%(asctime)s - %(levelname)s - %(message)s'
28
+ )
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def load_model_data() -> 'pd.DataFrame':
33
+ """Load model data from precomputed parquet or CSV."""
34
+ import pandas as pd
35
+
36
+ backend_dir = Path(__file__).parent.parent
37
+ root_dir = backend_dir.parent
38
+
39
+ # Try precomputed data first
40
+ precomputed_dir = root_dir / "precomputed_data"
41
+ if precomputed_dir.exists():
42
+ parquet_files = list(precomputed_dir.glob("*.parquet"))
43
+ if parquet_files:
44
+ logger.info(f"Loading from precomputed parquet: {parquet_files[0]}")
45
+ return pd.read_parquet(parquet_files[0])
46
+
47
+ # Try CSV data
48
+ csv_path = precomputed_dir / "models.csv"
49
+ if csv_path.exists():
50
+ logger.info(f"Loading from CSV: {csv_path}")
51
+ return pd.read_csv(csv_path)
52
+
53
+ # Try data directory
54
+ data_dir = root_dir / "data"
55
+ if data_dir.exists():
56
+ csv_files = list(data_dir.glob("*.csv"))
57
+ for csv_file in csv_files:
58
+ if "model" in csv_file.name.lower():
59
+ logger.info(f"Loading from {csv_file}")
60
+ return pd.read_csv(csv_file)
61
+
62
+ raise FileNotFoundError("No model data found")
63
+
64
+
65
+ def load_existing_graph(graph_path: str = None) -> Optional['nx.DiGraph']:
66
+ """Load pre-existing networkx graph from pickle file."""
67
+ import networkx as nx
68
+
69
+ if graph_path and Path(graph_path).exists():
70
+ logger.info(f"Loading existing graph from {graph_path}")
71
+ with open(graph_path, 'rb') as f:
72
+ return pickle.load(f)
73
+
74
+ # Search for graph file
75
+ search_paths = [
76
+ Path(__file__).parent.parent.parent / "ai-ecosystem" / "data" / "ai_ecosystem_graph.pkl",
77
+ Path(__file__).parent.parent.parent.parent / "ai-ecosystem" / "data" / "ai_ecosystem_graph.pkl",
78
+ Path.home() / "ai-ecosystem-v2" / "ai-ecosystem" / "data" / "ai_ecosystem_graph.pkl",
79
+ ]
80
+
81
+ for path in search_paths:
82
+ if path.exists():
83
+ logger.info(f"Found existing graph at {path}")
84
+ with open(path, 'rb') as f:
85
+ return pickle.load(f)
86
+
87
+ return None
88
+
89
+
90
+ def build_network_graph(df: 'pd.DataFrame') -> 'nx.DiGraph':
91
+ """Build network graph from model dataframe."""
92
+ import networkx as nx
93
+
94
+ logger.info(f"Building network from {len(df):,} models...")
95
+ G = nx.DiGraph()
96
+
97
+ # Add all models as nodes
98
+ for _, row in df.iterrows():
99
+ model_id = str(row.get('model_id', row.get('modelId', '')))
100
+ if not model_id:
101
+ continue
102
+
103
+ G.add_node(model_id,
104
+ downloads=row.get('downloads', 0),
105
+ likes=row.get('likes', 0),
106
+ library=row.get('library_name', row.get('library', '')),
107
+ pipeline=row.get('pipeline_tag', '')
108
+ )
109
+
110
+ # Add edges based on parent relationships
111
+ edge_count = 0
112
+ for _, row in df.iterrows():
113
+ model_id = str(row.get('model_id', row.get('modelId', '')))
114
+ parent_id = row.get('parent_model', row.get('base_model', None))
115
+
116
+ if not model_id:
117
+ continue
118
+
119
+ if pd.notna(parent_id) and str(parent_id).strip() and str(parent_id) != 'nan':
120
+ parent_id = str(parent_id).strip()
121
+ if parent_id in G.nodes:
122
+ G.add_edge(parent_id, model_id, edge_type='derivative')
123
+ edge_count += 1
124
+
125
+ logger.info(f"Network: {G.number_of_nodes():,} nodes, {edge_count:,} edges")
126
+ return G
127
+
128
+
129
+ def compute_force_layout_3d(
130
+ G: 'nx.Graph',
131
+ iterations: int = 100,
132
+ seed: int = 42
133
+ ) -> Dict[str, Tuple[float, float, float]]:
134
+ """
135
+ Compute 3D force-directed layout using networkx spring_layout.
136
+ For very large graphs, uses Barnes-Hut approximation.
137
+ """
138
+ import networkx as nx
139
+
140
+ n_nodes = G.number_of_nodes()
141
+ logger.info(f"Computing 3D layout for {n_nodes:,} nodes...")
142
+
143
+ if n_nodes == 0:
144
+ return {}
145
+
146
+ start_time = time.time()
147
+
148
+ # For large graphs, compute layout on largest connected component first
149
+ if n_nodes > 100000:
150
+ logger.info("Large graph detected - using optimized approach...")
151
+
152
+ # Get largest connected component (treat as undirected)
153
+ if isinstance(G, nx.DiGraph):
154
+ G_undirected = G.to_undirected()
155
+ else:
156
+ G_undirected = G
157
+
158
+ components = list(nx.connected_components(G_undirected))
159
+ components.sort(key=len, reverse=True)
160
+
161
+ logger.info(f"Found {len(components):,} connected components")
162
+
163
+ # Compute layouts for each component
164
+ positions = {}
165
+ offset_x = 0
166
+
167
+ for i, component in enumerate(components):
168
+ if len(component) < 2:
169
+ # Isolated nodes - place randomly
170
+ for node in component:
171
+ positions[node] = (
172
+ offset_x + np.random.randn() * 10,
173
+ np.random.randn() * 100,
174
+ np.random.randn() * 100
175
+ )
176
+ continue
177
+
178
+ subgraph = G_undirected.subgraph(component)
179
+
180
+ # Use spring layout with reduced iterations for large components
181
+ iter_count = min(iterations, max(20, 100 - len(component) // 10000))
182
+
183
+ logger.info(f" Component {i+1}/{len(components)}: {len(component):,} nodes, {iter_count} iterations")
184
+
185
+ try:
186
+ # 3D layout using spring_layout
187
+ pos_2d = nx.spring_layout(
188
+ subgraph,
189
+ dim=3,
190
+ k=1.0 / np.sqrt(len(component)),
191
+ iterations=iter_count,
192
+ seed=seed + i,
193
+ scale=100 * np.log10(max(len(component), 10))
194
+ )
195
+
196
+ # Apply offset to separate components
197
+ for node, (x, y, z) in pos_2d.items():
198
+ positions[node] = (x + offset_x, y, z)
199
+
200
+ # Move offset for next component
201
+ offset_x += 300 * np.log10(max(len(component), 10))
202
+
203
+ except Exception as e:
204
+ logger.warning(f"Layout failed for component {i}: {e}")
205
+ # Fallback: random positions
206
+ for node in component:
207
+ positions[node] = (
208
+ offset_x + np.random.randn() * 50,
209
+ np.random.randn() * 50,
210
+ np.random.randn() * 50
211
+ )
212
+ else:
213
+ # Standard approach for smaller graphs
214
+ try:
215
+ positions_raw = nx.spring_layout(
216
+ G.to_undirected() if isinstance(G, nx.DiGraph) else G,
217
+ dim=3,
218
+ k=2.0 / np.sqrt(n_nodes) if n_nodes > 0 else 1.0,
219
+ iterations=iterations,
220
+ seed=seed,
221
+ scale=200
222
+ )
223
+ positions = {node: tuple(pos) for node, pos in positions_raw.items()}
224
+ except Exception as e:
225
+ logger.warning(f"Spring layout failed: {e}, using random positions")
226
+ np.random.seed(seed)
227
+ positions = {
228
+ node: (np.random.randn() * 100, np.random.randn() * 100, np.random.randn() * 100)
229
+ for node in G.nodes()
230
+ }
231
+
232
+ elapsed = time.time() - start_time
233
+ logger.info(f"Layout computed in {elapsed:.1f}s")
234
+
235
+ return positions
236
+
237
+
238
+ def compute_force_layout_fa2(
239
+ G: 'nx.Graph',
240
+ iterations: int = 100,
241
+ seed: int = 42
242
+ ) -> Dict[str, Tuple[float, float, float]]:
243
+ """
244
+ Compute layout using ForceAtlas2 algorithm (faster for large graphs).
245
+ Falls back to spring_layout if fa2 not available.
246
+ """
247
+ try:
248
+ from fa2 import ForceAtlas2
249
+
250
+ n_nodes = G.number_of_nodes()
251
+ logger.info(f"Computing FA2 layout for {n_nodes:,} nodes...")
252
+
253
+ if n_nodes == 0:
254
+ return {}
255
+
256
+ # Convert to undirected for layout
257
+ if isinstance(G, nx.DiGraph):
258
+ import networkx as nx
259
+ G_layout = G.to_undirected()
260
+ else:
261
+ G_layout = G
262
+
263
+ # Initialize ForceAtlas2
264
+ fa2 = ForceAtlas2(
265
+ outboundAttractionDistribution=True,
266
+ linLogMode=False,
267
+ adjustSizes=False,
268
+ edgeWeightInfluence=1.0,
269
+ jitterTolerance=1.0,
270
+ barnesHutOptimize=True,
271
+ barnesHutTheta=1.2,
272
+ multiThreaded=False,
273
+ scalingRatio=2.0,
274
+ strongGravityMode=False,
275
+ gravity=1.0,
276
+ verbose=False
277
+ )
278
+
279
+ # Compute 2D positions
280
+ positions_2d = fa2.forceatlas2_networkx_layout(
281
+ G_layout,
282
+ iterations=iterations
283
+ )
284
+
285
+ # Add 3rd dimension based on hierarchy/properties
286
+ np.random.seed(seed)
287
+ positions = {}
288
+ for node, (x, y) in positions_2d.items():
289
+ # Z based on downloads (popular models higher)
290
+ downloads = G.nodes[node].get('downloads', 0) if node in G.nodes else 0
291
+ z = np.log10(max(downloads, 1)) * 10 + np.random.randn() * 5
292
+ positions[node] = (x * 100, y * 100, z)
293
+
294
+ return positions
295
+
296
+ except ImportError:
297
+ logger.warning("fa2 not installed, falling back to spring_layout")
298
+ return compute_force_layout_3d(G, iterations, seed)
299
+
300
+
301
+ def save_layout(
302
+ positions: Dict[str, Tuple[float, float, float]],
303
+ output_path: str,
304
+ graph: 'nx.Graph' = None
305
+ ):
306
+ """Save layout positions to pickle file."""
307
+
308
+ data = {
309
+ 'positions': positions,
310
+ 'n_nodes': len(positions),
311
+ 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
312
+ }
313
+
314
+ if graph is not None:
315
+ data['n_edges'] = graph.number_of_edges()
316
+
317
+ # Calculate bounds
318
+ if positions:
319
+ xs = [p[0] for p in positions.values()]
320
+ ys = [p[1] for p in positions.values()]
321
+ zs = [p[2] for p in positions.values()]
322
+ data['bounds'] = {
323
+ 'x_min': min(xs), 'x_max': max(xs),
324
+ 'y_min': min(ys), 'y_max': max(ys),
325
+ 'z_min': min(zs), 'z_max': max(zs),
326
+ }
327
+
328
+ with open(output_path, 'wb') as f:
329
+ pickle.dump(data, f)
330
+
331
+ logger.info(f"Saved layout to {output_path}")
332
+ logger.info(f" Nodes: {len(positions):,}")
333
+ if 'bounds' in data:
334
+ b = data['bounds']
335
+ logger.info(f" Bounds: X[{b['x_min']:.1f}, {b['x_max']:.1f}], Y[{b['y_min']:.1f}, {b['y_max']:.1f}], Z[{b['z_min']:.1f}, {b['z_max']:.1f}]")
336
+
337
+
338
+ def main():
339
+ parser = argparse.ArgumentParser(description='Pre-compute force-directed layout')
340
+ parser.add_argument('--output', '-o', type=str, default='force_layout_3d.pkl',
341
+ help='Output pickle file path')
342
+ parser.add_argument('--iterations', '-i', type=int, default=100,
343
+ help='Number of layout iterations')
344
+ parser.add_argument('--algorithm', '-a', choices=['spring', 'fa2'], default='spring',
345
+ help='Layout algorithm to use')
346
+ parser.add_argument('--seed', '-s', type=int, default=42,
347
+ help='Random seed for reproducibility')
348
+ parser.add_argument('--graph', '-g', type=str, default=None,
349
+ help='Path to existing networkx graph pickle file')
350
+
351
+ args = parser.parse_args()
352
+
353
+ # Determine output path
354
+ backend_dir = Path(__file__).parent.parent
355
+ root_dir = backend_dir.parent
356
+ precomputed_dir = root_dir / "precomputed_data"
357
+ precomputed_dir.mkdir(exist_ok=True)
358
+
359
+ output_path = precomputed_dir / args.output
360
+
361
+ logger.info("=" * 60)
362
+ logger.info("Pre-computing Force-Directed Layout")
363
+ logger.info("=" * 60)
364
+
365
+ # Try to load existing graph first (faster)
366
+ G = load_existing_graph(args.graph)
367
+
368
+ if G is None:
369
+ # Load data and build graph
370
+ df = load_model_data()
371
+ logger.info(f"Loaded {len(df):,} models")
372
+ G = build_network_graph(df)
373
+ else:
374
+ logger.info(f"Using existing graph: {G.number_of_nodes():,} nodes, {G.number_of_edges():,} edges")
375
+
376
+ # Compute layout
377
+ if args.algorithm == 'fa2':
378
+ positions = compute_force_layout_fa2(G, args.iterations, args.seed)
379
+ else:
380
+ positions = compute_force_layout_3d(G, args.iterations, args.seed)
381
+
382
+ # Save
383
+ save_layout(positions, str(output_path), G)
384
+
385
+ logger.info("=" * 60)
386
+ logger.info("Done!")
387
+ logger.info("=" * 60)
388
+
389
+
390
+ if __name__ == "__main__":
391
+ main()
392
+
frontend/src/components/visualizations/ForceDirectedGraph3DInstanced.tsx ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Highly optimized 3D Force-directed graph visualization using instanced rendering.
3
+ * Designed to handle millions of nodes efficiently by:
4
+ * 1. Using pre-computed positions from server
5
+ * 2. Instanced mesh rendering (single draw call for all nodes)
6
+ * 3. GPU-based picking for interactions
7
+ * 4. Level-of-detail (LOD) for labels
8
+ */
9
+ import React, { useMemo, useRef, useEffect, useState, useCallback } from 'react';
10
+ import { Canvas, useFrame, useThree } from '@react-three/fiber';
11
+ import { OrbitControls } from '@react-three/drei';
12
+ import * as THREE from 'three';
13
+ import { GraphNode, GraphLink, EdgeType } from './ForceDirectedGraph';
14
+ import './ForceDirectedGraph.css';
15
+
16
+ export interface ForceDirectedGraph3DInstancedProps {
17
+ width: number;
18
+ height: number;
19
+ nodes: GraphNode[];
20
+ links: GraphLink[];
21
+ onNodeClick?: (node: GraphNode) => void;
22
+ onNodeHover?: (node: GraphNode | null) => void;
23
+ selectedNodeId?: string | null;
24
+ enabledEdgeTypes?: Set<EdgeType>;
25
+ showLabels?: boolean;
26
+ maxVisibleNodes?: number;
27
+ maxVisibleEdges?: number;
28
+ }
29
+
30
+ // Color scheme for different libraries
31
+ const LIBRARY_COLORS: Record<string, string> = {
32
+ transformers: '#3b82f6', // Blue
33
+ pytorch: '#ef4444', // Red
34
+ tensorflow: '#f97316', // Orange
35
+ diffusers: '#8b5cf6', // Purple
36
+ 'sentence-transformers': '#10b981', // Green
37
+ timm: '#06b6d4', // Cyan
38
+ peft: '#ec4899', // Pink
39
+ default: '#6b7280', // Gray
40
+ };
41
+
42
+ // Color scheme for different edge types
43
+ const EDGE_COLORS: Record<EdgeType, THREE.Color> = {
44
+ finetune: new THREE.Color('#3b82f6'), // Blue
45
+ quantized: new THREE.Color('#10b981'), // Green
46
+ adapter: new THREE.Color('#f59e0b'), // Orange
47
+ merge: new THREE.Color('#8b5cf6'), // Purple
48
+ parent: new THREE.Color('#6b7280'), // Gray
49
+ };
50
+
51
+ /**
52
+ * Get color for a node based on its library
53
+ */
54
+ function getNodeColor(library: string | undefined): THREE.Color {
55
+ const colorHex = LIBRARY_COLORS[library?.toLowerCase() || ''] || LIBRARY_COLORS.default;
56
+ return new THREE.Color(colorHex);
57
+ }
58
+
59
+ /**
60
+ * Calculate node size based on downloads (log scale)
61
+ */
62
+ function getNodeSize(downloads: number): number {
63
+ return 0.3 + Math.log10(Math.max(downloads, 1)) * 0.15;
64
+ }
65
+
66
+ /**
67
+ * Instanced nodes component - renders all nodes in a single draw call
68
+ */
69
+ function InstancedNodes({
70
+ nodes,
71
+ selectedNodeId,
72
+ onNodeClick,
73
+ onNodeHover,
74
+ maxVisible = 500000,
75
+ }: {
76
+ nodes: GraphNode[];
77
+ selectedNodeId?: string | null;
78
+ onNodeClick?: (node: GraphNode) => void;
79
+ onNodeHover?: (node: GraphNode | null) => void;
80
+ maxVisible?: number;
81
+ }) {
82
+ const meshRef = useRef<THREE.InstancedMesh>(null);
83
+ const { camera, raycaster, pointer } = useThree();
84
+ const [hoveredIndex, setHoveredIndex] = useState<number | null>(null);
85
+
86
+ // Limit nodes for performance
87
+ const visibleNodes = useMemo(() => {
88
+ if (nodes.length <= maxVisible) return nodes;
89
+ // Sort by downloads and take top N
90
+ return [...nodes]
91
+ .sort((a, b) => (b.downloads || 0) - (a.downloads || 0))
92
+ .slice(0, maxVisible);
93
+ }, [nodes, maxVisible]);
94
+
95
+ // Node ID to index map for lookup
96
+ const nodeIndexMap = useMemo(() => {
97
+ const map = new Map<string, number>();
98
+ visibleNodes.forEach((node, i) => map.set(node.id, i));
99
+ return map;
100
+ }, [visibleNodes]);
101
+
102
+ // Pre-compute matrices and colors
103
+ const { matrices, colors, sizes } = useMemo(() => {
104
+ const matrices: THREE.Matrix4[] = [];
105
+ const colors: THREE.Color[] = [];
106
+ const sizes: number[] = [];
107
+
108
+ const tempMatrix = new THREE.Matrix4();
109
+
110
+ visibleNodes.forEach((node) => {
111
+ const x = node.x || 0;
112
+ const y = node.y || 0;
113
+ const z = node.z || 0;
114
+ const size = getNodeSize(node.downloads || 0);
115
+
116
+ tempMatrix.makeScale(size, size, size);
117
+ tempMatrix.setPosition(x, y, z);
118
+ matrices.push(tempMatrix.clone());
119
+
120
+ colors.push(getNodeColor(node.library));
121
+ sizes.push(size);
122
+ });
123
+
124
+ return { matrices, colors, sizes };
125
+ }, [visibleNodes]);
126
+
127
+ // Update instance attributes when data changes
128
+ useEffect(() => {
129
+ const mesh = meshRef.current;
130
+ if (!mesh) return;
131
+
132
+ const tempColor = new THREE.Color();
133
+
134
+ matrices.forEach((matrix, i) => {
135
+ mesh.setMatrixAt(i, matrix);
136
+
137
+ // Highlight selected/hovered nodes
138
+ const isSelected = visibleNodes[i]?.id === selectedNodeId;
139
+ const isHovered = i === hoveredIndex;
140
+
141
+ if (isSelected) {
142
+ tempColor.set('#ef4444'); // Red for selected
143
+ } else if (isHovered) {
144
+ tempColor.set('#fbbf24'); // Yellow for hovered
145
+ } else {
146
+ tempColor.copy(colors[i]);
147
+ }
148
+
149
+ mesh.setColorAt(i, tempColor);
150
+ });
151
+
152
+ mesh.instanceMatrix.needsUpdate = true;
153
+ if (mesh.instanceColor) mesh.instanceColor.needsUpdate = true;
154
+ }, [matrices, colors, selectedNodeId, hoveredIndex, visibleNodes]);
155
+
156
+ // Raycasting for hover/click
157
+ useFrame(() => {
158
+ const mesh = meshRef.current;
159
+ if (!mesh) return;
160
+
161
+ raycaster.setFromCamera(pointer, camera);
162
+ const intersects = raycaster.intersectObject(mesh);
163
+
164
+ if (intersects.length > 0) {
165
+ const index = intersects[0].instanceId;
166
+ if (index !== undefined && index !== hoveredIndex) {
167
+ setHoveredIndex(index);
168
+ if (onNodeHover && visibleNodes[index]) {
169
+ onNodeHover(visibleNodes[index]);
170
+ }
171
+ }
172
+ } else if (hoveredIndex !== null) {
173
+ setHoveredIndex(null);
174
+ if (onNodeHover) {
175
+ onNodeHover(null);
176
+ }
177
+ }
178
+ });
179
+
180
+ // Handle click
181
+ const handleClick = useCallback(() => {
182
+ if (hoveredIndex !== null && onNodeClick && visibleNodes[hoveredIndex]) {
183
+ onNodeClick(visibleNodes[hoveredIndex]);
184
+ }
185
+ }, [hoveredIndex, onNodeClick, visibleNodes]);
186
+
187
+ if (visibleNodes.length === 0) return null;
188
+
189
+ return (
190
+ <instancedMesh
191
+ ref={meshRef}
192
+ args={[undefined, undefined, visibleNodes.length]}
193
+ onClick={handleClick}
194
+ frustumCulled={false}
195
+ >
196
+ <sphereGeometry args={[1, 8, 8]} />
197
+ <meshStandardMaterial
198
+ vertexColors
199
+ roughness={0.4}
200
+ metalness={0.1}
201
+ />
202
+ </instancedMesh>
203
+ );
204
+ }
205
+
206
+ /**
207
+ * Edges component using line segments
208
+ */
209
+ function Edges({
210
+ nodes,
211
+ links,
212
+ enabledEdgeTypes,
213
+ maxVisible = 100000,
214
+ }: {
215
+ nodes: GraphNode[];
216
+ links: GraphLink[];
217
+ enabledEdgeTypes?: Set<EdgeType>;
218
+ maxVisible?: number;
219
+ }) {
220
+ const lineRef = useRef<THREE.LineSegments>(null);
221
+
222
+ // Create node lookup map
223
+ const nodeMap = useMemo(() => {
224
+ const map = new Map<string, GraphNode>();
225
+ nodes.forEach(node => map.set(node.id, node));
226
+ return map;
227
+ }, [nodes]);
228
+
229
+ // Filter and limit links
230
+ const visibleLinks = useMemo(() => {
231
+ let filtered = links;
232
+
233
+ if (enabledEdgeTypes && enabledEdgeTypes.size > 0) {
234
+ filtered = links.filter(link => {
235
+ const linkTypes = link.edge_types || [link.edge_type];
236
+ return linkTypes.some(type => enabledEdgeTypes.has(type));
237
+ });
238
+ }
239
+
240
+ if (filtered.length > maxVisible) {
241
+ return filtered.slice(0, maxVisible);
242
+ }
243
+
244
+ return filtered;
245
+ }, [links, enabledEdgeTypes, maxVisible]);
246
+
247
+ // Build geometry
248
+ const geometry = useMemo(() => {
249
+ const positions: number[] = [];
250
+ const colors: number[] = [];
251
+
252
+ visibleLinks.forEach(link => {
253
+ const sourceId = typeof link.source === 'string' ? link.source : link.source?.id;
254
+ const targetId = typeof link.target === 'string' ? link.target : link.target?.id;
255
+
256
+ const source = nodeMap.get(sourceId || '');
257
+ const target = nodeMap.get(targetId || '');
258
+
259
+ if (!source || !target) return;
260
+
261
+ // Source position
262
+ positions.push(source.x || 0, source.y || 0, source.z || 0);
263
+ // Target position
264
+ positions.push(target.x || 0, target.y || 0, target.z || 0);
265
+
266
+ // Edge color based on type
267
+ const edgeType = link.edge_type || 'parent';
268
+ const color = EDGE_COLORS[edgeType] || EDGE_COLORS.parent;
269
+
270
+ // Add color for both vertices
271
+ colors.push(color.r, color.g, color.b);
272
+ colors.push(color.r, color.g, color.b);
273
+ });
274
+
275
+ const geom = new THREE.BufferGeometry();
276
+ geom.setAttribute('position', new THREE.Float32BufferAttribute(positions, 3));
277
+ geom.setAttribute('color', new THREE.Float32BufferAttribute(colors, 3));
278
+
279
+ return geom;
280
+ }, [visibleLinks, nodeMap]);
281
+
282
+ if (visibleLinks.length === 0) return null;
283
+
284
+ return (
285
+ <lineSegments ref={lineRef} geometry={geometry}>
286
+ <lineBasicMaterial
287
+ vertexColors
288
+ transparent
289
+ opacity={0.3}
290
+ depthWrite={false}
291
+ />
292
+ </lineSegments>
293
+ );
294
+ }
295
+
296
+ /**
297
+ * Main scene component
298
+ */
299
+ function Scene({
300
+ nodes,
301
+ links,
302
+ onNodeClick,
303
+ onNodeHover,
304
+ selectedNodeId,
305
+ enabledEdgeTypes,
306
+ maxVisibleNodes = 500000,
307
+ maxVisibleEdges = 100000,
308
+ }: ForceDirectedGraph3DInstancedProps) {
309
+ return (
310
+ <>
311
+ <Edges
312
+ nodes={nodes}
313
+ links={links}
314
+ enabledEdgeTypes={enabledEdgeTypes}
315
+ maxVisible={maxVisibleEdges}
316
+ />
317
+ <InstancedNodes
318
+ nodes={nodes}
319
+ selectedNodeId={selectedNodeId}
320
+ onNodeClick={onNodeClick}
321
+ onNodeHover={onNodeHover}
322
+ maxVisible={maxVisibleNodes}
323
+ />
324
+ </>
325
+ );
326
+ }
327
+
328
+ /**
329
+ * Main component with Canvas wrapper
330
+ */
331
+ export default function ForceDirectedGraph3DInstanced({
332
+ width,
333
+ height,
334
+ nodes,
335
+ links,
336
+ onNodeClick,
337
+ onNodeHover,
338
+ selectedNodeId,
339
+ enabledEdgeTypes,
340
+ showLabels = false,
341
+ maxVisibleNodes = 500000,
342
+ maxVisibleEdges = 100000,
343
+ }: ForceDirectedGraph3DInstancedProps) {
344
+ // Calculate bounds for camera positioning
345
+ const bounds = useMemo(() => {
346
+ if (nodes.length === 0) {
347
+ return { center: [0, 0, 0] as [number, number, number], radius: 100 };
348
+ }
349
+
350
+ let minX = Infinity, maxX = -Infinity;
351
+ let minY = Infinity, maxY = -Infinity;
352
+ let minZ = Infinity, maxZ = -Infinity;
353
+
354
+ // Sample nodes for bounds calculation if too many
355
+ const sampleNodes = nodes.length > 10000
356
+ ? nodes.filter((_, i) => i % Math.ceil(nodes.length / 10000) === 0)
357
+ : nodes;
358
+
359
+ sampleNodes.forEach(node => {
360
+ const x = node.x || 0;
361
+ const y = node.y || 0;
362
+ const z = node.z || 0;
363
+ minX = Math.min(minX, x);
364
+ maxX = Math.max(maxX, x);
365
+ minY = Math.min(minY, y);
366
+ maxY = Math.max(maxY, y);
367
+ minZ = Math.min(minZ, z);
368
+ maxZ = Math.max(maxZ, z);
369
+ });
370
+
371
+ const center: [number, number, number] = [
372
+ (minX + maxX) / 2,
373
+ (minY + maxY) / 2,
374
+ (minZ + maxZ) / 2,
375
+ ];
376
+ const radius = Math.max(
377
+ maxX - minX,
378
+ maxY - minY,
379
+ maxZ - minZ
380
+ ) / 2 || 100;
381
+
382
+ return { center, radius };
383
+ }, [nodes]);
384
+
385
+ if (nodes.length === 0) {
386
+ return (
387
+ <div className="force-directed-graph-container">
388
+ <div className="graph-empty">No nodes to display</div>
389
+ </div>
390
+ );
391
+ }
392
+
393
+ return (
394
+ <div className="force-directed-graph-container" style={{ width, height }}>
395
+ <Canvas
396
+ dpr={[1, 1.5]}
397
+ gl={{
398
+ antialias: true,
399
+ alpha: false,
400
+ powerPreference: 'high-performance',
401
+ stencil: false,
402
+ depth: true,
403
+ }}
404
+ camera={{
405
+ position: [
406
+ bounds.center[0] + bounds.radius * 1.5,
407
+ bounds.center[1] + bounds.radius * 1.5,
408
+ bounds.center[2] + bounds.radius * 1.5,
409
+ ],
410
+ fov: 45,
411
+ near: 0.1,
412
+ far: bounds.radius * 20,
413
+ }}
414
+ frameloop="demand"
415
+ >
416
+ <color attach="background" args={['#1a1a1a']} />
417
+
418
+ <OrbitControls
419
+ target={bounds.center}
420
+ enableDamping={true}
421
+ dampingFactor={0.05}
422
+ minDistance={bounds.radius * 0.1}
423
+ maxDistance={bounds.radius * 5}
424
+ makeDefault
425
+ />
426
+
427
+ <ambientLight intensity={0.8} />
428
+ <directionalLight position={[1, 1, 1]} intensity={0.5} />
429
+
430
+ <Scene
431
+ nodes={nodes}
432
+ links={links}
433
+ onNodeClick={onNodeClick}
434
+ onNodeHover={onNodeHover}
435
+ selectedNodeId={selectedNodeId}
436
+ enabledEdgeTypes={enabledEdgeTypes}
437
+ maxVisibleNodes={maxVisibleNodes}
438
+ maxVisibleEdges={maxVisibleEdges}
439
+ width={width}
440
+ height={height}
441
+ />
442
+ </Canvas>
443
+
444
+ {/* Performance info overlay */}
445
+ <div className="graph-performance-info" style={{
446
+ position: 'absolute',
447
+ top: '10px',
448
+ right: '10px',
449
+ padding: '8px 12px',
450
+ background: 'rgba(0,0,0,0.7)',
451
+ color: '#fff',
452
+ borderRadius: '4px',
453
+ fontSize: '12px',
454
+ fontFamily: 'monospace',
455
+ }}>
456
+ <div>Nodes: {nodes.length.toLocaleString()}</div>
457
+ <div>Edges: {links.length.toLocaleString()}</div>
458
+ {nodes.length > maxVisibleNodes && (
459
+ <div style={{ color: '#f59e0b' }}>
460
+ Showing top {maxVisibleNodes.toLocaleString()} by popularity
461
+ </div>
462
+ )}
463
+ </div>
464
+ </div>
465
+ );
466
+ }
467
+
frontend/src/pages/GraphPage.tsx CHANGED
@@ -1,6 +1,7 @@
1
  import React, { useState, useEffect, useCallback } from 'react';
2
  import ForceDirectedGraph, { EdgeType, GraphNode } from '../components/visualizations/ForceDirectedGraph';
3
  import ForceDirectedGraph3D from '../components/visualizations/ForceDirectedGraph3D';
 
4
  import ScatterPlot3D from '../components/visualizations/ScatterPlot3D';
5
  import { fetchFamilyNetwork, fetchFullDerivativeNetwork, getAvailableEdgeTypes } from '../utils/api/graphApi';
6
  import LoadingProgress from '../components/ui/LoadingProgress';
@@ -11,6 +12,9 @@ import './GraphPage.css';
11
 
12
  const ALL_EDGE_TYPES: EdgeType[] = ['finetune', 'quantized', 'adapter', 'merge', 'parent'];
13
 
 
 
 
14
  type ViewMode = 'graph' | 'embedding' | 'graph3d';
15
  type GraphMode = 'family' | 'full';
16
 
@@ -451,16 +455,32 @@ export default function GraphPage() {
451
  ) : viewMode === 'graph3d' ? (
452
  <>
453
  <div style={{ width: '100%', height: '100%', position: 'relative' }}>
454
- <ForceDirectedGraph3D
455
- width={dimensions.width}
456
- height={dimensions.height}
457
- nodes={nodes}
458
- links={links}
459
- onNodeClick={handleNodeClick}
460
- selectedNodeId={selectedNodeId}
461
- enabledEdgeTypes={enabledEdgeTypes}
462
- showLabels={true}
463
- />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  </div>
465
  <EdgeTypeLegend
466
  edgeTypes={ALL_EDGE_TYPES}
@@ -471,11 +491,11 @@ export default function GraphPage() {
471
  <div className="graph-stats">
472
  <div className="stat-item">
473
  <span className="stat-label">Nodes:</span>
474
- <span className="stat-value">{graphStats.nodes || nodes.length}</span>
475
  </div>
476
  <div className="stat-item">
477
  <span className="stat-label">Edges:</span>
478
- <span className="stat-value">{graphStats.edges || links.length}</span>
479
  </div>
480
  {graphStats.avg_degree && (
481
  <div className="stat-item">
@@ -483,6 +503,12 @@ export default function GraphPage() {
483
  <span className="stat-value">{graphStats.avg_degree.toFixed(2)}</span>
484
  </div>
485
  )}
 
 
 
 
 
 
486
  </div>
487
  )}
488
  </>
 
1
  import React, { useState, useEffect, useCallback } from 'react';
2
  import ForceDirectedGraph, { EdgeType, GraphNode } from '../components/visualizations/ForceDirectedGraph';
3
  import ForceDirectedGraph3D from '../components/visualizations/ForceDirectedGraph3D';
4
+ import ForceDirectedGraph3DInstanced from '../components/visualizations/ForceDirectedGraph3DInstanced';
5
  import ScatterPlot3D from '../components/visualizations/ScatterPlot3D';
6
  import { fetchFamilyNetwork, fetchFullDerivativeNetwork, getAvailableEdgeTypes } from '../utils/api/graphApi';
7
  import LoadingProgress from '../components/ui/LoadingProgress';
 
12
 
13
  const ALL_EDGE_TYPES: EdgeType[] = ['finetune', 'quantized', 'adapter', 'merge', 'parent'];
14
 
15
+ // Use instanced rendering threshold for large graphs
16
+ const INSTANCED_THRESHOLD = 10000;
17
+
18
  type ViewMode = 'graph' | 'embedding' | 'graph3d';
19
  type GraphMode = 'family' | 'full';
20
 
 
455
  ) : viewMode === 'graph3d' ? (
456
  <>
457
  <div style={{ width: '100%', height: '100%', position: 'relative' }}>
458
+ {/* Use instanced rendering for large graphs (>10k nodes) */}
459
+ {nodes.length > INSTANCED_THRESHOLD ? (
460
+ <ForceDirectedGraph3DInstanced
461
+ width={dimensions.width}
462
+ height={dimensions.height}
463
+ nodes={nodes}
464
+ links={links}
465
+ onNodeClick={handleNodeClick}
466
+ selectedNodeId={selectedNodeId}
467
+ enabledEdgeTypes={enabledEdgeTypes}
468
+ showLabels={false}
469
+ maxVisibleNodes={500000}
470
+ maxVisibleEdges={200000}
471
+ />
472
+ ) : (
473
+ <ForceDirectedGraph3D
474
+ width={dimensions.width}
475
+ height={dimensions.height}
476
+ nodes={nodes}
477
+ links={links}
478
+ onNodeClick={handleNodeClick}
479
+ selectedNodeId={selectedNodeId}
480
+ enabledEdgeTypes={enabledEdgeTypes}
481
+ showLabels={true}
482
+ />
483
+ )}
484
  </div>
485
  <EdgeTypeLegend
486
  edgeTypes={ALL_EDGE_TYPES}
 
491
  <div className="graph-stats">
492
  <div className="stat-item">
493
  <span className="stat-label">Nodes:</span>
494
+ <span className="stat-value">{(graphStats.nodes || nodes.length).toLocaleString()}</span>
495
  </div>
496
  <div className="stat-item">
497
  <span className="stat-label">Edges:</span>
498
+ <span className="stat-value">{(graphStats.edges || links.length).toLocaleString()}</span>
499
  </div>
500
  {graphStats.avg_degree && (
501
  <div className="stat-item">
 
503
  <span className="stat-value">{graphStats.avg_degree.toFixed(2)}</span>
504
  </div>
505
  )}
506
+ {nodes.length > INSTANCED_THRESHOLD && (
507
+ <div className="stat-item stat-info">
508
+ <span className="stat-label">Rendering:</span>
509
+ <span className="stat-value">Instanced (optimized)</span>
510
+ </div>
511
+ )}
512
  </div>
513
  )}
514
  </>