Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	lower columns temperature
Browse files- generate.py +5 -2
- samplers.py +1 -0
    	
        generate.py
    CHANGED
    
    | @@ -2,6 +2,7 @@ | |
| 2 | 
             
            import json
         | 
| 3 | 
             
            import logging
         | 
| 4 | 
             
            import time
         | 
|  | |
| 5 | 
             
            from typing import Annotated, Iterator
         | 
| 6 |  | 
| 7 | 
             
            import ijson
         | 
| @@ -31,6 +32,7 @@ else: | |
| 31 |  | 
| 32 | 
             
            tokenizer = AutoTokenizer.from_pretrained(model_id)
         | 
| 33 | 
             
            sampler = PenalizedMultinomialSampler()
         | 
|  | |
| 34 | 
             
            empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id]).strip()]
         | 
| 35 | 
             
            sampler.set_max_repeats(empty_tokens, 1)
         | 
| 36 |  | 
| @@ -56,7 +58,7 @@ samples_generator_template = generate.json(model, Dataset, sampler=sampler) | |
| 56 | 
             
            class Columns(BaseModel):
         | 
| 57 | 
             
                columns: conset(Annotated[str, StringConstraints(pattern=r'[a-z0-9_]+')], min_length=2, max_length=len(Sample.model_fields) - 1)  # type: ignore
         | 
| 58 |  | 
| 59 | 
            -
            columns_generator = generate.json(model, Columns, sampler= | 
| 60 |  | 
| 61 | 
             
            def get_samples_generator(new_fields: list[str]) -> SequenceGenerator:
         | 
| 62 | 
             
                fsm=samples_generator_template.fsm
         | 
| @@ -89,7 +91,8 @@ def samples_prommpt(filename: str, prompt: str, columns: str): | |
| 89 | 
             
                {{ prompt }}
         | 
| 90 | 
             
                """
         | 
| 91 |  | 
| 92 | 
            -
            def  | 
|  | |
| 93 | 
             
                logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
         | 
| 94 | 
             
                _start = time.time()
         | 
| 95 | 
             
                rng = torch.Generator(device=model.device)
         | 
|  | |
| 2 | 
             
            import json
         | 
| 3 | 
             
            import logging
         | 
| 4 | 
             
            import time
         | 
| 5 | 
            +
            from pathlib import Path
         | 
| 6 | 
             
            from typing import Annotated, Iterator
         | 
| 7 |  | 
| 8 | 
             
            import ijson
         | 
|  | |
| 32 |  | 
| 33 | 
             
            tokenizer = AutoTokenizer.from_pretrained(model_id)
         | 
| 34 | 
             
            sampler = PenalizedMultinomialSampler()
         | 
| 35 | 
            +
            low_temperature_sampler = PenalizedMultinomialSampler(temperature=0.3)
         | 
| 36 | 
             
            empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id]).strip()]
         | 
| 37 | 
             
            sampler.set_max_repeats(empty_tokens, 1)
         | 
| 38 |  | 
|  | |
| 58 | 
             
            class Columns(BaseModel):
         | 
| 59 | 
             
                columns: conset(Annotated[str, StringConstraints(pattern=r'[a-z0-9_]+')], min_length=2, max_length=len(Sample.model_fields) - 1)  # type: ignore
         | 
| 60 |  | 
| 61 | 
            +
            columns_generator = generate.json(model, Columns, sampler=low_temperature_sampler)
         | 
| 62 |  | 
| 63 | 
             
            def get_samples_generator(new_fields: list[str]) -> SequenceGenerator:
         | 
| 64 | 
             
                fsm=samples_generator_template.fsm
         | 
|  | |
| 91 | 
             
                {{ prompt }}
         | 
| 92 | 
             
                """
         | 
| 93 |  | 
| 94 | 
            +
            def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]:
         | 
| 95 | 
            +
                filename = Path(filename).stem
         | 
| 96 | 
             
                logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})")
         | 
| 97 | 
             
                _start = time.time()
         | 
| 98 | 
             
                rng = torch.Generator(device=model.device)
         | 
    	
        samplers.py
    CHANGED
    
    | @@ -6,6 +6,7 @@ from outlines.samplers import MultinomialSampler | |
| 6 |  | 
| 7 | 
             
            logger = logging.getLogger(__name__)
         | 
| 8 |  | 
|  | |
| 9 | 
             
            class PenalizedMultinomialSampler(MultinomialSampler):
         | 
| 10 |  | 
| 11 | 
             
                def __init__(self, **kwargs):
         | 
|  | |
| 6 |  | 
| 7 | 
             
            logger = logging.getLogger(__name__)
         | 
| 8 |  | 
| 9 | 
            +
             | 
| 10 | 
             
            class PenalizedMultinomialSampler(MultinomialSampler):
         | 
| 11 |  | 
| 12 | 
             
                def __init__(self, **kwargs):
         | 
