diff --git a/onnx/builder.py b/onnx/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..fe7a1b97a263cc0b119987634c21ca724e48a7d8 --- /dev/null +++ b/onnx/builder.py @@ -0,0 +1,628 @@ +import argparse +import numpy as np +import onnx +import onnxruntime as ort +import onnxscript +import os +import requests +import shutil +import soundfile +import subprocess +import sys +import torch + +from onnx import helper, numpy_helper, TensorProto +from onnxruntime_genai.models.builder import create_model +from onnxruntime.transformers.dynamo_onnx_helper import DynamoOnnxHelper +from onnxscript import ir +from PIL import Image +from transformers import AutoConfig, AutoProcessor, AutoModelForCausalLM + + +def build_vision(args): + # Many images: + prompt = f"{user_prompt}<|image_1|>\n<|image_2|>\n<|image_3|>\n<|image_4|>\nWhat is shown in these four images?{prompt_suffix}{assistant_prompt}" + url = "https://www.ilankelman.org/stopsigns/australia.jpg" + image_1 = Image.open(requests.get(url, stream=True).raw) + url = "https://img.freepik.com/free-photo/painting-mountain-lake-with-mountain-background_188544-9126.jpg?w=2000" + image_2 = Image.open(requests.get(url, stream=True).raw) + url = "https://th.bing.com/th/id/OIP.gCvQ1vmPVJmrq1nnzM3ZHQHaEo?rs=1&pid=ImgDetMain" + image_3 = Image.open(requests.get(url, stream=True).raw) + url = "https://wallpaper.dog/large/10809054.jpg" + image_4 = Image.open(requests.get(url, stream=True).raw) + images = [image_1, image_2, image_3, image_4] + inputs = processor(prompt, images=images, return_tensors="pt").to(args.execution_provider.replace("dml", "cuda")) + inputs["input_image_embeds"] = inputs["input_image_embeds"].to(args.precision) + inputs["image_attention_mask"] = inputs["image_attention_mask"].to(args.precision) + + # TorchScript export + dummy_inputs = ( + inputs["input_image_embeds"], # image_embeds: torch.FloatTensor + inputs["image_attention_mask"], # image_attention_mask: torch.FloatTensor + inputs["image_sizes"], # image_sizes: torch.LongTensor + ) + dynamic_axes = { + "pixel_values": {0: "num_images", 1: "max_num_crops", 3: "height", 4: "width"}, + "image_attention_mask": {0: "num_images", 1: "max_num_crops"}, + "image_sizes": {0: "num_images"}, + "image_features": {0: "num_image_tokens"}, + } + filename = "phi-4-mm-vision.onnx" + + temp_folder_1 = os.path.join(args.output, "vision_init_export") + os.makedirs(temp_folder_1, exist_ok=True) + + fpath_1 = os.path.join(temp_folder_1, filename) + torch.onnx.export( + model.model.embed_tokens_extend.image_embed, + args=dummy_inputs, + f=fpath_1, + export_params=True, + input_names=["pixel_values", "image_attention_mask", "image_sizes"], + output_names=["image_features"], + dynamic_axes=dynamic_axes, + opset_version=14, + do_constant_folding=True, + ) + + onnx.checker.check_model(fpath_1) + onnx.shape_inference.infer_shapes_path(fpath_1) + onnx_model = onnx.load_model(fpath_1, load_external_data=True) + + temp_folder_2 = os.path.join(args.output, "vision_after_export") + os.makedirs(temp_folder_2, exist_ok=True) + + fpath_2 = os.path.join(temp_folder_2, filename) + onnx.save_model( + onnx_model, + fpath_2, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=f"{filename}.data", + size_threshold=0, + convert_attribute=False, + ) + shutil.rmtree(temp_folder_1) + + # ORT transformer optimizer + temp_folder_3 = os.path.join(args.output, "vision_after_opt") + fpath_3 = os.path.join(temp_folder_3, filename) + subprocess.run( + [ + f"{sys.executable}", "-m", "onnxruntime.transformers.optimizer", + "--input", fpath_2, + "--output", fpath_3, + "--model_type", "clip", + "--num_heads", str(16), + "--hidden_size", str(1152), + "--use_external_data_format", + "--opt_level", str(0), + "--disable_shape_inference", + ] + ) + shutil.rmtree(temp_folder_2) + + # ORT 4-bits quantizer + fpath_4 = os.path.join(args.output, filename) + cmd = [ + f"{sys.executable}", "-m", "onnxruntime.quantization.matmul_4bits_quantizer", + "--input_model", fpath_3, + "--output_model", fpath_4, + "--block_size", str(32), + ] + if args.precision == torch.float32: cmd.extend(["--accuracy_level", str(4)]) + subprocess.run(cmd) + shutil.rmtree(temp_folder_3) + + +def build_speech(args): + # Speech file: + prompt = f"{user_prompt}<|audio_1|>\n<|audio_2|>\nWhat are the stories that these audios come from?{prompt_suffix}{assistant_prompt}" + audio1 = soundfile.read(os.path.join(args.input, "examples", "1272-128104-0004.wav")) + audio2 = soundfile.read(os.path.join(args.input, "examples", "1272-128104-0009.wav")) + inputs = processor(prompt, audios=[audio1, audio2], return_tensors="pt").to(args.execution_provider.replace("dml", "cuda")) + inputs["input_audio_embeds"] = inputs["input_audio_embeds"].to(args.precision) + + # TorchScript export + dummy_inputs = ( + inputs["input_audio_embeds"], # audio_embeds: torch.FloatTensor + inputs["audio_attention_mask"], # audio_attention_mask: torch.BoolTensor + inputs["audio_embed_sizes"], # audio_sizes: torch.FloatTensor + inputs["input_mode"], # audio_projection_mode: int + ) + dynamic_axes = { + "audio_embeds": {0: "num_audios", 1: "num_frames", 2: "feature_size"}, + "audio_attention_mask": {0: "num_audios", 1: "num_frames"}, + "audio_sizes": {0: "num_audios"}, + "audio_features": {0: "num_audio_tokens"}, + } + filename = "phi-4-mm-speech.onnx" + + temp_folder_1 = os.path.join(args.output, "speech_init_export") + os.makedirs(temp_folder_1, exist_ok=True) + + fpath_1 = os.path.join(temp_folder_1, filename) + torch._dynamo.config.capture_scalar_outputs = True + ep = torch.export.export( + model.model.embed_tokens_extend.audio_embed, args=dummy_inputs, strict=False, + dynamic_shapes=[ + {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO, 2: torch.export.Dim.AUTO}, + {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO}, + {0: torch.export.Dim.AUTO}, + {0: torch.export.Dim.AUTO}, + ] + ) + onnx_program = torch.onnx.export(ep, (), input_names=["audio_embeds", "audio_attention_mask", "audio_sizes", "audio_projection_mode"], output_names=["audio_features"]) + onnx_program.optimize() + onnx_program.save(fpath_1, external_data=True) + + onnx.checker.check_model(fpath_1) + onnx.shape_inference.infer_shapes_path(fpath_1) + onnx_model = onnx.load_model(fpath_1, load_external_data=True) + + temp_folder_2 = os.path.join(args.output, "speech_after_export") + os.makedirs(temp_folder_2, exist_ok=True) + + fpath_2 = os.path.join(temp_folder_2, filename) + onnx.save_model( + onnx_model, + fpath_2, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=f"{filename}.data", + size_threshold=0, + convert_attribute=False, + ) + shutil.rmtree(temp_folder_1) + + # ONNX/ORT rewriter + temp_folder_3 = os.path.join(args.output, "speech_after_rewrite") + os.makedirs(temp_folder_3, exist_ok=True) + + onnx_model = ir.load(fpath_2) + DynamoOnnxHelper.fold_transpose_initializers(onnx_model) + onnxscript.rewriter.rewrite(onnx_model) + onnxscript.optimizer.optimize(onnx_model, onnx_shape_inference=False, input_size_limit=4*2048*2048, output_size_limit=4*2048*2048) + + fpath_3 = os.path.join(temp_folder_3, filename) + ir.save(onnx_model, fpath_3, external_data=f"{filename}.data") + shutil.rmtree(temp_folder_2) + + onnx_model = onnx.load_model(fpath_3, load_external_data=True) + # Fix labels of dynamic axes since they can't be specified during Dynamo export currently + onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = "num_audios" + onnx_model.graph.input[0].type.tensor_type.shape.dim[1].dim_param = "num_frames" + onnx_model.graph.input[1].type.tensor_type.shape.dim[0].dim_param = "num_audios" + onnx_model.graph.input[1].type.tensor_type.shape.dim[1].dim_param = "num_frames" + onnx_model.graph.input[2].type.tensor_type.shape.dim[0].dim_param = "num_audios" + onnx_model.graph.output[0].type.tensor_type.shape.dim[0].dim_param = "num_audio_tokens" + + onnx_model = DynamoOnnxHelper(onnx_model) + onnx_model.convert_constants_to_initializers() + onnx_model.clear_metadata() + + os.remove(fpath_3) + os.remove(fpath_3 + ".data") + onnx_model.model.save_model_to_file(fpath_3, use_external_data_format=True, all_tensors_to_one_file=True, convert_attribute=True) # convert_attribute = True needed because of ONNX/ORT rewriter + + # ORT transformer optimizer + temp_folder_4 = os.path.join(args.output, "speech_after_opt") + fpath_4 = os.path.join(temp_folder_4, filename) + subprocess.run( + [ + f"{sys.executable}", "-m", "onnxruntime.transformers.optimizer", + "--input", fpath_3, + "--output", fpath_4, + "--model_type", "conformer", + "--num_heads", str(16), + "--hidden_size", str(1024), + "--use_external_data_format", + "--opt_level", str(0), + "--disable_shape_inference", + "--convert_attribute", + ] + ) + shutil.rmtree(temp_folder_3) + + # ORT 4-bits quantizer + fpath_5 = os.path.join(args.output, filename) + cmd = [ + f"{sys.executable}", "-m", "onnxruntime.quantization.matmul_4bits_quantizer", + "--input_model", fpath_4, + "--output_model", fpath_5, + "--block_size", str(32), + ] + if args.precision == torch.float32: cmd.extend(["--accuracy_level", str(4)]) + subprocess.run(cmd) + shutil.rmtree(temp_folder_4) + + +def build_embedding(args): + # TorchScript export + batch_size, sequence_length, num_image_tokens, num_audio_tokens = 2, 8, 2, 2 + inputs = { + "input_ids": torch.randint(low=0, high=config.vocab_size, size=(batch_size, sequence_length), device=args.execution_provider.replace("dml", "cuda"), dtype=torch.int64), + "image_features": torch.randn(num_image_tokens, config.hidden_size, device=args.execution_provider.replace("dml", "cuda"), dtype=args.precision), + "audio_features": torch.randn(num_audio_tokens, config.hidden_size, device=args.execution_provider.replace("dml", "cuda"), dtype=args.precision), + } + inputs["input_ids"][0][0] = -1 + inputs["input_ids"][0][1] = -1 + inputs["input_ids"][0][2] = -10000 + inputs["input_ids"][0][3] = -10000 + dummy_inputs = ( + inputs["input_ids"], # input_ids: torch.LongTensor + inputs["image_features"], # image_features: Optional[torch.FloatTensor] = None, + inputs["audio_features"], # audio_features: Optional[torch.FloatTensor] = None, + ) + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "image_features": {0: "num_image_tokens"}, + "audio_features": {0: "num_audio_tokens"}, + "inputs_embeds": {0: "batch_size", 1: "sequence_length"}, + } + filename = "phi-4-mm-embedding.onnx" + + temp_folder_1 = os.path.join(args.output, "embedding_init_export") + os.makedirs(temp_folder_1, exist_ok=True) + + fpath_1 = os.path.join(temp_folder_1, filename) + torch.onnx.export( + model.model.combined_embed, + args=dummy_inputs, + f=fpath_1, + export_params=True, + input_names=["input_ids", "image_features", "audio_features"], + output_names=["inputs_embeds"], + dynamic_axes=dynamic_axes, + opset_version=14, + do_constant_folding=True, + ) + + onnx.checker.check_model(fpath_1) + onnx.shape_inference.infer_shapes_path(fpath_1) + onnx_model = onnx.load_model(fpath_1, load_external_data=True) + + fpath_2 = os.path.join(args.output, filename) + onnx.save_model( + onnx_model, + fpath_2, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=f"{filename}.data", + size_threshold=0, + convert_attribute=False, + ) + shutil.rmtree(temp_folder_1) + + +def build_text(args): + # Create ONNX model + model_name = None + precision = "int4" + extra_options = { + "exclude_embeds": "true", + "filename": "phi-4-mm-text.onnx", + } + if args.precision == torch.float32: extra_options["int4_accuracy_level"] = 4 + create_model(model_name, args.input, args.output, precision, args.execution_provider, args.cache_dir, **extra_options) + + +def build_adapters(args): + # setattr(args, 'use_ortvalue', True) + # build_float_adapters(args) + + setattr(args, 'use_ortvalue', False) + build_quantized_adapters(args) + + +def extract_adapters_from_torch(args): + # Extract LoRAs from PyTorch model + hidden_size = config.hidden_size + num_kv_heads = config.num_key_value_heads + num_attn_heads = config.num_attention_heads + head_size = hidden_size // num_attn_heads + + q_size = num_attn_heads * head_size + kv_size = num_kv_heads * head_size + intermediate_size = config.intermediate_size + + vision_scaling = config.vision_lora["lora_alpha"] / config.vision_lora["r"] + speech_scaling = config.speech_lora["lora_alpha"] / config.speech_lora["r"] + + vision_adapters = {} + speech_adapters = {} + for key, val in model.state_dict().items(): + # Map name in graph as key + new_dict = {} + key = key.replace("self_attn", "attn").replace("lora_A", "lora_A.MatMul").replace("lora_B", "lora_B.MatMul") + + if "lora_A" in key: + # LoRA_A is shared across projections + if "qkv_proj" in key: + new_dict[key.replace("qkv_proj", "q_proj")] = val + new_dict[key.replace("qkv_proj", "k_proj")] = val + new_dict[key.replace("qkv_proj", "v_proj")] = val + elif "gate_up_proj" in key: + new_dict[key.replace("gate_up_proj", "gate_proj")] = val + new_dict[key.replace("gate_up_proj", "up_proj")] = val + else: + new_dict[key] = val + + elif "lora_B" in key: + # LoRA_B is split across projections + if "qkv_proj" in key: + new_dict[key.replace("qkv_proj", "q_proj")] = val[: q_size, :] + new_dict[key.replace("qkv_proj", "k_proj")] = val[q_size : q_size + kv_size, :] + new_dict[key.replace("qkv_proj", "v_proj")] = val[q_size + kv_size :, :] + elif "gate_up_proj" in key: + new_dict[key.replace("gate_up_proj", "gate_proj")] = val[: intermediate_size, :] + new_dict[key.replace("gate_up_proj", "up_proj")] = val[intermediate_size :, :] + else: + new_dict[key] = val + + else: + continue + + for new_key, new_val in new_dict.items(): + new_key = new_key.replace(".vision", "").replace(".speech", "") + if "vision" in key: + np_data = new_val.detach().cpu().to(args.precision).numpy().transpose() + if "lora_B" in key: + np_data *= vision_scaling + vision_adapters[new_key] = ort.OrtValue.ortvalue_from_numpy(np_data) if args.use_ortvalue else np_data + elif "speech" in key: + np_data = new_val.detach().cpu().to(args.precision).numpy().transpose() + if "lora_B" in key: + np_data *= speech_scaling + speech_adapters[new_key] = ort.OrtValue.ortvalue_from_numpy(np_data) if args.use_ortvalue else np_data + else: + raise ValueError(f"Unknown LoRA key found: {key}") + + return vision_adapters, speech_adapters + + +def build_onnx_adapters(vision_adapters, speech_adapters): + # Convert vision LoRAs + adapter_format = ort.AdapterFormat() + adapter_format.set_adapter_version(1) + adapter_format.set_model_version(1) + adapter_format.set_parameters(vision_adapters) + adapter_format.export_adapter(os.path.join(args.output, "phi-4-mm-vision.onnx_adapter")) + + # Convert speech LoRAs + adapter_format = ort.AdapterFormat() + adapter_format.set_adapter_version(1) + adapter_format.set_model_version(1) + adapter_format.set_parameters(speech_adapters) + adapter_format.export_adapter(os.path.join(args.output, "phi-4-mm-speech.onnx_adapter")) + + # Convert LoRA weights in ONNX model to inputs + filename = "phi-4-mm-text.onnx" + fpath = os.path.join(args.output, filename) + onnx_model = onnx.load_model(fpath) + + to_proto = { + "tensor(int8)": TensorProto.INT8, + "tensor(uint8)": TensorProto.UINT8, + "tensor(float16)": TensorProto.FLOAT16, + "tensor(float)": TensorProto.FLOAT, + } + for key, val in vision_adapters.items(): + # Handle different sized feature dimensions between adapters by using dynamic axes + shape = val.shape() + if "lora_A.MatMul.weight_Q4" in key: + shape[0] = "out_features" + elif "lora_B.MatMul.weight_Q4" in key: + shape[1] = "(in_features + block_size - 1) // block_size" + elif "lora_A.MatMul.weight_scales" in key or "lora_B.MatMul.weight_scales" in key: + shape[0] = "in_features * out_features / block_size" + elif "lora_A.MatMul.weight" in key: + shape[1] = "out_features" + elif "lora_B.MatMul.weight" in key: + shape[0] = "in_features" + + new_input = helper.make_tensor_value_info(key, to_proto[val.data_type()], shape) + onnx_model.graph.input.extend([new_input]) + for initializer in onnx_model.graph.initializer: + if initializer.name == key: + # Add 0-filled static initializer for when LoRA isn't used + # since size of inner dims in LoRA path don't matter + zero_initializer = helper.make_tensor( + name=initializer.name, + data_type=initializer.data_type, + dims=val.shape(), + vals=np.zeros(val.shape(), dtype=helper.tensor_dtype_to_np_dtype(initializer.data_type)).flatten(), + ) + onnx_model.graph.initializer.remove(initializer) + onnx_model.graph.initializer.append(zero_initializer) + break + + os.remove(fpath) + os.remove(fpath + ".data") + onnx.save_model( + onnx_model, + fpath, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=f"{filename}.data", + size_threshold=0, + convert_attribute=False, + ) + + +def build_float_adapters(args): + vision_adapters, speech_adapters = extract_adapters_from_torch(args) + build_onnx_adapters(vision_adapters, speech_adapters) + + +def build_adapter_only_onnx_model(args, adapters, filename, fpath): + inputs, outputs, initializers, value_infos, nodes = [], [], [], [], [] + dtype = TensorProto.FLOAT16 if args.precision == torch.float16 else TensorProto.FLOAT + for key, val in adapters.items(): + # Create input and output + inputs.append(helper.make_tensor_value_info(f"input.{key}", dtype, ["batch_size", "sequence_length", val.shape[0]])) + outputs.append(helper.make_tensor_value_info(f"output.{key}", dtype, ["batch_size", "sequence_length", val.shape[1]])) + + # Create initializer data + tensor = numpy_helper.from_array(val) + tensor.name = key + initializers.append(tensor) + + # Create MatMul node + matmul_node = helper.make_node( + "MatMul", + inputs=[inputs[-1].name, tensor.name], + outputs=[outputs[-1].name], + name=f"node.{key}", + ) + nodes.append(matmul_node) + + model = helper.make_model( + opset_imports=[helper.make_operatorsetid('', 14)], + ir_version=7, + producer_name="onnxruntime-genai", + producer_version="0.0.0", + graph=helper.make_graph( + name="main_graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + value_info=value_infos, + nodes=nodes, + ) + ) + onnx.save_model( + model, + fpath, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=f"{filename}.data", + size_threshold=0, + convert_attribute=False, + ) + + +def extract_adapters_from_onnx(args, fpath): + adapters = {} + model = onnx.load_model(fpath) + for initializer in model.graph.initializer: + val = numpy_helper.to_array(initializer) + adapters[initializer.name] = ort.OrtValue.ortvalue_from_numpy(val) + return adapters + + +def build_quantized_adapters(args): + # 1. Extract LoRAs from PyTorch model + vision_adapters, speech_adapters = extract_adapters_from_torch(args) + + # 2. Put LoRAs into separate ONNX models + filename = "phi-4-mm-lora-vision.onnx" + fpath_1 = os.path.join(args.output, filename) + vision_model = build_adapter_only_onnx_model(args, vision_adapters, filename, fpath_1) + + filename = "phi-4-mm-lora-speech.onnx" + fpath_2 = os.path.join(args.output, filename) + speech_model = build_adapter_only_onnx_model(args, speech_adapters, filename, fpath_2) + + # 3. Quantize ONNX models to int4 + filename = "phi-4-mm-qlora-vision.onnx" + fpath_3 = os.path.join(args.output, filename) + cmd = [ + f"{sys.executable}", "-m", "onnxruntime.quantization.matmul_4bits_quantizer", + "--input_model", fpath_1, + "--output_model", fpath_3, + "--block_size", str(32), + ] + if args.precision == torch.float32: cmd.extend(["--accuracy_level", str(4)]) + subprocess.run(cmd) + + filename = "phi-4-mm-qlora-speech.onnx" + fpath_4 = os.path.join(args.output, filename) + cmd = [ + f"{sys.executable}", "-m", "onnxruntime.quantization.matmul_4bits_quantizer", + "--input_model", fpath_2, + "--output_model", fpath_4, + "--block_size", str(32), + ] + if args.precision == torch.float32: cmd.extend(["--accuracy_level", str(4)]) + subprocess.run(cmd) + + os.remove(fpath_1) + os.remove(fpath_1 + ".data") + os.remove(fpath_2) + os.remove(fpath_2 + ".data") + + # 4. Extract quantized LoRAs from ONNX models + vision_adapters = extract_adapters_from_onnx(args, fpath_3) + speech_adapters = extract_adapters_from_onnx(args, fpath_4) + + # 5. Store quantized LoRAs in adapter files + build_onnx_adapters(vision_adapters, speech_adapters) + + os.remove(fpath_3) + os.remove(fpath_3 + ".data") + os.remove(fpath_4) + os.remove(fpath_4 + ".data") + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-i", + "--input", + required=True, + help="Path to folder on disk containing the Hugging Face config, model, tokenizer, etc.", + ) + + parser.add_argument( + "-o", + "--output", + required=True, + help="Path to folder to store ONNX model and additional files (e.g. GenAI config, external data files, etc.)", + ) + + parser.add_argument( + "-p", + "--precision", + required=True, + choices=["fp16", "fp32"], + help="Precision to export PyTorch components with", + ) + + parser.add_argument( + "-e", + "--execution_provider", + required=True, + choices=["cpu", "cuda", "dml"], + help="Execution provider for Phi-3.5 vision components", + ) + + parser.add_argument( + "-c", + "--cache_dir", + required=False, + default=os.path.join('.', 'cache_dir'), + help="Cache directory for Hugging Face files and temporary ONNX external data files", + ) + + args = parser.parse_args() + args.precision = torch.float16 if args.precision == "fp16" else torch.float32 + return args + +if __name__ == "__main__": + user_prompt = '<|user|>\n' + assistant_prompt = '<|assistant|>\n' + prompt_suffix = "<|end|>\n" + + args = get_args() + config = AutoConfig.from_pretrained(args.input, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(args.input, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(args.input, trust_remote_code=True, torch_dtype=args.precision).to(args.execution_provider.replace("dml", "cuda")) + + # Build model components + build_vision(args) + build_speech(args) + build_embedding(args) + build_text(args) + build_adapters(args) diff --git a/onnx/config.json b/onnx/config.json new file mode 100644 index 0000000000000000000000000000000000000000..96ba4596d689797e5ed3e5e7178e175929c24cec --- /dev/null +++ b/onnx/config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16fb355ba07bea3ffdf794f297f2005aee4f4ee6aba9742e264ad4471535e966 +size 4585 diff --git a/onnx/modeling_phio.py b/onnx/modeling_phio.py new file mode 100644 index 0000000000000000000000000000000000000000..176305681675a34d5c94b5e05bfdba44e8c6b806 --- /dev/null +++ b/onnx/modeling_phio.py @@ -0,0 +1,2636 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" PyTorch Phi-O model.""" +import os +import math +import warnings +from typing import List, Optional, Tuple, Union + +import numpy as np + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers import PretrainedConfig + +from .configuration_phio import PhiOConfig +from .processing_phio import InputMode +from .vision_siglip_navit import get_siglip_vision_model +from .speech_conformer_encoder import ConformerEncoder + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "TBA" +_CONFIG_FOR_DOC = "PhiOConfig" + +# Special token ids +_IMAGE_SPECIAL_TOKEN_ID = 200010 # '<|endoftext10|>', or we can better name it (in `tokenizer_config.json`) +_AUDIO_SPECIAL_TOKEN_ID = 200011 # '<|endoftext11|>' +_COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE = [-9999, -1] # For backward compatibility +_COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE = [float('-inf'), -10000] # For backward compatibility + +################################################## +# Functions that have to be outside of the class +# for torch.jit.script_if_tracing to work +################################################## + +@torch.jit.script_if_tracing +def get_image_embeddings( + image_features, # [bs, max_num_crops, base_feat_height * base_feat_width (24*24), C] + attention_mask, # 4D image attention mask + image_sizes, # [num_images, 2] + sub_GN, # [1, 1, 1, image_dim_out * base_feat_height_reduction**2] + glb_GN, # [1, 1, image_dim_out * base_feat_height_reduction**2] + bfht: int, # base_feat_height_target + crop_size: int, # base_resolution + bfhr: int, # base_feat_height_reduction + bfh: int, # base_feat_height + bfw: int, # base_feat_width + C: int, # Channels + H: int, # Height + device: torch.device, # Target device + dtype: torch.dtype, # Target dtype +): + """ + Compute HD feature transformation + """ + # Compute common constants used frequently in for-loop + H_bfhr = H // bfhr + bfhr_2_C = bfhr * bfhr * C + bfh_bfhr = bfh // bfhr + bfw_bfhr = bfw // bfhr + + all_image_embeddings = torch.empty(0, 1152).to(device) + for i, img_size in enumerate(image_sizes): + h, w = img_size[0], img_size[1] + h = torch.tensor(h // crop_size, dtype=torch.int64) + w = torch.tensor(w // crop_size, dtype=torch.int64) + B_ = h * w + + # Compute common constants used frequently that are dependent on values in for-loop + h_bfh_bfhr = h * bfh // bfhr # h * bfh_bfhr + w_bfw_bfhr = w * bfw // bfhr # w * bfw_bfhr + + # 1 x (24x24) x 1024 + global_img_feature = image_features[i, :1] + + # 1 x 12 x 12 x 4096 + glb_img = ( + global_img_feature.reshape(1, H, H, C) + .reshape(1, H_bfhr, bfhr, H_bfhr, bfhr, C).contiguous().permute(0, 1, 3, 2, 4, 5) + .reshape(1, H_bfhr, H_bfhr, bfhr_2_C).contiguous() + ) + temp_glb_GN = sub_GN.repeat(1, H_bfhr, 1, 1) + + # 1 x 156 x 4096 + glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(1, -1, bfhr_2_C) + + # (max_num_crops-1) x (12x12) x C + sub_img = image_features[i, 1:] + # 16x574x1024 + # get rid of padding sub_img + sub_img = sub_img[:B_] + + # (num_crops, 12, 2, 12, 2, 1024) -> (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024) + sub_img = ( + sub_img.reshape(B_, H, H, C) + .reshape(B_, H_bfhr, bfhr, H_bfhr, bfhr, C).contiguous().permute(0, 1, 3, 2, 4, 5) + .reshape(B_, -1, bfhr_2_C).contiguous() + ) + sub_img = ( + sub_img.reshape(1, h, w, bfh_bfhr, bfw_bfhr, -1).permute(0, 1, 3, 2, 4, 5) + .reshape(1, h_bfh_bfhr, w_bfw_bfhr, bfhr_2_C) + ) + + reshaped_attention_mask = ( + attention_mask[i, 1:B_+1, 0::2, 0::2] + .reshape(1, h, w, bfh_bfhr, bfw_bfhr) + .permute(0, 1, 3, 2, 4) + .reshape(1, h_bfh_bfhr, w_bfw_bfhr) + ) + useful_height = int(reshaped_attention_mask[0, :, 0].sum().to(torch.int64).item()) + useful_width = int(reshaped_attention_mask[0, 0, :].sum().to(torch.int64).item()) + sub_img = sub_img[:,:useful_height, :useful_width] + temp_sub_GN = sub_GN.repeat(1, useful_height, 1, 1) + + sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1, -1, bfhr_2_C) + # (1, num_img_tokens, 1024*4) + + # Apply hd_transform_order = 'sub_glb' + all_image_embeddings = torch.cat( + [ + all_image_embeddings, + sub_img.view(-1, 1152), + glb_GN.view(-1, 1152), + glb_img.view(-1, 1152), + ] + ) + + return all_image_embeddings + + +@torch.jit.script_if_tracing +def get_merged_audio_set_tensor(audio_set_tensor: torch.Tensor, audio_sizes: torch.Tensor, device: torch.device): + # audio_features_proj: (merged_N_tokens, C = 3072) + audio_features_proj = torch.empty(0, 3072).to(device) + for i in range(len(audio_sizes)): + t = audio_set_tensor[i, :audio_sizes[i], :] + audio_features_proj = torch.cat([audio_features_proj, t], dim=0) + return audio_features_proj + + +@torch.jit.script_if_tracing +def calculate_positions(input_ids: torch.LongTensor, features: torch.FloatTensor, special_token_id: int): + # Calculate positions for image/audio tokens + if features.numel(): + return torch.where(input_ids == special_token_id) + return torch.where(torch.zeros((1, 1), dtype=torch.bool)) + + +@torch.jit.script_if_tracing +def select_logic(hidden_states: torch.FloatTensor, features: torch.FloatTensor, positions: List[torch.LongTensor]): + if features.numel(): + # apply 'select' logic + hidden_states = hidden_states.index_put( + positions, features, accumulate=False + ) + + return hidden_states + + +class PhiOEmbedding(nn.Module): + """Phi-O embedding for text-only, vision + text, speech + text, and vision + speech + text""" + def __init__(self, wte): + super().__init__() + self.wte = wte + + def forward( + self, + input_ids: torch.LongTensor, + image_features: torch.FloatTensor, + audio_features: torch.FloatTensor, + ): + # Mask input ids for image and audio tokens + new_input_ids = input_ids.clone() + new_input_ids[(input_ids >= _COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE[0]) & + (input_ids <= _COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE[1])] = _IMAGE_SPECIAL_TOKEN_ID + new_input_ids[(input_ids >= _COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE[0]) & + (input_ids <= _COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE[1])] = _AUDIO_SPECIAL_TOKEN_ID + input_ids = new_input_ids + + # Calculate position masks + image_position_mask = input_ids == _IMAGE_SPECIAL_TOKEN_ID + non_image_position_mask = ~image_position_mask + + # Calculate base hidden states + hidden_states = self.wte(input_ids) + + # Calculate hidden states for image tokens + image_positions = calculate_positions(input_ids, image_features, _IMAGE_SPECIAL_TOKEN_ID) + image_hidden_states = select_logic(hidden_states, image_features, image_positions) + + # Calculate hidden states for audio tokens + audio_positions = calculate_positions(input_ids, audio_features, _AUDIO_SPECIAL_TOKEN_ID) + audio_hidden_states = select_logic(hidden_states, audio_features, audio_positions) + + # Merge image, audio, and text hidden states into final hidden states for language model + hidden_states = image_hidden_states * image_position_mask.unsqueeze(-1) + audio_hidden_states * non_image_position_mask.unsqueeze(-1) + return hidden_states + + +class PhiOImageEmbedding(nn.Module): + """Image embedding.""" + + def __init__(self, config: PretrainedConfig, **kwargs) -> None: + super().__init__() + + # n_embed or hidden_size + hidden_size = config.n_embd if hasattr(config, 'n_embd') else config.hidden_size + if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'): + embd_drop = config.embd_pdrop if hasattr(config, 'embd_pdrop') else config.embed_pdrop + self.drop = nn.Dropout(embd_drop) + else: + self.drop = None + + logger.info(f"create image tower {config.img_processor}") + enable_gradient_checkpointing = kwargs.get('enable_gradient_checkpointing', False) + + # Load SigLIP model + self.img_processor = get_siglip_vision_model(_flash_attn_2_enabled=False) + + pe_weight = self.img_processor.embeddings.position_embedding.weight + L, D = pe_weight.size() + H = int(math.sqrt(L)) + assert H**2 == L + if H % 2 != 0: #and kwargs.get('image_token_compression_cls', None) is None: + self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) + H += 1 + image_dim_out = D + # ((448/14)//2)**2 + self.num_img_tokens = (H//2)**2 + self.base_feat_height_target = H + + if enable_gradient_checkpointing: + self.img_processor.encoder.gradient_checkpointing = True + + self.image_dim_out = image_dim_out + self.img_sizes = None + self.image_attention_mask = None + + # global_gn and sub_gn for hd transform, serves as line separator + self.use_hd_transform = kwargs.get('use_hd_transform', False) + self.with_learnable_separator = kwargs.get('with_learnable_separator', False) + self.hd_transform_order = kwargs.get('hd_transform_order', 'glb_sub') + self.freeze_img_processor = kwargs.get('freeze_img_processor', False) + self.crop_size = kwargs.get('crop_size', 336) + logger.info(f'freeze_img_processor = {self.freeze_img_processor}') + + # image token compression + self.image_token_compression_cls = kwargs.get('image_token_compression_cls', None) + if self.image_token_compression_cls == 'avg_pool_2d': + self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) + self.base_feat_height_reduction = 1 + self.base_feat_height_target = self.base_feat_height_target // 2 + elif self.image_token_compression_cls is None: + self.image_token_compression = None + self.base_feat_height_reduction = 2 + else: + raise NotImplementedError(f'image_token_compression_cls = {self.image_token_compression_cls}, not implemented') + + # with_hd_transform and with_learnable_separator should have same value + assert self.use_hd_transform == self.with_learnable_separator, 'use_hd_transform and with_learnable_separator should have same value' + if self.with_learnable_separator: + assert self.use_hd_transform, 'learnable separator is only for hd transform' + # 1024 * 4, merge spatial to channel dimension + self.glb_GN = nn.Parameter(torch.zeros([1, 1, self.image_dim_out * self.base_feat_height_reduction**2])) + self.sub_GN = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out * self.base_feat_height_reduction**2])) + logger.info(f'learnable separator enabled for hd transform, hd_transform_order = {self.hd_transform_order}') + + projection_cls = kwargs.get('projection_cls', 'linear') + if projection_cls == 'linear': + self.img_projection = nn.Linear(image_dim_out, hidden_size) + elif projection_cls == 'mlp' and self.use_hd_transform: + dim_projection = hidden_size + depth = 2 + layers = [nn.Linear(image_dim_out * self.base_feat_height_reduction**2, dim_projection)] + for _ in range(1, depth): + layers.extend([nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + elif projection_cls == 'mlp': + # follow llava-v1.5's implementation + # (do not use image_projection and image_proj_norm) + dim_projection = hidden_size + depth = 2 + layers = [nn.Linear(image_dim_out, dim_projection)] + for _ in range(1, depth): + layers.extend([nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + else: + raise NotImplementedError(f'projection_cls = {projection_cls}, not implemented') + + self.vocab_size = config.vocab_size + self.img_features = None + + if isinstance(config.img_processor, dict): + self.layer_idx = config.img_processor.get('layer_idx', -2) + self.type_feature = config.img_processor.get('type_feature', 'patch') + else: + self.layer_idx = -2 + self.type_feature = 'patch' + + def set_img_features(self, img_features: torch.FloatTensor) -> None: + self.img_features = img_features + + def set_img_sizes(self, img_sizes: torch.LongTensor) -> None: + self.img_sizes = img_sizes + + def set_img_attn_mask(self, image_attention_mask: torch.FloatTensor) -> None: + self.image_attention_mask = image_attention_mask + + def get_img_features(self, img_embeds: torch.FloatTensor, attention_mask=None) -> torch.FloatTensor: + LAYER_IDX = self.layer_idx + TYPE_FEATURE = self.type_feature + + if self.freeze_img_processor: + with torch.no_grad(): + if attention_mask is not None: + img_processor_output = self.img_processor(img_embeds, output_hidden_states=True, patch_attention_mask=attention_mask) + else: + img_processor_output = self.img_processor(img_embeds, output_hidden_states=True) + img_feature = img_processor_output.hidden_states[LAYER_IDX] + else: + if attention_mask is not None: + img_processor_output = self.img_processor(img_embeds, output_hidden_states=True, patch_attention_mask=attention_mask) + else: + img_processor_output = self.img_processor(img_embeds, output_hidden_states=True) + img_feature = img_processor_output.hidden_states[LAYER_IDX] + + if TYPE_FEATURE == "patch": + patch_feature = img_feature + if self.image_token_compression is not None: + # reshape to 2D tensor + width = int(math.sqrt(patch_feature.size(1))) + patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) + # convert to NCHW + patch_feature = patch_feature.permute(0, 3, 1, 2) + if getattr(self, 'img_processor_padding', None) is not None: + patch_feature = self.img_processor_padding(patch_feature) + patch_feature = self.image_token_compression(patch_feature) + # convert to NHWC + patch_feature = patch_feature.permute(0, 2, 3, 1) + patch_feature = patch_feature.view(-1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1)) + elif getattr(self, 'img_processor_padding', None) is not None: + width = int(math.sqrt(patch_feature.size(1))) + patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) + # convert to NCHW + patch_feature = patch_feature.permute(0, 3, 1, 2) + patch_feature = self.img_processor_padding(patch_feature) + # convert to NHWC + patch_feature = patch_feature.permute(0, 2, 3, 1) + patch_feature = patch_feature.view(-1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1)) + return patch_feature + + if TYPE_FEATURE == "cls_patch": + if self.image_token_compression is not None: + # reshape to 2D tensor + patch_feature = img_feature[:, 1:] + cls_feature = img_feature[:, 0] + width = math.sqrt(patch_feature.size(1)) + patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) + patch_feature = self.image_token_compression(patch_feature) + patch_feature = patch_feature.view(-1, patch_feature.size(-2) * patch_feature.size(-1)) + img_feature = torch.cat([cls_feature, patch_feature], dim=1) + return img_feature + + logger.info(f'processed img feature size = {img_feature.size()}') + raise NotImplementedError + + def spatiotemporal_pool(self, x, num_img_tokens, batch_size=1, T=1): + if self.image_pos_embed is not None: + x = x.view(batch_size * T, -1, x.shape[-1]) + num_tokens = x.shape[-2] + h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5) + assert h * w == num_tokens, 'only support square feature maps for now' + x = x.view(batch_size * T, h, w, x.shape[-1]) + pos_embed = self.image_pos_embed(x) + x = x + pos_embed + x = x.view(batch_size, T * h * w, x.shape[-1]) + + if self.visual_temporal_embed is not None: + visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) + x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1]) + + new_x = [] + # [bsz, T * H' * W', C] -> [bsz, T, C] + spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) + new_x.append(spatial_avg_pool_x) + + # [bsz, T * H' * W', C] -> [bsz, H'*W', C] + temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) + new_x.append(temporal_avg_pool_x) + + x = torch.cat(new_x, dim=1).view(-1, self.image_dim_out) + num_img_tokens += T + return x, num_img_tokens + + # # PyTorch forward pass: + # def forward(self, input_ids: torch.LongTensor, input_embeds: torch.FloatTensor, image_sizes=None, **kwargs) -> torch.FloatTensor: + + # if isinstance(input_ids, tuple): + # # # pipeline parallel + # input_ids, input_embeds = input_ids + + # img_embeds = input_embeds + # if image_sizes is None and 'image_sizes' in kwargs: + # image_sizes = kwargs['image_sizes'] + # img_sizes = image_sizes + + # if self.img_features is not None: + # img_embeds = self.img_features.clone() + # self.img_features = None + + # if self.img_sizes is not None: + # img_sizes = self.img_sizes + + # if img_embeds is not None: + # # convert to bf16 + # img_embeds = img_embeds.to(torch.bfloat16) + + # if self.image_attention_mask is not None: + # image_attention_mask = self.image_attention_mask.clone() + # self.image_attention_mask = None + # elif 'image_attention_mask' in kwargs: + # image_attention_mask = kwargs['image_attention_mask'] + # else: + # image_attention_mask = None + # input_shape = input_ids.size() + # input_ids = input_ids.view(-1, input_shape[-1]) + + # with torch.no_grad(): + # positions = torch.nonzero(input_ids == _IMAGE_SPECIAL_TOKEN_ID, as_tuple=False) + # positions_tuple = torch.nonzero(input_ids == _IMAGE_SPECIAL_TOKEN_ID, as_tuple=True) + + # # logger.info(f'position size: {positions.size()} ...') + # fake_image_forward = False + # select = False + # hd_transform = False + + # if isinstance(self.img_projection, nn.Sequential): + # target_device = self.img_projection[0].bias.device + # target_dtype = self.img_projection[0].bias.dtype + # else: # It's a single nn.Linear layer + # target_device = self.img_projection.bias.device + # target_dtype = self.img_projection.bias.dtype + + # num_img_tokens = self.num_img_tokens + # if len(positions.tolist()) > 0: + # if self.use_hd_transform and img_sizes is not None and len(img_sizes): + # hd_transform = True + # assert img_embeds.ndim == 5, f'(branch 1) img_embeds size: {img_embeds.size()}, expect 5D tensor for hd transform' + # # img_embeds: (num_images, max_num_crops, 3, H, W) + # # img_sizes: (num_images, 2).view(1, -1) + + # bs = img_embeds.shape[0] + # # Nx(HW)xC + # if image_attention_mask is not None and len(image_attention_mask) > 0: + # img_features = self.get_img_features(img_embeds.flatten(0, 1), attention_mask=image_attention_mask.type(torch.BoolTensor).flatten(0,1).to(target_device)) + # else: + # img_features = self.get_img_features(img_embeds.flatten(0, 1)) + + # base_feat_height_target = self.base_feat_height_target + # base_resolution = self.crop_size + # base_feat_height_reduction = self.base_feat_height_reduction + + # base_feat_height = base_feat_width = int(np.sqrt(img_features.shape[1])) + + # assert base_feat_height == base_feat_height_target and base_feat_width == base_feat_height_target, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect {base_feat_height_target} features for hd transform' + + # # bs x max_num_crops x (24x24) x C + # img_features = img_features.view(bs, -1, base_feat_height * base_feat_width, self.image_dim_out) + # C = self.image_dim_out + # H = base_feat_height + + # output_imgs = [] + # output_len = [] + # # training is tensor, inference is list + # if isinstance(img_sizes, torch.Tensor): + # img_sizes = img_sizes.view(-1, 2) + # for _bs in range(bs): + # h, w = img_sizes[_bs] + # h = h // base_resolution + # w = w // base_resolution + # B_ = h * w + + # # 1 x (24x24) x 1024 + # global_img_feature = img_features[_bs, :1] + + # # 1 x 12 x 12 x 4096 + # glb_img = global_img_feature.reshape(1,H,H,C).reshape(1,H//base_feat_height_reduction,base_feat_height_reduction,H//base_feat_height_reduction,base_feat_height_reduction,C).contiguous().permute(0,1,3,2,4,5).reshape(1,H//base_feat_height_reduction,H//base_feat_height_reduction,base_feat_height_reduction*base_feat_height_reduction*C).contiguous() + # temp_glb_GN = self.sub_GN.repeat(1, H//base_feat_height_reduction, 1, 1) + + # # 1 x 156 x 4096 + # glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(1,-1,base_feat_height_reduction*base_feat_height_reduction*C) + + # # (max_num_crops-1) x (12x12) x C + # sub_img = img_features[_bs, 1:] + # # 16x574x1024 + # # get rid of padding sub_img + # sub_img = sub_img[:B_] + + # # (num_crops, 12, 2, 12, 2, 1024) -> (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024) + # sub_img = sub_img.reshape(B_,H,H,C).reshape(B_,H//base_feat_height_reduction,base_feat_height_reduction,H//base_feat_height_reduction,base_feat_height_reduction,C).contiguous().permute(0,1,3,2,4,5).reshape(B_,-1,base_feat_height_reduction*base_feat_height_reduction*C).contiguous() + # sub_img = sub_img.reshape(1, h, w, base_feat_height // base_feat_height_reduction, base_feat_width // base_feat_height_reduction, -1).permute(0,1,3,2,4,5).reshape(1,h*base_feat_height//base_feat_height_reduction,w*base_feat_width//base_feat_height_reduction,base_feat_height_reduction*base_feat_height_reduction*C) + + # if image_attention_mask is not None and len(image_attention_mask) > 0: + # reshaped_image_attention_mask = image_attention_mask[_bs,1:B_+1,0::2,0::2].reshape(1, h, w, base_feat_height // base_feat_height_reduction, base_feat_width // base_feat_height_reduction).permute(0,1,3,2,4).reshape(1,h*base_feat_height//base_feat_height_reduction,w*base_feat_width//base_feat_height_reduction) + # useful_height = int(reshaped_image_attention_mask[0,:,0].sum().item()) + # useful_width = int(reshaped_image_attention_mask[0,0,:].sum().item()) + # sub_img = sub_img[:,:useful_height, :useful_width] + # temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1) + # temp_len = int(image_attention_mask[_bs,:B_+1,0::2,0::2].sum().item()) + (useful_height+1) + base_feat_height//base_feat_height_reduction + # else: + # temp_sub_GN = self.sub_GN.repeat(1, h*base_feat_height//base_feat_height_reduction, 1, 1) + # temp_len = int((h*w+1)*self.num_img_tokens+ 1 + (h+1)*base_feat_height//base_feat_height_reduction) + + # sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1,-1,base_feat_height_reduction*base_feat_height_reduction*C) + # # (1, num_img_tokens, 1024*4) + + # # glb + sub + # if self.hd_transform_order == 'glb_sub': + # output_imgs.append(torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) + # elif self.hd_transform_order == 'sub_glb': + # output_imgs.append(torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) + # else: + # raise NotImplementedError(f'hd_transform_order = {self.hd_transform_order}, not implemented') + + # #temp_len = int((h*w+1)*144 + 1 + (h+1)*12) + # assert temp_len == output_imgs[-1].shape[1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}' + # output_len.append(temp_len) + + # num_img_tokens = output_len + # img_set_tensor = [] + # for _output_img in output_imgs: + # img_feature_proj = self.img_projection(_output_img.to(target_device).to(target_dtype)) + # img_set_tensor.append(img_feature_proj) + # #logger.info(f'img_embeds size: {img_embeds.size()}, image sizes: {img_sizes} loading time {datetime.now() - start_time}') + # #assert sum(num_img_tokens) == len(g_values), f'(branch 1) sum(num_img_tokens): {sum(num_img_tokens)}, g_values size: {len(g_values)}, g_values {g_values}' + + # else: + # raise NotImplementedError + # select = True + # else: + # # # create a fake image tensor + # # # TODO: need define image size for different vision model + # if self.training: + # img_embeds = torch.zeros(1, 3, self.crop_size, self.crop_size, dtype=torch.bfloat16, device=input_ids.device) + + # tt = ( + # self.get_img_features(img_embeds) + # .to(target_device) + # .to(target_dtype) + # .reshape(-1, 1024) + # ) + # if self.use_hd_transform: + # img_set_tensor = self.img_projection(tt.reshape(-1, self.image_dim_out*self.base_feat_height_reduction**2) * self.glb_GN[0] * self.sub_GN[0, 0]) + # else: + # img_set_tensor = self.img_projection(tt) # adapted visual features. + # fake_image_forward = True + + # # we use the token embedding layer from the huggingface model, this is REQUIRED to make sure we are using the loaded weights. + # hidden_states = kwargs['wte'](input_ids) + + # if select: + # if hd_transform: + # # new implementation without in-place operation + # # Ref: https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/4a0d683eba9f1d0cbfb6151705d1ee73c25a80ca/modeling_phi3_v.py#L233 + # # Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.index_put.html + # # Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html#torch.Tensor.index_put_ + # # img_set_tensor: a list of tensors, each tensor has shape (1, N_tokens, C) + # assert all([_img_set_tensor.shape[0] == 1 for _img_set_tensor in img_set_tensor]), 'img_set_tensor should have shape (1, N_tokens, C)' + # # Shape: (merged_N_tokens, C) + # merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0) + # merged_img_set_tensor = merged_img_set_tensor.to(hidden_states.dtype).to(hidden_states.device) + # # Temporarily disable autocast to avoid issue on bf16 tensors + # # Ref: https://github.com/pytorch/pytorch/issues/132715 + # with torch.autocast(device_type=hidden_states.device.type, enabled=False): + # new_hidden_states = hidden_states.index_put( + # indices=positions_tuple, + # values=merged_img_set_tensor, + # accumulate=False + # ) + # hidden_states = new_hidden_states + # else: + # raise NotImplementedError + + # if fake_image_forward and self.training: + # hidden_states = hidden_states + (0 * img_set_tensor[0].to(hidden_states.dtype).to(hidden_states.device)).sum() + + # if self.drop is not None: + # hidden_states = self.drop(hidden_states) + + # return hidden_states + + # ONNX forward pass: + def forward(self, pixel_values: torch.FloatTensor, attention_mask: torch.FloatTensor, image_sizes: torch.LongTensor) -> torch.FloatTensor: + # pixel_values: (num_images, max_num_crops, 3, H, W) + # image_sizes: (num_images, 2).view(1, -1) + if isinstance(self.img_projection, nn.Sequential): + target_device = self.img_projection[0].bias.device + target_dtype = self.img_projection[0].bias.dtype + else: # It's a single nn.Linear layer + target_device = self.img_projection.bias.device + target_dtype = self.img_projection.bias.dtype + assert pixel_values.ndim == 5, f'(branch 1) pixel_values size: {pixel_values.size()}, expect 5D tensor for hd transform' + + # Compute image features: Nx(HW)xC + image_features = self.get_img_features(pixel_values.flatten(0, 1), attention_mask=attention_mask.type(torch.BoolTensor).flatten(0, 1).to(target_device)) + + # Calculate height and width of base feature + base_feat_height = base_feat_width = int(torch.sqrt(image_features.shape[1])) + assert base_feat_height == self.base_feat_height_target and base_feat_width == self.base_feat_height_target, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect {self.base_feat_height_target} features for hd transform' + + # bs x max_num_crops x (bfh*bfw) x C + bs = pixel_values.shape[0] + image_features = image_features.view(bs, -1, base_feat_height * base_feat_width, self.image_dim_out) + + # all_image_embeddings: (num_img_tokens, 1152) + all_image_embeddings = get_image_embeddings( + image_features, attention_mask, image_sizes, self.sub_GN, self.glb_GN, + bfht=self.base_feat_height_target, crop_size=self.crop_size, bfhr=self.base_feat_height_reduction, + bfh=base_feat_height, bfw=base_feat_width, C=self.image_dim_out, H=base_feat_height, + device=target_device, dtype=target_dtype, + ) + + # image_features_proj: (num_img_tokens, 3072) + image_features_proj = self.img_projection( + all_image_embeddings.unsqueeze(0).to(device=target_device, dtype=target_dtype) + ) + return image_features_proj.squeeze() + + +class PhiOAudioEmbedding(nn.Module): + """Audio embedding.""" + + def __init__(self, config: PretrainedConfig, **kwargs) -> None: + super().__init__() + self.config = config + # n_embed or hidden_size for text LM + hidden_size = config.n_embd if hasattr(config, 'n_embd') else config.hidden_size + + if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'): + embd_drop = config.embd_pdrop if hasattr(config, 'embd_pdrop') else config.embed_pdrop + self.drop = nn.Dropout(embd_drop) + else: + self.drop = None + + audio_dim_out = None # Set this variable according to the actual audio processor + logger.info(f"create audio processor {config.audio_processor}") + self.layer_idx = -2 + + if isinstance(config.audio_processor, dict) and config.audio_processor.get('name', None) == "cascades": + encoder_config = config.audio_processor.get("config", None) + assert encoder_config is not None + self.encoder = ConformerEncoder(**encoder_config) + + # fake initialization, create encoder_embedding layer only so that + # in decoding, all parameters can be loaded in from_pretrained_function + # in training, we do post init after from_pretrained function to make sure the correct initialization + self.encoder.post_init({}) + + audio_dim_out = encoder_config["attention_dim"] + n_mels = encoder_config["input_size"] + else: + raise NotImplementedError + + assert audio_dim_out is not None, "Remember to set values for audio_dim_out" + self.audio_dim_out = audio_dim_out + self.audio_dim_in = n_mels + + self.freeze_audio_processor = kwargs.get('freeze_audio_processor', False) + logger.info(f'freeze_audio_processor = {self.freeze_audio_processor}') + + self.downsample_rate = kwargs.get('downsample_rate', 1) + + enable_gradient_checkpointing = kwargs.get('enable_gradient_checkpointing', False) + if enable_gradient_checkpointing: + self.encoder.gradient_checkpointing_enable() + logger.info(f'gradient checkpointing enabled for audio processor') + + projection_cls = kwargs.get('projection_cls', 'linear') + if projection_cls == 'linear': + self.audio_projection = nn.Linear(audio_dim_out, hidden_size) + elif projection_cls == 'mlp': + # follow llava-v1.5's implementation + # (do not use image_projection and image_proj_norm) + dim_projection = hidden_size + depth = 2 + self.linear_downsample_rate = self.downsample_rate + + layers_for_speech = [nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)] + for _ in range(1, depth): + layers_for_speech.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) + audio_projection_for_speech = nn.Sequential(*layers_for_speech) + + layers_for_vision = [nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)] + for _ in range(1, depth): + layers_for_vision.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) + audio_projection_for_vision = nn.Sequential(*layers_for_vision) + + self.audio_projection = nn.ModuleDict({ + 'speech': audio_projection_for_speech, + 'vision': audio_projection_for_vision + }) + else: + raise NotImplementedError(f'projection_cls = {projection_cls}, not implemented') + + self.vocab_size = config.vocab_size + self.input_embeds = None + self.audio_embed_sizes = None + + def post_init(self, audio_config): + # execute after the from_pretrained() initialization of the phio model + if audio_config.get('name', None) == "cascades": + init_model_config = audio_config.get("init_model", {}) + self.encoder.post_init(init_model_config) + # remove the init model in config so it is not saved in the config. + # This might affect the model loading in resuming training and decoding. + if "init_model" in audio_config: + audio_config.pop("init_model") + + def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None: + self.input_embeds = input_embeds + + def set_audio_embed_sizes(self, audio_embed_sizes: torch.LongTensor) -> None: + self.audio_embed_sizes = audio_embed_sizes + + def get_audio_features(self, input_embeds: torch.FloatTensor, audio_attention_mask: torch.LongTensor, audio_projection_mode: torch.LongTensor): + if self.freeze_audio_processor: + with torch.no_grad(): + audio_features, masks = self.encoder(input_embeds, audio_attention_mask) + else: + audio_features, masks = self.encoder(input_embeds, audio_attention_mask) + + if isinstance(self.audio_projection, nn.Sequential): + audio_set_tensor = self.audio_projection(audio_features) + elif isinstance(self.audio_projection, nn.ModuleDict): + audio_set_tensor = torch.cond( + audio_projection_mode == torch.tensor([InputMode.SPEECH.value], device=audio_projection_mode.device).long(), + lambda x: self.audio_projection['speech'](x), + lambda x: self.audio_projection['vision'](x), + (audio_features,) + ) + else: + raise NotImplementedError + + return audio_set_tensor + + # # PyTorch forward pass: + # def forward(self, input_ids: torch.LongTensor, input_embeds: torch.FloatTensor, audio_embed_sizes=None, audio_attention_mask=None, audio_projection_mode='speech', **kwargs) -> torch.FloatTensor: + # ''' + # arguments: + # input_ids: input text ids (B, U) + # input_embeds: audio features (B, T, D) B: num audios in a sequence + # ''' + # if self.input_embeds is not None: + # input_embeds = self.input_embeds.clone() + # if self.audio_embed_sizes is not None: + # audio_embed_sizes = self.audio_embed_sizes.clone() + + # input_shape = input_ids.size() + # input_ids = input_ids.view(-1, input_shape[-1]) + # MAX_INPUT_ID = int(1e9) + + # with torch.no_grad(): + # positions = torch.nonzero(input_ids == _AUDIO_SPECIAL_TOKEN_ID, as_tuple=False) + # positions_tuple = torch.nonzero(input_ids == _AUDIO_SPECIAL_TOKEN_ID, as_tuple=True) + + # if isinstance(self.audio_projection, nn.Sequential): + # target_device = self.audio_projection[0].bias.device + # target_dtype = self.audio_projection[0].bias.dtype + # elif isinstance(self.audio_projection, nn.ModuleDict): + # target_device = self.audio_projection[audio_projection_mode][0].bias.device + # target_dtype = self.audio_projection[audio_projection_mode][0].bias.dtype + # else: # It's a single nn.Linear layer + # target_device = self.audio_projection.bias.device + # target_dtype = self.audio_projection.bias.dtype + + # if input_embeds is not None: + # input_embeds = input_embeds.to(target_device).to(target_dtype) + + # if len(positions.tolist()) > 0: + # audio_set_tensor = self.get_audio_features(input_embeds, audio_attention_mask, audio_projection_mode) + # else: + # # # create an audio tensor + # # To do: not sure if this is required for text only input + # if self.training: + # audio_embeds = torch.zeros(1, 500, self.audio_dim_in).to(target_device).to(target_dtype) + # audio_attention_mask = audio_embeds.new_ones(audio_embeds.size()[:2]).long() + # audio_set_tensor = self.get_audio_features(audio_embeds, audio_attention_mask, audio_projection_mode) + + # hidden_states = kwargs['wte'](input_ids) + + # if len(positions.tolist()) > 0: + + # assert audio_embed_sizes.sum().item() == len(positions), \ + # f"please ensure the encoder outputs have the same length as defined in input_ids! \n audio_embed_sizes.sum().item(): {audio_embed_sizes.sum().item()} \n len(positions): {len(positions)} \n audio_embed_sizes: {audio_embed_sizes} \n positions: {positions} \n input_ids.shape \n {input_ids.shape}" + + # # new implementation without in-place operation + # # Ref: https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/4a0d683eba9f1d0cbfb6151705d1ee73c25a80ca/modeling_phi3_v.py#L233 + # # Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.index_put.html + # # Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html#torch.Tensor.index_put_ + # # audio_set_tensor: shape (N_audios, N_padded_tokens, C) + # # Shape: (merged_N_tokens, C) + # merged_audio_set_tensor = torch.cat([ + # audio_set_tensor[i, :audio_embed_sizes[i], :] + # for i in range(len(audio_embed_sizes)) + # ], dim=0) + # merged_audio_set_tensor = merged_audio_set_tensor.to(hidden_states.dtype).to(hidden_states.device) + # # Temporarily disable autocast to avoid issue on bf16 tensors + # # Ref: https://github.com/pytorch/pytorch/issues/132715 + # with torch.autocast(device_type=hidden_states.device.type, enabled=False): + # new_hidden_states = hidden_states.index_put( + # indices=positions_tuple, + # values=merged_audio_set_tensor, + # accumulate=False + # ) + # hidden_states = new_hidden_states + # else: + # if self.training: + # hidden_states = hidden_states + (0 * audio_set_tensor[:,0].to(hidden_states.dtype).to(hidden_states.device)).sum() + + # if self.drop is not None: + # hidden_states = self.drop(hidden_states) + + # return hidden_states + + def forward(self, audio_embeds: torch.FloatTensor, attention_mask: torch.LongTensor, audio_sizes: torch.LongTensor, audio_projection_mode: torch.LongTensor) -> torch.FloatTensor: + # audio_embeds: (batch_size, num_frames, feature_size) + # audio_attention_mask: (batch_size, feature_size) + # audio_sizes: (batch_size) + # audio_projection_mode: (1) - InputMode.SPEECH or InputMode.VISION + if isinstance(self.audio_projection, nn.Sequential): + target_device = self.audio_projection[0].bias.device + target_dtype = self.audio_projection[0].bias.dtype + elif isinstance(self.audio_projection, nn.ModuleDict): + target_device = self.audio_projection['speech'][0].bias.device + target_dtype = self.audio_projection['speech'][0].bias.dtype + else: # It's a single nn.Linear layer + target_device = self.audio_projection.bias.device + target_dtype = self.audio_projection.bias.dtype + + # audio_set_tensor: (N_audios, N_padded_tokens, C = 3072) + audio_set_tensor = self.get_audio_features(audio_embeds, attention_mask, audio_projection_mode) + + # mask: (N_audios, N_padded_tokens) + # audio_sizes: (N_audios, 1) + # Create a mask to select the valid audio features + mask = torch.arange(audio_set_tensor.size(1)).expand(audio_sizes.size(0), -1).to(target_device) < audio_sizes.unsqueeze(1) + + # audio_features_proj: (merged_N_tokens, C = 3072) + # Use the mask to select and concatenate the valid audio features + audio_features_proj = audio_set_tensor[mask].view(-1, audio_set_tensor.size(2)) + + return audio_features_proj + + +class PhiOImageAudioEmbedding(nn.Module): + """Image-audio embedding.""" + + def __init__(self, config: PretrainedConfig, **kwargs) -> None: + super().__init__() + + self.vocab_size = config.vocab_size + + self.image_input_id = kwargs.get('image_input_id', -1) + self.audio_input_id = kwargs.get('audio_input_id', -10000) + assert self.image_input_id != self.audio_input_id, 'image_input_id and audio_input_id should be different' + + self.image_embd_layer_kwargs = kwargs['image_embd_layer'] + self.image_embed = PhiOImageEmbedding(config, **self.image_embd_layer_kwargs) + self.audio_embd_layer_kwargs = kwargs['audio_embd_layer'] + self.audio_embed = PhiOAudioEmbedding(config, **self.audio_embd_layer_kwargs) + + self.input_image_embeds = None + self.image_sizes = None + self.image_attention_mask = None + self.input_audio_embeds = None + self.audio_embed_sizes = None + + def post_init(self, audio_config): + # post init for audio embedding + # ref: model.model.embed_tokens_extend.post_init(audio_config) in phyagi/getters/model.py + self.audio_embed.post_init(audio_config) + + def set_input_image_embeds(self, input_image_embeds: torch.FloatTensor) -> None: + self.input_image_embeds = input_image_embeds + + def set_image_sizes(self, image_sizes: torch.LongTensor) -> None: + self.image_sizes = image_sizes + + def set_img_attn_mask(self, image_attention_mask: torch.FloatTensor) -> None: + self.image_attention_mask = image_attention_mask + + def set_input_audio_embeds(self, input_audio_embeds: torch.FloatTensor) -> None: + self.input_audio_embeds = input_audio_embeds + + def set_audio_embed_sizes(self, audio_embed_sizes: torch.LongTensor) -> None: + self.audio_embed_sizes = audio_embed_sizes + + def forward( + self, + input_ids: torch.LongTensor, + input_embeds: torch.FloatTensor, + input_image_embeds: torch.FloatTensor=None, + input_audio_embeds: torch.FloatTensor=None, + image_sizes=None, + image_attention_mask=None, + audio_embed_sizes=None, + audio_attention_mask=None, + audio_projection_mode='speech', + wte=None, + ) -> torch.FloatTensor: + MAX_INPUT_ID = int(1e9) + assert -MAX_INPUT_ID < self.audio_input_id < self.image_input_id + + # override image and audio embeddings and sizes from object itself + # this is for inference + # ref: phyagi/eval/utils/text_generation_vision_audio_pipeline.py + if self.input_image_embeds is not None: + assert input_image_embeds is None + input_image_embeds = self.input_image_embeds.clone() + # NOTE weijian: set input_image_embeds to None after first call in for eval stage + # during evaluation, it will call model's forward() multiple times + # the first time input_ids contains the prompt (including <|image_{}|>) and input_embeds exists + # from the second time, the input_ids will only contain the generated text + # thus, the input_image_embeds is no longer needed + self.input_image_embeds = None + + if self.image_sizes is not None: + assert image_sizes is None + image_sizes = self.image_sizes + + if self.input_audio_embeds is not None: + assert input_audio_embeds is None + input_audio_embeds = self.input_audio_embeds.clone() + self.input_audio_embeds = None + + if self.audio_embed_sizes is not None: + assert audio_embed_sizes is None + audio_embed_sizes = self.audio_embed_sizes.clone() + + if input_image_embeds is not None: + # convert to bf16 + input_image_embeds = input_image_embeds.to(torch.bfloat16) + + if self.image_attention_mask is not None: + assert image_attention_mask is None + image_attention_mask = self.image_attention_mask.clone() + self.image_attention_mask = None + + if input_audio_embeds is not None: + # convert to bf16 + input_audio_embeds = input_audio_embeds.to(torch.bfloat16) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + # backward compatibility + with torch.no_grad(): + new_input_ids = input_ids.clone() + new_input_ids[(input_ids >= _COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE[0]) & + (input_ids <= _COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE[1])] = _IMAGE_SPECIAL_TOKEN_ID + new_input_ids[(input_ids >= _COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE[0]) & + (input_ids <= _COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE[1])] = _AUDIO_SPECIAL_TOKEN_ID + input_ids = new_input_ids + + with torch.no_grad(): + image_position_mask = input_ids == _IMAGE_SPECIAL_TOKEN_ID + non_image_position_mask = ~image_position_mask + + assert input_embeds is None + if self.training: + assert input_image_embeds is not None and input_audio_embeds is not None + + # copy the input ids since they will be modified in place in image_embed and audio_embed + image_hidden_states = self.image_embed( + input_ids=input_ids, + input_embeds=input_image_embeds, + image_sizes=image_sizes, + wte=wte, + image_attention_mask=image_attention_mask + ) + audio_hidden_states = self.audio_embed( + input_ids=input_ids, + input_embeds=input_audio_embeds, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + wte=wte, + audio_projection_mode=audio_projection_mode, + ) + + # merge image and audio hidden states + # NOTE weijian: for non-image-audio tokens, here we use audio hidden states + # actually, in the debug code above, the non-image-audio tokens from image_hidden_states and audio_hidden_states should be the same + hidden_states = image_hidden_states * image_position_mask.to(torch.bfloat16).unsqueeze(-1) + audio_hidden_states * non_image_position_mask.to(torch.bfloat16).unsqueeze(-1) + + return hidden_states + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3 +class PhiORMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + PhiORMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3 +class PhiORotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class PhiOSuScaledRotaryEmbedding(PhiORotaryEmbedding): + def __init__(self, dim, config, device=None): + warnings.warn( + "The class PhiOSuScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers. Please" + " use PhiOLongRoPEScaledRotaryEmbedding instead.", + FutureWarning, + ) + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling["short_factor"] + self.long_factor = config.rope_scaling["long_factor"] + self.original_max_position_embeddings = config.original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class PhiOYarnScaledRotaryEmbedding(PhiORotaryEmbedding): + def __init__(self, dim, config, device=None): + warnings.warn( + "The class PhiOYarnScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers", + FutureWarning, + ) + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling["short_factor"] + self.long_factor = config.rope_scaling["long_factor"] + self.original_max_position_embeddings = config.original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = 0.1 * math.log(scale) + 1.0 + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class PhiOLongRoPEScaledRotaryEmbedding(PhiORotaryEmbedding): + def __init__(self, dim, config, device=None): + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling["short_factor"] + self.long_factor = config.rope_scaling["long_factor"] + self.original_max_position_embeddings = config.original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = seq_len or torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) + k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) + return q_embed, k_embed + + +class PhiOMLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class PhiOAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: PhiOConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.original_max_position_embeddings = config.original_max_position_embeddings + self.rope_theta = config.rope_theta + self.rope_scaling = config.rope_scaling + self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False) + self._init_rope() + + def _init_rope(self): + if self.rope_scaling is None: + self.rotary_emb = PhiORotaryEmbedding( + self.rotary_ndims, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + if scaling_type == "longrope": + self.rotary_emb = PhiOLongRoPEScaledRotaryEmbedding(self.rotary_ndims, self.config) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.") + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights += causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class PhiOFlashAttention2(PhiOAttention): + """ + Phi-O flash attention module. This module inherits from `PhiOAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # PhiOFlashAttention2 attention does not support output_attentions + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = ( + max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len + ) + + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len, position_ids=position_ids) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_dropout = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + if query_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.qkv_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=attn_dropout, + sliding_window=getattr(self.config, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi +# TODO @Arthur no longer copied from LLama after static cache +class PhiOSdpaAttention(PhiOAttention): + """ + PhiO attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `PhiOAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from PhiOAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "PhiOModel is using PhiOSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +PHIO_ATTENTION_CLASSES = { + "eager": PhiOAttention, + "flash_attention_2": PhiOFlashAttention2, + "sdpa": PhiOSdpaAttention, +} + + +class PhiODecoderLayer(nn.Module): + def __init__(self, config: PhiOConfig, layer_idx: int): + super().__init__() + + self.config = config + self.self_attn = PHIO_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + + self.mlp = PhiOMLP(config) + self.input_layernorm = PhiORMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + self.post_attention_layernorm = PhiORMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attn_outputs, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = residual + self.resid_attn_dropout(attn_outputs) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +PHIO_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`PhiOConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Phi-O model outputting raw hidden-states without any specific head on top.", + PHIO_START_DOCSTRING, +) +class PhiOPreTrainedModel(PreTrainedModel): + config_class = PhiOConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PhiODecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + _version = "0.0.5" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +PHIO_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Phi-O model outputting raw hidden-states without any specific head on top.", + PHIO_START_DOCSTRING, +) +class PhiOModel(PhiOPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiODecoderLayer`] + + Args: + config: PhiOConfig + """ + + def __init__(self, config: PhiOConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_dropout = nn.Dropout(config.embd_pdrop) + self.combined_embed = PhiOEmbedding(self.embed_tokens) + + self.embed_tokens_extend = None + if isinstance(config.embd_layer, dict): + embedding_config = { + 'embedding_cls': config.embd_layer['embedding_cls'], + **config.embd_layer + } + self.embed_tokens_extend = PhiOImageAudioEmbedding(config, **embedding_config) + + self.layers = nn.ModuleList( + [PhiODecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = PhiORMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHIO_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + input_image_embeds: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + image_attention_mask=None, + input_audio_embeds: Optional[torch.FloatTensor] = None, + audio_embed_sizes=None, + audio_attention_mask=None, + audio_projection_mode=None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens_extend( + input_ids=input_ids, + input_embeds=inputs_embeds, + input_image_embeds=input_image_embeds, + input_audio_embeds=input_audio_embeds, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + audio_projection_mode=audio_projection_mode, + wte=self.embed_tokens, + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Phi3 + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: PhiOConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`PhiOConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class PhiOForCausalLM(PhiOPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi + def __init__(self, config): + super().__init__(config) + self.model = PhiOModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # LoRA related settings + assert getattr(config, "vision_lora", None) is not None + from peft import LoraConfig, get_peft_model + vision_lora_config = LoraConfig( + r=config.vision_lora['r'], + lora_alpha=config.vision_lora['lora_alpha'], + target_modules=config.vision_lora['layer'], + lora_dropout=config.vision_lora['dp'], + task_type="CAUSAL_LM", + ) + peft_model = get_peft_model(self.model, vision_lora_config, adapter_name="vision") + self.config.vision_lora['r'] = config.vision_lora['r'] + self.config.vision_lora['lora_alpha'] = config.vision_lora['lora_alpha'] + self.config.vision_lora['layer'] = config.vision_lora['layer'] + self.config.vision_lora['dp'] = config.vision_lora['dp'] + + assert getattr(config, "speech_lora", None) is not None + speech_lora_config = LoraConfig( + r=config.speech_lora['r'], + lora_alpha=config.speech_lora['lora_alpha'], + target_modules=config.speech_lora['layer'], + lora_dropout=config.speech_lora['dp'], + task_type="CAUSAL_LM", + ) + peft_model.base_model.active_adapter.append("speech") + peft_model.add_adapter("speech", speech_lora_config) + self.config.speech_lora['r'] = config.speech_lora['r'] + self.config.speech_lora['lora_alpha'] = config.speech_lora['lora_alpha'] + self.config.speech_lora['layer'] = config.speech_lora['layer'] + self.config.speech_lora['dp'] = config.speech_lora['dp'] + + def set_lora_adapter(self, adapter_name) -> None: + from peft.tuners.lora.layer import LoraLayer + for module in self.modules(): + if isinstance(module, LoraLayer): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + module._disable_adapters = False + + def unset_lora_adapter(self) -> None: + # Ref: peft/tuners/tuners_utils.py - enable_adapters() + # Ref: peft/tuners/lora/layer.py + from peft.tuners.lora.layer import LoraLayer + for module in self.modules(): + if isinstance(module, LoraLayer): + # disable grads on all adapter layers + # TODO weijian: may use enable_adapters() instead + for layer_name in module.adapter_layer_names: + layer = getattr(module, layer_name) + layer.requires_grad_(False) + module._disable_adapters = True + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings + def get_input_embeddings(self): + return self.model.embed_tokens + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder + def set_decoder(self, decoder): + self.model = decoder + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder + def get_decoder(self): + return self.model + + # Ignore copy + @add_start_docstrings_to_model_forward(PHIO_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + input_image_embeds: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + image_attention_mask=None, + input_audio_embeds: Optional[torch.FloatTensor] = None, + audio_embed_sizes=None, + audio_attention_mask=None, + input_mode=None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, PhiOForCausalLM + + >>> model = PhiOForCausalLM.from_pretrained("TBA") + >>> tokenizer = AutoTokenizer.from_pretrained("TBA") + + >>> prompt = "This is an example script ." + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' + ```""" + if ( + use_cache + and self.config.rope_scaling + and cache_position is not None + and cache_position[0] == self.config.original_max_position_embeddings + ): + logger.warning( + f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed." + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if isinstance(input_mode, torch.Tensor): + assert len(input_mode) == 1 + input_mode = input_mode[0].item() + input_mode = InputMode(input_mode) + + if input_mode in [InputMode.VISION_SPEECH, InputMode.VISION]: + self.set_lora_adapter('vision') + audio_projection_mode = InputMode.VISION # 'vision' + elif input_mode == InputMode.SPEECH: + self.set_lora_adapter('speech') + audio_projection_mode = InputMode.SPEECH # 'speech' + elif input_mode == InputMode.LANGUAGE: + self.unset_lora_adapter() + audio_projection_mode = InputMode.SPEECH # 'speech' + else: + raise ValueError(f"Invalid input_mode: {input_mode}") + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + input_image_embeds=input_image_embeds, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + input_audio_embeds=input_audio_embeds, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + audio_projection_mode=audio_projection_mode, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + input_image_embeds=None, + image_sizes=None, + image_attention_mask=None, + input_audio_embeds=None, + audio_embed_sizes=None, + audio_attention_mask=None, + input_mode=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + input_image_embeds=input_image_embeds, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + input_audio_embeds=input_audio_embeds, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + input_mode=input_mode, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + num_logits_to_keep=num_logits_to_keep, + **kwargs, + ) + return model_inputs + + +@add_start_docstrings( + """ + The [`PhiOModel`] with a sequence classification head on top (linear layer). + + [`PhiOForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + PHIO_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi, LLAMA->PHI, self.transformer->self.model, transformer_outputs->model_outputs +class PhiOForSequenceClassification(PhiOPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = PhiOModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHIO_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = model_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + model_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=model_outputs.past_key_values, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) + + +@add_start_docstrings( + """ + [`PhiOModel`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + PHIO_START_DOCSTRING, +) +# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi,MPT->PHI,self.transformer->self.model,transformer_outputs->model_outputs +class PhiOForTokenClassification(PhiOPreTrainedModel): + def __init__(self, config: PhiOConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = PhiOModel(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PHIO_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = model_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + batch_size, seq_length = labels.shape + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) + ) + + if not return_dict: + output = (logits,) + model_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) diff --git a/onnx/processing_phio.py b/onnx/processing_phio.py new file mode 100644 index 0000000000000000000000000000000000000000..bf1521ad107c6c87ca39ad6873969f6f406ba4f3 --- /dev/null +++ b/onnx/processing_phio.py @@ -0,0 +1,732 @@ +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Processor class for PhiO +""" +import re +from typing import List, Optional, Tuple, Union +import math +from enum import Enum + +import numpy as np +import scipy +import torch +import torchvision + +from transformers import AutoFeatureExtractor, AutoImageProcessor +from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature +from transformers.image_utils import ( + ImageInput, + make_list_of_images, + valid_images, +) +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import PaddingStrategy, TextInput, TruncationStrategy +from transformers.utils import TensorType, logging +from torch.nn.utils.rnn import pad_sequence + + +logger = logging.get_logger(__name__) + +# Special tokens +_COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN = r'<\|image_\d+\|>' # For backward compatibility +_COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN = r'<\|audio_\d+\|>' # For backward compatibility +_IMAGE_SPECIAL_TOKEN = '<|endoftext10|>' +_AUDIO_SPECIAL_TOKEN = '<|endoftext11|>' +_IMAGE_SPECIAL_TOKEN_ID = 200010 # '<|endoftext10|>', or we can better name it (in `tokenizer_config.json`) +_AUDIO_SPECIAL_TOKEN_ID = 200011 # '<|endoftext11|>' + + +class InputMode(Enum): + LANGUAGE = 0 + VISION = 1 + SPEECH = 2 + VISION_SPEECH = 3 + + +class PhiOImageProcessor(BaseImageProcessor): + r""" + Constructs a PhiO image processor. + """ + model_input_names = ["input_image_embeds", "image_sizes", "image_attention_mask"] + + def __init__( + self, + dynamic_hd, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dynamic_hd = dynamic_hd + + def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=384, mask_size=27, use_thumbnail=True): + orig_width, orig_height = image.size + + w_crop_num = math.ceil(orig_width/float(image_size)) + h_crop_num = math.ceil(orig_height/float(image_size)) + if w_crop_num * h_crop_num > max_num: + + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = self.find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + print(target_aspect_ratio) + else: + target_width = image_size * w_crop_num + target_height = image_size * h_crop_num + target_aspect_ratio = (w_crop_num, h_crop_num) + + # Calculate the ratio + ratio_width = target_width / orig_width + ratio_height = target_height / orig_height + if ratio_width < ratio_height: + new_size = (target_width, int(orig_height * ratio_width)) + padding_width = 0 + padding_height = target_height - int(orig_height * ratio_width) + else: + new_size = (int(orig_width * ratio_height), target_height) + padding_width = target_width - int(orig_width * ratio_height) + padding_height = 0 + + attention_mask = torch.ones((int(mask_size*target_aspect_ratio[1]), int(mask_size*target_aspect_ratio[0]))) + if padding_width >= 14: + attention_mask[:, -math.floor(padding_width/14):] = 0 + if padding_height >= 14: + attention_mask[-math.floor(padding_height/14):,:] = 0 + assert attention_mask.sum() > 0 + + if min(new_size[1], target_height) < 10 or min(new_size[0], target_width) < 10: + raise ValueError(f'the aspect ratio is very extreme {new_size}') + + image = torchvision.transforms.functional.resize(image, [new_size[1], new_size[0]],) + + resized_img = torchvision.transforms.functional.pad(image, [0, 0, padding_width, padding_height], fill=[255,255,255]) + + return resized_img, attention_mask + + def pad_to_max_num_crops(self, images, max_crops=5): + """ + images: B x 3 x H x W, B<=max_crops + """ + B, _, H, W = images.shape + if B < max_crops: + pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device) + images = torch.cat([images, pad], dim=0) + return images + + def pad_mask_to_max_num_crops(self, masks, max_crops=5): + B, H, W = masks.shape + if B < max_crops: + pad = torch.ones(max_crops - B, H, W, dtype=masks.dtype, device=masks.device) + masks = torch.cat([masks, pad], dim=0) + return masks + + def preprocess( + self, + images: ImageInput, + return_tensors: Optional[Union[str, TensorType]] = None, + ): + """ + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + """ + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + # Basic settings. + img_processor = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (0.5, 0.5, 0.5), + (0.5, 0.5, 0.5) + ), + ]) + dyhd_base_resolution = 448 + + # Dynamic HD + base_resolution = dyhd_base_resolution + images = [image.convert('RGB') for image in images] + # cover 384 and 448 resolution + mask_resolution = base_resolution // 14 + elems, image_attention_masks = [], [] + for im in images: + elem, attention_mask = self.dynamic_preprocess(im, max_num=self.dynamic_hd, image_size=base_resolution, mask_size=mask_resolution) + elems.append(elem) + image_attention_masks.append(attention_mask) + hd_images = [img_processor(im) for im in elems] + global_image = [torch.nn.functional.interpolate(im.unsqueeze(0).float(), size=(base_resolution, base_resolution), mode='bicubic',).to(im.dtype) for im in hd_images] + shapes = [[im.size(1), im.size(2)] for im in hd_images] + mask_shapes = [[mask.size(0), mask.size(1)] for mask in image_attention_masks] + global_attention_mask = [torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images] + hd_images_reshape = [im.reshape(1, 3, + h//base_resolution, + base_resolution, + w//base_resolution, + base_resolution + ).permute(0,2,4,1,3,5).reshape(-1, 3, base_resolution, base_resolution).contiguous() for im, (h, w) in zip(hd_images, shapes)] + attention_masks_reshape = [mask.reshape(1, + h//mask_resolution, + mask_resolution, + w//mask_resolution, + mask_resolution + ).permute(0,1,3,2,4).reshape(-1, mask_resolution, mask_resolution).contiguous() for mask, (h, w) in zip(image_attention_masks, mask_shapes)] + downsample_attention_masks = [mask[:,0::2,0::2].reshape(1, + h//mask_resolution, + w//mask_resolution, + mask_resolution//2+mask_resolution%2, + mask_resolution//2+mask_resolution%2 + ).permute(0,1,3,2,4) for mask, (h,w) in zip(attention_masks_reshape, mask_shapes)] + downsample_attention_masks = [mask.reshape(mask.size(1)*mask.size(2), mask.size(3)*mask.size(4))for mask in downsample_attention_masks] + num_img_tokens = [256 + 1 + int(mask.sum().item()) + int(mask[:,0].sum().item()) + 16 for mask in downsample_attention_masks] + + hd_images_reshape = [torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape)] + hd_masks_reshape = [torch.cat([_global_mask] + [_mask], dim=0) for _global_mask, _mask in zip(global_attention_mask, attention_masks_reshape)] + max_crops = max([img.size(0) for img in hd_images_reshape]) + image_transformed = [self.pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape] + image_transformed = torch.stack(image_transformed, dim=0) + mask_transformed = [self.pad_mask_to_max_num_crops(mask, max_crops) for mask in hd_masks_reshape] + mask_transformed = torch.stack(mask_transformed, dim=0) + + returned_input_image_embeds = image_transformed + returned_image_sizes = torch.tensor(shapes, dtype=torch.long) + returned_image_attention_mask = mask_transformed + returned_num_img_tokens = num_img_tokens + + data = { + "input_image_embeds": returned_input_image_embeds, + "image_sizes": returned_image_sizes, + "image_attention_mask": returned_image_attention_mask, + "num_img_tokens": returned_num_img_tokens, + } + + return BatchFeature(data=data, tensor_type=return_tensors) + + +AudioInput = Tuple[Union[np.ndarray, torch.Tensor], int] +AudioInputs = List[AudioInput] + + +def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None): + """Create a Mel filter-bank the same as SpeechLib FbankFC. + + Args: + sample_rate (int): Sample rate in Hz. number > 0 [scalar] + n_fft (int): FFT size. int > 0 [scalar] + n_mel (int): Mel filter size. int > 0 [scalar] + fmin (float): lowest frequency (in Hz). If None use 0.0. + float >= 0 [scalar] + fmax: highest frequency (in Hz). If None use sample_rate / 2. + float >= 0 [scalar] + + Returns + out (numpy.ndarray): Mel transform matrix + [shape=(n_mels, 1 + n_fft/2)] + """ + + bank_width = int(n_fft // 2 + 1) + if fmax is None: + fmax = sample_rate / 2 + if fmin is None: + fmin = 0 + assert fmin >= 0, "fmin cannot be negtive" + assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]" + + def mel(f): + return 1127.0 * np.log(1.0 + f / 700.0) + + def bin2mel(fft_bin): + return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0)) + + def f2bin(f): + return int((f * n_fft / sample_rate) + 0.5) + + # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1] + klo = f2bin(fmin) + 1 + khi = f2bin(fmax) + + khi = max(khi, klo) + + # Spec 2: SpeechLib uses trianges in Mel space + mlo = mel(fmin) + mhi = mel(fmax) + m_centers = np.linspace(mlo, mhi, n_mels + 2) + ms = (mhi - mlo) / (n_mels + 1) + + matrix = np.zeros((n_mels, bank_width), dtype=np.float32) + for m in range(0, n_mels): + left = m_centers[m] + center = m_centers[m + 1] + right = m_centers[m + 2] + for fft_bin in range(klo, khi): + mbin = bin2mel(fft_bin) + if left < mbin < right: + matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms + + return matrix + + +class PhiOAudioFeatureExtractor(SequenceFeatureExtractor): + model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"] + + def __init__(self, audio_compression_rate, audio_downsample_rate, audio_feat_stride, **kwargs): + feature_size = 80 + sampling_rate = 16000 + padding_value = 0.0 + super().__init__(feature_size, sampling_rate, padding_value, **kwargs) + + self.compression_rate = audio_compression_rate + self.qformer_compression_rate = audio_downsample_rate + self.feat_stride = audio_feat_stride + + self._eightk_method = "fillzero" + self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T + + self._hamming400 = np.hamming(400) # for 16k audio + self._hamming200 = np.hamming(200) # for 8k audio + + def duration_to_frames(self, duration): + """duration in s, estimated frames""" + frame_rate = 10 + + num_frames = duration * 1000 // frame_rate + return num_frames + + def __call__( + self, + audios: List[AudioInput], + return_tensors: Optional[Union[str, TensorType]] = None, + ): + # Ref: https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py#L161 + returned_input_audio_embeds = [] + returned_audio_embed_sizes = [] + audio_frames_list = [] + # import pdb; pdb.set_trace() + + for audio_data, sample_rate in audios: + audio_embeds = self._extract_features(audio_data, sample_rate) + audio_frames = len(audio_embeds) * self.feat_stride + audio_embed_size = self._compute_audio_embed_size(audio_frames) + + returned_input_audio_embeds.append(torch.tensor(audio_embeds)) + returned_audio_embed_sizes.append(torch.tensor(audio_embed_size).long()) + audio_frames_list.append(audio_frames) + + returned_input_audio_embeds = pad_sequence( + returned_input_audio_embeds, batch_first=True + ) + returned_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0) + audio_frames = torch.tensor(audio_frames_list) + returned_audio_attention_mask = torch.arange(0, audio_frames.max()).unsqueeze(0) < audio_frames.unsqueeze(1) if len(audios) > 1 else None + + data = { + "input_audio_embeds": returned_input_audio_embeds, + "audio_embed_sizes": returned_audio_embed_sizes, + } + if returned_audio_attention_mask is not None: + data["audio_attention_mask"] = returned_audio_attention_mask + + return BatchFeature(data=data, tensor_type=return_tensors) + + def _extract_spectrogram(self, wav, fs): + """Extract spectrogram features from waveform. + Args: + wav (1D array): waveform of the input + fs (int): sampling rate of the waveform, 16000 or 8000. + If fs=8000, the waveform will be resampled to 16000Hz. + Output: + log_fbank (2D array): a TxD matrix of log Mel filterbank features. + D=80, and T is the number of frames. + """ + if wav.ndim > 1: + wav = np.squeeze(wav) + + # by default, we extract the mean if stereo + if len(wav.shape) == 2: + wav = wav.mean(1) + + # Resample to 16000 or 8000 if needed + if fs > 16000: + wav = scipy.signal.resample_poly(wav, 1, fs // 16000) + fs = 16000 + elif 8000 < fs < 16000: + wav = scipy.signal.resample_poly(wav, 1, fs // 8000) + fs = 8000 + elif fs < 8000: + raise RuntimeError(f"Unsupported sample rate {fs}") + + if fs == 8000: + if self._eightk_method == "resample": + # Input audio is 8 kHz. Convert to 16 kHz before feature + # extraction + wav = scipy.signal.resample_poly(wav, 2, 1) + fs = 16000 + # Do nothing here for fillzero method + elif fs != 16000: + # Input audio is not a supported sample rate. + raise RuntimeError(f"Input data using an unsupported sample rate: {fs}") + + preemphasis = 0.97 + + if fs == 8000: + n_fft = 256 + win_length = 200 + hop_length = 80 + fft_window = self._hamming200 + elif fs == 16000: + n_fft = 512 + win_length = 400 + hop_length = 160 + fft_window = self._hamming400 + + # Spec 1: SpeechLib cut remaining sample insufficient for a hop + n_batch = (wav.shape[0] - win_length) // hop_length + 1 + # Here we don't use stride_tricks since the input array may not satisfy + # memory layout requirement and we need writeable output + # Here we only use list of views before copy to desination + # so it is more efficient than broadcasting + y_frames = np.array( + [wav[_stride : _stride + win_length] for _stride in range(0, hop_length * n_batch, hop_length)], + dtype=np.float32, + ) + + # Spec 2: SpeechLib applies preemphasis within each batch + y_frames_prev = np.roll(y_frames, 1, axis=1) + y_frames_prev[:, 0] = y_frames_prev[:, 1] + y_frames = (y_frames - preemphasis * y_frames_prev) * 32768 + + S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype(np.complex64) + + if fs == 8000: + # Need to pad the output to look like 16 kHz data but with zeros in + # the 4 to 8 kHz bins. + frames, bins = S.shape + padarray = np.zeros((frames, bins)) + S = np.concatenate((S[:, 0:-1], padarray), axis=1) # Nyquist bin gets set to zero + + spec = np.abs(S).astype(np.float32) + return spec + + def _extract_features(self, wav, fs): + """Extract log filterbank features from waveform. + Args: + wav (1D array): waveform of the input + fs (int): sampling rate of the waveform, 16000 or 8000. + If fs=8000, the waveform will be resampled to 16000Hz. + Output: + log_fbank (2D array): a TxD matrix of log Mel filterbank features. + D=80, and T is the number of frames. + """ + spec = self._extract_spectrogram(wav, fs) + spec_power = spec**2 + + fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None) + log_fbank = np.log(fbank_power).astype(np.float32) + + return log_fbank + + def _compute_audio_embed_size(self, audio_frames): + integer = audio_frames // self.compression_rate + remainder = audio_frames % self.compression_rate + + result = integer if remainder == 0 else integer + 1 + + integer = result // self.qformer_compression_rate + remainder = result % self.qformer_compression_rate + result = integer if remainder == 0 else integer + 1 # qformer compression + + return result + + +class PhiOProcessor(ProcessorMixin): + r""" + Constructs a PhiO processor which raps an image processor, a audio processor, and a GPT tokenizer into a single processor. + + [`PhiOProcessor`] offers all the functionalities of [`PhiOImageProcessor`] and [`GPT2Tokenizer`]. See the + [`~PhiOProcessor.__call__`] and [`~PhiOProcessor.decode`] for more information. + + Args: + image_processor ([`PhiOImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`GPT2Tokenizer`], *optional*): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "audio_processor", "tokenizer"] + tokenizer_class = "GPT2TokenizerFast" + image_processor_class = "AutoImageProcessor" # PhiOImageProcessor will be registered later + audio_processor_class = "AutoFeatureExtractor" # PhiOAudioFeatureExtractor will be registered later + + def __init__(self, image_processor, audio_processor, tokenizer): + self.image_processor = image_processor + self.audio_processor = audio_processor + self.tokenizer = tokenizer + + def __call__( + self, + text: Union[TextInput, List[TextInput]], + images: Optional[ImageInput] = None, + audios: Optional[AudioInputs] = None, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Optional[Union[bool, str, TruncationStrategy]] = None, + max_length=None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forards the `text` + and `kwargs` arguments to GPT2Tokenizer's [`~GPT2Tokenizer.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + PhiOImageProcessor's [`~PhiOImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **input_image_embeds** -- Pixel values to be fed to a model. + - **image_sizes** -- List of tuples specifying the size of each image in `input_image_embeds`. + - **image_attention_mask** -- List of attention masks for each image in `input_image_embeds`. + - **input_audio_embeds** -- Audio embeddings to be fed to a model. + - **audio_embed_sizes** -- List of integers specifying the size of each audio in `input_audio_embeds`. + - **audio_attention_mask** -- List of attention masks for each audio in `input_audio_embeds`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + """ + image_inputs = self.image_processor(images, return_tensors=return_tensors) if images is not None else {} + audio_inputs = self.audio_processor(audios, return_tensors=return_tensors) if audios is not None else {} + inputs = self._convert_images_audios_text_to_inputs( + image_inputs, + audio_inputs, + text, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + ) + + # idenfity the input mode + if len(image_inputs) > 0 and len(audio_inputs) > 0: + input_mode = InputMode.VISION_SPEECH + elif len(image_inputs) > 0: + input_mode = InputMode.VISION + elif len(audio_inputs) > 0: + input_mode = InputMode.SPEECH + else: + input_mode = InputMode.LANGUAGE + inputs["input_mode"] = torch.tensor([input_mode.value], dtype=torch.long) + + return inputs + + @property + def special_image_token_id(self): + return self.tokenizer.convert_tokens_to_ids(self.special_image_token) + + def get_special_image_token_id(self): + return self.tokenizer.convert_tokens_to_ids(self.special_image_token) + + def _convert_images_audios_text_to_inputs( + self, images, audios, text, padding=False, truncation=None, max_length=None, return_tensors=None + ): + # prepare image id to image input ids + if len(images) > 0: + input_image_embeds = images["input_image_embeds"] + image_sizes = images["image_sizes"] + image_attention_mask = images["image_attention_mask"] + num_img_tokens = images['num_img_tokens'] + else: + input_image_embeds = torch.tensor([]) + image_sizes = torch.tensor([]) + image_attention_mask = torch.tensor([]) + num_img_tokens = [] + + # prepare audio id to audio input ids + if len(audios) > 0: + input_audio_embeds = audios["input_audio_embeds"] + audio_embed_sizes = audios["audio_embed_sizes"] + audio_attention_mask = audios.get("audio_attention_mask", torch.tensor([])) + else: + input_audio_embeds = torch.tensor([]) + audio_embed_sizes = torch.tensor([]) + audio_attention_mask = torch.tensor([]) + + # Replace certain special tokens for compatibility + # Ref: https://stackoverflow.com/questions/11475885/python-replace-regex + if isinstance(text, str): + text = [text] + assert isinstance(text, list) + processed_text = [re.sub(_COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN, _IMAGE_SPECIAL_TOKEN, t) for t in text] + processed_text = [re.sub(_COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN, _AUDIO_SPECIAL_TOKEN, t) for t in processed_text] + + input_ids_list = [self.tokenizer(t).input_ids for t in processed_text] + + img_cnt, audio_cnt = 0, 0 # only needed for later assertion + image_token_count_iter = iter(num_img_tokens) + audio_embed_size_iter = iter(audio_embed_sizes.tolist()) + new_input_ids_list = [] + for input_ids in input_ids_list: + i = 0 + while i < len(input_ids): + token_id = input_ids[i] + if token_id == _AUDIO_SPECIAL_TOKEN_ID: + token_count = next(audio_embed_size_iter) + audio_cnt += 1 + elif token_id == _IMAGE_SPECIAL_TOKEN_ID: + token_count = next(image_token_count_iter) + img_cnt += 1 + else: + i += 1 + continue + tokens = [token_id] * token_count + input_ids = input_ids[:i] + tokens + input_ids[i + 1:] + i += token_count + input_ids = torch.tensor(input_ids, dtype=torch.long) + new_input_ids_list.append(input_ids) + lengths = torch.tensor([len(input_ids) for input_ids in new_input_ids_list]) + max_len = lengths.max() + input_ids = input_ids.new_full((len(new_input_ids_list), max_len), self.tokenizer.pad_token_id) + # batched inference requires left padding + for i in range(len(new_input_ids_list)): + input_ids[i, max_len - len(new_input_ids_list[i]):] = new_input_ids_list[i] + + # If the below assertion fails, it might be that input pure-text + # messages contain image/audio special tokens literally + # (<|endoftext10|>, <|endoftext11|>). + assert ( + img_cnt == len(num_img_tokens) + ), ( + f"Number of image tokens in prompt_token_ids ({img_cnt}) " + f"does not match number of images ({len(num_img_tokens)})" + ) + assert ( + audio_cnt == len(audio_embed_sizes) + ), ( + f"Number of audio tokens in prompt_token_ids ({audio_cnt}) " + f"does not match number of audios ({len(audio_embed_sizes)})" + ) + + # prepare attention mask + seq_range = torch.arange(max_len - 1, -1, -1) + attention_mask = seq_range.unsqueeze(0) < lengths.unsqueeze(1) + + # prepare batch feature + data = { + "input_ids": input_ids, + "input_image_embeds": input_image_embeds, + "image_sizes": image_sizes, + "image_attention_mask": image_attention_mask, + "input_audio_embeds": input_audio_embeds, + "audio_embed_sizes": audio_embed_sizes, + "audio_attention_mask": audio_attention_mask, + "attention_mask": attention_mask, + } + + return BatchFeature( + data=data + ) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + audio_processor_input_names = self.audio_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + audio_processor_input_names)) + + +AutoImageProcessor.register("PhiOImageProcessor", PhiOImageProcessor) +AutoFeatureExtractor.register("PhiOAudioFeatureExtractor", PhiOAudioFeatureExtractor) diff --git a/onnx/speech_conformer_encoder.py b/onnx/speech_conformer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0576686ea062313855906b92575b62a9f4e852cb --- /dev/null +++ b/onnx/speech_conformer_encoder.py @@ -0,0 +1,2954 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +#!/usr/bin/env python3 + +# activation_checkpointing.py +"""helper function for activation checkpointing""" + +from typing import Union, Dict, Callable +from functools import partial +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + offload_wrapper, + CheckpointImpl, +) + + +# utils.py +"""cascade basic blocks""" + +import math +import backoff +import random +import numpy as np +from typing import Optional, Tuple, Union +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F + + +# conformer_encoder.py +"""ConformerEncoder Module""" + +from typing import Optional, Tuple, List, Literal +import abc +import math +import numpy as np + +import torch +from torch import nn, Tensor + +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper +from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel + + +# activation_checkpointing.py +def validate_checkpointing_config(activation_checkpointing): + """validate activation checkpointing configuration""" + if isinstance(activation_checkpointing, str): + assert activation_checkpointing in ( + "", + "checkpoint", + "offload", + ), "activation_checkpointing has to be a dict or a str in ('', 'checkpoint', 'offload')." + elif isinstance(activation_checkpointing, dict): + assert activation_checkpointing.get("module", "transformer") in ( + "transformer", + "attention", + ), "module in activation_checkpointing has to be in ('transformer', 'attention')." + else: + raise ValueError("activation_checkpointing has to be a str or dict.") + + +def embedding_checkpoint_wrapper( + activation_checkpointing: Union[str, Dict], +) -> Callable: + """return encoder embedding activation checkpoint wrapper""" + validate_checkpointing_config(activation_checkpointing) + + if isinstance(activation_checkpointing, str): + if activation_checkpointing: + if activation_checkpointing == "offload": + return offload_wrapper + return partial(checkpoint_wrapper) + return lambda x: x + + if isinstance(activation_checkpointing, dict): + enabled = activation_checkpointing.get("embed", False) + if enabled: + offloading = activation_checkpointing.get("offload", False) + if offloading: + return offload_wrapper + impl = ( + CheckpointImpl.REENTRANT + if activation_checkpointing.get("reentrant", False) + else CheckpointImpl.NO_REENTRANT + ) + return partial(checkpoint_wrapper, checkpoint_impl=impl) + return lambda x: x + raise ValueError("Invalid activation_checkpointing config") + + +def encoder_checkpoint_wrapper( + activation_checkpointing: Union[str, Dict], + layer_cls: type, + idx: int = 0, +) -> Callable: + """return encoder activation checkpoint wrapper""" + validate_checkpointing_config(activation_checkpointing) + + if isinstance(activation_checkpointing, str): + if activation_checkpointing: + if activation_checkpointing == "offload": + return offload_wrapper + return partial(checkpoint_wrapper) + return lambda x: x + + if isinstance(activation_checkpointing, dict): + target_layer_cls = activation_checkpointing.get("module", "transformer") + if target_layer_cls.lower() == "transformer": + target_layer_cls = ( + "EncoderLayer", + "ConformerEncoderLayer", + ) + elif target_layer_cls.lower() == "attention": + target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention") + checkpointing_interval = activation_checkpointing.get("interval", 1) + offloading = activation_checkpointing.get("offload", False) + impl = ( + CheckpointImpl.REENTRANT + if activation_checkpointing.get("reentrant", True) + else CheckpointImpl.NO_REENTRANT + ) + + if idx % checkpointing_interval == 0 and layer_cls.__name__ in target_layer_cls: + if offloading: + return offload_wrapper + return partial(checkpoint_wrapper, checkpoint_impl=impl) + return lambda x: x + + raise ValueError("Invalid activation_checkpointing config") + + +def attn_checkpointing(activation_checkpointing: Union[str, Dict], i) -> Union[str, Dict]: + """return activation checkpointing config for attention layer""" + if isinstance(activation_checkpointing, str): + return "" + + if isinstance(activation_checkpointing, dict): + target_layer_cls = activation_checkpointing.get("module", "transformer") + checkpointing_interval = activation_checkpointing.get("interval", 1) + if target_layer_cls == "attention" and i % checkpointing_interval == 0: + return activation_checkpointing + return "" + + raise ValueError("Invalid activation_checkpointing config") + + +# utils.py +class Block(nn.Module): + """Block abstract module""" + + def __init__(self, input_size, output_size): + super().__init__() + self.input_size = input_size + self.output_size = output_size + +def get_activation(name="relu"): + """Select an activation function by name + + Args: + name: str + activation function name, + one of ["relu", "gelu", "swish", "sigmoid"], + default "relu". + """ + name = name.lower() + if name == "relu": + return nn.ReLU(inplace=True) + if name == "gelu": + return nn.GELU() + if name == "swish": + return Swish() + if name == "sigmoid": + return torch.nn.Sigmoid() + return nn.Identity() + +def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): + """ + The function is very important for Transformer Transducer Streaming mode + Args: + xs_len (int): sequence length + chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] + left_window (int): how many left chunks can be seen + right_window (int): how many right chunks can be seen. It is used for chunk overlap model. + Returns: + mask (torch.Tensor): a mask tensor for streaming model + Torch 1.0.1 + tensor([[1., 1., 0., 0.], + [0., 1., 1., 0.], + [0., 0., 1., 1.]]) + Torch 1.4.1 + tensor([[True., True., False., False.], + [False., True., True., False.], + [False., False., True., True.]]) + """ + chunk_start_idx = torch.Tensor( + chunk_start_idx + ).long() # first idx of each chunk, such as [0,18,36,48]. + # start_pad = torch.nn.functional.pad( + # chunk_start_idx, (1, 0) + # ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] + # end_pad = torch.nn.functional.pad( + # chunk_start_idx, (0, 1), value=x_len + # ) # append x_len to the end, so it becomes [0,18,36,48, x_len] + start_pad = torch.cat((torch.tensor([0], dtype=torch.int64), chunk_start_idx), dim=0) + end_pad = torch.cat((chunk_start_idx, torch.tensor([x_len], dtype=torch.int64)), dim=0) + seq_range = torch.arange(0, x_len).unsqueeze(-1) # seq_range size: [x_len, 1] + idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] # idx size: [x_len] + boundary = end_pad[idx] # boundary size: [x_len] + seq_range_expand = ( + torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + ) # seq_range_expand size [x_len, x_len] + idx_left = idx - left_window + idx_left[idx_left < 0] = 0 + boundary_left = start_pad[idx_left] + mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) + idx_right = idx + right_window + idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) + boundary_right = end_pad[idx_right] + mask_right = seq_range_expand < boundary_right.unsqueeze(-1) + return mask_left & mask_right + +class Swish(nn.Module): + """Implement Swish activation module. + From https://arxiv.org/pdf/2005.03191.pdf + + """ + + def __init__(self) -> None: + super().__init__() + self.act_fn = nn.Sigmoid() + + def forward(self, x: Tensor) -> Tensor: + """Apply Swish function + + Args: + x: torch.Tensor + Input. + """ + return x * self.act_fn(x) + +class GLU(nn.Module): + """Implement Gated Linear Unit (GLU) module""" + + def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None: + super().__init__() + self.dim = dim + self.act_name = act_name.lower() + + if self.act_name == "relu": + self.act_fn = nn.ReLU(inplace=True) + elif self.act_name == "gelu": + self.act_fn = nn.GELU() + elif self.act_name == "swish": + self.act_fn = Swish() + elif self.act_name == "sigmoid": + self.act_fn = nn.Sigmoid() + else: + self.act_fn = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """GLU forward + Apply Swish function on the first half of input matrices + with sigmoid of the second half. + + Args: + x: torch.Tensor + Input. + + """ + half_x, gate = x.chunk(2, dim=self.dim) + return half_x * self.act_fn(gate) + +# TODO: Abdel, this can be improved using GLU module +class GLUPointWiseConv(nn.Module): + """GLUPointWiseConv module + used for conformer architecture, + for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + output_dim: int + output channel size. + kernel_size: int + kernel size + glu_type: str, optional + activation function one of + ["sigmoid", "relu", "gelu"] + default "sigmoid". + bias_in_glu: bool, optional + use addtive bias in glu + causal: bool, optional + if set to True, padding is set to the half of + kernel size, ie, convolution can't see future frames. + default False. + + """ + + def __init__( + self, input_dim, output_dim, kernel_size, glu_type="sigmoid", bias_in_glu=True, causal=False + ): + super().__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + self.bias_in_glu = bias_in_glu + if causal: + self.ext_pw_conv_1d = nn.Conv1d( + input_dim, output_dim * 2, kernel_size, 1, padding=(kernel_size - 1) + ) + else: + self.ext_pw_conv_1d = nn.Conv1d( + input_dim, output_dim * 2, kernel_size, 1, padding=(kernel_size - 1) // 2 + ) + + if glu_type == "sigmoid": + self.glu_act = nn.Sigmoid() + elif glu_type == "relu": + self.glu_act = nn.ReLU() + elif glu_type == "gelu": + self.glu_act = nn.GELU() + elif glu_type == "swish": + self.glu_act = Swish() + else: + raise ValueError(f"Unsupported activation type {self.glu_act}") + + if bias_in_glu: + self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1)) + self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1)) + + def forward(self, x): + """ + Args: + x: torch.Tensor + input tensor + """ + # to be consistent with GLULinear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = x.permute([0, 2, 1]) + x = self.ext_pw_conv_1d(x) + if self.glu_type == "bilinear": + if self.bias_in_glu: + x = (x[:, 0 : self.output_dim, :] + self.b1) * ( + x[:, self.output_dim : self.output_dim * 2, :] + self.b2 + ) + else: + x = (x[:, 0 : self.output_dim, :]) * ( + x[:, self.output_dim : self.output_dim * 2, :] + ) + else: + if self.bias_in_glu: + x = (x[:, 0 : self.output_dim, :] + self.b1) * self.glu_act( + x[:, self.output_dim : self.output_dim * 2, :] + self.b2 + ) + else: + x = (x[:, 0 : self.output_dim, :]) * self.glu_act( + x[:, self.output_dim : self.output_dim * 2, :] + ) + + x = x.permute([0, 2, 1]) + return x + + +class DepthWiseSeperableConv1d(nn.Module): + """DepthWiseSeperableConv1d module used in Convnet module + for the conformer, for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + depthwise_seperable_out_channel: int + if set different to 0, the number of depthwise_seperable_out_channel + will be used as a channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + kernel_size: int + kernel_size + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + padding: int, optional + padding for the conv1d, + default: 0. + + """ + + def __init__( + self, + input_dim, + depthwise_seperable_out_channel, + kernel_size, + depthwise_multiplier, + padding=0, + ): + super().__init__() + + self.dw_conv = nn.Conv1d( + input_dim, + input_dim * depthwise_multiplier, + kernel_size, + 1, + padding=padding, + groups=input_dim, + ) + + if depthwise_seperable_out_channel != 0: + self.pw_conv = nn.Conv1d( + input_dim * depthwise_multiplier, depthwise_seperable_out_channel, 1, 1, 0 + ) + else: + self.pw_conv = nn.Identity() + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + + def forward(self, x): + """ + + Args: + x: torch.Tensor + input tensor + """ + x = self.dw_conv(x) + if self.depthwise_seperable_out_channel != 0: + x = self.pw_conv(x) + return x + + +class ConvModule(nn.Module): + """ConvModule Module for the conformer block. + for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + ext_pw_out_channel: int + if > 0, ext_pw_out_channel is a dim channel size + for the last pointwise conv after swish activation. + depthwise_seperable_out_channel: int + if set different to 0, the number of depthwise_seperable_out_channel + will be used as a channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + ext_pw_kernel_size: int + kernel size of the conv pointwise of the conformer. + kernel_size: int + kernel size. + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + dropout_rate: float + dropout rate. + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation. + default False + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + chunk_size: int, optional + chunk size for cnn. default 18 + activation: str, optional + activation function used in ConvModule, + default: "relu". + glu_type: str, optional + activation function used for the glu, + default: "sigmoid". + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + export: bool, optional, + if set to True, padding is equal to 0. This is for inference, + or onnx export. Typically this is set by the export program or + the decoder program, and it isn't present in your config file. + default False + """ + + def __init__( + self, + input_dim, + ext_pw_out_channel, + depthwise_seperable_out_channel, + ext_pw_kernel_size, + kernel_size, + depthwise_multiplier, + dropout_rate, + causal=False, + batch_norm=False, + chunk_se=0, + chunk_size=18, + activation="relu", + glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + export=False, + ): + super().__init__() + self.layer_norm = nn.LayerNorm(input_dim) + self.input_dim = input_dim + self.ext_pw_out_channel = ext_pw_out_channel + self.ext_pw_kernel_size = ext_pw_kernel_size + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + self.glu_type = glu_type + self.bias_in_glu = bias_in_glu + self.linear_glu_in_convm = linear_glu_in_convm + self.causal = causal + + self._add_ext_pw_layer() + + self.batch_norm = batch_norm + self.kernel_size = kernel_size + + if batch_norm: + self.bn_layer = nn.BatchNorm1d(input_dim) + + self.act = get_activation(activation) + self.dropout = nn.Dropout(dropout_rate) + self.export = export + + if causal: + if export: # Inference only. + padding = 0 # A cache is concatenated to the left. No padding in the kernel. + else: + # Training only. Padding will be added symmetrically on both sides. + # After convolution, clip off kernel_size-1 points on the right. + padding = kernel_size - 1 + else: + padding = (kernel_size - 1) // 2 + + self.dw_sep_conv_1d = DepthWiseSeperableConv1d( + input_dim, + depthwise_seperable_out_channel, + kernel_size, + depthwise_multiplier, + padding=padding, + ) + + if depthwise_seperable_out_channel != 0: + if input_dim != depthwise_seperable_out_channel: + self.ln2 = nn.Linear(depthwise_seperable_out_channel, input_dim) + else: + if depthwise_multiplier != 1: + self.ln2 = nn.Linear(input_dim * depthwise_multiplier, input_dim) + + def _add_ext_pw_layer(self): + """ + This function is an extension of __init__ function + and dedicated to the convolution module creation + of the conformer. + """ + self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = nn.Identity() # jit hacks. + self.squeeze_excitation = nn.Identity() # jit. + self.apply_ln1 = self.fix_len1 = False # jit. + + if self.ext_pw_out_channel != 0: + if self.causal: + self.ext_pw_conv_1d = nn.Conv1d( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + 1, + padding=(self.ext_pw_kernel_size - 1), + ) + if self.ext_pw_kernel_size > 1: + self.fix_len1 = True + else: + self.fix_len1 = False + else: + self.ext_pw_conv_1d = nn.Conv1d( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + 1, + padding=(self.ext_pw_kernel_size - 1) // 2, + ) + self.fix_len1 = False + + if self.linear_glu_in_convm: + self.glu = GLULinear( + self.input_dim, self.ext_pw_out_channel, self.glu_type, self.bias_in_glu + ) + else: + self.glu = GLUPointWiseConv( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + self.glu_type, + self.bias_in_glu, + self.causal, + ) + + if self.input_dim != self.ext_pw_out_channel: + self.apply_ln1 = True + self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim) + else: + self.apply_ln1 = False + else: + self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3)) + self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3)) + + def forward(self, x): + """ConvModule Forward. + + Args: + x: torch.Tensor + input tensor. + """ + x = self.layer_norm(x) + + if self.ext_pw_out_channel != 0: + x = self.glu(x) + if self.causal and self.ext_pw_kernel_size > 1: + x = x[:, : -(self.ext_pw_kernel_size - 1), :] + if self.apply_ln1: + x = self.ln1(x) + else: + x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0] + x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1] + x = x_0 + x_1 + + x = x.permute([0, 2, 1]) + + x = self.dw_sep_conv_1d(x) + if self.causal and self.kernel_size > 1: + x = x[:, :, : -(self.kernel_size - 1)] + if hasattr(self, "ln2"): + x = x.permute([0, 2, 1]) + x = self.ln2(x) + x = x.permute([0, 2, 1]) + if self.batch_norm: + x = self.bn_layer(x) + x = self.act(x) + + if self.ext_pw_out_channel != 0: + x = self.ext_pw_conv_1d(x) + if self.fix_len1: + x = x[:, :, : -(self.ext_pw_kernel_size - 1)] + + if self.apply_ln1: + x = x.permute([0, 2, 1]) + x = self.ln1(x) + x = x.permute([0, 2, 1]) + + x = x.permute([0, 2, 1]) + else: + x = x.unsqueeze(1).permute([0, 1, 3, 2]) + x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2] + x = x.squeeze(1) + + x = self.dropout(x) + return x + +class GLULinear(nn.Module): + """Linear + GLU module + + Args: + input_dim: int + input size + output_dim: int + output size. + glu_type: + activation function name used in glu module. + default "sigmoid" (swish function). + bias_in_glu: bool, optional + If True, the addtive bias is added. Default False. + """ + + def __init__( + self, + input_dim, + output_dim, + glu_type="sigmoid", + bias_in_glu=True, + ): + super().__init__() + self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu) + self.glu_act = GLU(-1, glu_type) + + def forward(self, x): + """GLULinear forward + + Args: + x: torch.Tensor + inpute tensor. + """ + x = self.linear(x) + return self.glu_act(x) + +class FeedForward(nn.Module): + """FeedForward Module. + For more details see Conformer paper: + https://arxiv.org/pdf/2005.08100.pdf + + Args: + d_model: int + input size. + d_inner: int + output size. + dropout_rate: float, + dropout rate. + activation: str, + activation function name, + one of ["relu", "swish", "sigmoid"], + sigmoid activation is only used with "glu_in_fnn=True", + default "sigmoid". + bias_in_glu: bool, optional + """ + + def __init__( + self, + d_model, + d_inner, + dropout_rate, + activation="sigmoid", + bias_in_glu=True, + ): + super().__init__() + self.d_model = d_model + self.d_inner = d_inner + + self.layer_norm = nn.LayerNorm(d_model) + module = GLULinear(d_model, d_inner, activation, bias_in_glu) + self.net = nn.Sequential( + module, + nn.Dropout(dropout_rate), + nn.Linear(d_inner, d_model), + nn.Dropout(dropout_rate), + ) + + def forward(self, x): + """FeedForward forward function. + + Args: + x: torch.Tensor + input tensor. + """ + out = self.net(self.layer_norm(x)) + + return out + +#### positional encoding starts here +def _pre_hook( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs +): + """Perform pre-hook in load_state_dict for backward compatibility. + + Note: + We saved self.pe until v.0.5.2 but we have omitted it later. + Therefore, we remove the item "pe" from `state_dict` for backward compatibility. + + """ + k = prefix + "pe" + if k in state_dict: + state_dict.pop(k) + +class T5RelativeAttentionLogitBias(nn.Module): + """ + This module implements the relative position bias described in Section 2.1 of + the T5 paper: https://arxiv.org/pdf/1910.10683.pdf + + The Huggingface implementation is used as a reference + https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/t5/modeling_t5.py#L435 + + Modifies attention as Q*K^T + B, where B is a learned scalar bias based on relative position + of the query and key. It is HxNxN, where H is the number of heads, N is the sequence length. + + I've made these modifications to the original T5 bias: + - Skipping of the bucketing step. Original T5 bias converted rel position distances into + logarithmically increasing buckets. This is supposed to help with length generalization. + - I just directly use rel position index as bias values, as we don't need length + generalization (40s max is good enough for ASR encoder), and it keeps ONNX export simple. + - I've also extended it so that biases can be asymmetric, the default implementation treats + L->R and R->L the same. Asymmetric was found to yield better results in my experiments. + + Args: + num_heads: int + Number of attention heads + num_buckets: int + Number of buckets to use for relative attention bias. This is the size of the learnable + bias parameter. Bucketing is not yet supported, so this defaults to -1 which means + no bucketing is used (max_distance determines size of bias param). + max_distance: int + Maximum distance to use for relative attention bias. With num_buckets=-1, this directly + controls the max size of the bias parameter. When num_buckets > 0 is supported, this + will control the maximum distance for logarithmic bucketing after which all positions + are in the same bucket. + symmetric: bool + Whether to use symmetric or asymmetric biases. symmetric=False uses 2x number of bias + params to distinguish L->R from R->L. This was found to be better for the encoder. + """ + + def __init__(self, num_heads, num_buckets=-1, max_distance=1000, symmetric=False): + super().__init__() + self.num_heads = num_heads + self.num_buckets = num_buckets + self.max_distance = max_distance + self.symmetric = symmetric + self._skip_bucketing = self.num_buckets < 0 + if self._skip_bucketing: + self.num_buckets = max_distance + else: + raise NotImplementedError("T5 attention bias with bucketed positions is not yet tested") + if not self.symmetric: + self.num_buckets *= 2 + self.bias_values = nn.Embedding(self.num_buckets, self.num_heads) + + def forward(self, x): + # instantiate bias compatible with shape of x + maxpos = x.size(1) + context_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[:, None] + memory_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + # clipping to a maximum distance using ops that play well with ONNX export + relative_position = relative_position.masked_fill( + relative_position < -self.max_distance, -self.max_distance + ) + relative_position = relative_position.masked_fill( + relative_position > self.max_distance - 1, self.max_distance - 1 + ) + + # mapping from relative position to index in the bias parameter + if self._skip_bucketing: + bias_idx = relative_position + else: + bias_idx = self._bucket_relative_position(relative_position) + if self.symmetric: + bias_idx = bias_idx.abs() + else: + bias_idx += self.num_buckets // 2 + + t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H] + t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(0) # [1, H, L, L] + + return t5_rel_att_bias + + def _bucket_relative_position(self, relative_position): + # this is a placeholder (isn't tested, likely buggy) using HuggingFace implem as a reference + # this also needs to be extended to support asymmetric +/- ve positions + relative_buckets = 0 + if not self.causal: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(self.max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + +class AbsolutePositionalEncoding(nn.Module): + """Absolute Positional encoding module. + This module implement Absolute sinusoidal positional encoding + from: https://arxiv.org/pdf/1706.03762.pdf + + Args: + d_model: int + Input embedding size. + dropout_rate: float + dropout rate + max_len: int, optional + Maximum input length sequence, Default 5000 + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Construct an PositionalEncoding object.""" + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self._register_load_state_dict_pre_hook(_pre_hook) + + def extend_pe(self, x): + """Reset the positional encodings. + + Args: + x: torch.Tensor + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + + Args: + x: torch.Tensor + Input tensor. shape is (batch, time, ...) + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + +#### forward embedding layers starts here + +@backoff.on_exception(backoff.expo, Exception, max_tries=10) +def np_loadtxt_with_retry(filepath): + """np.loadtxt with retry + + Args: + filepath: str + file path to the numpy array. + """ + result = np.loadtxt(filepath, dtype="f") + return result + +class MeanVarianceNormLayer(nn.Module): + """Mean/variance normalization layer. + + Will substract mean and multiply input by inverted standard deviation. + Typically used as a very first layer in a model. + + Args: + input_size: int + layer input size. + """ + + def __init__(self, input_size): + super().__init__() + self.input_size = input_size + self.register_buffer("global_mean", torch.zeros(input_size)) + self.register_buffer("global_invstd", torch.ones(input_size)) + self.global_mean: Optional[Tensor] + self.global_invstd: Optional[Tensor] + + def forward(self, input_: Tensor) -> Tensor: + """MeanVarianceNormLayer Forward + + Args: + input_: torch.Tensor + input tensor. + """ + return (input_ - self.global_mean) * self.global_invstd + + def load_mean_invstd(self, mean_file, invstd_file, cuside_features=False): + """Load feature mean and variance used for normalization. + + Args: + mean_file: str + path to the feature mean statistics file. + invstd_file: str + path to the features inverted standard deviation + statistics file. + cuside_features: bool + Boolean that indicates CUSIDE is being used. + The statistics of CUSIDE features are copied + from the normal features + """ + self.global_mean.data = torch.from_numpy(np_loadtxt_with_retry(mean_file)) + self.global_invstd.data = torch.from_numpy(np_loadtxt_with_retry(invstd_file)) + + if cuside_features: + self.global_mean.data = torch.cat((self.global_mean.data, self.global_mean.data), 0) + self.global_invstd.data = torch.cat( + (self.global_invstd.data, self.global_invstd.data), 0 + ) + +class CausalConv1D(nn.Conv1d): + """ + A causal version of nn.Conv1d where each step would have limited access to locations on its right or left + All arguments are the same as nn.Conv1d except padding. + + If padding is set None, then paddings are set automatically to make it a causal convolution where each location would not see any steps on its right. + + If padding is set as a list (size of 2), then padding[0] would be used as left padding and padding[1] as right padding. + It would make it possible to control the number of steps to be accessible on the right and left. + This mode is not supported when stride > 1. padding[0]+padding[1] should be equal to (kernel_size - 1). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: Union[str, int] = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + self.cache_drop_size = None + if padding is None: + self._left_padding = kernel_size - 1 + self._right_padding = stride - 1 + else: + if stride != 1 and padding != kernel_size - 1: + raise ValueError("No striding allowed for non-symmetric convolutions!") + if isinstance(padding, int): + self._left_padding = padding + self._right_padding = padding + elif ( + isinstance(padding, list) + and len(padding) == 2 + and padding[0] + padding[1] == kernel_size - 1 + ): + self._left_padding = padding[0] + self._right_padding = padding[1] + else: + raise ValueError(f"Invalid padding param: {padding}!") + + self._max_cache_len = self._left_padding + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def update_cache(self, x, cache=None): + if cache is None: + new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) + next_cache = cache + else: + new_x = F.pad(x, pad=(0, self._right_padding)) + new_x = torch.cat([cache, new_x], dim=-1) + if self.cache_drop_size > 0: + next_cache = new_x[:, :, : -self.cache_drop_size] + else: + next_cache = new_x + next_cache = next_cache[:, :, -cache.size(-1) :] + return new_x, next_cache + + def forward(self, x, cache=None): + x, cache = self.update_cache(x, cache=cache) + x = super().forward(x) + if cache is None: + return x + else: + return x, cache + + +class CausalConv2D(nn.Conv2d): + """ + A causal version of nn.Conv2d where each location in the 2D matrix would have no access to locations on its right or down + All arguments are the same as nn.Conv2d except padding which should be set as None + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: Union[str, int] = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + if padding is not None: + raise ValueError("Argument padding should be set to None for CausalConv2D.") + self._left_padding = kernel_size - 1 + self._right_padding = stride - 1 + + padding = 0 + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + + def forward( + self, + x, + ): + if self.training: + x = F.pad( + x, + pad=( + self._left_padding, + self._right_padding, + self._left_padding, + self._right_padding, + ), + ) + else: + x = F.pad( + x, + pad=(self._left_padding, self._right_padding, 0, 0), + ) + x = super().forward(x) + return x + + +class NemoConvSubsampling(torch.nn.Module): + """Convlutional subsampling module, taken from NeMo ASR + (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a34501479cf/nemo/collections/asr/parts/submodules/subsampling.py) + + Striding Subsampling: "Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for + Speech Recognition" by Linhao Dong et al. (https://ieeexplore.ieee.org/document/8462506) + + + Compared with the EncoderConv2D (`input_layer: custom`), this is a much simplified approach, + and uses no LayerNorm and far fewer Conv2Ds. Moreover, depthwise convolutions are used to reduce + FLOPs, but the first layer is kept as a regular convolution so as not to degrade accuracy. + + `Striding` and `dw_striding` are the same except that the latter uses depthwise convolutions + after the first layer, whereas the former does not. + + Args: + subsampling_factor (int): Time reduction factor + feat_in (int): size of the input features + feat_out (int): size of the output features + subsampling (str): The subsampling technique, choose from + {"striding", "dw-striding", "striding_conv1d", "dw_striding_conv1d"} + conv_channels (int): Number of channels for the convolution layers, default is 256. + subsampling_conv_chunking_factor (int): Input chunking factor which can be -1 (no chunking) + 1 (auto) or a power of 2. Default is 1 + activation (Module): activation function, default is nn.ReLU() + is_causal (bool): whether to use causal Conv1/2D, where each step will have limited access + to locations on its right or left + """ + + def __init__( + self, + feat_in, + feat_out, + subsampling_factor=4, + subsampling="dw_striding", + conv_channels=256, + subsampling_conv_chunking_factor=1, + activation=nn.ReLU(), + is_causal=False, + ): + super().__init__() + self._subsampling = subsampling + self._conv_channels = conv_channels + self._feat_in = feat_in + self._feat_out = feat_out + + if subsampling_factor % 2 != 0: + raise ValueError("Sampling factor should be a multiply of 2!") + self._sampling_num = int(math.log(subsampling_factor, 2)) + self.subsampling_factor = subsampling_factor + self.is_causal = is_causal + self.subsampling_causal_cond = subsampling in ("dw_striding", "striding", "striding_conv1d") + + if ( + subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0 + ): + raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2") + self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + + in_channels = 1 + layers = [] + + if subsampling == "dw_striding": + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + # Layer 1 + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + ) + ) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + ) + ) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + groups=in_channels, + ) + ) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ) + ) + + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + ) + ) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "striding": + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + for i in range(self._sampling_num): + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + ) + ) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + ) + ) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "striding_conv1d": + in_channels = feat_in + + self._stride = 2 + self._kernel_size = 5 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + for i in range(self._sampling_num): + if self.is_causal: + layers.append( + CausalConv1D( + in_channels=in_channels, + out_channels=feat_out if self._sampling_num == i + 1 else conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + ) + ) + else: + layers.append( + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=feat_out if self._sampling_num == i + 1 else conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + ) + ) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "dw_striding_conv1d": + in_channels = feat_in + + self._stride = 2 + self._kernel_size = 5 + self._ceil_mode = False + + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + + # Layer 1 + layers.extend( + [ + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ), + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=feat_out if self._sampling_num == 1 else conv_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + ), + ] + ) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + layers.extend( + [ + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ), + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=feat_out if self._sampling_num == i + 2 else conv_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + ), + ] + ) + layers.append(activation) + in_channels = conv_channels + + else: + raise ValueError(f"Not valid sub-sampling: {subsampling}!") + + if subsampling in ["dw_striding", "striding"]: + in_length = torch.tensor(feat_in, dtype=torch.float) + out_length = calc_length( + lengths=in_length, + all_paddings=self._left_padding + self._right_padding, + kernel_size=self._kernel_size, + stride=self._stride, + ceil_mode=self._ceil_mode, + repeat_num=self._sampling_num, + ) + self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out) + self.conv2d_subsampling = True + elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]: + self.out = None + self.conv2d_subsampling = False + else: + raise ValueError(f"Not valid sub-sampling: {subsampling}!") + + self.conv = torch.nn.Sequential(*layers) + + def get_sampling_frames(self): + return [1, self.subsampling_factor] + + def get_streaming_cache_size(self): + return [0, self.subsampling_factor + 1] + + def forward(self, x, mask): + """ + Forward method for NeMo subsampling. + + Args: + x[Batch, Time, Filters]: torch.Tensor + input tensor + x_mask: torch.Tensor + input mask + + Returns: + x: torch.Tensor + Resulting tensor from subsampling (B, T // time_reduction_factor, feat_out) + pad_mask: torch.Tensor + tensor of padded hidden state sequences (B, 1, T // time_reduction_factor) + """ + batch_size = x.shape[0] + # Unsqueeze Channel Axis + if self.conv2d_subsampling: + x = x.unsqueeze(1) + # Transpose to Channel First mode + else: + x = x.transpose(1, 2) + + # split inputs if chunking_factor is set + if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling: + if self.subsampling_conv_chunking_factor == 1: + # if subsampling_conv_chunking_factor is 1, we split only if needed + # avoiding a bug / feature limiting indexing of tensors to 2**31 + # see https://github.com/pytorch/pytorch/issues/80020 + x_ceil = 2**31 / self._conv_channels * self._stride * self._stride + if torch.numel(x) > x_ceil: + need_to_split = True + else: + need_to_split = False + else: + # if subsampling_conv_chunking_factor > 1 we always split + need_to_split = True + + if need_to_split: + x, success = self.conv_split_by_batch(x) + if not success: # if unable to split by batch, try by channel + if self._subsampling == "dw_striding": + x = self.conv_split_by_channel(x) + else: + x = self.conv(x) # try anyway + else: + x = self.conv(x) + else: + x = self.conv(x) + + # Flatten Channel and Frequency Axes + if self.conv2d_subsampling: + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).reshape(b, t, -1)) + # Transpose to Channel Last mode + else: + x = x.transpose(1, 2) + + max_audio_length = x.shape[1] + feature_lens = mask.sum(1) + padding_length = torch.ceil(feature_lens.to(torch.float32) / float(self.subsampling_factor)).to(torch.int64) + if self.is_causal and self.subsampling_causal_cond: + feature_lens_remainder = feature_lens % self.subsampling_factor + padding_length[feature_lens_remainder != 1] += 1 + pad_mask = ( + torch.arange(0, max_audio_length, device=x.device).expand(padding_length.size(0), -1) + < padding_length.unsqueeze(1) + ) + + condition = torch.full_like(pad_mask, batch_size != 1).bool() + pad_mask = pad_mask * condition + return x, pad_mask.unsqueeze(1) + + def reset_parameters(self): + # initialize weights + if self._subsampling == "dw_striding": + with torch.no_grad(): + # init conv + scale = 1.0 / self._kernel_size + dw_max = (self._kernel_size**2) ** -0.5 + pw_max = self._conv_channels**-0.5 + + torch.nn.init.uniform_(self.conv[0].weight, -scale, scale) + torch.nn.init.uniform_(self.conv[0].bias, -scale, scale) + + for idx in range(2, len(self.conv), 3): + torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max) + torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max) + torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max) + torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max) + + # init fc (80 * 64 = 5120 from https://github.com/kssteven418/Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/src/models/conformer_encoder.py#L487 + fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5 + torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) + torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) + + def conv_split_by_batch(self, x): + """Tries to split input by batch, run conv and concat results""" + b, _, _, _ = x.size() + if b == 1: # can't split if batch size is 1 + return x, False + + if self.subsampling_conv_chunking_factor > 1: + cf = self.subsampling_conv_chunking_factor + else: + # avoiding a bug / feature limiting indexing of tensors to 2**31 + # see https://github.com/pytorch/pytorch/issues/80020 + x_ceil = 2**31 / self._conv_channels * self._stride * self._stride + p = math.ceil(math.log(torch.numel(x) / x_ceil, 2)) + cf = 2**p + + new_batch_size = b // cf + if new_batch_size == 0: # input is too big + return x, False + + return torch.cat([self.conv(chunk) for chunk in torch.split(x, new_batch_size, 0)]), True + + def conv_split_by_channel(self, x): + """For dw convs, tries to split input by time, run conv and concat results""" + x = self.conv[0](x) # full conv2D + x = self.conv[1](x) # activation + + for i in range(self._sampling_num - 1): + _, c, t, _ = x.size() + + if self.subsampling_conv_chunking_factor > 1: + cf = self.subsampling_conv_chunking_factor + else: + # avoiding a bug / feature limiting indexing of tensors to 2**31 + # see https://github.com/pytorch/pytorch/issues/80020 + p = math.ceil(math.log(torch.numel(x) / 2**31, 2)) + cf = 2**p + + new_c = int(c // cf) + if new_c == 0: + new_c = 1 + + new_t = int(t // cf) + if new_t == 0: + new_t = 1 + + x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c, x) # conv2D, depthwise + + # splitting pointwise convs by time + x = torch.cat( + [self.conv[i * 3 + 3](chunk) for chunk in torch.split(x, new_t, 2)], 2 + ) # conv2D, pointwise + x = self.conv[i * 3 + 4](x) # activation + return x + + def channel_chunked_conv(self, conv, chunk_size, x): + """Performs channel chunked convolution""" + + ind = 0 + out_chunks = [] + for chunk in torch.split(x, chunk_size, 1): + step = chunk.size()[1] + + if self.is_causal: + chunk = nn.functional.pad( + chunk, + pad=( + self._kernel_size - 1, + self._stride - 1, + self._kernel_size - 1, + self._stride - 1, + ), + ) + ch_out = nn.functional.conv2d( + chunk, + conv.weight[ind : ind + step, :, :, :], + bias=conv.bias[ind : ind + step], + stride=self._stride, + padding=0, + groups=step, + ) + else: + ch_out = nn.functional.conv2d( + chunk, + conv.weight[ind : ind + step, :, :, :], + bias=conv.bias[ind : ind + step], + stride=self._stride, + padding=self._left_padding, + groups=step, + ) + out_chunks.append(ch_out) + ind += step + + return torch.cat(out_chunks, 1) + + def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_factor: int): + if ( + subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0 + ): + raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2") + self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + + +def calc_length(lengths, all_paddings, kernel_size, stride, ceil_mode, repeat_num=1): + """Calculates the output length of a Tensor passed through a convolution or max pooling layer""" + add_pad: float = all_paddings - kernel_size + one: float = 1.0 + for i in range(repeat_num): + lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one + if ceil_mode: + lengths = torch.ceil(lengths) + else: + lengths = torch.floor(lengths) + return lengths.to(dtype=torch.int) + +#### multihead attention starts here +class AttModule(nn.Module): + """Attention abstraction module""" + + def __init__(self): + super().__init__() + self.export_mode = False + + def set_export(self, mode=True): + """set the export mode""" + self.export_mode = mode + + def forward( + self, + x: Tensor, + memory: Optional[Tensor] = None, + pos_emb: Optional[Tensor] = None, + att_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + """AttModule forward + + Args: + x: torch.Tensor + input tensor. + memory: torch.Tensor, optional + memory tensor. + pos_emb: torch.Tensor, optional + positional encoder embedding. + att_mask: torch.Tensor, optional + attention mask tensor. + """ + return x, memory, pos_emb, att_mask + + +class AttBlock(Block, AttModule): + """Attention Block module to support both Attention and Block module.""" + + def memory_dims(self, max_len=False): + """memory dimensions""" + return (1, self.input_size) + +def masked_softmax( + scores, + mask: Optional[Tensor], +): + if mask is not None: + # mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) + scores = scores.masked_fill(mask, -torch.inf) + attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + return attn + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer with optional relative position embedding and GLU. + + Args: + n_head: int + the number of heads. + n_feat: int + input size features. + dropout_rate: float + dropout rate. + use_LN: bool + apply layer norm or not + dropout_at_output: bool + whether to apply dropout at output + attention_inner_dim: int, optional + the attention dimension used in the class, + it can be different from the input dimension n_feat. + default: -1 (equal to n_feat). + use_pt_scaled_dot_product_attention: bool, optional + if set True, use pytorch scaled dot product attention in training. NOTE: this will NOT + be used in ONNX decoding due to a lack of support. In that case, we use the original + attention implementation, which shows no regression. + default: False. + n_value: int, optional + if set to values other than -1, use a different dimension for value. With the default value (i.e. -1), it is backward compatible. + group_size: int, optional. must divide `n_head` + if group_size > 1: GQA + if group_size = 1: MHA + if group_size = n_head: MQA + """ + + inv_sqrt_d_k: torch.jit.Final[float] + h: torch.jit.Final[int] + h_k: torch.jit.Final[int] + g: torch.jit.Final[int] + + def __init__( + self, + n_head, + n_feat, + dropout_rate, + attention_inner_dim=-1, + glu_type="swish", + bias_in_glu=True, + use_pt_scaled_dot_product_attention=False, + n_value=-1, + group_size: int = 1, + ): + super().__init__() + if n_value == -1: + n_value = n_feat + if attention_inner_dim == -1: + attention_inner_dim = n_feat + assert attention_inner_dim % n_head == 0 + + # We assume d_v always equals d_k + self.d_k = attention_inner_dim // n_head + self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k) + self.h = n_head + assert n_head % group_size == 0, "group_size must divide n_head" + self.g = group_size + self.h_k = n_head // group_size + + self.linear_q = nn.Linear(n_feat, attention_inner_dim) + self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size) + self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size) + self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value) + + self.attn = torch.jit.Attribute(None, Optional[Tensor]) + self.dropout = nn.Dropout(p=dropout_rate) + self.dropout_rate = dropout_rate + self.use_pt_scaled_dot_product_attention = use_pt_scaled_dot_product_attention + + if use_pt_scaled_dot_product_attention and group_size > 1: + raise ValueError("Cannot use PT Scaled Attention with GQA") + + # Torchscript eager quantization. Note that these functions below are + # NOOPs and have very little impact on performance unless quantization is + # enabled. + self.quant_q = torch.ao.quantization.QuantStub() + self.quant_x = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + self.ffunc = torch.ao.nn.quantized.FloatFunctional() + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_k: Tensor, + pos_v: Tensor, + mask: Optional[Tensor], + relative_attention_bias: Optional[Tensor] = None, + ): + """Compute 'Scaled Dot Product Attention'. + + Args: + query: torch.Tensor + query tensor (batch, time1, size) + key: torch.Tensor + key tensor (batch, time2, size) + value: torch.Tensor + value tensor (batch, time1, size) + pos_k: torch.Tensor + key tensor used for relative positional embedding. + pos_v: torch.Tensor + value tensor used for relative positional embedding. + mask: torch.Tensor + mask tensor (batch, time1, time2) + relative_attention_bias: torch.Tensor + bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) + """ + # self.h = num_heads + # self.h_k = num_kv_heads + # self.d_k = head_size + n_batch, seq_len, _ = query.size() + + q = self.linear_q(query).view(n_batch, seq_len, self.h, self.d_k) # (b, t, d) + k = self.linear_k(key).view(n_batch, seq_len, self.h_k, self.d_k) # (b, t, d) + v = self.linear_v(value).view(n_batch, seq_len, self.h_k, self.d_k) + q = ( + q.transpose(1, 2) + if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting() + else q.transpose(1, 2) * self.inv_sqrt_d_k + ) + k = k.transpose(1, 2) # (batch, head_k, time2, d_k) + v = v.transpose(1, 2) # (batch, head_k, time2, d_k) + + if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting(): + attn_mask = None + if mask is not None: + mask = mask.unsqueeze(1) + if relative_attention_bias is not None: + attn_mask = mask + relative_attention_bias + else: + attn_mask = mask + if mask.dtype != q.dtype: + attn_mask = attn_mask.to(q.dtype) + + with torch.backends.cuda.sdp_kernel( + enable_flash=True, enable_math=True, enable_mem_efficient=True + ): + x = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=self.dropout_rate, + ) + else: + if self.h != self.h_k: + q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k) + A = torch.einsum("b g h t d, b h s d -> b h t s", q, k) + else: + A = torch.matmul(q, k.transpose(-2, -1)) + if pos_k is not None: + if self.h != self.h_k: + B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k) + else: + reshape_q = ( + q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0, 1) + ) # (t1,nh,dk) + B = torch.matmul(reshape_q, pos_k.transpose(-2, -1)) # pos_k: (t1,dk,t2) + B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1)) + scores = A + B + else: + scores = A + + if relative_attention_bias is not None: + scores = scores + relative_attention_bias + + attn = masked_softmax(scores, mask) # (batch, head, time1, time2) + + self.attn = attn + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn.to(v.dtype), v) # (batch, head, time1, d_k) + if pos_v is not None: + reshape_attn = ( + p_attn.contiguous() + .view(n_batch * self.h, pos_v.size(0), pos_v.size(1)) + .transpose(0, 1) + ) # (t1, bh, t2) + + attn_v = ( + torch.matmul(reshape_attn, pos_v) + .transpose(0, 1) + .contiguous() + .view(n_batch, self.h, pos_v.size(0), self.d_k) + ) + x = x + attn_v + x = ( + x.transpose(1, 2).contiguous().view(n_batch, seq_len, self.h_k * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + +def unfold_tensor(xs_pad, max_seq_len): + """ + For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len, + this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. + Args: + xs_pad: N, T, D + """ + _, _, D = xs_pad.shape + xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T + # N x D x 1 x T => N x (D x max_seq_len) x T' + xs_pad = F.unfold( + xs_pad[..., None, :], + kernel_size=(1, max_seq_len), + stride=(1, max_seq_len), + ) + + new_bsz, _, slen = xs_pad.shape + # N x D x max_seq_len x T' + xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen) + # N x T' x max_seq_len x D + xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous() + # NT' x max_seq_len x D + xs_pad = xs_pad.view(-1, max_seq_len, D) + return xs_pad + +# conformer_encoder.py +class MultiSequential(torch.nn.Sequential): + """Multi-input multi-output torch.nn.Sequential""" + + # @torch.jit.ignore + def forward(self, *args): + """Forward method implementation.""" + for m in self: + args = m(*args) + return args + +def repeat(repeat_num, module_gen_fn): + """repeat module N times + + :param int repeat_num: repeat time + :param function module_gen_fn: function to generate module + :return: repeated modules + :rtype: MultiSequential + """ + return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)]) + +class ConformerEncoderLayer(nn.Module): + """ConformerEncoder Layer module. + for more details see conformer paper: + https://arxiv.org/abs/2005.08100 + This module implement the Conformer block layer. + + Args: + d_model: int + attention dim. + ext_pw_out_channel: int + if > 0, ext_pw_out_channel is a dim channel size + for the last pointwise conv after swish activation. + depthwise_seperable_out_channel: int + if set different to 0, the number of depthwise_seperable_out_channel + will be used as a channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + n_head: int + the number of heads for multihead attention module. + d_ffn: int + output size of the feed_forward blocks. + ext_pw_kernel_size: int + kernel size of the conv pointwise of the conformer. + kernel_size: int + kernel size. + dropout_rate: float + dropout rate. + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation + in ConvModule layer of the conformer. + default False + activation: str, optional + activation function name, + one of ["relu", "swish", "sigmoid"], + sigmoid activation is only used with "glu_in_fnn=True", + default "relu". + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + default 0. + chunk_size: int, optional + chunk_size for cnn. default 18 + conv_activation: str, optional + activation function used in ConvModule part + of the conformer, default "relu". + conv_glu_type: str, optional + activation function used for the glu inside + the ConvModule part of the conformer. + default: "sigmoid". + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + attention_innner_dim: int, otional + if equal to -1, attention dim for linears k/q/v is + equal to d_model. otherwise attention_innner_dim is used. + default -1. + attention_glu_type: str, optional + activation function for glu used in the multihead attention, + default "swish". + activation_checkpointing: str, optional + a dictionarry of {"module","interval","offload"}, where + "module": str + accept ["transformer", "attention"] to select + which module should do activation checkpointing. + "interval": int, default 1, + interval of applying activation checkpointing, + interval = 1 means that we apply checkpointing + on every layer (if activation), otherwise, + we apply it every x interval. + "offload": bool, default False, + if set to True, we offload activation to cpu and + reload it during backward, otherwise, + we recalculate activation in backward. + default "". + export: bool, optional + if set to True, it remove the padding from convolutional layers + and allow the onnx conversion for inference. + default False. + use_pt_scaled_dot_product_attention: bool, optional + if set to True, use pytorch's scaled dot product attention implementation in training. + attn_group_sizes: int, optional + the number of groups to use for attention, default 1 (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attn_group_sizes < attention_heads = Grouped-Query Attention + attn_group_sizes = attenion_heads = Multi-Query Attention + """ + + def __init__( + self, + d_model=512, + ext_pw_out_channel=0, + depthwise_seperable_out_channel=256, + depthwise_multiplier=1, + n_head=4, + d_ffn=2048, + ext_pw_kernel_size=1, + kernel_size=3, + dropout_rate=0.1, + causal=False, + batch_norm=False, + activation="relu", + chunk_se=0, + chunk_size=18, + conv_activation="relu", + conv_glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + attention_innner_dim=-1, + attention_glu_type="swish", + activation_checkpointing="", + export=False, + use_pt_scaled_dot_product_attention=False, + attn_group_sizes: int = 1, + ): + super().__init__() + + self.feed_forward_in = FeedForward( + d_model=d_model, + d_inner=d_ffn, + dropout_rate=dropout_rate, + activation=activation, + bias_in_glu=bias_in_glu, + ) + + self.self_attn = encoder_checkpoint_wrapper( + activation_checkpointing, + MultiHeadedAttention, + )( + MultiHeadedAttention( + n_head, + d_model, + dropout_rate, + attention_innner_dim, + attention_glu_type, + bias_in_glu, + use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, + group_size=attn_group_sizes, + ) + ) + self.conv = ConvModule( + d_model, + ext_pw_out_channel, + depthwise_seperable_out_channel, + ext_pw_kernel_size, + kernel_size, + depthwise_multiplier, + dropout_rate, + causal, + batch_norm, + chunk_se, + chunk_size, + conv_activation, + conv_glu_type, + bias_in_glu, + linear_glu_in_convm, + export=export, + ) + + self.feed_forward_out = FeedForward( + d_model=d_model, + d_inner=d_ffn, + dropout_rate=dropout_rate, + activation=activation, + bias_in_glu=bias_in_glu, + ) + + self.layer_norm_att = nn.LayerNorm(d_model) + self.layer_norm = nn.LayerNorm(d_model) + + def forward( + self, + x, + pos_k, + pos_v, + mask, + relative_attention_bias: Optional[Tensor] = None, + ): + """ConformerEncoder forward. + + Args: + x: torch.Tensor + input feature of shape (batch, max_time_in, size) + pos_k: torch.Tensor + positional key embedding. + mask: torch.Tensor + mask for x (batch, max_time_in) + relative_attention_bias: Optional[torch.Tensor] + bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) + """ + x = x + 0.5 * self.feed_forward_in(x) + norm_x = self.layer_norm_att(x) + + x = x + self.self_attn( + norm_x, + norm_x, + norm_x, + pos_k, + pos_v, + mask, + relative_attention_bias=relative_attention_bias, + ) + x = x + self.conv(x) + x = x + 0.5 * self.feed_forward_out(x) + + out = self.layer_norm(x) + + return out, pos_k, pos_v + +class TransformerEncoderBase(abc.ABC, nn.Module): + """The Base class for Transformer based encoders + + Please set causal = True in streaming model + Args: + input_size: int + input feature dimension. + chunk_size: int, list(int) + Number of frames for each chunk + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training + Some examples for the 2 cases: + chunk_size = 12 + chunk_size = [6, 8, 12, 24] + left_chunk: int, list(int) + Number of chunks used for masking in streaming mode. + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training. When + chunk_size is a list, left_chunk must be a list with same length. + Some examples for the 2 cases: + left_chunk = 6 + left_chunk = [12, 9, 6, 3] + attention_dim: int, optional + attention dimension. default 256. + attention_heads: int, optional + the number of heads. default 4 + input_layer: str, optional + input layer type before Conformer, + one of ["linear", "conv2d", "custom", "vgg2l", "embed"], + default "conv2d" + cnn_out: int, optional + the number of CNN channels before Conformer. + default -1. + cnn_layer_norm: bool, optional + layer norm between Conformer and the first CNN. + default False. + time_reduction: int, optional + time reduction factor + default 4 + dropout_rate: float, optional + dropout rate. default 0.1 + padding_idx: int, optional + padding index for input_layer=embed + default -1 + relative_attention_bias_args: dict, optional + use more efficient scalar bias-based relative multihead attention (Q*K^T + B) + implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias + usage: relative_attention_bias_args={"type": t5/alibi} + additional method-specific arguments can be provided (see transformer_base.py) + positional_dropout_rate: float, optional + dropout rate after positional encoding. default 0.0 + nemo_conv_settings: dict, optional + A dictionary of settings for NeMo Subsampling. + default None + conv2d_extra_padding: str, optional + Add extra padding in conv2d subsampling layers. Choices are + (feat, feat_time, none, True). + if True or feat_time, the extra padding is added into non full + supraframe utts in batch. + Default: none + attention_group_size: int, optional + the number of groups to use for attention, default 1 (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attention_group_size < attention_heads = Grouped-Query Attention + attention_group_size = attenion_heads = Multi-Query Attention + """ + + def __init__( + self, + input_size, + chunk_size, + left_chunk, + attention_dim=256, + attention_heads=4, + input_layer="nemo_conv", + cnn_out=-1, + cnn_layer_norm=False, + time_reduction=4, + dropout_rate=0.0, + padding_idx=-1, + relative_attention_bias_args=None, + positional_dropout_rate=0.0, + nemo_conv_settings=None, + conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", + attention_group_size=1, + encoder_embedding_config=None, + ): + super().__init__() + self.input_size = input_size + self.input_layer = input_layer + self.chunk_size = chunk_size + self.left_chunk = left_chunk + self.attention_dim = attention_dim + self.num_heads = attention_heads + self.attention_group_size = attention_group_size + self.time_reduction = time_reduction + self.nemo_conv_settings = nemo_conv_settings + self.encoder_embedding_config = encoder_embedding_config + + if self.input_layer == "nemo_conv": + default_nemo_conv_settings = { + "subsampling": "dw_striding", + "subsampling_factor": self.time_reduction, + "feat_in": input_size, + "feat_out": attention_dim, + "conv_channels": 256, + "subsampling_conv_chunking_factor": 1, + "activation": nn.ReLU(), + "is_causal": False, + } + # Override any of the defaults with the incoming, user settings + if nemo_conv_settings: + default_nemo_conv_settings.update(nemo_conv_settings) + for i in ["subsampling_factor", "feat_in", "feat_out"]: + assert ( + i not in nemo_conv_settings + ), "{i} should be specified outside of the NeMo dictionary" + + self.embed = NemoConvSubsampling( + **default_nemo_conv_settings, + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + + self.pos_emb = AbsolutePositionalEncoding(attention_dim, positional_dropout_rate) + + self.relative_attention_bias_type = ( + relative_attention_bias_args.get("type") if relative_attention_bias_args else None + ) + if self.relative_attention_bias_type == "t5": + assert ( + self.num_heads % self.attention_group_size == 0 + ), "attention_group_size must divide n_head" + self.relative_attention_bias_layer = T5RelativeAttentionLogitBias( + self.num_heads // self.attention_group_size, + max_distance=relative_attention_bias_args.get("t5_bias_max_distance", 1000), + symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False), + ) + else: + raise NotImplementedError + + + def post_init(self, init_model_config): + + pretrained_speech_encoder_path = init_model_config.get('pretrained_speech_encoder_path', None) + if pretrained_speech_encoder_path: + model_state = torch.load(pretrained_speech_encoder_path, map_location="cpu") + encoder_state_dict = {} + for k, v in model_state.items(): + if "encoder." in k: + tmp_k = k.replace("encoder.", "") + encoder_state_dict[tmp_k] = v + + if hasattr(self, "encoder_embedding"): + del self.encoder_embedding + self.load_state_dict(encoder_state_dict) + + if not hasattr(self, "encoder_embedding"): + self.encoder_embedding = MeanVarianceNormLayer(self.encoder_embedding_config["input_size"]) + + mean_file = init_model_config.get('mean_file', None) + invstd_file = init_model_config.get('invstd_file', None) + if mean_file is not None and invstd_file is not None: + self.encoder_embedding.load_mean_invstd(mean_file, invstd_file) + + def compute_lens_change(self, feature_lens): + """feature_lens: int + return updated feature lens. + + This used to return a different lambda function for each case that computed + the right thing. That does not work within Torchscript. If you really + need this to be faster, create nn.Module()-s for all the cases and return + one of them. Torchscript does support that. + """ + if self.input_layer == "nemo_conv": + # Handle the special causal case + subsampling_causal_cond = self.nemo_conv_settings.get("subsampling", "dw_striding") in [ + "dw_striding", + "striding", + "striding_conv1d", + ] + is_causal = self.nemo_conv_settings.get("is_causal", False) + if is_causal and subsampling_causal_cond: + lens_change = ( + torch.ceil(feature_lens / self.time_reduction).long() + if isinstance(feature_lens, Tensor) + else math.ceil(feature_lens / self.time_reduction) + ) + feature_lens_remainder = feature_lens % self.time_reduction + if isinstance(feature_lens, Tensor): + lens_change[feature_lens_remainder != 1] += 1 + elif feature_lens_remainder != 1: + lens_change += 1 + return lens_change + # ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil + # return ceil_func(feature_lens / self.time_reduction) + return math.ceil(feature_lens / self.time_reduction) + + @abc.abstractmethod + def forward(self): + """Abstract forward method implementation.""" + + def _chunk_size_selection(self, chunk_size=None, left_chunk=None): + """If chunk size is a list, we will randomly select a chunk size.""" + + if chunk_size is None: + chunk_size = self.chunk_size + if left_chunk is None: + left_chunk = self.left_chunk + if isinstance(chunk_size, list): + # Variable chunk size during training + chunk_size_index = int(torch.randint(low=0, high=len(chunk_size), size=(1,))) + chunk_size_train_eff = chunk_size[chunk_size_index] + if not isinstance(left_chunk, list): + raise ValueError("Since chunk_size is a list, left_chunk must be a list") + if len(left_chunk) != len(chunk_size): + raise ValueError( + "The length of left_chunk must be the same as length of chunk_size." + ) + left_chunk_train_eff = left_chunk[chunk_size_index] + else: + chunk_size_train_eff = chunk_size + left_chunk_train_eff = left_chunk + + return chunk_size_train_eff, left_chunk_train_eff + + def _get_embed_class(self, embed): + # pylint: disable=protected-access + is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) + is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) + embed_class = embed + if is_embed_using_act_chkpt: + embed_class = embed._checkpoint_wrapped_module + if is_embed_fsdp_wrapped: + embed_class = embed.module + return embed_class + + def _forward_embeddings_core(self, input_tensor, masks): + embed_class = self._get_embed_class(self.embed) + assert isinstance(embed_class, NemoConvSubsampling) + input_tensor, masks = self.embed(input_tensor, masks) + return input_tensor, masks + + def _position_embedding(self, input_tensor): + pos_k = None + pos_v = None + if self.relative_attention_bias_layer is None: + input_tensor = self.pos_emb(input_tensor) # default to add abs sinusoid embedding + return pos_k, pos_v + + def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): + chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection( + chunk_size, left_chunk + ) + + # Create mask matrix for streaming + # S stores start index. if chunksize is 18, s is [0,18,36,....] + # chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff) + # avoid randomness when run evaluation or decoding + if self.training and np.random.rand() > 0.5: + # Either first or last chunk is not complete. + # If only the last one is not complete, EOS is not effective + chunk_start_idx = seq_len - chunk_start_idx + chunk_start_idx = chunk_start_idx[::-1] + chunk_start_idx = chunk_start_idx[:-1] + chunk_start_idx = np.insert(chunk_start_idx, 0, 0) + else: + chunk_start_idx = torch.tensor([], dtype=torch.int64) + + enc_streaming_mask = ( + adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk_train_eff) + .unsqueeze(0) + .expand([batch_size, -1, -1]) + ) + return enc_streaming_mask + + def expand_mask(self, mask): + # Convert `mask = torch.tensor([])` to `mask = torch.tensor(0)` + # and leave `mask` unmodified otherwise + orig_num_elements = mask.numel() + new_num_elements = max(1, orig_num_elements) + mask_shape = list(mask.shape) + mask_shape[0] = max(1, mask_shape[0]) + expanded_mask = torch.zeros(new_num_elements, dtype=mask.dtype) + expanded_mask[ : orig_num_elements] = mask.flatten() + expanded_mask = expanded_mask.view(mask_shape) + return expanded_mask + + def forward_embeddings(self, xs_pad, masks, chunk_size_nc=None, left_chunk_nc=None): + """Forwarding the inputs through the top embedding layers + + Args: + xs_pad: torch.Tensor + input tensor + masks: torch.Tensor + input mask + chunk_size_nc: (optional, default is None) chunk size for non-causal layers + left_chunk_nc: (optional, default is None) # of left chunks for non-causal layers + """ + # pylint: disable=R0915 + # get new lens. + seq_len = self.compute_lens_change(xs_pad.shape[1]) + if seq_len <= 0: + raise ValueError( + f"""The sequence length after time reduction is invalid: {seq_len}. + Your input feature is too short. Consider filtering out the very + short sentence from data loader""", + ) + + batch_size = xs_pad.shape[0] + + enc_streaming_mask = self._streaming_mask( + seq_len, batch_size, self.chunk_size, self.left_chunk + ) + + if xs_pad.is_cuda: + enc_streaming_mask = enc_streaming_mask.cuda() + xs_pad = xs_pad.cuda() + + input_tensor = xs_pad + input_tensor, masks = self._forward_embeddings_core(input_tensor, masks) + + # Select correct `hs_mask` + streaming_mask = enc_streaming_mask + expanded_masks = self.expand_mask(masks) + expanded_streaming_mask = self.expand_mask(streaming_mask) + + if_condition = torch.full_like(expanded_streaming_mask, streaming_mask.numel() > 0 and batch_size > 1).bool() + elif_condition = torch.full_like(expanded_masks, streaming_mask.numel() == 0 and batch_size > 1).bool() + else_condition = ~elif_condition + hs_mask = (expanded_masks & expanded_streaming_mask) * if_condition + expanded_masks * elif_condition + expanded_streaming_mask * else_condition + + if chunk_size_nc is not None: + enc_streaming_mask_nc = self._streaming_mask( + seq_len, batch_size, chunk_size_nc, left_chunk_nc + ) + if xs_pad.is_cuda: + enc_streaming_mask_nc = enc_streaming_mask_nc.cuda() + if masks is not None: + hs_mask_nc = masks & enc_streaming_mask_nc + else: + hs_mask_nc = enc_streaming_mask_nc + else: + hs_mask_nc = None + + pos_k, pos_v = self._position_embedding(input_tensor) + + if chunk_size_nc is None: + return input_tensor, pos_k, pos_v, hs_mask, masks + return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc + + def get_offset(self): + """Returns offset used when retaining inputs for decoding. + + This is essentially, how many additional frames have to be added to + the front-end CNN input to ensure it can produce a single output. + So if the "padding" parameter is 0, typically offset will be > 0. + """ + return get_offset(self.input_layer, self.time_reduction) + + +def get_offset(input_layer: str, time_reduction: int): + """Get an offset. We will use the offset for determining #frames of a subsampled feature. + + Args: + input_layer (str): Type of an input layer + time_reduction (int): time reduction factor for downsampling a feature + Returns: + int: offset + """ + if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4: + return 3 + if input_layer in ("conv2d",) and time_reduction == 6: + return 1 + if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8: + return 7 + return 0 + + +class ConformerEncoder(TransformerEncoderBase): + """ConformerEncoder module. + see original paper for more details: + https://arxiv.org/abs/2005.08100 + + Please set causal = True in streaming model + Args: + input_size: int + input feature dimension. + chunk_size: int, list(int) + Number of frames for each chunk + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training + Some examples for the 2 cases: + chunk_size = 12 + chunk_size = [6, 8, 12, 24] + left_chunk: int, list(int) + Number of chunks used for masking in streaming mode. + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training. When + chunk_size is a list, left_chunk must be a list with same length. + Some examples for the 2 cases: + left_chunk = 6 + left_chunk = [12, 9, 6, 3] + left_chunk: int + number of chunks used for masking in streaming mode. + num_lang: int + This parameter is used to store the number of languages in the lang_dict, + only used for multiseed/multilingual models. default None. + attention_dim: int, optional + attention dimension. default 256. + attention_heads: int, optional + the number of heads. default 4 + linear_units: + the number of units of position-wise feed forward. + default 2048 + num_block: + number of Transformer layer. default 6 + dropout_rate: float, optional + dropout rate. default 0.1 + input_layer: str, optional + input layer type before Conformer, + one of ["linear", "conv2d", "custom", "vgg2l", "embed"], + default "conv2d" + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation + in ConvModule layer of the conformer. + default False + cnn_out: int, optional + the number of CNN channels before Conformer. + default -1. + cnn_layer_norm: bool, optional + layer norm between Conformer and the first CNN. + default False. + ext_pw_out_channel: int, optional + the number of channel for CNN + before depthwise_seperable_CNN. + If 0 then use linear. default 0. + ext_pw_kernel_size: int, optional + kernel size of N before depthwise_seperable_CNN. + only work for ext_pw_out_channel > 0. + default 1 + depthwise_seperable_out_channel: int, optional + the number of channel for + depthwise_seperable_CNN. + default 256. + depthwise_multiplier: int, optional + the number of multiplier for + depthwise_seperable_CNN. + default 1. + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + default 0. + kernel_size: int, optional + the number of kernels for depthwise_seperable_CNN. + default 3. + activation: str, optional + FeedForward block activation. + one of ["relu", "swish", "sigmoid"] + default "relu". + conv_activation: str, optional + activation function used in ConvModule part + of the conformer, default "relu". + conv_glu_type: str, otional + activation used use glu in depthwise_seperable_CNN, + default "sigmoid" + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. default True + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + attention_glu_type: str + only work for glu_in_attention !=0 + default "swish". + export: bool, optional + if set to True, it remove the padding from convolutional layers + and allow the onnx conversion for inference. + default False. + activation_checkpointing: str, optional + a dictionarry of {"module","interval","offload"}, where + "module": str + accept ["transformer", "attention"] to select + which module should do activation checkpointing. + "interval": int, default 1, + interval of applying activation checkpointing, + interval = 1 means that we apply checkpointing + on every layer (if activation), otherwise, + we apply it every x interval. + "offload": bool, default False, + if set to True, we offload activation to cpu and + reload it during backward, otherwise, + we recalculate activation in backward. + default "". + extra_layer_output_idx: int + the layer index to be exposed. + relative_attention_bias_args: dict, optional + use more efficient scalar bias-based relative multihead attention (Q*K^T + B) + implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias + usage: relative_attention_bias_args={"type": t5/alibi} + additional method-specific arguments can be provided (see transformer_base.py) + time_reduction: int optional + time reduction factor + default 4 + use_pt_scaled_dot_product_attention: whether to use pytorch scaled dot product attention + in training. + Default: False + nemo_conv_settings: dict, optional + A dictionary of settings for NeMo Subsampling. + default: None + usage: nemo_conv_settings= + { + "subsampling": + dw_striding/striding/dw_striding_conv1d/striding_conv1d, + "conv_channels": int, + "subsampling_conv_chunking_factor": int, + "is_causal": True/False + } + conv2d_extra_padding: str, optional + Add extra padding in conv2d subsampling layers. Choices are + (feat, feat_time, none, True) + Default: none + replication_pad_for_subsample_embedding: For batched-streaming decoding, use + "replication" padding for the cache at start of utterance. + Default: False + attention_group_size: int, optional + the number of groups to use for attention, default 1 (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attention_group_size < attention_heads = Grouped-Query Attention + attention_group_size = attenion_heads = Multi-Query Attention + """ + + extra_multi_layer_output_idxs: List[int] + + def __init__( # pylint: disable-all + self, + input_size, + chunk_size, + left_chunk, + num_lang=None, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + input_layer="nemo_conv", + causal=True, + batch_norm=False, + cnn_out=-1, + cnn_layer_norm=False, + ext_pw_out_channel=0, + ext_pw_kernel_size=1, + depthwise_seperable_out_channel=256, + depthwise_multiplier=1, + chunk_se=0, + kernel_size=3, + activation="relu", + conv_activation="relu", + conv_glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + attention_glu_type="swish", + export=False, + extra_layer_output_idx=-1, + extra_multi_layer_output_idxs=[], + activation_checkpointing="", + relative_attention_bias_args=None, + time_reduction=4, + use_pt_scaled_dot_product_attention=False, + nemo_conv_settings=None, + conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", + replication_pad_for_subsample_embedding=False, + attention_group_size=1, + encoder_embedding_config=None, + ): + super().__init__( + input_size, + chunk_size, + left_chunk, + attention_dim, + attention_heads, + input_layer, + cnn_out, + cnn_layer_norm, + time_reduction, + dropout_rate=dropout_rate, + relative_attention_bias_args=relative_attention_bias_args, + positional_dropout_rate=0.0, + nemo_conv_settings=nemo_conv_settings, + conv2d_extra_padding=conv2d_extra_padding, + attention_group_size=attention_group_size, + encoder_embedding_config=encoder_embedding_config, + ) + self.num_blocks = num_blocks + self.num_lang = num_lang + self.kernel_size = kernel_size + self.embed = embedding_checkpoint_wrapper(activation_checkpointing)(self.embed) + self.replication_pad_for_subsample_embedding: bool = replication_pad_for_subsample_embedding + assert self.num_heads % attention_group_size == 0, "attention_group_size must divide n_head" + self.num_heads_k = self.num_heads // attention_group_size + + self.encoders = repeat( + num_blocks, + lambda i: encoder_checkpoint_wrapper( + activation_checkpointing, ConformerEncoderLayer, i + )( + ConformerEncoderLayer( + d_model=attention_dim, + ext_pw_out_channel=ext_pw_out_channel, + depthwise_seperable_out_channel=depthwise_seperable_out_channel, + depthwise_multiplier=depthwise_multiplier, + n_head=attention_heads, + d_ffn=linear_units, + ext_pw_kernel_size=ext_pw_kernel_size, + kernel_size=kernel_size, + dropout_rate=dropout_rate, + causal=causal, + batch_norm=batch_norm, + activation=activation, + chunk_se=chunk_se, + chunk_size=chunk_size, + conv_activation=conv_activation, + conv_glu_type=conv_glu_type, + bias_in_glu=bias_in_glu, + linear_glu_in_convm=linear_glu_in_convm, + attention_glu_type=attention_glu_type, + activation_checkpointing=attn_checkpointing(activation_checkpointing, i), + export=export, + use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, + attn_group_sizes=attention_group_size, + ) + ), + ) + self.extra_layer_output_idx = extra_layer_output_idx + self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs + # Make a zeros scalar we can use in get_initial_state to determine + # the device and the needed dtype: + self.register_buffer("dev_type", torch.zeros(()), persistent=False) + + for i in range(len(self.encoders)): + self.encoders[i] = self.encoders[i]._checkpoint_wrapped_module + + def init_relative_attention_bias(self, input_tensor): + if self.relative_attention_bias_layer: + return self.relative_attention_bias_layer(input_tensor) + + def calculate_hs_mask(self, xs_pad, device, mask): + max_audio_length = xs_pad.shape[1] + batch_size = xs_pad.shape[0] + enc_streaming_mask = self._streaming_mask( + max_audio_length, batch_size, self.chunk_size, self.left_chunk + ) + enc_streaming_mask = enc_streaming_mask.to(device) + + feature_lens = mask.sum(1) + padding_length = feature_lens + pad_mask = ( + torch.arange(0, max_audio_length, device=device).expand(padding_length.size(0), -1) + < padding_length.unsqueeze(1) + ) + pad_mask = pad_mask.unsqueeze(1) + pad_mask = pad_mask & enc_streaming_mask + + if_condition = torch.full_like(enc_streaming_mask, batch_size == 1).bool() + else_condition = ~if_condition + return enc_streaming_mask * if_condition + pad_mask * else_condition + + # @torch.jit.ignore + def forward(self, xs_pad, masks): + """Conformer Forward function + + Args: + xs_pad: torch.Tensor + input tensor + masks: torch.Tensor + post-embedding input lengths + """ + xs_pad = self.encoder_embedding(xs_pad) + input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings(xs_pad, masks) + + # Rewritten unfold logic + ori_bz, seq_len, D = input_tensor.shape + max_seq_len = 500 # maximum position for absolute positional encoding + + chunk_pad_size = (max_seq_len - (seq_len % max_seq_len)) % max_seq_len + # the unfold op will drop residual frames, pad it to the multiple of max_seq_len + if_padding_tensor = torch.zeros((ori_bz, chunk_pad_size, D), device=input_tensor.device) + if_padded_input_tensor = torch.cat((input_tensor, if_padding_tensor), dim=1).to(input_tensor.device) + + new_bz = ori_bz * ((seq_len + chunk_pad_size) // max_seq_len) + if_unfolded_tensor = if_padded_input_tensor.reshape(new_bz, max_seq_len, D) + + # revise hs_mask here because the previous calculated hs_mask did not consider extra pad + subsampled_pad_mask = masks.squeeze(1) # [bz, subsampled_unmask_seq_len] + extra_padded_subsampled_padding_tensor = torch.zeros((ori_bz, chunk_pad_size), device=subsampled_pad_mask.device) + extra_padded_subsampled_pad_mask = torch.cat((subsampled_pad_mask, extra_padded_subsampled_padding_tensor), dim=1).to(subsampled_pad_mask.device) # extra padding to the pad mask + if_masks_unfold = extra_padded_subsampled_pad_mask.reshape(new_bz, max_seq_len) # unfold the pad mask like we did to the input tensor + + # Calculate values in masks_unfold to use + # If condition is true, all values in masks_unfold are left as is. + # If condition is false, all values in masks_unfold are set to 0. + condition = torch.full_like(if_masks_unfold, ori_bz != 1).bool() + masks_unfold = if_masks_unfold * condition + + if_hs_mask = self.calculate_hs_mask(if_unfolded_tensor, if_unfolded_tensor.device, masks_unfold) + + # Pad original hs_mask to be the same shape as if_hs_mask + hs_mask_dim0, hs_mask_dim1, hs_mask_dim2 = hs_mask.shape + if_hs_mask_dim0, if_hs_mask_dim1, if_hs_mask_dim2 = if_hs_mask.shape + padded_dim0 = max(hs_mask_dim0, if_hs_mask_dim0) + padded_dim1 = max(hs_mask_dim1, if_hs_mask_dim1) + padded_dim2 = max(hs_mask_dim2, if_hs_mask_dim2) + + padded_hs_mask = torch.zeros(padded_dim0, padded_dim1, padded_dim2, device=input_tensor.device) + padded_hs_mask[:hs_mask_dim0, :hs_mask_dim1, :hs_mask_dim2] = hs_mask + padded_if_hs_mask = torch.zeros(padded_dim0, padded_dim1, padded_dim2, device=input_tensor.device) + padded_if_hs_mask[:if_hs_mask_dim0, :if_hs_mask_dim1, :if_hs_mask_dim2] = if_hs_mask + + if_condition = torch.full_like(padded_if_hs_mask, seq_len > max_seq_len).bool() + else_condition = ~if_condition + chosen_padded_hs_mask = padded_if_hs_mask * if_condition + padded_hs_mask * else_condition + + # Remove any padding from chosen_padded_hs_mask + shape_dim1, shape_dim2 = min(hs_mask_dim1, if_hs_mask_dim1), min(hs_mask_dim2, if_hs_mask_dim2) + hs_mask = chosen_padded_hs_mask[:, :shape_dim1, :shape_dim2] + + # layer_emb = None + hs_mask = hs_mask.to(torch.int32).unsqueeze(1).eq(0) # (batch, 1, time1, time2) + + relative_attention_bias = self.init_relative_attention_bias(input_tensor) + + _simplified_path = ( + self.extra_layer_output_idx == -1 + and relative_attention_bias is None + ) + + if _simplified_path: + input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask) + else: + for i, layer in enumerate(self.encoders): + input_tensor, _, _ = layer( + input_tensor, + pos_k, + pos_v, + hs_mask, + relative_attention_bias=relative_attention_bias, + ) + + # if i == self.extra_layer_output_idx: + # layer_emb = input_tensor + + embed_dim = input_tensor.shape[-1] + input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim) + # if we ever padded before unfolding, we need to remove the padding + input_tensor = input_tensor[:, :seq_len, :] # original code does :-chunk_pad_size to remove padding, which makes the tensor the same shape as before any padding + + return input_tensor, masks #, layer_emb + + def gradient_checkpointing_enable(self): + pass diff --git a/onnx/vision_siglip_navit.py b/onnx/vision_siglip_navit.py new file mode 100644 index 0000000000000000000000000000000000000000..bdde84b7a23e55dce73331efaa6d898ee3a9980e --- /dev/null +++ b/onnx/vision_siglip_navit.py @@ -0,0 +1,1721 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Siglip model configuration""" + +import os +from typing import Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/config.json", +} + + +class SiglipTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a + Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`SiglipModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 64): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the padding token in the vocabulary. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the vocabulary. + Example: + ```python + >>> from transformers import SiglipTextConfig, SiglipTextModel + >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipTextConfig() + >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipTextModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_text_model" + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=64, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + # This differs from `CLIPTokenizer`'s default and from openai/siglip + # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + _flash_attn_2_enabled=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self._flash_attn_2_enabled = _flash_attn_2_enabled + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["text_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class SiglipVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a + Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + Example: + ```python + >>> from transformers import SiglipVisionConfig, SiglipVisionModel + >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipVisionConfig() + >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipVisionModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + _flash_attn_2_enabled=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self._flash_attn_2_enabled = _flash_attn_2_enabled + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class SiglipConfig(PretrainedConfig): + r""" + [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to + instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SiglipTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SiglipVisionConfig`]. + kwargs (*optional*): + Dictionary of keyword arguments. + Example: + ```python + >>> from transformers import SiglipConfig, SiglipModel + >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipConfig() + >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig + >>> from transformers import SiglipTextConfig, SiglipVisionConfig + >>> # Initializing a SiglipText and SiglipVision configuration + >>> config_text = SiglipTextConfig() + >>> config_vision = SiglipVisionConfig() + >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision) + ```""" + + model_type = "siglip" + + def __init__(self, text_config=None, vision_config=None, **kwargs): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.") + + self.text_config = SiglipTextConfig(**text_config) + self.vision_config = SiglipVisionConfig(**vision_config) + + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs): + r""" + Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision + model configuration. + Returns: + [`SiglipConfig`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + +# coding=utf-8 +# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Siglip model.""" + + +import math +import warnings +from dataclasses import dataclass +from typing import Any, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn.init import _calculate_fan_in_and_fan_out + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + logging, + replace_return_docstrings, +) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" + +SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "google/siglip-base-patch16-224", + # See all SigLIP models at https://huggingface.co/models?filter=siglip +] + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + if tensor.dtype in [torch.float16, torch.bfloat16]: + # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu + og_dtype = tensor.dtype + tensor = tensor.to(torch.float32) + tensor.erfinv_() + tensor = tensor.to(og_dtype) + else: + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + if tensor.dtype == torch.float16: + # The `clamp_` op is not (yet?) defined in float16+cpu + tensor = tensor.to(torch.float32) + tensor.clamp_(min=a, max=b) + tensor = tensor.to(torch.float16) + else: + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip +class SiglipVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip +class SiglipTextModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip +class SiglipOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. + image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. + text_model_output(`BaseModelOutputWithPooling`): + The output of the [`SiglipTextModel`]. + vision_model_output(`BaseModelOutputWithPooling`): + The output of the [`SiglipVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: torch.FloatTensor = None + logits_per_text: torch.FloatTensor = None + text_embeds: torch.FloatTensor = None + image_embeds: torch.FloatTensor = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +@torch.jit.script_if_tracing +def filter_position_ids(patch_attention_mask: torch.Tensor, position_ids: torch.Tensor, boundaries: torch.Tensor, num_patches_per_side: int): + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = (bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids + return position_ids + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: + batch_size = pixel_values.size(0) + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + position_ids = torch.full( + size=( + batch_size, + max_nb_patches_h * max_nb_patches_w, + ), + fill_value=0, + ) + + position_ids = filter_position_ids(patch_attention_mask, position_ids, boundaries, self.num_patches_per_side) + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip +class SiglipTextEmbeddings(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class SiglipFlashAttention2(SiglipAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False # Hack to make sure we don't use a causal mask + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # if past_key_value is not None: + # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to the fact" + " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip +class SiglipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip +class SiglipEncoderLayer(nn.Module): + def __init__(self, config: SiglipConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = ( + SiglipAttention(config) + if not getattr(config, "_flash_attn_2_enabled", False) + else SiglipFlashAttention2(config) + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SiglipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SiglipConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + + if isinstance(module, SiglipVisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, SiglipConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.normal_(module.q_proj.weight) + nn.init.normal_(module.k_proj.weight) + nn.init.normal_(module.v_proj.weight) + nn.init.normal_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.normal_(module.fc1.weight) + nn.init.normal_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, SiglipMultiheadAttentionPoolingHead): + nn.init.normal_(module.probe.data) + nn.init.normal_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, SiglipModel): + logit_scale_init = torch.tensor(0.0) + module.logit_scale.data.fill_(logit_scale_init) + module.logit_bias.data.zero_() + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +SIGLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`SiglipConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SIGLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +SIGLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +SIGLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip +class SiglipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SiglipEncoderLayer`]. + Args: + config: SiglipConfig + """ + + def __init__(self, config: SiglipConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class SiglipTextTransformer(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = SiglipTextEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.head = nn.Linear(embed_dim, embed_dim) + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. + # expand attention_mask + if attention_mask is not None: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # Assuming "sticky" EOS tokenization, last token is always EOS. + pooled_output = last_hidden_state[:, -1, :] + pooled_output = self.head(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The text model from SigLIP without any head or projection on top.""", + SIGLIP_START_DOCSTRING, +) +class SiglipTextModel(SiglipPreTrainedModel): + config_class = SiglipTextConfig + + _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"] + + def __init__(self, config: SiglipTextConfig): + super().__init__(config) + self.text_model = SiglipTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + Examples: + ```python + >>> from transformers import AutoTokenizer, SiglipTextModel + >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SiglipVisionTransformer(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.head = SiglipMultiheadAttentionPoolingHead(config) + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_attention_mask = torch.ones( + size=( + batch_size, + pixel_values.size(2) // self.config.patch_size, + pixel_values.size(3) // self.config.patch_size, + ), + dtype=torch.bool, + device=pixel_values.device, + ) + + hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + attention_mask=None + else: + attention_mask = ( + _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) + if not self.config._flash_attn_2_enabled + else patch_attention_mask + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = self.head( + hidden_state=last_hidden_state, + attention_mask=patch_attention_mask, + ) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward(self, hidden_state, attention_mask): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention( + query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask + )[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +@add_start_docstrings( + """The vision model from SigLIP without any head or projection on top.""", + SIGLIP_START_DOCSTRING, +) +class SiglipVisionModel(SiglipPreTrainedModel): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + + self.vision_model = SiglipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, SiglipVisionModel + >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings(SIGLIP_START_DOCSTRING) +class SiglipModel(SiglipPreTrainedModel): + config_class = SiglipConfig + + def __init__(self, config: SiglipConfig): + super().__init__(config) + + if not isinstance(config.text_config, SiglipTextConfig): + raise ValueError( + "config.text_config is expected to be of type SiglipTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, SiglipVisionConfig): + raise ValueError( + "config.vision_config is expected to be of type SiglipVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + self.text_model = SiglipTextTransformer(text_config) + self.vision_model = SiglipVisionTransformer(vision_config) + + self.logit_scale = nn.Parameter(torch.randn(1)) + self.logit_bias = nn.Parameter(torch.randn(1)) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`SiglipTextModel`]. + Examples: + ```python + >>> from transformers import AutoTokenizer, AutoModel + >>> import torch + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... text_features = model.get_text_features(**inputs) + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`SiglipVisionModel`]. + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + >>> with torch.no_grad(): + ... image_features = model.get_image_features(**inputs) + ```""" + # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SiglipOutput]: + r""" + Returns: + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] + >>> # important: we pass `padding=max_length` since the model was trained with this + >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image + >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 31.9% that image 0 is 'a photo of 2 cats' + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + text_embeds = text_outputs[1] + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + raise NotImplementedError("SigLIP loss to be implemented") + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return SiglipOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +def get_siglip_vision_model(_flash_attn_2_enabled=True, **kwargs): + siglip_vision_config = { + "hidden_size": 1152, + "image_size": 448, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14, + } + + model_config = SiglipVisionConfig(**siglip_vision_config, _flash_attn_2_enabled=_flash_attn_2_enabled, **kwargs) + + vision_model = SiglipVisionModel(model_config).vision_model + + return vision_model