File size: 8,930 Bytes
59d2585
 
 
 
 
 
 
 
40b438e
59d2585
40b438e
 
59d2585
 
 
 
 
 
40b438e
59d2585
 
40b438e
 
59d2585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40b438e
 
 
 
 
 
 
 
59d2585
 
 
 
 
 
 
 
40b438e
 
59d2585
40b438e
59d2585
 
40b438e
 
 
 
59d2585
 
 
 
 
 
 
 
 
 
 
40b438e
59d2585
 
 
 
 
40b438e
59d2585
 
 
 
 
 
 
 
 
40b438e
 
59d2585
 
40b438e
 
 
 
 
59d2585
 
40b438e
59d2585
 
 
40b438e
59d2585
 
40b438e
 
 
 
 
 
 
 
59d2585
40b438e
 
 
 
 
59d2585
40b438e
 
 
 
 
 
 
 
 
59d2585
 
40b438e
59d2585
40b438e
 
 
59d2585
40b438e
 
59d2585
 
40b438e
 
 
 
 
 
 
 
59d2585
 
40b438e
59d2585
 
40b438e
 
 
59d2585
40b438e
59d2585
40b438e
 
 
 
 
 
 
 
 
 
59d2585
 
 
40b438e
59d2585
 
 
 
40b438e
59d2585
 
 
 
 
 
 
40b438e
 
 
 
 
 
 
 
 
59d2585
 
 
 
 
 
 
 
 
 
40b438e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import os
import argparse
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from tqdm import tqdm


def resize_lora_model(model_path, output_path, new_dim, device, method):
    """
    Resizes the LoRA dimension of a model using SVD or Randomized SVD.
    Also scales the alpha value(s) proportionally.

    Args:
        model_path (str): Path to the LoRA model to resize.
        output_path (str): Path to save the new resized model.
        new_dim (int): The target new dimension for the LoRA weights.
        device (str): The device to run calculations on ('cuda' or 'cpu').
        method (str): The resizing method to use ('svd' or 'randomized_svd').
    """
    print(f"Loading model from: {model_path}")
    # Load the model onto CPU memory first to avoid VRAM issues with large models
    model = load_file(model_path, device="cpu")
    new_model = {}

    # --- Metadata & Weight Inspection ---
    original_dim = None
    alpha = None
    try:
        with safe_open(model_path, framework="pt", device="cpu") as f:
            metadata = f.metadata()
            if metadata:
                if 'ss_network_dim' in metadata:
                    original_dim = int(metadata['ss_network_dim'])
                    print(f"Original dimension (from metadata): {original_dim}")
                if 'ss_network_alpha' in metadata:
                    alpha = float(metadata['ss_network_alpha'])
                    print(f"Original alpha (from metadata): {alpha}")
    except Exception as e:
        print(f"Could not read metadata: {e}. Dimension and alpha will be inferred.")

    # Infer original_dim from weights if not in metadata
    if original_dim is None:
        for key in model.keys():
            if key.endswith((".lora_down.weight", ".lora_A.weight")):
                original_dim = model[key].shape[0]
                print(f"Inferred original dimension from weights: {original_dim}")
                break

    if original_dim is None:
        print("Error: Could not determine original LoRA dimension.")
        return

    if original_dim == new_dim:
        print("Error: New dimension is the same as the original dimension. No changes to make.")
        return

    # Infer alpha from weights if not in metadata
    if alpha is None:
        for key in model.keys():
            if key.endswith(".alpha"):
                alpha = model[key].item()
                print(f"Inferred alpha from weights: {alpha}")
                break

    # Fallback for alpha if still not found
    if alpha is None:
        alpha = float(original_dim)
        print(f"Alpha not found, falling back to using dimension value: {alpha}")

    # --- Tensor Processing ---
    # Calculate the scaling ratio for alpha
    ratio = new_dim / original_dim
    print(f"Dimension resize ratio: {ratio:.4f}")

    lora_keys_to_process = set()
    for key in model.keys():
        if 'lora_' in key and key.endswith('.weight'):
            base_key = key.split('.lora_')[0]
            lora_keys_to_process.add(base_key)

    if not lora_keys_to_process:
        print("Error: No LoRA weights found in the model.")
        return

    print(f"\nFound {len(lora_keys_to_process)} LoRA modules to resize...")
    print(f"Using '{method}' method for resizing.")

    for base_key in tqdm(sorted(list(lora_keys_to_process)), desc="Resizing modules"):
        try:
            down_key, up_key = None, None

            # Determine the correct key names for down and up weights
            if base_key + ".lora_down.weight" in model:
                down_key = base_key + ".lora_down.weight"
                up_key = base_key + ".lora_up.weight"
            elif base_key + ".lora_A.weight" in model:
                down_key = base_key + ".lora_A.weight"
                up_key = base_key + ".lora_B.weight"
            else:
                continue

            down_weight = model[down_key]
            up_weight = model[up_key]
            original_dtype = up_weight.dtype

            # Move weights to the selected device for processing
            down_weight = down_weight.to(device, dtype=torch.float32)
            up_weight = up_weight.to(device, dtype=torch.float32)

            # Handle both linear and convolutional layers
            conv2d = down_weight.ndim == 4
            if conv2d:
                conv_shape = down_weight.shape
                down_weight = down_weight.flatten(1)
                up_weight = up_weight.flatten(1)

            # Reconstruct the full weight matrix
            full_weight = up_weight @ down_weight

            if method == 'svd':
                # --- Full SVD Resizing (Accurate) ---
                U, S, Vh = torch.linalg.svd(full_weight)

                # Truncate or pad the SVD components to the new dimension
                U = U[:, :new_dim]
                S = S[:new_dim]
                Vh = Vh[:new_dim, :]

                # Distribute singular values (S) back to the new matrices
                # A common practice is to take the square root for balanced distribution
                S_sqrt = torch.sqrt(S)
                new_up = U @ torch.diag(S_sqrt)
                new_down = torch.diag(S_sqrt) @ Vh

            elif method == 'randomized_svd':
                # --- Randomized SVD Resizing (Fast Approximation) ---
                U, S, V = torch.svd_lowrank(full_weight, q=new_dim)
                Vh = V.T

                # Distribute singular values like in the full SVD method
                S_sqrt = torch.sqrt(S)
                new_up = U @ torch.diag(S_sqrt)
                new_down = torch.diag(S_sqrt) @ Vh

            if conv2d:
                new_down = new_down.reshape(new_dim, conv_shape[1], conv_shape[2], conv_shape[3])

            # Store the new resized weights
            new_model[down_key] = new_down.contiguous().to(original_dtype)
            new_model[up_key] = new_up.contiguous().to(original_dtype)

            # --- MODIFICATION START ---
            # If a per-module alpha exists, scale it proportionally.
            alpha_key = base_key + ".alpha"
            if alpha_key in model:
                original_alpha_tensor = model[alpha_key]
                # Calculate new alpha and create a new tensor with the same dtype
                new_alpha_value = original_alpha_tensor.item() * ratio
                new_model[alpha_key] = torch.tensor(new_alpha_value, dtype=original_alpha_tensor.dtype)
            # --- MODIFICATION END ---

        except Exception as e:
            print(f"Warning: Failed to process {base_key}. Error: {e}")
            continue

    # Copy all non-LoRA tensors from the original model
    for key, value in model.items():
        if ".lora_" not in key:
            # Ensure we don't copy an old alpha that has already been processed
            if ".alpha" not in key or key not in new_model:
                new_model[key] = value

    # Update metadata with the new dimension and the globally scaled alpha
    new_metadata = {'ss_network_dim': str(new_dim)}
    new_alpha = alpha * ratio
    new_metadata['ss_network_alpha'] = str(new_alpha)
    print(f"\nNew global alpha scaled to: {new_alpha:.2f}")

    # Move all tensors to CPU before saving
    if device != 'cpu':
        print("\nMoving processed tensors to CPU for saving...")
        for key in tqdm(new_model.keys(), desc="Finalizing"):
            if isinstance(new_model[key], torch.Tensor):
                new_model[key] = new_model[key].cpu()

    print(f"\nSaving resized model to: {output_path}")
    save_file(new_model, output_path, metadata=new_metadata)
    print("Done! 🎉")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Resize a LoRA model to a new dimension and scales alpha proportionally.",
        formatter_class=argparse.RawTextHelpFormatter
    )
    parser.add_argument("model_path", type=str, help="Path to the LoRA model (.safetensors).")
    parser.add_argument("output_path", type=str, help="Path to save the resized LoRA model.")
    parser.add_argument("new_dim", type=int, help="The new LoRA dimension (rank).")
    parser.add_argument("--device", type=str, default=None,
                        help="Device to use (e.g., 'cpu', 'cuda'). Autodetects if not specified.")
    parser.add_argument(
        "--method",
        type=str,
        default="svd",
        choices=["svd", "randomized_svd"],
        help="""Resizing method:
'svd' (default): Accurate but slower. Uses full SVD for optimal weight preservation.
'randomized_svd': Faster approximation of SVD. Excellent for speed on large models."""
    )

    args = parser.parse_args()

    if args.device:
        device = args.device
    else:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    print(f"Using device: {device}")

    resize_lora_model(args.model_path, args.output_path, args.new_dim, device, args.method)