Skip to content

Commit

Permalink
fix some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
ken77921 committed Nov 13, 2023
1 parent bea4da4 commit 04a0ead
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 19 deletions.
29 changes: 15 additions & 14 deletions recbole/model/sequential_recommender/gru4rec_ours.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,21 @@ def __init__(self, config, dataset):
self.use_att = config['use_att'] #added for mfs
self.only_compute_loss = True #added for mfs
if self.use_att:
self.n_embd = 2* self.hidden_size #added for mfs
assert self.use_out_emb
self.dropout = nn.Dropout(self.dropout_prob)
self.We = nn.Linear(self.hidden_size, self.hidden_size)
self.Ue = nn.Linear(self.hidden_size, self.hidden_size)
self.tanh = nn.Tanh()
self.Ve = nn.Linear(self.hidden_size, 1)
out_size = 2*self.hidden_size
else:
self.n_embd = self.hidden_size #added for mfs
self.dense = nn.Linear(self.hidden_size, self.embedding_size)
out_size = self.embedding_size
self.n_embd = out_size
#if self.use_att:
# self.n_embd = 2* self.hidden_size #added for mfs
#else:
# self.n_embd = self.hidden_size #added for mfs

self.use_proj_bias = config['use_proj_bias'] #added for mfs
self.weight_mode = config['weight_mode'] #added for mfs
Expand Down Expand Up @@ -129,17 +141,6 @@ def __init__(self, config, dataset):
)
self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)

if self.use_att:
assert self.use_out_emb
self.dropout = nn.Dropout(self.dropout_prob)
self.We = nn.Linear(self.hidden_size, self.hidden_size)
self.Ue = nn.Linear(self.hidden_size, self.hidden_size)
self.tanh = nn.Tanh()
self.Ve = nn.Linear(self.hidden_size, 1)
out_size = 2*self.hidden_size
else:
self.dense = nn.Linear(self.hidden_size, self.embedding_size)
out_size = self.hidden_size

if self.use_out_emb:
self.out_item_embedding = nn.Linear(out_size, self.n_items, bias = False)
Expand Down Expand Up @@ -458,7 +459,7 @@ def calculate_loss_prob(self, interaction):
#logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
#loss = self.loss_fct(logits, pos_items)
#return loss
return loss, prediction_prob.squeeze()
return loss, prediction_prob.squeeze(dim=0)

def calculate_loss(self, interaction):
loss, prediction_prob = self.calculate_loss_prob(interaction)
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/sequential_recommender/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def calculate_loss_prob(self, interaction):
raise Exception("Labels can not be None")
#logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
#loss = self.loss_fct(logits, pos_items)
return loss, prediction_prob.squeeze()
return loss, prediction_prob.squeeze(dim=0)

def calculate_loss(self, interaction):
loss, prediction_prob = self.calculate_loss_prob(interaction)
Expand Down
4 changes: 2 additions & 2 deletions run_hyper_loop.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ elif [[ "$softmax_mode" == "softmax_C" ]]; then
elif [[ "$softmax_mode" == "softmax_CP" ]]; then
model_config="--n_facet=1+--n_facet_all=4+--n_facet_context=1+--n_facet_emb=2+--n_facet_hidden=1+--n_facet_window=0+--n_facet_MLP=0+--context_norm=1"
elif [[ "$softmax_mode" == "softmax_CPR:100" ]]; then
model_config="--n_facet=1+--n_facet_context=1+--n_facet_reranker=1+--n_facet_emb=2+--n_facet_all=7+--n_facet_hidden=1+--n_facet_window=0+--n_facet_MLP=0+--context_norm=1+--reranker_CAN_NUM=100+--reranker_merging_mode=replace"
model_config="--n_facet=1+--n_facet_context=1+--n_facet_reranker=1+--n_facet_emb=2+--n_facet_all=5+--n_facet_hidden=1+--n_facet_window=0+--n_facet_MLP=0+--context_norm=1+--reranker_CAN_NUM=100+--reranker_merging_mode=replace"
elif [[ "$softmax_mode" == "softmax_CPR:100_Mi" ]]; then
model_config="--n_facet=1+--n_facet_context=1+--n_facet_reranker=1+--n_facet_emb=2+--n_facet_all=7+--n_facet_hidden=2+--n_facet_window=-2+--n_facet_MLP=-1+--context_norm=1+--reranker_CAN_NUM=100+--reranker_merging_mode=replace"
model_config="--n_facet=1+--n_facet_context=1+--n_facet_reranker=1+--n_facet_emb=2+--n_facet_all=5+--n_facet_hidden=2+--n_facet_window=-2+--n_facet_MLP=-1+--context_norm=1+--reranker_CAN_NUM=100+--reranker_merging_mode=replace"
elif [[ "$softmax_mode" == "softmax_CPR:20,100,500_Mi" ]]; then
model_config="--n_facet=1+--n_facet_context=1+--n_facet_reranker=1+--n_facet_emb=2+--n_facet_all=7+--n_facet_hidden=2+--n_facet_window=-2+--n_facet_MLP=-1+--context_norm=1+--reranker_merging_mode=replace"
elif [[ "$softmax_mode" == "MoS" ]]; then
Expand Down
4 changes: 2 additions & 2 deletions run_hyper_slurm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ elif [[ "$softmax_mode" == "softmax_C" ]]; then
elif [[ "$softmax_mode" == "softmax_CP" ]]; then
model_config="--n_facet=1+--n_facet_all=4+--n_facet_context=1+--n_facet_emb=2+--n_facet_hidden=1+--n_facet_window=0+--n_facet_MLP=0+--context_norm=1"
elif [[ "$softmax_mode" == "softmax_CPR:100" ]]; then
model_config="--n_facet=1+--n_facet_context=1+--n_facet_reranker=1+--n_facet_emb=2+--n_facet_all=7+--n_facet_hidden=1+--n_facet_window=0+--n_facet_MLP=0+--context_norm=1+--reranker_CAN_NUM=100+--reranker_merging_mode=replace"
model_config="--n_facet=1+--n_facet_context=1+--n_facet_reranker=1+--n_facet_emb=2+--n_facet_all=5+--n_facet_hidden=1+--n_facet_window=0+--n_facet_MLP=0+--context_norm=1+--reranker_CAN_NUM=100+--reranker_merging_mode=replace"
elif [[ "$softmax_mode" == "softmax_CPR:100_Mi" ]]; then
model_config="--n_facet=1+--n_facet_context=1+--n_facet_reranker=1+--n_facet_emb=2+--n_facet_all=7+--n_facet_hidden=2+--n_facet_window=-2+--n_facet_MLP=-1+--context_norm=1+--reranker_CAN_NUM=100+--reranker_merging_mode=replace"
model_config="--n_facet=1+--n_facet_context=1+--n_facet_reranker=1+--n_facet_emb=2+--n_facet_all=5+--n_facet_hidden=2+--n_facet_window=-2+--n_facet_MLP=-1+--context_norm=1+--reranker_CAN_NUM=100+--reranker_merging_mode=replace"
elif [[ "$softmax_mode" == "softmax_CPR:20,100,500_Mi" ]]; then
model_config="--n_facet=1+--n_facet_context=1+--n_facet_reranker=1+--n_facet_emb=2+--n_facet_all=7+--n_facet_hidden=2+--n_facet_window=-2+--n_facet_MLP=-1+--context_norm=1+--reranker_merging_mode=replace"
elif [[ "$softmax_mode" == "MoS" ]]; then
Expand Down

0 comments on commit 04a0ead

Please sign in to comment.