summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java20
1 files changed, 9 insertions, 11 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index 7b675fa79af..bc94fc6aa76 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -6,13 +6,11 @@ import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
@@ -159,13 +157,13 @@ public class Reshape extends IntermediateOperation {
inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero));
} else if (dim == (inputType.rank() - 1)) {
ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize));
- ExpressionNode div = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, size);
+ ExpressionNode div = new OperationNode(unrolled, Operator.modulo, size);
inputDimensionExpression = new EmbracedNode(div);
} else {
ExpressionNode size = new ConstantNode(new DoubleValue(innerSize));
ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize));
- ExpressionNode mod = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, previousSize);
- ExpressionNode div = new ArithmeticNode(new EmbracedNode(mod), ArithmeticOperator.DIVIDE, size);
+ ExpressionNode mod = new OperationNode(unrolled, Operator.modulo, previousSize);
+ ExpressionNode div = new OperationNode(new EmbracedNode(mod), Operator.divide, size);
inputDimensionExpression = new EmbracedNode(div);
}
dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression)));
@@ -183,21 +181,21 @@ public class Reshape extends IntermediateOperation {
return new ConstantNode(DoubleValue.zero);
List<ExpressionNode> children = new ArrayList<>();
- List<ArithmeticOperator> operators = new ArrayList<>();
+ List<Operator> operators = new ArrayList<>();
int size = 1;
for (int i = type.dimensions().size() - 1; i >= 0; --i) {
TensorType.Dimension dimension = type.dimensions().get(i);
children.add(0, new ReferenceNode(dimension.name()));
if (size > 1) {
- operators.add(0, ArithmeticOperator.MULTIPLY);
+ operators.add(0, Operator.multiply);
children.add(0, new ConstantNode(new DoubleValue(size)));
}
size *= OrderedTensorType.dimensionSize(dimension);
if (i > 0) {
- operators.add(0, ArithmeticOperator.PLUS);
+ operators.add(0, Operator.plus);
}
}
- return new ArithmeticNode(children, operators);
+ return new OperationNode(children, operators);
}
@Override