From 0fb057d3868a93af502dfbf3ea6acf9dd8aaebe7 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Wed, 21 Aug 2024 14:22:26 +0800 Subject: [PATCH] graph base --- libai/models/utils/graph_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/libai/models/utils/graph_base.py b/libai/models/utils/graph_base.py index 651209ccd..316a70ed7 100644 --- a/libai/models/utils/graph_base.py +++ b/libai/models/utils/graph_base.py @@ -39,12 +39,14 @@ def __init__( is_train=True, auto_parallel_conf=None, global_mode=None, + device="cuda", ): super().__init__() self.model = model self.is_train = is_train self.global_mode = global_mode + self.device = device if is_train: self.add_optimizer(optimizer, lr_sch=lr_scheduler) @@ -103,7 +105,7 @@ def build(self, **kwargs): if self.is_train: placement_sbp_dict = ( dict( - placement=flow.env.all_device_placement("cuda"), + placement=flow.env.all_device_placement(self.device), sbp=flow.sbp.split(0), ) if self.global_mode.enabled