Commit
·
1e1a20a
1
Parent(s):
b6d9ff2
Update README.md for flash attn
Browse files
README.md
CHANGED
|
@@ -105,10 +105,16 @@ triton==2.0.0.dev20221202
|
|
| 105 |
|
| 106 |
Then, move the model to `bfloat16` and use it as follows:
|
| 107 |
```python
|
| 108 |
-
from transformers import AutoModelForCausalLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
# load model
|
| 111 |
-
model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b',
|
| 112 |
model.to(device='cuda:0', dtype=torch.bfloat16)
|
| 113 |
|
| 114 |
# forward pass
|
|
|
|
| 105 |
|
| 106 |
Then, move the model to `bfloat16` and use it as follows:
|
| 107 |
```python
|
| 108 |
+
from transformers import AutoModelForCausalLM, AutoConfig
|
| 109 |
+
|
| 110 |
+
config = AutoConfig.from_pretrained(
|
| 111 |
+
"replit/replit-code-v1-3b",
|
| 112 |
+
trust_remote_code=True
|
| 113 |
+
)
|
| 114 |
+
config.attn_config['attn_impl'] = 'triton'
|
| 115 |
|
| 116 |
# load model
|
| 117 |
+
model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', config=config, trust_remote_code=True)
|
| 118 |
model.to(device='cuda:0', dtype=torch.bfloat16)
|
| 119 |
|
| 120 |
# forward pass
|