From ee09d8d3513273a6c5348e2ed254aea47ff8c23e Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Thu, 23 Nov 2023 08:37:30 +0000 Subject: add special handling of "closest" feature --- .../com/yahoo/schema/MapEvaluationTypeContext.java | 30 ++++++++++++++-- .../src/test/derived/tensor/rank-profiles.cfg | 41 ++++++++++++++++++++++ config-model/src/test/derived/tensor/tensor.sd | 17 +++++++++ 3 files changed, 86 insertions(+), 2 deletions(-) diff --git a/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java index f75bdec111e..2a8dd49a0c1 100644 --- a/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java @@ -290,17 +290,43 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement } /** - * There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet. + * There are three features which may return some (non-empty) tensor type: + * - tensorFromLabels + * - tensorFromWeightedSet + * - closest * This returns the type of those features if this is a reference to either of them, or empty otherwise. */ private Optional tensorFeatureType(Reference reference) { - if ( ! reference.name().equals("tensorFromLabels") && ! reference.name().equals("tensorFromWeightedSet")) + if ( ! reference.name().equals("tensorFromLabels") && + ! reference.name().equals("tensorFromWeightedSet") && + ! reference.name().equals("closest")) + { return Optional.empty(); + } if (reference.arguments().size() != 1 && reference.arguments().size() != 2) throw new IllegalArgumentException(reference.name() + " must have one or two arguments"); ExpressionNode arg0 = reference.arguments().expressions().get(0); + if (reference.name().equals("closest")) { + if (arg0 instanceof ReferenceNode argRefNode) { + var argRef = argRefNode.reference(); + if (argRef.isIdentifier()) { + var attrFeature = FeatureNames.asAttributeFeature(argRef.name()); + TensorType attrTT = featureTypes.get(attrFeature); + if (attrTT != null && attrTT.rank() > 0) { + TensorType mapped = attrTT.mappedSubtype(); + if (mapped.rank() > 0) { + return Optional.of(mapped); + } else { + throw new IllegalArgumentException("Unexpected tensor type " + attrTT + " for " + attrFeature + " used by " + reference); + } + } + } + } + throw new IllegalArgumentException("The first argument of " + reference.name() + + " must be the name of a tensor attribute, not " + arg0); + } if ( ! ( arg0 instanceof ReferenceNode) || ! FeatureNames.isSimpleFeature(((ReferenceNode)arg0).reference())) throw new IllegalArgumentException("The first argument of " + reference.name() + " must be a simple feature, not " + arg0); diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg index a72e9dc27cd..92c02cc768c 100644 --- a/config-model/src/test/derived/tensor/rank-profiles.cfg +++ b/config-model/src/test/derived/tensor/rank-profiles.cfg @@ -225,3 +225,44 @@ rankprofile[].fef.property[].name "vespa.type.query.para" rankprofile[].fef.property[].value "tensor(p{})" rankprofile[].fef.property[].name "vespa.type.query.qvec" rankprofile[].fef.property[].value "tensor(x[40])" +rankprofile[].name "with-closest-one" +rankprofile[].fef.property[].name "rankingExpression(dot_products).rankingScript" +rankprofile[].fef.property[].value "reduce(query(qvec) * attribute(f7), sum, x)" +rankprofile[].fef.property[].name "rankingExpression(dot_products).type" +rankprofile[].fef.property[].value "tensor(p{})" +rankprofile[].fef.property[].name "vespa.type.feature.closest(f7,foobarbaz)" +rankprofile[].fef.property[].value "tensor(p{})" +rankprofile[].fef.property[].name "vespa.type.feature.closest(f7)" +rankprofile[].fef.property[].value "tensor(p{})" +rankprofile[].fef.property[].name "vespa.type.feature.attribute(f7)" +rankprofile[].fef.property[].value "tensor(p{},x[5])" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "rankingExpression(firstphase)" +rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" +rankprofile[].fef.property[].value "reduce(rankingExpression(dot_products), max, p)" +rankprofile[].fef.property[].name "vespa.rank.globalphase" +rankprofile[].fef.property[].value "rankingExpression(globalphase)" +rankprofile[].fef.property[].name "rankingExpression(globalphase).rankingScript" +rankprofile[].fef.property[].value "reduce(closest(f7) * attribute(f7) * query(qvec), sum)" +rankprofile[].fef.property[].name "vespa.match.feature" +rankprofile[].fef.property[].value "closest(f7,foobarbaz)" +rankprofile[].fef.property[].name "vespa.match.feature" +rankprofile[].fef.property[].value "closest(f7)" +rankprofile[].fef.property[].name "vespa.match.feature" +rankprofile[].fef.property[].value "attribute(f7)" +rankprofile[].fef.property[].name "vespa.hidden.matchfeature" +rankprofile[].fef.property[].value "closest(f7)" +rankprofile[].fef.property[].name "vespa.hidden.matchfeature" +rankprofile[].fef.property[].value "attribute(f7)" +rankprofile[].fef.property[].name "vespa.type.attribute.f7" +rankprofile[].fef.property[].value "tensor(p{},x[5])" +rankprofile[].fef.property[].name "vespa.type.attribute.f2" +rankprofile[].fef.property[].value "tensor(x[2],y[1])" +rankprofile[].fef.property[].name "vespa.type.attribute.f3" +rankprofile[].fef.property[].value "tensor(x{})" +rankprofile[].fef.property[].name "vespa.type.attribute.f4" +rankprofile[].fef.property[].value "tensor(x[10],y[10])" +rankprofile[].fef.property[].name "vespa.type.attribute.f5" +rankprofile[].fef.property[].value "tensor(x[10])" +rankprofile[].fef.property[].name "vespa.type.query.qvec" +rankprofile[].fef.property[].value "tensor(x[5])" diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd index a0f5cd92c56..7dc61eba06d 100644 --- a/config-model/src/test/derived/tensor/tensor.sd +++ b/config-model/src/test/derived/tensor/tensor.sd @@ -145,4 +145,21 @@ schema tensor { } } + rank-profile with-closest-one { + inputs { + query(qvec) tensor(x[5]) + } + function dot_products() { + expression: reduce(query(qvec)*attribute(f7), sum, x) + } + first-phase { + expression: reduce(dot_products, max, p) + } + global-phase { + expression: sum(closest(f7)*attribute(f7)*query(qvec)) + } + match-features { + closest(f7,foobarbaz) + } + } } -- cgit v1.2.3