Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import ipyleaflet | |
| import openai | |
| import solara | |
| center_default = (0, 0) | |
| zoom_default = 2 | |
| messages_default = [] | |
| messages = solara.reactive(messages_default) | |
| zoom_level = solara.reactive(zoom_default) | |
| center = solara.reactive(center_default) | |
| markers = solara.reactive([]) | |
| url = ipyleaflet.basemaps.OpenStreetMap.Mapnik.build_url() | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| model = "gpt-4-1106-preview" | |
| function_descriptions = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "update_map", | |
| "description": "Update map to center on a particular location", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "longitude": { | |
| "type": "number", | |
| "description": "Longitude of the location to center the map on", | |
| }, | |
| "latitude": { | |
| "type": "number", | |
| "description": "Latitude of the location to center the map on", | |
| }, | |
| "zoom": { | |
| "type": "integer", | |
| "description": "Zoom level of the map", | |
| }, | |
| }, | |
| "required": ["longitude", "latitude", "zoom"], | |
| }, | |
| }, | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "add_marker", | |
| "description": "Add marker to the map", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "longitude": { | |
| "type": "number", | |
| "description": "Longitude of the location to the marker", | |
| }, | |
| "latitude": { | |
| "type": "number", | |
| "description": "Latitude of the location to the marker", | |
| }, | |
| "label": { | |
| "type": "string", | |
| "description": "Text to display on the marker", | |
| }, | |
| }, | |
| "required": ["longitude", "latitude", "label"], | |
| }, | |
| }, | |
| }, | |
| ] | |
| def update_map(longitude, latitude, zoom): | |
| print("update_map", longitude, latitude, zoom) | |
| center.set((latitude, longitude)) | |
| zoom_level.set(zoom) | |
| return "Map updated" | |
| def add_marker(longitude, latitude, label): | |
| markers.set(markers.value + [{"location": (latitude, longitude), "label": label}]) | |
| return "Marker added" | |
| functions = { | |
| "update_map": update_map, | |
| "add_marker": add_marker, | |
| } | |
| def ai_call(tool_call): | |
| function = tool_call["function"] | |
| name = function["name"] | |
| arguments = json.loads(function["arguments"]) | |
| return_value = functions[name](**arguments) | |
| message = { | |
| "role": "tool", | |
| "tool_call_id": tool_call["id"], | |
| "name": tool_call["function"]["name"], | |
| "content": return_value, | |
| } | |
| return message | |
| def Map(): | |
| print("Map", zoom_level.value, center.value, markers.value) | |
| ipyleaflet.Map.element( # type: ignore | |
| zoom=zoom_level.value, | |
| # on_zoom=zoom_level.set, | |
| center=center.value, | |
| # on_center=center.set, | |
| scroll_wheel_zoom=True, | |
| layers=[ | |
| ipyleaflet.TileLayer.element(url=url), | |
| *[ | |
| ipyleaflet.Marker.element(location=k["location"], draggable=False) | |
| for k in markers.value | |
| ], | |
| ], | |
| ) | |
| def ChatInterface(): | |
| prompt = solara.use_reactive("") | |
| def add_message(value: str): | |
| if value == "": | |
| return | |
| messages.set(messages.value + [{"role": "user", "content": value}]) | |
| prompt.set("") | |
| def ask(): | |
| if not messages.value: | |
| return | |
| last_message = messages.value[-1] | |
| if last_message["role"] == "user" or last_message["role"] == "tool": | |
| completion = openai.ChatCompletion.create( | |
| model=model, | |
| messages=messages.value, | |
| # Add function calling | |
| tools=function_descriptions, | |
| tool_choice="auto", | |
| ) | |
| output = completion.choices[0].message | |
| print("received", output) | |
| try: | |
| handled_messages = handle_message(output) | |
| messages.value = [*messages.value, output, *handled_messages] | |
| except Exception as e: | |
| print("errr", e) | |
| def handle_message(message): | |
| print("handle", message) | |
| messages = [] | |
| if message["role"] == "assistant": | |
| tools_calls = message.get("tool_calls", []) | |
| for tool_call in tools_calls: | |
| messages.append(ai_call(tool_call)) | |
| return messages | |
| def handle_initial(): | |
| print("handle initial", messages.value) | |
| for message in messages.value: | |
| handle_message(message) | |
| solara.use_effect(handle_initial, []) | |
| result = solara.use_thread(ask, dependencies=[messages.value]) | |
| with solara.Column( | |
| style={"height": "100%", "width": "38vw", "justify-content": "center"}, | |
| classes=["chat-interface"], | |
| ): | |
| if len(messages.value) > 0: | |
| with solara.Column(style={"flex-grow": "1", "overflow-y": "auto"}): | |
| for message in messages.value: | |
| if message["role"] == "user": | |
| solara.Text( | |
| message["content"], classes=["chat-message", "user-message"] | |
| ) | |
| elif message["role"] == "assistant": | |
| if message["content"]: | |
| solara.Markdown(message["content"]) | |
| elif message["tool_calls"]: | |
| solara.Markdown("*Calling map functions*") | |
| else: | |
| solara.Preformatted( | |
| repr(message), | |
| classes=["chat-message", "assistant-message"], | |
| ) | |
| elif message["role"] == "tool": | |
| pass # no need to display | |
| else: | |
| solara.Preformatted( | |
| repr(message), classes=["chat-message", "assistant-message"] | |
| ) | |
| # solara.Text(message, classes=["chat-message"]) | |
| with solara.Column(): | |
| solara.InputText( | |
| label="Ask your ", | |
| value=prompt, | |
| style={"flex-grow": "1"}, | |
| on_value=add_message, | |
| disabled=result.state == solara.ResultState.RUNNING, | |
| ) | |
| solara.ProgressLinear(result.state == solara.ResultState.RUNNING) | |
| if result.state == solara.ResultState.ERROR: | |
| solara.Error(repr(result.error)) | |
| # solara.Text("Thinking...") | |
| # solara.Button("Send", on_click=lambda: messages.set(messages.value + [message_input.value])) | |
| def Page(): | |
| reset_counter, set_reset_counter = solara.use_state(0) | |
| print("reset", reset_counter, f"chat-{reset_counter}") | |
| def reset_ui(): | |
| set_reset_counter(reset_counter + 1) | |
| def save(): | |
| with open("log.json", "w") as f: | |
| json.dump(messages.value, f) | |
| def load(): | |
| with open("log.json", "r") as f: | |
| messages.set(json.load(f)) | |
| reset_ui() | |
| with solara.Column(style={"flex-grow": "1"}, gap=0): | |
| with solara.AppBar(): | |
| solara.Button("Save", on_click=save) | |
| solara.Button("Load", on_click=load) | |
| solara.Button("Soft reset", on_click=reset_ui) | |
| with solara.Row(style={"height": "100%"}, justify="space-between"): | |
| ChatInterface().key(f"chat-{reset_counter}") | |
| with solara.Column(style={"width": "58vw", "justify-content": "center"}): | |
| Map() # .key(f"map-{reset_counter}") | |
| solara.Style( | |
| """ | |
| .jupyter-widgets.leaflet-widgets{ | |
| height: 100%; | |
| } | |
| .solara-autorouter-content{ | |
| display: flex; | |
| flex-direction: column; | |
| justify-content: stretch; | |
| } | |
| """ | |
| ) | |
| # TODO: custom layout | |
| # @solara.component | |
| # def Layout(children): | |
| # with solara.v.AppBar(): | |
| # with solara.Column(children=children): | |
| # pass | |