summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-03-26 15:39:52 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-03-26 15:39:52 +0200
commitb34ab2a5656401edd97ea70137a4be31406fb719 (patch)
treeedb7441250486d01e3656ada858b98ddcf60eadc /config-model
parentdc46e712efefb2324869a1abf7baac198b33778e (diff)
Optimize type resolving
- Cache reference hash code - Cache resolved types
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java20
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java8
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java4
4 files changed, 27 insertions, 7 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index 0d9ea00bf73..a0f35dbefe6 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -37,6 +37,8 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
private final Map<Reference, TensorType> featureTypes = new HashMap<>();
+ private final Map<Reference, TensorType> resolvedTypes = new HashMap<>();
+
/** For invocation loop detection */
private final Deque<Reference> currentResolutionCallStack;
@@ -63,8 +65,24 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
throw new UnsupportedOperationException("Not able to parse gereral references from string form");
}
+ public void forgetResolvedTypes() {
+ resolvedTypes.clear();
+ }
+
@Override
public TensorType getType(Reference reference) {
+ // computeIfAbsent without concurrent modification due to resolve adding more resolved entries:
+ TensorType resolvedType = resolvedTypes.get(reference);
+ if (resolvedType != null) return resolvedType;
+
+ resolvedType = resolveType(reference);
+ if (resolvedType == null)
+ return defaultTypeOf(reference); // Don't store fallback to default as we may know more later
+ resolvedTypes.put(reference, resolvedType);
+ return resolvedType;
+ }
+
+ private TensorType resolveType(Reference reference) {
if (currentResolutionCallStack.contains(reference))
throw new IllegalArgumentException("Invocation loop: " +
currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) +
@@ -90,7 +108,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
// The argument may be a local identifier bound to the actual value
String argument = reference.simpleArgument().get();
reference = Reference.simple(reference.name(), bindings.getOrDefault(argument, argument));
- return featureTypes.getOrDefault(reference, defaultTypeOf(reference));
+ return featureTypes.get(reference);
}
// A reference to a function?
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index bc49c40e4e1..b3853b36aa5 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -738,7 +738,7 @@ public class RankProfile implements Serializable, Cloneable {
* Creates a context containing the type information of all constants, attributes and query profiles
* referable from this rank profile.
*/
- public TypeContext<Reference> typeContext(QueryProfileRegistry queryProfiles) {
+ public MapEvaluationTypeContext typeContext(QueryProfileRegistry queryProfiles) {
MapEvaluationTypeContext context = new MapEvaluationTypeContext(getFunctions().values().stream()
.map(RankingExpressionFunction::function)
.collect(Collectors.toList()));
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 93848c067e0..f197e2dfe6d 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -12,6 +12,7 @@ import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.FeatureNames;
+import com.yahoo.searchdefinition.MapEvaluationTypeContext;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext;
@@ -371,7 +372,7 @@ public class ConvertedModel {
*/
private static void reduceBatchDimensions(RankingExpression expression, ImportedMlModel model,
RankProfile profile, QueryProfileRegistry queryProfiles) {
- TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
+ MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles);
TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
// Check generated functions for inputs to reduce
@@ -398,7 +399,7 @@ public class ConvertedModel {
}
private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model,
- TypeContext<Reference> typeContext) {
+ MapEvaluationTypeContext typeContext) {
if (node instanceof TensorFunctionNode) {
TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
if (tensorFunction instanceof Rename) {
@@ -428,7 +429,7 @@ public class ConvertedModel {
return node;
}
- private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
+ private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, MapEvaluationTypeContext context) {
TensorFunction result = function;
TensorType type = function.type(context);
if (type.dimensions().size() > 1) {
@@ -440,6 +441,7 @@ public class ConvertedModel {
}
if (reduceDimensions.size() > 0) {
result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
+ context.forgetResolvedTypes(); // We changed types
}
}
return new TensorFunctionNode(result);
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java
index 4bc61f20d95..992e52a9e5b 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java
@@ -23,7 +23,7 @@ public class GeminiTestCase extends AbstractExportingTestCase {
Map<String, String> ranking = removePartKeySuffixes(asMap(p.configProperties()));
assertEquals("attribute(right)", resolve(lookup("toplevel", ranking), ranking));
}
-
+
private Map<String, String> asMap(List<Pair<String, String>> properties) {
Map<String, String> map = new HashMap<>();
for (Pair<String, String> property : properties)
@@ -45,7 +45,7 @@ public class GeminiTestCase extends AbstractExportingTestCase {
}
/**
- * Recurively resolves references to other ranking expressions - rankingExpression(name) -
+ * Recursively resolves references to other ranking expressions - rankingExpression(name) -
* and replaces the reference by the expression
*/
private String resolve(String expression, Map<String, String> ranking) {