Spaces:
Runtime error
Runtime error
| import random | |
| import numpy as np | |
| from nltk import word_tokenize | |
| from collections import defaultdict | |
| from copy import deepcopy | |
| import tqdm | |
| class PunktTokenizer: | |
| def __call__(self, texts): | |
| return [word_tokenize(t) for t in texts] | |
| class WhiteSpaceTokenizer: | |
| def __call__(self, texts): | |
| return [t.split() for t in texts] | |
| class SearchState: | |
| def __init__(self, tokens): | |
| self.tokens = tokens | |
| self.masks = [] | |
| self.mask_set = set() | |
| self.summaries = [] | |
| self.scores = [] | |
| self.best_step = None | |
| self.terminated = False | |
| self.step = 0 | |
| def update(self, mask, summary, score): | |
| if self.best_step is None or score > self.best_score(): | |
| self.best_step = self.step | |
| self.masks.append(mask) | |
| self.mask_set.add(tuple(mask)) | |
| self.summaries.append(summary) | |
| self.scores.append(score) | |
| self.step += 1 | |
| def best_mask(self): | |
| return self.masks[self.best_step] | |
| def best_score(self): | |
| return self.scores[self.best_step] | |
| def best_summary(self): | |
| return self.summaries[self.best_step] | |
| def to_dict(self): | |
| return { | |
| "scores": self.scores, | |
| "masks": self.masks, | |
| "summaries": self.summaries, | |
| "best_summary": self.best_summary(), | |
| "best_score": self.best_score(), | |
| } | |
| class DynamicRestartHCSC: | |
| def __init__(self, tokenizer, objective): | |
| self.tokenizer = tokenizer | |
| self.objective = objective | |
| self.n_trials = 100 | |
| def _mask_to_summary(self, mask, tokens): | |
| summary = [tokens[i] for i in range(len(mask)) if mask[i] == 1] | |
| return " ".join(summary) | |
| def _sample(self, state, sent_len, target_len, from_scratch=False): | |
| """ | |
| Swaps one selected word for another, discarding previous solutions. | |
| """ | |
| if target_len >= sent_len: | |
| mask = [1 for _ in range(sent_len)] | |
| state.terminated = True | |
| return mask, True | |
| if state.step == 0 or from_scratch: | |
| indices = list(range(sent_len)) | |
| sampled = set(random.sample(indices, min(target_len, sent_len))) | |
| mask = [int(i in sampled) for i in indices] | |
| return mask, False | |
| else: | |
| mask = state.masks[state.best_step] | |
| indices = list(range(len(mask))) | |
| one_indices = [i for i in range(len(mask)) if mask[i] == 1] | |
| zero_indices = [i for i in range(len(mask)) if mask[i] == 0] | |
| if len(zero_indices) == 0: | |
| return mask | |
| terminated = True | |
| # trying to find unknown state, heuristically with fixed no. trials | |
| for _ in range(self.n_trials): | |
| i = random.choice(one_indices) | |
| j = random.choice(zero_indices) | |
| new_mask = mask.copy() | |
| new_mask[i] = 0 | |
| new_mask[j] = 1 | |
| if tuple(new_mask) not in state.mask_set: | |
| terminated = False | |
| mask = new_mask | |
| break | |
| # terminate if no unknown neighbor state is found | |
| return mask, terminated | |
| def aggregate_states(self, states): | |
| masks = [m for s in states for m in s.masks] | |
| summaries = [x for s in states for x in s.summaries] | |
| scores = [x for s in states for x in s.scores] | |
| best_step = np.argmax(scores) | |
| return { | |
| "masks": masks, | |
| "summaries": summaries, | |
| "scores": scores, | |
| "best_score": scores[best_step], | |
| "best_summary": summaries[best_step], | |
| } | |
| def __call__( | |
| self, | |
| sentences, | |
| target_lens, | |
| n_steps=100, | |
| verbose=False, | |
| return_states=False, | |
| ): | |
| tok_sentences = self.tokenizer(sentences) | |
| batch_size = len(sentences) | |
| terminated_states = [[] for _ in range(batch_size)] | |
| states = [SearchState(s) for s in tok_sentences] | |
| for t in tqdm.tqdm(list(range(1, n_steps + 1))): | |
| masks = [] | |
| for i in range(batch_size): | |
| if states[i].terminated: | |
| if verbose: | |
| print(f"step {t}, restarting state {i} with score {states[i].best_score()}") | |
| terminated_states[i].append(states[i]) | |
| states[i] = SearchState(tok_sentences[i]) | |
| mask, terminated = self._sample( | |
| states[i], | |
| sent_len=len(tok_sentences[i]), | |
| target_len=target_lens[i], | |
| ) | |
| states[i].terminated = terminated | |
| masks.append(mask) | |
| summaries = [ | |
| self._mask_to_summary(m, tokens) | |
| for m, tokens in zip(masks, tok_sentences) | |
| ] | |
| scores, _ = self.objective(sentences, summaries) | |
| if verbose: | |
| print(f"t={t}") | |
| for i in range(batch_size): | |
| print(f"[{scores[i]:.3f}][{summaries[i]}]") | |
| print() | |
| for i in range(batch_size): | |
| states[i].update(masks[i], summaries[i], scores[i]) | |
| for i in range(batch_size): | |
| terminated_states[i].append(states[i]) | |
| output_states = [ | |
| self.aggregate_states(i_states) for i_states in terminated_states | |
| ] | |
| return output_states | |