Update app.py
Browse files
app.py
CHANGED
|
@@ -83,6 +83,7 @@ def register_embedding_model(model_name: str = "open-clip") -> Any:
|
|
| 83 |
Returns:
|
| 84 |
Embedding model instance
|
| 85 |
"""
|
|
|
|
| 86 |
# TODO: Get the registry instance
|
| 87 |
registry = EmbeddingFunctionRegistry().get_instance()
|
| 88 |
|
|
@@ -98,9 +99,11 @@ def register_embedding_model(model_name: str = "open-clip") -> Any:
|
|
| 98 |
|
| 99 |
|
| 100 |
# Global embedding model
|
|
|
|
| 101 |
clip_model = register_embedding_model()
|
| 102 |
|
| 103 |
|
|
|
|
| 104 |
class FashionItem(LanceModel):
|
| 105 |
"""
|
| 106 |
Schema for fashion items in vector database
|
|
@@ -113,9 +116,10 @@ class FashionItem(LanceModel):
|
|
| 113 |
2. image_uri: String field for image file paths
|
| 114 |
3. description: Optional string field for text descriptions
|
| 115 |
"""
|
| 116 |
-
|
| 117 |
# TODO: Add vector field for embeddings
|
| 118 |
vector: Vector(clip_model.ndims()) = clip_model.VectorField()
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
# TODO: Add image field
|
|
@@ -151,7 +155,7 @@ def setup_fashion_database(
|
|
| 151 |
4. Process and save images locally
|
| 152 |
5. Create vector database table
|
| 153 |
"""
|
| 154 |
-
|
| 155 |
# TODO: Connect to LanceDB
|
| 156 |
db = lancedb.connect(database_path)
|
| 157 |
|
|
@@ -406,6 +410,9 @@ def setup_llm_model(model_name: str = "Qwen/Qwen2.5-0.5B-Instruct") -> Tuple[Any
|
|
| 406 |
low_cpu_mem_usage=True
|
| 407 |
)
|
| 408 |
|
|
|
|
|
|
|
|
|
|
| 409 |
if device == "cpu":
|
| 410 |
model = model.to(device)
|
| 411 |
|
|
@@ -556,14 +563,9 @@ def run_fashion_rag_pipeline(
|
|
| 556 |
Phase 4 - STORAGE: Save retrieved images
|
| 557 |
"""
|
| 558 |
|
| 559 |
-
global clip_model
|
| 560 |
print("π Starting Fashion RAG Pipeline")
|
| 561 |
print("=" * 50)
|
| 562 |
|
| 563 |
-
# Initialize models if not already loaded
|
| 564 |
-
if clip_model is None:
|
| 565 |
-
clip_model = register_embedding_model()
|
| 566 |
-
|
| 567 |
try:
|
| 568 |
# PHASE 1: RETRIEVAL
|
| 569 |
print("π PHASE 1: RETRIEVAL")
|
|
@@ -744,6 +746,7 @@ def launch_gradio_app():
|
|
| 744 |
|
| 745 |
|
| 746 |
def main():
|
|
|
|
| 747 |
"""Main function to handle command line arguments and run the pipeline"""
|
| 748 |
|
| 749 |
# If running in Hugging Face Spaces, automatically launch the app
|
|
|
|
| 83 |
Returns:
|
| 84 |
Embedding model instance
|
| 85 |
"""
|
| 86 |
+
print("Registering embedding is called...")
|
| 87 |
# TODO: Get the registry instance
|
| 88 |
registry = EmbeddingFunctionRegistry().get_instance()
|
| 89 |
|
|
|
|
| 99 |
|
| 100 |
|
| 101 |
# Global embedding model
|
| 102 |
+
print("Global variable is called...")
|
| 103 |
clip_model = register_embedding_model()
|
| 104 |
|
| 105 |
|
| 106 |
+
|
| 107 |
class FashionItem(LanceModel):
|
| 108 |
"""
|
| 109 |
Schema for fashion items in vector database
|
|
|
|
| 116 |
2. image_uri: String field for image file paths
|
| 117 |
3. description: Optional string field for text descriptions
|
| 118 |
"""
|
| 119 |
+
print ("Class Fashion Item is called...")
|
| 120 |
# TODO: Add vector field for embeddings
|
| 121 |
vector: Vector(clip_model.ndims()) = clip_model.VectorField()
|
| 122 |
+
# This assignement is also required although AI show incorrect. As without this program show column mismatch error.
|
| 123 |
|
| 124 |
|
| 125 |
# TODO: Add image field
|
|
|
|
| 155 |
4. Process and save images locally
|
| 156 |
5. Create vector database table
|
| 157 |
"""
|
| 158 |
+
print("Setup fashion database is called...that uses FashionItem class")
|
| 159 |
# TODO: Connect to LanceDB
|
| 160 |
db = lancedb.connect(database_path)
|
| 161 |
|
|
|
|
| 410 |
low_cpu_mem_usage=True
|
| 411 |
)
|
| 412 |
|
| 413 |
+
# Set model to eval mode for consistent outputs
|
| 414 |
+
model.eval()
|
| 415 |
+
|
| 416 |
if device == "cpu":
|
| 417 |
model = model.to(device)
|
| 418 |
|
|
|
|
| 563 |
Phase 4 - STORAGE: Save retrieved images
|
| 564 |
"""
|
| 565 |
|
|
|
|
| 566 |
print("π Starting Fashion RAG Pipeline")
|
| 567 |
print("=" * 50)
|
| 568 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
try:
|
| 570 |
# PHASE 1: RETRIEVAL
|
| 571 |
print("π PHASE 1: RETRIEVAL")
|
|
|
|
| 746 |
|
| 747 |
|
| 748 |
def main():
|
| 749 |
+
print("Main method is called...")
|
| 750 |
"""Main function to handle command line arguments and run the pipeline"""
|
| 751 |
|
| 752 |
# If running in Hugging Face Spaces, automatically launch the app
|