MCP-WebGPU / src /workers /llm.worker.ts
shreyask's picture
feat: implement worker for LLM model loading and response generation with message handling
ebde7f4 verified
raw
history blame
5.53 kB
import {
AutoModelForCausalLM,
AutoTokenizer,
TextStreamer,
} from "@huggingface/transformers";
// Worker state
let model: any = null;
let tokenizer: any = null;
let pastKeyValues: any = null;
let isGenerating = false;
// Cache for loaded models
const modelCache: {
[modelId: string]: {
model: any;
tokenizer: any;
};
} = {};
// Message types from main thread
interface LoadMessage {
type: "load";
modelId: string;
}
interface GenerateMessage {
type: "generate";
messages: Array<{ role: string; content: string }>;
tools: Array<any>;
}
interface InterruptMessage {
type: "interrupt";
}
interface ResetMessage {
type: "reset";
}
type WorkerMessage = LoadMessage | GenerateMessage | InterruptMessage | ResetMessage;
// Message types to main thread
interface ProgressMessage {
type: "progress";
progress: number;
file?: string;
}
interface ReadyMessage {
type: "ready";
}
interface UpdateMessage {
type: "update";
token: string;
tokensPerSecond: number;
numTokens: number;
}
interface CompleteMessage {
type: "complete";
text: string;
}
interface ErrorMessage {
type: "error";
error: string;
}
type WorkerResponse = ProgressMessage | ReadyMessage | UpdateMessage | CompleteMessage | ErrorMessage;
function postMessage(message: WorkerResponse) {
self.postMessage(message);
}
// Load model
async function loadModel(modelId: string) {
try {
// Check cache first
if (modelCache[modelId]) {
model = modelCache[modelId].model;
tokenizer = modelCache[modelId].tokenizer;
postMessage({ type: "ready" });
return;
}
const progressCallback = (progress: any) => {
if (
progress.status === "progress" &&
progress.file.endsWith(".onnx_data")
) {
const percentage = Math.round(
(progress.loaded / progress.total) * 100
);
postMessage({
type: "progress",
progress: percentage,
file: progress.file,
});
}
};
// Load tokenizer
tokenizer = await AutoTokenizer.from_pretrained(modelId, {
progress_callback: progressCallback,
});
// Load model
model = await AutoModelForCausalLM.from_pretrained(modelId, {
dtype: "q4f16",
device: "webgpu",
progress_callback: progressCallback,
});
// Pre-warm the model with a dummy input for shader compilation
const dummyInput = tokenizer("Hello", {
return_tensors: "pt",
padding: false,
truncation: false,
});
await model.generate({
...dummyInput,
max_new_tokens: 1,
do_sample: false,
});
// Cache the loaded model
modelCache[modelId] = { model, tokenizer };
postMessage({ type: "ready" });
} catch (error) {
postMessage({
type: "error",
error: error instanceof Error ? error.message : "Failed to load model",
});
}
}
// Generate response
async function generate(
messages: Array<{ role: string; content: string }>,
tools: Array<any>
) {
if (!model || !tokenizer) {
postMessage({ type: "error", error: "Model not loaded" });
return;
}
try {
isGenerating = true;
// Apply chat template with tools
const input = tokenizer.apply_chat_template(messages, {
tools,
add_generation_prompt: true,
return_dict: true,
});
// Track tokens and timing
const startTime = performance.now();
let tokenCount = 0;
const streamer = new TextStreamer(tokenizer, {
skip_prompt: true,
skip_special_tokens: false,
callback_function: (token: string) => {
if (!isGenerating) return; // Check if interrupted
tokenCount++;
const elapsed = (performance.now() - startTime) / 1000;
const tps = tokenCount / elapsed;
postMessage({
type: "update",
token,
tokensPerSecond: tps,
numTokens: tokenCount,
});
},
});
// Generate the response
const { sequences, past_key_values } = await model.generate({
...input,
past_key_values: pastKeyValues,
max_new_tokens: 1024,
do_sample: false,
streamer,
return_dict_in_generate: true,
});
pastKeyValues = past_key_values;
// Decode the generated text
const response = tokenizer
.batch_decode(sequences.slice(null, [input.input_ids.dims[1], null]), {
skip_special_tokens: false,
})[0]
.replace(/<\|im_end\|>$/, "")
.replace(/<\|end_of_text\|>$/, "");
if (isGenerating) {
postMessage({ type: "complete", text: response });
}
isGenerating = false;
} catch (error) {
isGenerating = false;
postMessage({
type: "error",
error: error instanceof Error ? error.message : "Generation failed",
});
}
}
// Interrupt generation
function interrupt() {
isGenerating = false;
// Send a completion message with empty text to resolve the promise
postMessage({ type: "complete", text: "" });
}
// Reset past key values
function reset() {
pastKeyValues = null;
}
// Handle messages from main thread
self.onmessage = async (e: MessageEvent<WorkerMessage>) => {
const message = e.data;
switch (message.type) {
case "load":
await loadModel(message.modelId);
break;
case "generate":
await generate(message.messages, message.tools);
break;
case "interrupt":
interrupt();
break;
case "reset":
reset();
break;
}
};
// Export for TypeScript
export {};