// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. // -------------------------------------------------------------------------------- // // JavaCC options. When this file is changed, run "mvn generate-sources" to rebuild // the parser classes. // // -------------------------------------------------------------------------------- options { UNICODE_INPUT = true; CACHE_TOKENS = false; DEBUG_PARSER = false; ERROR_REPORTING = true; FORCE_LA_CHECK = true; USER_CHAR_STREAM = true; } // -------------------------------------------------------------------------------- // // Parser body. // // -------------------------------------------------------------------------------- PARSER_BEGIN(ModelParser) package ai.vespa.rankingexpression.importer.vespa.parser; import java.io.File; import java.io.Reader; import java.io.StringReader; import java.util.List; import java.util.ArrayList; import ai.vespa.rankingexpression.importer.ImportedModel; import com.yahoo.io.IOUtils; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.serialization.JsonFormat; import com.yahoo.searchlib.rankingexpression.RankingExpression; /** * Parser of Vespa ML model files: Ranking expression functions enclosed in brackets. * * @author bratseth */ public class ModelParser { /** The model we are importing into */ private ImportedModel model; /** Creates a parser of a string */ public ModelParser(String input, ImportedModel model) { this(new SimpleCharStream(input), model); } /** Creates a parser */ public ModelParser(SimpleCharStream input, ImportedModel model) { this(input); this.model = model; } } PARSER_END(ModelParser) // -------------------------------------------------------------------------------- // // Token declarations. // // -------------------------------------------------------------------------------- // Declare white space characters. These do not include newline because it has // special meaning in several of the production rules. SKIP : { " " | "\t" | "\r" | "\f" } // Declare all tokens to be recognized. When a word token is added it MUST be // added to the identifier() production rule. TOKEN : { < NL: "\n" > | < FUNCTION: "function" > | < TENSOR_TYPE: "tensor(" (~["(",")"])+ ")" > | < TENSORVALUE: (" ")* ":" (" ")* ("{") ("\n")? > | < TENSOR_VALUE_SL: "value" (" ")* ":" (" ")* ("{") ("\n")? > | < TENSOR_VALUE_ML: "value" ()? "{" (["\n"," "])* ("{") (["\n"," "])* "}" ("\n")? > | < LBRACE: "{" > | < RBRACE: "}" > | < COLON: ":" > | < DOT: "." > | < COMMA: "," > | < MODEL: "model" > | < TYPE: "type" > | < EXPRESSION_SL: "expression" (" ")* ":" (("{")|)* ("\n")? > | < EXPRESSION_ML: "expression" ()? "{" (("{")|)* "}" > | < #BRACE_SL_LEVEL_1: (("{")|)* "}" > | < #BRACE_SL_LEVEL_2: (("{")|)* "}" > | < #BRACE_SL_LEVEL_3: "}" > | < #BRACE_SL_CONTENT: (~["{","}","\n"])* > | < #BRACE_ML_LEVEL_1: (("{")|)* "}" > | < #BRACE_ML_LEVEL_2: (("{")|)* "}" > | < #BRACE_ML_LEVEL_3: "}" > | < #BRACE_ML_CONTENT: (~["{","}"])* > | < #SEARCHLIB_SKIP: ([" ","\f","\n","\r","\t"])+ > | < CONSTANT: "constant" > | < CONSTANTS: "constants" > | < FILE: "file" > | < URI: "uri" > | < IDENTIFIER: ["a"-"z","A"-"Z", "_"] (["a"-"z","A"-"Z","0"-"9","_"])* > | < CONTEXT: ["a"-"z","A"-"Z"] (["a"-"z", "A"-"Z", "0"-"9"])* > | < DOUBLE: ("-")? (["0"-"9"])+ "." (["0"-"9"])+ > | < STRING: (["a"-"z","A"-"Z","_","0"-"9","."])+ > | < FILE_PATH: ["a"-"z","A"-"Z", "_"] (["a"-"z","A"-"Z","0"-"9","_","-", "/", "."])+ > | < HTTP: ["h","H"] ["t","T"] ["t","T"] ["p","P"] (["s","S"])? > | < URI_PATH: ("//")? (["a"-"z","A"-"Z","0"-"9","_","-", "/", ".",":"])+ > } // Declare a special skip token for comments. SPECIAL_TOKEN : { } // -------------------------------------------------------------------------------- // // Production rules. // // -------------------------------------------------------------------------------- void model() : { String name; } { ()* ()* name = identifier() ()* modelContent() ()* { if ( ! model.name().endsWith(name)) throw new IllegalArgumentException("Unexpected model name '" + model.name() + "': Model '" + name + "' must be saved in a file named '" + name + ".model'"); } } void modelContent() : { } { ( | input() | constants() | largeConstant() | function() )* } /** Declared input variables (aka features). All non-scalar inputs must be declared. */ void input() : { String name; TensorType type; } { name = identifier() type = tensorType("Input parameter '" + name + "'") { model.input(name, type); } } /** A function */ void function() : { String name, expression, parameter; List parameters = new ArrayList(); } { ( name = identifier() "(" [ parameter = identifier() { parameters.add(parameter); } ( parameter = identifier() { parameters.add(parameter); } )* ] ")" lbrace() expression = expression() ()* ) { model.expression(name, expression); } } /** Consumes the constants of this model. */ void constants() : { String name; } { ()* ( name = identifier() ( constantDouble(name) | constantTensor(name) ) ()* )* } void constantDouble(String name) : { Token value; } { value = { model.smallConstant(name, Tensor.from(Double.parseDouble(value.image))); } } void constantTensor(String name) : { TensorType type; Token value; } { type = tensorType("constant '" + name + "'") value = { model.smallConstant(name, Tensor.from(type, value.image.substring(1))); } } String constantTensorErrorMessage(String model, String constantTensorName) : {} { { return "For constant tensor '" + constantTensorName + "' in model '" + model + "'"; } } String tensorValue() : { String tensor; } { ( { tensor = token.image.substring(token.image.indexOf(":") + 1); } | { tensor = token.image.substring(token.image.indexOf("{") + 1, token.image.lastIndexOf("}")); } ) { return tensor; } } TensorType tensorType(String errorMessage) : { String tensorTypeString; } { { tensorTypeString = token.image; } { TensorType tensorType; try { tensorType = TensorType.fromSpec(tensorTypeString); } catch (IllegalArgumentException e) { throw new IllegalArgumentException(errorMessage + ": Illegal tensor type spec: " + e.getMessage()); } return tensorType; } } /** Consumes a large constant. */ void largeConstant() : { String name; Tensor value; } { ( name = identifier() lbrace() value = largeConstantBody(name) ) { model.largeConstant(name, value); } } // TODO: Add support in ImportedModel for passing a large tensor through as a file/Uri pointer instead of reading it here Tensor largeConstantBody(String name) : { String path = null; TensorType type = null; } { ( path = filePath() // | ( path = uriPath() TODO | type = tensorType("Constant '" + name + "'") | )+ { try { return JsonFormat.decode(type, IOUtils.readFileBytes(model.relativeFile(path, "constant '" + name + "'"))); } catch (Exception e) { throw new IllegalArgumentException("Could not read constant '" + name + "'", e); } } } String filePath() : { } { ( | | ) { return token.image; } } String uriPath() : { } { ( ) { return token.image; } } /** Consumes an expression token and returns its image. */ String expression() : { String exp; } { ( { exp = token.image.substring(token.image.indexOf(":") + 1); } | { exp = token.image.substring(token.image.indexOf("{") + 1, token.image.lastIndexOf("}")); } ) { return exp; } } /** Consumes an identifier. This must be kept in sync with all word tokens that should be parseable as identifiers. */ String identifier() : { } { ( | | | | | ) { return token.image; } } /** Consumes an opening brace with leading and trailing newline tokens. */ void lbrace() : { } { ()* ()* }