| import json | |
| import logging | |
| from qa_generator_pipeline import QAGeneratorPipeline | |
| logger = logging.getLogger(__name__) | |
| JSON_CONTENT_TYPE = 'application/json' | |
| def model_fn(model_dir): | |
| logging.info('[### model_fn ###] Loading model from {}'.format(model_dir)) | |
| model = QAGeneratorPipeline(model_dir=model_dir, use_cuda=True) | |
| return model | |
| def predict_fn(input_data, model): | |
| logging.info('[### predict_fn ###] Entering predict_fn() method') | |
| logger.info("input text: {}".format(input_data)) | |
| prediction = model(input_data) | |
| logger.info("prediction: {}".format(input_data)) | |
| return prediction | |
| def input_fn(serialized_input_data, content_type=JSON_CONTENT_TYPE): | |
| logging.info('[### input_fn ###] Entering input_fn() method') | |
| logging.info('[### input_fn ###] request_content_type: {}'.format(content_type)) | |
| logging.info('[### input_fn ###] request_body: {}'.format(type(serialized_input_data))) | |
| if content_type == JSON_CONTENT_TYPE: | |
| input_data = json.loads(serialized_input_data) | |
| return input_data | |
| else: | |
| pass | |
| def output_fn(prediction_output, accept=JSON_CONTENT_TYPE): | |
| logging.info('[### output_fn ###] Entering output_fn() method') | |
| logging.info('[### output_fn ###] prediction: {}'.format(prediction_output)) | |
| if accept == JSON_CONTENT_TYPE: | |
| return json.dumps(prediction_output), accept | |
| raise Exception('Unsupported Content Type') |