Skip to content

Commit

Permalink
Compress .trks by using gzip compression when writing the tarfile. (#…
Browse files Browse the repository at this point in the history
…64)

* Write tarfiles with gzip compression.

* Use tmpdir test fixture for temporary directories.
  • Loading branch information
willgraf authored Jul 15, 2021
1 parent 35ef710 commit f2bb550
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 70 deletions.
2 changes: 1 addition & 1 deletion deepcell_tracking/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ def dump(self, filename, track_review_dict=None):

filename = str(filename)

with tarfile.open(filename, 'w') as trks:
with tarfile.open(filename, 'w:gz') as trks:
# disable auto deletion and close/delete manually
# to resolve double-opening issue on Windows.
with tempfile.NamedTemporaryFile('w', delete=False) as lineage:
Expand Down
74 changes: 32 additions & 42 deletions deepcell_tracking/tracking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@
from __future__ import division
from __future__ import print_function

import errno
import os
import shutil
import tempfile

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -123,7 +120,7 @@ def test_simple(self):
neighborhood_encoder=encoder,
data_format='invalid')

def test_track_cells(self):
def test_track_cells(self, tmpdir):
frames = 10
track_length = 3
labels_per_frame = 3
Expand Down Expand Up @@ -172,41 +169,34 @@ def test_track_cells(self):
with pytest.raises(ValueError):
tracker.dataframe(bad_value=-1)

try:
# test tracker.postprocess
tempdir = tempfile.mkdtemp() # create dir
path = os.path.join(tempdir, 'postprocess.xyz')
tracker.postprocess(filename=path)
post_saved_path = os.path.join(tempdir, 'postprocess.trk')
assert os.path.isfile(post_saved_path)

# test tracker.dump
path = os.path.join(tempdir, 'test.xyz')
tracker.dump(path)
dump_saved_path = os.path.join(tempdir, 'test.trk')
assert os.path.isfile(dump_saved_path)

# utility tests for loading trk files
# TODO: move utility tests into utils_test.py

# test trk_folder_to_trks
utils.trk_folder_to_trks(tempdir, os.path.join(tempdir, 'all.trks'))
assert os.path.isfile(os.path.join(tempdir, 'all.trks'))

# test load_trks
data = utils.load_trks(post_saved_path)
assert isinstance(data['lineages'], list)
assert all(isinstance(d, dict) for d in data['lineages'])
np.testing.assert_equal(data['X'], tracker.X)
np.testing.assert_equal(data['y'], tracker.y_tracked)
# load trks instead of trk
data = utils.load_trks(os.path.join(tempdir, 'all.trks'))

# test trks_stats
utils.trks_stats(os.path.join(tempdir, 'test.trk'))
finally:
try:
shutil.rmtree(tempdir) # delete directory
except OSError as exc:
if exc.errno != errno.ENOENT: # no such file or directory
raise # re-raise exception
# test tracker.postprocess
tempdir = str(tmpdir)
path = os.path.join(tempdir, 'postprocess.xyz')
tracker.postprocess(filename=path)
post_saved_path = os.path.join(tempdir, 'postprocess.trk')
assert os.path.isfile(post_saved_path)

# test tracker.dump
path = os.path.join(tempdir, 'test.xyz')
tracker.dump(path)
dump_saved_path = os.path.join(tempdir, 'test.trk')
assert os.path.isfile(dump_saved_path)

# utility tests for loading trk files
# TODO: move utility tests into utils_test.py

# test trk_folder_to_trks
utils.trk_folder_to_trks(tempdir, os.path.join(tempdir, 'all.trks'))
assert os.path.isfile(os.path.join(tempdir, 'all.trks'))

# test load_trks
data = utils.load_trks(post_saved_path)
assert isinstance(data['lineages'], list)
assert all(isinstance(d, dict) for d in data['lineages'])
np.testing.assert_equal(data['X'], tracker.X)
np.testing.assert_equal(data['y'], tracker.y_tracked)
# load trks instead of trk
data = utils.load_trks(os.path.join(tempdir, 'all.trks'))

# test trks_stats
utils.trks_stats(os.path.join(tempdir, 'test.trk'))
2 changes: 1 addition & 1 deletion deepcell_tracking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def save_trks(filename, lineages, raw, tracked):
if not str(filename).lower().endswith('.trks'):
raise ValueError('filename must end with `.trks`. Found %s' % filename)

with tarfile.open(filename, 'w') as trks:
with tarfile.open(filename, 'w:gz') as trks:
with tempfile.NamedTemporaryFile('w', delete=False) as lineages_file:
json.dump(lineages, lineages_file, indent=4)
lineages_file.flush()
Expand Down
41 changes: 15 additions & 26 deletions deepcell_tracking/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@
from __future__ import print_function

import copy
import errno
import os
import shutil
import tempfile

import numpy as np
import skimage as sk
Expand Down Expand Up @@ -144,33 +141,25 @@ def test_count_pairs(self):
y, same_probability=prob, data_format='channels_first')
assert pairs == expected

def test_save_trks(self):
def test_save_trks(self, tmpdir):
X = get_image(30, 30)
y = np.random.randint(low=0, high=10, size=X.shape)
lineage = [dict()]

try:
tempdir = tempfile.mkdtemp() # create dir
with pytest.raises(ValueError):
badfilename = os.path.join(tempdir, 'x.trk')
utils.save_trks(badfilename, lineage, X, y)

filename = os.path.join(tempdir, 'x.trks')
utils.save_trks(filename, lineage, X, y)
assert os.path.isfile(filename)

# test saved tracks can be loaded
loaded = utils.load_trks(filename)
assert loaded['lineages'] == lineage
np.testing.assert_array_equal(X, loaded['X'])
np.testing.assert_array_equal(y, loaded['y'])

finally:
try:
shutil.rmtree(tempdir) # delete directory
except OSError as exc:
if exc.errno != errno.ENOENT: # no such file or directory
raise # re-raise exception
tempdir = str(tmpdir)
with pytest.raises(ValueError):
badfilename = os.path.join(tempdir, 'x.trk')
utils.save_trks(badfilename, lineage, X, y)

filename = os.path.join(tempdir, 'x.trks')
utils.save_trks(filename, lineage, X, y)
assert os.path.isfile(filename)

# test saved tracks can be loaded
loaded = utils.load_trks(filename)
assert loaded['lineages'] == lineage
np.testing.assert_array_equal(X, loaded['X'])
np.testing.assert_array_equal(y, loaded['y'])

def test_normalize_adj_matrix(self):
frames = 3
Expand Down

0 comments on commit f2bb550

Please sign in to comment.