Commit 
							
							·
						
						6069b37
	
1
								Parent(s):
							
							d7f24a6
								
add `apply_rotary_pos_emb` API to make it adapt to transformers
Browse files- rotary-xpu/rotary_xpu.cpp +38 -0
- rotary/rotary_cuda.cu +38 -0
    	
        rotary-xpu/rotary_xpu.cpp
    CHANGED
    
    | @@ -38,3 +38,41 @@ void _apply_rotary(torch::Tensor const &x1, torch::Tensor const &x2, | |
| 38 | 
             
                    });
         | 
| 39 | 
             
                }
         | 
| 40 | 
             
            }
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 38 | 
             
                    });
         | 
| 39 | 
             
                }
         | 
| 40 | 
             
            }
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            std::tuple<torch::Tensor, torch::Tensor> apply_rotary_pos_emb(
         | 
| 43 | 
            +
                torch::Tensor const &q, torch::Tensor const &k,
         | 
| 44 | 
            +
                torch::Tensor const &cos, torch::Tensor const &sin,
         | 
| 45 | 
            +
                torch::Tensor const &position_ids, int64_t unsqueeze_dim) {
         | 
| 46 | 
            +
                
         | 
| 47 | 
            +
                // Handle unsqueeze_dim parameter
         | 
| 48 | 
            +
                auto cos_unsqueezed = cos.unsqueeze(unsqueeze_dim);
         | 
| 49 | 
            +
                auto sin_unsqueezed = sin.unsqueeze(unsqueeze_dim);
         | 
| 50 | 
            +
                
         | 
| 51 | 
            +
                // Clone inputs since we'll modify them
         | 
| 52 | 
            +
                auto q_rotated = q.clone();
         | 
| 53 | 
            +
                auto k_rotated = k.clone();
         | 
| 54 | 
            +
                
         | 
| 55 | 
            +
                // Get half dimension for rotation
         | 
| 56 | 
            +
                int64_t half_dim = q.size(-1) / 2;
         | 
| 57 | 
            +
                
         | 
| 58 | 
            +
                // Split Q and K for rotation
         | 
| 59 | 
            +
                auto q1 = q_rotated.slice(-1, 0, half_dim);
         | 
| 60 | 
            +
                auto q2 = q_rotated.slice(-1, half_dim, q.size(-1));
         | 
| 61 | 
            +
                auto k1 = k_rotated.slice(-1, 0, half_dim);
         | 
| 62 | 
            +
                auto k2 = k_rotated.slice(-1, half_dim, k.size(-1));
         | 
| 63 | 
            +
                
         | 
| 64 | 
            +
                // Make sure cos/sin match the half dimension
         | 
| 65 | 
            +
                auto cos_final = cos_unsqueezed;
         | 
| 66 | 
            +
                auto sin_final = sin_unsqueezed;
         | 
| 67 | 
            +
                if (cos_unsqueezed.size(-1) != half_dim) {
         | 
| 68 | 
            +
                    // Trim cos/sin to match half_dim
         | 
| 69 | 
            +
                    cos_final = cos_unsqueezed.slice(-1, 0, half_dim);
         | 
| 70 | 
            +
                    sin_final = sin_unsqueezed.slice(-1, 0, half_dim);
         | 
| 71 | 
            +
                }
         | 
| 72 | 
            +
                
         | 
| 73 | 
            +
                // Apply rotary embedding using our kernel
         | 
| 74 | 
            +
                _apply_rotary(q1, q2, cos_final, sin_final, q1, q2, false);
         | 
| 75 | 
            +
                _apply_rotary(k1, k2, cos_final, sin_final, k1, k2, false);
         | 
| 76 | 
            +
                
         | 
| 77 | 
            +
                return std::make_tuple(q_rotated, k_rotated);
         | 
| 78 | 
            +
            }
         | 
    	
        rotary/rotary_cuda.cu
    CHANGED
    
    | @@ -43,3 +43,41 @@ void _apply_rotary(torch::Tensor const &x1, torch::Tensor const &x2, | |
| 43 | 
             
                    });
         | 
| 44 | 
             
                }
         | 
| 45 | 
             
            }
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 43 | 
             
                    });
         | 
| 44 | 
             
                }
         | 
| 45 | 
             
            }
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            std::tuple<torch::Tensor, torch::Tensor> apply_rotary_pos_emb(
         | 
| 48 | 
            +
                torch::Tensor const &q, torch::Tensor const &k,
         | 
| 49 | 
            +
                torch::Tensor const &cos, torch::Tensor const &sin,
         | 
| 50 | 
            +
                torch::Tensor const &position_ids, int64_t unsqueeze_dim) {
         | 
| 51 | 
            +
                
         | 
| 52 | 
            +
                // Handle unsqueeze_dim parameter
         | 
| 53 | 
            +
                auto cos_unsqueezed = cos.unsqueeze(unsqueeze_dim);
         | 
| 54 | 
            +
                auto sin_unsqueezed = sin.unsqueeze(unsqueeze_dim);
         | 
| 55 | 
            +
                
         | 
| 56 | 
            +
                // Clone inputs since we'll modify them
         | 
| 57 | 
            +
                auto q_rotated = q.clone();
         | 
| 58 | 
            +
                auto k_rotated = k.clone();
         | 
| 59 | 
            +
                
         | 
| 60 | 
            +
                // Get half dimension for rotation
         | 
| 61 | 
            +
                int64_t half_dim = q.size(-1) / 2;
         | 
| 62 | 
            +
                
         | 
| 63 | 
            +
                // Split Q and K for rotation
         | 
| 64 | 
            +
                auto q1 = q_rotated.slice(-1, 0, half_dim);
         | 
| 65 | 
            +
                auto q2 = q_rotated.slice(-1, half_dim, q.size(-1));
         | 
| 66 | 
            +
                auto k1 = k_rotated.slice(-1, 0, half_dim);
         | 
| 67 | 
            +
                auto k2 = k_rotated.slice(-1, half_dim, k.size(-1));
         | 
| 68 | 
            +
                
         | 
| 69 | 
            +
                // Make sure cos/sin match the half dimension
         | 
| 70 | 
            +
                auto cos_final = cos_unsqueezed;
         | 
| 71 | 
            +
                auto sin_final = sin_unsqueezed;
         | 
| 72 | 
            +
                if (cos_unsqueezed.size(-1) != half_dim) {
         | 
| 73 | 
            +
                    // Trim cos/sin to match half_dim
         | 
| 74 | 
            +
                    cos_final = cos_unsqueezed.slice(-1, 0, half_dim);
         | 
| 75 | 
            +
                    sin_final = sin_unsqueezed.slice(-1, 0, half_dim);
         | 
| 76 | 
            +
                }
         | 
| 77 | 
            +
                
         | 
| 78 | 
            +
                // Apply rotary embedding using our kernel
         | 
| 79 | 
            +
                _apply_rotary(q1, q2, cos_final, sin_final, q1, q2, false);
         | 
| 80 | 
            +
                _apply_rotary(k1, k2, cos_final, sin_final, k1, k2, false);
         | 
| 81 | 
            +
                
         | 
| 82 | 
            +
                return std::make_tuple(q_rotated, k_rotated);
         | 
| 83 | 
            +
            }
         | 
