From b34ab2a5656401edd97ea70137a4be31406fb719 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Tue, 26 Mar 2019 15:39:52 +0200 Subject: Optimize type resolving - Cache reference hash code - Cache resolved types --- .../searchdefinition/MapEvaluationTypeContext.java | 20 +++++++++++++++++++- .../java/com/yahoo/searchdefinition/RankProfile.java | 2 +- .../com/yahoo/vespa/model/ml/ConvertedModel.java | 8 +++++--- .../searchdefinition/derived/GeminiTestCase.java | 4 ++-- 4 files changed, 27 insertions(+), 7 deletions(-) (limited to 'config-model') diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index 0d9ea00bf73..a0f35dbefe6 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -37,6 +37,8 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement private final Map featureTypes = new HashMap<>(); + private final Map resolvedTypes = new HashMap<>(); + /** For invocation loop detection */ private final Deque currentResolutionCallStack; @@ -63,8 +65,24 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement throw new UnsupportedOperationException("Not able to parse gereral references from string form"); } + public void forgetResolvedTypes() { + resolvedTypes.clear(); + } + @Override public TensorType getType(Reference reference) { + // computeIfAbsent without concurrent modification due to resolve adding more resolved entries: + TensorType resolvedType = resolvedTypes.get(reference); + if (resolvedType != null) return resolvedType; + + resolvedType = resolveType(reference); + if (resolvedType == null) + return defaultTypeOf(reference); // Don't store fallback to default as we may know more later + resolvedTypes.put(reference, resolvedType); + return resolvedType; + } + + private TensorType resolveType(Reference reference) { if (currentResolutionCallStack.contains(reference)) throw new IllegalArgumentException("Invocation loop: " + currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) + @@ -90,7 +108,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement // The argument may be a local identifier bound to the actual value String argument = reference.simpleArgument().get(); reference = Reference.simple(reference.name(), bindings.getOrDefault(argument, argument)); - return featureTypes.getOrDefault(reference, defaultTypeOf(reference)); + return featureTypes.get(reference); } // A reference to a function? diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index bc49c40e4e1..b3853b36aa5 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -738,7 +738,7 @@ public class RankProfile implements Serializable, Cloneable { * Creates a context containing the type information of all constants, attributes and query profiles * referable from this rank profile. */ - public TypeContext typeContext(QueryProfileRegistry queryProfiles) { + public MapEvaluationTypeContext typeContext(QueryProfileRegistry queryProfiles) { MapEvaluationTypeContext context = new MapEvaluationTypeContext(getFunctions().values().stream() .map(RankingExpressionFunction::function) .collect(Collectors.toList())); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 93848c067e0..f197e2dfe6d 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -12,6 +12,7 @@ import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.FeatureNames; +import com.yahoo.searchdefinition.MapEvaluationTypeContext; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext; @@ -371,7 +372,7 @@ public class ConvertedModel { */ private static void reduceBatchDimensions(RankingExpression expression, ImportedMlModel model, RankProfile profile, QueryProfileRegistry queryProfiles) { - TypeContext typeContext = profile.typeContext(queryProfiles); + MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles); TensorType typeBeforeReducing = expression.getRoot().type(typeContext); // Check generated functions for inputs to reduce @@ -398,7 +399,7 @@ public class ConvertedModel { } private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model, - TypeContext typeContext) { + MapEvaluationTypeContext typeContext) { if (node instanceof TensorFunctionNode) { TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); if (tensorFunction instanceof Rename) { @@ -428,7 +429,7 @@ public class ConvertedModel { return node; } - private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext context) { + private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, MapEvaluationTypeContext context) { TensorFunction result = function; TensorType type = function.type(context); if (type.dimensions().size() > 1) { @@ -440,6 +441,7 @@ public class ConvertedModel { } if (reduceDimensions.size() > 0) { result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); + context.forgetResolvedTypes(); // We changed types } } return new TensorFunctionNode(result); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java index 4bc61f20d95..992e52a9e5b 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java @@ -23,7 +23,7 @@ public class GeminiTestCase extends AbstractExportingTestCase { Map ranking = removePartKeySuffixes(asMap(p.configProperties())); assertEquals("attribute(right)", resolve(lookup("toplevel", ranking), ranking)); } - + private Map asMap(List> properties) { Map map = new HashMap<>(); for (Pair property : properties) @@ -45,7 +45,7 @@ public class GeminiTestCase extends AbstractExportingTestCase { } /** - * Recurively resolves references to other ranking expressions - rankingExpression(name) - + * Recursively resolves references to other ranking expressions - rankingExpression(name) - * and replaces the reference by the expression */ private String resolve(String expression, Map ranking) { -- cgit v1.2.3