Model Card for Guided-Chest-CT-LeJEPA

This repository hosts the backbone weights for a foundational Vision Transformer (ViT-Large) trained on Chest CT scans using a highly customized Latent-Euclidean Joint-Embedding Predictive Architecture (LeJEPA).

This model was developed by the Institute for Biomedical Informatics Center for Applied AI (IBI-CAAI) at the University of Kentucky to serve as a robust feature extractor for downstream medical imaging tasks, including segmentation, multi-instance learning (MIL), and anomaly detection.

Model Details

  • Model Type: Vision Transformer (ViT-Large) for Chest CT analysis.
  • Developed by: Institute for Biomedical Informatics Center for Applied AI (IBI-CAAI)
  • Model Date: 03/2026
  • Base Model Architecture: vit_large_patch14_dinov2 (via timm). Note: This model was randomly initialized and trained entirely from scratch.
  • Input: 1-channel Grayscale CT Image.
  • Output: Class token and patch tokens. These can be used for various downstream tasks (e.g., classification, anomaly detection, multi-instance learning).
  • Embedding Dimension: 1024
  • Patch Size: 14
  • Image Size Compatibility: The model was trained using global crops of 224x224 and local crops of 140x140.
    • The architecture supports variable input sizes natively. It can accept images of any resolution, provided the height and width are divisible by the patch size (14).
  • License: CC BY-NC-SA 4.0 (Inherited from the CT-RATE dataset terms).

Intended Uses

This model is intended for research purposes in the field of medical imaging and radiology.

  • Primary Intended Uses:
    • Feature extraction for quantitative analysis of Chest CT scans.
    • Foundational backbone for downstream models predicting organ anomalies, segmentation, or volume-level analysis via MIL.

Training Data

  • Dataset(s): The model was trained exclusively on the train split of the CT-RATE dataset, leveraging annotations for guided cropping from ReXGroundingCT.
  • Preprocessing: Hounsfield Units (HU) were strictly clipped between [-997.0, 888.0]. These values correspond to the 0.5% and 99.5% pixel intensities of the foreground voxels calculated on a subset of the CT-RATE dataset. The clipped values were mapped to a [0, 1] range, followed by Z-score normalization utilizing a dataset mean of -142.39 and standard deviation of 360.97.

Training Procedure

  • Training System/Framework: LeJEPA with a guided semi-3D cropping strategy, along with an auxiliary supervised loss. Distributed Data Parallel (DDP) was utilized for scaling.
  • Hardware & Scale: The model was trained across 8x NVIDIA H200 GPUs for a total of 50,000 steps (executed in two 25,000-step runs) with a global batch size of 512.
  • Training Strategy: Unlike standard 2D slice-by-slice SSL training, this model was trained with spatial and anatomical awareness. Global and local crops were sampled from within a 12mm physical slab (using dynamic Z-spacing calculations) rather than a single 2D plane.
  • Base Model: Initialized from scratch using the vit_large_patch14_dinov2 architecture defined in the timm library. No pre-trained weights were used.
  • Training Objective(s): Self-supervised LeJEPA loss combined with an auxiliary multi-hot supervised loss predicting the presence of 118 TotalSegmentator classes within the cropped regions.

Training Convergence

Training Curves Training metrics from the first 25,000 steps, demonstrating highly stable convergence across the LeJEPA objectives and a steady climb to >0.8 F1 score on the auxiliary organ-prediction task.

Data Augmentation Pipeline

During training, a specialized, GPU-accelerated augmentation pipeline was applied to generate the multi-crop views required for the LeJEPA architecture.

1. Spatial & Guided Cropping

  • Global Crops: 2 global crops were generated per volume, sized at 224x224 pixels, with a random scale between 60% and 100% of the original image dimensions.
  • Local Crops: 8 local crops were generated per volume, sized at 140x140 pixels, with a random scale between 30% and 60% of the original image dimensions.
  • Anatomy & Label Guidance: Local crops had an 80% probability of being explicitly centered on physical anatomy (probabilistically guided by TotalSegmentator or ReXGroundingCT masks) to ensure the model learns from tissue rather than background air.
  • Flip & Resize: All crops were resized using nearest-neighbor interpolation to prevent the artificial averaging of Hounsfield Units, followed by a random horizontal flip applied with a 50% probability.

2. Intensity & Noise Augmentations

  • Random Gamma Correction: A random gamma shift (range: 0.9 to 1.1) was applied independently to all crops with an 80% probability.
  • Gaussian Blur: Applied selectively to global crops using a 3x3 kernel. The first global crop was always blurred (100% probability), while the second global crop was blurred with only a 10% probability. Local crops were not blurred.

How to Get Started with the Model

Because CT scans require strict Hounsfield Unit (HU) windowing and normalization to match the training distribution, you must apply the specific preprocessing logic below before passing tensors to the model.

import torch
import torch.nn.functional as F
import numpy as np
import timm
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

class CTInferenceTransform:
    """ 
    Applies the exact HU windowing and Z-score normalization used during LeJEPA training. 
    """
    def __init__(self):
        # Calculated from the 0.5% and 99.5% foreground pixel intensities of CT-RATE
        self.clip_min = -997.0
        self.clip_max = 888.0
        self.mean_hu = -142.39
        self.std_hu = 360.97
        self.patch_size = 14
        
        # Calculate 0-1 scaled mean and std
        range_val = self.clip_max - self.clip_min
        self.norm_mean = (self.mean_hu - self.clip_min) / range_val
        self.norm_std = self.std_hu / range_val

    def __call__(self, volume):
        # Expects a 2D numpy array or torch tensor (H, W) in Hounsfield Units
        if isinstance(volume, np.ndarray):
            volume = torch.from_numpy(volume).float()
        if volume.ndim == 2:
            volume = volume.unsqueeze(0) # Add channel dim: (1, H, W)

        # 1. Clamp HU values and map strictly to [0, 1]
        volume = torch.clamp(volume, self.clip_min, self.clip_max)
        range_val = self.clip_max - self.clip_min
        volume = (volume - self.clip_min) / range_val

        # 2. Z-score standardization
        volume = (volume - self.norm_mean) / self.norm_std

        # 3. Padding/Interpolation for strict patch size alignment
        C, H, W = volume.shape
        target_h = int((H // self.patch_size) * self.patch_size)
        target_w = int((W // self.patch_size) * self.patch_size)

        if target_h != H or target_w != W:
            volume = volume.unsqueeze(0) # (1, C, H, W)
            # Use nearest interpolation to prevent averaging of exact HU values
            volume = F.interpolate(volume, size=(target_h, target_w), mode='nearest')
            volume = volume.squeeze(0)

        # Returns (1, 1, H, W). For batched inference, stack these along dim=0.
        return volume.unsqueeze(0) 

def load_guided_ct_model(repo_id="IBI-CAAI/Guided-Chest-CT-LeJEPA"):
    """
    Downloads and initializes the ViT-Large backbone using timm and safetensors.
    """
    # 1. Initialize the base architecture
    model = timm.create_model(
        "vit_large_patch14_dinov2",
        pretrained=False,
        num_classes=0,
        in_chans=1,          # Grayscale CT inputs
        img_size=518,        # Base initialization size
        dynamic_img_size=True # Allows native processing of variable resolutions (e.g., 504x504)
    )
    
    # 2. Download and load the custom safetensors weights
    model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
    state_dict = load_file(model_path)
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    
    return model

if __name__ == "__main__":
    # Initialize the transform and the model
    transform = CTInferenceTransform()
    model = load_guided_ct_model()
    
    # Simulate a raw CT slice (Replace this with an actual NIfTI/DICOM load in Hounsfield Units)
    raw_ct_slice = np.random.uniform(-1000, 1000, size=(512, 512)) 
    
    # Process the image to ensure correct normalization and patch dimension alignment
    # A 512x512 image will be automatically resized to 504x504 (closest multiple of 14)
    input_tensor = transform(raw_ct_slice)
    
    # Extract embeddings
    with torch.no_grad():
        # Option A: Get the single pooled global feature for the entire slice
        global_feature = model(input_tensor)
        
        # Option B: Get the unpooled, dense spatial patch tokens (for fine-grained tasks like Segmentation)
        patch_tokens = model.forward_features(input_tensor)
        
    print(f"Input tensor shape: {input_tensor.shape}")  
    print(f"Extracted features shape: {global_feature.shape}")
    print(f"Dense patch tokens shape: {patch_tokens.shape}")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Datasets used to train IBI-CAAI/Guided-Chest-CT-LeJEPA

Collection including IBI-CAAI/Guided-Chest-CT-LeJEPA