aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2018-01-25 14:17:08 +0100
committerGitHub <noreply@github.com>2018-01-25 14:17:08 +0100
commitce0178998956ae2fea340d5e23e9f17c0e5c3db6 (patch)
tree46ef9c7f4a0ebfe7360cb4f771ac0774bea22d36
parent819533f55f0f137d30c6828b7851a7e0d3010ed7 (diff)
parent1fe988b927663cd39a8f4189ecf05f202bb5c7c9 (diff)
Merge pull request #4779 from vespa-engine/bratseth/tensorflow-cleanup
Bratseth/tensorflow cleanup
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java11
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java3
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java29
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java62
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java126
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java71
7 files changed, 169 insertions, 136 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index c95601f6bbf..5343d4622c7 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -33,7 +33,6 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
-import java.util.logging.Logger;
/**
* Replaces instances of the tensorflow(model-path, signature, output)
@@ -46,8 +45,6 @@ import java.util.logging.Logger;
// TODO: Avoid name conflicts across models for constants
public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
- private static final Logger log = Logger.getLogger(TensorFlowFeatureConverter.class.getName());
-
private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
@@ -68,8 +65,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
try {
ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(),
- feature.getArguments());
- if (store.hasTensorFlowModels())
+ feature.getArguments());
+ if (store.hasTensorFlowModels()) // TODO: Check if we have created a converted model already instead
return transformFromTensorFlowModel(store, context.rankProfile());
else // is should have previously stored model information instead
return transformFromStoredModel(store, context.rankProfile());
@@ -206,7 +203,6 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
* Adds this expression to the application package, such that it can be read later.
*/
public void writeConverted(RankingExpression expression) {
- log.info("Writing converted TensorFlow expression to " + arguments.expressionPath());
application.getFile(arguments.expressionPath())
.writeFile(new StringReader(expression.getRoot().toString()));
}
@@ -214,7 +210,6 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
/** Reads the previously stored ranking expression for these arguments */
public RankingExpression readConverted() {
try {
- log.info("Reading converted TensorFlow expression from " + arguments.expressionPath());
return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
}
catch (IOException e) {
@@ -261,12 +256,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
// Remember the constant in a file we replicate in ZooKeeper
- log.info("Writing converted TensorFlow constant information to " + arguments.rankingConstantsPath().append(name + ".constant"));
application.getFile(arguments.rankingConstantsPath().append(name + ".constant"))
.writeFile(new StringReader(name + ":" + constant.type() + ":" + constantPathCorrected));
// Write content explicitly as a file on the file system as this is distributed using file distribution
- log.info("Writing converted TensorFlow constant to " + application.getFileReference(constantPath).getAbsolutePath());
createIfNeeded(constantsPath);
IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
return constantPathCorrected;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
index ca880e6f310..8edb9b9b7a1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
@@ -13,6 +13,8 @@ import java.nio.LongBuffer;
/**
+ * Converts TensorFlow tensors into Vespa tensors.
+ *
* @author bratseth
*/
public class TensorConverter {
@@ -149,4 +151,5 @@ public class TensorConverter {
}
}
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index b9e244a3e08..4780c39d21d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -159,7 +159,8 @@ public class TensorFlowImporter {
// not supported
default :
- throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported (" + tfNode.getName() + ")");
+ throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() +
+ "' is not supported (" + tfNode.getName() + ")");
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
new file mode 100644
index 00000000000..c6ee586a78c
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
@@ -0,0 +1,29 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+/**
+ * @author lesters
+ */
+public class BatchNormImportTestCase {
+
+ @Test
+ public void testBatchNormImport() {
+ TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/batch_norm/saved");
+ TensorFlowModel.Signature signature = model.get().signature("serving_default");
+
+ assertEquals("Has skipped outputs",
+ 0, model.get().signature("serving_default").skippedOutputs().size());
+
+ RankingExpression output = signature.outputExpression("y");
+ assertNotNull(output);
+ assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName());
+ model.assertEqualResult("X", output.getName());
+ }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
new file mode 100644
index 00000000000..f12b9a2c628
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
@@ -0,0 +1,62 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+/**
+ * @author bratseth
+ */
+public class MnistSoftmaxImportTestCase {
+
+ @Test
+ public void testMnistSoftmaxImport() {
+ TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/mnist_softmax/saved");
+
+ // Check constants
+ assertEquals(2, model.get().constants().size());
+
+ Tensor constant0 = model.get().constants().get("Variable");
+ assertNotNull(constant0);
+ assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
+ constant0.type());
+ assertEquals(7840, constant0.size());
+
+ Tensor constant1 = model.get().constants().get("Variable_1");
+ assertNotNull(constant1);
+ assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
+ constant1.type());
+ assertEquals(10, constant1.size());
+
+ // Check signatures
+ assertEquals(1, model.get().signatures().size());
+ TensorFlowModel.Signature signature = model.get().signatures().get("serving_default");
+ assertNotNull(signature);
+
+ // ... signature inputs
+ assertEquals(1, signature.inputs().size());
+ TensorType argument0 = signature.inputArgument("x");
+ assertNotNull(argument0);
+ assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
+
+ // ... signature outputs
+ assertEquals(1, signature.outputs().size());
+ RankingExpression output = signature.outputExpression("y");
+ assertNotNull(output);
+ assertEquals("add", output.getName());
+ assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))",
+ output.getRoot().toString());
+
+ // Test execution
+ model.assertEqualResult("Placeholder", "Variable/read");
+ model.assertEqualResult("Placeholder", "Variable_1/read");
+ model.assertEqualResult("Placeholder", "MatMul");
+ model.assertEqualResult("Placeholder", "add");
+ }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
deleted file mode 100644
index 13d042ee5dd..00000000000
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
+++ /dev/null
@@ -1,126 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import org.junit.Test;
-import org.tensorflow.SavedModelBundle;
-import org.tensorflow.Session;
-
-import java.nio.FloatBuffer;
-import java.util.List;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
-
-/**
- * @author bratseth
- */
-public class TensorflowImportTestCase {
-
- @Test
- public void testMnistSoftmaxImport() {
- String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved";
- SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
- TensorFlowModel result = new TensorFlowImporter().importModel(model);
-
- // Check constants
- assertEquals(2, result.constants().size());
-
- Tensor constant0 = result.constants().get("Variable");
- assertNotNull(constant0);
- assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
- constant0.type());
- assertEquals(7840, constant0.size());
-
- Tensor constant1 = result.constants().get("Variable_1");
- assertNotNull(constant1);
- assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
- constant1.type());
- assertEquals(10, constant1.size());
-
- // Check signatures
- assertEquals(1, result.signatures().size());
- TensorFlowModel.Signature signature = result.signatures().get("serving_default");
- assertNotNull(signature);
-
- // ... signature inputs
- assertEquals(1, signature.inputs().size());
- TensorType argument0 = signature.inputArgument("x");
- assertNotNull(argument0);
- assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
-
- // ... signature outputs
- assertEquals(1, signature.outputs().size());
- RankingExpression output = signature.outputExpression("y");
- assertNotNull(output);
- assertEquals("add", output.getName());
- assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))",
- toNonPrimitiveString(output));
-
- // Test execution
- assertEqualResult(model, result, "Placeholder", "Variable/read");
- assertEqualResult(model, result, "Placeholder", "Variable_1/read");
- assertEqualResult(model, result, "Placeholder", "MatMul");
- assertEqualResult(model, result, "Placeholder", "add");
- }
-
- @Test
- public void testBatchNormImport() {
- String modelDir = "src/test/files/integration/tensorflow/batch_norm/saved";
- SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
- TensorFlowModel result = new TensorFlowImporter().importModel(model);
- TensorFlowModel.Signature signature = result.signature("serving_default");
-
- assertEquals("Has skipped outputs", 0, result.signature("serving_default").skippedOutputs().size());
-
- RankingExpression output = signature.outputExpression("y");
- assertNotNull(output);
- assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName());
- assertEqualResult(model, result, "X", output.getName());
-
- }
-
- private void assertEqualResult(SavedModelBundle model, TensorFlowModel result, String inputName, String operationName) {
- Tensor tfResult = tensorFlowExecute(model, inputName, operationName);
- Context context = contextFrom(result);
- Tensor placeholder = placeholderArgument();
- context.put(inputName, new TensorValue(placeholder));
- Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor();
- assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult);
- }
-
- private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) {
- Session.Runner runner = model.session().runner();
- org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784));
- runner.feed(inputName, placeholder);
- List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
- assertEquals(1, results.size());
- return new TensorConverter().toVespaTensor(results.get(0));
- }
-
- private Context contextFrom(TensorFlowModel result) {
- MapContext context = new MapContext();
- result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor)));
- return context;
- }
-
- private String toNonPrimitiveString(RankingExpression expression) {
- // toString on the wrapping expression will map to primitives, which is harder to read
- return ((TensorFunctionNode)expression.getRoot()).function().toString();
- }
-
- private Tensor placeholderArgument() {
- int size = 784;
- Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", size).build());
- for (int i = 0; i < size; i++)
- b.cell(0, 0, i);
- return b.build();
- }
-
-}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
new file mode 100644
index 00000000000..186717d24cd
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
@@ -0,0 +1,71 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+
+import java.nio.FloatBuffer;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Helper for TensorFlow import tests: Imports a model and provides asserts on it.
+ * This currently assumes the TensorFlow model takes a single input of type tensor(d0[1],d1[784])
+ *
+ * @author bratseth
+ */
+public class TestableTensorFlowModel {
+
+ private SavedModelBundle tensorFlowModel;
+ private TensorFlowModel model;
+
+ // Sizes of the input vector
+ private final int d0Size = 1;
+ private final int d1Size = 784;
+
+ public TestableTensorFlowModel(String modelDir) {
+ tensorFlowModel = SavedModelBundle.load(modelDir, "serve");
+ model = new TensorFlowImporter().importModel(tensorFlowModel);
+ }
+
+ public TensorFlowModel get() { return model; }
+
+ public void assertEqualResult(String inputName, String operationName) {
+ Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName);
+ Context context = contextFrom(model);
+ Tensor placeholder = placeholderArgument();
+ context.put(inputName, new TensorValue(placeholder));
+ Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor();
+ assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult);
+ }
+
+ private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) {
+ Session.Runner runner = model.session().runner();
+ org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size },
+ FloatBuffer.allocate(d0Size * d1Size));
+ runner.feed(inputName, placeholder);
+ List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
+ assertEquals(1, results.size());
+ return new TensorConverter().toVespaTensor(results.get(0));
+ }
+
+ private Context contextFrom(TensorFlowModel result) {
+ MapContext context = new MapContext();
+ result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor)));
+ return context;
+ }
+
+ private Tensor placeholderArgument() {
+ Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build());
+ for (int d0 = 0; d0 < d0Size; d0++)
+ for (int d1 = 0; d1 < d1Size; d1++)
+ b.cell(0, d0, d1);
+ return b.build();
+ }
+
+}