Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Doubt about the attention_mask #6

Closed
caojiangxia opened this issue Oct 9, 2020 · 7 comments
Closed

Doubt about the attention_mask #6

caojiangxia opened this issue Oct 9, 2020 · 7 comments

Comments

@caojiangxia
Copy link

First:
timeline_mask = torch.BoolTensor(log_seqs == 0).to(self.dev) -> timeline_mask = torch.FloatTensor(log_seqs > 0).to(self.dev)

Second:
attention_mask = ~torch.tril(torch.ones((tl, tl), dtype=torch.bool, device=self.dev)) - > attention_mask = torch.tril(torch.ones((tl, tl), dtype=torch.float, device=self.dev))

Why there has a sign ~ ?

@pmixer
Copy link
Owner

pmixer commented Oct 9, 2020

Hi @caojiangxia, thx for your question! If you mean the negate operation in attention_mask = ~torch.tril, that's based on the context for using attention masks in PyTorch's multi-head attention, pls check https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html in which they noted positions with True is not allowed to attend while False values will be unchanged.(True means mask it), so I generated the triangle matrix and negate it to generate the mask to blacklist upper right entries in attention weights matrix(corresponding to attention for future values). There should be other ways for doing it.

For your another concern, I used timeline_mask and ~timeline_mask alternatively depending on context in implementation, as item_id are non-negative integers, ==0 or >0 are both okay if we use the generated mask in the right way.

Moreover, there's two ways for masking attention weights:

  • Additive, which means we add a negative number which has large abs value to entries in attention weights need to be masked before entering softmax, so the final generated attention weights would have close to 0 values for these entries.
  • Multiplicative(why we used bool tensors rather than float tensors here as required by PyTorch), which should enforce more strict causality as implemented in PyTorch recently.

I do not like leaky attention weights that may introduce bit of future information and breaks causality, so I choose the second choice, the details of how PyTorch implemented multiplicative attention mask is still unclear for me, for my own MHA(https://github.com/pmixer/TiSASRec.pytorch/blob/9ddc21e400254bc352bb2174fd68bc2cf0585c5b/model.py#L67), I still use additive approach although do not feel quite comfortable about its leaky issue.

For more details related to your question, pls check #1

Stay healthy~
Zan

@caojiangxia
Copy link
Author

Thank you very much! (True means mask it. I am misunderstanding it.)

@caojiangxia
Copy link
Author

Thanks for your full explanation!

@caojiangxia
Copy link
Author

caojiangxia commented Oct 12, 2020

Honestly, in the original author implement, the evaluation process didn't ranking all remaining items for each user, which will make the metric much higher than its practical performance.

@pmixer
Copy link
Owner

pmixer commented Oct 12, 2020

@caojiangxia yes, the performance overestimation is a big problem, NDCG&HIT would drop to really low score if rank all items, pls check https://www.kdd.org/kdd2020/accepted-papers/view/on-sampled-metrics-for-item-recommendation for more detailed study of this issue.

@caojiangxia
Copy link
Author

Does that mean the current sequential recommender paper's reported experiment results are not reliable, such as SASRec and TiSASRec? Here need convictive experiments about this task.

@pmixer
Copy link
Owner

pmixer commented Oct 13, 2020

Does that mean the current sequential recommender paper's reported experiment results are not reliable, such as SASRec and TiSASRec? Here need convictive experiments about this task.

it depends on your definition of reliability, based on given experiment setting, SASRec and TiSASRec do outperform many other sequential recommenders, but for different experiment settings like ranking-all-items all the time and consider other metrics like mAP rather than NDCG&HIT(like in Caser), many work including SASRec and TiSASRec would suffer from an abrupt performance drop, its a well know issue now, I recommend checking one of RecSys 2019 best paper: https://dl.acm.org/doi/10.1145/3298689.3347058 and KDD 2020 best paper as listed before, also you may be interested in ACL 2020's best paper for knowing recent trends: https://github.com/marcotcr/checklist there's still a long way to go for making reliable systems :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants