diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-06-02 13:25:45 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-06-02 13:25:45 +0200 |
commit | c25c8a52e2328bcff2f5a35496e7568ee5a7c752 (patch) | |
tree | cd624363ad22b7a2b6a76e41bd27c0cd7f5169d7 /model-integration/src | |
parent | e9e5a422c0aa6364c3c5f7b9da53e9fcf9a5f0f8 (diff) |
Vespa global model import
Diffstat (limited to 'model-integration/src')
7 files changed, 86 insertions, 43 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index 36cb8c4f1cf..90529ccdca0 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -204,7 +204,7 @@ public class ImportedModel implements ImportedMlModel { ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>(); // Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to* // in the model, as these are the names which must actually be bound, if we are to avoid creating an - // "input mapping" to accommodate this complexity in + // "input mapping" to accommodate this complexity for (Map.Entry<String, String> inputEntry : inputs().entrySet()) inputs.put(inputEntry.getValue(), owner().inputs().get(inputEntry.getValue())); return inputs.build(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java index ac462cc39eb..686cf6cd2df 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java @@ -1,11 +1,13 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer.xgboost; +import com.yahoo.io.IOUtils; import com.yahoo.searchlib.rankingexpression.RankingExpression; import ai.vespa.rankingexpression.importer.ImportedModel; import ai.vespa.rankingexpression.importer.ModelImporter; import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import java.io.BufferedReader; import java.io.File; import java.io.IOException; @@ -22,7 +24,27 @@ public class XGBoostImporter extends ModelImporter { File modelFile = new File(modelPath); if ( ! modelFile.isFile()) return false; - return modelFile.toString().endsWith(".json"); // No other models ends by json yet + return modelFile.toString().endsWith(".json") && probe(modelFile); + } + + /** + * Returns true if the give file looks like an XGBoost json file. + * Currently, we just check if the file has an array on the top level. + */ + private boolean probe(File modelFile) { + try { + BufferedReader reader = IOUtils.createReader(modelFile.getAbsolutePath()); + String line; + while ((line = reader.readLine()) != null) { + line = line.trim(); + if (line.startsWith("[")) return true; + if ( ! line.isEmpty()) return false; + } + return false; + } + catch (IOException e) { + throw new IllegalArgumentException("Could not read '" + modelFile + "'", e); + } } @Override diff --git a/model-integration/src/main/javacc/ModelParser.jj b/model-integration/src/main/javacc/ModelParser.jj index 18dfb4c68ed..a5510dd89f3 100644 --- a/model-integration/src/main/javacc/ModelParser.jj +++ b/model-integration/src/main/javacc/ModelParser.jj @@ -23,13 +23,16 @@ PARSER_BEGIN(ModelParser) package ai.vespa.rankingexpression.importer.vespa.parser; +import java.io.File; import java.io.Reader; import java.io.StringReader; import java.util.List; import java.util.ArrayList; import ai.vespa.rankingexpression.importer.ImportedModel; +import com.yahoo.io.IOUtils; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.serialization.JsonFormat; import com.yahoo.searchlib.rankingexpression.RankingExpression; /** @@ -99,6 +102,7 @@ TOKEN : | < #BRACE_ML_LEVEL_3: <BRACE_ML_CONTENT> "}" > | < #BRACE_ML_CONTENT: (~["{","}"])* > | < #SEARCHLIB_SKIP: ([" ","\f","\n","\r","\t"])+ > +| < CONSTANT: "constant" > | < CONSTANTS: "constants" > | < FILE: "file" > | < URI: "uri" > @@ -147,7 +151,7 @@ void modelContent() : { } { - ( <NL> | input() | constants() | function() )* + ( <NL> | input() | constants() | largeConstant() | function() )* } /** Declared input variables (aka features). All non-scalar inputs must be declared. */ @@ -233,15 +237,6 @@ String tensorValue() : } } -TensorType tensorTypeWithPrefix(String errorMessage) : -{ - TensorType type; -} -{ - <TYPE> <COLON> type=tensorType(errorMessage) - { return type; } -} - TensorType tensorType(String errorMessage) : { String tensorTypeString; @@ -259,47 +254,38 @@ TensorType tensorType(String errorMessage) : } } -//---------------------------------------- -/** Consumes a constant block of model. */ -/* +/** Consumes a large constant. */ void largeConstant() : { String name; - RankingConstant constant; + Tensor value; } { - ( <CONSTANT> name = identifier() - { -// constant = new RankingConstant(name); - } - lbrace() (rankingConstantItem(constant) (<NL>)*)+ <RBRACE> ) - { - } + ( <CONSTANT> name = identifier() lbrace() value = largeConstantBody(name) <RBRACE> ) + { model.largeConstant(name, value); } } -*/ -/** Consumes a constant block. */ -/* -void rankingConstantItem(RankingConstant constant) : +// TODO: Add support in ImportedModel for passing a large tensor through as a file/Uri pointer instead of reading it here +Tensor largeConstantBody(String name) : { String path = null; TensorType type = null; } { - ( (<FILE> <COLON> path = filePath() { } (<NL>)*) { constant.setFileName(path); } - | (<URI> <COLON> path = uriPath() { } (<NL>)*) { constant.setUri(path); } - | type = tensorTypeWithPrefix(rankingConstantErrorMessage(constant.getName())) (<NL>)* { constant.setType(type); } - ) + ( <FILE> <COLON> path = filePath() +// | (<URI> <COLON> path = uriPath() TODO + | <TYPE> <COLON> type = tensorType("Constant '" + name + "'") + | <NL> + )+ { - return null; + try { + return JsonFormat.decode(type, IOUtils.readFileBytes(new File(new File(model.source()).getParent(), path))); + } + catch (Exception e) { + throw new IllegalArgumentException("Could not read constant '" + name + "'", e); + } } } -*/ - -String rankingConstantErrorMessage(String name) : {} -{ - { return "For ranking constant ' " + name + "'"; } -} String filePath() : { } { @@ -312,7 +298,6 @@ String uriPath() : { } ( <URI_PATH> ) { return token.image; } } -//---------------------------------------- /** Consumes an expression token and returns its image. */ String expression() : diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java index 1be2b7a4183..4c8890f6476 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java @@ -2,8 +2,11 @@ package ai.vespa.rankingexpression.importer.vespa; import ai.vespa.rankingexpression.importer.ImportedModel; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; import org.junit.Test; +import java.util.List; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -25,8 +28,24 @@ public class VespaImportTestCase { assertEquals("tensor(x[3]):{{x:0}:0.5,{x:1}:1.5,{x:2}:2.5}", model.smallConstants().get("constant1")); assertEquals("tensor():{3.0}", model.smallConstants().get("constant2")); - assertEquals("max(reduce(input1 * input2, sum, name),x) * constant2", - model.expressions().get("foo").getRoot().toString()); + assertEquals(1, model.largeConstants().size()); + assertEquals("tensor(x[3]):{{x:0}:0.5,{x:1}:1.5,{x:2}:2.5}", model.largeConstants().get("constant1asLarge")); + + assertEquals(2, model.expressions().size()); + assertEquals("max(reduce(input1 * input2, sum, name) * constant1,x) * constant2", + model.expressions().get("foo1").getRoot().toString()); + assertEquals("max(reduce(input1 * input2, sum, name) * constant1asLarge,x) * constant2", + model.expressions().get("foo2").getRoot().toString()); + + List<ImportedMlFunction> functions = model.outputExpressions(); + assertEquals(2, functions.size()); + ImportedMlFunction foo1Function = functions.get(0); + assertEquals(2, foo1Function.arguments().size()); + assertTrue(foo1Function.arguments().contains("input1")); + assertTrue(foo1Function.arguments().contains("input2")); + assertEquals(2, foo1Function.argumentTypes().size()); + assertEquals("tensor(name{},x[3])", foo1Function.argumentTypes().get("input1")); + assertEquals("tensor(x[3])", foo1Function.argumentTypes().get("input2")); } @Test diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java index 67a3b17255c..6d54b63db4b 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java @@ -22,6 +22,7 @@ public class XGBoostImportTestCase { assertNotNull(expression); assertEquals("if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)", expression.getRoot().toString()); + assertEquals(1, model.outputExpressions().size()); } } diff --git a/model-integration/src/test/models/vespa/constant1asLarge.json b/model-integration/src/test/models/vespa/constant1asLarge.json new file mode 100644 index 00000000000..d2944d255af --- /dev/null +++ b/model-integration/src/test/models/vespa/constant1asLarge.json @@ -0,0 +1,7 @@ +{ + "cells": [ + { "address": { "x": "0" }, "value": 0.5 }, + { "address": { "x": "1" }, "value": 1.5 }, + { "address": { "x": "2" }, "value": 2.5 } + ] +}
\ No newline at end of file diff --git a/model-integration/src/test/models/vespa/example.model b/model-integration/src/test/models/vespa/example.model index c0ea461db09..9579be4e44c 100644 --- a/model-integration/src/test/models/vespa/example.model +++ b/model-integration/src/test/models/vespa/example.model @@ -9,8 +9,17 @@ model example { constant2: 3.0 } - function foo() { - expression: max(sum(input1 * input2, name), x) * constant2 + constant constant1asLarge { + type: tensor(x[3]) + file: constant1asLarge.json + } + + function foo1() { + expression: max(sum(input1 * input2, name) * constant1, x) * constant2 + } + + function foo2() { + expression: max(sum(input1 * input2, name) * constant1asLarge, x) * constant2 } }
\ No newline at end of file |