Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify | |
| from dotenv import load_dotenv | |
| import os | |
| import pymongo | |
| import google.generativeai as genai | |
| from flask_cors import CORS | |
| from tqdm import tqdm | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Access the key | |
| MONGODB_URI = os.getenv('MONGODB_URI') | |
| EMBEDDING_MODEL = os.getenv('EMBEDDING_MODEL') or 'keepitreal/vietnamese-sbert' | |
| DB_NAME = os.getenv('DB_NAME') | |
| DB_COLLECTION = os.getenv('DB_COLLECTION') | |
| GEMINI_KEY = os.getenv('GEMINI_KEY') | |
| genai.configure(api_key=GEMINI_KEY) | |
| model = genai.GenerativeModel('gemini-1.5-pro') | |
| client = pymongo.MongoClient(MONGODB_URI) | |
| db = client[DB_NAME] | |
| collection = db[DB_COLLECTION] | |
| app = Flask(__name__) | |
| CORS(app) | |
| from sentence_transformers import SentenceTransformer | |
| embedding_model = SentenceTransformer(EMBEDDING_MODEL) | |
| def vector_search(user_query, collection, limit=4): | |
| """ | |
| Perform a vector search in the MongoDB collection based on the user query. | |
| Args: | |
| user_query (str): The user's query string. | |
| collection (MongoCollection): The MongoDB collection to search. | |
| Returns: | |
| list: A list of matching documents. | |
| """ | |
| # Generate embedding for the user query | |
| query_embedding = get_embedding(user_query) | |
| if query_embedding is None: | |
| return "Invalid query or embedding generation failed." | |
| # Define the vector search pipeline | |
| vector_search_stage = { | |
| "$vectorSearch": { | |
| "index": "vector_index", | |
| "queryVector": query_embedding, | |
| "path": "embedding", | |
| "numCandidates": 150, | |
| "limit": limit, | |
| } | |
| } | |
| unset_stage = { | |
| "$unset": "embedding" | |
| } | |
| project_stage = { | |
| "$project": { | |
| "_id": 0, | |
| "title": 1, | |
| "details": 1, | |
| "price": 1, | |
| "promotion_price": 1, | |
| "size_options": 1, | |
| "gender_options": 1, | |
| "quantity": 1, | |
| "stock": 1, | |
| "is_shoes": 1, | |
| "is_sandals": 1, | |
| } | |
| } | |
| pipeline = [vector_search_stage, unset_stage, project_stage] | |
| # Execute the search | |
| results = collection.aggregate(pipeline) | |
| return list(results) | |
| def get_search_result(query, collection): | |
| get_knowledge = vector_search(query, collection, 10) | |
| search_result = "" | |
| i = 0 | |
| for result in get_knowledge: | |
| # print(result) | |
| i += 1 | |
| if result.get('price'): | |
| search_result += f"\n\nSản phẩm {i+1}: {result.get('title')}, Giá: {result.get('price')}" | |
| if result.get('promotion_price'): | |
| search_result += f", Giá ưu đãi: {result.get('promotion_price')}" | |
| if result.get('stock'): | |
| search_result += f", Trạng thái: {result.get('stock')}" | |
| if result.get('is_shoes') == True: | |
| search_result += f", Loại: Giày" | |
| if result.get('is_sandals') == True: | |
| search_result += f", Loại: Dép" | |
| if result.get('size_options'): | |
| search_result += f", Size: {result.get('size_options')}" | |
| if result.get('gender_options'): | |
| search_result += f", Dành cho: {result.get('gender_options')}" | |
| if result.get('details'): | |
| search_result += f", Chi tiết sản phẩm: {result.get('details')}" | |
| return search_result | |
| def get_embedding(text): | |
| if not text.strip(): | |
| print("Attempted to get embedding for empty text.") | |
| return [] | |
| embedding = embedding_model.encode(text) | |
| return embedding.tolist() | |
| def process_query(query): | |
| return query.lower() | |
| def handle_query(): | |
| data = request.get_json() | |
| query = process_query(data.get('question')) | |
| if not query: | |
| return jsonify({'error': 'No query provided'}), 400 | |
| # Retrieve data from vector database | |
| source_information = get_search_result(query, collection).replace('<br>', '\n') | |
| combined_information = f"Hãy trở thành chuyên gia tư vấn bán hàng cho một website bán giày dép ThuThaoShoes. Câu hỏi của khách hàng: {query}\nTrả lời câu hỏi dựa vào các thông tin sản phẩm dưới đây: {source_information}." | |
| response = model.generate_content(combined_information) | |
| return jsonify({ | |
| 'content': response.text | |
| }) | |
| def get_embedding_api(): | |
| # Lấy tất cả các tài liệu từ collection | |
| documents = list(collection.find({})) | |
| for doc in tqdm(documents, desc="Processing documents"): | |
| product_specs = doc.get('title', '') | |
| product_cat = doc.get('category', '') | |
| print(product_specs + ' ' + product_cat) | |
| embedding = get_embedding(product_specs + ' Danh mục: ' + product_cat) | |
| if embedding is not None: | |
| # Cập nhật tài liệu với embedding mới | |
| collection.update_one( | |
| {'_id': doc['_id']}, | |
| {'$set': {'embedding': embedding}} | |
| ) | |
| return jsonify({'message': 'Embedding cập nhật thành công cho tất cả các tài liệu.'}) | |
| if __name__ == '__main__': | |
| app.run(debug=True) | |