Commit 
							
							·
						
						1357761
	
1
								Parent(s):
							
							31c9df4
								
Update cyg_conversation.py
Browse files- cyg_conversation.py +24 -0
    	
        cyg_conversation.py
    CHANGED
    
    | @@ -126,6 +126,30 @@ conv_templates = { | |
| 126 | 
             
                "bair_v1": conv_bair_v1,
         | 
| 127 | 
             
            }
         | 
| 128 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 129 |  | 
| 130 | 
             
            if __name__ == "__main__":
         | 
| 131 | 
             
                print(default_conversation.get_prompt())
         | 
|  | |
| 126 | 
             
                "bair_v1": conv_bair_v1,
         | 
| 127 | 
             
            }
         | 
| 128 |  | 
| 129 | 
            +
            def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token):
         | 
| 130 | 
            +
                conv = default_conversation.copy()
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                conv.append_message(conv.roles[1], None)
         | 
| 133 | 
            +
                conv.append_message(conv.roles[0], text)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                example = tokenizer.encode_plus(f"{conv.get_prompt()}", None, max_length=None)['input_ids']
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                while(len(history) > 0 and (len(example) < max_token)):
         | 
| 138 | 
            +
                    tmp = history.pop()
         | 
| 139 | 
            +
                    if tmp[0] == 'ASSISTANT':
         | 
| 140 | 
            +
                        conv.append_message(conv.roles[1], tmp[1])
         | 
| 141 | 
            +
                    else:
         | 
| 142 | 
            +
                        conv.append_message(conv.roles[0], tmp[1])
         | 
| 143 | 
            +
                    example = tokenizer.encode_plus(f"{conv.get_prompt()}", None, max_length=None)['input_ids']
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                if len(example) >= max_token:
         | 
| 146 | 
            +
                    conv.messages.pop()
         | 
| 147 | 
            +
                conv.messages = conv.messages[::-1]
         | 
| 148 | 
            +
                print('model in:', conv.get_prompt())
         | 
| 149 | 
            +
                example = tokenizer.encode_plus(f"{conv.get_prompt()}", None, max_length=None)['input_ids']
         | 
| 150 | 
            +
                example = example[1:-1]
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                return example
         | 
| 153 |  | 
| 154 | 
             
            if __name__ == "__main__":
         | 
| 155 | 
             
                print(default_conversation.get_prompt())
         | 
