Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| scaled_dot_product_attention = F.scaled_dot_product_attention | |
| if os.environ.get('CA_USE_SAGEATTN', '0') == '1': | |
| try: | |
| from sageattention import sageattn | |
| except ImportError: | |
| raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.') | |
| scaled_dot_product_attention = sageattn | |
| class CrossAttentionProcessor: | |
| def __call__(self, attn, q, k, v): | |
| out = scaled_dot_product_attention(q, k, v) | |
| return out | |