aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-06-06 09:36:51 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-06-06 09:36:51 +0200
commit3344c999c2721b75666e2e49f7fd6f15c9fe1353 (patch)
tree1d0084ff7c5037774d8740ce9dd5a7777c9367f7 /model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
parent3e1ef49b358ef027311d1d44d846695ea46125b8 (diff)
Expression file references in Vespa models
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java41
1 files changed, 41 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index 90529ccdca0..58962d1a5ff 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -4,10 +4,16 @@ package ai.vespa.rankingexpression.importer;
import com.google.common.collect.ImmutableMap;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.io.File;
+import java.io.IOException;
+import java.io.StringReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
@@ -115,6 +121,41 @@ public class ImportedModel implements ImportedMlModel {
public void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
public void function(String name, RankingExpression expression) { functions.put(name, expression); }
+ public void expression(String name, String expression) {
+ try {
+ expression = expression.trim();
+ if ( expression.startsWith("file:")) {
+ String filePath = expression.substring("file:".length()).trim();
+ if ( ! filePath.endsWith(ApplicationPackage.RANKEXPRESSION_NAME_SUFFIX))
+ filePath = filePath + ApplicationPackage.RANKEXPRESSION_NAME_SUFFIX;
+ expression = IOUtils.readFile(relativeFile(filePath, "function '" + name + "'"));
+ }
+ expression(name, new RankingExpression(expression));
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not read file referenced in '" + name + "'");
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException("Could not parse function '" + name + "'", e);
+ }
+ }
+
+ /**
+ * Returns a reference to the File at a path given relative to the source root of this model
+ *
+ * @throws IllegalArgumentException if the path is illegal or non-existent
+ */
+ public File relativeFile(String relativePath, String descriptionOfPath) {
+ File file = new File(new File(source()).getParent(), relativePath);
+ if (file.isAbsolute())
+ throw new IllegalArgumentException(descriptionOfPath + " uses the absolute file path '" + relativePath +
+ "'. File paths must be relative to the directory referencing them");
+ if ( ! file.exists())
+ throw new IllegalArgumentException(descriptionOfPath + " references '" + relativePath +
+ "', but this file does not exist");
+ return file;
+ }
+
/**
* Returns all the output expressions of this indexed by name. The names consist of one or two parts
* separated by dot, where the first part is the signature name