# model_runner.py import os import sys from typing import List, Optional from llama_cpp import Llama print(f"[BOOT] model_runner from {__file__}", file=sys.stderr) # ---- Phase 2: flags (no behavior change) ------------------------------------ # Reads LAB_* env toggles; all defaults preserve current behavior. try: from utils import flags # if your package path is different, adjust import except Exception: # Fallback inline flags if utils.flags isn't available in this lab def _as_bool(val: Optional[str], default: bool) -> bool: if val is None: return default return val.strip().lower() in {"1", "true", "yes", "on", "y", "t"} class _F: SANITIZE_ENABLED = _as_bool(os.getenv("LAB_SANITIZE_ENABLED"), False) # you don't sanitize today STOPSEQ_ENABLED = _as_bool(os.getenv("LAB_STOPSEQ_ENABLED"), False) # extra stops only; defaults off CRITIC_ENABLED = _as_bool(os.getenv("LAB_CRITIC_ENABLED"), False) JSON_MODE = _as_bool(os.getenv("LAB_JSON_MODE"), False) EVIDENCE_GATE = _as_bool(os.getenv("LAB_EVIDENCE_GATE"), False) @staticmethod def snapshot(): return { "LAB_SANITIZE_ENABLED": _F.SANITIZE_ENABLED, "LAB_STOPSEQ_ENABLED": _F.STOPSEQ_ENABLED, "LAB_CRITIC_ENABLED": _F.CRITIC_ENABLED, "LAB_JSON_MODE": _F.JSON_MODE, "LAB_EVIDENCE_GATE": _F.EVIDENCE_GATE, } flags = _F() print("[flags] snapshot:", getattr(flags, "snapshot", lambda: {} )(), file=sys.stderr) # Optional sanitizer hook (kept no-op unless enabled later) def _sanitize(text: str) -> str: # Phase 2: default False -> no behavior change if getattr(flags, "SANITIZE_ENABLED", False): # TODO: wire your real sanitizer in Phase 3+ return text.strip() return text # Stop sequences: keep today's defaults ALWAYS. # If LAB_STOPSEQ_ENABLED=true, add *extra* stops from STOP_SEQUENCES env (comma-separated). DEFAULT_STOPS: List[str] = ["\nUser:", "\nAssistant:"] def _extra_stops_from_env() -> List[str]: if not getattr(flags, "STOPSEQ_ENABLED", False): return [] raw = os.getenv("STOP_SEQUENCES", "") toks = [t.strip() for t in raw.split(",") if t.strip()] return toks # ---- Model cache / load ------------------------------------------------------ _model = None # module-level cache def load_model(): global _model if _model is not None: return _model model_path = os.getenv("MODEL_PATH") if not model_path or not os.path.exists(model_path): raise ValueError(f"Model path does not exist or is not set: {model_path}") print(f"[INFO] Loading model from {model_path}") _model = Llama( model_path=model_path, n_ctx=1024, # short context to reduce memory use n_threads=4, # number of CPU threads n_gpu_layers=0 # CPU only (Hugging Face free tier) ) return _model # ---- Inference --------------------------------------------------------------- def generate(prompt: str, max_tokens: int = 256) -> str: model = load_model() # Preserve existing default stops; optionally extend with extra ones if flag is on stops = DEFAULT_STOPS + _extra_stops_from_env() output = model( prompt, max_tokens=max_tokens, stop=stops, # unchanged defaults; may include extra stops if enabled echo=False, temperature=0.7, top_p=0.95, ) raw_text = output["choices"][0]["text"] # Preserve current manual truncation by the same default stops (kept intentionally) # Extra stops are also applied here if enabled for consistency. for stop_token in stops: if stop_token and stop_token in raw_text: raw_text = raw_text.split(stop_token)[0] final = _sanitize(raw_text) return final.strip()