aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-06-29 09:18:55 +0200
committerLester Solbakken <lesters@oath.com>2020-06-29 09:18:55 +0200
commit78254a58c2e905de2fbdab4088cbf2b7c3ec1b32 (patch)
tree178311e5559c85cb20783fd4f36ec40ea62dbb05 /model-integration
parentfea65d8cef74e124b340408f094ff277e000abe8 (diff)
Gather: Support negative indexing with a scalar constant
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java7
1 files changed, 6 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
index 3cb51b87104..d67f064916b 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
@@ -88,7 +88,12 @@ public class Gather extends IntermediateOperation {
}
if (indicesType.rank() == 0 && indices.isConstant()) {
- ExpressionNode indexExpression = new ConstantNode(new DoubleValue(indices.getConstantValue().get().asDouble()));
+ double constantValue = indices.getConstantValue().get().asDouble();
+ ExpressionNode indexExpression = new ConstantNode(new DoubleValue(constantValue));
+ if (constantValue < 0) {
+ ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get()));
+ indexExpression = new EmbracedNode(new ArithmeticNode(indexExpression, ArithmeticOperator.PLUS, axisSize));
+ }
addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression);
} else {
List<Slice.DimensionValue<Reference>> indicesSliceDimensions = new ArrayList<>();