Skip to content

Commit

Permalink
Save steps with sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
Parzival-05 committed Jan 7, 2025
1 parent f732fb4 commit 02c7e74
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 61 deletions.
161 changes: 102 additions & 59 deletions VSharp.Explorer/AISearcher.fs
Original file line number Diff line number Diff line change
Expand Up @@ -2,64 +2,25 @@ namespace VSharp.Explorer

open System.Collections.Generic
open Microsoft.ML.OnnxRuntime
open System.IO
open System
open System.Net
open System.Net.Sockets
open System.Text
open System.Text.Json
open VSharp
open VSharp.IL.Serializer
open VSharp.ML.GameServer.Messages
open System.IO

type AIMode =
| Runner
| TrainingSendModel
| TrainingSendEachStep

type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrainingMode>) =
let stepsToSwitchToAI =
match aiAgentTrainingMode with
| None -> 0u<step>
| Some (SendModel options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI
| Some (SendEachStep options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI

let stepsToPlay =
match aiAgentTrainingMode with
| None -> 0u<step>
| Some (SendModel options) -> options.aiAgentTrainingOptions.stepsToPlay
| Some (SendEachStep options) -> options.aiAgentTrainingOptions.stepsToPlay

let mutable lastCollectedStatistics =
Statistics ()
let mutable defaultSearcherSteps = 0u<step>
let mutable (gameState: Option<GameState>) =
None
let mutable useDefaultSearcher =
stepsToSwitchToAI > 0u<step>
let mutable afterFirstAIPeek = false
let mutable incorrectPredictedStateId =
false

let defaultSearcher =
let pickSearcher =
function
| BFSMode -> BFSSearcher () :> IForwardSearcher
| DFSMode -> DFSSearcher () :> IForwardSearcher
| x -> failwithf $"Unexpected default searcher {x}. DFS and BFS supported for now."

match aiAgentTrainingMode with
| None -> BFSSearcher () :> IForwardSearcher
| Some (SendModel options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy
| Some (SendEachStep options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy

let mutable stepsPlayed = 0u<step>

let isInAIMode () =
(not useDefaultSearcher) && afterFirstAIPeek

let q = ResizeArray<_> ()
let availableStates = HashSet<_> ()

let updateGameState (delta: GameState) =
module GameUtils =
let updateGameState (delta: GameState) (gameState: Option<GameState>) =
match gameState with
| None -> gameState <- Some delta
| None -> Some delta
| Some s ->
let updatedBasicBlocks =
delta.GraphVertices |> Array.map (fun b -> b.Id) |> HashSet
Expand Down Expand Up @@ -106,7 +67,52 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
)
)

gameState <- Some <| GameState (vertices.ToArray (), states, edges.ToArray ())
Some <| GameState (vertices.ToArray (), states, edges.ToArray ())
type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrainingMode>) =
let stepsToSwitchToAI =
match aiAgentTrainingMode with
| None -> 0u<step>
| Some (SendModel options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI
| Some (SendEachStep options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI

let stepsToPlay =
match aiAgentTrainingMode with
| None -> 0u<step>
| Some (SendModel options) -> options.aiAgentTrainingOptions.stepsToPlay
| Some (SendEachStep options) -> options.aiAgentTrainingOptions.stepsToPlay

let mutable lastCollectedStatistics =
Statistics ()
let mutable defaultSearcherSteps = 0u<step>
let mutable (gameState: Option<GameState>) =
None
let mutable useDefaultSearcher =
stepsToSwitchToAI > 0u<step>
let mutable afterFirstAIPeek = false
let mutable incorrectPredictedStateId =
false

let defaultSearcher =
let pickSearcher =
function
| BFSMode -> BFSSearcher () :> IForwardSearcher
| DFSMode -> DFSSearcher () :> IForwardSearcher
| x -> failwithf $"Unexpected default searcher {x}. DFS and BFS supported for now."

match aiAgentTrainingMode with
| None -> BFSSearcher () :> IForwardSearcher
| Some (SendModel options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy
| Some (SendEachStep options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy

let mutable stepsPlayed = 0u<step>

let isInAIMode () =
(not useDefaultSearcher) && afterFirstAIPeek

let q = ResizeArray<_> ()
let availableStates = HashSet<_> ()



let init states =
q.AddRange states
Expand Down Expand Up @@ -153,7 +159,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
if Seq.length availableStates > 0 then
let gameStateDelta =
collectGameStateDelta ()
updateGameState gameStateDelta
gameState <- GameUtils.updateGameState gameStateDelta gameState
let statistics =
computeStatistics gameState.Value
Application.applicationGraphDelta.Clear ()
Expand All @@ -168,7 +174,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
else
let gameStateDelta =
collectGameStateDelta ()
updateGameState gameStateDelta
gameState <- GameUtils.updateGameState gameStateDelta gameState
let statistics =
computeStatistics gameState.Value

Expand All @@ -184,12 +190,12 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
else
let toPredict =
match aiMode with
| TrainingSendEachStep ->
| TrainingSendEachStep
| TrainingSendModel ->
if stepsPlayed > 0u<step> then
gameStateDelta
else
gameState.Value
| TrainingSendModel
| Runner -> gameState.Value

let stateId = oracle.Predict toPredict
Expand Down Expand Up @@ -225,15 +231,45 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
let arrayOutputJson =
JsonSerializer.Serialize arrayOutput
arrayOutputJson

let writeStep (gameState: GameState) (output: IDisposableReadOnlyCollection<OrtValue>) (filePath: string) =
let stepToString (gameState: GameState) (output: IDisposableReadOnlyCollection<OrtValue>) =
let gameStateJson =
JsonSerializer.Serialize gameState
let stateJson = serializeOutput output
File.WriteAllText (filePath + "_gameState", gameStateJson)
File.WriteAllText (filePath + "_nn_output", stateJson)
let outputJson = serializeOutput output
let DELIM = Environment.NewLine
let strToSaveAsList =
[
gameStateJson
DELIM
outputJson
DELIM
]
String.concat " " strToSaveAsList

let createOracleRunner (pathToONNX: string, aiAgentTrainingModelOptions: Option<AIAgentTrainingModelOptions>) =
let host = "localhost"
let port =
match aiAgentTrainingModelOptions with
| Some options -> options.port
| None -> 0

let client = new TcpClient ()
client.Connect (host, port)
client.SendBufferSize <- 2048
let stream = client.GetStream ()
let waitForAck () =
let buffer = Array.zeroCreate<byte> 4
let bytesRead =
stream.Read (buffer, 0, buffer.Length)
if bytesRead > 0 && Encoding.UTF8.GetString (buffer, 0, bytesRead) = "ACK\n" then
()
else
failwith "Did not receive ACK from server"
let saveStep (gameState: GameState) (output: IDisposableReadOnlyCollection<OrtValue>) =
let bytes =
Encoding.UTF8.GetBytes (stepToString gameState output)
stream.Write (bytes, 0, bytes.Length)
stream.Flush ()
waitForAck ()

let sessionOptions =
if useGPU then
Expand All @@ -254,8 +290,15 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
let feedback (x: Feedback) = ()

let mutable stepsPlayed = 0
let mutable currentGameState = None

let predict (gameState: GameState) =
let predict (gameStateOrDelta: GameState) =
let _ =
match aiAgentTrainingModelOptions with
| Some _ when not (stepsPlayed = 0) ->
currentGameState <- GameUtils.updateGameState gameStateOrDelta currentGameState
| _ -> currentGameState <- Some gameStateOrDelta
let gameState = currentGameState.Value
let stateIds =
Dictionary<uint<stateId>, int> ()
let verticesIds =
Expand Down Expand Up @@ -441,7 +484,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai

let _ =
match aiAgentTrainingModelOptions with
| Some options -> writeStep gameState output (options.outputDirectory + ($"/{stepsPlayed}"))
| Some _ -> saveStep gameStateOrDelta output
| None -> ()

stepsPlayed <- stepsPlayed + 1
Expand Down
1 change: 1 addition & 0 deletions VSharp.Explorer/Options.fs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ type AIAgentTrainingModelOptions =
{
aiAgentTrainingOptions: AIAgentTrainingOptions
outputDirectory: string
port: int
}


Expand Down
5 changes: 3 additions & 2 deletions VSharp.ML.GameServer.Runner/Main.fs
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ let generateDataForPretraining outputDirectory datasetBasePath (maps: ResizeArra
API.Reset ()
HashMap.hashMap.Clear ()

let runTrainingSendModelMode outputDirectory (gameMap: GameMap) (pathToModel: string) (useGPU: bool) (optimize: bool) =
let runTrainingSendModelMode outputDirectory (gameMap: GameMap) (pathToModel: string) (useGPU: bool) (optimize: bool) (port: int) =
printfn $"Run infer on {gameMap.MapName} have started."

let aiTrainingOptions =
Expand All @@ -360,6 +360,7 @@ let runTrainingSendModelMode outputDirectory (gameMap: GameMap) (pathToModel: st
{
aiAgentTrainingOptions = aiTrainingOptions
outputDirectory = outputDirectory
port = port
}
)

Expand Down Expand Up @@ -473,7 +474,7 @@ let main args =
let optimize =
(args.TryGetResult <@ Optimize @>).IsSome

runTrainingSendModelMode outputDirectory gameMap model useGPU optimize
runTrainingSendModelMode outputDirectory gameMap model useGPU optimize port
| Mode.Generator ->
let datasetDescription =
args.GetResult <@ DatasetDescription @>
Expand Down

0 comments on commit 02c7e74

Please sign in to comment.