Spaces:
Running on CPU Upgrade

File size: 7,670 Bytes
9a5d8ec
ba84108
 
 
 
 
 
9a5d8ec
 
ba84108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a5d8ec
 
ba84108
406793a
bb3d05e
 
 
ba84108
bb3d05e
ba84108
 
bb3d05e
9cfe75e
 
3d475b8
 
 
 
 
f677a94
406793a
3d475b8
 
 
 
 
 
f677a94
3d475b8
 
 
 
 
 
 
 
 
 
 
 
4646b78
3d475b8
 
 
 
 
9cfe75e
325ed03
b6ecc6e
6370ba9
ba84108
b6ecc6e
 
 
 
 
9ea7979
ba84108
406793a
71c9493
 
406793a
9cfe75e
defc447
c921016
9ea7979
 
406793a
1e3cd55
9ea7979
f1b3d54
f35e5d2
1e3cd55
 
177a8e6
f1b3d54
ba84108
f1b3d54
 
 
ba84108
1e3cd55
5a001f3
 
406793a
f1b3d54
177a8e6
d5e14ad
177a8e6
f1b3d54
 
 
 
 
f35e5d2
1e3cd55
c921016
 
598b6e6
4c56373
f677a94
4c56373
5a001f3
4c56373
c921016
 
1e3cd55
cd07d29
79079d1
 
fa07bfe
2954446
79079d1
4612b23
f2be8d8
4612b23
ba84108
32c703c
1529238
32c703c
1529238
32c703c
9cfe75e
b0c4a5b
 
 
 
 
 
 
 
f0fd942
06353d7
296279c
1f7a544
296279c
06353d7
be1d31d
 
 
 
 
fe494ee
2954446
c50c843
 
 
406793a
c50c843
 
4612b23
c50c843
5a001f3
ec632c7
c50c843
ad4809c
c50c843
 
 
1e3cd55
 
c50c843
79079d1
 
1e3cd55
fa07bfe
ac240ce
f0fd942
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# Standardbibliotheken
import os          # Umgebungsvariablen (z.B. HF_TOKEN)
import time        # Timing / Performance-Messung
import random      # Zufallswerte (z.B. Beispiel-Reviews)
import html        # HTML-Escaping für sichere Ausgabe in Gradio
import types       # Monkeypatching von Instanzen (fastText .predict)
import numpy as np # Numerische Arrays und Wahrscheinlichkeiten

# Machine Learning / NLP
import torch       # PyTorch: Modelle, Tensoren, Devices
import fasttext    # Sprach-ID-Modell (lid.176)
# Folgende sind notwendig, auch wenn sie nicht explizit genutzt werden:
import sentencepiece  # Pflicht für SentencePiece-basierte Tokenizer (z.B. DeBERTa v3)
import tiktoken       # Optionaler Converter (verhindert Fallback-Fehler bei Tokenizer)
from langid.langid import LanguageIdentifier, model  # Alternative Sprach-ID

# Hugging Face Ökosystem
import spaces                                     # HF Spaces-Dekoratoren (@spaces.GPU)
from transformers import AutoTokenizer            # Tokenizer laden (use_fast=False für DeBERTa v3)
from huggingface_hub import hf_hub_download       # Download von Dateien/Weights aus dem HF Hub
from safetensors.torch import load_file           # Sicheres & schnelles Laden von Weights (.safetensors)

# Übersetzung
import deepl  # DeepL API für automatische Übersetzung

# UI / Serving
import gradio as gr  # Web-UI für Demo/Spaces

# Projektspezifische Module
from lib.bert_regressor import BertMultiHeadRegressor
from lib.bert_regressor_utils import (
    predict_flavours,  # Hauptfunktion: Vorhersage der 8 Aromenachsen
)
from lib.wheel import build_svg_with_values       # SVG-Rendering für Flavour Wheel
from lib.examples import EXAMPLES                 # Beispiel-Reviews (vordefiniert)

### Stettings ####################################################################

MODEL_BASE = "microsoft/deberta-v3-base"
REPO_ID    = "ziem-io/deberta_flavour_regressor_multi_head"

# (optional) falls das Model-Repo privat ist:
HF_TOKEN = os.getenv("HF_TOKEN")  # in Space-Secrets hinterlegen
MODEL_FILE = os.getenv("MODEL_FILE")  # in Space-Secrets hinterlegen
DEEPL_API_KEY = os.getenv("DEEPL_API_KEY")  # in Space-Secrets hinterlegen

##################################################################################

# --- Download Weights ---
weights_path = hf_hub_download(
    repo_id=REPO_ID, 
    filename=MODEL_FILE, 
    token=HF_TOKEN
)

# --- Tokenizer (SentencePiece!) ---
tokenizer_flavours = AutoTokenizer.from_pretrained(
    MODEL_BASE, 
    use_fast=False
)

model_flavours = BertMultiHeadRegressor(
    pretrained_model_name=MODEL_BASE
)
state = load_file(weights_path)                          # safetensors -> dict[str, Tensor]
_ = model_flavours.load_state_dict(state, strict=False)  # strict=True wenn Keys exakt passen

device = "cuda" if torch.cuda.is_available() else "cpu"
model_flavours.to(device).eval()

### Check if lang is english #####################################################

ID = LanguageIdentifier.from_modelstring(model, norm_probs=True)

def _is_eng(text: str, min_chars: int = 6, threshold: float = 0.1):
    t = (text or "").strip()
    if len(t) < min_chars:
        return True, 0.0
    lang, prob = ID.classify(t)  # prob ∈ [0,1]
    return (lang == "en" and prob >= threshold), float(prob)

def _translate_en(text: str, target_lang: str = "EN-GB"):
    deepl_client = deepl.Translator(DEEPL_API_KEY)
    result = deepl_client.translate_text(text, target_lang=target_lang)
    return result.text

### Do actual prediction #########################################################
@spaces.GPU(duration=10)  # Sekunden GPU-Zeit pro Call
def predict(review: str):

    review = (review or "").strip()
    is_translated = False
    html_info_out = ""

    # Abort if no text if given
    if not review:
        # immer drei Outputs zurückgeben
        return "Please enter a review.", "", {}
    
    # Check for lang of text
    review_is_eng, review_lang_prob = _is_eng(review)

    # Abort if text is not english
    if not review_is_eng:
        review = _translate_en(review)
        html_info_out = f"""<strong style='margin-bottom: 0.5em'>Your text has been automatically translated:</strong>
        <p>{review}</p>
        """
        is_translated = True
    
    prediction_flavours = {}
    prediction_flavours_list = [0, 0, 0, 0, 0, 0, 0, 0]

    # Do actual predictions if is english and whisky note
    t_start_flavours = time.time()
    prediction_flavours = predict_flavours(review, model_flavours, tokenizer_flavours, device)
    prediction_flavours_list = list(prediction_flavours.values())
    t_end_flavours = time.time()

    html_wheel_out = build_svg_with_values(prediction_flavours_list)

    json_out = {
        "result": dict(prediction_flavours.items()),
        "review": review,
        "model": MODEL_FILE,
        "device": device,
        "translated": is_translated,
        "duration": round((t_end_flavours - t_start_flavours), 3),
    }

    return html_info_out, html_wheel_out, json_out

##################################################################################

def random_text():
    return random.choice(EXAMPLES)

def _start_text():
    return EXAMPLES[20]

def _get_device_info():
    if torch.cuda.is_available():
        return f"◉ Runs on GPU: {torch.cuda.get_device_name(0)}"
    else:
        return "◎ Runs on CPU (May be slower)"

### Create Form interface with Gradio Framework ##################################
custom_css = """
@media (prefers-color-scheme: dark) {
    svg#wheel > text {
        fill: rgb(200, 200, 200);
    }
}
"""

with gr.Blocks(css=custom_css) as demo:
    gr.HTML("<h2>Multi-Axis Regression of Whisky Tasting Notes</h2>")
    gr.HTML("""
    <h3>Automatically turns Whisky Tasting Notes into Flavour Wheels.</h3>
    <p>This model is a fine-tuned version of <a href='https://huggingface.co/microsoft/deberta-v3-base'>microsoft/deberta-v3-base</a> designed to analyze English whisky tasting notes. It predicts the intensity of eight sensory categories — <strong>grainy</strong>, <strong>grassy</strong>, <strong>fragrant</strong>, <strong>fruity</strong>, <strong>peated</strong>, <strong>woody</strong>, <strong>winey</strong> and <strong>off-notes</strong> — on a continuous scale from 0 (none) to 4 (extreme).</p>
""")

    gr.HTML("""
    <p style='color: var(--block-title-text-color)'>Learn more about use cases and get in touch at <a href='https://www.whisky-wheel.com'>www.whisky-wheel.com</a></p>
""")
    
    #gr.HTML(f"<span style='color: var(--block-title-text-color)'>{_get_device_info()}</span>")

    with gr.Row():  # alles nebeneinander
        with gr.Column(scale=1):  # linke Seite: Input
            review_box = gr.Textbox(
                label="Whisky Review",
                lines=8,
                placeholder="Enter whisky review",
                value=_start_text(),
            )
            gr.HTML("<div style='color: gray; font-size: 0.9em;'>Note: Non-English texts will be automatically translated.</div>")
            
            with gr.Row():
                replace_btn = gr.Button("Load Example", variant="secondary", scale=1)
                submit_btn  = gr.Button("Submit", variant="primary", scale=1)

        with gr.Column(scale=1):  # rechte Seite: Output
            html_info_out = gr.HTML(label="Info")
            html_wheel_out = gr.HTML(label="Flavour Wheel")
            json_out = gr.JSON(label="JSON")

    # Events
    submit_btn.click(predict, inputs=review_box, outputs=[html_info_out, html_wheel_out, json_out])
    replace_btn.click(random_text, outputs=review_box)

demo.launch(show_api=False)