Spaces:
Running
Running
jhj0517
commited on
Commit
·
25c9e51
1
Parent(s):
2415a05
add args for local model path
Browse files- app.py +4 -0
- modules/faster_whisper_inference.py +2 -2
- modules/whisper_Inference.py +2 -2
app.py
CHANGED
|
@@ -17,8 +17,10 @@ class App:
|
|
| 17 |
self.app = gr.Blocks(css=CSS, theme=self.args.theme)
|
| 18 |
self.whisper_inf = WhisperInference() if self.args.disable_faster_whisper else FasterWhisperInference()
|
| 19 |
if isinstance(self.whisper_inf, FasterWhisperInference):
|
|
|
|
| 20 |
print("Use Faster Whisper implementation")
|
| 21 |
else:
|
|
|
|
| 22 |
print("Use Open AI Whisper implementation")
|
| 23 |
print(f"Device \"{self.whisper_inf.device}\" is detected")
|
| 24 |
self.nllb_inf = NLLBInference()
|
|
@@ -296,6 +298,8 @@ parser.add_argument('--password', type=str, default=None, help='Gradio authentic
|
|
| 296 |
parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme')
|
| 297 |
parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
|
| 298 |
parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='enable api or not')
|
|
|
|
|
|
|
| 299 |
_args = parser.parse_args()
|
| 300 |
|
| 301 |
if __name__ == "__main__":
|
|
|
|
| 17 |
self.app = gr.Blocks(css=CSS, theme=self.args.theme)
|
| 18 |
self.whisper_inf = WhisperInference() if self.args.disable_faster_whisper else FasterWhisperInference()
|
| 19 |
if isinstance(self.whisper_inf, FasterWhisperInference):
|
| 20 |
+
self.whisper_inf.model_dir = args.faster_whisper_model_dir
|
| 21 |
print("Use Faster Whisper implementation")
|
| 22 |
else:
|
| 23 |
+
self.whisper_inf.model_dir = args.whisper_model_dir
|
| 24 |
print("Use Open AI Whisper implementation")
|
| 25 |
print(f"Device \"{self.whisper_inf.device}\" is detected")
|
| 26 |
self.nllb_inf = NLLBInference()
|
|
|
|
| 298 |
parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme')
|
| 299 |
parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
|
| 300 |
parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='enable api or not')
|
| 301 |
+
parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
|
| 302 |
+
parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
|
| 303 |
_args = parser.parse_args()
|
| 304 |
|
| 305 |
if __name__ == "__main__":
|
modules/faster_whisper_inference.py
CHANGED
|
@@ -32,7 +32,7 @@ class FasterWhisperInference(BaseInterface):
|
|
| 32 |
self.available_compute_types = ctranslate2.get_supported_compute_types(
|
| 33 |
"cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
|
| 34 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
| 35 |
-
self.
|
| 36 |
|
| 37 |
def transcribe_file(self,
|
| 38 |
files: list,
|
|
@@ -311,7 +311,7 @@ class FasterWhisperInference(BaseInterface):
|
|
| 311 |
self.model = faster_whisper.WhisperModel(
|
| 312 |
device=self.device,
|
| 313 |
model_size_or_path=model_size,
|
| 314 |
-
download_root=
|
| 315 |
compute_type=self.current_compute_type
|
| 316 |
)
|
| 317 |
|
|
|
|
| 32 |
self.available_compute_types = ctranslate2.get_supported_compute_types(
|
| 33 |
"cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
|
| 34 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
| 35 |
+
self.model_dir = os.path.join("models", "Whisper", "faster-whisper")
|
| 36 |
|
| 37 |
def transcribe_file(self,
|
| 38 |
files: list,
|
|
|
|
| 311 |
self.model = faster_whisper.WhisperModel(
|
| 312 |
device=self.device,
|
| 313 |
model_size_or_path=model_size,
|
| 314 |
+
download_root=self.model_dir,
|
| 315 |
compute_type=self.current_compute_type
|
| 316 |
)
|
| 317 |
|
modules/whisper_Inference.py
CHANGED
|
@@ -26,7 +26,7 @@ class WhisperInference(BaseInterface):
|
|
| 26 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
self.available_compute_types = ["float16", "float32"]
|
| 28 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
| 29 |
-
self.
|
| 30 |
|
| 31 |
def transcribe_file(self,
|
| 32 |
files: list,
|
|
@@ -288,7 +288,7 @@ class WhisperInference(BaseInterface):
|
|
| 288 |
self.model = whisper.load_model(
|
| 289 |
name=model_size,
|
| 290 |
device=self.device,
|
| 291 |
-
download_root=
|
| 292 |
)
|
| 293 |
|
| 294 |
@staticmethod
|
|
|
|
| 26 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
self.available_compute_types = ["float16", "float32"]
|
| 28 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
| 29 |
+
self.model_dir = os.path.join("models", "Whisper")
|
| 30 |
|
| 31 |
def transcribe_file(self,
|
| 32 |
files: list,
|
|
|
|
| 288 |
self.model = whisper.load_model(
|
| 289 |
name=model_size,
|
| 290 |
device=self.device,
|
| 291 |
+
download_root=self.model_dir
|
| 292 |
)
|
| 293 |
|
| 294 |
@staticmethod
|