Skip to content

Commit

Permalink
graph base
Browse files Browse the repository at this point in the history
  • Loading branch information
ShawnXuan committed Aug 21, 2024
1 parent c8f64a9 commit 0fb057d
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion libai/models/utils/graph_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ def __init__(
is_train=True,
auto_parallel_conf=None,
global_mode=None,
device="cuda",
):
super().__init__()

self.model = model
self.is_train = is_train
self.global_mode = global_mode
self.device = device

if is_train:
self.add_optimizer(optimizer, lr_sch=lr_scheduler)
Expand Down Expand Up @@ -103,7 +105,7 @@ def build(self, **kwargs):
if self.is_train:
placement_sbp_dict = (
dict(
placement=flow.env.all_device_placement("cuda"),
placement=flow.env.all_device_placement(self.device),
sbp=flow.sbp.split(0),
)
if self.global_mode.enabled
Expand Down

0 comments on commit 0fb057d

Please sign in to comment.