summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-10-01 06:08:08 +0200
committerJon Bratseth <bratseth@oath.com>2018-10-01 06:08:08 +0200
commit0ff988ecf9704faac33f6201cb59349e48846457 (patch)
tree0cceb9c6961836a7b6149798d041e341bedcf903
parent9c80048457caab3881f3319aadd0990f65c04937 (diff)
Resoløve return types whenever possible
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java7
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java5
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java15
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java12
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java7
-rw-r--r--model-evaluation/src/test/resources/config/models/rank-profiles.cfg6
6 files changed, 43 insertions, 9 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java
index 4c8b5910b78..3d1ef48c9dd 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java
@@ -6,6 +6,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -60,7 +61,7 @@ public class RankingExpressionTypeResolver extends Processor {
private void resolveTypesIn(RankProfile profile, boolean validate) {
TypeContext<Reference> context = profile.typeContext(queryProfiles);
for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) {
- if ( ! function.getValue().function().arguments().isEmpty()) continue;
+ if (hasUntypedArguments(function.getValue().function())) continue;
TensorType type = resolveType(function.getValue().function().getBody(),
"function '" + function.getKey() + "'",
context);
@@ -74,6 +75,10 @@ public class RankingExpressionTypeResolver extends Processor {
}
}
+ private boolean hasUntypedArguments(ExpressionFunction function) {
+ return function.arguments().size() > function.argumentTypes().size();
+ }
+
private TensorType resolveType(RankingExpression expression, String expressionDescription, TypeContext context) {
if (expression == null) return null;
return resolveType(expression.getRoot(), expressionDescription, context);
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index 7d4db9daeff..daae2dbc496 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -90,7 +90,7 @@ public class ModelEvaluationTest {
RankProfilesConfig.Builder b = new RankProfilesConfig.Builder();
cluster.getConfig(b);
RankProfilesConfig config = new RankProfilesConfig(b);
- System.out.println(config);
+ // System.out.println(config);
RankingConstantsConfig.Builder cb = new RankingConstantsConfig.Builder();
cluster.getConfig(cb);
@@ -147,7 +147,8 @@ public class ModelEvaluationTest {
"rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript: join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))\n" +
"rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).type: tensor(d3[300])\n" +
"rankingExpression(serving_default.y).rankingScript: join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))\n" +
- "rankingExpression(serving_default.y).x.type: tensor(d0[],d1[784])\n";
+ "rankingExpression(serving_default.y).x.type: tensor(d0[],d1[784])\n" +
+ "rankingExpression(serving_default.y).type: tensor(d1[10])\n";
private RankProfilesConfig.Rankprofile.Fef findProfile(String name, RankProfilesConfig config) {
for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
index 5bb22b23345..fa45920f3c8 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
@@ -28,6 +28,8 @@ class FunctionReference {
Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?");
private static final Pattern argumentTypePattern =
Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)\\.([a-zA-Z0-9_]+)\\.type?");
+ private static final Pattern returnTypePattern =
+ Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)\\.type?");
/** The name of the function referenced */
private final String name;
@@ -92,6 +94,19 @@ class FunctionReference {
return Optional.of(new Pair<>(new FunctionReference(name, instance), argument));
}
+ /**
+ * Returns a function reference from the given return type serial form,
+ * or empty if the string is not a valid function return typoe serial form
+ */
+ static Optional<FunctionReference> fromReturnTypeSerial(String serialForm) {
+ Matcher expressionMatcher = returnTypePattern.matcher(serialForm);
+ if ( ! expressionMatcher.matches()) return Optional.empty();
+
+ String name = expressionMatcher.group(1);
+ String instance = expressionMatcher.group(2);
+ return Optional.of(new FunctionReference(name, instance));
+ }
+
public static FunctionReference fromName(String name) {
return new FunctionReference(name, null);
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index f48d76e86f3..648c6d931a9 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -20,6 +20,7 @@ import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collection;
+import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
@@ -74,10 +75,10 @@ public class RankProfilesConfigImporter {
for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) {
Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name());
Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name());
+ Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name());
if ( reference.isPresent()) {
- List<String> arguments = new ArrayList<>(); // TODO: Arguments?
RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value());
- ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), arguments, expression);
+ ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), Collections.emptyList(), expression);
if (reference.get().isFree()) // make available in model under configured name
functions.put(reference.get(), function);
@@ -92,6 +93,13 @@ public class RankProfilesConfigImporter {
functions.put(argReference, function);
referencedFunctions.put(argReference, function);
}
+ else if (returnType.isPresent()) { // Return type always follows the function in properties
+ ExpressionFunction function = referencedFunctions.get(returnType.get());
+ function = function.withReturnType(TensorType.fromSpec(property.value()));
+ if (returnType.get().isFree())
+ functions.put(returnType.get(), function);
+ referencedFunctions.put(returnType.get(), function);
+ }
else if (property.name().equals("vespa.rank.firstphase")) { // Include in addition to functions
firstPhase = new ExpressionFunction("firstphase", new ArrayList<>(),
new RankingExpression("first-phase", property.value()));
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index 40ef2c65aaa..287a2387b34 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -33,7 +33,6 @@ public class MlModelsImportingTest {
"(optimized sum of condition trees of size 192 bytes)",
xgboost);
-
// Function
assertEquals(1, xgboost.functions().size());
ExpressionFunction function = xgboost.functions().get(0);
@@ -58,7 +57,7 @@ public class MlModelsImportingTest {
// Function
assertEquals(1, onnxMnistSoftmax.functions().size());
ExpressionFunction function = onnxMnistSoftmax.functions().get(0);
- // assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); TODO
+ assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
assertEquals(1, function.arguments().size());
assertEquals("Placeholder", function.arguments().get(0));
assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder"));
@@ -78,7 +77,7 @@ public class MlModelsImportingTest {
// Function
assertEquals(1, tfMnistSoftmax.functions().size());
ExpressionFunction function = tfMnistSoftmax.functions().get(0);
- // assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); TODO
+ assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
assertEquals(1, function.arguments().size());
assertEquals("x", function.arguments().get(0));
assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("x"));
@@ -103,7 +102,7 @@ public class MlModelsImportingTest {
// Function
assertEquals(2, tfMnist.functions().size()); // TODO: Filter out generated function
ExpressionFunction function = tfMnist.functions().get(1);
- // assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); TODO
+ assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
assertEquals(1, function.arguments().size());
assertEquals("x", function.arguments().get(0));
assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("x"));
diff --git a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
index 7980d157193..9175b60315b 100644
--- a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
+++ b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
@@ -3,6 +3,8 @@ rankprofile[0].fef.property[0].name "rankingExpression(default.add).rankingScrip
rankprofile[0].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))"
rankprofile[0].fef.property[1].name "rankingExpression(default.add).Placeholder.type"
rankprofile[0].fef.property[1].value "tensor(d0[],d1[784])"
+rankprofile[0].fef.property[2].name "rankingExpression(default.add).type"
+rankprofile[0].fef.property[2].value "tensor(d1[10])"
rankprofile[1].name "xgboost_2_2"
rankprofile[1].fef.property[0].name "rankingExpression(xgboost_2_2).rankingScript"
rankprofile[1].fef.property[0].value "if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)"
@@ -11,6 +13,8 @@ rankprofile[2].fef.property[0].name "rankingExpression(serving_default.y).rankin
rankprofile[2].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))"
rankprofile[2].fef.property[1].name "rankingExpression(serving_default.y).x.type"
rankprofile[2].fef.property[1].value "tensor(d0[],d1[784])"
+rankprofile[2].fef.property[2].name "rankingExpression(serving_default.y).type"
+rankprofile[2].fef.property[2].value "tensor(d1[10])"
rankprofile[3].name "mnist_saved"
rankprofile[3].fef.property[0].name "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript"
rankprofile[3].fef.property[0].value "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))"
@@ -20,3 +24,5 @@ rankprofile[3].fef.property[2].name "rankingExpression(serving_default.y).rankin
rankprofile[3].fef.property[2].value "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))"
rankprofile[3].fef.property[3].name "rankingExpression(serving_default.y).x.type"
rankprofile[3].fef.property[3].value "tensor(d0[],d1[784])"
+rankprofile[3].fef.property[4].name "rankingExpression(serving_default.y).type"
+rankprofile[3].fef.property[4].value "tensor(d1[10])" \ No newline at end of file