Skip to content

Commit

Permalink
Update beta dist sampling to being vec based
Browse files Browse the repository at this point in the history
  • Loading branch information
hbuurmei committed Oct 24, 2024
1 parent 269a2f5 commit 3002f9f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions stack/main/scripts/control_inputs_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pandas as pd # type: ignore
import matplotlib.pyplot as plt
from itertools import product
from scipy.stats import norm


def save_to_csv(df, control_inputs_file):
Expand Down Expand Up @@ -44,8 +43,8 @@ def uniform_sampling(control_variables):
return control_inputs_df


def beta_sampling(control_variables, sample_size=150):
np.random.seed(0)
def beta_sampling(control_variables, seed, sample_size=150):
np.random.seed(seed)
tip_range, mid_range, base_range = 0.2, 0.3, 0.4

control_inputs_df = pd.DataFrame(columns=['ID'] + control_variables)
Expand All @@ -66,11 +65,19 @@ def beta_sampling(control_variables, sample_size=150):
u3 = (np.random.beta(a, b) - 0.5) * 2 * base_range
u4 = (np.random.beta(a, b) - 0.5) * 2 * base_range

# Compute control input vectors
u1_vec = u1 * np.array([-np.cos(15 * np.pi/180), np.sin(15 * np.pi/180)])
u2_vec = u2 * np.array([np.cos(45 * np.pi/180), np.sin(45 * np.pi/180)])
u3_vec = u3 * np.array([-np.cos(15 * np.pi/180), -np.sin(15 * np.pi/180)])
u4_vec = u4 * np.array([-np.cos(75 * np.pi/180), np.sin(75 * np.pi/180)])
u5_vec = u5 * np.array([np.cos(45 * np.pi/180), -np.sin(45 * np.pi/180)])
u6_vec = u6 * np.array([-np.cos(75 * np.pi/180), -np.sin(75 * np.pi/180)])

# Calculate the norm based on the constraint
vector_sum = (
0.75 * (u3 + u4) +
1.0 * (u2 + u5) +
1.25 * (u1 + u6)
0.75 * (u3_vec + u4_vec) +
1.0 * (u2_vec + u5_vec) +
1.25 * (u1_vec + u6_vec)
)
norm_value = np.linalg.norm(vector_sum)

Expand Down Expand Up @@ -102,17 +109,20 @@ def visualize_samples(control_inputs_df):
plt.show()


def main(data_type='dynamic', sampling_type='uniform'):
def main(data_type='dynamic', sampling_type='uniform', seed=None):
control_variables = ['u1', 'u2', 'u3', 'u4', 'u5', 'u6']
data_dir = os.getenv('TRUNK_DATA', '/home/trunk/Documents/trunk-stack/stack/main/data')
control_inputs_file = os.path.join(data_dir, f'trajectories/{data_type}/control_inputs_{sampling_type}.csv')
if seed is not None:
control_inputs_file = os.path.join(data_dir, f'trajectories/{data_type}/control_inputs_{sampling_type}_seed{seed}.csv')
else:
control_inputs_file = os.path.join(data_dir, f'trajectories/{data_type}/control_inputs_{sampling_type}.csv')

if sampling_type=='sinusoidal':
control_inputs_df = sinusoidal_sampling(control_variables)
elif sampling_type=='uniform':
control_inputs_df = uniform_sampling(control_variables)
elif sampling_type=='beta':
control_inputs_df = beta_sampling(control_variables)
control_inputs_df = beta_sampling(control_variables, seed)
else:
raise ValueError(f"Invalid sampling_type: {sampling_type}")

Expand All @@ -123,4 +133,5 @@ def main(data_type='dynamic', sampling_type='uniform'):
if __name__ == '__main__':
data_type = 'steady_state' # 'steady_state' or 'dynamic'
sampling_type = 'beta' # 'beta', 'uniform' or 'sinusoidal'
main(data_type, sampling_type)
seed = 1 # choose integer seed number
main(data_type, sampling_type, seed)

0 comments on commit 3002f9f

Please sign in to comment.