File size: 7,289 Bytes
697df9e
 
9f09d69
697df9e
 
 
2d1fbe3
f2ca0de
697df9e
 
 
 
 
 
 
f2ca0de
626d7e1
697df9e
9f09d69
 
 
 
 
 
 
697df9e
 
 
 
ac9f5bf
 
 
 
 
 
697df9e
7cae346
697df9e
 
8735941
ac9f5bf
 
697df9e
 
 
 
 
f2ca0de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
697df9e
0029bf1
5c01f6c
f5df766
cd49bb3
 
 
 
 
22a06c1
697df9e
f3598b9
773c7a0
f3598b9
5c01f6c
 
 
 
 
 
 
f5df766
5c01f6c
 
773c7a0
697df9e
f2ca0de
5c01f6c
f2ca0de
 
a3db2dc
 
 
f2ca0de
 
 
 
 
 
5c01f6c
f2ca0de
 
a3db2dc
 
 
 
 
f3598b9
 
 
59dae16
f3598b9
f8a90e7
 
a3db2dc
f3598b9
f8a90e7
 
0029bf1
a3db2dc
 
 
 
 
59dae16
 
 
 
 
f5df766
0029bf1
 
 
f2ca0de
a3db2dc
f2ca0de
 
5c01f6c
f2ca0de
697df9e
 
 
773c7a0
 
 
 
 
 
 
 
 
 
f2ca0de
 
a3db2dc
 
1409ad8
 
c41ff99
f8a90e7
206cc5e
 
 
 
 
773c7a0
1409ad8
697df9e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os
import torch
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, AutoModel
import chromadb
import gradio as gr
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score

# Mean Pooling - Take attention mask into account for correct averaging
def meanpooling(output, mask):
    embeddings = output[0]  # First element of model_output contains all token embeddings
    mask = mask.unsqueeze(-1).expand(embeddings.size()).float()
    return torch.sum(embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)

# Load the dataset
dataset = load_dataset("thankrandomness/mimic-iii")

# Split the dataset into train and validation sets
split_dataset = dataset['train'].train_test_split(test_size=0.2, seed=42)
dataset = DatasetDict({
    'train': split_dataset['train'],
    'validation': split_dataset['test']
})

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("neuml/pubmedbert-base-embeddings-matryoshka")
model = AutoModel.from_pretrained("neuml/pubmedbert-base-embeddings-matryoshka")

# Function to normalize embeddings to unit vectors
def normalize_embedding(embedding):
    norm = np.linalg.norm(embedding)
    return (embedding / norm).tolist() if norm > 0 else embedding

# Function to embed and normalize text
def embed_text(text):
    inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
    with torch.no_grad():
        output = model(**inputs)
    embeddings = meanpooling(output, inputs['attention_mask'])
    normalized_embeddings = normalize_embedding(embeddings.numpy())
    return normalized_embeddings

# Initialize ChromaDB client
client = chromadb.Client()
collection = client.create_collection(name="pubmedbert_matryoshka_embeddings")

# Function to upsert data into ChromaDB
def upsert_data(dataset_split):
    for i, row in enumerate(dataset_split):
        for note in row['notes']:
            text = note.get('text', '')
            annotations_list = []
            
            for annotation in note.get('annotations', []):
                try:
                    code = annotation['code']
                    code_system = annotation['code_system']
                    description = annotation['description']
                    annotations_list.append({"code": code, "code_system": code_system, "description": description})
                except KeyError as e:
                    print(f"Skipping annotation due to missing key: {e}")

            if text and annotations_list:
                embeddings = embed_text([text])[0]

                # Upsert data, embeddings, and annotations into ChromaDB
                for j, annotation in enumerate(annotations_list):
                    collection.upsert(
                        ids=[f"note_{note['note_id']}_{j}"],
                        embeddings=[embeddings],
                        metadatas=[annotation]
                    )
            else:
                print(f"Skipping note {note['note_id']} due to missing 'text' or 'annotations'")

# Upsert training data
upsert_data(dataset['train'])

# Define retrieval function with similarity threshold
def retrieve_relevant_text(input_text):
    input_embedding = embed_text([input_text])[0]
    results = collection.query(
        query_embeddings=[input_embedding],
        n_results=5,
        include=["metadatas", "documents", "distances"]
    )
    
    output = []
    #print("Retrieved items and their similarity scores:")
    for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
        #print(f"Code: {metadata['code']}, Similarity Score: {distance}")
        #if distance <= similarity_threshold:
        output.append({
            "similarity_score": distance,
            "code": metadata['code'],
            "code_system": metadata['code_system'],
            "description": metadata['description']
        })
    
    # if not output:
    #     print("No results met the similarity threshold.")
    return output

# Evaluate retrieval efficiency on the validation/test set
def evaluate_efficiency(dataset_split):
    y_true = []
    y_pred = []
    total_similarity = 0
    total_items = 0

    for i, row in enumerate(dataset_split):
        for note in row['notes']:
            text = note.get('text', '')
            annotations_list = [annotation['code'] for annotation in note.get('annotations', []) if 'code' in annotation]
            
            if text and annotations_list:
                retrieved_results = retrieve_relevant_text(text)
                retrieved_codes = [result['code'] for result in retrieved_results]
                
                # Sum up similarity scores for average calculation
                for result in retrieved_results:
                    total_similarity += result['similarity_score']
                    total_items += 1

                # Ground truth
                y_true.extend(annotations_list)
                # Predictions (limit to length of true annotations to avoid mismatch)
                y_pred.extend(retrieved_codes[:len(annotations_list)])
                
                # for result in retrieved_results:
                #     print(f"  Code: {result['code']}, Similarity Score: {result['similarity_score']:.2f}")

    # Debugging output to check for mismatches and understand results
    # print("Sample y_true:", y_true[:10])
    # print("Sample y_pred:", y_pred[:10])

    if total_items > 0:
        avg_similarity = total_similarity / total_items
    else:
        avg_similarity = 0

    if len(y_true) != len(y_pred):
        min_length = min(len(y_true), len(y_pred))
        y_true = y_true[:min_length]
        y_pred = y_pred[:min_length]
    
    # Calculate metrics
    precision = precision_score(y_true, y_pred, average='macro', zero_division=0)
    recall = recall_score(y_true, y_pred, average='macro', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    
    return precision, recall, f1, avg_similarity

# Calculate retrieval efficiency metrics
precision, recall, f1, avg_similarity = evaluate_efficiency(dataset['validation'])

# Gradio interface
def gradio_interface(input_text):
    results = retrieve_relevant_text(input_text)
    formatted_results = [
        f"Result {i + 1}:\n"
        f"Similarity Score: {result['similarity_score']:.2f}\n"
        f"Code: {result['code']}\n"
        f"Code System: {result['code_system']}\n"
        f"Description: {result['description']}\n"
        "-------------------"
        for i, result in enumerate(results)
    ]
    return "\n".join(formatted_results)

# Display retrieval efficiency metrics
# metrics = f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}"
metrics = f"Accuracy: {avg_similarity:.2f}"

with gr.Blocks() as interface:
    gr.Markdown("# Automated Medical Coding POC")
    # gr.Markdown(metrics)
    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(label="Input Text")
            submit_button = gr.Button("Submit")
        with gr.Column():
            text_output = gr.Textbox(label="Retrieved Results", lines=10)
    submit_button.click(fn=gradio_interface, inputs=text_input, outputs=text_output)

interface.launch()