| --- |
| license: apache-2.0 |
| tags: |
| - mlx |
| - mlx-image |
| - vision |
| - image-classification |
| datasets: |
| - imagenet-1k |
| library_name: mlx-image |
| --- |
| # swin_tiny_patch4_window7_224 |
|
|
| A [Swin Transformer](https://arxiv.org/abs/2103.14030) image classification model. Weights are learned on ImageNet-1k data. |
|
|
| Disclaimer: This is a porting of the torchvision model weights to Apple MLX Framework. |
|
|
|
|
| ## How to use |
| ```bash |
| pip install mlx-image |
| ``` |
|
|
| Here is how to use this model for image classification: |
|
|
| ```python |
| from mlxim.model import create_model |
| from mlxim.io import read_rgb |
| from mlxim.transform import ImageNetTransform |
| |
| transform = ImageNetTransform(train=False, img_size=224) |
| x = transform(read_rgb("cat.png")) |
| x = mx.expand_dims(x, 0) |
| |
| model = create_model("swin_tiny_patch4_window7_224") |
| model.eval() |
| |
| logits = model(x) |
| ``` |
|
|
| You can also use the embeds from layer before head: |
| ```python |
| from mlxim.model import create_model |
| from mlxim.io import read_rgb |
| from mlxim.transform import ImageNetTransform |
| |
| transform = ImageNetTransform(train=False, img_size=224) |
| x = transform(read_rgb("cat.png")) |
| x = mx.expand_dims(x, 0) |
| |
| # first option |
| model = create_model("swin_tiny_patch4_window7_224", num_classes=0) |
| model.eval() |
| |
| embeds = model(x) |
| |
| # second option |
| model = create_model("swin_tiny_patch4_window7_224") |
| model.eval() |
| |
| embeds = model.get_features(x) |
| ``` |
|
|
|
|
| ## Model Comparison |
|
|
| Explore the metrics of this model in [mlx-image model results](https://github.com/riccardomusmeci/mlx-image/blob/main/results/results-imagenet-1k.csv). |