Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import logging | |
| import regex | |
| import time | |
| from pathlib import Path | |
| from typing import Annotated, Iterator | |
| import ijson | |
| import outlines | |
| import torch | |
| from pydantic import BaseModel, StringConstraints, conlist, conset | |
| from outlines import generate, models | |
| from outlines.generate.api import SequenceGenerator | |
| from transformers import AutoTokenizer | |
| from fsm import replace_fields | |
| from samplers import PenalizedMultinomialSampler | |
| from utils import StringIteratorIO | |
| logger = logging.getLogger(__name__) | |
| logger.warning("Loading model...") | |
| model_id = "google/gemma-2b-it" | |
| # model_id = "Qwen/Qwen1.5-0.5B-Chat" | |
| if torch.backends.mps.is_available(): | |
| device = "mps" | |
| model = models.transformers(model_id, device=device) | |
| else: | |
| device = "cuda" | |
| model = models.transformers(model_id, device=device) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| sampler = PenalizedMultinomialSampler() | |
| low_temperature_sampler = PenalizedMultinomialSampler(temperature=0.3) | |
| empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id], skip_special_tokens=True).strip()] | |
| sampler.set_max_repeats(empty_tokens, 1) | |
| disallowed_patterns = [regex.compile(r"\p{Han}")] # focus on english for now | |
| disallowed_tokens = [token_id for token_id in range(tokenizer.vocab_size) if any(pattern.match(tokenizer.decode([token_id], skip_special_tokens=True)) for pattern in disallowed_patterns)] | |
| sampler.set_max_repeats(disallowed_tokens, 0) | |
| # This Sample & Dataset models ztr just templated with placeholder fields | |
| class Sample(BaseModel): | |
| # We use get_samples_generator() to replace the placeholder with the requested fields | |
| ABCDabcd12: str | |
| EFGHefgh34: str | |
| IJKLijkl56: str | |
| MNOPmnop78: str | |
| QRSTqrst90: str | |
| # PS: don't use StringConstraints with max_length here since it creates a fsm that is too big | |
| class Dataset(BaseModel): | |
| # We use get_samples_generator() to set the length to infinity | |
| data: conlist(Sample, min_length=2, max_length=3) # type: ignore | |
| samples_generator_template = generate.json(model, Dataset, sampler=sampler) | |
| class Columns(BaseModel): | |
| columns: conset(Annotated[str, StringConstraints(pattern=r'[a-z0-9_]+')], min_length=2, max_length=len(Sample.model_fields) - 1) # type: ignore | |
| columns_generator = generate.json(model, Columns, sampler=low_temperature_sampler) | |
| def get_samples_generator(new_fields: list[str]) -> SequenceGenerator: | |
| fsm=samples_generator_template.fsm | |
| fsm = replace_fields( # replace the placeholder fields by the real fields | |
| fsm=samples_generator_template.fsm, | |
| model=Sample, | |
| new_fields=new_fields, | |
| tokenizer=tokenizer, | |
| make_infinite_loop=True # to generate as many samples as we want | |
| ) | |
| return SequenceGenerator( | |
| fsm=fsm, | |
| model=samples_generator_template.model, | |
| sampler=samples_generator_template.sampler, | |
| device=device | |
| ) | |
| def columns_prompt(filename: str): | |
| """I would like to create a JSON file named {{ filename }}.json for a dataset of realistic data. | |
| Give an example of column names / columns for this dataset to populate a SQL schema. | |
| Please reply in JSON format and place the columns in a field named "columns". | |
| """ | |
| def samples_prommpt(filename: str, prompt: str, columns: str): | |
| """I would like to create a JSON file named {{ filename }}.json for a dataset of realistic data. | |
| Give an example of content using a JSON field named "data" with samples with columns {{ columns }}. | |
| {{ prompt }} | |
| """ | |
| def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]: | |
| filename = Path(filename).stem | |
| logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})") | |
| _start = time.time() | |
| rng = torch.Generator(device=model.device) | |
| rng.manual_seed(seed) | |
| if not columns: | |
| messages = [ | |
| {"role": "user", "content": columns_prompt(filename=filename)} | |
| ] | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) Generating columns...") | |
| columns_generator_tokens = columns_generator.stream(text, rng=rng) | |
| for column in ijson.items(StringIteratorIO(columns_generator_tokens), "columns.item", buf_size=16): | |
| columns.append(column) | |
| logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) Generating columns... DONE (total={time.time() - _start:.02f}s)") | |
| columns = [ | |
| tokenizer.decode(tokenizer.encode(column, add_special_tokens=False)[:len(orig_field)], skip_special_tokens=True) | |
| for column, orig_field in zip(columns, Sample.model_fields) | |
| ] | |
| logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide...") | |
| samples_generator = get_samples_generator(new_fields=columns) | |
| logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide... DONE (total={time.time() - _start:.02f}s)") | |
| logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples...") | |
| messages = [ | |
| {"role": "user", "content": samples_prommpt(filename=filename, prompt=prompt, columns="'" + "', '".join(columns) + "'")} | |
| ] | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| samples_generator_tokens = samples_generator.stream(text, rng=rng) | |
| for _, sample in zip(range(size), ijson.items(StringIteratorIO(samples_generator_tokens), "data.item", buf_size=4)): | |
| yield json.dumps(sample, ensure_ascii=False) + "\n" | |
| logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples... DONE (total={time.time() - _start:.02f}s)") |