sairampillai commited on
Commit
54db4e6
·
unverified ·
1 Parent(s): 70de74a

Fix review comments and formatting

Browse files
Files changed (3) hide show
  1. app.py +4 -45
  2. helpers/chat_helper.py +60 -0
  3. helpers/llm_helper.py +37 -48
app.py CHANGED
@@ -16,52 +16,11 @@ import ollama
16
  import requests
17
  import streamlit as st
18
  from dotenv import load_dotenv
19
- # Custom message classes to replace LangChain components
20
- class ChatMessage:
21
- def __init__(self, content: str, role: str):
22
- self.content = content
23
- self.role = role
24
- self.type = role # For compatibility with existing code
25
-
26
- class HumanMessage(ChatMessage):
27
- def __init__(self, content: str):
28
- super().__init__(content, "user")
29
-
30
- class AIMessage(ChatMessage):
31
- def __init__(self, content: str):
32
- super().__init__(content, "ai")
33
-
34
- class StreamlitChatMessageHistory:
35
- def __init__(self, key: str):
36
- self.key = key
37
- if key not in st.session_state:
38
- st.session_state[key] = []
39
-
40
- @property
41
- def messages(self):
42
- return st.session_state[self.key]
43
-
44
- def add_user_message(self, content: str):
45
- st.session_state[self.key].append(HumanMessage(content))
46
-
47
- def add_ai_message(self, content: str):
48
- st.session_state[self.key].append(AIMessage(content))
49
-
50
- class ChatPromptTemplate:
51
- def __init__(self, template: str):
52
- self.template = template
53
-
54
- @classmethod
55
- def from_template(cls, template: str):
56
- return cls(template)
57
-
58
- def format(self, **kwargs):
59
- return self.template.format(**kwargs)
60
 
61
  import global_config as gcfg
62
  import helpers.file_manager as filem
63
  from global_config import GlobalConfig
64
- from helpers import llm_helper, pptx_helper, text_helper
65
 
66
  load_dotenv()
67
 
@@ -333,8 +292,8 @@ def set_up_chat_ui():
333
  st.info(APP_TEXT['like_feedback'])
334
  st.chat_message('ai').write(random.choice(APP_TEXT['ai_greetings']))
335
 
336
- history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
337
- prompt_template = ChatPromptTemplate.from_template(
338
  _get_prompt_template(
339
  is_refinement=_is_it_refinement()
340
  )
@@ -653,7 +612,7 @@ def _get_user_messages() -> List[str]:
653
  """
654
 
655
  return [
656
- msg.content for msg in st.session_state[CHAT_MESSAGES] if isinstance(msg, HumanMessage)
657
  ]
658
 
659
 
 
16
  import requests
17
  import streamlit as st
18
  from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  import global_config as gcfg
21
  import helpers.file_manager as filem
22
  from global_config import GlobalConfig
23
+ from helpers import chat_helper, llm_helper, pptx_helper, text_helper
24
 
25
  load_dotenv()
26
 
 
292
  st.info(APP_TEXT['like_feedback'])
293
  st.chat_message('ai').write(random.choice(APP_TEXT['ai_greetings']))
294
 
295
+ history = chat_helper.StreamlitChatMessageHistory(key=CHAT_MESSAGES)
296
+ prompt_template = chat_helper.ChatPromptTemplate.from_template(
297
  _get_prompt_template(
298
  is_refinement=_is_it_refinement()
299
  )
 
612
  """
613
 
614
  return [
615
+ msg.content for msg in st.session_state[CHAT_MESSAGES] if isinstance(msg, chat_helper.HumanMessage)
616
  ]
617
 
618
 
helpers/chat_helper.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chat helper classes to replace LangChain components.
3
+ """
4
+ import streamlit as st
5
+
6
+
7
+ class ChatMessage:
8
+ """Base class for chat messages."""
9
+
10
+ def __init__(self, content: str, role: str):
11
+ self.content = content
12
+ self.role = role
13
+ self.type = role # For compatibility with existing code
14
+
15
+
16
+ class HumanMessage(ChatMessage):
17
+ """Message from human user."""
18
+
19
+ def __init__(self, content: str):
20
+ super().__init__(content, 'user')
21
+
22
+
23
+ class AIMessage(ChatMessage):
24
+ """Message from AI assistant."""
25
+
26
+ def __init__(self, content: str):
27
+ super().__init__(content, 'ai')
28
+
29
+
30
+ class StreamlitChatMessageHistory:
31
+ """Chat message history stored in Streamlit session state."""
32
+
33
+ def __init__(self, key: str):
34
+ self.key = key
35
+ if key not in st.session_state:
36
+ st.session_state[key] = []
37
+
38
+ @property
39
+ def messages(self):
40
+ return st.session_state[self.key]
41
+
42
+ def add_user_message(self, content: str):
43
+ st.session_state[self.key].append(HumanMessage(content))
44
+
45
+ def add_ai_message(self, content: str):
46
+ st.session_state[self.key].append(AIMessage(content))
47
+
48
+
49
+ class ChatPromptTemplate:
50
+ """Template for chat prompts."""
51
+
52
+ def __init__(self, template: str):
53
+ self.template = template
54
+
55
+ @classmethod
56
+ def from_template(cls, template: str):
57
+ return cls(template)
58
+
59
+ def format(self, **kwargs):
60
+ return self.template.format(**kwargs)
helpers/llm_helper.py CHANGED
@@ -73,17 +73,23 @@ def get_provider_model(provider_model: str, use_ollama: bool) -> Tuple[str, str]
73
 
74
  # Validate that the provider is in the valid providers list
75
  if inside_brackets not in GlobalConfig.VALID_PROVIDERS:
76
- logger.warning(f"Provider '{inside_brackets}' not in VALID_PROVIDERS: {GlobalConfig.VALID_PROVIDERS}")
 
 
 
77
  return '', ''
78
 
79
  # Validate that the model name is not empty
80
  if not outside_brackets.strip():
81
- logger.warning(f"Empty model name for provider '{inside_brackets}'")
82
  return '', ''
83
 
84
  return inside_brackets, outside_brackets
85
 
86
- logger.warning(f"Could not parse provider_model: '{provider_model}' (use_ollama={use_ollama})")
 
 
 
87
  return '', ''
88
 
89
 
@@ -135,38 +141,20 @@ def get_litellm_model_name(provider: str, model: str) -> str:
135
  Convert provider and model to LiteLLM model name format.
136
  """
137
  provider_prefix_map = {
138
- GlobalConfig.PROVIDER_HUGGING_FACE: "huggingface",
139
- GlobalConfig.PROVIDER_GOOGLE_GEMINI: "gemini",
140
- GlobalConfig.PROVIDER_AZURE_OPENAI: "azure",
141
- GlobalConfig.PROVIDER_OPENROUTER: "openrouter",
142
- GlobalConfig.PROVIDER_COHERE: "cohere",
143
- GlobalConfig.PROVIDER_TOGETHER_AI: "together_ai",
144
- GlobalConfig.PROVIDER_OLLAMA: "ollama",
145
  }
146
  prefix = provider_prefix_map.get(provider)
147
  if prefix:
148
- return f"{prefix}/{model}"
149
  return model
150
 
151
 
152
- def get_litellm_api_key(provider: str, api_key: str) -> str:
153
- """
154
- Get the appropriate API key for LiteLLM based on provider.
155
- """
156
- # All listed providers just return the api_key, so we can use a set for clarity
157
- providers_with_api_key = {
158
- GlobalConfig.PROVIDER_OPENROUTER,
159
- GlobalConfig.PROVIDER_COHERE,
160
- GlobalConfig.PROVIDER_TOGETHER_AI,
161
- GlobalConfig.PROVIDER_GOOGLE_GEMINI,
162
- GlobalConfig.PROVIDER_AZURE_OPENAI,
163
- GlobalConfig.PROVIDER_HUGGING_FACE,
164
- }
165
- if provider in providers_with_api_key:
166
- return api_key
167
- return api_key
168
-
169
-
170
  def stream_litellm_completion(
171
  provider: str,
172
  model: str,
@@ -200,34 +188,32 @@ def stream_litellm_completion(
200
  # This is consistent with Azure OpenAI's requirement to use deployment names
201
  if not azure_deployment_name:
202
  raise ValueError("Azure deployment name is required for Azure OpenAI provider")
203
- litellm_model = f"azure/{azure_deployment_name}"
204
  else:
205
  litellm_model = get_litellm_model_name(provider, model)
206
 
207
  # Prepare the request parameters
208
  request_params = {
209
- "model": litellm_model,
210
- "messages": messages,
211
- "max_tokens": max_tokens,
212
- "temperature": GlobalConfig.LLM_MODEL_TEMPERATURE,
213
- "stream": True,
214
  }
215
 
216
  # Set API key and any provider-specific params
217
  if provider != GlobalConfig.PROVIDER_OLLAMA:
218
- # For OpenRouter, set environment variable as per documentation
219
  if provider == GlobalConfig.PROVIDER_OPENROUTER:
220
- os.environ["OPENROUTER_API_KEY"] = api_key
221
- # Don't add API key to request_params for OpenRouter
222
  elif provider == GlobalConfig.PROVIDER_AZURE_OPENAI:
223
- # For Azure OpenAI, set environment variables as per documentation
224
- os.environ["AZURE_API_KEY"] = api_key
225
- os.environ["AZURE_API_BASE"] = azure_endpoint_url
226
- os.environ["AZURE_API_VERSION"] = azure_api_version
227
  else:
228
  # For other providers, pass API key as parameter
229
- api_key_to_use = get_litellm_api_key(provider, api_key)
230
- request_params["api_key"] = api_key_to_use
231
 
232
  logger.debug('Streaming completion via LiteLLM: %s', litellm_model)
233
 
@@ -245,7 +231,7 @@ def stream_litellm_completion(
245
  yield choice.message.content
246
 
247
  except Exception as e:
248
- logger.error(f"Error in LiteLLM completion: {e}")
249
  raise
250
 
251
 
@@ -277,7 +263,10 @@ def get_litellm_llm(
277
 
278
  # Create a simple wrapper object that mimics the LangChain streaming interface
279
  class LiteLLMWrapper:
280
- def __init__(self, provider, model, max_tokens, api_key, azure_endpoint_url, azure_deployment_name, azure_api_version):
 
 
 
281
  self.provider = provider
282
  self.model = model
283
  self.max_tokens = max_tokens
@@ -287,7 +276,7 @@ def get_litellm_llm(
287
  self.azure_api_version = azure_api_version
288
 
289
  def stream(self, prompt: str):
290
- messages = [{"role": "user", "content": prompt}]
291
  return stream_litellm_completion(
292
  provider=self.provider,
293
  model=self.model,
@@ -323,4 +312,4 @@ if __name__ == '__main__':
323
  ]
324
 
325
  for text in inputs:
326
- print(get_provider_model(text, use_ollama=False))
 
73
 
74
  # Validate that the provider is in the valid providers list
75
  if inside_brackets not in GlobalConfig.VALID_PROVIDERS:
76
+ logger.warning(
77
+ "Provider '%s' not in VALID_PROVIDERS: %s",
78
+ inside_brackets, GlobalConfig.VALID_PROVIDERS
79
+ )
80
  return '', ''
81
 
82
  # Validate that the model name is not empty
83
  if not outside_brackets.strip():
84
+ logger.warning("Empty model name for provider '%s'", inside_brackets)
85
  return '', ''
86
 
87
  return inside_brackets, outside_brackets
88
 
89
+ logger.warning(
90
+ "Could not parse provider_model: '%s' (use_ollama=%s)",
91
+ provider_model, use_ollama
92
+ )
93
  return '', ''
94
 
95
 
 
141
  Convert provider and model to LiteLLM model name format.
142
  """
143
  provider_prefix_map = {
144
+ GlobalConfig.PROVIDER_HUGGING_FACE: 'huggingface',
145
+ GlobalConfig.PROVIDER_GOOGLE_GEMINI: 'gemini',
146
+ GlobalConfig.PROVIDER_AZURE_OPENAI: 'azure',
147
+ GlobalConfig.PROVIDER_OPENROUTER: 'openrouter',
148
+ GlobalConfig.PROVIDER_COHERE: 'cohere',
149
+ GlobalConfig.PROVIDER_TOGETHER_AI: 'together_ai',
150
+ GlobalConfig.PROVIDER_OLLAMA: 'ollama',
151
  }
152
  prefix = provider_prefix_map.get(provider)
153
  if prefix:
154
+ return '%s/%s' % (prefix, model)
155
  return model
156
 
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  def stream_litellm_completion(
159
  provider: str,
160
  model: str,
 
188
  # This is consistent with Azure OpenAI's requirement to use deployment names
189
  if not azure_deployment_name:
190
  raise ValueError("Azure deployment name is required for Azure OpenAI provider")
191
+ litellm_model = 'azure/%s' % azure_deployment_name
192
  else:
193
  litellm_model = get_litellm_model_name(provider, model)
194
 
195
  # Prepare the request parameters
196
  request_params = {
197
+ 'model': litellm_model,
198
+ 'messages': messages,
199
+ 'max_tokens': max_tokens,
200
+ 'temperature': GlobalConfig.LLM_MODEL_TEMPERATURE,
201
+ 'stream': True,
202
  }
203
 
204
  # Set API key and any provider-specific params
205
  if provider != GlobalConfig.PROVIDER_OLLAMA:
206
+ # For OpenRouter, pass API key as parameter
207
  if provider == GlobalConfig.PROVIDER_OPENROUTER:
208
+ request_params['api_key'] = api_key
 
209
  elif provider == GlobalConfig.PROVIDER_AZURE_OPENAI:
210
+ # For Azure OpenAI, pass credentials as parameters
211
+ request_params['api_key'] = api_key
212
+ request_params['azure_api_base'] = azure_endpoint_url
213
+ request_params['azure_api_version'] = azure_api_version
214
  else:
215
  # For other providers, pass API key as parameter
216
+ request_params['api_key'] = api_key
 
217
 
218
  logger.debug('Streaming completion via LiteLLM: %s', litellm_model)
219
 
 
231
  yield choice.message.content
232
 
233
  except Exception as e:
234
+ logger.error('Error in LiteLLM completion: %s', e)
235
  raise
236
 
237
 
 
263
 
264
  # Create a simple wrapper object that mimics the LangChain streaming interface
265
  class LiteLLMWrapper:
266
+ def __init__(
267
+ self, provider, model, max_tokens, api_key, azure_endpoint_url,
268
+ azure_deployment_name, azure_api_version
269
+ ):
270
  self.provider = provider
271
  self.model = model
272
  self.max_tokens = max_tokens
 
276
  self.azure_api_version = azure_api_version
277
 
278
  def stream(self, prompt: str):
279
+ messages = [{'role': 'user', 'content': prompt}]
280
  return stream_litellm_completion(
281
  provider=self.provider,
282
  model=self.model,
 
312
  ]
313
 
314
  for text in inputs:
315
+ print(get_provider_model(text, use_ollama=False))