diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition')
3 files changed, 32 insertions, 14 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index 6109e5c4aae..a54e21aae68 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -42,8 +42,9 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement /** For invocation loop detection */ private final Deque<Reference> currentResolutionCallStack; - MapEvaluationTypeContext(Collection<ExpressionFunction> functions) { + MapEvaluationTypeContext(Collection<ExpressionFunction> functions, Map<Reference, TensorType> featureTypes) { super(functions); + this.featureTypes.putAll(featureTypes); this.currentResolutionCallStack = new ArrayDeque<>(); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 34277b88252..1283da20395 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -675,15 +675,16 @@ public class RankProfile implements Cloneable { checkNameCollisions(getFunctions(), getConstants()); ExpressionTransforms expressionTransforms = new ExpressionTransforms(); + Map<Reference, TensorType> featureTypes = collectFeatureTypes(); // Function compiling first pass: compile inline functions without resolving other functions Map<String, RankingExpressionFunction> inlineFunctions = - compileFunctions(this::getInlineFunctions, queryProfiles, importedModels, Collections.emptyMap(), expressionTransforms); + compileFunctions(this::getInlineFunctions, queryProfiles, featureTypes, importedModels, Collections.emptyMap(), expressionTransforms); // Function compiling second pass: compile all functions and insert previously compiled inline functions - functions = compileFunctions(this::getFunctions, queryProfiles, importedModels, inlineFunctions, expressionTransforms); + functions = compileFunctions(this::getFunctions, queryProfiles, featureTypes, importedModels, inlineFunctions, expressionTransforms); - firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, importedModels, getConstants(), inlineFunctions, expressionTransforms); - secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, importedModels, getConstants(), inlineFunctions, expressionTransforms); + firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); + secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); } private void checkNameCollisions(Map<String, RankingExpressionFunction> functions, Map<String, Value> constants) { @@ -701,6 +702,7 @@ public class RankProfile implements Cloneable { private Map<String, RankingExpressionFunction> compileFunctions(Supplier<Map<String, RankingExpressionFunction>> functions, QueryProfileRegistry queryProfiles, + Map<Reference, TensorType> featureTypes, ImportedMlModels importedModels, Map<String, RankingExpressionFunction> inlineFunctions, ExpressionTransforms expressionTransforms) { @@ -711,7 +713,7 @@ public class RankProfile implements Cloneable { // A straightforward iteration will either miss those functions, or may cause a ConcurrentModificationException while (null != (entry = findUncompiledFunction(functions.get(), compiledFunctions.keySet()))) { RankingExpressionFunction rankingExpressionFunction = entry.getValue(); - RankingExpression compiled = compile(rankingExpressionFunction.function().getBody(), queryProfiles, + RankingExpression compiled = compile(rankingExpressionFunction.function().getBody(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); compiledFunctions.put(entry.getKey(), rankingExpressionFunction.withExpression(compiled)); } @@ -729,6 +731,7 @@ public class RankProfile implements Cloneable { private RankingExpression compile(RankingExpression expression, QueryProfileRegistry queryProfiles, + Map<Reference, TensorType> featureTypes, ImportedMlModels importedModels, Map<String, Value> constants, Map<String, RankingExpressionFunction> inlineFunctions, @@ -736,6 +739,7 @@ public class RankProfile implements Cloneable { if (expression == null) return null; RankProfileTransformContext context = new RankProfileTransformContext(this, queryProfiles, + featureTypes, importedModels, constants, inlineFunctions); @@ -751,18 +755,28 @@ public class RankProfile implements Cloneable { * referable from this rank profile. */ public MapEvaluationTypeContext typeContext(QueryProfileRegistry queryProfiles) { + + return typeContext(queryProfiles, collectFeatureTypes()); + } + + private Map<Reference, TensorType> collectFeatureTypes() { + Map<Reference, TensorType> featureTypes = new HashMap<>(); + // Add attributes + allFields().forEach(field -> addAttributeFeatureTypes(field, featureTypes)); + allImportedFields().forEach(field -> addAttributeFeatureTypes(field, featureTypes)); + return featureTypes; + } + + public MapEvaluationTypeContext typeContext(QueryProfileRegistry queryProfiles, Map<Reference, TensorType> featureTypes) { MapEvaluationTypeContext context = new MapEvaluationTypeContext(getFunctions().values().stream() .map(RankingExpressionFunction::function) - .collect(Collectors.toList())); + .collect(Collectors.toList()), + featureTypes); // Add small and large constants, respectively getConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.type())); rankingConstants().asMap().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.getTensorType())); - // Add attributes - allFields().forEach(field -> addAttributeFeatureTypes(field, context)); - allImportedFields().forEach(field -> addAttributeFeatureTypes(field, context)); - // Add query features from rank profile types reached from the "default" profile for (QueryProfileType queryProfileType : queryProfiles.getTypeRegistry().allComponents()) { for (FieldDescription field : queryProfileType.declaredFields().values()) { @@ -785,13 +799,13 @@ public class RankProfile implements Cloneable { return context; } - private void addAttributeFeatureTypes(ImmutableSDField field, MapEvaluationTypeContext context) { + private void addAttributeFeatureTypes(ImmutableSDField field, Map<Reference, TensorType> featureTypes) { Attribute attribute = field.getAttribute(); field.getAttributes().forEach((k, a) -> { String name = k; if (attribute == a) // this attribute should take the fields name name = field.getName(); // switch to that - it is separate for imported fields - context.setType(FeatureNames.asAttributeFeature(name), + featureTypes.put(FeatureNames.asAttributeFeature(name), a.tensorType().orElse(TensorType.empty)); }); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java index c76b8536ea0..a12b06624cf 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java @@ -4,8 +4,10 @@ package com.yahoo.searchdefinition.expressiontransforms; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.transform.TransformContext; +import com.yahoo.tensor.TensorType; import java.util.HashMap; import java.util.Map; @@ -25,10 +27,11 @@ public class RankProfileTransformContext extends TransformContext { public RankProfileTransformContext(RankProfile rankProfile, QueryProfileRegistry queryProfiles, + Map<Reference, TensorType> featureTypes, ImportedMlModels importedModels, Map<String, Value> constants, Map<String, RankProfile.RankingExpressionFunction> inlineFunctions) { - super(constants, rankProfile.typeContext(queryProfiles)); + super(constants, rankProfile.typeContext(queryProfiles, featureTypes)); this.rankProfile = rankProfile; this.queryProfiles = queryProfiles; this.importedModels = importedModels; |