From a7678cec85365b1d4625f2220751158575637ac6 Mon Sep 17 00:00:00 2001 From: Samuel Balco Date: Tue, 25 Jul 2023 02:37:17 +0100 Subject: [PATCH] Reshuffle the EquationT transformer stack (#240) Fixes https://github.com/runtimeverification/hs-backend-booster/issues/236 by reshuffling the `EquationT` transformer, currently defined as ```Haskell newtype EquationT io a = EquationT (StateT EquationState (ReaderT EquationConfig (ExceptT EquationFailure io)) a) -- ~ EquationState -> EquationConfig -> io (Either EquationFailure (a, EquationState)) ``` into ```Haskell newtype EquationT io a = EquationT (ReaderT EquationConfig (ExceptT EquationFailure (StateT EquationState io)) a) -- ~ EquationConfig -> EquationState -> io (Either EquationFailure a, EquationState) ``` This ensures that we always return the state even after a failed computation, hence we don't need to include the traces as data within the `EquationFailure` --- library/Booster/JsonRpc.hs | 28 ++++---- library/Booster/Pattern/ApplyEquations.hs | 56 +++++++-------- library/Booster/Pattern/Rewrite.hs | 71 +++++++++---------- .../Test/Booster/Pattern/ApplyEquations.hs | 11 +-- 4 files changed, 82 insertions(+), 84 deletions(-) diff --git a/library/Booster/JsonRpc.hs b/library/Booster/JsonRpc.hs index 5d893fc22..59caa3388 100644 --- a/library/Booster/JsonRpc.hs +++ b/library/Booster/JsonRpc.hs @@ -144,7 +144,7 @@ respond stateVar = Right (TermAndPredicate pat) -> do Log.logInfoNS "booster" "Simplifying a pattern" ApplyEquations.evaluatePattern doTracing def mLlvmLibrary pat >>= \case - Right (newPattern, patternTraces) -> do + (Right newPattern, patternTraces) -> do let (t, p) = externalisePattern newPattern tSort = externaliseSort (sortOfPattern newPattern) result = maybe t (KoreJson.KJAnd tSort t) p @@ -153,22 +153,22 @@ respond stateVar = { state = addHeader result , logs = mkTraces patternTraces } - Left (ApplyEquations.SideConditionsFalse _ traces) -> do + (Left ApplyEquations.SideConditionsFalse{}, patternTraces) -> do let tSort = fromMaybe (error "unknown sort") $ sortOfJson req.state.term pure . Right . Simplify $ SimplifyResult { state = addHeader $ KoreJson.KJBottom tSort - , logs = mkTraces traces + , logs = mkTraces patternTraces } - Left (ApplyEquations.EquationLoop _traces terms) -> + (Left (ApplyEquations.EquationLoop terms), _traces) -> pure . Left . backendError RpcError.Aborted $ map externaliseTerm terms -- FIXME - Left other -> + (Left other, _traces) -> pure . Left . backendError RpcError.Aborted $ show other -- FIXME -- predicate only Right (APredicate predicate) -> do Log.logInfoNS "booster" "Simplifying a predicate" ApplyEquations.simplifyConstraint doTracing def mLlvmLibrary predicate >>= \case - Right (newPred, traces) -> do + (Right newPred, traces) -> do let predicateSort = fromMaybe (error "not a predicate") $ sortOfJson req.state.term @@ -178,7 +178,7 @@ respond stateVar = { state = addHeader result , logs = mkTraces traces } - Left something -> + (Left something, _traces) -> pure . Left . backendError RpcError.Aborted $ show something -- FIXME -- this case is only reachable if the cancel appeared as part of a batch request @@ -481,13 +481,12 @@ mkLogRewriteTrace } , origin = Booster } - RewriteSimplified (Right equationTraces) + RewriteSimplified equationTraces Nothing | fromMaybe False logSuccessfulSimplifications || fromMaybe False logFailedSimplifications -> mapM (mkLogEquationTrace equationLogOpts) equationTraces - RewriteSimplified (Left failure) - | fromMaybe False logFailedSimplifications -> Just $ - singleton $ - case failure of + RewriteSimplified equationTraces (Just failure) + | fromMaybe False logFailedSimplifications -> do + let final = singleton $ case failure of ApplyEquations.IndexIsNone trm -> Simplification { originalTerm = Just $ execStateToKoreJson $ toExecState $ Pattern trm [] @@ -502,7 +501,7 @@ mkLogRewriteTrace , origin = Booster , result = Failure{reason = "Reached iteration depth limit " <> pack (show i), _ruleId = Nothing} } - ApplyEquations.EquationLoop _ _ -> + ApplyEquations.EquationLoop _ -> Simplification { originalTerm = Nothing , originalTermIndex = Nothing @@ -516,11 +515,12 @@ mkLogRewriteTrace , origin = Booster , result = Failure{reason = "Internal error: " <> err, _ruleId = Nothing} } - ApplyEquations.SideConditionsFalse _predicates _traces -> + ApplyEquations.SideConditionsFalse _predicates -> Simplification { originalTerm = Nothing , originalTermIndex = Nothing , origin = Booster , result = Failure{reason = "Side conditions false", _ruleId = Nothing} } + (<> final) <$> mapM (mkLogEquationTrace equationLogOpts) equationTraces _ -> Nothing diff --git a/library/Booster/Pattern/ApplyEquations.hs b/library/Booster/Pattern/ApplyEquations.hs index b19b0df4f..3966e14b1 100644 --- a/library/Booster/Pattern/ApplyEquations.hs +++ b/library/Booster/Pattern/ApplyEquations.hs @@ -56,17 +56,18 @@ import Booster.Pattern.Util import Booster.Prettyprinter (renderDefault) newtype EquationT io a - = EquationT (StateT EquationState (ReaderT EquationConfig (ExceptT EquationFailure io)) a) + = EquationT (ReaderT EquationConfig (ExceptT EquationFailure (StateT EquationState io)) a) + -- ~ EquationConfig -> EquationState -> io (Either EquationFailure a, EquationState) deriving newtype (Functor, Applicative, Monad, MonadIO, MonadLogger, MonadLoggerIO) throw :: MonadLoggerIO io => EquationFailure -> EquationT io a -throw = EquationT . lift . lift . throwE +throw = EquationT . lift . throwE data EquationFailure = IndexIsNone Term | TooManyIterations Int Term Term - | EquationLoop [EquationTrace] [Term] - | SideConditionsFalse [Predicate] [EquationTrace] + | EquationLoop [Term] + | SideConditionsFalse [Predicate] | InternalError Text deriving stock (Eq, Show) @@ -160,32 +161,32 @@ startState = EquationState{termStack = [], changed = False, predicates = mempty, trace = mempty} getState :: MonadLoggerIO io => EquationT io EquationState -getState = EquationT get +getState = EquationT (lift $ lift get) getConfig :: MonadLoggerIO io => EquationT io EquationConfig -getConfig = EquationT (lift ask) +getConfig = EquationT ask countSteps :: MonadLoggerIO io => EquationT io Int countSteps = length . (.termStack) <$> getState pushTerm :: MonadLoggerIO io => Term -> EquationT io () -pushTerm t = EquationT . modify $ \s -> s{termStack = t : s.termStack} +pushTerm t = EquationT . lift . lift . modify $ \s -> s{termStack = t : s.termStack} pushConstraints :: MonadLoggerIO io => [Predicate] -> EquationT io () -pushConstraints ps = EquationT . modify $ \s -> s{predicates = s.predicates <> Set.fromList ps} +pushConstraints ps = EquationT . lift . lift . modify $ \s -> s{predicates = s.predicates <> Set.fromList ps} setChanged, resetChanged :: MonadLoggerIO io => EquationT io () -setChanged = EquationT . modify $ \s -> s{changed = True} -resetChanged = EquationT . modify $ \s -> s{changed = False} +setChanged = EquationT . lift . lift . modify $ \s -> s{changed = True} +resetChanged = EquationT . lift . lift . modify $ \s -> s{changed = False} getChanged :: MonadLoggerIO io => EquationT io Bool -getChanged = EquationT $ gets (.changed) +getChanged = EquationT $ lift $ lift $ gets (.changed) checkForLoop :: MonadLoggerIO io => Term -> EquationT io () checkForLoop t = do - EquationState{termStack, trace} <- getState + EquationState{termStack} <- getState whenJust (elemIndex t termStack) $ \i -> do - throw (EquationLoop (toList trace) . reverse $ t : take (i + 1) termStack) + throw (EquationLoop $ reverse $ t : take (i + 1) termStack) data Direction = TopDown | BottomUp deriving stock (Eq, Show) @@ -199,14 +200,13 @@ runEquationT :: KoreDefinition -> Maybe LLVM.API -> EquationT io a -> - io (Either EquationFailure (a, [EquationTrace])) + io (Either EquationFailure a, [EquationTrace]) runEquationT doTracing definition llvmApi (EquationT m) = do - endState <- - runExceptT - . flip runReaderT EquationConfig{definition, llvmApi, doTracing} - . runStateT m - $ startState - pure (fmap (toList . trace) <$> endState) + (res, endState) <- + flip runStateT startState $ + runExceptT $ + runReaderT m EquationConfig{definition, llvmApi, doTracing} + pure (res, toList $ trace endState) iterateEquations :: MonadLoggerIO io => @@ -245,7 +245,7 @@ evaluateTerm :: KoreDefinition -> Maybe LLVM.API -> Term -> - io (Either EquationFailure (Term, [EquationTrace])) + io (Either EquationFailure Term, [EquationTrace]) evaluateTerm doTracing direction def llvmApi = runEquationT doTracing def llvmApi . evaluateTerm' direction @@ -267,7 +267,7 @@ evaluatePattern :: KoreDefinition -> Maybe LLVM.API -> Pattern -> - io (Either EquationFailure (Pattern, [EquationTrace])) + io (Either EquationFailure Pattern, [EquationTrace]) evaluatePattern doTracing def mLlvmLibrary = runEquationT doTracing def mLlvmLibrary . evaluatePattern' @@ -290,7 +290,7 @@ evaluatePattern' Pattern{term, constraints} = do simplifyAssumedPredicate p = do allPs <- predicates <$> getState let otherPs = Set.delete p allPs - EquationT $ modify $ \s -> s{predicates = otherPs} + EquationT $ lift $ lift $ modify $ \s -> s{predicates = otherPs} newP <- simplifyConstraint' p pushConstraints [newP] @@ -418,7 +418,7 @@ handleFunctionEquation success continue abort = \case IndeterminateMatch -> abort IndeterminateCondition -> abort ConditionFalse -> continue - EnsuresFalse ps -> getState >>= throw . SideConditionsFalse ps . toList . trace + EnsuresFalse ps -> throw $ SideConditionsFalse ps RuleNotPreservingDefinedness -> abort MatchConstraintViolated{} -> continue @@ -429,7 +429,7 @@ handleSimplificationEquation success continue _abort = \case IndeterminateMatch -> continue IndeterminateCondition -> continue ConditionFalse -> continue - EnsuresFalse ps -> getState >>= throw . SideConditionsFalse ps . toList . trace + EnsuresFalse ps -> throw $ SideConditionsFalse ps RuleNotPreservingDefinedness -> continue MatchConstraintViolated{} -> continue @@ -488,7 +488,7 @@ traceRuleApplication t loc lbl uid res = do logOther (LevelOther "Simplify") (pack . renderDefault . pretty $ newTraceItem) config <- getConfig when (config.doTracing) $ - EquationT . modify $ + EquationT . lift . lift . modify $ \s -> s{trace = s.trace :|> newTraceItem} applyEquation :: @@ -628,7 +628,7 @@ simplifyConstraint :: KoreDefinition -> Maybe LLVM.API -> Predicate -> - io (Either EquationFailure (Predicate, [EquationTrace])) + io (Either EquationFailure Predicate, [EquationTrace]) simplifyConstraint doTracing def mbApi p = runEquationT doTracing def mbApi $ simplifyConstraint' p @@ -666,5 +666,5 @@ simplifyConstraint' = \case evalBool t = do prior <- getState -- save prior state so we can revert result <- iterateEquations 100 TopDown PreferFunctions t - EquationT $ put prior + EquationT $ lift $ lift $ put prior pure result diff --git a/library/Booster/Pattern/Rewrite.hs b/library/Booster/Pattern/Rewrite.hs index a519f5319..a03c06031 100644 --- a/library/Booster/Pattern/Rewrite.hs +++ b/library/Booster/Pattern/Rewrite.hs @@ -41,8 +41,6 @@ import Booster.Pattern.ApplyEquations ( EquationFailure (..), EquationTrace, evaluatePattern, - isMatchFailure, - isSuccess, simplifyConstraint, ) import Booster.Pattern.Base @@ -228,11 +226,11 @@ applyRule pat rule = runMaybeT $ do MaybeT (RewriteT io (RewriteFailed k)) (Maybe a) checkConstraint onUnclear p = do RewriteConfig{definition, llvmApi, doTracing} <- lift $ RewriteT ask - simplified <- simplifyConstraint doTracing definition llvmApi p + (simplified, _traces) <- simplifyConstraint doTracing definition llvmApi p case simplified of - Right (Bottom, _) -> fail "Rule condition was False" - Right (Top, _) -> pure Nothing - Right (other, _) -> pure $ Just $ onUnclear other + Right Bottom -> fail "Rule condition was False" + Right Top -> pure Nothing + Right other -> pure $ Just $ onUnclear other Left _ -> pure $ Just $ onUnclear p {- | Reason why a rewrite did not produce a result. Contains additional @@ -337,7 +335,7 @@ data RewriteTrace pat | -- | attempted rewrite failed RewriteStepFailed (RewriteFailed "Rewrite") | -- | Applied simplification to the pattern - RewriteSimplified (Either EquationFailure [EquationTrace]) + RewriteSimplified [EquationTrace] (Maybe EquationFailure) deriving stock (Eq, Show) deriving (Functor, Foldable, Traversable) @@ -440,36 +438,35 @@ performRewrite doTracing def mLlvmLibrary mbMaxDepth cutLabels terminalLabels pa simplifyP :: Pattern -> StateT (Natural, Seq (RewriteTrace Pattern)) io (Maybe Pattern) simplifyP p = - evaluatePattern doTracing def mLlvmLibrary p >>= \case - Right (newPattern, traces) -> do - logTraces $ filter (not . isMatchFailure) traces - rewriteTrace $ RewriteSimplified $ Right traces - pure $ Just newPattern - Left r@(SideConditionsFalse _ps traces) -> do - logSimplify "Side conditions were found to be false, pruning" - logTraces $ filter (not . isMatchFailure) traces - rewriteTrace $ RewriteSimplified $ Left r - pure Nothing - -- NB any errors here might be caused by simplifying one - -- of the constraints, so we cannot use partial results - -- and have to return the original on errors. - Left r@(TooManyIterations n _start _result) -> do - logWarn $ "Simplification unable to finish in " <> prettyText n <> " steps." - -- could output term before and after at debug or custom log level - rewriteTrace $ RewriteSimplified $ Left r - pure $ Just p - Left r@(EquationLoop traces (t : ts)) -> do - let termDiffs = zipWith (curry mkDiffTerms) (t : ts) ts - logError "Equation evaluation loop" - logTraces $ filter isSuccess traces - logSimplify $ - "produced the evaluation loop: " <> Text.unlines (map (prettyText . fst) termDiffs) - rewriteTrace $ RewriteSimplified $ Left r - pure $ Just p - Left other -> do - logError . pack $ "Simplification error during rewrite: " <> show other - rewriteTrace $ RewriteSimplified $ Left other - pure $ Just p + evaluatePattern doTracing def mLlvmLibrary p >>= \(res, traces) -> do + logTraces traces + case res of + Right newPattern -> do + rewriteTrace $ RewriteSimplified traces Nothing + pure $ Just newPattern + Left r@(SideConditionsFalse _ps) -> do + logSimplify "Side conditions were found to be false, pruning" + rewriteTrace $ RewriteSimplified traces (Just r) + pure Nothing + -- NB any errors here might be caused by simplifying one + -- of the constraints, so we cannot use partial results + -- and have to return the original on errors. + Left r@(TooManyIterations n _start _result) -> do + logWarn $ "Simplification unable to finish in " <> prettyText n <> " steps." + -- could output term before and after at debug or custom log level + rewriteTrace $ RewriteSimplified traces (Just r) + pure $ Just p + Left r@(EquationLoop (t : ts)) -> do + let termDiffs = zipWith (curry mkDiffTerms) (t : ts) ts + logError "Equation evaluation loop" + logSimplify $ + "produced the evaluation loop: " <> Text.unlines (map (prettyText . fst) termDiffs) + rewriteTrace $ RewriteSimplified traces (Just r) + pure $ Just p + Left other -> do + logError . pack $ "Simplification error during rewrite: " <> show other + rewriteTrace $ RewriteSimplified traces (Just other) + pure $ Just p -- Results may change when simplification prunes a false side -- condition, otherwise this would mainly be fmap simplifyP diff --git a/unit-tests/Test/Booster/Pattern/ApplyEquations.hs b/unit-tests/Test/Booster/Pattern/ApplyEquations.hs index 2879226c3..d2b366d5f 100644 --- a/unit-tests/Test/Booster/Pattern/ApplyEquations.hs +++ b/unit-tests/Test/Booster/Pattern/ApplyEquations.hs @@ -87,7 +87,7 @@ test_evaluateFunction = eval BottomUp subj @?= Right result ] where - eval direction = fmap fst . unsafePerformIO . runNoLoggingT . evaluateTerm False direction funDef Nothing + eval direction = unsafePerformIO . runNoLoggingT . (fst <$>) . evaluateTerm False direction funDef Nothing isTooManyIterations (Left (TooManyIterations _n _ _)) = pure () isTooManyIterations (Left err) = assertFailure $ "Unexpected error " <> show err @@ -114,7 +114,7 @@ test_simplify = simpl BottomUp subj @?= Right result ] where - simpl direction = fmap fst . unsafePerformIO . runNoLoggingT . evaluateTerm False direction simplDef Nothing + simpl direction = unsafePerformIO . runNoLoggingT . (fst <$>) . evaluateTerm False direction simplDef Nothing a = var "A" someSort test_simplifyPattern :: TestTree @@ -143,7 +143,7 @@ test_simplifyPattern = simpl subj @?= Right result ] where - simpl = fmap fst . unsafePerformIO . runNoLoggingT . evaluatePattern False simplDef Nothing + simpl = unsafePerformIO . runNoLoggingT . (fst <$>) . evaluatePattern False simplDef Nothing a = var "A" someSort test_errors :: TestTree @@ -156,10 +156,11 @@ test_errors = subj = f $ app con1 [a] loopTerms = [f $ app con1 [a], f $ app con2 [a], f $ app con3 [a, a], f $ app con1 [a]] - isLoop loopTerms . unsafePerformIO . runNoLoggingT $ evaluateTerm False TopDown loopDef Nothing subj + isLoop loopTerms . unsafePerformIO . runNoLoggingT $ + fst <$> evaluateTerm False TopDown loopDef Nothing subj ] where - isLoop ts (Left (EquationLoop _ ts')) = ts @?= ts' + isLoop ts (Left (EquationLoop ts')) = ts @?= ts' isLoop _ (Left err) = assertFailure $ "Unexpected error " <> show err isLoop _ (Right r) = assertFailure $ "Unexpected result " <> show r