diff options
author | Arne H Juul <arnej@yahooinc.com> | 2022-01-06 18:30:08 +0000 |
---|---|---|
committer | Arne H Juul <arnej@yahooinc.com> | 2022-01-07 07:17:26 +0000 |
commit | 696e624b9cc9e1f4033c7bfc05f17e2cf33430d1 (patch) | |
tree | 04607404bbd59cf3e114ee7968272868df9527f7 /model-integration/src/test/java/ai | |
parent | 0867ac297c706bf962c2154ba2425f3a2ba2fa88 (diff) |
specialize TensorFunction etc on Reference
Diffstat (limited to 'model-integration/src/test/java/ai')
-rw-r--r-- | model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java index dfc4e98d409..3ef96cdf166 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.operations.Constant; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; @@ -703,7 +704,7 @@ public class OnnxOperationsTestCase { return builder.build(); } - private TensorFunction optimizeAndRename(String opName, IntermediateOperation op) { + private TensorFunction<Reference> optimizeAndRename(String opName, IntermediateOperation op) { IntermediateGraph graph = new IntermediateGraph(modelName); graph.put(opName, op); graph.outputs(graph.defaultSignature()).put(opName, opName); @@ -717,7 +718,7 @@ public class OnnxOperationsTestCase { if ( ! operationType.equals(standardNamingType)) { List<String> renameFrom = operationType.dimensionNames(); List<String> renameTo = standardNamingType.dimensionNames(); - TensorFunction func = new Rename(new ConstantTensor(tensor), renameFrom, renameTo); + TensorFunction<Reference> func = new Rename<>(new ConstantTensor<Reference>(tensor), renameFrom, renameTo); return func.evaluate(); } return tensor; |