diff options
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Tf2OnnxImportTestCase.java')
-rw-r--r-- | model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Tf2OnnxImportTestCase.java | 61 |
1 files changed, 0 insertions, 61 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Tf2OnnxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Tf2OnnxImportTestCase.java deleted file mode 100644 index 0510a433dd9..00000000000 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Tf2OnnxImportTestCase.java +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.rankingexpression.importer.tensorflow; - -import ai.vespa.rankingexpression.importer.ImportedModel; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import org.junit.Ignore; -import org.junit.Test; - -import static org.junit.Assert.assertEquals; - -/** - * @author lesters - */ -public class Tf2OnnxImportTestCase { - - @Ignore // Ignored because conversion requires python tf2onnx dependencies - tested in system test - @Test - public void testConversionFromTensorFlowToOnnx() { - String modelPath = "src/test/models/tensorflow/mnist_softmax/saved"; - String modelPathToConvert = "src/test/models/tensorflow/mnist_softmax/tf_2_onnx"; - - Tensor argument = placeholderArgument(); - Tensor tensorFlowResult = evaluateTensorFlowModel(modelPath, argument, "Placeholder", "add"); - Tensor tf2OnnxResult = evaluateTensorFlowModel(modelPathToConvert, argument, "Placeholder", "add"); - - assertEquals("Operation 'add' produces equal results", tensorFlowResult, tf2OnnxResult); - } - - private Tensor evaluateTensorFlowModel(String path, Tensor argument, String input, String output) { - ImportedModel model = new TensorFlowImporter().importModel("test", path); - String outputExpr = model.signatures().values().iterator().next().outputs().values().iterator().next(); - return evaluateExpression(model.expressions().get(outputExpr), contextFrom(model), argument, input); - } - - private Tensor evaluateExpression(RankingExpression expression, Context context, Tensor argument, String input) { - context.put(input, new TensorValue(argument)); - return expression.evaluate(context).asTensor(); - } - - private Context contextFrom(ImportedModel result) { - MapContext context = new MapContext(); - result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); - result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); - return context; - } - - private Tensor placeholderArgument() { - Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", 784).build()); - for (int d0 = 0; d0 < 1; d0++) - for (int d1 = 0; d1 < 784; d1++) - b.cell(d1 * 1.0 / 784, d0, d1); - return b.build(); - } - - -} |