summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai
diff options
context:
space:
mode:
authorArne H Juul <arnej@yahooinc.com>2022-01-06 18:30:08 +0000
committerArne H Juul <arnej@yahooinc.com>2022-01-07 07:17:26 +0000
commit696e624b9cc9e1f4033c7bfc05f17e2cf33430d1 (patch)
tree04607404bbd59cf3e114ee7968272868df9527f7 /model-integration/src/test/java/ai
parent0867ac297c706bf962c2154ba2425f3a2ba2fa88 (diff)
specialize TensorFunction etc on Reference
Diffstat (limited to 'model-integration/src/test/java/ai')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java5
1 files changed, 3 insertions, 2 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
index dfc4e98d409..3ef96cdf166 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
@@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.Constant;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
@@ -703,7 +704,7 @@ public class OnnxOperationsTestCase {
return builder.build();
}
- private TensorFunction optimizeAndRename(String opName, IntermediateOperation op) {
+ private TensorFunction<Reference> optimizeAndRename(String opName, IntermediateOperation op) {
IntermediateGraph graph = new IntermediateGraph(modelName);
graph.put(opName, op);
graph.outputs(graph.defaultSignature()).put(opName, opName);
@@ -717,7 +718,7 @@ public class OnnxOperationsTestCase {
if ( ! operationType.equals(standardNamingType)) {
List<String> renameFrom = operationType.dimensionNames();
List<String> renameTo = standardNamingType.dimensionNames();
- TensorFunction func = new Rename(new ConstantTensor(tensor), renameFrom, renameTo);
+ TensorFunction<Reference> func = new Rename<>(new ConstantTensor<Reference>(tensor), renameFrom, renameTo);
return func.evaluate();
}
return tensor;