Spaces:
Running
Running
| import type { | |
| ChatCompletionMessageParam, | |
| ChatOptions, | |
| InitProgressCallback, | |
| MLCEngineConfig, | |
| } from "@mlc-ai/web-llm"; | |
| import { addLogEntry } from "./logEntries"; | |
| import { | |
| getSettings, | |
| updateModelLoadingProgress, | |
| updateModelSizeInMegabytes, | |
| updateResponse, | |
| updateTextGenerationState, | |
| } from "./pubSub"; | |
| import { | |
| canStartResponding, | |
| defaultContextSize, | |
| getDefaultChatCompletionCreateParamsStreaming, | |
| getDefaultChatMessages, | |
| getFormattedSearchResults, | |
| handleStreamingResponse, | |
| } from "./textGenerationUtilities"; | |
| import type { ChatMessage } from "./types"; | |
| export async function generateTextWithWebLlm() { | |
| const engine = await initializeWebLlmEngine(); | |
| if (getSettings().enableAiResponse) { | |
| await canStartResponding(); | |
| updateTextGenerationState("preparingToGenerate"); | |
| const completion = await engine.chat.completions.create({ | |
| ...getDefaultChatCompletionCreateParamsStreaming(), | |
| messages: getDefaultChatMessages( | |
| getFormattedSearchResults(true), | |
| ) as ChatCompletionMessageParam[], | |
| }); | |
| await handleStreamingResponse(completion, updateResponse, { | |
| shouldUpdateGeneratingState: true, | |
| }); | |
| } | |
| addLogEntry( | |
| `WebLLM finished generating the response. Stats: ${await engine.runtimeStatsText()}`, | |
| ); | |
| engine.unload(); | |
| } | |
| export async function generateChatWithWebLlm( | |
| messages: ChatMessage[], | |
| onUpdate: (partialResponse: string) => void, | |
| ) { | |
| const engine = await initializeWebLlmEngine(); | |
| const completion = await engine.chat.completions.create({ | |
| ...getDefaultChatCompletionCreateParamsStreaming(), | |
| messages: messages as ChatCompletionMessageParam[], | |
| }); | |
| const response = await handleStreamingResponse(completion, onUpdate); | |
| addLogEntry( | |
| `WebLLM finished generating the chat response. Stats: ${await engine.runtimeStatsText()}`, | |
| ); | |
| engine.unload(); | |
| return response; | |
| } | |
| async function initializeWebLlmEngine() { | |
| const { | |
| CreateWebWorkerMLCEngine, | |
| CreateMLCEngine, | |
| hasModelInCache, | |
| prebuiltAppConfig, | |
| } = await import("@mlc-ai/web-llm"); | |
| const selectedModelId = getSettings().webLlmModelId; | |
| updateModelSizeInMegabytes( | |
| prebuiltAppConfig.model_list.find((m) => m.model_id === selectedModelId) | |
| ?.vram_required_MB || 0, | |
| ); | |
| addLogEntry(`Selected WebLLM model: ${selectedModelId}`); | |
| const isModelCached = await hasModelInCache(selectedModelId); | |
| let initProgressCallback: InitProgressCallback | undefined; | |
| if (!isModelCached) { | |
| initProgressCallback = (report) => { | |
| updateModelLoadingProgress(Math.round(report.progress * 100)); | |
| }; | |
| } | |
| const engineConfig: MLCEngineConfig = { | |
| initProgressCallback, | |
| logLevel: "SILENT", | |
| }; | |
| const chatOptions: ChatOptions = { | |
| context_window_size: defaultContextSize, | |
| sliding_window_size: -1, | |
| attention_sink_size: -1, | |
| }; | |
| return Worker | |
| ? await CreateWebWorkerMLCEngine( | |
| new Worker(new URL("./webLlmWorker.ts", import.meta.url), { | |
| type: "module", | |
| }), | |
| selectedModelId, | |
| engineConfig, | |
| chatOptions, | |
| ) | |
| : await CreateMLCEngine(selectedModelId, engineConfig, chatOptions); | |
| } | |