From cd08e5b9e96556e6041b9d3ff23bd83cbdc1d4b2 Mon Sep 17 00:00:00 2001 From: Ivan Zhou Date: Tue, 27 Aug 2024 18:29:23 -0700 Subject: [PATCH] run train lm ray --- execute_submit_ray_job.sh | 6 +++++- ray_runtime_env.yaml | 1 + src/levanter/main/train_lm_ray.py | 12 +++++++++--- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/execute_submit_ray_job.sh b/execute_submit_ray_job.sh index 6a53e826c..5cb81c305 100644 --- a/execute_submit_ray_job.sh +++ b/execute_submit_ray_job.sh @@ -1,6 +1,10 @@ -COMMAND="pip install git+https://github.com/stanford-crfm/levanter.git@ivan-ray-jobs fire && WANDB_API_KEY=be441272f3bd2812a2eb009739e26a202f14d7ba \ +# pip install git+https://github.com/stanford-crfm/levanter.git@ivan-ray-jobs fire && +COMMAND="pip install fire && pip uninstall -y levanter && pip install -U git+https://github.com/stanford-crfm/levanter.git@ivan-ray-jobs && pip install jax[tpu]" +COMMAND="echo 'Hello, World!'" +LAUNCH_COMMAND="WANDB_API_KEY=be441272f3bd2812a2eb009739e26a202f14d7ba \ WANDB_PROJECT=marin \ python src/levanter/main/train_lm_ray.py --config_path config/gpt2_nano.yaml" +COMMAND="$COMMAND && $LAUNCH_COMMAND" echo $COMMAND ray job submit --address http://127.0.0.1:8265 --working-dir . \ diff --git a/ray_runtime_env.yaml b/ray_runtime_env.yaml index 0b954e14d..ec1630b01 100644 --- a/ray_runtime_env.yaml +++ b/ray_runtime_env.yaml @@ -5,6 +5,7 @@ runtime_env: HF_TOKEN: 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 TPU_STDERR_LOG_LEVEL: 0 TPU_MIN_LOG_LEVEL: 0 + RAY_ENABLE_RECORD_ACTOR_TASK_LOGGING: 1 pip: # - git+https://github.com/stanford-crfm/levanter.git - fire diff --git a/src/levanter/main/train_lm_ray.py b/src/levanter/main/train_lm_ray.py index 4fcb185b3..e0e11126f 100644 --- a/src/levanter/main/train_lm_ray.py +++ b/src/levanter/main/train_lm_ray.py @@ -1,16 +1,22 @@ import fire +import jax import ray + from levanter.config import main as config_main from levanter.main.train_lm import main as train_lm_main ray.init() +print(f"ray.available_resources(): {ray.available_resources()}") +print(f"jax.device_count(): {jax.device_count()}") - -@ray.remote +@ray.remote(resources={"TPU": 8}) def train_lm(config_path: str): config_main(train_lm_main)(args=[config_path]) +def main(config_path: str): + train_lm.remote(config_path) + if __name__ == "__main__": - fire.Fire(train_lm) + fire.Fire(main)