initial import
Browse files- .gitignore +1 -0
 - .vscode/settings.json +4 -0
 - app.py +54 -0
 - convert.py +112 -0
 - requirements.txt +2 -0
 
    	
        .gitignore
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            .env/
         
     | 
    	
        .vscode/settings.json
    ADDED
    
    | 
         @@ -0,0 +1,4 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
            	"editor.formatOnSave": true,
         
     | 
| 3 | 
         
            +
            	"python.formatting.provider": "black"
         
     | 
| 4 | 
         
            +
            }
         
     | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,54 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import gradio as gr
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from convert import convert
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            def run(token: str, model_id: str) -> str:
         
     | 
| 7 | 
         
            +
                if token == "" or model_id == "":
         
     | 
| 8 | 
         
            +
                    return """
         
     | 
| 9 | 
         
            +
                    ### Invalid input 🐞
         
     | 
| 10 | 
         
            +
                    
         
     | 
| 11 | 
         
            +
                    Please fill a token and model_id.
         
     | 
| 12 | 
         
            +
                    """
         
     | 
| 13 | 
         
            +
                try:
         
     | 
| 14 | 
         
            +
                    pr_url = convert(token=token, model_id=model_id)
         
     | 
| 15 | 
         
            +
                    return f"""
         
     | 
| 16 | 
         
            +
                    ### Success 🔥
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                    Yay! This model was successfully converted and a PR was open using your token, here:
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                    {pr_url}
         
     | 
| 21 | 
         
            +
                    """
         
     | 
| 22 | 
         
            +
                except Exception as e:
         
     | 
| 23 | 
         
            +
                    return f"""
         
     | 
| 24 | 
         
            +
                    ### Error 😢😢😢
         
     | 
| 25 | 
         
            +
                    
         
     | 
| 26 | 
         
            +
                    {e}
         
     | 
| 27 | 
         
            +
                    """
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            DESCRIPTION = """
         
     | 
| 31 | 
         
            +
            The steps are the following:
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            - Paste a read-access token from hf.co/settings/tokens. Read access is enough given that we will open a PR against the source repo.
         
     | 
| 34 | 
         
            +
            - Input a model id from the Hub
         
     | 
| 35 | 
         
            +
            - Click "Submit"
         
     | 
| 36 | 
         
            +
            - That's it! You'll get feedback if it works or not, and if it worked, you'll get the URL of the opened PR 🔥
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            ⚠️ For now only `pytorch_model.bin` files are supported but we'll extend in the future.
         
     | 
| 39 | 
         
            +
            """
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            demo = gr.Interface(
         
     | 
| 42 | 
         
            +
                title="Convert any model to Safetensors and open a PR",
         
     | 
| 43 | 
         
            +
                description=DESCRIPTION,
         
     | 
| 44 | 
         
            +
                allow_flagging="never",
         
     | 
| 45 | 
         
            +
                article="Check out the [Safetensors repo on GitHub](https://github.com/huggingface/safetensors)",
         
     | 
| 46 | 
         
            +
                inputs=[
         
     | 
| 47 | 
         
            +
                    gr.Text(max_lines=1, label="your_hf_token"),
         
     | 
| 48 | 
         
            +
                    gr.Text(max_lines=1, label="model_id"),
         
     | 
| 49 | 
         
            +
                ],
         
     | 
| 50 | 
         
            +
                outputs=[gr.Markdown(label="output")],
         
     | 
| 51 | 
         
            +
                fn=run,
         
     | 
| 52 | 
         
            +
            )
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            demo.launch()
         
     | 
    	
        convert.py
    ADDED
    
    | 
         @@ -0,0 +1,112 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import argparse
         
     | 
| 2 | 
         
            +
            import json
         
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download
         
     | 
| 8 | 
         
            +
            from safetensors.torch import save_file
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            def rename(pt_filename) -> str:
         
     | 
| 12 | 
         
            +
                local = pt_filename.replace(".bin", ".safetensors")
         
     | 
| 13 | 
         
            +
                local = local.replace("pytorch_model", "model")
         
     | 
| 14 | 
         
            +
                return local
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            def convert_multi(model_id) -> str:
         
     | 
| 18 | 
         
            +
                local_filenames = []
         
     | 
| 19 | 
         
            +
                try:
         
     | 
| 20 | 
         
            +
                    filename = hf_hub_download(
         
     | 
| 21 | 
         
            +
                        repo_id=model_id, filename="pytorch_model.bin.index.json"
         
     | 
| 22 | 
         
            +
                    )
         
     | 
| 23 | 
         
            +
                    with open(filename, "r") as f:
         
     | 
| 24 | 
         
            +
                        data = json.load(f)
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                    filenames = set(data["weight_map"].values())
         
     | 
| 27 | 
         
            +
                    for filename in filenames:
         
     | 
| 28 | 
         
            +
                        cached_filename = hf_hub_download(repo_id=model_id, filename=filename)
         
     | 
| 29 | 
         
            +
                        loaded = torch.load(cached_filename)
         
     | 
| 30 | 
         
            +
                        local = rename(filename)
         
     | 
| 31 | 
         
            +
                        save_file(loaded, local, metadata={"format": "pt"})
         
     | 
| 32 | 
         
            +
                        local_filenames.append(local)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    index = "model.safetensors.index.json"
         
     | 
| 35 | 
         
            +
                    with open(index, "w") as f:
         
     | 
| 36 | 
         
            +
                        newdata = {k: v for k, v in data.items()}
         
     | 
| 37 | 
         
            +
                        newmap = {k: rename(v) for k, v in data["weight_map"].items()}
         
     | 
| 38 | 
         
            +
                        newdata["weight_map"] = newmap
         
     | 
| 39 | 
         
            +
                        json.dump(newdata, f)
         
     | 
| 40 | 
         
            +
                    local_filenames.append(index)
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    api = HfApi()
         
     | 
| 43 | 
         
            +
                    operations = [
         
     | 
| 44 | 
         
            +
                        CommitOperationAdd(path_in_repo=local, path_or_fileobj=local)
         
     | 
| 45 | 
         
            +
                        for local in local_filenames
         
     | 
| 46 | 
         
            +
                    ]
         
     | 
| 47 | 
         
            +
                    return api.create_commit(
         
     | 
| 48 | 
         
            +
                        repo_id=model_id,
         
     | 
| 49 | 
         
            +
                        operations=operations,
         
     | 
| 50 | 
         
            +
                        commit_message="Adding `safetensors` variant of this model",
         
     | 
| 51 | 
         
            +
                        create_pr=True,
         
     | 
| 52 | 
         
            +
                    )
         
     | 
| 53 | 
         
            +
                finally:
         
     | 
| 54 | 
         
            +
                    for local in local_filenames:
         
     | 
| 55 | 
         
            +
                        os.remove(local)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            def convert_single(model_id) -> str:
         
     | 
| 59 | 
         
            +
                local = "model.safetensors"
         
     | 
| 60 | 
         
            +
                try:
         
     | 
| 61 | 
         
            +
                    filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
         
     | 
| 62 | 
         
            +
                    loaded = torch.load(filename)
         
     | 
| 63 | 
         
            +
                    save_file(loaded, local, metadata={"format": "pt"})
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    api = HfApi()
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                    return api.upload_file(
         
     | 
| 68 | 
         
            +
                        path_or_fileobj=local,
         
     | 
| 69 | 
         
            +
                        create_pr=True,
         
     | 
| 70 | 
         
            +
                        path_in_repo=local,
         
     | 
| 71 | 
         
            +
                        repo_id=model_id,
         
     | 
| 72 | 
         
            +
                    )
         
     | 
| 73 | 
         
            +
                finally:
         
     | 
| 74 | 
         
            +
                    os.remove(local)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            def convert(token: str, model_id: str) -> str:
         
     | 
| 78 | 
         
            +
                """
         
     | 
| 79 | 
         
            +
                returns url to the PR
         
     | 
| 80 | 
         
            +
                """
         
     | 
| 81 | 
         
            +
                api = HfApi(token=token)
         
     | 
| 82 | 
         
            +
                info = api.model_info(model_id)
         
     | 
| 83 | 
         
            +
                filenames = set(s.rfilename for s in info.siblings)
         
     | 
| 84 | 
         
            +
                if "pytorch_model.bin" in filenames:
         
     | 
| 85 | 
         
            +
                    return convert_single(model_id)
         
     | 
| 86 | 
         
            +
                elif "pytorch_model.bin.index.json" in filenames:
         
     | 
| 87 | 
         
            +
                    return convert_multi(model_id)
         
     | 
| 88 | 
         
            +
                raise ValueError("repo does not seem to have a pytorch_model in it")
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 92 | 
         
            +
                DESCRIPTION = """
         
     | 
| 93 | 
         
            +
                Simple utility tool to convert automatically some weights on the hub to `safetensors` format.
         
     | 
| 94 | 
         
            +
                It is PyTorch exclusive for now.
         
     | 
| 95 | 
         
            +
                It works by downloading the weights (PT), converting them locally, and uploading them back
         
     | 
| 96 | 
         
            +
                as a PR on the hub.
         
     | 
| 97 | 
         
            +
                """
         
     | 
| 98 | 
         
            +
                parser = argparse.ArgumentParser(description=DESCRIPTION)
         
     | 
| 99 | 
         
            +
                parser.add_argument(
         
     | 
| 100 | 
         
            +
                    "model_id",
         
     | 
| 101 | 
         
            +
                    type=str,
         
     | 
| 102 | 
         
            +
                    help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
         
     | 
| 103 | 
         
            +
                )
         
     | 
| 104 | 
         
            +
                args = parser.parse_args()
         
     | 
| 105 | 
         
            +
                model_id = args.model_id
         
     | 
| 106 | 
         
            +
                api = HfApi()
         
     | 
| 107 | 
         
            +
                info = api.model_info(model_id)
         
     | 
| 108 | 
         
            +
                filenames = set(s.rfilename for s in info.siblings)
         
     | 
| 109 | 
         
            +
                if "pytorch_model.bin" in filenames:
         
     | 
| 110 | 
         
            +
                    convert_single(model_id)
         
     | 
| 111 | 
         
            +
                else:
         
     | 
| 112 | 
         
            +
                    convert_multi(model_id)
         
     | 
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,2 @@ 
     | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            git+https://github.com/huggingface/huggingface_hub@main
         
     | 
| 2 | 
         
            +
            safetensors
         
     |