import gradio as gr from transformers import AutoConfig from huggingface_hub import list_models import asyncio from typing import List import time from functools import lru_cache # Credits: This implementation is derived from and builds upon the excellent work by gaunernst # Original implementation: https://huggingface.co/spaces/gaunernst/kv-cache-calculator search_cache = {} POPULAR_MODELS = [ "Qwen/Qwen3-30B-A3B", "meta-llama/Llama-3.1-8B-Instruct", "meta-llama/Llama-3.1-70B-Instruct", "microsoft/DialoGPT-medium", "microsoft/DialoGPT-large", "mistralai/Mistral-7B-Instruct-v0.3", "mistralai/Mixtral-8x7B-Instruct-v0.1", "deepseek-ai/DeepSeek-V2-Chat", "deepseek-ai/DeepSeek-V3-Base", "google/gemma-2-9b", "google/gemma-2-27b", "Qwen/QwQ-32B-Preview", "Qwen/Qwen2.5-72B-Instruct", "anthropic/claude-3-haiku-20240307", ] def search_models(query: str, max_results: int = 50) -> List[str]: if not query or len(query.strip()) < 1: return POPULAR_MODELS[:15] query = query.strip() cache_key = f"{query.lower()}_{max_results}" current_time = time.time() if cache_key in search_cache: cached_result, cache_time = search_cache[cache_key] if current_time - cache_time < 300: return cached_result try: print(f"Searching HF Hub for: {query}") models = list_models( search=query, task="text-generation", library="transformers", sort="downloads", direction=-1, limit=max_results * 2, full=False ) all_matches = [] seen_models = set() for model in POPULAR_MODELS: if query.lower() in model.lower() and model not in seen_models: all_matches.append(model) seen_models.add(model) for model in models: if model.id not in seen_models and len(all_matches) < max_results: all_matches.append(model.id) seen_models.add(model.id) if len(all_matches) < max_results // 2: try: broader_models = list_models( search=query, library="transformers", sort="downloads", direction=-1, limit=max_results * 2 ) for model in broader_models: if model.id not in seen_models and len(all_matches) < max_results: model_id_lower = model.id.lower() if any(keyword in model_id_lower for keyword in ['chat', 'instruct', 'base', 'model']): all_matches.append(model.id) seen_models.add(model.id) except Exception as e: print(f"Broader search failed: {e}") result = all_matches[:max_results] search_cache[cache_key] = (result, current_time) if len(search_cache) > 20: oldest_key = min(search_cache.keys(), key=lambda k: search_cache[k][1]) del search_cache[oldest_key] return result except Exception as e: print(f"Search error: {e}") popular_matches = [model for model in POPULAR_MODELS if query.lower() in model.lower()] return popular_matches if popular_matches else POPULAR_MODELS[:15] def calculate(name: str, ctx_len: int, num_users: int, dtype: str, hf_token: str): hf_token = hf_token.strip() try: cfg = AutoConfig.from_pretrained( name, trust_remote_code=True, token=hf_token or None, ) except Exception as e: raise gr.Error(e) use_mla = cfg.architectures[0].startswith(("DeepseekV2", "DeepseekV3")) if hasattr(cfg, "text_config"): cfg = cfg.text_config num_layers = cfg.num_hidden_layers num_attention_heads = cfg.num_attention_heads num_kv_heads = getattr(cfg, "num_key_value_heads", num_attention_heads) if use_mla: attention_type = "MLA" elif num_kv_heads == num_attention_heads: attention_type = "MHA" else: attention_type = "GQA" model_config = [ ["num_layers", num_layers], ["max_ctx_len", cfg.max_position_embeddings], ["attention_type", attention_type], ["num_attention_heads", num_attention_heads], ["num_kv_heads", num_kv_heads], ] if ctx_len > cfg.max_position_embeddings: gr.Warning( "Requested context length is larger than the max value supported by the model" ) if use_mla: kv_lora_rank = cfg.kv_lora_rank qk_rope_head_dim = cfg.qk_rope_head_dim nelems_per_token = num_layers * (kv_lora_rank + qk_rope_head_dim) model_config.append(["kv_lora_rank", kv_lora_rank]) model_config.append(["qk_rope_head_dim", qk_rope_head_dim]) model_config.append(["calc_formula", f"{num_layers} * ({kv_lora_rank} + {qk_rope_head_dim})"]) else: head_dim = getattr(cfg, "head_dim", cfg.hidden_size // num_attention_heads) nelems_per_token = num_layers * num_kv_heads * head_dim * 2 model_config.append(["head_dim", head_dim]) if attention_type == "GQA": kv_ratio = num_attention_heads // num_kv_heads model_config.append(["gqa_ratio", f"{kv_ratio}:1"]) model_config.append(["calc_formula", f"{num_layers} * {num_kv_heads} * {head_dim} * 2"]) if dtype == "fp16/bf16": nbytes_per_elem = 2 elif dtype == "fp8": nbytes_per_elem = 1 + 2 / cfg.hidden_size # assume per-token scaling elif dtype == "fp4": nbytes_per_elem = 0.5 + 2 / 32 # 4-bit weights + scaling factor every 32 elements (MXFP4) kv_cache_size = nelems_per_token * ctx_len * num_users * nbytes_per_elem / 1e9 return kv_cache_size, model_config DESCRIPTION = ( "Calculate KV cache memory requirements for transformer models. " "Supports MHA, GQA, and MLA attention mechanisms with fp16/bf16, fp8, and fp4 data types." ) def search_and_update_models(query): if not query or len(query.strip()) < 2: return gr.Dropdown(choices=POPULAR_MODELS) search_results = search_models(query.strip(), max_results=50) if query.strip() not in search_results: search_results.insert(0, query.strip()) return gr.Dropdown(choices=search_results, value=query.strip()) with gr.Blocks(title="KV Cache Calculator", theme=gr.themes.Soft()) as demo: gr.Markdown("# KV Cache Calculator") gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): model_search = gr.Textbox( label="🔍 Search Models", placeholder="Type model name (e.g., llama, qwen, mistral...)", value="Qwen/Qwen3-30B-A3B", info="Search the entire HuggingFace Hub database" ) model_dropdown = gr.Dropdown( label="📋 Select Model", choices=POPULAR_MODELS, value="Qwen/Qwen3-30B-A3B", allow_custom_value=True, info="Models matching your search - or type a custom model ID" ) with gr.Row(): gr.Markdown("**💡 Tip:** Search updates the dropdown with real HF Hub results") ctx_len = gr.Number(label="Context Length", value=128_000, minimum=1) num_users = gr.Number(label="Number of Users", value=1, minimum=1) dtype = gr.Dropdown( label="KV Cache Data Type", choices=["fp16/bf16", "fp8", "fp4"], value="fp16/bf16" ) hf_token = gr.Textbox( label="HuggingFace Token (optional)", type="password", placeholder="For gated models" ) calculate_btn = gr.Button("Calculate KV Cache Size", variant="primary") with gr.Column(): cache_size = gr.Number(label="KV Cache Size (GB)", precision=2) model_config = gr.Dataframe( label="Model Configuration", headers=["Parameter", "Value"], datatype=["str", "str"], wrap=True ) model_search.change( fn=search_and_update_models, inputs=[model_search], outputs=[model_dropdown], show_progress=False ) calculate_btn.click( fn=calculate, inputs=[model_dropdown, ctx_len, num_users, dtype, hf_token], outputs=[cache_size, model_config] ) demo.css = """ .gradio-container { max-width: 1000px !important; margin: 0 auto !important; } """ if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, allowed_paths=[], app_kwargs={"docs_url": None, "redoc_url": None} )