From 02c7e74e7b10c0789868fcc4372b7c199001526b Mon Sep 17 00:00:00 2001 From: Parzival-05 Date: Mon, 6 Jan 2025 04:42:49 +0300 Subject: [PATCH] Save steps with sockets --- VSharp.Explorer/AISearcher.fs | 161 ++++++++++++++++++---------- VSharp.Explorer/Options.fs | 1 + VSharp.ML.GameServer.Runner/Main.fs | 5 +- 3 files changed, 106 insertions(+), 61 deletions(-) diff --git a/VSharp.Explorer/AISearcher.fs b/VSharp.Explorer/AISearcher.fs index 62cd58641..0d114a365 100644 --- a/VSharp.Explorer/AISearcher.fs +++ b/VSharp.Explorer/AISearcher.fs @@ -2,64 +2,25 @@ namespace VSharp.Explorer open System.Collections.Generic open Microsoft.ML.OnnxRuntime +open System.IO +open System +open System.Net +open System.Net.Sockets +open System.Text open System.Text.Json open VSharp open VSharp.IL.Serializer open VSharp.ML.GameServer.Messages -open System.IO type AIMode = | Runner | TrainingSendModel | TrainingSendEachStep -type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option) = - let stepsToSwitchToAI = - match aiAgentTrainingMode with - | None -> 0u - | Some (SendModel options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI - | Some (SendEachStep options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI - - let stepsToPlay = - match aiAgentTrainingMode with - | None -> 0u - | Some (SendModel options) -> options.aiAgentTrainingOptions.stepsToPlay - | Some (SendEachStep options) -> options.aiAgentTrainingOptions.stepsToPlay - - let mutable lastCollectedStatistics = - Statistics () - let mutable defaultSearcherSteps = 0u - let mutable (gameState: Option) = - None - let mutable useDefaultSearcher = - stepsToSwitchToAI > 0u - let mutable afterFirstAIPeek = false - let mutable incorrectPredictedStateId = - false - - let defaultSearcher = - let pickSearcher = - function - | BFSMode -> BFSSearcher () :> IForwardSearcher - | DFSMode -> DFSSearcher () :> IForwardSearcher - | x -> failwithf $"Unexpected default searcher {x}. DFS and BFS supported for now." - - match aiAgentTrainingMode with - | None -> BFSSearcher () :> IForwardSearcher - | Some (SendModel options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy - | Some (SendEachStep options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy - - let mutable stepsPlayed = 0u - - let isInAIMode () = - (not useDefaultSearcher) && afterFirstAIPeek - - let q = ResizeArray<_> () - let availableStates = HashSet<_> () - - let updateGameState (delta: GameState) = +module GameUtils = + let updateGameState (delta: GameState) (gameState: Option) = match gameState with - | None -> gameState <- Some delta + | None -> Some delta | Some s -> let updatedBasicBlocks = delta.GraphVertices |> Array.map (fun b -> b.Id) |> HashSet @@ -106,7 +67,52 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option) = + let stepsToSwitchToAI = + match aiAgentTrainingMode with + | None -> 0u + | Some (SendModel options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI + | Some (SendEachStep options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI + + let stepsToPlay = + match aiAgentTrainingMode with + | None -> 0u + | Some (SendModel options) -> options.aiAgentTrainingOptions.stepsToPlay + | Some (SendEachStep options) -> options.aiAgentTrainingOptions.stepsToPlay + + let mutable lastCollectedStatistics = + Statistics () + let mutable defaultSearcherSteps = 0u + let mutable (gameState: Option) = + None + let mutable useDefaultSearcher = + stepsToSwitchToAI > 0u + let mutable afterFirstAIPeek = false + let mutable incorrectPredictedStateId = + false + + let defaultSearcher = + let pickSearcher = + function + | BFSMode -> BFSSearcher () :> IForwardSearcher + | DFSMode -> DFSSearcher () :> IForwardSearcher + | x -> failwithf $"Unexpected default searcher {x}. DFS and BFS supported for now." + + match aiAgentTrainingMode with + | None -> BFSSearcher () :> IForwardSearcher + | Some (SendModel options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy + | Some (SendEachStep options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy + + let mutable stepsPlayed = 0u + + let isInAIMode () = + (not useDefaultSearcher) && afterFirstAIPeek + + let q = ResizeArray<_> () + let availableStates = HashSet<_> () + + let init states = q.AddRange states @@ -153,7 +159,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option 0 then let gameStateDelta = collectGameStateDelta () - updateGameState gameStateDelta + gameState <- GameUtils.updateGameState gameStateDelta gameState let statistics = computeStatistics gameState.Value Application.applicationGraphDelta.Clear () @@ -168,7 +174,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option + | TrainingSendEachStep + | TrainingSendModel -> if stepsPlayed > 0u then gameStateDelta else gameState.Value - | TrainingSendModel | Runner -> gameState.Value let stateId = oracle.Predict toPredict @@ -225,15 +231,45 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option) (filePath: string) = + let stepToString (gameState: GameState) (output: IDisposableReadOnlyCollection) = let gameStateJson = JsonSerializer.Serialize gameState - let stateJson = serializeOutput output - File.WriteAllText (filePath + "_gameState", gameStateJson) - File.WriteAllText (filePath + "_nn_output", stateJson) + let outputJson = serializeOutput output + let DELIM = Environment.NewLine + let strToSaveAsList = + [ + gameStateJson + DELIM + outputJson + DELIM + ] + String.concat " " strToSaveAsList let createOracleRunner (pathToONNX: string, aiAgentTrainingModelOptions: Option) = + let host = "localhost" + let port = + match aiAgentTrainingModelOptions with + | Some options -> options.port + | None -> 0 + + let client = new TcpClient () + client.Connect (host, port) + client.SendBufferSize <- 2048 + let stream = client.GetStream () + let waitForAck () = + let buffer = Array.zeroCreate 4 + let bytesRead = + stream.Read (buffer, 0, buffer.Length) + if bytesRead > 0 && Encoding.UTF8.GetString (buffer, 0, bytesRead) = "ACK\n" then + () + else + failwith "Did not receive ACK from server" + let saveStep (gameState: GameState) (output: IDisposableReadOnlyCollection) = + let bytes = + Encoding.UTF8.GetBytes (stepToString gameState output) + stream.Write (bytes, 0, bytes.Length) + stream.Flush () + waitForAck () let sessionOptions = if useGPU then @@ -254,8 +290,15 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option + currentGameState <- GameUtils.updateGameState gameStateOrDelta currentGameState + | _ -> currentGameState <- Some gameStateOrDelta + let gameState = currentGameState.Value let stateIds = Dictionary, int> () let verticesIds = @@ -441,7 +484,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option writeStep gameState output (options.outputDirectory + ($"/{stepsPlayed}")) + | Some _ -> saveStep gameStateOrDelta output | None -> () stepsPlayed <- stepsPlayed + 1 diff --git a/VSharp.Explorer/Options.fs b/VSharp.Explorer/Options.fs index 1602a7d70..62a509884 100644 --- a/VSharp.Explorer/Options.fs +++ b/VSharp.Explorer/Options.fs @@ -78,6 +78,7 @@ type AIAgentTrainingModelOptions = { aiAgentTrainingOptions: AIAgentTrainingOptions outputDirectory: string + port: int } diff --git a/VSharp.ML.GameServer.Runner/Main.fs b/VSharp.ML.GameServer.Runner/Main.fs index 77efafeda..f3946d486 100644 --- a/VSharp.ML.GameServer.Runner/Main.fs +++ b/VSharp.ML.GameServer.Runner/Main.fs @@ -334,7 +334,7 @@ let generateDataForPretraining outputDirectory datasetBasePath (maps: ResizeArra API.Reset () HashMap.hashMap.Clear () -let runTrainingSendModelMode outputDirectory (gameMap: GameMap) (pathToModel: string) (useGPU: bool) (optimize: bool) = +let runTrainingSendModelMode outputDirectory (gameMap: GameMap) (pathToModel: string) (useGPU: bool) (optimize: bool) (port: int) = printfn $"Run infer on {gameMap.MapName} have started." let aiTrainingOptions = @@ -360,6 +360,7 @@ let runTrainingSendModelMode outputDirectory (gameMap: GameMap) (pathToModel: st { aiAgentTrainingOptions = aiTrainingOptions outputDirectory = outputDirectory + port = port } ) @@ -473,7 +474,7 @@ let main args = let optimize = (args.TryGetResult <@ Optimize @>).IsSome - runTrainingSendModelMode outputDirectory gameMap model useGPU optimize + runTrainingSendModelMode outputDirectory gameMap model useGPU optimize port | Mode.Generator -> let datasetDescription = args.GetResult <@ DatasetDescription @>