medchat / data_loader.py
vihashini-18
i
0a5c991
"""
Module to load and prepare medical data from Hugging Face
"""
import pandas as pd
from datasets import load_dataset
import re
def clean_text(text):
"""Clean and preprocess text"""
if pd.isna(text):
return ""
# Remove extra whitespaces
text = re.sub(r'\s+', ' ', str(text))
# Remove special characters but keep medical terms
text = re.sub(r'[^\w\s\.\,\?\!\-\:]', '', text)
return text.strip()
def load_medical_datasets():
"""
Load medical datasets from Hugging Face MultiMedQA collection
Reference: https://huggingface.co/collections/openlifescienceai/multimedqa
Returns a list of medical text documents
"""
print("Loading MultiMedQA datasets from Hugging Face...")
print("Source: https://huggingface.co/collections/openlifescienceai/multimedqa")
documents = []
# Comprehensive list of medical datasets from Hugging Face
# Reference: https://huggingface.co/collections/openlifescienceai/multimedqa
# Reference: https://huggingface.co/collections/openlifescienceai/life-science-health-and-medical-models-for-ml
datasets_to_load = [
# MMLU Medical Datasets
("openlifescienceai/mmlu_clinical_knowledge", 299),
("openlifescienceai/mmlu_college_medicine", 200),
("openlifescienceai/mmlu_college_biology", 165),
("openlifescienceai/mmlu_professional_medicine", 308),
("openlifescienceai/mmlu_anatomy", 154),
("openlifescienceai/mmlu_medical_genetics", 116),
# Medical QA Datasets
("openlifescienceai/pubmedqa", 2000),
("openlifescienceai/medmcqa", 5000),
("openlifescienceai/medqa", 2000),
# Additional medical datasets
("bigbio/medical_questions_pairs", 1000),
("luffycodes/medical_textbooks", 1000),
("Clinical-AI-Apollo/medical-knowledge", 1000),
# Medical note datasets
("iampiccardo/medical_consultations", 1000),
("medalpaca/medical_meadow_mmmlu", 1000),
# Wikipedia medical datasets
("sentence-transformers/wikipedia-sections", 500),
]
for dataset_name, limit in datasets_to_load:
try:
print(f"\nLoading {dataset_name}...")
# Try different splits to find available data
dataset = None
for split_name in ['train', 'test', 'validation', 'all']:
try:
if split_name == 'all':
dataset = load_dataset(dataset_name, split=f"train+test+validation[:{limit}]")
else:
dataset = load_dataset(dataset_name, split=f"{split_name}[:{limit}]")
break
except:
continue
if dataset is None:
print(f" Could not load any data from {dataset_name}")
continue
for item in dataset:
# Extract question and answer based on dataset structure
question = ""
answer = ""
context = ""
# Handle different dataset formats
if 'question' in item:
question = str(item.get('question', ''))
if 'answer' in item:
answer = str(item.get('answer', ''))
if 'input' in item:
question = str(item.get('input', ''))
if 'target' in item:
answer = str(item.get('target', ''))
if 'final_decision' in item:
answer = str(item.get('final_decision', ''))
if 'exp' in item and not answer:
answer = str(item.get('exp', ''))
if 'text' in item and not question:
question = str(item.get('text', ''))
if 'context' in item and not answer:
answer = str(item.get('context', ''))
if 'label' in item and not answer:
answer = str(item.get('label', ''))
# Handle MMLU/medmcqa style multiple choice
if 'options' in item:
options = item.get('options', [])
if isinstance(options, list) and len(options) >= 2:
options_str = f"Choices: {' | '.join(options)}"
answer = answer + " " + options_str if answer else options_str
elif isinstance(options, dict):
options_str = ", ".join([f"{k}: {v}" for k, v in options.items()])
answer = answer + " " + options_str if answer else options_str
if 'cop' in item and answer:
# Correct option for multiple choice
cop = item.get('cop', '')
if cop:
answer = f"Correct answer: {cop}. {answer}"
# Combine question and answer
if question and answer:
context = f"Question: {question}\n\nAnswer: {answer}"
elif question:
context = f"Question: {question}"
elif answer:
context = f"Medical Information: {answer}"
else:
continue
context = clean_text(context)
if context and len(context) > 20: # Filter out very short texts
documents.append({
'text': context,
'source': dataset_name.split('/')[-1],
'metadata': {
'question': question[:200] if question else '',
'answer': answer[:200] if answer else '',
'type': dataset_name.split('/')[-1]
}
})
print(f"✓ Loaded {dataset_name.split('/')[-1]}: {len([d for d in documents if d['source'] == dataset_name.split('/')[-1]])} items")
except Exception as e:
print(f"✗ Error loading {dataset_name}: {e}")
continue
print(f"\n{'='*50}")
print(f"Successfully loaded {len(documents)} total medical documents")
print(f"{'='*50}\n")
return documents
def chunk_text(text, chunk_size=512, overlap=50):
"""
Split text into chunks for better retrieval
"""
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunk = ' '.join(words[i:i + chunk_size])
chunks.append(chunk)
if i + chunk_size >= len(words):
break
return chunks