-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
333 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""Module which runs the desired QAgent.""" | ||
|
||
import numpy as np | ||
|
||
from src.qlearning import small_example_qagent | ||
|
||
if __name__ == "__main__": | ||
qagent = small_example_qagent.SmallExampleQAgent(0.1, 0.9, np.zeros((9,))) | ||
route = qagent.training(0, 8, 100000) | ||
print(route) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
"""Module representing a situation where Q-Learning can be used.""" | ||
|
||
import math | ||
from abc import abstractmethod | ||
|
||
import numpy as np | ||
|
||
|
||
class QAgent: | ||
"""Class representing a situation where Q-Learning can be used.""" | ||
|
||
@abstractmethod | ||
def get_playable_actions(self, current_state, differentials, timestep): | ||
"""Returns a list of states reachable from [current_state] after time [timestep] | ||
has elapsed. [current_state] is a list of 4 numbers: x coordinate, y coordinate, | ||
speed, and angle. [differentials] is also 4 numbers, but is the | ||
differences between cells in the matrix in SI units. [timestep] is what states are | ||
possible after [timestep] amount of time.""" | ||
print(current_state, differentials, timestep) | ||
|
||
@abstractmethod | ||
def get_state_matrix(self): | ||
"""Returns a tuple, the first element is a matrix with dimensions: x coordinate, | ||
y coordinate, speed, angle. The second element is the differences between | ||
each element of the matrix in SI units. This function should be determined before | ||
compile-time based on the occupancy grid resolution and other physical factors. | ||
""" | ||
|
||
@abstractmethod | ||
def set_up_rewards(self, end_state): | ||
"""Creates the reward matrix for the QAgent.""" | ||
print(end_state) | ||
|
||
def get_random_state(self): | ||
"""Selects a random state from all states in the state space.""" | ||
result = [] | ||
for dim in self.q.shape: | ||
result.append(np.random.randint(0, high=dim)) | ||
return tuple(result) | ||
|
||
def get_rewards(self, occupancy_grid, distance): | ||
"""Returns the reward matrix given an occupancy grid and the current distance to | ||
the goal.""" | ||
constant_a = 1 | ||
constant_b = 4 | ||
constant_c = 2 | ||
|
||
# The reward should increase as we approach the goal and | ||
# decrease as the probability of encountering an object increases. | ||
|
||
rewards = (constant_a / math.sqrt((distance**2) + constant_b)) * ( | ||
1 - (constant_c * occupancy_grid) | ||
) | ||
|
||
return rewards | ||
|
||
def qlearning(self, rewards_new, iterations, end_state): | ||
"""Fill in the Q-matrix""" | ||
for _ in range(iterations): | ||
current_state = self.get_random_state() | ||
if current_state == end_state: | ||
continue | ||
playable_actions = self.get_playable_actions( | ||
current_state, self.differentials, self.dt | ||
) | ||
temporal_difference = ( | ||
rewards_new[current_state] | ||
+ self.gamma * np.amax(self.q[playable_actions]) | ||
- self.q[current_state] | ||
) | ||
self.q[current_state] += self.alpha * temporal_difference | ||
|
||
def reset_matrix(self, rewards_new, iterations, end_state, dimensions): | ||
"""Reset the Q-matrix, and rerun the Q-learning algorithm""" | ||
shape = tuple([len(self.q)] * dimensions) | ||
self.q: np.ndarray[np.float_, np.dtype[np.float_]] = np.zeros(shape) | ||
QAgent.qlearning(self, rewards_new, iterations, end_state) | ||
|
||
def alter_matrix(self, rewards_new, iterations, end_state, scale): | ||
"""Partially reset the Q-matrix keeping a scale fraction of the previous | ||
Q-matrix as a starting place, and rerun the Q-learning algorithm""" | ||
rewards_new = rewards_new * scale | ||
QAgent.qlearning(self, rewards_new, iterations, end_state) | ||
|
||
def get_optimal_route(self, start_state, end_state): | ||
"""Given a Q-matrix, greedily select the highest next Q-matrix value until the | ||
end state is reached. This assumes that greedily choosing the next path will | ||
eventually reach the finish, which is not always true. This is used as a | ||
visualization, but not a step in the navigation algorithm.""" | ||
route = [start_state] | ||
next_state = start_state | ||
while next_state != end_state: | ||
playable_actions = self.get_playable_actions( | ||
next_state, self.differentials, self.dt | ||
) | ||
next_state = playable_actions[0][np.argmax(self.q[playable_actions])] | ||
if next_state in route: | ||
route.append(next_state) | ||
break | ||
route.append(next_state) | ||
|
||
return route | ||
|
||
def training(self, start_state, end_state, iterations): | ||
"""Run all the steps of the Q-learning algorithm.""" | ||
rewards_new = self.set_up_rewards(end_state) | ||
self.qlearning(rewards_new, iterations, end_state) | ||
route = self.get_optimal_route(start_state, end_state) | ||
return route | ||
|
||
def __init__(self, alpha, gamma, rewards, dt): | ||
self.gamma = gamma | ||
self.alpha = alpha | ||
self.rewards = rewards | ||
self.q, self.differentials = self.get_state_matrix() | ||
self.dt = dt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
"""This module contains simulation constants/settings | ||
""" | ||
|
||
import numpy as np | ||
|
||
# PHYSICAL PROPERTIES | ||
|
||
DIST_REAR_AXEL = 0.85 # m | ||
DIST_FRONT_AXEL = 0.85 # m | ||
WHEEL_BASE = DIST_REAR_AXEL + DIST_FRONT_AXEL | ||
|
||
MASS = 15 # kg | ||
|
||
CORNERING_STIFF_FRONT = 16.0 * 2.0 # N/rad | ||
CORNERING_STIFF_REAR = 17.0 * 2.0 # N/rad | ||
|
||
YAW_INERTIA = 22 # kg/(m^2) | ||
|
||
# INPUT CONSTRAINTS | ||
|
||
SPEED_MIN = 2 # m/s | ||
SPEED_MAX = 10 # m/s | ||
|
||
STEER_ANGLE_MAX = np.radians(37) # rad | ||
STEER_ANGLE_DELTA = np.radians(1) # rad | ||
|
||
ACCELERATION_MIN = -5 # m/(s^2) | ||
ACCELERATION_MAX = 10 # m/(s^2) | ||
|
||
STEER_ACC_MIN = -5 | ||
STEER_ACC_MAX = 5 | ||
|
||
DT = 0.1 # (s) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
""" | ||
State prediction visualization | ||
""" | ||
|
||
# For visualization | ||
import matplotlib as mpl | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from matplotlib import patches | ||
from scipy.stats import gaussian_kde | ||
|
||
from src.state_pred import constants as cst | ||
|
||
# mpl.use('TkAgg') | ||
|
||
|
||
# Colors | ||
DARK = '#0B0B0B' | ||
LIGHT = '#C9C9C9' | ||
RED = '#ff3333' | ||
PURPLE = '#9933ff' | ||
ORANGE = '#ff9933' | ||
BLUE = '#33ccff' | ||
|
||
fig, ax = plt.subplots(1, 1) | ||
|
||
# Configure figure appearance | ||
ax.set_facecolor(DARK) | ||
fig.patch.set_facecolor(DARK) | ||
for spine in ['bottom', 'top', 'left', 'right']: | ||
ax.spines[spine].set_color(LIGHT) | ||
ax.xaxis.label.set_color(LIGHT) | ||
ax.yaxis.label.set_color(LIGHT) | ||
ax.tick_params(axis='x', colors=LIGHT) | ||
ax.tick_params(axis='y', colors=LIGHT) | ||
ax.grid(linestyle='dashed') | ||
ax.set_aspect('equal', adjustable='box') | ||
|
||
# major_spacing = 0.5 | ||
# ax.set_xticks(np.arange(-2,5,major_spacing)) | ||
# ax.set_xticks(np.arange(-2,5,major_spacing)) | ||
|
||
# minor_spacing = major_spacing/DIFFERENTIALS[0] | ||
# minor_locator = minor(minor_spacing) | ||
# ax.xaxis.set_minor_locator(minor_locator) | ||
|
||
# plt.grid(which="both") | ||
|
||
|
||
def plot_bike(state): | ||
""" | ||
Plot a representation of the bike | ||
""" | ||
|
||
print("Plotting bike...") | ||
|
||
x, y, vel_x, vel_y, yaw_angle, steer_angle = state | ||
|
||
wheel_width = 0.1 | ||
wheel_length = 0.6 | ||
|
||
rear_axel_x = x - cst.DIST_REAR_AXEL * np.cos(yaw_angle) | ||
rear_axel_y = y - cst.DIST_REAR_AXEL * np.sin(yaw_angle) | ||
|
||
front_axel_x = x + cst.DIST_FRONT_AXEL * np.cos(yaw_angle) | ||
front_axel_y = y + cst.DIST_FRONT_AXEL * np.sin(yaw_angle) | ||
|
||
# Plot center line | ||
ax.plot( | ||
[rear_axel_x, front_axel_x], | ||
[rear_axel_y, front_axel_y], | ||
marker='o', | ||
color="white", | ||
) | ||
|
||
# Why neg steering? | ||
rear_wheel = draw_wheel( | ||
rear_axel_x, rear_axel_y, wheel_width, wheel_length, RED, yaw_angle | ||
) | ||
front_wheel = draw_wheel( | ||
front_axel_x, | ||
front_axel_y, | ||
wheel_width, | ||
wheel_length, | ||
BLUE, | ||
yaw_angle + steer_angle, | ||
) | ||
|
||
ax.add_patch(rear_wheel) | ||
ax.add_patch(front_wheel) | ||
|
||
ax.plot([x, x + vel_x], [y, y + vel_y], marker='>', color=ORANGE) | ||
|
||
ax.set_xlim(-2, 5) | ||
ax.set_ylim(-2, 5) | ||
|
||
|
||
# pylint: disable=too-many-arguments | ||
def draw_wheel(center_x, center_y, width, length, color, rotation): | ||
""" | ||
Draw a wheel helper method | ||
""" | ||
|
||
rect = patches.Rectangle( | ||
(center_x - length / 2, center_y - width / 2), | ||
length, | ||
width, | ||
color=color, | ||
alpha=1, | ||
) | ||
ax.transData.transform([center_x, center_y]) | ||
transform = ( | ||
mpl.transforms.Affine2D().rotate_around(center_x, center_y, rotation) | ||
+ ax.transData | ||
) | ||
rect.set_transform(transform) | ||
|
||
return rect | ||
|
||
|
||
def plot_states(states): | ||
""" | ||
Plot a list of states | ||
""" | ||
|
||
print("Plotting states...") | ||
|
||
x = list(state[0] for state in states) | ||
y = list(state[1] for state in states) | ||
|
||
xy = np.vstack([x, y]) | ||
z = gaussian_kde(xy)(xy) | ||
|
||
ax.scatter(x, y, c=z, s=100, alpha=1) | ||
|
||
|
||
def plot_invalid(states): | ||
""" | ||
Plot a list of invalid states in red | ||
""" | ||
|
||
x = list(state[0] for state in states) | ||
y = list(state[1] for state in states) | ||
|
||
ax.scatter(x, y, c=RED, s=100, alpha=0.05) | ||
|
||
|
||
def show_plot(): | ||
""" | ||
Save the plot locally (work around until figure out window forwarding with Docker) | ||
""" | ||
# plt.savefig("./state_pred/state_prediction.png", format="png") | ||
plt.show() |
File renamed without changes.