Spaces:
Build error
Build error
| import os | |
| import random | |
| import re | |
| import shutil | |
| from pyke import knowledge_engine | |
| class PykeProgram: | |
| def __init__( | |
| self, logic_program: str, dataset_name='ProntoQA', workspace_mount_path='./' | |
| ) -> None: | |
| self.logic_program = logic_program | |
| self.flag = self.parse_logic_program() | |
| self.dataset_name = dataset_name | |
| self.cache_dir = os.path.join(workspace_mount_path, '.cache_program') | |
| # prepare the files for facts and rules | |
| try: | |
| self.create_fact_file(self.Facts) | |
| self.create_rule_file(self.Rules) | |
| self.flag = True | |
| except Exception: | |
| self.flag = False | |
| self.answer_map = { | |
| 'ProntoQA': self.answer_map_prontoqa, | |
| 'ProofWriter': self.answer_map_proofwriter, | |
| } | |
| def parse_logic_program(self): | |
| keywords = ['Query:', 'Rules:', 'Facts:', 'Predicates:'] | |
| program_str = self.logic_program | |
| for keyword in keywords: | |
| try: | |
| program_str, segment_list = self._parse_segment(program_str, keyword) | |
| setattr(self, keyword[:-1], segment_list) | |
| except Exception: | |
| setattr(self, keyword[:-1], None) | |
| return self.validate_program() | |
| def _parse_segment(self, program_str, key_phrase): | |
| remain_program_str, segment = program_str.split(key_phrase) | |
| segment_list = segment.strip().split('\n') | |
| for i in range(len(segment_list)): | |
| segment_list[i] = segment_list[i].split(':::')[0].strip() | |
| return remain_program_str, segment_list | |
| # check if the program is valid; if not, try to fix it | |
| def validate_program(self): | |
| if self.Rules is not None and self.Facts is not None: | |
| if not self.Rules[0] == '' and not self.Facts[0] == '': | |
| return True | |
| # try to fix the program | |
| tmp_rules = [] | |
| tmp_facts = [] | |
| statements = self.Facts if self.Facts is not None else self.Rules | |
| if statements is None: | |
| return False | |
| for fact in statements: | |
| if fact.find('>>>') >= 0: # this is a rule | |
| tmp_rules.append(fact) | |
| else: | |
| tmp_facts.append(fact) | |
| self.Rules = tmp_rules | |
| self.Facts = tmp_facts | |
| return False | |
| def create_fact_file(self, facts): | |
| with open(os.path.join(self.cache_dir, 'facts.kfb'), 'w') as f: | |
| for fact in facts: | |
| # check for invalid facts | |
| if not fact.find('$x') >= 0: | |
| f.write(fact + '\n') | |
| def create_rule_file(self, rules): | |
| pyke_rules = [] | |
| for idx, rule in enumerate(rules): | |
| pyke_rules.append(self.parse_forward_rule(idx + 1, rule)) | |
| with open(os.path.join(self.cache_dir, 'rules.krb'), 'w') as f: | |
| f.write('\n\n'.join(pyke_rules)) | |
| # example rule: Furry($x, True) && Quite($x, True) >>> White($x, True) | |
| def parse_forward_rule(self, f_index, rule): | |
| premise, conclusion = rule.split('>>>') | |
| premise = premise.strip() | |
| # split the premise into multiple facts if needed | |
| premise = premise.split('&&') | |
| premise_list = [p.strip() for p in premise] | |
| conclusion = conclusion.strip() | |
| # split the conclusion into multiple facts if needed | |
| conclusion = conclusion.split('&&') | |
| conclusion_list = [c.strip() for c in conclusion] | |
| # create the Pyke rule | |
| pyke_rule = f"""fact{f_index}\n\tforeach""" | |
| for p in premise_list: | |
| pyke_rule += f"""\n\t\tfacts.{p}""" | |
| pyke_rule += """\n\tassert""" | |
| for c in conclusion_list: | |
| pyke_rule += f"""\n\t\tfacts.{c}""" | |
| return pyke_rule | |
| """ | |
| for example: Is Marvin from Mars? | |
| Query: FromMars(Marvin, $label) | |
| """ | |
| def check_specific_predicate(self, subject_name, predicate_name, engine): | |
| results = [] | |
| with engine.prove_goal( | |
| f'facts.{predicate_name}({subject_name}, $label)' | |
| ) as gen: | |
| for vars, plan in gen: | |
| results.append(vars['label']) | |
| with engine.prove_goal( | |
| f'rules.{predicate_name}({subject_name}, $label)' | |
| ) as gen: | |
| for vars, plan in gen: | |
| results.append(vars['label']) | |
| if len(results) == 1: | |
| return results[0] | |
| elif len(results) == 2: | |
| return results[0] and results[1] | |
| elif len(results) == 0: | |
| return None | |
| """ | |
| Input Example: Metallic(Wren, False) | |
| """ | |
| def parse_query(self, query): | |
| pattern = r'(\w+)\(([^,]+),\s*([^)]+)\)' | |
| match = re.match(pattern, query) | |
| if match: | |
| function_name = match.group(1) | |
| arg1 = match.group(2) | |
| arg2 = match.group(3) | |
| arg2 = True if arg2 == 'True' else False | |
| return function_name, arg1, arg2 | |
| else: | |
| raise ValueError(f'Invalid query: {query}') | |
| def execute_program(self): | |
| # delete the compiled_krb dir | |
| complied_krb_dir = './models/compiled_krb' | |
| if os.path.exists(complied_krb_dir): | |
| print('removing compiled_krb') | |
| # os.system(f'rm -rf {complied_krb_dir}/*') | |
| shutil.rmtree(complied_krb_dir) | |
| # absolute_path = os.path.abspath(complied_krb_dir) | |
| # print(absolute_path) | |
| try: | |
| engine = knowledge_engine.engine(self.cache_dir) | |
| engine.reset() | |
| engine.activate('rules') | |
| engine.get_kb('facts') | |
| # parse the logic query into pyke query | |
| predicate, subject, value_to_check = self.parse_query(self.Query[0]) | |
| result = self.check_specific_predicate(subject, predicate, engine) | |
| answer = self.answer_map[self.dataset_name](result, value_to_check) | |
| except Exception as err: | |
| return None, err | |
| return answer, '' | |
| def answer_mapping(self, answer): | |
| return answer | |
| def answer_map_prontoqa(self, result, value_to_check): | |
| if result == value_to_check: | |
| return 'A' | |
| else: | |
| return 'B' | |
| def answer_map_proofwriter(self, result, value_to_check): | |
| if result is None: | |
| return 'C' | |
| elif result == value_to_check: | |
| return 'A' | |
| else: | |
| return 'B' | |
| class LogicInferenceEngine: | |
| def __init__(self): | |
| self.dataset_name = os.environ.get('DATASET_NAME', 'ProofWriter') | |
| self.workspace_mount_path = '/workspace' | |
| def random_backup(self): | |
| if self.dataset_name == 'ProntoQA': | |
| return random.choice(['A', 'B']) | |
| elif self.dataset_name == 'ProofWriter': | |
| return random.choice(['A', 'B', 'C']) | |
| def safe_execute_program(self, logic_program): | |
| program = PykeProgram( | |
| logic_program, self.dataset_name, self.workspace_mount_path | |
| ) | |
| # cannot parse the program | |
| if not program.flag: | |
| answer = self.random_backup() | |
| return answer, 'parsing error', '' | |
| # execute the program | |
| answer, error_message = program.execute_program() | |
| # not executable | |
| if answer is None: | |
| answer = self.random_backup() | |
| return answer, 'execution error', error_message | |
| # successfully executed | |
| answer = program.answer_mapping(answer) | |
| return answer, 'success', '' | |