PeterPinetree commited on
Commit
bf08f52
·
1 Parent(s): 71d95c1

Update to correct Inference Providers chat completions API format

Browse files
Files changed (1) hide show
  1. app.py +42 -13
app.py CHANGED
@@ -10,8 +10,8 @@ from dotenv import load_dotenv
10
  load_dotenv()
11
 
12
  # Configuration
13
- API_BASE = "https://router.huggingface.co/hf-inference/models/"
14
- MODEL_ID = "openai-community/gpt2"
15
  HF_TOKEN = os.getenv('HF_NEXT_TOKEN_PREDICTOR_TOKEN', '')
16
 
17
  def show_token(token: str) -> str:
@@ -34,20 +34,20 @@ def predict_next_token(text: str, top_k: int = 10, hide_punctuation: bool = Fals
34
  start_time = time.time()
35
 
36
  try:
37
- # Call Hugging Face Serverless Inference API
38
- url = f"{API_BASE}{MODEL_ID}"
39
  headers = {
40
  'Authorization': f'Bearer {HF_TOKEN}',
41
  'Content-Type': 'application/json',
42
  }
43
  payload = {
 
44
  'inputs': text,
45
  'parameters': {
46
  'max_new_tokens': 1,
 
47
  'do_sample': False,
48
- 'return_full_text': False,
49
- 'details': True,
50
- 'top_k': min(top_k, 50) # API limitation
51
  }
52
  }
53
 
@@ -60,10 +60,10 @@ def predict_next_token(text: str, top_k: int = 10, hide_punctuation: bool = Fals
60
  print(f"Response text: {response.text}")
61
 
62
  if not response.ok:
63
- # Try GPT-2 Medium as fallback if the main model fails
64
- if MODEL_ID == "openai-community/gpt2":
65
- print(f"Main model failed, trying GPT-2 Medium fallback...")
66
- fallback_url = f"{API_BASE}openai-community/gpt2-medium"
67
  fallback_response = requests.post(fallback_url, headers=headers, json=payload, timeout=30)
68
  print(f"Fallback response status: {fallback_response.status_code}")
69
  if fallback_response.ok:
@@ -86,8 +86,17 @@ def predict_next_token(text: str, top_k: int = 10, hide_punctuation: bool = Fals
86
  result = response.json()
87
  prediction_time = int((time.time() - start_time) * 1000)
88
 
89
- # Parse response and create token list
90
- tokens_html = create_token_display(result, top_k, hide_punctuation)
 
 
 
 
 
 
 
 
 
91
 
92
  return tokens_html, f"Prediction time: {prediction_time}ms"
93
 
@@ -98,6 +107,26 @@ def predict_next_token(text: str, top_k: int = 10, hide_punctuation: bool = Fals
98
  except Exception as e:
99
  return f"❌ Error: {str(e)}", ""
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  def create_token_display(api_result: dict, top_k: int, hide_punctuation: bool) -> str:
102
  """Create HTML display for predicted tokens"""
103
 
 
10
  load_dotenv()
11
 
12
  # Configuration
13
+ API_BASE = "https://router.huggingface.co/v1/"
14
+ MODEL_ID = "Qwen/Qwen3-0.6B"
15
  HF_TOKEN = os.getenv('HF_NEXT_TOKEN_PREDICTOR_TOKEN', '')
16
 
17
  def show_token(token: str) -> str:
 
34
  start_time = time.time()
35
 
36
  try:
37
+ # Call Hugging Face Inference Providers API (Text Generation format)
38
+ url = f"{API_BASE}text-generation"
39
  headers = {
40
  'Authorization': f'Bearer {HF_TOKEN}',
41
  'Content-Type': 'application/json',
42
  }
43
  payload = {
44
+ 'model': MODEL_ID,
45
  'inputs': text,
46
  'parameters': {
47
  'max_new_tokens': 1,
48
+ 'temperature': 0.0,
49
  'do_sample': False,
50
+ 'return_full_text': False
 
 
51
  }
52
  }
53
 
 
60
  print(f"Response text: {response.text}")
61
 
62
  if not response.ok:
63
+ # Try a different Qwen model as fallback if the main model fails
64
+ if MODEL_ID == "Qwen/Qwen3-0.6B":
65
+ print(f"Main model failed, trying Qwen2.5-0.5B fallback...")
66
+ fallback_url = f"{API_BASE}Qwen/Qwen2.5-0.5B-Instruct"
67
  fallback_response = requests.post(fallback_url, headers=headers, json=payload, timeout=30)
68
  print(f"Fallback response status: {fallback_response.status_code}")
69
  if fallback_response.ok:
 
86
  result = response.json()
87
  prediction_time = int((time.time() - start_time) * 1000)
88
 
89
+ # Parse chat completion response - it returns a single message, not probabilities
90
+ try:
91
+ predicted_text = result['choices'][0]['message']['content'].strip()
92
+ # Extract just the next word (in case model returns more)
93
+ next_word = predicted_text.split()[0] if predicted_text else "?"
94
+
95
+ # Create simple display since we don't have probabilities
96
+ tokens_html = create_simple_token_display(next_word)
97
+
98
+ except (KeyError, IndexError) as e:
99
+ return f"❌ Error parsing response: {str(e)}", ""
100
 
101
  return tokens_html, f"Prediction time: {prediction_time}ms"
102
 
 
107
  except Exception as e:
108
  return f"❌ Error: {str(e)}", ""
109
 
110
+ def create_simple_token_display(predicted_word: str) -> str:
111
+ """Create HTML display for a single predicted token (chat completions format)"""
112
+
113
+ # Create HTML for single token
114
+ html = """
115
+ <div style="font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace; background: #0e162b; border: 1px solid #1c2945; border-radius: 14px; padding: 12px;">
116
+ """
117
+
118
+ token_display = show_token(predicted_word)
119
+
120
+ html += f"""
121
+ <div style="display: grid; grid-template-columns: 1fr auto; gap: 8px; align-items: center; padding: 8px 10px; margin: 4px 0; border-radius: 10px; background: #0f1930; border: 1px solid #22365e; cursor: pointer;">
122
+ <div style="color: #e6f1ff; font-size: 14px;">{token_display}</div>
123
+ <div style="color: #9ab0d0; font-size: 12px;">Predicted</div>
124
+ </div>
125
+ """
126
+
127
+ html += "</div>"
128
+ return html
129
+
130
  def create_token_display(api_result: dict, top_k: int, hide_punctuation: bool) -> str:
131
  """Create HTML display for predicted tokens"""
132