Spaces:
Running
Running
| '''API for implementing LynxKite operations.''' | |
| from __future__ import annotations | |
| import dataclasses | |
| import enum | |
| import functools | |
| import inspect | |
| import networkx as nx | |
| import pandas as pd | |
| import pydantic | |
| import typing | |
| from typing_extensions import Annotated | |
| ALL_OPS = {} | |
| typeof = type # We have some arguments called "type". | |
| def type_to_json(t): | |
| if isinstance(t, type) and issubclass(t, enum.Enum): | |
| return {'enum': list(t.__members__.keys())} | |
| if isinstance(t, tuple) and t[0] == 'collapsed': | |
| return {'collapsed': str(t[1])} | |
| return {'type': str(t)} | |
| Type = Annotated[ | |
| typing.Any, pydantic.PlainSerializer(type_to_json, return_type=dict) | |
| ] | |
| class BaseConfig(pydantic.BaseModel): | |
| model_config = pydantic.ConfigDict( | |
| arbitrary_types_allowed=True, | |
| ) | |
| class Parameter(BaseConfig): | |
| '''Defines a parameter for an operation.''' | |
| name: str | |
| default: any | |
| type: Type = None | |
| def options(name, options, default=None): | |
| e = enum.Enum(f'OptionsFor_{name}', options) | |
| return Parameter.basic(name, e[default or options[0]], e) | |
| def collapsed(name, default, type=None): | |
| return Parameter.basic(name, default, ('collapsed', type or typeof(default))) | |
| def basic(name, default=None, type=None): | |
| if default is inspect._empty: | |
| default = None | |
| if type is None or type is inspect._empty: | |
| type = typeof(default) if default else None | |
| return Parameter(name=name, default=default, type=type) | |
| class Op(BaseConfig): | |
| func: callable = pydantic.Field(exclude=True) | |
| name: str | |
| params: dict[str, Parameter] | |
| inputs: dict[str, Type] # name -> type | |
| outputs: dict[str, Type] # name -> type | |
| type: str # The UI to use for this operation. | |
| sub_nodes: list[Op] = None # If set, these nodes can be placed inside the operation's node. | |
| def __call__(self, *inputs, **params): | |
| # Convert parameters. | |
| for p in params: | |
| if p in self.params: | |
| if self.params[p].type == int: | |
| params[p] = int(params[p]) | |
| elif self.params[p].type == float: | |
| params[p] = float(params[p]) | |
| # Convert inputs. | |
| inputs = list(inputs) | |
| for i, (x, t) in enumerate(zip(inputs, self.inputs.values())): | |
| if t == nx.Graph and isinstance(x, Bundle): | |
| inputs[i] = x.to_nx() | |
| elif t == Bundle and isinstance(x, nx.Graph): | |
| inputs[i] = Bundle.from_nx(x) | |
| res = self.func(*inputs, **params) | |
| return res | |
| class RelationDefinition: | |
| '''Defines a set of edges.''' | |
| df: str # The DataFrame that contains the edges. | |
| source_column: str # The column in the edge DataFrame that contains the source node ID. | |
| target_column: str # The column in the edge DataFrame that contains the target node ID. | |
| source_table: str # The DataFrame that contains the source nodes. | |
| target_table: str # The DataFrame that contains the target nodes. | |
| source_key: str # The column in the source table that contains the node ID. | |
| target_key: str # The column in the target table that contains the node ID. | |
| class Bundle: | |
| '''A collection of DataFrames and other data. | |
| Can efficiently represent a knowledge graph (homogeneous or heterogeneous) or tabular data. | |
| It can also carry other data, such as a trained model. | |
| ''' | |
| dfs: dict[str, pd.DataFrame] = dataclasses.field(default_factory=dict) | |
| relations: list[RelationDefinition] = dataclasses.field(default_factory=list) | |
| other: dict[str, typing.Any] = None | |
| def from_nx(cls, graph: nx.Graph): | |
| edges = nx.to_pandas_edgelist(graph) | |
| d = dict(graph.nodes(data=True)) | |
| nodes = pd.DataFrame(d.values(), index=d.keys()) | |
| nodes['id'] = nodes.index | |
| return cls( | |
| dfs={'edges': edges, 'nodes': nodes}, | |
| relations=[ | |
| RelationDefinition( | |
| df='edges', | |
| source_column='source', | |
| target_column='target', | |
| source_table='nodes', | |
| target_table='nodes', | |
| source_key='id', | |
| target_key='id', | |
| ) | |
| ] | |
| ) | |
| def to_nx(self): | |
| graph = nx.from_pandas_edgelist(self.dfs['edges']) | |
| nx.set_node_attributes(graph, self.dfs['nodes'].set_index('id').to_dict('index')) | |
| return graph | |
| def nx_node_attribute_func(name): | |
| '''Decorator for wrapping a function that adds a NetworkX node attribute.''' | |
| def decorator(func): | |
| def wrapper(graph: nx.Graph, **kwargs): | |
| graph = graph.copy() | |
| attr = func(graph, **kwargs) | |
| nx.set_node_attributes(graph, attr, name) | |
| return graph | |
| return wrapper | |
| return decorator | |
| def op(name, *, view='basic', sub_nodes=None): | |
| '''Decorator for defining an operation.''' | |
| def decorator(func): | |
| sig = inspect.signature(func) | |
| # Positional arguments are inputs. | |
| inputs = { | |
| name: param.annotation | |
| for name, param in sig.parameters.items() | |
| if param.kind != param.KEYWORD_ONLY} | |
| params = {} | |
| for n, param in sig.parameters.items(): | |
| if param.kind == param.KEYWORD_ONLY: | |
| params[n] = Parameter.basic(n, param.default, param.annotation) | |
| outputs = {'output': 'yes'} if view == 'basic' else {} # Maybe more fancy later. | |
| op = Op(func=func, name=name, params=params, inputs=inputs, outputs=outputs, type=view) | |
| if sub_nodes is not None: | |
| op.sub_nodes = sub_nodes | |
| op.type = 'sub_flow' | |
| ALL_OPS[name] = op | |
| return func | |
| return decorator | |
| def no_op(*args, **kwargs): | |
| if args: | |
| return args[0] | |
| return Bundle() | |
| def register_passive_op(name, inputs={'input': Bundle}, outputs={'output': Bundle}, params=[]): | |
| '''A passive operation has no associated code.''' | |
| op = Op(no_op, name, params={p.name: p for p in params}, inputs=inputs, outputs=outputs, type='basic') | |
| ALL_OPS[name] = op | |
| return op | |
| def register_area(name, params=[]): | |
| '''A node that represents an area. It can contain other nodes, but does not restrict movement in any way.''' | |
| op = register_passive_op(name, params=params, inputs={}, outputs={}) | |
| op.type = 'area' | |