diff --git a/densenet161.py b/densenet161.py index 936bfe2..caadc60 100644 --- a/densenet161.py +++ b/densenet161.py @@ -1,5 +1,5 @@ from keras.models import Model -from keras.layers import Input, merge, ZeroPadding2D +from keras.layers import Input, merge, ZeroPadding2D, concatenate from keras.layers.core import Dense, Dropout, Activation from keras.layers.convolutional import Convolution2D from keras.layers.pooling import AveragePooling2D, GlobalAveragePooling2D, MaxPooling2D @@ -29,8 +29,8 @@ def DenseNet(nb_dense_block=4, growth_rate=48, nb_filter=96, reduction=0.0, drop # Handle Dimension Ordering for different backends global concat_axis - if K.image_dim_ordering() == 'tf': - concat_axis = 3 + if K.image_data_format() == 'channels_last': + concat_axis = -1 img_input = Input(shape=(224, 224, 3), name='data') else: concat_axis = 1 @@ -162,7 +162,7 @@ def dense_block(x, stage, nb_layers, nb_filter, growth_rate, dropout_rate=None, for i in range(nb_layers): branch = i+1 x = conv_block(concat_feat, stage, branch, growth_rate, dropout_rate, weight_decay) - concat_feat = merge([concat_feat, x], mode='concat', concat_axis=concat_axis, name='concat_'+str(stage)+'_'+str(branch)) + concat_feat = concatenate([concat_feat, x], axis=concat_axis, name='concat_'+str(stage)+'_'+str(branch)) if grow_nb_filters: nb_filter += growth_rate