granite-4.0-3b-vision / downsampling.py
Avihu's picture
Upload folder using huggingface_hub
a0bef21 verified
import torch
from torch import nn
import math
from fractions import Fraction
from transformers.models.blip_2.configuration_blip_2 import Blip2QFormerConfig
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerModel
class InterpolateDownsampler:
"""Spatial downsampling via area interpolation."""
def __init__(self, config, mode="area"):
self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size
self.new_image_side = int(self.orig_image_side * Fraction(config.downsample_rate))
self.mode = mode
def __call__(self, image_features):
batch_size, _, dim = image_features.size()
up_shape = [batch_size] + [self.orig_image_side] * 2 + [dim]
large_image_permuted = image_features.view(up_shape).permute(0,3,1,2)
small_image_permuted = torch.nn.functional.interpolate(
large_image_permuted, size=(self.new_image_side, self.new_image_side),
mode=self.mode,
)
final = small_image_permuted.permute(0,2,3,1).flatten(1,2)
return final
class SpatialOffsetDownsampler:
"""
Downsampler that samples one position from each 2x2 block across the image.
Maintains full spatial coverage while creating local continuity.
"""
def __init__(self, config, offset=0):
"""
Args:
config: Model configuration
offset: Integer offset (0, 1, 2, or 3) for position within each 2x2 block
0: top-left, 1: top-right, 2: bottom-left, 3: bottom-right
"""
self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size
self.new_image_side = self.orig_image_side // 2
self.offset = offset
self.offsets = [(0, 0), (0, 1), (1, 0), (1, 1)]
self.offset_h, self.offset_w = self.offsets[offset]
def __call__(self, image_features):
batch_size, seq_len, hidden_dim = image_features.shape
features_2d = image_features.reshape(batch_size, self.orig_image_side, self.orig_image_side, hidden_dim)
n_blocks = self.new_image_side
features_blocks = features_2d.reshape(
batch_size, n_blocks, 2, n_blocks, 2, hidden_dim
)
sampled = features_blocks[:, :, self.offset_h, :, self.offset_w, :]
sampled = sampled.reshape(batch_size, -1, hidden_dim)
return sampled
class WindowQFormerDownsampler(nn.Module):
"""Window-based QFormer downsampler that processes image patches in windows."""
def __init__(self, config, spatial_offset=None):
super().__init__()
llm_hidden_size = config.text_config.hidden_size
vision_hidden_size = config.vision_config.hidden_size
self.dropout = nn.Dropout(config.projector_dropout)
if spatial_offset is not None:
self.downsampler = SpatialOffsetDownsampler(config, offset=spatial_offset)
else:
self.downsampler = InterpolateDownsampler(config)
configuration = Blip2QFormerConfig(
hidden_size=vision_hidden_size,
num_attention_heads=vision_hidden_size // 64,
intermediate_size=3072,
num_hidden_layers=1,
encoder_hidden_size=vision_hidden_size,
cross_attention_frequency=1,
max_position_embeddings=2048,
use_qformer_text_input=False,
)
self.qformer = Blip2QFormerModel(configuration)
self.image_side = config.vision_config.image_size // config.vision_config.patch_size
q, w = config.downsample_rate.split("/")
self.query_side, self.window_side = int(q), int(w)
self.query_length = self.query_side ** 2
embed_std = 1 / math.sqrt(vision_hidden_size)
self.norm = nn.LayerNorm(vision_hidden_size, eps=1e-6)
self.query = nn.Parameter(torch.randn(1, self.query_length, vision_hidden_size) * embed_std)
self.image_positions = nn.Parameter(torch.randn(1, self.window_side ** 2, vision_hidden_size) * embed_std)
self.out_linear = nn.Linear(vision_hidden_size, llm_hidden_size, bias=True)
def _win(self, x, side, win):
"""
(B, side*side, C) raster -> (B*n*n, win*win, C) where n=side//win
windows are raster-ordered, and tokens inside each window are raster-ordered.
"""
B, _, C = x.shape
n = side // win
return (
x.view(B, side, side, C)
.view(B, n, win, n, win, C)
.transpose(2, 3) # (B, n, n, win, win, C)
.flatten(0, 2) # (B*n*n, win, win, C)
.flatten(1, 2) # (B*n*n, win*win, C)
)
def _unwin(self, xw, n, win):
"""
(B*n*n, win*win, C) -> (B, (n*win)^2, C) raster
"""
Bnn, _, C = xw.shape
assert Bnn % (n * n) == 0
B = Bnn // (n * n)
side = n * win
return (
xw.view(B, n, n, win, win, C)
.transpose(2, 3) # (B, n, win, n, win, C)
.contiguous()
.view(B, side, side, C)
.flatten(1, 2)
)
def forward(self, image_features):
B, HW, C = image_features.shape
assert HW == self.image_side * self.image_side
n = self.image_side // self.window_side
image_features = self.norm(image_features)
enc = self._win(image_features, self.image_side, self.window_side)
downsampled = self.downsampler(image_features)
new_side = n * self.query_side
downsampled_w = self._win(downsampled, new_side, self.query_side)
query_embeds = self.query + downsampled_w
encoder_embeds = self.dropout(enc + self.image_positions)
out_w = self.qformer(
query_embeds=query_embeds,
encoder_hidden_states=encoder_embeds,
return_dict=True,
).last_hidden_state
out = self._unwin(out_w, n=n, win=self.query_side)
out = self.dropout(out)
return self.out_linear(out)