Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 39 additions & 31 deletions popper/tester.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os
import time
from typing import AnyStr, Any, Dict

import pkg_resources
from janus_swi import query_once, consult
from functools import cache
from contextlib import contextmanager
from . util import order_prog, prog_is_recursive, rule_is_recursive, calc_rule_size, calc_prog_size, prog_hash, format_rule, format_literal, Literal
from .util import order_prog, prog_is_recursive, rule_is_recursive, calc_rule_size, calc_prog_size, prog_hash, \
format_rule, format_literal, Literal, Settings
from bitarray import bitarray, frozenbitarray
from bitarray.util import ones
from collections import defaultdict
Expand All @@ -19,6 +22,11 @@ def bool_query(query):

class Tester():

settings: Settings
cached_pos_covered: Dict[int, frozenbitarray]
neg_fact_str: str
neg_literal_set: frozenset

def __init__(self, settings):
self.settings = settings

Expand Down Expand Up @@ -59,7 +67,7 @@ def __init__(self, settings):
self.pos_examples_ = ones(self.num_pos)

self.cached_pos_covered = {}
self.cached_inconsistent = {}
# self.cached_inconsistent = {} -- never set or referenced.

if self.settings.recursion_enabled:
query_once(f'assert(timeout({self.settings.eval_timeout})), fail')
Expand Down Expand Up @@ -374,35 +382,35 @@ def has_redundant_literal(self, prog):
# print(q, False)
return False

# # WE ASSUME THAT THERE IS A REUNDANT RULE
def find_redundant_rule_(self, prog):
prog_ = []
for i, (head, body) in enumerate(prog):
c = f"{i}-[{','.join(('not_'+ format_literal(head),) + tuple(format_literal(lit) for lit in body))}]"
prog_.append(c)
prog_ = f"[{','.join(prog_)}]"
prog_ = janus_format_rule(prog_)
q = f'find_redundant_rule({prog_}, K1, K2)'
res = query_once(q)
k1 = res['K1']
k2 = res['K2']
return prog[k1], prog[k2]

def find_redundant_rules(self, prog):
# assert(False)
# AC: if the overhead of this call becomes too high, such as when learning programs with lots of clauses, we can improve it by not comparing already compared clauses
base = []
step = []
for rule in prog:
if rule_is_recursive(rule):
step.append(rule)
else:
base.append(rule)
if len(base) > 1 and self.has_redundant_rule(base):
return self.find_redundant_rule_(base)
if len(step) > 1 and self.has_redundant_rule(step):
return self.find_redundant_rule_(step)
return None
# # WE ASSUME THAT THERE IS A REDUNDANT RULE
# def find_redundant_rule_(self, prog):
# prog_ = []
# for i, (head, body) in enumerate(prog):
# c = f"{i}-[{','.join(('not_'+ format_literal(head),) + tuple(format_literal(lit) for lit in body))}]"
# prog_.append(c)
# prog_ = f"[{','.join(prog_)}]"
# prog_ = janus_format_rule(prog_)
# q = f'find_redundant_rule({prog_}, K1, K2)'
# res = query_once(q)
# k1 = res['K1']
# k2 = res['K2']
# return prog[k1], prog[k2]
#
# def find_redundant_rules(self, prog):
# # assert(False)
# # AC: if the overhead of this call becomes too high, such as when learning programs with lots of clauses, we can improve it by not comparing already compared clauses
# base = []
# step = []
# for rule in prog:
# if rule_is_recursive(rule):
# step.append(rule)
# else:
# base.append(rule)
# if len(base) > 1 and self.has_redundant_rule(base):
# return self.find_redundant_rule_(base)
# if len(step) > 1 and self.has_redundant_rule(step):
# return self.find_redundant_rule_(step)
# return None

def find_pointless_relations(self):
settings = self.settings
Expand Down