Spaces:
Sleeping
Sleeping
| """ | |
| AbMelt Complete Pipeline - Enhanced Hugging Face Space Implementation | |
| Full molecular dynamics simulation pipeline for antibody thermostability prediction with advanced visualization | |
| """ | |
| import gradio as gr | |
| import os | |
| import sys | |
| import logging | |
| import tempfile | |
| import threading | |
| import time | |
| import json | |
| import shutil | |
| from pathlib import Path | |
| import pandas as pd | |
| import traceback | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| from plotly.subplots import make_subplots | |
| import numpy as np | |
| # Add src to path for imports | |
| sys.path.insert(0, str(Path(__file__).parent / "src")) | |
| from structure_generator import StructureGenerator | |
| from gromacs_pipeline import GromacsPipeline, GromacsError | |
| from descriptor_calculator import DescriptorCalculator | |
| from ml_predictor import ThermostabilityPredictor | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class EnhancedVisualization: | |
| """Enhanced visualization for MD analysis""" | |
| def __init__(self): | |
| pass | |
| def create_stability_summary_plot(self, predictions): | |
| """Create a summary plot of stability predictions""" | |
| if not predictions: | |
| return None | |
| fig = go.Figure() | |
| # Extract values | |
| metrics = [] | |
| values = [] | |
| colors = [] | |
| for metric, data in predictions.items(): | |
| if isinstance(data, dict) and 'value' in data: | |
| metrics.append(metric.upper()) | |
| values.append(data['value']) | |
| # Color coding based on typical ranges | |
| if metric == 'tagg': | |
| colors.append('#FF6B6B' if data['value'] < 60 else '#4ECDC4') | |
| elif metric == 'tm': | |
| colors.append('#FF6B6B' if data['value'] < 60 else '#4ECDC4') | |
| else: | |
| colors.append('#45B7D1') | |
| if values: | |
| fig.add_trace(go.Bar( | |
| x=metrics, | |
| y=values, | |
| marker_color=colors, | |
| text=[f"{v:.1f}°C" for v in values], | |
| textposition='auto' | |
| )) | |
| fig.update_layout( | |
| title="Antibody Thermostability Predictions", | |
| xaxis_title="Stability Metrics", | |
| yaxis_title="Temperature (°C)", | |
| showlegend=False, | |
| height=400 | |
| ) | |
| return fig | |
| def create_3d_structure_viewer(self, pdb_path): | |
| """Create 3D structure viewer HTML""" | |
| if not os.path.exists(pdb_path): | |
| return "" | |
| # Simple 3Dmol.js viewer | |
| html = f""" | |
| <div id="viewer" style="width: 100%; height: 400px;"></div> | |
| <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> | |
| <script> | |
| let viewer = $3Dmol.createViewer('viewer'); | |
| $3Dmol.get('{pdb_path}', function(data) {{ | |
| viewer.addModel(data, 'pdb'); | |
| viewer.setStyle({{}}, {{cartoon: {{color: 'spectrum'}}}}); | |
| viewer.zoomTo(); | |
| viewer.render(); | |
| }}); | |
| </script> | |
| """ | |
| return html | |
| class AbMeltPipeline: | |
| """Enhanced AbMelt pipeline for HF Space with visualization""" | |
| def __init__(self): | |
| self.structure_gen = StructureGenerator() | |
| self.predictor = None | |
| self.visualizer = EnhancedVisualization() | |
| self.current_job = None | |
| self.job_status = {} | |
| # Initialize ML predictor | |
| try: | |
| models_dir = Path(__file__).parent / "models" | |
| self.predictor = ThermostabilityPredictor(models_dir) | |
| logger.info("ML predictor initialized") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize ML predictor: {e}") | |
| def run_complete_pipeline(self, heavy_chain, light_chain, sim_time_ns=10, | |
| temperatures="300,350,400", progress_callback=None, | |
| generate_visuals=True): | |
| """ | |
| Run the complete AbMelt pipeline with enhanced visualization | |
| Args: | |
| heavy_chain (str): Heavy chain variable region sequence | |
| light_chain (str): Light chain variable region sequence | |
| sim_time_ns (int): Simulation time in nanoseconds | |
| temperatures (str): Comma-separated temperatures | |
| progress_callback (callable): Function to update progress | |
| generate_visuals (bool): Whether to generate visualizations | |
| Returns: | |
| dict: Results including predictions, visualizations, and intermediate files | |
| """ | |
| results = { | |
| 'success': False, | |
| 'predictions': {}, | |
| 'intermediate_files': {}, | |
| 'descriptors': {}, | |
| 'visualizations': {}, | |
| 'error': None, | |
| 'logs': [] | |
| } | |
| temp_list = [int(t.strip()) for t in temperatures.split(',')] | |
| job_id = f"job_{int(time.time())}" | |
| try: | |
| # Initialize progress tracking | |
| if progress_callback: | |
| progress_callback(0, "Starting AbMelt pipeline...") | |
| # Step 1: Generate structure (10% progress) | |
| if progress_callback: | |
| progress_callback(10, "Generating antibody structure...") | |
| structure_path = self.structure_gen.generate_structure( | |
| heavy_chain, light_chain | |
| ) | |
| if not structure_path or not os.path.exists(structure_path): | |
| raise FileNotFoundError("Structure generation failed. PDB file not created.") | |
| # Copy structure file to persistent location before cleanup | |
| persistent_dir = os.path.join(os.getcwd(), "outputs") | |
| os.makedirs(persistent_dir, exist_ok=True) | |
| persistent_structure = os.path.join(persistent_dir, f"structure_{job_id}.pdb") | |
| shutil.copy2(structure_path, persistent_structure) | |
| results['intermediate_files']['structure'] = persistent_structure | |
| logger.info(f"Structure copied to persistent location: {persistent_structure}") | |
| results['logs'].append("✓ Structure generation completed") | |
| # Generate 3D structure visualization | |
| if generate_visuals and structure_path: | |
| results['visualizations']['structure_3d'] = self.visualizer.create_3d_structure_viewer(structure_path) | |
| # Step 2: Setup MD system (20% progress) | |
| if progress_callback: | |
| progress_callback(20, "Preparing GROMACS molecular dynamics system...") | |
| md_pipeline = GromacsPipeline() | |
| try: | |
| prepared_system = md_pipeline.prepare_system(structure_path) | |
| results['intermediate_files']['prepared_system'] = prepared_system | |
| results['logs'].append("✓ GROMACS system preparation completed") | |
| # Step 3: Run MD simulations (30-80% progress) | |
| if progress_callback: | |
| progress_callback(30, f"Running MD simulations at {len(temp_list)} temperatures...") | |
| trajectories = md_pipeline.run_md_simulations( | |
| temperatures=temp_list, | |
| sim_time_ns=sim_time_ns | |
| ) | |
| # Copy important MD files to persistent location | |
| persistent_dir = os.path.join(os.getcwd(), "outputs") | |
| os.makedirs(persistent_dir, exist_ok=True) | |
| persistent_trajectories = {} | |
| for temp, traj_path in trajectories.items(): | |
| if traj_path and os.path.exists(traj_path): | |
| persistent_traj = os.path.join(persistent_dir, f"trajectory_{temp}K_{job_id}.xtc") | |
| shutil.copy2(traj_path, persistent_traj) | |
| persistent_trajectories[temp] = persistent_traj | |
| else: | |
| persistent_trajectories[temp] = traj_path | |
| results['intermediate_files']['trajectories'] = persistent_trajectories | |
| results['logs'].append(f"✓ MD simulations completed for {len(temp_list)} temperatures") | |
| # Step 4: Calculate descriptors (80-90% progress) | |
| if progress_callback: | |
| progress_callback(80, "Calculating molecular descriptors...") | |
| descriptor_calc = DescriptorCalculator(md_pipeline.work_dir) | |
| # Create topology file mapping | |
| topology_files = {temp: os.path.join(md_pipeline.work_dir, f"md_{temp}.tpr") | |
| for temp in temp_list} | |
| descriptors = descriptor_calc.calculate_all_descriptors( | |
| trajectories, topology_files | |
| ) | |
| results['descriptors'] = descriptors | |
| results['logs'].append("✓ Descriptor calculation completed") | |
| # Export descriptors | |
| desc_csv_path = os.path.join(md_pipeline.work_dir, "descriptors.csv") | |
| descriptor_calc.export_descriptors_csv(descriptors, desc_csv_path) | |
| results['intermediate_files']['descriptors_csv'] = desc_csv_path | |
| # Step 5: Make predictions (90-100% progress) | |
| if progress_callback: | |
| progress_callback(90, "Making thermostability predictions...") | |
| if self.predictor: | |
| predictions = self.predictor.predict_thermostability(descriptors) | |
| results['predictions'] = predictions | |
| results['logs'].append("✓ Thermostability predictions completed") | |
| # Generate visualization plots | |
| if generate_visuals: | |
| results['visualizations']['stability_summary'] = self.visualizer.create_stability_summary_plot(predictions) | |
| else: | |
| results['logs'].append("⚠ ML predictor not available") | |
| if progress_callback: | |
| progress_callback(100, "Pipeline completed successfully!") | |
| results['success'] = True | |
| except GromacsError as e: | |
| error_msg = f"GROMACS error: {str(e)}" | |
| results['error'] = error_msg | |
| results['logs'].append(f"✗ {error_msg}") | |
| logger.error(error_msg) | |
| # Add specific check for NaN coordinate error | |
| if "NaN" in str(e) or "invalid coordinates" in str(e): | |
| results['logs'].append(" Hint: This error is often caused by problems in the fallback " | |
| "structure generator. Check sequence validity.") | |
| finally: | |
| # Cleanup MD pipeline | |
| try: | |
| md_pipeline.cleanup() | |
| except: | |
| pass | |
| except (Exception, GromacsError) as e: | |
| error_msg = f"Pipeline error: {str(e)}" | |
| results['error'] = error_msg | |
| results['logs'].append(f"✗ {error_msg}") | |
| logger.error(f"Pipeline failed: {traceback.format_exc()}") | |
| finally: | |
| # Cleanup structure generator | |
| try: | |
| self.structure_gen.cleanup() | |
| except: | |
| pass | |
| return results | |
| def create_interface(): | |
| """Create the enhanced Gradio interface""" | |
| pipeline = AbMeltPipeline() | |
| with gr.Blocks(title="AbMelt: Enhanced MD Pipeline", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🧬 AbMelt: Enhanced Molecular Dynamics Pipeline | |
| **Predict antibody thermostability through multi-temperature molecular dynamics simulations** | |
| This enhanced space implements the complete AbMelt protocol with advanced visualizations: | |
| - Structure generation with ImmuneBuilder and 3D visualization | |
| - Multi-temperature MD simulations with trajectory analysis | |
| - Comprehensive descriptor calculation and plotting | |
| - Machine learning predictions with uncertainty quantification | |
| - Interactive visualization of results | |
| ⚠️ **Note**: Full pipeline takes 2-4 hours per antibody due to MD simulation requirements. | |
| """) | |
| with gr.Tab("🚀 Complete Pipeline"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Input Sequences") | |
| heavy_chain = gr.Textbox( | |
| label="Heavy Chain Variable Region", | |
| placeholder="Enter VH amino acid sequence (e.g., QVQLVQSGAEVKKPG...)", | |
| lines=3, | |
| info="Variable region of heavy chain (VH)" | |
| ) | |
| light_chain = gr.Textbox( | |
| label="Light Chain Variable Region", | |
| placeholder="Enter VL amino acid sequence (e.g., DIQMTQSPSSLSASVGDR...)", | |
| lines=3, | |
| info="Variable region of light chain (VL)" | |
| ) | |
| gr.Markdown("### Simulation Parameters") | |
| sim_time = gr.Slider( | |
| minimum=10, | |
| maximum=100, | |
| value=10, | |
| step=10, | |
| label="Simulation time (ns)", | |
| info="Longer simulations are more accurate but take more time" | |
| ) | |
| temperatures = gr.Textbox( | |
| label="Temperatures (K)", | |
| value="300,350,400", | |
| info="Comma-separated temperatures for MD simulations" | |
| ) | |
| generate_visuals = gr.Checkbox( | |
| label="Generate Advanced Visualizations", | |
| value=True, | |
| info="Create 3D structure viewer and analysis plots" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Pipeline Progress") | |
| run_button = gr.Button("🔬 Run Complete Pipeline", variant="primary") | |
| gr.Markdown("### Estimated Time") | |
| time_estimate = gr.Textbox( | |
| label="Estimated Completion Time", | |
| value="Not calculated", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("### 📊 Results") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### Thermostability Predictions") | |
| tagg_result = gr.Number( | |
| label="Tagg - Aggregation Temperature (°C)", | |
| info="Temperature at which aggregation begins", | |
| interactive=False | |
| ) | |
| tmon_result = gr.Number( | |
| label="Tm,on - Melting Temperature On-pathway (°C)", | |
| info="On-pathway melting temperature", | |
| interactive=False | |
| ) | |
| tm_result = gr.Number( | |
| label="Tm - Overall Melting Temperature (°C)", | |
| info="Overall thermal melting temperature", | |
| interactive=False | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("#### Pipeline Logs") | |
| pipeline_logs = gr.Textbox( | |
| label="Execution Log", | |
| lines=8, | |
| info="Real-time pipeline progress and status", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("### 📈 Visualization") | |
| stability_plot = gr.Plot(label="Stability Summary") | |
| with gr.Row(): | |
| gr.Markdown("### 📁 Download Results") | |
| with gr.Row(): | |
| structure_download = gr.File( | |
| label="Generated Structure (PDB)" | |
| ) | |
| descriptors_download = gr.File( | |
| label="Calculated Descriptors (CSV)" | |
| ) | |
| trajectory_info = gr.Textbox( | |
| label="Trajectory Information", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("### 🔬 3D Structure Viewer") | |
| structure_viewer = gr.HTML(label="Interactive 3D Structure") | |
| with gr.Row(): | |
| gr.Markdown("### 📊 Analysis Plots") | |
| with gr.Row(): | |
| rmsd_plot = gr.Plot(label="RMSD Analysis") | |
| flexibility_plot = gr.Plot(label="Flexibility Heatmap") | |
| with gr.Row(): | |
| rg_plot = gr.Plot(label="Radius of Gyration") | |
| contact_plot = gr.Plot(label="Contact Maps") | |
| with gr.Row(): | |
| trajectory_snapshots = gr.Plot(label="Trajectory Snapshots") | |
| with gr.Tab("⚡ Quick Prediction"): | |
| gr.Markdown(""" | |
| ### Upload Pre-calculated Descriptors | |
| If you have already calculated MD descriptors, upload them here for quick predictions. | |
| """) | |
| descriptor_upload = gr.File( | |
| label="Upload Descriptor CSV", | |
| file_types=[".csv"] | |
| ) | |
| quick_predict_btn = gr.Button("🔮 Make Quick Prediction", variant="secondary") | |
| with gr.Row(): | |
| quick_tagg = gr.Number(label="Tagg (°C)", interactive=False) | |
| quick_tmon = gr.Number(label="Tm,on (°C)", interactive=False) | |
| quick_tm = gr.Number(label="Tm (°C)", interactive=False) | |
| quick_summary_plot = gr.Plot(label="Quick Prediction Summary") | |
| with gr.Tab("📚 Documentation"): | |
| gr.Markdown(""" | |
| ## How to Use AbMelt | |
| ### 1. Input Requirements | |
| - **Heavy Chain**: Variable region sequence (VH) of your antibody | |
| - **Light Chain**: Variable region sequence (VL) of your antibody | |
| - **Simulation Time**: 10-100 ns (longer = more accurate but slower) | |
| - **Temperatures**: Comma-separated list (default: 300,350,400 K) | |
| ### 2. Pipeline Steps | |
| 1. **Structure Generation**: Uses ImmuneBuilder to create 3D structure | |
| 2. **System Preparation**: Sets up GROMACS molecular dynamics system | |
| 3. **MD Simulations**: Runs simulations at specified temperatures | |
| 4. **Descriptor Calculation**: Extracts molecular features from trajectories | |
| 5. **ML Prediction**: Uses trained models to predict thermostability | |
| ### 3. Output Interpretation | |
| - **Tagg**: Aggregation onset temperature - higher is better | |
| - **Tm,on**: On-pathway melting temperature - thermal stability indicator | |
| - **Tm**: Overall melting temperature - global stability measure | |
| ### 4. Tips for Best Results | |
| - Use complete variable regions (110-130 residues typical) | |
| - Longer simulations provide more reliable predictions | |
| - Check logs for any warnings or errors | |
| - Download trajectories for further analysis | |
| ### Additional Resources | |
| - [GitHub Repository](https://github.com/MSDLLCpapers/AbMelt) | |
| - [Original Paper](https://doi.org/10.1016/j.bpj.2024.06.003) | |
| - [GROMACS Documentation](https://manual.gromacs.org/) | |
| - [ImmuneBuilder](https://github.com/oxpig/ImmuneBuilder) | |
| """) | |
| with gr.Tab("❓ FAQ"): | |
| gr.Markdown(""" | |
| ## Frequently Asked Questions | |
| **Q: How long does the full pipeline take?** | |
| A: Approximately 15-20 minutes per nanosecond per temperature. For 10ns at 3 temperatures, | |
| expect 2-3 hours. For 100ns simulations, expect 20-30 hours. | |
| **Q: Can I use this for nanobodies or single-domain antibodies?** | |
| A: Yes, but you'll need to provide both VH and VL sequences. For nanobodies, | |
| you can duplicate the VH sequence as VL, though predictions may be less accurate. | |
| **Q: What if I only have the full antibody sequence?** | |
| A: You need to extract just the variable regions (VH and VL). Use IMGT/V-QUEST | |
| or similar tools to identify variable region boundaries. | |
| **Q: Why are my predictions very low/high?** | |
| A: Check that: | |
| - Sequences are correct variable regions | |
| - No unusual amino acids or modifications | |
| - Proper sequence length (typically 110-130 residues) | |
| **Q: Can I compare multiple antibodies?** | |
| A: Currently, process one at a time. Save results for comparison. | |
| Batch processing may be added in future versions. | |
| **Q: What does "GROMACS error" mean?** | |
| A: Usually indicates: | |
| - Severely malformed structure | |
| - System preparation issues | |
| - Memory constraints | |
| Try reducing simulation time or contact support. | |
| **Q: How accurate are the predictions?** | |
| A: Based on published validation: | |
| - Tagg: R² = 0.57 ± 0.11 | |
| - Tm,on: R² = 0.56 ± 0.01 | |
| - Tm: R² = 0.60 ± 0.06 | |
| **Q: Can I use the trajectories for other analyses?** | |
| A: Yes! Download the trajectory files and use with VMD, PyMOL, or other MD analysis tools. | |
| """) | |
| # Event handlers | |
| def update_time_estimate(sim_time_val, temps_str): | |
| """Calculate and update time estimate""" | |
| try: | |
| temp_count = len([t.strip() for t in temps_str.split(',') if t.strip()]) | |
| # More accurate time estimation | |
| setup_time = 10 # minutes for setup | |
| md_time_per_ns = 12 # minutes per ns per temperature (average) | |
| analysis_time = 5 # minutes for analysis | |
| total_md_time = sim_time_val * temp_count * md_time_per_ns | |
| total_time = setup_time + total_md_time + analysis_time | |
| hours = int(total_time // 60) | |
| minutes = int(total_time % 60) | |
| if hours > 0: | |
| time_str = f"~{hours}h {minutes}m" | |
| else: | |
| time_str = f"~{minutes}m" | |
| # Add warning for long simulations | |
| if total_time > 240: # More than 4 hours | |
| time_str += " ⚠️ Consider using shorter simulation time" | |
| return time_str | |
| except Exception as e: | |
| return "Unable to estimate" | |
| def run_pipeline_wrapper(heavy, light, sim_time_val, temps_str, gen_visuals): | |
| """Wrapper to run pipeline with progress updates""" | |
| # Input validation | |
| if not heavy or not light: | |
| return ( | |
| None, None, None, # predictions | |
| "❌ Error: Both heavy and light chain sequences are required", # logs | |
| None, # stability plot | |
| None, None, # downloads | |
| None, # trajectory info | |
| "", # structure viewer HTML | |
| None, None, None, None, None # visualization plots | |
| ) | |
| # Validate sequence length and content | |
| if len(heavy.strip()) < 50 or len(light.strip()) < 50: | |
| return ( | |
| None, None, None, | |
| "❌ Error: Sequences too short. Provide complete variable regions (>50 residues)", | |
| None, None, None, None, "", None, None, None, None, None | |
| ) | |
| # Check for valid amino acids | |
| valid_aas = set('ACDEFGHIKLMNPQRSTVWY') | |
| heavy_clean = heavy.upper().strip() | |
| light_clean = light.upper().strip() | |
| if not all(aa in valid_aas for aa in heavy_clean): | |
| return ( | |
| None, None, None, | |
| "❌ Error: Invalid characters in heavy chain sequence", | |
| None, None, None, None, "", None, None, None, None, None | |
| ) | |
| if not all(aa in valid_aas for aa in light_clean): | |
| return ( | |
| None, None, None, | |
| "❌ Error: Invalid characters in light chain sequence", | |
| None, None, None, None, "", None, None, None, None, None | |
| ) | |
| # Progress tracking | |
| def progress_callback(percent, message): | |
| return f"[{percent}%] {message}" | |
| try: | |
| # Run the pipeline | |
| results = pipeline.run_complete_pipeline( | |
| heavy_clean, light_clean, sim_time_val, temps_str, | |
| progress_callback, gen_visuals | |
| ) | |
| # Extract results | |
| predictions = results.get('predictions', {}) | |
| logs = "\n".join(results.get('logs', [])) | |
| visuals = results.get('visualizations', {}) | |
| if results.get('error'): | |
| logs += f"\n❌ {results['error']}" | |
| # Prepare file outputs | |
| structure_file = results.get('intermediate_files', {}).get('structure') | |
| desc_file = results.get('intermediate_files', {}).get('descriptors_csv') | |
| # Trajectory information | |
| traj_info = "No trajectory files generated" | |
| if results.get('intermediate_files', {}).get('trajectories'): | |
| traj_dict = results['intermediate_files']['trajectories'] | |
| traj_info = "Generated trajectories:\n" | |
| for temp, traj_path in traj_dict.items(): | |
| if os.path.exists(traj_path): | |
| size_mb = os.path.getsize(traj_path) / (1024*1024) | |
| traj_info += f"- {temp}K: {os.path.basename(traj_path)} ({size_mb:.1f} MB)\n" | |
| # Extract prediction values | |
| tagg_val = predictions.get('tagg', {}).get('value') | |
| tmon_val = predictions.get('tmon', {}).get('value') | |
| tm_val = predictions.get('tm', {}).get('value') | |
| # Extract visualizations | |
| structure_html = visuals.get('structure_3d', "") | |
| rmsd_fig = visuals.get('rmsd_plot') | |
| flex_fig = visuals.get('flexibility_heatmap') | |
| rg_fig = visuals.get('rg_plot') | |
| contact_fig = visuals.get('contact_maps') | |
| traj_fig = visuals.get('trajectory_snapshots') | |
| summary_fig = visuals.get('stability_summary') | |
| return ( | |
| tagg_val, tmon_val, tm_val, # predictions | |
| logs, # pipeline logs | |
| summary_fig, # stability plot | |
| structure_file, desc_file, # downloads | |
| traj_info, # trajectory info | |
| structure_html, # 3D viewer | |
| rmsd_fig, flex_fig, rg_fig, contact_fig, traj_fig # plots | |
| ) | |
| except Exception as e: | |
| error_msg = f"❌ Pipeline failed: {str(e)}" | |
| logger.error(f"Pipeline wrapper failed: {traceback.format_exc()}") | |
| return ( | |
| None, None, None, # predictions | |
| error_msg, # logs | |
| None, None, None, None, "", None, None, None, None, None # rest | |
| ) | |
| def quick_prediction(desc_file): | |
| """Handle quick prediction from uploaded descriptors""" | |
| if desc_file is None: | |
| return None, None, None, None | |
| try: | |
| # Load descriptors | |
| df = pd.read_csv(desc_file.name) | |
| # Convert DataFrame to dictionary format expected by predictor | |
| if len(df) > 0: | |
| descriptors = df.iloc[0].to_dict() | |
| else: | |
| return None, None, None, None | |
| # Make predictions | |
| if pipeline.predictor: | |
| predictions = pipeline.predictor.predict_thermostability(descriptors) | |
| tagg_val = predictions.get('tagg', {}).get('value') | |
| tmon_val = predictions.get('tmon', {}).get('value') | |
| tm_val = predictions.get('tm', {}).get('value') | |
| # Generate summary plot | |
| summary_plot = pipeline.visualizer.create_stability_summary_plot(predictions) | |
| return tagg_val, tmon_val, tm_val, summary_plot | |
| else: | |
| return None, None, None, None | |
| except Exception as e: | |
| logger.error(f"Quick prediction failed: {e}") | |
| return None, None, None, None | |
| # Connect event handlers | |
| sim_time.change( | |
| update_time_estimate, | |
| inputs=[sim_time, temperatures], | |
| outputs=time_estimate | |
| ) | |
| temperatures.change( | |
| update_time_estimate, | |
| inputs=[sim_time, temperatures], | |
| outputs=time_estimate | |
| ) | |
| run_button.click( | |
| run_pipeline_wrapper, | |
| inputs=[heavy_chain, light_chain, sim_time, temperatures, generate_visuals], | |
| outputs=[ | |
| tagg_result, tmon_result, tm_result, # predictions | |
| pipeline_logs, # logs | |
| stability_plot, # stability summary | |
| structure_download, descriptors_download, # downloads | |
| trajectory_info, # trajectory info | |
| structure_viewer, # 3D viewer | |
| rmsd_plot, flexibility_plot, rg_plot, # analysis plots | |
| contact_plot, trajectory_snapshots # more plots | |
| ] | |
| ) | |
| quick_predict_btn.click( | |
| quick_prediction, | |
| inputs=descriptor_upload, | |
| outputs=[quick_tagg, quick_tmon, quick_tm, quick_summary_plot] | |
| ) | |
| # Add example sequences | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYYMHWVRQAPGQGLEWMGIINPSGGSTSYAQKFQGRVTMTRDTSTSTVYMELSSLRSEDTAVYYCARGGNSAFYSSWFAYWGQGTLVTVSS", | |
| "DIQMTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKLLIYAASSLQSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYSTPPITFGQGTRLEIKR" | |
| ], | |
| [ | |
| "EVQLLESGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLEWVSAISGSGGSTYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCAKDRGYSSGWYLFDYWGQGTLVTVSS", | |
| "EIVLTQSPGTLSLSPGERATLSCRASQSVSSSYLAWYQQKPGQAPRLLIYGASSRATGIPDRFSGSGSGTDFTLTISRLEPEDFAVYYCQQYGSSPLTFGGGTKVEIKR" | |
| ] | |
| ], | |
| inputs=[heavy_chain, light_chain], | |
| label="Example Antibody Sequences" | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| # Create and launch the enhanced interface | |
| demo = create_interface() | |
| # Configure for HF Space or local deployment | |
| if os.getenv("SPACE_ID"): | |
| # Running in HF Space | |
| demo.queue(max_size=3) | |
| demo.launch(share=True) | |
| else: | |
| # Running locally | |
| demo.queue(max_size=5) | |
| demo.launch(share=True, server_name="0.0.0.0", server_port=7860) |