| """ | |
| scikit-learn Model Wrapper | |
| -------------------------- | |
| """ | |
| import pandas as pd | |
| from .model_wrapper import ModelWrapper | |
| class SklearnModelWrapper(ModelWrapper): | |
| """Loads a scikit-learn model and tokenizer (tokenizer implements | |
| `transform` and model implements `predict_proba`). | |
| May need to be extended and modified for different types of | |
| tokenizers. | |
| """ | |
| def __init__(self, model, tokenizer): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| def __call__(self, text_input_list, batch_size=None): | |
| encoded_text_matrix = self.tokenizer.transform(text_input_list).toarray() | |
| tokenized_text_df = pd.DataFrame( | |
| encoded_text_matrix, columns=self.tokenizer.get_feature_names() | |
| ) | |
| return self.model.predict_proba(tokenized_text_df) | |
| def get_grad(self, text_input): | |
| raise NotImplementedError() | |