Spaces:
Runtime error
Runtime error
| import asyncio | |
| import websockets | |
| import logging | |
| import json | |
| from convaimodel_extended import ConvAIModelExtended | |
| logging.basicConfig() | |
| STATE = {"value": 0} | |
| USERS = set() | |
| train_args = {} | |
| model = ConvAIModelExtended("gpt2", "cahya/gpt2-small-indonesian-personachat", | |
| args=train_args, use_cuda=False) | |
| def connection_event(): | |
| return json.dumps({"type": "connection", "value": True}) | |
| def state_event(): | |
| return json.dumps({"type": "state", **STATE}) | |
| def dialog_event(message): | |
| return json.dumps({"type": "dialog", "message": message}) | |
| def personality_event(message): | |
| return json.dumps({"type": "personality", "message": message}) | |
| def persona_list_event(message): | |
| return json.dumps({"type": "persona_list", "message": message}) | |
| def personality_reply_event(message): | |
| return json.dumps({"type": "personality_reply", "message": message}) | |
| def persona_greeting_event(message): | |
| return json.dumps({"type": "persona_greeting", "message": message}) | |
| def talk_event(message): | |
| return json.dumps({"type": "talk", "message": message}) | |
| def users_event(): | |
| return json.dumps({"type": "users", "count": len(USERS)}) | |
| async def chatbot(websocket, path): | |
| dialog_id = 0 | |
| try: | |
| # Register user | |
| USERS.add(websocket) | |
| await websocket.send(connection_event()) | |
| websockets.broadcast(USERS, users_event()) | |
| # Send current state to user | |
| await websocket.send(state_event()) | |
| # Manage state changes | |
| async for message in websocket: | |
| message = message.strip() | |
| if message == "": | |
| continue | |
| try: | |
| data = json.loads(message) | |
| if data["action"] == "minus": | |
| STATE["value"] -= 1 | |
| websockets.broadcast(USERS, state_event()) | |
| elif data["action"] == "plus": | |
| STATE["value"] += 1 | |
| websockets.broadcast(USERS, state_event()) | |
| elif data["action"] == "get_users": | |
| await websocket.send(users_event()) | |
| elif data["action"] == "dialog": | |
| if dialog_id == 0: | |
| dialog_id = model.new_dialog() | |
| if dialog_id != 0: | |
| await websocket.send(dialog_event("New dialog is created")) | |
| persona_list = model.get_persona_list(dialog_id) | |
| await websocket.send(persona_list_event(persona_list)) | |
| else: | |
| await websocket.send(dialog_event("Dialog is not created")) | |
| elif data["action"] == "talk": | |
| if dialog_id != 0: | |
| do_sample = bool(data["do_sample"]) if "do_sample" in data else True | |
| min_length = int(data["min_length"]) if "min_length" in data else 1 | |
| max_length = int(data["max_length"]) if "max_length" in data else 20 | |
| temperature = float(data["temperature"]) if "temperature" in data else 0.7 | |
| top_k = int(data["top_k"]) if "top_k" in data else 0 | |
| top_p = float(data["top_p"]) if "top_p" in data else 0.9 | |
| reply = model.talk(dialog_id, | |
| persona_id=data["persona_id"], | |
| utterance=data["utterance"], | |
| do_sample=do_sample, | |
| min_length=min_length, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p) | |
| await websocket.send(talk_event(reply)) | |
| elif data["action"] == "personality": | |
| if dialog_id != 0: | |
| model.set_personality(dialog_id, persona_id=data["persona_id"], personality=data["message"]) | |
| await websocket.send(personality_reply_event("Personality has been updated")) | |
| elif data["action"] == "persona_chosen": | |
| if dialog_id != 0: | |
| name = ConvAIModelExtended.get_persona_name(dialog_id, data["persona_id"]) | |
| greeting = f"Hi, I am {name}. Nice to meet you. Feel free too talk in English or Indonesian." | |
| await websocket.send(persona_greeting_event(greeting)) | |
| else: | |
| logging.error("unsupported event: %s", data) | |
| except json.decoder.JSONDecodeError as error: | |
| print(error) | |
| finally: | |
| # Unregister user | |
| ConvAIModelExtended.delete_dialog(dialog_id) | |
| USERS.remove(websocket) | |
| websockets.broadcast(USERS, users_event()) | |
| async def main(): | |
| async with websockets.serve(chatbot, "0.0.0.0", 8502): | |
| print("Websocket is running") | |
| await asyncio.Future() # run forever | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |