| import sys | |
| from importlib import import_module | |
| from datasets import load_dataset | |
| import argparse | |
| def main(): | |
| if len(sys.argv) < 3: | |
| raise Exception( | |
| 'args len < 3, example: fengshen_pipeline text_classification predict xxxxx') | |
| pipeline_name = sys.argv[1] | |
| method = sys.argv[2] | |
| pipeline_class = getattr(import_module('fengshen.pipelines.' + pipeline_name), 'Pipeline') | |
| total_parser = argparse.ArgumentParser("FengShen Pipeline") | |
| total_parser.add_argument('--model', default='', type=str) | |
| total_parser.add_argument('--datasets', default='', type=str) | |
| total_parser.add_argument('--text', default='', type=str) | |
| total_parser = pipeline_class.add_pipeline_specific_args(total_parser) | |
| args = total_parser.parse_args(args=sys.argv[3:]) | |
| pipeline = pipeline_class(args=args, model=args.model) | |
| if method == 'predict': | |
| print(pipeline(args.text)) | |
| elif method == 'train': | |
| datasets = load_dataset(args.datasets) | |
| pipeline.train(datasets) | |
| else: | |
| raise Exception( | |
| 'cmd not support, now only support {predict, train}') | |
| if __name__ == '__main__': | |
| main() | |