Spaces:
Runtime error
Runtime error
| """BioNeMo related operations | |
| The intention is to showcase how BioNeMo can be integrated with LynxKite. This should be | |
| considered as a reference implementation and not a production ready code. | |
| The operations are quite specific for this example notebook: | |
| https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/examples/bionemo-geneformer/geneformer-celltype-classification.ipynb | |
| """ | |
| from lynxkite.core import ops | |
| import requests | |
| import tarfile | |
| import os | |
| from collections import Counter | |
| from . import core | |
| import joblib | |
| import numpy as np | |
| import torch | |
| from pathlib import Path | |
| import random | |
| from contextlib import contextmanager | |
| import cellxgene_census # TODO: This needs numpy < 2 | |
| import tempfile | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.pipeline import Pipeline | |
| from sklearn.model_selection import StratifiedKFold, cross_validate | |
| from sklearn.metrics import ( | |
| make_scorer, | |
| accuracy_score, | |
| precision_score, | |
| recall_score, | |
| f1_score, | |
| roc_auc_score, | |
| confusion_matrix, | |
| ) | |
| from sklearn.decomposition import PCA | |
| from sklearn.model_selection import cross_val_predict | |
| from sklearn.preprocessing import LabelEncoder | |
| from bionemo.scdl.io.single_cell_collection import SingleCellCollection | |
| import scanpy | |
| mem = joblib.Memory("joblib-cache") | |
| op = ops.op_registration(core.ENV) | |
| DATA_PATH = Path("/workspace") | |
| def random_seed(seed: int): | |
| state = random.getstate() | |
| random.seed(seed) | |
| try: | |
| yield | |
| finally: | |
| # Go back to previous state | |
| random.setstate(state) | |
| def download_cellxgene_dataset( | |
| *, | |
| save_path: str, | |
| census_version: str = "2023-12-15", | |
| organism: str = "Homo sapiens", | |
| value_filter='dataset_id=="8e47ed12-c658-4252-b126-381df8d52a3d"', | |
| max_workers: int = 1, | |
| use_mp: bool = False, | |
| ) -> None: | |
| """Downloads a CELLxGENE dataset""" | |
| with cellxgene_census.open_soma(census_version=census_version) as census: | |
| adata = cellxgene_census.get_anndata( | |
| census, | |
| organism, | |
| obs_value_filter=value_filter, | |
| ) | |
| with random_seed(32): | |
| indices = list(range(len(adata))) | |
| random.shuffle(indices) | |
| micro_batch_size: int = 32 | |
| num_steps: int = 256 | |
| selection = sorted(indices[: micro_batch_size * num_steps]) | |
| # NOTE: there's a current constraint that predict_step needs to be a function of micro-batch-size. | |
| # this is something we are working on fixing. A quick hack is to set micro-batch-size=1, but this is | |
| # slow. In this notebook we are going to use mbs=32 and subsample the anndata. | |
| adata = adata[selection].copy() # so it's not a view | |
| h5ad_outfile = DATA_PATH / Path("hs-celltype-bench.h5ad") | |
| adata.write_h5ad(h5ad_outfile) | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| coll = SingleCellCollection(temp_dir) | |
| coll.load_h5ad_multi( | |
| h5ad_outfile.parent, max_workers=max_workers, use_processes=use_mp | |
| ) | |
| coll.flatten(DATA_PATH / save_path, destroy_on_copy=True) | |
| return DATA_PATH / save_path | |
| def import_h5ad(*, file_path: str): | |
| return scanpy.read_h5ad(DATA_PATH / Path(file_path)) | |
| def download_model(*, model_name: str) -> str: | |
| """Downloads a model.""" | |
| model_download_parameters = { | |
| "geneformer_100m": { | |
| "name": "geneformer_100m", | |
| "version": "2.0", | |
| "path": "geneformer_106M_240530_nemo2", | |
| }, | |
| "geneformer_10m": { | |
| "name": "geneformer_10m", | |
| "version": "2.0", | |
| "path": "geneformer_10M_240530_nemo2", | |
| }, | |
| "geneformer_10m2": { | |
| "name": "geneformer_10m", | |
| "version": "2.1", | |
| "path": "geneformer_10M_241113_nemo2", | |
| }, | |
| } | |
| # Define the URL and output file | |
| url_template = "https://api.ngc.nvidia.com/v2/models/org/nvidia/team/clara/{name}/{version}/files?redirect=true&path={path}.tar.gz" | |
| url = url_template.format(**model_download_parameters[model_name]) | |
| model_filename = f"{DATA_PATH}/{model_download_parameters[model_name]['path']}" | |
| output_file = f"{model_filename}.tar.gz" | |
| # Send the request | |
| response = requests.get(url, allow_redirects=True, stream=True) | |
| response.raise_for_status() # Raise an error for bad responses (4xx and 5xx) | |
| # Save the file to disk | |
| with open(f"{output_file}", "wb") as file: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| file.write(chunk) | |
| # Extract the tar.gz file | |
| os.makedirs(model_filename, exist_ok=True) | |
| with tarfile.open(output_file, "r:gz") as tar: | |
| tar.extractall(path=model_filename) | |
| return model_filename | |
| def infer( | |
| dataset_path: str, model_path: str | None = None, *, results_path: str | |
| ) -> str: | |
| """Infer on a dataset.""" | |
| # This import is slow, so we only import it when we need it. | |
| from bionemo.geneformer.scripts.infer_geneformer import infer_model | |
| infer_model( | |
| data_path=dataset_path, | |
| checkpoint_path=model_path, | |
| results_path=DATA_PATH / results_path, | |
| include_hiddens=False, | |
| micro_batch_size=32, | |
| include_embeddings=True, | |
| include_logits=False, | |
| seq_length=2048, | |
| precision="bf16-mixed", | |
| devices=1, | |
| num_nodes=1, | |
| num_dataset_workers=10, | |
| ) | |
| return DATA_PATH / results_path | |
| def load_results(results_path: str): | |
| embeddings = ( | |
| torch.load(f"{results_path}/predictions__rank_0.pt")["embeddings"] | |
| .float() | |
| .cpu() | |
| .numpy() | |
| ) | |
| return embeddings | |
| def get_labels(adata): | |
| infer_metadata = adata.obs | |
| labels = infer_metadata["cell_type"].values | |
| label_encoder = LabelEncoder() | |
| integer_labels = label_encoder.fit_transform(labels) | |
| label_encoder.integer_labels = integer_labels | |
| return label_encoder | |
| def plot_labels(adata): | |
| infer_metadata = adata.obs | |
| labels = infer_metadata["cell_type"].values | |
| label_counts = Counter(labels) | |
| labels = list(label_counts.keys()) | |
| values = list(label_counts.values()) | |
| options = { | |
| "title": { | |
| "text": "Cell type counts for classification dataset", | |
| "left": "center", | |
| }, | |
| "tooltip": {"trigger": "axis", "axisPointer": {"type": "shadow"}}, | |
| "xAxis": { | |
| "type": "category", | |
| "data": labels, | |
| "axisLabel": {"rotate": 45, "align": "right"}, | |
| }, | |
| "yAxis": {"type": "value"}, | |
| "series": [ | |
| { | |
| "name": "Count", | |
| "type": "bar", | |
| "data": values, | |
| "itemStyle": {"color": "#4285F4"}, | |
| } | |
| ], | |
| } | |
| return options | |
| def run_benchmark(data, labels, *, use_pca: bool = False): | |
| """ | |
| data - contains the single cell expression (or whatever feature) in each row. | |
| labels - contains the string label for each cell | |
| data_shape (R, C) | |
| labels_shape (R,) | |
| """ | |
| np.random.seed(1337) | |
| # Define the target dimension 'n_components' | |
| n_components = 10 # for example, adjust based on your specific needs | |
| # Create a pipeline that includes Gaussian random projection and RandomForestClassifier | |
| if use_pca: | |
| pipeline = Pipeline( | |
| [ | |
| ("projection", PCA(n_components=n_components)), | |
| ("classifier", RandomForestClassifier(class_weight="balanced")), | |
| ] | |
| ) | |
| else: | |
| pipeline = Pipeline( | |
| [("classifier", RandomForestClassifier(class_weight="balanced"))] | |
| ) | |
| # Set up StratifiedKFold to ensure each fold reflects the overall distribution of labels | |
| cv = StratifiedKFold(n_splits=5) | |
| # Define the scoring functions | |
| scoring = { | |
| "accuracy": make_scorer(accuracy_score), | |
| "precision": make_scorer( | |
| precision_score, average="macro" | |
| ), # 'macro' averages over classes | |
| "recall": make_scorer(recall_score, average="macro"), | |
| "f1_score": make_scorer(f1_score, average="macro"), | |
| # 'roc_auc' requires probability or decision function; hence use multi_class if applicable | |
| "roc_auc": make_scorer(roc_auc_score, multi_class="ovr"), | |
| } | |
| labels = labels.integer_labels | |
| # Perform stratified cross-validation with multiple metrics using the pipeline | |
| results = cross_validate( | |
| pipeline, data, labels, cv=cv, scoring=scoring, return_train_score=False | |
| ) | |
| # Print the cross-validation results | |
| print("Cross-validation metrics:") | |
| results_out = {} | |
| for metric, scores in results.items(): | |
| if metric.startswith("test_"): | |
| results_out[metric] = (scores.mean(), scores.std()) | |
| print(f"{metric[5:]}: {scores.mean():.3f} (+/- {scores.std():.3f})") | |
| predictions = cross_val_predict(pipeline, data, labels, cv=cv) | |
| # v Return confusion matrix and metrics. | |
| conf_matrix = confusion_matrix(labels, predictions) | |
| return results_out, conf_matrix | |
| def plot_confusion_matrix(benchmark_output, labels): | |
| cm = benchmark_output[1] | |
| labels = labels.classes_ | |
| str_labels = [str(label) for label in labels] | |
| norm_cm = [[float(val / sum(row)) if sum(row) else 0 for val in row] for row in cm] | |
| # heatmap has the 0,0 at the bottom left corner | |
| num_rows = len(str_labels) | |
| heatmap_data = [ | |
| [j, num_rows - i - 1, norm_cm[i][j]] | |
| for i in range(len(labels)) | |
| for j in range(len(labels)) | |
| ] | |
| options = { | |
| "title": {"text": "Confusion Matrix", "left": "center"}, | |
| "tooltip": {"position": "top"}, | |
| "xAxis": { | |
| "type": "category", | |
| "data": str_labels, | |
| "splitArea": {"show": True}, | |
| "axisLabel": {"rotate": 70, "align": "right"}, | |
| }, | |
| "yAxis": { | |
| "type": "category", | |
| "data": list(reversed(str_labels)), | |
| "splitArea": {"show": True}, | |
| }, | |
| "grid": { | |
| "height": "70%", | |
| "width": "70%", | |
| "left": "20%", | |
| "right": "10%", | |
| "bottom": "10%", | |
| "top": "10%", | |
| }, | |
| "visualMap": { | |
| "min": 0, | |
| "max": 1, | |
| "calculable": True, | |
| "orient": "vertical", | |
| "right": 10, | |
| "top": "center", | |
| "inRange": { | |
| "color": ["#E0F7FA", "#81D4FA", "#29B6F6", "#0288D1", "#01579B"] | |
| }, | |
| }, | |
| "series": [ | |
| { | |
| "name": "Confusion matrix", | |
| "type": "heatmap", | |
| "data": heatmap_data, | |
| "emphasis": {"itemStyle": {"borderColor": "#333", "borderWidth": 1}}, | |
| "itemStyle": {"borderColor": "#D3D3D3", "borderWidth": 2}, | |
| } | |
| ], | |
| } | |
| return options | |
| def accuracy_comparison(benchmark_output10m, benchmark_output100m): | |
| results_10m = benchmark_output10m[0] | |
| results_106M = benchmark_output100m[0] | |
| data = { | |
| "model": ["10M parameters", "106M parameters"], | |
| "accuracy_mean": [ | |
| results_10m["test_accuracy"][0], | |
| results_106M["test_accuracy"][0], | |
| ], | |
| "accuracy_std": [ | |
| results_10m["test_accuracy"][1], | |
| results_106M["test_accuracy"][1], | |
| ], | |
| } | |
| labels = data["model"] # X-axis labels | |
| values = data["accuracy_mean"] # Y-axis values | |
| error_bars = data["accuracy_std"] # Standard deviation for error bars | |
| options = { | |
| "title": { | |
| "text": "Accuracy Comparison", | |
| "left": "center", | |
| "textStyle": { | |
| "fontSize": 20, # Bigger font for title | |
| "fontWeight": "bold", # Make title bold | |
| }, | |
| }, | |
| "grid": { | |
| "height": "70%", | |
| "width": "70%", | |
| "left": "20%", | |
| "right": "10%", | |
| "bottom": "10%", | |
| "top": "10%", | |
| }, | |
| "tooltip": {"trigger": "axis", "axisPointer": {"type": "shadow"}}, | |
| "xAxis": { | |
| "type": "category", | |
| "data": labels, | |
| "axisLabel": { | |
| "rotate": 45, # Rotate labels for better readability | |
| "align": "right", | |
| "textStyle": { | |
| "fontSize": 14, # Bigger font for X-axis labels | |
| "fontWeight": "bold", | |
| }, | |
| }, | |
| }, | |
| "yAxis": { | |
| "type": "value", | |
| "name": "Accuracy", | |
| "min": 0, | |
| "max": 1, | |
| "interval": 0.1, # Matches np.arange(0, 1.05, 0.05) | |
| "axisLabel": { | |
| "textStyle": { | |
| "fontSize": 14, # Bigger font for X-axis labels | |
| "fontWeight": "bold", | |
| } | |
| }, | |
| }, | |
| "series": [ | |
| { | |
| "name": "Accuracy", | |
| "type": "bar", | |
| "data": values, | |
| "itemStyle": { | |
| "color": "#440154" # Viridis color palette (dark purple) | |
| }, | |
| }, | |
| { | |
| "name": "Error Bars", | |
| "type": "errorbar", | |
| "data": [ | |
| [val - err, val + err] for val, err in zip(values, error_bars) | |
| ], | |
| "itemStyle": {"color": "#1f77b4"}, | |
| }, | |
| ], | |
| } | |
| return options | |
| def f1_comparison(benchmark_output10m, benchmark_output100m): | |
| results_10m = benchmark_output10m[0] | |
| results_106M = benchmark_output100m[0] | |
| data = { | |
| "model": ["10M parameters", "106M parameters"], | |
| "f1_score_mean": [ | |
| results_10m["test_f1_score"][0], | |
| results_106M["test_f1_score"][0], | |
| ], | |
| "f1_score_std": [ | |
| results_10m["test_f1_score"][1], | |
| results_106M["test_f1_score"][1], | |
| ], | |
| } | |
| labels = data["model"] # X-axis labels | |
| values = data["f1_score_mean"] # Y-axis values | |
| error_bars = data["f1_score_std"] # Standard deviation for error bars | |
| options = { | |
| "title": { | |
| "text": "F1 Score Comparison", | |
| "left": "center", | |
| "textStyle": { | |
| "fontSize": 20, # Bigger font for title | |
| "fontWeight": "bold", # Make title bold | |
| }, | |
| }, | |
| "grid": { | |
| "height": "70%", | |
| "width": "70%", | |
| "left": "20%", | |
| "right": "10%", | |
| "bottom": "10%", | |
| "top": "10%", | |
| }, | |
| "tooltip": {"trigger": "axis", "axisPointer": {"type": "shadow"}}, | |
| "xAxis": { | |
| "type": "category", | |
| "data": labels, | |
| "axisLabel": { | |
| "rotate": 45, # Rotate labels for better readability | |
| "align": "right", | |
| "textStyle": { | |
| "fontSize": 14, # Bigger font for X-axis labels | |
| "fontWeight": "bold", | |
| }, | |
| }, | |
| }, | |
| "yAxis": { | |
| "type": "value", | |
| "name": "F1 Score", | |
| "min": 0, | |
| "max": 1, | |
| "interval": 0.1, # Matches np.arange(0, 1.05, 0.05), | |
| "axisLabel": { | |
| "textStyle": { | |
| "fontSize": 14, # Bigger font for X-axis labels | |
| "fontWeight": "bold", | |
| } | |
| }, | |
| }, | |
| "series": [ | |
| { | |
| "name": "F1 Score", | |
| "type": "bar", | |
| "data": values, | |
| "itemStyle": { | |
| "color": "#440154" # Viridis color palette (dark purple) | |
| }, | |
| }, | |
| { | |
| "name": "Error Bars", | |
| "type": "errorbar", | |
| "data": [ | |
| [val - err, val + err] for val, err in zip(values, error_bars) | |
| ], | |
| "itemStyle": {"color": "#1f77b4"}, | |
| }, | |
| ], | |
| } | |
| return options | |