From 79e56003aca7148209637fa6f7286c8e57473914 Mon Sep 17 00:00:00 2001 From: ken77921 Date: Mon, 13 Nov 2023 01:00:47 +0000 Subject: [PATCH] fix some bugs. update readme --- README.md | 3 ++- recbole/model/sequential_recommender/gru4rec_ours.py | 2 +- recbole/model/sequential_recommender/sasrec.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a3f4878d..7c8c93dd 100644 --- a/README.md +++ b/README.md @@ -48,4 +48,5 @@ Haw-Shiuan Chang, Nikhil Agarwal, and Andrew McCallum. 2024. To Copy, or not to ``` ## License -The code from Recbole is under [MIT License](https://github.com/RUCAIBox/RecBole/blob/master/LICENSE) but our modification (e.g., Softmax-CPR implementation) is under Apache-2.0 license. +All RecBole data and code are under [MIT License](https://github.com/RUCAIBox/RecBole/blob/master/LICENSE) and can only be used for academic purposes (as indicated in https://github.com/RUCAIBox/RecBole#license on 10/25/2023). +Our modification (e.g., Softmax-CPR implementation) is under Apache-2.0 license and could be used for commercial purpose. diff --git a/recbole/model/sequential_recommender/gru4rec_ours.py b/recbole/model/sequential_recommender/gru4rec_ours.py index 0c1e85b4..9959c47d 100644 --- a/recbole/model/sequential_recommender/gru4rec_ours.py +++ b/recbole/model/sequential_recommender/gru4rec_ours.py @@ -459,7 +459,7 @@ def calculate_loss_prob(self, interaction): #logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) #loss = self.loss_fct(logits, pos_items) #return loss - return loss, prediction_prob.squeeze(dim=0) + return loss, prediction_prob.squeeze(dim=1) def calculate_loss(self, interaction): loss, prediction_prob = self.calculate_loss_prob(interaction) diff --git a/recbole/model/sequential_recommender/sasrec.py b/recbole/model/sequential_recommender/sasrec.py index 41c5bffc..272d1cfb 100644 --- a/recbole/model/sequential_recommender/sasrec.py +++ b/recbole/model/sequential_recommender/sasrec.py @@ -471,7 +471,7 @@ def calculate_loss_prob(self, interaction): raise Exception("Labels can not be None") #logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) #loss = self.loss_fct(logits, pos_items) - return loss, prediction_prob.squeeze(dim=0) + return loss, prediction_prob.squeeze(dim=1) def calculate_loss(self, interaction): loss, prediction_prob = self.calculate_loss_prob(interaction)