Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update pad.py to include reflective padding #195

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ec88b0c
remove duplicated for loop
pattonw Oct 6, 2023
abcecf2
increment patch number
pattonw Oct 6, 2023
b557548
ArraySpec docs
pattonw Oct 6, 2023
80033bd
ArraySpec bug fix:
pattonw Oct 6, 2023
7f8b877
fix the deform augment test
pattonw Nov 1, 2023
f1fd63a
better bounds on required packages
pattonw Nov 1, 2023
f96aa78
ignore missing imports from packages that don't provide type hints
pattonw Nov 2, 2023
98bbe71
fix typehint mistakes
pattonw Nov 2, 2023
3ba99da
format pyproject.toml
pattonw Nov 2, 2023
7a10397
black format
pattonw Nov 2, 2023
1615a7d
move register hooks to the start method
pattonw Nov 2, 2023
4a8ccf4
fix typo
pattonw Nov 2, 2023
a001d7d
support non-spatial arrays in ArraySource
pattonw Nov 2, 2023
98a2e34
overhaul torch tests
pattonw Nov 2, 2023
96ff2f1
remove multiprocess set start method monkey patch
pattonw Nov 2, 2023
75ff6ff
only deploy docs on tagged commits to main
pattonw Nov 2, 2023
865fe53
minor black formatting and configuration changes
pattonw Nov 2, 2023
2078036
properly skip torch tests if torch not installed
pattonw Nov 2, 2023
46676dd
black formatting
pattonw Nov 2, 2023
84593e8
avoid testing on python 3.7, instead use 3.11
pattonw Nov 2, 2023
ed3c7ce
add typed libraries to dev dependencies
pattonw Nov 2, 2023
d49db1f
pass torch train test
pattonw Nov 30, 2023
797994c
pass torch train test
pattonw Nov 30, 2023
b2f8c2d
remove extra error printing
pattonw Dec 19, 2023
625cb03
switch error printing order
pattonw Dec 19, 2023
35fcd43
black format docs and examples
pattonw Dec 19, 2023
076661f
Squashed commit of the following:
pattonw Dec 19, 2023
b6c425f
parameterize tests for cuda devices
pattonw Dec 19, 2023
a7503d7
Update pad.py to include reflective padding
lmanan Nov 2, 2023
443c666
Replace .ndim by len()
lmanan Nov 3, 2023
531d81d
update the pad tests
pattonw Dec 19, 2023
a7027c6
fix the test case
pattonw Dec 19, 2023
3782525
pass the fixed tests
pattonw Dec 19, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/publish-docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ name: Deploy Docs to GitHub Pages
on:
push:
branches: [main]
pull_request:
branches: [main]
tags: "*"
workflow_dispatch:

# Allow this job to clone the repo and create a page deployment
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
platform: [ubuntu-latest]

steps:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
]

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]

# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
Expand Down
54 changes: 22 additions & 32 deletions examples/cremi/mknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,26 @@
import tensorflow as tf
import json

def create_network(input_shape, name):

def create_network(input_shape, name):
tf.reset_default_graph()

# create a placeholder for the 3D raw input tensor
raw = tf.placeholder(tf.float32, shape=input_shape)

# create a U-Net
raw_batched = tf.reshape(raw, (1, 1) + input_shape)
unet_output = unet(raw_batched, 6, 4, [[1,3,3],[1,3,3],[1,3,3]])
unet_output = unet(raw_batched, 6, 4, [[1, 3, 3], [1, 3, 3], [1, 3, 3]])

# add a convolution layer to create 3 output maps representing affinities
# in z, y, and x
pred_affs_batched = conv_pass(
unet_output,
kernel_size=1,
num_fmaps=3,
num_repetitions=1,
activation='sigmoid')
unet_output, kernel_size=1, num_fmaps=3, num_repetitions=1, activation="sigmoid"
)

# get the shape of the output
output_shape_batched = pred_affs_batched.get_shape().as_list()
output_shape = output_shape_batched[1:] # strip the batch dimension
output_shape = output_shape_batched[1:] # strip the batch dimension

# the 4D output tensor (3, depth, height, width)
pred_affs = tf.reshape(pred_affs_batched, output_shape)
Expand All @@ -33,46 +30,39 @@ def create_network(input_shape, name):
gt_affs = tf.placeholder(tf.float32, shape=output_shape)

# create a placeholder for per-voxel loss weights
loss_weights = tf.placeholder(
tf.float32,
shape=output_shape)
loss_weights = tf.placeholder(tf.float32, shape=output_shape)

# compute the loss as the weighted mean squared error between the
# predicted and the ground-truth affinities
loss = tf.losses.mean_squared_error(
gt_affs,
pred_affs,
loss_weights)
loss = tf.losses.mean_squared_error(gt_affs, pred_affs, loss_weights)

# use the Adam optimizer to minimize the loss
opt = tf.train.AdamOptimizer(
learning_rate=0.5e-4,
beta1=0.95,
beta2=0.999,
epsilon=1e-8)
learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8
)
optimizer = opt.minimize(loss)

# store the network in a meta-graph file
tf.train.export_meta_graph(filename=name + '.meta')
tf.train.export_meta_graph(filename=name + ".meta")

# store network configuration for use in train and predict scripts
config = {
'raw': raw.name,
'pred_affs': pred_affs.name,
'gt_affs': gt_affs.name,
'loss_weights': loss_weights.name,
'loss': loss.name,
'optimizer': optimizer.name,
'input_shape': input_shape,
'output_shape': output_shape[1:]
"raw": raw.name,
"pred_affs": pred_affs.name,
"gt_affs": gt_affs.name,
"loss_weights": loss_weights.name,
"loss": loss.name,
"optimizer": optimizer.name,
"input_shape": input_shape,
"output_shape": output_shape[1:],
}
with open(name + '_config.json', 'w') as f:
with open(name + "_config.json", "w") as f:
json.dump(config, f)

if __name__ == "__main__":

if __name__ == "__main__":
# create a network for training
create_network((84, 268, 268), 'train_net')
create_network((84, 268, 268), "train_net")

# create a larger network for faster prediction
create_network((120, 322, 322), 'test_net')
create_network((120, 322, 322), "test_net")
63 changes: 28 additions & 35 deletions examples/cremi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,29 @@
import gunpowder as gp
import json

def predict(iteration):

def predict(iteration):
##################
# DECLARE ARRAYS #
##################

# raw intensities
raw = gp.ArrayKey('RAW')
raw = gp.ArrayKey("RAW")

# the predicted affinities
pred_affs = gp.ArrayKey('PRED_AFFS')
pred_affs = gp.ArrayKey("PRED_AFFS")

####################
# DECLARE REQUESTS #
####################

with open('test_net_config.json', 'r') as f:
with open("test_net_config.json", "r") as f:
net_config = json.load(f)

# get the input and output size in world units (nm, in this case)
voxel_size = gp.Coordinate((40, 4, 4))
input_size = gp.Coordinate(net_config['input_shape'])*voxel_size
output_size = gp.Coordinate(net_config['output_shape'])*voxel_size
input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size
output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size
context = input_size - output_size

# formulate the request for what a batch should contain
Expand All @@ -37,52 +37,44 @@ def predict(iteration):
#############################

source = gp.Hdf5Source(
'sample_A_padded_20160501.hdf',
datasets = {
raw: 'volumes/raw'
})
"sample_A_padded_20160501.hdf", datasets={raw: "volumes/raw"}
)

# get the ROI provided for raw (we need it later to calculate the ROI in
# which we can make predictions)
with gp.build(source):
raw_roi = source.spec[raw].roi

pipeline = (

# read from HDF5 file
source +

source
+
# convert raw to float in [0, 1]
gp.Normalize(raw) +

gp.Normalize(raw)
+
# perform one training iteration for each passing batch (here we use
# the tensor names earlier stored in train_net.config)
gp.tensorflow.Predict(
graph='test_net.meta',
checkpoint='train_net_checkpoint_%d'%iteration,
inputs={
net_config['raw']: raw
},
outputs={
net_config['pred_affs']: pred_affs
},
array_specs={
pred_affs: gp.ArraySpec(roi=raw_roi.grow(-context, -context))
}) +

graph="test_net.meta",
checkpoint="train_net_checkpoint_%d" % iteration,
inputs={net_config["raw"]: raw},
outputs={net_config["pred_affs"]: pred_affs},
array_specs={pred_affs: gp.ArraySpec(roi=raw_roi.grow(-context, -context))},
)
+
# store all passing batches in the same HDF5 file
gp.Hdf5Write(
{
raw: '/volumes/raw',
pred_affs: '/volumes/pred_affs',
raw: "/volumes/raw",
pred_affs: "/volumes/pred_affs",
},
output_filename='predictions_sample_A.hdf',
compression_type='gzip'
) +

output_filename="predictions_sample_A.hdf",
compression_type="gzip",
)
+
# show a summary of time spend in each node every 10 iterations
gp.PrintProfilingStats(every=10) +

gp.PrintProfilingStats(every=10)
+
# iterate over the whole dataset in a scanning fashion, emitting
# requests that match the size of the network
gp.Scan(reference=request)
Expand All @@ -93,5 +85,6 @@ def predict(iteration):
# without keeping the complete dataset in memory
pipeline.request_batch(gp.BatchRequest())


if __name__ == "__main__":
predict(200000)
Loading
Loading