diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-06-06 09:36:51 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-06-06 09:36:51 +0200 |
commit | 3344c999c2721b75666e2e49f7fd6f15c9fe1353 (patch) | |
tree | 1d0084ff7c5037774d8740ce9dd5a7777c9367f7 /model-integration | |
parent | 3e1ef49b358ef027311d1d44d846695ea46125b8 (diff) |
Expression file references in Vespa models
Diffstat (limited to 'model-integration')
4 files changed, 45 insertions, 8 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index 90529ccdca0..58962d1a5ff 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -4,10 +4,16 @@ package ai.vespa.rankingexpression.importer; import com.google.common.collect.ImmutableMap; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.io.IOUtils; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.io.File; +import java.io.IOException; +import java.io.StringReader; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -115,6 +121,41 @@ public class ImportedModel implements ImportedMlModel { public void expression(String name, RankingExpression expression) { expressions.put(name, expression); } public void function(String name, RankingExpression expression) { functions.put(name, expression); } + public void expression(String name, String expression) { + try { + expression = expression.trim(); + if ( expression.startsWith("file:")) { + String filePath = expression.substring("file:".length()).trim(); + if ( ! filePath.endsWith(ApplicationPackage.RANKEXPRESSION_NAME_SUFFIX)) + filePath = filePath + ApplicationPackage.RANKEXPRESSION_NAME_SUFFIX; + expression = IOUtils.readFile(relativeFile(filePath, "function '" + name + "'")); + } + expression(name, new RankingExpression(expression)); + } + catch (IOException e) { + throw new IllegalArgumentException("Could not read file referenced in '" + name + "'"); + } + catch (ParseException e) { + throw new IllegalArgumentException("Could not parse function '" + name + "'", e); + } + } + + /** + * Returns a reference to the File at a path given relative to the source root of this model + * + * @throws IllegalArgumentException if the path is illegal or non-existent + */ + public File relativeFile(String relativePath, String descriptionOfPath) { + File file = new File(new File(source()).getParent(), relativePath); + if (file.isAbsolute()) + throw new IllegalArgumentException(descriptionOfPath + " uses the absolute file path '" + relativePath + + "'. File paths must be relative to the directory referencing them"); + if ( ! file.exists()) + throw new IllegalArgumentException(descriptionOfPath + " references '" + relativePath + + "', but this file does not exist"); + return file; + } + /** * Returns all the output expressions of this indexed by name. The names consist of one or two parts * separated by dot, where the first part is the signature name diff --git a/model-integration/src/main/javacc/ModelParser.jj b/model-integration/src/main/javacc/ModelParser.jj index 5dde54e88e2..a7822cd1a00 100644 --- a/model-integration/src/main/javacc/ModelParser.jj +++ b/model-integration/src/main/javacc/ModelParser.jj @@ -180,12 +180,7 @@ void function() : ")" lbrace() expression = expression() (<NL>)* <RBRACE> ) { - try { - model.expression(name, new RankingExpression(expression)); - } - catch (com.yahoo.searchlib.rankingexpression.parser.ParseException e) { - throw new IllegalArgumentException("Could not parse function '" + name + "'", e); - } + model.expression(name, expression); } } @@ -280,7 +275,7 @@ Tensor largeConstantBody(String name) : )+ { try { - return JsonFormat.decode(type, IOUtils.readFileBytes(new File(new File(model.source()).getParent(), path))); + return JsonFormat.decode(type, IOUtils.readFileBytes(model.relativeFile(path, "constant '" + name + "'"))); } catch (Exception e) { throw new IllegalArgumentException("Could not read constant '" + name + "'", e); diff --git a/model-integration/src/test/models/vespa/example.model b/model-integration/src/test/models/vespa/example.model index 66d21cfc53f..6d660732db9 100644 --- a/model-integration/src/test/models/vespa/example.model +++ b/model-integration/src/test/models/vespa/example.model @@ -15,7 +15,7 @@ model example { } function foo1() { - expression: reduce(sum(input1 * input2, name) * constant1, max, x) * constant2 + expression: file:test.expression } function foo2() { diff --git a/model-integration/src/test/models/vespa/test.expression b/model-integration/src/test/models/vespa/test.expression new file mode 100644 index 00000000000..5db8a720498 --- /dev/null +++ b/model-integration/src/test/models/vespa/test.expression @@ -0,0 +1 @@ +reduce(sum(input1 * input2, name) * constant1, max, x) * constant2
\ No newline at end of file |