diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java | 20 |
1 files changed, 9 insertions, 11 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index 7b675fa79af..bc94fc6aa76 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -6,13 +6,11 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import ai.vespa.rankingexpression.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.Function; -import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.Tensor; @@ -159,13 +157,13 @@ public class Reshape extends IntermediateOperation { inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero)); } else if (dim == (inputType.rank() - 1)) { ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); - ExpressionNode div = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, size); + ExpressionNode div = new OperationNode(unrolled, Operator.modulo, size); inputDimensionExpression = new EmbracedNode(div); } else { ExpressionNode size = new ConstantNode(new DoubleValue(innerSize)); ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize)); - ExpressionNode mod = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, previousSize); - ExpressionNode div = new ArithmeticNode(new EmbracedNode(mod), ArithmeticOperator.DIVIDE, size); + ExpressionNode mod = new OperationNode(unrolled, Operator.modulo, previousSize); + ExpressionNode div = new OperationNode(new EmbracedNode(mod), Operator.divide, size); inputDimensionExpression = new EmbracedNode(div); } dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression))); @@ -183,21 +181,21 @@ public class Reshape extends IntermediateOperation { return new ConstantNode(DoubleValue.zero); List<ExpressionNode> children = new ArrayList<>(); - List<ArithmeticOperator> operators = new ArrayList<>(); + List<Operator> operators = new ArrayList<>(); int size = 1; for (int i = type.dimensions().size() - 1; i >= 0; --i) { TensorType.Dimension dimension = type.dimensions().get(i); children.add(0, new ReferenceNode(dimension.name())); if (size > 1) { - operators.add(0, ArithmeticOperator.MULTIPLY); + operators.add(0, Operator.multiply); children.add(0, new ConstantNode(new DoubleValue(size))); } size *= OrderedTensorType.dimensionSize(dimension); if (i > 0) { - operators.add(0, ArithmeticOperator.PLUS); + operators.add(0, Operator.plus); } } - return new ArithmeticNode(children, operators); + return new OperationNode(children, operators); } @Override |