data_analysis / src /streamlit_app.py
Starberry15's picture
Update src/streamlit_app.py
8558e29 verified
raw
history blame
10.1 kB
# streamlit_data_analysis_app.py
# Streamlit Data Analysis App using Gemini 2.0 Flash (Free-tier)
# Features:
# - Upload CSV / Excel
# - Automatic cleaning & standardization
# - Preprocessing (imputation, encoding, scaling)
# - Quick visualizations
# - Dataset summary + preview
# - Insights powered by Gemini 2.0 Flash (Google AI)
import os
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
import google.generativeai as genai
# ---------- CONFIGURATION ----------
st.set_page_config(page_title="Data Analysis App", layout="wide")
# Load Gemini API key safely
try:
GEMINI_API_KEY = st.secrets["GEMINI_API_KEY"]
except Exception:
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if GEMINI_API_KEY:
genai.configure(api_key=GEMINI_API_KEY)
st.success("βœ… Gemini API key loaded successfully.")
else:
st.warning("⚠️ No Gemini API key found. Please add GEMINI_API_KEY to .env or Streamlit secrets.")
# ---------- UTILITIES ----------
def read_file(uploaded_file):
"""Read uploaded file and return DataFrame"""
name = uploaded_file.name.lower()
try:
if name.endswith(('.csv', '.txt')):
# βœ… FIX: Remove 'errors' argument
return pd.read_csv(uploaded_file, encoding="utf-8")
elif name.endswith(('.xls', '.xlsx')):
return pd.read_excel(uploaded_file)
else:
raise ValueError("Unsupported file type. Please upload CSV or Excel.")
except UnicodeDecodeError:
# fallback encoding if utf-8 fails
return pd.read_csv(uploaded_file, encoding="latin1")
except Exception as e:
st.error(f"❌ File reading failed: {e}")
raise
def clean_column_name(col: str) -> str:
col = str(col).strip().lower().replace("\n", " ").replace("\t", " ")
col = "_".join(col.split())
col = ''.join(c for c in col if (c.isalnum() or c == '_'))
while '__' in col:
col = col.replace('__', '_')
return col
def standardize_dataframe(df: pd.DataFrame, drop_all_nan_cols: bool = True) -> pd.DataFrame:
df = df.copy()
for c in df.select_dtypes(include=['object']).columns:
df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x)
df.columns = [clean_column_name(c) for c in df.columns]
if drop_all_nan_cols:
df.dropna(axis=1, how='all', inplace=True)
for c in df.columns:
if df[c].dtype == object:
sample = df[c].dropna().astype(str).head(20)
if not sample.empty:
parsed = pd.to_datetime(sample, errors='coerce')
if parsed.notna().sum() / len(sample) > 0.6:
df[c] = pd.to_datetime(df[c], errors='coerce')
return df
def summarize_dataframe(df: pd.DataFrame, max_rows: int = 5):
summary = {'shape': df.shape, 'columns': [], 'preview': df.head(max_rows).to_dict(orient='records')}
for c in df.columns:
info = {'name': c, 'dtype': str(df[c].dtype), 'n_missing': int(df[c].isna().sum()), 'n_unique': int(df[c].nunique(dropna=True))}
if pd.api.types.is_numeric_dtype(df[c]):
info['summary'] = df[c].describe().to_dict()
elif pd.api.types.is_datetime64_any_dtype(df[c]):
info['summary'] = {'min': str(df[c].min()), 'max': str(df[c].max())}
else:
info['top_values'] = df[c].astype(str).value_counts().head(5).to_dict()
summary['columns'].append(info)
return summary
def prepare_preprocessing_pipeline(df: pd.DataFrame, impute_strategy_num='median', scale_numeric=True, encode_categorical='onehot'):
numeric_cols = list(df.select_dtypes(include=[np.number]).columns)
cat_cols = list(df.select_dtypes(include=['object', 'category', 'bool']).columns)
transformers = []
if numeric_cols:
num_pipe = [('imputer', SimpleImputer(strategy=impute_strategy_num))]
if scale_numeric:
num_pipe.append(('scaler', StandardScaler()))
transformers.append(('num', Pipeline(num_pipe), numeric_cols))
if cat_cols:
if encode_categorical == 'onehot':
cat_pipe = Pipeline([
('imputer', SimpleImputer(strategy='most_frequent')),
('onehot', OneHotEncoder(handle_unknown='ignore', sparse=False))
])
else:
cat_pipe = Pipeline([
('imputer', SimpleImputer(strategy='most_frequent')),
('ord', OrdinalEncoder())
])
transformers.append(('cat', cat_pipe, cat_cols))
return ColumnTransformer(transformers), numeric_cols + cat_cols
def apply_preprocessing(df: pd.DataFrame, preprocessor: ColumnTransformer) -> pd.DataFrame:
X = preprocessor.fit_transform(df)
feature_names = []
for name, trans, cols in preprocessor.transformers_:
if name == 'num':
feature_names += cols
elif name == 'cat':
try:
ohe = trans.named_steps['onehot']
for col, cats in zip(cols, ohe.categories_):
feature_names += [f"{col}__{c}" for c in cats]
except Exception:
feature_names += cols
return pd.DataFrame(X, columns=feature_names)
# ---------- LLM (Gemini only) ----------
def build_dataset_prompt(summary, user_question=None):
s = [f"Dataset shape: {summary['shape'][0]} rows, {summary['shape'][1]} columns."]
for c in summary['columns']:
s.append(f"- {c['name']} ({c['dtype']}) missing={c['n_missing']} unique={c['n_unique']}")
s.append("Preview:")
for row in summary['preview']:
s.append(str(row))
if user_question:
s.append(f"User question: {user_question}")
else:
s.append("Please provide a summary, notable patterns, and suggestions for visualizations.")
return "\n".join(s)
def call_llm_gemini(prompt: str, model="gemini-2.0-flash"):
if not GEMINI_API_KEY:
return "⚠️ Gemini API key not found."
try:
model_obj = genai.GenerativeModel(model)
response = model_obj.generate_content(prompt)
return response.text
except Exception as e:
return f"❌ Gemini call failed: {e}"
# ---------- STREAMLIT UI ----------
st.title("πŸ“Š Data Analysis & Cleaning App (Gemini-Powered)")
st.markdown("Upload CSV or Excel, clean and preprocess it, visualize data, and get insights powered by **Gemini 2.0 Flash**.")
with st.sidebar:
st.header("βš™οΈ Options")
st.info("Using **Gemini 2.0 Flash (Google AI)** for insights.")
impute_strategy_num = st.selectbox("Numeric imputation", ['mean', 'median', 'most_frequent'])
encode_categorical = st.selectbox("Categorical encoding", ['onehot', 'ordinal'])
scale_numeric = st.checkbox("Scale numeric features", True)
show_raw_preview = st.checkbox("Show raw preview", True)
uploaded_file = st.file_uploader("πŸ“‚ Upload CSV or Excel file", type=['csv', 'xls', 'xlsx', 'txt'])
if uploaded_file:
# βœ… FIX: Save to /tmp for Hugging Face Spaces compatibility
temp_path = os.path.join("/tmp", uploaded_file.name)
with open(temp_path, "wb") as f:
f.write(uploaded_file.getbuffer())
with open(temp_path, "rb") as f:
raw_df = read_file(f)
if show_raw_preview:
st.subheader("Raw Data Preview")
st.dataframe(raw_df.head())
st.subheader("Data Cleaning & Standardization")
cleaned_df = standardize_dataframe(raw_df)
st.write(f"βœ… Cleaned data shape: {cleaned_df.shape}")
st.dataframe(cleaned_df.head())
st.subheader("Summary")
summary = summarize_dataframe(cleaned_df)
st.write(f"Shape: {summary['shape']}")
st.json(summary['columns'])
st.subheader("Preprocessing")
if st.button("Generate Preprocessing Pipeline"):
preproc, _ = prepare_preprocessing_pipeline(cleaned_df, impute_strategy_num, scale_numeric, encode_categorical)
processed_df = apply_preprocessing(cleaned_df, preproc)
st.success("Preprocessing complete!")
st.dataframe(processed_df.head())
st.download_button("⬇️ Download Processed CSV", processed_df.to_csv(index=False), "processed_data.csv")
st.subheader("Visualizations")
viz_col = st.selectbox("Select column", options=cleaned_df.columns)
viz_type = st.selectbox("Visualization type", ['Histogram', 'Boxplot', 'Bar (categorical)', 'Scatter', 'Correlation heatmap'])
if viz_type == 'Scatter':
second_col = st.selectbox("Second column", options=[c for c in cleaned_df.columns if c != viz_col])
if st.button("Show Visualization"):
fig, ax = plt.subplots(figsize=(8, 5))
try:
if viz_type == 'Histogram':
sns.histplot(cleaned_df[viz_col], kde=True, ax=ax)
elif viz_type == 'Boxplot':
sns.boxplot(x=cleaned_df[viz_col], ax=ax)
elif viz_type == 'Bar (categorical)':
counts = cleaned_df[viz_col].astype(str).value_counts().head(20)
sns.barplot(x=counts.values, y=counts.index, ax=ax)
elif viz_type == 'Scatter':
sns.scatterplot(x=cleaned_df[viz_col], y=cleaned_df[second_col], ax=ax)
elif viz_type == 'Correlation heatmap':
corr = cleaned_df.select_dtypes(include=[np.number]).corr()
sns.heatmap(corr, annot=True, cmap='coolwarm', ax=ax)
st.pyplot(fig)
except Exception as e:
st.error(f"Visualization failed: {e}")
st.subheader("🧠 Ask Gemini for Insights")
user_q = st.text_area("Enter your question (optional):")
if st.button("Get Insights"):
with st.spinner("Generating insights via Gemini..."):
prompt = build_dataset_prompt(summary, user_q if user_q else None)
llm_resp = call_llm_gemini(prompt)
st.write(llm_resp)
else:
st.info("πŸ“₯ Upload a file to begin.")