diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-05-31 19:31:45 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-05-31 19:31:45 +0200 |
commit | e9e5a422c0aa6364c3c5f7b9da53e9fcf9a5f0f8 (patch) | |
tree | 073747aaeb79aea6ac7f1f9193513e064ca3b006 /model-integration | |
parent | 986c2da2986a2fc0de4895a8107c85e4d0f37fd3 (diff) |
Support small constants
Diffstat (limited to 'model-integration')
3 files changed, 84 insertions, 77 deletions
diff --git a/model-integration/src/main/javacc/ModelParser.jj b/model-integration/src/main/javacc/ModelParser.jj index 7604259e850..18dfb4c68ed 100644 --- a/model-integration/src/main/javacc/ModelParser.jj +++ b/model-integration/src/main/javacc/ModelParser.jj @@ -78,6 +78,7 @@ TOKEN : < NL: "\n" > | < FUNCTION: "function" > | < TENSOR_TYPE: "tensor(" (~["(",")"])+ ")" > +| < TENSORVALUE: (" ")* ":" (" ")* ("{"<BRACE_SL_LEVEL_1>) ("\n")? > | < TENSOR_VALUE_SL: "value" (" ")* ":" (" ")* ("{"<BRACE_SL_LEVEL_1>) ("\n")? > | < TENSOR_VALUE_ML: "value" (<SEARCHLIB_SKIP>)? "{" (["\n"," "])* ("{"<BRACE_ML_LEVEL_1>) (["\n"," "])* "}" ("\n")? > | < LBRACE: "{" > @@ -146,7 +147,7 @@ void modelContent() : { } { - ( <NL> | input() | function() )* + ( <NL> | input() | constants() | function() )* } /** Declared input variables (aka features). All non-scalar inputs must be declared. */ @@ -183,68 +184,14 @@ void function() : } } -/** Consumes a constant block of model. */ -/* -void rankingConstant() : -{ - String name; - RankingConstant constant; -} -{ - ( <CONSTANT> name = identifier() - { -// constant = new RankingConstant(name); - } - lbrace() (rankingConstantItem(constant) (<NL>)*)+ <RBRACE> ) - { - } -} -*/ - -/** Consumes a constant block. */ -/* -void rankingConstantItem(RankingConstant constant) : -{ - String path = null; - TensorType type = null; -} -{ - ( (<FILE> <COLON> path = filePath() { } (<NL>)*) { constant.setFileName(path); } - | (<URI> <COLON> path = uriPath() { } (<NL>)*) { constant.setUri(path); } - | type = tensorTypeWithPrefix(rankingConstantErrorMessage(constant.getName())) (<NL>)* { constant.setType(type); } - ) - { - return null; - } -} -*/ - -String rankingConstantErrorMessage(String name) : {} -{ - { return "For ranking constant ' " + name + "'"; } -} - -String filePath() : { } -{ - ( <FILE_PATH> | <STRING> | <IDENTIFIER>) - { return token.image; } -} - -String uriPath() : { } -{ - ( <URI_PATH> ) - { return token.image; } -} - /** Consumes the constants of this model. */ -void constants(ImportedModel model) : +void constants() : { String name; } { <CONSTANTS> <LBRACE> (<NL>)* - ( name = identifier() ( constantDouble(name) | - constantTensor(name) ) (<NL>)* )* + ( name = identifier() <COLON> ( constantDouble(name) | constantTensor(name) ) (<NL>)* )* <RBRACE> } @@ -253,25 +200,19 @@ void constantDouble(String name) : Token value; } { - <COLON> value = <DOUBLE> { model.smallConstant(name, Tensor.from(Double.parseDouble(token.image))); } + value = <DOUBLE> { model.smallConstant(name, Tensor.from(Double.parseDouble(value.image))); } } void constantTensor(String name) : { - String tensorString = ""; - TensorType tensorType = null; + TensorType type; + Token value; } { - <LBRACE> (<NL>)* - (( tensorString = tensorValue() | - tensorType = tensorTypeWithPrefix(constantTensorErrorMessage(model.name(), name)) ) (<NL>)* )* <RBRACE> - { - if (tensorType != null) { - model.smallConstant(name, Tensor.from(tensorType, tensorString)); - } else { - model.smallConstant(name, Tensor.from(tensorString)); - } - } + type = tensorType("constant '" + name + "'") value = <TENSORVALUE> + { + model.smallConstant(name, Tensor.from(type, value.image.substring(1))); + } } String constantTensorErrorMessage(String model, String constantTensorName) : {} @@ -297,7 +238,7 @@ TensorType tensorTypeWithPrefix(String errorMessage) : TensorType type; } { - <TYPE> <COLON> type= tensorType(errorMessage) + <TYPE> <COLON> type=tensorType(errorMessage) { return type; } } @@ -318,6 +259,61 @@ TensorType tensorType(String errorMessage) : } } +//---------------------------------------- +/** Consumes a constant block of model. */ +/* +void largeConstant() : +{ + String name; + RankingConstant constant; +} +{ + ( <CONSTANT> name = identifier() + { +// constant = new RankingConstant(name); + } + lbrace() (rankingConstantItem(constant) (<NL>)*)+ <RBRACE> ) + { + } +} +*/ + +/** Consumes a constant block. */ +/* +void rankingConstantItem(RankingConstant constant) : +{ + String path = null; + TensorType type = null; +} +{ + ( (<FILE> <COLON> path = filePath() { } (<NL>)*) { constant.setFileName(path); } + | (<URI> <COLON> path = uriPath() { } (<NL>)*) { constant.setUri(path); } + | type = tensorTypeWithPrefix(rankingConstantErrorMessage(constant.getName())) (<NL>)* { constant.setType(type); } + ) + { + return null; + } +} +*/ + +String rankingConstantErrorMessage(String name) : {} +{ + { return "For ranking constant ' " + name + "'"; } +} + +String filePath() : { } +{ + ( <FILE_PATH> | <STRING> | <IDENTIFIER>) + { return token.image; } +} + +String uriPath() : { } +{ + ( <URI_PATH> ) + { return token.image; } +} +//---------------------------------------- + /** Consumes an expression token and returns its image. */ String expression() : { @@ -340,7 +336,6 @@ String identifier() : { } | <URI> | <MODEL> | <TYPE> - | <CONSTANTS> ) { return token.image; } } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java index 4f9fb9c070a..1be2b7a4183 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java @@ -17,10 +17,16 @@ public class VespaImportTestCase { public void testExample() { ImportedModel model = importModel("example"); - assertEquals(1, model.inputs().size()); - assertEquals("tensor(name{},x[10])", model.inputs().get("input1").toString()); + assertEquals(2, model.inputs().size()); + assertEquals("tensor(name{},x[3])", model.inputs().get("input1").toString()); + assertEquals("tensor(x[3])", model.inputs().get("input2").toString()); - assertEquals("var1 * var2", model.expressions().get("foo").getRoot().toString()); + assertEquals(2, model.smallConstants().size()); + assertEquals("tensor(x[3]):{{x:0}:0.5,{x:1}:1.5,{x:2}:2.5}", model.smallConstants().get("constant1")); + assertEquals("tensor():{3.0}", model.smallConstants().get("constant2")); + + assertEquals("max(reduce(input1 * input2, sum, name),x) * constant2", + model.expressions().get("foo").getRoot().toString()); } @Test diff --git a/model-integration/src/test/models/vespa/example.model b/model-integration/src/test/models/vespa/example.model index 19598690aad..c0ea461db09 100644 --- a/model-integration/src/test/models/vespa/example.model +++ b/model-integration/src/test/models/vespa/example.model @@ -1,10 +1,16 @@ model example { - input1: tensor(name{}, x[10]) + # All inputs that are not scalar (aka 0-dimensional tensor) must be declared + input1: tensor(name{}, x[3]) + input2: tensor(x[3]) + constants { + constant1: tensor(x[3]):{{x:0}:0.5, {x:1}:1.5, {x:2}:2.5} + constant2: 3.0 + } function foo() { - expression: var1 * var2 + expression: max(sum(input1 * input2, name), x) * constant2 } }
\ No newline at end of file |