You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thaks for the repo!
I am currently testing out your implementation of S5. Sadly I am not very familiar with the S5 architecture.
When I run your code I get this warning:
~/.local/lib/python3.10/site-packages/jax/_src/lax/lax.py:2652: ComplexWarning: Casting complex values to real discards the imaginary part
x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
The warning originates in the PPO loss computation and is related to the complex parameters of the S5 model.
The command I am running is below.
Is this behavior intended? I tried reading up on the literature on S4 and S5 but it was not immediately obvious to me so I have little intuition around what it means to cast complex parameters to float.
Hi 👋 ,
Thaks for the repo!
I am currently testing out your implementation of S5. Sadly I am not very familiar with the S5 architecture.
When I run your code I get this warning:
~/.local/lib/python3.10/site-packages/jax/_src/lax/lax.py:2652: ComplexWarning: Casting complex values to real discards the imaginary part x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
The warning originates in the PPO loss computation and is related to the complex parameters of the S5 model.
The command I am running is below.
Is this behavior intended? I tried reading up on the literature on S4 and S5 but it was not immediately obvious to me so I have little intuition around what it means to cast complex parameters to float.
Feedback is appreciated! Thanks!
python3 -m minimax.train \ --seed=1 \ --agent_rl_algo=ppo \ --n_total_updates=30000 \ --train_runner=plr \ --n_devices=1 \ --student_model_name=default_student_cnn \ --env_name=Maze \ --verbose=False \ --log_dir=~/logs/minimax \ --log_interval=10 \ --from_last_checkpoint=True \ --checkpoint_interval=1000 \ --archive_interval=0 \ --archive_init_checkpoint=False \ --test_interval=100 \ --n_students=1 \ --n_parallel=32 \ --n_eval=1 \ --n_rollout_steps=256 \ --lr=3e-05 \ --lr_anneal_steps=0 \ --max_grad_norm=0.5 \ --adam_eps=1e-05 \ --track_env_metrics=True \ --discount=0.999 \ --n_unroll_rollout=10 \ --render=False \ --ued_score=max_mc \ --plr_replay_prob=0.5 \ --plr_buffer_size=4000 \ --plr_staleness_coef=0.3 \ --plr_temp=0.3 \ --plr_use_score_ranks=True \ --plr_min_fill_ratio=0.5 \ --plr_use_robust_plr=True \ --plr_use_parallel_eval=False \ --plr_force_unique=True \ --student_gae_lambda=0.98 \ --student_entropy_coef=0.001 \ --student_value_loss_coef=0.5 \ --student_n_unroll_update=5 \ --student_ppo_n_epochs=5 \ --student_ppo_n_minibatches=1 \ --student_ppo_clip_eps=0.2 \ --student_ppo_clip_value_loss=True \ --student_recurrent_arch=s5 \ --student_recurrent_hidden_dim=256 \ --student_hidden_dim=32 \ --student_n_hidden_layers=1 \ --student_n_conv_filters=16 \ --student_n_scalar_embeddings=4 \ --student_scalar_embed_dim=5 \ --student_s5_n_blocks=2 \ --student_s5_n_layers=2 \ --student_s5_layernorm_pos=pre \ --student_s5_activation=half_glu1 \ --maze_height=13 \ --maze_width=13 \ --maze_n_walls=60 \ --maze_replace_wall_pos=True \ --maze_sample_n_walls=False \ --maze_see_agent=False \ --maze_normalize_obs=True \ --maze_obs_agent_pos=False \ --maze_max_episode_steps=250 \ --test_n_episodes=10 \ --test_env_names=Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze \ --maze_test_see_agent=False \ --maze_test_normalize_obs=True \ --xpid=plr-maze13x13w60na_f-rf_p0.5b4000t0.3s0.3m0.5r_r1s_32p_1e_256t_ae1e-05_smm-ppo_lr3e-05g0.999cv0.5ce0.001e5mb1l0.98_pc0.2_h32cf16fc1se5ba_re_lpr_ahg1_s5_h256nb2nl2_0
The text was updated successfully, but these errors were encountered: