Skip to content

Commit

Permalink
Merge pull request #76 from juglab/issue-73
Browse files Browse the repository at this point in the history
Fix for bug # 73
  • Loading branch information
turekg authored Apr 30, 2020
2 parents 04f9760 + bec93d8 commit f0832e4
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 19 deletions.
13 changes: 6 additions & 7 deletions n2v/internals/N2V_DataGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ def __extract_patches__(self, data, num_patches=None, shape=(256, 256), n_dims=2
patches = []
if n_dims == 2:
if data.shape[1] > shape[0] and data.shape[2] > shape[1]:
for y in range(0, data.shape[1] - shape[0], shape[0]):
for x in range(0, data.shape[2] - shape[1], shape[1]):
for y in range(0, data.shape[1] - shape[0] + 1, shape[0]):
for x in range(0, data.shape[2] - shape[1] + 1, shape[1]):
patches.append(data[:, y:y + shape[0], x:x + shape[1]])

return np.concatenate(patches)
Expand All @@ -194,14 +194,13 @@ def __extract_patches__(self, data, num_patches=None, shape=(256, 256), n_dims=2
print("'shape' is too big.")
elif n_dims == 3:
if data.shape[1] > shape[0] and data.shape[2] > shape[1] and data.shape[3] > shape[2]:
for z in range(0, data.shape[1] - shape[0], shape[0]):
for y in range(0, data.shape[2] - shape[1], shape[1]):
for x in range(0, data.shape[3] - shape[2], shape[2]):
for z in range(0, data.shape[1] - shape[0] + 1, shape[0]):
for y in range(0, data.shape[2] - shape[1] + 1, shape[1]):
for x in range(0, data.shape[3] - shape[2] + 1, shape[2]):
patches.append(data[:, z:z + shape[0], y:y + shape[1], x:x + shape[2]])

return np.concatenate(patches)
elif data.shape[1] == shape[0] and data.shape[2] == shape[1] and data.shape[3] == shape[
2]:
elif data.shape[1] == shape[0] and data.shape[2] == shape[1] and data.shape[3] == shape[2]:
return data
else:
print("'shape' is too big.")
Expand Down
2 changes: 1 addition & 1 deletion n2v/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.10'
__version__ = '0.1.11'
18 changes: 9 additions & 9 deletions tests/functional/test_training2D_RGB.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@
from n2v.utils.n2v_utils import manipulate_val_data
from n2v.internals.N2V_DataGenerator import N2V_DataGenerator
from matplotlib import pyplot as plt
import urllib
import urllib.request
import os
import zipfile

# create a folder for our data
if not os.path.isdir('./data'):
os.mkdir('data')
# check if data has been downloaded already
zipPath="data/RGB.zip"
if not os.path.exists(zipPath):
# download and unzip data
data = urllib.request.urlretrieve('https://cloud.mpi-cbg.de/index.php/s/Frru2hsjjAljpfW/download', zipPath)
with zipfile.ZipFile(zipPath, 'r') as zip_ref:
zip_ref.extractall("data")
# check if data has been downloaded already
zipPath = "data/RGB.zip"
if not os.path.exists(zipPath):
# download and unzip data
data = urllib.request.urlretrieve('https://cloud.mpi-cbg.de/index.php/s/Frru2hsjjAljpfW/download', zipPath)
with zipfile.ZipFile(zipPath, 'r') as zip_ref:
zip_ref.extractall("data")

# For training, we will load __one__ low-SNR RGB image and use the <code>N2V_DataGenerator</code> to extract non-overlapping patches
datagen = N2V_DataGenerator()
Expand All @@ -29,7 +29,7 @@
# The function will return a list of images (numpy arrays).
# In the 'dims' parameter we specify the order of dimensions in the image files we are reading:
# 'C' stands for channels (color)
imgs = datagen.load_imgs_from_directory(directory="data/", filter='*.png', dims='YXC')
imgs = datagen.load_imgs_from_directory(directory="./data", filter='*.png', dims='YXC')

print('shape of loaded images: ',imgs[0].shape)
# Remove alpha channel
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/test_training2D_SEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# create a folder for our data
if not os.path.isdir('./data'):
os.mkdir('./data')
zipPath="data/SEM.zip"
zipPath = "data/SEM.zip"
if not os.path.exists(zipPath):
# download and unzip data
data = urllib.request.urlretrieve('https://cloud.mpi-cbg.de/index.php/s/pXgfbobntrw06lC/download', zipPath)
Expand Down
47 changes: 47 additions & 0 deletions tests/test_Noise2VoidDataGenerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from n2v.internals.N2V_DataGenerator import N2V_DataGenerator
import urllib.request
import os
import zipfile


def test_generate_patches_2D():

if not os.path.isdir('data'):
os.mkdir('data')
zip_path = "data/RGB.zip"
if not os.path.exists(zip_path):
data = urllib.request.urlretrieve('https://cloud.mpi-cbg.de/index.php/s/Frru2hsjjAljpfW/download', zip_path)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall('data')

datagen = N2V_DataGenerator()

imgs = datagen.load_imgs_from_directory(directory="data", filter='*.png', dims='YXC')
imgs[0] = imgs[0][..., :3]
patches = datagen.generate_patches_from_list(imgs, shape=(1100, 2800))
assert len(patches) == 1
patches = datagen.generate_patches_from_list(imgs, shape=(550, 1400))
assert len(patches) == 4
patches = datagen.generate_patches_from_list(imgs, shape=(110, 280))
assert len(patches) == 100

def test_generate_patches_3D():

if not os.path.isdir('data'):
os.mkdir('data')
zip_path = 'data/flywing-data.zip'
if not os.path.exists(zip_path):
# download and unzip data
data = urllib.request.urlretrieve('https://cloud.mpi-cbg.de/index.php/s/RKStdwKo4FlFrxE/download', zip_path)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall('data')

datagen = N2V_DataGenerator()

imgs = datagen.load_imgs_from_directory(directory="data", filter='*.tif', dims='ZYX')
print(imgs[0].shape)
patches = datagen.generate_patches_from_list(imgs[:1], shape=(35, 520, 692))
assert len(patches) == 1
patches = datagen.generate_patches_from_list(imgs[:1], shape=(5, 52, 174))
assert len(patches) == 210

2 changes: 1 addition & 1 deletion tests/test_Noise2VoidDataWrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from n2v.internals.N2V_DataWrapper import N2V_DataWrapper
from n2v.internals.N2V_DataWrapper import N2V_DataWrapper

import numpy as np

Expand Down

0 comments on commit f0832e4

Please sign in to comment.