Spaces:
Running
Running
File size: 7,549 Bytes
cc95e47 aa7be3b cc95e47 505dbb9 cc95e47 505dbb9 cc95e47 505dbb9 cc95e47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
import pandas as pd
import matplotlib.pyplot as plt
import pypdb
import biotite.database.rcsb as rcsb
from MDAnalysis.analysis import rms
from opencadd.structure.superposition.engines.mda import MDAnalysisAligner
from lynxkite_core.ops import op
import os
import numpy as np
from Bio.PDB import PDBList, PDBParser, Superimposer
def calc_rmsd(A, B):
"""
Calculate RMSD between two structures.
Parameters
----------
A : opencadd.structure.core.Structure
Structure A.
B : opencadd.structure.core.Structure
Structure B.
Returns
-------
float
RMSD value.
"""
aligner = MDAnalysisAligner()
selection, _ = aligner.matching_selection(A, B)
A = A.select_atoms(selection["reference"])
B = B.select_atoms(selection["mobile"])
return rms.rmsd(A.positions, B.positions, superposition=False)
def calc_rmsd_matrix(structures, names):
"""
Calculate RMSD matrix between a list of structures.
Parameters
----------
structures : list of opencadd.structure.core.Structure
List of structures.
names : list of str
List of structure names.
Returns
-------
pandas.DataFrame
RMSD matrix.
"""
values = {name: {} for name in names}
for i, (A, name_i) in enumerate(zip(structures, names)):
for j, (B, name_j) in enumerate(zip(structures, names)):
if i == j:
values[name_i][name_j] = 0.0
continue
if i < j:
rmsd = calc_rmsd(A, B)
values[name_i][name_j] = rmsd
values[name_j][name_i] = rmsd
continue
df = pd.DataFrame.from_dict(values)
return df
@op("LynxKite Graph Analytics", "PDB composite search")
def get_pdb_count(
*, ligand_id: str, experimental_method: str, max_resolution: float, polymer_count: int
):
"""
Query the RCSB PDB for structures matching specified criteria and
return the list of matching PDB IDs.
Parameters
----------
bundle : LynxKiteBundle
The workflow bundle (unused here, included for op compatibility).
ligand_id : str
Non-polymer component ID to filter on (e.g., 'STI').
experimental_method : str
Experimental method to filter by (e.g., 'X-RAY DIFFRACTION').
max_resolution : float
Maximum resolution (Å) to include (<= this value).
polymer_count : int
Exact number of polymer chains in the structure.
Returns
-------
List[str]
A list of PDB IDs matching all criteria.
"""
# 1) Query by ligand ID
q_ligand = rcsb.FieldQuery(
"rcsb_nonpolymer_entity_container_identifiers.nonpolymer_comp_id", exact_match=ligand_id
)
count_ligand = rcsb.count(q_ligand)
print(f"Number of matches for ligand '{ligand_id}': {count_ligand}")
# 2) Query by experimental method
q_method = rcsb.FieldQuery("exptl.method", exact_match=experimental_method)
count_method = rcsb.count(q_method)
print(f"Number of matches for experimental method '{experimental_method}': {count_method}")
# 3) Query by resolution
q_resolution = rcsb.FieldQuery(
"rcsb_entry_info.resolution_combined", less_or_equal=max_resolution
)
count_resolution = rcsb.count(q_resolution)
print(f"Number of matches with resolution ≤ {max_resolution}: {count_resolution}")
# 4) Query by polymer chain count
q_polymer = rcsb.FieldQuery(
"rcsb_entry_info.deposited_polymer_entity_instance_count", equals=polymer_count
)
count_polymer = rcsb.count(q_polymer)
print(f"Number of matches with polymer count == {polymer_count}: {count_polymer}")
# 5) Composite query (AND all criteria)
composite_q = rcsb.CompositeQuery([q_ligand, q_method, q_resolution, q_polymer], "and")
pdb_ids = rcsb.search(composite_q)
# print(f"Number of composite matches: {len(pdb_ids)}")
# print("Selected PDB IDs:")
# print(*pdb_ids)
pdb_ids = rcsb.search(composite_q)
# Fetch PDBx descriptors
pdbs_info = [pypdb.get_all_info(pid) for pid in pdb_ids]
print(pdbs_info)
title = [pdb_info["struct"]["title"] for pdb_info in pdbs_info]
# Build DataFrame
return pd.DataFrame({"pdb_id": pdb_ids, "description": title})
@op("LynxKite Graph Analytics", "PDB alignment RMSD")
def compute_pdb_rmsd(df: pd.DataFrame, *, pdb_id_col: str = "pdb_id") -> pd.DataFrame:
"""
Accepts a DataFrame with a column of PDB IDs, downloads PDB files,
selects Cα atoms from chain A (or first chain), superimposes using BioPython's
Superimposer, computes the RMSD matrix (only on common residues),
and returns it as a DataFrame.
Parameters
----------
df : pd.DataFrame
Input DataFrame containing PDB IDs.
pdb_id_col : str
Name of the column in `df` with PDB IDs (default 'pdb_id').
Returns
-------
pd.DataFrame
Square DataFrame of pairwise RMSD values (Å), indexed and columned by PDB IDs.
"""
# Prepare PDB directory
pdb_dir = "pdb_files"
os.makedirs(pdb_dir, exist_ok=True)
out = df.copy()
ids = out[pdb_id_col].tolist()
n = len(ids)
pdbl = PDBList()
parser = PDBParser(QUIET=True)
atom_dicts = []
for pid in ids:
pdbl.retrieve_pdb_file(pid, pdir=pdb_dir, file_format="pdb")
path = os.path.join(pdb_dir, f"pdb{pid.lower()}.ent")
struct = parser.get_structure(pid, path)
model = struct[0]
try:
chain = model["A"]
except KeyError:
chain = next(model.get_chains())
ca_atoms = {residue.id: residue["CA"] for residue in chain if residue.has_id("CA")}
atom_dicts.append(ca_atoms)
rmsd_mat = np.zeros((n, n))
sup = Superimposer()
for i in range(n):
for j in range(i + 1, n):
common = sorted(set(atom_dicts[i].keys()) & set(atom_dicts[j].keys()))
if not common:
rmsd = np.nan
else:
fixed_atoms = [atom_dicts[i][k] for k in common]
moving_atoms = [atom_dicts[j][k] for k in common]
sup.set_atoms(fixed_atoms, moving_atoms)
rmsd = sup.rms
rmsd_mat[i, j] = rmsd_mat[j, i] = round(rmsd, 1) if not np.isnan(rmsd) else np.nan
return pd.DataFrame(rmsd_mat, index=ids, columns=ids)
@op("LynxKite Graph Analytics", "Plot matrix", view="matplotlib")
def plot_heatmap_from_df(
df: pd.DataFrame, *, value_label: str = "Value", title: str = None
) -> plt.Figure:
"""
Plot a heatmap of a square DataFrame using matplotlib.
Parameters
----------
df : pd.DataFrame
Square DataFrame of values to plot.
value_label : str
Label for the color bar.
title : str, optional
Title for the plot.
Returns
-------
plt.Figure
The matplotlib Figure object containing the heatmap.
"""
fig, ax = plt.subplots()
im = ax.imshow(df.values)
# create and label the colorbar
cbar = fig.colorbar(im, ax=ax)
cbar.set_label(value_label)
if title:
ax.set_title(title)
labels = df.index.tolist()
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=90)
ax.set_yticks(range(len(labels)))
ax.set_yticklabels(labels)
# Annotate each cell
for i in range(df.shape[0]):
for j in range(df.shape[1]):
ax.text(j, i, f"{df.iat[i, j]:.1f}", ha="center", va="center")
plt.tight_layout()
return fig
|