rdune71's picture
Implement LLM abstraction layer
87a7535
raw
history blame
2.11 kB
import requests
import openai
from utils.config import config
class LLMClient:
def __init__(self, provider="ollama", model_name=None):
self.provider = provider
self.model_name = model_name or config.local_model_name
# Set up OpenAI client for Hugging Face endpoint
self.hf_client = openai.OpenAI(
base_url=config.hf_api_url,
api_key=config.hf_token
)
def generate(self, prompt, max_tokens=8192, stream=True):
if self.provider == "ollama":
return self._generate_ollama(prompt, max_tokens, stream)
elif self.provider == "huggingface":
return self._generate_hf(prompt, max_tokens, stream)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def _generate_ollama(self, prompt, max_tokens, stream):
url = f"{config.ollama_host}/api/generate"
payload = {
"model": self.model_name,
"prompt": prompt,
"stream": stream
}
try:
with requests.post(url, json=payload, stream=stream) as response:
if response.status_code != 200:
raise Exception(f"Ollama API error: {response.text}")
if stream:
return (chunk.decode("utf-8") for chunk in response.iter_content())
else:
return response.json()["response"]
except Exception as e:
raise Exception(f"Ollama request failed: {e}")
def _generate_hf(self, prompt, max_tokens, stream):
try:
response = self.hf_client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
stream=stream
)
if stream:
return (chunk.choices[0].delta.content or "" for chunk in response)
else:
return response.choices[0].text
except Exception as e:
raise Exception(f"Hugging Face API error: {e}")