Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Oct 22, 2024
1 parent f16ab86 commit 6c71b70
Showing 1 changed file with 36 additions and 28 deletions.
64 changes: 36 additions & 28 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,57 +7,47 @@
import functools
import itertools
import operator

import sys
import warnings
from copy import deepcopy
from dataclasses import asdict, dataclass

from packaging import version as pack_version
import numpy as np
import pytest
import torch
from _utils_internal import ( # noqa
dtype_fixture,
get_available_devices,
get_default_devices,
)
from mocking_classes import ContinuousActionConvMockEnv

from packaging import version, version as pack_version

from tensordict import assert_allclose_td, TensorDict, TensorDictBase
from tensordict._C import unravel_keys
from tensordict.nn import (
CompositeDistribution,
InteractionType,
NormalParamExtractor,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictModule as ProbMod,
ProbabilisticTensorDictSequential,
ProbabilisticTensorDictSequential as ProbSeq,
TensorDictModule,
TensorDictModule as Mod,
TensorDictSequential,
TensorDictSequential as Seq,
)
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type
from torchrl.modules.models import QMixer

_has_functorch = True
try:
import functorch as ft # noqa

make_functional_with_buffers = ft.make_functional_with_buffers
FUNCTORCH_ERR = ""
except ImportError as err:
_has_functorch = False
FUNCTORCH_ERR = str(err)

import numpy as np
import pytest
import torch
from _utils_internal import ( # noqa
dtype_fixture,
get_available_devices,
get_default_devices,
)
from mocking_classes import ContinuousActionConvMockEnv
from packaging import version

# from torchrl.data.postprocs.utils import expand_as_right
from tensordict import assert_allclose_td, TensorDict, TensorDictBase
from tensordict.nn import NormalParamExtractor, TensorDictModule
from tensordict.nn.utils import Buffer
from tensordict.utils import unravel_key
from torch import autograd, nn
from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded
from torchrl.data.postprocs.postprocs import MultiStep
from torchrl.envs.model_based.dreamer import DreamerEnv
from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type
from torchrl.modules import (
DistributionalQValueActor,
OneHotCategorical,
Expand All @@ -66,6 +56,7 @@
WorldModelWrapper,
)
from torchrl.modules.distributions.continuous import TanhDelta, TanhNormal
from torchrl.modules.models import QMixer
from torchrl.modules.models.model_based import (
DreamerActor,
ObsDecoder,
Expand Down Expand Up @@ -147,7 +138,18 @@
_split_and_pad_sequence,
)

_has_functorch = True
try:
import functorch as ft # noqa

make_functional_with_buffers = ft.make_functional_with_buffers
FUNCTORCH_ERR = ""
except ImportError as err:
_has_functorch = False
FUNCTORCH_ERR = str(err)

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
IS_WINDOWS = sys.platform == "win32"

# Capture all warnings
pytestmark = [
Expand Down Expand Up @@ -15735,7 +15737,13 @@ def __init__(self):
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
)
@pytest.mark.skipif(IS_WINDOWS, reason="windows tests do not support compile")
def test_exploration_compile():
try:
torch._dynamo.reset_code_caches()
except Exception:
# older versions of PT don't have that function
pass
m = ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["sample"],
Expand Down

0 comments on commit 6c71b70

Please sign in to comment.