diff --git a/composer/callbacks/eval_output_logging_callback.py b/composer/callbacks/eval_output_logging_callback.py index 6334572410..cb0469e5fc 100644 --- a/composer/callbacks/eval_output_logging_callback.py +++ b/composer/callbacks/eval_output_logging_callback.py @@ -18,9 +18,9 @@ class EvalOutputLogging(Callback): """Logs eval outputs for each sample of each ICL evaluation dataset. ICL metrics are required to support caching the model's responses including information on whether model was correct. - Metrics are responsible for returning the results of individual datapoints in a dictionary of lists. + Metrics are responsible for returning the results of individual data points in a dictionary of lists. The callback will log the metric name, the depadded and detokenized input, any data stored in state.metric_outputs, and - any keys from the batch pased into `batch_keys_to_log`. It will do so after every eval batch. + any keys from the batch passed into `batch_keys_to_log`. It will do so after every eval batch. """ def __init__(self, log_tokens=False, *args, **kwargs): diff --git a/composer/models/huggingface.py b/composer/models/huggingface.py index 32d5d50902..2637f8cc9e 100644 --- a/composer/models/huggingface.py +++ b/composer/models/huggingface.py @@ -531,8 +531,8 @@ def eval_forward(self, batch, outputs: Optional[Any] = None): # don't remove prefix space to sentencepiece models if len( - self.tokenizer(' a', add_special_tokens=False)['input_ids'], - ) == 1: # pyright: ignore[reportGeneralTypeIssues] + self.tokenizer(' a', add_special_tokens=False)['input_ids'], # pyright: ignore[reportGeneralTypeIssues] + ) == 1: return self.tokenizer.batch_decode( generation[:, batch['input_ids'].shape[1]:], skip_special_tokens=True, @@ -658,8 +658,8 @@ def get_metadata(self): conda_package='sentencepiece', ) from e s = spm.SentencePieceProcessor( - model_file=str(tokenizer_file_path), - ) # pyright: ignore[reportGeneralTypeIssues] + model_file=str(tokenizer_file_path), # pyright: ignore[reportGeneralTypeIssues] + ) tokenizer_file_content = s.serialized_model_proto() else: raise ValueError( diff --git a/composer/utils/misc.py b/composer/utils/misc.py index 09fc2b404d..c5679b7eb1 100644 --- a/composer/utils/misc.py +++ b/composer/utils/misc.py @@ -111,7 +111,7 @@ def check_interval(state: State, event: Event): if event == interval_event: if state.max_duration.unit == TimeUnit.EPOCH and int(state.timestamp.batch) % math.ceil( - state.max_duration.value * float(time_interval) * state.dataloader_len, + state.max_duration.value * float(time_interval) * state.dataloader_len.value, ) == 0: last_batch_seen = state.timestamp.batch return True diff --git a/composer/utils/object_store/oci_object_store.py b/composer/utils/object_store/oci_object_store.py index 2df4d43cd9..ab2baee1bf 100644 --- a/composer/utils/object_store/oci_object_store.py +++ b/composer/utils/object_store/oci_object_store.py @@ -164,8 +164,8 @@ def download_object( try: head_object_response = self.client.head_object(self.namespace, self.bucket, object_name) object_size = int( - head_object_response.headers['content-length'], - ) # pyright: ignore[reportOptionalMemberAccess] + head_object_response.headers['content-length'], # pyright: ignore[reportOptionalMemberAccess] + ) except Exception as e: _reraise_oci_errors(self.get_uri(object_name), e) diff --git a/tests/utils/test_misc.py b/tests/utils/test_misc.py index 333262795d..8359f8ce62 100644 --- a/tests/utils/test_misc.py +++ b/tests/utils/test_misc.py @@ -1,7 +1,22 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 -from composer.utils.misc import partial_format +import pytest + +from composer.core import Event, Time, Timestamp +from composer.utils.misc import create_interval_scheduler, partial_format + + +class DummyState: + + def __init__(self, current_batches: int, max_duration: str, dataloader_len: str): + self.previous_timestamp = Timestamp(batch=current_batches - 1) + self.timestamp = Timestamp(batch=current_batches) + self.max_duration = Time.from_timestring(max_duration) + self.dataloader_len = Time.from_timestring(dataloader_len) + + def get_elapsed_duration(self): + return 0 def test_partial_format(): @@ -20,3 +35,32 @@ def test_partial_format(): assert partial_format('{foo} {}', 'World') == '{foo} World' assert partial_format('{foo} {}', foo='Hello') == 'Hello {}' assert partial_format('{foo} {}', 'World', foo='Hello') == 'Hello World' + + +@pytest.mark.parametrize( + 'interval,current_batches,max_duration,dataloader_len,expected', + [ + ('0.25dur', 1, '1ep', '1ba', True), + ('0.25dur', 1, '1ep', '4ba', True), + ('0.25dur', 2, '1ep', '5ba', True), + ('0.25dur', 1, '1ep', '5ba', False), + ('0.25dur', 1, '1ba', '1ba', True), + ('0.25dur', 1, '4ba', '4ba', True), + ('0.25dur', 2, '5ba', '5ba', True), + ('0.25dur', 1, '5ba', '5ba', False), + ], +) +def test_interval_scheduler( + interval: str, + current_batches: int, + max_duration: str, + dataloader_len: str, + expected: bool, +): + interval_scheduler = create_interval_scheduler(interval) + dummy_state = DummyState(current_batches, max_duration, dataloader_len) + + event = Event.BATCH_CHECKPOINT + + actual = interval_scheduler(dummy_state, event) # type: ignore (intentional) + assert actual == expected