Spaces:
Sleeping
Sleeping
| from datetime import datetime | |
| from flask import Flask | |
| from flask import make_response | |
| from flask import request | |
| from flask import send_from_directory, redirect | |
| from typing import Literal | |
| import json | |
| import logging | |
| import numpy as np | |
| import os | |
| import portpicker | |
| import requests | |
| import shutil | |
| import sys | |
| import threading | |
| import traceback | |
| import urllib.parse | |
| import zipfile | |
| _VISUAL_BLOCKS_BUNDLE_VERSION = "1716228179" | |
| # Disable logging from werkzeug. | |
| # | |
| # Without this, flask will show a warning about using dev server (which is OK | |
| # in our usecase). | |
| logging.getLogger("werkzeug").disabled = True | |
| # Function registrations. | |
| GENERIC_FNS = {} | |
| TEXT_TO_TEXT_FNS = {} | |
| TEXT_TO_TENSORS_FNS = {} | |
| def register_vb_fn( | |
| type: Literal["generic", "text_to_text", "text_to_tensors"] = "generic" | |
| ): | |
| """A function decorator to register python function with Visual Blocks. | |
| Args: | |
| type: | |
| the type of function to register for. | |
| Currently, VB supports the following function types: | |
| generic: | |
| A function or iterable of functions, defined in the same Colab notebook, | |
| that Visual Blocks can call to implement a generic model runner block. | |
| A generic inference function must take a single argument, the input | |
| tensors as an iterable of numpy.ndarrays; run inference; and return the | |
| output tensors, also as an iterable of numpy.ndarrays. | |
| text_to_text: | |
| A function or iterable of functions, defined in the same Colab notebook, | |
| that Visual Blocks can call to implement a text-to-text model runner | |
| block. | |
| A text_to_text function must take a string and return a string. | |
| text_to_tensors: | |
| A function or iterable of functions, defined in the same Colab notebook, | |
| that Visual Blocks can call to implement a text-to-tensors model runner | |
| block. | |
| A text_to_tensors function must take a string and return the output | |
| tensors, as an iterable of numpy.ndarrays. | |
| """ | |
| def decorator_register_vb_fn(func): | |
| func_name = func.__name__ | |
| if type == "generic": | |
| GENERIC_FNS[func_name] = func | |
| elif type == "text_to_text": | |
| TEXT_TO_TEXT_FNS[func_name] = func | |
| elif type == "text_to_tensors": | |
| TEXT_TO_TENSORS_FNS[func_name] = func | |
| return func | |
| return decorator_register_vb_fn | |
| def _json_to_ndarray(json_tensor): | |
| """Convert a JSON dictionary from the web app to an np.ndarray.""" | |
| array = np.array(json_tensor["tensorValues"]) | |
| array.shape = json_tensor["tensorShape"] | |
| return array | |
| def _ndarray_to_json(array): | |
| """Convert a np.ndarray to the JSON dictionary for the web app.""" | |
| values = array.ravel().tolist() | |
| shape = array.shape | |
| return { | |
| "tensorValues": values, | |
| "tensorShape": shape, | |
| } | |
| def _make_json_response(obj): | |
| body = json.dumps(obj) | |
| resp = make_response(body) | |
| resp.headers["Content-Type"] = "application/json" | |
| return resp | |
| def _ensure_iterable(x): | |
| """Turn x into an iterable if not already iterable.""" | |
| if x is None: | |
| return () | |
| elif hasattr(x, "__iter__"): | |
| return x | |
| else: | |
| return (x,) | |
| def _add_to_registry(fns, registry): | |
| """Adds the functions to the given registry (dict).""" | |
| for fn in fns: | |
| registry[fn.__name__] = fn | |
| def _is_list_of_nd_array(obj): | |
| return isinstance(obj, list) and all(isinstance(elem, np.ndarray) for elem in obj) | |
| def Server( | |
| host="0.0.0.0", | |
| port=7860, | |
| generic=None, | |
| text_to_text=None, | |
| text_to_tensors=None, | |
| height=900, | |
| tmp_dir="/tmp", | |
| read_saved_pipeline=True, | |
| ): | |
| """Creates a server that serves visual blocks web app in an iFrame. | |
| Other than serving the web app, it will also listen to requests sent from the | |
| web app at various API end points. Once a request is received, it will use the | |
| data in the request body to call the corresponding functions that users have | |
| registered with VB, either through the '@register_vb_fn' decorator, or passed | |
| in when creating the server. | |
| Args: | |
| generic: | |
| A function or iterable of functions, defined in the same Colab notebook, | |
| that Visual Blocks can call to implement a generic model runner block. | |
| A generic inference function must take a single argument, the input | |
| tensors as an iterable of numpy.ndarrays; run inference; and return the output | |
| tensors, also as an iterable of numpy.ndarrays. | |
| text_to_text: | |
| A function or iterable of functions, defined in the same Colab notebook, | |
| that Visual Blocks can call to implement a text-to-text model runner | |
| block. | |
| A text_to_text function must take a string and return a string. | |
| text_to_tensors: | |
| A function or iterable of functions, defined in the same Colab notebook, | |
| that Visual Blocks can call to implement a text-to-tensors model runner | |
| block. | |
| A text_to_tensors function must take a string and return the output | |
| tensors, as an iterable of numpy.ndarrays. | |
| height: | |
| The height of the embedded iFrame. | |
| tmp_dir: | |
| The tmp dir where the server stores the web app's static resources. | |
| read_saved_pipeline: | |
| Whether to read the saved pipeline in the notebook or not. | |
| """ | |
| _add_to_registry(_ensure_iterable(generic), GENERIC_FNS) | |
| _add_to_registry(_ensure_iterable(text_to_text), TEXT_TO_TEXT_FNS) | |
| _add_to_registry(_ensure_iterable(text_to_tensors), TEXT_TO_TENSORS_FNS) | |
| app = Flask(__name__) | |
| # Disable startup messages. | |
| cli = sys.modules["flask.cli"] | |
| cli.show_server_banner = lambda *x: None | |
| # Prepare tmp dir and log file. | |
| base_path = tmp_dir + "/visual-blocks-colab" | |
| if os.path.exists(base_path): | |
| shutil.rmtree(base_path) | |
| os.mkdir(base_path) | |
| log_file_path = base_path + "/log" | |
| open(log_file_path, "w").close() | |
| # Download the zip file that bundles the visual blocks web app. | |
| bundle_target_path = os.path.join(base_path, "visual_blocks.zip") | |
| url = ( | |
| "https://storage.googleapis.com/tfweb/rapsai-colab-bundles/visual_blocks_%s.zip" | |
| % _VISUAL_BLOCKS_BUNDLE_VERSION | |
| ) | |
| r = requests.get(url) | |
| with open(bundle_target_path, "wb") as zip_file: | |
| zip_file.write(r.content) | |
| # Unzip it. | |
| # This will unzip all files to {base_path}/build. | |
| with zipfile.ZipFile(bundle_target_path, "r") as zip_ref: | |
| zip_ref.extractall(base_path) | |
| site_root_path = os.path.join(base_path, "build") | |
| def log(msg): | |
| """Logs the given message to the log file.""" | |
| now = datetime.now() | |
| dt_string = now.strftime("%d/%m/%Y %H:%M:%S") | |
| with open(log_file_path, "a") as log_file: | |
| log_file.write("{}: {}\n".format(dt_string, msg)) | |
| def list_inference_functions(): | |
| result = {} | |
| if len(GENERIC_FNS): | |
| result["generic"] = list(GENERIC_FNS.keys()) | |
| result["generic"].sort() | |
| if len(TEXT_TO_TEXT_FNS): | |
| result["text_to_text"] = list(TEXT_TO_TEXT_FNS.keys()) | |
| result["text_to_text"].sort() | |
| if len(TEXT_TO_TENSORS_FNS): | |
| result["text_to_tensors"] = list(TEXT_TO_TENSORS_FNS.keys()) | |
| result["text_to_tensors"].sort() | |
| return _make_json_response(result) | |
| # Note: using "/api/..." for POST requests is not allowed. | |
| def inference_generic(): | |
| """Handler for the generic api endpoint.""" | |
| result = {} | |
| try: | |
| func_name = request.json["function"] | |
| inference_fn = GENERIC_FNS[func_name] | |
| input_tensors = [_json_to_ndarray(x) for x in request.json["tensors"]] | |
| output_tensors = inference_fn(input_tensors) | |
| if not _is_list_of_nd_array(output_tensors): | |
| result = { | |
| "error": "The returned value from %s is not a list of ndarray" | |
| % func_name | |
| } | |
| else: | |
| result["tensors"] = [_ndarray_to_json(x) for x in output_tensors] | |
| except Exception as e: | |
| msg = "".join(traceback.format_exception(type(e), e, e.__traceback__)) | |
| result = {"error": msg} | |
| finally: | |
| return _make_json_response(result) | |
| # Note: using "/api/..." for POST requests is not allowed. | |
| def inference_text_to_text(): | |
| """Handler for the text_to_text api endpoint.""" | |
| result = {} | |
| try: | |
| func_name = request.json["function"] | |
| inference_fn = TEXT_TO_TEXT_FNS[func_name] | |
| text = request.json["text"] | |
| ret = inference_fn(text) | |
| if not isinstance(ret, str): | |
| result = { | |
| "error": "The returned value from %s is not a string" % func_name | |
| } | |
| else: | |
| result["text"] = ret | |
| except Exception as e: | |
| msg = "".join(traceback.format_exception(type(e), e, e.__traceback__)) | |
| result = {"error": msg} | |
| finally: | |
| return _make_json_response(result) | |
| # Note: using "/api/..." for POST requests is not allowed. | |
| def inference_text_to_tensors(): | |
| """Handler for the text_to_tensors api endpoint.""" | |
| result = {} | |
| try: | |
| func_name = request.json["function"] | |
| inference_fn = TEXT_TO_TENSORS_FNS[func_name] | |
| text = request.json["text"] | |
| output_tensors = inference_fn(text) | |
| if not _is_list_of_nd_array(output_tensors): | |
| result = { | |
| "error": "The returned value from %s is not a list of ndarray" | |
| % func_name | |
| } | |
| else: | |
| result["tensors"] = [_ndarray_to_json(x) for x in output_tensors] | |
| except Exception as e: | |
| msg = "".join(traceback.format_exception(type(e), e, e.__traceback__)) | |
| result = {"error": msg} | |
| finally: | |
| return _make_json_response(result) | |
| def redirect_to_edit_new(): | |
| """Redirect root URL to /#/edit/new/""" | |
| return redirect("/#/edit/new/") | |
| def get_static(path): | |
| """Handler for serving static resources.""" | |
| if path == "": | |
| path = "index.html" | |
| return send_from_directory(site_root_path, path) | |
| # Start background server. | |
| # threading.Thread(target=app.run, kwargs={"host": host, "port": port}).start() | |
| # A thin wrapper class for exposing a "display" method. | |
| class _Server: | |
| def run(self): | |
| print("Visual Blocks server started at http://%s:%s" % (host, port)) | |
| app.run(host=host, port=port) | |
| return _Server() | |