Spaces:
Sleeping
Sleeping
| """ | |
| AbMelt Complete Pipeline - Hugging Face Space Implementation | |
| Full molecular dynamics simulation pipeline for antibody thermostability prediction | |
| """ | |
| import gradio as gr | |
| import os | |
| import sys | |
| import logging | |
| import tempfile | |
| import threading | |
| import time | |
| import json | |
| from pathlib import Path | |
| import pandas as pd | |
| import traceback | |
| # Add src to path for imports | |
| sys.path.insert(0, str(Path(__file__).parent / "src")) | |
| from structure_generator import StructureGenerator | |
| from gromacs_pipeline import GromacsPipeline, GromacsError | |
| from descriptor_calculator import DescriptorCalculator | |
| from ml_predictor import ThermostabilityPredictor | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class AbMeltPipeline: | |
| """Complete AbMelt pipeline for HF Space""" | |
| def __init__(self): | |
| self.structure_gen = StructureGenerator() | |
| self.predictor = None | |
| self.current_job = None | |
| self.job_status = {} | |
| # Initialize ML predictor | |
| try: | |
| models_dir = Path(__file__).parent / "models" | |
| self.predictor = ThermostabilityPredictor(models_dir) | |
| logger.info("ML predictor initialized") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize ML predictor: {e}") | |
| def run_complete_pipeline(self, heavy_chain, light_chain, sim_time_ns=10, | |
| temperatures="300,350,400", progress_callback=None): | |
| """ | |
| Run the complete AbMelt pipeline | |
| Args: | |
| heavy_chain (str): Heavy chain variable region sequence | |
| light_chain (str): Light chain variable region sequence | |
| sim_time_ns (int): Simulation time in nanoseconds | |
| temperatures (str): Comma-separated temperatures | |
| progress_callback (callable): Function to update progress | |
| Returns: | |
| dict: Results including predictions and intermediate files | |
| """ | |
| results = { | |
| 'success': False, | |
| 'predictions': {}, | |
| 'intermediate_files': {}, | |
| 'descriptors': {}, | |
| 'error': None, | |
| 'logs': [] | |
| } | |
| temp_list = [int(t.strip()) for t in temperatures.split(',')] | |
| job_id = f"job_{int(time.time())}" | |
| try: | |
| # Initialize progress tracking | |
| if progress_callback: | |
| progress_callback(0, "Starting AbMelt pipeline...") | |
| # Step 1: Generate structure (10% progress) | |
| if progress_callback: | |
| progress_callback(10, "Generating antibody structure with ImmuneBuilder...") | |
| structure_path = self.structure_gen.generate_structure( | |
| heavy_chain, light_chain | |
| ) | |
| results['intermediate_files']['structure'] = structure_path | |
| results['logs'].append("β Structure generation completed") | |
| # Step 2: Setup MD system (20% progress) | |
| if progress_callback: | |
| progress_callback(20, "Preparing GROMACS molecular dynamics system...") | |
| md_pipeline = GromacsPipeline() | |
| try: | |
| prepared_system = md_pipeline.prepare_system(structure_path) | |
| results['intermediate_files']['prepared_system'] = prepared_system | |
| results['logs'].append("β GROMACS system preparation completed") | |
| # Step 3: Run MD simulations (30-80% progress) | |
| if progress_callback: | |
| progress_callback(30, f"Running MD simulations at {len(temp_list)} temperatures...") | |
| trajectories = md_pipeline.run_md_simulations( | |
| temperatures=temp_list, | |
| sim_time_ns=sim_time_ns | |
| ) | |
| results['intermediate_files']['trajectories'] = trajectories | |
| results['logs'].append(f"β MD simulations completed for {len(temp_list)} temperatures") | |
| # Step 4: Calculate descriptors (80-90% progress) | |
| if progress_callback: | |
| progress_callback(80, "Calculating molecular descriptors...") | |
| descriptor_calc = DescriptorCalculator(md_pipeline.work_dir) | |
| # Create topology file mapping | |
| topology_files = {temp: os.path.join(md_pipeline.work_dir, f"md_{temp}.tpr") | |
| for temp in temp_list} | |
| descriptors = descriptor_calc.calculate_all_descriptors( | |
| trajectories, topology_files | |
| ) | |
| results['descriptors'] = descriptors | |
| results['logs'].append("β Descriptor calculation completed") | |
| # Export descriptors | |
| desc_csv_path = os.path.join(md_pipeline.work_dir, "descriptors.csv") | |
| descriptor_calc.export_descriptors_csv(descriptors, desc_csv_path) | |
| results['intermediate_files']['descriptors_csv'] = desc_csv_path | |
| # Step 5: Make predictions (90-100% progress) | |
| if progress_callback: | |
| progress_callback(90, "Making thermostability predictions...") | |
| if self.predictor: | |
| predictions = self.predictor.predict_thermostability(descriptors) | |
| results['predictions'] = predictions | |
| results['logs'].append("β Thermostability predictions completed") | |
| else: | |
| results['logs'].append("β ML predictor not available") | |
| if progress_callback: | |
| progress_callback(100, "Pipeline completed successfully!") | |
| results['success'] = True | |
| except GromacsError as e: | |
| error_msg = f"GROMACS error: {str(e)}" | |
| results['error'] = error_msg | |
| results['logs'].append(f"β {error_msg}") | |
| logger.error(error_msg) | |
| finally: | |
| # Cleanup MD pipeline | |
| try: | |
| md_pipeline.cleanup() | |
| except: | |
| pass | |
| except Exception as e: | |
| error_msg = f"Pipeline error: {str(e)}" | |
| results['error'] = error_msg | |
| results['logs'].append(f"β {error_msg}") | |
| logger.error(f"Pipeline failed: {traceback.format_exc()}") | |
| finally: | |
| # Cleanup structure generator | |
| try: | |
| self.structure_gen.cleanup() | |
| except: | |
| pass | |
| return results | |
| def create_interface(): | |
| """Create the Gradio interface""" | |
| pipeline = AbMeltPipeline() | |
| with gr.Blocks(title="AbMelt: Complete MD Pipeline", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 𧬠AbMelt: Complete Molecular Dynamics Pipeline | |
| **Predict antibody thermostability through multi-temperature molecular dynamics simulations** | |
| This space implements the complete AbMelt protocol from sequence to thermostability predictions: | |
| - Structure generation with ImmuneBuilder | |
| - Multi-temperature MD simulations (300K, 350K, 400K) | |
| - Comprehensive descriptor calculation | |
| - Machine learning predictions for Tagg, Tm,on, and Tm | |
| β οΈ **Note**: Full pipeline takes 2-4 hours per antibody due to MD simulation requirements. | |
| """) | |
| with gr.Tab("π Complete Pipeline"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Input Sequences") | |
| heavy_chain = gr.Textbox( | |
| label="Heavy Chain Variable Region", | |
| placeholder="Enter VH amino acid sequence (e.g., QVQLVQSGAEVKKPG...)", | |
| lines=3, | |
| info="Variable region of heavy chain (VH)" | |
| ) | |
| light_chain = gr.Textbox( | |
| label="Light Chain Variable Region", | |
| placeholder="Enter VL amino acid sequence (e.g., DIQMTQSPSSLSASVGDR...)", | |
| lines=3, | |
| info="Variable region of light chain (VL)" | |
| ) | |
| gr.Markdown("### Simulation Parameters") | |
| sim_time = gr.Slider( | |
| minimum=10, | |
| maximum=100, | |
| value=10, | |
| step=10, | |
| label="Simulation time (ns)", | |
| info="Longer simulations are more accurate but take more time" | |
| ) | |
| temperatures = gr.Textbox( | |
| label="Temperatures (K)", | |
| value="300,350,400", | |
| info="Comma-separated temperatures for MD simulations" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Pipeline Progress") | |
| status_text = gr.Textbox( | |
| label="Current Status", | |
| value="Ready to start...", | |
| interactive=False | |
| ) | |
| run_button = gr.Button("π¬ Run Complete Pipeline", variant="primary") | |
| gr.Markdown("### Estimated Time") | |
| time_estimate = gr.Textbox( | |
| label="Estimated Completion Time", | |
| value="Not calculated", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("### π Results") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### Thermostability Predictions") | |
| tagg_result = gr.Number( | |
| label="Tagg - Aggregation Temperature (Β°C)", | |
| info="Temperature at which aggregation begins", | |
| interactive=False | |
| ) | |
| tmon_result = gr.Number( | |
| label="Tm,on - Melting Temperature On-pathway (Β°C)", | |
| info="On-pathway melting temperature", | |
| interactive=False | |
| ) | |
| tm_result = gr.Number( | |
| label="Tm - Overall Melting Temperature (Β°C)", | |
| info="Overall thermal melting temperature", | |
| interactive=False | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("#### Pipeline Logs") | |
| pipeline_logs = gr.Textbox( | |
| label="Execution Log", | |
| lines=8, | |
| info="Real-time pipeline progress and status", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("### π Download Results") | |
| with gr.Row(): | |
| structure_download = gr.File( | |
| label="Generated Structure (PDB)" | |
| ) | |
| descriptors_download = gr.File( | |
| label="Calculated Descriptors (CSV)" | |
| ) | |
| trajectory_info = gr.Textbox( | |
| label="Trajectory Information", | |
| interactive=False | |
| ) | |
| with gr.Tab("β‘ Quick Prediction"): | |
| gr.Markdown(""" | |
| ### Upload Pre-calculated Descriptors | |
| If you have already calculated MD descriptors, upload them here for quick predictions. | |
| """) | |
| descriptor_upload = gr.File( | |
| label="Upload Descriptor CSV", | |
| file_types=[".csv"] | |
| ) | |
| quick_predict_btn = gr.Button("π― Quick Predict", variant="secondary") | |
| with gr.Row(): | |
| quick_tagg = gr.Number(label="Tagg (Β°C)", interactive=False) | |
| quick_tmon = gr.Number(label="Tm,on (Β°C)", interactive=False) | |
| quick_tm = gr.Number(label="Tm (Β°C)", interactive=False) | |
| with gr.Tab("π Information"): | |
| gr.Markdown(""" | |
| ### About AbMelt | |
| AbMelt is a computational protocol for predicting antibody thermostability using molecular dynamics simulations and machine learning. | |
| #### Method Overview: | |
| 1. **Structure Generation**: Uses ImmuneBuilder to generate 3D antibody structures from sequences | |
| 2. **System Preparation**: Prepares molecular dynamics simulation system with GROMACS | |
| 3. **Multi-temperature MD**: Runs simulations at 300K, 350K, and 400K | |
| 4. **Descriptor Calculation**: Computes structural and dynamic descriptors | |
| 5. **ML Prediction**: Uses Random Forest models to predict thermostability | |
| #### Predictions: | |
| - **Tagg**: Aggregation temperature - when antibodies start to clump together | |
| - **Tm,on**: On-pathway melting temperature - structured unfolding temperature | |
| - **Tm**: Overall melting temperature - general thermal stability | |
| #### Citation: | |
| ``` | |
| @article{rollins2024, | |
| title = {{AbMelt}: {Learning} {antibody} {thermostability} from {molecular} {dynamics}}, | |
| journal = {preprint}, | |
| author = {Rollins, Zachary A and Widatalla, Talal and Cheng, Alan C and Metwally, Essam}, | |
| month = feb, | |
| year = {2024} | |
| } | |
| ``` | |
| #### Computational Requirements: | |
| - Full pipeline: 2-4 hours per antibody | |
| - Memory: ~8GB for typical antibody | |
| - Storage: ~2GB for trajectory files | |
| """) | |
| # Event handlers | |
| def update_time_estimate(sim_time_val, temps_str): | |
| try: | |
| temp_count = len([t.strip() for t in temps_str.split(',') if t.strip()]) | |
| base_time_minutes = sim_time_val * temp_count * 15 # 15 min per ns per temperature | |
| total_time = base_time_minutes + 30 # Add overhead | |
| hours = total_time // 60 | |
| minutes = total_time % 60 | |
| if hours > 0: | |
| return f"~{hours}h {minutes}m" | |
| else: | |
| return f"~{minutes}m" | |
| except: | |
| return "Unable to estimate" | |
| def run_pipeline_wrapper(heavy, light, sim_time_val, temps_str): | |
| """Wrapper to run pipeline with progress updates""" | |
| # Validate inputs | |
| if not heavy or not light: | |
| return ( | |
| None, None, None, # predictions | |
| "β Error: Both heavy and light chain sequences are required", # logs | |
| None, None, None # files | |
| ) | |
| if len(heavy.strip()) < 50 or len(light.strip()) < 50: | |
| return ( | |
| None, None, None, | |
| "β Error: Sequences seem too short. Please provide complete variable regions (>50 residues each)", | |
| None, None, None | |
| ) | |
| # Progress tracking | |
| progress_updates = [] | |
| def progress_callback(percent, message): | |
| progress_updates.append(f"[{percent}%] {message}") | |
| return progress_updates | |
| try: | |
| # Run the pipeline | |
| results = pipeline.run_complete_pipeline( | |
| heavy, light, sim_time_val, temps_str, progress_callback | |
| ) | |
| # Extract results | |
| predictions = results.get('predictions', {}) | |
| logs = "\\n".join(results.get('logs', [])) | |
| if results.get('error'): | |
| logs += f"\\nβ {results['error']}" | |
| # Prepare file outputs | |
| structure_file = results.get('intermediate_files', {}).get('structure') | |
| desc_file = results.get('intermediate_files', {}).get('descriptors_csv') | |
| traj_info = None | |
| if results.get('intermediate_files', {}).get('trajectories'): | |
| traj_count = len(results['intermediate_files']['trajectories']) | |
| traj_info = f"Generated {traj_count} trajectory files" | |
| # Extract prediction values | |
| tagg_val = predictions.get('tagg', {}).get('value') | |
| tmon_val = predictions.get('tmon', {}).get('value') | |
| tm_val = predictions.get('tm', {}).get('value') | |
| return ( | |
| tagg_val, tmon_val, tm_val, # predictions | |
| logs, # pipeline logs | |
| structure_file, desc_file, traj_info # files | |
| ) | |
| except Exception as e: | |
| error_msg = f"β Pipeline failed: {str(e)}" | |
| logger.error(f"Pipeline wrapper failed: {traceback.format_exc()}") | |
| return ( | |
| None, None, None, # predictions | |
| error_msg, # logs | |
| None, None, None # files | |
| ) | |
| def quick_prediction(desc_file): | |
| """Handle quick prediction from uploaded descriptors""" | |
| if desc_file is None: | |
| return None, None, None, "Please upload a descriptor CSV file" | |
| try: | |
| # Load descriptors | |
| df = pd.read_csv(desc_file.name) | |
| descriptors = df.iloc[0].to_dict() # Use first row | |
| # Make predictions | |
| if pipeline.predictor: | |
| predictions = pipeline.predictor.predict_thermostability(descriptors) | |
| tagg_val = predictions.get('tagg', {}).get('value') | |
| tmon_val = predictions.get('tmon', {}).get('value') | |
| tm_val = predictions.get('tm', {}).get('value') | |
| return tagg_val, tmon_val, tm_val | |
| else: | |
| return None, None, None | |
| except Exception as e: | |
| logger.error(f"Quick prediction failed: {e}") | |
| return None, None, None | |
| # Connect event handlers | |
| sim_time.change( | |
| update_time_estimate, | |
| inputs=[sim_time, temperatures], | |
| outputs=time_estimate | |
| ) | |
| temperatures.change( | |
| update_time_estimate, | |
| inputs=[sim_time, temperatures], | |
| outputs=time_estimate | |
| ) | |
| run_button.click( | |
| run_pipeline_wrapper, | |
| inputs=[heavy_chain, light_chain, sim_time, temperatures], | |
| outputs=[ | |
| tagg_result, tmon_result, tm_result, # predictions | |
| pipeline_logs, # logs | |
| structure_download, descriptors_download, trajectory_info # files | |
| ] | |
| ) | |
| quick_predict_btn.click( | |
| quick_prediction, | |
| inputs=descriptor_upload, | |
| outputs=[quick_tagg, quick_tmon, quick_tm] | |
| ) | |
| # File downloads will be shown when pipeline completes | |
| return demo | |
| if __name__ == "__main__": | |
| # Create and launch the interface | |
| demo = create_interface() | |
| demo.queue(max_size=3) # Maximum queue size | |
| demo.launch(share=True) |