diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-05-19 12:03:06 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-05-19 12:03:06 +0200 |
commit | 5c24dc5c9642a8d9ed70aee4c950fd0678a1ebec (patch) | |
tree | bd9b74bf00c832456f0b83c1b2cd7010be387d68 /config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java | |
parent | f17c4fe7de4c55f5c4ee61897eab8c2f588d8405 (diff) |
Rename the 'searchdefinition' package to 'schema'
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java | 361 |
1 files changed, 361 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java new file mode 100644 index 00000000000..c6c807f2dbb --- /dev/null +++ b/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java @@ -0,0 +1,361 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.schema; + +import com.google.common.collect.ImmutableMap; +import com.yahoo.schema.expressiontransforms.OnnxModelTransformer; +import com.yahoo.schema.expressiontransforms.TokenTransformer; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext; +import com.yahoo.searchlib.rankingexpression.rule.NameNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.ArrayDeque; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.SortedSet; +import java.util.TreeSet; +import java.util.stream.Collectors; + +/** + * A context which only contains type information. + * This returns empty tensor types (double) for unknown features which are not + * query, attribute or constant features, as we do not have information about which such + * features exist (but we know those that exist are doubles). + * + * This is not multithread safe. + * + * @author bratseth + */ +public class MapEvaluationTypeContext extends FunctionReferenceContext implements TypeContext<Reference> { + + private final Optional<MapEvaluationTypeContext> parent; + + private final Map<Reference, TensorType> featureTypes = new HashMap<>(); + + private final Map<Reference, TensorType> resolvedTypes = new HashMap<>(); + + /** To avoid re-resolving diamond-shaped dependencies */ + private final Map<Reference, TensorType> globallyResolvedTypes; + + /** For invocation loop detection */ + private final Deque<Reference> currentResolutionCallStack; + + private final SortedSet<Reference> queryFeaturesNotDeclared; + private boolean tensorsAreUsed; + + MapEvaluationTypeContext(ImmutableMap<String, ExpressionFunction> functions, Map<Reference, TensorType> featureTypes) { + super(functions); + this.parent = Optional.empty(); + this.featureTypes.putAll(featureTypes); + this.currentResolutionCallStack = new ArrayDeque<>(); + this.queryFeaturesNotDeclared = new TreeSet<>(); + tensorsAreUsed = false; + globallyResolvedTypes = new HashMap<>(); + } + + private MapEvaluationTypeContext(Map<String, ExpressionFunction> functions, + Map<String, String> bindings, + Optional<MapEvaluationTypeContext> parent, + Map<Reference, TensorType> featureTypes, + Deque<Reference> currentResolutionCallStack, + SortedSet<Reference> queryFeaturesNotDeclared, + boolean tensorsAreUsed, + Map<Reference, TensorType> globallyResolvedTypes) { + super(functions, bindings); + this.parent = parent; + this.featureTypes.putAll(featureTypes); + this.currentResolutionCallStack = currentResolutionCallStack; + this.queryFeaturesNotDeclared = queryFeaturesNotDeclared; + this.tensorsAreUsed = tensorsAreUsed; + this.globallyResolvedTypes = globallyResolvedTypes; + } + + public void setType(Reference reference, TensorType type) { + featureTypes.put(reference, type); + queryFeaturesNotDeclared.remove(reference); + } + + public Map<Reference, TensorType> featureTypes() { return Collections.unmodifiableMap(featureTypes); } + + @Override + public TensorType getType(String reference) { + throw new UnsupportedOperationException("Not able to parse general references from string form"); + } + + public void forgetResolvedTypes() { + resolvedTypes.clear(); + } + + private boolean referenceCanBeResolvedGlobally(Reference reference) { + Optional<ExpressionFunction> function = functionInvocation(reference); + return function.isPresent() && function.get().arguments().size() == 0; + // are there other cases we would like to resolve globally? + } + + @Override + public TensorType getType(Reference reference) { + // computeIfAbsent without concurrent modification due to resolve adding more resolved entries: + boolean canBeResolvedGlobally = referenceCanBeResolvedGlobally(reference); + + TensorType resolvedType = resolvedTypes.get(reference); + if (resolvedType == null && canBeResolvedGlobally) { + resolvedType = globallyResolvedTypes.get(reference); + } + if (resolvedType != null) { + return resolvedType; + } + + resolvedType = resolveType(reference); + if (resolvedType == null) + return defaultTypeOf(reference); // Don't store fallback to default as we may know more later + resolvedTypes.put(reference, resolvedType); + if (resolvedType.rank() > 0) + tensorsAreUsed = true; + + if (canBeResolvedGlobally) { + globallyResolvedTypes.put(reference, resolvedType); + } + + return resolvedType; + } + + MapEvaluationTypeContext getParent(String forArgument, String boundTo) { + return parent.orElseThrow( + () -> new IllegalArgumentException("argument "+forArgument+" is bound to "+boundTo+" but there is no parent context")); + } + + String resolveBinding(String argument) { + String bound = getBinding(argument); + if (bound == null) { + return argument; + } + return getParent(argument, bound).resolveBinding(bound); + } + + private TensorType resolveType(Reference reference) { + if (currentResolutionCallStack.contains(reference)) + throw new IllegalArgumentException("Invocation loop: " + + currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) + + " -> " + reference); + + // Bound to a function argument? + Optional<String> binding = boundIdentifier(reference); + if (binding.isPresent()) { + try { + // This is not pretty, but changing to bind expressions rather + // than their string values requires deeper changes + var expr = new RankingExpression(binding.get()); + var type = expr.type(getParent(reference.name(), binding.get())); + return type; + } catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + + try { + currentResolutionCallStack.addLast(reference); + + // A reference to an attribute, query or constant feature? + if (FeatureNames.isSimpleFeature(reference)) { + // The argument may be a local identifier bound to the actual value + String argument = reference.simpleArgument().get(); + String argumentBinding = resolveBinding(argument); + reference = Reference.simple(reference.name(), argumentBinding); + return featureTypes.get(reference); + } + + // A reference to a function? + Optional<ExpressionFunction> function = functionInvocation(reference); + if (function.isPresent()) { + var body = function.get().getBody(); + var child = this.withBindings(bind(function.get().arguments(), reference.arguments())); + var type = body.type(child); + return type; + } + + // A reference to an ONNX model? + Optional<TensorType> onnxFeatureType = onnxFeatureType(reference); + if (onnxFeatureType.isPresent()) { + return onnxFeatureType.get(); + } + + // A reference to a feature for transformer token input? + Optional<TensorType> transformerTokensFeatureType = transformerTokensFeatureType(reference); + if (transformerTokensFeatureType.isPresent()) { + return transformerTokensFeatureType.get(); + } + + // A reference to a feature which returns a tensor? + Optional<TensorType> featureTensorType = tensorFeatureType(reference); + if (featureTensorType.isPresent()) { + return featureTensorType.get(); + } + + // A directly injected identifier? (Useful for stateless model evaluation) + if (reference.isIdentifier() && featureTypes.containsKey(reference)) { + return featureTypes.get(reference); + } + + // the name of a constant feature? + if (reference.isIdentifier()) { + Reference asConst = FeatureNames.asConstantFeature(reference.name()); + if (featureTypes.containsKey(asConst)) { + return featureTypes.get(asConst); + } + } + + // We do not know what this is - since we do not have complete knowledge about the match features + // in Java we must assume this is a match feature and return the double type - which is the type of + // all match features + return TensorType.empty; + } + finally { + currentResolutionCallStack.removeLast(); + } + } + + /** + * Returns the default type for this simple feature, or null if it does not have a default + */ + public TensorType defaultTypeOf(Reference reference) { + if ( ! FeatureNames.isSimpleFeature(reference)) + throw new IllegalArgumentException("This can only be called for simple references, not " + reference); + if (reference.name().equals("query")) { // we do not require all query features to be declared, only non-doubles + queryFeaturesNotDeclared.add(reference); + return TensorType.empty; + } + return null; + } + + /** + * Returns the binding if this reference is a simple identifier which is bound in this context. + * Returns empty otherwise. + */ + private Optional<String> boundIdentifier(Reference reference) { + if ( ! reference.arguments().isEmpty()) return Optional.empty(); + if ( reference.output() != null) return Optional.empty(); + return Optional.ofNullable(getBinding(reference.name())); + } + + private Optional<ExpressionFunction> functionInvocation(Reference reference) { + if (reference.output() != null) return Optional.empty(); + ExpressionFunction function = getFunctions().get(reference.name()); + if (function == null) return Optional.empty(); + if (function.arguments().size() != reference.arguments().size()) return Optional.empty(); + return Optional.of(function); + } + + private Optional<TensorType> onnxFeatureType(Reference reference) { + if ( ! reference.name().equals("onnxModel") && ! reference.name().equals("onnx")) + return Optional.empty(); + + if ( ! featureTypes.containsKey(reference)) { + String configOrFileName = reference.arguments().expressions().get(0).toString(); + + // Look up standardized format as added in RankProfile + String modelConfigName = OnnxModelTransformer.getModelConfigName(reference); + String modelOutput = OnnxModelTransformer.getModelOutput(reference, null); + + reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput); + if ( ! featureTypes.containsKey(reference)) { + throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'"); + } + } + + return Optional.of(featureTypes.get(reference)); + } + + private Optional<TensorType> transformerTokensFeatureType(Reference reference) { + if ( ! reference.name().equals("tokenTypeIds") && + ! reference.name().equals("tokenInputIds") && + ! reference.name().equals("tokenAttentionMask")) + return Optional.empty(); + + if ( ! (reference.arguments().size() > 1)) + throw new IllegalArgumentException(reference.name() + " must have at least 2 arguments"); + + ExpressionNode size = reference.arguments().expressions().get(0); + return Optional.of(TokenTransformer.createTensorType(reference.name(), size)); + } + + /** + * There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet. + * This returns the type of those features if this is a reference to either of them, or empty otherwise. + */ + private Optional<TensorType> tensorFeatureType(Reference reference) { + if ( ! reference.name().equals("tensorFromLabels") && ! reference.name().equals("tensorFromWeightedSet")) + return Optional.empty(); + + if (reference.arguments().size() != 1 && reference.arguments().size() != 2) + throw new IllegalArgumentException(reference.name() + " must have one or two arguments"); + + ExpressionNode arg0 = reference.arguments().expressions().get(0); + if ( ! ( arg0 instanceof ReferenceNode) || ! FeatureNames.isSimpleFeature(((ReferenceNode)arg0).reference())) + throw new IllegalArgumentException("The first argument of " + reference.name() + + " must be a simple feature, not " + arg0); + + String dimension; + if (reference.arguments().size() > 1) { + ExpressionNode arg1 = reference.arguments().expressions().get(1); + if ( ( ! (arg1 instanceof ReferenceNode) || ! (((ReferenceNode)arg1).reference().isIdentifier())) + && + ( ! (arg1 instanceof NameNode))) + throw new IllegalArgumentException("The second argument of " + reference.name() + + " must be a dimension name, not " + arg1); + dimension = reference.arguments().expressions().get(1).toString(); + } + else { // default + dimension = ((ReferenceNode)arg0).reference().arguments().expressions().get(0).toString(); + } + + // TODO: Determine the type of the weighted set/vector and use that as value type + return Optional.of(new TensorType.Builder().mapped(dimension).build()); + } + + /** Binds the given list of formal arguments to their actual values */ + private Map<String, String> bind(List<String> formalArguments, + Arguments invocationArguments) { + Map<String, String> bindings = new HashMap<>(formalArguments.size()); + for (int i = 0; i < formalArguments.size(); i++) { + String identifier = invocationArguments.expressions().get(i).toString(); + bindings.put(formalArguments.get(i), identifier); + } + return bindings; + } + + /** + * Returns an unmodifiable view of the query features which was requested but for which we have no type info + * (such that they default to TensorType.empty), shared between all instances of this + * involved in resolving a particular rank profile. + */ + public SortedSet<Reference> queryFeaturesNotDeclared() { + return Collections.unmodifiableSortedSet(queryFeaturesNotDeclared); + } + + /** Returns true if any feature across all instances involved in resolving this rank profile resolves to a tensor */ + public boolean tensorsAreUsed() { return tensorsAreUsed; } + + @Override + public MapEvaluationTypeContext withBindings(Map<String, String> bindings) { + return new MapEvaluationTypeContext(getFunctions(), + bindings, + Optional.of(this), + featureTypes, + currentResolutionCallStack, + queryFeaturesNotDeclared, + tensorsAreUsed, + globallyResolvedTypes); + } + +} |