|
|
|
|
|
""" |
|
|
eval.py - Evaluation script for OLMoE models using lm-evaluation-harness |
|
|
|
|
|
This script supports evaluation of both: |
|
|
1. Standard Transformers OLMoE models |
|
|
2. Custom MyOLMoE models with modified routing |
|
|
|
|
|
Usage Examples: |
|
|
# Evaluate standard OLMoE model |
|
|
python eval.py --model_type transformers --tasks mmlu hellaswag |
|
|
|
|
|
# Evaluate custom MyOLMoE model with non-deterministic routing |
|
|
python eval.py --model_type custom --routing_type non_deterministic --tasks mmlu |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
import sys |
|
|
import logging |
|
|
from typing import Dict, List, Optional, Any |
|
|
import numpy as np |
|
|
import torch |
|
|
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
from lm_eval import evaluator |
|
|
from lm_eval.models.huggingface import HFLM |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
"""Parse command line arguments.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Evaluate OLMoE models using lm-evaluation-harness", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Examples: |
|
|
# Standard OLMoE evaluation |
|
|
python eval.py --model_type transformers --tasks mmlu arc_easy |
|
|
|
|
|
# Custom MyOLMoE with non-deterministic routing |
|
|
python eval.py --model_type custom --routing_type non_deterministic \\ |
|
|
--router_temperature 0.8 --tasks mmlu hellaswag |
|
|
|
|
|
# Dense routing evaluation |
|
|
python eval.py --model_type custom --routing_type dense --tasks gsm8k |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--model_path", |
|
|
type=str, |
|
|
default="allenai/OLMoE-1B-7B-0924", |
|
|
help="Path or name of the pretrained model" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model_type", |
|
|
type=str, |
|
|
default="transformers", |
|
|
choices=["transformers", "custom"], |
|
|
help="Model type: 'transformers' for standard OLMoE, 'custom' for MyOLMoE" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--custom_model_path", |
|
|
type=str, |
|
|
default="./myolmoe_model", |
|
|
help="Path to custom MyOLMoE model code (when using --model_type custom)" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--routing_type", |
|
|
type=str, |
|
|
default="topk", |
|
|
choices=["topk", "multinomial", "botk", "topk+botk", "nth-descending", "depthconstant", "depthlatter"], |
|
|
help="Routing type (only used with custom models)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--router_temperature", |
|
|
type=float, |
|
|
default=1.0, |
|
|
help="Temperature for non-deterministic routing" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_experts_per_tok", |
|
|
type=int, |
|
|
help="Number of experts per token" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--tasks", |
|
|
type=str, |
|
|
nargs="+", |
|
|
default=["mmlu"], |
|
|
help="Tasks to evaluate on (e.g., mmlu, hellaswag, arc_easy, gsm8k)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_fewshot", |
|
|
type=int, |
|
|
default=0, |
|
|
help="Number of few-shot examples" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--batch_size", |
|
|
type=int, |
|
|
default=8, |
|
|
help="Batch size for evaluation" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_batch_size", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Maximum batch size (auto if None)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default="auto", |
|
|
help="Device to use ('auto', 'cuda', 'cpu')" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dtype", |
|
|
type=str, |
|
|
default="auto", |
|
|
choices=["auto", "float16", "bfloat16", "float32"], |
|
|
help="Data type for model weights" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--output_dir", |
|
|
type=str, |
|
|
default="./eval_results", |
|
|
help="Directory to save evaluation results" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_filename", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Custom filename for results (auto-generated if not provided)" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--limit", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Limit number of examples per task (for testing)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--write_out", |
|
|
action="store_true", |
|
|
help="Write out individual predictions to files" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--trust_remote_code", |
|
|
action="store_true", |
|
|
help="Trust remote code when loading model" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--verbosity", |
|
|
type=str, |
|
|
default="INFO", |
|
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"], |
|
|
help="Logging verbosity level" |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def load_transformers_model(args) -> HFLM: |
|
|
""" |
|
|
Load standard Transformers OLMoE model. |
|
|
|
|
|
Args: |
|
|
args: Parsed command line arguments |
|
|
|
|
|
Returns: |
|
|
HFLM: Wrapped model ready for evaluation |
|
|
""" |
|
|
logger.info(f"Loading Transformers OLMoE model: {args.model_path}") |
|
|
|
|
|
|
|
|
model = HFLM( |
|
|
pretrained=args.model_path, |
|
|
device=args.device, |
|
|
batch_size=args.batch_size, |
|
|
max_batch_size=args.max_batch_size, |
|
|
dtype=args.dtype, |
|
|
trust_remote_code=args.trust_remote_code |
|
|
) |
|
|
|
|
|
logger.info("Transformers model loaded successfully") |
|
|
return model |
|
|
|
|
|
|
|
|
def load_custom_model(args) -> HFLM: |
|
|
""" |
|
|
Load custom MyOLMoE model with routing configuration. |
|
|
|
|
|
Args: |
|
|
args: Parsed command line arguments |
|
|
|
|
|
Returns: |
|
|
HFLM: Wrapped model ready for evaluation |
|
|
""" |
|
|
logger.info(f"Loading custom MyOLMoE model: {args.model_path}") |
|
|
logger.info(f"Routing configuration: {args.routing_type}") |
|
|
|
|
|
|
|
|
if os.path.exists(args.custom_model_path): |
|
|
sys.path.insert(0, args.custom_model_path) |
|
|
logger.info(f"Added {args.custom_model_path} to Python path") |
|
|
else: |
|
|
logger.warning(f"Custom model path not found: {args.custom_model_path}") |
|
|
|
|
|
try: |
|
|
|
|
|
from modeling_myolmoe import MyOlmoeForCausalLM |
|
|
logger.info("Successfully imported MyOlmoeForCausalLM") |
|
|
except ImportError as e: |
|
|
logger.error(f"Failed to import custom model: {e}") |
|
|
logger.error("Make sure the custom model code is available in the specified path") |
|
|
raise |
|
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained( |
|
|
args.model_path, |
|
|
trust_remote_code=args.trust_remote_code |
|
|
) |
|
|
|
|
|
|
|
|
config.routing_type = args.routing_type |
|
|
config.router_temperature = args.router_temperature |
|
|
config.num_experts_per_tok = args.num_experts_per_tok |
|
|
|
|
|
logger.info(f"Model config updated:") |
|
|
logger.info(f" - routing_type: {config.routing_type}") |
|
|
logger.info(f" - router_temperature: {config.router_temperature}") |
|
|
logger.info(f" - num_experts_per_tok: {config.num_experts_per_tok}") |
|
|
|
|
|
|
|
|
if args.dtype == "auto": |
|
|
torch_dtype = "auto" |
|
|
else: |
|
|
torch_dtype = { |
|
|
"float16": torch.float16, |
|
|
"bfloat16": torch.bfloat16, |
|
|
"float32": torch.float32 |
|
|
}[args.dtype] |
|
|
|
|
|
|
|
|
hf_model = MyOlmoeForCausalLM.from_pretrained( |
|
|
args.model_path, |
|
|
config=config, |
|
|
torch_dtype=torch_dtype, |
|
|
device_map="auto" if args.device == "auto" else None, |
|
|
trust_remote_code=args.trust_remote_code |
|
|
).eval() |
|
|
|
|
|
|
|
|
model = HFLM( |
|
|
pretrained=hf_model, |
|
|
device=args.device, |
|
|
batch_size=args.batch_size, |
|
|
max_batch_size=args.max_batch_size, |
|
|
dtype=args.dtype |
|
|
) |
|
|
|
|
|
logger.info("Custom model loaded successfully") |
|
|
return model |
|
|
|
|
|
|
|
|
def validate_model_config(model_path: str, trust_remote_code: bool = False) -> Dict[str, Any]: |
|
|
""" |
|
|
Validate model configuration and return key information. |
|
|
|
|
|
Args: |
|
|
model_path: Path to the model |
|
|
trust_remote_code: Whether to trust remote code |
|
|
|
|
|
Returns: |
|
|
Dict containing model configuration information |
|
|
""" |
|
|
try: |
|
|
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=trust_remote_code) |
|
|
|
|
|
model_info = { |
|
|
"model_type": getattr(config, "model_type", "unknown"), |
|
|
"vocab_size": getattr(config, "vocab_size", "unknown"), |
|
|
"hidden_size": getattr(config, "hidden_size", "unknown"), |
|
|
"num_layers": getattr(config, "num_hidden_layers", "unknown"), |
|
|
"num_experts": getattr(config, "num_experts", "not specified"), |
|
|
"routing_type": getattr(config, "routing_type", "default"), |
|
|
} |
|
|
|
|
|
logger.info("Model validation successful:") |
|
|
for key, value in model_info.items(): |
|
|
logger.info(f" {key}: {value}") |
|
|
|
|
|
return model_info |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Model validation failed: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
def make_serializable(obj: Any) -> Any: |
|
|
""" |
|
|
Convert objects to JSON-serializable format. |
|
|
|
|
|
Args: |
|
|
obj: Object to convert |
|
|
|
|
|
Returns: |
|
|
JSON-serializable version of the object |
|
|
""" |
|
|
if isinstance(obj, dict): |
|
|
return {k: make_serializable(v) for k, v in obj.items()} |
|
|
elif isinstance(obj, list): |
|
|
return [make_serializable(v) for v in obj] |
|
|
elif isinstance(obj, tuple): |
|
|
return tuple(make_serializable(v) for v in obj) |
|
|
elif isinstance(obj, (np.integer, np.floating)): |
|
|
return obj.item() |
|
|
elif isinstance(obj, np.dtype): |
|
|
return str(obj) |
|
|
elif isinstance(obj, torch.Tensor): |
|
|
return obj.tolist() |
|
|
elif isinstance(obj, torch.dtype): |
|
|
return str(obj) |
|
|
else: |
|
|
return obj |
|
|
|
|
|
|
|
|
def run_evaluation(args) -> Dict[str, Any]: |
|
|
""" |
|
|
Run evaluation on the specified model. |
|
|
|
|
|
Args: |
|
|
args: Parsed command line arguments |
|
|
|
|
|
Returns: |
|
|
Dict containing evaluation results |
|
|
""" |
|
|
logger.info("Starting evaluation...") |
|
|
|
|
|
|
|
|
validate_model_config(args.model_path, args.trust_remote_code) |
|
|
|
|
|
|
|
|
if args.model_type == "transformers": |
|
|
model = load_transformers_model(args) |
|
|
elif args.model_type == "custom": |
|
|
model = load_custom_model(args) |
|
|
else: |
|
|
raise ValueError(f"Unknown model type: {args.model_type}") |
|
|
|
|
|
|
|
|
logger.info(f"Running evaluation on tasks: {args.tasks}") |
|
|
logger.info(f"Few-shot examples: {args.num_fewshot}") |
|
|
logger.info(f"Batch size: {args.batch_size}") |
|
|
|
|
|
results = evaluator.simple_evaluate( |
|
|
model=model, |
|
|
tasks=args.tasks, |
|
|
num_fewshot=args.num_fewshot, |
|
|
limit=args.limit, |
|
|
write_out=args.write_out, |
|
|
) |
|
|
|
|
|
logger.info("Evaluation completed successfully") |
|
|
return results |
|
|
|
|
|
|
|
|
def save_results(results: Dict[str, Any], args) -> str: |
|
|
""" |
|
|
Save evaluation results to file. |
|
|
|
|
|
Args: |
|
|
results: Evaluation results |
|
|
args: Parsed command line arguments |
|
|
|
|
|
Returns: |
|
|
str: Path to saved results file |
|
|
""" |
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
if args.output_filename is None: |
|
|
model_name = os.path.basename(args.model_path.rstrip('/')) |
|
|
tasks_str = "_".join(args.tasks[:3]) |
|
|
if len(args.tasks) > 3: |
|
|
tasks_str += f"_and_{len(args.tasks)-3}_more" |
|
|
|
|
|
if args.model_type == "custom": |
|
|
filename = f"{model_name}_{args.routing_type}_{tasks_str}_results.json" |
|
|
else: |
|
|
filename = f"{model_name}_transformers_{tasks_str}_results.json" |
|
|
else: |
|
|
filename = args.output_filename |
|
|
|
|
|
if not filename.endswith('.json'): |
|
|
filename += '.json' |
|
|
|
|
|
output_path = os.path.join(args.output_dir, filename) |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"model_path": args.model_path, |
|
|
"model_type": args.model_type, |
|
|
"tasks": args.tasks, |
|
|
"num_fewshot": args.num_fewshot, |
|
|
"batch_size": args.batch_size, |
|
|
"device": args.device, |
|
|
"dtype": args.dtype, |
|
|
"limit": args.limit, |
|
|
} |
|
|
|
|
|
|
|
|
if args.model_type == "custom": |
|
|
metadata.update({ |
|
|
"routing_type": args.routing_type, |
|
|
"router_temperature": args.router_temperature, |
|
|
"num_experts_per_tok": args.num_experts_per_tok, |
|
|
}) |
|
|
|
|
|
results_with_metadata = { |
|
|
"metadata": metadata, |
|
|
"results": results |
|
|
} |
|
|
|
|
|
|
|
|
serializable_results = make_serializable(results_with_metadata) |
|
|
|
|
|
|
|
|
with open(output_path, 'w') as f: |
|
|
json.dump(serializable_results, f, indent=2) |
|
|
|
|
|
logger.info(f"Results saved to {output_path}") |
|
|
return output_path |
|
|
|
|
|
|
|
|
def print_summary(results: Dict[str, Any], args) -> None: |
|
|
""" |
|
|
Print a formatted summary of evaluation results. |
|
|
|
|
|
Args: |
|
|
results: Evaluation results |
|
|
args: Parsed command line arguments |
|
|
""" |
|
|
print(f"\n{'='*80}") |
|
|
print(f"EVALUATION SUMMARY") |
|
|
print(f"Model: {args.model_path}") |
|
|
print(f"Type: {args.model_type.upper()}") |
|
|
if args.model_type == "custom": |
|
|
print(f"Routing: {args.routing_type.upper()}") |
|
|
print(f"Tasks: {', '.join(args.tasks)}") |
|
|
print(f"{'='*80}") |
|
|
|
|
|
if "results" in results: |
|
|
for task, metrics in results["results"].items(): |
|
|
if isinstance(metrics, dict): |
|
|
print(f"\n📊 {task.upper()}:") |
|
|
for metric, value in metrics.items(): |
|
|
if isinstance(value, (int, float)) and not metric.endswith('_stderr'): |
|
|
stderr_key = f"{metric}_stderr" |
|
|
stderr = metrics.get(stderr_key, 0) |
|
|
print(f" {metric:.<20} {value:.4f} (±{stderr:.4f})") |
|
|
else: |
|
|
print("\n⚠️ No results found in evaluation output") |
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main evaluation function.""" |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
numeric_level = getattr(logging, args.verbosity.upper(), None) |
|
|
if isinstance(numeric_level, int): |
|
|
logging.getLogger().setLevel(numeric_level) |
|
|
logger.setLevel(numeric_level) |
|
|
|
|
|
try: |
|
|
logger.info("="*80) |
|
|
logger.info("Starting OLMoE Model Evaluation") |
|
|
logger.info("="*80) |
|
|
|
|
|
|
|
|
results = run_evaluation(args) |
|
|
|
|
|
|
|
|
output_path = save_results(results, args) |
|
|
|
|
|
|
|
|
print_summary(results, args) |
|
|
|
|
|
logger.info(f"✅ Evaluation completed successfully!") |
|
|
logger.info(f"📁 Results saved to: {output_path}") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
logger.info("Evaluation interrupted by user") |
|
|
sys.exit(1) |
|
|
except Exception as e: |
|
|
logger.error(f"❌ Evaluation failed: {e}") |
|
|
logger.debug("Full traceback:", exc_info=True) |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |