Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	update
Browse files
    	
        examples/add_punctuation/add_punctuation.py
    ADDED
    
    | @@ -0,0 +1,43 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/python3
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
            import argparse
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import sherpa_onnx
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from project_settings import project_path
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def get_args():
         | 
| 11 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 12 | 
            +
                parser.add_argument(
         | 
| 13 | 
            +
                    "--model_file",
         | 
| 14 | 
            +
                    default=(project_path / "pretrained_models/huggingface/csukuangfj/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx").as_posix(),
         | 
| 15 | 
            +
                    type=str
         | 
| 16 | 
            +
                )
         | 
| 17 | 
            +
                parser.add_argument(
         | 
| 18 | 
            +
                    "--text",
         | 
| 19 | 
            +
                    default="i'm a google virtual assistant recording this call for the person you're trying to reach before i try to connect you can ask what you're calling about",
         | 
| 20 | 
            +
                    type=str
         | 
| 21 | 
            +
                )
         | 
| 22 | 
            +
                args = parser.parse_args()
         | 
| 23 | 
            +
                return args
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def main():
         | 
| 27 | 
            +
                args = get_args()
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                config = sherpa_onnx.OfflinePunctuationConfig(
         | 
| 30 | 
            +
                    model=sherpa_onnx.OfflinePunctuationModelConfig(
         | 
| 31 | 
            +
                        ct_transformer=args.model_file
         | 
| 32 | 
            +
                    ),
         | 
| 33 | 
            +
                )
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                punctuation_model = sherpa_onnx.OfflinePunctuation(config)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                text = punctuation_model.add_punctuation(args.text)
         | 
| 38 | 
            +
                print("text: {}".format(text))
         | 
| 39 | 
            +
                return
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            if __name__ == '__main__':
         | 
| 43 | 
            +
                main()
         | 
    	
        examples/add_punctuation/download_model.py
    ADDED
    
    | @@ -0,0 +1,58 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/python3
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
            import argparse
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            from pathlib import Path
         | 
| 6 | 
            +
            import sys
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            pwd = os.path.abspath(os.path.dirname(__file__))
         | 
| 9 | 
            +
            sys.path.append(os.path.join(pwd, "../../"))
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import huggingface_hub
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from project_settings import project_path
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def get_args():
         | 
| 17 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                parser.add_argument(
         | 
| 20 | 
            +
                    "--repo_id",
         | 
| 21 | 
            +
                    default="csukuangfj/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12",
         | 
| 22 | 
            +
                    type=str
         | 
| 23 | 
            +
                )
         | 
| 24 | 
            +
                parser.add_argument("--model_filename", default="model.onnx", type=str)
         | 
| 25 | 
            +
                parser.add_argument("--model_sub_folder", default=".", type=str)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                parser.add_argument(
         | 
| 28 | 
            +
                    "--pretrained_model_dir",
         | 
| 29 | 
            +
                    default=(project_path / "pretrained_models").as_posix(),
         | 
| 30 | 
            +
                    type=str
         | 
| 31 | 
            +
                )
         | 
| 32 | 
            +
                args = parser.parse_args()
         | 
| 33 | 
            +
                return args
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def main():
         | 
| 37 | 
            +
                args = get_args()
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                pretrained_model_dir = Path(args.pretrained_model_dir)
         | 
| 40 | 
            +
                pretrained_model_dir.mkdir(exist_ok=True)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                repo_id: Path = Path(args.repo_id)
         | 
| 43 | 
            +
                local_model_dir = pretrained_model_dir / "huggingface" / repo_id
         | 
| 44 | 
            +
                local_model_dir.mkdir(parents=True, exist_ok=True)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                print("download model")
         | 
| 47 | 
            +
                model_filename = huggingface_hub.hf_hub_download(
         | 
| 48 | 
            +
                    repo_id=args.repo_id,
         | 
| 49 | 
            +
                    filename=args.model_filename,
         | 
| 50 | 
            +
                    subfolder=args.model_sub_folder,
         | 
| 51 | 
            +
                    local_dir=local_model_dir.as_posix(),
         | 
| 52 | 
            +
                )
         | 
| 53 | 
            +
                print(model_filename)
         | 
| 54 | 
            +
                return
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            if __name__ == "__main__":
         | 
| 58 | 
            +
                main()
         | 
    	
        examples/gradio_client/{predict.py → asr.py}
    RENAMED
    
    | 
            File without changes
         | 
    	
        examples/gradio_client/whisper_large_v3.py
    ADDED
    
    | @@ -0,0 +1,37 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/python3
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
            import argparse
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from gradio_client import Client, file
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from project_settings import project_path
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def get_args():
         | 
| 11 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 12 | 
            +
                parser.add_argument(
         | 
| 13 | 
            +
                    "--filename",
         | 
| 14 | 
            +
                    default=(project_path / "data/test_wavs/paraformer-zh/si_chuan_hua.wav").as_posix(),
         | 
| 15 | 
            +
                    type=str
         | 
| 16 | 
            +
                )
         | 
| 17 | 
            +
                args = parser.parse_args()
         | 
| 18 | 
            +
                return args
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def main():
         | 
| 22 | 
            +
                args = get_args()
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                filename = args.filename
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                client = Client("hf-audio/whisper-large-v3")
         | 
| 27 | 
            +
                result = client.predict(
         | 
| 28 | 
            +
                    inputs=file(filename),
         | 
| 29 | 
            +
                    task="transcribe",
         | 
| 30 | 
            +
                    api_name="/predict"
         | 
| 31 | 
            +
                )
         | 
| 32 | 
            +
                print(result)
         | 
| 33 | 
            +
                return
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            if __name__ == '__main__':
         | 
| 37 | 
            +
                main()
         | 
    	
        examples/wenet/toolbox_download.py
    DELETED
    
    | 
            File without changes
         | 
    	
        main.py
    CHANGED
    
    | @@ -148,10 +148,20 @@ def process( | |
| 148 | 
             
                                                   filename=out_filename.as_posix(),
         | 
| 149 | 
             
                                                   )
         | 
| 150 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 151 | 
             
                date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
         | 
| 152 | 
             
                end = time.time()
         | 
| 153 |  | 
| 154 | 
            -
                # statistics
         | 
| 155 | 
             
                metadata = torchaudio.info(out_filename.as_posix())
         | 
| 156 | 
             
                duration = metadata.num_frames / 16000
         | 
| 157 | 
             
                rtf = (end - start) / duration
         | 
|  | |
| 148 | 
             
                                                   filename=out_filename.as_posix(),
         | 
| 149 | 
             
                                                   )
         | 
| 150 |  | 
| 151 | 
            +
                # load_punctuation_model
         | 
| 152 | 
            +
                if add_punctuation == "Yes":
         | 
| 153 | 
            +
                    local_model_dir = pretrained_model_dir / "huggingface" / md5_encrypt("csukuangfj/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12")
         | 
| 154 | 
            +
                    punctuation_model = nn_models.load_punctuation_model(
         | 
| 155 | 
            +
                        local_model_dir=local_model_dir,
         | 
| 156 | 
            +
                        nn_model_file="model.onnx",
         | 
| 157 | 
            +
                        nn_model_file_sub_folder=".",
         | 
| 158 | 
            +
                    )
         | 
| 159 | 
            +
                    text = punctuation_model.add_punctuation(text)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                # statistics
         | 
| 162 | 
             
                date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
         | 
| 163 | 
             
                end = time.time()
         | 
| 164 |  | 
|  | |
| 165 | 
             
                metadata = torchaudio.info(out_filename.as_posix())
         | 
| 166 | 
             
                duration = metadata.num_frames / 16000
         | 
| 167 | 
             
                rtf = (end - start) / duration
         | 
    	
        toolbox/k2_sherpa/nn_models.py
    CHANGED
    
    | @@ -764,7 +764,7 @@ def load_sherpa_onnx_online_recognizer_from_paraformer(encoder_model_file: str, | |
| 764 | 
             
            def load_recognizer(local_model_dir: Path,
         | 
| 765 | 
             
                                decoding_method: str = "greedy_search",
         | 
| 766 | 
             
                                num_active_paths: int = 4,
         | 
| 767 | 
            -
                                **kwargs
         | 
| 768 | 
             
                                ):
         | 
| 769 | 
             
                if not local_model_dir.exists():
         | 
| 770 | 
             
                    download_model(
         | 
| @@ -839,5 +839,29 @@ def load_recognizer(local_model_dir: Path, | |
| 839 | 
             
                return recognizer
         | 
| 840 |  | 
| 841 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 842 | 
             
            if __name__ == "__main__":
         | 
| 843 | 
             
                pass
         | 
|  | |
| 764 | 
             
            def load_recognizer(local_model_dir: Path,
         | 
| 765 | 
             
                                decoding_method: str = "greedy_search",
         | 
| 766 | 
             
                                num_active_paths: int = 4,
         | 
| 767 | 
            +
                                **kwargs,
         | 
| 768 | 
             
                                ):
         | 
| 769 | 
             
                if not local_model_dir.exists():
         | 
| 770 | 
             
                    download_model(
         | 
|  | |
| 839 | 
             
                return recognizer
         | 
| 840 |  | 
| 841 |  | 
| 842 | 
            +
            def load_punctuation_model(local_model_dir: Path,
         | 
| 843 | 
            +
                                       nn_model_file: str,
         | 
| 844 | 
            +
                                       nn_model_file_sub_folder: str,
         | 
| 845 | 
            +
                                       ):
         | 
| 846 | 
            +
                if not local_model_dir.exists():
         | 
| 847 | 
            +
                    download_model(
         | 
| 848 | 
            +
                        local_model_dir=local_model_dir.as_posix(),
         | 
| 849 | 
            +
                        nn_model_file=nn_model_file,
         | 
| 850 | 
            +
                        nn_model_file_sub_folder=nn_model_file_sub_folder,
         | 
| 851 | 
            +
                    )
         | 
| 852 | 
            +
             | 
| 853 | 
            +
                nn_model_file = (local_model_dir / nn_model_file_sub_folder / nn_model_file).as_posix()
         | 
| 854 | 
            +
             | 
| 855 | 
            +
                config = sherpa_onnx.OfflinePunctuationConfig(
         | 
| 856 | 
            +
                    model=sherpa_onnx.OfflinePunctuationModelConfig(
         | 
| 857 | 
            +
                        ct_transformer=nn_model_file
         | 
| 858 | 
            +
                    ),
         | 
| 859 | 
            +
                )
         | 
| 860 | 
            +
             | 
| 861 | 
            +
                punctuation_model = sherpa_onnx.OfflinePunctuation(config)
         | 
| 862 | 
            +
             | 
| 863 | 
            +
                return punctuation_model
         | 
| 864 | 
            +
             | 
| 865 | 
            +
             | 
| 866 | 
             
            if __name__ == "__main__":
         | 
| 867 | 
             
                pass
         | 
