helloperson123 commited on
Commit
e08d99f
·
verified ·
1 Parent(s): 437a17e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -36
app.py CHANGED
@@ -1,44 +1,28 @@
1
- from transformers import pipeline
2
- from flask import Flask, request, render_template_string
 
3
 
4
  app = Flask(__name__)
5
 
6
- pipe = pipeline("text-generation", model="sshleifer/tiny-gpt2")
 
 
 
7
 
8
- HTML_PAGE = """
9
- <!DOCTYPE html>
10
- <html>
11
- <head>
12
- <title>TinyGPT2 Chat</title>
13
- <style>
14
- body { font-family: sans-serif; margin: 40px; }
15
- textarea { width: 100%; height: 100px; }
16
- button { margin-top: 10px; padding: 8px 16px; }
17
- .output { margin-top: 20px; white-space: pre-wrap; }
18
- </style>
19
- </head>
20
- <body>
21
- <h1>🤖 TinyGPT2 Chat</h1>
22
- <form method="POST">
23
- <textarea name="prompt" placeholder="Type your message here...">{{prompt}}</textarea><br>
24
- <button type="submit">Generate</button>
25
- </form>
26
- {% if output %}
27
- <div class="output"><strong>AI:</strong> {{output}}</div>
28
- {% endif %}
29
- </body>
30
- </html>
31
- """
32
 
33
- @app.route("/", methods=["GET", "POST"])
34
- def chat():
35
- output = ""
36
- prompt = ""
37
- if request.method == "POST":
38
- prompt = request.form["prompt"]
39
- result = pipe(prompt, max_length=100)[0]["generated_text"]
40
- output = result
41
- return render_template_string(HTML_PAGE, output=output, prompt=prompt)
42
 
43
  if __name__ == "__main__":
44
  app.run(host="0.0.0.0", port=7860)
 
1
+ from flask import Flask, request, jsonify
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
 
5
  app = Flask(__name__)
6
 
7
+ # Load model and tokenizer
8
+ MODEL_NAME = "openai-community/gpt2" # or your Tiny LLaMA model
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")
11
 
12
+ @app.route("/api/ask", methods=["POST"])
13
+ def ask():
14
+ data = request.get_json()
15
+ prompt = data.get("prompt", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ inputs = tokenizer(prompt, return_tensors="pt")
18
+ outputs = model.generate(**inputs, max_new_tokens=50)
19
+ text = tokenizer.decode(outputs[0], skip_special_tokens=True)
20
+
21
+ return jsonify({"reply": text})
22
+
23
+ @app.route("/")
24
+ def home():
25
+ return "✅ Model API running! POST JSON to /api/ask with {'prompt': 'your text'}"
26
 
27
  if __name__ == "__main__":
28
  app.run(host="0.0.0.0", port=7860)