import os
from datetime import datetime, timedelta
import argilla as rg
import gradio as gr
import pandas as pd
import plotly.colors as colors
import plotly.graph_objects as go
from cachetools import TTLCache, cached
client = rg.Argilla(
    api_url=os.getenv("ARGILLA_API_URL"), api_key=os.getenv("ARGILLA_API_KEY")
)
cache = TTLCache(maxsize=100, ttl=timedelta(minutes=10), timer=datetime.now)
@cached(cache)
def fetch_data(dataset_name: str, workspace: str):
    return client.datasets(dataset_name, workspace=workspace)
def get_progress(dataset) -> dict:
    records = list(dataset.records)
    total_records = len(records)
    annotated_records = len(
        [record.status for record in records if record.status == "completed"]
    )
    progress = (annotated_records / total_records) * 100 if total_records > 0 else 0
    return {
        "total": total_records,
        "annotated": annotated_records,
        "progress": progress,
    }
def get_leaderboard(dataset) -> dict:
    user_annotations = {}
    for record in dataset.records:
        for response in record.responses:
            user = response.user_id
            retrieved_user = client.users(id=user)
            user = retrieved_user.username
            if user not in user_annotations:
                user_annotations[user] = 0
            user_annotations[user] += 1
    print(user_annotations)
    return user_annotations
def create_gauge_chart(progress):
    fig = go.Figure(
        go.Indicator(
            mode="gauge+number+delta",
            value=progress["progress"],
            title={"text": "Dataset Annotation Progress", "font": {"size": 24}},
            delta={"reference": 100, "increasing": {"color": "RebeccaPurple"}},
            number={"font": {"size": 40}, "valueformat": ".1f", "suffix": "%"},
            gauge={
                "axis": {"range": [None, 100], "tickwidth": 1, "tickcolor": "darkblue"},
                "bar": {"color": "deepskyblue"},
                "bgcolor": "white",
                "borderwidth": 2,
                "bordercolor": "gray",
                "steps": [
                    {"range": [0, progress["progress"]], "color": "royalblue"},
                    {"range": [progress["progress"], 100], "color": "lightgray"},
                ],
                "threshold": {
                    "line": {"color": "red", "width": 4},
                    "thickness": 0.75,
                    "value": 100,
                },
            },
        )
    )
    fig.update_layout(
        annotations=[
            dict(
                text=(
                    f"Total records: {progress['total']}
"
                    f"Annotated: {progress['annotated']} ({progress['progress']:.1f}%)
"
                    f"Remaining: {progress['total'] - progress['annotated']} ({100 - progress['progress']:.1f}%)"
                ),
                # x=0.5,
                # y=-0.2,
                showarrow=False,
                xref="paper",
                yref="paper",
                font=dict(size=16),
            )
        ],
    )
    fig.add_annotation(
        text=(
            f"Current Progress: {progress['progress']:.1f}% complete
"
            f"({progress['annotated']} out of {progress['total']} records annotated)"
        ),
        xref="paper",
        yref="paper",
        x=0.5,
        y=1.1,
        showarrow=False,
        font=dict(size=18),
        align="center",
    )
    return fig
def create_treemap(user_annotations, total_records):
    sorted_users = sorted(user_annotations.items(), key=lambda x: x[1], reverse=True)
    color_scale = colors.qualitative.Pastel + colors.qualitative.Set3
    labels, parents, values, text, user_colors = [], [], [], [], []
    for i, (user, contribution) in enumerate(sorted_users):
        percentage = (contribution / total_records) * 100
        labels.append(user)
        parents.append("Annotations")
        values.append(contribution)
        text.append(f"{contribution} annotations
{percentage:.2f}%")
        user_colors.append(color_scale[i % len(color_scale)])
    labels.append("Annotations")
    parents.append("")
    values.append(total_records)
    text.append(f"Total: {total_records} annotations")
    user_colors.append("#FFFFFF")
    fig = go.Figure(
        go.Treemap(
            labels=labels,
            parents=parents,
            values=values,
            text=text,
            textinfo="label+text",
            hoverinfo="label+text+value",
            marker=dict(colors=user_colors, line=dict(width=2)),
        )
    )
    fig.update_layout(
        title_text="User contributions to the total end dataset",
        height=500,
        margin=dict(l=10, r=10, t=50, b=10),
        paper_bgcolor="#F0F0F0",  # Light gray background
        plot_bgcolor="#F0F0F0",  # Light gray background
    )
    return fig
def update_dashboard():
    dataset = fetch_data(os.getenv("DATASET_NAME"), os.getenv("WORKSPACE"))
    progress = get_progress(dataset)
    user_annotations = get_leaderboard(dataset)
    gauge_chart = create_gauge_chart(progress)
    treemap = create_treemap(user_annotations, progress["total"])
    leaderboard_df = pd.DataFrame(
        list(user_annotations.items()), columns=["User", "Annotations"]
    )
    leaderboard_df = leaderboard_df.sort_values(
        "Annotations", ascending=False
    ).reset_index(drop=True)
    return gauge_chart, treemap, leaderboard_df
with gr.Blocks() as demo:
    gr.Markdown("# Argilla Dataset Dashboard")
    with gr.Row():
        gauge_output = gr.Plot(label="Overall Progress")
        treemap_output = gr.Plot(label="User contributions")
    with gr.Row():
        leaderboard_output = gr.Dataframe(
            label="Leaderboard", headers=["User", "Annotations"]
        )
    demo.load(
        update_dashboard,
        inputs=None,
        outputs=[gauge_output, treemap_output, leaderboard_output],
    )
    gr.Button("Refresh").click(
        update_dashboard,
        inputs=None,
        outputs=[gauge_output, treemap_output, leaderboard_output],
    )
if __name__ == "__main__":
    demo.launch()