summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-06-19 15:06:04 +0200
committerLester Solbakken <lesters@oath.com>2020-06-19 15:06:04 +0200
commita9ede0792556b4ccb54af6a6367f4c7395e80f75 (patch)
treebd2b7ea95972ac04b65362e715656ca76a2dd731 /model-integration/src/test/java/ai
parent92b73034306ca58f6841f158149bd048bddb374f (diff)
Replace joins with same input with a map
Diffstat (limited to 'model-integration/src/test/java/ai')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java22
1 files changed, 20 insertions, 2 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
index 7b9868d71f5..7a0488362a9 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
@@ -579,6 +579,23 @@ public class OnnxOperationsTestCase {
assertEval("expand", input, shape, evaluate("tensor(d0[2],d1[3],d2[4]):[1,1,1,1,2,2,2,2,3,3,3,3,1,1,1,1,2,2,2,2,3,3,3,3]"));
}
+ @Test
+ public void testJoinWithSameInput() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]");
+ String opName = "mul";
+
+ Context context = new MapContext(DoubleValue.NaN);
+ List<IntermediateOperation> inputs = new ArrayList<>();
+ inputs.add(addInput(inputs, context, x, "x"));
+ IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, createAttributes().build(), 0);
+ optimizeAndRename(opName, op);
+
+ Tensor result = evaluate(op);
+ Tensor expected = evaluate("tensor(d0[2],d1[3]):[1,4,9,16,25,36]");
+ assertEquals(expected, result);
+ assertEquals(expected.type(), result.type());
+ }
+
private Tensor evaluate(String expr) throws ParseException {
return evaluate(expr, null, null, null);
}
@@ -668,12 +685,13 @@ public class OnnxOperationsTestCase {
return inputs;
}
- private void addInput(List<IntermediateOperation> inputs, Context context, Tensor x, String name) {
- if (x == null) return;
+ private IntermediateOperation addInput(List<IntermediateOperation> inputs, Context context, Tensor x, String name) {
+ if (x == null) return null;
context.put(name, new TensorValue(x));
IntermediateOperation op = new Constant(modelName, name, OrderedTensorType.fromSpec(x.type().toString()));
op.setConstantValueFunction(type -> new TensorValue(convertTypeAfterRename(x, type)));
inputs.add(op);
+ return op;
}
Tensor convertTypeAfterRename(Tensor tensor, OrderedTensorType type) {