summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-06-06 09:36:51 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-06-06 09:36:51 +0200
commit3344c999c2721b75666e2e49f7fd6f15c9fe1353 (patch)
tree1d0084ff7c5037774d8740ce9dd5a7777c9367f7 /model-integration
parent3e1ef49b358ef027311d1d44d846695ea46125b8 (diff)
Expression file references in Vespa models
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java41
-rw-r--r--model-integration/src/main/javacc/ModelParser.jj9
-rw-r--r--model-integration/src/test/models/vespa/example.model2
-rw-r--r--model-integration/src/test/models/vespa/test.expression1
4 files changed, 45 insertions, 8 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index 90529ccdca0..58962d1a5ff 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -4,10 +4,16 @@ package ai.vespa.rankingexpression.importer;
import com.google.common.collect.ImmutableMap;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.io.File;
+import java.io.IOException;
+import java.io.StringReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
@@ -115,6 +121,41 @@ public class ImportedModel implements ImportedMlModel {
public void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
public void function(String name, RankingExpression expression) { functions.put(name, expression); }
+ public void expression(String name, String expression) {
+ try {
+ expression = expression.trim();
+ if ( expression.startsWith("file:")) {
+ String filePath = expression.substring("file:".length()).trim();
+ if ( ! filePath.endsWith(ApplicationPackage.RANKEXPRESSION_NAME_SUFFIX))
+ filePath = filePath + ApplicationPackage.RANKEXPRESSION_NAME_SUFFIX;
+ expression = IOUtils.readFile(relativeFile(filePath, "function '" + name + "'"));
+ }
+ expression(name, new RankingExpression(expression));
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not read file referenced in '" + name + "'");
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException("Could not parse function '" + name + "'", e);
+ }
+ }
+
+ /**
+ * Returns a reference to the File at a path given relative to the source root of this model
+ *
+ * @throws IllegalArgumentException if the path is illegal or non-existent
+ */
+ public File relativeFile(String relativePath, String descriptionOfPath) {
+ File file = new File(new File(source()).getParent(), relativePath);
+ if (file.isAbsolute())
+ throw new IllegalArgumentException(descriptionOfPath + " uses the absolute file path '" + relativePath +
+ "'. File paths must be relative to the directory referencing them");
+ if ( ! file.exists())
+ throw new IllegalArgumentException(descriptionOfPath + " references '" + relativePath +
+ "', but this file does not exist");
+ return file;
+ }
+
/**
* Returns all the output expressions of this indexed by name. The names consist of one or two parts
* separated by dot, where the first part is the signature name
diff --git a/model-integration/src/main/javacc/ModelParser.jj b/model-integration/src/main/javacc/ModelParser.jj
index 5dde54e88e2..a7822cd1a00 100644
--- a/model-integration/src/main/javacc/ModelParser.jj
+++ b/model-integration/src/main/javacc/ModelParser.jj
@@ -180,12 +180,7 @@ void function() :
")"
lbrace() expression = expression() (<NL>)* <RBRACE> )
{
- try {
- model.expression(name, new RankingExpression(expression));
- }
- catch (com.yahoo.searchlib.rankingexpression.parser.ParseException e) {
- throw new IllegalArgumentException("Could not parse function '" + name + "'", e);
- }
+ model.expression(name, expression);
}
}
@@ -280,7 +275,7 @@ Tensor largeConstantBody(String name) :
)+
{
try {
- return JsonFormat.decode(type, IOUtils.readFileBytes(new File(new File(model.source()).getParent(), path)));
+ return JsonFormat.decode(type, IOUtils.readFileBytes(model.relativeFile(path, "constant '" + name + "'")));
}
catch (Exception e) {
throw new IllegalArgumentException("Could not read constant '" + name + "'", e);
diff --git a/model-integration/src/test/models/vespa/example.model b/model-integration/src/test/models/vespa/example.model
index 66d21cfc53f..6d660732db9 100644
--- a/model-integration/src/test/models/vespa/example.model
+++ b/model-integration/src/test/models/vespa/example.model
@@ -15,7 +15,7 @@ model example {
}
function foo1() {
- expression: reduce(sum(input1 * input2, name) * constant1, max, x) * constant2
+ expression: file:test.expression
}
function foo2() {
diff --git a/model-integration/src/test/models/vespa/test.expression b/model-integration/src/test/models/vespa/test.expression
new file mode 100644
index 00000000000..5db8a720498
--- /dev/null
+++ b/model-integration/src/test/models/vespa/test.expression
@@ -0,0 +1 @@
+reduce(sum(input1 * input2, name) * constant1, max, x) * constant2 \ No newline at end of file