diff options
18 files changed, 524 insertions, 31 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml index c1300d3be12..536d3578f8c 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -77,8 +77,12 @@ <artifactId>maven-compiler-plugin</artifactId> <configuration> <compilerArgs> - <arg>-Xlint:rawtypes</arg> - <arg>-Xlint:unchecked</arg> + <arg>-Xlint:all</arg> + <arg>-Xlint:-rawtypes</arg> + <arg>-Xlint:-unchecked</arg> + <arg>-Xlint:-serial</arg> + <arg>-Xlint:-cast</arg> + <arg>-Xlint:-overloads</arg> <arg>-Werror</arg> </compilerArgs> </configuration> @@ -91,6 +95,10 @@ <groupId>com.yahoo.vespa</groupId> <artifactId>abi-check-plugin</artifactId> </plugin> + <plugin> + <groupId>com.helger.maven</groupId> + <artifactId>ph-javacc-maven-plugin</artifactId> + </plugin> </plugins> </build> diff --git a/model-integration/src/main/config/model-integration.xml b/model-integration/src/main/config/model-integration.xml index da45ce23575..90ec7d7275e 100644 --- a/model-integration/src/main/config/model-integration.xml +++ b/model-integration/src/main/config/model-integration.xml @@ -8,3 +8,4 @@ <component id="ai.vespa.rankingexpression.importer.onnx.OnnxImporter" bundle="model-integration" /> <component id="ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter" bundle="model-integration" /> <component id="ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter" bundle="model-integration" /> +<component id="ai.vespa.rankingexpression.importer.vespa.VespaImporter" bundle="model-integration" /> diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index d7ac8bc90b2..36cb8c4f1cf 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -4,7 +4,6 @@ package ai.vespa.rankingexpression.importer; import com.google.common.collect.ImmutableMap; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -184,7 +183,6 @@ public class ImportedModel implements ImportedMlModel { private final Map<String, String> inputs = new LinkedHashMap<>(); private final Map<String, String> outputs = new LinkedHashMap<>(); private final Map<String, String> skippedOutputs = new HashMap<>(); - private final List<String> importWarnings = new ArrayList<>(); Signature(String name) { this.name = name; @@ -206,7 +204,7 @@ public class ImportedModel implements ImportedMlModel { ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>(); // Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to* // in the model, as these are the names which must actually be bound, if we are to avoid creating an - // "input mapping" to accomodate this complexity in + // "input mapping" to accommodate this complexity in for (Map.Entry<String, String> inputEntry : inputs().entrySet()) inputs.put(inputEntry.getValue(), owner().inputs().get(inputEntry.getValue())); return inputs.build(); @@ -224,9 +222,6 @@ public class ImportedModel implements ImportedMlModel { */ public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); } - /** Returns an immutable list of possibly non-fatal warnings encountered during import. */ - public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } - /** Returns the expression this output references as an imported function */ public ImportedMlFunction outputFunction(String outputName, String functionName) { return new ImportedMlFunction(functionName, @@ -242,7 +237,6 @@ public class ImportedModel implements ImportedMlModel { void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } void output(String name, String expressionName) { outputs.put(name, expressionName); } void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } - void importWarning(String warning) { importWarnings.add(warning); } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index 54c19211277..99bfa08db43 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -187,8 +187,7 @@ public abstract class ModelImporter implements MlModelImporter { TensorFunction function = operation.rankingExpressionFunction().get(); try { model.function(operation.rankingExpressionFunctionName(), - new RankingExpression(operation.rankingExpressionFunctionName(), - function.toString())); + new RankingExpression(operation.rankingExpressionFunctionName(), function.toString())); } catch (ParseException e) { throw new RuntimeException("Model function " + function + @@ -210,7 +209,7 @@ public abstract class ModelImporter implements MlModelImporter { private static void reportWarnings(IntermediateOperation operation, ImportedModel model) { for (String warning : operation.warnings()) { - model.defaultSignature().importWarning(warning); + // If we want to report warnings, that code goes here } for (IntermediateOperation input : operation.inputs()) { reportWarnings(input, model); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java index 9c8f6238731..9115dc99b82 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java @@ -110,24 +110,25 @@ public class OrderedTensorType { } @Override - public boolean equals(Object obj) { - if (obj == null || !(obj instanceof OrderedTensorType)) { - return false; - } - OrderedTensorType other = (OrderedTensorType) obj; - if (dimensions.size() != dimensions.size()) { - return false; - } + public boolean equals(Object other) { + if (other == this) return true; + if ( ! (other instanceof OrderedTensorType)) return false; + List<TensorType.Dimension> thisDimensions = this.dimensions(); - List<TensorType.Dimension> otherDimensions = other.dimensions(); + List<TensorType.Dimension> otherDimensions = ((OrderedTensorType)other).dimensions(); + if (thisDimensions.size() != otherDimensions.size()) return false; + for (int i = 0; i < thisDimensions.size(); ++i) { - if (!thisDimensions.get(i).equals(otherDimensions.get(i))) { - return false; - } + if ( ! thisDimensions.get(i).equals(otherDimensions.get(i))) return false; } return true; } + @Override + public int hashCode() { + return type.hashCode(); + } + public OrderedTensorType rename(DimensionRenamer renamer) { List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size()); for (TensorType.Dimension dimension : dimensions) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/package-info.java index 5a844bb5773..3258426dac4 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/package-info.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/package-info.java @@ -1,5 +1,5 @@ /** - * The config models view of imported models. This API cannot be changed withoug taking earlier config models + * The config models view of imported models. This API cannot be changed without taking earlier config models * into account, not even on major versions. */ @ExportPackage diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java index 27b80157d74..45ac2b16e97 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java @@ -1,5 +1 @@ -// TODO: Don't export after November 2018 -@ExportPackage package ai.vespa.rankingexpression.importer; - -import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java new file mode 100644 index 00000000000..021fa1f7e51 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java @@ -0,0 +1,40 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.vespa; + +import ai.vespa.rankingexpression.importer.ImportedModel; +import ai.vespa.rankingexpression.importer.ModelImporter; +import ai.vespa.rankingexpression.importer.vespa.parser.ModelParser; + +import ai.vespa.rankingexpression.importer.vespa.parser.ParseException; +import ai.vespa.rankingexpression.importer.vespa.parser.SimpleCharStream; +import com.yahoo.io.IOUtils; + +import java.io.File; +import java.io.IOException; + +/** + * Imports a model from a Vespa native ranking expression "model" file + */ +public class VespaImporter extends ModelImporter { + + @Override + public boolean canImport(String modelPath) { + File modelFile = new File(modelPath); + if ( ! modelFile.isFile()) return false; + + return modelFile.toString().endsWith(".model"); + } + + @Override + public ImportedModel importModel(String modelName, String modelPath) { + try { + ImportedModel model = new ImportedModel(modelName, modelPath); + new ModelParser(new SimpleCharStream(IOUtils.readFile(new File(modelPath))), model).model(); + return model; + } + catch (IOException | ParseException e) { + throw new IllegalArgumentException("Could not import a Vespa model from '" + modelPath + "'", e); + } + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/package-info.java new file mode 100644 index 00000000000..76c7ad6a134 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/package-info.java @@ -0,0 +1,10 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +/** + * A model imported from Vespa ranking expressions + */ +@ExportPackage +package ai.vespa.rankingexpression.importer.vespa; + +import com.yahoo.osgi.annotation.ExportPackage; + diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/parser/SimpleCharStream.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/parser/SimpleCharStream.java new file mode 100644 index 00000000000..8db9577a66c --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/parser/SimpleCharStream.java @@ -0,0 +1,12 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.vespa.parser; + +import com.yahoo.javacc.FastCharStream; + +public class SimpleCharStream extends FastCharStream implements ai.vespa.rankingexpression.importer.vespa.parser.CharStream { + + public SimpleCharStream(String input) { + super(input); + } + +} diff --git a/model-integration/src/main/javacc/ModelParser.jj b/model-integration/src/main/javacc/ModelParser.jj new file mode 100644 index 00000000000..7604259e850 --- /dev/null +++ b/model-integration/src/main/javacc/ModelParser.jj @@ -0,0 +1,352 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// -------------------------------------------------------------------------------- +// +// JavaCC options. When this file is changed, run "mvn generate-sources" to rebuild +// the parser classes. +// +// -------------------------------------------------------------------------------- +options { + UNICODE_INPUT = true; + CACHE_TOKENS = false; + DEBUG_PARSER = false; + ERROR_REPORTING = true; + FORCE_LA_CHECK = true; + USER_CHAR_STREAM = true; +} + +// -------------------------------------------------------------------------------- +// +// Parser body. +// +// -------------------------------------------------------------------------------- +PARSER_BEGIN(ModelParser) + +package ai.vespa.rankingexpression.importer.vespa.parser; + +import java.io.Reader; +import java.io.StringReader; +import java.util.List; +import java.util.ArrayList; +import ai.vespa.rankingexpression.importer.ImportedModel; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.Tensor; +import com.yahoo.searchlib.rankingexpression.RankingExpression; + +/** + * Parser of Vespa ML model files: Ranking expression functions enclosed in brackets. + * + * @author bratseth + */ +public class ModelParser { + + /** The model we are importing into */ + private ImportedModel model; + + /** Creates a parser of a string */ + public ModelParser(String input, ImportedModel model) { + this(new SimpleCharStream(input), model); + } + + /** Creates a parser */ + public ModelParser(SimpleCharStream input, ImportedModel model) { + this(input); + this.model = model; + } + +} + +PARSER_END(ModelParser) + + +// -------------------------------------------------------------------------------- +// +// Token declarations. +// +// -------------------------------------------------------------------------------- + +// Declare white space characters. These do not include newline because it has +// special meaning in several of the production rules. +SKIP : +{ + " " | "\t" | "\r" | "\f" +} + +// Declare all tokens to be recognized. When a word token is added it MUST be +// added to the identifier() production rule. +TOKEN : +{ + < NL: "\n" > +| < FUNCTION: "function" > +| < TENSOR_TYPE: "tensor(" (~["(",")"])+ ")" > +| < TENSOR_VALUE_SL: "value" (" ")* ":" (" ")* ("{"<BRACE_SL_LEVEL_1>) ("\n")? > +| < TENSOR_VALUE_ML: "value" (<SEARCHLIB_SKIP>)? "{" (["\n"," "])* ("{"<BRACE_ML_LEVEL_1>) (["\n"," "])* "}" ("\n")? > +| < LBRACE: "{" > +| < RBRACE: "}" > +| < COLON: ":" > +| < DOT: "." > +| < COMMA: "," > +| < MODEL: "model" > +| < TYPE: "type" > +| < EXPRESSION_SL: "expression" (" ")* ":" (("{"<BRACE_SL_LEVEL_1>)|<BRACE_SL_CONTENT>)* ("\n")? > +| < EXPRESSION_ML: "expression" (<SEARCHLIB_SKIP>)? "{" (("{"<BRACE_ML_LEVEL_1>)|<BRACE_ML_CONTENT>)* "}" > +| < #BRACE_SL_LEVEL_1: (("{"<BRACE_SL_LEVEL_2>)|<BRACE_SL_CONTENT>)* "}" > +| < #BRACE_SL_LEVEL_2: (("{"<BRACE_SL_LEVEL_3>)|<BRACE_SL_CONTENT>)* "}" > +| < #BRACE_SL_LEVEL_3: <BRACE_SL_CONTENT> "}" > +| < #BRACE_SL_CONTENT: (~["{","}","\n"])* > +| < #BRACE_ML_LEVEL_1: (("{"<BRACE_ML_LEVEL_2>)|<BRACE_ML_CONTENT>)* "}" > +| < #BRACE_ML_LEVEL_2: (("{"<BRACE_ML_LEVEL_3>)|<BRACE_ML_CONTENT>)* "}" > +| < #BRACE_ML_LEVEL_3: <BRACE_ML_CONTENT> "}" > +| < #BRACE_ML_CONTENT: (~["{","}"])* > +| < #SEARCHLIB_SKIP: ([" ","\f","\n","\r","\t"])+ > +| < CONSTANTS: "constants" > +| < FILE: "file" > +| < URI: "uri" > +| < IDENTIFIER: ["a"-"z","A"-"Z", "_"] (["a"-"z","A"-"Z","0"-"9","_"])* > +| < CONTEXT: ["a"-"z","A"-"Z"] (["a"-"z", "A"-"Z", "0"-"9"])* > +| < DOUBLE: ("-")? (["0"-"9"])+ "." (["0"-"9"])+ > +| < STRING: (["a"-"z","A"-"Z","_","0"-"9","."])+ > +| < FILE_PATH: ["a"-"z","A"-"Z", "_"] (["a"-"z","A"-"Z","0"-"9","_","-", "/", "."])+ > +| < HTTP: ["h","H"] ["t","T"] ["t","T"] ["p","P"] (["s","S"])? > +| < URI_PATH: <HTTP> <COLON> ("//")? (["a"-"z","A"-"Z","0"-"9","_","-", "/", ".",":"])+ > +} + +// Declare a special skip token for comments. +SPECIAL_TOKEN : +{ + <SINGLE_LINE_COMMENT: "#" (~["\n","\r"])* > +} + + +// -------------------------------------------------------------------------------- +// +// Production rules. +// +// -------------------------------------------------------------------------------- + +void model() : +{ + String name; +} +{ + (<NL>)* + <MODEL> + (<NL>)* + name = identifier() + (<NL>)* + <LBRACE> modelContent() <RBRACE> + (<NL>)* + <EOF> + { + if ( ! name.equals(model.name())) + throw new IllegalArgumentException("Model '" + name + "' must be saved in a file named '" + name + ".model'"); + } +} + +void modelContent() : +{ +} +{ + ( <NL> | input() | function() )* +} + +/** Declared input variables (aka features). All non-scalar inputs must be declared. */ +void input() : +{ + String name; + TensorType type; +} +{ + name = identifier() <COLON> type = tensorType("Input parameter '" + name + "'") + { model.input(name, type); } +} + +/** A function */ +void function() : +{ + String name, expression, parameter; + List parameters = new ArrayList(); +} +{ + ( <FUNCTION> name = identifier() + "(" + [ parameter = identifier() { parameters.add(parameter); } + ( <COMMA> parameter = identifier() { parameters.add(parameter); } )* ] + ")" + lbrace() expression = expression() (<NL>)* <RBRACE> ) + { + try { + model.expression(name, new RankingExpression(expression)); + } + catch (com.yahoo.searchlib.rankingexpression.parser.ParseException e) { + throw new IllegalArgumentException("Could not parse function '" + name + "'", e); + } + } +} + +/** Consumes a constant block of model. */ +/* +void rankingConstant() : +{ + String name; + RankingConstant constant; +} +{ + ( <CONSTANT> name = identifier() + { +// constant = new RankingConstant(name); + } + lbrace() (rankingConstantItem(constant) (<NL>)*)+ <RBRACE> ) + { + } +} +*/ + +/** Consumes a constant block. */ +/* +void rankingConstantItem(RankingConstant constant) : +{ + String path = null; + TensorType type = null; +} +{ + ( (<FILE> <COLON> path = filePath() { } (<NL>)*) { constant.setFileName(path); } + | (<URI> <COLON> path = uriPath() { } (<NL>)*) { constant.setUri(path); } + | type = tensorTypeWithPrefix(rankingConstantErrorMessage(constant.getName())) (<NL>)* { constant.setType(type); } + ) + { + return null; + } +} +*/ + +String rankingConstantErrorMessage(String name) : {} +{ + { return "For ranking constant ' " + name + "'"; } +} + +String filePath() : { } +{ + ( <FILE_PATH> | <STRING> | <IDENTIFIER>) + { return token.image; } +} + +String uriPath() : { } +{ + ( <URI_PATH> ) + { return token.image; } +} + +/** Consumes the constants of this model. */ +void constants(ImportedModel model) : +{ + String name; +} +{ + <CONSTANTS> <LBRACE> (<NL>)* + ( name = identifier() ( constantDouble(name) | + constantTensor(name) ) (<NL>)* )* + <RBRACE> +} + +void constantDouble(String name) : +{ + Token value; +} +{ + <COLON> value = <DOUBLE> { model.smallConstant(name, Tensor.from(Double.parseDouble(token.image))); } +} + +void constantTensor(String name) : +{ + String tensorString = ""; + TensorType tensorType = null; +} +{ + <LBRACE> (<NL>)* + (( tensorString = tensorValue() | + tensorType = tensorTypeWithPrefix(constantTensorErrorMessage(model.name(), name)) ) (<NL>)* )* <RBRACE> + { + if (tensorType != null) { + model.smallConstant(name, Tensor.from(tensorType, tensorString)); + } else { + model.smallConstant(name, Tensor.from(tensorString)); + } + } +} + +String constantTensorErrorMessage(String model, String constantTensorName) : {} +{ + { return "For constant tensor '" + constantTensorName + "' in model '" + model + "'"; } +} + +String tensorValue() : +{ + String tensor; +} +{ + ( <TENSOR_VALUE_SL> { tensor = token.image.substring(token.image.indexOf(":") + 1); } | + <TENSOR_VALUE_ML> { tensor = token.image.substring(token.image.indexOf("{") + 1, + token.image.lastIndexOf("}")); } ) + { + return tensor; + } +} + +TensorType tensorTypeWithPrefix(String errorMessage) : +{ + TensorType type; +} +{ + <TYPE> <COLON> type= tensorType(errorMessage) + { return type; } +} + +TensorType tensorType(String errorMessage) : +{ + String tensorTypeString; +} +{ + <TENSOR_TYPE> { tensorTypeString = token.image; } + { + TensorType tensorType; + try { + tensorType = TensorType.fromSpec(tensorTypeString); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(errorMessage + ": Illegal tensor type spec: " + e.getMessage()); + } + return tensorType; + } +} + +/** Consumes an expression token and returns its image. */ +String expression() : +{ + String exp; +} +{ + ( <EXPRESSION_SL> { exp = token.image.substring(token.image.indexOf(":") + 1); } | + <EXPRESSION_ML> { exp = token.image.substring(token.image.indexOf("{") + 1, + token.image.lastIndexOf("}")); } ) + { return exp; } +} + +/** Consumes an identifier. This must be kept in sync with all word tokens that should be parseable as identifiers. */ +String identifier() : { } +{ + ( + <IDENTIFIER> + | <DOUBLE> + | <FILE> + | <URI> + | <MODEL> + | <TYPE> + | <CONSTANTS> + ) + { return token.image; } +} + +/** Consumes an opening brace with leading and trailing newline tokens. */ +void lbrace() : { } +{ + (<NL>)* <LBRACE> (<NL>)* +} diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java new file mode 100644 index 00000000000..4f9fb9c070a --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java @@ -0,0 +1,58 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.vespa; + +import ai.vespa.rankingexpression.importer.ImportedModel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * @author bratseth + */ +public class VespaImportTestCase { + + @Test + public void testExample() { + ImportedModel model = importModel("example"); + + assertEquals(1, model.inputs().size()); + assertEquals("tensor(name{},x[10])", model.inputs().get("input1").toString()); + + assertEquals("var1 * var2", model.expressions().get("foo").getRoot().toString()); + } + + @Test + public void testEmpty() { + ImportedModel model = importModel("empty"); + assertTrue(model.expressions().isEmpty()); + assertTrue(model.functions().isEmpty()); + assertTrue(model.inputs().isEmpty()); + assertTrue(model.largeConstants().isEmpty()); + assertTrue(model.smallConstants().isEmpty()); + } + + @Test + public void testWrongName() { + try { + importModel("misnamed"); + fail("Expected exception"); + } + catch (IllegalArgumentException e) { + assertEquals("Model 'expectedname' must be saved in a file named 'expectedname.model'", e.getMessage()); + } + } + + private ImportedModel importModel(String name) { + String modelPath = "src/test/models/vespa/" + name + ".model"; + + VespaImporter importer = new VespaImporter(); + assertTrue(importer.canImport(modelPath)); + ImportedModel model = new VespaImporter().importModel(name, modelPath); + assertEquals(name, model.name()); + assertEquals(modelPath, model.source()); + return model; + } + +} diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java index 965d5eb8577..67a3b17255c 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java @@ -18,7 +18,6 @@ public class XGBoostImportTestCase { ImportedModel model = new XGBoostImporter().importModel("test", "src/test/models/xgboost/xgboost.2.2.json"); assertTrue("All inputs are scalar", model.inputs().isEmpty()); assertEquals(1, model.expressions().size()); - System.out.println(model.expressions().keySet()); RankingExpression expression = model.expressions().get("test"); assertNotNull(expression); assertEquals("if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)", diff --git a/model-integration/src/test/models/vespa/empty.model b/model-integration/src/test/models/vespa/empty.model new file mode 100644 index 00000000000..f5381b2ba93 --- /dev/null +++ b/model-integration/src/test/models/vespa/empty.model @@ -0,0 +1,2 @@ +model empty { +}
\ No newline at end of file diff --git a/model-integration/src/test/models/vespa/example.model b/model-integration/src/test/models/vespa/example.model new file mode 100644 index 00000000000..19598690aad --- /dev/null +++ b/model-integration/src/test/models/vespa/example.model @@ -0,0 +1,10 @@ +model example { + + input1: tensor(name{}, x[10]) + + + function foo() { + expression: var1 * var2 + } + +}
\ No newline at end of file diff --git a/model-integration/src/test/models/vespa/misnamed.model b/model-integration/src/test/models/vespa/misnamed.model new file mode 100644 index 00000000000..44bfa5e380d --- /dev/null +++ b/model-integration/src/test/models/vespa/misnamed.model @@ -0,0 +1,3 @@ +model expectedname { + +}
\ No newline at end of file diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 4f81f3baea8..c31eed32830 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1176,7 +1176,8 @@ "public static boolean approxEquals(double, double)", "public static com.yahoo.tensor.Tensor from(com.yahoo.tensor.TensorType, java.lang.String)", "public static com.yahoo.tensor.Tensor from(java.lang.String, java.lang.String)", - "public static com.yahoo.tensor.Tensor from(java.lang.String)" + "public static com.yahoo.tensor.Tensor from(java.lang.String)", + "public static com.yahoo.tensor.Tensor from(double)" ], "fields": [] }, diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index ebb341147cf..22ff793e6fa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -367,6 +367,13 @@ public interface Tensor { return TensorParser.tensorFrom(tensorString, Optional.empty()); } + /** + * Returns a double as a tensor: A dimensionless tensor containing the value as its cell + */ + static Tensor from(double value) { + return Tensor.Builder.of(TensorType.empty).cell(value).build(); + } + class Cell implements Map.Entry<TensorAddress, Double> { private final TensorAddress address; |