Skip to content

Commit

Permalink
run train lm ray
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivan-Zhou committed Aug 28, 2024
1 parent cb76d3c commit cd08e5b
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
6 changes: 5 additions & 1 deletion execute_submit_ray_job.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
COMMAND="pip install git+https://github.com/stanford-crfm/levanter.git@ivan-ray-jobs fire && WANDB_API_KEY=be441272f3bd2812a2eb009739e26a202f14d7ba \
# pip install git+https://github.com/stanford-crfm/levanter.git@ivan-ray-jobs fire &&
COMMAND="pip install fire && pip uninstall -y levanter && pip install -U git+https://github.com/stanford-crfm/levanter.git@ivan-ray-jobs && pip install jax[tpu]"
COMMAND="echo 'Hello, World!'"
LAUNCH_COMMAND="WANDB_API_KEY=be441272f3bd2812a2eb009739e26a202f14d7ba \
WANDB_PROJECT=marin \
python src/levanter/main/train_lm_ray.py --config_path config/gpt2_nano.yaml"
COMMAND="$COMMAND && $LAUNCH_COMMAND"

echo $COMMAND
ray job submit --address http://127.0.0.1:8265 --working-dir . \
Expand Down
1 change: 1 addition & 0 deletions ray_runtime_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ runtime_env:
HF_TOKEN: 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
TPU_STDERR_LOG_LEVEL: 0
TPU_MIN_LOG_LEVEL: 0
RAY_ENABLE_RECORD_ACTOR_TASK_LOGGING: 1
pip:
# - git+https://github.com/stanford-crfm/levanter.git
- fire
Expand Down
12 changes: 9 additions & 3 deletions src/levanter/main/train_lm_ray.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import fire
import jax
import ray

from levanter.config import main as config_main
from levanter.main.train_lm import main as train_lm_main


ray.init()
print(f"ray.available_resources(): {ray.available_resources()}")
print(f"jax.device_count(): {jax.device_count()}")


@ray.remote
@ray.remote(resources={"TPU": 8})
def train_lm(config_path: str):
config_main(train_lm_main)(args=[config_path])


def main(config_path: str):
train_lm.remote(config_path)

if __name__ == "__main__":
fire.Fire(train_lm)
fire.Fire(main)

0 comments on commit cd08e5b

Please sign in to comment.