File size: 4,977 Bytes
6bf206f
71f6ac4
 
6bf206f
 
712d59b
 
 
 
6bf206f
 
 
0ad549f
6bf206f
 
712d59b
6bf206f
 
 
 
71f6ac4
f3ce64a
f81b95a
 
 
 
f3ce64a
712d59b
6bf206f
bb9501b
 
 
 
6bf206f
bb9501b
6bf206f
 
 
 
 
 
 
 
 
 
 
 
 
 
0ad549f
6bf206f
 
 
 
 
 
 
 
f2e901b
712d59b
6bf206f
712d59b
6bf206f
712d59b
 
 
6bf206f
712d59b
6bf206f
712d59b
 
 
 
 
f2e901b
6bf206f
f2e901b
d678393
 
712d59b
d678393
 
712d59b
d678393
 
712d59b
d678393
 
712d59b
 
 
 
 
 
6bf206f
9c0c415
71f6ac4
63db693
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fd593b
 
 
 
 
 
 
 
 
63db693
9fd593b
63db693
9fd593b
 
 
 
 
 
 
 
 
63db693
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
token_r=os.environ['token_r']
token_w=os.environ['token_w']
import torch
import gradio as gr
from unsloth import FastLanguageModel
from peft import PeftConfig, PeftModel, get_peft_model
from transformers import pipeline, TextIteratorStreamer
from threading import Thread

# For getting tokenizer()
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
peft_model_adapter_id = "nttwt1597/test_v2_cancer_v4_checkpoint2900"

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_id, 
    max_seq_length = 4096,
    dtype = None,
    load_in_4bit = True,
)
model.load_adapter(peft_model_adapter_id, token=token_r)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

FastLanguageModel.for_inference(model) 

criteria_prompt = """Based on the provided instructions and clinical trial information, generate the eligibility criteria for the study.

### Instruction:
As a clinical researcher, generate comprehensive eligibility criteria to be used in clinical research based on the given clinical trial information. Ensure the criteria are clear, specific, and suitable for a clinical research setting.

### Clinical trial information:
{}

### Eligibility criteria:
{}"""

def format_prompt(text):
    return criteria_prompt.format(text, "")

def run_model_on_text(text):
  prompt = format_prompt(text)
  inputs = tokenizer(prompt, return_tensors='pt')

  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

  generation_kwargs = dict(inputs, streamer=streamer,eos_token_id=terminators, max_new_tokens=1024, repetition_penalty=1.175,)
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
  thread.start()

  generated_text = ""
  for new_text in streamer:
    generated_text += new_text
    yield generated_text

place_holder = f"""Study Objectives
The purpose of this study is to evaluate the safety, tolerance and efficacy of Liposomal Paclitaxel With Nedaplatin as First-line in patients with Advanced or Recurrent Esophageal Carcinoma 

Conditions: Esophageal Carcinoma 

Intervention / Treatment: 
DRUG: Liposomal Paclitaxel, 
DRUG: Nedaplatin 

Location: China 

Study Design and Phases 
Study Type: INTERVENTIONAL 
Phase: PHASE2 Primary Purpose: 
TREATMENT Allocation: NA 
Interventional Model: SINGLE_GROUP Masking: NONE
"""

prefilled_value = """Study Objectives
[Brief Summary] and/or [Detailed Description]

Conditions: [Disease]

Intervention / Treatment 
[DRUGs]

Location
[Location]

Study Design and Phases
Study Type:
Phase:
Primary Purpose:
Allocation: 
Interventional Model:
Masking:"""

   
hf_writer = gr.HuggingFaceDatasetSaver("criteria-feedback-demo",token_w, private=True)
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            prompt_box = gr.Textbox(
                label="Research Information",
                placeholder=place_holder,
                value=prefilled_value,
                lines=10)
            submit_button = gr.Button("Generate")
        with gr.Column():
            output_box = gr.Textbox(
                label="Eligiblecriteria Criteria",
                lines=21,
                interactive=False)
    with gr.Row():
        with gr.Column():
            feedback_box = gr.Textbox(label="Enter your feedback here...", lines=3, interactive=True)
            feedback_button = gr.Button("Submit Feedback")
            status_text = gr.Textbox(label="Status", lines=1, interactive=False)

    submit_button.click(
        run_model_on_text,
        inputs=prompt_box,
        outputs=output_box
    )   

    def submit_feedback(prompt, generated_text, feedback):
        data = {
            "prompt": prompt,
            "generated_text": generated_text,
            "feedback": feedback
        }
        hf_writer.flag(data)  
        return "Feedback submitted."

    feedback_button.click(
        submit_feedback,
        inputs=[prompt_box, output_box, feedback_box],
        outputs=status_text
    )

    # feedback_button.click(
    #     hf_writer.flag([prompt_box,output_box,feedback_box]),
    #     # lambda *args: hf_writer.flag(args),
    #     inputs=[prompt_box, output_box, feedback_box],
    #     outputs=status_text,
    #     )
    
    # gr.Interface(lambda x:x, "text", "text", allow_flagging="manual", flagging_callback=hf_writer)

    # feedback_button.click(
    #     save_feedback,
    #     inputs=[prompt_box, output_box, feedback_box],
    #     outputs=status_text
    # )

demo.launch()



#----------------------------------
# prompt_box = gr.Textbox(
#     lines=25,
#     label="Research Information",
#     placeholder=place_holder,
#     value=prefilled_value,
# )

# output_box = gr.Textbox(
#     lines=25,
#     label="Eligiblecriteria Criteria",
# )

# demo = gr.Interface(
#   fn=run_model_on_text,
#   inputs=prompt_box,
#   outputs=output_box,
#   allow_flagging='auto',
# )

# demo.queue(max_size=20).launch(debug=True, share=True)