Skip to content

Commit

Permalink
Move AI-related logic form explorer to AISearcher
Browse files Browse the repository at this point in the history
  • Loading branch information
gsvgit committed Dec 19, 2023
1 parent 4b62478 commit 034f3c6
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 87 deletions.
62 changes: 36 additions & 26 deletions VSharp.Explorer/AISearcher.fs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,21 @@ open VSharp
open VSharp.IL.Serializer
open VSharp.ML.GameServer.Messages

type internal AISearcher(coverageToSwitchToAI: uint, oracle:Oracle, serialize:bool) =
type internal AISearcher(coverageToSwitchToAI: uint, oracle:Oracle, serialize:bool, stepsToPlay:uint, pathToSerialize:string) =
let folderToStoreSerializationResult = getFolderToStoreSerializationResult pathToSerialize
let fileForExpectedResults = getFileForExpectedResults folderToStoreSerializationResult
do
if serialize
then
System.IO.File.AppendAllLines(fileForExpectedResults, ["GraphID ExpectedStateNumber ExpectedRewardForCoveredInStep ExpectedRewardForVisitedInstructionsInStep TotalReachableRewardFromCurrentState"])
let mutable lastCollectedStatistics = Statistics()
let mutable (gameState:Option<GameState>) = None
let mutable useDefaultSearcher = coverageToSwitchToAI > 0u
let mutable afterFirstAIPeek = false
let mutable incorrectPredictedStateId = false
let defaultSearcher = BFSSearcher() :> IForwardSearcher
let defaultSearcher = BFSSearcher() :> IForwardSearcher
let mutable stepsPlayed = 0u
let isInAIMode () = (not useDefaultSearcher) && afterFirstAIPeek
let q = ResizeArray<_>()
let availableStates = HashSet<_>()
let updateGameState (delta:GameState) =
Expand Down Expand Up @@ -88,34 +96,36 @@ type internal AISearcher(coverageToSwitchToAI: uint, oracle:Oracle, serialize:bo
defaultSearcher.Pick()
elif Seq.length availableStates = 0
then None
else
else
let gameStateDelta,_ = collectGameStateDelta serialize
updateGameState gameStateDelta
let statistics = computeStatistics gameState.Value
if isInAIMode()
then
let reward = computeReward lastCollectedStatistics statistics
oracle.Feedback (Feedback.MoveReward reward)
Application.applicationGraphDelta.Clear()
lastCollectedStatistics <- statistics
let stateId, _ =
let x,y = oracle.Predict gameState.Value
x * 1u<stateId>, y
afterFirstAIPeek <- true
let state = availableStates |> Seq.tryFind (fun s -> s.internalId = stateId)
match state with
| Some state ->
Some state
| None ->
incorrectPredictedStateId <- true
oracle.Feedback (Feedback.IncorrectPredictedStateId stateId)
None
member this.LastCollectedStatistics
with get () = lastCollectedStatistics
and set v = lastCollectedStatistics <- v
member this.LastGameState with set v = gameState <- Some v
member this.ProvideOracleFeedback feedback =
if not incorrectPredictedStateId
then
oracle.Feedback feedback
incorrectPredictedStateId <- false
member this.InAIMode with get () = (not useDefaultSearcher) && afterFirstAIPeek
if stepsToPlay = stepsPlayed
then None
else
let stateId, _ =
let x,y = oracle.Predict gameState.Value
x * 1u<stateId>, y
afterFirstAIPeek <- true
let state = availableStates |> Seq.tryFind (fun s -> s.internalId = stateId)
if serialize
then
dumpGameState (System.IO.Path.Combine(folderToStoreSerializationResult , string firstFreeEpisodeNumber)) serialize
saveExpectedResult fileForExpectedResults stateId lastCollectedStatistics statistics
lastCollectedStatistics <- statistics
stepsPlayed <- stepsPlayed + 1u
match state with
| Some state ->
Some state
| None ->
incorrectPredictedStateId <- true
oracle.Feedback (Feedback.IncorrectPredictedStateId stateId)
None

interface IForwardSearcher with
override x.Init states = init states
Expand Down
58 changes: 6 additions & 52 deletions VSharp.Explorer/Explorer.fs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ type private SVMExplorer(explorationOptions: ExplorationOptions, statistics: SVM
let rec mkForwardSearcher mode =
let getRandomSeedOption() = if options.randomSeed < 0 then None else Some options.randomSeed
match mode with
| AIMode -> AISearcher(options.coverageToSwitchToAI, options.oracle.Value, options.serialize) :> IForwardSearcher
| AIMode -> AISearcher(options.coverageToSwitchToAI, options.oracle.Value, options.serialize, options.stepsToPlay, options.pathToSerialize) :> IForwardSearcher
| BFSMode -> BFSSearcher() :> IForwardSearcher
| DFSMode -> DFSSearcher() :> IForwardSearcher
| ShortestDistanceBasedMode -> ShortestDistanceBasedSearcher statistics :> IForwardSearcher
Expand Down Expand Up @@ -328,79 +328,33 @@ type private SVMExplorer(explorationOptions: ExplorationOptions, statistics: SVM
Logger.trace "UNSAT for pob = %O and s'.PC = %s" p' (API.Print.PrintPC s'.state.pc)

member private x.BidirectionalSymbolicExecution() =
let folderToStoreSerializationResult = getFolderToStoreSerializationResult options.pathToSerialize
let fileForExpectedResults = getFileForExpectedResults folderToStoreSerializationResult
if options.serialize
then
System.IO.File.AppendAllLines(fileForExpectedResults, ["GraphID ExpectedStateNumber ExpectedRewardForCoveredInStep ExpectedRewardForVisitedInstructionsInStep TotalReachableRewardFromCurrentState"])

let mutable action = Stop
let mutable stepsPlayed = 0u

let pick() =
match searcher.Pick() with
| Stop -> false
| a -> action <- a; true
(* TODO: checking for timeout here is not fine-grained enough (that is, we can work significantly beyond the
timeout, but we'll live with it for now. *)
while not isStopped && not <| isStepsLimitReached() && not <| isTimeoutReached() && pick() do
stepsCount <- stepsCount + 1
if searcher :? BidirectionalSearcher && (searcher :?> BidirectionalSearcher).ForwardSearcher :? AISearcher && ((searcher :?> BidirectionalSearcher).ForwardSearcher :?> AISearcher).InAIMode
then stepsPlayed <- stepsPlayed + 1u
stepsCount <- stepsCount + 1
if shouldReleaseBranches() then
releaseBranches()
match action with
| GoFront s ->
try
let statisticsBeforeStep =
match searcher with
| :? BidirectionalSearcher as s ->
match s.ForwardSearcher with
| :? AISearcher as s -> Some s.LastCollectedStatistics
| _ -> None
| _ -> None
let statistics1 =
if options.serialize
then Some(dumpGameState (System.IO.Path.Combine(folderToStoreSerializationResult , string firstFreeEpisodeNumber)) options.serialize)
else None
x.Forward(s)
match searcher with
| :? BidirectionalSearcher as searcher ->
match searcher.ForwardSearcher with
| :? AISearcher as searcher ->
///TODO !!! Do not use collectFullGameState
let gameState,_ = collectFullGameState options.serialize
let statisticsAfterStep = computeStatistics gameState
//searcher.LastGameState <- gameState
searcher.LastCollectedStatistics <- statisticsAfterStep
let reward = computeReward statisticsBeforeStep.Value statisticsAfterStep
if searcher.InAIMode
then searcher.ProvideOracleFeedback (Feedback.MoveReward reward)
| _ -> ()
| _ -> ()
if options.serialize
then
let gameState,_ = collectFullGameState options.serialize
let statistics2 = computeStatistics gameState
saveExpectedResult fileForExpectedResults s.internalId statistics1.Value statistics2
x.Forward(s)
with
| e ->
match searcher with
| :? BidirectionalSearcher as searcher ->
match searcher.ForwardSearcher with
| :? AISearcher as searcher ->
if searcher.InAIMode
then searcher.ProvideOracleFeedback (Feedback.MoveReward (Reward(0u<coverageReward>,0u<_>,0u<_>)))
| _ -> ()
| _ -> ()
reportStateInternalFail s e
| GoBack(s, p) ->
try
x.Backward p s
with
| e -> reportStateInternalFail s e
| Stop -> __unreachable__()
if searcher :? BidirectionalSearcher && (searcher :?> BidirectionalSearcher).ForwardSearcher :? AISearcher && (options.stepsToPlay = stepsPlayed)
then x.Stop()


member private x.AnswerPobs initialStates =
statistics.ExplorationStarted()

Expand Down
13 changes: 4 additions & 9 deletions VSharp.IL/Serializer.fs
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,6 @@ let computeStatistics (gameState:GameState) =

Statistics(coveredVerticesInZone,coveredVerticesOutOfZone,visitedVerticesInZone,visitedVerticesOutOfZone,visitedInstructionsInZone,touchedVerticesInZone,touchedVerticesOutOfZone, totalVisibleVerticesInZone)



let collectGameState (basicBlocks:ResizeArray<BasicBlock>) (serialize: bool) =

let vertices = ResizeArray<_>()
Expand Down Expand Up @@ -341,13 +339,10 @@ let collectGameStateDelta serialize =

let dumpGameState fileForResultWithoutExtension serialize =
let gameState, statesInfoToDump = collectFullGameState serialize
if serialize
then
let gameStateJson = JsonSerializer.Serialize gameState
let statesInfoJson = JsonSerializer.Serialize statesInfoToDump.Value
System.IO.File.WriteAllText(fileForResultWithoutExtension + "_gameState",gameStateJson)
System.IO.File.WriteAllText(fileForResultWithoutExtension + "_statesInfo",statesInfoJson)
computeStatistics gameState
let gameStateJson = JsonSerializer.Serialize gameState
let statesInfoJson = JsonSerializer.Serialize statesInfoToDump.Value
System.IO.File.WriteAllText(fileForResultWithoutExtension + "_gameState",gameStateJson)
System.IO.File.WriteAllText(fileForResultWithoutExtension + "_statesInfo",statesInfoJson)

let computeReward (statisticsBeforeStep:Statistics) (statisticsAfterStep:Statistics) =
let rewardForCoverage =
Expand Down

0 comments on commit 034f3c6

Please sign in to comment.