From 6c71b704b4991460e74453f733da8b666d98f347 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 22 Oct 2024 07:52:42 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- test/test_cost.py | 64 ++++++++++++++++++++++++++--------------------- 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 4de2ba0426d..f6b0ed21936 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -7,50 +7,39 @@ import functools import itertools import operator + +import sys import warnings from copy import deepcopy from dataclasses import asdict, dataclass -from packaging import version as pack_version +import numpy as np +import pytest +import torch +from _utils_internal import ( # noqa + dtype_fixture, + get_available_devices, + get_default_devices, +) +from mocking_classes import ContinuousActionConvMockEnv + +from packaging import version, version as pack_version + +from tensordict import assert_allclose_td, TensorDict, TensorDictBase from tensordict._C import unravel_keys from tensordict.nn import ( CompositeDistribution, InteractionType, + NormalParamExtractor, ProbabilisticTensorDictModule, ProbabilisticTensorDictModule as ProbMod, ProbabilisticTensorDictSequential, ProbabilisticTensorDictSequential as ProbSeq, + TensorDictModule, TensorDictModule as Mod, TensorDictSequential, TensorDictSequential as Seq, ) -from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type -from torchrl.modules.models import QMixer - -_has_functorch = True -try: - import functorch as ft # noqa - - make_functional_with_buffers = ft.make_functional_with_buffers - FUNCTORCH_ERR = "" -except ImportError as err: - _has_functorch = False - FUNCTORCH_ERR = str(err) - -import numpy as np -import pytest -import torch -from _utils_internal import ( # noqa - dtype_fixture, - get_available_devices, - get_default_devices, -) -from mocking_classes import ContinuousActionConvMockEnv -from packaging import version - -# from torchrl.data.postprocs.utils import expand_as_right -from tensordict import assert_allclose_td, TensorDict, TensorDictBase -from tensordict.nn import NormalParamExtractor, TensorDictModule from tensordict.nn.utils import Buffer from tensordict.utils import unravel_key from torch import autograd, nn @@ -58,6 +47,7 @@ from torchrl.data.postprocs.postprocs import MultiStep from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv +from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type from torchrl.modules import ( DistributionalQValueActor, OneHotCategorical, @@ -66,6 +56,7 @@ WorldModelWrapper, ) from torchrl.modules.distributions.continuous import TanhDelta, TanhNormal +from torchrl.modules.models import QMixer from torchrl.modules.models.model_based import ( DreamerActor, ObsDecoder, @@ -147,7 +138,18 @@ _split_and_pad_sequence, ) +_has_functorch = True +try: + import functorch as ft # noqa + + make_functional_with_buffers = ft.make_functional_with_buffers + FUNCTORCH_ERR = "" +except ImportError as err: + _has_functorch = False + FUNCTORCH_ERR = str(err) + TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) +IS_WINDOWS = sys.platform == "win32" # Capture all warnings pytestmark = [ @@ -15735,7 +15737,13 @@ def __init__(self): @pytest.mark.skipif( TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" ) +@pytest.mark.skipif(IS_WINDOWS, reason="windows tests do not support compile") def test_exploration_compile(): + try: + torch._dynamo.reset_code_caches() + except Exception: + # older versions of PT don't have that function + pass m = ProbabilisticTensorDictModule( in_keys=["loc", "scale"], out_keys=["sample"],