darabos's picture
Remove docs. "lynxkite.core" is now "lynxkite_core".
aa7be3b
import pandas as pd
import matplotlib.pyplot as plt
import pypdb
import biotite.database.rcsb as rcsb
from MDAnalysis.analysis import rms
from opencadd.structure.superposition.engines.mda import MDAnalysisAligner
from lynxkite_core.ops import op
import os
import numpy as np
from Bio.PDB import PDBList, PDBParser, Superimposer
def calc_rmsd(A, B):
"""
Calculate RMSD between two structures.
Parameters
----------
A : opencadd.structure.core.Structure
Structure A.
B : opencadd.structure.core.Structure
Structure B.
Returns
-------
float
RMSD value.
"""
aligner = MDAnalysisAligner()
selection, _ = aligner.matching_selection(A, B)
A = A.select_atoms(selection["reference"])
B = B.select_atoms(selection["mobile"])
return rms.rmsd(A.positions, B.positions, superposition=False)
def calc_rmsd_matrix(structures, names):
"""
Calculate RMSD matrix between a list of structures.
Parameters
----------
structures : list of opencadd.structure.core.Structure
List of structures.
names : list of str
List of structure names.
Returns
-------
pandas.DataFrame
RMSD matrix.
"""
values = {name: {} for name in names}
for i, (A, name_i) in enumerate(zip(structures, names)):
for j, (B, name_j) in enumerate(zip(structures, names)):
if i == j:
values[name_i][name_j] = 0.0
continue
if i < j:
rmsd = calc_rmsd(A, B)
values[name_i][name_j] = rmsd
values[name_j][name_i] = rmsd
continue
df = pd.DataFrame.from_dict(values)
return df
@op("LynxKite Graph Analytics", "PDB composite search")
def get_pdb_count(
*, ligand_id: str, experimental_method: str, max_resolution: float, polymer_count: int
):
"""
Query the RCSB PDB for structures matching specified criteria and
return the list of matching PDB IDs.
Parameters
----------
bundle : LynxKiteBundle
The workflow bundle (unused here, included for op compatibility).
ligand_id : str
Non-polymer component ID to filter on (e.g., 'STI').
experimental_method : str
Experimental method to filter by (e.g., 'X-RAY DIFFRACTION').
max_resolution : float
Maximum resolution (Å) to include (<= this value).
polymer_count : int
Exact number of polymer chains in the structure.
Returns
-------
List[str]
A list of PDB IDs matching all criteria.
"""
# 1) Query by ligand ID
q_ligand = rcsb.FieldQuery(
"rcsb_nonpolymer_entity_container_identifiers.nonpolymer_comp_id", exact_match=ligand_id
)
count_ligand = rcsb.count(q_ligand)
print(f"Number of matches for ligand '{ligand_id}': {count_ligand}")
# 2) Query by experimental method
q_method = rcsb.FieldQuery("exptl.method", exact_match=experimental_method)
count_method = rcsb.count(q_method)
print(f"Number of matches for experimental method '{experimental_method}': {count_method}")
# 3) Query by resolution
q_resolution = rcsb.FieldQuery(
"rcsb_entry_info.resolution_combined", less_or_equal=max_resolution
)
count_resolution = rcsb.count(q_resolution)
print(f"Number of matches with resolution ≤ {max_resolution}: {count_resolution}")
# 4) Query by polymer chain count
q_polymer = rcsb.FieldQuery(
"rcsb_entry_info.deposited_polymer_entity_instance_count", equals=polymer_count
)
count_polymer = rcsb.count(q_polymer)
print(f"Number of matches with polymer count == {polymer_count}: {count_polymer}")
# 5) Composite query (AND all criteria)
composite_q = rcsb.CompositeQuery([q_ligand, q_method, q_resolution, q_polymer], "and")
pdb_ids = rcsb.search(composite_q)
# print(f"Number of composite matches: {len(pdb_ids)}")
# print("Selected PDB IDs:")
# print(*pdb_ids)
pdb_ids = rcsb.search(composite_q)
# Fetch PDBx descriptors
pdbs_info = [pypdb.get_all_info(pid) for pid in pdb_ids]
print(pdbs_info)
title = [pdb_info["struct"]["title"] for pdb_info in pdbs_info]
# Build DataFrame
return pd.DataFrame({"pdb_id": pdb_ids, "description": title})
@op("LynxKite Graph Analytics", "PDB alignment RMSD")
def compute_pdb_rmsd(df: pd.DataFrame, *, pdb_id_col: str = "pdb_id") -> pd.DataFrame:
"""
Accepts a DataFrame with a column of PDB IDs, downloads PDB files,
selects Cα atoms from chain A (or first chain), superimposes using BioPython's
Superimposer, computes the RMSD matrix (only on common residues),
and returns it as a DataFrame.
Parameters
----------
df : pd.DataFrame
Input DataFrame containing PDB IDs.
pdb_id_col : str
Name of the column in `df` with PDB IDs (default 'pdb_id').
Returns
-------
pd.DataFrame
Square DataFrame of pairwise RMSD values (Å), indexed and columned by PDB IDs.
"""
# Prepare PDB directory
pdb_dir = "pdb_files"
os.makedirs(pdb_dir, exist_ok=True)
out = df.copy()
ids = out[pdb_id_col].tolist()
n = len(ids)
pdbl = PDBList()
parser = PDBParser(QUIET=True)
atom_dicts = []
for pid in ids:
pdbl.retrieve_pdb_file(pid, pdir=pdb_dir, file_format="pdb")
path = os.path.join(pdb_dir, f"pdb{pid.lower()}.ent")
struct = parser.get_structure(pid, path)
model = struct[0]
try:
chain = model["A"]
except KeyError:
chain = next(model.get_chains())
ca_atoms = {residue.id: residue["CA"] for residue in chain if residue.has_id("CA")}
atom_dicts.append(ca_atoms)
rmsd_mat = np.zeros((n, n))
sup = Superimposer()
for i in range(n):
for j in range(i + 1, n):
common = sorted(set(atom_dicts[i].keys()) & set(atom_dicts[j].keys()))
if not common:
rmsd = np.nan
else:
fixed_atoms = [atom_dicts[i][k] for k in common]
moving_atoms = [atom_dicts[j][k] for k in common]
sup.set_atoms(fixed_atoms, moving_atoms)
rmsd = sup.rms
rmsd_mat[i, j] = rmsd_mat[j, i] = round(rmsd, 1) if not np.isnan(rmsd) else np.nan
return pd.DataFrame(rmsd_mat, index=ids, columns=ids)
@op("LynxKite Graph Analytics", "Plot matrix", view="matplotlib")
def plot_heatmap_from_df(
df: pd.DataFrame, *, value_label: str = "Value", title: str = None
) -> plt.Figure:
"""
Plot a heatmap of a square DataFrame using matplotlib.
Parameters
----------
df : pd.DataFrame
Square DataFrame of values to plot.
value_label : str
Label for the color bar.
title : str, optional
Title for the plot.
Returns
-------
plt.Figure
The matplotlib Figure object containing the heatmap.
"""
fig, ax = plt.subplots()
im = ax.imshow(df.values)
# create and label the colorbar
cbar = fig.colorbar(im, ax=ax)
cbar.set_label(value_label)
if title:
ax.set_title(title)
labels = df.index.tolist()
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=90)
ax.set_yticks(range(len(labels)))
ax.set_yticklabels(labels)
# Annotate each cell
for i in range(df.shape[0]):
for j in range(df.shape[1]):
ax.text(j, i, f"{df.iat[i, j]:.1f}", ha="center", va="center")
plt.tight_layout()
return fig