Skip to content

Commit

Permalink
Handle too high batch steps more graciously
Browse files Browse the repository at this point in the history
Instead of erroring, when too many batchsteps is set such that the final batch
size would exceed dataset length, simply don't truncate the batch steps instead
of throwing an error.
This change enables experimenting with more aggressive batch steps, and also
comes in handy when working with long-read data.
  • Loading branch information
jakobnissen committed Jun 2, 2023
1 parent 744ebda commit a5bf301
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 16 deletions.
14 changes: 12 additions & 2 deletions test/test_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_loss_falls(self):
vae = vamb.encode.VAE(self.rpkm.shape[1])
rpkm_copy = self.rpkm.copy()
tnfs_copy = self.tnfs.copy()
dl, mask = vamb.encode.make_dataloader(
dl, _ = vamb.encode.make_dataloader(
rpkm_copy, tnfs_copy, self.lens, batchsize=16, destroy=True
)
di = torch.Tensor(rpkm_copy)
Expand Down Expand Up @@ -202,10 +202,20 @@ def test_loss_falls(self):
after_encoding = vae_2.encode(dl)
self.assertTrue(np.all(np.abs(before_encoding - after_encoding) < 1e-6))

def test_warn_too_many_batch_steps(self):
vae = vamb.encode.VAE(self.rpkm.shape[1])
rpkm_copy = self.rpkm.copy()
tnfs_copy = self.tnfs.copy()
dl, _ = vamb.encode.make_dataloader(
rpkm_copy, tnfs_copy, self.lens, batchsize=16, destroy=True
)
with self.assertWarns(Warning):
vae.trainmodel(dl, nepochs=4, batchsteps=[1, 2, 3])

def test_encoding(self):
nlatent = 15
vae = vamb.encode.VAE(self.rpkm.shape[1], nlatent=nlatent)
dl, mask = vamb.encode.make_dataloader(
dl, _ = vamb.encode.make_dataloader(
self.rpkm, self.tnfs, self.lens, batchsize=32
)
encoding = vae.encode(dl)
Expand Down
42 changes: 28 additions & 14 deletions vamb/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch import nn as _nn
from math import log as _log
from time import time
import warnings

__doc__ = """Encode a depths matrix and a tnf matrix to latent representation.
Expand Down Expand Up @@ -379,7 +380,7 @@ def trainepoch(
epoch_celoss = 0.0

if epoch in batchsteps:
data_loader = set_batchsize(data_loader, data_loader.batch_size * 2)
data_loader = set_batchsize(data_loader, data_loader.batch_size * 2) # type: ignore

for depths_in, tnf_in, weights in data_loader:
depths_in.requires_grad = True
Expand Down Expand Up @@ -450,7 +451,7 @@ def encode(self, data_loader) -> _np.ndarray:

row = 0
with _torch.no_grad():
for depths, tnf, weights in new_data_loader:
for depths, tnf, _ in new_data_loader:
# Move input to GPU if requested
if self.usecuda:
depths = depths.cuda()
Expand Down Expand Up @@ -551,28 +552,41 @@ def trainmodel(
if nepochs < 1:
raise ValueError("Minimum 1 epoch, not {nepochs}")

if batchsteps is None:
batchsteps_set: set[int] = set()
if batchsteps is None or len(batchsteps) == 0:
sorted_batch_steps: list[int] = []
else:
# First collect to list in order to allow all element types, then check that
# they are integers
batchsteps = list(batchsteps)
if not all(isinstance(i, int) for i in batchsteps):
raise ValueError("All elements of batchsteps must be integers")
if max(batchsteps, default=0) >= nepochs:
sorted_batch_steps = sorted(set(batchsteps))
if sorted_batch_steps[0] < 1:
raise ValueError(
f"Minimum of batchsteps must be 1, not {sorted_batch_steps[0]}"
)
if sorted_batch_steps[-1] >= nepochs:
raise ValueError("Max batchsteps must not equal or exceed nepochs")
last_batchsize = dataloader.batch_size * 2 ** len(batchsteps)
if len(dataloader.dataset) < last_batchsize: # type: ignore

n_contigs = len(dataloader.dataset) # type: ignore
starting_batch_size: int = dataloader.batch_size # type: ignore
if n_contigs < starting_batch_size:
raise ValueError(
f"Last batch size of {last_batchsize} exceeds dataset length "
f"of {len(dataloader.dataset)}. " # type: ignore
f"Starting batch size of {starting_batch_size} exceeds dataset length "
f"of {n_contigs}. "
"This means you have too few contigs left after filtering to train. "
"It is not adviced to run Vamb with fewer than 10,000 sequences "
"after filtering. "
"Please check the Vamb log file to see where the sequences were "
"filtered away, and verify BAM files has sensible content."
)
batchsteps_set = set(batchsteps)
maximum_batch_steps = (n_contigs // starting_batch_size).bit_length() - 1
if maximum_batch_steps < len(sorted_batch_steps):
warnings.warn(
f"Requested {len(sorted_batch_steps)} batch steps, but with a starting "
f"batch size of {starting_batch_size} and {n_contigs} contigs, "
f"only the first {maximum_batch_steps} batch steps can be used."
)
sorted_batch_steps = sorted_batch_steps[:maximum_batch_steps]

# Get number of features
# Following line is un-inferrable due to typing problems with DataLoader
Expand All @@ -591,8 +605,8 @@ def trainmodel(
print("\tN epochs:", nepochs, file=logfile)
print("\tStarting batch size:", dataloader.batch_size, file=logfile)
batchsteps_string = (
", ".join(map(str, sorted(batchsteps_set)))
if batchsteps_set
", ".join(map(str, sorted_batch_steps))
if len(sorted_batch_steps) > 0
else "None"
)
print("\tBatchsteps:", batchsteps_string, file=logfile)
Expand All @@ -603,7 +617,7 @@ def trainmodel(
# Train
for epoch in range(nepochs):
dataloader = self.trainepoch(
dataloader, epoch, optimizer, sorted(batchsteps_set), time(), logfile
dataloader, epoch, optimizer, sorted_batch_steps, time(), logfile
)

# Save weights - Lord forgive me, for I have sinned when catching all exceptions
Expand Down

0 comments on commit a5bf301

Please sign in to comment.