Spaces:
Running
Running
| """Boxes for defining PyTorch models.""" | |
| import graphlib | |
| import pydantic | |
| from lynxkite.core import ops, workspace | |
| from lynxkite.core.ops import Parameter as P | |
| import torch | |
| import torch_geometric as pyg | |
| from dataclasses import dataclass | |
| from . import core | |
| ENV = "PyTorch model" | |
| def reg(name, inputs=[], outputs=None, params=[]): | |
| if outputs is None: | |
| outputs = inputs | |
| return ops.register_passive_op( | |
| ENV, | |
| name, | |
| inputs=[ | |
| ops.Input(name=name, position="bottom", type="tensor") for name in inputs | |
| ], | |
| outputs=[ | |
| ops.Output(name=name, position="top", type="tensor") for name in outputs | |
| ], | |
| params=params, | |
| ) | |
| reg("Input: embedding", outputs=["x"]) | |
| reg("Input: graph edges", outputs=["edges"]) | |
| reg("Input: label", outputs=["y"]) | |
| reg("Input: positive sample", outputs=["x_pos"]) | |
| reg("Input: negative sample", outputs=["x_neg"]) | |
| reg("Input: sequential", outputs=["y"]) | |
| reg("Input: zeros", outputs=["x"]) | |
| reg("LSTM", inputs=["x", "h"], outputs=["x", "h"]) | |
| reg( | |
| "Neural ODE", | |
| inputs=["x"], | |
| params=[ | |
| P.basic("relative_tolerance"), | |
| P.basic("absolute_tolerance"), | |
| P.options( | |
| "method", | |
| [ | |
| "dopri8", | |
| "dopri5", | |
| "bosh3", | |
| "fehlberg2", | |
| "adaptive_heun", | |
| "euler", | |
| "midpoint", | |
| "rk4", | |
| "explicit_adams", | |
| "implicit_adams", | |
| ], | |
| ), | |
| ], | |
| ) | |
| reg("Attention", inputs=["q", "k", "v"], outputs=["x", "weights"]) | |
| reg("LayerNorm", inputs=["x"]) | |
| reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)]) | |
| reg("Linear", inputs=["x"], params=[P.basic("output_dim", "same")]) | |
| reg("Softmax", inputs=["x"]) | |
| reg( | |
| "Graph conv", | |
| inputs=["x", "edges"], | |
| outputs=["x"], | |
| params=[P.options("type", ["GCNConv", "GATConv", "GATv2Conv", "SAGEConv"])], | |
| ) | |
| reg( | |
| "Activation", | |
| inputs=["x"], | |
| params=[P.options("type", ["ReLU", "Leaky ReLU", "Tanh", "Mish"])], | |
| ) | |
| reg("Concatenate", inputs=["a", "b"], outputs=["x"]) | |
| reg("Add", inputs=["a", "b"], outputs=["x"]) | |
| reg("Subtract", inputs=["a", "b"], outputs=["x"]) | |
| reg("Multiply", inputs=["a", "b"], outputs=["x"]) | |
| reg("MSE loss", inputs=["x", "y"], outputs=["loss"]) | |
| reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"]) | |
| reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"]) | |
| reg( | |
| "Optimizer", | |
| inputs=["loss"], | |
| outputs=[], | |
| params=[ | |
| P.options( | |
| "type", | |
| [ | |
| "AdamW", | |
| "Adafactor", | |
| "Adagrad", | |
| "SGD", | |
| "Lion", | |
| "Paged AdamW", | |
| "Galore AdamW", | |
| ], | |
| ), | |
| P.basic("lr", 0.001), | |
| ], | |
| ) | |
| ops.register_passive_op( | |
| ENV, | |
| "Repeat", | |
| inputs=[ops.Input(name="input", position="top", type="tensor")], | |
| outputs=[ops.Output(name="output", position="bottom", type="tensor")], | |
| params=[ | |
| ops.Parameter.basic("times", 1, int), | |
| ops.Parameter.basic("same_weights", True, bool), | |
| ], | |
| ) | |
| ops.register_passive_op( | |
| ENV, | |
| "Recurrent chain", | |
| inputs=[ops.Input(name="input", position="top", type="tensor")], | |
| outputs=[ops.Output(name="output", position="bottom", type="tensor")], | |
| params=[], | |
| ) | |
| def _to_id(s: str) -> str: | |
| """Replaces all non-alphanumeric characters with underscores.""" | |
| return "".join(c if c.isalnum() else "_" for c in s) | |
| class ColumnSpec(pydantic.BaseModel): | |
| df: str | |
| column: str | |
| class ModelMapping(pydantic.BaseModel): | |
| map: dict[str, ColumnSpec] | |
| class ModelConfig: | |
| model: torch.nn.Module | |
| model_inputs: list[str] | |
| model_outputs: list[str] | |
| loss_inputs: list[str] | |
| loss: torch.nn.Module | |
| optimizer: torch.optim.Optimizer | |
| def _forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: | |
| model_inputs = [inputs[i] for i in self.model_inputs] | |
| output = self.model(*model_inputs) | |
| if not isinstance(output, tuple): | |
| output = (output,) | |
| values = {k: v for k, v in zip(self.model_outputs, output)} | |
| return values | |
| def inference(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: | |
| # TODO: Do multiple batches. | |
| self.model.eval() | |
| return self._forward(inputs) | |
| def train(self, inputs: dict[str, torch.Tensor]) -> float: | |
| """Train the model for one epoch. Returns the loss.""" | |
| # TODO: Do multiple batches. | |
| self.model.train() | |
| self.optimizer.zero_grad() | |
| values = self._forward(inputs) | |
| values.update(inputs) | |
| loss_inputs = [values[i] for i in self.loss_inputs] | |
| loss = self.loss(*loss_inputs) | |
| loss.backward() | |
| self.optimizer.step() | |
| return loss.item() | |
| def copy(self): | |
| """Returns a copy of the model.""" | |
| c = super().copy() | |
| c.model = self.model.copy() | |
| return c | |
| def default_display(self): | |
| return { | |
| "type": "model", | |
| "model": { | |
| "inputs": self.model_inputs, | |
| "outputs": self.model_outputs, | |
| "loss_inputs": self.loss_inputs, | |
| }, | |
| } | |
| def build_model( | |
| ws: workspace.Workspace, inputs: dict[str, torch.Tensor] | |
| ) -> ModelConfig: | |
| """Builds the model described in the workspace.""" | |
| catalog = ops.CATALOGS[ENV] | |
| optimizers = [] | |
| nodes = {} | |
| for node in ws.nodes: | |
| nodes[node.id] = node | |
| if node.data.title == "Optimizer": | |
| optimizers.append(node.id) | |
| assert optimizers, "No optimizer found." | |
| assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}" | |
| [optimizer] = optimizers | |
| dependencies = {n.id: [] for n in ws.nodes} | |
| edges = {} | |
| # TODO: Dissolve repeat boxes here. | |
| for e in ws.edges: | |
| dependencies[e.target].append(e.source) | |
| edges.setdefault((e.target, e.targetHandle), []).append( | |
| (e.source, e.sourceHandle) | |
| ) | |
| sizes = {} | |
| for k, i in inputs.items(): | |
| sizes[k] = i.shape[-1] | |
| ts = graphlib.TopologicalSorter(dependencies) | |
| layers = [] | |
| loss_layers = [] | |
| in_loss = set() | |
| cfg = {} | |
| loss_inputs = set() | |
| used_inputs = set() | |
| for node_id in ts.static_order(): | |
| node = nodes[node_id] | |
| t = node.data.title | |
| op = catalog[t] | |
| p = op.convert_params(node.data.params) | |
| for b in dependencies[node_id]: | |
| if b in in_loss: | |
| in_loss.add(node_id) | |
| ls = loss_layers if node_id in in_loss else layers | |
| nid = _to_id(node_id) | |
| match t: | |
| case "Linear": | |
| [(ib, ih)] = edges[node_id, "x"] | |
| i = _to_id(ib) + "_" + ih | |
| used_inputs.add(i) | |
| isize = sizes[i] | |
| osize = isize if p["output_dim"] == "same" else int(p["output_dim"]) | |
| ls.append((torch.nn.Linear(isize, osize), f"{i} -> {nid}_x")) | |
| sizes[f"{nid}_x"] = osize | |
| case "Activation": | |
| [(ib, ih)] = edges[node_id, "x"] | |
| i = _to_id(ib) + "_" + ih | |
| used_inputs.add(i) | |
| f = getattr( | |
| torch.nn.functional, p["type"].name.lower().replace(" ", "_") | |
| ) | |
| ls.append((f, f"{i} -> {nid}_x")) | |
| sizes[f"{nid}_x"] = sizes[i] | |
| case "MSE loss": | |
| [(xb, xh)] = edges[node_id, "x"] | |
| xi = _to_id(xb) + "_" + xh | |
| [(yb, yh)] = edges[node_id, "y"] | |
| yi = _to_id(yb) + "_" + yh | |
| loss_inputs.add(xi) | |
| loss_inputs.add(yi) | |
| in_loss.add(node_id) | |
| loss_layers.append( | |
| (torch.nn.functional.mse_loss, f"{xi}, {yi} -> {nid}_loss") | |
| ) | |
| cfg["model_inputs"] = list(used_inputs & inputs.keys()) | |
| cfg["model_outputs"] = list(loss_inputs - inputs.keys()) | |
| cfg["loss_inputs"] = list(loss_inputs) | |
| # Make sure the trained output is output from the last model layer. | |
| outputs = ", ".join(cfg["model_outputs"]) | |
| layers.append((torch.nn.Identity(), f"{outputs} -> {outputs}")) | |
| # Create model. | |
| cfg["model"] = pyg.nn.Sequential(", ".join(used_inputs & inputs.keys()), layers) | |
| # Make sure the loss is output from the last loss layer. | |
| [(lossb, lossh)] = edges[optimizer, "loss"] | |
| lossi = _to_id(lossb) + "_" + lossh | |
| loss_layers.append((torch.nn.Identity(), f"{lossi} -> loss")) | |
| # Create loss function. | |
| cfg["loss"] = pyg.nn.Sequential(", ".join(loss_inputs), loss_layers) | |
| assert not list(cfg["loss"].parameters()), ( | |
| f"loss should have no parameters: {list(cfg['loss'].parameters())}" | |
| ) | |
| # Create optimizer. | |
| op = catalog["Optimizer"] | |
| p = op.convert_params(nodes[optimizer].data.params) | |
| o = getattr(torch.optim, p["type"].name) | |
| cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"]) | |
| return ModelConfig(**cfg) | |
| def to_tensors(b: core.Bundle, m: ModelMapping) -> dict[str, torch.Tensor]: | |
| """Converts a tensor to the correct type for PyTorch.""" | |
| tensors = {} | |
| for k, v in m.map.items(): | |
| tensors[k] = torch.tensor(b.dfs[v.df][v.column].to_list(), dtype=torch.float32) | |
| return tensors | |