Spaces:
Runtime error
Runtime error
✨ add ability to force CPU
Browse filesSigned-off-by: peter szemraj <peterszemraj@gmail.com>
- aggregate.py +9 -2
aggregate.py
CHANGED
|
@@ -54,15 +54,22 @@ class BatchAggregator:
|
|
| 54 |
DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:"
|
| 55 |
|
| 56 |
def __init__(
|
| 57 |
-
self,
|
|
|
|
|
|
|
|
|
|
| 58 |
):
|
| 59 |
"""
|
| 60 |
__init__ initializes the BatchAggregator class.
|
| 61 |
|
| 62 |
:param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
|
|
|
|
| 63 |
"""
|
| 64 |
self.device = None
|
| 65 |
self.is_compiled = False
|
|
|
|
|
|
|
|
|
|
| 66 |
self.logger = logging.getLogger(__name__)
|
| 67 |
self.init_model(model_name)
|
| 68 |
|
|
@@ -105,7 +112,7 @@ class BatchAggregator:
|
|
| 105 |
|
| 106 |
:raises Exception: if the pipeline cannot be created
|
| 107 |
"""
|
| 108 |
-
self.device = 0 if torch.cuda.is_available() else -1
|
| 109 |
try:
|
| 110 |
self.logger.info(
|
| 111 |
f"Creating pipeline with model {model_name} on device {self.device}"
|
|
|
|
| 54 |
DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:"
|
| 55 |
|
| 56 |
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1",
|
| 59 |
+
force_cpu: bool = False,
|
| 60 |
+
**kwargs,
|
| 61 |
):
|
| 62 |
"""
|
| 63 |
__init__ initializes the BatchAggregator class.
|
| 64 |
|
| 65 |
:param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
|
| 66 |
+
:param bool force_cpu: force the model to run on CPU, default: False
|
| 67 |
"""
|
| 68 |
self.device = None
|
| 69 |
self.is_compiled = False
|
| 70 |
+
self.model_name = None
|
| 71 |
+
self.aggregator = None
|
| 72 |
+
self.force_cpu = force_cpu
|
| 73 |
self.logger = logging.getLogger(__name__)
|
| 74 |
self.init_model(model_name)
|
| 75 |
|
|
|
|
| 112 |
|
| 113 |
:raises Exception: if the pipeline cannot be created
|
| 114 |
"""
|
| 115 |
+
self.device = 0 if torch.cuda.is_available() and not self.force_cpu else -1
|
| 116 |
try:
|
| 117 |
self.logger.info(
|
| 118 |
f"Creating pipeline with model {model_name} on device {self.device}"
|