File size: 3,591 Bytes
680e7ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
import numpy as np
from sklearn.cluster import MiniBatchKMeans
from collections import defaultdict
from tqdm import tqdm
import selfies as sf
from multiprocessing import Pool, cpu_count
from functools import partial
def generate_fingerprint_batch_selfies(selfies_batch):
    fps = []
    valid_selfies = []
    
    for selfies in tqdm(selfies_batch, desc="Generating fingerprints", leave=False):
        try:
            # Convert SELFIES to SMILES then to molecule
            smiles = sf.decoder(selfies)
            mol = Chem.MolFromSmiles(smiles)
            if mol is not None:
                fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, 2048)
                arr = np.zeros((1,))
                DataStructs.ConvertToNumpyArray(fp, arr)
                fps.append(arr)
                valid_selfies.append(selfies)
        except:
            continue
            
    return np.array(fps), valid_selfies

def process_batch(batch, n_clusters, seed):
    fps, valid_selfies = generate_fingerprint_batch_selfies(batch)
    if len(fps) > 0:
        clusterer = MiniBatchKMeans(n_clusters=n_clusters, random_state=seed)
        clusterer.fit(fps)
        labels = clusterer.predict(fps)
        return list(zip(labels, valid_selfies))
    return []

def parallel_clustering_split_selfies(selfies_list, batch_size=10000, n_clusters=1000, train_ratio=0.9, seed=42):
    np.random.seed(seed)
    
    # Create batches
    batches = [selfies_list[i:i + batch_size] 
               for i in range(0, len(selfies_list), batch_size)]
    
    # Initialize parallel processing
    n_cores = 12
    process_batch_partial = partial(process_batch, n_clusters=n_clusters, seed=seed)
    
    cluster_assignments = defaultdict(list)
    with Pool(n_cores) as pool:
        results = list(tqdm(
            pool.imap(process_batch_partial, batches),
            total=len(batches),
            desc="Processing batches"
        ))
    
    # Combine results
    for batch_results in results:
        for label, selfies in batch_results:
            cluster_assignments[label].append(selfies)
    
    # Split into train/val
    clusters = list(cluster_assignments.values())
    np.random.shuffle(clusters)
    
    train_selfies = []
    val_selfies = []
    total_mols = sum(len(cluster) for cluster in clusters)
    
    for cluster in tqdm(clusters, desc="Splitting clusters"):
        if len(train_selfies) / total_mols < train_ratio:
            train_selfies.extend(cluster)
        else:
            val_selfies.extend(cluster)
    
    print(f"Final splits: Train={len(train_selfies)}, Validation={len(val_selfies)}")
    return train_selfies, val_selfies

try:
    with open('/home/yz927/projects/peptune/tokens/filtered_peptides_selfies.txt', 'r') as f:
        selfies_list = [line.strip() for line in f if line.strip()]
    print(f"Loaded {len(selfies_list)} selfies sequences from file")
except FileNotFoundError:
    raise FileNotFoundError(f"Could not find the file at file")
except Exception as e:
    raise Exception(f"Error reading file: {str(e)}")

train_selfies, val_selfies = parallel_clustering_split_selfies(
    selfies_list,
    batch_size=10000,
    n_clusters=1000,
    train_ratio=0.8
)
with open('/home/yz927/projects/peptune/tokens/11M_selfies/train_selfies.txt', 'w') as f:
    for line in train_selfies:
        f.write(f"{line}\n")
with open('/home/yz927/projects/peptune/tokens/11M_selfies/val_selfies.txt', 'w') as f:
    for line in val_selfies:
        f.write(f"{line}\n")