forked from tensorflow/tfjs-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cart_pole.js
126 lines (113 loc) · 3.95 KB
/
cart_pole.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Implementation based on: http://incompleteideas.net/book/code/pole.c
*/
import * as tf from '@tensorflow/tfjs';
/**
* Cart-pole system simulator.
*
* In the control-theory sense, there are four state variables in this system:
*
* - x: The 1D location of the cart.
* - xDot: The velocity of the cart.
* - theta: The angle of the pole (in radians). A value of 0 corresponds to
* a vertical position.
* - thetaDot: The angular velocity of the pole.
*
* The system is controlled through a single action:
*
* - leftward or rightward force.
*/
export class CartPole {
/**
* Constructor of CartPole.
*/
constructor() {
// Constants that characterize the system.
this.gravity = 9.8;
this.massCart = 1.0;
this.massPole = 0.1;
this.totalMass = this.massCart + this.massPole;
this.cartWidth = 0.2;
this.cartHeight = 0.1;
this.length = 0.5;
this.poleMoment = this.massPole * this.length;
this.forceMag = 10.0;
this.tau = 0.02; // Seconds between state updates.
// Threshold values, beyond which a simulation will be marked as failed.
this.xThreshold = 2.4;
this.thetaThreshold = 12 / 360 * 2 * Math.PI;
this.setRandomState();
}
/**
* Set the state of the cart-pole system randomly.
*/
setRandomState() {
// The control-theory state variables of the cart-pole system.
// Cart position, meters.
this.x = Math.random() - 0.5;
// Cart velocity.
this.xDot = (Math.random() - 0.5) * 1;
// Pole angle, radians.
this.theta = (Math.random() - 0.5) * 2 * (6 / 360 * 2 * Math.PI);
// Pole angle velocity.
this.thetaDot = (Math.random() - 0.5) * 0.5;
}
/**
* Get current state as a tf.Tensor of shape [1, 4].
*/
getStateTensor() {
return tf.tensor2d([[this.x, this.xDot, this.theta, this.thetaDot]]);
}
/**
* Update the cart-pole system using an action.
* @param {number} action Only the sign of `action` matters.
* A value > 0 leads to a rightward force of a fixed magnitude.
* A value <= 0 leads to a leftward force of the same fixed magnitude.
*/
update(action) {
const force = action > 0 ? this.forceMag : -this.forceMag;
const cosTheta = Math.cos(this.theta);
const sinTheta = Math.sin(this.theta);
const temp =
(force + this.poleMoment * this.thetaDot * this.thetaDot * sinTheta) /
this.totalMass;
const thetaAcc = (this.gravity * sinTheta - cosTheta * temp) /
(this.length *
(4 / 3 - this.massPole * cosTheta * cosTheta / this.totalMass));
const xAcc = temp - this.poleMoment * thetaAcc * cosTheta / this.totalMass;
// Update the four state variables, using Euler's method.
this.x += this.tau * this.xDot;
this.xDot += this.tau * xAcc;
this.theta += this.tau * this.thetaDot;
this.thetaDot += this.tau * thetaAcc;
return this.isDone();
}
/**
* Determine whether this simulation is done.
*
* A simulation is done when `x` (position of the cart) goes out of bound
* or when `theta` (angle of the pole) goes out of bound.
*
* @returns {bool} Whether the simulation is done.
*/
isDone() {
return this.x < -this.xThreshold || this.x > this.xThreshold ||
this.theta < -this.thetaThreshold || this.theta > this.thetaThreshold;
}
}