From b3823458d679ba0e79ac9544204f0151eabae111 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 9 Feb 2024 13:54:32 -0800 Subject: [PATCH] Move to cuda unconditionally so pp-only run works --- torchtrain/parallelisms/parallelize_llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index d6db313fd..1ef4dd99b 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -185,4 +185,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, args): rank0_log("Applied FSDP to the model...") + # redundant if FSDP is used, but ensure the model is on device consistently regardless with parallelisms were used + model.cuda() return model