diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-09-28 17:42:55 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-09-28 17:42:55 +0200 |
commit | bcbb2009c44380055b2670e7cdefcad232f9ece4 (patch) | |
tree | b1f1716e6dcb7cb79dcb9871af21758cf6c0a5c2 /model-integration | |
parent | 3d49f155fccfa4fc08882b01e7a6e3a982c55212 (diff) |
Drop 'arithmetic' from name
Diffstat (limited to 'model-integration')
7 files changed, 34 insertions, 36 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java index cd0c4da6d0f..c66022975c7 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java @@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -92,7 +92,7 @@ public class Gather extends IntermediateOperation { ExpressionNode indexExpression = new ConstantNode(new DoubleValue(constantValue)); if (constantValue < 0) { ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get())); - indexExpression = new EmbracedNode(new ArithmeticNode(indexExpression, ArithmeticOperator.PLUS, axisSize)); + indexExpression = new EmbracedNode(new OperationNode(indexExpression, Operator.PLUS, axisSize)); } addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression); } else { @@ -125,8 +125,8 @@ public class Gather extends IntermediateOperation { /** to support negative indexing */ private ExpressionNode createIndexExpression(OrderedTensorType dataType, ExpressionNode slice) { ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get())); - ExpressionNode plus = new EmbracedNode(new ArithmeticNode(slice, ArithmeticOperator.PLUS, axisSize)); - ExpressionNode mod = new ArithmeticNode(plus, ArithmeticOperator.MODULO, axisSize); + ExpressionNode plus = new EmbracedNode(new OperationNode(slice, Operator.PLUS, axisSize)); + ExpressionNode mod = new OperationNode(plus, Operator.MODULO, axisSize); return mod; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java index 1f447f2a575..81d633dea4b 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java @@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.TensorType; @@ -97,17 +97,17 @@ public class Gemm extends IntermediateOperation { TensorFunction<Reference> AxB = new com.yahoo.tensor.functions.Matmul<>(aFunction.get(), bFunction.get(), joinDimension); TensorFunction<Reference> alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction( - new ArithmeticNode( + new OperationNode( new TensorFunctionNode(AxB), - ArithmeticOperator.MULTIPLY, + Operator.MULTIPLY, new ConstantNode(new DoubleValue(alpha)))); if (inputs.size() == 3) { Optional<TensorFunction<Reference>> cFunction = inputs.get(2).function(); TensorFunction<Reference> betaxC = new TensorFunctionNode.ExpressionTensorFunction( - new ArithmeticNode( + new OperationNode( new TensorFunctionNode(cFunction.get()), - ArithmeticOperator.MULTIPLY, + Operator.MULTIPLY, new ConstantNode(new DoubleValue(beta)))); return new com.yahoo.tensor.functions.Join<>(alphaxAxB, betaxC, ScalarFunctions.add()); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java index 5c4e8cd6cd0..66e810b954e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java @@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -65,8 +65,8 @@ public class Range extends IntermediateOperation { ExpressionNode startExpr = new ConstantNode(new DoubleValue(start)); ExpressionNode deltaExpr = new ConstantNode(new DoubleValue(delta)); ExpressionNode dimExpr = new EmbracedNode(new ReferenceNode(dimensionName)); - ExpressionNode stepExpr = new ArithmeticNode(deltaExpr, ArithmeticOperator.MULTIPLY, dimExpr); - ExpressionNode addExpr = new ArithmeticNode(startExpr, ArithmeticOperator.PLUS, stepExpr); + ExpressionNode stepExpr = new OperationNode(deltaExpr, Operator.MULTIPLY, dimExpr); + ExpressionNode addExpr = new OperationNode(startExpr, Operator.PLUS, stepExpr); TensorFunction<Reference> function = Generate.bound(type.type(), wrapScalar(addExpr)); return function; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index 7b675fa79af..ce93461bff3 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -6,13 +6,11 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import ai.vespa.rankingexpression.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.Function; -import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.Tensor; @@ -159,13 +157,13 @@ public class Reshape extends IntermediateOperation { inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero)); } else if (dim == (inputType.rank() - 1)) { ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); - ExpressionNode div = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, size); + ExpressionNode div = new OperationNode(unrolled, Operator.MODULO, size); inputDimensionExpression = new EmbracedNode(div); } else { ExpressionNode size = new ConstantNode(new DoubleValue(innerSize)); ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize)); - ExpressionNode mod = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, previousSize); - ExpressionNode div = new ArithmeticNode(new EmbracedNode(mod), ArithmeticOperator.DIVIDE, size); + ExpressionNode mod = new OperationNode(unrolled, Operator.MODULO, previousSize); + ExpressionNode div = new OperationNode(new EmbracedNode(mod), Operator.DIVIDE, size); inputDimensionExpression = new EmbracedNode(div); } dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression))); @@ -183,21 +181,21 @@ public class Reshape extends IntermediateOperation { return new ConstantNode(DoubleValue.zero); List<ExpressionNode> children = new ArrayList<>(); - List<ArithmeticOperator> operators = new ArrayList<>(); + List<Operator> operators = new ArrayList<>(); int size = 1; for (int i = type.dimensions().size() - 1; i >= 0; --i) { TensorType.Dimension dimension = type.dimensions().get(i); children.add(0, new ReferenceNode(dimension.name())); if (size > 1) { - operators.add(0, ArithmeticOperator.MULTIPLY); + operators.add(0, Operator.MULTIPLY); children.add(0, new ConstantNode(new DoubleValue(size))); } size *= OrderedTensorType.dimensionSize(dimension); if (i > 0) { - operators.add(0, ArithmeticOperator.PLUS); + operators.add(0, Operator.PLUS); } } - return new ArithmeticNode(children, operators); + return new OperationNode(children, operators); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java index 91b7064b19c..617e1f00c94 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java @@ -6,8 +6,8 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -166,8 +166,8 @@ public class Slice extends IntermediateOperation { // step * (d0 + start) ExpressionNode reference = new ReferenceNode(outputDimensionName); - ExpressionNode plus = new EmbracedNode(new ArithmeticNode(reference, ArithmeticOperator.PLUS, startIndex)); - ExpressionNode mul = new ArithmeticNode(stepSize, ArithmeticOperator.MULTIPLY, plus); + ExpressionNode plus = new EmbracedNode(new OperationNode(reference, Operator.PLUS, startIndex)); + ExpressionNode mul = new OperationNode(stepSize, Operator.MULTIPLY, plus); dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mul)))); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java index 6f720716adb..42901259821 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java @@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -96,7 +96,7 @@ public class Split extends IntermediateOperation { for (int i = 0; i < inputType.rank(); ++i) { String inputDimensionName = inputType.dimensions().get(i).name(); ExpressionNode reference = new ReferenceNode(inputDimensionName); - ExpressionNode offset = new ArithmeticNode(reference, ArithmeticOperator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0))); + ExpressionNode offset = new OperationNode(reference, Operator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0))); dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(offset)))); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java index 4bfab284cc2..cd88c625d81 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java @@ -4,8 +4,8 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -77,7 +77,7 @@ public class Tile extends IntermediateOperation { ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); ExpressionNode reference = new ReferenceNode(inputDimensionName); - ExpressionNode mod = new ArithmeticNode(reference, ArithmeticOperator.MODULO, size); + ExpressionNode mod = new OperationNode(reference, Operator.MODULO, size); dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mod)))); } |