summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2022-12-01 22:45:55 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2022-12-01 22:45:55 +0100
commit6a9681d7f3e42f29bd1d9de9fe9c271489b0c886 (patch)
tree227e5c3f2e1fcae248e0daf85735358b06971e33
parent1eb22cc4a24973f52b344c3033cff394c724cbe4 (diff)
Use well defined order where we output text and generate config. Makes config stable and simple tests predictable.
-rw-r--r--config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java7
-rw-r--r--config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java2
-rw-r--r--config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java7
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java6
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java16
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java3
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java3
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java12
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java2
-rw-r--r--searchlib/abi-spec.json1
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java1
12 files changed, 32 insertions, 30 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java
index 59f4035f34f..14ee60bb9a6 100644
--- a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java
@@ -20,7 +20,6 @@ import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
-import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import java.nio.ByteBuffer;
@@ -196,7 +195,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
rankProperties = new ArrayList<>(compiled.getRankProperties());
Map<String, RankProfile.RankingExpressionFunction> functions = compiled.getFunctions();
- List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList());
+ List<ExpressionFunction> functionExpressions = functions.values().stream().map(RankProfile.RankingExpressionFunction::function).collect(Collectors.toList());
Map<String, String> functionProperties = new LinkedHashMap<>();
SerializationContext functionSerializationContext = new SerializationContext(functionExpressions,
Map.of(),
@@ -248,8 +247,8 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
String expressionString = e.getValue().function().getBody().getRoot().toString(context).toString();
context.addFunctionSerialization(propertyName, expressionString);
- for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet())
- context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue());
+ e.getValue().function().argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey())
+ .forEach(argumentType -> context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue()));
if (e.getValue().function().returnType().isPresent())
context.addFunctionTypeSerialization(e.getKey(), e.getValue().function().returnType().get());
// else if (e.getValue().function().arguments().isEmpty()) TODO: Enable this check when we resolve all types
diff --git a/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java b/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java
index 3e7a1f7613b..871b79a7737 100644
--- a/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java
+++ b/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java
@@ -73,7 +73,7 @@ public class RankingExpressionTypeResolver extends Processor {
for (String argument : expressionFunction.arguments()) {
Reference ref = Reference.fromIdentifier(argument);
if (context.getType(ref).equals(TensorType.empty)) {
- context.setType(ref, expressionFunction.argumentTypes().get(argument));
+ context.setType(ref, expressionFunction.getArgumentType(argument));
}
}
context.forgetResolvedTypes();
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java
index dc72df9fc78..01e80e0f47a 100644
--- a/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java
@@ -28,11 +28,12 @@ public class VespaMlModelTestCase {
"constant(constant1).type : tensor(x[3])\n" +
"constant(constant1).value : tensor(x[3]):[0.5, 1.5, 2.5]\n" +
"rankingExpression(foo1).rankingScript : reduce(reduce(input1 * input2, sum, name) * constant(constant1), max, x) * 3.0\n" +
- "rankingExpression(foo1).input2.type : tensor(x[3])\n" +
"rankingExpression(foo1).input1.type : tensor(name{},x[3])\n" +
+ "rankingExpression(foo1).input2.type : tensor(x[3])\n" +
"rankingExpression(foo2).rankingScript : reduce(reduce(input1 * input2, sum, name) * constant(constant1asLarge), max, x) * 3.0\n" +
- "rankingExpression(foo2).input2.type : tensor(x[3])\n" +
- "rankingExpression(foo2).input1.type : tensor(name{},x[3])\n";
+ "rankingExpression(foo2).input1.type : tensor(name{},x[3])\n" +
+ "rankingExpression(foo2).input2.type : tensor(x[3])\n";
+
/** The model name */
private final String name = "example";
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index f4d37cc4b35..caf0d22d44e 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -153,8 +153,8 @@ public class ModelEvaluationTest {
assertNotNull(evaluator.evaluatorOf("add_mul", "default.output2"));
assertNotNull(evaluator.evaluatorOf("add_mul", "default", "output1"));
assertNotNull(evaluator.evaluatorOf("add_mul", "default", "output2"));
- assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input1"));
- assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input2"));
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).getArgumentType("input1"));
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).getArgumentType("input2"));
Model sqrt = evaluator.models().get("sqrt");
assertNotNull(sqrt);
@@ -163,7 +163,7 @@ public class ModelEvaluationTest {
assertNotNull(sqrt.evaluatorOf("out_layer_1_1")); // converted from "out/layer/1:1"
assertNotNull(evaluator.evaluatorOf("sqrt"));
assertNotNull(evaluator.evaluatorOf("sqrt", "out_layer_1_1"));
- assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), sqrt.functions().get(0).argumentTypes().get("input"));
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), sqrt.functions().get(0).getArgumentType("input"));
}
private final String profile =
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index 1d3da73a509..3fada4c8b6d 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -38,12 +38,13 @@ public class FunctionEvaluator {
public FunctionEvaluator bind(String name, Tensor value) {
if (evaluated)
throw new IllegalStateException("Cannot bind a new value in a used evaluator");
- TensorType requiredType = function.argumentTypes().get(name);
+ TensorType requiredType = function.getArgumentType(name);
if (requiredType == null)
throw new IllegalArgumentException("'" + name + "' is not a valid argument in " + function +
- ". Expected arguments: " + function.argumentTypes().entrySet().stream()
- .map(e -> e.getKey() + ": " + e.getValue())
- .collect(Collectors.joining(", ")));
+ ". Expected arguments: " +
+ function.argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey())
+ .map(e -> e.getKey() + ": " + e.getValue())
+ .collect(Collectors.joining(", ")));
if ( ! value.type().isAssignableTo(requiredType))
throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type());
context.put(name, new TensorValue(value));
@@ -101,11 +102,8 @@ public class FunctionEvaluator {
}
public Tensor evaluate() {
- function.argumentTypes().keySet().stream().sorted()
- .forEach(name -> {
- var type = function.argumentTypes().get(name);
- checkArgument(name, type);
- });
+ function.argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey())
+ .forEach(argument -> checkArgument(argument.getKey(), argument.getValue()));
evaluated = true;
evaluateOnnxModels();
return function.getBody().evaluate(context).asTensor();
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index ffcfb5e9379..d66d0330ea6 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -84,7 +84,7 @@ public class Model {
}
else {
// External functions have type info (when not scalar) - add argument types
- if (function.getValue().argumentTypes().get(argument) == null)
+ if (function.getValue().getArgumentType(argument) == null)
functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty));
}
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index 40a503e0212..28b613ca281 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -10,6 +10,7 @@ import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
+import java.util.Collections;
import java.util.Map;
/**
@@ -42,7 +43,7 @@ public class ModelsEvaluator extends AbstractComponent {
}
public ModelsEvaluator(Map<String, Model> models) {
- this.models = Map.copyOf(models);
+ this.models = Collections.unmodifiableMap(models);
}
/** Returns the models of this as an immutable map */
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index c2cb1993fc0..83674d6789e 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -32,6 +32,7 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.TreeMap;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -59,7 +60,7 @@ public class RankProfilesConfigImporter {
RankingExpressionsConfig expressionsConfig,
OnnxModelsConfig onnxModelsConfig) {
try {
- Map<String, Model> models = new HashMap<>();
+ Map<String, Model> models = new TreeMap<>();
for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) {
Model model = importProfile(profile, constantsConfig, expressionsConfig, onnxModelsConfig);
models.put(model.name(), model);
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index 6a66202609b..d76bade6c1a 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -38,7 +38,7 @@ public class MlModelsImportingTest {
xgboost);
assertEquals("tensor()", function.returnType().get().toString());
assertEquals("f109, f29, f56, f60", commaSeparated(function.arguments()));
- function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg)));
+ function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.getArgumentType(arg)));
// Evaluator
FunctionEvaluator evaluator = xgboost.evaluatorOf();
@@ -56,7 +56,7 @@ public class MlModelsImportingTest {
lightgbm);
assertEquals("tensor()", function.returnType().get().toString());
assertEquals("categorical_1, categorical_2, numerical_1, numerical_2", commaSeparated(function.arguments()));
- function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg)));
+ function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.getArgumentType(arg)));
// Evaluator
FunctionEvaluator evaluator = lightgbm.evaluatorOf();
@@ -76,7 +76,7 @@ public class MlModelsImportingTest {
assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("Placeholder", function.arguments().get(0));
- assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString());
+ assertEquals("tensor(d0[],d1[784])", function.getArgumentType("Placeholder").toString());
// Evaluator
assertEquals("tensor(d1[10],d2[784])",
@@ -98,7 +98,7 @@ public class MlModelsImportingTest {
assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("Placeholder", function.arguments().get(0));
- assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString());
+ assertEquals("tensor(d0[],d1[784])", function.getArgumentType("Placeholder").toString());
// Evaluator
FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available
@@ -116,7 +116,7 @@ public class MlModelsImportingTest {
assertEquals("tensor(d3[300])", generatedFunction.returnType().get().toString());
assertEquals(1, generatedFunction.arguments().size());
assertEquals("input", generatedFunction.arguments().get(0));
- assertNull(null, generatedFunction.argumentTypes().get("input")); // TODO: Not available until we resolve all argument types
+ assertNull(null, generatedFunction.getArgumentType("input")); // TODO: Not available until we resolve all argument types
// Function
assertEquals(1, tfMnist.functions().size());
@@ -127,7 +127,7 @@ public class MlModelsImportingTest {
assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("input", function.arguments().get(0));
- assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("input").toString());
+ assertEquals("tensor(d0[],d1[784])", function.getArgumentType("input").toString());
// Evaluator
FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default");
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
index f09bac63085..3cd04db8edd 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
@@ -104,7 +104,7 @@ public class ModelsEvaluatorTest {
evaluator.evaluate();
}
catch (IllegalArgumentException e) {
- assertEquals("'argNone' is not a valid argument in function 'test'. Expected arguments: arg2: tensor(d1{}), arg1: tensor(d0[1])",
+ assertEquals("'argNone' is not a valid argument in function 'test'. Expected arguments: arg1: tensor(d0[1]), arg2: tensor(d1{})",
Exceptions.toMessageString(e));
}
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 16c1b2c0e7d..5413907e967 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -277,6 +277,7 @@
"public java.util.List arguments()",
"public com.yahoo.searchlib.rankingexpression.RankingExpression getBody()",
"public java.util.Map argumentTypes()",
+ "public com.yahoo.tensor.TensorType getArgumentType(java.lang.String)",
"public java.util.Optional returnType()",
"public com.yahoo.searchlib.rankingexpression.ExpressionFunction withName(java.lang.String)",
"public com.yahoo.searchlib.rankingexpression.ExpressionFunction withBody(com.yahoo.searchlib.rankingexpression.RankingExpression)",
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
index 241a53fb458..171151bfdf4 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
@@ -83,6 +83,7 @@ public class ExpressionFunction {
/** Returns the types of the arguments of this, if specified. The keys of this may be any subset of the arguments */
public Map<String, TensorType> argumentTypes() { return argumentTypes; }
+ public TensorType getArgumentType(String argumentName) { return argumentTypes.get(argumentName); }
/** Returns the return type of this, or empty if not specified */
public Optional<TensorType> returnType() { return returnType; }