diff --git a/tensorflow-examples/pom.xml b/tensorflow-examples/pom.xml
index bbc0adb..5feefd3 100644
--- a/tensorflow-examples/pom.xml
+++ b/tensorflow-examples/pom.xml
@@ -12,18 +12,19 @@
1.8
1.8
+ 0.4.0
org.tensorflow
tensorflow-core-platform
- 0.3.1
+ ${tensorflow.version}
org.tensorflow
tensorflow-framework
- 0.3.1
+ ${tensorflow.version}
diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java
index c1d3728..8395969 100644
--- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java
+++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java
@@ -101,11 +101,15 @@ The given SavedModel SignatureDef contains the following output(s):
*/
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.TreeMap;
import org.tensorflow.Graph;
+import org.tensorflow.Operand;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
-import org.tensorflow.Operand;
import org.tensorflow.ndarray.FloatNdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
@@ -121,12 +125,6 @@ The given SavedModel SignatureDef contains the following output(s):
import org.tensorflow.types.TUint8;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.TreeMap;
-
-
/**
* Loads an image using ReadFile and DecodeJpeg and then uses the saved model
* faster_rcnn/inception_resnet_v2_1024x1024/1 to detect objects with a detection score greater than 0.3
@@ -254,7 +252,6 @@ public static void main(String[] params) {
Constant fileName = tf.constant(imagePath);
ReadFile readFile = tf.io.readFile(fileName);
Session.Runner runner = s.runner();
- s.run(tf.init());
DecodeJpeg.Options options = DecodeJpeg.channels(3L);
DecodeJpeg decodeImage = tf.image.decodeJpeg(readFile.contents(), options);
//fetch image from file
diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java
index fedd86d..8ce4ff3 100644
--- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java
+++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java
@@ -22,9 +22,20 @@
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Session;
-import org.tensorflow.Tensor;
+import org.tensorflow.framework.optimizers.AdaDelta;
+import org.tensorflow.framework.optimizers.AdaGrad;
+import org.tensorflow.framework.optimizers.AdaGradDA;
+import org.tensorflow.framework.optimizers.Adam;
+import org.tensorflow.framework.optimizers.GradientDescent;
+import org.tensorflow.framework.optimizers.Momentum;
+import org.tensorflow.framework.optimizers.Optimizer;
+import org.tensorflow.framework.optimizers.RMSProp;
import org.tensorflow.model.examples.datasets.ImageBatch;
import org.tensorflow.model.examples.datasets.mnist.MnistDataset;
+import org.tensorflow.ndarray.ByteNdArray;
+import org.tensorflow.ndarray.FloatNdArray;
+import org.tensorflow.ndarray.Shape;
+import org.tensorflow.ndarray.index.Indices;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Constant;
@@ -38,20 +49,8 @@
import org.tensorflow.op.nn.MaxPool;
import org.tensorflow.op.nn.Relu;
import org.tensorflow.op.nn.Softmax;
-import org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits;
+import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits;
import org.tensorflow.op.random.TruncatedNormal;
-import org.tensorflow.ndarray.Shape;
-import org.tensorflow.ndarray.ByteNdArray;
-import org.tensorflow.ndarray.FloatNdArray;
-import org.tensorflow.ndarray.index.Indices;
-import org.tensorflow.framework.optimizers.AdaDelta;
-import org.tensorflow.framework.optimizers.AdaGrad;
-import org.tensorflow.framework.optimizers.AdaGradDA;
-import org.tensorflow.framework.optimizers.Adam;
-import org.tensorflow.framework.optimizers.GradientDescent;
-import org.tensorflow.framework.optimizers.Momentum;
-import org.tensorflow.framework.optimizers.Optimizer;
-import org.tensorflow.framework.optimizers.RMSProp;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TUint8;
@@ -75,7 +74,6 @@ public class CnnMnist {
public static final String TARGET = "target";
public static final String TRAIN = "train";
public static final String TRAINING_LOSS = "training_loss";
- public static final String INIT = "init";
private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz";
private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz";
@@ -160,8 +158,7 @@ public static Graph build(String optimizerName) {
// Loss function & regularization
OneHot oneHot = tf
.oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f));
- SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.raw
- .softmaxCrossEntropyWithLogits(logits, oneHot);
+ SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.softmaxCrossEntropyWithLogits(logits, oneHot);
Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0));
Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math
.add(tf.nn.l2Loss(fc1Biases),
@@ -197,19 +194,13 @@ public static Graph build(String optimizerName) {
default:
throw new IllegalArgumentException("Unknown optimizer " + optimizerName);
}
- logger.info("Optimizer = " + optimizer.toString());
+ logger.info("Optimizer = " + optimizer);
Op minimize = optimizer.minimize(loss, TRAIN);
- tf.init();
-
return graph;
}
public static void train(Session session, int epochs, int minibatchSize, MnistDataset dataset) {
- // Initialises the parameters.
- session.runner().addTarget(INIT).run();
- logger.info("Initialised the model parameters");
-
int interval = 0;
// Train the model
for (int i = 0; i < epochs; i++) {
@@ -274,7 +265,7 @@ public static void test(Session session, int minibatchSize, MnistDataset dataset
sb.append("\n");
}
- System.out.println(sb.toString());
+ System.out.println(sb);
}
/**
diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMNIST.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMNIST.java
index b8c5c26..2a62270 100644
--- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMNIST.java
+++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMNIST.java
@@ -16,9 +16,8 @@
*/
package org.tensorflow.model.examples.cnn.vgg;
-import org.tensorflow.model.examples.datasets.mnist.MnistDataset;
-
import java.util.logging.Logger;
+import org.tensorflow.model.examples.datasets.mnist.MnistDataset;
/**
* Trains and evaluates VGG'11 model on FashionMNIST dataset.
diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java
index 9e9725d..d128c43 100644
--- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java
+++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java
@@ -16,14 +16,20 @@
*/
package org.tensorflow.model.examples.cnn.vgg;
+import java.util.Arrays;
+import java.util.logging.Level;
+import java.util.logging.Logger;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Session;
-import org.tensorflow.Tensor;
import org.tensorflow.framework.optimizers.Adam;
import org.tensorflow.framework.optimizers.Optimizer;
import org.tensorflow.model.examples.datasets.ImageBatch;
import org.tensorflow.model.examples.datasets.mnist.MnistDataset;
+import org.tensorflow.ndarray.ByteNdArray;
+import org.tensorflow.ndarray.FloatNdArray;
+import org.tensorflow.ndarray.Shape;
+import org.tensorflow.ndarray.index.Indices;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.OneHot;
@@ -35,19 +41,11 @@
import org.tensorflow.op.nn.Conv2d;
import org.tensorflow.op.nn.MaxPool;
import org.tensorflow.op.nn.Relu;
-import org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits;
+import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits;
import org.tensorflow.op.random.TruncatedNormal;
-import org.tensorflow.ndarray.Shape;
-import org.tensorflow.ndarray.ByteNdArray;
-import org.tensorflow.ndarray.FloatNdArray;
-import org.tensorflow.ndarray.index.Indices;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TUint8;
-import java.util.Arrays;
-import java.util.logging.Level;
-import java.util.logging.Logger;
-
/**
* Describes the VGGModel.
*/
@@ -64,7 +62,6 @@ public class VGGModel implements AutoCloseable {
public static final String TARGET = "target";
public static final String TRAIN = "train";
public static final String TRAINING_LOSS = "training_loss";
- public static final String INIT = "init";
private static final Logger logger = Logger.getLogger(VGGModel.class.getName());
@@ -127,8 +124,6 @@ public static Graph compile() {
optimizer.minimize(loss, TRAIN);
- tf.init();
-
return graph;
}
@@ -159,8 +154,7 @@ public static Add buildFCLayersAndRegularization(Ops tf, Placeholder oneHot = tf
.oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f));
- SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.raw
- .softmaxCrossEntropyWithLogits(logits, oneHot);
+ SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.softmaxCrossEntropyWithLogits(logits, oneHot);
Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0));
Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math
.add(tf.nn.l2Loss(fc1Biases),
@@ -193,10 +187,6 @@ public static Relu vggConv2DLayer(String layerName, Ops tf, Operand weights = tf.variable(weightShape, TFloat32.class);
- tf.initAdd(tf.assign(weights, tf.zerosLike(weights)));
+ Variable weights = tf.variable(tf.zeros(tf.constant(weightShape), TFloat32.class));
// Create biases with an initial value of 0
Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES);
- Variable biases = tf.variable(biasShape, TFloat32.class);
- tf.initAdd(tf.assign(biases, tf.zerosLike(biases)));
-
- // Register all variable initializers for single execution
- tf.init();
+ Variable biases = tf.variable(tf.zeros(tf.constant(biasShape), TFloat32.class));
// Predict the class of each image in the batch and compute the loss
Softmax softmax =
@@ -103,9 +97,6 @@ public void run() {
// Run the graph
try (Session session = new Session(graph)) {
- // Initialize variables
- session.run(tf.init());
-
// Train the model
for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) {
try (TFloat32 batchImages = preprocessImages(trainingBatch.images());
@@ -163,8 +154,8 @@ private static TFloat32 preprocessLabels(ByteNdArray rawLabels) {
).asTensor();
}
- private Graph graph;
- private MnistDataset dataset;
+ private final Graph graph;
+ private final MnistDataset dataset;
private SimpleMnist(Graph graph, MnistDataset dataset) {
this.graph = graph;
diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java
index ccd0a76..4e8fbd5 100644
--- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java
+++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java
@@ -16,11 +16,13 @@
*/
package org.tensorflow.model.examples.regression.linear;
+import java.util.List;
+import java.util.Random;
import org.tensorflow.Graph;
import org.tensorflow.Session;
-import org.tensorflow.Tensor;
import org.tensorflow.framework.optimizers.GradientDescent;
import org.tensorflow.framework.optimizers.Optimizer;
+import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
@@ -29,12 +31,8 @@
import org.tensorflow.op.math.Div;
import org.tensorflow.op.math.Mul;
import org.tensorflow.op.math.Pow;
-import org.tensorflow.ndarray.Shape;
import org.tensorflow.types.TFloat32;
-import java.util.List;
-import java.util.Random;
-
/**
* In this example TensorFlow finds the weight and bias of the linear regression during 1 epoch,
* training on observations one by one.
@@ -89,8 +87,6 @@ public static void main(String[] args) {
Op minimize = optimizer.minimize(mse);
try (Session session = new Session(graph)) {
- // Initialize graph variables
- session.run(tf.init());
// Train the model on data
for (int i = 0; i < xValues.length; i++) {