Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| from typing import Any, Generator | |
| from app_conf import ( | |
| GALLERY_PATH, | |
| GALLERY_PREFIX, | |
| POSTERS_PATH, | |
| POSTERS_PREFIX, | |
| UPLOADS_PATH, | |
| UPLOADS_PREFIX, | |
| ) | |
| from data.loader import preload_data | |
| from data.schema import schema | |
| from data.store import set_videos | |
| from flask import Flask, make_response, Request, request, Response, send_from_directory | |
| from flask_cors import CORS | |
| from inference.data_types import PropagateDataResponse, PropagateInVideoRequest | |
| from inference.multipart import MultipartResponseBuilder | |
| from inference.predictor import InferenceAPI | |
| from strawberry.flask.views import GraphQLView | |
| logger = logging.getLogger(__name__) | |
| app = Flask(__name__) | |
| cors = CORS(app, supports_credentials=True) | |
| videos = preload_data() | |
| set_videos(videos) | |
| inference_api = InferenceAPI() | |
| def healthy() -> Response: | |
| return make_response("OK", 200) | |
| def send_gallery_video(path: str) -> Response: | |
| try: | |
| return send_from_directory( | |
| GALLERY_PATH, | |
| path, | |
| ) | |
| except: | |
| raise ValueError("resource not found") | |
| def send_poster_image(path: str) -> Response: | |
| try: | |
| return send_from_directory( | |
| POSTERS_PATH, | |
| path, | |
| ) | |
| except: | |
| raise ValueError("resource not found") | |
| def send_uploaded_video(path: str): | |
| try: | |
| return send_from_directory( | |
| UPLOADS_PATH, | |
| path, | |
| ) | |
| except: | |
| raise ValueError("resource not found") | |
| # TOOD: Protect route with ToS permission check | |
| def propagate_in_video() -> Response: | |
| data = request.json | |
| args = { | |
| "session_id": data["session_id"], | |
| "start_frame_index": data.get("start_frame_index", 0), | |
| } | |
| boundary = "frame" | |
| frame = gen_track_with_mask_stream(boundary, **args) | |
| return Response(frame, mimetype="multipart/x-savi-stream; boundary=" + boundary) | |
| def gen_track_with_mask_stream( | |
| boundary: str, | |
| session_id: str, | |
| start_frame_index: int, | |
| ) -> Generator[bytes, None, None]: | |
| with inference_api.autocast_context(): | |
| request = PropagateInVideoRequest( | |
| type="propagate_in_video", | |
| session_id=session_id, | |
| start_frame_index=start_frame_index, | |
| ) | |
| for chunk in inference_api.propagate_in_video(request=request): | |
| yield MultipartResponseBuilder.build( | |
| boundary=boundary, | |
| headers={ | |
| "Content-Type": "application/json; charset=utf-8", | |
| "Frame-Current": "-1", | |
| # Total frames minus the reference frame | |
| "Frame-Total": "-1", | |
| "Mask-Type": "RLE[]", | |
| }, | |
| body=chunk.to_json().encode("UTF-8"), | |
| ).get_message() | |
| class MyGraphQLView(GraphQLView): | |
| def get_context(self, request: Request, response: Response) -> Any: | |
| return {"inference_api": inference_api} | |
| # Add GraphQL route to Flask app. | |
| app.add_url_rule( | |
| "/graphql", | |
| view_func=MyGraphQLView.as_view( | |
| "graphql_view", | |
| schema=schema, | |
| # Disable GET queries | |
| # https://strawberry.rocks/docs/operations/deployment | |
| # https://strawberry.rocks/docs/integrations/flask | |
| allow_queries_via_get=False, | |
| # Strawberry recently changed multipart request handling, which now | |
| # requires enabling support explicitly for views. | |
| # https://github.com/strawberry-graphql/strawberry/issues/3655 | |
| multipart_uploads_enabled=True, | |
| ), | |
| ) | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=5000) | |