Adedoyinjames commited on
Commit
900a36d
Β·
verified Β·
1 Parent(s): 5252ee6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -12
app.py CHANGED
@@ -1,8 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  import uvicorn
4
  import torch
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  import time
7
  from fastapi.middleware.cors import CORSMiddleware
8
 
@@ -24,8 +39,7 @@ app.add_middleware(
24
 
25
  class YAHBot:
26
  def __init__(self):
27
- # βœ… Changed to load from your HF repo instead of direct model name
28
- self.repo_id = "Adedoyinjames/brain-ai" # Your HF repo
29
  self.tokenizer = None
30
  self.model = None
31
  self._load_model()
@@ -34,8 +48,16 @@ class YAHBot:
34
  """Load the model from your Hugging Face repo"""
35
  try:
36
  print(f"πŸ”„ Loading AI model from {self.repo_id}...")
37
- self.tokenizer = AutoTokenizer.from_pretrained(self.repo_id)
38
- self.model = AutoModelForSeq2SeqLM.from_pretrained(self.repo_id)
 
 
 
 
 
 
 
 
39
  print("βœ… AI model loaded successfully from HF repo!")
40
  except Exception as e:
41
  print(f"❌ Failed to load AI model from repo: {e}")
@@ -43,12 +65,12 @@ class YAHBot:
43
  self.tokenizer = None
44
 
45
  def generate_response(self, user_input):
46
- """Generate response using AI model"""
47
  if self.model and self.tokenizer:
48
  try:
49
- prompt = f"Question: {user_input}\nAnswer: "
 
50
 
51
- # Tokenize
52
  inputs = self.tokenizer(
53
  prompt,
54
  return_tensors="pt",
@@ -57,18 +79,27 @@ class YAHBot:
57
  padding=True
58
  )
59
 
60
- # Generate response
 
 
 
61
  with torch.no_grad():
62
  outputs = self.model.generate(
63
  inputs.input_ids,
64
- max_length=150,
65
  num_return_sequences=1,
66
  temperature=0.7,
67
  do_sample=True,
68
- pad_token_id=self.tokenizer.pad_token_id,
 
69
  )
70
 
71
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
72
  return response
73
 
74
  except Exception as e:
@@ -100,7 +131,8 @@ async def root():
100
  return {
101
  "message": "YAH Tech AI API is running",
102
  "status": "active",
103
- "model_repo": yah_bot.repo_id, # Show which repo is being used
 
104
  "endpoints": {
105
  "chat": "POST /api/chat",
106
  "health": "GET /api/health"
 
1
+ # Install required dependencies
2
+ import subprocess
3
+ import sys
4
+
5
+ def install_packages():
6
+ packages = ["sentencepiece", "protobuf", "transformers", "torch", "accelerate"]
7
+ for package in packages:
8
+ try:
9
+ __import__(package)
10
+ except ImportError:
11
+ print(f"Installing {package}...")
12
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
13
+
14
+ install_packages()
15
+
16
  from fastapi import FastAPI, HTTPException
17
  from pydantic import BaseModel
18
  import uvicorn
19
  import torch
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM
21
  import time
22
  from fastapi.middleware.cors import CORSMiddleware
23
 
 
39
 
40
  class YAHBot:
41
  def __init__(self):
42
+ self.repo_id = "Adedoyinjames/brain-ai"
 
43
  self.tokenizer = None
44
  self.model = None
45
  self._load_model()
 
48
  """Load the model from your Hugging Face repo"""
49
  try:
50
  print(f"πŸ”„ Loading AI model from {self.repo_id}...")
51
+ self.tokenizer = AutoTokenizer.from_pretrained(
52
+ self.repo_id,
53
+ trust_remote_code=True # Required for phi-3
54
+ )
55
+ self.model = AutoModelForCausalLM.from_pretrained(
56
+ self.repo_id,
57
+ trust_remote_code=True, # Required for phi-3
58
+ torch_dtype=torch.float16,
59
+ device_map="auto"
60
+ )
61
  print("βœ… AI model loaded successfully from HF repo!")
62
  except Exception as e:
63
  print(f"❌ Failed to load AI model from repo: {e}")
 
65
  self.tokenizer = None
66
 
67
  def generate_response(self, user_input):
68
+ """Generate response using causal language model"""
69
  if self.model and self.tokenizer:
70
  try:
71
+ # Format prompt for phi-3 (causal LM)
72
+ prompt = f"<|user|>\n{user_input}<|end|>\n<|assistant|>\n"
73
 
 
74
  inputs = self.tokenizer(
75
  prompt,
76
  return_tensors="pt",
 
79
  padding=True
80
  )
81
 
82
+ # Move to same device as model
83
+ device = next(self.model.parameters()).device
84
+ inputs = {k: v.to(device) for k, v in inputs.items()}
85
+
86
  with torch.no_grad():
87
  outputs = self.model.generate(
88
  inputs.input_ids,
89
+ max_new_tokens=150,
90
  num_return_sequences=1,
91
  temperature=0.7,
92
  do_sample=True,
93
+ pad_token_id=self.tokenizer.eos_token_id, # Use EOS token for padding
94
+ eos_token_id=self.tokenizer.eos_token_id,
95
  )
96
 
97
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
98
+
99
+ # Remove the prompt from the response for cleaner output
100
+ if prompt in response:
101
+ response = response.replace(prompt, "").strip()
102
+
103
  return response
104
 
105
  except Exception as e:
 
131
  return {
132
  "message": "YAH Tech AI API is running",
133
  "status": "active",
134
+ "model_repo": yah_bot.repo_id,
135
+ "model_type": "causal_lm",
136
  "endpoints": {
137
  "chat": "POST /api/chat",
138
  "health": "GET /api/health"