Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| from threading import Thread | |
| import falcon | |
| from falcon.http_status import HTTPStatus | |
| import json | |
| import requests | |
| import time | |
| from Model import generate_completion | |
| import sys | |
| class AutoComplete(object): | |
| def on_post(self, req, resp, single_endpoint=True, x=None, y=None): | |
| json_data = json.loads(req.bounded_stream.read()) | |
| resp.status = falcon.HTTP_200 | |
| start = time.time() | |
| try: | |
| context = json_data["context"].rstrip() | |
| except KeyError: | |
| resp.body = "The context field is required" | |
| resp.status = falcon.HTTP_422 | |
| return | |
| try: | |
| n_samples = json_data['samples'] | |
| except KeyError: | |
| n_samples = 3 | |
| try: | |
| length = json_data['gen_length'] | |
| except KeyError: | |
| length = 20 | |
| try: | |
| max_time = json_data['max_time'] | |
| except KeyError: | |
| max_time = -1 | |
| try: | |
| model_name = json_data['model_size'] | |
| except KeyError: | |
| model_name = "small" | |
| try: | |
| temperature = json_data['temperature'] | |
| except KeyError: | |
| temperature = 0.7 | |
| try: | |
| max_tokens = json_data['max_tokens'] | |
| except KeyError: | |
| max_tokens = 256 | |
| try: | |
| top_p = json_data['top_p'] | |
| except KeyError: | |
| top_p = 0.95 | |
| try: | |
| top_k = json_data['top_k'] | |
| except KeyError: | |
| top_k = 40 | |
| # CTRL | |
| try: | |
| repetition_penalty = json_data['repetition_penalty'] | |
| except KeyError: | |
| repetition_penalty = 0.02 | |
| # PPLM | |
| try: | |
| stepsize = json_data['step_size'] | |
| except KeyError: | |
| stepsize = 0.02 | |
| try: | |
| gm_scale = json_data['gm_scale'] | |
| except KeyError: | |
| gm_scale = None | |
| try: | |
| kl_scale = json_data['kl_scale'] | |
| except KeyError: | |
| kl_scale = None | |
| try: | |
| num_iterations = json_data['num_iterations'] | |
| except KeyError: | |
| num_iterations = None | |
| try: | |
| use_sampling = json_data['use_sampling'] | |
| except KeyError: | |
| use_sampling = None | |
| try: | |
| bag_of_words_or_discrim = json_data['bow_or_discrim'] | |
| except KeyError: | |
| bag_of_words_or_discrim = "kitchen" | |
| print(json_data) | |
| sentences = generate_completion( | |
| context, | |
| length=length, | |
| max_time=max_time, | |
| model_name=model_name, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| top_p=top_p, | |
| top_k=top_k, | |
| # CTRL | |
| repetition_penalty=repetition_penalty, | |
| # PPLM | |
| stepsize=stepsize, | |
| bag_of_words_or_discrim=bag_of_words_or_discrim, | |
| gm_scale=gm_scale, | |
| kl_scale=kl_scale, | |
| num_iterations=num_iterations, | |
| use_sampling=use_sampling | |
| ) | |
| resp.body = json.dumps({"sentences": sentences, 'time': time.time() - start}) | |
| resp.status = falcon.HTTP_200 | |
| sys.stdout.flush() | |
| class Request(Thread): | |
| def __init__(self, end_point, data): | |
| Thread.__init__(self) | |
| self.end_point = end_point | |
| self.data = data | |
| self.ret = None | |
| def run(self): | |
| print("Requesting with url", self.end_point) | |
| self.ret = requests.post(url=self.end_point, json=self.data) | |
| def join(self): | |
| Thread.join(self) | |
| return self.ret.text | |
| class HandleCORS(object): | |
| def process_request(self, req, resp): | |
| resp.set_header('Access-Control-Allow-Origin', '*') | |
| resp.set_header('Access-Control-Allow-Methods', '*') | |
| resp.set_header('Access-Control-Allow-Headers', '*') | |
| if req.method == 'OPTIONS': | |
| raise HTTPStatus(falcon.HTTP_200, body='\n') | |
| autocomplete = AutoComplete() | |
| app = falcon.API(middleware=[HandleCORS()]) | |
| app.add_route('/autocomplete', autocomplete) | |
| app.add_route('/autocomplete/{x}', autocomplete) | |
| app.add_route('/autocomplete/{x}/{y}', autocomplete) | |
| application = app | |