aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne H Juul <arnej@yahooinc.com>2022-01-06 18:30:08 +0000
committerArne H Juul <arnej@yahooinc.com>2022-01-07 07:17:26 +0000
commit696e624b9cc9e1f4033c7bfc05f17e2cf33430d1 (patch)
tree04607404bbd59cf3e114ee7968272868df9527f7
parent0867ac297c706bf962c2154ba2425f3a2ba2fa88 (diff)
specialize TensorFunction etc on Reference
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java13
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java1
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java7
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java11
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java9
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java7
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java17
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java17
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java7
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java19
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java11
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java9
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java9
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java15
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java15
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java19
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java7
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java11
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java7
-rw-r--r--model-integration/src/main/javacc/ModelParser.jj2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java5
41 files changed, 165 insertions, 133 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index b4b21d388b5..5627327d429 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -100,7 +101,7 @@ public abstract class ModelImporter implements MlModelImporter {
for (ImportedModel.Signature signature : model.signatures().values()) {
for (String outputName : signature.outputs().values()) {
try {
- Optional<TensorFunction> function = importExpression(graph.get(outputName), model);
+ Optional<TensorFunction<Reference>> function = importExpression(graph.get(outputName), model);
if (function.isEmpty()) {
signature.skippedOutput(outputName, "No valid output function could be found.");
}
@@ -112,7 +113,7 @@ public abstract class ModelImporter implements MlModelImporter {
}
}
- private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) {
+ private static Optional<TensorFunction<Reference>> importExpression(IntermediateOperation operation, ImportedModel model) {
if (model.expressions().containsKey(operation.name())) {
return operation.function();
}
@@ -134,7 +135,7 @@ public abstract class ModelImporter implements MlModelImporter {
operation.inputs().forEach(input -> importExpression(input, model));
}
- private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) {
+ private static Optional<TensorFunction<Reference>> importConstant(IntermediateOperation operation, ImportedModel model) {
String name = operation.vespaName();
if (model.hasLargeConstant(name) || model.hasSmallConstant(name)) {
return operation.function();
@@ -160,7 +161,7 @@ public abstract class ModelImporter implements MlModelImporter {
if (operation.function().isPresent()) {
String name = operation.name();
if ( ! model.expressions().containsKey(name)) {
- TensorFunction function = operation.function().get();
+ TensorFunction<Reference> function = operation.function().get();
if (isSignatureOutput(model, operation)) {
OrderedTensorType operationType = operation.type().get();
@@ -168,7 +169,7 @@ public abstract class ModelImporter implements MlModelImporter {
if ( ! operationType.equals(standardNamingType)) {
List<String> renameFrom = operationType.dimensionNames();
List<String> renameTo = standardNamingType.dimensionNames();
- function = new Rename(function, renameFrom, renameTo);
+ function = new Rename<Reference>(function, renameFrom, renameTo);
}
}
@@ -196,7 +197,7 @@ public abstract class ModelImporter implements MlModelImporter {
private static void importFunctionExpression(IntermediateOperation operation, ImportedModel model) {
if (operation.rankingExpressionFunction().isPresent()) {
- TensorFunction function = operation.rankingExpressionFunction().get();
+ TensorFunction<Reference> function = operation.rankingExpressionFunction().get();
try {
model.function(operation.rankingExpressionFunctionName(),
new RankingExpression(operation.rankingExpressionFunctionName(), function.toString()));
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
index 37f5ae9dd29..b77960ff3fb 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
@@ -2,6 +2,7 @@
package ai.vespa.rankingexpression.importer;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TensorTypeParser;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
index e58b5341e6b..bda2f16f9e2 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.operations;
+import com.yahoo.searchlib.rankingexpression.Reference;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import com.yahoo.tensor.evaluation.VariableTensor;
@@ -26,12 +27,12 @@ public class Argument extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
- TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type());
+ protected TensorFunction<Reference> lazyGetFunction() {
+ TensorFunction<Reference> output = new VariableTensor<Reference>(vespaName(), standardNamingType.type());
if ( ! standardNamingType.equals(type)) {
List<String> renameFrom = standardNamingType.dimensionNames();
List<String> renameTo = type.dimensionNames();
- output = new Rename(output, renameFrom, renameTo);
+ output = new Rename<Reference>(output, renameFrom, renameTo);
}
return output;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java
index bf10eb2457b..9484545c9c1 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
@@ -25,15 +26,15 @@ public class ConcatReduce extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputFunctionsPresent(inputs.size())) return null;
- TensorFunction result = inputs.get(0).function().get();
+ TensorFunction<Reference> result = inputs.get(0).function().get();
for (int i = 1; i < inputs.size(); ++i) {
- TensorFunction b = inputs.get(i).function().get();
- result = new com.yahoo.tensor.functions.Concat(result, b, tmpDimensionName);
+ TensorFunction<Reference> b = inputs.get(i).function().get();
+ result = new com.yahoo.tensor.functions.Concat<>(result, b, tmpDimensionName);
}
- return new com.yahoo.tensor.functions.Reduce(result, aggregator, tmpDimensionName);
+ return new com.yahoo.tensor.functions.Reduce<>(result, aggregator, tmpDimensionName);
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
index 9f3b15cddbd..6cb810aff94 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
@@ -68,14 +69,14 @@ public class ConcatV2 extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) {
return null;
}
- TensorFunction result = inputs.get(0).function().get();
+ TensorFunction<Reference> result = inputs.get(0).function().get();
for (int i = 1; i < inputs.size() - 1; ++i) {
- TensorFunction b = inputs.get(i).function().get();
- result = new com.yahoo.tensor.functions.Concat(result, b, concatDimensionName);
+ TensorFunction<Reference> b = inputs.get(i).function().get();
+ result = new com.yahoo.tensor.functions.Concat<>(result, b, concatDimensionName);
}
return result;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
index 859702dec40..d68b632bf61 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
@@ -35,7 +35,7 @@ public class Const extends IntermediateOperation {
}
@Override
- public Optional<TensorFunction> function() {
+ public Optional<TensorFunction<Reference>> function() {
if (function == null) {
function = lazyGetFunction();
}
@@ -43,7 +43,7 @@ public class Const extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
ExpressionNode expressionNode;
if (type.type().rank() == 0 && getConstantValue().isPresent()) {
expressionNode = new ConstantNode(getConstantValue().get().asDoubleValue());
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
index a381b2cb8a0..cdc408b3e70 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.functions.TensorFunction;
@@ -23,7 +24,7 @@ public class Constant extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
return null; // will be added by function() since this is constant.
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java
index c48e5592a56..d88fc34725e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
@@ -60,10 +61,10 @@ public class ConstantOfShape extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputTypesPresent(1)) return null;
ExpressionNode valueExpr = new ConstantNode(new DoubleValue(valueToFillWith));
- TensorFunction function = Generate.bound(type.type(), wrapScalar(valueExpr));
+ TensorFunction<Reference> function = Generate.bound(type.type(), wrapScalar(valueExpr));
return function;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java
index eda188b339f..6d57adbd888 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java
@@ -74,7 +74,7 @@ public class Expand extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if (!allInputFunctionsPresent(2)) return null;
IntermediateOperation input = inputs.get(0);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
index 027532cd02d..83132b0669c 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
@@ -4,6 +4,7 @@ package ai.vespa.rankingexpression.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
@@ -65,7 +66,7 @@ public class ExpandDims extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputFunctionsPresent(2)) return null;
// multiply with a generated tensor created from the reduced dimensions
@@ -75,9 +76,9 @@ public class ExpandDims extends IntermediateOperation {
}
TensorType generatedType = typeBuilder.build();
ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
- Generate generatedFunction = new Generate(generatedType,
+ Generate<Reference> generatedFunction = new Generate<>(generatedType,
new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
- return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply());
+ return new com.yahoo.tensor.functions.Join<>(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply());
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
index bab9c47ca9a..cd0c4da6d0f 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
@@ -71,7 +71,7 @@ public class Gather extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputFunctionsPresent(2)) return null;
IntermediateOperation data = inputs.get(0);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
index 4b3208fdeb0..1f447f2a575 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
@@ -78,7 +79,7 @@ public class Gemm extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! check2or3InputsPresent()) return null;
OrderedTensorType aType = inputs.get(0).type().get();
@@ -86,29 +87,29 @@ public class Gemm extends IntermediateOperation {
if (aType.type().rank() != 2 || bType.type().rank() != 2)
throw new IllegalArgumentException("Tensors in Gemm must have rank of exactly 2");
- Optional<TensorFunction> aFunction = inputs.get(0).function();
- Optional<TensorFunction> bFunction = inputs.get(1).function();
+ Optional<TensorFunction<Reference>> aFunction = inputs.get(0).function();
+ Optional<TensorFunction<Reference>> bFunction = inputs.get(1).function();
if (aFunction.isEmpty() || bFunction.isEmpty()) {
return null;
}
String joinDimension = aType.dimensions().get(1 - transposeA).name();
- TensorFunction AxB = new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), joinDimension);
- TensorFunction alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction(
+ TensorFunction<Reference> AxB = new com.yahoo.tensor.functions.Matmul<>(aFunction.get(), bFunction.get(), joinDimension);
+ TensorFunction<Reference> alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction(
new ArithmeticNode(
new TensorFunctionNode(AxB),
ArithmeticOperator.MULTIPLY,
new ConstantNode(new DoubleValue(alpha))));
if (inputs.size() == 3) {
- Optional<TensorFunction> cFunction = inputs.get(2).function();
- TensorFunction betaxC = new TensorFunctionNode.ExpressionTensorFunction(
+ Optional<TensorFunction<Reference>> cFunction = inputs.get(2).function();
+ TensorFunction<Reference> betaxC = new TensorFunctionNode.ExpressionTensorFunction(
new ArithmeticNode(
new TensorFunctionNode(cFunction.get()),
ArithmeticOperator.MULTIPLY,
new ConstantNode(new DoubleValue(beta))));
- return new com.yahoo.tensor.functions.Join(alphaxAxB, betaxC, ScalarFunctions.add());
+ return new com.yahoo.tensor.functions.Join<>(alphaxAxB, betaxC, ScalarFunctions.add());
}
return alphaxAxB;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
index f096cb1e54f..ab840e708a7 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
@@ -2,6 +2,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
@@ -20,7 +21,7 @@ public class Identity extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if (!allInputFunctionsPresent(1))
return null;
return inputs.get(0).function().orElse(null);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 6ebb478715a..6378442c6d0 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -45,8 +45,8 @@ public abstract class IntermediateOperation {
protected final List<IntermediateOperation> outputs = new ArrayList<>();
protected OrderedTensorType type;
- protected TensorFunction function;
- protected TensorFunction rankingExpressionFunction = null;
+ protected TensorFunction<Reference> function;
+ protected TensorFunction<Reference> rankingExpressionFunction = null;
protected boolean exportAsRankingFunction = false;
private boolean hasRenamedDimensions = false;
@@ -65,7 +65,7 @@ public abstract class IntermediateOperation {
}
protected abstract OrderedTensorType lazyGetType();
- protected abstract TensorFunction lazyGetFunction();
+ protected abstract TensorFunction<Reference> lazyGetFunction();
public String modelName() { return modelName; }
@@ -78,14 +78,14 @@ public abstract class IntermediateOperation {
}
/** Returns the Vespa tensor function implementing all operations from this node with inputs */
- public Optional<TensorFunction> function() {
+ public Optional<TensorFunction<Reference>> function() {
if (function == null) {
if (isConstant()) {
ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName()));
function = new TensorFunctionNode.ExpressionTensorFunction(constant);
} else if (outputs.size() > 1 || exportAsRankingFunction) {
rankingExpressionFunction = lazyGetFunction();
- function = new VariableTensor(rankingExpressionFunctionName(), type.type());
+ function = new VariableTensor<Reference>(rankingExpressionFunctionName(), type.type());
} else {
function = lazyGetFunction();
}
@@ -103,7 +103,7 @@ public abstract class IntermediateOperation {
public List<IntermediateOperation> outputs() { return Collections.unmodifiableList(outputs); }
/** Returns a function that should be added as a ranking expression function */
- public Optional<TensorFunction> rankingExpressionFunction() {
+ public Optional<TensorFunction<Reference>> rankingExpressionFunction() {
return Optional.ofNullable(rankingExpressionFunction);
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
index 92b5f2e743b..667641dc33a 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ScalarFunctions;
@@ -53,7 +54,7 @@ public class Join extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputTypesPresent(2)) return null;
if ( ! allInputFunctionsPresent(2)) return null;
@@ -63,7 +64,7 @@ public class Join extends IntermediateOperation {
if (mapOperator.isPresent()) {
IntermediateOperation input = inputs.get(0);
input.removeDuplicateOutputsTo(this); // avoids unnecessary function export
- return new com.yahoo.tensor.functions.Map(input.function().get(), mapOperator.get());
+ return new com.yahoo.tensor.functions.Map<Reference>(input.function().get(), mapOperator.get());
}
}
@@ -86,23 +87,23 @@ public class Join extends IntermediateOperation {
}
}
- TensorFunction aReducedFunction = a.function().get();
+ TensorFunction<Reference> aReducedFunction = a.function().get();
if (aDimensionsToReduce.size() > 0) {
- aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce);
+ aReducedFunction = new Reduce<Reference>(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce);
}
- TensorFunction bReducedFunction = b.function().get();
+ TensorFunction<Reference> bReducedFunction = b.function().get();
if (bDimensionsToReduce.size() > 0) {
- bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce);
+ bReducedFunction = new Reduce<Reference>(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce);
}
// retain order of inputs
if (a == inputs.get(1)) {
- TensorFunction temp = bReducedFunction;
+ TensorFunction<Reference> temp = bReducedFunction;
bReducedFunction = aReducedFunction;
aReducedFunction = temp;
}
- return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator);
+ return new com.yahoo.tensor.functions.Join<Reference>(aReducedFunction, bReducedFunction, operator);
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
index 1fd0f72f416..c9b03ba9b85 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
@@ -2,6 +2,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
@@ -26,12 +27,12 @@ public class Map extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if (!allInputFunctionsPresent(1)) {
return null;
}
- Optional<TensorFunction> input = inputs.get(0).function();
- return new com.yahoo.tensor.functions.Map(input.get(), operator);
+ Optional<TensorFunction<Reference>> input = inputs.get(0).function();
+ return new com.yahoo.tensor.functions.Map<>(input.get(), operator);
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 673df9be36b..7d64a023e27 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
@@ -58,24 +59,24 @@ public class MatMul extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputTypesPresent(2)) return null;
if ( ! allInputFunctionsPresent(2)) return null;
OrderedTensorType typeA = inputs.get(0).type().get();
OrderedTensorType typeB = inputs.get(1).type().get();
- TensorFunction functionA = handleBroadcasting(inputs.get(0).function().get(), typeA, typeB);
- TensorFunction functionB = handleBroadcasting(inputs.get(1).function().get(), typeB, typeA);
+ TensorFunction<Reference> functionA = handleBroadcasting(inputs.get(0).function().get(), typeA, typeB);
+ TensorFunction<Reference> functionB = handleBroadcasting(inputs.get(1).function().get(), typeB, typeA);
- return new com.yahoo.tensor.functions.Reduce(
- new Join(functionA, functionB, ScalarFunctions.multiply()),
+ return new com.yahoo.tensor.functions.Reduce<Reference>(
+ new Join<Reference>(functionA, functionB, ScalarFunctions.multiply()),
Reduce.Aggregator.sum,
typeA.dimensions().get(typeA.rank() - 1).name());
}
- private TensorFunction handleBroadcasting(TensorFunction tensorFunction, OrderedTensorType typeA, OrderedTensorType typeB) {
- List<Slice.DimensionValue> slices = new ArrayList<>();
+ private TensorFunction<Reference> handleBroadcasting(TensorFunction<Reference> tensorFunction, OrderedTensorType typeA, OrderedTensorType typeB) {
+ List<Slice.DimensionValue<Reference>> slices = new ArrayList<>();
for (int i = 0; i < typeA.rank() - 2; ++i) {
long dimSizeA = typeA.dimensions().get(i).size().get();
String dimNameA = typeA.dimensionNames().get(i);
@@ -84,11 +85,11 @@ public class MatMul extends IntermediateOperation {
long dimSizeB = typeB.dimensions().get(j).size().get();
if (dimSizeB > dimSizeA && dimSizeA == 1) {
ExpressionNode dimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero));
- slices.add(new Slice.DimensionValue(Optional.of(dimNameA), wrapScalar(dimensionExpression)));
+ slices.add(new Slice.DimensionValue<>(Optional.of(dimNameA), wrapScalar(dimensionExpression)));
}
}
}
- return slices.size() == 0 ? tensorFunction : new Slice(tensorFunction, slices);
+ return slices.size() == 0 ? tensorFunction : new Slice<>(tensorFunction, slices);
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
index a4a47ca8ce7..fd262b2892c 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
@@ -2,6 +2,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
@@ -56,12 +57,12 @@ public class Mean extends IntermediateOperation {
// optimization: if keepDims and one reduce dimension that has size 1: same as identity.
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputTypesPresent(2)) return null;
- TensorFunction inputFunction = inputs.get(0).function().get();
- TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions);
+ TensorFunction<Reference> inputFunction = inputs.get(0).function().get();
+ TensorFunction<Reference> output = new Reduce<>(inputFunction, Reduce.Aggregator.avg, reduceDimensions);
if (shouldKeepDimensions()) {
// multiply with a generated tensor created from the reduced dimensions
TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType());
@@ -70,9 +71,9 @@ public class Mean extends IntermediateOperation {
}
TensorType generatedType = typeBuilder.build();
ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
- Generate generatedFunction = new Generate(generatedType,
+ Generate<Reference> generatedFunction = new Generate<>(generatedType,
new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
- output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
+ output = new com.yahoo.tensor.functions.Join<>(output, generatedFunction, ScalarFunctions.multiply());
}
return output;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java
index f208cc97d4f..e2b5930f114 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java
@@ -2,6 +2,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
@@ -23,7 +24,7 @@ public class Merge extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
for (IntermediateOperation operation : inputs) {
if (operation.function().isPresent()) {
return operation.function().get();
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java
index 1d76fa3f0a7..d8055d548ad 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java
@@ -2,6 +2,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.Collections;
@@ -19,7 +20,7 @@ public class NoOp extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
return null;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java
index 7b0547be7d2..164e3dc5e11 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java
@@ -2,6 +2,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.functions.TensorFunction;
import onnx.Onnx.TensorProto.DataType;
@@ -30,13 +31,13 @@ public class OnnxCast extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputFunctionsPresent(1))
return null;
- TensorFunction input = inputs.get(0).function().get();
+ TensorFunction<Reference> input = inputs.get(0).function().get();
switch (toType) {
case BOOL:
- return new com.yahoo.tensor.functions.Map(input, new AsBool());
+ return new com.yahoo.tensor.functions.Map<>(input, new AsBool());
case INT8:
case INT16:
case INT32:
@@ -45,7 +46,7 @@ public class OnnxCast extends IntermediateOperation {
case UINT16:
case UINT32:
case UINT64:
- return new com.yahoo.tensor.functions.Map(input, new AsInt());
+ return new com.yahoo.tensor.functions.Map<>(input, new AsInt());
case FLOAT:
case DOUBLE:
case FLOAT16:
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java
index 2be8fc0dc4e..97818f4c27d 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
@@ -65,14 +66,14 @@ public class OnnxConcat extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) {
return null;
}
- TensorFunction result = inputs.get(0).function().get();
+ TensorFunction<Reference> result = inputs.get(0).function().get();
for (int i = 1; i < inputs.size(); ++i) {
- TensorFunction b = inputs.get(i).function().get();
- result = new com.yahoo.tensor.functions.Concat(result, b, concatDimensionName);
+ TensorFunction<Reference> b = inputs.get(i).function().get();
+ result = new com.yahoo.tensor.functions.Concat<>(result, b, concatDimensionName);
}
return result;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java
index 79123cb0380..675e18da637 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -36,7 +37,7 @@ public class OnnxConstant extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
return null; // will be added by function() since this is constant.
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java
index 3456a24f5dd..c0f825f9092 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java
@@ -2,6 +2,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
@@ -22,7 +23,7 @@ public class PlaceholderWithDefault extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if (!allInputFunctionsPresent(1)) {
return null;
}
@@ -32,7 +33,7 @@ public class PlaceholderWithDefault extends IntermediateOperation {
}
@Override
- public Optional<TensorFunction> rankingExpressionFunction() {
+ public Optional<TensorFunction<Reference>> rankingExpressionFunction() {
// For now, it is much more efficient to assume we always will return
// the default value, as we can prune away large parts of the expression
// tree by having it calculated as a constant. If a case arises where
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
index 81a9e4996b4..5c4e8cd6cd0 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
@@ -58,7 +59,7 @@ public class Range extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputTypesPresent(3)) return null;
String dimensionName = type().get().dimensionNames().get(0);
ExpressionNode startExpr = new ConstantNode(new DoubleValue(start));
@@ -66,7 +67,7 @@ public class Range extends IntermediateOperation {
ExpressionNode dimExpr = new EmbracedNode(new ReferenceNode(dimensionName));
ExpressionNode stepExpr = new ArithmeticNode(deltaExpr, ArithmeticOperator.MULTIPLY, dimExpr);
ExpressionNode addExpr = new ArithmeticNode(startExpr, ArithmeticOperator.PLUS, stepExpr);
- TensorFunction function = Generate.bound(type.type(), wrapScalar(addExpr));
+ TensorFunction<Reference> function = Generate.bound(type.type(), wrapScalar(addExpr));
return function;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
index 8e49ce15265..b7a8a4a4e43 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
@@ -72,14 +73,14 @@ public class Reduce extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputTypesPresent(1)) return null;
- TensorFunction inputFunction = inputs.get(0).function().get();
+ TensorFunction<Reference> inputFunction = inputs.get(0).function().get();
if (preOperator != null) {
- inputFunction = new com.yahoo.tensor.functions.Map(inputFunction, preOperator);
+ inputFunction = new com.yahoo.tensor.functions.Map<>(inputFunction, preOperator);
}
- TensorFunction output = new com.yahoo.tensor.functions.Reduce(inputFunction, aggregator, reduceDimensions);
+ TensorFunction<Reference> output = new com.yahoo.tensor.functions.Reduce<>(inputFunction, aggregator, reduceDimensions);
if (shouldKeepDimensions()) {
// multiply with a generated tensor created from the reduced dimensions
TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType());
@@ -88,12 +89,12 @@ public class Reduce extends IntermediateOperation {
}
TensorType generatedType = typeBuilder.build();
ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
- Generate generatedFunction = new Generate(generatedType,
+ Generate<Reference> generatedFunction = new Generate<>(generatedType,
new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
- output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
+ output = new com.yahoo.tensor.functions.Join<>(output, generatedFunction, ScalarFunctions.multiply());
}
if (postOperator != null) {
- output = new com.yahoo.tensor.functions.Map(output, postOperator);
+ output = new com.yahoo.tensor.functions.Map<>(output, postOperator);
}
return output;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
index 724e49084ee..d80058dfa07 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
@@ -43,9 +44,9 @@ public class Rename extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputFunctionsPresent(1)) return null;
- return new com.yahoo.tensor.functions.Rename(inputs.get(0).function().orElse(null), from, to);
+ return new com.yahoo.tensor.functions.Rename<>(inputs.get(0).function().orElse(null), from, to);
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index 57a43158c0d..7b675fa79af 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -110,12 +110,12 @@ public class Reshape extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent) ) return null;
if ( ! inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent) ) return null;
OrderedTensorType inputType = inputs.get(0).type().get();
- TensorFunction inputFunction = inputs.get(0).function().get();
+ TensorFunction<Reference> inputFunction = inputs.get(0).function().get();
return reshape(inputFunction, inputType, type);
}
@@ -129,7 +129,7 @@ public class Reshape extends IntermediateOperation {
return new Reshape(modelName(), name(), inputs, attributeMap);
}
- public TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) {
+ public TensorFunction<Reference> reshape(TensorFunction<Reference> inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) {
if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type())))
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
index a189ff9c07c..9836217866b 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ScalarFunctions;
@@ -34,13 +35,13 @@ public class Select extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if (!allInputFunctionsPresent(3)) {
return null;
}
IntermediateOperation conditionOperation = inputs().get(0);
- TensorFunction a = inputs().get(1).function().get();
- TensorFunction b = inputs().get(2).function().get();
+ TensorFunction<Reference> a = inputs().get(1).function().get();
+ TensorFunction<Reference> b = inputs().get(2).function().get();
// Shortcut: if we know during import which tensor to select, do that directly here.
if (conditionOperation.getConstantValue().isPresent()) {
@@ -61,13 +62,13 @@ public class Select extends IntermediateOperation {
// from 'x'. We do this by individually joining 'x' and 'y' with
// 'condition', and then joining the resulting two tensors.
- TensorFunction conditionFunction = conditionOperation.function().get();
- TensorFunction aCond = new com.yahoo.tensor.functions.Join(a, conditionFunction, ScalarFunctions.multiply());
- TensorFunction bCond = new com.yahoo.tensor.functions.Join(b, conditionFunction, new DoubleBinaryOperator() {
+ TensorFunction<Reference> conditionFunction = conditionOperation.function().get();
+ TensorFunction<Reference> aCond = new com.yahoo.tensor.functions.Join<>(a, conditionFunction, ScalarFunctions.multiply());
+ TensorFunction<Reference> bCond = new com.yahoo.tensor.functions.Join<>(b, conditionFunction, new DoubleBinaryOperator() {
@Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); }
@Override public String toString() { return "f(a,b)(a * (1-b))"; }
});
- return new com.yahoo.tensor.functions.Join(aCond, bCond, ScalarFunctions.add());
+ return new com.yahoo.tensor.functions.Join<>(aCond, bCond, ScalarFunctions.add());
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
index 28e0115810a..c1cffd4243e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
@@ -2,6 +2,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
@@ -28,7 +29,7 @@ public class Shape extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
return null; // will be added by function() since this is constant.
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
index ac5d66e22c1..91b7064b19c 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
@@ -143,7 +143,7 @@ public class Slice extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if (inputs.size() < 1 || inputs.get(0).function().isEmpty()) {
return null;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
index 6001bef87ed..d7060b9d440 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
@@ -2,6 +2,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Map;
import com.yahoo.tensor.functions.Reduce;
@@ -34,12 +35,12 @@ public class Softmax extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputFunctionsPresent(1)) return null;
List<String> reduceDimensions = reduceDimensions();
- TensorFunction input = inputs.get(0).function().get();
- TensorFunction sum = new Reduce(input, Reduce.Aggregator.sum, reduceDimensions);
- TensorFunction div = new Join(input, sum, ScalarFunctions.divide());
+ TensorFunction<Reference> input = inputs.get(0).function().get();
+ TensorFunction<Reference> sum = new Reduce<>(input, Reduce.Aggregator.sum, reduceDimensions);
+ TensorFunction<Reference> div = new Join<>(input, sum, ScalarFunctions.divide());
return div;
}
@@ -93,13 +94,13 @@ public class Softmax extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputFunctionsPresent(1)) return null;
List<String> reduceDimensions = reduceDimensions();
- TensorFunction input = inputs.get(0).function().get();
- TensorFunction max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions);
- TensorFunction cap = new Join(input, max, ScalarFunctions.subtract()); // to avoid overflow
- TensorFunction exp = new Map(cap, ScalarFunctions.exp());
+ TensorFunction<Reference> input = inputs.get(0).function().get();
+ TensorFunction<Reference> max = new Reduce<>(input, Reduce.Aggregator.max, reduceDimensions);
+ TensorFunction<Reference> cap = new Join<>(input, max, ScalarFunctions.subtract()); // to avoid overflow
+ TensorFunction<Reference> exp = new Map<>(cap, ScalarFunctions.exp());
return exp;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
index 2e586b38c71..6f720716adb 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
@@ -84,7 +84,7 @@ public class Split extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if (!allInputFunctionsPresent(1)) return null;
IntermediateOperation input = inputs.get(0);
@@ -104,7 +104,7 @@ public class Split extends IntermediateOperation {
com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues);
ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices);
- TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression));
+ TensorFunction<Reference> generate = Generate.bound(type.type(), wrapScalar(sliceExpression));
return generate;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
index 07110b9b966..9229d6af254 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
@@ -52,11 +53,11 @@ public class Squeeze extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputFunctionsPresent(1)) return null;
- TensorFunction inputFunction = inputs.get(0).function().get();
- return new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions);
+ TensorFunction<Reference> inputFunction = inputs.get(0).function().get();
+ return new Reduce<>(inputFunction, Reduce.Aggregator.sum, squeezeDimensions);
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
index b8ca114343d..902144cfea2 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
@@ -56,11 +57,11 @@ public class Sum extends IntermediateOperation {
// optimization: if keepDims and one reduce dimension that has size 1: same as identity.
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputTypesPresent(2)) return null;
- TensorFunction inputFunction = inputs.get(0).function().get();
- TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.sum, reduceDimensions);
+ TensorFunction<Reference> inputFunction = inputs.get(0).function().get();
+ TensorFunction<Reference> output = new Reduce<>(inputFunction, Reduce.Aggregator.sum, reduceDimensions);
if (shouldKeepDimensions()) {
// multiply with a generated tensor created from the reduced dimensions
TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType());
@@ -69,9 +70,9 @@ public class Sum extends IntermediateOperation {
}
TensorType generatedType = typeBuilder.build();
ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
- Generate generatedFunction = new Generate(generatedType,
+ Generate<Reference> generatedFunction = new Generate<>(generatedType,
new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
- output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
+ output = new com.yahoo.tensor.functions.Join<>(output, generatedFunction, ScalarFunctions.multiply());
}
return output;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java
index f41140075d1..502f0769350 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java
@@ -2,6 +2,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
@@ -29,7 +30,7 @@ public class Switch extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
IntermediateOperation predicateOperation = inputs().get(1);
if (!predicateOperation.getConstantValue().isPresent()) {
throw new IllegalArgumentException("Switch in " + name + ": predicate must be a constant");
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
index 7fe5e831391..4bfab284cc2 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
@@ -62,7 +62,7 @@ public class Tile extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if (!allInputFunctionsPresent(2)) return null;
IntermediateOperation input = inputs.get(0);
@@ -85,7 +85,7 @@ public class Tile extends IntermediateOperation {
com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues);
ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices);
- TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression));
+ TensorFunction<Reference> generate = Generate.bound(type.type(), wrapScalar(sliceExpression));
return generate;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java
index add24e665e6..ef51b11884a 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java
@@ -2,6 +2,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
@@ -36,7 +37,7 @@ public class Transpose extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if (!allInputFunctionsPresent(1))
return null;
return inputs.get(0).function().orElse(null);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java
index bd3130a7cd1..a73b5a4c6ef 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -64,7 +65,7 @@ public class Unsqueeze extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputFunctionsPresent(1)) return null;
// multiply with a generated tensor created from the expanded dimensions
@@ -74,9 +75,9 @@ public class Unsqueeze extends IntermediateOperation {
}
TensorType generatedType = typeBuilder.build();
ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
- Generate generatedFunction = new Generate(generatedType,
+ Generate<Reference> generatedFunction = new Generate<>(generatedType,
new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
- return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply());
+ return new com.yahoo.tensor.functions.Join<>(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply());
}
@Override
diff --git a/model-integration/src/main/javacc/ModelParser.jj b/model-integration/src/main/javacc/ModelParser.jj
index 9944b88a745..6f6f3508beb 100644
--- a/model-integration/src/main/javacc/ModelParser.jj
+++ b/model-integration/src/main/javacc/ModelParser.jj
@@ -170,7 +170,7 @@ void input() :
void function() :
{
String name, expression, parameter;
- List parameters = new ArrayList();
+ List< String > parameters = new ArrayList< String >();
}
{
( <FUNCTION> name = identifier()
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
index dfc4e98d409..3ef96cdf166 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
@@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.Constant;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
@@ -703,7 +704,7 @@ public class OnnxOperationsTestCase {
return builder.build();
}
- private TensorFunction optimizeAndRename(String opName, IntermediateOperation op) {
+ private TensorFunction<Reference> optimizeAndRename(String opName, IntermediateOperation op) {
IntermediateGraph graph = new IntermediateGraph(modelName);
graph.put(opName, op);
graph.outputs(graph.defaultSignature()).put(opName, opName);
@@ -717,7 +718,7 @@ public class OnnxOperationsTestCase {
if ( ! operationType.equals(standardNamingType)) {
List<String> renameFrom = operationType.dimensionNames();
List<String> renameTo = standardNamingType.dimensionNames();
- TensorFunction func = new Rename(new ConstantTensor(tensor), renameFrom, renameTo);
+ TensorFunction<Reference> func = new Rename<>(new ConstantTensor<Reference>(tensor), renameFrom, renameTo);
return func.evaluate();
}
return tensor;