Skip to content

Commit

Permalink
Update pytype to Py3.10 way (Part 2)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 559240196
  • Loading branch information
laurentes authored and pax authors committed Aug 22, 2023
1 parent 81ac508 commit c36b8e4
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 128 deletions.
2 changes: 1 addition & 1 deletion praxis/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@ py_strict_test(
":pax_fiddle",
# Implicit absl.testing.absltest dependency.
# Implicit fiddle dependency.
# Implicit fiddle.experimental.serialization dependency.
# Implicit flax.core dependency.
# Implicit jax dependency.
# Implicit ml_collections config_dict dependency.
Expand Down Expand Up @@ -488,6 +487,7 @@ py_strict_test(
":base_model",
":pax_fiddle",
":py_utils",
":pytypes",
":sample_decode",
":test_utils",
# Implicit absl.testing.absltest dependency.
Expand Down
16 changes: 7 additions & 9 deletions praxis/base_hyperparams_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import inspect
import pickle
import textwrap
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple
from typing import Any, Callable, NamedTuple

from absl.testing import absltest
import fiddle as fdl
Expand Down Expand Up @@ -66,7 +66,7 @@ class NestedTestClass(base_hyperparams.BaseParameterizable):
class HParams(base_hyperparams.BaseHyperParams):
# Note: This is now no longer recommended; only Params should be fields of
# Params.
d: Optional[SimpleTestChild] = None
d: SimpleTestChild | None = None
e: float = 3.0


Expand All @@ -86,7 +86,7 @@ class HParams(base_hyperparams.BaseHyperParams):

class NestedNestedOverrideTestClass(NestedNestedTestClass):
class HParams(NestedNestedTestClass.HParams):
_attribute_overrides: Tuple[str, ...] = ('tpl',)
_attribute_overrides: tuple[str, ...] = ('tpl',)
tpl: base_hyperparams.HParams = base_hyperparams.sub_config_field(
NestedTestBehaveClass.HParams)

Expand Down Expand Up @@ -120,7 +120,7 @@ class NestedStructToTextTestClass(base_hyperparams.BaseParameterizable):

class HParams(base_hyperparams.BaseHyperParams):
tpl: Any = base_hyperparams.sub_config_field(None)
a: Optional[frozen_dict.FrozenDict] = None
a: frozen_dict.FrozenDict | None = None


class FiddlifiedTestClass(base_hyperparams.FiddleBaseParameterizable):
Expand Down Expand Up @@ -403,7 +403,7 @@ class DefaultFactoryTestClass(base_hyperparams.BaseParameterizable):
_USE_DEPRECATED_HPARAMS_BASE_PARAMETERIZABLE = True

class HParams(base_hyperparams.BaseHyperParams):
a: List[str] = dataclasses.field(default_factory=lambda: [1, 2, 3])
a: list[str] = dataclasses.field(default_factory=lambda: [1, 2, 3])

instance_1 = DefaultFactoryTestClass.make()
instance_2 = DefaultFactoryTestClass.make()
Expand Down Expand Up @@ -533,7 +533,7 @@ class CheckpointLoadingRules(NamedTuple):


class CheckPointRuleTest(base_hyperparams.FiddleBaseParameterizable):
init_from_checkpoint_rules: Dict[str, CheckpointLoadingRules] = (
init_from_checkpoint_rules: dict[str, CheckpointLoadingRules] = (
pax_fiddle.instance_field(default_factory=dict)
)

Expand All @@ -542,9 +542,7 @@ class CheckPointRuleDataclassTest(base_hyperparams.FiddleBaseParameterizable):

@dataclasses.dataclass
class Train(base_hyperparams.FiddleBaseParameterizable):
init_from_checkpoint_rules: Optional[Dict[str, CheckpointLoadingRules]] = (
None
)
init_from_checkpoint_rules: dict[str, CheckpointLoadingRules] | None = None


class NestedStructToTextTestCase(absltest.TestCase):
Expand Down
16 changes: 8 additions & 8 deletions praxis/base_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import dataclasses
import sys
import typing
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -87,7 +87,7 @@ def quantize_weight(self) -> base_layer.NestedJTensor:


class MultipleLinearLayer(base_layer.BaseLayer):
linear1: Optional[AddBias] = pax_fiddle.instance_field(Linear)
linear1: AddBias | None = pax_fiddle.instance_field(Linear)
linear2_tpl: pax_fiddle.Config[Linear] = pax_fiddle.template_field(Linear)

def setup(self):
Expand Down Expand Up @@ -340,10 +340,10 @@ def __call__(self):
class ParentLayer(base_layer.BaseLayer):
# instance fields:
a: base_layer.BaseLayer = base_layer.instance_field(ChildLayer)
bs: List[base_layer.BaseLayer] = base_layer.instance_field(list)
bs: list[base_layer.BaseLayer] = base_layer.instance_field(list)
# template fields:
x_tpl: LayerTpl = base_layer.template_field(ChildLayer)
y_tpls: List[LayerTpl] = base_layer.template_field(list)
y_tpls: list[LayerTpl] = base_layer.template_field(list)

def setup(self):
self.create_child('x', self.x_tpl)
Expand Down Expand Up @@ -479,12 +479,12 @@ def __call__(self):
class FiddleParent(base_layer.BaseLayer):

child_tpl: pax_fiddle.Config = base_layer.template_field(FiddleChild)
child_tpl_list: List[pax_fiddle.Config] = base_layer.template_field(None)
child_tpl_dict: Dict[str, pax_fiddle.Config] = base_layer.template_field(
child_tpl_list: list[pax_fiddle.Config] = base_layer.template_field(None)
child_tpl_dict: dict[str, pax_fiddle.Config] = base_layer.template_field(
None
)
child_instance_list: Optional[List[base_layer.BaseLayer]] = None
child_instance_dict: Optional[List[base_layer.BaseLayer]] = None
child_instance_list: list[base_layer.BaseLayer] | None = None
child_instance_dict: list[base_layer.BaseLayer] | None = None

def setup(self):
child_tpl = self.child_tpl.clone()
Expand Down
9 changes: 5 additions & 4 deletions praxis/decoder_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import copy
import dataclasses
from typing import Optional, Sequence, TypeVar
from typing import Sequence, TypeVar

from praxis import decoder_utils
from praxis import pax_fiddle
Expand Down Expand Up @@ -92,7 +92,7 @@ class BeamSearchHParams(DecoderHParams):
early_exit: A bool, whether or not to allow early exit.
"""
beam_size: int = 1
tokens_per_beam: Optional[int] = None
tokens_per_beam: int | None = None
length_norm_alpha: float = 0.8
early_exit: bool = False
use_matmul_beam_shuffle: bool = False
Expand Down Expand Up @@ -125,6 +125,7 @@ class SampleDecoderHParams(DecoderHParams):
global_normalize: Normalize the logits over top-k logits or globally in the
whole vocabulary. It is used if k is nonzero and p is also not None.
cf_guidance_scale: If not None, apply classifier-free guidance.
controlled_decoding: Parameters for controlled decoding if used.
sort_samples: Whether to sort the samples by logprobs.
override_next_token_sampler_params: Whether to override, the next token
sampler params from the decoder ones. Ideally, this should not be
Expand All @@ -145,8 +146,8 @@ class SampleDecoderHParams(DecoderHParams):
pax_fiddle.template_field(sample_decode.DefaultNextTokenSampler))
global_normalize: bool = False
cf_guidance_scale: list[float] | float | None = None
controlled_decoding: Optional[decoder_utils.ControlledDecodingHParams] = None
sort_samples: Optional[bool] = True
controlled_decoding: decoder_utils.ControlledDecodingHParams | None = None
sort_samples: bool | None = True
override_next_token_sampler_params: bool = True
optimize_eos: bool = False
vanilla_sample_decode: bool = False
37 changes: 19 additions & 18 deletions praxis/pax_fiddle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import dataclasses
import functools
import types
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Union
from typing import Any, Callable, NamedTuple, Sequence

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -53,7 +53,7 @@ class ColoredWheel(Wheel):

@dataclasses.dataclass
class Person:
name: Optional[str] = None
name: str | None = None

def setup(self):
return self
Expand All @@ -64,7 +64,7 @@ class Vehicle:
wheel_tpl: pax_fiddle.Config[Wheel] = pax_fiddle.template_field(Wheel)
num_wheels: int = 4
owner: Person = pax_fiddle.instance_field(Person)
wheels: Optional[List[Wheel]] = None # Initialized by setup.
wheels: list[Wheel] | None = None # Initialized by setup.

def setup(self):
assert self.wheels is None
Expand All @@ -87,7 +87,7 @@ class Fleet:
vehicle_tpl: pax_fiddle.Config[Vehicle] = pax_fiddle.template_field(Vehicle)
num_vehicles: int = 1
manager: Person = pax_fiddle.instance_field(Person)
vehicles: Optional[List[Vehicle]] = None # Initialized by setup.
vehicles: list[Vehicle] | None = None # Initialized by setup.

def setup(self):
assert self.vehicles is None
Expand All @@ -101,29 +101,29 @@ def setup(self):
@dataclasses.dataclass
class BusStop:
location: str # required arg.
times: List[int] = dataclasses.field(default_factory=list)
times: list[int] = dataclasses.field(default_factory=list)


@dataclasses.dataclass
class HourlyBusStop(BusStop):
times: List[int] = dataclasses.field(default_factory=lambda: list(range(24)))
times: list[int] = dataclasses.field(default_factory=lambda: list(range(24)))


@dataclasses.dataclass
class WheelFactory:
wheel_tpl: List[pax_fiddle.Config[Wheel]] = dataclasses.field(
wheel_tpl: list[pax_fiddle.Config[Wheel]] = dataclasses.field(
default_factory=list
)


class NonDataclassWheelFactory:

def __init__(self, wheel_tpl: List[pax_fiddle.Config]):
def __init__(self, wheel_tpl: list[pax_fiddle.Config]):
self.wheel_tpl = wheel_tpl


class NamedTupleWheelFactory(NamedTuple):
wheel_tpl: List[pax_fiddle.Config[Wheel]]
wheel_tpl: list[pax_fiddle.Config[Wheel]]


class SubFieldAndTemplateFieldTest(testing.TestCase):
Expand Down Expand Up @@ -336,8 +336,8 @@ def test_build_fleet_directly(self):
def test_instance_field_empty_container_default_factory(self):
@dataclasses.dataclass
class TestCls:
items: List[Any] = pax_fiddle.instance_field(list)
tags: Dict[str, Any] = pax_fiddle.instance_field(dict)
items: list[Any] = pax_fiddle.instance_field(list)
tags: dict[str, Any] = pax_fiddle.instance_field(dict)

cfg = pax_fiddle.Config(TestCls)
self.assertDagEqual(cfg, pax_fiddle.Config(TestCls, items=[], tags={}))
Expand Down Expand Up @@ -394,7 +394,7 @@ class AnAutoconfigType:
tagged_type: ATaggedType = pax_fiddle.field(
default_factory=ATaggedType.default
)
another_default: Dict[str, Any] = pax_fiddle.field(
another_default: dict[str, Any] = pax_fiddle.field(
default_factory=nested_structure
)

Expand Down Expand Up @@ -1017,10 +1017,10 @@ def f1(x: pax_fiddle.Config):
def f2(x: pax_fiddle.Config[Wheel]):
return x

def f3(x: Optional[pax_fiddle.Config[Wheel]]):
def f3(x: pax_fiddle.Config[Wheel] | None):
return x

def f4(x: Union[pax_fiddle.Config, Sequence[pax_fiddle.Config]]):
def f4(x: pax_fiddle.Config | Sequence[pax_fiddle.Config] | None):
return x

for fn in [f1, f2, f3, f4]:
Expand All @@ -1029,16 +1029,17 @@ def f4(x: Union[pax_fiddle.Config, Sequence[pax_fiddle.Config]]):
self.assertDagEqual(result, pax_fiddle.Config(Wheel))

def test_do_not_build_function_args_if_arg_is_pax_config_container(self):
def f1(x: List[pax_fiddle.Config]):

def f1(x: list[pax_fiddle.Config]):
return x

def f2(x: List[pax_fiddle.Config[Wheel]]):
def f2(x: list[pax_fiddle.Config[Wheel]]):
return x

def f3(x: Optional[List[pax_fiddle.Config[Wheel]]]):
def f3(x: list[pax_fiddle.Config[Wheel]] | None):
return x

def f4(x: Union[Sequence[pax_fiddle.Config[Wheel]], pax_fiddle.Config]):
def f4(x: Sequence[pax_fiddle.Config[Wheel]] | pax_fiddle.Config):
return x

for fn in [f1, f2, f3, f4]:
Expand Down
Loading

0 comments on commit c36b8e4

Please sign in to comment.