midah commited on
Commit
4a759fa
·
1 Parent(s): e725546

Add force-directed graph visualization and embedding view toggle

Browse files

- Add ForceDirectedGraph component with color-coded edges for different relationship types (finetune, quantized, adapter, merge, parent)
- Add GraphPage with toggle between force-directed graph and 3D embedding view
- Add API utilities for fetching family network data
- Add 'Show All Models' toggle to disable sampling (show all models instead of 150k limit)
- Update backend to handle unlimited model requests (max_points >= 1M treated as unlimited)
- Add edge type filtering and legend in graph view
- Add color/size controls for embedding view
- Integrate GraphPage into main navigation

backend/api/main.py CHANGED
@@ -360,20 +360,23 @@ async def get_models(
360
  "returned_count": 0
361
  }
362
 
363
- if max_points is not None and len(filtered_df) > max_points:
 
 
 
364
  if 'library_name' in filtered_df.columns and filtered_df['library_name'].notna().any():
365
  # Sample proportionally by library, preserving all columns
366
  sampled_dfs = []
367
  for lib_name, group in filtered_df.groupby('library_name', group_keys=False):
368
- n_samples = max(1, int(max_points * len(group) / len(filtered_df)))
369
  sampled_dfs.append(group.sample(min(len(group), n_samples), random_state=42))
370
  filtered_df = pd.concat(sampled_dfs, ignore_index=True)
371
- if len(filtered_df) > max_points:
372
- filtered_df = filtered_df.sample(n=max_points, random_state=42).reset_index(drop=True)
373
  else:
374
  filtered_df = filtered_df.reset_index(drop=True)
375
  else:
376
- filtered_df = filtered_df.sample(n=max_points, random_state=42).reset_index(drop=True)
377
 
378
  # Determine which embeddings to use
379
  if use_graph_embeddings and combined_embeddings is not None:
 
360
  "returned_count": 0
361
  }
362
 
363
+ # Handle max_points: None means no limit, very large number also means no limit
364
+ effective_max_points = None if max_points is None or max_points >= 1000000 else max_points
365
+
366
+ if effective_max_points is not None and len(filtered_df) > effective_max_points:
367
  if 'library_name' in filtered_df.columns and filtered_df['library_name'].notna().any():
368
  # Sample proportionally by library, preserving all columns
369
  sampled_dfs = []
370
  for lib_name, group in filtered_df.groupby('library_name', group_keys=False):
371
+ n_samples = max(1, int(effective_max_points * len(group) / len(filtered_df)))
372
  sampled_dfs.append(group.sample(min(len(group), n_samples), random_state=42))
373
  filtered_df = pd.concat(sampled_dfs, ignore_index=True)
374
+ if len(filtered_df) > effective_max_points:
375
+ filtered_df = filtered_df.sample(n=effective_max_points, random_state=42).reset_index(drop=True)
376
  else:
377
  filtered_df = filtered_df.reset_index(drop=True)
378
  else:
379
+ filtered_df = filtered_df.sample(n=effective_max_points, random_state=42).reset_index(drop=True)
380
 
381
  # Determine which embeddings to use
382
  if use_graph_embeddings and combined_embeddings is not None:
backend/api/routes/models.py CHANGED
@@ -88,19 +88,22 @@ async def get_models(
88
  "returned_count": 0
89
  }
90
 
91
- if max_points is not None and len(filtered_df) > max_points:
 
 
 
92
  if 'library_name' in filtered_df.columns and filtered_df['library_name'].notna().any():
93
  sampled_dfs = []
94
  for lib_name, group in filtered_df.groupby('library_name', group_keys=False):
95
- n_samples = max(1, int(max_points * len(group) / len(filtered_df)))
96
  sampled_dfs.append(group.sample(min(len(group), n_samples), random_state=42))
97
  filtered_df = pd.concat(sampled_dfs, ignore_index=True)
98
- if len(filtered_df) > max_points:
99
- filtered_df = filtered_df.sample(n=max_points, random_state=42).reset_index(drop=True)
100
  else:
101
  filtered_df = filtered_df.reset_index(drop=True)
102
  else:
103
- filtered_df = filtered_df.sample(n=max_points, random_state=42).reset_index(drop=True)
104
 
105
  # Determine which embeddings to use
106
  if use_graph_embeddings and deps.combined_embeddings is not None:
 
88
  "returned_count": 0
89
  }
90
 
91
+ # Handle max_points: None means no limit, very large number also means no limit
92
+ effective_max_points = None if max_points is None or max_points >= 1000000 else max_points
93
+
94
+ if effective_max_points is not None and len(filtered_df) > effective_max_points:
95
  if 'library_name' in filtered_df.columns and filtered_df['library_name'].notna().any():
96
  sampled_dfs = []
97
  for lib_name, group in filtered_df.groupby('library_name', group_keys=False):
98
+ n_samples = max(1, int(effective_max_points * len(group) / len(filtered_df)))
99
  sampled_dfs.append(group.sample(min(len(group), n_samples), random_state=42))
100
  filtered_df = pd.concat(sampled_dfs, ignore_index=True)
101
+ if len(filtered_df) > effective_max_points:
102
+ filtered_df = filtered_df.sample(n=effective_max_points, random_state=42).reset_index(drop=True)
103
  else:
104
  filtered_df = filtered_df.reset_index(drop=True)
105
  else:
106
+ filtered_df = filtered_df.sample(n=effective_max_points, random_state=42).reset_index(drop=True)
107
 
108
  # Determine which embeddings to use
109
  if use_graph_embeddings and deps.combined_embeddings is not None:
frontend/src/App.css CHANGED
@@ -630,6 +630,34 @@
630
  font-weight: 500;
631
  }
632
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
  .control-btn {
634
  padding: 0.35rem 0.75rem;
635
  border: 1px solid var(--border-medium);
 
630
  font-weight: 500;
631
  }
632
 
633
+ .control-toggle {
634
+ display: flex;
635
+ align-items: center;
636
+ gap: 0.5rem;
637
+ cursor: pointer;
638
+ user-select: none;
639
+ }
640
+
641
+ .control-checkbox {
642
+ width: 16px;
643
+ height: 16px;
644
+ cursor: pointer;
645
+ accent-color: var(--accent-blue);
646
+ }
647
+
648
+ .control-toggle-label {
649
+ font-size: 0.85rem;
650
+ color: var(--text-primary);
651
+ white-space: nowrap;
652
+ }
653
+
654
+ .control-info {
655
+ font-size: 0.75rem;
656
+ color: var(--text-tertiary);
657
+ font-style: italic;
658
+ margin-left: 0.25rem;
659
+ }
660
+
661
  .control-btn {
662
  padding: 0.35rem 0.75rem;
663
  border: 1px solid var(--border-medium);
frontend/src/App.tsx CHANGED
@@ -10,6 +10,7 @@ import LiveModelCounter from './components/ui/LiveModelCounter';
10
  import ModelPopup from './components/ui/ModelPopup';
11
  import AnalyticsPage from './pages/AnalyticsPage';
12
  import FamiliesPage from './pages/FamiliesPage';
 
13
  // Types & Utils
14
  import { ModelPoint, Stats, SearchResult } from './types';
15
  import IntegratedSearch from './components/controls/IntegratedSearch';
@@ -82,6 +83,8 @@ function App() {
82
  const [semanticQueryModel] = useState<string | null>(null);
83
  const [showAnalytics, setShowAnalytics] = useState(false);
84
  const [showFamilies, setShowFamilies] = useState(false);
 
 
85
 
86
  const [, setSearchResults] = useState<SearchResult[]>([]);
87
  const [searchInput] = useState('');
@@ -195,8 +198,15 @@ function App() {
195
  params.append('search_query', searchQuery);
196
  }
197
 
198
- // Request up to 150k models for scatter plots, limit network graph for performance
199
- params.append('max_points', viewMode === 'network' ? '500' : '150000');
 
 
 
 
 
 
 
200
  // Add format parameter for MessagePack support
201
  params.append('format', 'msgpack');
202
 
@@ -270,7 +280,7 @@ function App() {
270
  setLoading(false);
271
  fetchDataAbortRef.current = null;
272
  }
273
- }, [minDownloads, minLikes, searchQuery, projectionMethod, baseModelsOnly, semanticSimilarityMode, semanticQueryModel, useGraphEmbeddings, selectedClusters, viewMode]);
274
 
275
  // Debounce times for different control types
276
  const SLIDER_DEBOUNCE_MS = 500; // Sliders need longer debounce
@@ -524,8 +534,9 @@ function App() {
524
  onClick={() => {
525
  setShowAnalytics(false);
526
  setShowFamilies(false);
 
527
  }}
528
- className={`nav-tab ${!showAnalytics && !showFamilies ? 'active' : ''}`}
529
  title="3D scatter plot of model embeddings — explore the model space interactively"
530
  >
531
  Visualization
@@ -534,6 +545,7 @@ function App() {
534
  onClick={() => {
535
  setShowAnalytics(false);
536
  setShowFamilies(true);
 
537
  }}
538
  className={`nav-tab ${showFamilies ? 'active' : ''}`}
539
  title="Browse model families and their lineage trees"
@@ -543,6 +555,18 @@ function App() {
543
  <button
544
  onClick={() => {
545
  setShowFamilies(false);
 
 
 
 
 
 
 
 
 
 
 
 
546
  setShowAnalytics(true);
547
  }}
548
  className={`nav-tab ${showAnalytics ? 'active' : ''}`}
@@ -570,6 +594,8 @@ function App() {
570
  <AnalyticsPage />
571
  ) : showFamilies ? (
572
  <FamiliesPage />
 
 
573
  ) : (
574
  <div className="visualization-layout">
575
  <div className="control-bar">
@@ -641,6 +667,26 @@ function App() {
641
  </span>
642
  </div>
643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644
  </div>
645
 
646
  {/* Right: Integrated Search */}
 
10
  import ModelPopup from './components/ui/ModelPopup';
11
  import AnalyticsPage from './pages/AnalyticsPage';
12
  import FamiliesPage from './pages/FamiliesPage';
13
+ import GraphPage from './pages/GraphPage';
14
  // Types & Utils
15
  import { ModelPoint, Stats, SearchResult } from './types';
16
  import IntegratedSearch from './components/controls/IntegratedSearch';
 
83
  const [semanticQueryModel] = useState<string | null>(null);
84
  const [showAnalytics, setShowAnalytics] = useState(false);
85
  const [showFamilies, setShowFamilies] = useState(false);
86
+ const [showGraph, setShowGraph] = useState(false);
87
+ const [showAllModels, setShowAllModels] = useState(false);
88
 
89
  const [, setSearchResults] = useState<SearchResult[]>([]);
90
  const [searchInput] = useState('');
 
198
  params.append('search_query', searchQuery);
199
  }
200
 
201
+ // Request models - either all or sampled
202
+ if (showAllModels) {
203
+ // Request all models (backend will handle if too many)
204
+ // Use a very large number to effectively mean "all"
205
+ params.append('max_points', '10000000'); // Effectively unlimited
206
+ } else {
207
+ // Request up to 150k models for scatter plots (sampled), limit network graph for performance
208
+ params.append('max_points', viewMode === 'network' ? '500' : '150000');
209
+ }
210
  // Add format parameter for MessagePack support
211
  params.append('format', 'msgpack');
212
 
 
280
  setLoading(false);
281
  fetchDataAbortRef.current = null;
282
  }
283
+ }, [minDownloads, minLikes, searchQuery, projectionMethod, baseModelsOnly, semanticSimilarityMode, semanticQueryModel, useGraphEmbeddings, selectedClusters, viewMode, showAllModels]);
284
 
285
  // Debounce times for different control types
286
  const SLIDER_DEBOUNCE_MS = 500; // Sliders need longer debounce
 
534
  onClick={() => {
535
  setShowAnalytics(false);
536
  setShowFamilies(false);
537
+ setShowGraph(false);
538
  }}
539
+ className={`nav-tab ${!showAnalytics && !showFamilies && !showGraph ? 'active' : ''}`}
540
  title="3D scatter plot of model embeddings — explore the model space interactively"
541
  >
542
  Visualization
 
545
  onClick={() => {
546
  setShowAnalytics(false);
547
  setShowFamilies(true);
548
+ setShowGraph(false);
549
  }}
550
  className={`nav-tab ${showFamilies ? 'active' : ''}`}
551
  title="Browse model families and their lineage trees"
 
555
  <button
556
  onClick={() => {
557
  setShowFamilies(false);
558
+ setShowAnalytics(false);
559
+ setShowGraph(true);
560
+ }}
561
+ className={`nav-tab ${showGraph ? 'active' : ''}`}
562
+ title="Force-directed graph showing model relationships and derivatives"
563
+ >
564
+ Graph
565
+ </button>
566
+ <button
567
+ onClick={() => {
568
+ setShowFamilies(false);
569
+ setShowGraph(false);
570
  setShowAnalytics(true);
571
  }}
572
  className={`nav-tab ${showAnalytics ? 'active' : ''}`}
 
594
  <AnalyticsPage />
595
  ) : showFamilies ? (
596
  <FamiliesPage />
597
+ ) : showGraph ? (
598
+ <GraphPage />
599
  ) : (
600
  <div className="visualization-layout">
601
  <div className="control-bar">
 
667
  </span>
668
  </div>
669
 
670
+ <span className="control-divider" />
671
+
672
+ {/* Show All Models Toggle */}
673
+ <div className="control-group">
674
+ <label className="control-toggle" title="Show all models (no sampling). When disabled, shows up to 150k models sampled proportionally by library, prioritizing base models and popular models.">
675
+ <input
676
+ type="checkbox"
677
+ checked={showAllModels}
678
+ onChange={(e) => setShowAllModels(e.target.checked)}
679
+ className="control-checkbox"
680
+ />
681
+ <span className="control-toggle-label">Show All Models</span>
682
+ </label>
683
+ {!showAllModels && (
684
+ <span className="control-info" title="Sampling strategy: Includes all base models, then adds popular derived models and diverse samples across libraries proportionally. Max 150k models for performance.">
685
+ (Sampled)
686
+ </span>
687
+ )}
688
+ </div>
689
+
690
  </div>
691
 
692
  {/* Right: Integrated Search */}
frontend/src/components/visualizations/ForceDirectedGraph.css ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .force-directed-graph-container {
2
+ position: relative;
3
+ width: 100%;
4
+ height: 100%;
5
+ background: var(--bg-secondary, #1a1a1a);
6
+ border-radius: 8px;
7
+ overflow: hidden;
8
+ }
9
+
10
+ .force-directed-graph {
11
+ width: 100%;
12
+ height: 100%;
13
+ display: block;
14
+ }
15
+
16
+ .force-directed-graph .links line {
17
+ transition: stroke-opacity 0.2s, stroke-width 0.2s;
18
+ }
19
+
20
+ .force-directed-graph .nodes circle {
21
+ transition: stroke-width 0.2s, r 0.2s;
22
+ }
23
+
24
+ .force-directed-graph .nodes circle:hover {
25
+ filter: brightness(1.2);
26
+ }
27
+
28
+ .force-directed-graph .labels text {
29
+ user-select: none;
30
+ pointer-events: none;
31
+ }
32
+
33
+ /* Edge type legend */
34
+ .edge-type-legend {
35
+ position: absolute;
36
+ top: 16px;
37
+ right: 16px;
38
+ background: var(--bg-primary, #0a0a0a);
39
+ border: 1px solid var(--border-color, #333);
40
+ border-radius: 8px;
41
+ padding: 12px;
42
+ font-size: 12px;
43
+ z-index: 10;
44
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.3);
45
+ }
46
+
47
+ .edge-type-legend h4 {
48
+ margin: 0 0 8px 0;
49
+ font-size: 13px;
50
+ font-weight: 600;
51
+ color: var(--text-primary, #fff);
52
+ }
53
+
54
+ .edge-type-item {
55
+ display: flex;
56
+ align-items: center;
57
+ gap: 8px;
58
+ margin-bottom: 6px;
59
+ cursor: pointer;
60
+ padding: 4px;
61
+ border-radius: 4px;
62
+ transition: background-color 0.2s;
63
+ }
64
+
65
+ .edge-type-item:hover {
66
+ background-color: var(--bg-secondary, #2a2a2a);
67
+ }
68
+
69
+ .edge-type-item:last-child {
70
+ margin-bottom: 0;
71
+ }
72
+
73
+ .edge-type-color {
74
+ width: 20px;
75
+ height: 3px;
76
+ border-radius: 2px;
77
+ flex-shrink: 0;
78
+ }
79
+
80
+ .edge-type-label {
81
+ flex: 1;
82
+ color: var(--text-secondary, #ccc);
83
+ text-transform: capitalize;
84
+ }
85
+
86
+ .edge-type-item.disabled .edge-type-color {
87
+ opacity: 0.3;
88
+ }
89
+
90
+ .edge-type-item.disabled .edge-type-label {
91
+ opacity: 0.5;
92
+ text-decoration: line-through;
93
+ }
94
+
95
+ /* Graph controls */
96
+ .graph-controls {
97
+ position: absolute;
98
+ bottom: 16px;
99
+ left: 16px;
100
+ display: flex;
101
+ gap: 8px;
102
+ z-index: 10;
103
+ }
104
+
105
+ .graph-control-btn {
106
+ background: var(--bg-primary, #0a0a0a);
107
+ border: 1px solid var(--border-color, #333);
108
+ border-radius: 6px;
109
+ padding: 8px 12px;
110
+ color: var(--text-primary, #fff);
111
+ font-size: 12px;
112
+ cursor: pointer;
113
+ transition: all 0.2s;
114
+ }
115
+
116
+ .graph-control-btn:hover {
117
+ background: var(--bg-secondary, #2a2a2a);
118
+ border-color: var(--border-hover, #555);
119
+ }
120
+
121
+ .graph-control-btn:active {
122
+ transform: scale(0.95);
123
+ }
124
+
125
+ /* Loading state */
126
+ .graph-loading {
127
+ display: flex;
128
+ align-items: center;
129
+ justify-content: center;
130
+ width: 100%;
131
+ height: 100%;
132
+ color: var(--text-secondary, #999);
133
+ font-size: 14px;
134
+ }
135
+
136
+ /* Empty state */
137
+ .graph-empty {
138
+ display: flex;
139
+ align-items: center;
140
+ justify-content: center;
141
+ width: 100%;
142
+ height: 100%;
143
+ color: var(--text-secondary, #999);
144
+ font-size: 14px;
145
+ }
frontend/src/components/visualizations/ForceDirectedGraph.tsx ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Force-directed graph visualization showing model relationships.
3
+ * Displays different types of derivatives (finetunes, adapters, quantizations, merges)
4
+ * with color-coded edges and interactive nodes.
5
+ */
6
+ import React, { useMemo, useRef, useEffect, useState, useCallback } from 'react';
7
+ import * as d3 from 'd3';
8
+ import './ForceDirectedGraph.css';
9
+
10
+ export type EdgeType = 'finetune' | 'quantized' | 'adapter' | 'merge' | 'parent';
11
+
12
+ export interface GraphNode {
13
+ id: string;
14
+ title: string;
15
+ downloads: number;
16
+ likes: number;
17
+ library: string;
18
+ pipeline: string;
19
+ x?: number;
20
+ y?: number;
21
+ fx?: number | null;
22
+ fy?: number | null;
23
+ }
24
+
25
+ export interface GraphLink {
26
+ source: string | GraphNode;
27
+ target: string | GraphNode;
28
+ edge_type: EdgeType;
29
+ edge_types?: EdgeType[];
30
+ change_in_downloads?: number;
31
+ change_in_likes?: number;
32
+ }
33
+
34
+ interface ProcessedLink {
35
+ source: string;
36
+ target: string;
37
+ edge_type: EdgeType;
38
+ edge_types?: EdgeType[];
39
+ change_in_downloads?: number;
40
+ change_in_likes?: number;
41
+ }
42
+
43
+ export interface ForceDirectedGraphProps {
44
+ width: number;
45
+ height: number;
46
+ nodes: GraphNode[];
47
+ links: GraphLink[];
48
+ onNodeClick?: (node: GraphNode) => void;
49
+ onNodeHover?: (node: GraphNode | null) => void;
50
+ selectedNodeId?: string | null;
51
+ enabledEdgeTypes?: Set<EdgeType>;
52
+ showLabels?: boolean;
53
+ }
54
+
55
+ // Color scheme for different edge types
56
+ const EDGE_COLORS: Record<EdgeType, string> = {
57
+ finetune: '#3b82f6', // Blue - fine-tuning
58
+ quantized: '#10b981', // Green - quantization
59
+ adapter: '#f59e0b', // Orange - adapters
60
+ merge: '#8b5cf6', // Purple - merges
61
+ parent: '#6b7280', // Gray - generic parent
62
+ };
63
+
64
+ const EDGE_STROKE_WIDTH: Record<EdgeType, number> = {
65
+ finetune: 2,
66
+ quantized: 1.5,
67
+ adapter: 1.5,
68
+ merge: 2,
69
+ parent: 1,
70
+ };
71
+
72
+ export default function ForceDirectedGraph({
73
+ width,
74
+ height,
75
+ nodes,
76
+ links,
77
+ onNodeClick,
78
+ onNodeHover,
79
+ selectedNodeId,
80
+ enabledEdgeTypes,
81
+ showLabels = true,
82
+ }: ForceDirectedGraphProps) {
83
+ const svgRef = useRef<SVGSVGElement>(null);
84
+ const simulationRef = useRef<d3.Simulation<GraphNode, GraphLink> | null>(null);
85
+ const [hoveredNodeId, setHoveredNodeId] = useState<string | null>(null);
86
+
87
+ // Filter links based on enabled edge types and ensure source/target are node IDs
88
+ const filteredLinks = useMemo((): ProcessedLink[] => {
89
+ let filtered = links;
90
+ if (enabledEdgeTypes && enabledEdgeTypes.size > 0) {
91
+ filtered = links.filter(link => {
92
+ const linkTypes = link.edge_types || [link.edge_type];
93
+ return linkTypes.some(type => enabledEdgeTypes.has(type));
94
+ });
95
+ }
96
+ // Ensure source and target are strings (node IDs)
97
+ return filtered.map(link => ({
98
+ ...link,
99
+ source: typeof link.source === 'string' ? link.source : (link.source as GraphNode).id,
100
+ target: typeof link.target === 'string' ? link.target : (link.target as GraphNode).id,
101
+ }));
102
+ }, [links, enabledEdgeTypes]);
103
+
104
+ // Filter nodes to only include those connected by filtered links
105
+ const filteredNodes = useMemo(() => {
106
+ if (!enabledEdgeTypes || enabledEdgeTypes.size === 0) {
107
+ return nodes;
108
+ }
109
+ const connectedNodeIds = new Set<string>();
110
+ filteredLinks.forEach(link => {
111
+ const sourceId = typeof link.source === 'string' ? link.source : link.source.id;
112
+ const targetId = typeof link.target === 'string' ? link.target : link.target.id;
113
+ connectedNodeIds.add(sourceId);
114
+ connectedNodeIds.add(targetId);
115
+ });
116
+ return nodes.filter(node => connectedNodeIds.has(node.id));
117
+ }, [nodes, filteredLinks, enabledEdgeTypes]);
118
+
119
+ useEffect(() => {
120
+ if (!svgRef.current || filteredNodes.length === 0) return;
121
+
122
+ const svg = d3.select(svgRef.current);
123
+ svg.selectAll('*').remove();
124
+
125
+ // Create container
126
+ const g = svg.append('g');
127
+
128
+ // Set up zoom behavior
129
+ const zoom = d3.zoom<SVGSVGElement, unknown>()
130
+ .scaleExtent([0.1, 4])
131
+ .on('zoom', (event) => {
132
+ g.attr('transform', event.transform);
133
+ });
134
+
135
+ svg.call(zoom as any);
136
+
137
+ // Set initial transform
138
+ const initialTransform = d3.zoomIdentity.translate(width / 2, height / 2).scale(0.8);
139
+ svg.call(zoom.transform as any, initialTransform);
140
+
141
+ // Set up force simulation
142
+ const simulation = d3
143
+ .forceSimulation<GraphNode>(filteredNodes)
144
+ .force(
145
+ 'link',
146
+ d3
147
+ .forceLink<GraphNode, ProcessedLink>(filteredLinks)
148
+ .id((d) => d.id)
149
+ .distance((d) => {
150
+ // Adjust distance based on edge type
151
+ const link = d as unknown as ProcessedLink;
152
+ const edgeType = link.edge_type;
153
+ switch (edgeType) {
154
+ case 'merge':
155
+ return 120; // Merges are more distinct
156
+ case 'finetune':
157
+ return 80; // Fine-tunes are closer
158
+ case 'quantized':
159
+ return 60; // Quantizations are very close
160
+ case 'adapter':
161
+ return 70;
162
+ default:
163
+ return 100;
164
+ }
165
+ })
166
+ )
167
+ .force('charge', d3.forceManyBody().strength(-300))
168
+ .force('center', d3.forceCenter(0, 0))
169
+ .force('collision', d3.forceCollide().radius((d) => {
170
+ // Node size based on downloads
171
+ const downloads = d.downloads || 0;
172
+ return 5 + Math.sqrt(downloads) / 200;
173
+ }));
174
+
175
+ simulationRef.current = simulation;
176
+
177
+ // Create arrow markers for directed edges
178
+ const defs = svg.append('defs');
179
+ Object.entries(EDGE_COLORS).forEach(([type, color]) => {
180
+ defs
181
+ .append('marker')
182
+ .attr('id', `arrow-${type}`)
183
+ .attr('viewBox', '0 -5 10 10')
184
+ .attr('refX', 15)
185
+ .attr('refY', 0)
186
+ .attr('markerWidth', 6)
187
+ .attr('markerHeight', 6)
188
+ .attr('orient', 'auto')
189
+ .append('path')
190
+ .attr('d', 'M0,-5L10,0L0,5')
191
+ .attr('fill', color);
192
+ });
193
+
194
+ // Create links
195
+ const link = g
196
+ .append('g')
197
+ .attr('class', 'links')
198
+ .selectAll('line')
199
+ .data(filteredLinks)
200
+ .join('line')
201
+ .attr('stroke', (d) => {
202
+ const edgeType = d.edge_type;
203
+ return EDGE_COLORS[edgeType] || EDGE_COLORS.parent;
204
+ })
205
+ .attr('stroke-width', (d) => {
206
+ const edgeType = d.edge_type;
207
+ return EDGE_STROKE_WIDTH[edgeType] || 1;
208
+ })
209
+ .attr('stroke-opacity', 0.6)
210
+ .attr('marker-end', (d) => {
211
+ const edgeType = d.edge_type;
212
+ return `url(#arrow-${edgeType})`;
213
+ })
214
+ .style('cursor', 'pointer')
215
+ .on('mouseenter', function(event, d) {
216
+ d3.select(this).attr('stroke-opacity', 1).attr('stroke-width', (d) => {
217
+ const edgeType = d.edge_type;
218
+ return (EDGE_STROKE_WIDTH[edgeType] || 1) + 1;
219
+ });
220
+ })
221
+ .on('mouseleave', function(event, d) {
222
+ d3.select(this).attr('stroke-opacity', 0.6).attr('stroke-width', (d) => {
223
+ const edgeType = d.edge_type;
224
+ return EDGE_STROKE_WIDTH[edgeType] || 1;
225
+ });
226
+ });
227
+
228
+ // Create nodes
229
+ const node = g
230
+ .append('g')
231
+ .attr('class', 'nodes')
232
+ .selectAll('circle')
233
+ .data(filteredNodes)
234
+ .join('circle')
235
+ .attr('r', (d) => {
236
+ const downloads = d.downloads || 0;
237
+ return 3 + Math.sqrt(downloads) / 200;
238
+ })
239
+ .attr('fill', (d) => {
240
+ // Color by library if available
241
+ if (d.library) {
242
+ const colors = d3.schemeCategory10;
243
+ const libraries = Array.from(new Set(filteredNodes.map(n => n.library).filter(Boolean)));
244
+ const libIndex = libraries.indexOf(d.library);
245
+ return colors[libIndex % colors.length];
246
+ }
247
+ return '#6b7280';
248
+ })
249
+ .attr('stroke', (d) => {
250
+ if (selectedNodeId === d.id) {
251
+ return '#ef4444';
252
+ }
253
+ if (hoveredNodeId === d.id) {
254
+ return '#fbbf24';
255
+ }
256
+ return '#fff';
257
+ })
258
+ .attr('stroke-width', (d) => {
259
+ if (selectedNodeId === d.id) {
260
+ return 3;
261
+ }
262
+ if (hoveredNodeId === d.id) {
263
+ return 2;
264
+ }
265
+ return 1.5;
266
+ })
267
+ .style('cursor', 'pointer')
268
+ .on('click', (event, d) => {
269
+ event.stopPropagation();
270
+ if (onNodeClick) onNodeClick(d);
271
+ })
272
+ .on('mouseenter', (event, d) => {
273
+ setHoveredNodeId(d.id);
274
+ if (onNodeHover) onNodeHover(d);
275
+ d3.select(event.currentTarget as SVGCircleElement)
276
+ .attr('stroke-width', selectedNodeId === d.id ? 3 : 2.5);
277
+ })
278
+ .on('mouseleave', (event, d) => {
279
+ setHoveredNodeId(null);
280
+ if (onNodeHover) onNodeHover(null);
281
+ d3.select(event.currentTarget as SVGCircleElement)
282
+ .attr('stroke-width', selectedNodeId === d.id ? 3 : 1.5);
283
+ })
284
+ .call(drag(simulation) as any);
285
+
286
+ // Create labels
287
+ if (showLabels) {
288
+ const label = g
289
+ .append('g')
290
+ .attr('class', 'labels')
291
+ .selectAll('text')
292
+ .data(filteredNodes.filter((d) => {
293
+ // Show labels for nodes with high downloads or selected/hovered nodes
294
+ return (d.downloads || 0) > 10000 || selectedNodeId === d.id || hoveredNodeId === d.id;
295
+ }))
296
+ .join('text')
297
+ .text((d) => d.title || d.id.split('/').pop() || d.id)
298
+ .attr('font-size', '10px')
299
+ .attr('dx', 8)
300
+ .attr('dy', 4)
301
+ .attr('fill', '#fff')
302
+ .attr('stroke', '#000')
303
+ .attr('stroke-width', '0.5px')
304
+ .attr('paint-order', 'stroke')
305
+ .style('pointer-events', 'none');
306
+ }
307
+
308
+ // Update positions on simulation tick
309
+ simulation.on('tick', () => {
310
+ link
311
+ .attr('x1', (d) => {
312
+ const sourceId = (d as unknown as ProcessedLink).source;
313
+ const source = filteredNodes.find(n => n.id === sourceId);
314
+ return source?.x || 0;
315
+ })
316
+ .attr('y1', (d) => {
317
+ const sourceId = (d as unknown as ProcessedLink).source;
318
+ const source = filteredNodes.find(n => n.id === sourceId);
319
+ return source?.y || 0;
320
+ })
321
+ .attr('x2', (d) => {
322
+ const targetId = (d as unknown as ProcessedLink).target;
323
+ const target = filteredNodes.find(n => n.id === targetId);
324
+ return target?.x || 0;
325
+ })
326
+ .attr('y2', (d) => {
327
+ const targetId = (d as unknown as ProcessedLink).target;
328
+ const target = filteredNodes.find(n => n.id === targetId);
329
+ return target?.y || 0;
330
+ });
331
+
332
+ node.attr('cx', (d) => d.x || 0).attr('cy', (d) => d.y || 0);
333
+
334
+ if (showLabels) {
335
+ const label = g.selectAll<SVGTextElement, GraphNode>('.labels text');
336
+ label.attr('x', (d) => d.x || 0).attr('y', (d) => d.y || 0);
337
+ }
338
+ });
339
+
340
+ // Cleanup
341
+ return () => {
342
+ simulation.stop();
343
+ };
344
+ }, [filteredNodes, filteredLinks, width, height, onNodeClick, onNodeHover, selectedNodeId, hoveredNodeId, showLabels]);
345
+
346
+ return (
347
+ <div className="force-directed-graph-container">
348
+ <svg ref={svgRef} width={width} height={height} className="force-directed-graph" />
349
+ </div>
350
+ );
351
+ }
352
+
353
+ function drag(simulation: d3.Simulation<GraphNode, undefined>) {
354
+ function dragstarted(event: d3.D3DragEvent<SVGCircleElement, GraphNode, GraphNode>) {
355
+ if (!event.active) simulation.alphaTarget(0.3).restart();
356
+ event.subject.fx = event.subject.x;
357
+ event.subject.fy = event.subject.y;
358
+ }
359
+
360
+ function dragged(event: d3.D3DragEvent<SVGCircleElement, GraphNode, GraphNode>) {
361
+ event.subject.fx = event.x;
362
+ event.subject.fy = event.y;
363
+ }
364
+
365
+ function dragended(event: d3.D3DragEvent<SVGCircleElement, GraphNode, GraphNode>) {
366
+ if (!event.active) simulation.alphaTarget(0);
367
+ event.subject.fx = null;
368
+ event.subject.fy = null;
369
+ }
370
+
371
+ return d3
372
+ .drag<SVGCircleElement, GraphNode>()
373
+ .on('start', dragstarted)
374
+ .on('drag', dragged)
375
+ .on('end', dragended);
376
+ }
frontend/src/pages/GraphPage.css ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .graph-page {
2
+ padding: 24px;
3
+ max-width: 1400px;
4
+ margin: 0 auto;
5
+ }
6
+
7
+ .page-header {
8
+ margin-bottom: 24px;
9
+ }
10
+
11
+ .page-header h1 {
12
+ margin: 0 0 8px 0;
13
+ font-size: 28px;
14
+ font-weight: 600;
15
+ color: var(--text-primary, #fff);
16
+ }
17
+
18
+ .page-description {
19
+ margin: 0;
20
+ color: var(--text-secondary, #999);
21
+ font-size: 14px;
22
+ line-height: 1.5;
23
+ }
24
+
25
+ .graph-controls-panel {
26
+ background: var(--bg-secondary, #1a1a1a);
27
+ border: 1px solid var(--border-color, #333);
28
+ border-radius: 8px;
29
+ padding: 16px;
30
+ margin-bottom: 16px;
31
+ }
32
+
33
+ .search-section {
34
+ position: relative;
35
+ margin-bottom: 16px;
36
+ }
37
+
38
+ .graph-search-input {
39
+ width: 100%;
40
+ background: var(--bg-primary, #0a0a0a);
41
+ border: 1px solid var(--border-color, #333);
42
+ border-radius: 6px;
43
+ padding: 10px 14px;
44
+ color: var(--text-primary, #fff);
45
+ font-size: 14px;
46
+ font-family: inherit;
47
+ }
48
+
49
+ .graph-search-input:focus {
50
+ outline: none;
51
+ border-color: var(--accent-color, #3b82f6);
52
+ }
53
+
54
+ .graph-search-input::placeholder {
55
+ color: var(--text-secondary, #999);
56
+ }
57
+
58
+ .search-results-dropdown {
59
+ position: absolute;
60
+ top: 100%;
61
+ left: 0;
62
+ right: 0;
63
+ background: var(--bg-primary, #0a0a0a);
64
+ border: 1px solid var(--border-color, #333);
65
+ border-radius: 8px;
66
+ margin-top: 4px;
67
+ max-height: 300px;
68
+ overflow-y: auto;
69
+ z-index: 100;
70
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.3);
71
+ }
72
+
73
+ .search-result-item {
74
+ padding: 12px 16px;
75
+ cursor: pointer;
76
+ border-bottom: 1px solid var(--border-color, #333);
77
+ transition: background-color 0.2s;
78
+ }
79
+
80
+ .search-result-item:last-child {
81
+ border-bottom: none;
82
+ }
83
+
84
+ .search-result-item:hover {
85
+ background-color: var(--bg-secondary, #2a2a2a);
86
+ }
87
+
88
+ .result-title {
89
+ font-weight: 500;
90
+ color: var(--text-primary, #fff);
91
+ margin-bottom: 4px;
92
+ }
93
+
94
+ .result-meta {
95
+ display: flex;
96
+ gap: 12px;
97
+ font-size: 12px;
98
+ color: var(--text-secondary, #999);
99
+ }
100
+
101
+ .graph-settings {
102
+ display: flex;
103
+ gap: 24px;
104
+ flex-wrap: wrap;
105
+ }
106
+
107
+ .setting-group {
108
+ display: flex;
109
+ align-items: center;
110
+ gap: 8px;
111
+ }
112
+
113
+ .setting-group label {
114
+ font-size: 13px;
115
+ color: var(--text-secondary, #ccc);
116
+ font-weight: 500;
117
+ }
118
+
119
+ .depth-input {
120
+ background: var(--bg-primary, #0a0a0a);
121
+ border: 1px solid var(--border-color, #333);
122
+ border-radius: 6px;
123
+ padding: 6px 10px;
124
+ color: var(--text-primary, #fff);
125
+ font-size: 13px;
126
+ width: 80px;
127
+ }
128
+
129
+ .depth-input:focus {
130
+ outline: none;
131
+ border-color: var(--accent-color, #3b82f6);
132
+ }
133
+
134
+ .view-mode-select,
135
+ .color-by-select,
136
+ .size-by-select {
137
+ background: var(--bg-primary, #0a0a0a);
138
+ border: 1px solid var(--border-color, #333);
139
+ border-radius: 6px;
140
+ padding: 6px 10px;
141
+ color: var(--text-primary, #fff);
142
+ font-size: 13px;
143
+ cursor: pointer;
144
+ }
145
+
146
+ .view-mode-select:focus,
147
+ .color-by-select:focus,
148
+ .size-by-select:focus {
149
+ outline: none;
150
+ border-color: var(--accent-color, #3b82f6);
151
+ }
152
+
153
+ .embedding-info {
154
+ position: absolute;
155
+ bottom: 16px;
156
+ left: 16px;
157
+ background: var(--bg-primary, #0a0a0a);
158
+ border: 1px solid var(--border-color, #333);
159
+ border-radius: 8px;
160
+ padding: 12px 16px;
161
+ display: flex;
162
+ gap: 16px;
163
+ font-size: 12px;
164
+ z-index: 10;
165
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.3);
166
+ }
167
+
168
+ .info-item {
169
+ display: flex;
170
+ gap: 6px;
171
+ }
172
+
173
+ .info-label {
174
+ color: var(--text-secondary, #999);
175
+ }
176
+
177
+ .info-value {
178
+ color: var(--text-primary, #fff);
179
+ font-weight: 600;
180
+ }
181
+
182
+ .current-model {
183
+ font-family: monospace;
184
+ font-size: 13px;
185
+ color: var(--text-primary, #fff);
186
+ background: var(--bg-primary, #0a0a0a);
187
+ padding: 6px 10px;
188
+ border-radius: 6px;
189
+ border: 1px solid var(--border-color, #333);
190
+ max-width: 400px;
191
+ overflow: hidden;
192
+ text-overflow: ellipsis;
193
+ white-space: nowrap;
194
+ }
195
+
196
+ .graph-container {
197
+ position: relative;
198
+ width: 100%;
199
+ min-height: 600px;
200
+ background: var(--bg-secondary, #1a1a1a);
201
+ border: 1px solid var(--border-color, #333);
202
+ border-radius: 8px;
203
+ overflow: hidden;
204
+ }
205
+
206
+ .graph-error {
207
+ display: flex;
208
+ flex-direction: column;
209
+ align-items: center;
210
+ justify-content: center;
211
+ padding: 48px;
212
+ color: var(--text-error, #ef4444);
213
+ text-align: center;
214
+ }
215
+
216
+ .error-hint {
217
+ margin-top: 8px;
218
+ color: var(--text-secondary, #999);
219
+ font-size: 14px;
220
+ }
221
+
222
+ .graph-empty {
223
+ display: flex;
224
+ flex-direction: column;
225
+ align-items: center;
226
+ justify-content: center;
227
+ padding: 48px;
228
+ color: var(--text-secondary, #999);
229
+ text-align: center;
230
+ }
231
+
232
+ .empty-hint {
233
+ margin-top: 16px;
234
+ font-size: 13px;
235
+ }
236
+
237
+ .empty-hint code {
238
+ background: var(--bg-primary, #0a0a0a);
239
+ padding: 2px 6px;
240
+ border-radius: 4px;
241
+ font-family: monospace;
242
+ color: var(--accent-color, #3b82f6);
243
+ }
244
+
245
+ .graph-stats {
246
+ position: absolute;
247
+ bottom: 16px;
248
+ left: 16px;
249
+ background: var(--bg-primary, #0a0a0a);
250
+ border: 1px solid var(--border-color, #333);
251
+ border-radius: 8px;
252
+ padding: 12px 16px;
253
+ display: flex;
254
+ gap: 16px;
255
+ font-size: 12px;
256
+ z-index: 10;
257
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.3);
258
+ }
259
+
260
+ .stat-item {
261
+ display: flex;
262
+ gap: 6px;
263
+ }
264
+
265
+ .stat-label {
266
+ color: var(--text-secondary, #999);
267
+ }
268
+
269
+ .stat-value {
270
+ color: var(--text-primary, #fff);
271
+ font-weight: 600;
272
+ }
273
+
274
+ /* Responsive */
275
+ @media (max-width: 768px) {
276
+ .graph-page {
277
+ padding: 16px;
278
+ }
279
+
280
+ .graph-settings {
281
+ flex-direction: column;
282
+ gap: 12px;
283
+ }
284
+
285
+ .graph-stats {
286
+ flex-direction: column;
287
+ gap: 8px;
288
+ }
289
+ }
frontend/src/pages/GraphPage.tsx ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useState, useEffect, useCallback } from 'react';
2
+ import ForceDirectedGraph, { EdgeType, GraphNode } from '../components/visualizations/ForceDirectedGraph';
3
+ import ScatterPlot3D from '../components/visualizations/ScatterPlot3D';
4
+ import { fetchFamilyNetwork, getAvailableEdgeTypes } from '../utils/api/graphApi';
5
+ import LoadingProgress from '../components/ui/LoadingProgress';
6
+ import { ModelPoint } from '../types';
7
+ // Simple search input for graph page
8
+ import { API_BASE } from '../config/api';
9
+ import './GraphPage.css';
10
+
11
+ const ALL_EDGE_TYPES: EdgeType[] = ['finetune', 'quantized', 'adapter', 'merge', 'parent'];
12
+
13
+ type ViewMode = 'graph' | 'embedding';
14
+
15
+ export default function GraphPage() {
16
+ const [modelId, setModelId] = useState<string>('');
17
+ const [viewMode, setViewMode] = useState<ViewMode>('graph');
18
+ const [nodes, setNodes] = useState<GraphNode[]>([]);
19
+ const [links, setLinks] = useState<any[]>([]);
20
+ const [embeddingData, setEmbeddingData] = useState<ModelPoint[]>([]);
21
+ const [loading, setLoading] = useState(false);
22
+ const [loadingEmbedding, setLoadingEmbedding] = useState(false);
23
+ const [error, setError] = useState<string | null>(null);
24
+ const [selectedNodeId, setSelectedNodeId] = useState<string | null>(null);
25
+ const [selectedModel, setSelectedModel] = useState<ModelPoint | null>(null);
26
+ const [enabledEdgeTypes, setEnabledEdgeTypes] = useState<Set<EdgeType>>(
27
+ new Set(ALL_EDGE_TYPES)
28
+ );
29
+ const [maxDepth, setMaxDepth] = useState<number | undefined>(5);
30
+ const [graphStats, setGraphStats] = useState<any>(null);
31
+ const [searchResults, setSearchResults] = useState<any[]>([]);
32
+ const [showSearchResults, setShowSearchResults] = useState(false);
33
+ const [colorBy, setColorBy] = useState<string>('library_name');
34
+ const [sizeBy, setSizeBy] = useState<string>('downloads');
35
+
36
+ // Load graph when modelId or maxDepth changes
37
+ useEffect(() => {
38
+ if (!modelId.trim()) {
39
+ setNodes([]);
40
+ setLinks([]);
41
+ setEmbeddingData([]);
42
+ setGraphStats(null);
43
+ return;
44
+ }
45
+
46
+ const loadGraph = async () => {
47
+ setLoading(true);
48
+ setError(null);
49
+ try {
50
+ // Load all edge types initially, filtering happens client-side
51
+ const data = await fetchFamilyNetwork(modelId, {
52
+ maxDepth,
53
+ edgeTypes: undefined, // Get all types, filter client-side
54
+ includeEdgeAttributes: true,
55
+ });
56
+
57
+ setNodes(data.nodes || []);
58
+ setLinks(data.links || []);
59
+ setGraphStats(data.statistics);
60
+
61
+ // Update enabled edge types based on available types (only on first load)
62
+ if (data.links && data.links.length > 0) {
63
+ const availableTypes = getAvailableEdgeTypes(data.links);
64
+ // If no types are currently enabled, enable all available
65
+ if (enabledEdgeTypes.size === 0 && availableTypes.size > 0) {
66
+ setEnabledEdgeTypes(new Set(availableTypes));
67
+ }
68
+ }
69
+ } catch (err: any) {
70
+ setError(err.message || 'Failed to load graph');
71
+ setNodes([]);
72
+ setLinks([]);
73
+ setEmbeddingData([]);
74
+ } finally {
75
+ setLoading(false);
76
+ }
77
+ };
78
+
79
+ loadGraph();
80
+ }, [modelId, maxDepth]); // Only reload when modelId or maxDepth changes
81
+
82
+ // Load embedding data when switching to embedding view or when nodes change
83
+ useEffect(() => {
84
+ if (viewMode !== 'embedding' || nodes.length === 0) {
85
+ setEmbeddingData([]);
86
+ return;
87
+ }
88
+
89
+ const loadEmbeddingData = async () => {
90
+ setLoadingEmbedding(true);
91
+ try {
92
+ // Fetch embedding data for all models in the graph
93
+ const modelIds = nodes.map(n => n.id);
94
+ const params = new URLSearchParams({
95
+ max_points: '10000', // Limit for performance
96
+ format: 'json',
97
+ });
98
+
99
+ // Add search query to filter to our models
100
+ // Since we can't filter by exact model IDs easily, we'll fetch and filter client-side
101
+ const response = await fetch(`${API_BASE}/api/models?${params}`);
102
+ if (!response.ok) throw new Error('Failed to fetch embedding data');
103
+
104
+ const data = await response.json();
105
+ const allModels: ModelPoint[] = Array.isArray(data) ? data : (data.models || []);
106
+
107
+ // Filter to only models in our graph
108
+ const modelIdSet = new Set(modelIds);
109
+ const filteredModels = allModels.filter(m => modelIdSet.has(m.model_id));
110
+
111
+ // If we don't have all models, try fetching them individually or use what we have
112
+ setEmbeddingData(filteredModels);
113
+ } catch (err: any) {
114
+ console.error('Failed to load embedding data:', err);
115
+ // Fallback: create ModelPoint objects from graph nodes (without coordinates)
116
+ const fallbackData: ModelPoint[] = nodes.map(node => ({
117
+ model_id: node.id,
118
+ x: 0,
119
+ y: 0,
120
+ z: 0,
121
+ library_name: node.library || null,
122
+ pipeline_tag: node.pipeline || null,
123
+ downloads: node.downloads || 0,
124
+ likes: node.likes || 0,
125
+ trending_score: null,
126
+ tags: null,
127
+ parent_model: null,
128
+ licenses: null,
129
+ family_depth: null,
130
+ cluster_id: null,
131
+ created_at: null,
132
+ }));
133
+ setEmbeddingData(fallbackData);
134
+ } finally {
135
+ setLoadingEmbedding(false);
136
+ }
137
+ };
138
+
139
+ loadEmbeddingData();
140
+ }, [viewMode, nodes]);
141
+
142
+ // Handle search
143
+ const handleSearch = useCallback(async (query: string) => {
144
+ if (!query.trim()) {
145
+ setSearchResults([]);
146
+ setShowSearchResults(false);
147
+ return;
148
+ }
149
+
150
+ try {
151
+ const response = await fetch(
152
+ `${API_BASE}/api/search?q=${encodeURIComponent(query)}&limit=10`
153
+ );
154
+ if (!response.ok) throw new Error('Search failed');
155
+ const data = await response.json();
156
+ const results = Array.isArray(data) ? data : (data.models || []);
157
+ setSearchResults(results);
158
+ setShowSearchResults(true);
159
+ } catch (err) {
160
+ setSearchResults([]);
161
+ setShowSearchResults(false);
162
+ }
163
+ }, []);
164
+
165
+ const handleSearchResultClick = useCallback((model: any) => {
166
+ setModelId(model.model_id);
167
+ setShowSearchResults(false);
168
+ }, []);
169
+
170
+ const toggleEdgeType = useCallback((type: EdgeType) => {
171
+ setEnabledEdgeTypes(prev => {
172
+ const newSet = new Set(prev);
173
+ if (newSet.has(type)) {
174
+ newSet.delete(type);
175
+ } else {
176
+ newSet.add(type);
177
+ }
178
+ return newSet;
179
+ });
180
+ }, []);
181
+
182
+ const handleNodeClick = useCallback((node: GraphNode) => {
183
+ setSelectedNodeId(node.id);
184
+ setModelId(node.id);
185
+ // Find corresponding model in embedding data if available
186
+ if (embeddingData.length > 0) {
187
+ const model = embeddingData.find(m => m.model_id === node.id);
188
+ if (model) {
189
+ setSelectedModel(model);
190
+ }
191
+ }
192
+ }, [embeddingData]);
193
+
194
+ const handleEmbeddingPointClick = useCallback((model: ModelPoint) => {
195
+ setSelectedModel(model);
196
+ setSelectedNodeId(model.model_id);
197
+ setModelId(model.model_id);
198
+ }, []);
199
+
200
+ const containerRef = React.useRef<HTMLDivElement>(null);
201
+ const [dimensions, setDimensions] = useState({ width: 1000, height: 600 });
202
+
203
+ useEffect(() => {
204
+ const updateDimensions = () => {
205
+ if (containerRef.current) {
206
+ const rect = containerRef.current.getBoundingClientRect();
207
+ setDimensions({
208
+ width: rect.width,
209
+ height: Math.max(600, rect.height - 200),
210
+ });
211
+ }
212
+ };
213
+
214
+ updateDimensions();
215
+ window.addEventListener('resize', updateDimensions);
216
+ return () => window.removeEventListener('resize', updateDimensions);
217
+ }, []);
218
+
219
+ return (
220
+ <div className="graph-page">
221
+ <div className="page-header">
222
+ <h1>Model Relationship Graph</h1>
223
+ <p className="page-description">
224
+ Visualize model derivatives and relationships. Switch between force-directed graph view and embedding space view.
225
+ Explore how models are connected through fine-tuning, quantization, adapters, and merges.
226
+ </p>
227
+ </div>
228
+
229
+ <div className="graph-controls-panel">
230
+ <div className="search-section">
231
+ <input
232
+ type="text"
233
+ className="graph-search-input"
234
+ placeholder="Search for a model to visualize its relationships..."
235
+ value={modelId}
236
+ onChange={(e) => {
237
+ const value = e.target.value;
238
+ setModelId(value);
239
+ if (value.trim()) {
240
+ handleSearch(value);
241
+ } else {
242
+ setSearchResults([]);
243
+ setShowSearchResults(false);
244
+ }
245
+ }}
246
+ onFocus={() => {
247
+ if (searchResults.length > 0) {
248
+ setShowSearchResults(true);
249
+ }
250
+ }}
251
+ />
252
+ {showSearchResults && searchResults.length > 0 && (
253
+ <div className="search-results-dropdown">
254
+ {searchResults.map((model) => (
255
+ <div
256
+ key={model.model_id}
257
+ className="search-result-item"
258
+ onClick={() => handleSearchResultClick(model)}
259
+ >
260
+ <div className="result-title">{model.model_id}</div>
261
+ <div className="result-meta">
262
+ {model.library_name && <span>{model.library_name}</span>}
263
+ {model.downloads > 0 && (
264
+ <span>{model.downloads.toLocaleString()} downloads</span>
265
+ )}
266
+ </div>
267
+ </div>
268
+ ))}
269
+ </div>
270
+ )}
271
+ </div>
272
+
273
+ <div className="graph-settings">
274
+ <div className="setting-group">
275
+ <label>View Mode:</label>
276
+ <select
277
+ value={viewMode}
278
+ onChange={(e) => setViewMode(e.target.value as ViewMode)}
279
+ className="view-mode-select"
280
+ >
281
+ <option value="graph">Force-Directed Graph</option>
282
+ <option value="embedding">Embedding Space (3D)</option>
283
+ </select>
284
+ </div>
285
+
286
+ {viewMode === 'embedding' && (
287
+ <>
288
+ <div className="setting-group">
289
+ <label>Color By:</label>
290
+ <select
291
+ value={colorBy}
292
+ onChange={(e) => setColorBy(e.target.value)}
293
+ className="color-by-select"
294
+ >
295
+ <option value="library_name">Library</option>
296
+ <option value="pipeline_tag">Task Type</option>
297
+ <option value="downloads">Downloads</option>
298
+ <option value="likes">Likes</option>
299
+ <option value="family_depth">Family Depth</option>
300
+ </select>
301
+ </div>
302
+
303
+ <div className="setting-group">
304
+ <label>Size By:</label>
305
+ <select
306
+ value={sizeBy}
307
+ onChange={(e) => setSizeBy(e.target.value)}
308
+ className="size-by-select"
309
+ >
310
+ <option value="downloads">Downloads</option>
311
+ <option value="likes">Likes</option>
312
+ <option value="none">Uniform</option>
313
+ </select>
314
+ </div>
315
+ </>
316
+ )}
317
+
318
+ <div className="setting-group">
319
+ <label>Max Depth:</label>
320
+ <input
321
+ type="number"
322
+ min="1"
323
+ max="20"
324
+ value={maxDepth || ''}
325
+ onChange={(e) => setMaxDepth(e.target.value ? parseInt(e.target.value) : undefined)}
326
+ className="depth-input"
327
+ />
328
+ </div>
329
+
330
+ <div className="setting-group">
331
+ <label>Current Model:</label>
332
+ <div className="current-model">{modelId || 'None selected'}</div>
333
+ </div>
334
+ </div>
335
+ </div>
336
+
337
+ <div className="graph-container" ref={containerRef}>
338
+ {loading ? (
339
+ <LoadingProgress message="Loading graph..." progress={0} />
340
+ ) : error ? (
341
+ <div className="graph-error">
342
+ <p>Error: {error}</p>
343
+ {error.includes('not found') && (
344
+ <p className="error-hint">Try searching for a different model.</p>
345
+ )}
346
+ </div>
347
+ ) : nodes.length === 0 ? (
348
+ <div className="graph-empty">
349
+ <p>Enter a model ID above to visualize its relationship graph.</p>
350
+ <p className="empty-hint">
351
+ Try popular models like: <code>bert-base-uncased</code>, <code>gpt2</code>, or <code>t5-base</code>
352
+ </p>
353
+ </div>
354
+ ) : viewMode === 'graph' ? (
355
+ <>
356
+ <ForceDirectedGraph
357
+ width={dimensions.width}
358
+ height={dimensions.height}
359
+ nodes={nodes}
360
+ links={links}
361
+ onNodeClick={handleNodeClick}
362
+ selectedNodeId={selectedNodeId}
363
+ enabledEdgeTypes={enabledEdgeTypes}
364
+ showLabels={true}
365
+ />
366
+ <EdgeTypeLegend
367
+ edgeTypes={ALL_EDGE_TYPES}
368
+ enabledTypes={enabledEdgeTypes}
369
+ onToggle={toggleEdgeType}
370
+ />
371
+ {graphStats && (
372
+ <div className="graph-stats">
373
+ <div className="stat-item">
374
+ <span className="stat-label">Nodes:</span>
375
+ <span className="stat-value">{graphStats.nodes || nodes.length}</span>
376
+ </div>
377
+ <div className="stat-item">
378
+ <span className="stat-label">Edges:</span>
379
+ <span className="stat-value">{graphStats.edges || links.length}</span>
380
+ </div>
381
+ {graphStats.avg_degree && (
382
+ <div className="stat-item">
383
+ <span className="stat-label">Avg Degree:</span>
384
+ <span className="stat-value">{graphStats.avg_degree.toFixed(2)}</span>
385
+ </div>
386
+ )}
387
+ </div>
388
+ )}
389
+ </>
390
+ ) : (
391
+ <>
392
+ {loadingEmbedding ? (
393
+ <LoadingProgress message="Loading embedding data..." progress={0} />
394
+ ) : embeddingData.length === 0 ? (
395
+ <div className="graph-empty">
396
+ <p>No embedding data available for these models.</p>
397
+ <p className="empty-hint">Try switching to graph view or selecting a different model.</p>
398
+ </div>
399
+ ) : (
400
+ <>
401
+ <ScatterPlot3D
402
+ data={embeddingData}
403
+ colorBy={colorBy}
404
+ sizeBy={sizeBy}
405
+ colorScheme="viridis"
406
+ onPointClick={handleEmbeddingPointClick}
407
+ hoveredModel={selectedModel}
408
+ />
409
+ <div className="embedding-info">
410
+ <div className="info-item">
411
+ <span className="info-label">Models:</span>
412
+ <span className="info-value">{embeddingData.length}</span>
413
+ </div>
414
+ <div className="info-item">
415
+ <span className="info-label">View:</span>
416
+ <span className="info-value">Embedding Space</span>
417
+ </div>
418
+ </div>
419
+ </>
420
+ )}
421
+ </>
422
+ )}
423
+ </div>
424
+ </div>
425
+ );
426
+ }
427
+
428
+ interface EdgeTypeLegendProps {
429
+ edgeTypes: EdgeType[];
430
+ enabledTypes: Set<EdgeType>;
431
+ onToggle: (type: EdgeType) => void;
432
+ }
433
+
434
+ const EDGE_COLORS: Record<EdgeType, string> = {
435
+ finetune: '#3b82f6',
436
+ quantized: '#10b981',
437
+ adapter: '#f59e0b',
438
+ merge: '#8b5cf6',
439
+ parent: '#6b7280',
440
+ };
441
+
442
+ const EDGE_LABELS: Record<EdgeType, string> = {
443
+ finetune: 'Fine-tuned',
444
+ quantized: 'Quantized',
445
+ adapter: 'Adapter',
446
+ merge: 'Merged',
447
+ parent: 'Parent',
448
+ };
449
+
450
+ function EdgeTypeLegend({ edgeTypes, enabledTypes, onToggle }: EdgeTypeLegendProps) {
451
+ return (
452
+ <div className="edge-type-legend">
453
+ <h4>Relationship Types</h4>
454
+ {edgeTypes.map((type) => (
455
+ <div
456
+ key={type}
457
+ className={`edge-type-item ${!enabledTypes.has(type) ? 'disabled' : ''}`}
458
+ onClick={() => onToggle(type)}
459
+ >
460
+ <div
461
+ className="edge-type-color"
462
+ style={{ backgroundColor: EDGE_COLORS[type] }}
463
+ />
464
+ <span className="edge-type-label">{EDGE_LABELS[type]}</span>
465
+ </div>
466
+ ))}
467
+ </div>
468
+ );
469
+ }
frontend/src/utils/api/graphApi.ts ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * API utilities for fetching graph/network data
3
+ */
4
+ import { API_BASE } from '../../config/api';
5
+ import { GraphNode, GraphLink, EdgeType } from '../../components/visualizations/ForceDirectedGraph';
6
+
7
+ export interface NetworkGraphResponse {
8
+ nodes: GraphNode[];
9
+ links: GraphLink[];
10
+ statistics?: {
11
+ nodes: number;
12
+ edges: number;
13
+ density: number;
14
+ avg_degree: number;
15
+ clustering: number;
16
+ };
17
+ root_model: string;
18
+ }
19
+
20
+ /**
21
+ * Fetch family network graph for a specific model
22
+ */
23
+ export async function fetchFamilyNetwork(
24
+ modelId: string,
25
+ options: {
26
+ maxDepth?: number;
27
+ edgeTypes?: EdgeType[];
28
+ includeEdgeAttributes?: boolean;
29
+ } = {}
30
+ ): Promise<NetworkGraphResponse> {
31
+ const { maxDepth, edgeTypes, includeEdgeAttributes = true } = options;
32
+
33
+ const params = new URLSearchParams();
34
+ if (maxDepth !== undefined) {
35
+ params.append('max_depth', maxDepth.toString());
36
+ }
37
+ if (edgeTypes && edgeTypes.length > 0) {
38
+ params.append('edge_types', edgeTypes.join(','));
39
+ }
40
+ if (includeEdgeAttributes !== undefined) {
41
+ params.append('include_edge_attributes', includeEdgeAttributes.toString());
42
+ }
43
+
44
+ const url = `${API_BASE}/api/network/family/${encodeURIComponent(modelId)}${params.toString() ? '?' + params.toString() : ''}`;
45
+
46
+ const response = await fetch(url);
47
+ if (!response.ok) {
48
+ throw new Error(`Failed to fetch network graph: ${response.statusText}`);
49
+ }
50
+
51
+ const data = await response.json();
52
+
53
+ // Transform the response to match our types
54
+ return {
55
+ nodes: data.nodes || [],
56
+ links: data.links || [],
57
+ statistics: data.statistics,
58
+ root_model: data.root_model || modelId,
59
+ };
60
+ }
61
+
62
+ /**
63
+ * Get all available edge types from a graph response
64
+ */
65
+ export function getAvailableEdgeTypes(links: GraphLink[]): Set<EdgeType> {
66
+ const types = new Set<EdgeType>();
67
+ links.forEach(link => {
68
+ if (link.edge_types && link.edge_types.length > 0) {
69
+ link.edge_types.forEach(type => types.add(type));
70
+ } else if (link.edge_type) {
71
+ types.add(link.edge_type);
72
+ }
73
+ });
74
+ return types;
75
+ }