Spaces:
Running
Running
| import { | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TextStreamer, | |
| } from "@huggingface/transformers"; | |
| // Worker state | |
| let model: any = null; | |
| let tokenizer: any = null; | |
| let pastKeyValues: any = null; | |
| let isGenerating = false; | |
| // Cache for loaded models | |
| const modelCache: { | |
| [modelId: string]: { | |
| model: any; | |
| tokenizer: any; | |
| }; | |
| } = {}; | |
| // Message types from main thread | |
| interface LoadMessage { | |
| type: "load"; | |
| modelId: string; | |
| } | |
| interface GenerateMessage { | |
| type: "generate"; | |
| messages: Array<{ role: string; content: string }>; | |
| tools: Array<any>; | |
| } | |
| interface InterruptMessage { | |
| type: "interrupt"; | |
| } | |
| interface ResetMessage { | |
| type: "reset"; | |
| } | |
| type WorkerMessage = LoadMessage | GenerateMessage | InterruptMessage | ResetMessage; | |
| // Message types to main thread | |
| interface ProgressMessage { | |
| type: "progress"; | |
| progress: number; | |
| file?: string; | |
| } | |
| interface ReadyMessage { | |
| type: "ready"; | |
| } | |
| interface UpdateMessage { | |
| type: "update"; | |
| token: string; | |
| tokensPerSecond: number; | |
| numTokens: number; | |
| } | |
| interface CompleteMessage { | |
| type: "complete"; | |
| text: string; | |
| } | |
| interface ErrorMessage { | |
| type: "error"; | |
| error: string; | |
| } | |
| type WorkerResponse = ProgressMessage | ReadyMessage | UpdateMessage | CompleteMessage | ErrorMessage; | |
| function postMessage(message: WorkerResponse) { | |
| self.postMessage(message); | |
| } | |
| // Load model | |
| async function loadModel(modelId: string) { | |
| try { | |
| // Check cache first | |
| if (modelCache[modelId]) { | |
| model = modelCache[modelId].model; | |
| tokenizer = modelCache[modelId].tokenizer; | |
| postMessage({ type: "ready" }); | |
| return; | |
| } | |
| const progressCallback = (progress: any) => { | |
| if ( | |
| progress.status === "progress" && | |
| progress.file.endsWith(".onnx_data") | |
| ) { | |
| const percentage = Math.round( | |
| (progress.loaded / progress.total) * 100 | |
| ); | |
| postMessage({ | |
| type: "progress", | |
| progress: percentage, | |
| file: progress.file, | |
| }); | |
| } | |
| }; | |
| // Load tokenizer | |
| tokenizer = await AutoTokenizer.from_pretrained(modelId, { | |
| progress_callback: progressCallback, | |
| }); | |
| // Load model | |
| model = await AutoModelForCausalLM.from_pretrained(modelId, { | |
| dtype: "q4f16", | |
| device: "webgpu", | |
| progress_callback: progressCallback, | |
| }); | |
| // Pre-warm the model with a dummy input for shader compilation | |
| const dummyInput = tokenizer("Hello", { | |
| return_tensors: "pt", | |
| padding: false, | |
| truncation: false, | |
| }); | |
| await model.generate({ | |
| ...dummyInput, | |
| max_new_tokens: 1, | |
| do_sample: false, | |
| }); | |
| // Cache the loaded model | |
| modelCache[modelId] = { model, tokenizer }; | |
| postMessage({ type: "ready" }); | |
| } catch (error) { | |
| postMessage({ | |
| type: "error", | |
| error: error instanceof Error ? error.message : "Failed to load model", | |
| }); | |
| } | |
| } | |
| // Generate response | |
| async function generate( | |
| messages: Array<{ role: string; content: string }>, | |
| tools: Array<any> | |
| ) { | |
| if (!model || !tokenizer) { | |
| postMessage({ type: "error", error: "Model not loaded" }); | |
| return; | |
| } | |
| try { | |
| isGenerating = true; | |
| // Apply chat template with tools | |
| const input = tokenizer.apply_chat_template(messages, { | |
| tools, | |
| add_generation_prompt: true, | |
| return_dict: true, | |
| }); | |
| // Track tokens and timing | |
| const startTime = performance.now(); | |
| let tokenCount = 0; | |
| const streamer = new TextStreamer(tokenizer, { | |
| skip_prompt: true, | |
| skip_special_tokens: false, | |
| callback_function: (token: string) => { | |
| if (!isGenerating) return; // Check if interrupted | |
| tokenCount++; | |
| const elapsed = (performance.now() - startTime) / 1000; | |
| const tps = tokenCount / elapsed; | |
| postMessage({ | |
| type: "update", | |
| token, | |
| tokensPerSecond: tps, | |
| numTokens: tokenCount, | |
| }); | |
| }, | |
| }); | |
| // Generate the response | |
| const { sequences, past_key_values } = await model.generate({ | |
| ...input, | |
| past_key_values: pastKeyValues, | |
| max_new_tokens: 1024, | |
| do_sample: false, | |
| streamer, | |
| return_dict_in_generate: true, | |
| }); | |
| pastKeyValues = past_key_values; | |
| // Decode the generated text | |
| const response = tokenizer | |
| .batch_decode(sequences.slice(null, [input.input_ids.dims[1], null]), { | |
| skip_special_tokens: false, | |
| })[0] | |
| .replace(/<\|im_end\|>$/, "") | |
| .replace(/<\|end_of_text\|>$/, ""); | |
| if (isGenerating) { | |
| postMessage({ type: "complete", text: response }); | |
| } | |
| isGenerating = false; | |
| } catch (error) { | |
| isGenerating = false; | |
| postMessage({ | |
| type: "error", | |
| error: error instanceof Error ? error.message : "Generation failed", | |
| }); | |
| } | |
| } | |
| // Interrupt generation | |
| function interrupt() { | |
| isGenerating = false; | |
| // Send a completion message with empty text to resolve the promise | |
| postMessage({ type: "complete", text: "" }); | |
| } | |
| // Reset past key values | |
| function reset() { | |
| pastKeyValues = null; | |
| } | |
| // Handle messages from main thread | |
| self.onmessage = async (e: MessageEvent<WorkerMessage>) => { | |
| const message = e.data; | |
| switch (message.type) { | |
| case "load": | |
| await loadModel(message.modelId); | |
| break; | |
| case "generate": | |
| await generate(message.messages, message.tools); | |
| break; | |
| case "interrupt": | |
| interrupt(); | |
| break; | |
| case "reset": | |
| reset(); | |
| break; | |
| } | |
| }; | |
| // Export for TypeScript | |
| export {}; | |