Spaces:
Paused
Paused
| from __future__ import annotations | |
| import json | |
| import pathlib | |
| from typing import Any | |
| import anywidget | |
| import traitlets | |
| import altair as alt | |
| from altair import TopLevelSpec | |
| from altair.utils._vegafusion_data import ( | |
| compile_to_vegafusion_chart_state, | |
| using_vegafusion, | |
| ) | |
| from altair.utils.selection import IndexSelection, IntervalSelection, PointSelection | |
| _here = pathlib.Path(__file__).parent | |
| class Params(traitlets.HasTraits): | |
| """Traitlet class storing a JupyterChart's params.""" | |
| def __init__(self, trait_values): | |
| super().__init__() | |
| for key, value in trait_values.items(): | |
| if isinstance(value, (int, float)): | |
| traitlet_type = traitlets.Float() | |
| elif isinstance(value, str): | |
| traitlet_type = traitlets.Unicode() | |
| elif isinstance(value, list): | |
| traitlet_type = traitlets.List() | |
| elif isinstance(value, dict): | |
| traitlet_type = traitlets.Dict() | |
| else: | |
| traitlet_type = traitlets.Any() | |
| # Add the new trait. | |
| self.add_traits(**{key: traitlet_type}) | |
| # Set the trait's value. | |
| setattr(self, key, value) | |
| def __repr__(self): | |
| return f"Params({self.trait_values()})" | |
| class Selections(traitlets.HasTraits): | |
| """Traitlet class storing a JupyterChart's selections.""" | |
| def __init__(self, trait_values): | |
| super().__init__() | |
| for key, value in trait_values.items(): | |
| if isinstance(value, IndexSelection): | |
| traitlet_type = traitlets.Instance(IndexSelection) | |
| elif isinstance(value, PointSelection): | |
| traitlet_type = traitlets.Instance(PointSelection) | |
| elif isinstance(value, IntervalSelection): | |
| traitlet_type = traitlets.Instance(IntervalSelection) | |
| else: | |
| msg = f"Unexpected selection type: {type(value)}" | |
| raise ValueError(msg) | |
| # Add the new trait. | |
| self.add_traits(**{key: traitlet_type}) | |
| # Set the trait's value. | |
| setattr(self, key, value) | |
| # Make read-only | |
| self.observe(self._make_read_only, names=key) | |
| def __repr__(self): | |
| return f"Selections({self.trait_values()})" | |
| def _make_read_only(self, change): | |
| """Work around to make traits read-only, but still allow us to change them internally.""" | |
| if change["name"] in self.traits() and change["old"] != change["new"]: | |
| self._set_value(change["name"], change["old"]) | |
| msg = ( | |
| "Selections may not be set from Python.\n" | |
| f"Attempted to set select: {change['name']}" | |
| ) | |
| raise ValueError(msg) | |
| def _set_value(self, key, value): | |
| self.unobserve(self._make_read_only, names=key) | |
| setattr(self, key, value) | |
| self.observe(self._make_read_only, names=key) | |
| def load_js_src() -> str: | |
| return (_here / "js" / "index.js").read_text() | |
| class JupyterChart(anywidget.AnyWidget): | |
| _esm = load_js_src() | |
| _css = r""" | |
| .vega-embed { | |
| /* Make sure action menu isn't cut off */ | |
| overflow: visible; | |
| } | |
| """ | |
| # Public traitlets | |
| chart = traitlets.Instance(TopLevelSpec, allow_none=True) | |
| spec = traitlets.Dict(allow_none=True).tag(sync=True) | |
| debounce_wait = traitlets.Float(default_value=10).tag(sync=True) | |
| max_wait = traitlets.Bool(default_value=True).tag(sync=True) | |
| local_tz = traitlets.Unicode(default_value=None, allow_none=True).tag(sync=True) | |
| debug = traitlets.Bool(default_value=False) | |
| embed_options = traitlets.Dict(default_value=None, allow_none=True).tag(sync=True) | |
| # Internal selection traitlets | |
| _selection_types = traitlets.Dict() | |
| _vl_selections = traitlets.Dict().tag(sync=True) | |
| # Internal param traitlets | |
| _params = traitlets.Dict().tag(sync=True) | |
| # Internal comm traitlets for VegaFusion support | |
| _chart_state = traitlets.Any(allow_none=True) | |
| _js_watch_plan = traitlets.Any(allow_none=True).tag(sync=True) | |
| _js_to_py_updates = traitlets.Any(allow_none=True).tag(sync=True) | |
| _py_to_js_updates = traitlets.Any(allow_none=True).tag(sync=True) | |
| # Track whether charts are configured for offline use | |
| _is_offline = False | |
| def enable_offline(cls, offline: bool = True): | |
| """ | |
| Configure JupyterChart's offline behavior. | |
| Parameters | |
| ---------- | |
| offline: bool | |
| If True, configure JupyterChart to operate in offline mode where JavaScript | |
| dependencies are loaded from vl-convert. | |
| If False, configure it to operate in online mode where JavaScript dependencies | |
| are loaded from CDN dynamically. This is the default behavior. | |
| """ | |
| from altair.utils._importers import import_vl_convert, vl_version_for_vl_convert | |
| if offline: | |
| if cls._is_offline: | |
| # Already offline | |
| return | |
| vlc = import_vl_convert() | |
| src_lines = load_js_src().split("\n") | |
| # Remove leading lines with only whitespace, comments, or imports | |
| while src_lines and ( | |
| len(src_lines[0].strip()) == 0 | |
| or src_lines[0].startswith("import") | |
| or src_lines[0].startswith("//") | |
| ): | |
| src_lines.pop(0) | |
| src = "\n".join(src_lines) | |
| # vl-convert's javascript_bundle function creates a self-contained JavaScript bundle | |
| # for JavaScript snippets that import from a small set of dependencies that | |
| # vl-convert includes. To see the available imports and their imported names, run | |
| # import vl_convert as vlc | |
| # help(vlc.javascript_bundle) | |
| bundled_src = vlc.javascript_bundle( | |
| src, vl_version=vl_version_for_vl_convert() | |
| ) | |
| cls._esm = bundled_src | |
| cls._is_offline = True | |
| else: | |
| cls._esm = load_js_src() | |
| cls._is_offline = False | |
| def __init__( | |
| self, | |
| chart: TopLevelSpec, | |
| debounce_wait: int = 10, | |
| max_wait: bool = True, | |
| debug: bool = False, | |
| embed_options: dict | None = None, | |
| **kwargs: Any, | |
| ): | |
| """ | |
| Jupyter Widget for displaying and updating Altair Charts, and retrieving selection and parameter values. | |
| Parameters | |
| ---------- | |
| chart: Chart | |
| Altair Chart instance | |
| debounce_wait: int | |
| Debouncing wait time in milliseconds. Updates will be sent from the client to the kernel | |
| after debounce_wait milliseconds of no chart interactions. | |
| max_wait: bool | |
| If True (default), updates will be sent from the client to the kernel every debounce_wait | |
| milliseconds even if there are ongoing chart interactions. If False, updates will not be | |
| sent until chart interactions have completed. | |
| debug: bool | |
| If True, debug messages will be printed | |
| embed_options: dict | |
| Options to pass to vega-embed. | |
| See https://github.com/vega/vega-embed?tab=readme-ov-file#options | |
| """ | |
| self.params = Params({}) | |
| self.selections = Selections({}) | |
| super().__init__( | |
| chart=chart, | |
| debounce_wait=debounce_wait, | |
| max_wait=max_wait, | |
| debug=debug, | |
| embed_options=embed_options, | |
| **kwargs, | |
| ) | |
| def _on_change_chart(self, change): # noqa: C901 | |
| """Updates the JupyterChart's internal state when the wrapped Chart instance changes.""" | |
| new_chart = change.new | |
| selection_watches = [] | |
| selection_types = {} | |
| initial_params = {} | |
| initial_vl_selections = {} | |
| empty_selections = {} | |
| if new_chart is None: | |
| with self.hold_sync(): | |
| self.spec = None | |
| self._selection_types = selection_types | |
| self._vl_selections = initial_vl_selections | |
| self._params = initial_params | |
| return | |
| params = getattr(new_chart, "params", []) | |
| if params is not alt.Undefined: | |
| for param in new_chart.params: | |
| if isinstance(param.name, alt.ParameterName): | |
| clean_name = param.name.to_json().strip('"') | |
| else: | |
| clean_name = param.name | |
| select = getattr(param, "select", alt.Undefined) | |
| if select != alt.Undefined: | |
| if not isinstance(select, dict): | |
| select = select.to_dict() | |
| select_type = select["type"] | |
| if select_type == "point": | |
| if not ( | |
| select.get("fields", None) or select.get("encodings", None) | |
| ): | |
| # Point selection with no associated fields or encodings specified. | |
| # This is an index-based selection | |
| selection_types[clean_name] = "index" | |
| empty_selections[clean_name] = IndexSelection( | |
| name=clean_name, value=[], store=[] | |
| ) | |
| else: | |
| selection_types[clean_name] = "point" | |
| empty_selections[clean_name] = PointSelection( | |
| name=clean_name, value=[], store=[] | |
| ) | |
| elif select_type == "interval": | |
| selection_types[clean_name] = "interval" | |
| empty_selections[clean_name] = IntervalSelection( | |
| name=clean_name, value={}, store=[] | |
| ) | |
| else: | |
| msg = f"Unexpected selection type {select.type}" | |
| raise ValueError(msg) | |
| selection_watches.append(clean_name) | |
| initial_vl_selections[clean_name] = {"value": None, "store": []} | |
| else: | |
| clean_value = param.value if param.value != alt.Undefined else None | |
| initial_params[clean_name] = clean_value | |
| # Handle the params generated by transforms | |
| for param_name in collect_transform_params(new_chart): | |
| initial_params[param_name] = None | |
| # Setup params | |
| self.params = Params(initial_params) | |
| def on_param_traitlet_changed(param_change): | |
| new_params = dict(self._params) | |
| new_params[param_change["name"]] = param_change["new"] | |
| self._params = new_params | |
| self.params.observe(on_param_traitlet_changed) | |
| # Setup selections | |
| self.selections = Selections(empty_selections) | |
| # Update properties all together | |
| with self.hold_sync(): | |
| if using_vegafusion(): | |
| if self.local_tz is None: | |
| self.spec = None | |
| def on_local_tz_change(change): | |
| self._init_with_vegafusion(change["new"]) | |
| self.observe(on_local_tz_change, ["local_tz"]) | |
| else: | |
| self._init_with_vegafusion(self.local_tz) | |
| else: | |
| self.spec = new_chart.to_dict() | |
| self._selection_types = selection_types | |
| self._vl_selections = initial_vl_selections | |
| self._params = initial_params | |
| def _init_with_vegafusion(self, local_tz: str): | |
| if self.chart is not None: | |
| vegalite_spec = self.chart.to_dict(context={"pre_transform": False}) | |
| with self.hold_sync(): | |
| self._chart_state = compile_to_vegafusion_chart_state( | |
| vegalite_spec, local_tz | |
| ) | |
| self._js_watch_plan = self._chart_state.get_watch_plan()[ | |
| "client_to_server" | |
| ] | |
| self.spec = self._chart_state.get_transformed_spec() | |
| # Callback to update chart state and send updates back to client | |
| def on_js_to_py_updates(change): | |
| if self.debug: | |
| updates_str = json.dumps(change["new"], indent=2) | |
| print( | |
| f"JavaScript to Python VegaFusion updates:\n {updates_str}" | |
| ) | |
| updates = self._chart_state.update(change["new"]) | |
| if self.debug: | |
| updates_str = json.dumps(updates, indent=2) | |
| print( | |
| f"Python to JavaScript VegaFusion updates:\n {updates_str}" | |
| ) | |
| self._py_to_js_updates = updates | |
| self.observe(on_js_to_py_updates, ["_js_to_py_updates"]) | |
| def _on_change_params(self, change): | |
| for param_name, value in change.new.items(): | |
| setattr(self.params, param_name, value) | |
| def _on_change_selections(self, change): | |
| """Updates the JupyterChart's public selections traitlet in response to changes that the JavaScript logic makes to the internal _selections traitlet.""" | |
| for selection_name, selection_dict in change.new.items(): | |
| value = selection_dict["value"] | |
| store = selection_dict["store"] | |
| selection_type = self._selection_types[selection_name] | |
| if selection_type == "index": | |
| self.selections._set_value( | |
| selection_name, | |
| IndexSelection.from_vega(selection_name, signal=value, store=store), | |
| ) | |
| elif selection_type == "point": | |
| self.selections._set_value( | |
| selection_name, | |
| PointSelection.from_vega(selection_name, signal=value, store=store), | |
| ) | |
| elif selection_type == "interval": | |
| self.selections._set_value( | |
| selection_name, | |
| IntervalSelection.from_vega( | |
| selection_name, signal=value, store=store | |
| ), | |
| ) | |
| def collect_transform_params(chart: TopLevelSpec) -> set[str]: | |
| """ | |
| Collect the names of params that are defined by transforms. | |
| Parameters | |
| ---------- | |
| chart: Chart from which to extract transform params | |
| Returns | |
| ------- | |
| set of param names | |
| """ | |
| transform_params = set() | |
| # Handle recursive case | |
| for prop in ("layer", "concat", "hconcat", "vconcat"): | |
| for child in getattr(chart, prop, []): | |
| transform_params.update(collect_transform_params(child)) | |
| # Handle chart's own transforms | |
| transforms = getattr(chart, "transform", []) | |
| transforms = transforms if transforms != alt.Undefined else [] | |
| for tx in transforms: | |
| if hasattr(tx, "param"): | |
| transform_params.add(tx.param) | |
| return transform_params | |