diff --git a/VSharp.API/VSharp.cs b/VSharp.API/VSharp.cs index 4875b59f2..590caba43 100644 --- a/VSharp.API/VSharp.cs +++ b/VSharp.API/VSharp.cs @@ -3,11 +3,9 @@ using System.Collections.Generic; using System.IO; using System.Linq; -using System.Net.WebSockets; using System.Reflection; using System.Text; using VSharp.CSharpUtils; -using VSharp.IL; using VSharp.Interpreter.IL; using VSharp.Explorer; @@ -187,8 +185,8 @@ private static Statistics StartExploration( stopOnCoverageAchieved: 100, randomSeed: options.RandomSeed, stepsLimit: options.StepsLimit, - oracle: options.Oracle, - aiAgentTrainingOptions: options.AIAgentTrainingOptions + aiAgentTrainingOptions: options.AIAgentTrainingOptions, + pathToModel: options.PathToModel ); var fuzzerOptions = diff --git a/VSharp.API/VSharpOptions.cs b/VSharp.API/VSharpOptions.cs index 2433babad..2528a2deb 100644 --- a/VSharp.API/VSharpOptions.cs +++ b/VSharp.API/VSharpOptions.cs @@ -111,8 +111,8 @@ public readonly record struct VSharpOptions public readonly bool ReleaseBranches = DefaultReleaseBranches; public readonly int RandomSeed = DefaultRandomSeed; public readonly uint StepsLimit = DefaultStepsLimit; - public readonly Oracle? Oracle = null; public readonly AIAgentTrainingOptions AIAgentTrainingOptions = null; + public readonly string PathToModel = null; /// /// Symbolic virtual machine options. @@ -129,6 +129,8 @@ public readonly record struct VSharpOptions /// If true and timeout is specified, a part of allotted time in the end is given to execute remaining states without branching. /// Fixed seed for random operations. Used if greater than or equal to zero. /// Number of symbolic machine steps to stop execution after. Zero value means no limit. + /// Settings for AI searcher training. + /// Path to ONNX file with model to use in AI searcher. public VSharpOptions( int timeout = DefaultTimeout, int solverTimeout = DefaultSolverTimeout, @@ -142,8 +144,8 @@ public VSharpOptions( bool releaseBranches = DefaultReleaseBranches, int randomSeed = DefaultRandomSeed, uint stepsLimit = DefaultStepsLimit, - Oracle? oracle = null, - AIAgentTrainingOptions aiAgentTrainingOptions = null) + AIAgentTrainingOptions aiAgentTrainingOptions = null, + string pathToModel = null) { Timeout = timeout; SolverTimeout = solverTimeout; @@ -157,8 +159,8 @@ public VSharpOptions( ReleaseBranches = releaseBranches; RandomSeed = randomSeed; StepsLimit = stepsLimit; - Oracle = oracle; AIAgentTrainingOptions = aiAgentTrainingOptions; + PathToModel = pathToModel; } /// diff --git a/VSharp.Explorer/Explorer.fs b/VSharp.Explorer/Explorer.fs index b4fe5abbb..2f48706a2 100644 --- a/VSharp.Explorer/Explorer.fs +++ b/VSharp.Explorer/Explorer.fs @@ -80,7 +80,16 @@ type private SVMExplorer(explorationOptions: ExplorationOptions, statistics: SVM let rec mkForwardSearcher mode = let getRandomSeedOption() = if options.randomSeed < 0 then None else Some options.randomSeed match mode with - | AIMode -> AISearcher(options.oracle.Value, options.aiAgentTrainingOptions) :> IForwardSearcher + | AIMode -> + match options.aiAgentTrainingOptions with + | Some aiOptions -> + match aiOptions.oracle with + | Some oracle -> AISearcher(oracle, options.aiAgentTrainingOptions) :> IForwardSearcher + | None -> failwith "Empty oracle for AI searcher." + | None -> + match options.pathToModel with + | Some s -> AISearcher s + | None -> failwith "Empty model for AI searcher." | BFSMode -> BFSSearcher() :> IForwardSearcher | DFSMode -> DFSSearcher() :> IForwardSearcher | ShortestDistanceBasedMode -> ShortestDistanceBasedSearcher statistics :> IForwardSearcher diff --git a/VSharp.Explorer/Options.fs b/VSharp.Explorer/Options.fs index 233e0be37..e78bac048 100644 --- a/VSharp.Explorer/Options.fs +++ b/VSharp.Explorer/Options.fs @@ -46,13 +46,15 @@ type Oracle = /// Default searcher that will be used to play few initial steps. /// Determine whether steps should be serialized. /// Name of map to play. +/// Name of map to play. type AIAgentTrainingOptions = { stepsToSwitchToAI: uint stepsToPlay: uint defaultSearchStrategy: searchMode serializeSteps: bool - mapName: string + mapName: string + oracle: Option } type SVMOptions = { @@ -67,8 +69,8 @@ type SVMOptions = { stopOnCoverageAchieved : int randomSeed : int stepsLimit : uint - oracle: Option aiAgentTrainingOptions: Option + pathToModel: Option } type explorationModeOptions = diff --git a/VSharp.ML.GameServer.Runner/Main.fs b/VSharp.ML.GameServer.Runner/Main.fs index 2e3f76e83..43f4ca0d8 100644 --- a/VSharp.ML.GameServer.Runner/Main.fs +++ b/VSharp.ML.GameServer.Runner/Main.fs @@ -156,8 +156,9 @@ let ws outputDirectory (webSocket : WebSocket) (context: HttpContext) = | x -> failwithf $"Unexpected searcher {x}. Use DFS or BFS for now." serializeSteps = false mapName = gameMap.MapName + oracle = Some oracle } - let options = VSharpOptions(timeout = 15 * 60, outputDirectory = outputDirectory, oracle = oracle, searchStrategy = SearchStrategy.AI, aiAgentTrainingOptions = aiTrainingOptions, solverTimeout=2) + let options = VSharpOptions(timeout = 15 * 60, outputDirectory = outputDirectory, searchStrategy = SearchStrategy.AI, aiAgentTrainingOptions = aiTrainingOptions, solverTimeout=2) let statistics = TestGenerator.Cover(method, options) let actualCoverage = try @@ -206,6 +207,7 @@ let generateDataForPretraining outputDirectory datasetBasePath (maps:Dictionary< defaultSearchStrategy = searchMode.BFSMode serializeSteps = true mapName = kvp.Value.MapName + oracle = None } let options = VSharpOptions(timeout = 5 * 60, outputDirectory = outputDirectory, searchStrategy = SearchStrategy.ExecutionTreeContributedCoverage, stepsLimit = stepsToSerialize, solverTimeout=2, aiAgentTrainingOptions = aiTrainingOptions) let statistics = TestGenerator.Cover(method, options) diff --git a/VSharp.Runner/RunnerProgram.cs b/VSharp.Runner/RunnerProgram.cs index 2729cf74d..46ab18b9d 100644 --- a/VSharp.Runner/RunnerProgram.cs +++ b/VSharp.Runner/RunnerProgram.cs @@ -134,6 +134,10 @@ public static int Main(string[] args) aliases: new[] { "--timeout", "-t" }, () => -1, "Time for test generation in seconds. Negative value means no timeout."); + var pathToModelOption = new Option( + aliases: new[] { "--model", "-m" }, + () => null, + "Path to ONNX file with model for AI searcher."); var solverTimeoutOption = new Option( aliases: new[] { "--solver-timeout", "-st" }, () => -1, @@ -178,6 +182,7 @@ public static int Main(string[] args) entryPointCommand.AddArgument(assemblyPathArgument); entryPointCommand.AddArgument(concreteArguments); entryPointCommand.AddGlobalOption(timeoutOption); + entryPointCommand.AddGlobalOption(pathToModelOption); entryPointCommand.AddGlobalOption(solverTimeoutOption); entryPointCommand.AddGlobalOption(outputOption); entryPointCommand.AddOption(unknownArgsOption); @@ -192,6 +197,7 @@ public static int Main(string[] args) rootCommand.AddCommand(allPublicMethodsCommand); allPublicMethodsCommand.AddArgument(assemblyPathArgument); allPublicMethodsCommand.AddGlobalOption(timeoutOption); + allPublicMethodsCommand.AddGlobalOption(pathToModelOption); allPublicMethodsCommand.AddGlobalOption(solverTimeoutOption); allPublicMethodsCommand.AddGlobalOption(outputOption); allPublicMethodsCommand.AddGlobalOption(renderTestsOption); @@ -208,6 +214,7 @@ public static int Main(string[] args) publicMethodsOfClassCommand.AddArgument(classArgument); publicMethodsOfClassCommand.AddArgument(assemblyPathArgument); publicMethodsOfClassCommand.AddGlobalOption(timeoutOption); + publicMethodsOfClassCommand.AddGlobalOption(pathToModelOption); publicMethodsOfClassCommand.AddGlobalOption(solverTimeoutOption); publicMethodsOfClassCommand.AddGlobalOption(outputOption); publicMethodsOfClassCommand.AddGlobalOption(renderTestsOption); @@ -223,6 +230,7 @@ public static int Main(string[] args) specificMethodCommand.AddArgument(methodArgument); specificMethodCommand.AddArgument(assemblyPathArgument); specificMethodCommand.AddGlobalOption(timeoutOption); + specificMethodCommand.AddGlobalOption(pathToModelOption); specificMethodCommand.AddGlobalOption(solverTimeoutOption); specificMethodCommand.AddGlobalOption(outputOption); specificMethodCommand.AddGlobalOption(renderTestsOption); @@ -238,6 +246,7 @@ public static int Main(string[] args) namespaceCommand.AddArgument(namespaceArgument); namespaceCommand.AddArgument(assemblyPathArgument); namespaceCommand.AddGlobalOption(timeoutOption); + namespaceCommand.AddGlobalOption(pathToModelOption); namespaceCommand.AddGlobalOption(solverTimeoutOption); namespaceCommand.AddGlobalOption(outputOption); namespaceCommand.AddGlobalOption(renderTestsOption); @@ -249,8 +258,8 @@ public static int Main(string[] args) rootCommand.Description = "Symbolic execution engine for .NET"; - entryPointCommand.Handler = CommandHandler.Create( - (assemblyPath, args, timeout, solverTimeout, output, unknownArgs, renderTests, runTests, strat, verbosity, recursionThreshold, explorationMode) => + entryPointCommand.Handler = CommandHandler.Create( + (assemblyPath, args, timeout, pathToModel, solverTimeout, output, unknownArgs, renderTests, runTests, strat, verbosity, recursionThreshold, explorationMode) => { var assembly = TryLoadAssembly(assemblyPath); var inputArgs = unknownArgs ? null : args; @@ -263,7 +272,8 @@ public static int Main(string[] args) searchStrategy: strat, verbosity: verbosity, recursionThreshold: recursionThreshold, - explorationMode: explorationMode); + explorationMode: explorationMode, + pathToModel: pathToModel); if (assembly == null) return; @@ -275,8 +285,8 @@ public static int Main(string[] args) else PostProcess(TestGenerator.Cover(assembly, inputArgs, options)); }); - allPublicMethodsCommand.Handler = CommandHandler.Create( - (assemblyPath, timeout, solverTimeout, output, renderTests, runTests, singleFile, strat, verbosity, recursionThreshold, explorationMode) => + allPublicMethodsCommand.Handler = CommandHandler.Create( + (assemblyPath, timeout, pathToModel, solverTimeout, output, renderTests, runTests, singleFile, strat, verbosity, recursionThreshold, explorationMode) => { var assembly = TryLoadAssembly(assemblyPath); var options = @@ -288,7 +298,8 @@ public static int Main(string[] args) searchStrategy: strat, verbosity: verbosity, recursionThreshold: recursionThreshold, - explorationMode: explorationMode); + explorationMode: explorationMode, + pathToModel: pathToModel); if (assembly == null) return; @@ -300,8 +311,8 @@ public static int Main(string[] args) else PostProcess(TestGenerator.Cover(assembly, options)); }); - publicMethodsOfClassCommand.Handler = CommandHandler.Create( - (className, assemblyPath, timeout, solverTimeout, output, renderTests, runTests, strat, verbosity, recursionThreshold, explorationMode) => + publicMethodsOfClassCommand.Handler = CommandHandler.Create( + (className, assemblyPath, timeout, pathToModel, solverTimeout, output, renderTests, runTests, strat, verbosity, recursionThreshold, explorationMode) => { var assembly = TryLoadAssembly(assemblyPath); if (assembly == null) return; @@ -322,7 +333,8 @@ public static int Main(string[] args) searchStrategy: strat, verbosity: verbosity, recursionThreshold: recursionThreshold, - explorationMode: explorationMode); + explorationMode: explorationMode, + pathToModel: pathToModel); if (runTests) { @@ -332,8 +344,8 @@ public static int Main(string[] args) else PostProcess(TestGenerator.Cover(type, options)); }); - specificMethodCommand.Handler = CommandHandler.Create( - (methodName, assemblyPath, timeout, solverTimeout, output, renderTests, runTests, strat, verbosity, recursionThreshold, explorationMode) => + specificMethodCommand.Handler = CommandHandler.Create( + (methodName, assemblyPath, timeout, pathToModel, solverTimeout, output, renderTests, runTests, strat, verbosity, recursionThreshold, explorationMode) => { var assembly = TryLoadAssembly(assemblyPath); if (assembly == null) return; @@ -370,7 +382,8 @@ public static int Main(string[] args) searchStrategy: strat, verbosity: verbosity, recursionThreshold: recursionThreshold, - explorationMode: explorationMode); + explorationMode: explorationMode, + pathToModel: pathToModel); if (runTests) { @@ -380,8 +393,8 @@ public static int Main(string[] args) else PostProcess(TestGenerator.Cover(method, options)); }); - namespaceCommand.Handler = CommandHandler.Create( - (namespaceName, assemblyPath, timeout, solverTimeout, output, renderTests, runTests, strat, verbosity, recursionThreshold, explorationMode) => + namespaceCommand.Handler = CommandHandler.Create( + (namespaceName, assemblyPath, timeout, pathToModel, solverTimeout, output, renderTests, runTests, strat, verbosity, recursionThreshold, explorationMode) => { var assembly = TryLoadAssembly(assemblyPath); if (assembly == null) return; @@ -402,7 +415,8 @@ public static int Main(string[] args) searchStrategy: strat, verbosity: verbosity, recursionThreshold: recursionThreshold, - explorationMode: explorationMode); + explorationMode: explorationMode, + pathToModel: pathToModel); if (runTests) { TestGenerator.CoverAndRun(namespaceTypes, out var statistics, options); diff --git a/VSharp.Test/Benchmarks/Benchmarks.cs b/VSharp.Test/Benchmarks/Benchmarks.cs index 60d015282..bb4f7f2a0 100644 --- a/VSharp.Test/Benchmarks/Benchmarks.cs +++ b/VSharp.Test/Benchmarks/Benchmarks.cs @@ -91,8 +91,8 @@ public static BenchmarkResult Run( stopOnCoverageAchieved: -1, randomSeed: randomSeed, stepsLimit: stepsLimit, - oracle:null, - aiAgentTrainingOptions: null + aiAgentTrainingOptions: null, + pathToModel: null ); var fuzzerOptions = new FuzzerOptions( diff --git a/VSharp.Test/IntegrationTests.cs b/VSharp.Test/IntegrationTests.cs index a0fa69566..8216d7aba 100644 --- a/VSharp.Test/IntegrationTests.cs +++ b/VSharp.Test/IntegrationTests.cs @@ -115,13 +115,12 @@ static TestSvmAttribute() private readonly bool _releaseBranches; private readonly bool _checkAttributes; private readonly bool _hasExternMocking; - private readonly string _pathToSerialize; - private readonly bool _serialize; private readonly OsType _supportedOs; private readonly FuzzerIsolation _fuzzerIsolation; private readonly ExplorationMode _explorationMode; private readonly int _randomSeed; private readonly uint _stepsLimit; + private readonly string _pathToModel; public TestSvmAttribute( int expectedCoverage = -1, @@ -139,7 +138,7 @@ public TestSvmAttribute( ExplorationMode explorationMode = ExplorationMode.Sili, int randomSeed = 0, uint stepsLimit = 0, - string serialize = null) + string pathToModel = null) { if (expectedCoverage < 0) _expectedCoverage = null; @@ -160,16 +159,8 @@ public TestSvmAttribute( _fuzzerIsolation = fuzzerIsolation; _explorationMode = explorationMode; _randomSeed = randomSeed; + _pathToModel = pathToModel; _stepsLimit = stepsLimit; - if (serialize == null) - { - _serialize = false; - } - else - { - _pathToSerialize = serialize; - _serialize = true; - } } public virtual TestCommand Wrap(TestCommand command) @@ -190,6 +181,7 @@ public virtual TestCommand Wrap(TestCommand command) _explorationMode, _randomSeed, _stepsLimit, + _pathToModel, _hasExternMocking ); } @@ -214,6 +206,7 @@ private class TestSvmCommand : DelegatingTestCommand private readonly ExplorationMode _explorationMode; private readonly int _randomSeed; private readonly uint _stepsLimit; + private readonly string _pathToModel; private class Reporter: IReporter { @@ -247,6 +240,7 @@ public TestSvmCommand( ExplorationMode explorationMode, int randomSeed, uint stepsLimit, + string pathToModel, bool hasExternMocking) : base(innerCommand) { _baseCoverageZone = coverageZone; @@ -299,6 +293,7 @@ public TestSvmCommand( _explorationMode = explorationMode; _randomSeed = randomSeed; _stepsLimit = stepsLimit; + _pathToModel = pathToModel; } private TestResult IgnoreTest(TestExecutionContext context) @@ -455,8 +450,8 @@ private TestResult Explore(TestExecutionContext context) stopOnCoverageAchieved: _expectedCoverage ?? -1, randomSeed: _randomSeed, stepsLimit: _stepsLimit, - oracle:null, - aiAgentTrainingOptions:null + aiAgentTrainingOptions:null, + pathToModel: _pathToModel ); var fuzzerOptions = new FuzzerOptions(