aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java')
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java26
1 files changed, 12 insertions, 14 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
index b6e83404ab1..e20ac16a691 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
@@ -1,6 +1,5 @@
package com.yahoo.searchlib.rankingexpression.integration.ml;
-import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
@@ -28,28 +27,27 @@ public class OnnxMnistSoftmaxImportTestCase {
Tensor constant0 = model.largeConstants().get("test_Variable");
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
- constant0.type());
+ constant0.type());
assertEquals(7840, constant0.size());
Tensor constant1 = model.largeConstants().get("test_Variable_1");
assertNotNull(constant1);
- assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type());
+ assertEquals(new TensorType.Builder().indexed("d1", 10).build(),
+ constant1.type());
assertEquals(10, constant1.size());
- // Check inputs
- assertEquals(1, model.inputs().size());
- assertTrue(model.inputs().containsKey("Placeholder"));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder"));
+ // Check required functions (inputs)
+ assertEquals(1, model.requiredFunctions().size());
+ assertTrue(model.requiredFunctions().containsKey("Placeholder"));
+ assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
+ model.requiredFunctions().get("Placeholder"));
- // Check signature
- ExpressionFunction output = model.defaultSignature().outputExpression("add");
+ // Check outputs
+ RankingExpression output = model.defaultSignature().outputExpression("add");
assertNotNull(output);
- assertEquals("add", output.getBody().getName());
+ assertEquals("add", output.getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
- output.getBody().getRoot().toString());
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"),
- model.inputs().get(model.defaultSignature().inputs().get("Placeholder")));
- assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString());
+ output.getRoot().toString());
}
@Test