diff options
author | Lester Solbakken <lesters@oath.com> | 2020-06-19 15:06:04 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-06-19 15:06:04 +0200 |
commit | a9ede0792556b4ccb54af6a6367f4c7395e80f75 (patch) | |
tree | bd2b7ea95972ac04b65362e715656ca76a2dd731 /model-integration/src/test/java/ai | |
parent | 92b73034306ca58f6841f158149bd048bddb374f (diff) |
Replace joins with same input with a map
Diffstat (limited to 'model-integration/src/test/java/ai')
-rw-r--r-- | model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java index 7b9868d71f5..7a0488362a9 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -579,6 +579,23 @@ public class OnnxOperationsTestCase { assertEval("expand", input, shape, evaluate("tensor(d0[2],d1[3],d2[4]):[1,1,1,1,2,2,2,2,3,3,3,3,1,1,1,1,2,2,2,2,3,3,3,3]")); } + @Test + public void testJoinWithSameInput() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"); + String opName = "mul"; + + Context context = new MapContext(DoubleValue.NaN); + List<IntermediateOperation> inputs = new ArrayList<>(); + inputs.add(addInput(inputs, context, x, "x")); + IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, createAttributes().build(), 0); + optimizeAndRename(opName, op); + + Tensor result = evaluate(op); + Tensor expected = evaluate("tensor(d0[2],d1[3]):[1,4,9,16,25,36]"); + assertEquals(expected, result); + assertEquals(expected.type(), result.type()); + } + private Tensor evaluate(String expr) throws ParseException { return evaluate(expr, null, null, null); } @@ -668,12 +685,13 @@ public class OnnxOperationsTestCase { return inputs; } - private void addInput(List<IntermediateOperation> inputs, Context context, Tensor x, String name) { - if (x == null) return; + private IntermediateOperation addInput(List<IntermediateOperation> inputs, Context context, Tensor x, String name) { + if (x == null) return null; context.put(name, new TensorValue(x)); IntermediateOperation op = new Constant(modelName, name, OrderedTensorType.fromSpec(x.type().toString())); op.setConstantValueFunction(type -> new TensorValue(convertTypeAfterRename(x, type))); inputs.add(op); + return op; } Tensor convertTypeAfterRename(Tensor tensor, OrderedTensorType type) { |