Sofia Santos commited on
Commit
b31e7ff
·
1 Parent(s): f02ee48

feat: adds langchain-openai

Browse files
Files changed (1) hide show
  1. tdagent/grchat.py +183 -24
tdagent/grchat.py CHANGED
@@ -1,5 +1,6 @@
1
  from __future__ import annotations
2
 
 
3
  from collections import OrderedDict
4
  from collections.abc import Mapping, Sequence
5
  from types import MappingProxyType
@@ -14,6 +15,7 @@ from langchain_aws import ChatBedrock
14
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
15
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
16
  from langchain_mcp_adapters.client import MultiServerMCPClient
 
17
  from langgraph.prebuilt import create_react_agent
18
  from openai import OpenAI
19
  from openai.types.chat import ChatCompletion
@@ -77,6 +79,14 @@ MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
77
  # ),
78
  },
79
  ),
 
 
 
 
 
 
 
 
80
  ),
81
  )
82
 
@@ -128,12 +138,15 @@ def create_bedrock_llm(
128
  def create_hf_llm(
129
  hf_model_id: str,
130
  huggingfacehub_api_token: str | None = None,
 
 
131
  ) -> tuple[ChatHuggingFace | None, str]:
132
  """Create a LangGraph Hugging Face agent."""
133
  try:
134
  llm = HuggingFaceEndpoint(
135
  model=hf_model_id,
136
- temperature=0.8,
 
137
  task="text-generation",
138
  huggingfacehub_api_token=huggingfacehub_api_token,
139
  )
@@ -166,6 +179,34 @@ def create_openai_llm(
166
  return llm, ""
167
 
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  #### UI functionality ####
170
  async def gr_connect_to_bedrock( # noqa: PLR0913
171
  model_id: str,
@@ -230,11 +271,18 @@ async def gr_connect_to_hf(
230
  model_id: str,
231
  hf_access_token_textbox: str | None,
232
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
 
 
233
  ) -> str:
234
  """Initialize Hugging Face agent."""
235
  global llm_agent # noqa: PLW0603
236
 
237
- llm, error = create_hf_llm(model_id, hf_access_token_textbox)
 
 
 
 
 
238
 
239
  if llm is None:
240
  return f"❌ Connection failed: {error}"
@@ -260,6 +308,51 @@ async def gr_connect_to_hf(
260
  return "✅ Successfully connected to Hugging Face!"
261
 
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  async def gr_connect_to_nebius(
264
  model_id: str,
265
  nebius_access_token_textbox: str,
@@ -334,6 +427,9 @@ def toggle_model_fields(
334
  dict[str, Any],
335
  dict[str, Any],
336
  dict[str, Any],
 
 
 
337
  ]: # ignore: F821
338
  """Toggle visibility of model fields based on the selected provider."""
339
  # Update model choices based on the selected provider
@@ -351,6 +447,8 @@ def toggle_model_fields(
351
  # Visibility settings for fields specific to each provider
352
  is_aws = provider == "AWS Bedrock"
353
  is_hf = provider == "HuggingFace"
 
 
354
  return (
355
  model_pretty,
356
  gr.update(visible=is_aws, interactive=is_aws),
@@ -358,43 +456,62 @@ def toggle_model_fields(
358
  gr.update(visible=is_aws, interactive=is_aws),
359
  gr.update(visible=is_aws, interactive=is_aws),
360
  gr.update(visible=is_hf, interactive=is_hf),
 
 
 
361
  )
362
 
363
 
364
  async def update_connection_status( # noqa: PLR0913
365
  provider: str,
366
- pretty_model: str,
367
  mcp_list_state: Sequence[MutableCheckBoxGroupEntry] | None,
368
  aws_access_key_textbox: str,
369
  aws_secret_key_textbox: str,
370
  aws_session_token_textbox: str,
371
  aws_region_dropdown: str,
372
  hf_token: str,
 
 
 
373
  temperature: float,
374
  max_tokens: int,
375
  ) -> str:
376
  """Update the connection status based on the selected provider and model."""
377
- if not provider or not pretty_model:
378
  return "❌ Please select a provider and model."
379
-
380
- model_id = MODEL_OPTIONS.get(provider, {}).get(pretty_model)
381
  connection = "❌ Invalid provider"
382
- if model_id:
383
- if provider == "AWS Bedrock":
384
- connection = await gr_connect_to_bedrock(
385
- model_id,
386
- aws_access_key_textbox,
387
- aws_secret_key_textbox,
388
- aws_session_token_textbox,
389
- aws_region_dropdown,
390
- mcp_list_state,
391
- temperature,
392
- max_tokens,
393
- )
394
- elif provider == "HuggingFace":
395
- connection = await gr_connect_to_hf(model_id, hf_token, mcp_list_state)
396
- elif provider == "Nebius":
397
- connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
  return connection
400
 
@@ -468,13 +585,39 @@ with (
468
  placeholder="Enter your Hugging Face Access Token",
469
  visible=False,
470
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
  with gr.Accordion("🧠 Model Configuration", open=True):
473
  model_display_id = gr.Dropdown(
474
- label="Select Model ID",
475
  choices=[],
476
  visible=False,
477
  )
 
 
 
 
 
 
 
478
  model_provider.change(
479
  toggle_model_fields,
480
  inputs=[model_provider],
@@ -485,8 +628,21 @@ with (
485
  aws_session_token_textbox,
486
  aws_region_dropdown,
487
  hf_token,
 
 
 
488
  ],
489
  )
 
 
 
 
 
 
 
 
 
 
490
  # Initialize the temperature and max tokens based on model specifications
491
  temperature = gr.Slider(
492
  label="Temperature",
@@ -510,13 +666,16 @@ with (
510
  update_connection_status,
511
  inputs=[
512
  model_provider,
513
- model_display_id,
514
  mcp_list.state,
515
  aws_access_key_textbox,
516
  aws_secret_key_textbox,
517
  aws_session_token_textbox,
518
  aws_region_dropdown,
519
  hf_token,
 
 
 
520
  temperature,
521
  max_tokens,
522
  ],
 
1
  from __future__ import annotations
2
 
3
+ import os
4
  from collections import OrderedDict
5
  from collections.abc import Mapping, Sequence
6
  from types import MappingProxyType
 
15
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
16
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
17
  from langchain_mcp_adapters.client import MultiServerMCPClient
18
+ from langchain_openai import AzureChatOpenAI
19
  from langgraph.prebuilt import create_react_agent
20
  from openai import OpenAI
21
  from openai.types.chat import ChatCompletion
 
79
  # ),
80
  },
81
  ),
82
+ (
83
+ "Azure OpenAI",
84
+ {
85
+ "GPT-4o": ("ggpt-4o-global-standard"),
86
+ "GPT-4o Mini": ("o4-mini"),
87
+ "GPT-4.5 Preview": ("gpt-4.5-preview"),
88
+ },
89
+ ),
90
  ),
91
  )
92
 
 
138
  def create_hf_llm(
139
  hf_model_id: str,
140
  huggingfacehub_api_token: str | None = None,
141
+ temperature: float = 0.8,
142
+ max_tokens: int = 512,
143
  ) -> tuple[ChatHuggingFace | None, str]:
144
  """Create a LangGraph Hugging Face agent."""
145
  try:
146
  llm = HuggingFaceEndpoint(
147
  model=hf_model_id,
148
+ temperature=temperature,
149
+ max_new_tokens=max_tokens,
150
  task="text-generation",
151
  huggingfacehub_api_token=huggingfacehub_api_token,
152
  )
 
179
  return llm, ""
180
 
181
 
182
+ def create_azure_llm(
183
+ model_id: str,
184
+ api_version: str,
185
+ endpoint: str,
186
+ token_id: str,
187
+ temperature: float = 0.8,
188
+ max_tokens: int = 512,
189
+ ) -> tuple[AzureChatOpenAI | None, str]:
190
+ """Create a LangGraph Azure OpenAI agent."""
191
+ try:
192
+ os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint
193
+ os.environ["AZURE_OPENAI_API_KEY"] = token_id
194
+ if "o4-mini" in model_id:
195
+ kwargs = {"max_completion_tokens": max_tokens}
196
+ else:
197
+ kwargs = {"max_tokens": max_tokens}
198
+ llm = AzureChatOpenAI(
199
+ azure_deployment=model_id,
200
+ api_key=token_id,
201
+ api_version=api_version,
202
+ temperature=temperature,
203
+ **kwargs,
204
+ )
205
+ except Exception as e: # noqa: BLE001
206
+ return None, str(e)
207
+ return llm, ""
208
+
209
+
210
  #### UI functionality ####
211
  async def gr_connect_to_bedrock( # noqa: PLR0913
212
  model_id: str,
 
271
  model_id: str,
272
  hf_access_token_textbox: str | None,
273
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
274
+ temperature: float = 0.8,
275
+ max_tokens: int = 512,
276
  ) -> str:
277
  """Initialize Hugging Face agent."""
278
  global llm_agent # noqa: PLW0603
279
 
280
+ llm, error = create_hf_llm(
281
+ model_id,
282
+ hf_access_token_textbox,
283
+ temperature=temperature,
284
+ max_tokens=max_tokens,
285
+ )
286
 
287
  if llm is None:
288
  return f"❌ Connection failed: {error}"
 
308
  return "✅ Successfully connected to Hugging Face!"
309
 
310
 
311
+ async def gr_connect_to_azure(
312
+ model_id: str,
313
+ azure_endpoint: str,
314
+ api_key: str,
315
+ api_version: str,
316
+ mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
317
+ temperature: float = 0.8,
318
+ max_tokens: int = 512,
319
+ ) -> str:
320
+ """Initialize Hugging Face agent."""
321
+ global llm_agent # noqa: PLW0603
322
+
323
+ llm, error = create_azure_llm(
324
+ model_id,
325
+ api_version=api_version,
326
+ endpoint=azure_endpoint,
327
+ token_id=api_key,
328
+ temperature=temperature,
329
+ max_tokens=max_tokens,
330
+ )
331
+
332
+ if llm is None:
333
+ return f"❌ Connection failed: {error}"
334
+ tools = []
335
+ if mcp_servers:
336
+ client = MultiServerMCPClient(
337
+ {
338
+ server.name.replace(" ", "-"): {
339
+ "url": server.value,
340
+ "transport": "sse",
341
+ }
342
+ for server in mcp_servers
343
+ },
344
+ )
345
+ tools = await client.get_tools()
346
+
347
+ llm_agent = create_react_agent(
348
+ model=llm,
349
+ tools=tools,
350
+ prompt=SYSTEM_MESSAGE,
351
+ )
352
+
353
+ return "✅ Successfully connected to Azure OpenAI!"
354
+
355
+
356
  async def gr_connect_to_nebius(
357
  model_id: str,
358
  nebius_access_token_textbox: str,
 
427
  dict[str, Any],
428
  dict[str, Any],
429
  dict[str, Any],
430
+ dict[str, Any],
431
+ dict[str, Any],
432
+ dict[str, Any],
433
  ]: # ignore: F821
434
  """Toggle visibility of model fields based on the selected provider."""
435
  # Update model choices based on the selected provider
 
447
  # Visibility settings for fields specific to each provider
448
  is_aws = provider == "AWS Bedrock"
449
  is_hf = provider == "HuggingFace"
450
+ is_azure = provider == "Azure OpenAI"
451
+ # is_nebius = provider == "Nebius"
452
  return (
453
  model_pretty,
454
  gr.update(visible=is_aws, interactive=is_aws),
 
456
  gr.update(visible=is_aws, interactive=is_aws),
457
  gr.update(visible=is_aws, interactive=is_aws),
458
  gr.update(visible=is_hf, interactive=is_hf),
459
+ gr.update(visible=is_azure, interactive=is_azure),
460
+ gr.update(visible=is_azure, interactive=is_azure),
461
+ gr.update(visible=is_azure, interactive=is_azure),
462
  )
463
 
464
 
465
  async def update_connection_status( # noqa: PLR0913
466
  provider: str,
467
+ model_id: str,
468
  mcp_list_state: Sequence[MutableCheckBoxGroupEntry] | None,
469
  aws_access_key_textbox: str,
470
  aws_secret_key_textbox: str,
471
  aws_session_token_textbox: str,
472
  aws_region_dropdown: str,
473
  hf_token: str,
474
+ azure_endpoint: str,
475
+ azure_api_token: str,
476
+ azure_api_version: str,
477
  temperature: float,
478
  max_tokens: int,
479
  ) -> str:
480
  """Update the connection status based on the selected provider and model."""
481
+ if not provider or not model_id:
482
  return "❌ Please select a provider and model."
 
 
483
  connection = "❌ Invalid provider"
484
+ if provider == "AWS Bedrock":
485
+ connection = await gr_connect_to_bedrock(
486
+ model_id,
487
+ aws_access_key_textbox,
488
+ aws_secret_key_textbox,
489
+ aws_session_token_textbox,
490
+ aws_region_dropdown,
491
+ mcp_list_state,
492
+ temperature,
493
+ max_tokens,
494
+ )
495
+ elif provider == "HuggingFace":
496
+ connection = await gr_connect_to_hf(
497
+ model_id,
498
+ hf_token,
499
+ mcp_list_state,
500
+ temperature,
501
+ max_tokens,
502
+ )
503
+ elif provider == "Azure OpenAI":
504
+ connection = await gr_connect_to_azure(
505
+ model_id,
506
+ azure_endpoint,
507
+ azure_api_token,
508
+ azure_api_version,
509
+ mcp_list_state,
510
+ temperature,
511
+ max_tokens,
512
+ )
513
+ elif provider == "Nebius":
514
+ connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
515
 
516
  return connection
517
 
 
585
  placeholder="Enter your Hugging Face Access Token",
586
  visible=False,
587
  )
588
+ azure_endpoint = gr.Textbox(
589
+ label="Azure OpenAI Endpoint",
590
+ type="text",
591
+ placeholder="Enter your Azure OpenAI Endpoint",
592
+ visible=False,
593
+ )
594
+ azure_api_token = gr.Textbox(
595
+ label="Azure Access Token",
596
+ type="password",
597
+ placeholder="Enter your Azure OpenAI Access Token",
598
+ visible=False,
599
+ )
600
+ azure_api_version = gr.Textbox(
601
+ label="Azure OpenAI API Version",
602
+ type="text",
603
+ placeholder="Enter your Azure OpenAI API Version",
604
+ value="2024-12-01-preview",
605
+ visible=False,
606
+ )
607
 
608
  with gr.Accordion("🧠 Model Configuration", open=True):
609
  model_display_id = gr.Dropdown(
610
+ label="Select Model from the list",
611
  choices=[],
612
  visible=False,
613
  )
614
+ model_id_textbox = gr.Textbox(
615
+ label="Model ID",
616
+ type="text",
617
+ placeholder="Enter the model ID",
618
+ visible=False,
619
+ interactive=True,
620
+ )
621
  model_provider.change(
622
  toggle_model_fields,
623
  inputs=[model_provider],
 
628
  aws_session_token_textbox,
629
  aws_region_dropdown,
630
  hf_token,
631
+ azure_endpoint,
632
+ azure_api_token,
633
+ azure_api_version,
634
  ],
635
  )
636
+ model_display_id.change(
637
+ lambda x, y: gr.update(
638
+ value=MODEL_OPTIONS.get(y, {}).get(x),
639
+ visible=True,
640
+ )
641
+ if x
642
+ else model_id_textbox.value,
643
+ inputs=[model_display_id, model_provider],
644
+ outputs=[model_id_textbox],
645
+ )
646
  # Initialize the temperature and max tokens based on model specifications
647
  temperature = gr.Slider(
648
  label="Temperature",
 
666
  update_connection_status,
667
  inputs=[
668
  model_provider,
669
+ model_id_textbox,
670
  mcp_list.state,
671
  aws_access_key_textbox,
672
  aws_secret_key_textbox,
673
  aws_session_token_textbox,
674
  aws_region_dropdown,
675
  hf_token,
676
+ azure_endpoint,
677
+ azure_api_token,
678
+ azure_api_version,
679
  temperature,
680
  max_tokens,
681
  ],