aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-06-27 10:49:03 +0200
committerLester Solbakken <lesters@oath.com>2020-06-27 10:49:03 +0200
commitfea65d8cef74e124b340408f094ff277e000abe8 (patch)
tree6c48ff0ee2803fc868f03e9826dec84f5bcbb673 /model-integration
parent0a8b5894dfc442d661836fce4ddb6c870bcc0ec0 (diff)
Gather: simplify expression for constant scalar indices
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java19
1 files changed, 12 insertions, 7 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
index 91ff5d9cdd8..3cb51b87104 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
@@ -87,19 +87,24 @@ public class Gather extends IntermediateOperation {
addSliceDimension(dataSliceDimensions, dataType.dimensions().get(i).name(), i);
}
- List<Slice.DimensionValue<Reference>> indicesSliceDimensions = new ArrayList<>();
- for (int i = 0; i < indicesType.rank(); ++i) {
- addSliceDimension(indicesSliceDimensions, indicesType.dimensions().get(i).name(), axis + i);
+ if (indicesType.rank() == 0 && indices.isConstant()) {
+ ExpressionNode indexExpression = new ConstantNode(new DoubleValue(indices.getConstantValue().get().asDouble()));
+ addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression);
+ } else {
+ List<Slice.DimensionValue<Reference>> indicesSliceDimensions = new ArrayList<>();
+ for (int i = 0; i < indicesType.rank(); ++i) {
+ addSliceDimension(indicesSliceDimensions, indicesType.dimensions().get(i).name(), axis + i);
+ }
+ ExpressionNode sliceExpression = createSliceExpression(indicesSliceDimensions, indicesFunctionName);
+ ExpressionNode indexExpression = createIndexExpression(dataType, sliceExpression);
+ addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression);
}
- ExpressionNode sliceExpression = createSliceExpression(indicesSliceDimensions, indicesFunctionName);
- ExpressionNode indexExpression = createIndexExpression(dataType, sliceExpression);
- addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression);
for (int i = axis + 1; i < dataType.rank(); ++i) {
addSliceDimension(dataSliceDimensions, dataType.dimensions().get(i).name(), i + indicesType.rank() - 1);
}
- sliceExpression = createSliceExpression(dataSliceDimensions, dataFunctionName);
+ ExpressionNode sliceExpression = createSliceExpression(dataSliceDimensions, dataFunctionName);
return Generate.bound(type.type(), wrapScalar(sliceExpression));
}