whitphx's picture
whitphx HF Staff
Per-task benchmarks
5d65665
raw
history blame
4.85 kB
import { pipeline } from "@huggingface/transformers";
import { BenchmarkRawResult, aggregateMetrics } from "../core/metrics.js";
import { BenchmarkResult } from "../core/types.js";
import { clearCaches } from "./cache.js";
import { getBrowserEnvInfo } from "./envinfo.js";
import { getTaskInput } from "../core/task-inputs.js";
function now() {
return performance.now();
}
async function benchOnce(
modelId: string,
task: string,
device: string,
dtype: string | undefined,
batchSize: number
): Promise<BenchmarkRawResult | { error: { type: string; message: string; stage: "load" | "inference" } }> {
try {
const t0 = now();
const options: any = { device };
if (dtype) options.dtype = dtype;
const pipe = await pipeline(task, modelId, options);
const t1 = now();
// Get task-appropriate input
const { inputs, options: taskOptions } = getTaskInput(task, batchSize);
const t2 = now();
await pipe(inputs, taskOptions);
const t3 = now();
// Run additional inferences to measure subsequent performance
const subsequentTimes: number[] = [];
for (let i = 0; i < 3; i++) {
const t4 = now();
await pipe(inputs, taskOptions);
const t5 = now();
subsequentTimes.push(+(t5 - t4).toFixed(1));
}
return {
load_ms: +(t1 - t0).toFixed(1),
first_infer_ms: +(t3 - t2).toFixed(1),
subsequent_infer_ms: subsequentTimes,
};
} catch (error: any) {
// Determine error type and stage
const errorMessage = error?.message || String(error);
let errorType = "runtime_error";
let stage: "load" | "inference" = "load";
if (errorMessage.includes("Aborted") || errorMessage.includes("out of memory")) {
errorType = "memory_error";
} else if (errorMessage.includes("Failed to fetch") || errorMessage.includes("network")) {
errorType = "network_error";
}
return {
error: {
type: errorType,
message: errorMessage,
stage,
},
};
}
}
export async function runWebBenchmarkCold(
modelId: string,
task: string,
repeats: number,
device: string,
dtype?: string,
batchSize: number = 1
): Promise<BenchmarkResult> {
await clearCaches();
const results: BenchmarkRawResult[] = [];
let error: { type: string; message: string; stage: "load" | "inference" } | undefined;
for (let i = 0; i < repeats; i++) {
const r = await benchOnce(modelId, task, device, dtype, batchSize);
if ('error' in r) {
error = r.error;
break;
}
results.push(r);
}
const envInfo = await getBrowserEnvInfo();
const result: BenchmarkResult = {
platform: "browser",
runtime: navigator.userAgent,
mode: "cold",
repeats,
batchSize,
model: modelId,
task,
device,
environment: envInfo,
notes: "Only the 1st iteration is strictly cold in a single page session.",
};
if (error) {
result.error = error;
} else {
const metrics = aggregateMetrics(results);
result.metrics = metrics;
}
if (dtype) result.dtype = dtype;
return result;
}
export async function runWebBenchmarkWarm(
modelId: string,
task: string,
repeats: number,
device: string,
dtype?: string,
batchSize: number = 1
): Promise<BenchmarkResult> {
let error: { type: string; message: string; stage: "load" | "inference" } | undefined;
// Prefetch/warmup
try {
const options: any = { device };
if (dtype) options.dtype = dtype;
const p = await pipeline(task, modelId, options);
const { inputs: warmupInputs, options: taskOptions } = getTaskInput(task, batchSize);
await p(warmupInputs, taskOptions);
} catch (err: any) {
const errorMessage = err?.message || String(err);
let errorType = "runtime_error";
if (errorMessage.includes("Aborted") || errorMessage.includes("out of memory")) {
errorType = "memory_error";
} else if (errorMessage.includes("Failed to fetch") || errorMessage.includes("network")) {
errorType = "network_error";
}
error = {
type: errorType,
message: errorMessage,
stage: "load",
};
}
const results: BenchmarkRawResult[] = [];
if (!error) {
for (let i = 0; i < repeats; i++) {
const r = await benchOnce(modelId, task, device, dtype, batchSize);
if ('error' in r) {
error = r.error;
break;
}
results.push(r);
}
}
const envInfo = await getBrowserEnvInfo();
const result: BenchmarkResult = {
platform: "browser",
runtime: navigator.userAgent,
mode: "warm",
repeats,
batchSize,
model: modelId,
task,
device,
environment: envInfo,
};
if (error) {
result.error = error;
} else {
const metrics = aggregateMetrics(results);
result.metrics = metrics;
}
if (dtype) result.dtype = dtype;
return result;
}