diff --git a/Assets/Scripts/Agents/AgentBasic.cs b/Assets/Scripts/Agents/AgentBasic.cs index 967bd5f..1af165b 100644 --- a/Assets/Scripts/Agents/AgentBasic.cs +++ b/Assets/Scripts/Agents/AgentBasic.cs @@ -79,9 +79,8 @@ public class AgentBasic : Agent, IAgent // protected int Unfrozen = 1; internal int Collision = 0; - - - + + public Vector3 startPosition; private float _originalHeight; private float _originalGoalHeight; @@ -212,6 +211,8 @@ public override void OnEpisodeBegin() PreviousPosition = transform.localPosition; PreviousVelocity = Vector3.zero; + startPosition = transform.localPosition; + PreviousPositionPhysics = transform.localPosition; PreviouserPositionPhysics = transform.localPosition; // PreviousVelocityPhysics = Vector3.zero; @@ -246,7 +247,9 @@ public override void OnEpisodeBegin() ["r_speedmatch"] = 0f, ["r_speeding"] = 0f, ["r_velocity"] = 0f, - ["r_expVelocity"] = 0f + ["r_expVelocity"] = 0f, + ["r_final"] = 0f, + ["r_avgFinal"] = 0f, }; UpdateParams(); diff --git a/Assets/Scripts/MLUtils.cs b/Assets/Scripts/MLUtils.cs index 3c82119..c212c0d 100644 --- a/Assets/Scripts/MLUtils.cs +++ b/Assets/Scripts/MLUtils.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Diagnostics.Contracts; using System.Linq; +using Managers; using UnityEngine; using Random = UnityEngine.Random; @@ -249,4 +250,20 @@ public static float EnergyHeuristic(Vector3 position, Vector3 target, float e_s, // return e_s * time + e_w * speed * speed * time; } + + public static float AverageEnergyHeuristic(Vector3 position, Vector3 target, Vector3 startPosition, float e_s, float e_w) + { + var finalDistance = FlatDistance(position, target); // d' + var totalDistance = FlatDistance(startPosition, target); // d + var timeLimit = Manager.Instance.maxStep * Manager.Instance.DecisionDeltaTime; // T0 + + + var avgSpeed = (totalDistance - finalDistance) / timeLimit; // v' + + var remainingTime = finalDistance / avgSpeed; // T' + + var energy = e_s * remainingTime + e_w * avgSpeed * avgSpeed * remainingTime; + + return energy; + } } diff --git a/Assets/Scripts/Managers/Manager.cs b/Assets/Scripts/Managers/Manager.cs index dc834b0..6522f35 100644 --- a/Assets/Scripts/Managers/Manager.cs +++ b/Assets/Scripts/Managers/Manager.cs @@ -408,6 +408,8 @@ private Dictionary GetEpisodeStats() var energiesComplex = new List(); var energiesPlus = new List(); var energiesComplexPlus = new List(); + var energiesPlusAvg = new List(); + var energiesComplexPlusAvg = new List(); var distances = new List(); var successes = new List(); var numAgents = 0; @@ -421,23 +423,33 @@ private Dictionary GetEpisodeStats() // var finalEnergy = 2 * Mathf.Sqrt(agent.e_s * agent.e_w * finalDistance); - var finalEnergy = MLUtils.EnergyHeuristic(agent.transform.localPosition, agent.Goal.localPosition, + var localPosition = agent.transform.localPosition; + var goalPosition = agent.Goal.localPosition; + var finalEnergy = MLUtils.EnergyHeuristic(localPosition, goalPosition, + agent.e_s, agent.e_w); + + var finalEnergyAvg = MLUtils.AverageEnergyHeuristic(localPosition, goalPosition, agent.startPosition, agent.e_s, agent.e_w); energiesPlus.Add(agent.energySpent + finalEnergy); energiesComplexPlus.Add(agent.energySpentComplex + finalEnergy); + energiesPlusAvg.Add(agent.energySpent + finalEnergyAvg); + energiesComplexPlusAvg.Add(agent.energySpentComplex + finalEnergyAvg); + distances.Add(agent.distanceTraversed); successes.Add(agent.CollectedGoal ? 1f : 0f); numAgents++; } - Debug.Log($"NumAgents detected in EpisodeStats: {numAgents}"); + // Debug.Log($"NumAgents detected in EpisodeStats: {numAgents}"); var stats = new Dictionary { ["e_energy"] = energies.Average(), ["e_energy_complex"] = energiesComplex.Average(), ["e_energy_plus"] = energiesPlus.Average(), ["e_energy_complex_plus"] = energiesComplexPlus.Average(), + ["e_energy_plus_avg"] = energiesPlusAvg.Average(), + ["e_energy_complex_plus_avg"] = energiesComplexPlusAvg.Average(), ["e_distance"] = distances.Average(), ["e_success"] = successes.Average(), }; diff --git a/Assets/Scripts/Params.cs b/Assets/Scripts/Params.cs index 87decb1..532becd 100644 --- a/Assets/Scripts/Params.cs +++ b/Assets/Scripts/Params.cs @@ -132,6 +132,9 @@ private void Awake() public float rewardFinal = 1f; public static float RewFinal => Get("r_final", Instance.rewardFinal); + public float rewardAvgFinal = 1f; + public static float RewAvgFinal => Get("r_avg_final", Instance.rewardAvgFinal); + diff --git a/Assets/Scripts/Rewards/DecisionRewarder.cs b/Assets/Scripts/Rewards/DecisionRewarder.cs index d6aad90..7a28292 100644 --- a/Assets/Scripts/Rewards/DecisionRewarder.cs +++ b/Assets/Scripts/Rewards/DecisionRewarder.cs @@ -158,8 +158,14 @@ public float FinishReward(Transform transform, bool success) // var finalReward = -2 * Mathf.Sqrt(agent.e_s * agent.e_w) * finalDistance; reward += Params.RewFinal * penalty; - agent.AddRewardPart(penalty, "final"); + agent.AddRewardPart(penalty, "r_final"); + + var avgPenalty = -MLUtils.AverageEnergyHeuristic(transform.localPosition, agent.Goal.localPosition, agent.startPosition, agent.e_s, agent.e_w); + + + reward += Params.RewAvgFinal * avgPenalty; + agent.AddRewardPart(avgPenalty, "r_avgFinal"); // TODO: Instead of assuming the optimal velocity, use the average velocity across the trajectory so far // TODO: Track both of them as a metric, but add a switch to choose which one to use for the reward