Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into georgy/simplify-impli…
Browse files Browse the repository at this point in the history
…cation
  • Loading branch information
geo2a committed Jul 25, 2023
2 parents 42c9c44 + a7678ce commit 1d2e66d
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 84 deletions.
28 changes: 14 additions & 14 deletions library/Booster/JsonRpc.hs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ respond stateVar =
Right (TermAndPredicate pat) -> do
Log.logInfoNS "booster" "Simplifying a pattern"
ApplyEquations.evaluatePattern doTracing def mLlvmLibrary pat >>= \case
Right (newPattern, patternTraces) -> do
(Right newPattern, patternTraces) -> do
let (t, p) = externalisePattern newPattern
tSort = externaliseSort (sortOfPattern newPattern)
result = maybe t (KoreJson.KJAnd tSort t) p
Expand All @@ -155,22 +155,22 @@ respond stateVar =
{ state = addHeader result
, logs = mkTraces patternTraces
}
Left (ApplyEquations.SideConditionsFalse _ traces) -> do
(Left ApplyEquations.SideConditionsFalse{}, patternTraces) -> do
let tSort = fromMaybe (error "unknown sort") $ sortOfJson req.state.term
pure . Right . Simplify $
SimplifyResult
{ state = addHeader $ KoreJson.KJBottom tSort
, logs = mkTraces traces
, logs = mkTraces patternTraces
}
Left (ApplyEquations.EquationLoop _traces terms) ->
(Left (ApplyEquations.EquationLoop terms), _traces) ->
pure . Left . backendError RpcError.Aborted $ map externaliseTerm terms -- FIXME
Left other ->
(Left other, _traces) ->
pure . Left . backendError RpcError.Aborted $ show other -- FIXME
-- predicate only
Right (APredicate predicate) -> do
Log.logInfoNS "booster" "Simplifying a predicate"
ApplyEquations.simplifyConstraint doTracing def mLlvmLibrary predicate >>= \case
Right (newPred, traces) -> do
(Right newPred, traces) -> do
let predicateSort =
fromMaybe (error "not a predicate") $
sortOfJson req.state.term
Expand All @@ -180,7 +180,7 @@ respond stateVar =
{ state = addHeader result
, logs = mkTraces traces
}
Left something ->
(Left something, _traces) ->
pure . Left . backendError RpcError.Aborted $ show something -- FIXME
SimplifyImplies req -> withContext req._module $ \(def, mLlvmLibrary) -> do
let logTraces =
Expand Down Expand Up @@ -525,13 +525,12 @@ mkLogRewriteTrace
}
, origin = Booster
}
RewriteSimplified (Right equationTraces)
RewriteSimplified equationTraces Nothing
| fromMaybe False logSuccessfulSimplifications || fromMaybe False logFailedSimplifications ->
mapM (mkLogEquationTrace equationLogOpts) equationTraces
RewriteSimplified (Left failure)
| fromMaybe False logFailedSimplifications -> Just $
singleton $
case failure of
RewriteSimplified equationTraces (Just failure)
| fromMaybe False logFailedSimplifications -> do
let final = singleton $ case failure of
ApplyEquations.IndexIsNone trm ->
Simplification
{ originalTerm = Just $ execStateToKoreJson $ toExecState $ Pattern trm []
Expand All @@ -546,7 +545,7 @@ mkLogRewriteTrace
, origin = Booster
, result = Failure{reason = "Reached iteration depth limit " <> pack (show i), _ruleId = Nothing}
}
ApplyEquations.EquationLoop _ _ ->
ApplyEquations.EquationLoop _ ->
Simplification
{ originalTerm = Nothing
, originalTermIndex = Nothing
Expand All @@ -560,11 +559,12 @@ mkLogRewriteTrace
, origin = Booster
, result = Failure{reason = "Internal error: " <> err, _ruleId = Nothing}
}
ApplyEquations.SideConditionsFalse _predicates _traces ->
ApplyEquations.SideConditionsFalse _predicates ->
Simplification
{ originalTerm = Nothing
, originalTermIndex = Nothing
, origin = Booster
, result = Failure{reason = "Side conditions false", _ruleId = Nothing}
}
(<> final) <$> mapM (mkLogEquationTrace equationLogOpts) equationTraces
_ -> Nothing
56 changes: 28 additions & 28 deletions library/Booster/Pattern/ApplyEquations.hs
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,18 @@ import Booster.Pattern.Util
import Booster.Prettyprinter (renderDefault)

newtype EquationT io a
= EquationT (StateT EquationState (ReaderT EquationConfig (ExceptT EquationFailure io)) a)
= EquationT (ReaderT EquationConfig (ExceptT EquationFailure (StateT EquationState io)) a)
-- ~ EquationConfig -> EquationState -> io (Either EquationFailure a, EquationState)
deriving newtype (Functor, Applicative, Monad, MonadIO, MonadLogger, MonadLoggerIO)

throw :: MonadLoggerIO io => EquationFailure -> EquationT io a
throw = EquationT . lift . lift . throwE
throw = EquationT . lift . throwE

data EquationFailure
= IndexIsNone Term
| TooManyIterations Int Term Term
| EquationLoop [EquationTrace] [Term]
| SideConditionsFalse [Predicate] [EquationTrace]
| EquationLoop [Term]
| SideConditionsFalse [Predicate]
| InternalError Text
deriving stock (Eq, Show)

Expand Down Expand Up @@ -161,32 +162,32 @@ startState =
EquationState{termStack = [], changed = False, predicates = mempty, trace = mempty}

getState :: MonadLoggerIO io => EquationT io EquationState
getState = EquationT get
getState = EquationT (lift $ lift get)

getConfig :: MonadLoggerIO io => EquationT io EquationConfig
getConfig = EquationT (lift ask)
getConfig = EquationT ask

countSteps :: MonadLoggerIO io => EquationT io Int
countSteps = length . (.termStack) <$> getState

pushTerm :: MonadLoggerIO io => Term -> EquationT io ()
pushTerm t = EquationT . modify $ \s -> s{termStack = t : s.termStack}
pushTerm t = EquationT . lift . lift . modify $ \s -> s{termStack = t : s.termStack}

pushConstraints :: MonadLoggerIO io => [Predicate] -> EquationT io ()
pushConstraints ps = EquationT . modify $ \s -> s{predicates = s.predicates <> Set.fromList ps}
pushConstraints ps = EquationT . lift . lift . modify $ \s -> s{predicates = s.predicates <> Set.fromList ps}

setChanged, resetChanged :: MonadLoggerIO io => EquationT io ()
setChanged = EquationT . modify $ \s -> s{changed = True}
resetChanged = EquationT . modify $ \s -> s{changed = False}
setChanged = EquationT . lift . lift . modify $ \s -> s{changed = True}
resetChanged = EquationT . lift . lift . modify $ \s -> s{changed = False}

getChanged :: MonadLoggerIO io => EquationT io Bool
getChanged = EquationT $ gets (.changed)
getChanged = EquationT $ lift $ lift $ gets (.changed)

checkForLoop :: MonadLoggerIO io => Term -> EquationT io ()
checkForLoop t = do
EquationState{termStack, trace} <- getState
EquationState{termStack} <- getState
whenJust (elemIndex t termStack) $ \i -> do
throw (EquationLoop (toList trace) . reverse $ t : take (i + 1) termStack)
throw (EquationLoop $ reverse $ t : take (i + 1) termStack)

data Direction = TopDown | BottomUp
deriving stock (Eq, Show)
Expand All @@ -200,14 +201,13 @@ runEquationT ::
KoreDefinition ->
Maybe LLVM.API ->
EquationT io a ->
io (Either EquationFailure (a, [EquationTrace]))
io (Either EquationFailure a, [EquationTrace])
runEquationT doTracing definition llvmApi (EquationT m) = do
endState <-
runExceptT
. flip runReaderT EquationConfig{definition, llvmApi, doTracing}
. runStateT m
$ startState
pure (fmap (toList . trace) <$> endState)
(res, endState) <-
flip runStateT startState $
runExceptT $
runReaderT m EquationConfig{definition, llvmApi, doTracing}
pure (res, toList $ trace endState)

iterateEquations ::
MonadLoggerIO io =>
Expand Down Expand Up @@ -246,7 +246,7 @@ evaluateTerm ::
KoreDefinition ->
Maybe LLVM.API ->
Term ->
io (Either EquationFailure (Term, [EquationTrace]))
io (Either EquationFailure Term, [EquationTrace])
evaluateTerm doTracing direction def llvmApi =
runEquationT doTracing def llvmApi
. evaluateTerm' direction
Expand All @@ -268,7 +268,7 @@ evaluatePattern ::
KoreDefinition ->
Maybe LLVM.API ->
Pattern ->
io (Either EquationFailure (Pattern, [EquationTrace]))
io (Either EquationFailure Pattern, [EquationTrace])
evaluatePattern doTracing def mLlvmLibrary =
runEquationT doTracing def mLlvmLibrary . evaluatePattern'

Expand All @@ -291,7 +291,7 @@ evaluatePattern' Pattern{term, constraints} = do
simplifyAssumedPredicate p = do
allPs <- predicates <$> getState
let otherPs = Set.delete p allPs
EquationT $ modify $ \s -> s{predicates = otherPs}
EquationT $ lift $ lift $ modify $ \s -> s{predicates = otherPs}
newP <- simplifyConstraint' p
pushConstraints [newP]

Expand Down Expand Up @@ -419,7 +419,7 @@ handleFunctionEquation success continue abort = \case
IndeterminateMatch -> abort
IndeterminateCondition -> abort
ConditionFalse -> continue
EnsuresFalse ps -> getState >>= throw . SideConditionsFalse ps . toList . trace
EnsuresFalse ps -> throw $ SideConditionsFalse ps
RuleNotPreservingDefinedness -> abort
MatchConstraintViolated{} -> continue

Expand All @@ -430,7 +430,7 @@ handleSimplificationEquation success continue _abort = \case
IndeterminateMatch -> continue
IndeterminateCondition -> continue
ConditionFalse -> continue
EnsuresFalse ps -> getState >>= throw . SideConditionsFalse ps . toList . trace
EnsuresFalse ps -> throw $ SideConditionsFalse ps
RuleNotPreservingDefinedness -> continue
MatchConstraintViolated{} -> continue

Expand Down Expand Up @@ -489,7 +489,7 @@ traceRuleApplication t loc lbl uid res = do
logOther (LevelOther "Simplify") (pack . renderDefault . pretty $ newTraceItem)
config <- getConfig
when (config.doTracing) $
EquationT . modify $
EquationT . lift . lift . modify $
\s -> s{trace = s.trace :|> newTraceItem}

applyEquation ::
Expand Down Expand Up @@ -629,7 +629,7 @@ simplifyConstraint ::
KoreDefinition ->
Maybe LLVM.API ->
Predicate ->
io (Either EquationFailure (Predicate, [EquationTrace]))
io (Either EquationFailure Predicate, [EquationTrace])
simplifyConstraint doTracing def mbApi p =
runEquationT doTracing def mbApi $ simplifyConstraint' p

Expand Down Expand Up @@ -667,7 +667,7 @@ simplifyConstraint' = \case
evalBool t = do
prior <- getState -- save prior state so we can revert
result <- iterateEquations 100 TopDown PreferFunctions t
EquationT $ put prior
EquationT $ lift $ lift $ put prior
pure result

--------------------------------------------------------------------
Expand Down
71 changes: 34 additions & 37 deletions library/Booster/Pattern/Rewrite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ import Booster.Pattern.ApplyEquations (
EquationFailure (..),
EquationTrace,
evaluatePattern,
isMatchFailure,
isSuccess,
simplifyConstraint,
)
import Booster.Pattern.Base
Expand Down Expand Up @@ -228,11 +226,11 @@ applyRule pat rule = runMaybeT $ do
MaybeT (RewriteT io (RewriteFailed k)) (Maybe a)
checkConstraint onUnclear p = do
RewriteConfig{definition, llvmApi, doTracing} <- lift $ RewriteT ask
simplified <- simplifyConstraint doTracing definition llvmApi p
(simplified, _traces) <- simplifyConstraint doTracing definition llvmApi p
case simplified of
Right (Bottom, _) -> fail "Rule condition was False"
Right (Top, _) -> pure Nothing
Right (other, _) -> pure $ Just $ onUnclear other
Right Bottom -> fail "Rule condition was False"
Right Top -> pure Nothing
Right other -> pure $ Just $ onUnclear other
Left _ -> pure $ Just $ onUnclear p

{- | Reason why a rewrite did not produce a result. Contains additional
Expand Down Expand Up @@ -337,7 +335,7 @@ data RewriteTrace pat
| -- | attempted rewrite failed
RewriteStepFailed (RewriteFailed "Rewrite")
| -- | Applied simplification to the pattern
RewriteSimplified (Either EquationFailure [EquationTrace])
RewriteSimplified [EquationTrace] (Maybe EquationFailure)
deriving stock (Eq, Show)
deriving (Functor, Foldable, Traversable)

Expand Down Expand Up @@ -440,36 +438,35 @@ performRewrite doTracing def mLlvmLibrary mbMaxDepth cutLabels terminalLabels pa

simplifyP :: Pattern -> StateT (Natural, Seq (RewriteTrace Pattern)) io (Maybe Pattern)
simplifyP p =
evaluatePattern doTracing def mLlvmLibrary p >>= \case
Right (newPattern, traces) -> do
logTraces $ filter (not . isMatchFailure) traces
rewriteTrace $ RewriteSimplified $ Right traces
pure $ Just newPattern
Left r@(SideConditionsFalse _ps traces) -> do
logSimplify "Side conditions were found to be false, pruning"
logTraces $ filter (not . isMatchFailure) traces
rewriteTrace $ RewriteSimplified $ Left r
pure Nothing
-- NB any errors here might be caused by simplifying one
-- of the constraints, so we cannot use partial results
-- and have to return the original on errors.
Left r@(TooManyIterations n _start _result) -> do
logWarn $ "Simplification unable to finish in " <> prettyText n <> " steps."
-- could output term before and after at debug or custom log level
rewriteTrace $ RewriteSimplified $ Left r
pure $ Just p
Left r@(EquationLoop traces (t : ts)) -> do
let termDiffs = zipWith (curry mkDiffTerms) (t : ts) ts
logError "Equation evaluation loop"
logTraces $ filter isSuccess traces
logSimplify $
"produced the evaluation loop: " <> Text.unlines (map (prettyText . fst) termDiffs)
rewriteTrace $ RewriteSimplified $ Left r
pure $ Just p
Left other -> do
logError . pack $ "Simplification error during rewrite: " <> show other
rewriteTrace $ RewriteSimplified $ Left other
pure $ Just p
evaluatePattern doTracing def mLlvmLibrary p >>= \(res, traces) -> do
logTraces traces
case res of
Right newPattern -> do
rewriteTrace $ RewriteSimplified traces Nothing
pure $ Just newPattern
Left r@(SideConditionsFalse _ps) -> do
logSimplify "Side conditions were found to be false, pruning"
rewriteTrace $ RewriteSimplified traces (Just r)
pure Nothing
-- NB any errors here might be caused by simplifying one
-- of the constraints, so we cannot use partial results
-- and have to return the original on errors.
Left r@(TooManyIterations n _start _result) -> do
logWarn $ "Simplification unable to finish in " <> prettyText n <> " steps."
-- could output term before and after at debug or custom log level
rewriteTrace $ RewriteSimplified traces (Just r)
pure $ Just p
Left r@(EquationLoop (t : ts)) -> do
let termDiffs = zipWith (curry mkDiffTerms) (t : ts) ts
logError "Equation evaluation loop"
logSimplify $
"produced the evaluation loop: " <> Text.unlines (map (prettyText . fst) termDiffs)
rewriteTrace $ RewriteSimplified traces (Just r)
pure $ Just p
Left other -> do
logError . pack $ "Simplification error during rewrite: " <> show other
rewriteTrace $ RewriteSimplified traces (Just other)
pure $ Just p

-- Results may change when simplification prunes a false side
-- condition, otherwise this would mainly be fmap simplifyP
Expand Down
11 changes: 6 additions & 5 deletions unit-tests/Test/Booster/Pattern/ApplyEquations.hs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ test_evaluateFunction =
eval BottomUp subj @?= Right result
]
where
eval direction = fmap fst . unsafePerformIO . runNoLoggingT . evaluateTerm False direction funDef Nothing
eval direction = unsafePerformIO . runNoLoggingT . (fst <$>) . evaluateTerm False direction funDef Nothing

isTooManyIterations (Left (TooManyIterations _n _ _)) = pure ()
isTooManyIterations (Left err) = assertFailure $ "Unexpected error " <> show err
Expand All @@ -114,7 +114,7 @@ test_simplify =
simpl BottomUp subj @?= Right result
]
where
simpl direction = fmap fst . unsafePerformIO . runNoLoggingT . evaluateTerm False direction simplDef Nothing
simpl direction = unsafePerformIO . runNoLoggingT . (fst <$>) . evaluateTerm False direction simplDef Nothing
a = var "A" someSort

test_simplifyPattern :: TestTree
Expand Down Expand Up @@ -143,7 +143,7 @@ test_simplifyPattern =
simpl subj @?= Right result
]
where
simpl = fmap fst . unsafePerformIO . runNoLoggingT . evaluatePattern False simplDef Nothing
simpl = unsafePerformIO . runNoLoggingT . (fst <$>) . evaluatePattern False simplDef Nothing
a = var "A" someSort

test_errors :: TestTree
Expand All @@ -156,10 +156,11 @@ test_errors =
subj = f $ app con1 [a]
loopTerms =
[f $ app con1 [a], f $ app con2 [a], f $ app con3 [a, a], f $ app con1 [a]]
isLoop loopTerms . unsafePerformIO . runNoLoggingT $ evaluateTerm False TopDown loopDef Nothing subj
isLoop loopTerms . unsafePerformIO . runNoLoggingT $
fst <$> evaluateTerm False TopDown loopDef Nothing subj
]
where
isLoop ts (Left (EquationLoop _ ts')) = ts @?= ts'
isLoop ts (Left (EquationLoop ts')) = ts @?= ts'
isLoop _ (Left err) = assertFailure $ "Unexpected error " <> show err
isLoop _ (Right r) = assertFailure $ "Unexpected result " <> show r

Expand Down

0 comments on commit 1d2e66d

Please sign in to comment.