Spaces:
Runtime error
Runtime error
π¨ π
Browse filesSigned-off-by: peter szemraj <peterszemraj@gmail.com>
- aggregate.py +31 -18
aggregate.py
CHANGED
|
@@ -1,3 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import pprint as pp
|
| 2 |
import logging
|
| 3 |
import time
|
|
@@ -14,10 +23,15 @@ logging.basicConfig(
|
|
| 14 |
|
| 15 |
|
| 16 |
class BatchAggregator:
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
GENERIC_CONFIG = GenerationConfig(
|
| 22 |
num_beams=8,
|
| 23 |
early_stopping=True,
|
|
@@ -29,10 +43,23 @@ class BatchAggregator:
|
|
| 29 |
no_repeat_ngram_size=4,
|
| 30 |
encoder_no_repeat_ngram_size=5,
|
| 31 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
def __init__(
|
| 34 |
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs
|
| 35 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
self.device = None
|
| 37 |
self.is_compiled = False
|
| 38 |
self.logger = logging.getLogger(__name__)
|
|
@@ -125,20 +152,6 @@ class BatchAggregator:
|
|
| 125 |
"""
|
| 126 |
self.aggregator.model.generation_config = self.GENERIC_CONFIG
|
| 127 |
|
| 128 |
-
if "bart" in self.model_name.lower():
|
| 129 |
-
self.logger.info("Using BART model, updating generation config")
|
| 130 |
-
upd = {
|
| 131 |
-
"num_beams": 8,
|
| 132 |
-
"repetition_penalty": 1.3,
|
| 133 |
-
"length_penalty": 1.0,
|
| 134 |
-
"_from_model_config": False,
|
| 135 |
-
"max_new_tokens": 256,
|
| 136 |
-
"min_new_tokens": 32,
|
| 137 |
-
"no_repeat_ngram_size": 3,
|
| 138 |
-
"encoder_no_repeat_ngram_size": 6,
|
| 139 |
-
} # TODO: clean up
|
| 140 |
-
self.aggregator.model.generation_config.update(**upd)
|
| 141 |
-
|
| 142 |
if (
|
| 143 |
"large"
|
| 144 |
or "xl" in self.model_name.lower()
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
aggregate.py is a module for aggregating text from multiple sources, or multiple parts of a single source.
|
| 3 |
+
Primary usage is through the BatchAggregator class.
|
| 4 |
+
|
| 5 |
+
How it works:
|
| 6 |
+
1. We tell the language model to do it.
|
| 7 |
+
2. The language model does it.
|
| 8 |
+
3. Yaay!
|
| 9 |
+
"""
|
| 10 |
import pprint as pp
|
| 11 |
import logging
|
| 12 |
import time
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
class BatchAggregator:
|
| 26 |
+
"""
|
| 27 |
+
BatchAggregator is a class for aggregating text from multiple sources.
|
| 28 |
+
|
| 29 |
+
Usage:
|
| 30 |
+
>>> from aggregate import BatchAggregator
|
| 31 |
+
>>> aggregator = BatchAggregator()
|
| 32 |
+
>>> aggregator.aggregate(["This is a test", "This is another test"])
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
GENERIC_CONFIG = GenerationConfig(
|
| 36 |
num_beams=8,
|
| 37 |
early_stopping=True,
|
|
|
|
| 43 |
no_repeat_ngram_size=4,
|
| 44 |
encoder_no_repeat_ngram_size=5,
|
| 45 |
)
|
| 46 |
+
CONFIGURED_MODELS = [
|
| 47 |
+
"pszemraj/bart-large-mnli-dolly_hhrlhf-v1",
|
| 48 |
+
"pszemraj/bart-base-instruct-dolly_hhrlhf",
|
| 49 |
+
"pszemraj/flan-t5-large-instruct-dolly_hhrlhf",
|
| 50 |
+
"pszemraj/flan-t5-base-instruct-dolly_hhrlhf",
|
| 51 |
+
] # these have generation configs defined for this task in their model repos
|
| 52 |
+
|
| 53 |
+
DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:"
|
| 54 |
|
| 55 |
def __init__(
|
| 56 |
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs
|
| 57 |
):
|
| 58 |
+
"""
|
| 59 |
+
__init__ initializes the BatchAggregator class.
|
| 60 |
+
|
| 61 |
+
:param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
|
| 62 |
+
"""
|
| 63 |
self.device = None
|
| 64 |
self.is_compiled = False
|
| 65 |
self.logger = logging.getLogger(__name__)
|
|
|
|
| 152 |
"""
|
| 153 |
self.aggregator.model.generation_config = self.GENERIC_CONFIG
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
if (
|
| 156 |
"large"
|
| 157 |
or "xl" in self.model_name.lower()
|