diff options
6 files changed, 27 insertions, 26 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index 0c5866b87fa..f7fbbbe8b10 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -97,7 +97,6 @@ public class ImportedModel implements ImportedMlModel { /** * Returns an immutable map of the functions that are part of this model. - * Note that the functions themselves are *not* copies and *not* immutable - they must be copied before modification. */ @Override public Map<String, String> functions() { return asExpressionStrings(functions); } @@ -245,15 +244,6 @@ public class ImportedModel implements ImportedMlModel { */ public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } - /** Returns the expression this output references */ - public ExpressionFunction outputExpression(String outputName) { - return new ExpressionFunction(outputName, - new ArrayList<>(inputs.values()), - owner().expressions().get(outputs.get(outputName)), - inputMap(), - Optional.empty()); - } - /** Returns the expression this output references as an imported function */ public ImportedMlFunction outputFunction(String outputName, String functionName) { return new ImportedMlFunction(functionName, diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index 315456c2613..424e4d6c57c 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer.onnx; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; @@ -45,11 +46,10 @@ public class OnnxMnistSoftmaxImportTestCase { assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder")); // Check signature - ExpressionFunction output = model.defaultSignature().outputExpression("add"); + ImportedMlFunction output = model.defaultSignature().outputFunction("add", "add"); assertNotNull(output); - assertEquals("add", output.getBody().getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", - output.getBody().getRoot().toString()); + output.expression()); assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get(model.defaultSignature().inputs().get("Placeholder"))); assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString()); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java index 1a072f54c89..1b8d06bf964 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java @@ -1,10 +1,13 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer.tensorflow; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import ai.vespa.rankingexpression.importer.ImportedModel; import org.junit.Test; +import java.util.List; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -22,11 +25,19 @@ public class BatchNormImportTestCase { assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ExpressionFunction function = signature.outputExpression("y"); + + // Test signature + ImportedMlFunction function = signature.outputFunction("y", "y"); assertNotNull(function); - assertEquals("dnn/batch_normalization_3/batchnorm/add_1", function.getBody().getName()); - model.assertEqualResult("X", function.getBody().getName()); assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString()); + + // Test outputs + List<ImportedMlFunction> outputs = model.get().outputExpressions(); + assertEquals(1, outputs.size()); + assertEquals("serving_default.y", outputs.get(0).name()); + assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString()); + model.assertEqualResult("X", "dnn/batch_normalization_3/batchnorm/add_1"); } + } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java index 5e20be051ea..5e5c81ddcf1 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer.tensorflow; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import ai.vespa.rankingexpression.importer.ImportedModel; import com.yahoo.tensor.TensorType; @@ -31,12 +32,11 @@ public class DropoutImportTestCase { Assert.assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ExpressionFunction function = signature.outputExpression("y"); + ImportedMlFunction function = signature.outputFunction("y", "y"); assertNotNull(function); - assertEquals("outputs/Maximum", function.getBody().getName()); assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))", - function.getBody().getRoot().toString()); - model.assertEqualResult("X", function.getBody().getName()); + function.expression()); + model.assertEqualResult("X", "outputs/Maximum"); assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString()); } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java index 28b91b3797a..42cc60608bf 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer.tensorflow; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import ai.vespa.rankingexpression.importer.ImportedModel; import org.junit.Assert; @@ -22,10 +23,9 @@ public class MnistImportTestCase { Assert.assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ExpressionFunction output = signature.outputExpression("y"); + ImportedMlFunction output = signature.outputFunction("y", "y"); assertNotNull(output); - assertEquals("dnn/outputs/add", output.getBody().getName()); - model.assertEqualResultSum("input", output.getBody().getName(), 0.00001); + model.assertEqualResultSum("input", "dnn/outputs/add", 0.00001); } } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java index be676186017..50e24f20972 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer.tensorflow; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import ai.vespa.rankingexpression.importer.ImportedModel; import com.yahoo.tensor.Tensor; @@ -58,11 +59,10 @@ public class TensorFlowMnistSoftmaxImportTestCase { // ... signature outputs assertEquals(1, signature.outputs().size()); - ExpressionFunction output = signature.outputExpression("y"); + ImportedMlFunction output = signature.outputFunction("y", "y"); assertNotNull(output); - assertEquals("add", output.getBody().getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))", - output.getBody().getRoot().toString()); + output.expression()); assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString()); // Test execution |