helloperson123 commited on
Commit
e4bb209
·
verified ·
1 Parent(s): f5e2a29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -10
app.py CHANGED
@@ -1,25 +1,36 @@
1
  from flask import Flask, request, jsonify
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
  app = Flask(__name__)
6
 
 
7
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
11
  @app.route("/api/ask", methods=["POST"])
12
  def ask():
13
  data = request.get_json()
14
  prompt = data.get("prompt", "")
15
- inputs = tokenizer(prompt, return_tensors="pt")
16
- outputs = model.generate(**inputs, max_new_tokens=100)
17
- reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
18
- return jsonify({"reply": reply})
19
-
20
- @app.route("/")
21
- def home():
22
- return "✅ TinyLlama API is running!"
 
 
 
 
 
 
 
 
 
 
23
 
24
  if __name__ == "__main__":
25
  app.run(host="0.0.0.0", port=7860)
 
1
  from flask import Flask, request, jsonify
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
  app = Flask(__name__)
6
 
7
+ # Load TinyLlama model
8
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
11
 
12
  @app.route("/api/ask", methods=["POST"])
13
  def ask():
14
  data = request.get_json()
15
  prompt = data.get("prompt", "")
16
+
17
+ # Make it respond like a chatbot
18
+ chat_prompt = f"### Instruction:\nYou are a helpful, friendly chatbot named Acla. Reply conversationally.\n\n### Input:\n{prompt}\n\n### Response:"
19
+
20
+ inputs = tokenizer(chat_prompt, return_tensors="pt").to(model.device)
21
+ outputs = model.generate(
22
+ **inputs,
23
+ max_new_tokens=150,
24
+ temperature=0.7,
25
+ top_p=0.9,
26
+ do_sample=True
27
+ )
28
+
29
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
+ # Strip off the system prompt from output
31
+ response = response.split("### Response:")[-1].strip()
32
+
33
+ return jsonify({"reply": response})
34
 
35
  if __name__ == "__main__":
36
  app.run(host="0.0.0.0", port=7860)