From 6a9681d7f3e42f29bd1d9de9fe9c271489b0c886 Mon Sep 17 00:00:00 2001 From: Henning Baldersheim Date: Thu, 1 Dec 2022 22:45:55 +0100 Subject: Use well defined order where we output text and generate config. Makes config stable and simple tests predictable. --- .../ai/vespa/models/evaluation/FunctionEvaluator.java | 16 +++++++--------- .../src/main/java/ai/vespa/models/evaluation/Model.java | 2 +- .../java/ai/vespa/models/evaluation/ModelsEvaluator.java | 3 ++- .../models/evaluation/RankProfilesConfigImporter.java | 3 ++- .../vespa/models/evaluation/MlModelsImportingTest.java | 12 ++++++------ .../ai/vespa/models/evaluation/ModelsEvaluatorTest.java | 2 +- 6 files changed, 19 insertions(+), 19 deletions(-) (limited to 'model-evaluation') diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 1d3da73a509..3fada4c8b6d 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -38,12 +38,13 @@ public class FunctionEvaluator { public FunctionEvaluator bind(String name, Tensor value) { if (evaluated) throw new IllegalStateException("Cannot bind a new value in a used evaluator"); - TensorType requiredType = function.argumentTypes().get(name); + TensorType requiredType = function.getArgumentType(name); if (requiredType == null) throw new IllegalArgumentException("'" + name + "' is not a valid argument in " + function + - ". Expected arguments: " + function.argumentTypes().entrySet().stream() - .map(e -> e.getKey() + ": " + e.getValue()) - .collect(Collectors.joining(", "))); + ". Expected arguments: " + + function.argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey()) + .map(e -> e.getKey() + ": " + e.getValue()) + .collect(Collectors.joining(", "))); if ( ! value.type().isAssignableTo(requiredType)) throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type()); context.put(name, new TensorValue(value)); @@ -101,11 +102,8 @@ public class FunctionEvaluator { } public Tensor evaluate() { - function.argumentTypes().keySet().stream().sorted() - .forEach(name -> { - var type = function.argumentTypes().get(name); - checkArgument(name, type); - }); + function.argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey()) + .forEach(argument -> checkArgument(argument.getKey(), argument.getValue())); evaluated = true; evaluateOnnxModels(); return function.getBody().evaluate(context).asTensor(); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index ffcfb5e9379..d66d0330ea6 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -84,7 +84,7 @@ public class Model { } else { // External functions have type info (when not scalar) - add argument types - if (function.getValue().argumentTypes().get(argument) == null) + if (function.getValue().getArgumentType(argument) == null) functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty)); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index 40a503e0212..28b613ca281 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -10,6 +10,7 @@ import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.config.search.core.RankingExpressionsConfig; +import java.util.Collections; import java.util.Map; /** @@ -42,7 +43,7 @@ public class ModelsEvaluator extends AbstractComponent { } public ModelsEvaluator(Map models) { - this.models = Map.copyOf(models); + this.models = Collections.unmodifiableMap(models); } /** Returns the models of this as an immutable map */ diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index c2cb1993fc0..83674d6789e 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -32,6 +32,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.TreeMap; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -59,7 +60,7 @@ public class RankProfilesConfigImporter { RankingExpressionsConfig expressionsConfig, OnnxModelsConfig onnxModelsConfig) { try { - Map models = new HashMap<>(); + Map models = new TreeMap<>(); for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) { Model model = importProfile(profile, constantsConfig, expressionsConfig, onnxModelsConfig); models.put(model.name(), model); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java index 6a66202609b..d76bade6c1a 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -38,7 +38,7 @@ public class MlModelsImportingTest { xgboost); assertEquals("tensor()", function.returnType().get().toString()); assertEquals("f109, f29, f56, f60", commaSeparated(function.arguments())); - function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg))); + function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.getArgumentType(arg))); // Evaluator FunctionEvaluator evaluator = xgboost.evaluatorOf(); @@ -56,7 +56,7 @@ public class MlModelsImportingTest { lightgbm); assertEquals("tensor()", function.returnType().get().toString()); assertEquals("categorical_1, categorical_2, numerical_1, numerical_2", commaSeparated(function.arguments())); - function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg))); + function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.getArgumentType(arg))); // Evaluator FunctionEvaluator evaluator = lightgbm.evaluatorOf(); @@ -76,7 +76,7 @@ public class MlModelsImportingTest { assertEquals("tensor(d1[10])", function.returnType().get().toString()); assertEquals(1, function.arguments().size()); assertEquals("Placeholder", function.arguments().get(0)); - assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString()); + assertEquals("tensor(d0[],d1[784])", function.getArgumentType("Placeholder").toString()); // Evaluator assertEquals("tensor(d1[10],d2[784])", @@ -98,7 +98,7 @@ public class MlModelsImportingTest { assertEquals("tensor(d1[10])", function.returnType().get().toString()); assertEquals(1, function.arguments().size()); assertEquals("Placeholder", function.arguments().get(0)); - assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString()); + assertEquals("tensor(d0[],d1[784])", function.getArgumentType("Placeholder").toString()); // Evaluator FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available @@ -116,7 +116,7 @@ public class MlModelsImportingTest { assertEquals("tensor(d3[300])", generatedFunction.returnType().get().toString()); assertEquals(1, generatedFunction.arguments().size()); assertEquals("input", generatedFunction.arguments().get(0)); - assertNull(null, generatedFunction.argumentTypes().get("input")); // TODO: Not available until we resolve all argument types + assertNull(null, generatedFunction.getArgumentType("input")); // TODO: Not available until we resolve all argument types // Function assertEquals(1, tfMnist.functions().size()); @@ -127,7 +127,7 @@ public class MlModelsImportingTest { assertEquals("tensor(d1[10])", function.returnType().get().toString()); assertEquals(1, function.arguments().size()); assertEquals("input", function.arguments().get(0)); - assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("input").toString()); + assertEquals("tensor(d0[],d1[784])", function.getArgumentType("input").toString()); // Evaluator FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index f09bac63085..3cd04db8edd 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -104,7 +104,7 @@ public class ModelsEvaluatorTest { evaluator.evaluate(); } catch (IllegalArgumentException e) { - assertEquals("'argNone' is not a valid argument in function 'test'. Expected arguments: arg2: tensor(d1{}), arg1: tensor(d0[1])", + assertEquals("'argNone' is not a valid argument in function 'test'. Expected arguments: arg1: tensor(d0[1]), arg2: tensor(d1{})", Exceptions.toMessageString(e)); } -- cgit v1.2.3