AhmedElTaher's picture
Upload 43 files
7b998cb verified
"""
AbMelt Complete Pipeline - Enhanced Hugging Face Space Implementation
Full molecular dynamics simulation pipeline for antibody thermostability prediction with advanced visualization
"""
import gradio as gr
import os
import sys
import logging
import tempfile
import threading
import time
import json
import shutil
from pathlib import Path
import pandas as pd
import traceback
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import numpy as np
# Add src to path for imports
sys.path.insert(0, str(Path(__file__).parent / "src"))
from structure_generator import StructureGenerator
from gromacs_pipeline import GromacsPipeline, GromacsError
from descriptor_calculator import DescriptorCalculator
from ml_predictor import ThermostabilityPredictor
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EnhancedVisualization:
"""Enhanced visualization for MD analysis"""
def __init__(self):
pass
def create_stability_summary_plot(self, predictions):
"""Create a summary plot of stability predictions"""
if not predictions:
return None
fig = go.Figure()
# Extract values
metrics = []
values = []
colors = []
for metric, data in predictions.items():
if isinstance(data, dict) and 'value' in data:
metrics.append(metric.upper())
values.append(data['value'])
# Color coding based on typical ranges
if metric == 'tagg':
colors.append('#FF6B6B' if data['value'] < 60 else '#4ECDC4')
elif metric == 'tm':
colors.append('#FF6B6B' if data['value'] < 60 else '#4ECDC4')
else:
colors.append('#45B7D1')
if values:
fig.add_trace(go.Bar(
x=metrics,
y=values,
marker_color=colors,
text=[f"{v:.1f}°C" for v in values],
textposition='auto'
))
fig.update_layout(
title="Antibody Thermostability Predictions",
xaxis_title="Stability Metrics",
yaxis_title="Temperature (°C)",
showlegend=False,
height=400
)
return fig
def create_3d_structure_viewer(self, pdb_path):
"""Create 3D structure viewer HTML"""
if not os.path.exists(pdb_path):
return ""
# Simple 3Dmol.js viewer
html = f"""
<div id="viewer" style="width: 100%; height: 400px;"></div>
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
<script>
let viewer = $3Dmol.createViewer('viewer');
$3Dmol.get('{pdb_path}', function(data) {{
viewer.addModel(data, 'pdb');
viewer.setStyle({{}}, {{cartoon: {{color: 'spectrum'}}}});
viewer.zoomTo();
viewer.render();
}});
</script>
"""
return html
class AbMeltPipeline:
"""Enhanced AbMelt pipeline for HF Space with visualization"""
def __init__(self):
self.structure_gen = StructureGenerator()
self.predictor = None
self.visualizer = EnhancedVisualization()
self.current_job = None
self.job_status = {}
# Initialize ML predictor
try:
models_dir = Path(__file__).parent / "models"
self.predictor = ThermostabilityPredictor(models_dir)
logger.info("ML predictor initialized")
except Exception as e:
logger.error(f"Failed to initialize ML predictor: {e}")
def run_complete_pipeline(self, heavy_chain, light_chain, sim_time_ns=10,
temperatures="300,350,400", progress_callback=None,
generate_visuals=True):
"""
Run the complete AbMelt pipeline with enhanced visualization
Args:
heavy_chain (str): Heavy chain variable region sequence
light_chain (str): Light chain variable region sequence
sim_time_ns (int): Simulation time in nanoseconds
temperatures (str): Comma-separated temperatures
progress_callback (callable): Function to update progress
generate_visuals (bool): Whether to generate visualizations
Returns:
dict: Results including predictions, visualizations, and intermediate files
"""
results = {
'success': False,
'predictions': {},
'intermediate_files': {},
'descriptors': {},
'visualizations': {},
'error': None,
'logs': []
}
temp_list = [int(t.strip()) for t in temperatures.split(',')]
job_id = f"job_{int(time.time())}"
try:
# Initialize progress tracking
if progress_callback:
progress_callback(0, "Starting AbMelt pipeline...")
# Step 1: Generate structure (10% progress)
if progress_callback:
progress_callback(10, "Generating antibody structure...")
structure_path = self.structure_gen.generate_structure(
heavy_chain, light_chain
)
if not structure_path or not os.path.exists(structure_path):
raise FileNotFoundError("Structure generation failed. PDB file not created.")
# Copy structure file to persistent location before cleanup
persistent_dir = os.path.join(os.getcwd(), "outputs")
os.makedirs(persistent_dir, exist_ok=True)
persistent_structure = os.path.join(persistent_dir, f"structure_{job_id}.pdb")
shutil.copy2(structure_path, persistent_structure)
results['intermediate_files']['structure'] = persistent_structure
logger.info(f"Structure copied to persistent location: {persistent_structure}")
results['logs'].append("✓ Structure generation completed")
# Generate 3D structure visualization
if generate_visuals and structure_path:
results['visualizations']['structure_3d'] = self.visualizer.create_3d_structure_viewer(structure_path)
# Step 2: Setup MD system (20% progress)
if progress_callback:
progress_callback(20, "Preparing GROMACS molecular dynamics system...")
md_pipeline = GromacsPipeline()
try:
prepared_system = md_pipeline.prepare_system(structure_path)
results['intermediate_files']['prepared_system'] = prepared_system
results['logs'].append("✓ GROMACS system preparation completed")
# Step 3: Run MD simulations (30-80% progress)
if progress_callback:
progress_callback(30, f"Running MD simulations at {len(temp_list)} temperatures...")
trajectories = md_pipeline.run_md_simulations(
temperatures=temp_list,
sim_time_ns=sim_time_ns
)
# Copy important MD files to persistent location
persistent_dir = os.path.join(os.getcwd(), "outputs")
os.makedirs(persistent_dir, exist_ok=True)
persistent_trajectories = {}
for temp, traj_path in trajectories.items():
if traj_path and os.path.exists(traj_path):
persistent_traj = os.path.join(persistent_dir, f"trajectory_{temp}K_{job_id}.xtc")
shutil.copy2(traj_path, persistent_traj)
persistent_trajectories[temp] = persistent_traj
else:
persistent_trajectories[temp] = traj_path
results['intermediate_files']['trajectories'] = persistent_trajectories
results['logs'].append(f"✓ MD simulations completed for {len(temp_list)} temperatures")
# Step 4: Calculate descriptors (80-90% progress)
if progress_callback:
progress_callback(80, "Calculating molecular descriptors...")
descriptor_calc = DescriptorCalculator(md_pipeline.work_dir)
# Create topology file mapping
topology_files = {temp: os.path.join(md_pipeline.work_dir, f"md_{temp}.tpr")
for temp in temp_list}
descriptors = descriptor_calc.calculate_all_descriptors(
trajectories, topology_files
)
results['descriptors'] = descriptors
results['logs'].append("✓ Descriptor calculation completed")
# Export descriptors
desc_csv_path = os.path.join(md_pipeline.work_dir, "descriptors.csv")
descriptor_calc.export_descriptors_csv(descriptors, desc_csv_path)
results['intermediate_files']['descriptors_csv'] = desc_csv_path
# Step 5: Make predictions (90-100% progress)
if progress_callback:
progress_callback(90, "Making thermostability predictions...")
if self.predictor:
predictions = self.predictor.predict_thermostability(descriptors)
results['predictions'] = predictions
results['logs'].append("✓ Thermostability predictions completed")
# Generate visualization plots
if generate_visuals:
results['visualizations']['stability_summary'] = self.visualizer.create_stability_summary_plot(predictions)
else:
results['logs'].append("⚠ ML predictor not available")
if progress_callback:
progress_callback(100, "Pipeline completed successfully!")
results['success'] = True
except GromacsError as e:
error_msg = f"GROMACS error: {str(e)}"
results['error'] = error_msg
results['logs'].append(f"✗ {error_msg}")
logger.error(error_msg)
# Add specific check for NaN coordinate error
if "NaN" in str(e) or "invalid coordinates" in str(e):
results['logs'].append(" Hint: This error is often caused by problems in the fallback "
"structure generator. Check sequence validity.")
finally:
# Cleanup MD pipeline
try:
md_pipeline.cleanup()
except:
pass
except (Exception, GromacsError) as e:
error_msg = f"Pipeline error: {str(e)}"
results['error'] = error_msg
results['logs'].append(f"✗ {error_msg}")
logger.error(f"Pipeline failed: {traceback.format_exc()}")
finally:
# Cleanup structure generator
try:
self.structure_gen.cleanup()
except:
pass
return results
def create_interface():
"""Create the enhanced Gradio interface"""
pipeline = AbMeltPipeline()
with gr.Blocks(title="AbMelt: Enhanced MD Pipeline", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🧬 AbMelt: Enhanced Molecular Dynamics Pipeline
**Predict antibody thermostability through multi-temperature molecular dynamics simulations**
This enhanced space implements the complete AbMelt protocol with advanced visualizations:
- Structure generation with ImmuneBuilder and 3D visualization
- Multi-temperature MD simulations with trajectory analysis
- Comprehensive descriptor calculation and plotting
- Machine learning predictions with uncertainty quantification
- Interactive visualization of results
⚠️ **Note**: Full pipeline takes 2-4 hours per antibody due to MD simulation requirements.
""")
with gr.Tab("🚀 Complete Pipeline"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Input Sequences")
heavy_chain = gr.Textbox(
label="Heavy Chain Variable Region",
placeholder="Enter VH amino acid sequence (e.g., QVQLVQSGAEVKKPG...)",
lines=3,
info="Variable region of heavy chain (VH)"
)
light_chain = gr.Textbox(
label="Light Chain Variable Region",
placeholder="Enter VL amino acid sequence (e.g., DIQMTQSPSSLSASVGDR...)",
lines=3,
info="Variable region of light chain (VL)"
)
gr.Markdown("### Simulation Parameters")
sim_time = gr.Slider(
minimum=10,
maximum=100,
value=10,
step=10,
label="Simulation time (ns)",
info="Longer simulations are more accurate but take more time"
)
temperatures = gr.Textbox(
label="Temperatures (K)",
value="300,350,400",
info="Comma-separated temperatures for MD simulations"
)
generate_visuals = gr.Checkbox(
label="Generate Advanced Visualizations",
value=True,
info="Create 3D structure viewer and analysis plots"
)
with gr.Column(scale=1):
gr.Markdown("### Pipeline Progress")
run_button = gr.Button("🔬 Run Complete Pipeline", variant="primary")
gr.Markdown("### Estimated Time")
time_estimate = gr.Textbox(
label="Estimated Completion Time",
value="Not calculated",
interactive=False
)
with gr.Row():
gr.Markdown("### 📊 Results")
with gr.Row():
with gr.Column():
gr.Markdown("#### Thermostability Predictions")
tagg_result = gr.Number(
label="Tagg - Aggregation Temperature (°C)",
info="Temperature at which aggregation begins",
interactive=False
)
tmon_result = gr.Number(
label="Tm,on - Melting Temperature On-pathway (°C)",
info="On-pathway melting temperature",
interactive=False
)
tm_result = gr.Number(
label="Tm - Overall Melting Temperature (°C)",
info="Overall thermal melting temperature",
interactive=False
)
with gr.Column():
gr.Markdown("#### Pipeline Logs")
pipeline_logs = gr.Textbox(
label="Execution Log",
lines=8,
info="Real-time pipeline progress and status",
interactive=False
)
with gr.Row():
gr.Markdown("### 📈 Visualization")
stability_plot = gr.Plot(label="Stability Summary")
with gr.Row():
gr.Markdown("### 📁 Download Results")
with gr.Row():
structure_download = gr.File(
label="Generated Structure (PDB)"
)
descriptors_download = gr.File(
label="Calculated Descriptors (CSV)"
)
trajectory_info = gr.Textbox(
label="Trajectory Information",
interactive=False
)
with gr.Row():
gr.Markdown("### 🔬 3D Structure Viewer")
structure_viewer = gr.HTML(label="Interactive 3D Structure")
with gr.Row():
gr.Markdown("### 📊 Analysis Plots")
with gr.Row():
rmsd_plot = gr.Plot(label="RMSD Analysis")
flexibility_plot = gr.Plot(label="Flexibility Heatmap")
with gr.Row():
rg_plot = gr.Plot(label="Radius of Gyration")
contact_plot = gr.Plot(label="Contact Maps")
with gr.Row():
trajectory_snapshots = gr.Plot(label="Trajectory Snapshots")
with gr.Tab("⚡ Quick Prediction"):
gr.Markdown("""
### Upload Pre-calculated Descriptors
If you have already calculated MD descriptors, upload them here for quick predictions.
""")
descriptor_upload = gr.File(
label="Upload Descriptor CSV",
file_types=[".csv"]
)
quick_predict_btn = gr.Button("🔮 Make Quick Prediction", variant="secondary")
with gr.Row():
quick_tagg = gr.Number(label="Tagg (°C)", interactive=False)
quick_tmon = gr.Number(label="Tm,on (°C)", interactive=False)
quick_tm = gr.Number(label="Tm (°C)", interactive=False)
quick_summary_plot = gr.Plot(label="Quick Prediction Summary")
with gr.Tab("📚 Documentation"):
gr.Markdown("""
## How to Use AbMelt
### 1. Input Requirements
- **Heavy Chain**: Variable region sequence (VH) of your antibody
- **Light Chain**: Variable region sequence (VL) of your antibody
- **Simulation Time**: 10-100 ns (longer = more accurate but slower)
- **Temperatures**: Comma-separated list (default: 300,350,400 K)
### 2. Pipeline Steps
1. **Structure Generation**: Uses ImmuneBuilder to create 3D structure
2. **System Preparation**: Sets up GROMACS molecular dynamics system
3. **MD Simulations**: Runs simulations at specified temperatures
4. **Descriptor Calculation**: Extracts molecular features from trajectories
5. **ML Prediction**: Uses trained models to predict thermostability
### 3. Output Interpretation
- **Tagg**: Aggregation onset temperature - higher is better
- **Tm,on**: On-pathway melting temperature - thermal stability indicator
- **Tm**: Overall melting temperature - global stability measure
### 4. Tips for Best Results
- Use complete variable regions (110-130 residues typical)
- Longer simulations provide more reliable predictions
- Check logs for any warnings or errors
- Download trajectories for further analysis
### Additional Resources
- [GitHub Repository](https://github.com/MSDLLCpapers/AbMelt)
- [Original Paper](https://doi.org/10.1016/j.bpj.2024.06.003)
- [GROMACS Documentation](https://manual.gromacs.org/)
- [ImmuneBuilder](https://github.com/oxpig/ImmuneBuilder)
""")
with gr.Tab("❓ FAQ"):
gr.Markdown("""
## Frequently Asked Questions
**Q: How long does the full pipeline take?**
A: Approximately 15-20 minutes per nanosecond per temperature. For 10ns at 3 temperatures,
expect 2-3 hours. For 100ns simulations, expect 20-30 hours.
**Q: Can I use this for nanobodies or single-domain antibodies?**
A: Yes, but you'll need to provide both VH and VL sequences. For nanobodies,
you can duplicate the VH sequence as VL, though predictions may be less accurate.
**Q: What if I only have the full antibody sequence?**
A: You need to extract just the variable regions (VH and VL). Use IMGT/V-QUEST
or similar tools to identify variable region boundaries.
**Q: Why are my predictions very low/high?**
A: Check that:
- Sequences are correct variable regions
- No unusual amino acids or modifications
- Proper sequence length (typically 110-130 residues)
**Q: Can I compare multiple antibodies?**
A: Currently, process one at a time. Save results for comparison.
Batch processing may be added in future versions.
**Q: What does "GROMACS error" mean?**
A: Usually indicates:
- Severely malformed structure
- System preparation issues
- Memory constraints
Try reducing simulation time or contact support.
**Q: How accurate are the predictions?**
A: Based on published validation:
- Tagg: R² = 0.57 ± 0.11
- Tm,on: R² = 0.56 ± 0.01
- Tm: R² = 0.60 ± 0.06
**Q: Can I use the trajectories for other analyses?**
A: Yes! Download the trajectory files and use with VMD, PyMOL, or other MD analysis tools.
""")
# Event handlers
def update_time_estimate(sim_time_val, temps_str):
"""Calculate and update time estimate"""
try:
temp_count = len([t.strip() for t in temps_str.split(',') if t.strip()])
# More accurate time estimation
setup_time = 10 # minutes for setup
md_time_per_ns = 12 # minutes per ns per temperature (average)
analysis_time = 5 # minutes for analysis
total_md_time = sim_time_val * temp_count * md_time_per_ns
total_time = setup_time + total_md_time + analysis_time
hours = int(total_time // 60)
minutes = int(total_time % 60)
if hours > 0:
time_str = f"~{hours}h {minutes}m"
else:
time_str = f"~{minutes}m"
# Add warning for long simulations
if total_time > 240: # More than 4 hours
time_str += " ⚠️ Consider using shorter simulation time"
return time_str
except Exception as e:
return "Unable to estimate"
def run_pipeline_wrapper(heavy, light, sim_time_val, temps_str, gen_visuals):
"""Wrapper to run pipeline with progress updates"""
# Input validation
if not heavy or not light:
return (
None, None, None, # predictions
"❌ Error: Both heavy and light chain sequences are required", # logs
None, # stability plot
None, None, # downloads
None, # trajectory info
"", # structure viewer HTML
None, None, None, None, None # visualization plots
)
# Validate sequence length and content
if len(heavy.strip()) < 50 or len(light.strip()) < 50:
return (
None, None, None,
"❌ Error: Sequences too short. Provide complete variable regions (>50 residues)",
None, None, None, None, "", None, None, None, None, None
)
# Check for valid amino acids
valid_aas = set('ACDEFGHIKLMNPQRSTVWY')
heavy_clean = heavy.upper().strip()
light_clean = light.upper().strip()
if not all(aa in valid_aas for aa in heavy_clean):
return (
None, None, None,
"❌ Error: Invalid characters in heavy chain sequence",
None, None, None, None, "", None, None, None, None, None
)
if not all(aa in valid_aas for aa in light_clean):
return (
None, None, None,
"❌ Error: Invalid characters in light chain sequence",
None, None, None, None, "", None, None, None, None, None
)
# Progress tracking
def progress_callback(percent, message):
return f"[{percent}%] {message}"
try:
# Run the pipeline
results = pipeline.run_complete_pipeline(
heavy_clean, light_clean, sim_time_val, temps_str,
progress_callback, gen_visuals
)
# Extract results
predictions = results.get('predictions', {})
logs = "\n".join(results.get('logs', []))
visuals = results.get('visualizations', {})
if results.get('error'):
logs += f"\n❌ {results['error']}"
# Prepare file outputs
structure_file = results.get('intermediate_files', {}).get('structure')
desc_file = results.get('intermediate_files', {}).get('descriptors_csv')
# Trajectory information
traj_info = "No trajectory files generated"
if results.get('intermediate_files', {}).get('trajectories'):
traj_dict = results['intermediate_files']['trajectories']
traj_info = "Generated trajectories:\n"
for temp, traj_path in traj_dict.items():
if os.path.exists(traj_path):
size_mb = os.path.getsize(traj_path) / (1024*1024)
traj_info += f"- {temp}K: {os.path.basename(traj_path)} ({size_mb:.1f} MB)\n"
# Extract prediction values
tagg_val = predictions.get('tagg', {}).get('value')
tmon_val = predictions.get('tmon', {}).get('value')
tm_val = predictions.get('tm', {}).get('value')
# Extract visualizations
structure_html = visuals.get('structure_3d', "")
rmsd_fig = visuals.get('rmsd_plot')
flex_fig = visuals.get('flexibility_heatmap')
rg_fig = visuals.get('rg_plot')
contact_fig = visuals.get('contact_maps')
traj_fig = visuals.get('trajectory_snapshots')
summary_fig = visuals.get('stability_summary')
return (
tagg_val, tmon_val, tm_val, # predictions
logs, # pipeline logs
summary_fig, # stability plot
structure_file, desc_file, # downloads
traj_info, # trajectory info
structure_html, # 3D viewer
rmsd_fig, flex_fig, rg_fig, contact_fig, traj_fig # plots
)
except Exception as e:
error_msg = f"❌ Pipeline failed: {str(e)}"
logger.error(f"Pipeline wrapper failed: {traceback.format_exc()}")
return (
None, None, None, # predictions
error_msg, # logs
None, None, None, None, "", None, None, None, None, None # rest
)
def quick_prediction(desc_file):
"""Handle quick prediction from uploaded descriptors"""
if desc_file is None:
return None, None, None, None
try:
# Load descriptors
df = pd.read_csv(desc_file.name)
# Convert DataFrame to dictionary format expected by predictor
if len(df) > 0:
descriptors = df.iloc[0].to_dict()
else:
return None, None, None, None
# Make predictions
if pipeline.predictor:
predictions = pipeline.predictor.predict_thermostability(descriptors)
tagg_val = predictions.get('tagg', {}).get('value')
tmon_val = predictions.get('tmon', {}).get('value')
tm_val = predictions.get('tm', {}).get('value')
# Generate summary plot
summary_plot = pipeline.visualizer.create_stability_summary_plot(predictions)
return tagg_val, tmon_val, tm_val, summary_plot
else:
return None, None, None, None
except Exception as e:
logger.error(f"Quick prediction failed: {e}")
return None, None, None, None
# Connect event handlers
sim_time.change(
update_time_estimate,
inputs=[sim_time, temperatures],
outputs=time_estimate
)
temperatures.change(
update_time_estimate,
inputs=[sim_time, temperatures],
outputs=time_estimate
)
run_button.click(
run_pipeline_wrapper,
inputs=[heavy_chain, light_chain, sim_time, temperatures, generate_visuals],
outputs=[
tagg_result, tmon_result, tm_result, # predictions
pipeline_logs, # logs
stability_plot, # stability summary
structure_download, descriptors_download, # downloads
trajectory_info, # trajectory info
structure_viewer, # 3D viewer
rmsd_plot, flexibility_plot, rg_plot, # analysis plots
contact_plot, trajectory_snapshots # more plots
]
)
quick_predict_btn.click(
quick_prediction,
inputs=descriptor_upload,
outputs=[quick_tagg, quick_tmon, quick_tm, quick_summary_plot]
)
# Add example sequences
gr.Examples(
examples=[
[
"QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYYMHWVRQAPGQGLEWMGIINPSGGSTSYAQKFQGRVTMTRDTSTSTVYMELSSLRSEDTAVYYCARGGNSAFYSSWFAYWGQGTLVTVSS",
"DIQMTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKLLIYAASSLQSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYSTPPITFGQGTRLEIKR"
],
[
"EVQLLESGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLEWVSAISGSGGSTYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCAKDRGYSSGWYLFDYWGQGTLVTVSS",
"EIVLTQSPGTLSLSPGERATLSCRASQSVSSSYLAWYQQKPGQAPRLLIYGASSRATGIPDRFSGSGSGTDFTLTISRLEPEDFAVYYCQQYGSSPLTFGGGTKVEIKR"
]
],
inputs=[heavy_chain, light_chain],
label="Example Antibody Sequences"
)
return demo
if __name__ == "__main__":
# Create and launch the enhanced interface
demo = create_interface()
# Configure for HF Space or local deployment
if os.getenv("SPACE_ID"):
# Running in HF Space
demo.queue(max_size=3)
demo.launch(share=True)
else:
# Running locally
demo.queue(max_size=5)
demo.launch(share=True, server_name="0.0.0.0", server_port=7860)