Premchan369 commited on
Commit
e221672
·
verified ·
1 Parent(s): c2a5894

Upload tests/test_v4.py

Browse files
Files changed (1) hide show
  1. tests/test_v4.py +245 -0
tests/test_v4.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ V4 integration tests for Q-TensorFormer.
3
+
4
+ Tests QKAN DARUAN activations, energy-aware training,
5
+ and the combined v4 pipeline.
6
+ """
7
+
8
+ import torch
9
+ import sys
10
+ import os
11
+
12
+ # Add src to path for testing
13
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14
+
15
+ from src.qkan import DARUAN, QKANLayer, HQKANFFN, create_qkan_ffn
16
+ from src.energy_v4 import (
17
+ EnergyEstimatorV4, ParetoTracker, HARDWARE_PROFILES,
18
+ estimate_model_energy, HardwareProfile
19
+ )
20
+ from src.config import ModelConfig, TrainingConfig, BudgetConfig
21
+
22
+
23
+ def test_daruan_basic():
24
+ """Test DARUAN activation on scalar input."""
25
+ daruan = DARUAN(n_repeats=3, base_activation="silu")
26
+ x = torch.randn(10)
27
+ out = daruan(x)
28
+ assert out.shape == (10,), f"Expected (10,), got {out.shape}"
29
+ assert not torch.isnan(out).any(), "NaN in DARUAN output"
30
+ print("✓ DARUAN basic: PASSED")
31
+
32
+
33
+ def test_daruan_batched():
34
+ """Test DARUAN on batched tensor."""
35
+ daruan = DARUAN(n_repeats=5, base_activation="gelu")
36
+ x = torch.randn(32, 128)
37
+ out = daruan(x)
38
+ assert out.shape == (32, 128), f"Expected (32, 128), got {out.shape}"
39
+ assert not torch.isnan(out).any(), "NaN in DARUAN output"
40
+ print("✓ DARUAN batched: PASSED")
41
+
42
+
43
+ def test_qkan_layer():
44
+ """Test QKANLayer as drop-in for Linear + Activation."""
45
+ layer = QKANLayer(128, 256, n_repeats=3)
46
+ x = torch.randn(16, 128)
47
+ out = layer(x)
48
+ assert out.shape == (16, 256), f"Expected (16, 256), got {out.shape}"
49
+
50
+ params = layer.parameter_count()
51
+ dense_params = 128 * 256 + 256 # weight + bias
52
+ print(f" QKAN params: {params} vs dense: {dense_params} ({(1 - params/dense_params)*100:.1f}% reduction)")
53
+ print("✓ QKANLayer: PASSED")
54
+
55
+
56
+ def test_hqkan_ffn():
57
+ """Test HQKAN FFN as drop-in for transformer FFN."""
58
+ ffn = HQKANFFN(hidden_dim=128, ff_multiplier=4, n_repeats=3)
59
+ x = torch.randn(8, 64, 128) # (batch, seq_len, d_model)
60
+ out = ffn(x)
61
+ assert out.shape == (8, 64, 128), f"Expected (8, 64, 128), got {out.shape}"
62
+ print(f" HQKAN FFN params: {ffn.total_params}")
63
+ print("✓ HQKAN FFN: PASSED")
64
+
65
+
66
+ def test_create_qkan_ffn():
67
+ """Test factory function for all QKAN FFN variants."""
68
+ # Standard
69
+ ffn_std = create_qkan_ffn(128, 4, n_repeats=3)
70
+ x = torch.randn(4, 32, 128)
71
+ out = ffn_std(x)
72
+ assert out.shape == (4, 32, 128)
73
+ print("✓ create_qkan_ffn (standard): PASSED")
74
+
75
+ # TT-QKAN hybrid
76
+ ffn_tt = create_qkan_ffn(128, 4, n_repeats=3, use_tt=True, tt_rank=4)
77
+ out = ffn_tt(x)
78
+ assert out.shape == (4, 32, 128), f"Expected (4, 32, 128), got {out.shape}"
79
+ print("✓ create_qkan_ffn (TT-hybrid): PASSED")
80
+
81
+
82
+ def test_energy_estimator():
83
+ """Test hardware-aware energy estimator."""
84
+ est = EnergyEstimatorV4("cpu_intel_xeon")
85
+
86
+ # Compute energy for a model forward pass
87
+ flops = 1e9 # 1 GFLOP
88
+ energy = est.compute_energy(flops, batch_size=16, memory_gb=0.5)
89
+ assert energy > 0, f"Energy should be positive, got {energy}"
90
+ print(f" Energy for 1 GFLOP on CPU: {energy:.2f} μJ")
91
+
92
+ # Carbon footprint
93
+ carbon = est.carbon_footprint(energy)
94
+ assert carbon > 0, f"Carbon should be positive"
95
+ print(f" Carbon: {carbon:.6f} g CO2")
96
+ print("✓ EnergyEstimator: PASSED")
97
+
98
+
99
+ def test_energy_all_hardware():
100
+ """Test energy estimation across all hardware targets."""
101
+ est = EnergyEstimatorV4()
102
+ flops = 1e9
103
+
104
+ print(" Hardware comparison (1 GFLOP):")
105
+ for hw_name in ["cpu_intel_xeon", "cpu_apple_m2", "gpu_a100", "edge_tpu", "edge_mobile"]:
106
+ est.set_hardware(hw_name)
107
+ energy = est.compute_energy(flops, batch_size=16)
108
+ print(f" {HARDWARE_PROFILES[hw_name].name}: {energy:.4f} μJ")
109
+ print("✓ All hardware targets: PASSED")
110
+
111
+
112
+ def test_quantum_energy():
113
+ """Test quantum circuit energy estimation."""
114
+ est = EnergyEstimatorV4("cpu_intel_xeon")
115
+ energy = est.quantum_energy(n_qubits=4, n_layers=2, n_tokens=100)
116
+ assert energy > 0
117
+ print(f" Quantum energy (4 qubits, 2 layers, 100 tokens): {energy:.2f} μJ")
118
+ print("✓ Quantum energy estimation: PASSED")
119
+
120
+
121
+ def test_training_energy():
122
+ """Test total training energy estimate."""
123
+ est = EnergyEstimatorV4("gpu_a100")
124
+ result = est.training_energy_estimate(
125
+ total_flops=1e9,
126
+ n_epochs=10,
127
+ batch_size=16,
128
+ dataset_size=10000,
129
+ quantum_tokens_per_batch=128,
130
+ n_qubits=4,
131
+ n_qlayers=2,
132
+ )
133
+ assert "total_energy_uj" in result
134
+ print(f" Total training energy: {result['total_energy_j']:.4f} J")
135
+ print(f" Carbon: {result['carbon_g']:.4f} g CO2")
136
+ print(f" Equivalent smartphone charges: {result['equivalent_smartphone_charges']:.4f}")
137
+ print("✓ Training energy estimate: PASSED")
138
+
139
+
140
+ def test_pareto_tracker():
141
+ """Test Pareto frontier tracking."""
142
+ tracker = ParetoTracker()
143
+
144
+ # Add some points
145
+ assert tracker.record(ppl=100, energy_uj=1000, step=0) # First point always Pareto
146
+ assert tracker.record(ppl=80, energy_uj=900, step=1) # Better both → Pareto
147
+ assert not tracker.record(ppl=90, energy_uj=950, step=2) # Dominated by (80, 900)
148
+ assert tracker.record(ppl=75, energy_uj=1200, step=3) # Better ppl, worse energy → Pareto
149
+
150
+ summary = tracker.summary()
151
+ assert summary["points"] == 3, f"Expected 3 Pareto points, got {summary['points']}"
152
+ print(f" Pareto frontier: {summary['frontier']}")
153
+ print("✓ ParetoTracker: PASSED")
154
+
155
+
156
+ def test_budget_integration():
157
+ """Test budget constraints with energy-aware optimization."""
158
+ config = ModelConfig(
159
+ d_model=64, n_layers=2, n_heads=4, tt_rank=4,
160
+ vocab_size=5000, use_quantum=False,
161
+ )
162
+ budget = BudgetConfig(
163
+ max_params=500000,
164
+ max_latency_ms=50.0,
165
+ max_energy_per_query=100.0,
166
+ )
167
+
168
+ # Validate configs
169
+ config.validate()
170
+ budget.validate()
171
+
172
+ print(f" Model config: d={config.d_model}, layers={config.n_layers}")
173
+ print(f" Budget: params≤{budget.max_params}, latency≤{budget.max_latency_ms}ms, energy≤{budget.max_energy_per_query}μJ")
174
+ print("✓ Budget integration: PASSED")
175
+
176
+
177
+ def test_e2e_v4_pipeline():
178
+ """End-to-end v4 pipeline test."""
179
+ from src.models import create_model
180
+ from src.config import ModelConfig
181
+ from src.energy_v4 import estimate_model_energy, EnergyEstimatorV4
182
+
183
+ config = ModelConfig(
184
+ vocab_size=1000,
185
+ d_model=64,
186
+ n_layers=2,
187
+ n_heads=4,
188
+ tt_rank=4,
189
+ max_seq_len=64,
190
+ n_qubits=4,
191
+ use_quantum=False, # Skip quantum for basic test
192
+ )
193
+
194
+ model = create_model(config, model_type="qtensor")
195
+
196
+ # Forward pass
197
+ x = torch.randint(0, 1000, (2, 16))
198
+ logits = model(x)
199
+ assert logits.shape == (2, 16, 1000), f"Expected (2, 16, 1000), got {logits.shape}"
200
+
201
+ # Energy estimate
202
+ est = EnergyEstimatorV4("cpu_apple_m2")
203
+ est_result = estimate_model_energy(model, est, seq_len=64, batch_size=2)
204
+ print(f" E2E energy: {est_result['energy_uj']:.2f} μJ")
205
+ print(f" E2E carbon: {est_result['carbon_per_query_ug']:.4f} μg CO2")
206
+ print(f" E2E params: {est_result['params']}")
207
+ print("✓ End-to-end v4 pipeline: PASSED")
208
+
209
+
210
+ if __name__ == "__main__":
211
+ print("=" * 60)
212
+ print("Q-TensorFormer v4 — Integration Tests")
213
+ print("=" * 60)
214
+
215
+ tests = [
216
+ ("DARUAN basic", test_daruan_basic),
217
+ ("DARUAN batched", test_daruan_batched),
218
+ ("QKANLayer", test_qkan_layer),
219
+ ("HQKAN FFN", test_hqkan_ffn),
220
+ ("create_qkan_ffn", test_create_qkan_ffn),
221
+ ("EnergyEstimator", test_energy_estimator),
222
+ ("All Hardware", test_energy_all_hardware),
223
+ ("Quantum Energy", test_quantum_energy),
224
+ ("Training Energy", test_training_energy),
225
+ ("ParetoTracker", test_pareto_tracker),
226
+ ("Budget Integration", test_budget_integration),
227
+ ("E2E v4 Pipeline", test_e2e_v4_pipeline),
228
+ ]
229
+
230
+ passed = 0
231
+ failed = 0
232
+ for name, test_fn in tests:
233
+ try:
234
+ test_fn()
235
+ passed += 1
236
+ except Exception as e:
237
+ print(f"✗ {name}: FAILED — {e}")
238
+ failed += 1
239
+
240
+ print(f"\n{'=' * 60}")
241
+ print(f"Results: {passed}/{passed + failed} tests passed")
242
+ if failed:
243
+ print(f"FAILED: {failed} test(s)")
244
+ else:
245
+ print("✅ ALL TESTS PASSED")