import pandas as pd import matplotlib.pyplot as plt import pypdb import biotite.database.rcsb as rcsb from MDAnalysis.analysis import rms from opencadd.structure.superposition.engines.mda import MDAnalysisAligner from lynxkite_core.ops import op import os import numpy as np from Bio.PDB import PDBList, PDBParser, Superimposer def calc_rmsd(A, B): """ Calculate RMSD between two structures. Parameters ---------- A : opencadd.structure.core.Structure Structure A. B : opencadd.structure.core.Structure Structure B. Returns ------- float RMSD value. """ aligner = MDAnalysisAligner() selection, _ = aligner.matching_selection(A, B) A = A.select_atoms(selection["reference"]) B = B.select_atoms(selection["mobile"]) return rms.rmsd(A.positions, B.positions, superposition=False) def calc_rmsd_matrix(structures, names): """ Calculate RMSD matrix between a list of structures. Parameters ---------- structures : list of opencadd.structure.core.Structure List of structures. names : list of str List of structure names. Returns ------- pandas.DataFrame RMSD matrix. """ values = {name: {} for name in names} for i, (A, name_i) in enumerate(zip(structures, names)): for j, (B, name_j) in enumerate(zip(structures, names)): if i == j: values[name_i][name_j] = 0.0 continue if i < j: rmsd = calc_rmsd(A, B) values[name_i][name_j] = rmsd values[name_j][name_i] = rmsd continue df = pd.DataFrame.from_dict(values) return df @op("LynxKite Graph Analytics", "PDB composite search") def get_pdb_count( *, ligand_id: str, experimental_method: str, max_resolution: float, polymer_count: int ): """ Query the RCSB PDB for structures matching specified criteria and return the list of matching PDB IDs. Parameters ---------- bundle : LynxKiteBundle The workflow bundle (unused here, included for op compatibility). ligand_id : str Non-polymer component ID to filter on (e.g., 'STI'). experimental_method : str Experimental method to filter by (e.g., 'X-RAY DIFFRACTION'). max_resolution : float Maximum resolution (Å) to include (<= this value). polymer_count : int Exact number of polymer chains in the structure. Returns ------- List[str] A list of PDB IDs matching all criteria. """ # 1) Query by ligand ID q_ligand = rcsb.FieldQuery( "rcsb_nonpolymer_entity_container_identifiers.nonpolymer_comp_id", exact_match=ligand_id ) count_ligand = rcsb.count(q_ligand) print(f"Number of matches for ligand '{ligand_id}': {count_ligand}") # 2) Query by experimental method q_method = rcsb.FieldQuery("exptl.method", exact_match=experimental_method) count_method = rcsb.count(q_method) print(f"Number of matches for experimental method '{experimental_method}': {count_method}") # 3) Query by resolution q_resolution = rcsb.FieldQuery( "rcsb_entry_info.resolution_combined", less_or_equal=max_resolution ) count_resolution = rcsb.count(q_resolution) print(f"Number of matches with resolution ≤ {max_resolution}: {count_resolution}") # 4) Query by polymer chain count q_polymer = rcsb.FieldQuery( "rcsb_entry_info.deposited_polymer_entity_instance_count", equals=polymer_count ) count_polymer = rcsb.count(q_polymer) print(f"Number of matches with polymer count == {polymer_count}: {count_polymer}") # 5) Composite query (AND all criteria) composite_q = rcsb.CompositeQuery([q_ligand, q_method, q_resolution, q_polymer], "and") pdb_ids = rcsb.search(composite_q) # print(f"Number of composite matches: {len(pdb_ids)}") # print("Selected PDB IDs:") # print(*pdb_ids) pdb_ids = rcsb.search(composite_q) # Fetch PDBx descriptors pdbs_info = [pypdb.get_all_info(pid) for pid in pdb_ids] print(pdbs_info) title = [pdb_info["struct"]["title"] for pdb_info in pdbs_info] # Build DataFrame return pd.DataFrame({"pdb_id": pdb_ids, "description": title}) @op("LynxKite Graph Analytics", "PDB alignment RMSD") def compute_pdb_rmsd(df: pd.DataFrame, *, pdb_id_col: str = "pdb_id") -> pd.DataFrame: """ Accepts a DataFrame with a column of PDB IDs, downloads PDB files, selects Cα atoms from chain A (or first chain), superimposes using BioPython's Superimposer, computes the RMSD matrix (only on common residues), and returns it as a DataFrame. Parameters ---------- df : pd.DataFrame Input DataFrame containing PDB IDs. pdb_id_col : str Name of the column in `df` with PDB IDs (default 'pdb_id'). Returns ------- pd.DataFrame Square DataFrame of pairwise RMSD values (Å), indexed and columned by PDB IDs. """ # Prepare PDB directory pdb_dir = "pdb_files" os.makedirs(pdb_dir, exist_ok=True) out = df.copy() ids = out[pdb_id_col].tolist() n = len(ids) pdbl = PDBList() parser = PDBParser(QUIET=True) atom_dicts = [] for pid in ids: pdbl.retrieve_pdb_file(pid, pdir=pdb_dir, file_format="pdb") path = os.path.join(pdb_dir, f"pdb{pid.lower()}.ent") struct = parser.get_structure(pid, path) model = struct[0] try: chain = model["A"] except KeyError: chain = next(model.get_chains()) ca_atoms = {residue.id: residue["CA"] for residue in chain if residue.has_id("CA")} atom_dicts.append(ca_atoms) rmsd_mat = np.zeros((n, n)) sup = Superimposer() for i in range(n): for j in range(i + 1, n): common = sorted(set(atom_dicts[i].keys()) & set(atom_dicts[j].keys())) if not common: rmsd = np.nan else: fixed_atoms = [atom_dicts[i][k] for k in common] moving_atoms = [atom_dicts[j][k] for k in common] sup.set_atoms(fixed_atoms, moving_atoms) rmsd = sup.rms rmsd_mat[i, j] = rmsd_mat[j, i] = round(rmsd, 1) if not np.isnan(rmsd) else np.nan return pd.DataFrame(rmsd_mat, index=ids, columns=ids) @op("LynxKite Graph Analytics", "Plot matrix", view="matplotlib") def plot_heatmap_from_df( df: pd.DataFrame, *, value_label: str = "Value", title: str = None ) -> plt.Figure: """ Plot a heatmap of a square DataFrame using matplotlib. Parameters ---------- df : pd.DataFrame Square DataFrame of values to plot. value_label : str Label for the color bar. title : str, optional Title for the plot. Returns ------- plt.Figure The matplotlib Figure object containing the heatmap. """ fig, ax = plt.subplots() im = ax.imshow(df.values) # create and label the colorbar cbar = fig.colorbar(im, ax=ax) cbar.set_label(value_label) if title: ax.set_title(title) labels = df.index.tolist() ax.set_xticks(range(len(labels))) ax.set_xticklabels(labels, rotation=90) ax.set_yticks(range(len(labels))) ax.set_yticklabels(labels) # Annotate each cell for i in range(df.shape[0]): for j in range(df.shape[1]): ax.text(j, i, f"{df.iat[i, j]:.1f}", ha="center", va="center") plt.tight_layout() return fig