Model Card for Chest-CT-LeJEPA

This repository hosts the backbone weights for a foundational Vision Transformer (ViT-Large) trained on Chest CT scans using a Latent-Euclidean Joint-Embedding Predictive Architecture (LeJEPA).

Unlike its "Guided" counterparts, this model was trained purely using self-supervised LeJEPA objectives without any auxiliary dense supervision or anatomical mask guidance, allowing it to learn general-purpose representations of CT volumes in a completely unsupervised manner.

This model was developed by the Institute for Biomedical Informatics Center for Applied AI (IBI-CAAI) at the University of Kentucky to serve as a robust feature extractor for downstream medical imaging tasks, including segmentation, multi-instance learning (MIL), and anomaly detection.

Model Details

  • Model Type: Vision Transformer (ViT-Large) for Chest CT analysis.
  • Developed by: Institute for Biomedical Informatics Center for Applied AI (IBI-CAAI)
  • Model Date: 04/2026
  • Base Model Architecture: vit_large_patch14_dinov2 (via timm). Note: This model was randomly initialized and trained entirely from scratch with modified architecture arguments.
  • Input: 1-channel Grayscale CT Image.
  • Output: Class token and patch tokens. These can be used for various downstream tasks (e.g., classification, anomaly detection, multi-instance learning).
  • Embedding Dimension: 1024
  • Patch Size: 16
  • Image Size Compatibility: Native 512x512 resolution. The architecture supports variable input sizes dynamically, provided the height and width are divisible by the patch size (16).
  • License: CC BY-NC-SA 4.0 (Inherited from the CT-RATE dataset terms).

Intended Uses

This model is intended for research purposes in the field of medical imaging and radiology.

  • Primary Intended Uses:
    • Feature extraction for quantitative analysis of Chest CT scans.
    • Foundational backbone for downstream models predicting organ anomalies, segmentation, or volume-level analysis via MIL.

Training Data

  • Dataset(s): The model was trained exclusively on the train split of the CT-RATE dataset.
  • Preprocessing: Hounsfield Units (HU) were strictly clipped between [-997.0, 888.0]. These values correspond to the 0.5% and 99.5% pixel intensities of the foreground voxels calculated on a subset of the CT-RATE dataset. The clipped values were mapped to a [0, 1] range, followed by Z-score normalization utilizing a dataset mean of -142.39 and standard deviation of 360.97.

Training Procedure

  • Training System/Framework: Distributed Data Parallel (DDP) utilizing bf16 mixed precision.
  • Hardware & Scale: The model was trained for a total of 50,000 iterations. The configuration utilized a batch size of 64 per GPU, a peak learning rate of 3.0e-04 (decaying to 3.0e-05), and a 5,000-step warmup.
  • Training Strategy: Global and local crops were sampled from within a 12mm physical slab rather than a single 2D plane to ensure anatomical awareness during the self-supervised reconstruction task.

Self-Supervised Objective: LeJEPA

The model was trained using solely the self-supervised LeJEPA objective, which combines a standard prediction loss ($\mathcal{L}{\text{pred}}$) with Sketched Isotropic Gaussian Regularization ($\mathcal{L}{\text{SIGReg}}$): LLeJEPA=(1−λ)Lpred+λLSIGReg\mathcal{L}_{\text{LeJEPA}}=(1-\lambda)\mathcal{L}_{\text{pred}}+\lambda\mathcal{L}_{\text{SIGReg}} where $\lambda=0.02$. The SIGReg formulation projects embeddings onto a set of random 1D directions to enforce normality via empirical characteristic functions.

Data Augmentation Pipeline

A specialized, GPU-accelerated augmentation pipeline generated the multi-crop views required for the LeJEPA architecture.

1. Spatial Cropping

  • Global Crops: 2 global crops generated per volume, sized at 256x256 pixels, with a random scale between 60% and 100% of the original image dimensions.
  • Local Crops: 8 local crops generated per volume, sized at 144x144 pixels, with a random scale between 30% and 60%.
  • Flip & Resize: Resized using nearest-neighbor interpolation to prevent artificial averaging of HU values, followed by a random horizontal flip (50% probability).

2. Intensity & Noise Augmentations

  • Random Gamma Correction: A random gamma shift (range: 0.9 to 1.1) was applied independently to all crops with an 80% probability.

How to Get Started with the Model

Because CT scans require strict Hounsfield Unit (HU) windowing and normalization to match the training distribution, you must apply the specific preprocessing logic below. Note: With the native 512px architecture, standard CT slices (512x512) do not require dynamic padding.

import torch
import numpy as np
import timm
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

class CTInferenceTransform:
    """ 
    Applies the exact HU windowing and Z-score normalization used during LeJEPA training. 
    Assumes standard 512x512 CT slice inputs.
    """
    def __init__(self):
        self.clip_min = -997.0
        self.clip_max = 888.0
        self.mean_hu = -142.39
        self.std_hu = 360.97
        
        # Calculate 0-1 scaled mean and std
        range_val = self.clip_max - self.clip_min
        self.norm_mean = (self.mean_hu - self.clip_min) / range_val
        self.norm_std = self.std_hu / range_val

    def __call__(self, volume):
        # Expects a 2D numpy array or torch tensor (H, W) in Hounsfield Units
        if isinstance(volume, np.ndarray):
            volume = torch.from_numpy(volume).float()
        if volume.ndim == 2:
            volume = volume.unsqueeze(0) # Add channel dim: (1, H, W)

        # 1. Clamp HU values and map strictly to [0, 1]
        volume = torch.clamp(volume, self.clip_min, self.clip_max)
        range_val = self.clip_max - self.clip_min
        volume = (volume - self.clip_min) / range_val

        # 2. Z-score standardization
        volume = (volume - self.norm_mean) / self.norm_std

        # Returns (1, 1, H, W). For batched inference, stack these along dim=0.
        return volume.unsqueeze(0) 

def load_ct_model(repo_id="IBI-CAAI/Chest-CT-LeJEPA"):
    """
    Downloads and initializes the ViT-Large backbone using timm and safetensors.
    """
    # 1. Initialize the base architecture with overrides (patch_size=16, img_size=512)
    model = timm.create_model(
        "vit_large_patch14_dinov2",
        pretrained=False,
        num_classes=0,
        in_chans=1,          
        patch_size=16,       
        img_size=512,        
        dynamic_img_size=True 
    )
    
    # 2. Download and load the custom safetensors weights
    model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
    state_dict = load_file(model_path)
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    
    return model

if __name__ == "__main__":
    # Initialize the transform and the model
    transform = CTInferenceTransform()
    model = load_ct_model()
    
    # Simulate a raw CT slice (Replace this with an actual NIfTI/DICOM load in Hounsfield Units)
    raw_ct_slice = np.random.uniform(-1000, 1000, size=(512, 512)) 
    
    # Process the image to ensure correct normalization
    input_tensor = transform(raw_ct_slice)
    
    # Extract embeddings
    with torch.no_grad():
        # Option A: Get the single pooled global feature for the entire slice
        global_feature = model(input_tensor)
        
        # Option B: Get the unpooled, dense spatial patch tokens (for fine-grained tasks like Segmentation)
        patch_tokens = model.forward_features(input_tensor)
        
    print(f"Input tensor shape: {input_tensor.shape}")  
    print(f"Extracted features shape: {global_feature.shape}")
    print(f"Dense patch tokens shape: {patch_tokens.shape}")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train IBI-CAAI/Chest-CT-LeJEPA

Collection including IBI-CAAI/Chest-CT-LeJEPA