Skip to content

Commit

Permalink
fix failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mbway committed Mar 19, 2024
1 parent 7003f3c commit 5468954
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
11 changes: 9 additions & 2 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io
import os
import pathlib
import re
import shutil
import tarfile
import tempfile
Expand Down Expand Up @@ -226,7 +227,10 @@ def test_checkpoint_saver_folder_filename_path(folder: Union[str, pathlib.Path],


def test_checkpoint_invalid_compressor(monkeypatch: pytest.MonkeyPatch):
with pytest.raises(CompressorNotFound, match='could not find compressor for "foo.pt.unknown_compressor"'):
with pytest.raises(
CompressorNotFound,
match=re.escape('Could not find compressor for "foo.pt.unknown_compressor".'),
):
CheckpointSaver(filename='foo.pt.unknown_compressor')

import composer.utils.compression
Expand All @@ -236,7 +240,10 @@ def test_checkpoint_invalid_compressor(monkeypatch: pytest.MonkeyPatch):
[CliCompressor('unknown_compressor', 'unknown_compressor_cmd')],
)

with pytest.raises(CompressorNotFound, match='could not find command "unknown_compressor_cmd" in the PATH'):
with pytest.raises(
CompressorNotFound,
match=re.escape('Could not find command "unknown_compressor_cmd" in the PATH'),
):
CheckpointSaver(filename='foo.pt.unknown_compressor')


Expand Down
3 changes: 2 additions & 1 deletion tests/utils/test_compression.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import re
from pathlib import Path

import pytest
Expand Down Expand Up @@ -28,7 +29,7 @@ def test_is_compressed_pt() -> None:


def test_get_invalid_compressor() -> None:
with pytest.raises(CompressorNotFound, match='could not find compressor for "foo.pt.unknown"'):
with pytest.raises(CompressorNotFound, match=re.escape('Could not find compressor for "foo.pt.unknown".')):
get_compressor('foo.pt.unknown')


Expand Down

0 comments on commit 5468954

Please sign in to comment.