-
-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Classification with x-transformers #264
base: main
Are you sure you want to change the base?
Conversation
x_transformers/x_transformers.py
Outdated
x = x[:, 0] | ||
|
||
if self.use_pooling: | ||
x = self.pooling(x).squeeze() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the pooling, we need to account for masking (masked averaging)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i can take care of this if you'd like, it is all around a bit tricky
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, please do so. Thank you 👍
@RyanKim17920 do you want to try the latest changes and see if that's enough? |
@RyanKim17920 hey Ryan, sorry for hijacking your efforts, just that the project is at a size where things need to be a bit more particular your example should run now as import torch
from torch import nn
from x_transformers import (
TransformerWrapper,
Encoder
)
# CLS token test
transformer = TransformerWrapper(
num_tokens=6,
max_seq_len=10,
logits_dim=2, # num_classes
use_cls_token=True,
attn_layers = Encoder(
dim = 6,
depth = 1,
heads = 2,
)
)
x = torch.randint(0, 5, (2, 10))
y = torch.tensor([0, 1])
print(x.shape)
logits = transformer(x)
print(logits.shape)
loss = nn.CrossEntropyLoss()(logits, y)
print(loss)
# BCE cls token
transformer = TransformerWrapper(
num_tokens=6,
max_seq_len=10,
logits_dim=1, # num_classes
use_cls_token=True,
squeeze_out_last_dim = True,
attn_layers = Encoder(
dim = 6,
depth = 1,
heads = 2,
)
)
x = torch.randint(0, 5, (2, 10)).float()
y = torch.tensor([0, 1]).float()
print(x.shape)
logits = transformer(x).squeeze()
loss = nn.BCEWithLogitsLoss()(logits, y)
print(loss)
# pooling test
transformer = TransformerWrapper(
num_tokens=6,
max_seq_len=10,
logits_dim=2, # num_classes
average_pool_embed = True,
attn_layers = Encoder(
dim = 6,
depth = 1,
heads = 2,
)
)
x = torch.randint(0, 5, (2, 10))
y = torch.tensor([0, 1])
print(x.shape)
logits = transformer(x)
print(logits.shape)
loss = nn.CrossEntropyLoss()(logits, y)
print(loss)
# pooling BCE test
# pooling test
transformer = TransformerWrapper(
num_tokens=6,
max_seq_len=10,
logits_dim=1, # num_classes
average_pool_embed = True,
squeeze_out_last_dim = True,
attn_layers = Encoder(
dim = 6,
depth = 1,
heads = 2,
)
)
x = torch.randint(0, 5, (2, 10)).float()
y = torch.tensor([0, 1]).float()
print(x.shape)
logits = transformer(x).squeeze()
print(logits.shape)
loss = nn.BCEWithLogitsLoss()(logits, y)
print(loss)
# normal test
transformer = TransformerWrapper(
num_tokens=6,
max_seq_len=10,
logits_dim=2, # num_classes
average_pool_embed = True,
attn_layers = Encoder(
dim = 6,
depth = 1,
heads = 2,
)
)
x = torch.randint(0, 5, (1, 10))
y = torch.tensor([0])
print(x.shape)
logits = transformer(x)
print(logits.shape) |
Thank you for the improvements you've already made to my original additions. I noticed that the test/x_transformers are outdated, so those changes aren't needed anymore. However, I believe the example I provided could still be valuable. It demonstrates the usage of the NLP classification with a well-known dataset, which might be useful for users to understand how to implement it while getting a high 90% validation accuracy. Would it be possible to add the example to the repository? |
c129bdc
to
1ccccaa
Compare
Added cls token/pooling option for NLP based full text classification