diff options
6 files changed, 43 insertions, 9 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java index 4c8b5910b78..3d1ef48c9dd 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java @@ -6,6 +6,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -60,7 +61,7 @@ public class RankingExpressionTypeResolver extends Processor { private void resolveTypesIn(RankProfile profile, boolean validate) { TypeContext<Reference> context = profile.typeContext(queryProfiles); for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) { - if ( ! function.getValue().function().arguments().isEmpty()) continue; + if (hasUntypedArguments(function.getValue().function())) continue; TensorType type = resolveType(function.getValue().function().getBody(), "function '" + function.getKey() + "'", context); @@ -74,6 +75,10 @@ public class RankingExpressionTypeResolver extends Processor { } } + private boolean hasUntypedArguments(ExpressionFunction function) { + return function.arguments().size() > function.argumentTypes().size(); + } + private TensorType resolveType(RankingExpression expression, String expressionDescription, TypeContext context) { if (expression == null) return null; return resolveType(expression.getRoot(), expressionDescription, context); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index 7d4db9daeff..daae2dbc496 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -90,7 +90,7 @@ public class ModelEvaluationTest { RankProfilesConfig.Builder b = new RankProfilesConfig.Builder(); cluster.getConfig(b); RankProfilesConfig config = new RankProfilesConfig(b); - System.out.println(config); + // System.out.println(config); RankingConstantsConfig.Builder cb = new RankingConstantsConfig.Builder(); cluster.getConfig(cb); @@ -147,7 +147,8 @@ public class ModelEvaluationTest { "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript: join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))\n" + "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).type: tensor(d3[300])\n" + "rankingExpression(serving_default.y).rankingScript: join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))\n" + - "rankingExpression(serving_default.y).x.type: tensor(d0[],d1[784])\n"; + "rankingExpression(serving_default.y).x.type: tensor(d0[],d1[784])\n" + + "rankingExpression(serving_default.y).type: tensor(d1[10])\n"; private RankProfilesConfig.Rankprofile.Fef findProfile(String name, RankProfilesConfig config) { for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) { diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java index 5bb22b23345..fa45920f3c8 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java @@ -28,6 +28,8 @@ class FunctionReference { Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?"); private static final Pattern argumentTypePattern = Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)\\.([a-zA-Z0-9_]+)\\.type?"); + private static final Pattern returnTypePattern = + Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)\\.type?"); /** The name of the function referenced */ private final String name; @@ -92,6 +94,19 @@ class FunctionReference { return Optional.of(new Pair<>(new FunctionReference(name, instance), argument)); } + /** + * Returns a function reference from the given return type serial form, + * or empty if the string is not a valid function return typoe serial form + */ + static Optional<FunctionReference> fromReturnTypeSerial(String serialForm) { + Matcher expressionMatcher = returnTypePattern.matcher(serialForm); + if ( ! expressionMatcher.matches()) return Optional.empty(); + + String name = expressionMatcher.group(1); + String instance = expressionMatcher.group(2); + return Optional.of(new FunctionReference(name, instance)); + } + public static FunctionReference fromName(String name) { return new FunctionReference(name, null); } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index f48d76e86f3..648c6d931a9 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; @@ -74,10 +75,10 @@ public class RankProfilesConfigImporter { for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) { Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name()); Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name()); + Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name()); if ( reference.isPresent()) { - List<String> arguments = new ArrayList<>(); // TODO: Arguments? RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value()); - ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), arguments, expression); + ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), Collections.emptyList(), expression); if (reference.get().isFree()) // make available in model under configured name functions.put(reference.get(), function); @@ -92,6 +93,13 @@ public class RankProfilesConfigImporter { functions.put(argReference, function); referencedFunctions.put(argReference, function); } + else if (returnType.isPresent()) { // Return type always follows the function in properties + ExpressionFunction function = referencedFunctions.get(returnType.get()); + function = function.withReturnType(TensorType.fromSpec(property.value())); + if (returnType.get().isFree()) + functions.put(returnType.get(), function); + referencedFunctions.put(returnType.get(), function); + } else if (property.name().equals("vespa.rank.firstphase")) { // Include in addition to functions firstPhase = new ExpressionFunction("firstphase", new ArrayList<>(), new RankingExpression("first-phase", property.value())); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java index 40ef2c65aaa..287a2387b34 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -33,7 +33,6 @@ public class MlModelsImportingTest { "(optimized sum of condition trees of size 192 bytes)", xgboost); - // Function assertEquals(1, xgboost.functions().size()); ExpressionFunction function = xgboost.functions().get(0); @@ -58,7 +57,7 @@ public class MlModelsImportingTest { // Function assertEquals(1, onnxMnistSoftmax.functions().size()); ExpressionFunction function = onnxMnistSoftmax.functions().get(0); - // assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); TODO + assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); assertEquals(1, function.arguments().size()); assertEquals("Placeholder", function.arguments().get(0)); assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder")); @@ -78,7 +77,7 @@ public class MlModelsImportingTest { // Function assertEquals(1, tfMnistSoftmax.functions().size()); ExpressionFunction function = tfMnistSoftmax.functions().get(0); - // assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); TODO + assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); assertEquals(1, function.arguments().size()); assertEquals("x", function.arguments().get(0)); assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("x")); @@ -103,7 +102,7 @@ public class MlModelsImportingTest { // Function assertEquals(2, tfMnist.functions().size()); // TODO: Filter out generated function ExpressionFunction function = tfMnist.functions().get(1); - // assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); TODO + assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); assertEquals(1, function.arguments().size()); assertEquals("x", function.arguments().get(0)); assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("x")); diff --git a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg index 7980d157193..9175b60315b 100644 --- a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg +++ b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg @@ -3,6 +3,8 @@ rankprofile[0].fef.property[0].name "rankingExpression(default.add).rankingScrip rankprofile[0].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))" rankprofile[0].fef.property[1].name "rankingExpression(default.add).Placeholder.type" rankprofile[0].fef.property[1].value "tensor(d0[],d1[784])" +rankprofile[0].fef.property[2].name "rankingExpression(default.add).type" +rankprofile[0].fef.property[2].value "tensor(d1[10])" rankprofile[1].name "xgboost_2_2" rankprofile[1].fef.property[0].name "rankingExpression(xgboost_2_2).rankingScript" rankprofile[1].fef.property[0].value "if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)" @@ -11,6 +13,8 @@ rankprofile[2].fef.property[0].name "rankingExpression(serving_default.y).rankin rankprofile[2].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))" rankprofile[2].fef.property[1].name "rankingExpression(serving_default.y).x.type" rankprofile[2].fef.property[1].value "tensor(d0[],d1[784])" +rankprofile[2].fef.property[2].name "rankingExpression(serving_default.y).type" +rankprofile[2].fef.property[2].value "tensor(d1[10])" rankprofile[3].name "mnist_saved" rankprofile[3].fef.property[0].name "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript" rankprofile[3].fef.property[0].value "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))" @@ -20,3 +24,5 @@ rankprofile[3].fef.property[2].name "rankingExpression(serving_default.y).rankin rankprofile[3].fef.property[2].value "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))" rankprofile[3].fef.property[3].name "rankingExpression(serving_default.y).x.type" rankprofile[3].fef.property[3].value "tensor(d0[],d1[784])" +rankprofile[3].fef.property[4].name "rankingExpression(serving_default.y).type" +rankprofile[3].fef.property[4].value "tensor(d1[10])"
\ No newline at end of file |