summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorHarald Musum <musum@verizonmedia.com>2022-12-02 01:11:20 +0100
committerGitHub <noreply@github.com>2022-12-02 01:11:20 +0100
commit4fa41e13f4baa0d8927e516c6db594b8f4ec8a3e (patch)
treeda04fd3f1ed4275a341c6a1bda092e27e9f3c9d6 /config-model
parentffdaafffd90a2a8cb1522c7e131f13fc718be3f7 (diff)
parent6a9681d7f3e42f29bd1d9de9fe9c271489b0c886 (diff)
Merge pull request #25065 from vespa-engine/revert-25064-revert-25062-balder/gc-even-more-guava-usage
Revert "Revert "- Reduce usage of guava.""
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java7
-rw-r--r--config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java2
-rw-r--r--config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java7
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java6
4 files changed, 11 insertions, 11 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java
index 59f4035f34f..14ee60bb9a6 100644
--- a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java
@@ -20,7 +20,6 @@ import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
-import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import java.nio.ByteBuffer;
@@ -196,7 +195,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
rankProperties = new ArrayList<>(compiled.getRankProperties());
Map<String, RankProfile.RankingExpressionFunction> functions = compiled.getFunctions();
- List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList());
+ List<ExpressionFunction> functionExpressions = functions.values().stream().map(RankProfile.RankingExpressionFunction::function).collect(Collectors.toList());
Map<String, String> functionProperties = new LinkedHashMap<>();
SerializationContext functionSerializationContext = new SerializationContext(functionExpressions,
Map.of(),
@@ -248,8 +247,8 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
String expressionString = e.getValue().function().getBody().getRoot().toString(context).toString();
context.addFunctionSerialization(propertyName, expressionString);
- for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet())
- context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue());
+ e.getValue().function().argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey())
+ .forEach(argumentType -> context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue()));
if (e.getValue().function().returnType().isPresent())
context.addFunctionTypeSerialization(e.getKey(), e.getValue().function().returnType().get());
// else if (e.getValue().function().arguments().isEmpty()) TODO: Enable this check when we resolve all types
diff --git a/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java b/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java
index 3e7a1f7613b..871b79a7737 100644
--- a/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java
+++ b/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java
@@ -73,7 +73,7 @@ public class RankingExpressionTypeResolver extends Processor {
for (String argument : expressionFunction.arguments()) {
Reference ref = Reference.fromIdentifier(argument);
if (context.getType(ref).equals(TensorType.empty)) {
- context.setType(ref, expressionFunction.argumentTypes().get(argument));
+ context.setType(ref, expressionFunction.getArgumentType(argument));
}
}
context.forgetResolvedTypes();
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java
index dc72df9fc78..01e80e0f47a 100644
--- a/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/processing/VespaMlModelTestCase.java
@@ -28,11 +28,12 @@ public class VespaMlModelTestCase {
"constant(constant1).type : tensor(x[3])\n" +
"constant(constant1).value : tensor(x[3]):[0.5, 1.5, 2.5]\n" +
"rankingExpression(foo1).rankingScript : reduce(reduce(input1 * input2, sum, name) * constant(constant1), max, x) * 3.0\n" +
- "rankingExpression(foo1).input2.type : tensor(x[3])\n" +
"rankingExpression(foo1).input1.type : tensor(name{},x[3])\n" +
+ "rankingExpression(foo1).input2.type : tensor(x[3])\n" +
"rankingExpression(foo2).rankingScript : reduce(reduce(input1 * input2, sum, name) * constant(constant1asLarge), max, x) * 3.0\n" +
- "rankingExpression(foo2).input2.type : tensor(x[3])\n" +
- "rankingExpression(foo2).input1.type : tensor(name{},x[3])\n";
+ "rankingExpression(foo2).input1.type : tensor(name{},x[3])\n" +
+ "rankingExpression(foo2).input2.type : tensor(x[3])\n";
+
/** The model name */
private final String name = "example";
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index f4d37cc4b35..caf0d22d44e 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -153,8 +153,8 @@ public class ModelEvaluationTest {
assertNotNull(evaluator.evaluatorOf("add_mul", "default.output2"));
assertNotNull(evaluator.evaluatorOf("add_mul", "default", "output1"));
assertNotNull(evaluator.evaluatorOf("add_mul", "default", "output2"));
- assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input1"));
- assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input2"));
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).getArgumentType("input1"));
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).getArgumentType("input2"));
Model sqrt = evaluator.models().get("sqrt");
assertNotNull(sqrt);
@@ -163,7 +163,7 @@ public class ModelEvaluationTest {
assertNotNull(sqrt.evaluatorOf("out_layer_1_1")); // converted from "out/layer/1:1"
assertNotNull(evaluator.evaluatorOf("sqrt"));
assertNotNull(evaluator.evaluatorOf("sqrt", "out_layer_1_1"));
- assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), sqrt.functions().get(0).argumentTypes().get("input"));
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), sqrt.functions().get(0).getArgumentType("input"));
}
private final String profile =