diff options
author | Harald Musum <musum@oath.com> | 2018-06-06 14:23:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-06 14:23:55 +0200 |
commit | 681963959794b47102d1a1cf72f215c72b0e2b51 (patch) | |
tree | 12ea280192a44f26b9718018c7cfb39b0c4c4735 /searchlib/src/test/java/com | |
parent | 240176d60c44507f4e6733c7512620e80554c8de (diff) |
Revert "Refactor ONNX and TF import to use same code base"
Diffstat (limited to 'searchlib/src/test/java/com')
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java) | 22 | ||||
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java) | 4 | ||||
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java) | 4 | ||||
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java) | 6 | ||||
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java) | 4 | ||||
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java) | 4 | ||||
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java) | 14 | ||||
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java) | 2 |
8 files changed, 35 insertions, 25 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java index a7926cd2e02..4b68cd40a08 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -1,9 +1,11 @@ -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.onnx; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -22,7 +24,7 @@ public class OnnxMnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() throws IOException { - ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"); + OnnxModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"); // Check constants assertEquals(2, model.largeConstants().size()); @@ -46,7 +48,7 @@ public class OnnxMnistSoftmaxImportTestCase { model.requiredMacros().get("Placeholder")); // Check outputs - RankingExpression output = model.defaultSignature().outputExpression("add"); + RankingExpression output = model.outputExpression("add"); assertNotNull(output); assertEquals("add", output.getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", @@ -66,12 +68,13 @@ public class OnnxMnistSoftmaxImportTestCase { } private Tensor evaluateTensorFlowModel(String path, Tensor argument, String input, String output) { - ImportedModel model = new TensorFlowImporter().importModel("test", path); + SavedModelBundle tensorFlowModel = SavedModelBundle.load(path, "serve"); + TensorFlowModel model = new TensorFlowImporter().importModel("test", tensorFlowModel); return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input); } private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) { - ImportedModel model = new OnnxImporter().importModel("test", path); + OnnxModel model = new OnnxImporter().importModel("test", path); return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input); } @@ -80,7 +83,14 @@ public class OnnxMnistSoftmaxImportTestCase { return expression.evaluate(context).asTensor(); } - private Context contextFrom(ImportedModel result) { + private Context contextFrom(TensorFlowModel result) { + MapContext context = new MapContext(); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + return context; + } + + private Context contextFrom(OnnxModel result) { MapContext context = new MapContext(); result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java index bf9684082f4..0f5eec93feb 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import org.junit.Test; @@ -15,7 +15,7 @@ public class BatchNormImportTestCase { @Test public void testBatchNormImport() { TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved"); - ImportedModel.Signature signature = model.get().signature("serving_default"); + TensorFlowModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java index c8c7ec798bb..74b0d11f1d6 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java @@ -1,6 +1,6 @@ -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; import org.junit.Test; import static org.junit.Assert.assertTrue; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java index a63c7346335..50a467ec581 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.TensorType; @@ -24,7 +24,7 @@ public class DropoutImportTestCase { assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), model.get().requiredMacros().get("X")); - ImportedModel.Signature signature = model.get().signature("serving_default"); + TensorFlowModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); @@ -32,7 +32,7 @@ public class DropoutImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("outputs/Maximum", output.getName()); - assertEquals("join(join(imported_ml_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))", + assertEquals("join(join(tf_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))", output.getRoot().toString()); model.assertEqualResult("X", output.getName()); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index bd7644be23b..9f919c452d6 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; @@ -45,7 +45,7 @@ public class MnistSoftmaxImportTestCase { // Check signatures assertEquals(1, model.get().signatures().size()); - ImportedModel.Signature signature = model.get().signatures().get("serving_default"); + TensorFlowModel.Signature signature = model.get().signatures().get("serving_default"); assertNotNull(signature); // ... signature inputs diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java index b2443082ab1..beec2ab1ead 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java @@ -1,6 +1,6 @@ -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java index 723c5f27914..7ca16939477 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -1,11 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; @@ -28,7 +28,7 @@ import static org.junit.Assert.assertEquals; public class TestableTensorFlowModel { private SavedModelBundle tensorFlowModel; - private ImportedModel model; + private TensorFlowModel model; // Sizes of the input vector private final int d0Size = 1; @@ -39,7 +39,7 @@ public class TestableTensorFlowModel { model = new TensorFlowImporter().importModel(modelName, tensorFlowModel); } - public ImportedModel get() { return model; } + public TensorFlowModel get() { return model; } public void assertEqualResult(String inputName, String operationName) { Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName); @@ -66,7 +66,7 @@ public class TestableTensorFlowModel { return TensorConverter.toVespaTensor(results.get(0)); } - private Context contextFrom(ImportedModel result) { + private Context contextFrom(TensorFlowModel result) { MapContext context = new MapContext(); result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); @@ -81,7 +81,7 @@ public class TestableTensorFlowModel { return b.build(); } - private void evaluateMacro(Context context, ImportedModel model, String macroName) { + private void evaluateMacro(Context context, TensorFlowModel model, String macroName) { if (!context.names().contains(macroName)) { RankingExpression e = model.macros().get(macroName); evaluateMacroDependencies(context, model, e.getRoot()); @@ -89,7 +89,7 @@ public class TestableTensorFlowModel { } } - private void evaluateMacroDependencies(Context context, ImportedModel model, ExpressionNode node) { + private void evaluateMacroDependencies(Context context, TensorFlowModel model, ExpressionNode node) { if (node instanceof ReferenceNode) { String name = node.toString(); if (model.macros().containsKey(name)) { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java index f94098e6255..051c2c60c95 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java @@ -1,4 +1,4 @@ -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import org.junit.Test; |