whitphx's picture
whitphx HF Staff
Add server and its client utility
0310320
raw
history blame
6.39 kB
#!/usr/bin/env node
import yargs from "yargs";
import { hideBin } from "yargs/helpers";
const SERVER_URL = process.env.BENCH_SERVER_URL || "http://localhost:3000";
interface SubmitOptions {
platform?: "node" | "web";
modelId: string;
task: string;
mode?: "warm" | "cold";
repeats?: number;
dtype?: string;
batchSize?: number;
device?: string;
browser?: string;
headed?: boolean;
}
async function submitBenchmark(options: SubmitOptions) {
const response = await fetch(`${SERVER_URL}/api/benchmark`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(options),
});
if (!response.ok) {
throw new Error(`Failed to submit benchmark: ${response.statusText}`);
}
return await response.json();
}
async function getBenchmark(id: string) {
const response = await fetch(`${SERVER_URL}/api/benchmark/${id}`);
if (!response.ok) {
throw new Error(`Failed to get benchmark: ${response.statusText}`);
}
return await response.json();
}
async function listBenchmarks() {
const response = await fetch(`${SERVER_URL}/api/benchmarks`);
if (!response.ok) {
throw new Error(`Failed to list benchmarks: ${response.statusText}`);
}
return await response.json();
}
async function getQueueStatus() {
const response = await fetch(`${SERVER_URL}/api/queue`);
if (!response.ok) {
throw new Error(`Failed to get queue status: ${response.statusText}`);
}
return await response.json();
}
async function pollBenchmark(id: string, interval = 2000): Promise<any> {
return new Promise((resolve, reject) => {
const check = async () => {
try {
const result = await getBenchmark(id);
if (result.status === "completed") {
resolve(result);
} else if (result.status === "failed") {
reject(new Error(result.error));
} else {
console.log(`Status: ${result.status}...`);
setTimeout(check, interval);
}
} catch (error) {
reject(error);
}
};
check();
});
}
yargs(hideBin(process.argv))
.command(
"submit <modelId> <task>",
"Submit a new benchmark request",
(yargs) => {
return yargs
.positional("modelId", {
describe: "Model ID to benchmark",
type: "string",
demandOption: true,
})
.positional("task", {
describe: "Task to perform (e.g., feature-extraction, fill-mask)",
type: "string",
demandOption: true,
})
.option("platform", {
describe: "Platform to run on",
choices: ["node", "web"] as const,
default: "node" as const,
})
.option("mode", {
describe: "Cache mode",
choices: ["warm", "cold"] as const,
default: "warm" as const,
})
.option("repeats", {
describe: "Number of times to repeat the benchmark",
type: "number",
default: 3,
})
.option("batch-size", {
describe: "Batch size for inference",
type: "number",
default: 1,
})
.option("dtype", {
describe: "Data type (fp32, fp16, q8, etc.)",
type: "string",
})
.option("device", {
describe: "Device for web platform",
type: "string",
default: "webgpu",
})
.option("browser", {
describe: "Browser for web platform",
choices: ["chromium", "firefox", "webkit"] as const,
default: "chromium" as const,
})
.option("headed", {
describe: "Run browser in headed mode",
type: "boolean",
default: false,
})
.option("wait", {
describe: "Wait for benchmark completion",
type: "boolean",
default: false,
});
},
async (argv) => {
const options: SubmitOptions = {
modelId: argv.modelId,
task: argv.task,
platform: argv.platform,
mode: argv.mode,
repeats: argv.repeats,
batchSize: argv.batchSize,
device: argv.device,
browser: argv.browser,
headed: argv.headed,
};
if (argv.dtype) {
options.dtype = argv.dtype;
}
console.log("Submitting benchmark...");
const result = await submitBenchmark(options);
console.log(`✓ Benchmark queued: ${result.id}`);
console.log(` Position in queue: ${result.position}`);
if (argv.wait) {
console.log("\nWaiting for completion...");
const completed = await pollBenchmark(result.id);
console.log("\n✅ Benchmark completed!");
console.log(JSON.stringify(completed.result, null, 2));
} else {
console.log(`\nCheck status with: bench-client get ${result.id}`);
}
}
)
.command(
"get <id>",
"Get benchmark result by ID",
(yargs) => {
return yargs.positional("id", {
describe: "Benchmark ID",
type: "string",
demandOption: true,
});
},
async (argv) => {
const result = await getBenchmark(argv.id);
console.log(JSON.stringify(result, null, 2));
}
)
.command(
"list",
"List all benchmark results",
() => {},
async () => {
const result = await listBenchmarks();
console.log(`Total benchmarks: ${result.total}\n`);
result.results.forEach((b: any) => {
console.log(`${b.id} - ${b.status} - ${b.platform}/${b.modelId}/${b.task}`);
});
}
)
.command(
"queue",
"Show queue status",
() => {},
async () => {
const result = await getQueueStatus();
console.log("Queue Status:");
console.log(` Pending: ${result.status.pending}`);
console.log(` Running: ${result.status.running}`);
console.log(` Completed: ${result.status.completed}`);
console.log(` Failed: ${result.status.failed}`);
if (result.queue.length > 0) {
console.log("\nCurrent Queue:");
result.queue.forEach((b: any) => {
console.log(` [${b.status}] ${b.id} - ${b.platform}/${b.modelId}`);
});
}
}
)
.demandCommand(1, "You need to specify a command")
.help()
.alias("h", "help")
.strict()
.parse();
export { submitBenchmark, getBenchmark, listBenchmarks, getQueueStatus, pollBenchmark };