Spaces:
Running
on
L40S
Running
on
L40S
| import gradio as gr | |
| from urllib.parse import urlparse | |
| import requests | |
| import time | |
| from PIL import Image | |
| import base64 | |
| import io | |
| import uuid | |
| import os | |
| def extract_property_info(prop): | |
| combined_prop = {} | |
| merge_keywords = ["allOf", "anyOf", "oneOf"] | |
| for keyword in merge_keywords: | |
| if keyword in prop: | |
| for subprop in prop[keyword]: | |
| combined_prop.update(subprop) | |
| del prop[keyword] | |
| if not combined_prop: | |
| combined_prop = prop.copy() | |
| for key in ["description", "default"]: | |
| if key in prop: | |
| combined_prop[key] = prop[key] | |
| return combined_prop | |
| def detect_file_type(filename): | |
| audio_extensions = [".mp3", ".wav", ".flac", ".aac", ".ogg", ".m4a"] | |
| image_extensions = [ | |
| ".jpg", | |
| ".jpeg", | |
| ".png", | |
| ".gif", | |
| ".bmp", | |
| ".tiff", | |
| ".svg", | |
| ".webp", | |
| ] | |
| video_extensions = [ | |
| ".mp4", | |
| ".mov", | |
| ".wmv", | |
| ".flv", | |
| ".avi", | |
| ".avchd", | |
| ".mkv", | |
| ".webm", | |
| ] | |
| # Extract the file extension | |
| if isinstance(filename, str): | |
| extension = filename[filename.rfind(".") :].lower() | |
| # Check the extension against each list | |
| if extension in audio_extensions: | |
| return "audio" | |
| elif extension in image_extensions: | |
| return "image" | |
| elif extension in video_extensions: | |
| return "video" | |
| else: | |
| return "string" | |
| elif isinstance(filename, list): | |
| return "list" | |
| def build_gradio_inputs(ordered_input_schema, example_inputs=None): | |
| inputs = [] | |
| input_field_strings = """inputs = []\n""" | |
| names = [] | |
| for index, (name, prop) in enumerate(ordered_input_schema): | |
| names.append(name) | |
| prop = extract_property_info(prop) | |
| if "enum" in prop: | |
| input_field = gr.Dropdown( | |
| choices=prop["enum"], | |
| label=prop.get("title"), | |
| info=prop.get("description"), | |
| value=prop.get("default"), | |
| ) | |
| input_field_string = f"""inputs.append(gr.Dropdown( | |
| choices={prop["enum"]}, label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value="{prop.get("default")}" | |
| ))\n""" | |
| elif prop["type"] == "integer": | |
| if prop.get("minimum") and prop.get("maximum"): | |
| input_field = gr.Slider( | |
| label=prop.get("title"), | |
| info=prop.get("description"), | |
| value=prop.get("default"), | |
| minimum=prop.get("minimum"), | |
| maximum=prop.get("maximum"), | |
| step=1, | |
| ) | |
| input_field_string = f"""inputs.append(gr.Slider( | |
| label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")}, | |
| minimum={prop.get("minimum")}, maximum={prop.get("maximum")}, step=1, | |
| ))\n""" | |
| else: | |
| input_field = gr.Number( | |
| label=prop.get("title"), | |
| info=prop.get("description"), | |
| value=prop.get("default"), | |
| ) | |
| input_field_string = f"""inputs.append(gr.Number( | |
| label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")} | |
| ))\n""" | |
| elif prop["type"] == "number": | |
| if prop.get("minimum") and prop.get("maximum"): | |
| input_field = gr.Slider( | |
| label=prop.get("title"), | |
| info=prop.get("description"), | |
| value=prop.get("default"), | |
| minimum=prop.get("minimum"), | |
| maximum=prop.get("maximum"), | |
| ) | |
| input_field_string = f"""inputs.append(gr.Slider( | |
| label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")}, | |
| minimum={prop.get("minimum")}, maximum={prop.get("maximum")} | |
| ))\n""" | |
| else: | |
| input_field = gr.Number( | |
| label=prop.get("title"), | |
| info=prop.get("description"), | |
| value=prop.get("default"), | |
| ) | |
| input_field_string = f"""inputs.append(gr.Number( | |
| label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")} | |
| ))\n""" | |
| elif prop["type"] == "boolean": | |
| input_field = gr.Checkbox( | |
| label=prop.get("title"), | |
| info=prop.get("description"), | |
| value=prop.get("default"), | |
| ) | |
| input_field_string = f"""inputs.append(gr.Checkbox( | |
| label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")} | |
| ))\n""" | |
| elif ( | |
| prop["type"] == "string" and prop.get("format") == "uri" and example_inputs | |
| ): | |
| input_type_example = example_inputs.get(name, None) | |
| if input_type_example: | |
| input_type = detect_file_type(input_type_example) | |
| else: | |
| input_type = None | |
| if input_type == "image": | |
| input_field = gr.Image(label=prop.get("title"), type="filepath") | |
| input_field_string = f"""inputs.append(gr.Image( | |
| label="{prop.get("title")}", type="filepath" | |
| ))\n""" | |
| elif input_type == "audio": | |
| input_field = gr.Audio(label=prop.get("title"), type="filepath") | |
| input_field_string = f"""inputs.append(gr.Audio( | |
| label="{prop.get("title")}", type="filepath" | |
| ))\n""" | |
| elif input_type == "video": | |
| input_field = gr.Video(label=prop.get("title")) | |
| input_field_string = f"""inputs.append(gr.Video( | |
| label="{prop.get("title")}" | |
| ))\n""" | |
| else: | |
| input_field = gr.File(label=prop.get("title")) | |
| input_field_string = f"""inputs.append(gr.File( | |
| label="{prop.get("title")}" | |
| ))\n""" | |
| else: | |
| input_field = gr.Textbox( | |
| label=prop.get("title"), | |
| info=prop.get("description"), | |
| ) | |
| input_field_string = f"""inputs.append(gr.Textbox( | |
| label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'} | |
| ))\n""" | |
| inputs.append(input_field) | |
| input_field_strings += f"{input_field_string}\n" | |
| input_field_strings += f"names = {names}\n" | |
| return inputs, input_field_strings, names | |
| def build_gradio_outputs_replicate(output_types): | |
| outputs = [] | |
| output_field_strings = """outputs = []\n""" | |
| if output_types: | |
| for output in output_types: | |
| if output == "image": | |
| output_field = gr.Image() | |
| output_field_string = "outputs.append(gr.Image())" | |
| elif output == "audio": | |
| output_field = gr.Audio(type="filepath") | |
| output_field_string = "outputs.append(gr.Audio(type='filepath'))" | |
| elif output == "video": | |
| output_field = gr.Video() | |
| output_field_string = "outputs.append(gr.Video())" | |
| elif output == "string": | |
| output_field = gr.Textbox() | |
| output_field_string = "outputs.append(gr.Textbox())" | |
| elif output == "json": | |
| output_field = gr.JSON() | |
| output_field_string = "outputs.append(gr.JSON())" | |
| elif output == "list": | |
| output_field = gr.JSON() | |
| output_field_string = "outputs.append(gr.JSON())" | |
| outputs.append(output_field) | |
| output_field_strings += f"{output_field_string}\n" | |
| else: | |
| output_field = gr.JSON() | |
| output_field_string = "outputs.append(gr.JSON())" | |
| outputs.append(output_field) | |
| return outputs, output_field_strings | |
| def build_gradio_outputs_cog(): | |
| pass | |
| def process_outputs(outputs): | |
| output_values = [] | |
| for output in outputs: | |
| if not output: | |
| continue | |
| if isinstance(output, str): | |
| if output.startswith("data:image"): | |
| base64_data = output.split(",", 1)[1] | |
| image_data = base64.b64decode(base64_data) | |
| image_stream = io.BytesIO(image_data) | |
| image = Image.open(image_stream) | |
| output_values.append(image) | |
| elif output.startswith("data:audio"): | |
| base64_data = output.split(",", 1)[1] | |
| audio_data = base64.b64decode(base64_data) | |
| audio_stream = io.BytesIO(audio_data) | |
| filename = f"{uuid.uuid4()}.wav" # Change format as needed | |
| with open(filename, "wb") as audio_file: | |
| audio_file.write(audio_stream.getbuffer()) | |
| output_values.append(filename) | |
| elif output.startswith("data:video"): | |
| base64_data = output.split(",", 1)[1] | |
| video_data = base64.b64decode(base64_data) | |
| video_stream = io.BytesIO(video_data) | |
| # Here you can save the audio or return the stream for further processing | |
| filename = f"{uuid.uuid4()}.mp4" # Change format as needed | |
| with open(filename, "wb") as video_file: | |
| video_file.write(video_stream.getbuffer()) | |
| output_values.append(filename) | |
| else: | |
| output_values.append(output) | |
| else: | |
| output_values.append(output) | |
| return output_values | |
| def parse_outputs(data): | |
| if isinstance(data, dict): | |
| # Handle case where data is an object | |
| dict_values = [] | |
| for value in data.values(): | |
| extracted_values = parse_outputs(value) | |
| # For dict, we append instead of extend to maintain list structure within objects | |
| if isinstance(value, list): | |
| dict_values += [extracted_values] | |
| else: | |
| dict_values += extracted_values | |
| return dict_values | |
| elif isinstance(data, list): | |
| # Handle case where data is an array | |
| list_values = [] | |
| for item in data: | |
| # Here we extend to flatten the list since we're already in an array context | |
| list_values += parse_outputs(item) | |
| return list_values | |
| else: | |
| # Handle primitive data types directly | |
| return [data] | |
| def create_dynamic_gradio_app( | |
| inputs, | |
| outputs, | |
| api_url, | |
| api_id=None, | |
| replicate_token=None, | |
| title="", | |
| model_description="", | |
| names=[], | |
| local_base=False, | |
| hostname="0.0.0.0", | |
| ): | |
| expected_outputs = len(outputs) | |
| def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)): | |
| payload = {"input": {}} | |
| if api_id: | |
| payload["version"] = api_id | |
| parsed_url = urlparse(str(request.url)) | |
| if local_base: | |
| base_url = f"http://{hostname}:7860" | |
| else: | |
| base_url = parsed_url.scheme + "://" + parsed_url.netloc | |
| for i, key in enumerate(names): | |
| value = args[i] | |
| if value and (os.path.exists(str(value))): | |
| value = f"{base_url}/file=" + value | |
| if value is not None and value != "": | |
| payload["input"][key] = value | |
| print(payload) | |
| headers = {"Content-Type": "application/json"} | |
| if replicate_token: | |
| headers["Authorization"] = f"Token {replicate_token}" | |
| print(headers) | |
| response = requests.post(api_url, headers=headers, json=payload) | |
| if response.status_code == 201: | |
| follow_up_url = response.json()["urls"]["get"] | |
| response = requests.get(follow_up_url, headers=headers) | |
| while response.json()["status"] != "succeeded": | |
| if response.json()["status"] == "failed": | |
| raise gr.Error("The submission failed!") | |
| response = requests.get(follow_up_url, headers=headers) | |
| time.sleep(1) | |
| # TODO: Add a failing mechanism if the API gets stuck | |
| if response.status_code == 200: | |
| json_response = response.json() | |
| # If the output component is JSON return the entire output response | |
| if outputs[0].get_config()["name"] == "json": | |
| return json_response["output"] | |
| predict_outputs = parse_outputs(json_response["output"]) | |
| processed_outputs = process_outputs(predict_outputs) | |
| difference_outputs = expected_outputs - len(processed_outputs) | |
| # If less outputs than expected, hide the extra ones | |
| if difference_outputs > 0: | |
| extra_outputs = [gr.update(visible=False)] * difference_outputs | |
| processed_outputs.extend(extra_outputs) | |
| # If more outputs than expected, cap the outputs to the expected number if | |
| elif difference_outputs < 0: | |
| processed_outputs = processed_outputs[:difference_outputs] | |
| return ( | |
| tuple(processed_outputs) | |
| if len(processed_outputs) > 1 | |
| else processed_outputs[0] | |
| ) | |
| else: | |
| if response.status_code == 409: | |
| raise gr.Error( | |
| f"Sorry, the Cog image is still processing. Try again in a bit." | |
| ) | |
| raise gr.Error(f"The submission failed! Error: {response.status_code}") | |
| app = gr.Interface( | |
| fn=predict, | |
| inputs=inputs, | |
| outputs=outputs, | |
| title=title, | |
| description=model_description, | |
| allow_flagging="never", | |
| ) | |
| return app | |
| def create_gradio_app_script( | |
| inputs_string, | |
| outputs_string, | |
| api_url, | |
| api_id=None, | |
| replicate_token=None, | |
| title="", | |
| model_description="", | |
| local_base=False, | |
| hostname="0.0.0.0" | |
| ): | |
| headers = {"Content-Type": "application/json"} | |
| if replicate_token: | |
| headers["Authorization"] = f"Token {replicate_token}" | |
| if local_base: | |
| base_url = f'base_url = "http://{hostname}:7860"' | |
| else: | |
| base_url = """parsed_url = urlparse(str(request.url)) | |
| base_url = parsed_url.scheme + "://" + parsed_url.netloc""" | |
| headers_string = f"""headers = {headers}\n""" | |
| api_id_value = f'payload["version"] = "{api_id}"' if api_id is not None else "" | |
| definition_string = """expected_outputs = len(outputs) | |
| def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):""" | |
| payload_string = f"""payload = {{"input": {{}}}} | |
| {api_id_value} | |
| {base_url} | |
| for i, key in enumerate(names): | |
| value = args[i] | |
| if value and (os.path.exists(str(value))): | |
| value = f"{{base_url}}/file=" + value | |
| if value is not None and value != "": | |
| payload["input"][key] = value\n""" | |
| request_string = ( | |
| f"""response = requests.post("{api_url}", headers=headers, json=payload)\n""" | |
| ) | |
| result_string = f""" | |
| if response.status_code == 201: | |
| follow_up_url = response.json()["urls"]["get"] | |
| response = requests.get(follow_up_url, headers=headers) | |
| while response.json()["status"] != "succeeded": | |
| if response.json()["status"] == "failed": | |
| raise gr.Error("The submission failed!") | |
| response = requests.get(follow_up_url, headers=headers) | |
| time.sleep(1) | |
| if response.status_code == 200: | |
| json_response = response.json() | |
| #If the output component is JSON return the entire output response | |
| if(outputs[0].get_config()["name"] == "json"): | |
| return json_response["output"] | |
| predict_outputs = parse_outputs(json_response["output"]) | |
| processed_outputs = process_outputs(predict_outputs) | |
| difference_outputs = expected_outputs - len(processed_outputs) | |
| # If less outputs than expected, hide the extra ones | |
| if difference_outputs > 0: | |
| extra_outputs = [gr.update(visible=False)] * difference_outputs | |
| processed_outputs.extend(extra_outputs) | |
| # If more outputs than expected, cap the outputs to the expected number | |
| elif difference_outputs < 0: | |
| processed_outputs = processed_outputs[:difference_outputs] | |
| return tuple(processed_outputs) if len(processed_outputs) > 1 else processed_outputs[0] | |
| else: | |
| if(response.status_code == 409): | |
| raise gr.Error(f"Sorry, the Cog image is still processing. Try again in a bit.") | |
| raise gr.Error(f"The submission failed! Error: {{response.status_code}}")\n""" | |
| interface_string = f"""title = "{title}" | |
| model_description = "{model_description}" | |
| app = gr.Interface( | |
| fn=predict, | |
| inputs=inputs, | |
| outputs=outputs, | |
| title=title, | |
| description=model_description, | |
| allow_flagging="never", | |
| ) | |
| app.launch(share=True) | |
| """ | |
| app_string = f"""import gradio as gr | |
| from urllib.parse import urlparse | |
| import requests | |
| import time | |
| import os | |
| from utils.gradio_helpers import parse_outputs, process_outputs | |
| {inputs_string} | |
| {outputs_string} | |
| {definition_string} | |
| {headers_string} | |
| {payload_string} | |
| {request_string} | |
| {result_string} | |
| {interface_string} | |
| """ | |
| return app_string | |