summaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2022-12-01 22:45:55 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2022-12-01 22:45:55 +0100
commit6a9681d7f3e42f29bd1d9de9fe9c271489b0c886 (patch)
tree227e5c3f2e1fcae248e0daf85735358b06971e33 /model-evaluation
parent1eb22cc4a24973f52b344c3033cff394c724cbe4 (diff)
Use well defined order where we output text and generate config. Makes config stable and simple tests predictable.
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java16
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java3
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java3
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java12
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java2
6 files changed, 19 insertions, 19 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index 1d3da73a509..3fada4c8b6d 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -38,12 +38,13 @@ public class FunctionEvaluator {
public FunctionEvaluator bind(String name, Tensor value) {
if (evaluated)
throw new IllegalStateException("Cannot bind a new value in a used evaluator");
- TensorType requiredType = function.argumentTypes().get(name);
+ TensorType requiredType = function.getArgumentType(name);
if (requiredType == null)
throw new IllegalArgumentException("'" + name + "' is not a valid argument in " + function +
- ". Expected arguments: " + function.argumentTypes().entrySet().stream()
- .map(e -> e.getKey() + ": " + e.getValue())
- .collect(Collectors.joining(", ")));
+ ". Expected arguments: " +
+ function.argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey())
+ .map(e -> e.getKey() + ": " + e.getValue())
+ .collect(Collectors.joining(", ")));
if ( ! value.type().isAssignableTo(requiredType))
throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type());
context.put(name, new TensorValue(value));
@@ -101,11 +102,8 @@ public class FunctionEvaluator {
}
public Tensor evaluate() {
- function.argumentTypes().keySet().stream().sorted()
- .forEach(name -> {
- var type = function.argumentTypes().get(name);
- checkArgument(name, type);
- });
+ function.argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey())
+ .forEach(argument -> checkArgument(argument.getKey(), argument.getValue()));
evaluated = true;
evaluateOnnxModels();
return function.getBody().evaluate(context).asTensor();
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index ffcfb5e9379..d66d0330ea6 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -84,7 +84,7 @@ public class Model {
}
else {
// External functions have type info (when not scalar) - add argument types
- if (function.getValue().argumentTypes().get(argument) == null)
+ if (function.getValue().getArgumentType(argument) == null)
functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty));
}
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index 40a503e0212..28b613ca281 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -10,6 +10,7 @@ import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
+import java.util.Collections;
import java.util.Map;
/**
@@ -42,7 +43,7 @@ public class ModelsEvaluator extends AbstractComponent {
}
public ModelsEvaluator(Map<String, Model> models) {
- this.models = Map.copyOf(models);
+ this.models = Collections.unmodifiableMap(models);
}
/** Returns the models of this as an immutable map */
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index c2cb1993fc0..83674d6789e 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -32,6 +32,7 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.TreeMap;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -59,7 +60,7 @@ public class RankProfilesConfigImporter {
RankingExpressionsConfig expressionsConfig,
OnnxModelsConfig onnxModelsConfig) {
try {
- Map<String, Model> models = new HashMap<>();
+ Map<String, Model> models = new TreeMap<>();
for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) {
Model model = importProfile(profile, constantsConfig, expressionsConfig, onnxModelsConfig);
models.put(model.name(), model);
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index 6a66202609b..d76bade6c1a 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -38,7 +38,7 @@ public class MlModelsImportingTest {
xgboost);
assertEquals("tensor()", function.returnType().get().toString());
assertEquals("f109, f29, f56, f60", commaSeparated(function.arguments()));
- function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg)));
+ function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.getArgumentType(arg)));
// Evaluator
FunctionEvaluator evaluator = xgboost.evaluatorOf();
@@ -56,7 +56,7 @@ public class MlModelsImportingTest {
lightgbm);
assertEquals("tensor()", function.returnType().get().toString());
assertEquals("categorical_1, categorical_2, numerical_1, numerical_2", commaSeparated(function.arguments()));
- function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg)));
+ function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.getArgumentType(arg)));
// Evaluator
FunctionEvaluator evaluator = lightgbm.evaluatorOf();
@@ -76,7 +76,7 @@ public class MlModelsImportingTest {
assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("Placeholder", function.arguments().get(0));
- assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString());
+ assertEquals("tensor(d0[],d1[784])", function.getArgumentType("Placeholder").toString());
// Evaluator
assertEquals("tensor(d1[10],d2[784])",
@@ -98,7 +98,7 @@ public class MlModelsImportingTest {
assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("Placeholder", function.arguments().get(0));
- assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString());
+ assertEquals("tensor(d0[],d1[784])", function.getArgumentType("Placeholder").toString());
// Evaluator
FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available
@@ -116,7 +116,7 @@ public class MlModelsImportingTest {
assertEquals("tensor(d3[300])", generatedFunction.returnType().get().toString());
assertEquals(1, generatedFunction.arguments().size());
assertEquals("input", generatedFunction.arguments().get(0));
- assertNull(null, generatedFunction.argumentTypes().get("input")); // TODO: Not available until we resolve all argument types
+ assertNull(null, generatedFunction.getArgumentType("input")); // TODO: Not available until we resolve all argument types
// Function
assertEquals(1, tfMnist.functions().size());
@@ -127,7 +127,7 @@ public class MlModelsImportingTest {
assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("input", function.arguments().get(0));
- assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("input").toString());
+ assertEquals("tensor(d0[],d1[784])", function.getArgumentType("input").toString());
// Evaluator
FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default");
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
index f09bac63085..3cd04db8edd 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
@@ -104,7 +104,7 @@ public class ModelsEvaluatorTest {
evaluator.evaluate();
}
catch (IllegalArgumentException e) {
- assertEquals("'argNone' is not a valid argument in function 'test'. Expected arguments: arg2: tensor(d1{}), arg1: tensor(d0[1])",
+ assertEquals("'argNone' is not a valid argument in function 'test'. Expected arguments: arg1: tensor(d0[1]), arg2: tensor(d1{})",
Exceptions.toMessageString(e));
}