|
|
import requests |
|
|
import openai |
|
|
from utils.config import config |
|
|
|
|
|
class LLMClient: |
|
|
def __init__(self, provider="ollama", model_name=None): |
|
|
self.provider = provider |
|
|
self.model_name = model_name or config.local_model_name |
|
|
|
|
|
|
|
|
self.hf_client = openai.OpenAI( |
|
|
base_url=config.hf_api_url, |
|
|
api_key=config.hf_token |
|
|
) |
|
|
|
|
|
def generate(self, prompt, max_tokens=8192, stream=True): |
|
|
if self.provider == "ollama": |
|
|
return self._generate_ollama(prompt, max_tokens, stream) |
|
|
elif self.provider == "huggingface": |
|
|
return self._generate_hf(prompt, max_tokens, stream) |
|
|
else: |
|
|
raise ValueError(f"Unsupported provider: {self.provider}") |
|
|
|
|
|
def _generate_ollama(self, prompt, max_tokens, stream): |
|
|
url = f"{config.ollama_host}/api/generate" |
|
|
payload = { |
|
|
"model": self.model_name, |
|
|
"prompt": prompt, |
|
|
"stream": stream |
|
|
} |
|
|
|
|
|
try: |
|
|
with requests.post(url, json=payload, stream=stream) as response: |
|
|
if response.status_code != 200: |
|
|
raise Exception(f"Ollama API error: {response.text}") |
|
|
|
|
|
if stream: |
|
|
return (chunk.decode("utf-8") for chunk in response.iter_content()) |
|
|
else: |
|
|
return response.json()["response"] |
|
|
except Exception as e: |
|
|
raise Exception(f"Ollama request failed: {e}") |
|
|
|
|
|
def _generate_hf(self, prompt, max_tokens, stream): |
|
|
try: |
|
|
response = self.hf_client.chat.completions.create( |
|
|
model=self.model_name, |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
max_tokens=max_tokens, |
|
|
stream=stream |
|
|
) |
|
|
if stream: |
|
|
return (chunk.choices[0].delta.content or "" for chunk in response) |
|
|
else: |
|
|
return response.choices[0].text |
|
|
except Exception as e: |
|
|
raise Exception(f"Hugging Face API error: {e}") |
|
|
|