Spaces:
Running
Running
Update gen_api_answer.py
Browse files- gen_api_answer.py +27 -0
gen_api_answer.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from openai import OpenAI
|
| 2 |
import anthropic
|
| 3 |
from together import Together
|
|
|
|
| 4 |
import json
|
| 5 |
import re
|
| 6 |
import os
|
|
@@ -11,6 +12,7 @@ anthropic_client = anthropic.Anthropic()
|
|
| 11 |
openai_client = OpenAI()
|
| 12 |
together_client = Together()
|
| 13 |
hf_api_key = os.getenv("HF_API_KEY")
|
|
|
|
| 14 |
huggingface_client = OpenAI(
|
| 15 |
base_url="https://otb7jglxy6r37af6.us-east-1.aws.endpoints.huggingface.cloud/v1/",
|
| 16 |
api_key=hf_api_key
|
|
@@ -93,6 +95,27 @@ def get_hf_response(model_name, prompt, max_tokens=500):
|
|
| 93 |
except Exception as e:
|
| 94 |
return f"Error with Hugging Face model {model_name}: {str(e)}"
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
def get_model_response(
|
| 97 |
model_name,
|
| 98 |
model_info,
|
|
@@ -127,6 +150,10 @@ def get_model_response(
|
|
| 127 |
return get_hf_response(
|
| 128 |
api_model, prompt, max_tokens
|
| 129 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
else:
|
| 131 |
# All other organizations use Together API
|
| 132 |
return get_together_response(
|
|
|
|
| 1 |
from openai import OpenAI
|
| 2 |
import anthropic
|
| 3 |
from together import Together
|
| 4 |
+
import cohere
|
| 5 |
import json
|
| 6 |
import re
|
| 7 |
import os
|
|
|
|
| 12 |
openai_client = OpenAI()
|
| 13 |
together_client = Together()
|
| 14 |
hf_api_key = os.getenv("HF_API_KEY")
|
| 15 |
+
cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
|
| 16 |
huggingface_client = OpenAI(
|
| 17 |
base_url="https://otb7jglxy6r37af6.us-east-1.aws.endpoints.huggingface.cloud/v1/",
|
| 18 |
api_key=hf_api_key
|
|
|
|
| 95 |
except Exception as e:
|
| 96 |
return f"Error with Hugging Face model {model_name}: {str(e)}"
|
| 97 |
|
| 98 |
+
def get_cohere_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
|
| 99 |
+
"""Get response from Cohere API"""
|
| 100 |
+
try:
|
| 101 |
+
response = cohere_client.chat(
|
| 102 |
+
model=model_name,
|
| 103 |
+
messages=[
|
| 104 |
+
{"role": "system", "content": system_prompt},
|
| 105 |
+
{"role": "user", "content": prompt}
|
| 106 |
+
],
|
| 107 |
+
max_tokens=max_tokens,
|
| 108 |
+
temperature=temperature
|
| 109 |
+
)
|
| 110 |
+
# Extract the text from the content items
|
| 111 |
+
content_items = response.message.content
|
| 112 |
+
if isinstance(content_items, list):
|
| 113 |
+
# Get the text from the first content item
|
| 114 |
+
return content_items[0].text
|
| 115 |
+
return str(content_items) # Fallback if it's not a list
|
| 116 |
+
except Exception as e:
|
| 117 |
+
return f"Error with Cohere model {model_name}: {str(e)}"
|
| 118 |
+
|
| 119 |
def get_model_response(
|
| 120 |
model_name,
|
| 121 |
model_info,
|
|
|
|
| 150 |
return get_hf_response(
|
| 151 |
api_model, prompt, max_tokens
|
| 152 |
)
|
| 153 |
+
elif organization == "Cohere":
|
| 154 |
+
return get_cohere_response(
|
| 155 |
+
api_model, prompt, system_prompt, max_tokens, temperature
|
| 156 |
+
)
|
| 157 |
else:
|
| 158 |
# All other organizations use Together API
|
| 159 |
return get_together_response(
|