File size: 3,560 Bytes
b8c24aa
3a82207
63b82b4
 
 
 
 
c8fdb3b
3a82207
4e81072
7dc3087
deaeb85
00f3401
1ce2a4a
 
2e26df2
1ce2a4a
2e26df2
 
 
 
 
 
1ce2a4a
2e26df2
 
 
 
1ce2a4a
 
 
 
 
 
 
 
 
 
 
4e81072
7dc3087
 
63b82b4
e003367
00f3401
 
d141da3
63b82b4
e003367
ea9c0d3
7115ad7
ea9c0d3
7dc3087
64d8a64
63b82b4
64d8a64
 
63b82b4
64d8a64
63b82b4
c7f7d96
63b82b4
 
08c1bd3
a6b8174
00f3401
3a82207
 
 
 
 
 
 
 
 
a6b8174
63b82b4
3a82207
 
 
63b82b4
3a82207
63b82b4
ea9c0d3
3a82207
00f3401
ea9c0d3
00f3401
 
3a82207
 
 
 
 
 
 
 
00f3401
3a82207
63b82b4
 
 
 
e2534da
63b82b4
 
 
 
 
 
 
00f3401
63b82b4
 
 
 
ea9c0d3
63b82b4
 
 
 
9a34670
63b82b4
7448d3b
63b82b4
3a82207
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
)
import os
from threading import Thread
import spaces
import time
import subprocess

print("\n=== Environment Setup ===")

if torch.cuda.is_available():
    print(f"GPU detected: {torch.cuda.get_device_name(0)}")
    try:
        subprocess.run(
            "pip install flash-attn --no-build-isolation",
            shell=True,
            check=True,
        )
        print("✅ flash-attn installed successfully")
    except subprocess.CalledProcessError as e:
        print("⚠️ flash-attn installation failed:", e)
else:
    print("⚙️ CPU detected — skipping flash-attn installation")
    # Disable flash-attn references safely
    os.environ["DISABLE_FLASH_ATTN"] = "1"
    os.environ["FLASH_ATTENTION_SKIP_CUDA_BUILD"] = "TRUE"
try:
    from transformers.utils import import_utils
    if "flash_attn" not in import_utils.PACKAGE_DISTRIBUTION_MAPPING:
        import_utils.PACKAGE_DISTRIBUTION_MAPPING["flash_attn"] = "flash-attn"
except Exception as e:
    print("⚠️ Patch skipped:", e)

    
token = os.environ["HF_TOKEN"]


model = AutoModelForCausalLM.from_pretrained(
    "microsoft/phi-4",
    token=token,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
)
tok = AutoTokenizer.from_pretrained("microsoft/phi-4", token=token)
terminators = [
    tok.eos_token_id,
]

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

model = model.to(device)
# Dispatch Errors


@spaces.GPU(duration=60)
def chat(message, history, temperature, do_sample, max_tokens):
    chat = []
    for item in history:
        chat.append({"role": "user", "content": item[0]})
        if item[1] is not None:
            chat.append({"role": "assistant", "content": item[1]})
    chat.append({"role": "user", "content": message})
    messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    model_inputs = tok([messages], return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(
        tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        eos_token_id=terminators,
    )

    if temperature == 0:
        generate_kwargs["do_sample"] = False

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    partial_text = ""
    for new_text in streamer:
        partial_text += new_text
        yield partial_text

    yield partial_text


demo = gr.ChatInterface(
    fn=chat,
    examples=[["Write me a poem about Machine Learning."]],
    # multimodal=False,
    additional_inputs_accordion=gr.Accordion(
        label="⚙️ Parameters", open=False, render=False
    ),
    additional_inputs=[
        gr.Slider(
            minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
        ),
        gr.Checkbox(label="Sampling", value=True),
        gr.Slider(
            minimum=128,
            maximum=4096,
            step=1,
            value=512,
            label="Max new tokens",
            render=False,
        ),
    ],
    stop_btn="Stop Generation",
    title="Chat With LLMs",
    description="Now Running [microsoft/phi-4](https://huggingface.co/microsoft/phi-4)",
)
demo.launch()