Spaces:
Running
on
Zero
Running
on
Zero
Update query_understanding.py
Browse files- query_understanding.py +25 -8
query_understanding.py
CHANGED
|
@@ -43,11 +43,12 @@ class QueryUnderstandingEngine:
|
|
| 43 |
def __init__(self):
|
| 44 |
"""初始化查詢理解引擎"""
|
| 45 |
self.sbert_model = None
|
|
|
|
| 46 |
self.breed_list = self._load_breed_list()
|
| 47 |
self.synonyms = self._initialize_synonyms()
|
| 48 |
self.semantic_templates = {}
|
| 49 |
-
|
| 50 |
-
|
| 51 |
|
| 52 |
def _load_breed_list(self) -> List[str]:
|
| 53 |
"""載入品種清單"""
|
|
@@ -66,25 +67,35 @@ class QueryUnderstandingEngine:
|
|
| 66 |
'Bulldog', 'Poodle', 'Beagle', 'Border_Collie', 'Yorkshire_Terrier']
|
| 67 |
|
| 68 |
def _initialize_sbert_model(self):
|
| 69 |
-
"""初始化 SBERT 模型"""
|
|
|
|
|
|
|
|
|
|
| 70 |
try:
|
|
|
|
| 71 |
model_options = ['all-MiniLM-L6-v2', 'all-mpnet-base-v2', 'all-MiniLM-L12-v2']
|
| 72 |
|
| 73 |
for model_name in model_options:
|
| 74 |
try:
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
| 78 |
except Exception as e:
|
| 79 |
print(f"Failed to load {model_name}: {str(e)}")
|
| 80 |
continue
|
| 81 |
|
| 82 |
print("All SBERT models failed to load. Using keyword-only analysis.")
|
| 83 |
self.sbert_model = None
|
|
|
|
| 84 |
|
| 85 |
except Exception as e:
|
| 86 |
print(f"Failed to initialize SBERT model: {str(e)}")
|
| 87 |
self.sbert_model = None
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
def _initialize_synonyms(self) -> DimensionalSynonyms:
|
| 90 |
"""初始化多維度同義詞字典"""
|
|
@@ -143,6 +154,10 @@ class QueryUnderstandingEngine:
|
|
| 143 |
|
| 144 |
def _build_semantic_templates(self):
|
| 145 |
"""建立語義模板向量(僅在 SBERT 可用時)"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
if not self.sbert_model:
|
| 147 |
return
|
| 148 |
|
|
@@ -192,6 +207,9 @@ class QueryUnderstandingEngine:
|
|
| 192 |
dimensions = self._extract_keyword_dimensions(normalized_input)
|
| 193 |
|
| 194 |
# 如果 SBERT 可用,進行語義分析增強
|
|
|
|
|
|
|
|
|
|
| 195 |
if self.sbert_model:
|
| 196 |
semantic_dimensions = self._extract_semantic_dimensions(user_input)
|
| 197 |
dimensions = self._merge_dimensions(dimensions, semantic_dimensions)
|
|
@@ -435,7 +453,6 @@ class QueryUnderstandingEngine:
|
|
| 435 |
])
|
| 436 |
}
|
| 437 |
|
| 438 |
-
# 便利函數
|
| 439 |
def analyze_user_query(user_input: str) -> QueryDimensions:
|
| 440 |
"""
|
| 441 |
便利函數:分析使用者查詢
|
|
@@ -461,4 +478,4 @@ def get_query_summary(user_input: str) -> Dict[str, Any]:
|
|
| 461 |
"""
|
| 462 |
engine = QueryUnderstandingEngine()
|
| 463 |
dimensions = engine.analyze_query(user_input)
|
| 464 |
-
return engine.get_dimension_summary(dimensions)
|
|
|
|
| 43 |
def __init__(self):
|
| 44 |
"""初始化查詢理解引擎"""
|
| 45 |
self.sbert_model = None
|
| 46 |
+
self._sbert_loading_attempted = False
|
| 47 |
self.breed_list = self._load_breed_list()
|
| 48 |
self.synonyms = self._initialize_synonyms()
|
| 49 |
self.semantic_templates = {}
|
| 50 |
+
# 延遲SBERT載入直到需要時才在GPU環境中進行
|
| 51 |
+
print("QueryUnderstandingEngine initialized (SBERT loading deferred)")
|
| 52 |
|
| 53 |
def _load_breed_list(self) -> List[str]:
|
| 54 |
"""載入品種清單"""
|
|
|
|
| 67 |
'Bulldog', 'Poodle', 'Beagle', 'Border_Collie', 'Yorkshire_Terrier']
|
| 68 |
|
| 69 |
def _initialize_sbert_model(self):
|
| 70 |
+
"""初始化 SBERT 模型 - 延遲載入以避免ZeroGPU CUDA初始化問題"""
|
| 71 |
+
if self.sbert_model is not None or getattr(self, '_sbert_loading_attempted', False):
|
| 72 |
+
return self.sbert_model
|
| 73 |
+
|
| 74 |
try:
|
| 75 |
+
print("Loading SBERT model for query understanding in GPU context...")
|
| 76 |
model_options = ['all-MiniLM-L6-v2', 'all-mpnet-base-v2', 'all-MiniLM-L12-v2']
|
| 77 |
|
| 78 |
for model_name in model_options:
|
| 79 |
try:
|
| 80 |
+
import torch
|
| 81 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 82 |
+
self.sbert_model = SentenceTransformer(model_name, device=device)
|
| 83 |
+
print(f"SBERT model {model_name} loaded successfully for query understanding on {device}")
|
| 84 |
+
return self.sbert_model
|
| 85 |
except Exception as e:
|
| 86 |
print(f"Failed to load {model_name}: {str(e)}")
|
| 87 |
continue
|
| 88 |
|
| 89 |
print("All SBERT models failed to load. Using keyword-only analysis.")
|
| 90 |
self.sbert_model = None
|
| 91 |
+
return None
|
| 92 |
|
| 93 |
except Exception as e:
|
| 94 |
print(f"Failed to initialize SBERT model: {str(e)}")
|
| 95 |
self.sbert_model = None
|
| 96 |
+
return None
|
| 97 |
+
finally:
|
| 98 |
+
self._sbert_loading_attempted = True
|
| 99 |
|
| 100 |
def _initialize_synonyms(self) -> DimensionalSynonyms:
|
| 101 |
"""初始化多維度同義詞字典"""
|
|
|
|
| 154 |
|
| 155 |
def _build_semantic_templates(self):
|
| 156 |
"""建立語義模板向量(僅在 SBERT 可用時)"""
|
| 157 |
+
# Initialize SBERT model if needed
|
| 158 |
+
if self.sbert_model is None:
|
| 159 |
+
self._initialize_sbert_model()
|
| 160 |
+
|
| 161 |
if not self.sbert_model:
|
| 162 |
return
|
| 163 |
|
|
|
|
| 207 |
dimensions = self._extract_keyword_dimensions(normalized_input)
|
| 208 |
|
| 209 |
# 如果 SBERT 可用,進行語義分析增強
|
| 210 |
+
if self.sbert_model is None:
|
| 211 |
+
self._initialize_sbert_model()
|
| 212 |
+
|
| 213 |
if self.sbert_model:
|
| 214 |
semantic_dimensions = self._extract_semantic_dimensions(user_input)
|
| 215 |
dimensions = self._merge_dimensions(dimensions, semantic_dimensions)
|
|
|
|
| 453 |
])
|
| 454 |
}
|
| 455 |
|
|
|
|
| 456 |
def analyze_user_query(user_input: str) -> QueryDimensions:
|
| 457 |
"""
|
| 458 |
便利函數:分析使用者查詢
|
|
|
|
| 478 |
"""
|
| 479 |
engine = QueryUnderstandingEngine()
|
| 480 |
dimensions = engine.analyze_query(user_input)
|
| 481 |
+
return engine.get_dimension_summary(dimensions)
|