-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprepare_dataset.py
148 lines (118 loc) · 4.8 KB
/
prepare_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# coding: utf-8
import os
from absl import flags
from absl import app
import numpy as np
from sklearn.model_selection import train_test_split
from retchat.dataset import DataFinder
from retchat.dataset import Annotation
from retchat.dataset import Stack, MaxBlockStack, PercentileNormalizedMaxBlockStack, Max2R10PreprocessedStack
from retchat.dataset import label_to_image_fname
from retchat.dataset import RegressionRecordParser
from retchat.sampler import PlaneSampler
from dlutils.dataset.tfrec_utils import tfrecord_from_sampler
# Application parameters
ARGS = flags.FLAGS
flags.DEFINE_multi_string('data_folder', None, 'Folder containing stacks.')
flags.DEFINE_string('label_folder', None, 'Folder containing annotations.')
flags.DEFINE_string('output_folder', None, 'Folder to write tfrecords')
flags.DEFINE_integer('samples_per_stack',
500,
'Number of frames to sample from each stack.',
lower_bound=1)
flags.DEFINE_integer('validation',
2,
'Number of stacks to use for validation',
lower_bound=0)
flags.DEFINE_integer('delta',
5,
'Number of planes around central plane to project over.',
lower_bound=0)
flags.DEFINE_integer(
'patch_size', 224,
'Patch size. Will be applied in both x and y axis leading to square patches.'
)
flags.DEFINE_string('stack_suffix', '.stk', 'Suffix of videos')
flags.DEFINE_string('projector', None, 'Projector to use for preprocessing')
flags.DEFINE_string('labels_suffix', '.mat', 'Suffix of annotations')
for required_flag in ['data_folder', 'label_folder', 'output_folder']:
flags.mark_flag_as_required(required_flag)
def _create_folder_if_needed(folder):
'''
'''
if not os.path.exists(folder):
print('Creating folders: {}'.format(folder))
os.makedirs(folder)
def process(image_path, label_path, outdir):
'''
'''
output = os.path.join(
outdir,
os.path.splitext(os.path.basename(image_path))[0] + '.tfrecord')
parser = RegressionRecordParser()
if ARGS.projector == 'percentile_max':
stack = PercentileNormalizedMaxBlockStack(image_path,
delta_x=ARGS.delta,
percentiles=(5, 95))
elif ARGS.projector == 'max':
stack = MaxBlockStack(image_path, delta_x=ARGS.delta)
elif ARGS.projector == 'max2rank':
stack = Max2R10PreprocessedStack(image_path, delta_x=ARGS.delta)
else:
stack = Stack(image_path)
sampler = PlaneSampler(stack=stack,
annotation=Annotation(label_path),
n_samples=ARGS.samples_per_stack,
patch_size=ARGS.patch_size)
tfrecord_from_sampler(output, sampler, parser.serialize)
def split_dataset(vals, validation_samples, test_samples):
'''splits videos into three categories: training, validation and test.
'''
splits = {}
if validation_samples <= 0 and test_samples <= 0:
splits['training'] = vals
return splits
splits['training'], remaining = train_test_split(
sorted(vals), test_size=validation_samples + test_samples)
if test_samples <= 0:
splits['validation'] = remaining
return splits
splits['validation'], splits['test'] = train_test_split(
remaining, test_size=test_samples)
return splits
def main(*args):
'''
'''
outdir = os.path.join(
ARGS.output_folder,
'N{}-{}'.format(ARGS.samples_per_stack,
('single' if ARGS.projector is None else
'{}-block-{}'.format(ARGS.projector, ARGS.delta))))
_create_folder_if_needed(outdir)
print('Creating training data in {}'.format(outdir))
# initialize seed
np.random.seed(13)
# prepare samples from individual videos.
paths = list(
DataFinder(data_dirs=ARGS.data_folder,
label_dir=ARGS.label_folder,
label_pattern='*' + ARGS.labels_suffix,
label_to_img_fn=label_to_image_fname))
if not paths:
print(
'No matching files found for data_folder: {} and label_folder: {}'.
format(ARGS.data_folder, ARGS.label_folder))
return
splits = split_dataset(paths, ARGS.validation, 0)
for split in splits:
print(split, splits[split])
for split in splits.keys():
if split == 'training':
split_outdir = outdir
else:
split_outdir = os.path.join(outdir, split)
_create_folder_if_needed(split_outdir)
for label_path, image_path in splits[split]:
process(image_path, label_path, split_outdir)
if __name__ == '__main__':
app.run(main)