Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | |
| import torch | |
| from pydantic import BaseModel, Field | |
| class RequestGenerate(BaseModel): | |
| prompt: str | |
| do_sample: bool = Field(default=bool(True), example=True) | |
| top_k: int = Field(default=int(1), example=1), | |
| temperature: float = Field(default=float(0.9), example=0.9), | |
| max_new_tokens: int = Field(default=int(500), example=500), | |
| repetition_penalty: float = Field(default=float(1.5), example=1.5), | |
| app = FastAPI() | |
| # model_name_or_id = "AI4Chem/ChemLLM-7B-Chat" | |
| model_name_or_id = "AI4Chem/CHEMLLM-2b-1_5" | |
| model = AutoModelForCausalLM.from_pretrained(model_name_or_id,trust_remote_code=True) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_id,trust_remote_code=True) | |
| def greet_json(): | |
| return {"Hello": "World!"} | |
| def generate(req: RequestGenerate): | |
| inputs = tokenizer(req.prompt, return_tensors="pt") | |
| generation_config = GenerationConfig( | |
| do_sample=req.do_sample, | |
| top_k=req.top_k, | |
| temperature=req.temperature, | |
| max_new_tokens=req.max_new_tokens, | |
| repetition_penalty=req.repetition_penalty, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| outputs = model.generate(**inputs, generation_config=generation_config) | |
| # print(tokenizer.decode(outputs[0], skip_special_tokens=True)) | |
| return {"text": tokenizer.decode(outputs[0], skip_special_tokens=True)} | |