Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

freeze models #379

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 49 additions & 82 deletions RecommenderSystems/dlrm/dlrm_train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@
warnings.filterwarnings("ignore", category=FutureWarning)
from petastorm.reader import make_batch_reader

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
)
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))


def get_args(print_args=True):
Expand All @@ -41,9 +39,7 @@ def str_list(x):

parser = argparse.ArgumentParser()

parser.add_argument(
"--disable_fusedmlp", action="store_true", help="disable fused MLP or not"
)
parser.add_argument("--disable_fusedmlp", action="store_true", help="disable fused MLP or not")
parser.add_argument("--embedding_vec_size", type=int, default=128)
parser.add_argument(
"--one_embedding_key_type",
Expand All @@ -69,19 +65,13 @@ def str_list(x):
parser.add_argument("--model_load_dir", type=str, default=None)
parser.add_argument("--model_save_dir", type=str, default=None)
parser.add_argument(
"--save_initial_model",
action="store_true",
help="save initial model parameters or not.",
"--save_initial_model", action="store_true", help="save initial model parameters or not.",
)
parser.add_argument(
"--save_model_after_each_eval",
action="store_true",
help="save model after each eval.",
"--save_model_after_each_eval", action="store_true", help="save model after each eval.",
)
parser.add_argument("--data_dir", type=str, required=True)
parser.add_argument(
"--eval_batches", type=int, default=1612, help="number of eval batches"
)
parser.add_argument("--eval_batches", type=int, default=1612, help="number of eval batches")
parser.add_argument("--eval_batch_size", type=int, default=55296)
parser.add_argument("--eval_interval", type=int, default=10000)
parser.add_argument("--train_batch_size", type=int, default=55296)
Expand All @@ -98,16 +88,15 @@ def str_list(x):
help="Embedding table size array for sparse fields",
)
parser.add_argument(
"--persistent_path",
type=str,
required=True,
help="path for persistent kv store",
"--persistent_path", type=str, required=True, help="path for persistent kv store",
)
parser.add_argument("--store_type", type=str, default="cached_host_mem")
parser.add_argument("--cache_memory_budget_mb", type=int, default=8192)
parser.add_argument("--amp", action="store_true", help="Run model with amp")
parser.add_argument("--loss_scale_policy", type=str, default="static", help="static or dynamic")
parser.add_argument("--freeze_embedding", action="store_true", help="freeze embedding model")
parser.add_argument(
"--loss_scale_policy", type=str, default="static", help="static or dynamic"
"--freeze_backbone", action="store_true", help="freeze backbone models except embedding"
)

args = parser.parse_args()
Expand Down Expand Up @@ -192,9 +181,7 @@ def get_batches(self, reader, batch_size=None):
pos = batch_size - len(tail[0])
tail = list(
[
np.concatenate(
(tail[i], rglist[i][0 : (batch_size - len(tail[i]))])
)
np.concatenate((tail[i], rglist[i][0 : (batch_size - len(tail[i]))]))
for i in range(self.C_end)
]
)
Expand All @@ -209,13 +196,8 @@ def get_batches(self, reader, batch_size=None):
continue
while (pos + batch_size) <= len(rglist[0]):
label = rglist[0][pos : pos + batch_size]
dense = [
rglist[j][pos : pos + batch_size] for j in range(1, self.I_end)
]
sparse = [
rglist[j][pos : pos + batch_size]
for j in range(self.I_end, self.C_end)
]
dense = [rglist[j][pos : pos + batch_size] for j in range(1, self.I_end)]
sparse = [rglist[j][pos : pos + batch_size] for j in range(self.I_end, self.C_end)]
pos += batch_size
yield label, np.stack(dense, axis=-1), np.stack(sparse, axis=-1)
if pos != len(rglist[0]):
Expand All @@ -237,7 +219,12 @@ def forward(self, x: flow.Tensor) -> flow.Tensor:

class MLP(nn.Module):
def __init__(
self, in_features: int, hidden_units, skip_final_activation=False, fused=True
self,
in_features: int,
hidden_units,
skip_final_activation=False,
fused=True,
freeze_model=False,
) -> None:
super(MLP, self).__init__()
if fused:
Expand All @@ -251,11 +238,7 @@ def __init__(
units = [in_features] + hidden_units
num_layers = len(hidden_units)
denses = [
Dense(
units[i],
units[i + 1],
not skip_final_activation or (i + 1) < num_layers,
)
Dense(units[i], units[i + 1], not skip_final_activation or (i + 1) < num_layers,)
for i in range(num_layers)
]
self.linear_layers = nn.Sequential(*denses)
Expand All @@ -265,6 +248,9 @@ def __init__(
nn.init.normal_(param, 0.0, np.sqrt(2 / sum(param.shape)))
elif "bias" in name:
nn.init.normal_(param, 0.0, np.sqrt(1 / param.shape[0]))
if freeze_model:
print(f"Model {name} in MLP is freezed")
param.requires_grad = False

def forward(self, x: flow.Tensor) -> flow.Tensor:
return self.linear_layers(x)
Expand All @@ -280,15 +266,9 @@ def __init__(
):
super(Interaction, self).__init__()
self.interaction_itself = interaction_itself
n_cols = (
num_embedding_fields + 2
if self.interaction_itself
else num_embedding_fields + 1
)
n_cols = num_embedding_fields + 2 if self.interaction_itself else num_embedding_fields + 1
output_size = dense_feature_size + sum(range(n_cols))
self.output_size = (
((output_size + 8 - 1) // 8 * 8) if interaction_padding else output_size
)
self.output_size = ((output_size + 8 - 1) // 8 * 8) if interaction_padding else output_size
self.output_padding = self.output_size - output_size

def forward(self, x: flow.Tensor, y: flow.Tensor) -> flow.Tensor:
Expand All @@ -310,6 +290,7 @@ def __init__(
store_type,
cache_memory_budget_mb,
key_type,
freeze_model,
):
assert table_size_array is not None
vocab_size = sum(table_size_array)
Expand Down Expand Up @@ -352,6 +333,10 @@ def __init__(
tables=tables,
store_options=store_options,
)
for name, param in self.one_embedding.named_parameters():
if freeze_model:
print(f"Model {name} in OneEmbedding is freezed")
param.requires_grad = False

def forward(self, ids):
return self.one_embedding.forward(ids)
Expand All @@ -372,8 +357,11 @@ def __init__(
interaction_itself=True,
interaction_padding=True,
dense_input_padding=True,
freeze_embedding=False,
freeze_backbone=False,
):
super(DLRMModule, self).__init__()
assert not(freeze_embedding and freeze_backbone), "Freezing all models are not allowed."
assert (
embedding_vec_size == bottom_mlp[-1]
), "Embedding vector size must equle to bottom MLP output size"
Expand All @@ -386,14 +374,17 @@ def __init__(
else None
)

self.bottom_mlp = MLP(self.num_dense_fields, bottom_mlp, fused=use_fusedmlp)
self.bottom_mlp = MLP(
self.num_dense_fields, bottom_mlp, fused=use_fusedmlp, freeze_model=freeze_backbone
)
self.embedding = OneEmbedding(
embedding_vec_size,
persistent_path,
table_size_array,
one_embedding_store_type,
cache_memory_budget_mb,
one_embedding_key_type,
freeze_embedding,
)
self.interaction = Interaction(
bottom_mlp[-1],
Expand All @@ -406,6 +397,7 @@ def __init__(
top_mlp + [1],
skip_final_activation=True,
fused=use_fusedmlp,
freeze_model=freeze_backbone,
)

def forward(self, dense_fields, sparse_fields) -> flow.Tensor:
Expand All @@ -432,6 +424,8 @@ def make_dlrm_module(args):
interaction_itself=args.interaction_itself,
interaction_padding=not args.disable_interaction_padding,
dense_input_padding=not args.disable_dense_input_padding,
freeze_embedding=args.freeze_embedding,
freeze_backbone=args.freeze_backbone,
)
return model

Expand Down Expand Up @@ -462,11 +456,7 @@ def make_lr_scheduler(args, optimizer):
optimizer, start_factor=0, total_iters=args.warmup_batches,
)
poly_decay_lr = flow.optim.lr_scheduler.PolynomialLR(
optimizer,
decay_batch=args.decay_batches,
end_learning_rate=0,
power=2.0,
cycle=False,
optimizer, decay_batch=args.decay_batches, end_learning_rate=0, power=2.0, cycle=False,
)
sequential_lr = flow.optim.lr_scheduler.SequentialLR(
optimizer=optimizer,
Expand All @@ -491,13 +481,7 @@ def build(self, dense_fields, sparse_fields):

class DLRMTrainGraph(flow.nn.Graph):
def __init__(
self,
dlrm_module,
loss,
optimizer,
lr_scheduler=None,
grad_scaler=None,
amp=False,
self, dlrm_module, loss, optimizer, lr_scheduler=None, grad_scaler=None, amp=False,
):
super(DLRMTrainGraph, self).__init__()
self.module = dlrm_module
Expand All @@ -522,9 +506,7 @@ def prefetch_eval_batches(data_dir, batch_size, num_batches):
cached_eval_batches = []
with make_criteo_dataloader(data_dir, batch_size, shuffle=False) as loader:
for _ in range(num_batches):
label, dense_fields, sparse_fields = batch_to_global(
*next(loader), is_train=False
)
label, dense_fields, sparse_fields = batch_to_global(*next(loader), is_train=False)
cached_eval_batches.append((label, dense_fields, sparse_fields))
return cached_eval_batches

Expand Down Expand Up @@ -560,25 +542,18 @@ def save_model(subdir):
grad_scaler = flow.amp.StaticGradScaler(1024)
else:
grad_scaler = flow.amp.GradScaler(
init_scale=1073741824,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
init_scale=1073741824, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000,
)

eval_graph = DLRMValGraph(dlrm_module, args.amp)
train_graph = DLRMTrainGraph(
dlrm_module, loss, opt, lr_scheduler, grad_scaler, args.amp
)
train_graph = DLRMTrainGraph(dlrm_module, loss, opt, lr_scheduler, grad_scaler, args.amp)

cached_eval_batches = prefetch_eval_batches(
f"{args.data_dir}/test", args.eval_batch_size, args.eval_batches
)

dlrm_module.train()
with make_criteo_dataloader(
f"{args.data_dir}/train", args.train_batch_size
) as loader:
with make_criteo_dataloader(f"{args.data_dir}/train", args.train_batch_size) as loader:
step, last_step, last_time = -1, 0, time.time()
for step in range(1, args.train_batches + 1):
labels, dense_fields, sparse_fields = batch_to_global(*next(loader))
Expand Down Expand Up @@ -610,15 +585,11 @@ def save_model(subdir):

def np_to_global(np):
t = flow.from_numpy(np)
return t.to_global(
placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0)
)
return t.to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0))


def batch_to_global(np_label, np_dense, np_sparse, is_train=True):
labels = (
np_to_global(np_label.reshape(-1, 1)) if is_train else np_label.reshape(-1, 1)
)
labels = np_to_global(np_label.reshape(-1, 1)) if is_train else np_label.reshape(-1, 1)
dense_fields = np_to_global(np_dense)
sparse_fields = np_to_global(np_sparse)
return labels, dense_fields, sparse_fields
Expand All @@ -638,15 +609,11 @@ def eval(cached_eval_batches, eval_graph, cur_step=0):
preds.append(pred.to_local())

labels = (
np_to_global(np.concatenate(labels, axis=0))
.to_global(sbp=flow.sbp.broadcast())
.to_local()
np_to_global(np.concatenate(labels, axis=0)).to_global(sbp=flow.sbp.broadcast()).to_local()
)
preds = (
flow.cat(preds, dim=0)
.to_global(
placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0)
)
.to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0))
.to_global(sbp=flow.sbp.broadcast())
.to_local()
)
Expand Down