From 6a9681d7f3e42f29bd1d9de9fe9c271489b0c886 Mon Sep 17 00:00:00 2001 From: Henning Baldersheim Date: Thu, 1 Dec 2022 22:45:55 +0100 Subject: Use well defined order where we output text and generate config. Makes config stable and simple tests predictable. --- .../java/com/yahoo/schema/derived/RawRankProfile.java | 7 +++---- .../schema/processing/RankingExpressionTypeResolver.java | 2 +- .../yahoo/schema/processing/VespaMlModelTestCase.java | 7 ++++--- .../com/yahoo/vespa/model/ml/ModelEvaluationTest.java | 6 +++--- .../ai/vespa/models/evaluation/FunctionEvaluator.java | 16 +++++++--------- .../src/main/java/ai/vespa/models/evaluation/Model.java | 2 +- .../java/ai/vespa/models/evaluation/ModelsEvaluator.java | 3 ++- .../models/evaluation/RankProfilesConfigImporter.java | 3 ++- .../vespa/models/evaluation/MlModelsImportingTest.java | 12 ++++++------ .../ai/vespa/models/evaluation/ModelsEvaluatorTest.java | 2 +- searchlib/abi-spec.json | 1 + .../searchlib/rankingexpression/ExpressionFunction.java | 1 + 12 files changed, 32 insertions(+), 30 deletions(-) diff --git a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java index 59f4035f34f..14ee60bb9a6 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java @@ -20,7 +20,6 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; -import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.RankProfilesConfig; import java.nio.ByteBuffer; @@ -196,7 +195,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { rankProperties = new ArrayList<>(compiled.getRankProperties()); Map functions = compiled.getFunctions(); - List functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList()); + List functionExpressions = functions.values().stream().map(RankProfile.RankingExpressionFunction::function).collect(Collectors.toList()); Map functionProperties = new LinkedHashMap<>(); SerializationContext functionSerializationContext = new SerializationContext(functionExpressions, Map.of(), @@ -248,8 +247,8 @@ public class RawRankProfile implements RankProfilesConfig.Producer { String expressionString = e.getValue().function().getBody().getRoot().toString(context).toString(); context.addFunctionSerialization(propertyName, expressionString); - for (Map.Entry argumentType : e.getValue().function().argumentTypes().entrySet()) - context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue()); + e.getValue().function().argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey()) + .forEach(argumentType -> context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue())); if (e.getValue().function().returnType().isPresent()) context.addFunctionTypeSerialization(e.getKey(), e.getValue().function().returnType().get()); // else if (e.getValue().function().arguments().isEmpty()) TODO: Enable this check when we resolve all types diff --git a/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java b/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java index 3e7a1f7613b..871b79a7737 100644 --- a/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java +++ b/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java @@ -73,7 +73,7 @@ public class RankingExpressionTypeResolver extends Processor { for (String argument : expressionFunction.arguments()) { Reference ref = Reference.fromIdentifier(argument); if (context.getType(ref).equals(TensorType.empty)) { - context.setType(ref, expressionFunction.argumentTypes().get(argument)); + context.setType(ref, expressionFunction.getArgumentType(argument)); } } context.forgetResolvedTypes(); diff --git a/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java index dc72df9fc78..01e80e0f47a 100644 --- a/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java @@ -28,11 +28,12 @@ public class VespaMlModelTestCase { "constant(constant1).type : tensor(x[3])\n" + "constant(constant1).value : tensor(x[3]):[0.5, 1.5, 2.5]\n" + "rankingExpression(foo1).rankingScript : reduce(reduce(input1 * input2, sum, name) * constant(constant1), max, x) * 3.0\n" + - "rankingExpression(foo1).input2.type : tensor(x[3])\n" + "rankingExpression(foo1).input1.type : tensor(name{},x[3])\n" + + "rankingExpression(foo1).input2.type : tensor(x[3])\n" + "rankingExpression(foo2).rankingScript : reduce(reduce(input1 * input2, sum, name) * constant(constant1asLarge), max, x) * 3.0\n" + - "rankingExpression(foo2).input2.type : tensor(x[3])\n" + - "rankingExpression(foo2).input1.type : tensor(name{},x[3])\n"; + "rankingExpression(foo2).input1.type : tensor(name{},x[3])\n" + + "rankingExpression(foo2).input2.type : tensor(x[3])\n"; + /** The model name */ private final String name = "example"; diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index f4d37cc4b35..caf0d22d44e 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -153,8 +153,8 @@ public class ModelEvaluationTest { assertNotNull(evaluator.evaluatorOf("add_mul", "default.output2")); assertNotNull(evaluator.evaluatorOf("add_mul", "default", "output1")); assertNotNull(evaluator.evaluatorOf("add_mul", "default", "output2")); - assertEquals(TensorType.fromSpec("tensor(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input1")); - assertEquals(TensorType.fromSpec("tensor(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input2")); + assertEquals(TensorType.fromSpec("tensor(d0[1])"), add_mul.functions().get(0).getArgumentType("input1")); + assertEquals(TensorType.fromSpec("tensor(d0[1])"), add_mul.functions().get(0).getArgumentType("input2")); Model sqrt = evaluator.models().get("sqrt"); assertNotNull(sqrt); @@ -163,7 +163,7 @@ public class ModelEvaluationTest { assertNotNull(sqrt.evaluatorOf("out_layer_1_1")); // converted from "out/layer/1:1" assertNotNull(evaluator.evaluatorOf("sqrt")); assertNotNull(evaluator.evaluatorOf("sqrt", "out_layer_1_1")); - assertEquals(TensorType.fromSpec("tensor(d0[1])"), sqrt.functions().get(0).argumentTypes().get("input")); + assertEquals(TensorType.fromSpec("tensor(d0[1])"), sqrt.functions().get(0).getArgumentType("input")); } private final String profile = diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 1d3da73a509..3fada4c8b6d 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -38,12 +38,13 @@ public class FunctionEvaluator { public FunctionEvaluator bind(String name, Tensor value) { if (evaluated) throw new IllegalStateException("Cannot bind a new value in a used evaluator"); - TensorType requiredType = function.argumentTypes().get(name); + TensorType requiredType = function.getArgumentType(name); if (requiredType == null) throw new IllegalArgumentException("'" + name + "' is not a valid argument in " + function + - ". Expected arguments: " + function.argumentTypes().entrySet().stream() - .map(e -> e.getKey() + ": " + e.getValue()) - .collect(Collectors.joining(", "))); + ". Expected arguments: " + + function.argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey()) + .map(e -> e.getKey() + ": " + e.getValue()) + .collect(Collectors.joining(", "))); if ( ! value.type().isAssignableTo(requiredType)) throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type()); context.put(name, new TensorValue(value)); @@ -101,11 +102,8 @@ public class FunctionEvaluator { } public Tensor evaluate() { - function.argumentTypes().keySet().stream().sorted() - .forEach(name -> { - var type = function.argumentTypes().get(name); - checkArgument(name, type); - }); + function.argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey()) + .forEach(argument -> checkArgument(argument.getKey(), argument.getValue())); evaluated = true; evaluateOnnxModels(); return function.getBody().evaluate(context).asTensor(); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index ffcfb5e9379..d66d0330ea6 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -84,7 +84,7 @@ public class Model { } else { // External functions have type info (when not scalar) - add argument types - if (function.getValue().argumentTypes().get(argument) == null) + if (function.getValue().getArgumentType(argument) == null) functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty)); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index 40a503e0212..28b613ca281 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -10,6 +10,7 @@ import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.config.search.core.RankingExpressionsConfig; +import java.util.Collections; import java.util.Map; /** @@ -42,7 +43,7 @@ public class ModelsEvaluator extends AbstractComponent { } public ModelsEvaluator(Map models) { - this.models = Map.copyOf(models); + this.models = Collections.unmodifiableMap(models); } /** Returns the models of this as an immutable map */ diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index c2cb1993fc0..83674d6789e 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -32,6 +32,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.TreeMap; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -59,7 +60,7 @@ public class RankProfilesConfigImporter { RankingExpressionsConfig expressionsConfig, OnnxModelsConfig onnxModelsConfig) { try { - Map models = new HashMap<>(); + Map models = new TreeMap<>(); for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) { Model model = importProfile(profile, constantsConfig, expressionsConfig, onnxModelsConfig); models.put(model.name(), model); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java index 6a66202609b..d76bade6c1a 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -38,7 +38,7 @@ public class MlModelsImportingTest { xgboost); assertEquals("tensor()", function.returnType().get().toString()); assertEquals("f109, f29, f56, f60", commaSeparated(function.arguments())); - function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg))); + function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.getArgumentType(arg))); // Evaluator FunctionEvaluator evaluator = xgboost.evaluatorOf(); @@ -56,7 +56,7 @@ public class MlModelsImportingTest { lightgbm); assertEquals("tensor()", function.returnType().get().toString()); assertEquals("categorical_1, categorical_2, numerical_1, numerical_2", commaSeparated(function.arguments())); - function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg))); + function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.getArgumentType(arg))); // Evaluator FunctionEvaluator evaluator = lightgbm.evaluatorOf(); @@ -76,7 +76,7 @@ public class MlModelsImportingTest { assertEquals("tensor(d1[10])", function.returnType().get().toString()); assertEquals(1, function.arguments().size()); assertEquals("Placeholder", function.arguments().get(0)); - assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString()); + assertEquals("tensor(d0[],d1[784])", function.getArgumentType("Placeholder").toString()); // Evaluator assertEquals("tensor(d1[10],d2[784])", @@ -98,7 +98,7 @@ public class MlModelsImportingTest { assertEquals("tensor(d1[10])", function.returnType().get().toString()); assertEquals(1, function.arguments().size()); assertEquals("Placeholder", function.arguments().get(0)); - assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString()); + assertEquals("tensor(d0[],d1[784])", function.getArgumentType("Placeholder").toString()); // Evaluator FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available @@ -116,7 +116,7 @@ public class MlModelsImportingTest { assertEquals("tensor(d3[300])", generatedFunction.returnType().get().toString()); assertEquals(1, generatedFunction.arguments().size()); assertEquals("input", generatedFunction.arguments().get(0)); - assertNull(null, generatedFunction.argumentTypes().get("input")); // TODO: Not available until we resolve all argument types + assertNull(null, generatedFunction.getArgumentType("input")); // TODO: Not available until we resolve all argument types // Function assertEquals(1, tfMnist.functions().size()); @@ -127,7 +127,7 @@ public class MlModelsImportingTest { assertEquals("tensor(d1[10])", function.returnType().get().toString()); assertEquals(1, function.arguments().size()); assertEquals("input", function.arguments().get(0)); - assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("input").toString()); + assertEquals("tensor(d0[],d1[784])", function.getArgumentType("input").toString()); // Evaluator FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index f09bac63085..3cd04db8edd 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -104,7 +104,7 @@ public class ModelsEvaluatorTest { evaluator.evaluate(); } catch (IllegalArgumentException e) { - assertEquals("'argNone' is not a valid argument in function 'test'. Expected arguments: arg2: tensor(d1{}), arg1: tensor(d0[1])", + assertEquals("'argNone' is not a valid argument in function 'test'. Expected arguments: arg1: tensor(d0[1]), arg2: tensor(d1{})", Exceptions.toMessageString(e)); } diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 16c1b2c0e7d..5413907e967 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -277,6 +277,7 @@ "public java.util.List arguments()", "public com.yahoo.searchlib.rankingexpression.RankingExpression getBody()", "public java.util.Map argumentTypes()", + "public com.yahoo.tensor.TensorType getArgumentType(java.lang.String)", "public java.util.Optional returnType()", "public com.yahoo.searchlib.rankingexpression.ExpressionFunction withName(java.lang.String)", "public com.yahoo.searchlib.rankingexpression.ExpressionFunction withBody(com.yahoo.searchlib.rankingexpression.RankingExpression)", diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index 241a53fb458..171151bfdf4 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -83,6 +83,7 @@ public class ExpressionFunction { /** Returns the types of the arguments of this, if specified. The keys of this may be any subset of the arguments */ public Map argumentTypes() { return argumentTypes; } + public TensorType getArgumentType(String argumentName) { return argumentTypes.get(argumentName); } /** Returns the return type of this, or empty if not specified */ public Optional returnType() { return returnType; } -- cgit v1.2.3