summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-04-22 06:24:30 +0000
committerArne Juul <arnej@verizonmedia.com>2021-04-22 06:24:30 +0000
commit7179bd7dae6aca37ae1aa061161da3d998c2644e (patch)
tree923d31bb3a4fe4550c5f04897493bcdb0ea91a01
parent4c46d92474745467a53eb53336fd4c5c162b2375 (diff)
Reapply "Arnej/evaluate bindings in parent context"
This reverts commit f1598d54afa672ec895330dba43a9f0fb5687587.
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java23
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java8
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolverTestCase.java68
3 files changed, 90 insertions, 9 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index bc98f0ab8c5..fc7362a43e1 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -39,6 +39,8 @@ import java.util.stream.Collectors;
*/
public class MapEvaluationTypeContext extends FunctionReferenceContext implements TypeContext<Reference> {
+ private final Optional<MapEvaluationTypeContext> parent;
+
private final Map<Reference, TensorType> featureTypes = new HashMap<>();
private final Map<Reference, TensorType> resolvedTypes = new HashMap<>();
@@ -54,6 +56,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
MapEvaluationTypeContext(Collection<ExpressionFunction> functions, Map<Reference, TensorType> featureTypes) {
super(functions);
+ this.parent = Optional.empty();
this.featureTypes.putAll(featureTypes);
this.currentResolutionCallStack = new ArrayDeque<>();
this.queryFeaturesNotDeclared = new TreeSet<>();
@@ -63,12 +66,14 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
private MapEvaluationTypeContext(Map<String, ExpressionFunction> functions,
Map<String, String> bindings,
+ Optional<MapEvaluationTypeContext> parent,
Map<Reference, TensorType> featureTypes,
Deque<Reference> currentResolutionCallStack,
SortedSet<Reference> queryFeaturesNotDeclared,
boolean tensorsAreUsed,
Map<Reference, TensorType> globallyResolvedTypes) {
super(functions, bindings);
+ this.parent = parent;
this.featureTypes.putAll(featureTypes);
this.currentResolutionCallStack = currentResolutionCallStack;
this.queryFeaturesNotDeclared = queryFeaturesNotDeclared;
@@ -130,13 +135,16 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) +
" -> " + reference);
- // Bound to a function argument, and not to a same-named identifier (which would lead to a loop)?
+ // Bound to a function argument?
Optional<String> binding = boundIdentifier(reference);
- if (binding.isPresent() && ! binding.get().equals(reference.toString())) {
+ if (binding.isPresent()) {
try {
// This is not pretty, but changing to bind expressions rather
// than their string values requires deeper changes
- return new RankingExpression(binding.get()).type(this);
+ var expr = new RankingExpression(binding.get());
+ var type = expr.type(parent.orElseThrow(
+ () -> new IllegalArgumentException("when a binding is present we must have a parent context")));
+ return type;
} catch (ParseException e) {
throw new IllegalArgumentException(e);
}
@@ -157,7 +165,10 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
// A reference to a function?
Optional<ExpressionFunction> function = functionInvocation(reference);
if (function.isPresent()) {
- return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments())));
+ var body = function.get().getBody();
+ var child = this.withBindings(bind(function.get().arguments(), reference.arguments()));
+ var type = body.type(child);
+ return type;
}
// A reference to an ONNX model?
@@ -297,8 +308,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
Map<String, String> bindings = new HashMap<>(formalArguments.size());
for (int i = 0; i < formalArguments.size(); i++) {
String identifier = invocationArguments.expressions().get(i).toString();
- String identifierBinding = super.getBinding(identifier);
- bindings.put(formalArguments.get(i), identifierBinding != null ? identifierBinding : identifier);
+ bindings.put(formalArguments.get(i), identifier);
}
return bindings;
}
@@ -323,6 +333,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
public MapEvaluationTypeContext withBindings(Map<String, String> bindings) {
return new MapEvaluationTypeContext(functions(),
bindings,
+ Optional.of(this),
featureTypes,
currentResolutionCallStack,
queryFeaturesNotDeclared,
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 5ee6ed02e61..b757259102b 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -238,7 +238,7 @@ public class ConvertedModel {
function.returnType().map(TensorType::fromSpec));
}
catch (ParseException e) {
- throw new IllegalArgumentException("Gor an illegal argument from importing " + function.name(), e);
+ throw new IllegalArgumentException("Got an illegal argument from importing " + function.name(), e);
}
}
@@ -260,8 +260,9 @@ public class ConvertedModel {
profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
for (RankingConstant constant : store.readLargeConstants()) {
- if ( ! profile.rankingConstants().asMap().containsKey(constant.getName()))
+ if ( ! profile.rankingConstants().asMap().containsKey(constant.getName())) {
profile.rankingConstants().add(constant);
+ }
}
for (Pair<String, RankingExpression> function : store.readFunctions()) {
@@ -320,7 +321,8 @@ public class ConvertedModel {
"\nwant to add " + expression + "\n");
return;
}
- profile.addFunction(new ExpressionFunction(functionName, expression), false); // TODO: Inline if only used once
+ var fun = new ExpressionFunction(functionName, expression);
+ profile.addFunction(fun, false); // TODO: Inline if only used once
}
/**
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolverTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolverTestCase.java
index 96f12a47a2f..71bfddf1419 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolverTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolverTestCase.java
@@ -195,6 +195,74 @@ public class RankingExpressionTypeResolverTestCase {
summaryFeatures(profile).get("return_b").type(profile.typeContext(builder.getQueryProfileRegistry())));
}
+
+ @Test
+ public void testTensorFunctionInvocationTypes_NestedSameName() throws Exception {
+ SearchBuilder builder = new SearchBuilder();
+ builder.importString(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[10],y[1]) {",
+ " indexing: attribute",
+ " }",
+ " field b type tensor(z[10]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " function return_a() {",
+ " expression: return_first(attribute(a), attribute(b))",
+ " }",
+ " function return_b() {",
+ " expression: return_second(attribute(a), attribute(b))",
+ " }",
+ " function return_first(e1, e2) {",
+ " expression: just_return(e1)",
+ " }",
+ " function just_return(e1) {",
+ " expression: e1",
+ " }",
+ " function return_second(e1, e2) {",
+ " expression: return_first(e2+0, e1)",
+ " }",
+ " summary-features {",
+ " return_a",
+ " return_b",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build();
+ RankProfile profile =
+ builder.getRankProfileRegistry().get(builder.getSearch(), "my_rank_profile");
+ assertEquals(TensorType.fromSpec("tensor(x[10],y[1])"),
+ summaryFeatures(profile).get("return_a").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ assertEquals(TensorType.fromSpec("tensor(z[10])"),
+ summaryFeatures(profile).get("return_b").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ }
+
+ @Test
+ public void testTensorFunctionInvocationTypes_viaFuncWithExpr() throws Exception {
+ SearchBuilder builder = new SearchBuilder();
+ builder.importString(joinLines(
+ "search test {",
+ " document test {",
+ " field t1 type tensor<float>(y{}) { indexing: attribute | summary }",
+ " field t2 type tensor<float>(x{}) { indexing: attribute | summary }",
+ " }",
+ " rank-profile test {",
+ " function my_func(t) { expression: sum(t, x) + 1 }",
+ " function test_func_via_func_with_expr() { expression: call_func_with_expr( attribute(t1), attribute(t2) ) }",
+ " function call_func_with_expr(a, b) { expression: my_func( a * b ) }",
+ " summary-features { test_func_via_func_with_expr }",
+ " }",
+ "}"));
+ builder.build();
+ RankProfile profile = builder.getRankProfileRegistry().get(builder.getSearch(), "test");
+ assertEquals(TensorType.fromSpec("tensor<float>(y{})"),
+ summaryFeatures(profile).get("test_func_via_func_with_expr").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ }
+
@Test
public void importedFieldsAreAvailable() throws Exception {
SearchBuilder builder = new SearchBuilder();