Starberry15 commited on
Commit
f62d086
·
verified ·
1 Parent(s): 64b21b3

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +11 -75
src/streamlit_app.py CHANGED
@@ -29,9 +29,8 @@ else:
29
  login(token=HF_TOKEN)
30
 
31
  if GEMINI_API_KEY:
32
- gemini_client = genai.Client(api_key=GEMINI_API_KEY)
33
  else:
34
- gemini_client = None
35
  st.warning("⚠️ Gemini API key missing. Gemini 2.5 Flash will not work.")
36
 
37
  # ======================================================
@@ -63,9 +62,9 @@ with st.sidebar:
63
  )
64
 
65
  temperature = st.slider("Temperature", 0.0, 1.0, 0.3)
66
- max_tokens = st.slider("Max Tokens", 128, 2048, 512)
67
 
68
- # Initialize inference clients
69
  hf_cleaner_client = InferenceClient(model=CLEANER_MODEL, token=HF_TOKEN)
70
  hf_analyst_client = None
71
  if ANALYST_MODEL != "Gemini 2.5 Flash (Google)":
@@ -75,7 +74,6 @@ if ANALYST_MODEL != "Gemini 2.5 Flash (Google)":
75
  # 🧩 SAFE GENERATION FUNCTION
76
  # ======================================================
77
  def safe_hf_generate(client, prompt, temperature=0.3, max_tokens=512):
78
- """HF text generation fallback to chat_completion"""
79
  try:
80
  resp = client.text_generation(
81
  prompt,
@@ -113,9 +111,7 @@ def fallback_clean(df: pd.DataFrame) -> pd.DataFrame:
113
  df.drop_duplicates(inplace=True)
114
  return df
115
 
116
-
117
  def ai_clean_dataset(df: pd.DataFrame) -> pd.DataFrame:
118
- """Clean dataset using AI. Full dataset sent for thorough cleaning."""
119
  csv_text = df.to_csv(index=False)
120
  prompt = f"""
121
  You are a professional data cleaning assistant.
@@ -129,7 +125,6 @@ Return ONLY a valid CSV text (no markdown, no explanations).
129
  Dataset:
130
  {csv_text}
131
  """
132
-
133
  try:
134
  cleaned_str = safe_hf_generate(hf_cleaner_client, prompt, temperature=0.1, max_tokens=4096)
135
  except Exception as e:
@@ -145,32 +140,16 @@ Dataset:
145
  st.warning(f"⚠️ AI CSV parse failed: {e}")
146
  return fallback_clean(df)
147
 
148
-
149
  # ======================================================
150
  # 🧩 DATA ANALYSIS
151
  # ======================================================
152
- def summarize_dataframe(df: pd.DataFrame) -> str:
153
- lines = [f"Rows: {len(df)} | Columns: {len(df.columns)}", "Column summaries:"]
154
- for col in df.columns[:10]:
155
- non_null = int(df[col].notnull().sum())
156
- if pd.api.types.is_numeric_dtype(df[col]):
157
- desc = df[col].describe().to_dict()
158
- mean = float(desc.get("mean", np.nan))
159
- median = float(df[col].median()) if non_null > 0 else None
160
- lines.append(f"- {col}: mean={mean:.3f}, median={median}, non_null={non_null}")
161
- else:
162
- top = df[col].value_counts().head(3).to_dict()
163
- lines.append(f"- {col}: top_values={top}, non_null={non_null}")
164
- return "\n".join(lines)
165
-
166
-
167
  def query_analysis_model(df: pd.DataFrame, user_query: str, dataset_name: str) -> str:
168
  csv_text = df.to_csv(index=False)
169
  prompt = f"""
170
  You are a professional data analyst.
171
  Analyze the dataset '{dataset_name}' and answer the user's question.
172
 
173
- --- FULL DATA SAMPLE ---
174
  {csv_text}
175
 
176
  --- USER QUESTION ---
@@ -184,19 +163,20 @@ Respond with:
184
  """
185
  try:
186
  if ANALYST_MODEL == "Gemini 2.5 Flash (Google)":
187
- if gemini_client is None:
188
  return "⚠️ Gemini API key missing."
189
- response = gemini_client.models.generate_content(
190
  model="gemini-2.5-flash",
191
- contents=[prompt]
 
 
192
  )
193
- return getattr(response, "text", "No response from Gemini.")
194
  else:
195
  return safe_hf_generate(hf_analyst_client, prompt, temperature=temperature, max_tokens=max_tokens)
196
  except Exception as e:
197
  return f"⚠️ Analysis failed: {e}"
198
 
199
-
200
  # ======================================================
201
  # 🚀 MAIN APP LOGIC
202
  # ======================================================
@@ -212,51 +192,7 @@ if uploaded:
212
  st.dataframe(cleaned_df.head(), use_container_width=True)
213
 
214
  with st.expander("📋 Cleaning Summary", expanded=False):
215
- st.text(summarize_dataframe(cleaned_df))
216
-
217
- with st.expander("📈 Quick Visualizations", expanded=True):
218
- numeric_cols = cleaned_df.select_dtypes(include="number").columns.tolist()
219
- categorical_cols = cleaned_df.select_dtypes(exclude="number").columns.tolist()
220
-
221
- viz_type = st.selectbox(
222
- "Visualization Type",
223
- ["Scatter Plot", "Histogram", "Box Plot", "Correlation Heatmap", "Categorical Count"]
224
- )
225
-
226
- if viz_type == "Scatter Plot" and len(numeric_cols) >= 2:
227
- x = st.selectbox("X-axis", numeric_cols)
228
- y = st.selectbox("Y-axis", numeric_cols, index=min(1, len(numeric_cols)-1))
229
- color = st.selectbox("Color", ["None"] + categorical_cols)
230
- fig = px.scatter(cleaned_df, x=x, y=y, color=None if color=="None" else color)
231
- st.plotly_chart(fig, use_container_width=True)
232
-
233
- elif viz_type == "Histogram" and numeric_cols:
234
- col = st.selectbox("Column", numeric_cols)
235
- fig = px.histogram(cleaned_df, x=col, nbins=30)
236
- st.plotly_chart(fig, use_container_width=True)
237
-
238
- elif viz_type == "Box Plot" and numeric_cols:
239
- col = st.selectbox("Column", numeric_cols)
240
- fig = px.box(cleaned_df, y=col)
241
- st.plotly_chart(fig, use_container_width=True)
242
-
243
- elif viz_type == "Correlation Heatmap" and len(numeric_cols) > 1:
244
- corr = cleaned_df[numeric_cols].corr()
245
- fig = ff.create_annotated_heatmap(
246
- z=corr.values,
247
- x=list(corr.columns),
248
- y=list(corr.index),
249
- annotation_text=corr.round(2).values,
250
- showscale=True
251
- )
252
- st.plotly_chart(fig, use_container_width=True)
253
-
254
- elif viz_type == "Categorical Count" and categorical_cols:
255
- cat = st.selectbox("Category", categorical_cols)
256
- fig = px.bar(cleaned_df[cat].value_counts().reset_index(), x="index", y=cat)
257
- st.plotly_chart(fig, use_container_width=True)
258
- else:
259
- st.warning("⚠️ Not enough columns for this visualization type.")
260
 
261
  st.subheader("💬 Ask AI About Your Data")
262
  user_query = st.text_area("Enter your question:", placeholder="e.g. What factors influence sales the most?")
 
29
  login(token=HF_TOKEN)
30
 
31
  if GEMINI_API_KEY:
32
+ genai.api_key = GEMINI_API_KEY
33
  else:
 
34
  st.warning("⚠️ Gemini API key missing. Gemini 2.5 Flash will not work.")
35
 
36
  # ======================================================
 
62
  )
63
 
64
  temperature = st.slider("Temperature", 0.0, 1.0, 0.3)
65
+ max_tokens = st.slider("Max Tokens", 128, 4096, 1024)
66
 
67
+ # Initialize HF clients
68
  hf_cleaner_client = InferenceClient(model=CLEANER_MODEL, token=HF_TOKEN)
69
  hf_analyst_client = None
70
  if ANALYST_MODEL != "Gemini 2.5 Flash (Google)":
 
74
  # 🧩 SAFE GENERATION FUNCTION
75
  # ======================================================
76
  def safe_hf_generate(client, prompt, temperature=0.3, max_tokens=512):
 
77
  try:
78
  resp = client.text_generation(
79
  prompt,
 
111
  df.drop_duplicates(inplace=True)
112
  return df
113
 
 
114
  def ai_clean_dataset(df: pd.DataFrame) -> pd.DataFrame:
 
115
  csv_text = df.to_csv(index=False)
116
  prompt = f"""
117
  You are a professional data cleaning assistant.
 
125
  Dataset:
126
  {csv_text}
127
  """
 
128
  try:
129
  cleaned_str = safe_hf_generate(hf_cleaner_client, prompt, temperature=0.1, max_tokens=4096)
130
  except Exception as e:
 
140
  st.warning(f"⚠️ AI CSV parse failed: {e}")
141
  return fallback_clean(df)
142
 
 
143
  # ======================================================
144
  # 🧩 DATA ANALYSIS
145
  # ======================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  def query_analysis_model(df: pd.DataFrame, user_query: str, dataset_name: str) -> str:
147
  csv_text = df.to_csv(index=False)
148
  prompt = f"""
149
  You are a professional data analyst.
150
  Analyze the dataset '{dataset_name}' and answer the user's question.
151
 
152
+ --- FULL DATA ---
153
  {csv_text}
154
 
155
  --- USER QUESTION ---
 
163
  """
164
  try:
165
  if ANALYST_MODEL == "Gemini 2.5 Flash (Google)":
166
+ if GEMINI_API_KEY is None:
167
  return "⚠️ Gemini API key missing."
168
+ response = genai.generate_text(
169
  model="gemini-2.5-flash",
170
+ prompt=prompt,
171
+ temperature=temperature,
172
+ max_output_tokens=max_tokens
173
  )
174
+ return getattr(response, "candidates", [{"content": "No response from Gemini."}])[0]["content"]
175
  else:
176
  return safe_hf_generate(hf_analyst_client, prompt, temperature=temperature, max_tokens=max_tokens)
177
  except Exception as e:
178
  return f"⚠️ Analysis failed: {e}"
179
 
 
180
  # ======================================================
181
  # 🚀 MAIN APP LOGIC
182
  # ======================================================
 
192
  st.dataframe(cleaned_df.head(), use_container_width=True)
193
 
194
  with st.expander("📋 Cleaning Summary", expanded=False):
195
+ st.text(f"Rows: {len(cleaned_df)} | Columns: {len(cleaned_df.columns)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  st.subheader("💬 Ask AI About Your Data")
198
  user_query = st.text_area("Enter your question:", placeholder="e.g. What factors influence sales the most?")