File size: 7,549 Bytes
cc95e47
 
 
 
 
 
aa7be3b
cc95e47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505dbb9
cc95e47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505dbb9
cc95e47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505dbb9
cc95e47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import pandas as pd
import matplotlib.pyplot as plt
import pypdb
import biotite.database.rcsb as rcsb
from MDAnalysis.analysis import rms
from opencadd.structure.superposition.engines.mda import MDAnalysisAligner
from lynxkite_core.ops import op
import os
import numpy as np
from Bio.PDB import PDBList, PDBParser, Superimposer


def calc_rmsd(A, B):
    """
    Calculate RMSD between two structures.

    Parameters
    ----------
    A : opencadd.structure.core.Structure
        Structure A.
    B : opencadd.structure.core.Structure
        Structure B.

    Returns
    -------
    float
        RMSD value.
    """
    aligner = MDAnalysisAligner()
    selection, _ = aligner.matching_selection(A, B)
    A = A.select_atoms(selection["reference"])
    B = B.select_atoms(selection["mobile"])
    return rms.rmsd(A.positions, B.positions, superposition=False)


def calc_rmsd_matrix(structures, names):
    """
    Calculate RMSD matrix between a list of structures.

    Parameters
    ----------
    structures : list of opencadd.structure.core.Structure
        List of structures.
    names : list of str
        List of structure names.

    Returns
    -------
    pandas.DataFrame
        RMSD matrix.
    """
    values = {name: {} for name in names}
    for i, (A, name_i) in enumerate(zip(structures, names)):
        for j, (B, name_j) in enumerate(zip(structures, names)):
            if i == j:
                values[name_i][name_j] = 0.0
                continue
            if i < j:
                rmsd = calc_rmsd(A, B)
                values[name_i][name_j] = rmsd
                values[name_j][name_i] = rmsd
                continue
    df = pd.DataFrame.from_dict(values)
    return df


@op("LynxKite Graph Analytics", "PDB composite search")
def get_pdb_count(
    *, ligand_id: str, experimental_method: str, max_resolution: float, polymer_count: int
):
    """
    Query the RCSB PDB for structures matching specified criteria and
    return the list of matching PDB IDs.

    Parameters
    ----------
    bundle : LynxKiteBundle
        The workflow bundle (unused here, included for op compatibility).
    ligand_id : str
        Non-polymer component ID to filter on (e.g., 'STI').
    experimental_method : str
        Experimental method to filter by (e.g., 'X-RAY DIFFRACTION').
    max_resolution : float
        Maximum resolution (Å) to include (<= this value).
    polymer_count : int
        Exact number of polymer chains in the structure.

    Returns
    -------
    List[str]
        A list of PDB IDs matching all criteria.
    """

    # 1) Query by ligand ID
    q_ligand = rcsb.FieldQuery(
        "rcsb_nonpolymer_entity_container_identifiers.nonpolymer_comp_id", exact_match=ligand_id
    )
    count_ligand = rcsb.count(q_ligand)
    print(f"Number of matches for ligand '{ligand_id}': {count_ligand}")

    # 2) Query by experimental method
    q_method = rcsb.FieldQuery("exptl.method", exact_match=experimental_method)
    count_method = rcsb.count(q_method)
    print(f"Number of matches for experimental method '{experimental_method}': {count_method}")

    # 3) Query by resolution
    q_resolution = rcsb.FieldQuery(
        "rcsb_entry_info.resolution_combined", less_or_equal=max_resolution
    )
    count_resolution = rcsb.count(q_resolution)
    print(f"Number of matches with resolution ≤ {max_resolution}: {count_resolution}")

    # 4) Query by polymer chain count
    q_polymer = rcsb.FieldQuery(
        "rcsb_entry_info.deposited_polymer_entity_instance_count", equals=polymer_count
    )
    count_polymer = rcsb.count(q_polymer)
    print(f"Number of matches with polymer count == {polymer_count}: {count_polymer}")

    # 5) Composite query (AND all criteria)
    composite_q = rcsb.CompositeQuery([q_ligand, q_method, q_resolution, q_polymer], "and")
    pdb_ids = rcsb.search(composite_q)

    # print(f"Number of composite matches: {len(pdb_ids)}")
    # print("Selected PDB IDs:")
    # print(*pdb_ids)
    pdb_ids = rcsb.search(composite_q)
    # Fetch PDBx descriptors
    pdbs_info = [pypdb.get_all_info(pid) for pid in pdb_ids]
    print(pdbs_info)
    title = [pdb_info["struct"]["title"] for pdb_info in pdbs_info]

    # Build DataFrame
    return pd.DataFrame({"pdb_id": pdb_ids, "description": title})


@op("LynxKite Graph Analytics", "PDB alignment RMSD")
def compute_pdb_rmsd(df: pd.DataFrame, *, pdb_id_col: str = "pdb_id") -> pd.DataFrame:
    """
    Accepts a DataFrame with a column of PDB IDs, downloads PDB files,
    selects Cα atoms from chain A (or first chain), superimposes using BioPython's
    Superimposer, computes the RMSD matrix (only on common residues),
    and returns it as a DataFrame.

    Parameters
    ----------
    df : pd.DataFrame
        Input DataFrame containing PDB IDs.
    pdb_id_col : str
        Name of the column in `df` with PDB IDs (default 'pdb_id').

    Returns
    -------
    pd.DataFrame
        Square DataFrame of pairwise RMSD values (Å), indexed and columned by PDB IDs.
    """

    # Prepare PDB directory
    pdb_dir = "pdb_files"
    os.makedirs(pdb_dir, exist_ok=True)

    out = df.copy()
    ids = out[pdb_id_col].tolist()
    n = len(ids)

    pdbl = PDBList()
    parser = PDBParser(QUIET=True)
    atom_dicts = []
    for pid in ids:
        pdbl.retrieve_pdb_file(pid, pdir=pdb_dir, file_format="pdb")
        path = os.path.join(pdb_dir, f"pdb{pid.lower()}.ent")
        struct = parser.get_structure(pid, path)
        model = struct[0]
        try:
            chain = model["A"]
        except KeyError:
            chain = next(model.get_chains())
        ca_atoms = {residue.id: residue["CA"] for residue in chain if residue.has_id("CA")}
        atom_dicts.append(ca_atoms)

    rmsd_mat = np.zeros((n, n))
    sup = Superimposer()
    for i in range(n):
        for j in range(i + 1, n):
            common = sorted(set(atom_dicts[i].keys()) & set(atom_dicts[j].keys()))
            if not common:
                rmsd = np.nan
            else:
                fixed_atoms = [atom_dicts[i][k] for k in common]
                moving_atoms = [atom_dicts[j][k] for k in common]
                sup.set_atoms(fixed_atoms, moving_atoms)
                rmsd = sup.rms
            rmsd_mat[i, j] = rmsd_mat[j, i] = round(rmsd, 1) if not np.isnan(rmsd) else np.nan

    return pd.DataFrame(rmsd_mat, index=ids, columns=ids)


@op("LynxKite Graph Analytics", "Plot matrix", view="matplotlib")
def plot_heatmap_from_df(
    df: pd.DataFrame, *, value_label: str = "Value", title: str = None
) -> plt.Figure:
    """
    Plot a heatmap of a square DataFrame using matplotlib.

    Parameters
    ----------
    df : pd.DataFrame
        Square DataFrame of values to plot.
    value_label : str
        Label for the color bar.
    title : str, optional
        Title for the plot.

    Returns
    -------
    plt.Figure
        The matplotlib Figure object containing the heatmap.
    """
    fig, ax = plt.subplots()
    im = ax.imshow(df.values)
    # create and label the colorbar
    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label(value_label)

    if title:
        ax.set_title(title)
    labels = df.index.tolist()
    ax.set_xticks(range(len(labels)))
    ax.set_xticklabels(labels, rotation=90)
    ax.set_yticks(range(len(labels)))
    ax.set_yticklabels(labels)
    # Annotate each cell
    for i in range(df.shape[0]):
        for j in range(df.shape[1]):
            ax.text(j, i, f"{df.iat[i, j]:.1f}", ha="center", va="center")
    plt.tight_layout()
    return fig