Skip to content

Commit

Permalink
Merge branch 'dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
martineberlein authored Mar 15, 2024
2 parents 8318484 + 62ed920 commit e16529d
Show file tree
Hide file tree
Showing 19 changed files with 1,171 additions and 252 deletions.
10 changes: 1 addition & 9 deletions evaluation/evaluate_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,7 @@
from avicenna.feature_extractor import DecisionTreeRelevanceLearner
from avicenna_formalizations.calculator import grammar, oracle, initial_inputs
from avicenna.evaluation_setup import EvaluationSubject


def eval_config() -> Dict[str, Any]:
return {
"grammar": grammar,
"oracle": oracle,
"initial_inputs": initial_inputs,
"feature_learner": DecisionTreeRelevanceLearner(grammar),
}
from avicenna.generator import ISLaSolverGenerator


class CalculatorSubject(EvaluationSubject):
Expand Down
53 changes: 42 additions & 11 deletions evaluation/evaluate_debugging_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,63 @@
from isla.language import ISLaUnparser

from debugging_benchmark.calculator.calculator import CalculatorBenchmarkRepository
from debugging_benchmark.student_assignments import MiddleAssignmentBenchmarkRepository
from debugging_benchmark.student_assignments import (
MiddleAssignmentBenchmarkRepository,
GCDStudentAssignmentBenchmarkRepository,
NPrStudentAssignmentBenchmarkRepository
)
from debugging_framework.benchmark import BenchmarkRepository, BenchmarkProgram

from avicenna import Avicenna
import avicenna.pattern_learner as pattern_learner


patterns = [
"""exists <?NONTERMINAL> elem_1 in start:
exists <?NONTERMINAL> elem_2 in start:
(= (str.to.int elem_1) (str.to.int elem_2))""",
"""
exists <?NONTERMINAL> elem in start:
(>= (str.to.int elem) (str.to.int <?STRING>))
"""
]

def main():

repos: List[BenchmarkRepository] = [
CalculatorBenchmarkRepository()
]
def main():
repos: List[BenchmarkRepository] = [MiddleAssignmentBenchmarkRepository()]

subjects: List[BenchmarkProgram] = []
for repo in repos:
subjects_ = repo.build()
subjects += subjects_

print(f"Number of subjects: {len(subjects)}")

for subject in subjects:
param = subject.to_dict()

avicenna = Avicenna(**param)
diagnosis = avicenna.explain()
print(f"Final Diagnosis for {subject}")
print(ISLaUnparser(diagnosis[0]).unparse())
param.update(
{
"top_n_relevant_features": 3,
"max_iterations": 10,
"log": True,
#"pattern_learner": pattern_learner.AviIslearn,
#"patterns": patterns,
}
)
try:
avicenna = Avicenna(**param)
diagnosis = avicenna.explain()
if diagnosis:
print(f"Final Diagnosis for {subject}")
print(ISLaUnparser(diagnosis[0]).unparse())
print(
f"Avicenna calculated a precision: {diagnosis[1] * 100:.2f}% and recall {diagnosis[2] * 100:.2f}%"
)
else:
print(f"No diagnosis has been learned for {subject}!")
except Exception as e:
print(f"Could not learn diagnosis for {subject}")


if __name__ == "__main__":
main()
main()
2 changes: 1 addition & 1 deletion evaluation/evaluate_middle.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_evaluation_config(self):
"grammar": self.grammar,
"oracle": self.oracle,
"initial_inputs": self.initial_inputs,
"top_n_relevant_features": 4,
"top_n_relevant_features": 3,
}
)
return param
Expand Down
84 changes: 38 additions & 46 deletions notebooks/calculator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,7 @@
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n"
]
}
],
"outputs": [],
"source": [
"from avicenna.avicenna import OracleResult"
]
Expand Down Expand Up @@ -495,15 +487,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
"sqrt(-79.40) FAILING\n",
"sqrt(-9) FAILING\n",
"sqrt(-3815.376) FAILING\n",
"sqrt(-52.11) FAILING\n",
"sqrt(-40.0) FAILING\n",
"sqrt(-12.4) FAILING\n",
"sqrt(-2.0) FAILING\n",
"sqrt(-6) FAILING\n",
"sqrt(-93.92) FAILING\n",
"sqrt(-5) FAILING\n",
"sqrt(-268057.4) FAILING\n",
"sqrt(-616) FAILING\n",
"sqrt(-12) FAILING\n",
"sqrt(-70.0) FAILING\n",
"sqrt(-32.4) FAILING\n",
"sqrt(-4.0) FAILING\n",
"sqrt(-8) FAILING\n",
"sqrt(-6) FAILING\n",
"sqrt(-94.0) FAILING\n",
"sqrt(-61.7) FAILING\n",
"sqrt(-3.2) FAILING\n",
Expand All @@ -514,15 +507,15 @@
"sqrt(-43.37) FAILING\n",
"sqrt(-3.1819496) FAILING\n",
"sqrt(-887) FAILING\n",
"sqrt(-94) FAILING\n",
"sqrt(-516.8) FAILING\n",
"sqrt(-7) FAILING\n",
"sqrt(-69.02) FAILING\n",
"sqrt(-94537.4) FAILING\n",
"sqrt(-266) FAILING\n",
"sqrt(-439.71361) FAILING\n",
"sqrt(-81.3) FAILING\n",
"sqrt(-1419.902) FAILING\n",
"sqrt(-7.0) FAILING\n",
"sqrt(-2943) FAILING\n",
"sqrt(-951.6) FAILING\n",
"sqrt(-6827) FAILING\n",
"sqrt(-84540) FAILING\n",
"sqrt(-171361.1) FAILING\n",
"sqrt(-4) FAILING\n",
"sqrt(-3) FAILING\n",
"sqrt(-5419.902) FAILING\n",
"sqrt(-3) FAILING\n",
"sqrt(-87427) FAILING\n",
"sqrt(-101) FAILING\n",
Expand All @@ -534,27 +527,26 @@
"sqrt(-4) FAILING\n",
"sqrt(-798) FAILING\n",
"sqrt(-4.0) FAILING\n",
"sqrt(-3) FAILING\n",
"sqrt(-730.72) FAILING\n",
"sqrt(-132.7) FAILING\n",
"sqrt(-7) FAILING\n",
"sqrt(-41.09) FAILING\n",
"sqrt(-35.8) FAILING\n",
"sqrt(-8641359.4) FAILING\n",
"sqrt(-6) FAILING\n",
"sqrt(-9.88) FAILING\n",
"sqrt(-2) FAILING\n",
"sqrt(-503.73909) FAILING\n",
"sqrt(-7.26) FAILING\n",
"sqrt(-8.7) FAILING\n",
"sqrt(-6.7) FAILING\n",
"sqrt(-46) FAILING\n",
"sqrt(-8) FAILING\n",
"sqrt(-256.8) FAILING\n",
"sqrt(-14.1) FAILING\n",
"sqrt(-69) FAILING\n",
"sqrt(-5) FAILING\n",
"sqrt(-4.7544) FAILING\n",
"sqrt(-9.4810313) FAILING\n",
"sqrt(-36) FAILING\n",
"sqrt(-7.2) FAILING\n",
"sqrt(-355) FAILING\n",
"sqrt(-83.6) FAILING\n",
"sqrt(-10.36) FAILING\n",
"sqrt(-61) FAILING\n",
"sqrt(-93) FAILING\n",
"sqrt(-5.7) FAILING\n",
"sqrt(-13.6) FAILING\n",
"sqrt(-866) FAILING\n",
"sqrt(-1.60) FAILING\n",
"sqrt(-8) FAILING\n"
"sqrt(-9.1315) FAILING\n",
"sqrt(-6.894) FAILING\n",
"sqrt(-8.2) FAILING\n",
"sqrt(-86.12) FAILING\n",
"sqrt(-149) FAILING\n",
"sqrt(-1.67) FAILING\n"
]
}
],
Expand Down
Loading

0 comments on commit e16529d

Please sign in to comment.