Skip to content

Commit

Permalink
Move parallel_state.py to the distributed folder
Browse files Browse the repository at this point in the history
  • Loading branch information
regisss committed Dec 23, 2024
1 parent f2a9079 commit a6ee7c2
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion optimum/habana/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from accelerate.utils.other import is_compiled_module
from torch.optim.lr_scheduler import LRScheduler

from .. import parallel_state
from ..distributed import parallel_state


if is_deepspeed_available():
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from torch.utils.data import BatchSampler, DataLoader, IterableDataset

from .. import parallel_state
from ..distributed import parallel_state
from .state import GaudiAcceleratorState
from .utils.operations import (
broadcast,
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from optimum.utils import logging

from .. import parallel_state
from ..distributed import parallel_state
from .utils import GaudiDistributedType


Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/distributed/contextparallel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from ..parallel_state import (
from .parallel_state import (
get_sequence_parallel_group,
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
Expand Down
File renamed without changes.
3 changes: 2 additions & 1 deletion optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
)
from transformers.utils import is_torchdynamo_compiling

from .... import distributed, parallel_state
from .... import distributed
from ....distributed import parallel_state
from ....distributed.strategy import DistributedStrategy, NoOpStrategy
from ....distributed.tensorparallel import (
reduce_from_tensor_model_parallel_region,
Expand Down

0 comments on commit a6ee7c2

Please sign in to comment.