ace-1's picture
Upload handler.py
cef3e59 verified
from typing import Dict, Any, List, Union
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TextClassificationPipeline,
)
class EndpointHandler:
"""Custom handler for Hugging Face Inference Endpoints.
Loads a fine-tuned text-classification model and exposes a callable
that the endpoint runtime will invoke. The runtime will instantiate
this class once at startup, passing the model directory path.
"""
def __init__(self, path: str = "", **kwargs):
# `path` is the directory where the model artefacts are stored.
# Fallback to current directory if not provided (local testing).
model_dir = path or "."
# Load tokenizer & model
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
# Build a text-classification pipeline
self.pipeline = TextClassificationPipeline(
model=self.model,
tokenizer=self.tokenizer,
device=-1, # CPU; the runtime sets CUDA if available automatically
return_all_scores=False,
function_to_apply="sigmoid"
if getattr(self.model.config, "problem_type", None)
== "multi_label_classification"
else "softmax",
)
def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""Run inference on the incoming request.
Expected input format from the Inference Endpoint runtime:
{
"inputs": "some text" | ["text 1", "text 2", ...],
"parameters": { ... } # optional pipeline kwargs (e.g., top_k)
}
"""
# Extract the text(s)
raw_inputs = data.get("inputs", data)
if isinstance(raw_inputs, str):
raw_inputs = [raw_inputs]
# Additional pipeline parameters (optional)
parameters = data.get("parameters", {})
# Execute the pipeline
outputs = self.pipeline(raw_inputs, **parameters)
# If only one input was provided, return a single dict for convenience
if len(outputs) == 1:
return outputs[0]
return outputs