Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from generate import generate | |
| from transformers import AutoTokenizer, AutoModel | |
| def chat(): | |
| device = 'cuda' | |
| model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval() | |
| tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True) | |
| gen_length = 128 | |
| steps = 128 | |
| print('*' * 66) | |
| print(f'** Answer Length: {gen_length} | Sampling Steps: {steps} **') | |
| print('*' * 66) | |
| conversation_num = 0 | |
| while True: | |
| user_input = input("Enter your question: ") | |
| m = [{"role": "user", "content": user_input}] | |
| user_input = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) | |
| input_ids = tokenizer(user_input)['input_ids'] | |
| input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) | |
| if conversation_num == 0: | |
| prompt = input_ids | |
| else: | |
| prompt = torch.cat([prompt, input_ids[:, 1:]], dim=1) | |
| out = generate(model, prompt, steps=steps, gen_length=gen_length, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence') | |
| answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0] | |
| print(f"Bot's reply: {answer}") | |
| # remove the <EOS> | |
| prompt = out[out != 126081].unsqueeze(0) | |
| conversation_num += 1 | |
| print('-----------------------------------------------------------------------') | |
| if __name__ == "__main__": | |
| chat() | |