Add 3D force-directed graph with instanced rendering and pre-computed positions
Browse files- Add ForceDirectedGraph3DInstanced.tsx for rendering millions of nodes efficiently
- Add precompute_force_layout.py script to generate positions offline
- Update backend to serve pre-computed positions in full-derivatives endpoint
- Update GraphPage to use instanced rendering for large graphs (>10k nodes)
- Include force_layout_3d.pkl with 1.86M pre-computed node positions
backend/api/main.py
CHANGED
|
@@ -1618,7 +1618,8 @@ async def get_family_network(
|
|
| 1618 |
@cached_response(ttl=3600, key_prefix="full_derivatives_network")
|
| 1619 |
async def get_full_derivative_network(
|
| 1620 |
edge_types: Optional[str] = Query(None, description="Comma-separated list of edge types to include (finetune,quantized,adapter,merge,parent). If None, includes all types."),
|
| 1621 |
-
include_edge_attributes: bool = Query(False, description="Whether to include edge attributes (change in likes, downloads, etc.). Default False for performance.")
|
|
|
|
| 1622 |
):
|
| 1623 |
"""
|
| 1624 |
Build full derivative relationship network for ALL models in the database.
|
|
@@ -1626,6 +1627,7 @@ async def get_full_derivative_network(
|
|
| 1626 |
This computes over every single model in the database.
|
| 1627 |
|
| 1628 |
Note: Edge attributes are disabled by default for performance with large datasets.
|
|
|
|
| 1629 |
"""
|
| 1630 |
if df is None:
|
| 1631 |
raise DataNotLoadedError()
|
|
@@ -1652,10 +1654,26 @@ async def get_full_derivative_network(
|
|
| 1652 |
build_time = time.time() - start_time
|
| 1653 |
logger.info(f"Graph built in {build_time:.2f}s: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges")
|
| 1654 |
|
| 1655 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1656 |
nodes = []
|
| 1657 |
for node_id, attrs in graph.nodes(data=True):
|
| 1658 |
-
|
| 1659 |
"id": node_id,
|
| 1660 |
"title": attrs.get('title', node_id),
|
| 1661 |
"freq": attrs.get('freq', 0),
|
|
@@ -1663,7 +1681,16 @@ async def get_full_derivative_network(
|
|
| 1663 |
"downloads": attrs.get('downloads', 0),
|
| 1664 |
"library": attrs.get('library', ''),
|
| 1665 |
"pipeline": attrs.get('pipeline', '')
|
| 1666 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1667 |
|
| 1668 |
logger.info(f"Processed {len(nodes):,} nodes")
|
| 1669 |
|
|
|
|
| 1618 |
@cached_response(ttl=3600, key_prefix="full_derivatives_network")
|
| 1619 |
async def get_full_derivative_network(
|
| 1620 |
edge_types: Optional[str] = Query(None, description="Comma-separated list of edge types to include (finetune,quantized,adapter,merge,parent). If None, includes all types."),
|
| 1621 |
+
include_edge_attributes: bool = Query(False, description="Whether to include edge attributes (change in likes, downloads, etc.). Default False for performance."),
|
| 1622 |
+
include_positions: bool = Query(True, description="Whether to include pre-computed 3D positions for each node. Default True for faster rendering.")
|
| 1623 |
):
|
| 1624 |
"""
|
| 1625 |
Build full derivative relationship network for ALL models in the database.
|
|
|
|
| 1627 |
This computes over every single model in the database.
|
| 1628 |
|
| 1629 |
Note: Edge attributes are disabled by default for performance with large datasets.
|
| 1630 |
+
If pre-computed positions exist, they will be included in the response.
|
| 1631 |
"""
|
| 1632 |
if df is None:
|
| 1633 |
raise DataNotLoadedError()
|
|
|
|
| 1654 |
build_time = time.time() - start_time
|
| 1655 |
logger.info(f"Graph built in {build_time:.2f}s: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges")
|
| 1656 |
|
| 1657 |
+
# Load pre-computed positions if available
|
| 1658 |
+
precomputed_positions = {}
|
| 1659 |
+
if include_positions:
|
| 1660 |
+
try:
|
| 1661 |
+
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 1662 |
+
root_dir = os.path.dirname(backend_dir)
|
| 1663 |
+
layout_file = os.path.join(root_dir, "precomputed_data", "force_layout_3d.pkl")
|
| 1664 |
+
|
| 1665 |
+
if os.path.exists(layout_file):
|
| 1666 |
+
with open(layout_file, 'rb') as f:
|
| 1667 |
+
layout_data = pickle.load(f)
|
| 1668 |
+
precomputed_positions = layout_data.get('positions', {})
|
| 1669 |
+
logger.info(f"Loaded {len(precomputed_positions):,} pre-computed positions")
|
| 1670 |
+
except Exception as e:
|
| 1671 |
+
logger.warning(f"Could not load pre-computed positions: {e}")
|
| 1672 |
+
|
| 1673 |
+
# Build nodes list with optional pre-computed positions
|
| 1674 |
nodes = []
|
| 1675 |
for node_id, attrs in graph.nodes(data=True):
|
| 1676 |
+
node_data = {
|
| 1677 |
"id": node_id,
|
| 1678 |
"title": attrs.get('title', node_id),
|
| 1679 |
"freq": attrs.get('freq', 0),
|
|
|
|
| 1681 |
"downloads": attrs.get('downloads', 0),
|
| 1682 |
"library": attrs.get('library', ''),
|
| 1683 |
"pipeline": attrs.get('pipeline', '')
|
| 1684 |
+
}
|
| 1685 |
+
|
| 1686 |
+
# Add pre-computed position if available
|
| 1687 |
+
if node_id in precomputed_positions:
|
| 1688 |
+
pos = precomputed_positions[node_id]
|
| 1689 |
+
node_data['x'] = pos[0]
|
| 1690 |
+
node_data['y'] = pos[1]
|
| 1691 |
+
node_data['z'] = pos[2]
|
| 1692 |
+
|
| 1693 |
+
nodes.append(node_data)
|
| 1694 |
|
| 1695 |
logger.info(f"Processed {len(nodes):,} nodes")
|
| 1696 |
|
backend/scripts/precompute_force_layout.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pre-compute force-directed layout positions for the full model network.
|
| 3 |
+
Uses graph-tool or networkx with Barnes-Hut optimization for large-scale layouts.
|
| 4 |
+
|
| 5 |
+
This script generates x, y, z coordinates for all nodes so the frontend
|
| 6 |
+
doesn't need to compute force simulation in real-time.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python precompute_force_layout.py [--output force_layout.pkl] [--3d]
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import time
|
| 15 |
+
import pickle
|
| 16 |
+
import argparse
|
| 17 |
+
import logging
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Dict, Tuple, Optional
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
# Add backend to path
|
| 23 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 24 |
+
|
| 25 |
+
logging.basicConfig(
|
| 26 |
+
level=logging.INFO,
|
| 27 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 28 |
+
)
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load_model_data() -> 'pd.DataFrame':
|
| 33 |
+
"""Load model data from precomputed parquet or CSV."""
|
| 34 |
+
import pandas as pd
|
| 35 |
+
|
| 36 |
+
backend_dir = Path(__file__).parent.parent
|
| 37 |
+
root_dir = backend_dir.parent
|
| 38 |
+
|
| 39 |
+
# Try precomputed data first
|
| 40 |
+
precomputed_dir = root_dir / "precomputed_data"
|
| 41 |
+
if precomputed_dir.exists():
|
| 42 |
+
parquet_files = list(precomputed_dir.glob("*.parquet"))
|
| 43 |
+
if parquet_files:
|
| 44 |
+
logger.info(f"Loading from precomputed parquet: {parquet_files[0]}")
|
| 45 |
+
return pd.read_parquet(parquet_files[0])
|
| 46 |
+
|
| 47 |
+
# Try CSV data
|
| 48 |
+
csv_path = precomputed_dir / "models.csv"
|
| 49 |
+
if csv_path.exists():
|
| 50 |
+
logger.info(f"Loading from CSV: {csv_path}")
|
| 51 |
+
return pd.read_csv(csv_path)
|
| 52 |
+
|
| 53 |
+
# Try data directory
|
| 54 |
+
data_dir = root_dir / "data"
|
| 55 |
+
if data_dir.exists():
|
| 56 |
+
csv_files = list(data_dir.glob("*.csv"))
|
| 57 |
+
for csv_file in csv_files:
|
| 58 |
+
if "model" in csv_file.name.lower():
|
| 59 |
+
logger.info(f"Loading from {csv_file}")
|
| 60 |
+
return pd.read_csv(csv_file)
|
| 61 |
+
|
| 62 |
+
raise FileNotFoundError("No model data found")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def load_existing_graph(graph_path: str = None) -> Optional['nx.DiGraph']:
|
| 66 |
+
"""Load pre-existing networkx graph from pickle file."""
|
| 67 |
+
import networkx as nx
|
| 68 |
+
|
| 69 |
+
if graph_path and Path(graph_path).exists():
|
| 70 |
+
logger.info(f"Loading existing graph from {graph_path}")
|
| 71 |
+
with open(graph_path, 'rb') as f:
|
| 72 |
+
return pickle.load(f)
|
| 73 |
+
|
| 74 |
+
# Search for graph file
|
| 75 |
+
search_paths = [
|
| 76 |
+
Path(__file__).parent.parent.parent / "ai-ecosystem" / "data" / "ai_ecosystem_graph.pkl",
|
| 77 |
+
Path(__file__).parent.parent.parent.parent / "ai-ecosystem" / "data" / "ai_ecosystem_graph.pkl",
|
| 78 |
+
Path.home() / "ai-ecosystem-v2" / "ai-ecosystem" / "data" / "ai_ecosystem_graph.pkl",
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
for path in search_paths:
|
| 82 |
+
if path.exists():
|
| 83 |
+
logger.info(f"Found existing graph at {path}")
|
| 84 |
+
with open(path, 'rb') as f:
|
| 85 |
+
return pickle.load(f)
|
| 86 |
+
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def build_network_graph(df: 'pd.DataFrame') -> 'nx.DiGraph':
|
| 91 |
+
"""Build network graph from model dataframe."""
|
| 92 |
+
import networkx as nx
|
| 93 |
+
|
| 94 |
+
logger.info(f"Building network from {len(df):,} models...")
|
| 95 |
+
G = nx.DiGraph()
|
| 96 |
+
|
| 97 |
+
# Add all models as nodes
|
| 98 |
+
for _, row in df.iterrows():
|
| 99 |
+
model_id = str(row.get('model_id', row.get('modelId', '')))
|
| 100 |
+
if not model_id:
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
G.add_node(model_id,
|
| 104 |
+
downloads=row.get('downloads', 0),
|
| 105 |
+
likes=row.get('likes', 0),
|
| 106 |
+
library=row.get('library_name', row.get('library', '')),
|
| 107 |
+
pipeline=row.get('pipeline_tag', '')
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Add edges based on parent relationships
|
| 111 |
+
edge_count = 0
|
| 112 |
+
for _, row in df.iterrows():
|
| 113 |
+
model_id = str(row.get('model_id', row.get('modelId', '')))
|
| 114 |
+
parent_id = row.get('parent_model', row.get('base_model', None))
|
| 115 |
+
|
| 116 |
+
if not model_id:
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
if pd.notna(parent_id) and str(parent_id).strip() and str(parent_id) != 'nan':
|
| 120 |
+
parent_id = str(parent_id).strip()
|
| 121 |
+
if parent_id in G.nodes:
|
| 122 |
+
G.add_edge(parent_id, model_id, edge_type='derivative')
|
| 123 |
+
edge_count += 1
|
| 124 |
+
|
| 125 |
+
logger.info(f"Network: {G.number_of_nodes():,} nodes, {edge_count:,} edges")
|
| 126 |
+
return G
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def compute_force_layout_3d(
|
| 130 |
+
G: 'nx.Graph',
|
| 131 |
+
iterations: int = 100,
|
| 132 |
+
seed: int = 42
|
| 133 |
+
) -> Dict[str, Tuple[float, float, float]]:
|
| 134 |
+
"""
|
| 135 |
+
Compute 3D force-directed layout using networkx spring_layout.
|
| 136 |
+
For very large graphs, uses Barnes-Hut approximation.
|
| 137 |
+
"""
|
| 138 |
+
import networkx as nx
|
| 139 |
+
|
| 140 |
+
n_nodes = G.number_of_nodes()
|
| 141 |
+
logger.info(f"Computing 3D layout for {n_nodes:,} nodes...")
|
| 142 |
+
|
| 143 |
+
if n_nodes == 0:
|
| 144 |
+
return {}
|
| 145 |
+
|
| 146 |
+
start_time = time.time()
|
| 147 |
+
|
| 148 |
+
# For large graphs, compute layout on largest connected component first
|
| 149 |
+
if n_nodes > 100000:
|
| 150 |
+
logger.info("Large graph detected - using optimized approach...")
|
| 151 |
+
|
| 152 |
+
# Get largest connected component (treat as undirected)
|
| 153 |
+
if isinstance(G, nx.DiGraph):
|
| 154 |
+
G_undirected = G.to_undirected()
|
| 155 |
+
else:
|
| 156 |
+
G_undirected = G
|
| 157 |
+
|
| 158 |
+
components = list(nx.connected_components(G_undirected))
|
| 159 |
+
components.sort(key=len, reverse=True)
|
| 160 |
+
|
| 161 |
+
logger.info(f"Found {len(components):,} connected components")
|
| 162 |
+
|
| 163 |
+
# Compute layouts for each component
|
| 164 |
+
positions = {}
|
| 165 |
+
offset_x = 0
|
| 166 |
+
|
| 167 |
+
for i, component in enumerate(components):
|
| 168 |
+
if len(component) < 2:
|
| 169 |
+
# Isolated nodes - place randomly
|
| 170 |
+
for node in component:
|
| 171 |
+
positions[node] = (
|
| 172 |
+
offset_x + np.random.randn() * 10,
|
| 173 |
+
np.random.randn() * 100,
|
| 174 |
+
np.random.randn() * 100
|
| 175 |
+
)
|
| 176 |
+
continue
|
| 177 |
+
|
| 178 |
+
subgraph = G_undirected.subgraph(component)
|
| 179 |
+
|
| 180 |
+
# Use spring layout with reduced iterations for large components
|
| 181 |
+
iter_count = min(iterations, max(20, 100 - len(component) // 10000))
|
| 182 |
+
|
| 183 |
+
logger.info(f" Component {i+1}/{len(components)}: {len(component):,} nodes, {iter_count} iterations")
|
| 184 |
+
|
| 185 |
+
try:
|
| 186 |
+
# 3D layout using spring_layout
|
| 187 |
+
pos_2d = nx.spring_layout(
|
| 188 |
+
subgraph,
|
| 189 |
+
dim=3,
|
| 190 |
+
k=1.0 / np.sqrt(len(component)),
|
| 191 |
+
iterations=iter_count,
|
| 192 |
+
seed=seed + i,
|
| 193 |
+
scale=100 * np.log10(max(len(component), 10))
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Apply offset to separate components
|
| 197 |
+
for node, (x, y, z) in pos_2d.items():
|
| 198 |
+
positions[node] = (x + offset_x, y, z)
|
| 199 |
+
|
| 200 |
+
# Move offset for next component
|
| 201 |
+
offset_x += 300 * np.log10(max(len(component), 10))
|
| 202 |
+
|
| 203 |
+
except Exception as e:
|
| 204 |
+
logger.warning(f"Layout failed for component {i}: {e}")
|
| 205 |
+
# Fallback: random positions
|
| 206 |
+
for node in component:
|
| 207 |
+
positions[node] = (
|
| 208 |
+
offset_x + np.random.randn() * 50,
|
| 209 |
+
np.random.randn() * 50,
|
| 210 |
+
np.random.randn() * 50
|
| 211 |
+
)
|
| 212 |
+
else:
|
| 213 |
+
# Standard approach for smaller graphs
|
| 214 |
+
try:
|
| 215 |
+
positions_raw = nx.spring_layout(
|
| 216 |
+
G.to_undirected() if isinstance(G, nx.DiGraph) else G,
|
| 217 |
+
dim=3,
|
| 218 |
+
k=2.0 / np.sqrt(n_nodes) if n_nodes > 0 else 1.0,
|
| 219 |
+
iterations=iterations,
|
| 220 |
+
seed=seed,
|
| 221 |
+
scale=200
|
| 222 |
+
)
|
| 223 |
+
positions = {node: tuple(pos) for node, pos in positions_raw.items()}
|
| 224 |
+
except Exception as e:
|
| 225 |
+
logger.warning(f"Spring layout failed: {e}, using random positions")
|
| 226 |
+
np.random.seed(seed)
|
| 227 |
+
positions = {
|
| 228 |
+
node: (np.random.randn() * 100, np.random.randn() * 100, np.random.randn() * 100)
|
| 229 |
+
for node in G.nodes()
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
elapsed = time.time() - start_time
|
| 233 |
+
logger.info(f"Layout computed in {elapsed:.1f}s")
|
| 234 |
+
|
| 235 |
+
return positions
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def compute_force_layout_fa2(
|
| 239 |
+
G: 'nx.Graph',
|
| 240 |
+
iterations: int = 100,
|
| 241 |
+
seed: int = 42
|
| 242 |
+
) -> Dict[str, Tuple[float, float, float]]:
|
| 243 |
+
"""
|
| 244 |
+
Compute layout using ForceAtlas2 algorithm (faster for large graphs).
|
| 245 |
+
Falls back to spring_layout if fa2 not available.
|
| 246 |
+
"""
|
| 247 |
+
try:
|
| 248 |
+
from fa2 import ForceAtlas2
|
| 249 |
+
|
| 250 |
+
n_nodes = G.number_of_nodes()
|
| 251 |
+
logger.info(f"Computing FA2 layout for {n_nodes:,} nodes...")
|
| 252 |
+
|
| 253 |
+
if n_nodes == 0:
|
| 254 |
+
return {}
|
| 255 |
+
|
| 256 |
+
# Convert to undirected for layout
|
| 257 |
+
if isinstance(G, nx.DiGraph):
|
| 258 |
+
import networkx as nx
|
| 259 |
+
G_layout = G.to_undirected()
|
| 260 |
+
else:
|
| 261 |
+
G_layout = G
|
| 262 |
+
|
| 263 |
+
# Initialize ForceAtlas2
|
| 264 |
+
fa2 = ForceAtlas2(
|
| 265 |
+
outboundAttractionDistribution=True,
|
| 266 |
+
linLogMode=False,
|
| 267 |
+
adjustSizes=False,
|
| 268 |
+
edgeWeightInfluence=1.0,
|
| 269 |
+
jitterTolerance=1.0,
|
| 270 |
+
barnesHutOptimize=True,
|
| 271 |
+
barnesHutTheta=1.2,
|
| 272 |
+
multiThreaded=False,
|
| 273 |
+
scalingRatio=2.0,
|
| 274 |
+
strongGravityMode=False,
|
| 275 |
+
gravity=1.0,
|
| 276 |
+
verbose=False
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# Compute 2D positions
|
| 280 |
+
positions_2d = fa2.forceatlas2_networkx_layout(
|
| 281 |
+
G_layout,
|
| 282 |
+
iterations=iterations
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# Add 3rd dimension based on hierarchy/properties
|
| 286 |
+
np.random.seed(seed)
|
| 287 |
+
positions = {}
|
| 288 |
+
for node, (x, y) in positions_2d.items():
|
| 289 |
+
# Z based on downloads (popular models higher)
|
| 290 |
+
downloads = G.nodes[node].get('downloads', 0) if node in G.nodes else 0
|
| 291 |
+
z = np.log10(max(downloads, 1)) * 10 + np.random.randn() * 5
|
| 292 |
+
positions[node] = (x * 100, y * 100, z)
|
| 293 |
+
|
| 294 |
+
return positions
|
| 295 |
+
|
| 296 |
+
except ImportError:
|
| 297 |
+
logger.warning("fa2 not installed, falling back to spring_layout")
|
| 298 |
+
return compute_force_layout_3d(G, iterations, seed)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def save_layout(
|
| 302 |
+
positions: Dict[str, Tuple[float, float, float]],
|
| 303 |
+
output_path: str,
|
| 304 |
+
graph: 'nx.Graph' = None
|
| 305 |
+
):
|
| 306 |
+
"""Save layout positions to pickle file."""
|
| 307 |
+
|
| 308 |
+
data = {
|
| 309 |
+
'positions': positions,
|
| 310 |
+
'n_nodes': len(positions),
|
| 311 |
+
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
if graph is not None:
|
| 315 |
+
data['n_edges'] = graph.number_of_edges()
|
| 316 |
+
|
| 317 |
+
# Calculate bounds
|
| 318 |
+
if positions:
|
| 319 |
+
xs = [p[0] for p in positions.values()]
|
| 320 |
+
ys = [p[1] for p in positions.values()]
|
| 321 |
+
zs = [p[2] for p in positions.values()]
|
| 322 |
+
data['bounds'] = {
|
| 323 |
+
'x_min': min(xs), 'x_max': max(xs),
|
| 324 |
+
'y_min': min(ys), 'y_max': max(ys),
|
| 325 |
+
'z_min': min(zs), 'z_max': max(zs),
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
with open(output_path, 'wb') as f:
|
| 329 |
+
pickle.dump(data, f)
|
| 330 |
+
|
| 331 |
+
logger.info(f"Saved layout to {output_path}")
|
| 332 |
+
logger.info(f" Nodes: {len(positions):,}")
|
| 333 |
+
if 'bounds' in data:
|
| 334 |
+
b = data['bounds']
|
| 335 |
+
logger.info(f" Bounds: X[{b['x_min']:.1f}, {b['x_max']:.1f}], Y[{b['y_min']:.1f}, {b['y_max']:.1f}], Z[{b['z_min']:.1f}, {b['z_max']:.1f}]")
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def main():
|
| 339 |
+
parser = argparse.ArgumentParser(description='Pre-compute force-directed layout')
|
| 340 |
+
parser.add_argument('--output', '-o', type=str, default='force_layout_3d.pkl',
|
| 341 |
+
help='Output pickle file path')
|
| 342 |
+
parser.add_argument('--iterations', '-i', type=int, default=100,
|
| 343 |
+
help='Number of layout iterations')
|
| 344 |
+
parser.add_argument('--algorithm', '-a', choices=['spring', 'fa2'], default='spring',
|
| 345 |
+
help='Layout algorithm to use')
|
| 346 |
+
parser.add_argument('--seed', '-s', type=int, default=42,
|
| 347 |
+
help='Random seed for reproducibility')
|
| 348 |
+
parser.add_argument('--graph', '-g', type=str, default=None,
|
| 349 |
+
help='Path to existing networkx graph pickle file')
|
| 350 |
+
|
| 351 |
+
args = parser.parse_args()
|
| 352 |
+
|
| 353 |
+
# Determine output path
|
| 354 |
+
backend_dir = Path(__file__).parent.parent
|
| 355 |
+
root_dir = backend_dir.parent
|
| 356 |
+
precomputed_dir = root_dir / "precomputed_data"
|
| 357 |
+
precomputed_dir.mkdir(exist_ok=True)
|
| 358 |
+
|
| 359 |
+
output_path = precomputed_dir / args.output
|
| 360 |
+
|
| 361 |
+
logger.info("=" * 60)
|
| 362 |
+
logger.info("Pre-computing Force-Directed Layout")
|
| 363 |
+
logger.info("=" * 60)
|
| 364 |
+
|
| 365 |
+
# Try to load existing graph first (faster)
|
| 366 |
+
G = load_existing_graph(args.graph)
|
| 367 |
+
|
| 368 |
+
if G is None:
|
| 369 |
+
# Load data and build graph
|
| 370 |
+
df = load_model_data()
|
| 371 |
+
logger.info(f"Loaded {len(df):,} models")
|
| 372 |
+
G = build_network_graph(df)
|
| 373 |
+
else:
|
| 374 |
+
logger.info(f"Using existing graph: {G.number_of_nodes():,} nodes, {G.number_of_edges():,} edges")
|
| 375 |
+
|
| 376 |
+
# Compute layout
|
| 377 |
+
if args.algorithm == 'fa2':
|
| 378 |
+
positions = compute_force_layout_fa2(G, args.iterations, args.seed)
|
| 379 |
+
else:
|
| 380 |
+
positions = compute_force_layout_3d(G, args.iterations, args.seed)
|
| 381 |
+
|
| 382 |
+
# Save
|
| 383 |
+
save_layout(positions, str(output_path), G)
|
| 384 |
+
|
| 385 |
+
logger.info("=" * 60)
|
| 386 |
+
logger.info("Done!")
|
| 387 |
+
logger.info("=" * 60)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
if __name__ == "__main__":
|
| 391 |
+
main()
|
| 392 |
+
|
frontend/src/components/visualizations/ForceDirectedGraph3DInstanced.tsx
ADDED
|
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Highly optimized 3D Force-directed graph visualization using instanced rendering.
|
| 3 |
+
* Designed to handle millions of nodes efficiently by:
|
| 4 |
+
* 1. Using pre-computed positions from server
|
| 5 |
+
* 2. Instanced mesh rendering (single draw call for all nodes)
|
| 6 |
+
* 3. GPU-based picking for interactions
|
| 7 |
+
* 4. Level-of-detail (LOD) for labels
|
| 8 |
+
*/
|
| 9 |
+
import React, { useMemo, useRef, useEffect, useState, useCallback } from 'react';
|
| 10 |
+
import { Canvas, useFrame, useThree } from '@react-three/fiber';
|
| 11 |
+
import { OrbitControls } from '@react-three/drei';
|
| 12 |
+
import * as THREE from 'three';
|
| 13 |
+
import { GraphNode, GraphLink, EdgeType } from './ForceDirectedGraph';
|
| 14 |
+
import './ForceDirectedGraph.css';
|
| 15 |
+
|
| 16 |
+
export interface ForceDirectedGraph3DInstancedProps {
|
| 17 |
+
width: number;
|
| 18 |
+
height: number;
|
| 19 |
+
nodes: GraphNode[];
|
| 20 |
+
links: GraphLink[];
|
| 21 |
+
onNodeClick?: (node: GraphNode) => void;
|
| 22 |
+
onNodeHover?: (node: GraphNode | null) => void;
|
| 23 |
+
selectedNodeId?: string | null;
|
| 24 |
+
enabledEdgeTypes?: Set<EdgeType>;
|
| 25 |
+
showLabels?: boolean;
|
| 26 |
+
maxVisibleNodes?: number;
|
| 27 |
+
maxVisibleEdges?: number;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// Color scheme for different libraries
|
| 31 |
+
const LIBRARY_COLORS: Record<string, string> = {
|
| 32 |
+
transformers: '#3b82f6', // Blue
|
| 33 |
+
pytorch: '#ef4444', // Red
|
| 34 |
+
tensorflow: '#f97316', // Orange
|
| 35 |
+
diffusers: '#8b5cf6', // Purple
|
| 36 |
+
'sentence-transformers': '#10b981', // Green
|
| 37 |
+
timm: '#06b6d4', // Cyan
|
| 38 |
+
peft: '#ec4899', // Pink
|
| 39 |
+
default: '#6b7280', // Gray
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
// Color scheme for different edge types
|
| 43 |
+
const EDGE_COLORS: Record<EdgeType, THREE.Color> = {
|
| 44 |
+
finetune: new THREE.Color('#3b82f6'), // Blue
|
| 45 |
+
quantized: new THREE.Color('#10b981'), // Green
|
| 46 |
+
adapter: new THREE.Color('#f59e0b'), // Orange
|
| 47 |
+
merge: new THREE.Color('#8b5cf6'), // Purple
|
| 48 |
+
parent: new THREE.Color('#6b7280'), // Gray
|
| 49 |
+
};
|
| 50 |
+
|
| 51 |
+
/**
|
| 52 |
+
* Get color for a node based on its library
|
| 53 |
+
*/
|
| 54 |
+
function getNodeColor(library: string | undefined): THREE.Color {
|
| 55 |
+
const colorHex = LIBRARY_COLORS[library?.toLowerCase() || ''] || LIBRARY_COLORS.default;
|
| 56 |
+
return new THREE.Color(colorHex);
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
/**
|
| 60 |
+
* Calculate node size based on downloads (log scale)
|
| 61 |
+
*/
|
| 62 |
+
function getNodeSize(downloads: number): number {
|
| 63 |
+
return 0.3 + Math.log10(Math.max(downloads, 1)) * 0.15;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
/**
|
| 67 |
+
* Instanced nodes component - renders all nodes in a single draw call
|
| 68 |
+
*/
|
| 69 |
+
function InstancedNodes({
|
| 70 |
+
nodes,
|
| 71 |
+
selectedNodeId,
|
| 72 |
+
onNodeClick,
|
| 73 |
+
onNodeHover,
|
| 74 |
+
maxVisible = 500000,
|
| 75 |
+
}: {
|
| 76 |
+
nodes: GraphNode[];
|
| 77 |
+
selectedNodeId?: string | null;
|
| 78 |
+
onNodeClick?: (node: GraphNode) => void;
|
| 79 |
+
onNodeHover?: (node: GraphNode | null) => void;
|
| 80 |
+
maxVisible?: number;
|
| 81 |
+
}) {
|
| 82 |
+
const meshRef = useRef<THREE.InstancedMesh>(null);
|
| 83 |
+
const { camera, raycaster, pointer } = useThree();
|
| 84 |
+
const [hoveredIndex, setHoveredIndex] = useState<number | null>(null);
|
| 85 |
+
|
| 86 |
+
// Limit nodes for performance
|
| 87 |
+
const visibleNodes = useMemo(() => {
|
| 88 |
+
if (nodes.length <= maxVisible) return nodes;
|
| 89 |
+
// Sort by downloads and take top N
|
| 90 |
+
return [...nodes]
|
| 91 |
+
.sort((a, b) => (b.downloads || 0) - (a.downloads || 0))
|
| 92 |
+
.slice(0, maxVisible);
|
| 93 |
+
}, [nodes, maxVisible]);
|
| 94 |
+
|
| 95 |
+
// Node ID to index map for lookup
|
| 96 |
+
const nodeIndexMap = useMemo(() => {
|
| 97 |
+
const map = new Map<string, number>();
|
| 98 |
+
visibleNodes.forEach((node, i) => map.set(node.id, i));
|
| 99 |
+
return map;
|
| 100 |
+
}, [visibleNodes]);
|
| 101 |
+
|
| 102 |
+
// Pre-compute matrices and colors
|
| 103 |
+
const { matrices, colors, sizes } = useMemo(() => {
|
| 104 |
+
const matrices: THREE.Matrix4[] = [];
|
| 105 |
+
const colors: THREE.Color[] = [];
|
| 106 |
+
const sizes: number[] = [];
|
| 107 |
+
|
| 108 |
+
const tempMatrix = new THREE.Matrix4();
|
| 109 |
+
|
| 110 |
+
visibleNodes.forEach((node) => {
|
| 111 |
+
const x = node.x || 0;
|
| 112 |
+
const y = node.y || 0;
|
| 113 |
+
const z = node.z || 0;
|
| 114 |
+
const size = getNodeSize(node.downloads || 0);
|
| 115 |
+
|
| 116 |
+
tempMatrix.makeScale(size, size, size);
|
| 117 |
+
tempMatrix.setPosition(x, y, z);
|
| 118 |
+
matrices.push(tempMatrix.clone());
|
| 119 |
+
|
| 120 |
+
colors.push(getNodeColor(node.library));
|
| 121 |
+
sizes.push(size);
|
| 122 |
+
});
|
| 123 |
+
|
| 124 |
+
return { matrices, colors, sizes };
|
| 125 |
+
}, [visibleNodes]);
|
| 126 |
+
|
| 127 |
+
// Update instance attributes when data changes
|
| 128 |
+
useEffect(() => {
|
| 129 |
+
const mesh = meshRef.current;
|
| 130 |
+
if (!mesh) return;
|
| 131 |
+
|
| 132 |
+
const tempColor = new THREE.Color();
|
| 133 |
+
|
| 134 |
+
matrices.forEach((matrix, i) => {
|
| 135 |
+
mesh.setMatrixAt(i, matrix);
|
| 136 |
+
|
| 137 |
+
// Highlight selected/hovered nodes
|
| 138 |
+
const isSelected = visibleNodes[i]?.id === selectedNodeId;
|
| 139 |
+
const isHovered = i === hoveredIndex;
|
| 140 |
+
|
| 141 |
+
if (isSelected) {
|
| 142 |
+
tempColor.set('#ef4444'); // Red for selected
|
| 143 |
+
} else if (isHovered) {
|
| 144 |
+
tempColor.set('#fbbf24'); // Yellow for hovered
|
| 145 |
+
} else {
|
| 146 |
+
tempColor.copy(colors[i]);
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
mesh.setColorAt(i, tempColor);
|
| 150 |
+
});
|
| 151 |
+
|
| 152 |
+
mesh.instanceMatrix.needsUpdate = true;
|
| 153 |
+
if (mesh.instanceColor) mesh.instanceColor.needsUpdate = true;
|
| 154 |
+
}, [matrices, colors, selectedNodeId, hoveredIndex, visibleNodes]);
|
| 155 |
+
|
| 156 |
+
// Raycasting for hover/click
|
| 157 |
+
useFrame(() => {
|
| 158 |
+
const mesh = meshRef.current;
|
| 159 |
+
if (!mesh) return;
|
| 160 |
+
|
| 161 |
+
raycaster.setFromCamera(pointer, camera);
|
| 162 |
+
const intersects = raycaster.intersectObject(mesh);
|
| 163 |
+
|
| 164 |
+
if (intersects.length > 0) {
|
| 165 |
+
const index = intersects[0].instanceId;
|
| 166 |
+
if (index !== undefined && index !== hoveredIndex) {
|
| 167 |
+
setHoveredIndex(index);
|
| 168 |
+
if (onNodeHover && visibleNodes[index]) {
|
| 169 |
+
onNodeHover(visibleNodes[index]);
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
} else if (hoveredIndex !== null) {
|
| 173 |
+
setHoveredIndex(null);
|
| 174 |
+
if (onNodeHover) {
|
| 175 |
+
onNodeHover(null);
|
| 176 |
+
}
|
| 177 |
+
}
|
| 178 |
+
});
|
| 179 |
+
|
| 180 |
+
// Handle click
|
| 181 |
+
const handleClick = useCallback(() => {
|
| 182 |
+
if (hoveredIndex !== null && onNodeClick && visibleNodes[hoveredIndex]) {
|
| 183 |
+
onNodeClick(visibleNodes[hoveredIndex]);
|
| 184 |
+
}
|
| 185 |
+
}, [hoveredIndex, onNodeClick, visibleNodes]);
|
| 186 |
+
|
| 187 |
+
if (visibleNodes.length === 0) return null;
|
| 188 |
+
|
| 189 |
+
return (
|
| 190 |
+
<instancedMesh
|
| 191 |
+
ref={meshRef}
|
| 192 |
+
args={[undefined, undefined, visibleNodes.length]}
|
| 193 |
+
onClick={handleClick}
|
| 194 |
+
frustumCulled={false}
|
| 195 |
+
>
|
| 196 |
+
<sphereGeometry args={[1, 8, 8]} />
|
| 197 |
+
<meshStandardMaterial
|
| 198 |
+
vertexColors
|
| 199 |
+
roughness={0.4}
|
| 200 |
+
metalness={0.1}
|
| 201 |
+
/>
|
| 202 |
+
</instancedMesh>
|
| 203 |
+
);
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
/**
|
| 207 |
+
* Edges component using line segments
|
| 208 |
+
*/
|
| 209 |
+
function Edges({
|
| 210 |
+
nodes,
|
| 211 |
+
links,
|
| 212 |
+
enabledEdgeTypes,
|
| 213 |
+
maxVisible = 100000,
|
| 214 |
+
}: {
|
| 215 |
+
nodes: GraphNode[];
|
| 216 |
+
links: GraphLink[];
|
| 217 |
+
enabledEdgeTypes?: Set<EdgeType>;
|
| 218 |
+
maxVisible?: number;
|
| 219 |
+
}) {
|
| 220 |
+
const lineRef = useRef<THREE.LineSegments>(null);
|
| 221 |
+
|
| 222 |
+
// Create node lookup map
|
| 223 |
+
const nodeMap = useMemo(() => {
|
| 224 |
+
const map = new Map<string, GraphNode>();
|
| 225 |
+
nodes.forEach(node => map.set(node.id, node));
|
| 226 |
+
return map;
|
| 227 |
+
}, [nodes]);
|
| 228 |
+
|
| 229 |
+
// Filter and limit links
|
| 230 |
+
const visibleLinks = useMemo(() => {
|
| 231 |
+
let filtered = links;
|
| 232 |
+
|
| 233 |
+
if (enabledEdgeTypes && enabledEdgeTypes.size > 0) {
|
| 234 |
+
filtered = links.filter(link => {
|
| 235 |
+
const linkTypes = link.edge_types || [link.edge_type];
|
| 236 |
+
return linkTypes.some(type => enabledEdgeTypes.has(type));
|
| 237 |
+
});
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
if (filtered.length > maxVisible) {
|
| 241 |
+
return filtered.slice(0, maxVisible);
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
return filtered;
|
| 245 |
+
}, [links, enabledEdgeTypes, maxVisible]);
|
| 246 |
+
|
| 247 |
+
// Build geometry
|
| 248 |
+
const geometry = useMemo(() => {
|
| 249 |
+
const positions: number[] = [];
|
| 250 |
+
const colors: number[] = [];
|
| 251 |
+
|
| 252 |
+
visibleLinks.forEach(link => {
|
| 253 |
+
const sourceId = typeof link.source === 'string' ? link.source : link.source?.id;
|
| 254 |
+
const targetId = typeof link.target === 'string' ? link.target : link.target?.id;
|
| 255 |
+
|
| 256 |
+
const source = nodeMap.get(sourceId || '');
|
| 257 |
+
const target = nodeMap.get(targetId || '');
|
| 258 |
+
|
| 259 |
+
if (!source || !target) return;
|
| 260 |
+
|
| 261 |
+
// Source position
|
| 262 |
+
positions.push(source.x || 0, source.y || 0, source.z || 0);
|
| 263 |
+
// Target position
|
| 264 |
+
positions.push(target.x || 0, target.y || 0, target.z || 0);
|
| 265 |
+
|
| 266 |
+
// Edge color based on type
|
| 267 |
+
const edgeType = link.edge_type || 'parent';
|
| 268 |
+
const color = EDGE_COLORS[edgeType] || EDGE_COLORS.parent;
|
| 269 |
+
|
| 270 |
+
// Add color for both vertices
|
| 271 |
+
colors.push(color.r, color.g, color.b);
|
| 272 |
+
colors.push(color.r, color.g, color.b);
|
| 273 |
+
});
|
| 274 |
+
|
| 275 |
+
const geom = new THREE.BufferGeometry();
|
| 276 |
+
geom.setAttribute('position', new THREE.Float32BufferAttribute(positions, 3));
|
| 277 |
+
geom.setAttribute('color', new THREE.Float32BufferAttribute(colors, 3));
|
| 278 |
+
|
| 279 |
+
return geom;
|
| 280 |
+
}, [visibleLinks, nodeMap]);
|
| 281 |
+
|
| 282 |
+
if (visibleLinks.length === 0) return null;
|
| 283 |
+
|
| 284 |
+
return (
|
| 285 |
+
<lineSegments ref={lineRef} geometry={geometry}>
|
| 286 |
+
<lineBasicMaterial
|
| 287 |
+
vertexColors
|
| 288 |
+
transparent
|
| 289 |
+
opacity={0.3}
|
| 290 |
+
depthWrite={false}
|
| 291 |
+
/>
|
| 292 |
+
</lineSegments>
|
| 293 |
+
);
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
/**
|
| 297 |
+
* Main scene component
|
| 298 |
+
*/
|
| 299 |
+
function Scene({
|
| 300 |
+
nodes,
|
| 301 |
+
links,
|
| 302 |
+
onNodeClick,
|
| 303 |
+
onNodeHover,
|
| 304 |
+
selectedNodeId,
|
| 305 |
+
enabledEdgeTypes,
|
| 306 |
+
maxVisibleNodes = 500000,
|
| 307 |
+
maxVisibleEdges = 100000,
|
| 308 |
+
}: ForceDirectedGraph3DInstancedProps) {
|
| 309 |
+
return (
|
| 310 |
+
<>
|
| 311 |
+
<Edges
|
| 312 |
+
nodes={nodes}
|
| 313 |
+
links={links}
|
| 314 |
+
enabledEdgeTypes={enabledEdgeTypes}
|
| 315 |
+
maxVisible={maxVisibleEdges}
|
| 316 |
+
/>
|
| 317 |
+
<InstancedNodes
|
| 318 |
+
nodes={nodes}
|
| 319 |
+
selectedNodeId={selectedNodeId}
|
| 320 |
+
onNodeClick={onNodeClick}
|
| 321 |
+
onNodeHover={onNodeHover}
|
| 322 |
+
maxVisible={maxVisibleNodes}
|
| 323 |
+
/>
|
| 324 |
+
</>
|
| 325 |
+
);
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
/**
|
| 329 |
+
* Main component with Canvas wrapper
|
| 330 |
+
*/
|
| 331 |
+
export default function ForceDirectedGraph3DInstanced({
|
| 332 |
+
width,
|
| 333 |
+
height,
|
| 334 |
+
nodes,
|
| 335 |
+
links,
|
| 336 |
+
onNodeClick,
|
| 337 |
+
onNodeHover,
|
| 338 |
+
selectedNodeId,
|
| 339 |
+
enabledEdgeTypes,
|
| 340 |
+
showLabels = false,
|
| 341 |
+
maxVisibleNodes = 500000,
|
| 342 |
+
maxVisibleEdges = 100000,
|
| 343 |
+
}: ForceDirectedGraph3DInstancedProps) {
|
| 344 |
+
// Calculate bounds for camera positioning
|
| 345 |
+
const bounds = useMemo(() => {
|
| 346 |
+
if (nodes.length === 0) {
|
| 347 |
+
return { center: [0, 0, 0] as [number, number, number], radius: 100 };
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
let minX = Infinity, maxX = -Infinity;
|
| 351 |
+
let minY = Infinity, maxY = -Infinity;
|
| 352 |
+
let minZ = Infinity, maxZ = -Infinity;
|
| 353 |
+
|
| 354 |
+
// Sample nodes for bounds calculation if too many
|
| 355 |
+
const sampleNodes = nodes.length > 10000
|
| 356 |
+
? nodes.filter((_, i) => i % Math.ceil(nodes.length / 10000) === 0)
|
| 357 |
+
: nodes;
|
| 358 |
+
|
| 359 |
+
sampleNodes.forEach(node => {
|
| 360 |
+
const x = node.x || 0;
|
| 361 |
+
const y = node.y || 0;
|
| 362 |
+
const z = node.z || 0;
|
| 363 |
+
minX = Math.min(minX, x);
|
| 364 |
+
maxX = Math.max(maxX, x);
|
| 365 |
+
minY = Math.min(minY, y);
|
| 366 |
+
maxY = Math.max(maxY, y);
|
| 367 |
+
minZ = Math.min(minZ, z);
|
| 368 |
+
maxZ = Math.max(maxZ, z);
|
| 369 |
+
});
|
| 370 |
+
|
| 371 |
+
const center: [number, number, number] = [
|
| 372 |
+
(minX + maxX) / 2,
|
| 373 |
+
(minY + maxY) / 2,
|
| 374 |
+
(minZ + maxZ) / 2,
|
| 375 |
+
];
|
| 376 |
+
const radius = Math.max(
|
| 377 |
+
maxX - minX,
|
| 378 |
+
maxY - minY,
|
| 379 |
+
maxZ - minZ
|
| 380 |
+
) / 2 || 100;
|
| 381 |
+
|
| 382 |
+
return { center, radius };
|
| 383 |
+
}, [nodes]);
|
| 384 |
+
|
| 385 |
+
if (nodes.length === 0) {
|
| 386 |
+
return (
|
| 387 |
+
<div className="force-directed-graph-container">
|
| 388 |
+
<div className="graph-empty">No nodes to display</div>
|
| 389 |
+
</div>
|
| 390 |
+
);
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
return (
|
| 394 |
+
<div className="force-directed-graph-container" style={{ width, height }}>
|
| 395 |
+
<Canvas
|
| 396 |
+
dpr={[1, 1.5]}
|
| 397 |
+
gl={{
|
| 398 |
+
antialias: true,
|
| 399 |
+
alpha: false,
|
| 400 |
+
powerPreference: 'high-performance',
|
| 401 |
+
stencil: false,
|
| 402 |
+
depth: true,
|
| 403 |
+
}}
|
| 404 |
+
camera={{
|
| 405 |
+
position: [
|
| 406 |
+
bounds.center[0] + bounds.radius * 1.5,
|
| 407 |
+
bounds.center[1] + bounds.radius * 1.5,
|
| 408 |
+
bounds.center[2] + bounds.radius * 1.5,
|
| 409 |
+
],
|
| 410 |
+
fov: 45,
|
| 411 |
+
near: 0.1,
|
| 412 |
+
far: bounds.radius * 20,
|
| 413 |
+
}}
|
| 414 |
+
frameloop="demand"
|
| 415 |
+
>
|
| 416 |
+
<color attach="background" args={['#1a1a1a']} />
|
| 417 |
+
|
| 418 |
+
<OrbitControls
|
| 419 |
+
target={bounds.center}
|
| 420 |
+
enableDamping={true}
|
| 421 |
+
dampingFactor={0.05}
|
| 422 |
+
minDistance={bounds.radius * 0.1}
|
| 423 |
+
maxDistance={bounds.radius * 5}
|
| 424 |
+
makeDefault
|
| 425 |
+
/>
|
| 426 |
+
|
| 427 |
+
<ambientLight intensity={0.8} />
|
| 428 |
+
<directionalLight position={[1, 1, 1]} intensity={0.5} />
|
| 429 |
+
|
| 430 |
+
<Scene
|
| 431 |
+
nodes={nodes}
|
| 432 |
+
links={links}
|
| 433 |
+
onNodeClick={onNodeClick}
|
| 434 |
+
onNodeHover={onNodeHover}
|
| 435 |
+
selectedNodeId={selectedNodeId}
|
| 436 |
+
enabledEdgeTypes={enabledEdgeTypes}
|
| 437 |
+
maxVisibleNodes={maxVisibleNodes}
|
| 438 |
+
maxVisibleEdges={maxVisibleEdges}
|
| 439 |
+
width={width}
|
| 440 |
+
height={height}
|
| 441 |
+
/>
|
| 442 |
+
</Canvas>
|
| 443 |
+
|
| 444 |
+
{/* Performance info overlay */}
|
| 445 |
+
<div className="graph-performance-info" style={{
|
| 446 |
+
position: 'absolute',
|
| 447 |
+
top: '10px',
|
| 448 |
+
right: '10px',
|
| 449 |
+
padding: '8px 12px',
|
| 450 |
+
background: 'rgba(0,0,0,0.7)',
|
| 451 |
+
color: '#fff',
|
| 452 |
+
borderRadius: '4px',
|
| 453 |
+
fontSize: '12px',
|
| 454 |
+
fontFamily: 'monospace',
|
| 455 |
+
}}>
|
| 456 |
+
<div>Nodes: {nodes.length.toLocaleString()}</div>
|
| 457 |
+
<div>Edges: {links.length.toLocaleString()}</div>
|
| 458 |
+
{nodes.length > maxVisibleNodes && (
|
| 459 |
+
<div style={{ color: '#f59e0b' }}>
|
| 460 |
+
Showing top {maxVisibleNodes.toLocaleString()} by popularity
|
| 461 |
+
</div>
|
| 462 |
+
)}
|
| 463 |
+
</div>
|
| 464 |
+
</div>
|
| 465 |
+
);
|
| 466 |
+
}
|
| 467 |
+
|
frontend/src/pages/GraphPage.tsx
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import React, { useState, useEffect, useCallback } from 'react';
|
| 2 |
import ForceDirectedGraph, { EdgeType, GraphNode } from '../components/visualizations/ForceDirectedGraph';
|
| 3 |
import ForceDirectedGraph3D from '../components/visualizations/ForceDirectedGraph3D';
|
|
|
|
| 4 |
import ScatterPlot3D from '../components/visualizations/ScatterPlot3D';
|
| 5 |
import { fetchFamilyNetwork, fetchFullDerivativeNetwork, getAvailableEdgeTypes } from '../utils/api/graphApi';
|
| 6 |
import LoadingProgress from '../components/ui/LoadingProgress';
|
|
@@ -11,6 +12,9 @@ import './GraphPage.css';
|
|
| 11 |
|
| 12 |
const ALL_EDGE_TYPES: EdgeType[] = ['finetune', 'quantized', 'adapter', 'merge', 'parent'];
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
type ViewMode = 'graph' | 'embedding' | 'graph3d';
|
| 15 |
type GraphMode = 'family' | 'full';
|
| 16 |
|
|
@@ -451,16 +455,32 @@ export default function GraphPage() {
|
|
| 451 |
) : viewMode === 'graph3d' ? (
|
| 452 |
<>
|
| 453 |
<div style={{ width: '100%', height: '100%', position: 'relative' }}>
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
</div>
|
| 465 |
<EdgeTypeLegend
|
| 466 |
edgeTypes={ALL_EDGE_TYPES}
|
|
@@ -471,11 +491,11 @@ export default function GraphPage() {
|
|
| 471 |
<div className="graph-stats">
|
| 472 |
<div className="stat-item">
|
| 473 |
<span className="stat-label">Nodes:</span>
|
| 474 |
-
<span className="stat-value">{graphStats.nodes || nodes.length}</span>
|
| 475 |
</div>
|
| 476 |
<div className="stat-item">
|
| 477 |
<span className="stat-label">Edges:</span>
|
| 478 |
-
<span className="stat-value">{graphStats.edges || links.length}</span>
|
| 479 |
</div>
|
| 480 |
{graphStats.avg_degree && (
|
| 481 |
<div className="stat-item">
|
|
@@ -483,6 +503,12 @@ export default function GraphPage() {
|
|
| 483 |
<span className="stat-value">{graphStats.avg_degree.toFixed(2)}</span>
|
| 484 |
</div>
|
| 485 |
)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
</div>
|
| 487 |
)}
|
| 488 |
</>
|
|
|
|
| 1 |
import React, { useState, useEffect, useCallback } from 'react';
|
| 2 |
import ForceDirectedGraph, { EdgeType, GraphNode } from '../components/visualizations/ForceDirectedGraph';
|
| 3 |
import ForceDirectedGraph3D from '../components/visualizations/ForceDirectedGraph3D';
|
| 4 |
+
import ForceDirectedGraph3DInstanced from '../components/visualizations/ForceDirectedGraph3DInstanced';
|
| 5 |
import ScatterPlot3D from '../components/visualizations/ScatterPlot3D';
|
| 6 |
import { fetchFamilyNetwork, fetchFullDerivativeNetwork, getAvailableEdgeTypes } from '../utils/api/graphApi';
|
| 7 |
import LoadingProgress from '../components/ui/LoadingProgress';
|
|
|
|
| 12 |
|
| 13 |
const ALL_EDGE_TYPES: EdgeType[] = ['finetune', 'quantized', 'adapter', 'merge', 'parent'];
|
| 14 |
|
| 15 |
+
// Use instanced rendering threshold for large graphs
|
| 16 |
+
const INSTANCED_THRESHOLD = 10000;
|
| 17 |
+
|
| 18 |
type ViewMode = 'graph' | 'embedding' | 'graph3d';
|
| 19 |
type GraphMode = 'family' | 'full';
|
| 20 |
|
|
|
|
| 455 |
) : viewMode === 'graph3d' ? (
|
| 456 |
<>
|
| 457 |
<div style={{ width: '100%', height: '100%', position: 'relative' }}>
|
| 458 |
+
{/* Use instanced rendering for large graphs (>10k nodes) */}
|
| 459 |
+
{nodes.length > INSTANCED_THRESHOLD ? (
|
| 460 |
+
<ForceDirectedGraph3DInstanced
|
| 461 |
+
width={dimensions.width}
|
| 462 |
+
height={dimensions.height}
|
| 463 |
+
nodes={nodes}
|
| 464 |
+
links={links}
|
| 465 |
+
onNodeClick={handleNodeClick}
|
| 466 |
+
selectedNodeId={selectedNodeId}
|
| 467 |
+
enabledEdgeTypes={enabledEdgeTypes}
|
| 468 |
+
showLabels={false}
|
| 469 |
+
maxVisibleNodes={500000}
|
| 470 |
+
maxVisibleEdges={200000}
|
| 471 |
+
/>
|
| 472 |
+
) : (
|
| 473 |
+
<ForceDirectedGraph3D
|
| 474 |
+
width={dimensions.width}
|
| 475 |
+
height={dimensions.height}
|
| 476 |
+
nodes={nodes}
|
| 477 |
+
links={links}
|
| 478 |
+
onNodeClick={handleNodeClick}
|
| 479 |
+
selectedNodeId={selectedNodeId}
|
| 480 |
+
enabledEdgeTypes={enabledEdgeTypes}
|
| 481 |
+
showLabels={true}
|
| 482 |
+
/>
|
| 483 |
+
)}
|
| 484 |
</div>
|
| 485 |
<EdgeTypeLegend
|
| 486 |
edgeTypes={ALL_EDGE_TYPES}
|
|
|
|
| 491 |
<div className="graph-stats">
|
| 492 |
<div className="stat-item">
|
| 493 |
<span className="stat-label">Nodes:</span>
|
| 494 |
+
<span className="stat-value">{(graphStats.nodes || nodes.length).toLocaleString()}</span>
|
| 495 |
</div>
|
| 496 |
<div className="stat-item">
|
| 497 |
<span className="stat-label">Edges:</span>
|
| 498 |
+
<span className="stat-value">{(graphStats.edges || links.length).toLocaleString()}</span>
|
| 499 |
</div>
|
| 500 |
{graphStats.avg_degree && (
|
| 501 |
<div className="stat-item">
|
|
|
|
| 503 |
<span className="stat-value">{graphStats.avg_degree.toFixed(2)}</span>
|
| 504 |
</div>
|
| 505 |
)}
|
| 506 |
+
{nodes.length > INSTANCED_THRESHOLD && (
|
| 507 |
+
<div className="stat-item stat-info">
|
| 508 |
+
<span className="stat-label">Rendering:</span>
|
| 509 |
+
<span className="stat-value">Instanced (optimized)</span>
|
| 510 |
+
</div>
|
| 511 |
+
)}
|
| 512 |
</div>
|
| 513 |
)}
|
| 514 |
</>
|