diff --git a/tensorflow-examples/pom.xml b/tensorflow-examples/pom.xml index bbc0adb..5feefd3 100644 --- a/tensorflow-examples/pom.xml +++ b/tensorflow-examples/pom.xml @@ -12,18 +12,19 @@ 1.8 1.8 + 0.4.0 org.tensorflow tensorflow-core-platform - 0.3.1 + ${tensorflow.version} org.tensorflow tensorflow-framework - 0.3.1 + ${tensorflow.version} diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java index c1d3728..8395969 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java @@ -101,11 +101,15 @@ The given SavedModel SignatureDef contains the following output(s): */ +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.TreeMap; import org.tensorflow.Graph; +import org.tensorflow.Operand; import org.tensorflow.SavedModelBundle; import org.tensorflow.Session; import org.tensorflow.Tensor; -import org.tensorflow.Operand; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; @@ -121,12 +125,6 @@ The given SavedModel SignatureDef contains the following output(s): import org.tensorflow.types.TUint8; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Map; -import java.util.TreeMap; - - /** * Loads an image using ReadFile and DecodeJpeg and then uses the saved model * faster_rcnn/inception_resnet_v2_1024x1024/1 to detect objects with a detection score greater than 0.3 @@ -254,7 +252,6 @@ public static void main(String[] params) { Constant fileName = tf.constant(imagePath); ReadFile readFile = tf.io.readFile(fileName); Session.Runner runner = s.runner(); - s.run(tf.init()); DecodeJpeg.Options options = DecodeJpeg.channels(3L); DecodeJpeg decodeImage = tf.image.decodeJpeg(readFile.contents(), options); //fetch image from file diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java index fedd86d..8ce4ff3 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java @@ -22,9 +22,20 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.Tensor; +import org.tensorflow.framework.optimizers.AdaDelta; +import org.tensorflow.framework.optimizers.AdaGrad; +import org.tensorflow.framework.optimizers.AdaGradDA; +import org.tensorflow.framework.optimizers.Adam; +import org.tensorflow.framework.optimizers.GradientDescent; +import org.tensorflow.framework.optimizers.Momentum; +import org.tensorflow.framework.optimizers.Optimizer; +import org.tensorflow.framework.optimizers.RMSProp; import org.tensorflow.model.examples.datasets.ImageBatch; import org.tensorflow.model.examples.datasets.mnist.MnistDataset; +import org.tensorflow.ndarray.ByteNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Indices; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Constant; @@ -38,20 +49,8 @@ import org.tensorflow.op.nn.MaxPool; import org.tensorflow.op.nn.Relu; import org.tensorflow.op.nn.Softmax; -import org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits; +import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; import org.tensorflow.op.random.TruncatedNormal; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.ByteNdArray; -import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.index.Indices; -import org.tensorflow.framework.optimizers.AdaDelta; -import org.tensorflow.framework.optimizers.AdaGrad; -import org.tensorflow.framework.optimizers.AdaGradDA; -import org.tensorflow.framework.optimizers.Adam; -import org.tensorflow.framework.optimizers.GradientDescent; -import org.tensorflow.framework.optimizers.Momentum; -import org.tensorflow.framework.optimizers.Optimizer; -import org.tensorflow.framework.optimizers.RMSProp; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TUint8; @@ -75,7 +74,6 @@ public class CnnMnist { public static final String TARGET = "target"; public static final String TRAIN = "train"; public static final String TRAINING_LOSS = "training_loss"; - public static final String INIT = "init"; private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz"; private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz"; @@ -160,8 +158,7 @@ public static Graph build(String optimizerName) { // Loss function & regularization OneHot oneHot = tf .oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); - SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.raw - .softmaxCrossEntropyWithLogits(logits, oneHot); + SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.softmaxCrossEntropyWithLogits(logits, oneHot); Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math .add(tf.nn.l2Loss(fc1Biases), @@ -197,19 +194,13 @@ public static Graph build(String optimizerName) { default: throw new IllegalArgumentException("Unknown optimizer " + optimizerName); } - logger.info("Optimizer = " + optimizer.toString()); + logger.info("Optimizer = " + optimizer); Op minimize = optimizer.minimize(loss, TRAIN); - tf.init(); - return graph; } public static void train(Session session, int epochs, int minibatchSize, MnistDataset dataset) { - // Initialises the parameters. - session.runner().addTarget(INIT).run(); - logger.info("Initialised the model parameters"); - int interval = 0; // Train the model for (int i = 0; i < epochs; i++) { @@ -274,7 +265,7 @@ public static void test(Session session, int minibatchSize, MnistDataset dataset sb.append("\n"); } - System.out.println(sb.toString()); + System.out.println(sb); } /** diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMNIST.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMNIST.java index b8c5c26..2a62270 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMNIST.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMNIST.java @@ -16,9 +16,8 @@ */ package org.tensorflow.model.examples.cnn.vgg; -import org.tensorflow.model.examples.datasets.mnist.MnistDataset; - import java.util.logging.Logger; +import org.tensorflow.model.examples.datasets.mnist.MnistDataset; /** * Trains and evaluates VGG'11 model on FashionMNIST dataset. diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java index 9e9725d..d128c43 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java @@ -16,14 +16,20 @@ */ package org.tensorflow.model.examples.cnn.vgg; +import java.util.Arrays; +import java.util.logging.Level; +import java.util.logging.Logger; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.Tensor; import org.tensorflow.framework.optimizers.Adam; import org.tensorflow.framework.optimizers.Optimizer; import org.tensorflow.model.examples.datasets.ImageBatch; import org.tensorflow.model.examples.datasets.mnist.MnistDataset; +import org.tensorflow.ndarray.ByteNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Indices; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.OneHot; @@ -35,19 +41,11 @@ import org.tensorflow.op.nn.Conv2d; import org.tensorflow.op.nn.MaxPool; import org.tensorflow.op.nn.Relu; -import org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits; +import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; import org.tensorflow.op.random.TruncatedNormal; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.ByteNdArray; -import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.index.Indices; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TUint8; -import java.util.Arrays; -import java.util.logging.Level; -import java.util.logging.Logger; - /** * Describes the VGGModel. */ @@ -64,7 +62,6 @@ public class VGGModel implements AutoCloseable { public static final String TARGET = "target"; public static final String TRAIN = "train"; public static final String TRAINING_LOSS = "training_loss"; - public static final String INIT = "init"; private static final Logger logger = Logger.getLogger(VGGModel.class.getName()); @@ -127,8 +124,6 @@ public static Graph compile() { optimizer.minimize(loss, TRAIN); - tf.init(); - return graph; } @@ -159,8 +154,7 @@ public static Add buildFCLayersAndRegularization(Ops tf, Placeholder oneHot = tf .oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); - SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.raw - .softmaxCrossEntropyWithLogits(logits, oneHot); + SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.softmaxCrossEntropyWithLogits(logits, oneHot); Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math .add(tf.nn.l2Loss(fc1Biases), @@ -193,10 +187,6 @@ public static Relu vggConv2DLayer(String layerName, Ops tf, Operand weights = tf.variable(weightShape, TFloat32.class); - tf.initAdd(tf.assign(weights, tf.zerosLike(weights))); + Variable weights = tf.variable(tf.zeros(tf.constant(weightShape), TFloat32.class)); // Create biases with an initial value of 0 Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES); - Variable biases = tf.variable(biasShape, TFloat32.class); - tf.initAdd(tf.assign(biases, tf.zerosLike(biases))); - - // Register all variable initializers for single execution - tf.init(); + Variable biases = tf.variable(tf.zeros(tf.constant(biasShape), TFloat32.class)); // Predict the class of each image in the batch and compute the loss Softmax softmax = @@ -103,9 +97,6 @@ public void run() { // Run the graph try (Session session = new Session(graph)) { - // Initialize variables - session.run(tf.init()); - // Train the model for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) { try (TFloat32 batchImages = preprocessImages(trainingBatch.images()); @@ -163,8 +154,8 @@ private static TFloat32 preprocessLabels(ByteNdArray rawLabels) { ).asTensor(); } - private Graph graph; - private MnistDataset dataset; + private final Graph graph; + private final MnistDataset dataset; private SimpleMnist(Graph graph, MnistDataset dataset) { this.graph = graph; diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java index ccd0a76..4e8fbd5 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java @@ -16,11 +16,13 @@ */ package org.tensorflow.model.examples.regression.linear; +import java.util.List; +import java.util.Random; import org.tensorflow.Graph; import org.tensorflow.Session; -import org.tensorflow.Tensor; import org.tensorflow.framework.optimizers.GradientDescent; import org.tensorflow.framework.optimizers.Optimizer; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Placeholder; @@ -29,12 +31,8 @@ import org.tensorflow.op.math.Div; import org.tensorflow.op.math.Mul; import org.tensorflow.op.math.Pow; -import org.tensorflow.ndarray.Shape; import org.tensorflow.types.TFloat32; -import java.util.List; -import java.util.Random; - /** * In this example TensorFlow finds the weight and bias of the linear regression during 1 epoch, * training on observations one by one. @@ -89,8 +87,6 @@ public static void main(String[] args) { Op minimize = optimizer.minimize(mse); try (Session session = new Session(graph)) { - // Initialize graph variables - session.run(tf.init()); // Train the model on data for (int i = 0; i < xValues.length; i++) {