Skip to content

Commit

Permalink
Simplify the code and use steps instead of seconds as the default uni…
Browse files Browse the repository at this point in the history
…t of measurement.
  • Loading branch information
fegin committed Feb 1, 2024
1 parent 000bfe5 commit a0257bc
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
7 changes: 4 additions & 3 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ MODEL="debugmodel"
NGPU=8
MP=4
# Change this string to a meaningful one to enable checkpoint
CHECKPOINT_FOLDER=""
# Please adjust this to a longer interval period. The unit of measurement is in seconds.
CHECKPOINT_INTERVAL=10
CHECKPOINT_FOLDER="/tmp/chienchin"
# Please adjust this to a longer interval period. The unit of measurement is in steps.
CHECKPOINT_INTERVAL=2
CHECKPOINT_INTERVAL_TYPE="seconds"

torchrun --nproc_per_node=${NGPU} \
train.py --steps 10 \
Expand Down
24 changes: 12 additions & 12 deletions torchtrain/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,20 @@ def __init__(
self.pg = dist.new_group(backend="gloo")
self.doit = None

def reset(self, step: int = 0) -> None:
self.begin = (
time.monotonic() if self.interval_type == IntervalType.SECONDS else step
)
def reset(self) -> None:
self.begin = time.monotonic()

def create_checkpoint_id(self, step: int) -> str:
return os.path.join(self.folder, f"step-{step}")

def save(self, curr_step: int, force: bool = False) -> None:
if not self.folder:
return

if not force:
if not self.folder:
return
if (
self.interval_type == IntervalType.STEPS
and (curr_step - self.begin) < self.interval
and not (curr_step % self.interval == 0)
):
return
if self.interval_type == IntervalType.SECONDS:
Expand All @@ -102,15 +101,16 @@ def save(self, curr_step: int, force: bool = False) -> None:
return
else:
return
else:
if self.work:
self.work.wait()
self.doit = None

if self.work:
self.work.wait()
self.work = None
self.doit = None

logging.warning(f"Saving a checkpoint in step {curr_step}.")
begin = time.monotonic()
dcp.save(self.states, checkpoint_id=self.create_checkpoint_id(curr_step))
self.reset(curr_step)
self.reset()
logging.warning(
f"Finish saving the checkpoint in step {curr_step}. "
f"{time.monotonic() - begin} seconds"
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,10 @@ def main(args):
)
parser.add_argument(
"--checkpoint-interval-type",
type=str, default="seconds",
type=str, default="steps",
help=(
"The checkpointing interval unit of measurement."
"The default value is seconds."
"The default value is step."
)
)
parser.add_argument(
Expand Down

0 comments on commit a0257bc

Please sign in to comment.