diff options
author | Lester Solbakken <lesters@oath.com> | 2020-06-27 10:49:03 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-06-27 10:49:03 +0200 |
commit | fea65d8cef74e124b340408f094ff277e000abe8 (patch) | |
tree | 6c48ff0ee2803fc868f03e9826dec84f5bcbb673 /model-integration | |
parent | 0a8b5894dfc442d661836fce4ddb6c870bcc0ec0 (diff) |
Gather: simplify expression for constant scalar indices
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java index 91ff5d9cdd8..3cb51b87104 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java @@ -87,19 +87,24 @@ public class Gather extends IntermediateOperation { addSliceDimension(dataSliceDimensions, dataType.dimensions().get(i).name(), i); } - List<Slice.DimensionValue<Reference>> indicesSliceDimensions = new ArrayList<>(); - for (int i = 0; i < indicesType.rank(); ++i) { - addSliceDimension(indicesSliceDimensions, indicesType.dimensions().get(i).name(), axis + i); + if (indicesType.rank() == 0 && indices.isConstant()) { + ExpressionNode indexExpression = new ConstantNode(new DoubleValue(indices.getConstantValue().get().asDouble())); + addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression); + } else { + List<Slice.DimensionValue<Reference>> indicesSliceDimensions = new ArrayList<>(); + for (int i = 0; i < indicesType.rank(); ++i) { + addSliceDimension(indicesSliceDimensions, indicesType.dimensions().get(i).name(), axis + i); + } + ExpressionNode sliceExpression = createSliceExpression(indicesSliceDimensions, indicesFunctionName); + ExpressionNode indexExpression = createIndexExpression(dataType, sliceExpression); + addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression); } - ExpressionNode sliceExpression = createSliceExpression(indicesSliceDimensions, indicesFunctionName); - ExpressionNode indexExpression = createIndexExpression(dataType, sliceExpression); - addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression); for (int i = axis + 1; i < dataType.rank(); ++i) { addSliceDimension(dataSliceDimensions, dataType.dimensions().get(i).name(), i + indicesType.rank() - 1); } - sliceExpression = createSliceExpression(dataSliceDimensions, dataFunctionName); + ExpressionNode sliceExpression = createSliceExpression(dataSliceDimensions, dataFunctionName); return Generate.bound(type.type(), wrapScalar(sliceExpression)); } |