whitphx HF Staff commited on
Commit
bf33293
ยท
1 Parent(s): 5d65665

update client to accept multiple queries

Browse files
Files changed (1) hide show
  1. client/src/index.ts +45 -13
client/src/index.ts CHANGED
@@ -269,7 +269,7 @@ yargs(hideBin(process.argv))
269
  }
270
  )
271
  .command(
272
- "batch <task> [query]",
273
  "Search HuggingFace models and submit benchmarks for them",
274
  (yargs) => {
275
  return yargs
@@ -279,8 +279,9 @@ yargs(hideBin(process.argv))
279
  demandOption: true,
280
  })
281
  .positional("query", {
282
- describe: "Optional search query to filter model names",
283
  type: "string",
 
284
  })
285
  .option("limit", {
286
  describe: "Maximum number of models to benchmark",
@@ -329,21 +330,52 @@ yargs(hideBin(process.argv))
329
  });
330
  },
331
  async (argv) => {
332
- console.log(`Searching for ${argv.task} models${argv.query ? ` matching "${argv.query}"` : ""}...\n`);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
- const models = await searchModels({
335
- task: argv.task as keyof typeof PIPELINE_DATA,
336
- search: argv.query,
337
- limit: argv.limit,
338
- });
 
 
 
 
 
 
 
 
 
 
 
339
 
340
- if (models.length === 0) {
341
  console.log("No models found.");
342
  return;
343
  }
344
 
345
- console.log(`Found ${models.length} models:\n`);
346
- models.forEach((model, index) => {
347
  console.log(`${index + 1}. ${formatModel(model)}`);
348
  });
349
 
@@ -365,7 +397,7 @@ yargs(hideBin(process.argv))
365
  dtype: string;
366
  }> = [];
367
 
368
- for (const model of models) {
369
  for (const platform of platforms) {
370
  for (const mode of modes) {
371
  for (const batchSize of batchSizes) {
@@ -390,7 +422,7 @@ yargs(hideBin(process.argv))
390
  }
391
 
392
  console.log(`\n๐Ÿ“Š Benchmark Plan:`);
393
- console.log(` Models: ${models.length}`);
394
  console.log(` Platforms: ${platforms.join(", ")}`);
395
  console.log(` Modes: ${modes.join(", ")}`);
396
  console.log(` Batch Sizes: ${batchSizes.join(", ")}`);
 
269
  }
270
  )
271
  .command(
272
+ "batch <task> [query...]",
273
  "Search HuggingFace models and submit benchmarks for them",
274
  (yargs) => {
275
  return yargs
 
279
  demandOption: true,
280
  })
281
  .positional("query", {
282
+ describe: "Optional search queries to filter model names (can specify multiple)",
283
  type: "string",
284
+ array: true,
285
  })
286
  .option("limit", {
287
  describe: "Maximum number of models to benchmark",
 
330
  });
331
  },
332
  async (argv) => {
333
+ const queries = argv.query && argv.query.length > 0 ? argv.query : undefined;
334
+ const queryText = queries && queries.length > 0
335
+ ? ` matching [${queries.join(", ")}]`
336
+ : "";
337
+
338
+ console.log(`Searching for ${argv.task} models${queryText}...\n`);
339
+
340
+ let allModels: ModelEntry[] = [];
341
+
342
+ if (queries && queries.length > 0) {
343
+ // Search with each query and combine results
344
+ const modelSets: ModelEntry[][] = [];
345
+ for (const query of queries) {
346
+ const models = await searchModels({
347
+ task: argv.task as keyof typeof PIPELINE_DATA,
348
+ search: query,
349
+ limit: argv.limit,
350
+ });
351
+ modelSets.push(models);
352
+ console.log(` Found ${models.length} models for query "${query}"`);
353
+ }
354
 
355
+ // Deduplicate models by ID
356
+ const modelMap = new Map<string, ModelEntry>();
357
+ for (const models of modelSets) {
358
+ for (const model of models) {
359
+ modelMap.set(model.id, model);
360
+ }
361
+ }
362
+ allModels = Array.from(modelMap.values());
363
+ console.log(` Total unique models: ${allModels.length}\n`);
364
+ } else {
365
+ // No query specified, search all
366
+ allModels = await searchModels({
367
+ task: argv.task as keyof typeof PIPELINE_DATA,
368
+ limit: argv.limit,
369
+ });
370
+ }
371
 
372
+ if (allModels.length === 0) {
373
  console.log("No models found.");
374
  return;
375
  }
376
 
377
+ console.log(`Found ${allModels.length} models:\n`);
378
+ allModels.forEach((model, index) => {
379
  console.log(`${index + 1}. ${formatModel(model)}`);
380
  });
381
 
 
397
  dtype: string;
398
  }> = [];
399
 
400
+ for (const model of allModels) {
401
  for (const platform of platforms) {
402
  for (const mode of modes) {
403
  for (const batchSize of batchSizes) {
 
422
  }
423
 
424
  console.log(`\n๐Ÿ“Š Benchmark Plan:`);
425
+ console.log(` Models: ${allModels.length}`);
426
  console.log(` Platforms: ${platforms.join(", ")}`);
427
  console.log(` Modes: ${modes.join(", ")}`);
428
  console.log(` Batch Sizes: ${batchSizes.join(", ")}`);