-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Debias #23
base: main
Are you sure you want to change the base?
Debias #23
Conversation
recstudio/model/mf/dice.py
Outdated
def _get_query_encoder(self, train_data): | ||
int = torch.nn.Embedding(train_data.num_users, self.embed_dim, padding_idx=0) | ||
pop = torch.nn.Embedding(train_data.num_users, self.embed_dim, padding_idx=0) | ||
class DICEQueryEncoder(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The int and pop in query encoder could be defined as one Embedding, with dimension 2*self.embed_dim? @pepsi2222
class DICEQueryEncoder(torch.nn.Module): | |
torch.nn.Embedding(train_data.num_users, 2*self.embed_dim, padding_idx=0) |
recstudio/model/mf/dice.py
Outdated
self.pop = pop | ||
def forward(self, batch): | ||
return torch.cat((self.int(batch), self.pop(batch)), dim=-1) | ||
return DICEItemEncoder(int, pop) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly as comment in query encoder above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job! But some changes should be token according to the comments.
recstudio/model/mf/dice.py
Outdated
return output | ||
|
||
def _get_sampler(self, train_data): | ||
class PopularSamplerWithMargin(Sampler): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be discussed. I'm confusing about what the pool
means and why negative items are sampled only in pop or unpop items when their size are smaller than pool. @pepsi2222
recstudio/model/mf/dice.py
Outdated
output['mask'] = mask | ||
output['score'] = {'pos_int_score': pos_int_score, 'pos_pop_score': pos_pop_score, 'pos_click_score': pos_click_score, | ||
'neg_int_score': neg_int_score, 'neg_pop_score': neg_pop_score, 'neg_click_score': neg_click_score} | ||
output['query'] = {'query_int': query.chunk(2, -1)[0], 'query_pop': query.chunk(2, -1)[1]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The chunk()
operations are duplicated here, query_int
and query_pop
are defined in line 117.
output['query'] = {'query_int': query.chunk(2, -1)[0], 'query_pop': query.chunk(2, -1)[1]} | |
output['query'] = {'query_int': query_int, 'query_pop': query_pop} |
recstudio/model/mf/dice.py
Outdated
from recstudio.model.mf.bpr import BPR | ||
from recstudio.model import basemodel, loss_func | ||
import time | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
u'd better attach the title and url of the paper corresponding to the model with a comment here. @pepsi2222
recstudio/data/advance_dataset.py
Outdated
@@ -12,7 +12,8 @@ class ALSDataset(MFDataset): | |||
So the data provided should be ``<u, Iu>`` and ``<i, Ui>`` alternatively. | |||
""" | |||
|
|||
def build(self, split_ratio, shuffle=True, split_mode='user_entry', **kwargs): | |||
def build(self, split_ratio, shuffle=True, split_mode='user_entry', excluding_hist=False, **kwargs): | |||
self.excluding_hist = excluding_hist |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think excluding_hist
is not a good name here. How about return_hist
? @pepsi2222 @Xiuchen519
recstudio/model/mf/dice.py
Outdated
|
||
if num_pop_items < self.pool: | ||
for cnt in range(num_neg): | ||
idx = torch.randint(num_unpop_items, (1,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not using torch.randint(num_unpop_items, (num_neg,))
instead of for
?
idx = torch.randint(num_unpop_items, (1,)) | |
idx = torch.randint(num_unpop_items, (num_neg,)) |
recstudio/model/mf/expomf.py
Outdated
# data to device | ||
batch = self._to_device(batch, self.device) | ||
# update latent user/item factors | ||
a = self._expectation(batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the method is not required to be overrided. U can add some conditional statement in training_step
method to achieve the EM alg as below:
if batch_idx % 2 == 0:
do expectation
else:
do maximization
recstudio/model/mf/pda.py
Outdated
excluding_hist=self.config.get('excluding_hist', False), | ||
method=self.config.get('sampling_method', 'none'), return_query=True) | ||
pos_score = self.score_func(query, pos_item_vec) | ||
pos_score = pos_item_vec.split([pos_item_vec.shape[-1]-1, 1], dim=-1)[1] ** self.config['gamma'] * pos_score # |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder the line is duplicated, becasue the operation has been done in scorer in line 53. @pepsi2222
If so, the method don't need to be overrided.
recstudio/model/mf/pmf.py
Outdated
def _get_query_encoder(self, train_data): | ||
return torch.nn.Embedding(train_data.num_users, self.embed_dim, padding_idx=0) | ||
|
||
def _init_parameter(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can define init_method: normal
and init_range:0.1
in pmf.yaml without overriding the method.
[feat&fix] add ExpoMF, PDA, DICE