Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 7,670 Bytes
9a5d8ec ba84108 9a5d8ec ba84108 9a5d8ec ba84108 406793a bb3d05e ba84108 bb3d05e ba84108 bb3d05e 9cfe75e 3d475b8 f677a94 406793a 3d475b8 f677a94 3d475b8 4646b78 3d475b8 9cfe75e 325ed03 b6ecc6e 6370ba9 ba84108 b6ecc6e 9ea7979 ba84108 406793a 71c9493 406793a 9cfe75e defc447 c921016 9ea7979 406793a 1e3cd55 9ea7979 f1b3d54 f35e5d2 1e3cd55 177a8e6 f1b3d54 ba84108 f1b3d54 ba84108 1e3cd55 5a001f3 406793a f1b3d54 177a8e6 d5e14ad 177a8e6 f1b3d54 f35e5d2 1e3cd55 c921016 598b6e6 4c56373 f677a94 4c56373 5a001f3 4c56373 c921016 1e3cd55 cd07d29 79079d1 fa07bfe 2954446 79079d1 4612b23 f2be8d8 4612b23 ba84108 32c703c 1529238 32c703c 1529238 32c703c 9cfe75e b0c4a5b f0fd942 06353d7 296279c 1f7a544 296279c 06353d7 be1d31d fe494ee 2954446 c50c843 406793a c50c843 4612b23 c50c843 5a001f3 ec632c7 c50c843 ad4809c c50c843 1e3cd55 c50c843 79079d1 1e3cd55 fa07bfe ac240ce f0fd942 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
# Standardbibliotheken
import os # Umgebungsvariablen (z.B. HF_TOKEN)
import time # Timing / Performance-Messung
import random # Zufallswerte (z.B. Beispiel-Reviews)
import html # HTML-Escaping für sichere Ausgabe in Gradio
import types # Monkeypatching von Instanzen (fastText .predict)
import numpy as np # Numerische Arrays und Wahrscheinlichkeiten
# Machine Learning / NLP
import torch # PyTorch: Modelle, Tensoren, Devices
import fasttext # Sprach-ID-Modell (lid.176)
# Folgende sind notwendig, auch wenn sie nicht explizit genutzt werden:
import sentencepiece # Pflicht für SentencePiece-basierte Tokenizer (z.B. DeBERTa v3)
import tiktoken # Optionaler Converter (verhindert Fallback-Fehler bei Tokenizer)
from langid.langid import LanguageIdentifier, model # Alternative Sprach-ID
# Hugging Face Ökosystem
import spaces # HF Spaces-Dekoratoren (@spaces.GPU)
from transformers import AutoTokenizer # Tokenizer laden (use_fast=False für DeBERTa v3)
from huggingface_hub import hf_hub_download # Download von Dateien/Weights aus dem HF Hub
from safetensors.torch import load_file # Sicheres & schnelles Laden von Weights (.safetensors)
# Übersetzung
import deepl # DeepL API für automatische Übersetzung
# UI / Serving
import gradio as gr # Web-UI für Demo/Spaces
# Projektspezifische Module
from lib.bert_regressor import BertMultiHeadRegressor
from lib.bert_regressor_utils import (
predict_flavours, # Hauptfunktion: Vorhersage der 8 Aromenachsen
)
from lib.wheel import build_svg_with_values # SVG-Rendering für Flavour Wheel
from lib.examples import EXAMPLES # Beispiel-Reviews (vordefiniert)
### Stettings ####################################################################
MODEL_BASE = "microsoft/deberta-v3-base"
REPO_ID = "ziem-io/deberta_flavour_regressor_multi_head"
# (optional) falls das Model-Repo privat ist:
HF_TOKEN = os.getenv("HF_TOKEN") # in Space-Secrets hinterlegen
MODEL_FILE = os.getenv("MODEL_FILE") # in Space-Secrets hinterlegen
DEEPL_API_KEY = os.getenv("DEEPL_API_KEY") # in Space-Secrets hinterlegen
##################################################################################
# --- Download Weights ---
weights_path = hf_hub_download(
repo_id=REPO_ID,
filename=MODEL_FILE,
token=HF_TOKEN
)
# --- Tokenizer (SentencePiece!) ---
tokenizer_flavours = AutoTokenizer.from_pretrained(
MODEL_BASE,
use_fast=False
)
model_flavours = BertMultiHeadRegressor(
pretrained_model_name=MODEL_BASE
)
state = load_file(weights_path) # safetensors -> dict[str, Tensor]
_ = model_flavours.load_state_dict(state, strict=False) # strict=True wenn Keys exakt passen
device = "cuda" if torch.cuda.is_available() else "cpu"
model_flavours.to(device).eval()
### Check if lang is english #####################################################
ID = LanguageIdentifier.from_modelstring(model, norm_probs=True)
def _is_eng(text: str, min_chars: int = 6, threshold: float = 0.1):
t = (text or "").strip()
if len(t) < min_chars:
return True, 0.0
lang, prob = ID.classify(t) # prob ∈ [0,1]
return (lang == "en" and prob >= threshold), float(prob)
def _translate_en(text: str, target_lang: str = "EN-GB"):
deepl_client = deepl.Translator(DEEPL_API_KEY)
result = deepl_client.translate_text(text, target_lang=target_lang)
return result.text
### Do actual prediction #########################################################
@spaces.GPU(duration=10) # Sekunden GPU-Zeit pro Call
def predict(review: str):
review = (review or "").strip()
is_translated = False
html_info_out = ""
# Abort if no text if given
if not review:
# immer drei Outputs zurückgeben
return "Please enter a review.", "", {}
# Check for lang of text
review_is_eng, review_lang_prob = _is_eng(review)
# Abort if text is not english
if not review_is_eng:
review = _translate_en(review)
html_info_out = f"""<strong style='margin-bottom: 0.5em'>Your text has been automatically translated:</strong>
<p>{review}</p>
"""
is_translated = True
prediction_flavours = {}
prediction_flavours_list = [0, 0, 0, 0, 0, 0, 0, 0]
# Do actual predictions if is english and whisky note
t_start_flavours = time.time()
prediction_flavours = predict_flavours(review, model_flavours, tokenizer_flavours, device)
prediction_flavours_list = list(prediction_flavours.values())
t_end_flavours = time.time()
html_wheel_out = build_svg_with_values(prediction_flavours_list)
json_out = {
"result": dict(prediction_flavours.items()),
"review": review,
"model": MODEL_FILE,
"device": device,
"translated": is_translated,
"duration": round((t_end_flavours - t_start_flavours), 3),
}
return html_info_out, html_wheel_out, json_out
##################################################################################
def random_text():
return random.choice(EXAMPLES)
def _start_text():
return EXAMPLES[20]
def _get_device_info():
if torch.cuda.is_available():
return f"◉ Runs on GPU: {torch.cuda.get_device_name(0)}"
else:
return "◎ Runs on CPU (May be slower)"
### Create Form interface with Gradio Framework ##################################
custom_css = """
@media (prefers-color-scheme: dark) {
svg#wheel > text {
fill: rgb(200, 200, 200);
}
}
"""
with gr.Blocks(css=custom_css) as demo:
gr.HTML("<h2>Multi-Axis Regression of Whisky Tasting Notes</h2>")
gr.HTML("""
<h3>Automatically turns Whisky Tasting Notes into Flavour Wheels.</h3>
<p>This model is a fine-tuned version of <a href='https://huggingface.co/microsoft/deberta-v3-base'>microsoft/deberta-v3-base</a> designed to analyze English whisky tasting notes. It predicts the intensity of eight sensory categories — <strong>grainy</strong>, <strong>grassy</strong>, <strong>fragrant</strong>, <strong>fruity</strong>, <strong>peated</strong>, <strong>woody</strong>, <strong>winey</strong> and <strong>off-notes</strong> — on a continuous scale from 0 (none) to 4 (extreme).</p>
""")
gr.HTML("""
<p style='color: var(--block-title-text-color)'>Learn more about use cases and get in touch at <a href='https://www.whisky-wheel.com'>www.whisky-wheel.com</a></p>
""")
#gr.HTML(f"<span style='color: var(--block-title-text-color)'>{_get_device_info()}</span>")
with gr.Row(): # alles nebeneinander
with gr.Column(scale=1): # linke Seite: Input
review_box = gr.Textbox(
label="Whisky Review",
lines=8,
placeholder="Enter whisky review",
value=_start_text(),
)
gr.HTML("<div style='color: gray; font-size: 0.9em;'>Note: Non-English texts will be automatically translated.</div>")
with gr.Row():
replace_btn = gr.Button("Load Example", variant="secondary", scale=1)
submit_btn = gr.Button("Submit", variant="primary", scale=1)
with gr.Column(scale=1): # rechte Seite: Output
html_info_out = gr.HTML(label="Info")
html_wheel_out = gr.HTML(label="Flavour Wheel")
json_out = gr.JSON(label="JSON")
# Events
submit_btn.click(predict, inputs=review_box, outputs=[html_info_out, html_wheel_out, json_out])
replace_btn.click(random_text, outputs=review_box)
demo.launch(show_api=False) |