TemplateA / model /model_runner.py
Dan Flower
deploy: sync model/utils into TemplateA and update Dockerfile (canonical COPY + cache-bust)
8667228
# model_runner.py
import os
import sys
from typing import List, Optional
from llama_cpp import Llama
print(f"[BOOT] model_runner from {__file__}", file=sys.stderr)
# ---- Phase 2: flags (no behavior change) ------------------------------------
# Reads LAB_* env toggles; all defaults preserve current behavior.
try:
from utils import flags # if your package path is different, adjust import
except Exception:
# Fallback inline flags if utils.flags isn't available in this lab
def _as_bool(val: Optional[str], default: bool) -> bool:
if val is None:
return default
return val.strip().lower() in {"1", "true", "yes", "on", "y", "t"}
class _F:
SANITIZE_ENABLED = _as_bool(os.getenv("LAB_SANITIZE_ENABLED"), False) # you don't sanitize today
STOPSEQ_ENABLED = _as_bool(os.getenv("LAB_STOPSEQ_ENABLED"), False) # extra stops only; defaults off
CRITIC_ENABLED = _as_bool(os.getenv("LAB_CRITIC_ENABLED"), False)
JSON_MODE = _as_bool(os.getenv("LAB_JSON_MODE"), False)
EVIDENCE_GATE = _as_bool(os.getenv("LAB_EVIDENCE_GATE"), False)
@staticmethod
def snapshot():
return {
"LAB_SANITIZE_ENABLED": _F.SANITIZE_ENABLED,
"LAB_STOPSEQ_ENABLED": _F.STOPSEQ_ENABLED,
"LAB_CRITIC_ENABLED": _F.CRITIC_ENABLED,
"LAB_JSON_MODE": _F.JSON_MODE,
"LAB_EVIDENCE_GATE": _F.EVIDENCE_GATE,
}
flags = _F()
print("[flags] snapshot:", getattr(flags, "snapshot", lambda: {} )(), file=sys.stderr)
# Optional sanitizer hook (kept no-op unless enabled later)
def _sanitize(text: str) -> str:
# Phase 2: default False -> no behavior change
if getattr(flags, "SANITIZE_ENABLED", False):
# TODO: wire your real sanitizer in Phase 3+
return text.strip()
return text
# Stop sequences: keep today's defaults ALWAYS.
# If LAB_STOPSEQ_ENABLED=true, add *extra* stops from STOP_SEQUENCES env (comma-separated).
DEFAULT_STOPS: List[str] = ["\nUser:", "\nAssistant:"]
def _extra_stops_from_env() -> List[str]:
if not getattr(flags, "STOPSEQ_ENABLED", False):
return []
raw = os.getenv("STOP_SEQUENCES", "")
toks = [t.strip() for t in raw.split(",") if t.strip()]
return toks
# ---- Model cache / load ------------------------------------------------------
_model = None # module-level cache
def load_model():
global _model
if _model is not None:
return _model
model_path = os.getenv("MODEL_PATH")
if not model_path or not os.path.exists(model_path):
raise ValueError(f"Model path does not exist or is not set: {model_path}")
print(f"[INFO] Loading model from {model_path}")
_model = Llama(
model_path=model_path,
n_ctx=1024, # short context to reduce memory use
n_threads=4, # number of CPU threads
n_gpu_layers=0 # CPU only (Hugging Face free tier)
)
return _model
# ---- Inference ---------------------------------------------------------------
def generate(prompt: str, max_tokens: int = 256) -> str:
model = load_model()
# Preserve existing default stops; optionally extend with extra ones if flag is on
stops = DEFAULT_STOPS + _extra_stops_from_env()
output = model(
prompt,
max_tokens=max_tokens,
stop=stops, # unchanged defaults; may include extra stops if enabled
echo=False,
temperature=0.7,
top_p=0.95,
)
raw_text = output["choices"][0]["text"]
# Preserve current manual truncation by the same default stops (kept intentionally)
# Extra stops are also applied here if enabled for consistency.
for stop_token in stops:
if stop_token and stop_token in raw_text:
raw_text = raw_text.split(stop_token)[0]
final = _sanitize(raw_text)
return final.strip()