diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2022-12-01 22:45:55 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2022-12-01 22:45:55 +0100 |
commit | 6a9681d7f3e42f29bd1d9de9fe9c271489b0c886 (patch) | |
tree | 227e5c3f2e1fcae248e0daf85735358b06971e33 /config-model/src | |
parent | 1eb22cc4a24973f52b344c3033cff394c724cbe4 (diff) |
Use well defined order where we output text and generate config. Makes config stable and simple tests predictable.
Diffstat (limited to 'config-model/src')
4 files changed, 11 insertions, 11 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java index 59f4035f34f..14ee60bb9a6 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java @@ -20,7 +20,6 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; -import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.RankProfilesConfig; import java.nio.ByteBuffer; @@ -196,7 +195,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { rankProperties = new ArrayList<>(compiled.getRankProperties()); Map<String, RankProfile.RankingExpressionFunction> functions = compiled.getFunctions(); - List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList()); + List<ExpressionFunction> functionExpressions = functions.values().stream().map(RankProfile.RankingExpressionFunction::function).collect(Collectors.toList()); Map<String, String> functionProperties = new LinkedHashMap<>(); SerializationContext functionSerializationContext = new SerializationContext(functionExpressions, Map.of(), @@ -248,8 +247,8 @@ public class RawRankProfile implements RankProfilesConfig.Producer { String expressionString = e.getValue().function().getBody().getRoot().toString(context).toString(); context.addFunctionSerialization(propertyName, expressionString); - for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet()) - context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue()); + e.getValue().function().argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey()) + .forEach(argumentType -> context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue())); if (e.getValue().function().returnType().isPresent()) context.addFunctionTypeSerialization(e.getKey(), e.getValue().function().returnType().get()); // else if (e.getValue().function().arguments().isEmpty()) TODO: Enable this check when we resolve all types diff --git a/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java b/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java index 3e7a1f7613b..871b79a7737 100644 --- a/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java +++ b/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java @@ -73,7 +73,7 @@ public class RankingExpressionTypeResolver extends Processor { for (String argument : expressionFunction.arguments()) { Reference ref = Reference.fromIdentifier(argument); if (context.getType(ref).equals(TensorType.empty)) { - context.setType(ref, expressionFunction.argumentTypes().get(argument)); + context.setType(ref, expressionFunction.getArgumentType(argument)); } } context.forgetResolvedTypes(); diff --git a/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java index dc72df9fc78..01e80e0f47a 100644 --- a/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java @@ -28,11 +28,12 @@ public class VespaMlModelTestCase { "constant(constant1).type : tensor(x[3])\n" + "constant(constant1).value : tensor(x[3]):[0.5, 1.5, 2.5]\n" + "rankingExpression(foo1).rankingScript : reduce(reduce(input1 * input2, sum, name) * constant(constant1), max, x) * 3.0\n" + - "rankingExpression(foo1).input2.type : tensor(x[3])\n" + "rankingExpression(foo1).input1.type : tensor(name{},x[3])\n" + + "rankingExpression(foo1).input2.type : tensor(x[3])\n" + "rankingExpression(foo2).rankingScript : reduce(reduce(input1 * input2, sum, name) * constant(constant1asLarge), max, x) * 3.0\n" + - "rankingExpression(foo2).input2.type : tensor(x[3])\n" + - "rankingExpression(foo2).input1.type : tensor(name{},x[3])\n"; + "rankingExpression(foo2).input1.type : tensor(name{},x[3])\n" + + "rankingExpression(foo2).input2.type : tensor(x[3])\n"; + /** The model name */ private final String name = "example"; diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index f4d37cc4b35..caf0d22d44e 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -153,8 +153,8 @@ public class ModelEvaluationTest { assertNotNull(evaluator.evaluatorOf("add_mul", "default.output2")); assertNotNull(evaluator.evaluatorOf("add_mul", "default", "output1")); assertNotNull(evaluator.evaluatorOf("add_mul", "default", "output2")); - assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input1")); - assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input2")); + assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).getArgumentType("input1")); + assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).getArgumentType("input2")); Model sqrt = evaluator.models().get("sqrt"); assertNotNull(sqrt); @@ -163,7 +163,7 @@ public class ModelEvaluationTest { assertNotNull(sqrt.evaluatorOf("out_layer_1_1")); // converted from "out/layer/1:1" assertNotNull(evaluator.evaluatorOf("sqrt")); assertNotNull(evaluator.evaluatorOf("sqrt", "out_layer_1_1")); - assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), sqrt.functions().get(0).argumentTypes().get("input")); + assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), sqrt.functions().get(0).getArgumentType("input")); } private final String profile = |