Spaces:
Running
Running
Merge pull request #154 from jhj0517/feature/add-local-model-path
Browse files- app.py +4 -0
- modules/faster_whisper_inference.py +2 -2
- modules/whisper_Inference.py +2 -2
- user-start-webui.bat +9 -2
app.py
CHANGED
|
@@ -17,8 +17,10 @@ class App:
|
|
| 17 |
self.app = gr.Blocks(css=CSS, theme=self.args.theme)
|
| 18 |
self.whisper_inf = WhisperInference() if self.args.disable_faster_whisper else FasterWhisperInference()
|
| 19 |
if isinstance(self.whisper_inf, FasterWhisperInference):
|
|
|
|
| 20 |
print("Use Faster Whisper implementation")
|
| 21 |
else:
|
|
|
|
| 22 |
print("Use Open AI Whisper implementation")
|
| 23 |
print(f"Device \"{self.whisper_inf.device}\" is detected")
|
| 24 |
self.nllb_inf = NLLBInference()
|
|
@@ -296,6 +298,8 @@ parser.add_argument('--password', type=str, default=None, help='Gradio authentic
|
|
| 296 |
parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme')
|
| 297 |
parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
|
| 298 |
parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='enable api or not')
|
|
|
|
|
|
|
| 299 |
_args = parser.parse_args()
|
| 300 |
|
| 301 |
if __name__ == "__main__":
|
|
|
|
| 17 |
self.app = gr.Blocks(css=CSS, theme=self.args.theme)
|
| 18 |
self.whisper_inf = WhisperInference() if self.args.disable_faster_whisper else FasterWhisperInference()
|
| 19 |
if isinstance(self.whisper_inf, FasterWhisperInference):
|
| 20 |
+
self.whisper_inf.model_dir = args.faster_whisper_model_dir
|
| 21 |
print("Use Faster Whisper implementation")
|
| 22 |
else:
|
| 23 |
+
self.whisper_inf.model_dir = args.whisper_model_dir
|
| 24 |
print("Use Open AI Whisper implementation")
|
| 25 |
print(f"Device \"{self.whisper_inf.device}\" is detected")
|
| 26 |
self.nllb_inf = NLLBInference()
|
|
|
|
| 298 |
parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme')
|
| 299 |
parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
|
| 300 |
parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='enable api or not')
|
| 301 |
+
parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
|
| 302 |
+
parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
|
| 303 |
_args = parser.parse_args()
|
| 304 |
|
| 305 |
if __name__ == "__main__":
|
modules/faster_whisper_inference.py
CHANGED
|
@@ -32,7 +32,7 @@ class FasterWhisperInference(BaseInterface):
|
|
| 32 |
self.available_compute_types = ctranslate2.get_supported_compute_types(
|
| 33 |
"cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
|
| 34 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
| 35 |
-
self.
|
| 36 |
|
| 37 |
def transcribe_file(self,
|
| 38 |
files: list,
|
|
@@ -311,7 +311,7 @@ class FasterWhisperInference(BaseInterface):
|
|
| 311 |
self.model = faster_whisper.WhisperModel(
|
| 312 |
device=self.device,
|
| 313 |
model_size_or_path=model_size,
|
| 314 |
-
download_root=
|
| 315 |
compute_type=self.current_compute_type
|
| 316 |
)
|
| 317 |
|
|
|
|
| 32 |
self.available_compute_types = ctranslate2.get_supported_compute_types(
|
| 33 |
"cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
|
| 34 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
| 35 |
+
self.model_dir = os.path.join("models", "Whisper", "faster-whisper")
|
| 36 |
|
| 37 |
def transcribe_file(self,
|
| 38 |
files: list,
|
|
|
|
| 311 |
self.model = faster_whisper.WhisperModel(
|
| 312 |
device=self.device,
|
| 313 |
model_size_or_path=model_size,
|
| 314 |
+
download_root=self.model_dir,
|
| 315 |
compute_type=self.current_compute_type
|
| 316 |
)
|
| 317 |
|
modules/whisper_Inference.py
CHANGED
|
@@ -26,7 +26,7 @@ class WhisperInference(BaseInterface):
|
|
| 26 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
self.available_compute_types = ["float16", "float32"]
|
| 28 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
| 29 |
-
self.
|
| 30 |
|
| 31 |
def transcribe_file(self,
|
| 32 |
files: list,
|
|
@@ -288,7 +288,7 @@ class WhisperInference(BaseInterface):
|
|
| 288 |
self.model = whisper.load_model(
|
| 289 |
name=model_size,
|
| 290 |
device=self.device,
|
| 291 |
-
download_root=
|
| 292 |
)
|
| 293 |
|
| 294 |
@staticmethod
|
|
|
|
| 26 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
self.available_compute_types = ["float16", "float32"]
|
| 28 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
| 29 |
+
self.model_dir = os.path.join("models", "Whisper")
|
| 30 |
|
| 31 |
def transcribe_file(self,
|
| 32 |
files: list,
|
|
|
|
| 288 |
self.model = whisper.load_model(
|
| 289 |
name=model_size,
|
| 290 |
device=self.device,
|
| 291 |
+
download_root=self.model_dir
|
| 292 |
)
|
| 293 |
|
| 294 |
@staticmethod
|
user-start-webui.bat
CHANGED
|
@@ -10,9 +10,10 @@ set SHARE=
|
|
| 10 |
set THEME=
|
| 11 |
set DISABLE_FASTER_WHISPER=
|
| 12 |
set API_OPEN=
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
-
:: Set args accordingly
|
| 16 |
if not "%SERVER_NAME%"=="" (
|
| 17 |
set SERVER_NAME_ARG=--server_name %SERVER_NAME%
|
| 18 |
)
|
|
@@ -37,7 +38,13 @@ if /I "%DISABLE_FASTER_WHISPER%"=="true" (
|
|
| 37 |
if /I "%API_OPEN%"=="true" (
|
| 38 |
set API_OPEN=--api_open
|
| 39 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
:: Call the original .bat script with optional arguments
|
| 42 |
-
start-webui.bat %SERVER_NAME_ARG% %SERVER_PORT_ARG% %USERNAME_ARG% %PASSWORD_ARG% %SHARE_ARG% %THEME_ARG% %DISABLE_FASTER_WHISPER_ARG% %API_OPEN%
|
| 43 |
pause
|
|
|
|
| 10 |
set THEME=
|
| 11 |
set DISABLE_FASTER_WHISPER=
|
| 12 |
set API_OPEN=
|
| 13 |
+
set WHISPER_MODEL_DIR=
|
| 14 |
+
set FASTER_WHISPER_MODEL_DIR=
|
| 15 |
|
| 16 |
|
|
|
|
| 17 |
if not "%SERVER_NAME%"=="" (
|
| 18 |
set SERVER_NAME_ARG=--server_name %SERVER_NAME%
|
| 19 |
)
|
|
|
|
| 38 |
if /I "%API_OPEN%"=="true" (
|
| 39 |
set API_OPEN=--api_open
|
| 40 |
)
|
| 41 |
+
if not "%WHISPER_MODEL_DIR%"=="" (
|
| 42 |
+
set WHISPER_MODEL_DIR_ARG=--whisper_model_dir "%WHISPER_MODEL_DIR%"
|
| 43 |
+
)
|
| 44 |
+
if not "%FASTER_WHISPER_MODEL_DIR%"=="" (
|
| 45 |
+
set FASTER_WHISPER_MODEL_DIR_ARG=--faster_whisper_model_dir "%FASTER_WHISPER_MODEL_DIR%"
|
| 46 |
+
)
|
| 47 |
|
| 48 |
:: Call the original .bat script with optional arguments
|
| 49 |
+
start-webui.bat %SERVER_NAME_ARG% %SERVER_PORT_ARG% %USERNAME_ARG% %PASSWORD_ARG% %SHARE_ARG% %THEME_ARG% %DISABLE_FASTER_WHISPER_ARG% %API_OPEN% %WHISPER_MODEL_DIR_ARG% %FASTER_WHISPER_MODEL_DIR_ARG%
|
| 50 |
pause
|