| def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): | |
| """Truncates a pair of sequences to a maximum sequence length.""" | |
| # print(len_a, len_b, max_num_tokens) | |
| assert len_a > 0 | |
| if len_a + len_b <= max_num_tokens: | |
| return False | |
| while len_a + len_b > max_num_tokens: | |
| if len_a > len_b: | |
| len_a -= 1 | |
| tokens = tokens_a | |
| else: | |
| len_b -= 1 | |
| tokens = tokens_b | |
| if np_rng.random() < 0.5: | |
| del tokens[0] | |
| else: | |
| tokens.pop() | |
| return True | |