aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-09-11 11:20:39 +0200
committerJon Bratseth <bratseth@oath.com>2018-09-11 11:20:39 +0200
commit5a3519d9d26df9b18e680f2bb7dbe9e3f25bcb0b (patch)
treea76f2af048b1b4c79fd2ee850ada9fee90530908 /model-evaluation/src/main
parentbb4a0112dcb2f708ad64c550668a0842face9559 (diff)
Import small constants
Diffstat (limited to 'model-evaluation/src/main')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java50
1 files changed, 48 insertions, 2 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 87ac53488db..d2fca309a19 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -23,6 +23,8 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
/**
* Converts RankProfilesConfig instances to RankingExpressions for evaluation.
@@ -60,10 +62,11 @@ public class RankProfilesConfigImporter {
throws ParseException {
List<ExpressionFunction> functions = new ArrayList<>();
Map<FunctionReference, ExpressionFunction> referencedFunctions = new HashMap<>();
+ SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo();
ExpressionFunction firstPhase = null;
ExpressionFunction secondPhase = null;
- List<Constant> constants = readConstants(constantsConfig);
+ List<Constant> constants = readLargeConstants(constantsConfig);
for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) {
Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name());
@@ -86,12 +89,17 @@ public class RankProfilesConfigImporter {
secondPhase = new ExpressionFunction("secondphase", new ArrayList<>(),
new RankingExpression("second-phase", property.value()));
}
+ else {
+ smallConstantsInfo.addIfSmallConstantInfo(property.name(), property.value());
+ }
}
if (functionByName("firstphase", functions) == null && firstPhase != null) // may be already included, depending on body
functions.add(firstPhase);
if (functionByName("secondphase", functions) == null && secondPhase != null) // may be already included, depending on body
functions.add(secondPhase);
+ constants.addAll(smallConstantsInfo.asConstants());
+
try {
return new Model(profile.name(), functions, referencedFunctions, constants);
}
@@ -107,7 +115,7 @@ public class RankProfilesConfigImporter {
return null;
}
- private List<Constant> readConstants(RankingConstantsConfig constantsConfig) {
+ private List<Constant> readLargeConstants(RankingConstantsConfig constantsConfig) {
List<Constant> constants = new ArrayList<>();
for (RankingConstantsConfig.Constant constantConfig : constantsConfig.constant()) {
@@ -138,4 +146,42 @@ public class RankProfilesConfigImporter {
}
}
+ /** Collected information about small constants */
+ private static class SmallConstantsInfo {
+
+ private static final Pattern valuePattern = Pattern.compile("constant\\(([a-zA-Z0-9_.]+)\\)\\.value");
+ private static final Pattern typePattern = Pattern.compile("constant\\(([a-zA-Z0-9_.]+)\\)\\.type");
+
+ private Map<String, TensorType> types = new HashMap<>();
+ private Map<String, String> values = new HashMap<>();
+
+ void addIfSmallConstantInfo(String key, String value) {
+ tryValue(key, value);
+ tryType(key, value);
+ }
+
+ private void tryValue(String key, String value) {
+ Matcher matcher = valuePattern.matcher(key);
+ if (matcher.matches())
+ values.put(matcher.group(1), value);
+ }
+
+ private void tryType(String key, String value) {
+ Matcher matcher = typePattern.matcher(key);
+ if (matcher.matches())
+ types.put(matcher.group(1), TensorType.fromSpec(value));
+ }
+
+ List<Constant> asConstants() {
+ List<Constant> constants = new ArrayList<>();
+ for (Map.Entry<String, String> entry : values.entrySet()) {
+ TensorType type = types.get(entry.getKey());
+ if (type == null) throw new IllegalStateException("Missing type of '" + entry.getKey() + "'"); // Won't happen
+ constants.add(new Constant(entry.getKey(), Tensor.from(type, entry.getValue())));
+ }
+ return constants;
+ }
+
+ }
+
}