Spaces:
Runtime error
Runtime error
| import pathlib | |
| import shutil | |
| from typing import Optional, List | |
| from substra import Client, BackendType | |
| from substra.sdk.schemas import ( | |
| DatasetSpec, | |
| Permissions, | |
| DataSampleSpec | |
| ) | |
| from substrafl.strategies import Strategy | |
| from substrafl.dependency import Dependency | |
| from substrafl.remote.register import add_metric | |
| from substrafl.index_generator import NpIndexGenerator | |
| from substrafl.algorithms.pytorch import TorchFedAvgAlgo | |
| from substrafl.nodes import TrainDataNode, AggregationNode, TestDataNode | |
| from substrafl.evaluation_strategy import EvaluationStrategy | |
| from substrafl.experiment import execute_experiment | |
| from substra.sdk.models import ComputePlan | |
| from datasets import load_dataset, Dataset | |
| from sklearn.metrics import accuracy_score | |
| import numpy as np | |
| import torch | |
| class SubstraRunner: | |
| def __init__(self): | |
| self.num_clients = 3 | |
| self.clients = {} | |
| self.algo_provider: Optional[Client] = None | |
| self.datasets: List[Dataset] = [] | |
| self.test_dataset: Optional[Dataset] = None | |
| self.path = pathlib.Path(__file__).parent.resolve() | |
| self.dataset_keys = {} | |
| self.train_data_sample_keys = {} | |
| self.test_data_sample_keys = {} | |
| self.metric_key: Optional[str] = None | |
| NUM_UPDATES = 100 | |
| BATCH_SIZE = 32 | |
| self.index_generator = NpIndexGenerator( | |
| batch_size=BATCH_SIZE, | |
| num_updates=NUM_UPDATES, | |
| ) | |
| self.algorithm: Optional[TorchFedAvgAlgo] = None | |
| self.strategy: Optional[Strategy] = None | |
| self.aggregation_node: Optional[AggregationNode] = None | |
| self.train_data_nodes = list() | |
| self.test_data_nodes = list() | |
| self.eval_strategy: Optional[EvaluationStrategy] = None | |
| self.NUM_ROUNDS = 3 | |
| self.compute_plan: Optional[ComputePlan] = None | |
| self.experiment_folder = self.path / "experiment_summaries" | |
| def set_up_clients(self): | |
| self.algo_provider = Client(backend_type=BackendType.LOCAL_SUBPROCESS) | |
| self.clients = { | |
| c.organization_info().organization_id: c | |
| for c in [Client(backend_type=BackendType.LOCAL_SUBPROCESS) for _ in range(self.num_clients - 1)] | |
| } | |
| def prepare_data(self): | |
| dataset = load_dataset("mnist", split="train").shuffle() | |
| self.datasets = [dataset.shard(num_shards=self.num_clients - 1, index=i) for i in range(self.num_clients - 1)] | |
| self.test_dataset = load_dataset("mnist", split="test") | |
| data_path = self.path / "data" | |
| if data_path.exists() and data_path.is_dir(): | |
| shutil.rmtree(data_path) | |
| for i, client_id in enumerate(self.clients): | |
| ds = self.datasets[i] | |
| ds.save_to_disk(data_path / client_id / "train") | |
| self.test_dataset.save_to_disk(data_path / client_id / "test") | |
| def register_data(self): | |
| for client_id, client in self.clients.items(): | |
| permissions_dataset = Permissions(public=False, authorized_ids=[ | |
| self.algo_provider.organization_info().organization_id | |
| ]) | |
| dataset = DatasetSpec( | |
| name="MNIST", | |
| type="npy", | |
| data_opener=self.path / pathlib.Path("dataset_assets/opener.py"), | |
| description=self.path / pathlib.Path("dataset_assets/description.md"), | |
| permissions=permissions_dataset, | |
| logs_permission=permissions_dataset, | |
| ) | |
| self.dataset_keys[client_id] = client.add_dataset(dataset) | |
| assert self.dataset_keys[client_id], "Missing dataset key" | |
| self.train_data_sample_keys[client_id] = client.add_data_sample(DataSampleSpec( | |
| data_manager_keys=[self.dataset_keys[client_id]], | |
| path=self.path / "data" / client_id / "train", | |
| )) | |
| data_sample = DataSampleSpec( | |
| data_manager_keys=[self.dataset_keys[client_id]], | |
| path=self.path / "data" / client_id / "test", | |
| ) | |
| self.test_data_sample_keys[client_id] = client.add_data_sample(data_sample) | |
| def register_metric(self): | |
| permissions_metric = Permissions( | |
| public=False, | |
| authorized_ids=[ | |
| self.algo_provider.organization_info().organization_id | |
| ] + list(self.clients.keys()) | |
| ) | |
| metric_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "scikit-learn==1.1.1"]) | |
| def accuracy(datasamples, predictions_path): | |
| y_true = datasamples["label"] | |
| y_pred = np.load(predictions_path) | |
| return accuracy_score(y_true, np.argmax(y_pred, axis=1)) | |
| self.metric_key = add_metric( | |
| client=self.algo_provider, | |
| metric_function=accuracy, | |
| permissions=permissions_metric, | |
| dependencies=metric_deps, | |
| ) | |
| def set_aggregation(self): | |
| self.aggregation_node = AggregationNode(self.algo_provider.organization_info().organization_id) | |
| for org_id in self.clients: | |
| train_data_node = TrainDataNode( | |
| organization_id=org_id, | |
| data_manager_key=self.dataset_keys[org_id], | |
| data_sample_keys=[self.train_data_sample_keys[org_id]], | |
| ) | |
| self.train_data_nodes.append(train_data_node) | |
| def set_testing(self): | |
| for org_id in self.clients: | |
| test_data_node = TestDataNode( | |
| organization_id=org_id, | |
| data_manager_key=self.dataset_keys[org_id], | |
| test_data_sample_keys=[self.test_data_sample_keys[org_id]], | |
| metric_keys=[self.metric_key], | |
| ) | |
| self.test_data_nodes.append(test_data_node) | |
| self.eval_strategy = EvaluationStrategy(test_data_nodes=self.test_data_nodes, rounds=1) | |
| def run_compute_plan(self): | |
| algo_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "torch==1.11.0"]) | |
| self.compute_plan = execute_experiment( | |
| client=self.algo_provider, | |
| algo=self.algorithm, | |
| strategy=self.strategy, | |
| train_data_nodes=self.train_data_nodes, | |
| evaluation_strategy=self.eval_strategy, | |
| aggregation_node=self.aggregation_node, | |
| num_rounds=self.NUM_ROUNDS, | |
| experiment_folder=self.experiment_folder, | |
| dependencies=algo_deps, | |
| ) | |
| def algo_generator(model, criterion, optimizer, index_generator, dataset, seed): | |
| class MyAlgo(TorchFedAvgAlgo): | |
| def __init__(self): | |
| super().__init__( | |
| model=model, | |
| criterion=criterion, | |
| optimizer=optimizer, | |
| index_generator=index_generator, | |
| dataset=dataset, | |
| seed=seed, | |
| ) | |
| return MyAlgo | |