diff options
author | Arne H Juul <arnej@yahooinc.com> | 2022-01-06 18:30:08 +0000 |
---|---|---|
committer | Arne H Juul <arnej@yahooinc.com> | 2022-01-07 07:17:26 +0000 |
commit | 696e624b9cc9e1f4033c7bfc05f17e2cf33430d1 (patch) | |
tree | 04607404bbd59cf3e114ee7968272868df9527f7 | |
parent | 0867ac297c706bf962c2154ba2425f3a2ba2fa88 (diff) |
specialize TensorFunction etc on Reference
41 files changed, 165 insertions, 133 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index b4b21d388b5..5627327d429 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -100,7 +101,7 @@ public abstract class ModelImporter implements MlModelImporter { for (ImportedModel.Signature signature : model.signatures().values()) { for (String outputName : signature.outputs().values()) { try { - Optional<TensorFunction> function = importExpression(graph.get(outputName), model); + Optional<TensorFunction<Reference>> function = importExpression(graph.get(outputName), model); if (function.isEmpty()) { signature.skippedOutput(outputName, "No valid output function could be found."); } @@ -112,7 +113,7 @@ public abstract class ModelImporter implements MlModelImporter { } } - private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) { + private static Optional<TensorFunction<Reference>> importExpression(IntermediateOperation operation, ImportedModel model) { if (model.expressions().containsKey(operation.name())) { return operation.function(); } @@ -134,7 +135,7 @@ public abstract class ModelImporter implements MlModelImporter { operation.inputs().forEach(input -> importExpression(input, model)); } - private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) { + private static Optional<TensorFunction<Reference>> importConstant(IntermediateOperation operation, ImportedModel model) { String name = operation.vespaName(); if (model.hasLargeConstant(name) || model.hasSmallConstant(name)) { return operation.function(); @@ -160,7 +161,7 @@ public abstract class ModelImporter implements MlModelImporter { if (operation.function().isPresent()) { String name = operation.name(); if ( ! model.expressions().containsKey(name)) { - TensorFunction function = operation.function().get(); + TensorFunction<Reference> function = operation.function().get(); if (isSignatureOutput(model, operation)) { OrderedTensorType operationType = operation.type().get(); @@ -168,7 +169,7 @@ public abstract class ModelImporter implements MlModelImporter { if ( ! operationType.equals(standardNamingType)) { List<String> renameFrom = operationType.dimensionNames(); List<String> renameTo = standardNamingType.dimensionNames(); - function = new Rename(function, renameFrom, renameTo); + function = new Rename<Reference>(function, renameFrom, renameTo); } } @@ -196,7 +197,7 @@ public abstract class ModelImporter implements MlModelImporter { private static void importFunctionExpression(IntermediateOperation operation, ImportedModel model) { if (operation.rankingExpressionFunction().isPresent()) { - TensorFunction function = operation.rankingExpressionFunction().get(); + TensorFunction<Reference> function = operation.rankingExpressionFunction().get(); try { model.function(operation.rankingExpressionFunctionName(), new RankingExpression(operation.rankingExpressionFunctionName(), function.toString())); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java index 37f5ae9dd29..b77960ff3fb 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.TensorTypeParser; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java index e58b5341e6b..bda2f16f9e2 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer.operations; +import com.yahoo.searchlib.rankingexpression.Reference; import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.DimensionRenamer; import com.yahoo.tensor.evaluation.VariableTensor; @@ -26,12 +27,12 @@ public class Argument extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { - TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type()); + protected TensorFunction<Reference> lazyGetFunction() { + TensorFunction<Reference> output = new VariableTensor<Reference>(vespaName(), standardNamingType.type()); if ( ! standardNamingType.equals(type)) { List<String> renameFrom = standardNamingType.dimensionNames(); List<String> renameTo = type.dimensionNames(); - output = new Rename(output, renameFrom, renameTo); + output = new Rename<Reference>(output, renameFrom, renameTo); } return output; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java index bf10eb2457b..9484545c9c1 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.TensorFunction; @@ -25,15 +26,15 @@ public class ConcatReduce extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(inputs.size())) return null; - TensorFunction result = inputs.get(0).function().get(); + TensorFunction<Reference> result = inputs.get(0).function().get(); for (int i = 1; i < inputs.size(); ++i) { - TensorFunction b = inputs.get(i).function().get(); - result = new com.yahoo.tensor.functions.Concat(result, b, tmpDimensionName); + TensorFunction<Reference> b = inputs.get(i).function().get(); + result = new com.yahoo.tensor.functions.Concat<>(result, b, tmpDimensionName); } - return new com.yahoo.tensor.functions.Reduce(result, aggregator, tmpDimensionName); + return new com.yahoo.tensor.functions.Reduce<>(result, aggregator, tmpDimensionName); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java index 9f3b15cddbd..6cb810aff94 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; @@ -68,14 +69,14 @@ public class ConcatV2 extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) { return null; } - TensorFunction result = inputs.get(0).function().get(); + TensorFunction<Reference> result = inputs.get(0).function().get(); for (int i = 1; i < inputs.size() - 1; ++i) { - TensorFunction b = inputs.get(i).function().get(); - result = new com.yahoo.tensor.functions.Concat(result, b, concatDimensionName); + TensorFunction<Reference> b = inputs.get(i).function().get(); + result = new com.yahoo.tensor.functions.Concat<>(result, b, concatDimensionName); } return result; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java index 859702dec40..d68b632bf61 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java @@ -35,7 +35,7 @@ public class Const extends IntermediateOperation { } @Override - public Optional<TensorFunction> function() { + public Optional<TensorFunction<Reference>> function() { if (function == null) { function = lazyGetFunction(); } @@ -43,7 +43,7 @@ public class Const extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { ExpressionNode expressionNode; if (type.type().rank() == 0 && getConstantValue().isPresent()) { expressionNode = new ConstantNode(getConstantValue().get().asDoubleValue()); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java index a381b2cb8a0..cdc408b3e70 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.functions.TensorFunction; @@ -23,7 +24,7 @@ public class Constant extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { return null; // will be added by function() since this is constant. } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java index c48e5592a56..d88fc34725e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; @@ -60,10 +61,10 @@ public class ConstantOfShape extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputTypesPresent(1)) return null; ExpressionNode valueExpr = new ConstantNode(new DoubleValue(valueToFillWith)); - TensorFunction function = Generate.bound(type.type(), wrapScalar(valueExpr)); + TensorFunction<Reference> function = Generate.bound(type.type(), wrapScalar(valueExpr)); return function; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java index eda188b339f..6d57adbd888 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java @@ -74,7 +74,7 @@ public class Expand extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!allInputFunctionsPresent(2)) return null; IntermediateOperation input = inputs.get(0); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index 027532cd02d..83132b0669c 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -4,6 +4,7 @@ package ai.vespa.rankingexpression.importer.operations; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; @@ -65,7 +66,7 @@ public class ExpandDims extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(2)) return null; // multiply with a generated tensor created from the reduced dimensions @@ -75,9 +76,9 @@ public class ExpandDims extends IntermediateOperation { } TensorType generatedType = typeBuilder.build(); ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); - Generate generatedFunction = new Generate(generatedType, + Generate<Reference> generatedFunction = new Generate<>(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); - return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply()); + return new com.yahoo.tensor.functions.Join<>(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply()); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java index bab9c47ca9a..cd0c4da6d0f 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java @@ -71,7 +71,7 @@ public class Gather extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(2)) return null; IntermediateOperation data = inputs.get(0); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java index 4b3208fdeb0..1f447f2a575 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; @@ -78,7 +79,7 @@ public class Gemm extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! check2or3InputsPresent()) return null; OrderedTensorType aType = inputs.get(0).type().get(); @@ -86,29 +87,29 @@ public class Gemm extends IntermediateOperation { if (aType.type().rank() != 2 || bType.type().rank() != 2) throw new IllegalArgumentException("Tensors in Gemm must have rank of exactly 2"); - Optional<TensorFunction> aFunction = inputs.get(0).function(); - Optional<TensorFunction> bFunction = inputs.get(1).function(); + Optional<TensorFunction<Reference>> aFunction = inputs.get(0).function(); + Optional<TensorFunction<Reference>> bFunction = inputs.get(1).function(); if (aFunction.isEmpty() || bFunction.isEmpty()) { return null; } String joinDimension = aType.dimensions().get(1 - transposeA).name(); - TensorFunction AxB = new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), joinDimension); - TensorFunction alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction( + TensorFunction<Reference> AxB = new com.yahoo.tensor.functions.Matmul<>(aFunction.get(), bFunction.get(), joinDimension); + TensorFunction<Reference> alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction( new ArithmeticNode( new TensorFunctionNode(AxB), ArithmeticOperator.MULTIPLY, new ConstantNode(new DoubleValue(alpha)))); if (inputs.size() == 3) { - Optional<TensorFunction> cFunction = inputs.get(2).function(); - TensorFunction betaxC = new TensorFunctionNode.ExpressionTensorFunction( + Optional<TensorFunction<Reference>> cFunction = inputs.get(2).function(); + TensorFunction<Reference> betaxC = new TensorFunctionNode.ExpressionTensorFunction( new ArithmeticNode( new TensorFunctionNode(cFunction.get()), ArithmeticOperator.MULTIPLY, new ConstantNode(new DoubleValue(beta)))); - return new com.yahoo.tensor.functions.Join(alphaxAxB, betaxC, ScalarFunctions.add()); + return new com.yahoo.tensor.functions.Join<>(alphaxAxB, betaxC, ScalarFunctions.add()); } return alphaxAxB; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java index f096cb1e54f..ab840e708a7 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.TensorFunction; import java.util.List; @@ -20,7 +21,7 @@ public class Identity extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!allInputFunctionsPresent(1)) return null; return inputs.get(0).function().orElse(null); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 6ebb478715a..6378442c6d0 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -45,8 +45,8 @@ public abstract class IntermediateOperation { protected final List<IntermediateOperation> outputs = new ArrayList<>(); protected OrderedTensorType type; - protected TensorFunction function; - protected TensorFunction rankingExpressionFunction = null; + protected TensorFunction<Reference> function; + protected TensorFunction<Reference> rankingExpressionFunction = null; protected boolean exportAsRankingFunction = false; private boolean hasRenamedDimensions = false; @@ -65,7 +65,7 @@ public abstract class IntermediateOperation { } protected abstract OrderedTensorType lazyGetType(); - protected abstract TensorFunction lazyGetFunction(); + protected abstract TensorFunction<Reference> lazyGetFunction(); public String modelName() { return modelName; } @@ -78,14 +78,14 @@ public abstract class IntermediateOperation { } /** Returns the Vespa tensor function implementing all operations from this node with inputs */ - public Optional<TensorFunction> function() { + public Optional<TensorFunction<Reference>> function() { if (function == null) { if (isConstant()) { ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName())); function = new TensorFunctionNode.ExpressionTensorFunction(constant); } else if (outputs.size() > 1 || exportAsRankingFunction) { rankingExpressionFunction = lazyGetFunction(); - function = new VariableTensor(rankingExpressionFunctionName(), type.type()); + function = new VariableTensor<Reference>(rankingExpressionFunctionName(), type.type()); } else { function = lazyGetFunction(); } @@ -103,7 +103,7 @@ public abstract class IntermediateOperation { public List<IntermediateOperation> outputs() { return Collections.unmodifiableList(outputs); } /** Returns a function that should be added as a ranking expression function */ - public Optional<TensorFunction> rankingExpressionFunction() { + public Optional<TensorFunction<Reference>> rankingExpressionFunction() { return Optional.ofNullable(rankingExpressionFunction); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index 92b5f2e743b..667641dc33a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.ScalarFunctions; @@ -53,7 +54,7 @@ public class Join extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputTypesPresent(2)) return null; if ( ! allInputFunctionsPresent(2)) return null; @@ -63,7 +64,7 @@ public class Join extends IntermediateOperation { if (mapOperator.isPresent()) { IntermediateOperation input = inputs.get(0); input.removeDuplicateOutputsTo(this); // avoids unnecessary function export - return new com.yahoo.tensor.functions.Map(input.function().get(), mapOperator.get()); + return new com.yahoo.tensor.functions.Map<Reference>(input.function().get(), mapOperator.get()); } } @@ -86,23 +87,23 @@ public class Join extends IntermediateOperation { } } - TensorFunction aReducedFunction = a.function().get(); + TensorFunction<Reference> aReducedFunction = a.function().get(); if (aDimensionsToReduce.size() > 0) { - aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce); + aReducedFunction = new Reduce<Reference>(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce); } - TensorFunction bReducedFunction = b.function().get(); + TensorFunction<Reference> bReducedFunction = b.function().get(); if (bDimensionsToReduce.size() > 0) { - bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce); + bReducedFunction = new Reduce<Reference>(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce); } // retain order of inputs if (a == inputs.get(1)) { - TensorFunction temp = bReducedFunction; + TensorFunction<Reference> temp = bReducedFunction; bReducedFunction = aReducedFunction; aReducedFunction = temp; } - return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator); + return new com.yahoo.tensor.functions.Join<Reference>(aReducedFunction, bReducedFunction, operator); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java index 1fd0f72f416..c9b03ba9b85 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.TensorFunction; import java.util.List; @@ -26,12 +27,12 @@ public class Map extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!allInputFunctionsPresent(1)) { return null; } - Optional<TensorFunction> input = inputs.get(0).function(); - return new com.yahoo.tensor.functions.Map(input.get(), operator); + Optional<TensorFunction<Reference>> input = inputs.get(0).function(); + return new com.yahoo.tensor.functions.Map<>(input.get(), operator); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 673df9be36b..7d64a023e27 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; @@ -58,24 +59,24 @@ public class MatMul extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputTypesPresent(2)) return null; if ( ! allInputFunctionsPresent(2)) return null; OrderedTensorType typeA = inputs.get(0).type().get(); OrderedTensorType typeB = inputs.get(1).type().get(); - TensorFunction functionA = handleBroadcasting(inputs.get(0).function().get(), typeA, typeB); - TensorFunction functionB = handleBroadcasting(inputs.get(1).function().get(), typeB, typeA); + TensorFunction<Reference> functionA = handleBroadcasting(inputs.get(0).function().get(), typeA, typeB); + TensorFunction<Reference> functionB = handleBroadcasting(inputs.get(1).function().get(), typeB, typeA); - return new com.yahoo.tensor.functions.Reduce( - new Join(functionA, functionB, ScalarFunctions.multiply()), + return new com.yahoo.tensor.functions.Reduce<Reference>( + new Join<Reference>(functionA, functionB, ScalarFunctions.multiply()), Reduce.Aggregator.sum, typeA.dimensions().get(typeA.rank() - 1).name()); } - private TensorFunction handleBroadcasting(TensorFunction tensorFunction, OrderedTensorType typeA, OrderedTensorType typeB) { - List<Slice.DimensionValue> slices = new ArrayList<>(); + private TensorFunction<Reference> handleBroadcasting(TensorFunction<Reference> tensorFunction, OrderedTensorType typeA, OrderedTensorType typeB) { + List<Slice.DimensionValue<Reference>> slices = new ArrayList<>(); for (int i = 0; i < typeA.rank() - 2; ++i) { long dimSizeA = typeA.dimensions().get(i).size().get(); String dimNameA = typeA.dimensionNames().get(i); @@ -84,11 +85,11 @@ public class MatMul extends IntermediateOperation { long dimSizeB = typeB.dimensions().get(j).size().get(); if (dimSizeB > dimSizeA && dimSizeA == 1) { ExpressionNode dimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero)); - slices.add(new Slice.DimensionValue(Optional.of(dimNameA), wrapScalar(dimensionExpression))); + slices.add(new Slice.DimensionValue<>(Optional.of(dimNameA), wrapScalar(dimensionExpression))); } } } - return slices.size() == 0 ? tensorFunction : new Slice(tensorFunction, slices); + return slices.size() == 0 ? tensorFunction : new Slice<>(tensorFunction, slices); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java index a4a47ca8ce7..fd262b2892c 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import ai.vespa.rankingexpression.importer.DimensionRenamer; @@ -56,12 +57,12 @@ public class Mean extends IntermediateOperation { // optimization: if keepDims and one reduce dimension that has size 1: same as identity. @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputTypesPresent(2)) return null; - TensorFunction inputFunction = inputs.get(0).function().get(); - TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions); + TensorFunction<Reference> inputFunction = inputs.get(0).function().get(); + TensorFunction<Reference> output = new Reduce<>(inputFunction, Reduce.Aggregator.avg, reduceDimensions); if (shouldKeepDimensions()) { // multiply with a generated tensor created from the reduced dimensions TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); @@ -70,9 +71,9 @@ public class Mean extends IntermediateOperation { } TensorType generatedType = typeBuilder.build(); ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); - Generate generatedFunction = new Generate(generatedType, + Generate<Reference> generatedFunction = new Generate<>(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); - output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply()); + output = new com.yahoo.tensor.functions.Join<>(output, generatedFunction, ScalarFunctions.multiply()); } return output; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java index f208cc97d4f..e2b5930f114 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.TensorFunction; import java.util.List; @@ -23,7 +24,7 @@ public class Merge extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { for (IntermediateOperation operation : inputs) { if (operation.function().isPresent()) { return operation.function().get(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java index 1d76fa3f0a7..d8055d548ad 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.TensorFunction; import java.util.Collections; @@ -19,7 +20,7 @@ public class NoOp extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { return null; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java index 7b0547be7d2..164e3dc5e11 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.TensorFunction; import onnx.Onnx.TensorProto.DataType; @@ -30,13 +31,13 @@ public class OnnxCast extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(1)) return null; - TensorFunction input = inputs.get(0).function().get(); + TensorFunction<Reference> input = inputs.get(0).function().get(); switch (toType) { case BOOL: - return new com.yahoo.tensor.functions.Map(input, new AsBool()); + return new com.yahoo.tensor.functions.Map<>(input, new AsBool()); case INT8: case INT16: case INT32: @@ -45,7 +46,7 @@ public class OnnxCast extends IntermediateOperation { case UINT16: case UINT32: case UINT64: - return new com.yahoo.tensor.functions.Map(input, new AsInt()); + return new com.yahoo.tensor.functions.Map<>(input, new AsInt()); case FLOAT: case DOUBLE: case FLOAT16: diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java index 2be8fc0dc4e..97818f4c27d 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; @@ -65,14 +66,14 @@ public class OnnxConcat extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) { return null; } - TensorFunction result = inputs.get(0).function().get(); + TensorFunction<Reference> result = inputs.get(0).function().get(); for (int i = 1; i < inputs.size(); ++i) { - TensorFunction b = inputs.get(i).function().get(); - result = new com.yahoo.tensor.functions.Concat(result, b, concatDimensionName); + TensorFunction<Reference> b = inputs.get(i).function().get(); + result = new com.yahoo.tensor.functions.Concat<>(result, b, concatDimensionName); } return result; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java index 79123cb0380..675e18da637 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -36,7 +37,7 @@ public class OnnxConstant extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { return null; // will be added by function() since this is constant. } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java index 3456a24f5dd..c0f825f9092 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.TensorFunction; import java.util.List; @@ -22,7 +23,7 @@ public class PlaceholderWithDefault extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!allInputFunctionsPresent(1)) { return null; } @@ -32,7 +33,7 @@ public class PlaceholderWithDefault extends IntermediateOperation { } @Override - public Optional<TensorFunction> rankingExpressionFunction() { + public Optional<TensorFunction<Reference>> rankingExpressionFunction() { // For now, it is much more efficient to assume we always will return // the default value, as we can prune away large parts of the expression // tree by having it calculated as a constant. If a case arises where diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java index 81a9e4996b4..5c4e8cd6cd0 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; @@ -58,7 +59,7 @@ public class Range extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputTypesPresent(3)) return null; String dimensionName = type().get().dimensionNames().get(0); ExpressionNode startExpr = new ConstantNode(new DoubleValue(start)); @@ -66,7 +67,7 @@ public class Range extends IntermediateOperation { ExpressionNode dimExpr = new EmbracedNode(new ReferenceNode(dimensionName)); ExpressionNode stepExpr = new ArithmeticNode(deltaExpr, ArithmeticOperator.MULTIPLY, dimExpr); ExpressionNode addExpr = new ArithmeticNode(startExpr, ArithmeticOperator.PLUS, stepExpr); - TensorFunction function = Generate.bound(type.type(), wrapScalar(addExpr)); + TensorFunction<Reference> function = Generate.bound(type.type(), wrapScalar(addExpr)); return function; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java index 8e49ce15265..b7a8a4a4e43 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; @@ -72,14 +73,14 @@ public class Reduce extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputTypesPresent(1)) return null; - TensorFunction inputFunction = inputs.get(0).function().get(); + TensorFunction<Reference> inputFunction = inputs.get(0).function().get(); if (preOperator != null) { - inputFunction = new com.yahoo.tensor.functions.Map(inputFunction, preOperator); + inputFunction = new com.yahoo.tensor.functions.Map<>(inputFunction, preOperator); } - TensorFunction output = new com.yahoo.tensor.functions.Reduce(inputFunction, aggregator, reduceDimensions); + TensorFunction<Reference> output = new com.yahoo.tensor.functions.Reduce<>(inputFunction, aggregator, reduceDimensions); if (shouldKeepDimensions()) { // multiply with a generated tensor created from the reduced dimensions TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); @@ -88,12 +89,12 @@ public class Reduce extends IntermediateOperation { } TensorType generatedType = typeBuilder.build(); ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); - Generate generatedFunction = new Generate(generatedType, + Generate<Reference> generatedFunction = new Generate<>(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); - output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply()); + output = new com.yahoo.tensor.functions.Join<>(output, generatedFunction, ScalarFunctions.multiply()); } if (postOperator != null) { - output = new com.yahoo.tensor.functions.Map(output, postOperator); + output = new com.yahoo.tensor.functions.Map<>(output, postOperator); } return output; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java index 724e49084ee..d80058dfa07 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; @@ -43,9 +44,9 @@ public class Rename extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(1)) return null; - return new com.yahoo.tensor.functions.Rename(inputs.get(0).function().orElse(null), from, to); + return new com.yahoo.tensor.functions.Rename<>(inputs.get(0).function().orElse(null), from, to); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index 57a43158c0d..7b675fa79af 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -110,12 +110,12 @@ public class Reshape extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent) ) return null; if ( ! inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent) ) return null; OrderedTensorType inputType = inputs.get(0).type().get(); - TensorFunction inputFunction = inputs.get(0).function().get(); + TensorFunction<Reference> inputFunction = inputs.get(0).function().get(); return reshape(inputFunction, inputType, type); } @@ -129,7 +129,7 @@ public class Reshape extends IntermediateOperation { return new Reshape(modelName(), name(), inputs, attributeMap); } - public TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) { + public TensorFunction<Reference> reshape(TensorFunction<Reference> inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) { if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type()))) throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java index a189ff9c07c..9836217866b 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ScalarFunctions; @@ -34,13 +35,13 @@ public class Select extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!allInputFunctionsPresent(3)) { return null; } IntermediateOperation conditionOperation = inputs().get(0); - TensorFunction a = inputs().get(1).function().get(); - TensorFunction b = inputs().get(2).function().get(); + TensorFunction<Reference> a = inputs().get(1).function().get(); + TensorFunction<Reference> b = inputs().get(2).function().get(); // Shortcut: if we know during import which tensor to select, do that directly here. if (conditionOperation.getConstantValue().isPresent()) { @@ -61,13 +62,13 @@ public class Select extends IntermediateOperation { // from 'x'. We do this by individually joining 'x' and 'y' with // 'condition', and then joining the resulting two tensors. - TensorFunction conditionFunction = conditionOperation.function().get(); - TensorFunction aCond = new com.yahoo.tensor.functions.Join(a, conditionFunction, ScalarFunctions.multiply()); - TensorFunction bCond = new com.yahoo.tensor.functions.Join(b, conditionFunction, new DoubleBinaryOperator() { + TensorFunction<Reference> conditionFunction = conditionOperation.function().get(); + TensorFunction<Reference> aCond = new com.yahoo.tensor.functions.Join<>(a, conditionFunction, ScalarFunctions.multiply()); + TensorFunction<Reference> bCond = new com.yahoo.tensor.functions.Join<>(b, conditionFunction, new DoubleBinaryOperator() { @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); } @Override public String toString() { return "f(a,b)(a * (1-b))"; } }); - return new com.yahoo.tensor.functions.Join(aCond, bCond, ScalarFunctions.add()); + return new com.yahoo.tensor.functions.Join<>(aCond, bCond, ScalarFunctions.add()); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java index 28e0115810a..c1cffd4243e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; @@ -28,7 +29,7 @@ public class Shape extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { return null; // will be added by function() since this is constant. } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java index ac5d66e22c1..91b7064b19c 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java @@ -143,7 +143,7 @@ public class Slice extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (inputs.size() < 1 || inputs.get(0).function().isEmpty()) { return null; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java index 6001bef87ed..d7060b9d440 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Map; import com.yahoo.tensor.functions.Reduce; @@ -34,12 +35,12 @@ public class Softmax extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(1)) return null; List<String> reduceDimensions = reduceDimensions(); - TensorFunction input = inputs.get(0).function().get(); - TensorFunction sum = new Reduce(input, Reduce.Aggregator.sum, reduceDimensions); - TensorFunction div = new Join(input, sum, ScalarFunctions.divide()); + TensorFunction<Reference> input = inputs.get(0).function().get(); + TensorFunction<Reference> sum = new Reduce<>(input, Reduce.Aggregator.sum, reduceDimensions); + TensorFunction<Reference> div = new Join<>(input, sum, ScalarFunctions.divide()); return div; } @@ -93,13 +94,13 @@ public class Softmax extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(1)) return null; List<String> reduceDimensions = reduceDimensions(); - TensorFunction input = inputs.get(0).function().get(); - TensorFunction max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions); - TensorFunction cap = new Join(input, max, ScalarFunctions.subtract()); // to avoid overflow - TensorFunction exp = new Map(cap, ScalarFunctions.exp()); + TensorFunction<Reference> input = inputs.get(0).function().get(); + TensorFunction<Reference> max = new Reduce<>(input, Reduce.Aggregator.max, reduceDimensions); + TensorFunction<Reference> cap = new Join<>(input, max, ScalarFunctions.subtract()); // to avoid overflow + TensorFunction<Reference> exp = new Map<>(cap, ScalarFunctions.exp()); return exp; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java index 2e586b38c71..6f720716adb 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java @@ -84,7 +84,7 @@ public class Split extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!allInputFunctionsPresent(1)) return null; IntermediateOperation input = inputs.get(0); @@ -104,7 +104,7 @@ public class Split extends IntermediateOperation { com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues); ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); - TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression)); + TensorFunction<Reference> generate = Generate.bound(type.type(), wrapScalar(sliceExpression)); return generate; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java index 07110b9b966..9229d6af254 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; @@ -52,11 +53,11 @@ public class Squeeze extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(1)) return null; - TensorFunction inputFunction = inputs.get(0).function().get(); - return new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions); + TensorFunction<Reference> inputFunction = inputs.get(0).function().get(); + return new Reduce<>(inputFunction, Reduce.Aggregator.sum, squeezeDimensions); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java index b8ca114343d..902144cfea2 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; @@ -56,11 +57,11 @@ public class Sum extends IntermediateOperation { // optimization: if keepDims and one reduce dimension that has size 1: same as identity. @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputTypesPresent(2)) return null; - TensorFunction inputFunction = inputs.get(0).function().get(); - TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.sum, reduceDimensions); + TensorFunction<Reference> inputFunction = inputs.get(0).function().get(); + TensorFunction<Reference> output = new Reduce<>(inputFunction, Reduce.Aggregator.sum, reduceDimensions); if (shouldKeepDimensions()) { // multiply with a generated tensor created from the reduced dimensions TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); @@ -69,9 +70,9 @@ public class Sum extends IntermediateOperation { } TensorType generatedType = typeBuilder.build(); ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); - Generate generatedFunction = new Generate(generatedType, + Generate<Reference> generatedFunction = new Generate<>(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); - output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply()); + output = new com.yahoo.tensor.functions.Join<>(output, generatedFunction, ScalarFunctions.multiply()); } return output; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java index f41140075d1..502f0769350 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.TensorFunction; import java.util.List; @@ -29,7 +30,7 @@ public class Switch extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { IntermediateOperation predicateOperation = inputs().get(1); if (!predicateOperation.getConstantValue().isPresent()) { throw new IllegalArgumentException("Switch in " + name + ": predicate must be a constant"); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java index 7fe5e831391..4bfab284cc2 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java @@ -62,7 +62,7 @@ public class Tile extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!allInputFunctionsPresent(2)) return null; IntermediateOperation input = inputs.get(0); @@ -85,7 +85,7 @@ public class Tile extends IntermediateOperation { com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues); ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); - TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression)); + TensorFunction<Reference> generate = Generate.bound(type.type(), wrapScalar(sliceExpression)); return generate; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java index add24e665e6..ef51b11884a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; @@ -36,7 +37,7 @@ public class Transpose extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!allInputFunctionsPresent(1)) return null; return inputs.get(0).function().orElse(null); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java index bd3130a7cd1..a73b5a4c6ef 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -64,7 +65,7 @@ public class Unsqueeze extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(1)) return null; // multiply with a generated tensor created from the expanded dimensions @@ -74,9 +75,9 @@ public class Unsqueeze extends IntermediateOperation { } TensorType generatedType = typeBuilder.build(); ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); - Generate generatedFunction = new Generate(generatedType, + Generate<Reference> generatedFunction = new Generate<>(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); - return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply()); + return new com.yahoo.tensor.functions.Join<>(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply()); } @Override diff --git a/model-integration/src/main/javacc/ModelParser.jj b/model-integration/src/main/javacc/ModelParser.jj index 9944b88a745..6f6f3508beb 100644 --- a/model-integration/src/main/javacc/ModelParser.jj +++ b/model-integration/src/main/javacc/ModelParser.jj @@ -170,7 +170,7 @@ void input() : void function() : { String name, expression, parameter; - List parameters = new ArrayList(); + List< String > parameters = new ArrayList< String >(); } { ( <FUNCTION> name = identifier() diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java index dfc4e98d409..3ef96cdf166 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.operations.Constant; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; @@ -703,7 +704,7 @@ public class OnnxOperationsTestCase { return builder.build(); } - private TensorFunction optimizeAndRename(String opName, IntermediateOperation op) { + private TensorFunction<Reference> optimizeAndRename(String opName, IntermediateOperation op) { IntermediateGraph graph = new IntermediateGraph(modelName); graph.put(opName, op); graph.outputs(graph.defaultSignature()).put(opName, opName); @@ -717,7 +718,7 @@ public class OnnxOperationsTestCase { if ( ! operationType.equals(standardNamingType)) { List<String> renameFrom = operationType.dimensionNames(); List<String> renameTo = standardNamingType.dimensionNames(); - TensorFunction func = new Rename(new ConstantTensor(tensor), renameFrom, renameTo); + TensorFunction<Reference> func = new Rename<>(new ConstantTensor<Reference>(tensor), renameFrom, renameTo); return func.evaluate(); } return tensor; |