Update qg_pipeline.py
Browse files- qg_pipeline.py +2 -2
qg_pipeline.py
CHANGED
|
@@ -12,7 +12,7 @@ class AEHandler:
|
|
| 12 |
def __init__(self, model, tokenizer):
|
| 13 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model)
|
| 14 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
| 15 |
-
self.device = torch.device('
|
| 16 |
self.model.to(self.device)
|
| 17 |
|
| 18 |
def __call__(self, context):
|
|
@@ -62,7 +62,7 @@ class QGHandler:
|
|
| 62 |
def __init__(self, model, tokenizer):
|
| 63 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model)
|
| 64 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
| 65 |
-
self.device = torch.device('
|
| 66 |
self.model.to(self.device)
|
| 67 |
|
| 68 |
def __call__(self, answers, context):
|
|
|
|
| 12 |
def __init__(self, model, tokenizer):
|
| 13 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model)
|
| 14 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
| 15 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 16 |
self.model.to(self.device)
|
| 17 |
|
| 18 |
def __call__(self, context):
|
|
|
|
| 62 |
def __init__(self, model, tokenizer):
|
| 63 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model)
|
| 64 |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
| 65 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 66 |
self.model.to(self.device)
|
| 67 |
|
| 68 |
def __call__(self, answers, context):
|