summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-11-27 14:30:12 +0100
committerJon Bratseth <bratseth@oath.com>2018-11-27 14:30:12 +0100
commitec5f88c54f8836c23f82260c2ee094ba9b98fe67 (patch)
tree95aea6d2833c99749c1a6d1e08bb11086513f637 /model-integration
parente12e2d54042b2aeca632ee630f0d67695dfb2f1b (diff)
Remove method only used by tests
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java10
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java6
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java17
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java8
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java6
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java6
6 files changed, 27 insertions, 26 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index 0c5866b87fa..f7fbbbe8b10 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -97,7 +97,6 @@ public class ImportedModel implements ImportedMlModel {
/**
* Returns an immutable map of the functions that are part of this model.
- * Note that the functions themselves are *not* copies and *not* immutable - they must be copied before modification.
*/
@Override
public Map<String, String> functions() { return asExpressionStrings(functions); }
@@ -245,15 +244,6 @@ public class ImportedModel implements ImportedMlModel {
*/
public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
- /** Returns the expression this output references */
- public ExpressionFunction outputExpression(String outputName) {
- return new ExpressionFunction(outputName,
- new ArrayList<>(inputs.values()),
- owner().expressions().get(outputs.get(outputName)),
- inputMap(),
- Optional.empty());
- }
-
/** Returns the expression this output references as an imported function */
public ImportedMlFunction outputFunction(String outputName, String functionName) {
return new ImportedMlFunction(functionName,
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
index 315456c2613..424e4d6c57c 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.onnx;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
@@ -45,11 +46,10 @@ public class OnnxMnistSoftmaxImportTestCase {
assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder"));
// Check signature
- ExpressionFunction output = model.defaultSignature().outputExpression("add");
+ ImportedMlFunction output = model.defaultSignature().outputFunction("add", "add");
assertNotNull(output);
- assertEquals("add", output.getBody().getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
- output.getBody().getRoot().toString());
+ output.expression());
assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"),
model.inputs().get(model.defaultSignature().inputs().get("Placeholder")));
assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString());
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java
index 1a072f54c89..1b8d06bf964 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java
@@ -1,10 +1,13 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.tensorflow;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import ai.vespa.rankingexpression.importer.ImportedModel;
import org.junit.Test;
+import java.util.List;
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
@@ -22,11 +25,19 @@ public class BatchNormImportTestCase {
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- ExpressionFunction function = signature.outputExpression("y");
+
+ // Test signature
+ ImportedMlFunction function = signature.outputFunction("y", "y");
assertNotNull(function);
- assertEquals("dnn/batch_normalization_3/batchnorm/add_1", function.getBody().getName());
- model.assertEqualResult("X", function.getBody().getName());
assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString());
+
+ // Test outputs
+ List<ImportedMlFunction> outputs = model.get().outputExpressions();
+ assertEquals(1, outputs.size());
+ assertEquals("serving_default.y", outputs.get(0).name());
+ assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString());
+ model.assertEqualResult("X", "dnn/batch_normalization_3/batchnorm/add_1");
}
+
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
index 5e20be051ea..5e5c81ddcf1 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.tensorflow;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import ai.vespa.rankingexpression.importer.ImportedModel;
import com.yahoo.tensor.TensorType;
@@ -31,12 +32,11 @@ public class DropoutImportTestCase {
Assert.assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- ExpressionFunction function = signature.outputExpression("y");
+ ImportedMlFunction function = signature.outputFunction("y", "y");
assertNotNull(function);
- assertEquals("outputs/Maximum", function.getBody().getName());
assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
- function.getBody().getRoot().toString());
- model.assertEqualResult("X", function.getBody().getName());
+ function.expression());
+ model.assertEqualResult("X", "outputs/Maximum");
assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString());
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java
index 28b91b3797a..42cc60608bf 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.tensorflow;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import ai.vespa.rankingexpression.importer.ImportedModel;
import org.junit.Assert;
@@ -22,10 +23,9 @@ public class MnistImportTestCase {
Assert.assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- ExpressionFunction output = signature.outputExpression("y");
+ ImportedMlFunction output = signature.outputFunction("y", "y");
assertNotNull(output);
- assertEquals("dnn/outputs/add", output.getBody().getName());
- model.assertEqualResultSum("input", output.getBody().getName(), 0.00001);
+ model.assertEqualResultSum("input", "dnn/outputs/add", 0.00001);
}
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
index be676186017..50e24f20972 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.tensorflow;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import ai.vespa.rankingexpression.importer.ImportedModel;
import com.yahoo.tensor.Tensor;
@@ -58,11 +59,10 @@ public class TensorFlowMnistSoftmaxImportTestCase {
// ... signature outputs
assertEquals(1, signature.outputs().size());
- ExpressionFunction output = signature.outputExpression("y");
+ ImportedMlFunction output = signature.outputFunction("y", "y");
assertNotNull(output);
- assertEquals("add", output.getBody().getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))",
- output.getBody().getRoot().toString());
+ output.expression());
assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString());
// Test execution