Model Card for Guided-Chest-CT-LeJEPA-V2

This repository hosts the backbone weights for a foundational Vision Transformer (ViT-Large) trained on Chest CT scans using a highly customized Latent-Euclidean Joint-Embedding Predictive Architecture (LeJEPA).

This V2 model improves upon the original by introducing a native 512px architecture, a refined patch size (16), and a more sophisticated dense auxiliary supervision objective that balances macroscopic and fine-grained pathology tracking.

This model was developed by the Institute for Biomedical Informatics Center for Applied AI (IBI-CAAI) at the University of Kentucky to serve as a robust feature extractor for downstream medical imaging tasks, including segmentation, multi-instance learning (MIL), and anomaly detection.

Model Details

  • Model Type: Vision Transformer (ViT-Large) for Chest CT analysis.
  • Developed by: Institute for Biomedical Informatics Center for Applied AI (IBI-CAAI)
  • Model Date: 04/2026
  • Base Model Architecture: vit_large_patch14_dinov2 (via timm). Note: This model was randomly initialized and trained entirely from scratch with modified architecture arguments.
  • Input: 1-channel Grayscale CT Image.
  • Output: Class token and patch tokens. These can be used for various downstream tasks (e.g., classification, anomaly detection, multi-instance learning).
  • Embedding Dimension: 1024
  • Patch Size: 16
  • Image Size Compatibility: Native 512x512 resolution. The architecture supports variable input sizes dynamically, provided the height and width are divisible by the patch size (16).
  • License: CC BY-NC-SA 4.0 (Inherited from the CT-RATE dataset terms).

Intended Uses

This model is intended for research purposes in the field of medical imaging and radiology.

  • Primary Intended Uses:
    • Feature extraction for quantitative analysis of Chest CT scans.
    • Foundational backbone for downstream models predicting organ anomalies, segmentation, or volume-level analysis via MIL.

Training Data

  • Dataset(s): The model was trained exclusively on the train split of the CT-RATE dataset, leveraging annotations for targeted cropping from ReXGroundingCT.
  • Sampling Strategy: To prioritize pathological representation, scans containing ReX masks were oversampled such that 20% of every training batch comprised slices from verified anomalous regions.
  • Preprocessing: Hounsfield Units (HU) were strictly clipped between [-997.0, 888.0]. These values correspond to the 0.5% and 99.5% pixel intensities of the foreground voxels calculated on a subset of the CT-RATE dataset. The clipped values were mapped to a [0, 1] range, followed by Z-score normalization utilizing a dataset mean of -142.39 and standard deviation of 360.97.

Training Procedure

  • Training System/Framework: Distributed Data Parallel (DDP) utilizing bf16 mixed precision.
  • Hardware & Scale: The model was trained for a total of 50,000 iterations. The configuration utilized a batch size of 64 per GPU, a peak learning rate of 3.0e-04 (decaying to 3.0e-05), and a 5,000-step warmup.
  • Training Strategy: Global and local crops were sampled from within a 12mm physical slab rather than a single 2D plane to ensure anatomical awareness.

LeJEPA Formulation and Dense Auxiliary Supervision

The total loss combines the self-supervised LeJEPA objective with a dense auxiliary head: $\mathcal{L}{\text{Total}}=\mathcal{L}{\text{LeJEPA}}+\lambda_{\text{aux}}\mathcal{L}{\text{Aux}}$ where $\lambda{\text{aux}}=0.1$.

Self-Supervised Objective: LeJEPA

LeJEPA combines a standard prediction loss ($\mathcal{L}{\text{pred}}$) with Sketched Isotropic Gaussian Regularization ($\mathcal{L}{\text{SIGReg}}$): LLeJEPA=(1βˆ’Ξ»)Lpred+Ξ»LSIGReg\mathcal{L}_{\text{LeJEPA}}=(1-\lambda)\mathcal{L}_{\text{pred}}+\lambda\mathcal{L}_{\text{SIGReg}} where $\lambda=0.02$. The SIGReg formulation projects embeddings onto a set of random 1D directions to enforce normality via empirical characteristic functions.

Auxiliary Supervision Objective

The auxiliary loss evaluates soft, fractional labels describing the proportional composition of anatomy/pathology within a crop. It is split between macroscopic and fine-grained views, and further divided between TotalSegmentator (TS) and ReX predictions: LAux=0.5β‹…LGlobal+0.5β‹…LPatch\mathcal{L}_{\text{Aux}}=0.5\cdot\mathcal{L}_{\text{Global}}+0.5\cdot\mathcal{L}_{\text{Patch}} Because ReX labels represent sparse abnormalities, their loss is strictly masked to prevent penalizing the model on unverified scans.

Data Augmentation Pipeline

A specialized, GPU-accelerated augmentation pipeline generated the multi-crop views required for the LeJEPA architecture.

1. Spatial & Targeted Cropping

  • Global Crops: 2 global crops generated per volume, sized at 256x256 pixels, with a random scale between 60% and 100% of the original image dimensions.
  • Local Crops: 8 local crops generated per volume, sized at 144x144 pixels, with a random scale between 30% and 60%.
  • Anatomy & Label Guidance: Local crops had an 80% probability of being explicitly centered on physical anatomy (probabilistically targeted by TS or ReX masks).
  • Flip & Resize: Resized using nearest-neighbor interpolation to prevent artificial averaging of HU values, followed by a random horizontal flip (50% probability).

2. Intensity & Noise Augmentations

  • Random Gamma Correction: A random gamma shift (range: 0.9 to 1.1) was applied independently to all crops with an 80% probability.

How to Get Started with the Model

Because CT scans require strict Hounsfield Unit (HU) windowing and normalization to match the training distribution, you must apply the specific preprocessing logic below. Note: With V2's native 512px architecture, standard CT slices (512x512) no longer require dynamic padding.

import torch
import numpy as np
import timm
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

class CTInferenceTransform:
    """ 
    Applies the exact HU windowing and Z-score normalization used during LeJEPA V2 training. 
    Assumes standard 512x512 CT slice inputs.
    """
    def __init__(self):
        self.clip_min = -997.0
        self.clip_max = 888.0
        self.mean_hu = -142.39
        self.std_hu = 360.97
        
        # Calculate 0-1 scaled mean and std
        range_val = self.clip_max - self.clip_min
        self.norm_mean = (self.mean_hu - self.clip_min) / range_val
        self.norm_std = self.std_hu / range_val

    def __call__(self, volume):
        # Expects a 2D numpy array or torch tensor (H, W) in Hounsfield Units
        if isinstance(volume, np.ndarray):
            volume = torch.from_numpy(volume).float()
        if volume.ndim == 2:
            volume = volume.unsqueeze(0) # Add channel dim: (1, H, W)

        # 1. Clamp HU values and map strictly to [0, 1]
        volume = torch.clamp(volume, self.clip_min, self.clip_max)
        range_val = self.clip_max - self.clip_min
        volume = (volume - self.clip_min) / range_val

        # 2. Z-score standardization
        volume = (volume - self.norm_mean) / self.norm_std

        # Returns (1, 1, H, W). For batched inference, stack these along dim=0.
        return volume.unsqueeze(0) 

def load_guided_ct_model_v2(repo_id="IBI-CAAI/Guided-Chest-CT-LeJEPA-V2"):
    """
    Downloads and initializes the ViT-Large backbone using timm and safetensors.
    """
    # 1. Initialize the base architecture with V2 overrides (patch_size=16, img_size=512)
    model = timm.create_model(
        "vit_large_patch14_dinov2",
        pretrained=False,
        num_classes=0,
        in_chans=1,          
        patch_size=16,       
        img_size=512,        
        dynamic_img_size=True 
    )
    
    # 2. Download and load the custom safetensors weights
    model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
    state_dict = load_file(model_path)
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    
    return model

if __name__ == "__main__":
    # Initialize the transform and the model
    transform = CTInferenceTransform()
    model = load_guided_ct_model_v2()
    
    # Simulate a raw CT slice (Replace this with an actual NIfTI/DICOM load in Hounsfield Units)
    raw_ct_slice = np.random.uniform(-1000, 1000, size=(512, 512)) 
    
    # Process the image to ensure correct normalization
    input_tensor = transform(raw_ct_slice)
    
    # Extract embeddings
    with torch.no_grad():
        # Option A: Get the single pooled global feature for the entire slice
        global_feature = model(input_tensor)
        
        # Option B: Get the unpooled, dense spatial patch tokens (for fine-grained tasks like Segmentation)
        patch_tokens = model.forward_features(input_tensor)
        
    print(f"Input tensor shape: {input_tensor.shape}")  
    print(f"Extracted features shape: {global_feature.shape}")
    print(f"Dense patch tokens shape: {patch_tokens.shape}")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Datasets used to train IBI-CAAI/Guided-Chest-CT-LeJEPA-V2

Collection including IBI-CAAI/Guided-Chest-CT-LeJEPA-V2