aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2022-12-01 22:45:55 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2022-12-01 22:45:55 +0100
commit6a9681d7f3e42f29bd1d9de9fe9c271489b0c886 (patch)
tree227e5c3f2e1fcae248e0daf85735358b06971e33 /config-model/src
parent1eb22cc4a24973f52b344c3033cff394c724cbe4 (diff)
Use well defined order where we output text and generate config. Makes config stable and simple tests predictable.
Diffstat (limited to 'config-model/src')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java7
-rw-r--r--config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java2
-rw-r--r--config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java7
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java6
4 files changed, 11 insertions, 11 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java
index 59f4035f34f..14ee60bb9a6 100644
--- a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java
@@ -20,7 +20,6 @@ import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
-import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import java.nio.ByteBuffer;
@@ -196,7 +195,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
rankProperties = new ArrayList<>(compiled.getRankProperties());
Map<String, RankProfile.RankingExpressionFunction> functions = compiled.getFunctions();
- List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList());
+ List<ExpressionFunction> functionExpressions = functions.values().stream().map(RankProfile.RankingExpressionFunction::function).collect(Collectors.toList());
Map<String, String> functionProperties = new LinkedHashMap<>();
SerializationContext functionSerializationContext = new SerializationContext(functionExpressions,
Map.of(),
@@ -248,8 +247,8 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
String expressionString = e.getValue().function().getBody().getRoot().toString(context).toString();
context.addFunctionSerialization(propertyName, expressionString);
- for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet())
- context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue());
+ e.getValue().function().argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey())
+ .forEach(argumentType -> context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue()));
if (e.getValue().function().returnType().isPresent())
context.addFunctionTypeSerialization(e.getKey(), e.getValue().function().returnType().get());
// else if (e.getValue().function().arguments().isEmpty()) TODO: Enable this check when we resolve all types
diff --git a/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java b/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java
index 3e7a1f7613b..871b79a7737 100644
--- a/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java
+++ b/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java
@@ -73,7 +73,7 @@ public class RankingExpressionTypeResolver extends Processor {
for (String argument : expressionFunction.arguments()) {
Reference ref = Reference.fromIdentifier(argument);
if (context.getType(ref).equals(TensorType.empty)) {
- context.setType(ref, expressionFunction.argumentTypes().get(argument));
+ context.setType(ref, expressionFunction.getArgumentType(argument));
}
}
context.forgetResolvedTypes();
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java
index dc72df9fc78..01e80e0f47a 100644
--- a/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java
@@ -28,11 +28,12 @@ public class VespaMlModelTestCase {
"constant(constant1).type : tensor(x[3])\n" +
"constant(constant1).value : tensor(x[3]):[0.5, 1.5, 2.5]\n" +
"rankingExpression(foo1).rankingScript : reduce(reduce(input1 * input2, sum, name) * constant(constant1), max, x) * 3.0\n" +
- "rankingExpression(foo1).input2.type : tensor(x[3])\n" +
"rankingExpression(foo1).input1.type : tensor(name{},x[3])\n" +
+ "rankingExpression(foo1).input2.type : tensor(x[3])\n" +
"rankingExpression(foo2).rankingScript : reduce(reduce(input1 * input2, sum, name) * constant(constant1asLarge), max, x) * 3.0\n" +
- "rankingExpression(foo2).input2.type : tensor(x[3])\n" +
- "rankingExpression(foo2).input1.type : tensor(name{},x[3])\n";
+ "rankingExpression(foo2).input1.type : tensor(name{},x[3])\n" +
+ "rankingExpression(foo2).input2.type : tensor(x[3])\n";
+
/** The model name */
private final String name = "example";
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index f4d37cc4b35..caf0d22d44e 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -153,8 +153,8 @@ public class ModelEvaluationTest {
assertNotNull(evaluator.evaluatorOf("add_mul", "default.output2"));
assertNotNull(evaluator.evaluatorOf("add_mul", "default", "output1"));
assertNotNull(evaluator.evaluatorOf("add_mul", "default", "output2"));
- assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input1"));
- assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input2"));
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).getArgumentType("input1"));
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).getArgumentType("input2"));
Model sqrt = evaluator.models().get("sqrt");
assertNotNull(sqrt);
@@ -163,7 +163,7 @@ public class ModelEvaluationTest {
assertNotNull(sqrt.evaluatorOf("out_layer_1_1")); // converted from "out/layer/1:1"
assertNotNull(evaluator.evaluatorOf("sqrt"));
assertNotNull(evaluator.evaluatorOf("sqrt", "out_layer_1_1"));
- assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), sqrt.functions().get(0).argumentTypes().get("input"));
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), sqrt.functions().get(0).getArgumentType("input"));
}
private final String profile =