FL Hybrid Eigendecomposition Beating cuSOLVER's Mathematical Purity with Compilable PyTorch
FL Hybrid Eigendecomposition
Faddeev-LeVerrier + Laguerre + Newton-Schulz + Rayleigh QuotientA compilable eigendecomposition pipeline that wins 84/84 mathematical purity metrics against cuSOLVER across matrix sizes n=3-12, uses 40× less memory, and achieves zero graph breaks under torch.compile(fullgraph=True).
Part of the GeoLIP geometric deep learning ecosystem.
Repository: AbstractEyes/geolip-core Related: geofractal · wide-compiler Date: April 1, 2026 Hardware: NVIDIA RTX PRO 6000 Blackwell Server Edition (95GB VRAM)
The Problem
torch.linalg.eigh calls cuSOLVER under the hood. cuSOLVER is fast, but it:
- Breaks compilation: Introduces 2+ graph breaks per call under
torch.compile, preventing full-graph optimization of training loops that include eigendecomposition - Consumes massive memory: 1,099 MB peak allocation at B=4096, n=6 — dominated by workspace buffers
- Cannot be fused: Lives behind an opaque CUDA library call, invisible to the compiler's fusion passes
For GeoLIP's Cayley-Menger validated geometric transformer, eigendecomposition runs inside every forward pass on batched 5×5 and 6×6 pentachoron matrices. Graph breaks and memory pressure directly impact training throughput.
The Algorithm
The FL Hybrid pipeline replaces the opaque cuSOLVER call with five transparent, compilable phases:
Phase 1: Faddeev-LeVerrier Characteristic Polynomial
The Faddeev-LeVerrier recurrence computes the characteristic polynomial coefficients directly from the matrix:
Starting from M₀ = 0, c[n] = 1. After n iterations of batched matrix multiplication, we have the monic characteristic polynomial p(λ) = λⁿ + c[n-1]λⁿ⁻¹ + ... + c[0] whose roots are the eigenvalues.
Implementation: n torch.bmm calls in fp64. Fully compilable. The M[k] matrices are stored for eigenvector extraction.
Phase 2: Laguerre Root-Finding with Synthetic Deflation
Laguerre's method finds each eigenvalue sequentially. For a polynomial with all real roots (guaranteed for symmetric matrices), Laguerre converges cubically from any starting point:
where G = p'(λ)/p(λ) and H = G² - p''(λ)/p(λ), evaluated via Horner's scheme.
After finding each root, synthetic division deflates the polynomial. Newton polish on the original (undeflated) polynomial corrects any accumulated deflation error.
Precision strategy: fp32 Laguerre for n ≤ 6 (sufficient), fp64 for n > 6 (deeper deflation chains). 5 Laguerre iterations per root, 3-5 Newton polish passes.
Phase 3: FL Adjugate Eigenvectors
The stored M[k] matrices from Phase 1 yield eigenvectors via Horner evaluation of the matrix polynomial:
The columns of R(λᵢ) are scalar multiples of the eigenvector for λᵢ. We extract the maximum-norm column for numerical stability.
Adaptive computation: Broadcast [B, n, n, n] Horner for n ≤ 6 (3.4 MB, fast). Per-eigenvalue [B, n, n] loop for n > 6 (avoids 566 MB allocation at n=12).
Phase 4: Newton-Schulz Orthogonalization
Two iterations of the Newton-Schulz polar decomposition iteration orthogonalize the eigenvector matrix:
Pure torch.bmm — no eigensolvers, no SVD. Quadratic convergence from the near-orthogonal FL eigenvectors.
Phase 5: Rayleigh Quotient Refinement
The Rayleigh quotient produces eigenvalues that are optimal for the given eigenvectors:
This fuses the algebraic accuracy of FL root-finding with the geometric accuracy of Newton-Schulz orthogonalization. Two batched matrix multiplies.
Results
Mathematical Purity
We test against pure mathematical definitions — not against a reference implementation. Twelve metrics evaluated on 2,048 random symmetric matrices:
| Metric | cuSOLVER | FL Precise | Winner |
|---|---|---|---|
| Eigenpair residual max | 5.8e-07 | 1.7e-07 | FL |
| Eigenpair residual mean | 1.3e-07 | 2.9e-08 | FL |
| Orthogonality max | 2.0e-06 | 3.0e-07 | FL |
| Orthogonality mean | 9.0e-07 | 1.5e-07 | FL |
| Reconstruction max | 1.3e-06 | 2.9e-07 | FL |
| Reconstruction mean | 4.9e-07 | 1.1e-07 | FL |
| Trace error max | 2.9e-06 | 1.2e-06 | FL |
| Trace error mean | 5.1e-07 | 2.4e-07 | FL |
| Determinant max | 3.1e-02 | 7.4e-04 | FL |
| Determinant mean | 1.8e-05 | 8.9e-07 | FL |
| Char. polynomial max | 1.0e-02 | 1.9e-03 | FL |
| Char. polynomial mean | 4.4e-05 | 1.7e-05 | FL |
FL Precise wins 12/12 metrics at n=6.
Purity Across Matrix Sizes
| n | FL wins | cuSOLVER wins | Best eigenpair | Best orthogonality |
|---|---|---|---|---|
| 3 | 12/12 | 0/12 | FL 1.8e-07 | FL 2.4e-07 |
| 4 | 12/12 | 0/12 | FL 2.1e-07 | FL 2.6e-07 |
| 5 | 12/12 | 0/12 | FL 1.9e-07 | FL 2.7e-07 |
| 6 | 12/12 | 0/12 | FL 2.1e-07 | FL 3.0e-07 |
| 8 | 12/12 | 0/12 | FL 1.7e-07 | FL 3.4e-07 |
| 10 | 12/12 | 0/12 | FL 1.8e-07 | FL 3.6e-07 |
| 12 | 12/12 | 0/12 | FL 1.4e-07 | FL 4.3e-07 |
| 16 | 2/12 | 10/12 | cuS 5.8e-07 | cuS 3.1e-06 |
84/84 across n=3-12. FL degrades at n=16 due to deep deflation chain conditioning.
Throughput (n=6, B=4,096)
| Method | Time | vs cuSOLVER | Matrices/sec |
|---|---|---|---|
| cuSOLVER | 241 µs | 1.00× | 17.0M/s |
| FL Fast compiled | 346 µs | 0.70× | 11.9M/s |
| FL Precise compiled | 350 µs | 0.69× | 11.7M/s |
| FL Fast + CUDA Graph | 279 µs | 0.86× | 14.7M/s |
| FL Precise + CUDA Graph | 287 µs | 0.84× | 14.3M/s |
| CUDA kernel (CuPy) | 425 µs | 0.56× | 9.6M/s |
Memory
| Method | Peak Allocation |
|---|---|
| cuSOLVER | 1,098.7 MB |
| FL Fast | 26.2 MB |
| FL Precise | 32.3 MB |
| CUDA kernel | 0.7 MB |
40× less memory than cuSOLVER. The CUDA kernel uses 1,500× less — everything lives in thread-local registers.
Accuracy Across All Sizes
| n | Fast val_err | Fast align | Precise val_err | Precise align |
|---|---|---|---|---|
| 3 | 1.7e-06 | 1.000000 | 1.4e-06 | 1.000000 |
| 4 | 2.1e-06 | 0.999999 | 1.7e-06 | 0.999999 |
| 5 | 2.1e-06 | 0.999999 | 2.4e-06 | 0.999999 |
| 6 | 2.9e-06 | 0.999999 | 2.6e-06 | 0.999999 |
| 8 | 2.9e-06 | 0.999999 | 2.4e-06 | 0.999999 |
| 10 | 4.8e-06 | 0.999999 | 3.7e-06 | 0.999999 |
| 12 | 5.5e-06 | 0.999999 | 5.5e-06 | 0.999999 |
| 16 | 3.1e-04 | 0.999933 | 5.2e-06 | 0.999999 |
CUDA Kernel (CuPy/NVRTC)
A standalone CUDA kernel compiled at runtime via CuPy's NVRTC interface. One thread per matrix, entire pipeline in registers. No ninja, no C++ compiler — just the CUDA driver.
Architecture
Thread allocation: 1 thread = 1 matrix (all N² elements in registers)
Register budget: ~232 fp64 registers at n=6 (fits 255 limit)
Memory: Zero global intermediates (interleaved FL+Horner)
Compilation: NVRTC at first call, cached at ~/.cupy/kernel_cache/
The kernel is generated per matrix size via a Python template. NVRTC compiles each size variant separately, enabling full loop unrolling for each n.
Batch Scaling (n=6)
| Batch Size | cuSOLVER | CUDA FL | Ratio |
|---|---|---|---|
| 256 | 102 µs | 426 µs | 0.24× |
| 1,024 | 116 µs | 424 µs | 0.27× |
| 4,096 | 240 µs | 425 µs | 0.56× |
| 8,192 | 409 µs | 427 µs | 0.96× |
| 16,384 | 743 µs | 429 µs | 1.73× |
| 32,768 | OOM | 810 µs | ∞ |
Flat scaling: 425 µs regardless of batch size. cuSOLVER scales linearly with B due to workspace allocation. The CUDA kernel has zero intermediate memory — everything is per-thread registers. At B=16,384 the kernel is 1.73× faster than cuSOLVER. At B=32,768, cuSOLVER cannot run (OOM at 95GB VRAM) while the kernel runs at 810 µs using 10 MB.
Size Scaling (B=4,096)
| n | cuSOLVER | CUDA FL | Ratio |
|---|---|---|---|
| 3 | 137 µs | 45 µs | 3.06× |
| 4 | 164 µs | 73 µs | 2.24× |
| 5 | 201 µs | 217 µs | 0.93× |
| 6 | 239 µs | 426 µs | 0.56× |
| 8 | 324 µs | 1.52 ms | 0.21× |
Dominates at n ≤ 4 (3× cuSOLVER). Register spilling degrades performance above n=6.
Batch × Size Matrix
| B | n=3 | n=5 | n=6 | n=8 | n=12 |
|---|---|---|---|---|---|
| 512 | 1.89× | 0.44× | 0.25× | 0.08× | 0.02× |
| 2,048 | 2.29× | 0.65× | 0.37× | 0.13× | 0.04× |
| 8,192 | 4.62× | 1.60× | 0.98× | 0.38× | 0.12× |
| 16,384 | 7.64× | 2.77× | 1.74× | 0.69× | 0.21× |
The CUDA kernel's domain: n ≤ 5 at any batch size, n=6 at B ≥ 8,192.
Package Architecture: geolip.linalg
A drop-in replacement for torch.linalg with automatic dispatch to optimized implementations:
import geolip.linalg as LA
# Our implementations (override torch.linalg)
vals, vecs = LA.eigh(A) # FL pipeline n≤12, cuSOLVER n>12
U, S, Vh = LA.svd(A) # Triton n=2,3 → FL n≤12 → cuSOLVER
# Passthrough to torch.linalg (zero overhead)
x = LA.solve(A, b) # torch.linalg.solve
L = LA.cholesky(A) # torch.linalg.cholesky
Module Structure
geolip/linalg/
__init__.py — Public API + torch.linalg proxy via __getattr__
_backend.py — Singleton: CUDA/Triton detection, one-time warning
eigh.py — FLEigh class + eigh() auto-dispatcher
svd.py — batched_svd with FL eigh integration
newton_schulz.py — Pure bmm inverse square root
procrustes.py — Subspace-preserving Procrustes alignment
The __getattr__ proxy transparently forwards any function we haven't overridden to torch.linalg:
import torch.linalg as _torch_linalg
def __getattr__(name):
if hasattr(_torch_linalg, name):
return getattr(_torch_linalg, name)
raise AttributeError(f"module 'geolip.linalg' has no attribute '{name}'")
Backend Configuration
from geolip.linalg import backend
backend.status() # print available features
backend.use_fl_eigh = True # toggle FL pipeline (default: True on CUDA)
backend.use_triton = True # toggle Triton SVD kernels
One warning on first fallback. No warning on subsequent calls. Degrades gracefully to pure PyTorch on CPU.
Dispatch Table
| Condition | Implementation | Reason |
|---|---|---|
| n ≤ 4, CuPy, B ≥ 512 | CUDA kernel | 2-7× cuSOLVER |
| n ≤ 6, CuPy, B ≥ 8,192 | CUDA kernel | 1.7× cuSOLVER, flat scaling |
| n ≤ 12, CUDA | FL Precise (compiled) | 84/84 purity, 40× less memory |
| n > 12 | torch.linalg.eigh | FL conditioning degrades |
What We Tried and What Failed
This section documents approaches that were explored, benchmarked, and found to not outperform the sequential Laguerre baseline. They are preserved for future reference.
Parallel Newton-Aberth (Converges, Not Faster)
Newton step + Aberth repulsion finds all n roots simultaneously:
Result: Converges to 1.05e-06 at n=6 after 20 iterations. But each iteration builds a [B, n, n] repulsion matrix (1.2 MB per iteration at B=4096, n=6). 0.89× sequential Laguerre at n=6. Marginal improvement at n=8 (1.08×). The [B, n, n] tensor overhead negates the parallelization benefit.
Parallel Laguerre-Aberth (Diverges)
Laguerre step + Aberth repulsion. Laguerre's cubic convergence should dominate Newton's quadratic convergence:
Result: Diverges at all sizes. Laguerre's aggressive step size amplifies through the Aberth correction, causing exponential error growth. The diagnostic showed val_err increasing from 7.8e+01 to 5.6e+04 over 20 iterations. The Laguerre step is too large for real-line root-finding with simultaneous repulsion — it overshoots past neighboring roots.
Pure Parallel Laguerre (Collapses)
Laguerre without any repulsion term. Each root independently converges using the undivided polynomial:
Result: Polynomial residual |p(z)| converges to 2.88e-15 (machine precision), but all roots collapse to the same eigenvalue (min_gap = 0 by iteration 2). Without deflation or repulsion, all starting points converge to the nearest dominant root. This is a fundamental limitation: parallel root-finding on real-rooted polynomials requires either deflation (sequential) or repulsion (which causes instability).
Ternary Spectral Bisection (Initialization Problem)
Bisection is guaranteed to converge. Ternary subdivision (divide into thirds, identify occupied third, recurse) tracks intervals exactly in int64:
Result: The ternary refinement works perfectly — bracket width shrinks as 3^{-k} per level. The problem is initialization: finding which intervals contain which roots. A dense grid requires:
| n | Grid points needed |
|---|---|
| 3 | 167 |
| 6 | 1,838 |
| 10 | 4,398 |
The grid must resolve the worst-case eigenvalue gap across the entire batch. At n=6, one near-degenerate pair in B=2,048 matrices forces 1,838 grid evaluations — more expensive than 30 sequential Laguerre steps.
Wormhole Bracket Initialization (Collapses + Incompilable)
Newton steps from coarse grid points as "wormholes" to find hidden close root pairs. Grid point derivatives predict root locations without dense scanning:
Result: Wormhole targets collapse to the same roots (same failure mode as parallel Newton). Additionally, the data-dependent if missing.any() branch prevents torch.compile(fullgraph=True).
Learned Root Predictor (Ill-Conditioned Input Space)
A 5,062-parameter MLP trained online to predict roots from FL polynomial coefficients:
Result: val_err stuck at ~2.5 after 500 training batches. The coefficient-to-root mapping (Vieta's inverse) is ill-conditioned — Wilkinson's polynomial shows that small coefficient changes cause large root movements. The predictor forward pass is fast (57 µs) but cannot learn the inverse mapping.
Triton Kernel (Wrong Abstraction)
A 7,461-line auto-generated Triton kernel for n=6. Python code generator produces fully-unrolled Triton source with explicit named variables:
Result: Triton's tile-based programming model doesn't fit small-matrix scalar FMA patterns. The generated kernel is too large for Triton's JIT (compilation time > 60s). Triton is designed for large matmuls and attention kernels, not 6×6 register-level computation. The CuPy/NVRTC CUDA kernel achieves what Triton cannot because CUDA gives direct register control.
Key Insights
Why FL Wins Mathematical Purity
cuSOLVER computes eigenvalues via tridiagonal reduction + implicit QR. This preserves geometric properties (orthogonality) by construction but introduces algebraic error in eigenvalue computation.
FL computes eigenvalues as exact roots of the characteristic polynomial. This preserves algebraic properties (trace, determinant, characteristic polynomial evaluation) by construction but requires explicit orthogonalization.
The Rayleigh quotient refinement (λᵢ = vᵢᵀ A vᵢ) bridges the gap: it produces eigenvalues that are optimal for the given eigenvectors, fusing algebraic and geometric accuracy. This is why FL wins eigenpair residual — the eigenvalue-eigenvector pair is jointly optimal.
The Sequential Laguerre Barrier
Root-finding for real-rooted polynomials has a fundamental sequential dependency: finding root k+1 requires knowledge of root k (via deflation). Without deflation, parallel methods either collapse to duplicates or require repulsion that causes instability.
Laguerre's method with synthetic deflation is the correct production algorithm for this problem. It converges cubically from diagonal initialization, handles close eigenvalue pairs via deflation, and compiles cleanly under torch.compile.
The CUDA Kernel's Register Advantage
The CUDA kernel's flat batch scaling comes from zero intermediate memory. Sequential Laguerre in PyTorch creates M_store [n+1, B, n, n] in global memory — at B=16,384, n=6, that's 128 MB of bandwidth per FL iteration. The CUDA kernel keeps everything in thread-local registers (232 registers per thread), touching global memory only for input load and output store.
This is a structural advantage that cannot be replicated in PyTorch's programming model. torch.compile optimizes graph execution but cannot convert global memory tensors to per-thread registers. The CUDA kernel is the right tool when batch size is large and matrix size is small.
Usage
Compiled Training Loop
from geolip.linalg import FLEigh
solver = torch.compile(FLEigh(mode='precise'), fullgraph=True)
for batch in dataloader:
A = compute_cm_matrices(batch) # [B, 6, 6] symmetric
eigenvalues, eigenvectors = solver(A) # zero graph breaks
loss = geometric_loss(eigenvalues, eigenvectors)
loss.backward()
optimizer.step()
High-Batch Inference
from fl_eigh_cuda import fl_eigh_cuda
# B=16384: 1.73× cuSOLVER, 0.7 MB memory
eigenvalues, eigenvectors = fl_eigh_cuda(A)
Drop-In Replacement
import geolip.linalg as LA
# Everything you use from torch.linalg, plus our optimizations
vals, vecs = LA.eigh(A) # auto-dispatches
x = LA.solve(A, b) # passthrough to torch.linalg
Mathematical Lineage
- Faddeev-LeVerrier (1840): Characteristic polynomial via matrix recurrence
- Laguerre (1834): Cubically convergent root-finding for real-rooted polynomials
- Newton (1669): Quadratic root polishing
- Schulz (1933): Iterative matrix inverse square root
- Rayleigh (1877): Optimal eigenvalue from eigenvector
- Eckart-Young (1936): SVD optimality
- Jacobi (1846): Cyclic rotation eigendecomposition (used in Triton SVD kernels for n=2,3)
Citation
@software{geolip_fl_eigh_2026,
author = {AbstractPhil},
title = {FL Hybrid Eigendecomposition: Compilable, Mathematically Superior Alternative to cuSOLVER},
year = {2026},
url = {https://github.com/AbstractEyes/geolip-core},
note = {Part of the GeoLIP geometric deep learning ecosystem}
}
License
Apache 2.0