Skip to content

Commit

Permalink
Fix rounding issue in interval calculation (#3109)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Mar 13, 2024
1 parent f616650 commit f65bb27
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 10 deletions.
4 changes: 2 additions & 2 deletions composer/callbacks/eval_output_logging_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ class EvalOutputLogging(Callback):
"""Logs eval outputs for each sample of each ICL evaluation dataset.
ICL metrics are required to support caching the model's responses including information on whether model was correct.
Metrics are responsible for returning the results of individual datapoints in a dictionary of lists.
Metrics are responsible for returning the results of individual data points in a dictionary of lists.
The callback will log the metric name, the depadded and detokenized input, any data stored in state.metric_outputs, and
any keys from the batch pased into `batch_keys_to_log`. It will do so after every eval batch.
any keys from the batch passed into `batch_keys_to_log`. It will do so after every eval batch.
"""

def __init__(self, log_tokens=False, *args, **kwargs):
Expand Down
8 changes: 4 additions & 4 deletions composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,8 +531,8 @@ def eval_forward(self, batch, outputs: Optional[Any] = None):

# don't remove prefix space to sentencepiece models
if len(
self.tokenizer(' a', add_special_tokens=False)['input_ids'],
) == 1: # pyright: ignore[reportGeneralTypeIssues]
self.tokenizer(' a', add_special_tokens=False)['input_ids'], # pyright: ignore[reportGeneralTypeIssues]
) == 1:
return self.tokenizer.batch_decode(
generation[:, batch['input_ids'].shape[1]:],
skip_special_tokens=True,
Expand Down Expand Up @@ -658,8 +658,8 @@ def get_metadata(self):
conda_package='sentencepiece',
) from e
s = spm.SentencePieceProcessor(
model_file=str(tokenizer_file_path),
) # pyright: ignore[reportGeneralTypeIssues]
model_file=str(tokenizer_file_path), # pyright: ignore[reportGeneralTypeIssues]
)
tokenizer_file_content = s.serialized_model_proto()
else:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion composer/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def check_interval(state: State, event: Event):

if event == interval_event:
if state.max_duration.unit == TimeUnit.EPOCH and int(state.timestamp.batch) % math.ceil(
state.max_duration.value * float(time_interval) * state.dataloader_len,
state.max_duration.value * float(time_interval) * state.dataloader_len.value,
) == 0:
last_batch_seen = state.timestamp.batch
return True
Expand Down
4 changes: 2 additions & 2 deletions composer/utils/object_store/oci_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def download_object(
try:
head_object_response = self.client.head_object(self.namespace, self.bucket, object_name)
object_size = int(
head_object_response.headers['content-length'],
) # pyright: ignore[reportOptionalMemberAccess]
head_object_response.headers['content-length'], # pyright: ignore[reportOptionalMemberAccess]
)
except Exception as e:
_reraise_oci_errors(self.get_uri(object_name), e)

Expand Down
46 changes: 45 additions & 1 deletion tests/utils/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

from composer.utils.misc import partial_format
import pytest

from composer.core import Event, Time, Timestamp
from composer.utils.misc import create_interval_scheduler, partial_format


class DummyState:

def __init__(self, current_batches: int, max_duration: str, dataloader_len: str):
self.previous_timestamp = Timestamp(batch=current_batches - 1)
self.timestamp = Timestamp(batch=current_batches)
self.max_duration = Time.from_timestring(max_duration)
self.dataloader_len = Time.from_timestring(dataloader_len)

def get_elapsed_duration(self):
return 0


def test_partial_format():
Expand All @@ -20,3 +35,32 @@ def test_partial_format():
assert partial_format('{foo} {}', 'World') == '{foo} World'
assert partial_format('{foo} {}', foo='Hello') == 'Hello {}'
assert partial_format('{foo} {}', 'World', foo='Hello') == 'Hello World'


@pytest.mark.parametrize(
'interval,current_batches,max_duration,dataloader_len,expected',
[
('0.25dur', 1, '1ep', '1ba', True),
('0.25dur', 1, '1ep', '4ba', True),
('0.25dur', 2, '1ep', '5ba', True),
('0.25dur', 1, '1ep', '5ba', False),
('0.25dur', 1, '1ba', '1ba', True),
('0.25dur', 1, '4ba', '4ba', True),
('0.25dur', 2, '5ba', '5ba', True),
('0.25dur', 1, '5ba', '5ba', False),
],
)
def test_interval_scheduler(
interval: str,
current_batches: int,
max_duration: str,
dataloader_len: str,
expected: bool,
):
interval_scheduler = create_interval_scheduler(interval)
dummy_state = DummyState(current_batches, max_duration, dataloader_len)

event = Event.BATCH_CHECKPOINT

actual = interval_scheduler(dummy_state, event) # type: ignore (intentional)
assert actual == expected

0 comments on commit f65bb27

Please sign in to comment.