File size: 2,105 Bytes
87a7535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import requests
import openai
from utils.config import config

class LLMClient:
    def __init__(self, provider="ollama", model_name=None):
        self.provider = provider
        self.model_name = model_name or config.local_model_name

        # Set up OpenAI client for Hugging Face endpoint
        self.hf_client = openai.OpenAI(
            base_url=config.hf_api_url,
            api_key=config.hf_token
        )

    def generate(self, prompt, max_tokens=8192, stream=True):
        if self.provider == "ollama":
            return self._generate_ollama(prompt, max_tokens, stream)
        elif self.provider == "huggingface":
            return self._generate_hf(prompt, max_tokens, stream)
        else:
            raise ValueError(f"Unsupported provider: {self.provider}")

    def _generate_ollama(self, prompt, max_tokens, stream):
        url = f"{config.ollama_host}/api/generate"
        payload = {
            "model": self.model_name,
            "prompt": prompt,
            "stream": stream
        }

        try:
            with requests.post(url, json=payload, stream=stream) as response:
                if response.status_code != 200:
                    raise Exception(f"Ollama API error: {response.text}")

                if stream:
                    return (chunk.decode("utf-8") for chunk in response.iter_content())
                else:
                    return response.json()["response"]
        except Exception as e:
            raise Exception(f"Ollama request failed: {e}")

    def _generate_hf(self, prompt, max_tokens, stream):
        try:
            response = self.hf_client.chat.completions.create(
                model=self.model_name,
                messages=[{"role": "user", "content": prompt}],
                max_tokens=max_tokens,
                stream=stream
            )
            if stream:
                return (chunk.choices[0].delta.content or "" for chunk in response)
            else:
                return response.choices[0].text
        except Exception as e:
            raise Exception(f"Hugging Face API error: {e}")