Spaces:
Running
Running
| import { useState, useEffect, useRef, useCallback } from "react"; | |
| import { | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TextStreamer, | |
| } from "@huggingface/transformers"; | |
| interface LLMState { | |
| isLoading: boolean; | |
| isReady: boolean; | |
| error: string | null; | |
| progress: number; | |
| } | |
| interface LLMInstance { | |
| model: any; | |
| tokenizer: any; | |
| } | |
| let moduleCache: { | |
| [modelId: string]: { | |
| instance: LLMInstance | null; | |
| loadingPromise: Promise<LLMInstance> | null; | |
| }; | |
| } = {}; | |
| export const useLLM = (modelId?: string) => { | |
| const [state, setState] = useState<LLMState>({ | |
| isLoading: false, | |
| isReady: false, | |
| error: null, | |
| progress: 0, | |
| }); | |
| const instanceRef = useRef<LLMInstance | null>(null); | |
| const loadingPromiseRef = useRef<Promise<LLMInstance> | null>(null); | |
| const abortControllerRef = useRef<AbortController | null>(null); | |
| const pastKeyValuesRef = useRef<any>(null); | |
| const loadModel = useCallback(async () => { | |
| if (!modelId) { | |
| throw new Error("Model ID is required"); | |
| } | |
| const MODEL_ID = `onnx-community/LFM2-${modelId}-ONNX`; | |
| if (!moduleCache[modelId]) { | |
| moduleCache[modelId] = { | |
| instance: null, | |
| loadingPromise: null, | |
| }; | |
| } | |
| const cache = moduleCache[modelId]; | |
| const existingInstance = instanceRef.current || cache.instance; | |
| if (existingInstance) { | |
| instanceRef.current = existingInstance; | |
| cache.instance = existingInstance; | |
| setState((prev) => ({ ...prev, isReady: true, isLoading: false })); | |
| return existingInstance; | |
| } | |
| const existingPromise = loadingPromiseRef.current || cache.loadingPromise; | |
| if (existingPromise) { | |
| try { | |
| const instance = await existingPromise; | |
| instanceRef.current = instance; | |
| cache.instance = instance; | |
| setState((prev) => ({ ...prev, isReady: true, isLoading: false })); | |
| return instance; | |
| } catch (error) { | |
| setState((prev) => ({ | |
| ...prev, | |
| isLoading: false, | |
| error: | |
| error instanceof Error ? error.message : "Failed to load model", | |
| })); | |
| throw error; | |
| } | |
| } | |
| setState((prev) => ({ | |
| ...prev, | |
| isLoading: true, | |
| error: null, | |
| progress: 0, | |
| })); | |
| abortControllerRef.current = new AbortController(); | |
| const loadingPromise = (async () => { | |
| try { | |
| const progressCallback = (progress: any) => { | |
| // Only update progress for weights | |
| if ( | |
| progress.status === "progress" && | |
| progress.file.endsWith(".onnx_data") | |
| ) { | |
| const percentage = Math.round( | |
| (progress.loaded / progress.total) * 100, | |
| ); | |
| setState((prev) => ({ ...prev, progress: percentage })); | |
| } | |
| }; | |
| const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID, { | |
| progress_callback: progressCallback, | |
| }); | |
| const model = await AutoModelForCausalLM.from_pretrained(MODEL_ID, { | |
| dtype: "q4f16", | |
| device: "webgpu", | |
| progress_callback: progressCallback, | |
| }); | |
| const instance = { model, tokenizer }; | |
| instanceRef.current = instance; | |
| cache.instance = instance; | |
| loadingPromiseRef.current = null; | |
| cache.loadingPromise = null; | |
| setState((prev) => ({ | |
| ...prev, | |
| isLoading: false, | |
| isReady: true, | |
| progress: 100, | |
| })); | |
| return instance; | |
| } catch (error) { | |
| loadingPromiseRef.current = null; | |
| cache.loadingPromise = null; | |
| setState((prev) => ({ | |
| ...prev, | |
| isLoading: false, | |
| error: | |
| error instanceof Error ? error.message : "Failed to load model", | |
| })); | |
| throw error; | |
| } | |
| })(); | |
| loadingPromiseRef.current = loadingPromise; | |
| cache.loadingPromise = loadingPromise; | |
| return loadingPromise; | |
| }, [modelId]); | |
| const generateResponse = useCallback( | |
| async ( | |
| messages: Array<{ role: string; content: string }>, | |
| tools: Array<any>, | |
| onToken?: (token: string) => void, | |
| ): Promise<string> => { | |
| const instance = instanceRef.current; | |
| if (!instance) { | |
| throw new Error("Model not loaded. Call loadModel() first."); | |
| } | |
| const { model, tokenizer } = instance; | |
| // Apply chat template with tools | |
| const input = tokenizer.apply_chat_template(messages, { | |
| tools, | |
| add_generation_prompt: true, | |
| return_dict: true, | |
| }); | |
| const streamer = onToken | |
| ? new TextStreamer(tokenizer, { | |
| skip_prompt: true, | |
| skip_special_tokens: false, | |
| callback_function: (token: string) => { | |
| onToken(token); | |
| }, | |
| }) | |
| : undefined; | |
| // Generate the response | |
| const { sequences, past_key_values } = await model.generate({ | |
| ...input, | |
| past_key_values: pastKeyValuesRef.current, | |
| max_new_tokens: 512, | |
| do_sample: false, | |
| streamer, | |
| return_dict_in_generate: true, | |
| }); | |
| pastKeyValuesRef.current = past_key_values; | |
| // Decode the generated text with special tokens preserved (except final <|im_end|>) for tool call detection | |
| const response = tokenizer | |
| .batch_decode(sequences.slice(null, [input.input_ids.dims[1], null]), { | |
| skip_special_tokens: false, | |
| })[0] | |
| .replace(/<\|im_end\|>$/, ""); | |
| return response; | |
| }, | |
| [], | |
| ); | |
| const clearPastKeyValues = useCallback(() => { | |
| pastKeyValuesRef.current = null; | |
| }, []); | |
| const cleanup = useCallback(() => { | |
| if (abortControllerRef.current) { | |
| abortControllerRef.current.abort(); | |
| } | |
| }, []); | |
| useEffect(() => { | |
| return cleanup; | |
| }, [cleanup]); | |
| useEffect(() => { | |
| if (modelId && moduleCache[modelId]) { | |
| const existingInstance = | |
| instanceRef.current || moduleCache[modelId].instance; | |
| if (existingInstance) { | |
| instanceRef.current = existingInstance; | |
| setState((prev) => ({ ...prev, isReady: true })); | |
| } | |
| } | |
| }, [modelId]); | |
| return { | |
| ...state, | |
| loadModel, | |
| generateResponse, | |
| clearPastKeyValues, | |
| cleanup, | |
| }; | |
| }; | |