File size: 1,922 Bytes
97d9cf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

def organize_messages(message, history):
    msg_ls = [dict(
        role = "system",
        content = "You are a helpful assistant.",
    )]
    for user, assistant in history:
        msg_ls.append(dict(
            role = "user",
            content = user,
        ))
        if assistant:
            msg_ls.append(dict(
                role = "assistant",
                content = assistant,
            ))
    msg_ls.append(dict(
        role = "user",
        content = message,
    ))
    return msg_ls

def stream2display_text(stream_text, token_per_sec):
    if stream_text.startswith("think>"):
        stream_text = f"<{stream_text}"
    
    if not stream_text.startswith("<think>"):
        return stream_text
    
    if not "</think>" in stream_text:
        think_text, result_text = stream_text.replace("<think>", ""), ""
    else:
        think_text, result_text = stream_text.split("</think>")
        think_text = think_text.replace("<think>", "")
    
    result_text = result_text.replace("<|im_end|>", "")

    think_block = "\n".join(f"> {line}" if line else ">" for line in think_text.rstrip().splitlines())
    # display_text = f"{think_block}\n\n{result_text}"

    display_text_ls = [think_block]
    if result_text:
        display_text_ls.append(f"{result_text}")
    display_text_ls.append(f"```{token_per_sec:.2f} token/s```")

    display_text = "\n\n".join(display_text_ls)

    return display_text

def mtp_new_tokens(pred_ids, gen_tk_count, existing_tk_count, stop_token_ids):
    output_ids = pred_ids[0][existing_tk_count:]

    if stop_token_ids:
        stop_token_ids_index = [
            i
            for i, id in enumerate(output_ids)
            if id in stop_token_ids
        ]
        if len(stop_token_ids_index) > 0:
            output_ids = output_ids[: stop_token_ids_index[0]]
    new_tokens = output_ids[gen_tk_count:]
    
    return new_tokens, len(output_ids)