summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java20
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java6
7 files changed, 34 insertions, 36 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
index cd0c4da6d0f..7ac2f4bff84 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
@@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -92,7 +92,7 @@ public class Gather extends IntermediateOperation {
ExpressionNode indexExpression = new ConstantNode(new DoubleValue(constantValue));
if (constantValue < 0) {
ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get()));
- indexExpression = new EmbracedNode(new ArithmeticNode(indexExpression, ArithmeticOperator.PLUS, axisSize));
+ indexExpression = new EmbracedNode(new OperationNode(indexExpression, Operator.plus, axisSize));
}
addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression);
} else {
@@ -125,8 +125,8 @@ public class Gather extends IntermediateOperation {
/** to support negative indexing */
private ExpressionNode createIndexExpression(OrderedTensorType dataType, ExpressionNode slice) {
ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get()));
- ExpressionNode plus = new EmbracedNode(new ArithmeticNode(slice, ArithmeticOperator.PLUS, axisSize));
- ExpressionNode mod = new ArithmeticNode(plus, ArithmeticOperator.MODULO, axisSize);
+ ExpressionNode plus = new EmbracedNode(new OperationNode(slice, Operator.plus, axisSize));
+ ExpressionNode mod = new OperationNode(plus, Operator.modulo, axisSize);
return mod;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
index 1f447f2a575..97bfdda385e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
@@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
@@ -97,17 +97,17 @@ public class Gemm extends IntermediateOperation {
TensorFunction<Reference> AxB = new com.yahoo.tensor.functions.Matmul<>(aFunction.get(), bFunction.get(), joinDimension);
TensorFunction<Reference> alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction(
- new ArithmeticNode(
+ new OperationNode(
new TensorFunctionNode(AxB),
- ArithmeticOperator.MULTIPLY,
+ Operator.multiply,
new ConstantNode(new DoubleValue(alpha))));
if (inputs.size() == 3) {
Optional<TensorFunction<Reference>> cFunction = inputs.get(2).function();
TensorFunction<Reference> betaxC = new TensorFunctionNode.ExpressionTensorFunction(
- new ArithmeticNode(
+ new OperationNode(
new TensorFunctionNode(cFunction.get()),
- ArithmeticOperator.MULTIPLY,
+ Operator.multiply,
new ConstantNode(new DoubleValue(beta))));
return new com.yahoo.tensor.functions.Join<>(alphaxAxB, betaxC, ScalarFunctions.add());
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
index 5c4e8cd6cd0..9a38ab9dfde 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
@@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -65,8 +65,8 @@ public class Range extends IntermediateOperation {
ExpressionNode startExpr = new ConstantNode(new DoubleValue(start));
ExpressionNode deltaExpr = new ConstantNode(new DoubleValue(delta));
ExpressionNode dimExpr = new EmbracedNode(new ReferenceNode(dimensionName));
- ExpressionNode stepExpr = new ArithmeticNode(deltaExpr, ArithmeticOperator.MULTIPLY, dimExpr);
- ExpressionNode addExpr = new ArithmeticNode(startExpr, ArithmeticOperator.PLUS, stepExpr);
+ ExpressionNode stepExpr = new OperationNode(deltaExpr, Operator.multiply, dimExpr);
+ ExpressionNode addExpr = new OperationNode(startExpr, Operator.plus, stepExpr);
TensorFunction<Reference> function = Generate.bound(type.type(), wrapScalar(addExpr));
return function;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index 7b675fa79af..bc94fc6aa76 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -6,13 +6,11 @@ import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
@@ -159,13 +157,13 @@ public class Reshape extends IntermediateOperation {
inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero));
} else if (dim == (inputType.rank() - 1)) {
ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize));
- ExpressionNode div = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, size);
+ ExpressionNode div = new OperationNode(unrolled, Operator.modulo, size);
inputDimensionExpression = new EmbracedNode(div);
} else {
ExpressionNode size = new ConstantNode(new DoubleValue(innerSize));
ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize));
- ExpressionNode mod = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, previousSize);
- ExpressionNode div = new ArithmeticNode(new EmbracedNode(mod), ArithmeticOperator.DIVIDE, size);
+ ExpressionNode mod = new OperationNode(unrolled, Operator.modulo, previousSize);
+ ExpressionNode div = new OperationNode(new EmbracedNode(mod), Operator.divide, size);
inputDimensionExpression = new EmbracedNode(div);
}
dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression)));
@@ -183,21 +181,21 @@ public class Reshape extends IntermediateOperation {
return new ConstantNode(DoubleValue.zero);
List<ExpressionNode> children = new ArrayList<>();
- List<ArithmeticOperator> operators = new ArrayList<>();
+ List<Operator> operators = new ArrayList<>();
int size = 1;
for (int i = type.dimensions().size() - 1; i >= 0; --i) {
TensorType.Dimension dimension = type.dimensions().get(i);
children.add(0, new ReferenceNode(dimension.name()));
if (size > 1) {
- operators.add(0, ArithmeticOperator.MULTIPLY);
+ operators.add(0, Operator.multiply);
children.add(0, new ConstantNode(new DoubleValue(size)));
}
size *= OrderedTensorType.dimensionSize(dimension);
if (i > 0) {
- operators.add(0, ArithmeticOperator.PLUS);
+ operators.add(0, Operator.plus);
}
}
- return new ArithmeticNode(children, operators);
+ return new OperationNode(children, operators);
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
index 91b7064b19c..916e9980131 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
@@ -6,8 +6,8 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -166,8 +166,8 @@ public class Slice extends IntermediateOperation {
// step * (d0 + start)
ExpressionNode reference = new ReferenceNode(outputDimensionName);
- ExpressionNode plus = new EmbracedNode(new ArithmeticNode(reference, ArithmeticOperator.PLUS, startIndex));
- ExpressionNode mul = new ArithmeticNode(stepSize, ArithmeticOperator.MULTIPLY, plus);
+ ExpressionNode plus = new EmbracedNode(new OperationNode(reference, Operator.plus, startIndex));
+ ExpressionNode mul = new OperationNode(stepSize, Operator.multiply, plus);
dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mul))));
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
index 6f720716adb..ef46d222941 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
@@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -96,7 +96,7 @@ public class Split extends IntermediateOperation {
for (int i = 0; i < inputType.rank(); ++i) {
String inputDimensionName = inputType.dimensions().get(i).name();
ExpressionNode reference = new ReferenceNode(inputDimensionName);
- ExpressionNode offset = new ArithmeticNode(reference, ArithmeticOperator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0)));
+ ExpressionNode offset = new OperationNode(reference, Operator.plus, new ConstantNode(new DoubleValue(i == axis ? start : 0)));
dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(offset))));
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
index 4bfab284cc2..a880bff87be 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
@@ -4,8 +4,8 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -77,7 +77,7 @@ public class Tile extends IntermediateOperation {
ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize));
ExpressionNode reference = new ReferenceNode(inputDimensionName);
- ExpressionNode mod = new ArithmeticNode(reference, ArithmeticOperator.MODULO, size);
+ ExpressionNode mod = new OperationNode(reference, Operator.modulo, size);
dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mod))));
}