aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/schema/parser/ParsedRankProfile.java16
-rw-r--r--config-model/src/main/javacc/SchemaParser.jj33
-rw-r--r--config-model/src/test/java/com/yahoo/schema/parser/SchemaParserTestCase.java37
3 files changed, 84 insertions, 2 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/parser/ParsedRankProfile.java b/config-model/src/main/java/com/yahoo/schema/parser/ParsedRankProfile.java
index 64dd8dd0ad4..2809ee0c633 100644
--- a/config-model/src/main/java/com/yahoo/schema/parser/ParsedRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/schema/parser/ParsedRankProfile.java
@@ -53,6 +53,8 @@ class ParsedRankProfile extends ParsedBlock {
private final Map<Reference, RankProfile.Constant> constants = new LinkedHashMap<>();
private final Map<Reference, RankProfile.Input> inputs = new LinkedHashMap<>();
private final List<OnnxModel> onnxModels = new ArrayList<>();
+ private Integer globalPhaseRerankCount = null;
+ private String globalPhaseExpression = null;
ParsedRankProfile(String name) {
super(name, "rank-profile");
@@ -77,6 +79,8 @@ class ParsedRankProfile extends ParsedBlock {
List<ParsedRankFunction> getFunctions() { return List.copyOf(functions.values()); }
List<MutateOperation> getMutateOperations() { return List.copyOf(mutateOperations); }
List<String> getInherited() { return List.copyOf(inherited); }
+ Optional<Integer> getGlobalPhaseRerankCount() { return Optional.ofNullable(this.globalPhaseRerankCount); }
+ Optional<String> getGlobalPhaseExpression() { return Optional.ofNullable(this.globalPhaseExpression); }
Map<String, Boolean> getFieldsWithRankFilter() { return Collections.unmodifiableMap(fieldsRankFilter); }
Map<String, Integer> getFieldsWithRankWeight() { return Collections.unmodifiableMap(fieldsRankWeight); }
@@ -197,11 +201,21 @@ class ParsedRankProfile extends ParsedBlock {
this.secondPhaseExpression = expression;
}
+ void setGlobalPhaseExpression(String expression) {
+ verifyThat(globalPhaseExpression == null, "already has global-phase expression");
+ this.globalPhaseExpression = expression;
+ }
+
+ void setGlobalPhaseRerankCount(int count) {
+ verifyThat(globalPhaseRerankCount == null, "already has global-phase rerank-count");
+ this.globalPhaseRerankCount = count;
+ }
+
void setStrict(boolean strict) {
verifyThat(this.strict == null, "already has strict");
this.strict = strict;
}
-
+
void setTermwiseLimit(double limit) {
verifyThat(termwiseLimit == null, "already has termwise-limit");
this.termwiseLimit = limit;
diff --git a/config-model/src/main/javacc/SchemaParser.jj b/config-model/src/main/javacc/SchemaParser.jj
index a9c67a0bb60..fa9d34139ea 100644
--- a/config-model/src/main/javacc/SchemaParser.jj
+++ b/config-model/src/main/javacc/SchemaParser.jj
@@ -279,6 +279,7 @@ TOKEN :
| < MAXHITS: "max-hits" >
| < FIRSTPHASE: "first-phase" >
| < SECONDPHASE: "second-phase" >
+| < GLOBALPHASE: "global-phase" >
| < MACRO: "macro" >
| < INLINE: "inline" >
| < ARITY: "arity" >
@@ -1706,6 +1707,7 @@ void rankProfileItem(ParsedSchema schema, ParsedRankProfile profile) : { }
| rankFeatures(profile)
| rankProperties(profile)
| secondPhase(profile)
+ | globalPhase(profile)
| inputs(profile)
| constants(schema, profile)
| matchFeatures(profile)
@@ -1923,6 +1925,34 @@ void secondPhaseItem(ParsedRankProfile profile) :
)
}
+/**
+ * Consumes the global-phase block of a rank profile.
+ *
+ * @param profile The rank profile to modify.
+ */
+void globalPhase(ParsedRankProfile profile) : { }
+{
+ <GLOBALPHASE> lbrace() (globalPhaseItem(profile) (<NL>)*)* <RBRACE>
+}
+
+/**
+ * Consumes a statement for a global-phase block.
+ *
+ * @param profile The rank profile to modify.
+ */
+void globalPhaseItem(ParsedRankProfile profile) :
+{
+ String expression;
+ int rerankCount;
+}
+{
+ ( expression = expression() { profile.setGlobalPhaseExpression(expression); }
+ | (<RERANKCOUNT> <COLON> rerankCount = integer()) { profile.setGlobalPhaseRerankCount(rerankCount); }
+ )
+}
+
+
+
/** Consumes an inputs block of a rank profile. */
void inputs(ParsedRankProfile profile) :
{
@@ -2519,7 +2549,7 @@ String expression() :
( <EXPRESSION_SL> { exp = token.image.substring(token.image.indexOf(":") + 1); } |
<EXPRESSION_ML> { exp = token.image.substring(token.image.indexOf("{") + 1,
token.image.lastIndexOf("}")); } )
- { return exp; }
+ { return exp.trim(); }
}
String identifierWithDash() :
@@ -2555,6 +2585,7 @@ String identifier() : { }
| <CONSTANT>
| <CONSTANTS>
| <CONTEXT>
+ | <GLOBALPHASE>
| <CREATEIFNONEXISTENT>
| <DENSEPOSTINGLISTTHRESHOLD>
| <DESCENDING>
diff --git a/config-model/src/test/java/com/yahoo/schema/parser/SchemaParserTestCase.java b/config-model/src/test/java/com/yahoo/schema/parser/SchemaParserTestCase.java
index 150c237bbba..e69f26a31c9 100644
--- a/config-model/src/test/java/com/yahoo/schema/parser/SchemaParserTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/parser/SchemaParserTestCase.java
@@ -83,6 +83,43 @@ public class SchemaParserTestCase {
"}\n")).getMessage());
}
+ @Test
+ void global_phase_can_be_parsed() throws Exception {
+ String input = """
+ schema foo {
+ rank-profile normal {
+ first-phase {
+ expression {
+ rankingExpression(1.0)
+ }
+ }
+ }
+ rank-profile bar {
+ global-phase {
+ expression: onnx(mymodel)
+ rerank-count: 79
+ }
+ }
+ }
+ """;
+ ParsedSchema schema = parseString(input);
+ assertEquals("foo", schema.name());
+ var rplist = schema.getRankProfiles();
+ assertEquals(2, rplist.size());
+ var rp0 = rplist.get(0);
+ assertEquals("normal", rp0.name());
+ assertFalse(rp0.getGlobalPhaseRerankCount().isPresent());
+ assertFalse(rp0.getGlobalPhaseExpression().isPresent());
+ assertTrue(rp0.getFirstPhaseExpression().isPresent());
+ assertEquals("rankingExpression(1.0)", rp0.getFirstPhaseExpression().get());
+ var rp1 = rplist.get(1);
+ assertEquals("bar", rp1.name());
+ assertTrue(rp1.getGlobalPhaseRerankCount().isPresent());
+ assertTrue(rp1.getGlobalPhaseExpression().isPresent());
+ assertEquals(79, rp1.getGlobalPhaseRerankCount().get());
+ assertEquals("onnx(mymodel)", rp1.getGlobalPhaseExpression().get());
+ }
+
void checkFileParses(String fileName) throws Exception {
var schema = parseFile(fileName);
assertNotNull(schema);