Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| /** | |
| * LeaderboardTable.tsx | |
| * | |
| * This component displays a structured table with hierarchical data (groups, subgroups, metrics) | |
| * and provides two independent sorting mechanisms: | |
| * | |
| * 1. Row Sorting: Clicking on a column header (model name) sorts the rows | |
| * - Implemented using TanStack Table's built-in sorting | |
| * - Controls which rows appear first in the table | |
| * - Groups sort against other groups based on their values | |
| * - Subgroups stay with their parent group but sort within the group | |
| * | |
| * 2. Column Sorting: Clicking on a row header sorts the columns (models) | |
| * - Custom implementation using modelOrderByOverallMetric | |
| * - Controls the order of models (columns) for each metric | |
| * - Completely independent of row sorting | |
| * | |
| * Both sorting mechanisms operate independently and can be used simultaneously. | |
| */ | |
| import React, { useEffect, useState, useMemo, useCallback, useRef } from 'react' | |
| import { ArrowDownTrayIcon } from '@heroicons/react/24/solid' | |
| import QualityMetricsTable from './QualityMetricsTable' | |
| import MetricInfoIcon from './MetricInfoIcon' | |
| import Descriptions from '../Descriptions' | |
| import { | |
| createColumnHelper, | |
| flexRender, | |
| getCoreRowModel, | |
| useReactTable, | |
| ColumnDef, | |
| } from '@tanstack/react-table' | |
| interface LeaderboardTableProps { | |
| benchmarkData: any | |
| selectedModels: Set<string> | |
| } | |
| // Original Row interface - used for the raw data | |
| interface OriginalRow { | |
| metric: string | |
| [key: string]: string | number | |
| } | |
| // New TableRow interface for the structured hierarchical data | |
| interface TableRow { | |
| id: string | |
| type: 'group' | 'subgroup' | 'metric' | |
| groupId?: string | |
| subgroupId?: string | |
| metricName?: string | |
| name: string | |
| visible: boolean | |
| depth: number | |
| isExpanded?: boolean | |
| [key: string]: any | |
| } | |
| interface Groups { | |
| [group: string]: { [subgroup: string]: string[] } | |
| } | |
| // For sorting rows (used when clicking column headers) | |
| interface RowSortState { | |
| columnId: string | |
| direction: 'asc' | 'desc' | |
| } | |
| // For sorting columns (used when clicking row headers) | |
| interface ColumnSortState { | |
| rowKey: string | |
| direction: 'asc' | 'desc' | |
| } | |
| const OVERALL_ROW = 'Overall' | |
| const DEFAULT_SELECTED_METRICS = new Set(['log10_p_value']) | |
| const OverallMetricFilter: React.FC<{ | |
| overallMetrics: string[] | |
| selectedOverallMetrics: Set<string> | |
| setSelectedOverallMetrics: (metrics: Set<string>) => void | |
| }> = ({ overallMetrics, selectedOverallMetrics, setSelectedOverallMetrics }) => { | |
| const toggleMetric = (metric: string) => { | |
| const newSelected = new Set(selectedOverallMetrics) | |
| if (newSelected.has(metric)) { | |
| newSelected.delete(metric) | |
| } else { | |
| newSelected.add(metric) | |
| } | |
| setSelectedOverallMetrics(newSelected) | |
| } | |
| return ( | |
| <div className="w-full"> | |
| <fieldset className="fieldset w-full p-4 rounded border border-gray-700 bg-base-200"> | |
| <legend className="fieldset-legend font-semibold"> | |
| Metrics ({selectedOverallMetrics.size}/{overallMetrics.length}) | |
| </legend> | |
| <div className="grid grid-cols-2 md:grid-cols-4 lg:grid-cols-6 gap-1 max-h-48 overflow-y-auto pr-2"> | |
| {overallMetrics.map((metric) => ( | |
| <label key={metric} className="flex items-center gap-2 text-sm"> | |
| <input | |
| type="checkbox" | |
| className="form-checkbox h-4 w-4" | |
| checked={selectedOverallMetrics.has(metric)} | |
| onChange={() => toggleMetric(metric)} | |
| /> | |
| <div className="flex items-center truncate"> | |
| <span className="truncate" title={metric}> | |
| {metric} | |
| </span> | |
| <MetricInfoIcon metricName={metric} /> | |
| </div> | |
| </label> | |
| ))} | |
| </div> | |
| </fieldset> | |
| </div> | |
| ) | |
| } | |
| const LeaderboardTable: React.FC<LeaderboardTableProps> = ({ benchmarkData, selectedModels }) => { | |
| const [rawRows, setRawRows] = useState<OriginalRow[]>([]) | |
| const [tableHeader, setTableHeader] = useState<string[]>([]) | |
| const [error, setError] = useState<string | null>(null) | |
| const [groupRows, setGroupRows] = useState<Groups>({}) | |
| const [openGroupRows, setOpenGroupRows] = useState<{ [key: string]: boolean }>({}) | |
| const [descriptionsLoaded, setDescriptionsLoaded] = useState(false) | |
| const descriptions = useRef(Descriptions.getInstance()) | |
| const [selectedMetrics, setSelectedMetrics] = useState<Set<string>>(new Set()) | |
| const [overallMetrics, setOverallMetrics] = useState<string[]>([]) | |
| const [selectedOverallMetrics, setSelectedOverallMetrics] = | |
| useState<Set<string>>(DEFAULT_SELECTED_METRICS) | |
| const [rowSortState, setRowSortState] = useState<RowSortState | null>(null) | |
| const [columnSortState, setColumnSortState] = useState<ColumnSortState | null>(null) | |
| const [modelOrderByOverallMetric, setModelOrderByOverallMetric] = useState<{ | |
| [key: string]: string[] | |
| }>({}) | |
| // Get filtered models based on selectedModels | |
| const models = useMemo(() => { | |
| return tableHeader.filter((model) => selectedModels.has(model)) | |
| }, [tableHeader, selectedModels]) | |
| // Load descriptions | |
| useEffect(() => { | |
| descriptions.current.load().then(() => setDescriptionsLoaded(true)) | |
| }, []) | |
| // Parse benchmark data when it changes | |
| useEffect(() => { | |
| if (!benchmarkData) { | |
| return | |
| } | |
| try { | |
| const data = benchmarkData | |
| const rows: OriginalRow[] = data['rows'] | |
| const allGroups = data['groups'] as { [key: string]: string[] } | |
| const { Overall: overallGroup, ...groups } = allGroups | |
| const uniqueMetrics = new Set<string>() | |
| overallGroup?.forEach((metric) => { | |
| if (metric.includes('_')) { | |
| const metricName = metric.split('_').slice(1).join('_') | |
| uniqueMetrics.add(metricName) | |
| } | |
| }) | |
| setOverallMetrics(Array.from(uniqueMetrics).sort()) | |
| setSelectedOverallMetrics(new Set(DEFAULT_SELECTED_METRICS)) | |
| // setSelectedOverallMetrics(new Set(Array.from(uniqueMetrics))) | |
| const groupsData = Object.entries(allGroups) | |
| .sort(([groupA], [groupB]) => { | |
| if (groupA === OVERALL_ROW) return -1 | |
| if (groupB === OVERALL_ROW) return 1 | |
| return groupA.localeCompare(groupB) | |
| }) | |
| .reduce( | |
| (acc, [group, metrics]) => { | |
| const sortedMetrics = [...metrics].sort() | |
| acc[group] = sortedMetrics.reduce<{ [key: string]: string[] }>((subAcc, metric) => { | |
| const [mainGroup, subGroup] = metric.split('_') | |
| if (!subAcc[mainGroup]) { | |
| subAcc[mainGroup] = [] | |
| } | |
| subAcc[mainGroup].push(metric) | |
| return subAcc | |
| }, {}) | |
| acc[group] = Object.fromEntries( | |
| Object.entries(acc[group]).sort(([subGroupA], [subGroupB]) => | |
| subGroupA.localeCompare(subGroupB) | |
| ) | |
| ) | |
| return acc | |
| }, | |
| {} as { [key: string]: { [key: string]: string[] } } | |
| ) | |
| const allKeys: string[] = Array.from(new Set(rows.flatMap((row) => Object.keys(row)))) | |
| const headers = allKeys.filter((key) => key !== 'metric') | |
| const initialOpenGroups: { [key: string]: boolean } = {} | |
| Object.keys(groupsData).forEach((group) => { | |
| initialOpenGroups[group] = false | |
| }) | |
| const allMetrics = Object.values(allGroups).flat() | |
| setSelectedMetrics(new Set(allMetrics)) | |
| setTableHeader(headers) | |
| setRawRows(rows) | |
| setGroupRows(groupsData) | |
| setOpenGroupRows(initialOpenGroups) | |
| // Initialize row sort state for Overall group | |
| setColumnSortState({ | |
| rowKey: getColumnSortRowKey(OVERALL_ROW, null, null), | |
| direction: 'asc', | |
| }) | |
| // Initialize model order by overall metric | |
| const metricOrders: { [key: string]: string[] } = {} | |
| Array.from(uniqueMetrics).forEach((metric) => { | |
| metricOrders[metric] = [...headers] | |
| }) | |
| // Store the original model order for resetting when sort is cleared | |
| setModelOrderByOverallMetric(metricOrders) | |
| setError(null) | |
| } catch (err: any) { | |
| setError('Failed to parse benchmark data, please try again: ' + err.message) | |
| } | |
| }, [benchmarkData]) | |
| const handleRowSort = (overallMetric: string, model: string) => { | |
| // Create the column ID for this metric-model combination | |
| const columnId = `${overallMetric}-${model}` | |
| let nextDirection: 'asc' | 'desc' | null = null | |
| if (!rowSortState || rowSortState.columnId !== columnId) { | |
| nextDirection = 'asc' | |
| } else if (rowSortState.direction === 'asc') { | |
| nextDirection = 'desc' | |
| } else { | |
| nextDirection = null | |
| } | |
| setRowSortState(nextDirection ? { columnId, direction: nextDirection } : null) | |
| } | |
| // Helper to generate a stable composite key for row-based column sorting | |
| function getColumnSortRowKey( | |
| group: string | null, | |
| subGroup: string | null, | |
| metric: string | null | |
| ): string { | |
| return `${group ?? ''}||${subGroup ?? ''}||${metric ?? ''}` | |
| } | |
| // Update the column order when a row's sort icon is clicked | |
| const handleColumnSort = ( | |
| group: string | null, | |
| subGroup: string | null, | |
| metric: string | null | |
| ) => { | |
| const rowKey = getColumnSortRowKey(group, subGroup, metric) | |
| // First determine the new sort direction | |
| let newDirection: 'asc' | 'desc' | null = null | |
| if (!columnSortState || columnSortState.rowKey !== rowKey) { | |
| // New sort, start with ascending | |
| newDirection = 'asc' | |
| } else if (columnSortState.direction === 'asc') { | |
| // Toggle from ascending to descending | |
| newDirection = 'desc' | |
| } else { | |
| // Toggle from descending to null (clear sort) | |
| newDirection = null | |
| } | |
| setColumnSortState(newDirection ? { rowKey, direction: newDirection } : null) | |
| // If clearing the sort, reset to default column order for this metric only | |
| if (!newDirection && metric) { | |
| setModelOrderByOverallMetric((prev) => { | |
| const newOrder = { ...prev } | |
| newOrder[metric] = [...tableHeader.filter((model) => selectedModels.has(model))] | |
| return newOrder | |
| }) | |
| } | |
| } | |
| // Find all metrics matching a particular extracted metric name (like "log10_p_value") | |
| const findAllMetricsForName = useCallback( | |
| (metricName: string): string[] => { | |
| return rawRows | |
| .filter((row) => { | |
| const metric = row.metric as string | |
| if (metric.includes('_')) { | |
| const extractedName = metric.split('_').slice(1).join('_') | |
| return extractedName.endsWith(metricName) | |
| } | |
| return false | |
| }) | |
| .map((row) => row.metric as string) | |
| }, | |
| [rawRows] | |
| ) | |
| // Identify metrics that don't belong to any overall metric group | |
| const findQualityMetrics = useCallback((): string[] => { | |
| const allMetrics = rawRows.map((row) => row.metric as string) | |
| return allMetrics.filter((metric: string) => { | |
| for (const overall of overallMetrics) { | |
| if (metric.endsWith(`_${overall}`) || metric === overall) { | |
| return false | |
| } | |
| } | |
| return true | |
| }) | |
| }, [rawRows, overallMetrics]) | |
| // Calculate average and standard deviation for a set of metrics for a specific column | |
| const calculateStats = useCallback( | |
| (metricNames: string[], columnKey: string): { avg: number; stdDev: number } => { | |
| const values = metricNames | |
| .map((metricName) => { | |
| const row = rawRows.find((row) => row.metric === metricName) | |
| return row ? Number(row[columnKey]) : NaN | |
| }) | |
| .filter((value) => !isNaN(value)) | |
| if (values.length === 0) return { avg: NaN, stdDev: NaN } | |
| const avg = values.reduce((sum, val) => sum + val, 0) / values.length | |
| const squareDiffs = values.map((value) => { | |
| const diff = value - avg | |
| return diff * diff | |
| }) | |
| const variance = squareDiffs.reduce((sum, sqrDiff) => sum + sqrDiff, 0) / values.length | |
| const stdDev = Math.sqrt(variance) | |
| return { avg, stdDev } | |
| }, | |
| [rawRows] | |
| ) | |
| // Filter metrics by group and/or subgroup | |
| const filterMetricsByGroupAndSubgroup = useCallback( | |
| ( | |
| metricNames: string[], | |
| group: string | null = null, | |
| subgroup: string | null = null | |
| ): string[] => { | |
| if (!group) return metricNames | |
| const groupMetrics = Object.values(groupRows[group] || {}).flat() as string[] | |
| if (subgroup && groupRows[group]?.[subgroup]) { | |
| return metricNames.filter( | |
| (metric) => groupRows[group][subgroup].includes(metric) && selectedMetrics.has(metric) | |
| ) | |
| } | |
| return metricNames.filter( | |
| (metric) => groupMetrics.includes(metric) && selectedMetrics.has(metric) | |
| ) | |
| }, | |
| [groupRows, selectedMetrics] | |
| ) | |
| // Compute visible metrics for rendering | |
| const visibleMetrics = overallMetrics.filter((metric) => selectedOverallMetrics.has(metric)) | |
| // Generate data for the table | |
| const tableData = useMemo(() => { | |
| const rows: TableRow[] = [] | |
| let groupEntries = Object.entries(groupRows) | |
| // --- Manual row sorting using rowSortState --- | |
| if (rowSortState) { | |
| const { columnId, direction } = rowSortState | |
| const [metric, model] = columnId.split('-') | |
| groupEntries = [...groupEntries].sort(([groupA, subGroupsA], [groupB, subGroupsB]) => { | |
| const allGroupMetricsA = Object.values(subGroupsA).flat() | |
| const allGroupMetricsB = Object.values(subGroupsB).flat() | |
| const allMetricsWithNameA = findAllMetricsForName(metric) | |
| const allMetricsWithNameB = allMetricsWithNameA | |
| const metricsInGroupA = allGroupMetricsA.filter((m) => allMetricsWithNameA.includes(m)) | |
| const metricsInGroupB = allGroupMetricsB.filter((m) => allMetricsWithNameB.includes(m)) | |
| const statsA = calculateStats(metricsInGroupA, model) | |
| const statsB = calculateStats(metricsInGroupB, model) | |
| const valueA = !isNaN(statsA.avg) ? statsA.avg : -Infinity | |
| const valueB = !isNaN(statsB.avg) ? statsB.avg : -Infinity | |
| return direction === 'asc' ? valueA - valueB : valueB - valueA | |
| }) | |
| } | |
| groupEntries.forEach(([group, subGroups]) => { | |
| const allGroupMetrics = Object.values(subGroups).flat() | |
| const visibleGroupMetrics = filterMetricsByGroupAndSubgroup(allGroupMetrics, group) | |
| if (visibleGroupMetrics.length === 0) return | |
| const groupRow: TableRow = { | |
| id: `group-${group}`, | |
| type: 'group', | |
| name: group, | |
| visible: true, | |
| depth: 0, | |
| isExpanded: openGroupRows[group], | |
| } | |
| selectedOverallMetrics.forEach((metric) => { | |
| if (overallMetrics.includes(metric)) { | |
| models.forEach((model) => { | |
| const allMetricsWithName = findAllMetricsForName(metric) | |
| const metricsInGroupForThisMetric = visibleGroupMetrics.filter((m) => | |
| allMetricsWithName.includes(m) | |
| ) | |
| const stats = calculateStats(metricsInGroupForThisMetric, model) | |
| groupRow[`${metric}-${model}`] = !isNaN(stats.avg) | |
| ? { avg: stats.avg, stdDev: stats.stdDev } | |
| : null | |
| }) | |
| } | |
| }) | |
| rows.push(groupRow) | |
| if (openGroupRows[group]) { | |
| let subGroupEntries = Object.entries(subGroups).sort(([a], [b]) => a.localeCompare(b)) | |
| if (rowSortState) { | |
| const { columnId, direction } = rowSortState | |
| const [metric, model] = columnId.split('-') | |
| subGroupEntries = [...subGroupEntries].sort(([subA, metricsA], [subB, metricsB]) => { | |
| const allMetricsWithName = findAllMetricsForName(metric) | |
| const metricsInSubgroupA = metricsA.filter((m) => allMetricsWithName.includes(m)) | |
| const metricsInSubgroupB = metricsB.filter((m) => allMetricsWithName.includes(m)) | |
| const statsA = calculateStats(metricsInSubgroupA, model) | |
| const statsB = calculateStats(metricsInSubgroupB, model) | |
| const valueA = !isNaN(statsA.avg) ? statsA.avg : -Infinity | |
| const valueB = !isNaN(statsB.avg) ? statsB.avg : -Infinity | |
| return direction === 'asc' ? valueA - valueB : valueB - valueA | |
| }) | |
| } | |
| subGroupEntries.forEach(([subGroup, metrics]) => { | |
| const visibleSubgroupMetrics = filterMetricsByGroupAndSubgroup(metrics, group, subGroup) | |
| if (visibleSubgroupMetrics.length === 0) return | |
| const subgroupRow: TableRow = { | |
| id: `group-${group}-subgroup-${subGroup}`, | |
| type: 'subgroup', | |
| groupId: group, | |
| name: subGroup, | |
| visible: true, | |
| depth: 1, | |
| isExpanded: false, | |
| } | |
| selectedOverallMetrics.forEach((metric) => { | |
| if (overallMetrics.includes(metric)) { | |
| models.forEach((model) => { | |
| const allMetricsWithName = findAllMetricsForName(metric) | |
| const metricsInSubgroupForThisMetric = visibleSubgroupMetrics.filter((m) => | |
| allMetricsWithName.includes(m) | |
| ) | |
| const stats = calculateStats(metricsInSubgroupForThisMetric, model) | |
| subgroupRow[`${metric}-${model}`] = !isNaN(stats.avg) | |
| ? { avg: stats.avg, stdDev: stats.stdDev } | |
| : null | |
| }) | |
| } | |
| }) | |
| rows.push(subgroupRow) | |
| }) | |
| } | |
| }) | |
| return rows | |
| }, [ | |
| rawRows, | |
| groupRows, | |
| openGroupRows, | |
| selectedOverallMetrics, | |
| selectedMetrics, | |
| models, | |
| columnSortState, | |
| modelOrderByOverallMetric, | |
| rowSortState, | |
| ]) | |
| // Effect: update model order when columnSortState or dependencies change | |
| useEffect(() => { | |
| if (!columnSortState) return | |
| // Parse out group, subGroup, metric from rowKey | |
| const [group, subGroup, metric] = columnSortState.rowKey.split('||').map((v) => v || null) | |
| const newDirection = columnSortState.direction | |
| if (!newDirection) return // Only run if a sort direction is present | |
| // Update model order for all visible metrics | |
| const metricsToUpdate = Array.from(selectedOverallMetrics) | |
| // Find the row in tableData that was clicked for sorting | |
| let rowToSort: TableRow | undefined | |
| if (group && subGroup && !metric) { | |
| // Subgroup row | |
| rowToSort = tableData.find( | |
| (row) => row.type === 'subgroup' && row.groupId === group && row.name === subGroup | |
| ) | |
| } else if (group && !subGroup && !metric) { | |
| // Group row | |
| rowToSort = tableData.find((row) => row.type === 'group' && row.name === group) | |
| } else if (metric) { | |
| // Metric row - not currently in tableData, handled differently | |
| rowToSort = undefined | |
| } | |
| if (!rowToSort && !metric) { | |
| // Try to proceed anyway with group/subgroup sorting | |
| if (group) { | |
| metricsToUpdate.forEach((metricName) => { | |
| // Get existing model order | |
| const currentOrder = modelOrderByOverallMetric[metricName] || [...models] | |
| // For group/subgroup with no row found, keep current model order but reverse it if changing direction | |
| if (newDirection === 'asc') { | |
| setModelOrderByOverallMetric((prev) => ({ | |
| ...prev, | |
| [metricName]: [...currentOrder], | |
| })) | |
| } else { | |
| setModelOrderByOverallMetric((prev) => ({ | |
| ...prev, | |
| [metricName]: [...currentOrder].reverse(), | |
| })) | |
| } | |
| }) | |
| } | |
| return | |
| } | |
| // Check if rowToSort has all the expected metrics | |
| for (const metricName of metricsToUpdate) { | |
| if ( | |
| !rowToSort || | |
| !models.some((model) => rowToSort[`${metricName}-${model}`] !== undefined) | |
| ) { | |
| console.log(`Row does not have metric values for ${metricName}`, rowToSort) | |
| } | |
| } | |
| // Sort the models for each metric | |
| const newOrders: { [key: string]: string[] } = {} | |
| metricsToUpdate.forEach((metricName) => { | |
| // Sort models based on the values in the clicked row | |
| const modelScores: { model: string; score: number }[] = models.map((model: string) => { | |
| let score = -Infinity | |
| if (rowToSort) { | |
| // For group/subgroup rows, use the aggregated values in the row for each metric | |
| const value: { avg: number; stdDev: number } | null = | |
| rowToSort[`${metricName}-${model}`] ?? null | |
| score = value && !isNaN(value.avg) ? value.avg : -Infinity | |
| } else if (metric) { | |
| // For metric rows (which aren't in tableData), we need a different approach | |
| // Find metrics for this group that have this metric name | |
| const allMetricsWithName = findAllMetricsForName(metric) | |
| if (allMetricsWithName.length > 0) { | |
| const values = allMetricsWithName | |
| .map((metricId) => { | |
| const row = rawRows.find((r) => r.metric === metricId) | |
| return row ? Number(row[model]) : NaN | |
| }) | |
| .filter((val) => !isNaN(val)) | |
| if (values.length > 0) { | |
| const avg = values.reduce((sum, val) => sum + val, 0) / values.length | |
| score = !isNaN(avg) ? avg : -Infinity | |
| } | |
| } | |
| } | |
| return { model, score } | |
| }) | |
| modelScores.sort((a, b) => (newDirection === 'asc' ? a.score - b.score : b.score - a.score)) | |
| newOrders[metricName] = modelScores | |
| .map((item) => item.model) | |
| .filter((m) => selectedModels.has(m)) | |
| }) | |
| // Only update if any order actually changed | |
| setModelOrderByOverallMetric((prev) => { | |
| let changed = false | |
| const next = { ...prev } | |
| metricsToUpdate.forEach((metricName) => { | |
| const currentOrder = prev[metricName] || [] | |
| const newOrder = newOrders[metricName] || [] | |
| const arraysEqual = (a: string[], b: string[]) => | |
| a.length === b.length && a.every((v, i) => v === b[i]) | |
| if (!arraysEqual(currentOrder, newOrder)) { | |
| next[metricName] = [...newOrder] | |
| changed = true | |
| } | |
| }) | |
| return changed ? next : prev | |
| }) | |
| }, [ | |
| columnSortState, | |
| models, | |
| selectedModels, | |
| modelOrderByOverallMetric, | |
| tableData, | |
| rawRows, | |
| selectedOverallMetrics, | |
| ]) | |
| // CSV export function | |
| const exportToCsv = () => { | |
| // Build header row | |
| const header = [ | |
| 'Attack Categories', | |
| ...overallMetrics | |
| .filter((metric) => selectedOverallMetrics.has(metric)) | |
| .flatMap((metric) => { | |
| const metricModels = modelOrderByOverallMetric[metric] || models | |
| return metricModels.map((model) => `${metric} - ${descriptions.current.getModelAlias(model) || model}`) | |
| }), | |
| ] | |
| // Build data rows | |
| const rows: (string | number)[][] = [] | |
| tableData.forEach((row) => { | |
| const csvRow: (string | number)[] = [row.name] | |
| overallMetrics | |
| .filter((metric) => selectedOverallMetrics.has(metric)) | |
| .forEach((metric) => { | |
| const metricModels = modelOrderByOverallMetric[metric] || models | |
| metricModels.forEach((model: string) => { | |
| const value = row[`${metric}-${model}`] as { avg: number; stdDev: number } | null | |
| if (!value) { | |
| csvRow.push('N/A') | |
| } else { | |
| csvRow.push(`${value.avg.toFixed(3)} ± ${value.stdDev.toFixed(3)}`) | |
| } | |
| }) | |
| }) | |
| rows.push(csvRow) | |
| }) | |
| // Generate CSV | |
| const csv = [header, ...rows] | |
| .map((row) => row.map((cell) => `"${String(cell).replace(/"/g, '""')}"`).join(',')) | |
| .join('\n') | |
| // Download | |
| const blob = new Blob([csv], { type: 'text/csv' }) | |
| const url = URL.createObjectURL(blob) | |
| const a = document.createElement('a') | |
| a.href = url | |
| a.download = 'leaderboard_metrics.csv' | |
| document.body.appendChild(a) | |
| a.click() | |
| document.body.removeChild(a) | |
| URL.revokeObjectURL(url) | |
| } | |
| // Toggle group expansion | |
| const toggleGroup = (group: string) => { | |
| setOpenGroupRows((prev) => ({ | |
| ...prev, | |
| [group]: !prev[group], | |
| })) | |
| } | |
| // Helper to get current column sort config for a row | |
| function getColumnSort(group: string | null, subGroup: string | null, metric: string | null) { | |
| const rowKey = getColumnSortRowKey(group, subGroup, metric) | |
| return columnSortState && columnSortState.rowKey === rowKey ? columnSortState : null | |
| } | |
| // Prepare columns for TanStack Table | |
| const columns = useMemo<any[]>(() => { | |
| const columnHelper = createColumnHelper<TableRow>() | |
| const cols: any[] = [] | |
| cols.push( | |
| columnHelper.accessor((row) => row.name, { | |
| id: 'category', | |
| header: () => 'Attack Categories', | |
| cell: (info) => { | |
| const row = info.row.original as TableRow | |
| const depth = row.depth || 0 | |
| if (row.type === 'group') { | |
| return ( | |
| <div | |
| className="sticky left-0 font-medium cursor-pointer select-none flex items-center" | |
| onClick={() => toggleGroup(row.name)} | |
| > | |
| <span>{row.isExpanded ? '▼ ' : '▶ '}</span> | |
| <span className="flex-1">{row.name}</span>{' '} | |
| <span | |
| className="ml-1 cursor-pointer font-bold" | |
| onClick={(e) => { | |
| e.stopPropagation() | |
| handleColumnSort(row.name, null, null) | |
| }} | |
| title={ | |
| getColumnSort(row.name, null, null) | |
| ? getColumnSort(row.name, null, null)?.direction === 'asc' | |
| ? 'Currently sorting models by this row in ascending order (low to high). Click for descending order.' | |
| : 'Currently sorting models by this row in descending order (high to low). Click to clear sort.' | |
| : 'Click to sort models by values in this row (independent of row sorting)' | |
| } | |
| > | |
| {getColumnSort(row.name, null, null) | |
| ? getColumnSort(row.name, null, null)?.direction === 'asc' | |
| ? '→' | |
| : '←' | |
| : '⇆'} | |
| </span> | |
| </div> | |
| ) | |
| } else if (row.type === 'subgroup') { | |
| return ( | |
| <div className="sticky left-0 pl-6 font-medium flex items-center gap-1"> | |
| <span className="flex-1">{row.name}</span> | |
| <span | |
| className="ml-1 cursor-pointer font-bold" | |
| onClick={(e) => { | |
| e.stopPropagation() | |
| handleColumnSort(row.groupId!, row.name, null) | |
| }} | |
| title={ | |
| getColumnSort(row.groupId!, row.name, null) | |
| ? getColumnSort(row.groupId!, row.name, null)?.direction === 'asc' | |
| ? 'Currently sorting models by this subgroup in ascending order (low to high). Click for descending order.' | |
| : 'Currently sorting models by this subgroup in descending order (high to low). Click to clear sort.' | |
| : 'Click to sort models by values in this subgroup (independent of row sorting)' | |
| } | |
| > | |
| {getColumnSort(row.groupId!, row.name, null) | |
| ? getColumnSort(row.groupId!, row.name, null)?.direction === 'asc' | |
| ? '→' | |
| : '←' | |
| : '⇆'} | |
| </span> | |
| </div> | |
| ) | |
| } else { | |
| // Metric row (add column sorting for model order) | |
| return ( | |
| <div className="sticky left-0 pl-12 font-medium flex items-center gap-1"> | |
| <span className="flex-1">{row.name}</span> | |
| <span | |
| className="ml-1 cursor-pointer font-bold" | |
| onClick={(e) => { | |
| e.stopPropagation() | |
| handleColumnSort( | |
| row.groupId ?? null, | |
| row.subgroupId ?? null, | |
| row.metricName ?? row.name | |
| ) | |
| }} | |
| title={ | |
| getColumnSort( | |
| row.groupId ?? null, | |
| row.subgroupId ?? null, | |
| row.metricName ?? row.name | |
| ) | |
| ? getColumnSort( | |
| row.groupId ?? null, | |
| row.subgroupId ?? null, | |
| row.metricName ?? row.name | |
| )?.direction === 'asc' | |
| ? 'Currently sorting models by this metric in ascending order (low to high). Click for descending order.' | |
| : 'Currently sorting models by this metric in descending order (high to low). Click to clear sort.' | |
| : 'Click to sort models by values in this metric (independent of row sorting)' | |
| } | |
| > | |
| {getColumnSort( | |
| row.groupId ?? null, | |
| row.subgroupId ?? null, | |
| row.metricName ?? row.name | |
| ) | |
| ? getColumnSort( | |
| row.groupId ?? null, | |
| row.subgroupId ?? null, | |
| row.metricName ?? row.name | |
| )?.direction === 'asc' | |
| ? '→' | |
| : '←' | |
| : '⇆'} | |
| </span> | |
| </div> | |
| ) | |
| } | |
| }, | |
| }) | |
| ) | |
| overallMetrics | |
| .filter((metric) => selectedOverallMetrics.has(metric)) | |
| .forEach((metric) => { | |
| const metricModels = modelOrderByOverallMetric[metric] || models | |
| metricModels.forEach((model: string) => { | |
| cols.push( | |
| columnHelper.accessor((row) => row[`${metric}-${model}`], { | |
| id: `${metric}-${model}`, | |
| header: () => { | |
| const isSorted = rowSortState && rowSortState.columnId === `${metric}-${model}` | |
| const direction = rowSortState ? rowSortState.direction : 'desc' | |
| return ( | |
| <div | |
| className="cursor-pointer select-none" | |
| onClick={() => handleRowSort(metric, model)} | |
| > | |
| {descriptionsLoaded && descriptions.current.getModelAlias(model) || model} | |
| <span | |
| className="ml-1 font-bold" | |
| title={ | |
| isSorted | |
| ? direction === 'asc' | |
| ? 'Currently sorting rows by this column in ascending order (low to high). Click for descending order.' | |
| : 'Currently sorting rows by this column in descending order (high to low). Click to clear sort.' | |
| : 'Click to sort rows by values in this column (subgroups always stay with their parent group)' | |
| } | |
| > | |
| {isSorted ? (direction === 'asc' ? '↑' : '↓') : '⇅'} | |
| </span> | |
| </div> | |
| ) | |
| }, | |
| cell: (info) => { | |
| const value = info.getValue() as { avg: number; stdDev: number } | null | |
| if (!value) return 'N/A' | |
| return `${value.avg.toFixed(3)} ± ${value.stdDev.toFixed(3)}` | |
| }, | |
| }) | |
| ) | |
| }) | |
| }) | |
| return cols | |
| }, [ | |
| selectedOverallMetrics, | |
| overallMetrics, | |
| modelOrderByOverallMetric, | |
| rowSortState, | |
| columnSortState, | |
| models, | |
| ]) | |
| // Create the table instance | |
| const table = useReactTable({ | |
| data: tableData, | |
| columns, | |
| getCoreRowModel: getCoreRowModel(), | |
| }) | |
| return ( | |
| <div className="rounded"> | |
| {error && <div className="text-red-500">{error}</div>} | |
| {!error && ( | |
| <div className="flex flex-col gap-4"> | |
| <div className="flex flex-col gap-4"> | |
| <OverallMetricFilter | |
| overallMetrics={overallMetrics} | |
| selectedOverallMetrics={selectedOverallMetrics} | |
| setSelectedOverallMetrics={setSelectedOverallMetrics} | |
| /> | |
| {/* <LeaderboardFilter | |
| groups={groupRows} | |
| selectedMetrics={selectedMetrics} | |
| setSelectedMetrics={setSelectedMetrics} | |
| /> */} | |
| </div> | |
| {selectedModels.size === 0 || | |
| selectedMetrics.size === 0 || | |
| visibleMetrics.length === 0 ? ( | |
| <div className="text-center p-4 text-lg"> | |
| Please select at least one model and one metric to display the data | |
| </div> | |
| ) : ( | |
| <> | |
| {/* Quality metrics table */} | |
| <QualityMetricsTable | |
| qualityMetrics={findQualityMetrics()} | |
| tableHeader={tableHeader} | |
| selectedModels={selectedModels} | |
| tableRows={rawRows} | |
| /> | |
| {/* Main metrics table */} | |
| <div className="relative flex justify-end mb-6"> | |
| <button | |
| className="absolute top-0 right-0 btn btn-ghost btn-circle" | |
| title="Export CSV" | |
| onClick={exportToCsv} | |
| > | |
| <ArrowDownTrayIcon className="h-6 w-6" /> | |
| </button> | |
| </div> | |
| <div className="overflow-x-auto max-h-[80vh] overflow-y-auto"> | |
| <table className="table w-full min-w-max border-separate border-spacing-0 border-gray-700 border"> | |
| <thead> | |
| <tr> | |
| <th className="sticky left-0 top-0 bg-base-100 z-20 border border-gray-700"> | |
| Attack Categories | |
| </th> | |
| {/* Add metric group headers */} | |
| {overallMetrics | |
| .filter((metric) => selectedOverallMetrics.has(metric)) | |
| .map((metric) => ( | |
| <th | |
| key={`header-metric-${metric}`} | |
| className="sticky top-0 bg-base-100 z-10 text-center text-xs border border-gray-700 select-none" | |
| colSpan={(modelOrderByOverallMetric[metric] || models).length} | |
| > | |
| <div className="flex items-center justify-center"> | |
| <span>{metric}</span> | |
| <MetricInfoIcon metricName={metric} /> | |
| </div> | |
| </th> | |
| ))} | |
| </tr> | |
| {/* Add model headers */} | |
| <tr> | |
| <th className="sticky left-0 top-12 bg-base-100 z-30 border border-gray-700"></th> | |
| {table | |
| .getHeaderGroups()[0] | |
| .headers.slice(1) | |
| .map((header) => ( | |
| <th | |
| key={header.id} | |
| className="sticky top-12 bg-base-100 z-10 text-center text-xs border border-gray-700" | |
| > | |
| {header.isPlaceholder | |
| ? null | |
| : flexRender(header.column.columnDef.header, header.getContext())} | |
| </th> | |
| ))} | |
| </tr> | |
| </thead> | |
| <tbody> | |
| {table.getRowModel().rows.map((row) => ( | |
| <tr | |
| key={row.id} | |
| className={`${ | |
| row.original.type === 'group' | |
| ? 'bg-base-200 hover:bg-base-300' | |
| : 'bg-base-100 hover:bg-base-200' | |
| }`} | |
| > | |
| {row.getVisibleCells().map((cell) => ( | |
| <td | |
| key={cell.id} | |
| className={`${ | |
| cell.column.id === 'category' | |
| ? `sticky left-0 ${row.original.type === 'group' ? 'bg-base-200' : 'bg-base-100'} z-10` | |
| : 'font-medium text-center' | |
| } border-gray-700 border`} | |
| > | |
| {flexRender(cell.column.columnDef.cell, cell.getContext())} | |
| </td> | |
| ))} | |
| </tr> | |
| ))} | |
| </tbody> | |
| </table> | |
| </div> | |
| </> | |
| )} | |
| </div> | |
| )} | |
| </div> | |
| ) | |
| } | |
| export default LeaderboardTable | |