pradeei commited on
Commit
d3df84f
·
verified ·
1 Parent(s): 1c41095

Update app.py

Browse files

Added IBM endpoints

Files changed (1) hide show
  1. app.py +147 -77
app.py CHANGED
@@ -1,77 +1,147 @@
1
- # Import necessary libraries
2
-
3
- import streamlit as st
4
- import os
5
- from openai import OpenAI
6
- import json
7
-
8
- def clear_chat():
9
- st.session_state.messages = []
10
-
11
- st.title("Intel® AI for Enterprise Inference")
12
- st.header("LLM chatbot")
13
-
14
- with st.sidebar:
15
- api_key = st.session_state.api_key = st.secrets["openai_apikey"] #Enter openai_api key under "Secrets " in HF settings
16
- base_url = st.session_state.base_url = os.environ.get("base_url") #Enter base_url under "Variables" in HF settings
17
- client = OpenAI(api_key=api_key, base_url=base_url)
18
- models = client.models.list()
19
- model_names = sorted([model.id for model in models]) # Extract 'id' from each model object
20
- default_model_name = "meta-llama/Llama-3.3-70B-Instruct" # Replace with your desired default model name
21
-
22
- # Use st.session_state to persist the selected model
23
- if "selected_model" not in st.session_state:
24
- st.session_state.selected_model = default_model_name if default_model_name in model_names else model_names[0]
25
-
26
- # Create the selectbox without the `index` parameter
27
- modelname = st.selectbox(
28
- "Select an LLM model (running on Intel® Gaudi®). Hosted on Denvr Dataworks.",
29
- model_names,
30
- key="selected_model", # This ties the widget to st.session_state["selected_model"]
31
- )
32
- st.write(f"You selected: {modelname}")
33
- st.button("Start New Chat", on_click=clear_chat)
34
-
35
- st.markdown("---") # Add a horizontal line for separation
36
- st.markdown(
37
- """
38
- Check the latest models hosted on [Denvr Dataworks](https://www.denvrdata.com/intel), and get your own OpenAI-compatible API key.
39
-
40
- Come and chat with other AI developers on [Intel’s DevHub Discord server](https://discord.gg/kfJ3NKEw5t).
41
- """
42
- )
43
-
44
- try:
45
- if "messages" not in st.session_state:
46
- st.session_state.messages = []
47
-
48
- for message in st.session_state.messages:
49
- with st.chat_message(message["role"]):
50
- st.markdown(message["content"])
51
-
52
- if prompt := st.chat_input("What is up?"):
53
- st.session_state.messages.append({"role": "user", "content": prompt})
54
- with st.chat_message("user"):
55
- st.markdown(prompt)
56
-
57
- with st.chat_message("assistant"):
58
- try:
59
- stream = client.chat.completions.create(
60
- model=modelname,
61
- messages=[
62
- {"role": m["role"], "content": m["content"]}
63
- for m in st.session_state.messages
64
- ],
65
- max_tokens=4096,
66
- stream=True,
67
- )
68
- response = st.write_stream(stream)
69
- except Exception as e:
70
- st.error(f"An error occurred while generating the response: {e}")
71
- response = "An error occurred while generating the response."
72
-
73
- st.session_state.messages.append({"role": "assistant", "content": response})
74
- except KeyError as e:
75
- st.error(f"Key error: {e}")
76
- except Exception as e:
77
- st.error(f"An unexpected error occurred: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from openai import OpenAI
4
+ import json
5
+
6
+ def clear_chat():
7
+ st.session_state.messages = []
8
+
9
+ def initialize_provider_settings(provider_choice):
10
+ """Configure API settings based on provider selection"""
11
+ provider_configs = {
12
+ "Denvr Dataworks": {
13
+ "api_key_source": st.secrets.get("openai_apikey", ""),
14
+ "base_url_source": os.environ.get("base_url", ""),
15
+ "fallback_model": "meta-llama/Llama-3.3-70B-Instruct"
16
+ },
17
+ "IBM": {
18
+ "api_key_source": os.environ.get("ibm_openai_apikey", ""),
19
+ "base_url_source": os.environ.get("ibm_base_url", ""),
20
+ "fallback_model": None
21
+ }
22
+ }
23
+
24
+ return provider_configs.get(provider_choice, {})
25
+
26
+ st.title("Intel® AI for Enterprise Inference")
27
+ st.header("LLM chatbot")
28
+
29
+ with st.sidebar:
30
+ # Provider selection dropdown
31
+ available_providers = ["Denvr Dataworks", "IBM"]
32
+
33
+ if "current_provider_choice" not in st.session_state:
34
+ st.session_state.current_provider_choice = available_providers[0]
35
+
36
+ provider_selection = st.selectbox(
37
+ "Choose AI Provider:",
38
+ available_providers,
39
+ key="current_provider_choice"
40
+ )
41
+
42
+ # Get provider-specific settings
43
+ provider_settings = initialize_provider_settings(provider_selection)
44
+
45
+ # Validate required credentials
46
+ if not provider_settings.get("api_key_source") or not provider_settings.get("base_url_source"):
47
+ st.error(f"Configuration missing for {provider_selection}. Check environment variables.")
48
+ st.stop()
49
+
50
+ # Setup OpenAI client
51
+ try:
52
+ api_client = OpenAI(
53
+ api_key=provider_settings["api_key_source"],
54
+ base_url=provider_settings["base_url_source"]
55
+ )
56
+ available_models = api_client.models.list()
57
+ model_list = sorted([m.id for m in available_models])
58
+
59
+ # Handle model selection with provider switching
60
+ session_key = f"model_for_{provider_selection}"
61
+ if session_key not in st.session_state or st.session_state.get("last_provider") != provider_selection:
62
+ preferred_model = provider_settings.get("fallback_model")
63
+ if preferred_model and preferred_model in model_list:
64
+ st.session_state[session_key] = preferred_model
65
+ elif model_list:
66
+ st.session_state[session_key] = model_list[0]
67
+ st.session_state.last_provider = provider_selection
68
+
69
+ if not model_list:
70
+ st.error(f"No models found for {provider_selection}")
71
+ st.stop()
72
+
73
+ # Model selection interface
74
+ chosen_model = st.selectbox(
75
+ f"Available models from {provider_selection}:",
76
+ model_list,
77
+ key=session_key,
78
+ )
79
+ st.info(f"Active model: {chosen_model}")
80
+
81
+ except Exception as connection_error:
82
+ st.error(f"Connection failed for {provider_selection}: {connection_error}")
83
+ st.stop()
84
+
85
+ st.button("Reset Conversation", on_click=clear_chat)
86
+
87
+ st.markdown("---")
88
+
89
+ # Display provider-specific information
90
+ if provider_selection == "Denvr Dataworks":
91
+ st.markdown(
92
+ """
93
+ **Denvr Dataworks Integration**
94
+
95
+ Visit [Denvr Dataworks](https://www.denvrdata.com/intel) for model information and API access.
96
+
97
+ Join the community: [Intel's DevHub Discord](https://discord.gg/kfJ3NKEw5t)
98
+ """
99
+ )
100
+ elif provider_selection == "IBM":
101
+ st.markdown(
102
+ """
103
+ **IBM AI Services**
104
+
105
+ Connected to IBM's AI infrastructure. Ensure your credentials are properly configured.
106
+ """
107
+ )
108
+
109
+ # Main chat interface
110
+ try:
111
+ if "messages" not in st.session_state:
112
+ st.session_state.messages = []
113
+
114
+ # Display conversation history
115
+ for msg in st.session_state.messages:
116
+ with st.chat_message(msg["role"]):
117
+ st.markdown(msg["content"])
118
+
119
+ # Handle new user input
120
+ if user_input := st.chat_input("Enter your message..."):
121
+ st.session_state.messages.append({"role": "user", "content": user_input})
122
+ with st.chat_message("user"):
123
+ st.markdown(user_input)
124
+
125
+ # Generate AI response
126
+ with st.chat_message("assistant"):
127
+ try:
128
+ response_stream = api_client.chat.completions.create(
129
+ model=chosen_model,
130
+ messages=[
131
+ {"role": msg["role"], "content": msg["content"]}
132
+ for msg in st.session_state.messages
133
+ ],
134
+ max_tokens=4096,
135
+ stream=True,
136
+ )
137
+ ai_response = st.write_stream(response_stream)
138
+ except Exception as generation_error:
139
+ st.error(f"Response generation failed: {generation_error}")
140
+ ai_response = "Unable to generate response due to an error."
141
+
142
+ st.session_state.messages.append({"role": "assistant", "content": ai_response})
143
+
144
+ except KeyError as key_err:
145
+ st.error(f"Configuration key error: {key_err}")
146
+ except Exception as general_err:
147
+ st.error(f"Unexpected error occurred: {general_err}")