diff --git a/VSharp.Explorer/AISearcher.fs b/VSharp.Explorer/AISearcher.fs index 30b844b4a..8ca789acf 100644 --- a/VSharp.Explorer/AISearcher.fs +++ b/VSharp.Explorer/AISearcher.fs @@ -5,13 +5,21 @@ open VSharp open VSharp.IL.Serializer open VSharp.ML.GameServer.Messages -type internal AISearcher(coverageToSwitchToAI: uint, oracle:Oracle, serialize:bool) = +type internal AISearcher(coverageToSwitchToAI: uint, oracle:Oracle, serialize:bool, stepsToPlay:uint, pathToSerialize:string) = + let folderToStoreSerializationResult = getFolderToStoreSerializationResult pathToSerialize + let fileForExpectedResults = getFileForExpectedResults folderToStoreSerializationResult + do + if serialize + then + System.IO.File.AppendAllLines(fileForExpectedResults, ["GraphID ExpectedStateNumber ExpectedRewardForCoveredInStep ExpectedRewardForVisitedInstructionsInStep TotalReachableRewardFromCurrentState"]) let mutable lastCollectedStatistics = Statistics() let mutable (gameState:Option) = None let mutable useDefaultSearcher = coverageToSwitchToAI > 0u let mutable afterFirstAIPeek = false let mutable incorrectPredictedStateId = false - let defaultSearcher = BFSSearcher() :> IForwardSearcher + let defaultSearcher = BFSSearcher() :> IForwardSearcher + let mutable stepsPlayed = 0u + let isInAIMode () = (not useDefaultSearcher) && afterFirstAIPeek let q = ResizeArray<_>() let availableStates = HashSet<_>() let updateGameState (delta:GameState) = @@ -88,34 +96,36 @@ type internal AISearcher(coverageToSwitchToAI: uint, oracle:Oracle, serialize:bo defaultSearcher.Pick() elif Seq.length availableStates = 0 then None - else + else let gameStateDelta,_ = collectGameStateDelta serialize updateGameState gameStateDelta let statistics = computeStatistics gameState.Value + if isInAIMode() + then + let reward = computeReward lastCollectedStatistics statistics + oracle.Feedback (Feedback.MoveReward reward) Application.applicationGraphDelta.Clear() - lastCollectedStatistics <- statistics - let stateId, _ = - let x,y = oracle.Predict gameState.Value - x * 1u, y - afterFirstAIPeek <- true - let state = availableStates |> Seq.tryFind (fun s -> s.internalId = stateId) - match state with - | Some state -> - Some state - | None -> - incorrectPredictedStateId <- true - oracle.Feedback (Feedback.IncorrectPredictedStateId stateId) - None - member this.LastCollectedStatistics - with get () = lastCollectedStatistics - and set v = lastCollectedStatistics <- v - member this.LastGameState with set v = gameState <- Some v - member this.ProvideOracleFeedback feedback = - if not incorrectPredictedStateId - then - oracle.Feedback feedback - incorrectPredictedStateId <- false - member this.InAIMode with get () = (not useDefaultSearcher) && afterFirstAIPeek + if stepsToPlay = stepsPlayed + then None + else + let stateId, _ = + let x,y = oracle.Predict gameState.Value + x * 1u, y + afterFirstAIPeek <- true + let state = availableStates |> Seq.tryFind (fun s -> s.internalId = stateId) + if serialize + then + dumpGameState (System.IO.Path.Combine(folderToStoreSerializationResult , string firstFreeEpisodeNumber)) serialize + saveExpectedResult fileForExpectedResults stateId lastCollectedStatistics statistics + lastCollectedStatistics <- statistics + stepsPlayed <- stepsPlayed + 1u + match state with + | Some state -> + Some state + | None -> + incorrectPredictedStateId <- true + oracle.Feedback (Feedback.IncorrectPredictedStateId stateId) + None interface IForwardSearcher with override x.Init states = init states diff --git a/VSharp.Explorer/Explorer.fs b/VSharp.Explorer/Explorer.fs index 2e99a2516..cfb4d9e7f 100644 --- a/VSharp.Explorer/Explorer.fs +++ b/VSharp.Explorer/Explorer.fs @@ -79,7 +79,7 @@ type private SVMExplorer(explorationOptions: ExplorationOptions, statistics: SVM let rec mkForwardSearcher mode = let getRandomSeedOption() = if options.randomSeed < 0 then None else Some options.randomSeed match mode with - | AIMode -> AISearcher(options.coverageToSwitchToAI, options.oracle.Value, options.serialize) :> IForwardSearcher + | AIMode -> AISearcher(options.coverageToSwitchToAI, options.oracle.Value, options.serialize, options.stepsToPlay, options.pathToSerialize) :> IForwardSearcher | BFSMode -> BFSSearcher() :> IForwardSearcher | DFSMode -> DFSSearcher() :> IForwardSearcher | ShortestDistanceBasedMode -> ShortestDistanceBasedSearcher statistics :> IForwardSearcher @@ -328,13 +328,9 @@ type private SVMExplorer(explorationOptions: ExplorationOptions, statistics: SVM Logger.trace "UNSAT for pob = %O and s'.PC = %s" p' (API.Print.PrintPC s'.state.pc) member private x.BidirectionalSymbolicExecution() = - let folderToStoreSerializationResult = getFolderToStoreSerializationResult options.pathToSerialize - let fileForExpectedResults = getFileForExpectedResults folderToStoreSerializationResult - if options.serialize - then - System.IO.File.AppendAllLines(fileForExpectedResults, ["GraphID ExpectedStateNumber ExpectedRewardForCoveredInStep ExpectedRewardForVisitedInstructionsInStep TotalReachableRewardFromCurrentState"]) + let mutable action = Stop - let mutable stepsPlayed = 0u + let pick() = match searcher.Pick() with | Stop -> false @@ -342,55 +338,15 @@ type private SVMExplorer(explorationOptions: ExplorationOptions, statistics: SVM (* TODO: checking for timeout here is not fine-grained enough (that is, we can work significantly beyond the timeout, but we'll live with it for now. *) while not isStopped && not <| isStepsLimitReached() && not <| isTimeoutReached() && pick() do - stepsCount <- stepsCount + 1 - if searcher :? BidirectionalSearcher && (searcher :?> BidirectionalSearcher).ForwardSearcher :? AISearcher && ((searcher :?> BidirectionalSearcher).ForwardSearcher :?> AISearcher).InAIMode - then stepsPlayed <- stepsPlayed + 1u + stepsCount <- stepsCount + 1 if shouldReleaseBranches() then releaseBranches() match action with | GoFront s -> try - let statisticsBeforeStep = - match searcher with - | :? BidirectionalSearcher as s -> - match s.ForwardSearcher with - | :? AISearcher as s -> Some s.LastCollectedStatistics - | _ -> None - | _ -> None - let statistics1 = - if options.serialize - then Some(dumpGameState (System.IO.Path.Combine(folderToStoreSerializationResult , string firstFreeEpisodeNumber)) options.serialize) - else None - x.Forward(s) - match searcher with - | :? BidirectionalSearcher as searcher -> - match searcher.ForwardSearcher with - | :? AISearcher as searcher -> - ///TODO !!! Do not use collectFullGameState - let gameState,_ = collectFullGameState options.serialize - let statisticsAfterStep = computeStatistics gameState - //searcher.LastGameState <- gameState - searcher.LastCollectedStatistics <- statisticsAfterStep - let reward = computeReward statisticsBeforeStep.Value statisticsAfterStep - if searcher.InAIMode - then searcher.ProvideOracleFeedback (Feedback.MoveReward reward) - | _ -> () - | _ -> () - if options.serialize - then - let gameState,_ = collectFullGameState options.serialize - let statistics2 = computeStatistics gameState - saveExpectedResult fileForExpectedResults s.internalId statistics1.Value statistics2 + x.Forward(s) with | e -> - match searcher with - | :? BidirectionalSearcher as searcher -> - match searcher.ForwardSearcher with - | :? AISearcher as searcher -> - if searcher.InAIMode - then searcher.ProvideOracleFeedback (Feedback.MoveReward (Reward(0u,0u<_>,0u<_>))) - | _ -> () - | _ -> () reportStateInternalFail s e | GoBack(s, p) -> try @@ -398,9 +354,7 @@ type private SVMExplorer(explorationOptions: ExplorationOptions, statistics: SVM with | e -> reportStateInternalFail s e | Stop -> __unreachable__() - if searcher :? BidirectionalSearcher && (searcher :?> BidirectionalSearcher).ForwardSearcher :? AISearcher && (options.stepsToPlay = stepsPlayed) - then x.Stop() - + member private x.AnswerPobs initialStates = statistics.ExplorationStarted() diff --git a/VSharp.IL/Serializer.fs b/VSharp.IL/Serializer.fs index 65493c3a6..7f0e061f8 100644 --- a/VSharp.IL/Serializer.fs +++ b/VSharp.IL/Serializer.fs @@ -216,8 +216,6 @@ let computeStatistics (gameState:GameState) = Statistics(coveredVerticesInZone,coveredVerticesOutOfZone,visitedVerticesInZone,visitedVerticesOutOfZone,visitedInstructionsInZone,touchedVerticesInZone,touchedVerticesOutOfZone, totalVisibleVerticesInZone) - - let collectGameState (basicBlocks:ResizeArray) (serialize: bool) = let vertices = ResizeArray<_>() @@ -341,13 +339,10 @@ let collectGameStateDelta serialize = let dumpGameState fileForResultWithoutExtension serialize = let gameState, statesInfoToDump = collectFullGameState serialize - if serialize - then - let gameStateJson = JsonSerializer.Serialize gameState - let statesInfoJson = JsonSerializer.Serialize statesInfoToDump.Value - System.IO.File.WriteAllText(fileForResultWithoutExtension + "_gameState",gameStateJson) - System.IO.File.WriteAllText(fileForResultWithoutExtension + "_statesInfo",statesInfoJson) - computeStatistics gameState + let gameStateJson = JsonSerializer.Serialize gameState + let statesInfoJson = JsonSerializer.Serialize statesInfoToDump.Value + System.IO.File.WriteAllText(fileForResultWithoutExtension + "_gameState",gameStateJson) + System.IO.File.WriteAllText(fileForResultWithoutExtension + "_statesInfo",statesInfoJson) let computeReward (statisticsBeforeStep:Statistics) (statisticsAfterStep:Statistics) = let rewardForCoverage =