diff options
Diffstat (limited to 'searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java')
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java | 26 |
1 files changed, 14 insertions, 12 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java index e20ac16a691..b6e83404ab1 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java @@ -1,5 +1,6 @@ package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; @@ -27,27 +28,28 @@ public class OnnxMnistSoftmaxImportTestCase { Tensor constant0 = model.largeConstants().get("test_Variable"); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), - constant0.type()); + constant0.type()); assertEquals(7840, constant0.size()); Tensor constant1 = model.largeConstants().get("test_Variable_1"); assertNotNull(constant1); - assertEquals(new TensorType.Builder().indexed("d1", 10).build(), - constant1.type()); + assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); assertEquals(10, constant1.size()); - // Check required functions (inputs) - assertEquals(1, model.requiredFunctions().size()); - assertTrue(model.requiredFunctions().containsKey("Placeholder")); - assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.requiredFunctions().get("Placeholder")); + // Check inputs + assertEquals(1, model.inputs().size()); + assertTrue(model.inputs().containsKey("Placeholder")); + assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder")); - // Check outputs - RankingExpression output = model.defaultSignature().outputExpression("add"); + // Check signature + ExpressionFunction output = model.defaultSignature().outputExpression("add"); assertNotNull(output); - assertEquals("add", output.getName()); + assertEquals("add", output.getBody().getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", - output.getRoot().toString()); + output.getBody().getRoot().toString()); + assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), + model.inputs().get(model.defaultSignature().inputs().get("Placeholder"))); + assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString()); } @Test |