Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import ( | |
| BertForQuestionAnswering, | |
| BertTokenizerFast, | |
| ) | |
| from transformers import pipeline | |
| from scipy.special import softmax | |
| import pandas as pd | |
| import numpy as np | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| model_name = 'deepset/bert-base-uncased-squad2' | |
| pipe = pipeline("question-answering", model=model_name) | |
| # model = BertForQuestionAnswering.from_pretrained(model_name) | |
| # tokenizer = BertTokenizerFast.from_pretrained(model_name) | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allow all origins | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allow all HTTP methods | |
| allow_headers=["*"], # Allow all headers | |
| ) | |
| def predict_answer(context, question): | |
| response = pipe({"context": context, "question": question}) | |
| return { | |
| "answer": response['answer'], | |
| "score": response['score'] | |
| } | |
| # inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512) | |
| # with torch.no_grad(): | |
| # outputs = model(**inputs) | |
| # start_scores, end_scores = softmax(outputs.start_logits)[0], softmax(outputs.end_logits)[0] | |
| # start_idx = np.argmax(start_scores) | |
| # end_idx = np.argmax(end_scores) | |
| # confidence_score = (start_scores[start_idx] + end_scores[end_idx]) / 2 | |
| # answer_ids = inputs.input_ids[0][start_idx: end_idx + 1] | |
| # answer_tokens = tokenizer.convert_ids_to_tokens(answer_ids) | |
| # answer = tokenizer.convert_tokens_to_string(answer_tokens) | |
| # if answer != tokenizer.cls_token: | |
| # return { | |
| # "answer": answer, | |
| # "score": confidence_score | |
| # } | |
| # else: | |
| # return { | |
| # "answer": "No answer found.", | |
| # "score": confidence_score | |
| # } | |
| # Define the request model | |
| class QnARequest(BaseModel): | |
| context: str | |
| question: str | |
| # Define the response model | |
| class QnAResponse(BaseModel): | |
| answer: str | |
| confidence: float | |
| async def extractive_qna(request: QnARequest): | |
| context = request.context | |
| question = request.question | |
| # print(context, question) | |
| if not context or not question: | |
| raise HTTPException(status_code=400, detail="Context and question cannot be empty.") | |
| try: | |
| result = predict_answer(context, question) | |
| print(result) | |
| return QnAResponse(answer=result["answer"], confidence=result["score"]) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing QnA: {str(e)}") | |