Anonumous's picture
Update benchmark
20b4d4f
raw
history blame
8.77 kB
import gradio as gr
import json
from constants import INTRODUCTION_TEXT
from utils import (
init_repo,
load_data,
process_submit,
get_datasets_description,
get_metrics_html,
compute_wer_cer,
get_submit_html,
DATASETS,
)
from styles import LEADERBOARD_CSS
init_repo()
gr.set_static_paths(paths=["."])
with gr.Blocks(css=LEADERBOARD_CSS, theme=gr.themes.Soft()) as demo:
gr.HTML(
'<img src="/gradio_api/file=Logo.png" '
'style="display:block; margin:0 auto; width:34%; height:auto;">'
)
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
with gr.Tabs():
with gr.Tab("🏅 Лидерборд"):
leaderboard_html = gr.HTML(value=load_data(), every=60)
with gr.Tab("📈 Метрики"):
gr.HTML(get_metrics_html())
with gr.Group():
gr.Markdown("### Песочница: посчитайте WER/CER на своих строках")
with gr.Row():
ref = gr.Textbox(
label="Референсный текст",
placeholder="например: я люблю машинное обучение",
lines=2,
)
hyp = gr.Textbox(
label="Гипотеза (распознанный текст)",
placeholder="например: я люблю мощинное обучение",
lines=2,
)
with gr.Row():
normalize = gr.Checkbox(
value=True,
label="Нормализовать (нижний регистр, без пунктуации)",
)
btn_calc = gr.Button("Посчитать")
with gr.Row():
out_wer = gr.Number(label="WER, %", precision=2)
out_cer = gr.Number(label="CER, %", precision=2)
def _ui_compute(ref_text, hyp_text, norm):
wer, cer = compute_wer_cer(ref_text or "", hyp_text or "", norm)
return wer, cer
btn_calc.click(
_ui_compute,
inputs=[ref, hyp, normalize],
outputs=[out_wer, out_cer],
)
with gr.Tab("📊 Датасеты"):
gr.HTML(get_datasets_description())
with gr.Tab("✉️ Отправить результат"):
gr.HTML(get_submit_html())
with gr.Row():
with gr.Column():
model_name = gr.Textbox(
label="Название модели *", placeholder="MyAwesomeASRModel"
)
link = gr.Textbox(
label="Ссылка на модель *",
placeholder="https://huggingface.co/username/model",
)
license_field = gr.Textbox(
label="Лицензия *", placeholder="MIT / Apache-2.0 / Closed"
)
with gr.Column():
metrics_json = gr.TextArea(
label="Метрики JSON *",
placeholder='{"Russian_LibriSpeech": {"wer": 0.1234, "cer": 0.0567}, ...}',
lines=16,
)
submit_btn = gr.Button("🚀 Отправить", elem_classes="full-width-btn")
output_msg = gr.HTML()
def _alert(kind, text):
return f'<div class="alert {kind}">{text}</div>'
def build_json_and_submit(name, link_, lic, metrics_str):
name = (name or "").strip()
link_ = (link_ or "").strip()
lic = (lic or "").strip()
if not name:
return (
gr.update(),
_alert("error", "Укажите название модели."),
metrics_str,
)
if not link_ or not (
link_.startswith("http://") or link_.startswith("https://")
):
return (
gr.update(),
_alert(
"error", "Ссылка должна начинаться с http:// или https://"
),
metrics_str,
)
if not lic:
return (
gr.update(),
_alert("error", "Укажите лицензию модели."),
metrics_str,
)
try:
metrics = json.loads(metrics_str)
except Exception as e:
return (
gr.update(),
_alert("error", f"Невалидный JSON метрик: {e}"),
metrics_str,
)
if not isinstance(metrics, dict):
return (
gr.update(),
_alert(
"error",
"Метрики должны быть объектом JSON с датасетами верхнего уровня.",
),
metrics_str,
)
missing = [ds for ds in DATASETS if ds not in metrics]
extra = [k for k in metrics.keys() if k not in DATASETS]
if missing:
return (
gr.update(),
_alert("error", f"Отсутствуют датасеты: {', '.join(missing)}"),
metrics_str,
)
if extra:
return (
gr.update(),
_alert("error", f"Лишние ключи в метриках: {', '.join(extra)}"),
metrics_str,
)
for ds in DATASETS:
entry = metrics.get(ds)
if not isinstance(entry, dict):
return (
gr.update(),
_alert(
"error",
f"{ds}: значение должно быть объектом с полями wer и cer",
),
metrics_str,
)
for k in ("wer", "cer"):
v = entry.get(k)
if not isinstance(v, (int, float)):
return (
gr.update(),
_alert("error", f"{ds}: поле {k} должно быть числом"),
metrics_str,
)
if not (0 <= float(v) <= 1):
return (
gr.update(),
_alert(
"error",
f"{ds}: поле {k} должно быть в диапазоне [0, 1]",
),
metrics_str,
)
payload = json.dumps(
{
"model_name": name,
"link": link_,
"license": lic,
"metrics": metrics,
},
ensure_ascii=False,
)
updated_html, status_msg, cleared = process_submit(payload)
if updated_html is None:
msg = status_msg.replace("Ошибка:", "").strip()
return (
gr.update(),
_alert("error", f"Не удалось добавить: {msg}"),
metrics_str,
)
return (
updated_html,
_alert("success", "✅ Результат добавлен в лидерборд."),
"",
)
submit_btn.click(
build_json_and_submit,
inputs=[model_name, link, license_field, metrics_json],
outputs=[leaderboard_html, output_msg, metrics_json],
)
demo.launch()