| --- |
| license: apache-2.0 |
| library_name: mlx-image |
| tags: |
| - mlx |
| - mlx-image |
| - vision |
| - image-classification |
| datasets: |
| - imagenet-1k |
| --- |
| |
| # efficientnet_b0 |
| |
| An EfficientNet B0 model architecture, pretrained on ImageNet-1K. |
| |
| Disclaimer: this is a port of the Torchvision model weights to Apple MLX Framework. |
| |
| See [mlx-convert-scripts](https://github.com/lextoumbourou/mlx-convert-scripts) repo for the conversion script used. |
| |
| ## How to use |
| |
| ```bash |
| pip install mlx-image |
| ``` |
| |
| Here is how to use this model for image classification: |
| |
| ```python |
| import mlx.core as mx |
| from mlxim.model import create_model |
| from mlxim.io import read_rgb |
| from mlxim.transform import ImageNetTransform |
| from mlxim.utils.imagenet import IMAGENET2012_CLASSES |
|
|
| transform = ImageNetTransform(train=False, img_size=224) |
| x = transform(read_rgb("cat.jpg")) |
| x = mx.array(x) |
| x = mx.expand_dims(x, 0) |
| |
| model = create_model("efficientnet_b0") |
| model.eval() |
| |
| logits = model(x) |
| predicted_idx = mx.argmax(logits, axis=-1).item() |
| predicted_class = list(IMAGENET2012_CLASSES.values())[predicted_idx] |
| |
| print(f"Predicted class: {predicted_class}") |
| ``` |
| |
| You can also use the embeds from layer before head: |
| |
| ```python |
| import mlx.core as mx |
| from mlxim.model import create_model |
| from mlxim.io import read_rgb |
| from mlxim.transform import ImageNetTransform |
|
|
| transform = ImageNetTransform(train=False, img_size=224) |
| x = transform(read_rgb("cat.jpg")) |
| x = mx.array(x) |
| x = mx.expand_dims(x, 0) |
| |
| # first option |
| model = create_model("efficientnet_b0", num_classes=0) |
| model.eval() |
|
|
| embeds = model(x) |
|
|
| # second option |
| model = create_model("efficientnet_b0") |
| model.eval() |
|
|
| embeds = model.get_features(x) |
| ``` |
| |