| import numpy as np | |
| import torch | |
| from glob import glob | |
| from safetensors.torch import save_file, load_file | |
| patch_weights = np.load("589-20240113-071533.npz") | |
| for file in glob("model*.safetensors"): | |
| print(f"{file=}") | |
| weights = load_file(file) | |
| for k, tensor in weights.items(): | |
| if k in patch_weights: | |
| print(f"patching {k}") | |
| weights[k] = torch.from_numpy(patch_weights[k]) | |
| save_file(weights, "patched_" + file, metadata={"format": "pt"}) | |