| from helpers import get_credentials | |
| import requests | |
| def hf_inference(prompt, model_id, temperature, max_new_tokens): | |
| hf_token, _ = get_credentials.get_hf_credentials() | |
| API_URL = "https://router.huggingface.co/together/v1/chat/completions" | |
| headers = { | |
| "Authorization": f"Bearer {hf_token}", | |
| } | |
| response = requests.post( | |
| API_URL, | |
| headers=headers, | |
| json={ | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": prompt | |
| }, | |
| ] | |
| } | |
| ], | |
| "model": model_id, | |
| 'temperature': temperature, | |
| 'max_new_tokens': max_new_tokens, | |
| } | |
| ) | |
| return response.json()["choices"][0]["message"] | |
| def replicate_inference(prompt, model_id, temperature, max_new_tokens): | |
| repl_token = get_credentials.get_replicate_credentials() | |
| API_URL = f"https://api.replicate.com/v1/models/{model_id}/predictions" | |
| headers = { | |
| "Authorization": f"Bearer {repl_token}", | |
| "Content-Type": "application/json", | |
| "Prefer": "wait" | |
| } | |
| response = requests.post( | |
| API_URL, | |
| headers=headers, | |
| json={ | |
| "input": { | |
| "prompt": prompt, | |
| "temperature": temperature, | |
| "max_tokens": max_new_tokens, | |
| } | |
| } | |
| ) | |
| return { | |
| "content": "".join(response.json()['output']) | |
| } | |
| INFERENCE_HANDLER = { | |
| 'huggingface': hf_inference, | |
| 'replicate': replicate_inference | |
| } |