diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2018-01-25 14:17:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-01-25 14:17:08 +0100 |
commit | ce0178998956ae2fea340d5e23e9f17c0e5c3db6 (patch) | |
tree | 46ef9c7f4a0ebfe7360cb4f771ac0774bea22d36 | |
parent | 819533f55f0f137d30c6828b7851a7e0d3010ed7 (diff) | |
parent | 1fe988b927663cd39a8f4189ecf05f202bb5c7c9 (diff) |
Merge pull request #4779 from vespa-engine/bratseth/tensorflow-cleanup
Bratseth/tensorflow cleanup
7 files changed, 169 insertions, 136 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index c95601f6bbf..5343d4622c7 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -33,7 +33,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.logging.Logger; /** * Replaces instances of the tensorflow(model-path, signature, output) @@ -46,8 +45,6 @@ import java.util.logging.Logger; // TODO: Avoid name conflicts across models for constants public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { - private static final Logger log = Logger.getLogger(TensorFlowFeatureConverter.class.getName()); - private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ @@ -68,8 +65,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil try { ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), - feature.getArguments()); - if (store.hasTensorFlowModels()) + feature.getArguments()); + if (store.hasTensorFlowModels()) // TODO: Check if we have created a converted model already instead return transformFromTensorFlowModel(store, context.rankProfile()); else // is should have previously stored model information instead return transformFromStoredModel(store, context.rankProfile()); @@ -206,7 +203,6 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil * Adds this expression to the application package, such that it can be read later. */ public void writeConverted(RankingExpression expression) { - log.info("Writing converted TensorFlow expression to " + arguments.expressionPath()); application.getFile(arguments.expressionPath()) .writeFile(new StringReader(expression.getRoot().toString())); } @@ -214,7 +210,6 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil /** Reads the previously stored ranking expression for these arguments */ public RankingExpression readConverted() { try { - log.info("Reading converted TensorFlow expression from " + arguments.expressionPath()); return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); } catch (IOException e) { @@ -261,12 +256,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } // Remember the constant in a file we replicate in ZooKeeper - log.info("Writing converted TensorFlow constant information to " + arguments.rankingConstantsPath().append(name + ".constant")); application.getFile(arguments.rankingConstantsPath().append(name + ".constant")) .writeFile(new StringReader(name + ":" + constant.type() + ":" + constantPathCorrected)); // Write content explicitly as a file on the file system as this is distributed using file distribution - log.info("Writing converted TensorFlow constant to " + application.getFileReference(constantPath).getAbsolutePath()); createIfNeeded(constantsPath); IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); return constantPathCorrected; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java index ca880e6f310..8edb9b9b7a1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java @@ -13,6 +13,8 @@ import java.nio.LongBuffer; /** + * Converts TensorFlow tensors into Vespa tensors. + * * @author bratseth */ public class TensorConverter { @@ -149,4 +151,5 @@ public class TensorConverter { } } + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index b9e244a3e08..4780c39d21d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -159,7 +159,8 @@ public class TensorFlowImporter { // not supported default : - throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported (" + tfNode.getName() + ")"); + throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + + "' is not supported (" + tfNode.getName() + ")"); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java new file mode 100644 index 00000000000..c6ee586a78c --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java @@ -0,0 +1,29 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * @author lesters + */ +public class BatchNormImportTestCase { + + @Test + public void testBatchNormImport() { + TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/batch_norm/saved"); + TensorFlowModel.Signature signature = model.get().signature("serving_default"); + + assertEquals("Has skipped outputs", + 0, model.get().signature("serving_default").skippedOutputs().size()); + + RankingExpression output = signature.outputExpression("y"); + assertNotNull(output); + assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName()); + model.assertEqualResult("X", output.getName()); + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java new file mode 100644 index 00000000000..f12b9a2c628 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -0,0 +1,62 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * @author bratseth + */ +public class MnistSoftmaxImportTestCase { + + @Test + public void testMnistSoftmaxImport() { + TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/mnist_softmax/saved"); + + // Check constants + assertEquals(2, model.get().constants().size()); + + Tensor constant0 = model.get().constants().get("Variable"); + assertNotNull(constant0); + assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), + constant0.type()); + assertEquals(7840, constant0.size()); + + Tensor constant1 = model.get().constants().get("Variable_1"); + assertNotNull(constant1); + assertEquals(new TensorType.Builder().indexed("d0", 10).build(), + constant1.type()); + assertEquals(10, constant1.size()); + + // Check signatures + assertEquals(1, model.get().signatures().size()); + TensorFlowModel.Signature signature = model.get().signatures().get("serving_default"); + assertNotNull(signature); + + // ... signature inputs + assertEquals(1, signature.inputs().size()); + TensorType argument0 = signature.inputArgument("x"); + assertNotNull(argument0); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); + + // ... signature outputs + assertEquals(1, signature.outputs().size()); + RankingExpression output = signature.outputExpression("y"); + assertNotNull(output); + assertEquals("add", output.getName()); + assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))", + output.getRoot().toString()); + + // Test execution + model.assertEqualResult("Placeholder", "Variable/read"); + model.assertEqualResult("Placeholder", "Variable_1/read"); + model.assertEqualResult("Placeholder", "MatMul"); + model.assertEqualResult("Placeholder", "add"); + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java deleted file mode 100644 index 13d042ee5dd..00000000000 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import org.junit.Test; -import org.tensorflow.SavedModelBundle; -import org.tensorflow.Session; - -import java.nio.FloatBuffer; -import java.util.List; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - -/** - * @author bratseth - */ -public class TensorflowImportTestCase { - - @Test - public void testMnistSoftmaxImport() { - String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved"; - SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); - TensorFlowModel result = new TensorFlowImporter().importModel(model); - - // Check constants - assertEquals(2, result.constants().size()); - - Tensor constant0 = result.constants().get("Variable"); - assertNotNull(constant0); - assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), - constant0.type()); - assertEquals(7840, constant0.size()); - - Tensor constant1 = result.constants().get("Variable_1"); - assertNotNull(constant1); - assertEquals(new TensorType.Builder().indexed("d0", 10).build(), - constant1.type()); - assertEquals(10, constant1.size()); - - // Check signatures - assertEquals(1, result.signatures().size()); - TensorFlowModel.Signature signature = result.signatures().get("serving_default"); - assertNotNull(signature); - - // ... signature inputs - assertEquals(1, signature.inputs().size()); - TensorType argument0 = signature.inputArgument("x"); - assertNotNull(argument0); - assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); - - // ... signature outputs - assertEquals(1, signature.outputs().size()); - RankingExpression output = signature.outputExpression("y"); - assertNotNull(output); - assertEquals("add", output.getName()); - assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))", - toNonPrimitiveString(output)); - - // Test execution - assertEqualResult(model, result, "Placeholder", "Variable/read"); - assertEqualResult(model, result, "Placeholder", "Variable_1/read"); - assertEqualResult(model, result, "Placeholder", "MatMul"); - assertEqualResult(model, result, "Placeholder", "add"); - } - - @Test - public void testBatchNormImport() { - String modelDir = "src/test/files/integration/tensorflow/batch_norm/saved"; - SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); - TensorFlowModel result = new TensorFlowImporter().importModel(model); - TensorFlowModel.Signature signature = result.signature("serving_default"); - - assertEquals("Has skipped outputs", 0, result.signature("serving_default").skippedOutputs().size()); - - RankingExpression output = signature.outputExpression("y"); - assertNotNull(output); - assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName()); - assertEqualResult(model, result, "X", output.getName()); - - } - - private void assertEqualResult(SavedModelBundle model, TensorFlowModel result, String inputName, String operationName) { - Tensor tfResult = tensorFlowExecute(model, inputName, operationName); - Context context = contextFrom(result); - Tensor placeholder = placeholderArgument(); - context.put(inputName, new TensorValue(placeholder)); - Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor(); - assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); - } - - private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) { - Session.Runner runner = model.session().runner(); - org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784)); - runner.feed(inputName, placeholder); - List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); - assertEquals(1, results.size()); - return new TensorConverter().toVespaTensor(results.get(0)); - } - - private Context contextFrom(TensorFlowModel result) { - MapContext context = new MapContext(); - result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); - return context; - } - - private String toNonPrimitiveString(RankingExpression expression) { - // toString on the wrapping expression will map to primitives, which is harder to read - return ((TensorFunctionNode)expression.getRoot()).function().toString(); - } - - private Tensor placeholderArgument() { - int size = 784; - Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", size).build()); - for (int i = 0; i < size; i++) - b.cell(0, 0, i); - return b.build(); - } - -} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java new file mode 100644 index 00000000000..186717d24cd --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -0,0 +1,71 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; + +import java.nio.FloatBuffer; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * Helper for TensorFlow import tests: Imports a model and provides asserts on it. + * This currently assumes the TensorFlow model takes a single input of type tensor(d0[1],d1[784]) + * + * @author bratseth + */ +public class TestableTensorFlowModel { + + private SavedModelBundle tensorFlowModel; + private TensorFlowModel model; + + // Sizes of the input vector + private final int d0Size = 1; + private final int d1Size = 784; + + public TestableTensorFlowModel(String modelDir) { + tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); + model = new TensorFlowImporter().importModel(tensorFlowModel); + } + + public TensorFlowModel get() { return model; } + + public void assertEqualResult(String inputName, String operationName) { + Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName); + Context context = contextFrom(model); + Tensor placeholder = placeholderArgument(); + context.put(inputName, new TensorValue(placeholder)); + Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor(); + assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); + } + + private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) { + Session.Runner runner = model.session().runner(); + org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, + FloatBuffer.allocate(d0Size * d1Size)); + runner.feed(inputName, placeholder); + List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); + assertEquals(1, results.size()); + return new TensorConverter().toVespaTensor(results.get(0)); + } + + private Context contextFrom(TensorFlowModel result) { + MapContext context = new MapContext(); + result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); + return context; + } + + private Tensor placeholderArgument() { + Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build()); + for (int d0 = 0; d0 < d0Size; d0++) + for (int d1 = 0; d1 < d1Size; d1++) + b.cell(0, d0, d1); + return b.build(); + } + +} |