Plant_Disease / inference.py
kero2111's picture
Upload inference.py with huggingface_hub
0b88a08 verified
import tensorflow as tf
import numpy as np
from PIL import Image
import os
# ========================
# Custom layer (مطلوبة)
# ========================
from tensorflow.keras.layers import Layer
class CustomScaleLayer(Layer):
def __init__(self, scale=1.0, **kwargs):
super(CustomScaleLayer, self).__init__(**kwargs)
self.scale = scale
def call(self, inputs):
if isinstance(inputs, (list, tuple)):
x = tf.add_n(inputs)
else:
x = inputs
return x * self.scale
def get_config(self):
config = super().get_config()
config.update({"scale": self.scale})
return config
# ========================
# Load the model
# ========================
def load_model(model_path):
try:
model = tf.keras.models.load_model(model_path, custom_objects={'CustomScaleLayer': CustomScaleLayer})
print("Model loaded successfully!")
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# ========================
# Image preprocessing
# ========================
def preprocess_image(image_path, target_size=(299, 299), normalize=True):
try:
img = Image.open(image_path)
if img.mode != 'RGB':
img = img.convert('RGB')
img = img.resize(target_size)
img_array = np.array(img)
if normalize:
img_array = img_array / 255.0
img_array = np.expand_dims(img_array, axis=0)
return img_array
except Exception as e:
print(f"Error preprocessing image: {e}")
return None
# ========================
# Prediction
# ========================
def predict_disease(model, image_array):
try:
prediction = model.predict(image_array)
predicted_class = np.argmax(prediction[0])
confidence = prediction[0][predicted_class]
return predicted_class, confidence, prediction[0]
except Exception as e:
print(f"Error making prediction: {e}")
return None, None, None
# ========================
# Class names
# ========================
def get_class_name(class_index):
classes = [
"Apple___Apple_scab", "Apple___Black_rot", "Apple___Cedar_apple_rust", "Apple___healthy",
"Blueberry___healthy", "Cherry_(including_sour)___Powdery_mildew", "Cherry_(including_sour)___healthy",
"Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot", "Corn_(maize)___Common_rust_",
"Corn_(maize)___Northern_Leaf_Blight", "Corn_(maize)___healthy", "Grape___Black_rot",
"Grape___Esca_(Black_Measles)", "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)", "Grape___healthy",
"Orange___Haunglongbing_(Citrus_greening)", "Peach___Bacterial_spot", "Peach___healthy",
"Pepper,_bell___Bacterial_spot", "Pepper,_bell___healthy", "Potato___Early_blight",
"Potato___Late_blight", "Potato___healthy", "Raspberry___healthy", "Soybean___healthy",
"Squash___Powdery_mildew", "Strawberry___Leaf_scorch", "Strawberry___healthy",
"Tomato___Bacterial_spot", "Tomato___Early_blight", "Tomato___Late_blight", "Tomato___Leaf_Mold",
"Tomato___Septoria_leaf_spot", "Tomato___Spider_mites Two-spotted_spider_mite",
"Tomato___Target_Spot", "Tomato___Tomato_Yellow_Leaf_Curl_Virus", "Tomato___Tomato_mosaic_virus",
"Tomato___healthy"
]
if 0 <= class_index < len(classes):
return classes[class_index]
else:
return "Unknown"
# ========================
# Main function
# ========================
def main():
model_path = "Pretrained_model.h5"
sample_image_path = "sample_image.jpg" # 👈 ضع هنا اسم الصورة
if not os.path.exists(model_path):
print(f"Model file not found at: {model_path}")
return
if not os.path.exists(sample_image_path):
print(f"Image file not found at: {sample_image_path}")
return
model = load_model(model_path)
if model is None:
return
# تحقق هل الموديل فيه طبقة Rescaling
has_rescaling = any(isinstance(layer, tf.keras.layers.Rescaling) for layer in model.layers)
image_array = preprocess_image(sample_image_path, target_size=(299, 299), normalize=not has_rescaling)
if image_array is None:
return
predicted_class, confidence, all_predictions = predict_disease(model, image_array)
if predicted_class is not None:
class_name = get_class_name(predicted_class)
print(f"\nPrediction Results:")
print(f"Predicted Class: {class_name}")
print(f"Confidence: {confidence:.2%}")
print(f"Class Index: {predicted_class}")
# Show top 3 predictions
top_3_indices = np.argsort(all_predictions)[-3:][::-1]
print(f"\nTop 3 Predictions:")
for i, idx in enumerate(top_3_indices):
class_name = get_class_name(idx)
confidence = all_predictions[idx]
print(f"{i+1}. {class_name}: {confidence:.2%}")
if __name__ == "__main__":
main()