Spaces:
Running
Running
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import pypdb | |
| import biotite.database.rcsb as rcsb | |
| from MDAnalysis.analysis import rms | |
| from opencadd.structure.superposition.engines.mda import MDAnalysisAligner | |
| from lynxkite_core.ops import op | |
| import os | |
| import numpy as np | |
| from Bio.PDB import PDBList, PDBParser, Superimposer | |
| def calc_rmsd(A, B): | |
| """ | |
| Calculate RMSD between two structures. | |
| Parameters | |
| ---------- | |
| A : opencadd.structure.core.Structure | |
| Structure A. | |
| B : opencadd.structure.core.Structure | |
| Structure B. | |
| Returns | |
| ------- | |
| float | |
| RMSD value. | |
| """ | |
| aligner = MDAnalysisAligner() | |
| selection, _ = aligner.matching_selection(A, B) | |
| A = A.select_atoms(selection["reference"]) | |
| B = B.select_atoms(selection["mobile"]) | |
| return rms.rmsd(A.positions, B.positions, superposition=False) | |
| def calc_rmsd_matrix(structures, names): | |
| """ | |
| Calculate RMSD matrix between a list of structures. | |
| Parameters | |
| ---------- | |
| structures : list of opencadd.structure.core.Structure | |
| List of structures. | |
| names : list of str | |
| List of structure names. | |
| Returns | |
| ------- | |
| pandas.DataFrame | |
| RMSD matrix. | |
| """ | |
| values = {name: {} for name in names} | |
| for i, (A, name_i) in enumerate(zip(structures, names)): | |
| for j, (B, name_j) in enumerate(zip(structures, names)): | |
| if i == j: | |
| values[name_i][name_j] = 0.0 | |
| continue | |
| if i < j: | |
| rmsd = calc_rmsd(A, B) | |
| values[name_i][name_j] = rmsd | |
| values[name_j][name_i] = rmsd | |
| continue | |
| df = pd.DataFrame.from_dict(values) | |
| return df | |
| def get_pdb_count( | |
| *, ligand_id: str, experimental_method: str, max_resolution: float, polymer_count: int | |
| ): | |
| """ | |
| Query the RCSB PDB for structures matching specified criteria and | |
| return the list of matching PDB IDs. | |
| Parameters | |
| ---------- | |
| bundle : LynxKiteBundle | |
| The workflow bundle (unused here, included for op compatibility). | |
| ligand_id : str | |
| Non-polymer component ID to filter on (e.g., 'STI'). | |
| experimental_method : str | |
| Experimental method to filter by (e.g., 'X-RAY DIFFRACTION'). | |
| max_resolution : float | |
| Maximum resolution (Å) to include (<= this value). | |
| polymer_count : int | |
| Exact number of polymer chains in the structure. | |
| Returns | |
| ------- | |
| List[str] | |
| A list of PDB IDs matching all criteria. | |
| """ | |
| # 1) Query by ligand ID | |
| q_ligand = rcsb.FieldQuery( | |
| "rcsb_nonpolymer_entity_container_identifiers.nonpolymer_comp_id", exact_match=ligand_id | |
| ) | |
| count_ligand = rcsb.count(q_ligand) | |
| print(f"Number of matches for ligand '{ligand_id}': {count_ligand}") | |
| # 2) Query by experimental method | |
| q_method = rcsb.FieldQuery("exptl.method", exact_match=experimental_method) | |
| count_method = rcsb.count(q_method) | |
| print(f"Number of matches for experimental method '{experimental_method}': {count_method}") | |
| # 3) Query by resolution | |
| q_resolution = rcsb.FieldQuery( | |
| "rcsb_entry_info.resolution_combined", less_or_equal=max_resolution | |
| ) | |
| count_resolution = rcsb.count(q_resolution) | |
| print(f"Number of matches with resolution ≤ {max_resolution}: {count_resolution}") | |
| # 4) Query by polymer chain count | |
| q_polymer = rcsb.FieldQuery( | |
| "rcsb_entry_info.deposited_polymer_entity_instance_count", equals=polymer_count | |
| ) | |
| count_polymer = rcsb.count(q_polymer) | |
| print(f"Number of matches with polymer count == {polymer_count}: {count_polymer}") | |
| # 5) Composite query (AND all criteria) | |
| composite_q = rcsb.CompositeQuery([q_ligand, q_method, q_resolution, q_polymer], "and") | |
| pdb_ids = rcsb.search(composite_q) | |
| # print(f"Number of composite matches: {len(pdb_ids)}") | |
| # print("Selected PDB IDs:") | |
| # print(*pdb_ids) | |
| pdb_ids = rcsb.search(composite_q) | |
| # Fetch PDBx descriptors | |
| pdbs_info = [pypdb.get_all_info(pid) for pid in pdb_ids] | |
| print(pdbs_info) | |
| title = [pdb_info["struct"]["title"] for pdb_info in pdbs_info] | |
| # Build DataFrame | |
| return pd.DataFrame({"pdb_id": pdb_ids, "description": title}) | |
| def compute_pdb_rmsd(df: pd.DataFrame, *, pdb_id_col: str = "pdb_id") -> pd.DataFrame: | |
| """ | |
| Accepts a DataFrame with a column of PDB IDs, downloads PDB files, | |
| selects Cα atoms from chain A (or first chain), superimposes using BioPython's | |
| Superimposer, computes the RMSD matrix (only on common residues), | |
| and returns it as a DataFrame. | |
| Parameters | |
| ---------- | |
| df : pd.DataFrame | |
| Input DataFrame containing PDB IDs. | |
| pdb_id_col : str | |
| Name of the column in `df` with PDB IDs (default 'pdb_id'). | |
| Returns | |
| ------- | |
| pd.DataFrame | |
| Square DataFrame of pairwise RMSD values (Å), indexed and columned by PDB IDs. | |
| """ | |
| # Prepare PDB directory | |
| pdb_dir = "pdb_files" | |
| os.makedirs(pdb_dir, exist_ok=True) | |
| out = df.copy() | |
| ids = out[pdb_id_col].tolist() | |
| n = len(ids) | |
| pdbl = PDBList() | |
| parser = PDBParser(QUIET=True) | |
| atom_dicts = [] | |
| for pid in ids: | |
| pdbl.retrieve_pdb_file(pid, pdir=pdb_dir, file_format="pdb") | |
| path = os.path.join(pdb_dir, f"pdb{pid.lower()}.ent") | |
| struct = parser.get_structure(pid, path) | |
| model = struct[0] | |
| try: | |
| chain = model["A"] | |
| except KeyError: | |
| chain = next(model.get_chains()) | |
| ca_atoms = {residue.id: residue["CA"] for residue in chain if residue.has_id("CA")} | |
| atom_dicts.append(ca_atoms) | |
| rmsd_mat = np.zeros((n, n)) | |
| sup = Superimposer() | |
| for i in range(n): | |
| for j in range(i + 1, n): | |
| common = sorted(set(atom_dicts[i].keys()) & set(atom_dicts[j].keys())) | |
| if not common: | |
| rmsd = np.nan | |
| else: | |
| fixed_atoms = [atom_dicts[i][k] for k in common] | |
| moving_atoms = [atom_dicts[j][k] for k in common] | |
| sup.set_atoms(fixed_atoms, moving_atoms) | |
| rmsd = sup.rms | |
| rmsd_mat[i, j] = rmsd_mat[j, i] = round(rmsd, 1) if not np.isnan(rmsd) else np.nan | |
| return pd.DataFrame(rmsd_mat, index=ids, columns=ids) | |
| def plot_heatmap_from_df( | |
| df: pd.DataFrame, *, value_label: str = "Value", title: str = None | |
| ) -> plt.Figure: | |
| """ | |
| Plot a heatmap of a square DataFrame using matplotlib. | |
| Parameters | |
| ---------- | |
| df : pd.DataFrame | |
| Square DataFrame of values to plot. | |
| value_label : str | |
| Label for the color bar. | |
| title : str, optional | |
| Title for the plot. | |
| Returns | |
| ------- | |
| plt.Figure | |
| The matplotlib Figure object containing the heatmap. | |
| """ | |
| fig, ax = plt.subplots() | |
| im = ax.imshow(df.values) | |
| # create and label the colorbar | |
| cbar = fig.colorbar(im, ax=ax) | |
| cbar.set_label(value_label) | |
| if title: | |
| ax.set_title(title) | |
| labels = df.index.tolist() | |
| ax.set_xticks(range(len(labels))) | |
| ax.set_xticklabels(labels, rotation=90) | |
| ax.set_yticks(range(len(labels))) | |
| ax.set_yticklabels(labels) | |
| # Annotate each cell | |
| for i in range(df.shape[0]): | |
| for j in range(df.shape[1]): | |
| ax.text(j, i, f"{df.iat[i, j]:.1f}", ha="center", va="center") | |
| plt.tight_layout() | |
| return fig | |