Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tests] make tests device-agnostic (part 3) #10437

Merged
merged 31 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8d0f387
initial comit
faaany Jan 3, 2025
88919c0
fix empty cache
faaany Jan 3, 2025
e32a9ac
fix one more
faaany Jan 3, 2025
cb7d9d5
fix style
faaany Jan 3, 2025
a393860
update device functions
faaany Jan 6, 2025
2f3ad32
update
faaany Jan 6, 2025
f3a519f
update
faaany Jan 6, 2025
b402629
Merge branch 'main' into xpu-enabling
faaany Jan 6, 2025
d1532d2
Update src/diffusers/utils/testing_utils.py
faaany Jan 7, 2025
16cca22
Update src/diffusers/utils/testing_utils.py
faaany Jan 7, 2025
3420e1f
Update src/diffusers/utils/testing_utils.py
faaany Jan 7, 2025
d15618b
Update tests/pipelines/controlnet/test_controlnet.py
faaany Jan 7, 2025
e814635
Update src/diffusers/utils/testing_utils.py
faaany Jan 7, 2025
e799516
Update src/diffusers/utils/testing_utils.py
faaany Jan 7, 2025
d3e8678
Update tests/pipelines/controlnet/test_controlnet.py
faaany Jan 7, 2025
fed282b
with gc.collect
faaany Jan 7, 2025
8577a14
update
faaany Jan 7, 2025
f08a849
Merge branch 'huggingface:main' into xpu-enabling
faaany Jan 7, 2025
35d7a7a
make style
hlky Jan 7, 2025
736cc7c
Merge branch 'main' into xpu-enabling
hlky Jan 7, 2025
c8661f0
check_torch_dependencies
hlky Jan 7, 2025
d4266a7
Merge branch 'main' into xpu-enabling
faaany Jan 8, 2025
d820f75
add mps empty cache
faaany Jan 8, 2025
b813f16
bug fix
faaany Jan 9, 2025
286fa53
Merge branch 'main' into xpu-enabling
faaany Jan 9, 2025
a2ee718
Merge branch 'huggingface:main' into xpu-enabling
faaany Jan 9, 2025
92eeb91
Merge branch 'main' into xpu-enabling
faaany Jan 13, 2025
c9b497e
Merge branch 'main' into xpu-enabling
faaany Jan 14, 2025
172cb97
Merge branch 'main' into xpu-enabling
hlky Jan 15, 2025
a442a21
Apply suggestions from code review
hlky Jan 15, 2025
1836172
Merge branch 'main' into xpu-enabling
faaany Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 64 additions & 5 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,12 @@
) from e
logger.info(f"torch_device overrode to {torch_device}")
else:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch_device = "cuda"
elif torch.xpu.is_available():
torch_device = "xpu"
else:
torch_device = "cpu"
is_torch_higher_equal_than_1_12 = version.parse(
version.parse(torch.__version__).base_version
) >= version.parse("1.12")
Expand Down Expand Up @@ -1067,12 +1072,51 @@ def _is_torch_fp64_available(device):
# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
if is_torch_available():
# Behaviour flags
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True}
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}

# Function definitions
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None}
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0}
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
BACKEND_EMPTY_CACHE = {
"cuda": torch.cuda.empty_cache,
"xpu": torch.xpu.empty_cache,
"cpu": None,
"mps": torch.mps.empty_cache,
"default": None,
}
BACKEND_DEVICE_COUNT = {
"cuda": torch.cuda.device_count,
"xpu": torch.xpu.device_count,
"cpu": lambda: 0,
"mps": lambda: 0,
"default": 0,
}
BACKEND_MANUAL_SEED = {
"cuda": torch.cuda.manual_seed,
"xpu": torch.xpu.manual_seed,
"cpu": torch.manual_seed,
"mps": torch.mps.manual_seed,
"default": torch.manual_seed,
}
BACKEND_RESET_PEAK_MEMORY_STATS = {
"cuda": torch.cuda.reset_peak_memory_stats,
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
"cpu": None,
"mps": None,
"default": None,
}
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.reset_max_memory_allocated,
"xpu": None,
"cpu": None,
"mps": None,
"default": None,
}
faaany marked this conversation as resolved.
Show resolved Hide resolved
BACKEND_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.max_memory_allocated,
"xpu": getattr(torch.xpu, "max_memory_allocated", None),
"cpu": 0,
"mps": 0,
"default": 0,
}


# This dispatches a defined function according to the accelerator from the function definitions.
Expand Down Expand Up @@ -1103,6 +1147,18 @@ def backend_device_count(device: str):
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)


def backend_reset_peak_memory_stats(device: str):
return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)


faaany marked this conversation as resolved.
Show resolved Hide resolved
def backend_reset_max_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)


def backend_max_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)


# These are callables which return boolean behaviour flags and can be used to specify some
# device agnostic alternative where the feature is unsupported.
def backend_supports_training(device: str):
Expand Down Expand Up @@ -1159,3 +1215,6 @@ def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")
update_mapping_from_spec(BACKEND_RESET_PEAK_MEMORY_STATS, "RESET_PEAK_MEMORY_STATS_FN")
update_mapping_from_spec(BACKEND_RESET_MAX_MEMORY_ALLOCATED, "RESET_MAX_MEMORY_ALLOCATED_FN")
update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN")
16 changes: 8 additions & 8 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
get_python_version,
is_torch_compile,
require_torch_2,
require_torch_accelerator,
require_torch_accelerator_with_training,
require_torch_gpu,
require_torch_multi_gpu,
run_test_in_subprocess,
torch_all_close,
Expand Down Expand Up @@ -543,7 +543,7 @@ def test_set_xformers_attn_processor_for_determinism(self):
assert torch.allclose(output, output_3, atol=self.base_precision)
assert torch.allclose(output_2, output_3, atol=self.base_precision)

@require_torch_gpu
@require_torch_accelerator
def test_set_attn_processor_for_determinism(self):
if self.uses_custom_attn_processor:
return
Expand Down Expand Up @@ -1068,7 +1068,7 @@ def test_wrong_adapter_name_raises_error(self):

self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))

@require_torch_gpu
@require_torch_accelerator
def test_cpu_offload(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
Expand Down Expand Up @@ -1098,7 +1098,7 @@ def test_cpu_offload(self):

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_torch_gpu
@require_torch_accelerator
def test_disk_offload_without_safetensors(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
Expand Down Expand Up @@ -1132,7 +1132,7 @@ def test_disk_offload_without_safetensors(self):

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_torch_gpu
@require_torch_accelerator
def test_disk_offload_with_safetensors(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
Expand Down Expand Up @@ -1191,7 +1191,7 @@ def test_model_parallelism(self):

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_torch_gpu
@require_torch_accelerator
def test_sharded_checkpoints(self):
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
Expand Down Expand Up @@ -1223,7 +1223,7 @@ def test_sharded_checkpoints(self):

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_torch_gpu
@require_torch_accelerator
def test_sharded_checkpoints_with_variant(self):
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
Expand Down Expand Up @@ -1261,7 +1261,7 @@ def test_sharded_checkpoints_with_variant(self):

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_torch_gpu
@require_torch_accelerator
def test_sharded_checkpoints_device_map(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
Expand Down
6 changes: 3 additions & 3 deletions tests/pipelines/allegro/test_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
enable_full_determinism,
numpy_cosine_similarity_distance,
require_hf_hub_version_greater,
require_torch_gpu,
require_torch_accelerator,
require_transformers_version_greater,
slow,
torch_device,
Expand Down Expand Up @@ -332,7 +332,7 @@ def test_save_load_dduf(self):


@slow
@require_torch_gpu
@require_torch_accelerator
class AllegroPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."

Expand All @@ -350,7 +350,7 @@ def test_allegro(self):
generator = torch.Generator("cpu").manual_seed(0)

pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt

videos = pipe(
Expand Down
11 changes: 6 additions & 5 deletions tests/pipelines/animatediff/test_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import is_xformers_available, logging
from diffusers.utils.testing_utils import (
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
Expand Down Expand Up @@ -547,19 +548,19 @@ def test_vae_slicing(self):


@slow
@require_torch_gpu
@require_torch_accelerator
class AnimateDiffPipelineSlowTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
faaany marked this conversation as resolved.
Show resolved Hide resolved

def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_animatediff(self):
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
Expand All @@ -573,7 +574,7 @@ def test_animatediff(self):
clip_sample=False,
)
pipe.enable_vae_slicing()
pipe.enable_model_cpu_offload()
pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)

prompt = "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"
Expand Down
6 changes: 3 additions & 3 deletions tests/pipelines/cogvideo/test_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
Expand Down Expand Up @@ -321,7 +321,7 @@ def test_fused_qkv_projections(self):


@slow
@require_torch_gpu
@require_torch_accelerator
class CogVideoXPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."

Expand All @@ -339,7 +339,7 @@ def test_cogvideox(self):
generator = torch.Generator("cpu").manual_seed(0)

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt

videos = pipe(
Expand Down
11 changes: 6 additions & 5 deletions tests/pipelines/cogvideo/test_cogvideox_image2video.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
Expand Down Expand Up @@ -344,25 +345,25 @@ def test_fused_qkv_projections(self):


@slow
@require_torch_gpu
@require_torch_accelerator
class CogVideoXImageToVideoPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."

def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_cogvideox(self):
generator = torch.Generator("cpu").manual_seed(0)

pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
pipe.enable_model_cpu_offload(device=torch_device)

prompt = self.prompt
image = load_image(
Expand Down
6 changes: 3 additions & 3 deletions tests/pipelines/cogview3/test_cogview3plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
Expand Down Expand Up @@ -232,7 +232,7 @@ def test_attention_slicing_forward_pass(


@slow
@require_torch_gpu
@require_torch_accelerator
class CogView3PlusPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."

Expand All @@ -250,7 +250,7 @@ def test_cogview3plus(self):
generator = torch.Generator("cpu").manual_seed(0)

pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3b", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt

images = pipe(
Expand Down
Loading
Loading