Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +11 -75
src/streamlit_app.py
CHANGED
|
@@ -29,9 +29,8 @@ else:
|
|
| 29 |
login(token=HF_TOKEN)
|
| 30 |
|
| 31 |
if GEMINI_API_KEY:
|
| 32 |
-
|
| 33 |
else:
|
| 34 |
-
gemini_client = None
|
| 35 |
st.warning("⚠️ Gemini API key missing. Gemini 2.5 Flash will not work.")
|
| 36 |
|
| 37 |
# ======================================================
|
|
@@ -63,9 +62,9 @@ with st.sidebar:
|
|
| 63 |
)
|
| 64 |
|
| 65 |
temperature = st.slider("Temperature", 0.0, 1.0, 0.3)
|
| 66 |
-
max_tokens = st.slider("Max Tokens", 128,
|
| 67 |
|
| 68 |
-
# Initialize
|
| 69 |
hf_cleaner_client = InferenceClient(model=CLEANER_MODEL, token=HF_TOKEN)
|
| 70 |
hf_analyst_client = None
|
| 71 |
if ANALYST_MODEL != "Gemini 2.5 Flash (Google)":
|
|
@@ -75,7 +74,6 @@ if ANALYST_MODEL != "Gemini 2.5 Flash (Google)":
|
|
| 75 |
# 🧩 SAFE GENERATION FUNCTION
|
| 76 |
# ======================================================
|
| 77 |
def safe_hf_generate(client, prompt, temperature=0.3, max_tokens=512):
|
| 78 |
-
"""HF text generation fallback to chat_completion"""
|
| 79 |
try:
|
| 80 |
resp = client.text_generation(
|
| 81 |
prompt,
|
|
@@ -113,9 +111,7 @@ def fallback_clean(df: pd.DataFrame) -> pd.DataFrame:
|
|
| 113 |
df.drop_duplicates(inplace=True)
|
| 114 |
return df
|
| 115 |
|
| 116 |
-
|
| 117 |
def ai_clean_dataset(df: pd.DataFrame) -> pd.DataFrame:
|
| 118 |
-
"""Clean dataset using AI. Full dataset sent for thorough cleaning."""
|
| 119 |
csv_text = df.to_csv(index=False)
|
| 120 |
prompt = f"""
|
| 121 |
You are a professional data cleaning assistant.
|
|
@@ -129,7 +125,6 @@ Return ONLY a valid CSV text (no markdown, no explanations).
|
|
| 129 |
Dataset:
|
| 130 |
{csv_text}
|
| 131 |
"""
|
| 132 |
-
|
| 133 |
try:
|
| 134 |
cleaned_str = safe_hf_generate(hf_cleaner_client, prompt, temperature=0.1, max_tokens=4096)
|
| 135 |
except Exception as e:
|
|
@@ -145,32 +140,16 @@ Dataset:
|
|
| 145 |
st.warning(f"⚠️ AI CSV parse failed: {e}")
|
| 146 |
return fallback_clean(df)
|
| 147 |
|
| 148 |
-
|
| 149 |
# ======================================================
|
| 150 |
# 🧩 DATA ANALYSIS
|
| 151 |
# ======================================================
|
| 152 |
-
def summarize_dataframe(df: pd.DataFrame) -> str:
|
| 153 |
-
lines = [f"Rows: {len(df)} | Columns: {len(df.columns)}", "Column summaries:"]
|
| 154 |
-
for col in df.columns[:10]:
|
| 155 |
-
non_null = int(df[col].notnull().sum())
|
| 156 |
-
if pd.api.types.is_numeric_dtype(df[col]):
|
| 157 |
-
desc = df[col].describe().to_dict()
|
| 158 |
-
mean = float(desc.get("mean", np.nan))
|
| 159 |
-
median = float(df[col].median()) if non_null > 0 else None
|
| 160 |
-
lines.append(f"- {col}: mean={mean:.3f}, median={median}, non_null={non_null}")
|
| 161 |
-
else:
|
| 162 |
-
top = df[col].value_counts().head(3).to_dict()
|
| 163 |
-
lines.append(f"- {col}: top_values={top}, non_null={non_null}")
|
| 164 |
-
return "\n".join(lines)
|
| 165 |
-
|
| 166 |
-
|
| 167 |
def query_analysis_model(df: pd.DataFrame, user_query: str, dataset_name: str) -> str:
|
| 168 |
csv_text = df.to_csv(index=False)
|
| 169 |
prompt = f"""
|
| 170 |
You are a professional data analyst.
|
| 171 |
Analyze the dataset '{dataset_name}' and answer the user's question.
|
| 172 |
|
| 173 |
-
--- FULL DATA
|
| 174 |
{csv_text}
|
| 175 |
|
| 176 |
--- USER QUESTION ---
|
|
@@ -184,19 +163,20 @@ Respond with:
|
|
| 184 |
"""
|
| 185 |
try:
|
| 186 |
if ANALYST_MODEL == "Gemini 2.5 Flash (Google)":
|
| 187 |
-
if
|
| 188 |
return "⚠️ Gemini API key missing."
|
| 189 |
-
response =
|
| 190 |
model="gemini-2.5-flash",
|
| 191 |
-
|
|
|
|
|
|
|
| 192 |
)
|
| 193 |
-
return getattr(response, "
|
| 194 |
else:
|
| 195 |
return safe_hf_generate(hf_analyst_client, prompt, temperature=temperature, max_tokens=max_tokens)
|
| 196 |
except Exception as e:
|
| 197 |
return f"⚠️ Analysis failed: {e}"
|
| 198 |
|
| 199 |
-
|
| 200 |
# ======================================================
|
| 201 |
# 🚀 MAIN APP LOGIC
|
| 202 |
# ======================================================
|
|
@@ -212,51 +192,7 @@ if uploaded:
|
|
| 212 |
st.dataframe(cleaned_df.head(), use_container_width=True)
|
| 213 |
|
| 214 |
with st.expander("📋 Cleaning Summary", expanded=False):
|
| 215 |
-
st.text(
|
| 216 |
-
|
| 217 |
-
with st.expander("📈 Quick Visualizations", expanded=True):
|
| 218 |
-
numeric_cols = cleaned_df.select_dtypes(include="number").columns.tolist()
|
| 219 |
-
categorical_cols = cleaned_df.select_dtypes(exclude="number").columns.tolist()
|
| 220 |
-
|
| 221 |
-
viz_type = st.selectbox(
|
| 222 |
-
"Visualization Type",
|
| 223 |
-
["Scatter Plot", "Histogram", "Box Plot", "Correlation Heatmap", "Categorical Count"]
|
| 224 |
-
)
|
| 225 |
-
|
| 226 |
-
if viz_type == "Scatter Plot" and len(numeric_cols) >= 2:
|
| 227 |
-
x = st.selectbox("X-axis", numeric_cols)
|
| 228 |
-
y = st.selectbox("Y-axis", numeric_cols, index=min(1, len(numeric_cols)-1))
|
| 229 |
-
color = st.selectbox("Color", ["None"] + categorical_cols)
|
| 230 |
-
fig = px.scatter(cleaned_df, x=x, y=y, color=None if color=="None" else color)
|
| 231 |
-
st.plotly_chart(fig, use_container_width=True)
|
| 232 |
-
|
| 233 |
-
elif viz_type == "Histogram" and numeric_cols:
|
| 234 |
-
col = st.selectbox("Column", numeric_cols)
|
| 235 |
-
fig = px.histogram(cleaned_df, x=col, nbins=30)
|
| 236 |
-
st.plotly_chart(fig, use_container_width=True)
|
| 237 |
-
|
| 238 |
-
elif viz_type == "Box Plot" and numeric_cols:
|
| 239 |
-
col = st.selectbox("Column", numeric_cols)
|
| 240 |
-
fig = px.box(cleaned_df, y=col)
|
| 241 |
-
st.plotly_chart(fig, use_container_width=True)
|
| 242 |
-
|
| 243 |
-
elif viz_type == "Correlation Heatmap" and len(numeric_cols) > 1:
|
| 244 |
-
corr = cleaned_df[numeric_cols].corr()
|
| 245 |
-
fig = ff.create_annotated_heatmap(
|
| 246 |
-
z=corr.values,
|
| 247 |
-
x=list(corr.columns),
|
| 248 |
-
y=list(corr.index),
|
| 249 |
-
annotation_text=corr.round(2).values,
|
| 250 |
-
showscale=True
|
| 251 |
-
)
|
| 252 |
-
st.plotly_chart(fig, use_container_width=True)
|
| 253 |
-
|
| 254 |
-
elif viz_type == "Categorical Count" and categorical_cols:
|
| 255 |
-
cat = st.selectbox("Category", categorical_cols)
|
| 256 |
-
fig = px.bar(cleaned_df[cat].value_counts().reset_index(), x="index", y=cat)
|
| 257 |
-
st.plotly_chart(fig, use_container_width=True)
|
| 258 |
-
else:
|
| 259 |
-
st.warning("⚠️ Not enough columns for this visualization type.")
|
| 260 |
|
| 261 |
st.subheader("💬 Ask AI About Your Data")
|
| 262 |
user_query = st.text_area("Enter your question:", placeholder="e.g. What factors influence sales the most?")
|
|
|
|
| 29 |
login(token=HF_TOKEN)
|
| 30 |
|
| 31 |
if GEMINI_API_KEY:
|
| 32 |
+
genai.api_key = GEMINI_API_KEY
|
| 33 |
else:
|
|
|
|
| 34 |
st.warning("⚠️ Gemini API key missing. Gemini 2.5 Flash will not work.")
|
| 35 |
|
| 36 |
# ======================================================
|
|
|
|
| 62 |
)
|
| 63 |
|
| 64 |
temperature = st.slider("Temperature", 0.0, 1.0, 0.3)
|
| 65 |
+
max_tokens = st.slider("Max Tokens", 128, 4096, 1024)
|
| 66 |
|
| 67 |
+
# Initialize HF clients
|
| 68 |
hf_cleaner_client = InferenceClient(model=CLEANER_MODEL, token=HF_TOKEN)
|
| 69 |
hf_analyst_client = None
|
| 70 |
if ANALYST_MODEL != "Gemini 2.5 Flash (Google)":
|
|
|
|
| 74 |
# 🧩 SAFE GENERATION FUNCTION
|
| 75 |
# ======================================================
|
| 76 |
def safe_hf_generate(client, prompt, temperature=0.3, max_tokens=512):
|
|
|
|
| 77 |
try:
|
| 78 |
resp = client.text_generation(
|
| 79 |
prompt,
|
|
|
|
| 111 |
df.drop_duplicates(inplace=True)
|
| 112 |
return df
|
| 113 |
|
|
|
|
| 114 |
def ai_clean_dataset(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
|
| 115 |
csv_text = df.to_csv(index=False)
|
| 116 |
prompt = f"""
|
| 117 |
You are a professional data cleaning assistant.
|
|
|
|
| 125 |
Dataset:
|
| 126 |
{csv_text}
|
| 127 |
"""
|
|
|
|
| 128 |
try:
|
| 129 |
cleaned_str = safe_hf_generate(hf_cleaner_client, prompt, temperature=0.1, max_tokens=4096)
|
| 130 |
except Exception as e:
|
|
|
|
| 140 |
st.warning(f"⚠️ AI CSV parse failed: {e}")
|
| 141 |
return fallback_clean(df)
|
| 142 |
|
|
|
|
| 143 |
# ======================================================
|
| 144 |
# 🧩 DATA ANALYSIS
|
| 145 |
# ======================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
def query_analysis_model(df: pd.DataFrame, user_query: str, dataset_name: str) -> str:
|
| 147 |
csv_text = df.to_csv(index=False)
|
| 148 |
prompt = f"""
|
| 149 |
You are a professional data analyst.
|
| 150 |
Analyze the dataset '{dataset_name}' and answer the user's question.
|
| 151 |
|
| 152 |
+
--- FULL DATA ---
|
| 153 |
{csv_text}
|
| 154 |
|
| 155 |
--- USER QUESTION ---
|
|
|
|
| 163 |
"""
|
| 164 |
try:
|
| 165 |
if ANALYST_MODEL == "Gemini 2.5 Flash (Google)":
|
| 166 |
+
if GEMINI_API_KEY is None:
|
| 167 |
return "⚠️ Gemini API key missing."
|
| 168 |
+
response = genai.generate_text(
|
| 169 |
model="gemini-2.5-flash",
|
| 170 |
+
prompt=prompt,
|
| 171 |
+
temperature=temperature,
|
| 172 |
+
max_output_tokens=max_tokens
|
| 173 |
)
|
| 174 |
+
return getattr(response, "candidates", [{"content": "No response from Gemini."}])[0]["content"]
|
| 175 |
else:
|
| 176 |
return safe_hf_generate(hf_analyst_client, prompt, temperature=temperature, max_tokens=max_tokens)
|
| 177 |
except Exception as e:
|
| 178 |
return f"⚠️ Analysis failed: {e}"
|
| 179 |
|
|
|
|
| 180 |
# ======================================================
|
| 181 |
# 🚀 MAIN APP LOGIC
|
| 182 |
# ======================================================
|
|
|
|
| 192 |
st.dataframe(cleaned_df.head(), use_container_width=True)
|
| 193 |
|
| 194 |
with st.expander("📋 Cleaning Summary", expanded=False):
|
| 195 |
+
st.text(f"Rows: {len(cleaned_df)} | Columns: {len(cleaned_df.columns)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
st.subheader("💬 Ask AI About Your Data")
|
| 198 |
user_query = st.text_area("Enter your question:", placeholder="e.g. What factors influence sales the most?")
|