summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-06-02 13:25:45 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-06-02 13:25:45 +0200
commitc25c8a52e2328bcff2f5a35496e7568ee5a7c752 (patch)
treecd624363ad22b7a2b6a76e41bd27c0cd7f5169d7 /model-integration
parente9e5a422c0aa6364c3c5f7b9da53e9fcf9a5f0f8 (diff)
Vespa global model import
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java24
-rw-r--r--model-integration/src/main/javacc/ModelParser.jj59
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java23
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java1
-rw-r--r--model-integration/src/test/models/vespa/constant1asLarge.json7
-rw-r--r--model-integration/src/test/models/vespa/example.model13
7 files changed, 86 insertions, 43 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index 36cb8c4f1cf..90529ccdca0 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -204,7 +204,7 @@ public class ImportedModel implements ImportedMlModel {
ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>();
// Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to*
// in the model, as these are the names which must actually be bound, if we are to avoid creating an
- // "input mapping" to accommodate this complexity in
+ // "input mapping" to accommodate this complexity
for (Map.Entry<String, String> inputEntry : inputs().entrySet())
inputs.put(inputEntry.getValue(), owner().inputs().get(inputEntry.getValue()));
return inputs.build();
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
index ac462cc39eb..686cf6cd2df 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
@@ -1,11 +1,13 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.xgboost;
+import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.ModelImporter;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
@@ -22,7 +24,27 @@ public class XGBoostImporter extends ModelImporter {
File modelFile = new File(modelPath);
if ( ! modelFile.isFile()) return false;
- return modelFile.toString().endsWith(".json"); // No other models ends by json yet
+ return modelFile.toString().endsWith(".json") && probe(modelFile);
+ }
+
+ /**
+ * Returns true if the give file looks like an XGBoost json file.
+ * Currently, we just check if the file has an array on the top level.
+ */
+ private boolean probe(File modelFile) {
+ try {
+ BufferedReader reader = IOUtils.createReader(modelFile.getAbsolutePath());
+ String line;
+ while ((line = reader.readLine()) != null) {
+ line = line.trim();
+ if (line.startsWith("[")) return true;
+ if ( ! line.isEmpty()) return false;
+ }
+ return false;
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not read '" + modelFile + "'", e);
+ }
}
@Override
diff --git a/model-integration/src/main/javacc/ModelParser.jj b/model-integration/src/main/javacc/ModelParser.jj
index 18dfb4c68ed..a5510dd89f3 100644
--- a/model-integration/src/main/javacc/ModelParser.jj
+++ b/model-integration/src/main/javacc/ModelParser.jj
@@ -23,13 +23,16 @@ PARSER_BEGIN(ModelParser)
package ai.vespa.rankingexpression.importer.vespa.parser;
+import java.io.File;
import java.io.Reader;
import java.io.StringReader;
import java.util.List;
import java.util.ArrayList;
import ai.vespa.rankingexpression.importer.ImportedModel;
+import com.yahoo.io.IOUtils;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.serialization.JsonFormat;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
/**
@@ -99,6 +102,7 @@ TOKEN :
| < #BRACE_ML_LEVEL_3: <BRACE_ML_CONTENT> "}" >
| < #BRACE_ML_CONTENT: (~["{","}"])* >
| < #SEARCHLIB_SKIP: ([" ","\f","\n","\r","\t"])+ >
+| < CONSTANT: "constant" >
| < CONSTANTS: "constants" >
| < FILE: "file" >
| < URI: "uri" >
@@ -147,7 +151,7 @@ void modelContent() :
{
}
{
- ( <NL> | input() | constants() | function() )*
+ ( <NL> | input() | constants() | largeConstant() | function() )*
}
/** Declared input variables (aka features). All non-scalar inputs must be declared. */
@@ -233,15 +237,6 @@ String tensorValue() :
}
}
-TensorType tensorTypeWithPrefix(String errorMessage) :
-{
- TensorType type;
-}
-{
- <TYPE> <COLON> type=tensorType(errorMessage)
- { return type; }
-}
-
TensorType tensorType(String errorMessage) :
{
String tensorTypeString;
@@ -259,47 +254,38 @@ TensorType tensorType(String errorMessage) :
}
}
-//----------------------------------------
-/** Consumes a constant block of model. */
-/*
+/** Consumes a large constant. */
void largeConstant() :
{
String name;
- RankingConstant constant;
+ Tensor value;
}
{
- ( <CONSTANT> name = identifier()
- {
-// constant = new RankingConstant(name);
- }
- lbrace() (rankingConstantItem(constant) (<NL>)*)+ <RBRACE> )
- {
- }
+ ( <CONSTANT> name = identifier() lbrace() value = largeConstantBody(name) <RBRACE> )
+ { model.largeConstant(name, value); }
}
-*/
-/** Consumes a constant block. */
-/*
-void rankingConstantItem(RankingConstant constant) :
+// TODO: Add support in ImportedModel for passing a large tensor through as a file/Uri pointer instead of reading it here
+Tensor largeConstantBody(String name) :
{
String path = null;
TensorType type = null;
}
{
- ( (<FILE> <COLON> path = filePath() { } (<NL>)*) { constant.setFileName(path); }
- | (<URI> <COLON> path = uriPath() { } (<NL>)*) { constant.setUri(path); }
- | type = tensorTypeWithPrefix(rankingConstantErrorMessage(constant.getName())) (<NL>)* { constant.setType(type); }
- )
+ ( <FILE> <COLON> path = filePath()
+// | (<URI> <COLON> path = uriPath() TODO
+ | <TYPE> <COLON> type = tensorType("Constant '" + name + "'")
+ | <NL>
+ )+
{
- return null;
+ try {
+ return JsonFormat.decode(type, IOUtils.readFileBytes(new File(new File(model.source()).getParent(), path)));
+ }
+ catch (Exception e) {
+ throw new IllegalArgumentException("Could not read constant '" + name + "'", e);
+ }
}
}
-*/
-
-String rankingConstantErrorMessage(String name) : {}
-{
- { return "For ranking constant ' " + name + "'"; }
-}
String filePath() : { }
{
@@ -312,7 +298,6 @@ String uriPath() : { }
( <URI_PATH> )
{ return token.image; }
}
-//----------------------------------------
/** Consumes an expression token and returns its image. */
String expression() :
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
index 1be2b7a4183..4c8890f6476 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
@@ -2,8 +2,11 @@
package ai.vespa.rankingexpression.importer.vespa;
import ai.vespa.rankingexpression.importer.ImportedModel;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
import org.junit.Test;
+import java.util.List;
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -25,8 +28,24 @@ public class VespaImportTestCase {
assertEquals("tensor(x[3]):{{x:0}:0.5,{x:1}:1.5,{x:2}:2.5}", model.smallConstants().get("constant1"));
assertEquals("tensor():{3.0}", model.smallConstants().get("constant2"));
- assertEquals("max(reduce(input1 * input2, sum, name),x) * constant2",
- model.expressions().get("foo").getRoot().toString());
+ assertEquals(1, model.largeConstants().size());
+ assertEquals("tensor(x[3]):{{x:0}:0.5,{x:1}:1.5,{x:2}:2.5}", model.largeConstants().get("constant1asLarge"));
+
+ assertEquals(2, model.expressions().size());
+ assertEquals("max(reduce(input1 * input2, sum, name) * constant1,x) * constant2",
+ model.expressions().get("foo1").getRoot().toString());
+ assertEquals("max(reduce(input1 * input2, sum, name) * constant1asLarge,x) * constant2",
+ model.expressions().get("foo2").getRoot().toString());
+
+ List<ImportedMlFunction> functions = model.outputExpressions();
+ assertEquals(2, functions.size());
+ ImportedMlFunction foo1Function = functions.get(0);
+ assertEquals(2, foo1Function.arguments().size());
+ assertTrue(foo1Function.arguments().contains("input1"));
+ assertTrue(foo1Function.arguments().contains("input2"));
+ assertEquals(2, foo1Function.argumentTypes().size());
+ assertEquals("tensor(name{},x[3])", foo1Function.argumentTypes().get("input1"));
+ assertEquals("tensor(x[3])", foo1Function.argumentTypes().get("input2"));
}
@Test
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
index 67a3b17255c..6d54b63db4b 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
@@ -22,6 +22,7 @@ public class XGBoostImportTestCase {
assertNotNull(expression);
assertEquals("if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)",
expression.getRoot().toString());
+ assertEquals(1, model.outputExpressions().size());
}
}
diff --git a/model-integration/src/test/models/vespa/constant1asLarge.json b/model-integration/src/test/models/vespa/constant1asLarge.json
new file mode 100644
index 00000000000..d2944d255af
--- /dev/null
+++ b/model-integration/src/test/models/vespa/constant1asLarge.json
@@ -0,0 +1,7 @@
+{
+ "cells": [
+ { "address": { "x": "0" }, "value": 0.5 },
+ { "address": { "x": "1" }, "value": 1.5 },
+ { "address": { "x": "2" }, "value": 2.5 }
+ ]
+} \ No newline at end of file
diff --git a/model-integration/src/test/models/vespa/example.model b/model-integration/src/test/models/vespa/example.model
index c0ea461db09..9579be4e44c 100644
--- a/model-integration/src/test/models/vespa/example.model
+++ b/model-integration/src/test/models/vespa/example.model
@@ -9,8 +9,17 @@ model example {
constant2: 3.0
}
- function foo() {
- expression: max(sum(input1 * input2, name), x) * constant2
+ constant constant1asLarge {
+ type: tensor(x[3])
+ file: constant1asLarge.json
+ }
+
+ function foo1() {
+ expression: max(sum(input1 * input2, name) * constant1, x) * constant2
+ }
+
+ function foo2() {
+ expression: max(sum(input1 * input2, name) * constant1asLarge, x) * constant2
}
} \ No newline at end of file