summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-05-31 19:31:45 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-05-31 19:31:45 +0200
commite9e5a422c0aa6364c3c5f7b9da53e9fcf9a5f0f8 (patch)
tree073747aaeb79aea6ac7f1f9193513e064ca3b006 /model-integration
parent986c2da2986a2fc0de4895a8107c85e4d0f37fd3 (diff)
Support small constants
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/javacc/ModelParser.jj139
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java12
-rw-r--r--model-integration/src/test/models/vespa/example.model10
3 files changed, 84 insertions, 77 deletions
diff --git a/model-integration/src/main/javacc/ModelParser.jj b/model-integration/src/main/javacc/ModelParser.jj
index 7604259e850..18dfb4c68ed 100644
--- a/model-integration/src/main/javacc/ModelParser.jj
+++ b/model-integration/src/main/javacc/ModelParser.jj
@@ -78,6 +78,7 @@ TOKEN :
< NL: "\n" >
| < FUNCTION: "function" >
| < TENSOR_TYPE: "tensor(" (~["(",")"])+ ")" >
+| < TENSORVALUE: (" ")* ":" (" ")* ("{"<BRACE_SL_LEVEL_1>) ("\n")? >
| < TENSOR_VALUE_SL: "value" (" ")* ":" (" ")* ("{"<BRACE_SL_LEVEL_1>) ("\n")? >
| < TENSOR_VALUE_ML: "value" (<SEARCHLIB_SKIP>)? "{" (["\n"," "])* ("{"<BRACE_ML_LEVEL_1>) (["\n"," "])* "}" ("\n")? >
| < LBRACE: "{" >
@@ -146,7 +147,7 @@ void modelContent() :
{
}
{
- ( <NL> | input() | function() )*
+ ( <NL> | input() | constants() | function() )*
}
/** Declared input variables (aka features). All non-scalar inputs must be declared. */
@@ -183,68 +184,14 @@ void function() :
}
}
-/** Consumes a constant block of model. */
-/*
-void rankingConstant() :
-{
- String name;
- RankingConstant constant;
-}
-{
- ( <CONSTANT> name = identifier()
- {
-// constant = new RankingConstant(name);
- }
- lbrace() (rankingConstantItem(constant) (<NL>)*)+ <RBRACE> )
- {
- }
-}
-*/
-
-/** Consumes a constant block. */
-/*
-void rankingConstantItem(RankingConstant constant) :
-{
- String path = null;
- TensorType type = null;
-}
-{
- ( (<FILE> <COLON> path = filePath() { } (<NL>)*) { constant.setFileName(path); }
- | (<URI> <COLON> path = uriPath() { } (<NL>)*) { constant.setUri(path); }
- | type = tensorTypeWithPrefix(rankingConstantErrorMessage(constant.getName())) (<NL>)* { constant.setType(type); }
- )
- {
- return null;
- }
-}
-*/
-
-String rankingConstantErrorMessage(String name) : {}
-{
- { return "For ranking constant ' " + name + "'"; }
-}
-
-String filePath() : { }
-{
- ( <FILE_PATH> | <STRING> | <IDENTIFIER>)
- { return token.image; }
-}
-
-String uriPath() : { }
-{
- ( <URI_PATH> )
- { return token.image; }
-}
-
/** Consumes the constants of this model. */
-void constants(ImportedModel model) :
+void constants() :
{
String name;
}
{
<CONSTANTS> <LBRACE> (<NL>)*
- ( name = identifier() ( constantDouble(name) |
- constantTensor(name) ) (<NL>)* )*
+ ( name = identifier() <COLON> ( constantDouble(name) | constantTensor(name) ) (<NL>)* )*
<RBRACE>
}
@@ -253,25 +200,19 @@ void constantDouble(String name) :
Token value;
}
{
- <COLON> value = <DOUBLE> { model.smallConstant(name, Tensor.from(Double.parseDouble(token.image))); }
+ value = <DOUBLE> { model.smallConstant(name, Tensor.from(Double.parseDouble(value.image))); }
}
void constantTensor(String name) :
{
- String tensorString = "";
- TensorType tensorType = null;
+ TensorType type;
+ Token value;
}
{
- <LBRACE> (<NL>)*
- (( tensorString = tensorValue() |
- tensorType = tensorTypeWithPrefix(constantTensorErrorMessage(model.name(), name)) ) (<NL>)* )* <RBRACE>
- {
- if (tensorType != null) {
- model.smallConstant(name, Tensor.from(tensorType, tensorString));
- } else {
- model.smallConstant(name, Tensor.from(tensorString));
- }
- }
+ type = tensorType("constant '" + name + "'") value = <TENSORVALUE>
+ {
+ model.smallConstant(name, Tensor.from(type, value.image.substring(1)));
+ }
}
String constantTensorErrorMessage(String model, String constantTensorName) : {}
@@ -297,7 +238,7 @@ TensorType tensorTypeWithPrefix(String errorMessage) :
TensorType type;
}
{
- <TYPE> <COLON> type= tensorType(errorMessage)
+ <TYPE> <COLON> type=tensorType(errorMessage)
{ return type; }
}
@@ -318,6 +259,61 @@ TensorType tensorType(String errorMessage) :
}
}
+//----------------------------------------
+/** Consumes a constant block of model. */
+/*
+void largeConstant() :
+{
+ String name;
+ RankingConstant constant;
+}
+{
+ ( <CONSTANT> name = identifier()
+ {
+// constant = new RankingConstant(name);
+ }
+ lbrace() (rankingConstantItem(constant) (<NL>)*)+ <RBRACE> )
+ {
+ }
+}
+*/
+
+/** Consumes a constant block. */
+/*
+void rankingConstantItem(RankingConstant constant) :
+{
+ String path = null;
+ TensorType type = null;
+}
+{
+ ( (<FILE> <COLON> path = filePath() { } (<NL>)*) { constant.setFileName(path); }
+ | (<URI> <COLON> path = uriPath() { } (<NL>)*) { constant.setUri(path); }
+ | type = tensorTypeWithPrefix(rankingConstantErrorMessage(constant.getName())) (<NL>)* { constant.setType(type); }
+ )
+ {
+ return null;
+ }
+}
+*/
+
+String rankingConstantErrorMessage(String name) : {}
+{
+ { return "For ranking constant ' " + name + "'"; }
+}
+
+String filePath() : { }
+{
+ ( <FILE_PATH> | <STRING> | <IDENTIFIER>)
+ { return token.image; }
+}
+
+String uriPath() : { }
+{
+ ( <URI_PATH> )
+ { return token.image; }
+}
+//----------------------------------------
+
/** Consumes an expression token and returns its image. */
String expression() :
{
@@ -340,7 +336,6 @@ String identifier() : { }
| <URI>
| <MODEL>
| <TYPE>
- | <CONSTANTS>
)
{ return token.image; }
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
index 4f9fb9c070a..1be2b7a4183 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
@@ -17,10 +17,16 @@ public class VespaImportTestCase {
public void testExample() {
ImportedModel model = importModel("example");
- assertEquals(1, model.inputs().size());
- assertEquals("tensor(name{},x[10])", model.inputs().get("input1").toString());
+ assertEquals(2, model.inputs().size());
+ assertEquals("tensor(name{},x[3])", model.inputs().get("input1").toString());
+ assertEquals("tensor(x[3])", model.inputs().get("input2").toString());
- assertEquals("var1 * var2", model.expressions().get("foo").getRoot().toString());
+ assertEquals(2, model.smallConstants().size());
+ assertEquals("tensor(x[3]):{{x:0}:0.5,{x:1}:1.5,{x:2}:2.5}", model.smallConstants().get("constant1"));
+ assertEquals("tensor():{3.0}", model.smallConstants().get("constant2"));
+
+ assertEquals("max(reduce(input1 * input2, sum, name),x) * constant2",
+ model.expressions().get("foo").getRoot().toString());
}
@Test
diff --git a/model-integration/src/test/models/vespa/example.model b/model-integration/src/test/models/vespa/example.model
index 19598690aad..c0ea461db09 100644
--- a/model-integration/src/test/models/vespa/example.model
+++ b/model-integration/src/test/models/vespa/example.model
@@ -1,10 +1,16 @@
model example {
- input1: tensor(name{}, x[10])
+ # All inputs that are not scalar (aka 0-dimensional tensor) must be declared
+ input1: tensor(name{}, x[3])
+ input2: tensor(x[3])
+ constants {
+ constant1: tensor(x[3]):{{x:0}:0.5, {x:1}:1.5, {x:2}:2.5}
+ constant2: 3.0
+ }
function foo() {
- expression: var1 * var2
+ expression: max(sum(input1 * input2, name), x) * constant2
}
} \ No newline at end of file