| import tensorflow as tf
|
| import numpy as np
|
| from PIL import Image
|
| import os
|
|
|
|
|
|
|
|
|
| from tensorflow.keras.layers import Layer
|
|
|
| class CustomScaleLayer(Layer):
|
| def __init__(self, scale=1.0, **kwargs):
|
| super(CustomScaleLayer, self).__init__(**kwargs)
|
| self.scale = scale
|
|
|
| def call(self, inputs):
|
| if isinstance(inputs, (list, tuple)):
|
| x = tf.add_n(inputs)
|
| else:
|
| x = inputs
|
| return x * self.scale
|
|
|
| def get_config(self):
|
| config = super().get_config()
|
| config.update({"scale": self.scale})
|
| return config
|
|
|
|
|
|
|
|
|
| def load_model(model_path):
|
| try:
|
| model = tf.keras.models.load_model(model_path, custom_objects={'CustomScaleLayer': CustomScaleLayer})
|
| print("Model loaded successfully!")
|
| return model
|
| except Exception as e:
|
| print(f"Error loading model: {e}")
|
| return None
|
|
|
|
|
|
|
|
|
| def preprocess_image(image_path, target_size=(299, 299), normalize=True):
|
| try:
|
| img = Image.open(image_path)
|
| if img.mode != 'RGB':
|
| img = img.convert('RGB')
|
| img = img.resize(target_size)
|
| img_array = np.array(img)
|
| if normalize:
|
| img_array = img_array / 255.0
|
| img_array = np.expand_dims(img_array, axis=0)
|
| return img_array
|
| except Exception as e:
|
| print(f"Error preprocessing image: {e}")
|
| return None
|
|
|
|
|
|
|
|
|
| def predict_disease(model, image_array):
|
| try:
|
| prediction = model.predict(image_array)
|
| predicted_class = np.argmax(prediction[0])
|
| confidence = prediction[0][predicted_class]
|
| return predicted_class, confidence, prediction[0]
|
| except Exception as e:
|
| print(f"Error making prediction: {e}")
|
| return None, None, None
|
|
|
|
|
|
|
|
|
| def get_class_name(class_index):
|
| classes = [
|
| "Apple___Apple_scab", "Apple___Black_rot", "Apple___Cedar_apple_rust", "Apple___healthy",
|
| "Blueberry___healthy", "Cherry_(including_sour)___Powdery_mildew", "Cherry_(including_sour)___healthy",
|
| "Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot", "Corn_(maize)___Common_rust_",
|
| "Corn_(maize)___Northern_Leaf_Blight", "Corn_(maize)___healthy", "Grape___Black_rot",
|
| "Grape___Esca_(Black_Measles)", "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)", "Grape___healthy",
|
| "Orange___Haunglongbing_(Citrus_greening)", "Peach___Bacterial_spot", "Peach___healthy",
|
| "Pepper,_bell___Bacterial_spot", "Pepper,_bell___healthy", "Potato___Early_blight",
|
| "Potato___Late_blight", "Potato___healthy", "Raspberry___healthy", "Soybean___healthy",
|
| "Squash___Powdery_mildew", "Strawberry___Leaf_scorch", "Strawberry___healthy",
|
| "Tomato___Bacterial_spot", "Tomato___Early_blight", "Tomato___Late_blight", "Tomato___Leaf_Mold",
|
| "Tomato___Septoria_leaf_spot", "Tomato___Spider_mites Two-spotted_spider_mite",
|
| "Tomato___Target_Spot", "Tomato___Tomato_Yellow_Leaf_Curl_Virus", "Tomato___Tomato_mosaic_virus",
|
| "Tomato___healthy"
|
| ]
|
| if 0 <= class_index < len(classes):
|
| return classes[class_index]
|
| else:
|
| return "Unknown"
|
|
|
|
|
|
|
|
|
| def main():
|
| model_path = "Pretrained_model.h5"
|
| sample_image_path = "sample_image.jpg"
|
|
|
| if not os.path.exists(model_path):
|
| print(f"Model file not found at: {model_path}")
|
| return
|
|
|
| if not os.path.exists(sample_image_path):
|
| print(f"Image file not found at: {sample_image_path}")
|
| return
|
|
|
| model = load_model(model_path)
|
| if model is None:
|
| return
|
|
|
|
|
| has_rescaling = any(isinstance(layer, tf.keras.layers.Rescaling) for layer in model.layers)
|
| image_array = preprocess_image(sample_image_path, target_size=(299, 299), normalize=not has_rescaling)
|
| if image_array is None:
|
| return
|
|
|
| predicted_class, confidence, all_predictions = predict_disease(model, image_array)
|
| if predicted_class is not None:
|
| class_name = get_class_name(predicted_class)
|
| print(f"\nPrediction Results:")
|
| print(f"Predicted Class: {class_name}")
|
| print(f"Confidence: {confidence:.2%}")
|
| print(f"Class Index: {predicted_class}")
|
|
|
|
|
| top_3_indices = np.argsort(all_predictions)[-3:][::-1]
|
| print(f"\nTop 3 Predictions:")
|
| for i, idx in enumerate(top_3_indices):
|
| class_name = get_class_name(idx)
|
| confidence = all_predictions[idx]
|
| print(f"{i+1}. {class_name}: {confidence:.2%}")
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|