Spaces:
Paused
Paused
| /** | |
| * Copyright (c) Meta Platforms, Inc. and affiliates. | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| import {generateThumbnail} from '@/common/components/video/editor/VideoEditorUtils'; | |
| import VideoWorkerContext from '@/common/components/video/VideoWorkerContext'; | |
| import Logger from '@/common/logger/Logger'; | |
| import { | |
| SAM2ModelAddNewPointsMutation, | |
| SAM2ModelAddNewPointsMutation$data, | |
| } from '@/common/tracker/__generated__/SAM2ModelAddNewPointsMutation.graphql'; | |
| import {SAM2ModelCancelPropagateInVideoMutation} from '@/common/tracker/__generated__/SAM2ModelCancelPropagateInVideoMutation.graphql'; | |
| import {SAM2ModelClearPointsInFrameMutation} from '@/common/tracker/__generated__/SAM2ModelClearPointsInFrameMutation.graphql'; | |
| import {SAM2ModelClearPointsInVideoMutation} from '@/common/tracker/__generated__/SAM2ModelClearPointsInVideoMutation.graphql'; | |
| import {SAM2ModelCloseSessionMutation} from '@/common/tracker/__generated__/SAM2ModelCloseSessionMutation.graphql'; | |
| import {SAM2ModelRemoveObjectMutation} from '@/common/tracker/__generated__/SAM2ModelRemoveObjectMutation.graphql'; | |
| import {SAM2ModelStartSessionMutation} from '@/common/tracker/__generated__/SAM2ModelStartSessionMutation.graphql'; | |
| import { | |
| BaseTracklet, | |
| Mask, | |
| SegmentationPoint, | |
| StreamingState, | |
| Tracker, | |
| Tracklet, | |
| } from '@/common/tracker/Tracker'; | |
| import {TrackerOptions} from '@/common/tracker/Trackers'; | |
| import { | |
| ClearPointsInVideoResponse, | |
| SessionStartFailedResponse, | |
| SessionStartedResponse, | |
| StreamingCompletedResponse, | |
| StreamingStartedResponse, | |
| StreamingStateUpdateResponse, | |
| TrackletCreatedResponse, | |
| TrackletDeletedResponse, | |
| TrackletsUpdatedResponse, | |
| } from '@/common/tracker/TrackerTypes'; | |
| import {convertMaskToRGBA} from '@/common/utils/MaskUtils'; | |
| import multipartStream from '@/common/utils/MultipartStream'; | |
| import {Stats} from '@/debug/stats/Stats'; | |
| import {INFERENCE_API_ENDPOINT} from '@/demo/DemoConfig'; | |
| import {createEnvironment} from '@/graphql/RelayEnvironment'; | |
| import { | |
| DataArray, | |
| Masks, | |
| RLEObject, | |
| decode, | |
| encode, | |
| toBbox, | |
| } from '@/jscocotools/mask'; | |
| import {THEME_COLORS} from '@/theme/colors'; | |
| import invariant from 'invariant'; | |
| import {IEnvironment, commitMutation, graphql} from 'relay-runtime'; | |
| type Options = Pick<TrackerOptions, 'inferenceEndpoint'>; | |
| type Session = { | |
| id: string | null; | |
| tracklets: {[id: number]: Tracklet}; | |
| }; | |
| type StreamMasksResult = { | |
| frameIndex: number; | |
| rleMaskList: Array<{ | |
| objectId: number; | |
| rleMask: RLEObject; | |
| }>; | |
| }; | |
| type StreamMasksAbortResult = { | |
| aborted: boolean; | |
| }; | |
| export class SAM2Model extends Tracker { | |
| private _endpoint: string; | |
| private _environment: IEnvironment; | |
| private abortController: AbortController | null = null; | |
| private _session: Session = { | |
| id: null, | |
| tracklets: {}, | |
| }; | |
| private _streamingState: StreamingState = 'none'; | |
| private _emptyMask: RLEObject | null = null; | |
| private _maskCanvas: OffscreenCanvas; | |
| private _maskCtx: OffscreenCanvasRenderingContext2D; | |
| private _stats?: Stats; | |
| constructor( | |
| context: VideoWorkerContext, | |
| options: Options = { | |
| inferenceEndpoint: INFERENCE_API_ENDPOINT, | |
| }, | |
| ) { | |
| super(context); | |
| this._endpoint = options.inferenceEndpoint; | |
| this._environment = createEnvironment(options.inferenceEndpoint); | |
| this._maskCanvas = new OffscreenCanvas(0, 0); | |
| const maskCtx = this._maskCanvas.getContext('2d'); | |
| invariant(maskCtx != null, 'context cannot be null'); | |
| this._maskCtx = maskCtx; | |
| } | |
| public startSession(videoPath: string): Promise<void> { | |
| // Reset streaming state. Force update with the true flag to make sure the | |
| // UI updates its state. | |
| this._updateStreamingState('none', true); | |
| return new Promise(resolve => { | |
| try { | |
| commitMutation<SAM2ModelStartSessionMutation>(this._environment, { | |
| mutation: graphql` | |
| mutation SAM2ModelStartSessionMutation($input: StartSessionInput!) { | |
| startSession(input: $input) { | |
| sessionId | |
| } | |
| } | |
| `, | |
| variables: { | |
| input: { | |
| path: videoPath, | |
| }, | |
| }, | |
| onCompleted: response => { | |
| const {sessionId} = response.startSession; | |
| this._session.id = sessionId; | |
| this._sendResponse<SessionStartedResponse>('sessionStarted', { | |
| sessionId, | |
| }); | |
| // Clear any tracklets from the previous session when | |
| // a new session is started | |
| this._clearTracklets(); | |
| // Make an empty tracklet | |
| this.createTracklet(); | |
| resolve(); | |
| }, | |
| onError: error => { | |
| Logger.error(error); | |
| this._sendResponse<SessionStartFailedResponse>( | |
| 'sessionStartFailed', | |
| ); | |
| resolve(); | |
| }, | |
| }); | |
| } catch (error) { | |
| Logger.error(error); | |
| this._sendResponse<SessionStartFailedResponse>('sessionStartFailed'); | |
| resolve(); | |
| } | |
| }); | |
| } | |
| public closeSession(): Promise<void> { | |
| const sessionId = this._session.id; | |
| // Do not call cleanup before retrieving the session id because cleanup | |
| // will reset the session id. If the order would be changed, it would | |
| // never execute the closeSession mutation. | |
| this._cleanup(); | |
| if (sessionId === null) { | |
| return Promise.resolve(); | |
| } | |
| return new Promise((resolve, reject) => { | |
| commitMutation<SAM2ModelCloseSessionMutation>(this._environment, { | |
| mutation: graphql` | |
| mutation SAM2ModelCloseSessionMutation($input: CloseSessionInput!) { | |
| closeSession(input: $input) { | |
| success | |
| } | |
| } | |
| `, | |
| variables: { | |
| input: { | |
| sessionId, | |
| }, | |
| }, | |
| onCompleted: response => { | |
| const {success} = response.closeSession; | |
| if (success === false) { | |
| reject(new Error('Failed to close session')); | |
| return; | |
| } | |
| resolve(); | |
| }, | |
| onError: error => { | |
| Logger.error(error); | |
| reject(error); | |
| }, | |
| }); | |
| }); | |
| } | |
| public createTracklet(): void { | |
| // This will return 0 for for empty tracklets and otherwise the next | |
| // largest number. | |
| const nextId = | |
| Object.values(this._session.tracklets).reduce( | |
| (prev, curr) => Math.max(prev, curr.id), | |
| -1, | |
| ) + 1; | |
| const newTracklet = { | |
| id: nextId, | |
| color: THEME_COLORS[nextId % THEME_COLORS.length], | |
| thumbnail: null, | |
| points: [], | |
| masks: [], | |
| isInitialized: false, | |
| }; | |
| this._session.tracklets[nextId] = newTracklet; | |
| // Notify the main thread | |
| this._updateTracklets(); | |
| this._sendResponse<TrackletCreatedResponse>('trackletCreated', { | |
| tracklet: newTracklet, | |
| }); | |
| } | |
| public deleteTracklet(trackletId: number): Promise<void> { | |
| const sessionId = this._session.id; | |
| if (sessionId === null) { | |
| return Promise.reject('No active session'); | |
| } | |
| const tracklet = this._session.tracklets[trackletId]; | |
| invariant( | |
| tracklet != null, | |
| 'tracklet for tracklet id %s not initialized', | |
| trackletId, | |
| ); | |
| return new Promise((resolve, reject) => { | |
| commitMutation<SAM2ModelRemoveObjectMutation>(this._environment, { | |
| mutation: graphql` | |
| mutation SAM2ModelRemoveObjectMutation($input: RemoveObjectInput!) { | |
| removeObject(input: $input) { | |
| frameIndex | |
| rleMaskList { | |
| objectId | |
| rleMask { | |
| counts | |
| size | |
| } | |
| } | |
| } | |
| } | |
| `, | |
| variables: { | |
| input: {objectId: trackletId, sessionId}, | |
| }, | |
| onCompleted: response => { | |
| const trackletUpdates = response.removeObject; | |
| this._sendResponse<TrackletDeletedResponse>('trackletDeleted', { | |
| isSuccessful: true, | |
| }); | |
| for (const trackletUpdate of trackletUpdates) { | |
| this._updateTrackletMasks( | |
| trackletUpdate, | |
| trackletUpdate.frameIndex === this._context.frameIndex, | |
| false, // shouldGoToFrame | |
| ); | |
| } | |
| this._removeTrackletMasks(tracklet); | |
| resolve(); | |
| }, | |
| onError: error => { | |
| this._sendResponse<TrackletDeletedResponse>('trackletDeleted', { | |
| isSuccessful: false, | |
| }); | |
| Logger.error(error); | |
| reject(error); | |
| }, | |
| }); | |
| }); | |
| } | |
| public updatePoints( | |
| frameIndex: number, | |
| objectId: number, | |
| points: SegmentationPoint[], | |
| ): Promise<void> { | |
| const sessionId = this._session.id; | |
| if (sessionId === null) { | |
| return Promise.reject('No active session'); | |
| } | |
| // TODO: This is not the right place to initialize the empty mask. | |
| // Move this into the constructor and listen to events on the context. | |
| // Note, the initial context.width and context.height is 0, so it needs | |
| // to happen based on an event, so when the video is initialized, it needs | |
| // to notify the tracker to update the empty mask. | |
| if (this._emptyMask === null) { | |
| // We need to round the height/width to the nearest integer since | |
| // Masks.toTensor() expects an integer value for the height/width. | |
| const tensor = new Masks( | |
| Math.trunc(this._context.height), | |
| Math.trunc(this._context.width), | |
| 1, | |
| ).toDataArray(); | |
| this._emptyMask = encode(tensor)[0]; | |
| } | |
| const tracklet = this._session.tracklets[objectId]; | |
| invariant( | |
| tracklet != null, | |
| 'tracklet for object id %s not initialized', | |
| objectId, | |
| ); | |
| // Mark session needing propagation when point is set | |
| this._updateStreamingState('required'); | |
| // Clear all points in frame if no points are provided. | |
| if (points.length === 0) { | |
| return this.clearPointsInFrame(frameIndex, objectId); | |
| } | |
| return new Promise((resolve, reject) => { | |
| const normalizedPoints = points.map(p => [ | |
| p[0] / this._context.width, | |
| p[1] / this._context.height, | |
| ]); | |
| const labels = points.map(p => p[2]); | |
| commitMutation<SAM2ModelAddNewPointsMutation>(this._environment, { | |
| mutation: graphql` | |
| mutation SAM2ModelAddNewPointsMutation($input: AddPointsInput!) { | |
| addPoints(input: $input) { | |
| frameIndex | |
| rleMaskList { | |
| objectId | |
| rleMask { | |
| counts | |
| size | |
| } | |
| } | |
| } | |
| } | |
| `, | |
| variables: { | |
| input: { | |
| sessionId, | |
| frameIndex, | |
| objectId, | |
| labels: labels, | |
| points: normalizedPoints, | |
| clearOldPoints: true, | |
| }, | |
| }, | |
| onCompleted: response => { | |
| tracklet.points[frameIndex] = points; | |
| tracklet.isInitialized = true; | |
| this._updateTrackletMasks(response.addPoints, true); | |
| resolve(); | |
| }, | |
| onError: error => { | |
| Logger.error(error); | |
| reject(error); | |
| }, | |
| }); | |
| }); | |
| } | |
| public clearPointsInFrame( | |
| frameIndex: number, | |
| objectId: number, | |
| ): Promise<void> { | |
| const sessionId = this._session.id; | |
| if (sessionId === null) { | |
| return Promise.reject('No active session'); | |
| } | |
| const tracklet = this._session.tracklets[objectId]; | |
| invariant( | |
| tracklet != null, | |
| 'tracklet for object id %s not initialized', | |
| objectId, | |
| ); | |
| // Mark session needing propagation when point is set | |
| this._updateStreamingState('required'); | |
| return new Promise((resolve, reject) => { | |
| commitMutation<SAM2ModelClearPointsInFrameMutation>(this._environment, { | |
| mutation: graphql` | |
| mutation SAM2ModelClearPointsInFrameMutation( | |
| $input: ClearPointsInFrameInput! | |
| ) { | |
| clearPointsInFrame(input: $input) { | |
| frameIndex | |
| rleMaskList { | |
| objectId | |
| rleMask { | |
| counts | |
| size | |
| } | |
| } | |
| } | |
| } | |
| `, | |
| variables: { | |
| input: { | |
| sessionId, | |
| frameIndex, | |
| objectId, | |
| }, | |
| }, | |
| onCompleted: response => { | |
| tracklet.points[frameIndex] = []; | |
| tracklet.isInitialized = true; | |
| this._updateTrackletMasks(response.clearPointsInFrame, true); | |
| resolve(); | |
| }, | |
| onError: error => { | |
| Logger.error(error); | |
| reject(error); | |
| }, | |
| }); | |
| }); | |
| } | |
| public clearPointsInVideo(): Promise<void> { | |
| const sessionId = this._session.id; | |
| if (sessionId === null) { | |
| return Promise.reject('No active session'); | |
| } | |
| // Mark session needing propagation when point is set | |
| this._updateStreamingState('none'); | |
| return new Promise(resolve => { | |
| commitMutation<SAM2ModelClearPointsInVideoMutation>(this._environment, { | |
| mutation: graphql` | |
| mutation SAM2ModelClearPointsInVideoMutation( | |
| $input: ClearPointsInVideoInput! | |
| ) { | |
| clearPointsInVideo(input: $input) { | |
| success | |
| } | |
| } | |
| `, | |
| variables: { | |
| input: { | |
| sessionId, | |
| }, | |
| }, | |
| onCompleted: response => { | |
| const {success} = response.clearPointsInVideo; | |
| if (!success) { | |
| this._sendResponse<ClearPointsInVideoResponse>( | |
| 'clearPointsInVideo', | |
| {isSuccessful: false}, | |
| ); | |
| return; | |
| } | |
| // Reset points and masks for each tracklet | |
| this._clearTracklets(); | |
| // Notify the main thread | |
| this._context.goToFrame(this._context.frameIndex); | |
| this._updateTracklets(); | |
| this._sendResponse<ClearPointsInVideoResponse>('clearPointsInVideo', { | |
| isSuccessful: true, | |
| }); | |
| resolve(); | |
| }, | |
| onError: error => { | |
| this._sendResponse<ClearPointsInVideoResponse>('clearPointsInVideo', { | |
| isSuccessful: false, | |
| }); | |
| Logger.error(error); | |
| }, | |
| }); | |
| }); | |
| } | |
| public async streamMasks(frameIndex: number): Promise<void> { | |
| const sessionId = this._session.id; | |
| if (sessionId === null) { | |
| return Promise.reject('No active session'); | |
| } | |
| try { | |
| this._sendResponse<StreamingStartedResponse>('streamingStarted'); | |
| // 1. Clear previous masks | |
| this._context.clearMasks(); | |
| this._clearTrackletMasks(); | |
| // 2. Create abort controller and async generator | |
| const controller = new AbortController(); | |
| this.abortController = controller; | |
| this._updateStreamingState('requesting'); | |
| const generator = this._streamMasksForSession( | |
| controller, | |
| sessionId, | |
| frameIndex, | |
| ); | |
| // 3. parse stream response and update masks in session objects | |
| let isAborted = false; | |
| for await (const result of generator) { | |
| if ('aborted' in result) { | |
| this._updateStreamingState('aborting'); | |
| await this._abortRequest(); | |
| this._updateStreamingState('aborted'); | |
| isAborted = true; | |
| } else { | |
| await this._updateTrackletMasks(result, false); | |
| this._updateStreamingState('partial'); | |
| } | |
| } | |
| if (!isAborted) { | |
| // Mark session needing propagation when point is set | |
| this._updateStreamingState('full'); | |
| } | |
| } catch (error) { | |
| Logger.error(error); | |
| throw error; | |
| } | |
| this._sendResponse<StreamingCompletedResponse>('streamingCompleted'); | |
| } | |
| public abortStreamMasks() { | |
| this.abortController?.abort(); | |
| this._sendResponse<StreamingCompletedResponse>('streamingCompleted'); | |
| } | |
| public enableStats(): void { | |
| this._stats = new Stats('ms', 'D', 1000 / 25); | |
| } | |
| // PRIVATE | |
| private _cleanup() { | |
| this._session.id = null; | |
| // Clear existing tracklets | |
| this._session.tracklets = []; | |
| } | |
| private _clearTracklets() { | |
| this._session.tracklets = []; | |
| this._context.clearMasks(); | |
| } | |
| private _updateStreamingState( | |
| state: StreamingState, | |
| forceUpdate: boolean = false, | |
| ) { | |
| if (!forceUpdate && this._streamingState === state) { | |
| return; | |
| } | |
| this._streamingState = state; | |
| this._sendResponse<StreamingStateUpdateResponse>('streamingStateUpdate', { | |
| state, | |
| }); | |
| } | |
| private async _removeTrackletMasks(tracklet: Tracklet) { | |
| this._context.clearTrackletMasks(tracklet); | |
| delete this._session.tracklets[tracklet.id]; | |
| // Notify the main thread | |
| this._context.goToFrame(this._context.frameIndex); | |
| this._updateTracklets(); | |
| } | |
| private async _updateTrackletMasks( | |
| data: SAM2ModelAddNewPointsMutation$data['addPoints'], | |
| updateThumbnails: boolean, | |
| shouldGoToFrame: boolean = true, | |
| ) { | |
| const {frameIndex, rleMaskList} = data; | |
| // 1. parse and decode masks for all objects | |
| for (const {objectId, rleMask} of rleMaskList) { | |
| const track = this._session.tracklets[objectId]; | |
| const {size, counts} = rleMask; | |
| const rleObject: RLEObject = { | |
| size: [size[0], size[1]], | |
| counts: counts, | |
| }; | |
| const isEmpty = counts === this._emptyMask?.counts; | |
| this._stats?.begin(); | |
| const decodedMask = decode([rleObject]); | |
| const bbox = toBbox([rleObject]); | |
| const mask: Mask = { | |
| data: rleObject as RLEObject, | |
| shape: [...decodedMask.shape], | |
| bounds: [ | |
| [bbox[0], bbox[1]], | |
| [bbox[0] + bbox[2], bbox[1] + bbox[3]], | |
| ], | |
| isEmpty, | |
| } as const; | |
| track.masks[frameIndex] = mask; | |
| if (updateThumbnails && !isEmpty) { | |
| const {ctx} = await this._compressMaskForCanvas(decodedMask); | |
| const frame = this._context.currentFrame as VideoFrame; | |
| await generateThumbnail(track, frameIndex, mask, frame, ctx); | |
| } | |
| } | |
| this._context.updateTracklets( | |
| frameIndex, | |
| Object.values(this._session.tracklets), | |
| shouldGoToFrame, | |
| ); | |
| // Notify the main thread | |
| this._updateTracklets(); | |
| } | |
| private _updateTracklets() { | |
| const tracklets: BaseTracklet[] = Object.values( | |
| this._session.tracklets, | |
| ).map(tracklet => { | |
| // Notify the main thread | |
| const { | |
| id, | |
| color, | |
| isInitialized, | |
| points: trackletPoints, | |
| thumbnail, | |
| masks, | |
| } = tracklet; | |
| return { | |
| id, | |
| color, | |
| isInitialized, | |
| points: trackletPoints, | |
| thumbnail, | |
| masks: masks.map(mask => ({ | |
| shape: mask.shape, | |
| bounds: mask.bounds, | |
| isEmpty: mask.isEmpty, | |
| })), | |
| }; | |
| }); | |
| this._sendResponse<TrackletsUpdatedResponse>('trackletsUpdated', { | |
| tracklets, | |
| }); | |
| } | |
| private _clearTrackletMasks() { | |
| const keys = Object.keys(this._session.tracklets); | |
| for (const key of keys) { | |
| const trackletId = Number(key); | |
| const tracklet = {...this._session.tracklets[trackletId], masks: []}; | |
| this._session.tracklets[trackletId] = tracklet; | |
| } | |
| this._updateTracklets(); | |
| } | |
| private async _compressMaskForCanvas( | |
| decodedMask: DataArray, | |
| ): Promise<{compressedData: Blob; ctx: OffscreenCanvasRenderingContext2D}> { | |
| const data = convertMaskToRGBA(decodedMask.data as Uint8Array); | |
| this._maskCanvas.width = decodedMask.shape[0]; | |
| this._maskCanvas.height = decodedMask.shape[1]; | |
| const imageData = new ImageData( | |
| data, | |
| decodedMask.shape[0], | |
| decodedMask.shape[1], | |
| ); | |
| this._maskCtx.putImageData(imageData, 0, 0); | |
| const canvas = new OffscreenCanvas( | |
| decodedMask.shape[1], | |
| decodedMask.shape[0], | |
| ); | |
| const ctx = canvas.getContext('2d'); | |
| invariant(ctx != null, 'context cannot be null'); | |
| ctx.save(); | |
| ctx.rotate(Math.PI / 2); | |
| // Since the image was previously rotated 90° clockwise, after the image is rotated, | |
| // we scale the canvas's width using scaleY and height using scaleX. | |
| ctx.scale(1, -1); | |
| ctx.drawImage(this._maskCanvas, 0, 0); | |
| ctx.restore(); | |
| const compressedData = await canvas.convertToBlob({type: 'image/png'}); | |
| return {compressedData, ctx}; | |
| } | |
| private async *_streamMasksForSession( | |
| abortController: AbortController, | |
| sessionId: string, | |
| startFrameIndex: undefined | number = 0, | |
| ): AsyncGenerator<StreamMasksResult | StreamMasksAbortResult, undefined> { | |
| const url = `${this._endpoint}/propagate_in_video`; | |
| const requestBody = { | |
| session_id: sessionId, | |
| start_frame_index: startFrameIndex, | |
| }; | |
| const headers: {[name: string]: string} = Object.assign({ | |
| 'Content-Type': 'application/json', | |
| }); | |
| const response = await fetch(url, { | |
| method: 'POST', | |
| body: JSON.stringify(requestBody), | |
| headers, | |
| }); | |
| const contentType = response.headers.get('Content-Type'); | |
| if ( | |
| contentType == null || | |
| !contentType.startsWith('multipart/x-savi-stream;') | |
| ) { | |
| throw new Error( | |
| 'endpoint needs to support Content-Type "multipart/x-savi-stream"', | |
| ); | |
| } | |
| const responseBody = response.body; | |
| if (responseBody == null) { | |
| throw new Error('response body is null'); | |
| } | |
| const reader = multipartStream(contentType, responseBody).getReader(); | |
| const textDecoder = new TextDecoder(); | |
| while (true) { | |
| if (abortController.signal.aborted) { | |
| reader.releaseLock(); | |
| yield {aborted: true}; | |
| return; | |
| } | |
| const {done, value} = await reader.read(); | |
| if (done) { | |
| return; | |
| } | |
| const {headers, body} = value; | |
| const contentType = headers.get('Content-Type') as string; | |
| if (contentType.startsWith('application/json')) { | |
| const jsonResponse = JSON.parse(textDecoder.decode(body)); | |
| const maskResults = jsonResponse.results; | |
| const rleMaskList = maskResults.map( | |
| (mask: {object_id: number; mask: RLEObject}) => { | |
| return { | |
| objectId: mask.object_id, | |
| rleMask: mask.mask, | |
| }; | |
| }, | |
| ); | |
| yield { | |
| frameIndex: jsonResponse.frame_index, | |
| rleMaskList, | |
| }; | |
| } | |
| } | |
| } | |
| private async _abortRequest(): Promise<void> { | |
| const sessionId = this._session.id; | |
| invariant(sessionId != null, 'session id cannot be empty'); | |
| return new Promise((resolve, reject) => { | |
| try { | |
| commitMutation<SAM2ModelCancelPropagateInVideoMutation>( | |
| this._environment, | |
| { | |
| mutation: graphql` | |
| mutation SAM2ModelCancelPropagateInVideoMutation( | |
| $input: CancelPropagateInVideoInput! | |
| ) { | |
| cancelPropagateInVideo(input: $input) { | |
| success | |
| } | |
| } | |
| `, | |
| variables: { | |
| input: { | |
| sessionId, | |
| }, | |
| }, | |
| onCompleted: response => { | |
| const {success} = response.cancelPropagateInVideo; | |
| if (!success) { | |
| reject(`could not abort session ${sessionId}`); | |
| return; | |
| } | |
| resolve(); | |
| }, | |
| onError: error => { | |
| Logger.error(error); | |
| reject(error); | |
| }, | |
| }, | |
| ); | |
| } catch (error) { | |
| Logger.error(error); | |
| reject(error); | |
| } | |
| }); | |
| } | |
| } | |