summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/test/integration/vespa/models/constant1asLarge.json7
-rw-r--r--config-model/src/test/integration/vespa/models/example.model25
-rw-r--r--config-model/src/test/integration/vespa/services.xml6
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java77
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java6
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorReader.java3
-rw-r--r--model-integration/pom.xml12
-rw-r--r--model-integration/src/main/config/model-integration.xml1
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java25
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/package-info.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java40
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/package-info.java10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/parser/SimpleCharStream.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java24
-rw-r--r--model-integration/src/main/javacc/ModelParser.jj332
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java83
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java2
-rw-r--r--model-integration/src/test/models/vespa/constant1asLarge.json7
-rw-r--r--model-integration/src/test/models/vespa/empty.model2
-rw-r--r--model-integration/src/test/models/vespa/example.model25
-rw-r--r--model-integration/src/test/models/vespa/misnamed.model3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java3
-rw-r--r--vespajlib/abi-spec.json3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java23
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java4
29 files changed, 724 insertions, 37 deletions
diff --git a/config-model/src/test/integration/vespa/models/constant1asLarge.json b/config-model/src/test/integration/vespa/models/constant1asLarge.json
new file mode 100644
index 00000000000..d2944d255af
--- /dev/null
+++ b/config-model/src/test/integration/vespa/models/constant1asLarge.json
@@ -0,0 +1,7 @@
+{
+ "cells": [
+ { "address": { "x": "0" }, "value": 0.5 },
+ { "address": { "x": "1" }, "value": 1.5 },
+ { "address": { "x": "2" }, "value": 2.5 }
+ ]
+} \ No newline at end of file
diff --git a/config-model/src/test/integration/vespa/models/example.model b/config-model/src/test/integration/vespa/models/example.model
new file mode 100644
index 00000000000..9579be4e44c
--- /dev/null
+++ b/config-model/src/test/integration/vespa/models/example.model
@@ -0,0 +1,25 @@
+model example {
+
+ # All inputs that are not scalar (aka 0-dimensional tensor) must be declared
+ input1: tensor(name{}, x[3])
+ input2: tensor(x[3])
+
+ constants {
+ constant1: tensor(x[3]):{{x:0}:0.5, {x:1}:1.5, {x:2}:2.5}
+ constant2: 3.0
+ }
+
+ constant constant1asLarge {
+ type: tensor(x[3])
+ file: constant1asLarge.json
+ }
+
+ function foo1() {
+ expression: max(sum(input1 * input2, name) * constant1, x) * constant2
+ }
+
+ function foo2() {
+ expression: max(sum(input1 * input2, name) * constant1asLarge, x) * constant2
+ }
+
+} \ No newline at end of file
diff --git a/config-model/src/test/integration/vespa/services.xml b/config-model/src/test/integration/vespa/services.xml
new file mode 100644
index 00000000000..aa1c0223bdf
--- /dev/null
+++ b/config-model/src/test/integration/vespa/services.xml
@@ -0,0 +1,6 @@
+<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
+<services>
+ <container version="1.0">
+
+ </container>
+</services>
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java
new file mode 100644
index 00000000000..a75699d2a1d
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java
@@ -0,0 +1,77 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.searchdefinition.derived.RawRankProfile;
+import com.yahoo.vespa.model.VespaModel;
+import com.yahoo.vespa.model.ml.ImportedModelTester;
+import org.junit.After;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.Optional;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests adding Vespa ranking expression based models in the models/ dir
+ *
+ * @author bratseth
+ */
+public class VespaMlModelTestCase {
+
+ private final Path applicationDir = Path.fromString("src/test/integration/vespa/");
+
+ private final String expectedRankConfig =
+ "constant(constant1).type : tensor(x[3])\n" +
+ "constant(constant1).value : tensor(x[3]):{{x:0}:0.5,{x:1}:1.5,{x:2}:2.5}\n" +
+ "rankingExpression(foo1).rankingScript : reduce(reduce(input1 * input2, sum, name) * constant(constant1), max, x) * 3.0\n" +
+ "rankingExpression(foo1).input2.type : tensor(x[3])\n" +
+ "rankingExpression(foo1).input1.type : tensor(name{},x[3])\n" +
+ "rankingExpression(foo2).rankingScript : max(reduce(input1 * input2, sum, name) * constant1asLarge,x) * 3.0\n" +
+ "rankingExpression(foo2).input2.type : tensor(x[3])\n" +
+ "rankingExpression(foo2).input1.type : tensor(name{},x[3])\n";
+
+ /** The model name */
+ private final String name = "example";
+
+ @After
+ public void removeGeneratedModelFiles() {
+ IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ }
+
+ @Test
+ public void testGlobalVespaModel() throws IOException {
+ ImportedModelTester tester = new ImportedModelTester(name, applicationDir);
+ VespaModel model = tester.createVespaModel();
+ tester.assertLargeConstant("constant1asLarge", model, Optional.of(3L));
+ assertEquals(expectedRankConfig, rankConfigOf("example", model));
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedAppDir = applicationDir.append("copy");
+ try {
+ storedAppDir.toFile().mkdirs();
+ IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString());
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir);
+ VespaModel storedModel = storedTester.createVespaModel();
+ storedTester.assertLargeConstant("constant1asLarge", model, Optional.of(3L));
+ assertEquals(expectedRankConfig, rankConfigOf("example", storedModel));
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(storedAppDir.toFile());
+ }
+ }
+
+ private String rankConfigOf(String rankProfileName, VespaModel model) {
+ StringBuilder b = new StringBuilder();
+ RawRankProfile profile = model.rankProfileList().getRankProfile(rankProfileName);
+ for (var property : profile.configProperties())
+ b.append(property.getFirst()).append(" : ").append(property.getSecond()).append("\n");
+ return b.toString();
+ }
+
+}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
index 563572b4af6..41811738ea4 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.ml;
+import ai.vespa.rankingexpression.importer.vespa.VespaImporter;
import com.google.common.collect.ImmutableList;
import com.yahoo.config.model.ApplicationPackageTester;
import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
@@ -8,10 +9,12 @@ import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
+import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
+import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.model.VespaModel;
@@ -34,7 +37,8 @@ public class ImportedModelTester {
private final ImmutableList<MlModelImporter> importers = ImmutableList.of(new TensorFlowImporter(),
new OnnxImporter(),
- new XGBoostImporter());
+ new XGBoostImporter(),
+ new VespaImporter());
private final String modelName;
private final Path applicationDir;
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
index 9a1a37caade..a3d2a157073 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
@@ -10,7 +10,7 @@ import static com.yahoo.document.json.readers.JsonParserHelpers.*;
/**
* Reads the tensor format described at
- * http://docs.vespa.ai/documentation/reference/document-json-put-format.html#tensor
+ * http://docs.vespa.ai/documentation/reference/document-json-format.html#tensor
*/
public class TensorReader {
@@ -20,6 +20,7 @@ public class TensorReader {
public static final String TENSOR_VALUE = "value";
public static void fillTensor(TokenBuffer buffer, TensorFieldValue tensorFieldValue) {
+ // TODO: Switch implementation to om.yahoo.tensor.serialization.JsonFormat.decode
Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorFieldValue.getDataType().getTensorType());
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index c1300d3be12..536d3578f8c 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -77,8 +77,12 @@
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<compilerArgs>
- <arg>-Xlint:rawtypes</arg>
- <arg>-Xlint:unchecked</arg>
+ <arg>-Xlint:all</arg>
+ <arg>-Xlint:-rawtypes</arg>
+ <arg>-Xlint:-unchecked</arg>
+ <arg>-Xlint:-serial</arg>
+ <arg>-Xlint:-cast</arg>
+ <arg>-Xlint:-overloads</arg>
<arg>-Werror</arg>
</compilerArgs>
</configuration>
@@ -91,6 +95,10 @@
<groupId>com.yahoo.vespa</groupId>
<artifactId>abi-check-plugin</artifactId>
</plugin>
+ <plugin>
+ <groupId>com.helger.maven</groupId>
+ <artifactId>ph-javacc-maven-plugin</artifactId>
+ </plugin>
</plugins>
</build>
diff --git a/model-integration/src/main/config/model-integration.xml b/model-integration/src/main/config/model-integration.xml
index da45ce23575..90ec7d7275e 100644
--- a/model-integration/src/main/config/model-integration.xml
+++ b/model-integration/src/main/config/model-integration.xml
@@ -8,3 +8,4 @@
<component id="ai.vespa.rankingexpression.importer.onnx.OnnxImporter" bundle="model-integration" />
<component id="ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter" bundle="model-integration" />
<component id="ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter" bundle="model-integration" />
+<component id="ai.vespa.rankingexpression.importer.vespa.VespaImporter" bundle="model-integration" />
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index d7ac8bc90b2..90529ccdca0 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -4,7 +4,6 @@ package ai.vespa.rankingexpression.importer;
import com.google.common.collect.ImmutableMap;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
-import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -184,7 +183,6 @@ public class ImportedModel implements ImportedMlModel {
private final Map<String, String> inputs = new LinkedHashMap<>();
private final Map<String, String> outputs = new LinkedHashMap<>();
private final Map<String, String> skippedOutputs = new HashMap<>();
- private final List<String> importWarnings = new ArrayList<>();
Signature(String name) {
this.name = name;
@@ -206,7 +204,7 @@ public class ImportedModel implements ImportedMlModel {
ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>();
// Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to*
// in the model, as these are the names which must actually be bound, if we are to avoid creating an
- // "input mapping" to accomodate this complexity in
+ // "input mapping" to accommodate this complexity
for (Map.Entry<String, String> inputEntry : inputs().entrySet())
inputs.put(inputEntry.getValue(), owner().inputs().get(inputEntry.getValue()));
return inputs.build();
@@ -224,9 +222,6 @@ public class ImportedModel implements ImportedMlModel {
*/
public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); }
- /** Returns an immutable list of possibly non-fatal warnings encountered during import. */
- public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
-
/** Returns the expression this output references as an imported function */
public ImportedMlFunction outputFunction(String outputName, String functionName) {
return new ImportedMlFunction(functionName,
@@ -242,7 +237,6 @@ public class ImportedModel implements ImportedMlModel {
void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
void output(String name, String expressionName) { outputs.put(name, expressionName); }
void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
- void importWarning(String warning) { importWarnings.add(warning); }
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index 54c19211277..99bfa08db43 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -187,8 +187,7 @@ public abstract class ModelImporter implements MlModelImporter {
TensorFunction function = operation.rankingExpressionFunction().get();
try {
model.function(operation.rankingExpressionFunctionName(),
- new RankingExpression(operation.rankingExpressionFunctionName(),
- function.toString()));
+ new RankingExpression(operation.rankingExpressionFunctionName(), function.toString()));
}
catch (ParseException e) {
throw new RuntimeException("Model function " + function +
@@ -210,7 +209,7 @@ public abstract class ModelImporter implements MlModelImporter {
private static void reportWarnings(IntermediateOperation operation, ImportedModel model) {
for (String warning : operation.warnings()) {
- model.defaultSignature().importWarning(warning);
+ // If we want to report warnings, that code goes here
}
for (IntermediateOperation input : operation.inputs()) {
reportWarnings(input, model);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
index 9c8f6238731..9115dc99b82 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
@@ -110,24 +110,25 @@ public class OrderedTensorType {
}
@Override
- public boolean equals(Object obj) {
- if (obj == null || !(obj instanceof OrderedTensorType)) {
- return false;
- }
- OrderedTensorType other = (OrderedTensorType) obj;
- if (dimensions.size() != dimensions.size()) {
- return false;
- }
+ public boolean equals(Object other) {
+ if (other == this) return true;
+ if ( ! (other instanceof OrderedTensorType)) return false;
+
List<TensorType.Dimension> thisDimensions = this.dimensions();
- List<TensorType.Dimension> otherDimensions = other.dimensions();
+ List<TensorType.Dimension> otherDimensions = ((OrderedTensorType)other).dimensions();
+ if (thisDimensions.size() != otherDimensions.size()) return false;
+
for (int i = 0; i < thisDimensions.size(); ++i) {
- if (!thisDimensions.get(i).equals(otherDimensions.get(i))) {
- return false;
- }
+ if ( ! thisDimensions.get(i).equals(otherDimensions.get(i))) return false;
}
return true;
}
+ @Override
+ public int hashCode() {
+ return type.hashCode();
+ }
+
public OrderedTensorType rename(DimensionRenamer renamer) {
List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size());
for (TensorType.Dimension dimension : dimensions) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/package-info.java
index 5a844bb5773..3258426dac4 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/package-info.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/package-info.java
@@ -1,5 +1,5 @@
/**
- * The config models view of imported models. This API cannot be changed withoug taking earlier config models
+ * The config models view of imported models. This API cannot be changed without taking earlier config models
* into account, not even on major versions.
*/
@ExportPackage
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java
index 27b80157d74..45ac2b16e97 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java
@@ -1,5 +1 @@
-// TODO: Don't export after November 2018
-@ExportPackage
package ai.vespa.rankingexpression.importer;
-
-import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java
new file mode 100644
index 00000000000..021fa1f7e51
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/VespaImporter.java
@@ -0,0 +1,40 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.vespa;
+
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import ai.vespa.rankingexpression.importer.ModelImporter;
+import ai.vespa.rankingexpression.importer.vespa.parser.ModelParser;
+
+import ai.vespa.rankingexpression.importer.vespa.parser.ParseException;
+import ai.vespa.rankingexpression.importer.vespa.parser.SimpleCharStream;
+import com.yahoo.io.IOUtils;
+
+import java.io.File;
+import java.io.IOException;
+
+/**
+ * Imports a model from a Vespa native ranking expression "model" file
+ */
+public class VespaImporter extends ModelImporter {
+
+ @Override
+ public boolean canImport(String modelPath) {
+ File modelFile = new File(modelPath);
+ if ( ! modelFile.isFile()) return false;
+
+ return modelFile.toString().endsWith(".model");
+ }
+
+ @Override
+ public ImportedModel importModel(String modelName, String modelPath) {
+ try {
+ ImportedModel model = new ImportedModel(modelName, modelPath);
+ new ModelParser(new SimpleCharStream(IOUtils.readFile(new File(modelPath))), model).model();
+ return model;
+ }
+ catch (IOException | ParseException e) {
+ throw new IllegalArgumentException("Could not import a Vespa model from '" + modelPath + "'", e);
+ }
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/package-info.java
new file mode 100644
index 00000000000..76c7ad6a134
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/package-info.java
@@ -0,0 +1,10 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+/**
+ * A model imported from Vespa ranking expressions
+ */
+@ExportPackage
+package ai.vespa.rankingexpression.importer.vespa;
+
+import com.yahoo.osgi.annotation.ExportPackage;
+
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/parser/SimpleCharStream.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/parser/SimpleCharStream.java
new file mode 100644
index 00000000000..8db9577a66c
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/vespa/parser/SimpleCharStream.java
@@ -0,0 +1,12 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.vespa.parser;
+
+import com.yahoo.javacc.FastCharStream;
+
+public class SimpleCharStream extends FastCharStream implements ai.vespa.rankingexpression.importer.vespa.parser.CharStream {
+
+ public SimpleCharStream(String input) {
+ super(input);
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
index ac462cc39eb..686cf6cd2df 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
@@ -1,11 +1,13 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.xgboost;
+import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.ModelImporter;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
@@ -22,7 +24,27 @@ public class XGBoostImporter extends ModelImporter {
File modelFile = new File(modelPath);
if ( ! modelFile.isFile()) return false;
- return modelFile.toString().endsWith(".json"); // No other models ends by json yet
+ return modelFile.toString().endsWith(".json") && probe(modelFile);
+ }
+
+ /**
+ * Returns true if the give file looks like an XGBoost json file.
+ * Currently, we just check if the file has an array on the top level.
+ */
+ private boolean probe(File modelFile) {
+ try {
+ BufferedReader reader = IOUtils.createReader(modelFile.getAbsolutePath());
+ String line;
+ while ((line = reader.readLine()) != null) {
+ line = line.trim();
+ if (line.startsWith("[")) return true;
+ if ( ! line.isEmpty()) return false;
+ }
+ return false;
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not read '" + modelFile + "'", e);
+ }
}
@Override
diff --git a/model-integration/src/main/javacc/ModelParser.jj b/model-integration/src/main/javacc/ModelParser.jj
new file mode 100644
index 00000000000..a5510dd89f3
--- /dev/null
+++ b/model-integration/src/main/javacc/ModelParser.jj
@@ -0,0 +1,332 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+// --------------------------------------------------------------------------------
+//
+// JavaCC options. When this file is changed, run "mvn generate-sources" to rebuild
+// the parser classes.
+//
+// --------------------------------------------------------------------------------
+options {
+ UNICODE_INPUT = true;
+ CACHE_TOKENS = false;
+ DEBUG_PARSER = false;
+ ERROR_REPORTING = true;
+ FORCE_LA_CHECK = true;
+ USER_CHAR_STREAM = true;
+}
+
+// --------------------------------------------------------------------------------
+//
+// Parser body.
+//
+// --------------------------------------------------------------------------------
+PARSER_BEGIN(ModelParser)
+
+package ai.vespa.rankingexpression.importer.vespa.parser;
+
+import java.io.File;
+import java.io.Reader;
+import java.io.StringReader;
+import java.util.List;
+import java.util.ArrayList;
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import com.yahoo.io.IOUtils;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.serialization.JsonFormat;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+
+/**
+ * Parser of Vespa ML model files: Ranking expression functions enclosed in brackets.
+ *
+ * @author bratseth
+ */
+public class ModelParser {
+
+ /** The model we are importing into */
+ private ImportedModel model;
+
+ /** Creates a parser of a string */
+ public ModelParser(String input, ImportedModel model) {
+ this(new SimpleCharStream(input), model);
+ }
+
+ /** Creates a parser */
+ public ModelParser(SimpleCharStream input, ImportedModel model) {
+ this(input);
+ this.model = model;
+ }
+
+}
+
+PARSER_END(ModelParser)
+
+
+// --------------------------------------------------------------------------------
+//
+// Token declarations.
+//
+// --------------------------------------------------------------------------------
+
+// Declare white space characters. These do not include newline because it has
+// special meaning in several of the production rules.
+SKIP :
+{
+ " " | "\t" | "\r" | "\f"
+}
+
+// Declare all tokens to be recognized. When a word token is added it MUST be
+// added to the identifier() production rule.
+TOKEN :
+{
+ < NL: "\n" >
+| < FUNCTION: "function" >
+| < TENSOR_TYPE: "tensor(" (~["(",")"])+ ")" >
+| < TENSORVALUE: (" ")* ":" (" ")* ("{"<BRACE_SL_LEVEL_1>) ("\n")? >
+| < TENSOR_VALUE_SL: "value" (" ")* ":" (" ")* ("{"<BRACE_SL_LEVEL_1>) ("\n")? >
+| < TENSOR_VALUE_ML: "value" (<SEARCHLIB_SKIP>)? "{" (["\n"," "])* ("{"<BRACE_ML_LEVEL_1>) (["\n"," "])* "}" ("\n")? >
+| < LBRACE: "{" >
+| < RBRACE: "}" >
+| < COLON: ":" >
+| < DOT: "." >
+| < COMMA: "," >
+| < MODEL: "model" >
+| < TYPE: "type" >
+| < EXPRESSION_SL: "expression" (" ")* ":" (("{"<BRACE_SL_LEVEL_1>)|<BRACE_SL_CONTENT>)* ("\n")? >
+| < EXPRESSION_ML: "expression" (<SEARCHLIB_SKIP>)? "{" (("{"<BRACE_ML_LEVEL_1>)|<BRACE_ML_CONTENT>)* "}" >
+| < #BRACE_SL_LEVEL_1: (("{"<BRACE_SL_LEVEL_2>)|<BRACE_SL_CONTENT>)* "}" >
+| < #BRACE_SL_LEVEL_2: (("{"<BRACE_SL_LEVEL_3>)|<BRACE_SL_CONTENT>)* "}" >
+| < #BRACE_SL_LEVEL_3: <BRACE_SL_CONTENT> "}" >
+| < #BRACE_SL_CONTENT: (~["{","}","\n"])* >
+| < #BRACE_ML_LEVEL_1: (("{"<BRACE_ML_LEVEL_2>)|<BRACE_ML_CONTENT>)* "}" >
+| < #BRACE_ML_LEVEL_2: (("{"<BRACE_ML_LEVEL_3>)|<BRACE_ML_CONTENT>)* "}" >
+| < #BRACE_ML_LEVEL_3: <BRACE_ML_CONTENT> "}" >
+| < #BRACE_ML_CONTENT: (~["{","}"])* >
+| < #SEARCHLIB_SKIP: ([" ","\f","\n","\r","\t"])+ >
+| < CONSTANT: "constant" >
+| < CONSTANTS: "constants" >
+| < FILE: "file" >
+| < URI: "uri" >
+| < IDENTIFIER: ["a"-"z","A"-"Z", "_"] (["a"-"z","A"-"Z","0"-"9","_"])* >
+| < CONTEXT: ["a"-"z","A"-"Z"] (["a"-"z", "A"-"Z", "0"-"9"])* >
+| < DOUBLE: ("-")? (["0"-"9"])+ "." (["0"-"9"])+ >
+| < STRING: (["a"-"z","A"-"Z","_","0"-"9","."])+ >
+| < FILE_PATH: ["a"-"z","A"-"Z", "_"] (["a"-"z","A"-"Z","0"-"9","_","-", "/", "."])+ >
+| < HTTP: ["h","H"] ["t","T"] ["t","T"] ["p","P"] (["s","S"])? >
+| < URI_PATH: <HTTP> <COLON> ("//")? (["a"-"z","A"-"Z","0"-"9","_","-", "/", ".",":"])+ >
+}
+
+// Declare a special skip token for comments.
+SPECIAL_TOKEN :
+{
+ <SINGLE_LINE_COMMENT: "#" (~["\n","\r"])* >
+}
+
+
+// --------------------------------------------------------------------------------
+//
+// Production rules.
+//
+// --------------------------------------------------------------------------------
+
+void model() :
+{
+ String name;
+}
+{
+ (<NL>)*
+ <MODEL>
+ (<NL>)*
+ name = identifier()
+ (<NL>)*
+ <LBRACE> modelContent() <RBRACE>
+ (<NL>)*
+ <EOF>
+ {
+ if ( ! name.equals(model.name()))
+ throw new IllegalArgumentException("Model '" + name + "' must be saved in a file named '" + name + ".model'");
+ }
+}
+
+void modelContent() :
+{
+}
+{
+ ( <NL> | input() | constants() | largeConstant() | function() )*
+}
+
+/** Declared input variables (aka features). All non-scalar inputs must be declared. */
+void input() :
+{
+ String name;
+ TensorType type;
+}
+{
+ name = identifier() <COLON> type = tensorType("Input parameter '" + name + "'")
+ { model.input(name, type); }
+}
+
+/** A function */
+void function() :
+{
+ String name, expression, parameter;
+ List parameters = new ArrayList();
+}
+{
+ ( <FUNCTION> name = identifier()
+ "("
+ [ parameter = identifier() { parameters.add(parameter); }
+ ( <COMMA> parameter = identifier() { parameters.add(parameter); } )* ]
+ ")"
+ lbrace() expression = expression() (<NL>)* <RBRACE> )
+ {
+ try {
+ model.expression(name, new RankingExpression(expression));
+ }
+ catch (com.yahoo.searchlib.rankingexpression.parser.ParseException e) {
+ throw new IllegalArgumentException("Could not parse function '" + name + "'", e);
+ }
+ }
+}
+
+/** Consumes the constants of this model. */
+void constants() :
+{
+ String name;
+}
+{
+ <CONSTANTS> <LBRACE> (<NL>)*
+ ( name = identifier() <COLON> ( constantDouble(name) | constantTensor(name) ) (<NL>)* )*
+ <RBRACE>
+}
+
+void constantDouble(String name) :
+{
+ Token value;
+}
+{
+ value = <DOUBLE> { model.smallConstant(name, Tensor.from(Double.parseDouble(value.image))); }
+}
+
+void constantTensor(String name) :
+{
+ TensorType type;
+ Token value;
+}
+{
+ type = tensorType("constant '" + name + "'") value = <TENSORVALUE>
+ {
+ model.smallConstant(name, Tensor.from(type, value.image.substring(1)));
+ }
+}
+
+String constantTensorErrorMessage(String model, String constantTensorName) : {}
+{
+ { return "For constant tensor '" + constantTensorName + "' in model '" + model + "'"; }
+}
+
+String tensorValue() :
+{
+ String tensor;
+}
+{
+ ( <TENSOR_VALUE_SL> { tensor = token.image.substring(token.image.indexOf(":") + 1); } |
+ <TENSOR_VALUE_ML> { tensor = token.image.substring(token.image.indexOf("{") + 1,
+ token.image.lastIndexOf("}")); } )
+ {
+ return tensor;
+ }
+}
+
+TensorType tensorType(String errorMessage) :
+{
+ String tensorTypeString;
+}
+{
+ <TENSOR_TYPE> { tensorTypeString = token.image; }
+ {
+ TensorType tensorType;
+ try {
+ tensorType = TensorType.fromSpec(tensorTypeString);
+ } catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException(errorMessage + ": Illegal tensor type spec: " + e.getMessage());
+ }
+ return tensorType;
+ }
+}
+
+/** Consumes a large constant. */
+void largeConstant() :
+{
+ String name;
+ Tensor value;
+}
+{
+ ( <CONSTANT> name = identifier() lbrace() value = largeConstantBody(name) <RBRACE> )
+ { model.largeConstant(name, value); }
+}
+
+// TODO: Add support in ImportedModel for passing a large tensor through as a file/Uri pointer instead of reading it here
+Tensor largeConstantBody(String name) :
+{
+ String path = null;
+ TensorType type = null;
+}
+{
+ ( <FILE> <COLON> path = filePath()
+// | (<URI> <COLON> path = uriPath() TODO
+ | <TYPE> <COLON> type = tensorType("Constant '" + name + "'")
+ | <NL>
+ )+
+ {
+ try {
+ return JsonFormat.decode(type, IOUtils.readFileBytes(new File(new File(model.source()).getParent(), path)));
+ }
+ catch (Exception e) {
+ throw new IllegalArgumentException("Could not read constant '" + name + "'", e);
+ }
+ }
+}
+
+String filePath() : { }
+{
+ ( <FILE_PATH> | <STRING> | <IDENTIFIER>)
+ { return token.image; }
+}
+
+String uriPath() : { }
+{
+ ( <URI_PATH> )
+ { return token.image; }
+}
+
+/** Consumes an expression token and returns its image. */
+String expression() :
+{
+ String exp;
+}
+{
+ ( <EXPRESSION_SL> { exp = token.image.substring(token.image.indexOf(":") + 1); } |
+ <EXPRESSION_ML> { exp = token.image.substring(token.image.indexOf("{") + 1,
+ token.image.lastIndexOf("}")); } )
+ { return exp; }
+}
+
+/** Consumes an identifier. This must be kept in sync with all word tokens that should be parseable as identifiers. */
+String identifier() : { }
+{
+ (
+ <IDENTIFIER>
+ | <DOUBLE>
+ | <FILE>
+ | <URI>
+ | <MODEL>
+ | <TYPE>
+ )
+ { return token.image; }
+}
+
+/** Consumes an opening brace with leading and trailing newline tokens. */
+void lbrace() : { }
+{
+ (<NL>)* <LBRACE> (<NL>)*
+}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
new file mode 100644
index 00000000000..4c8890f6476
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
@@ -0,0 +1,83 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.vespa;
+
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
+import org.junit.Test;
+
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+/**
+ * @author bratseth
+ */
+public class VespaImportTestCase {
+
+ @Test
+ public void testExample() {
+ ImportedModel model = importModel("example");
+
+ assertEquals(2, model.inputs().size());
+ assertEquals("tensor(name{},x[3])", model.inputs().get("input1").toString());
+ assertEquals("tensor(x[3])", model.inputs().get("input2").toString());
+
+ assertEquals(2, model.smallConstants().size());
+ assertEquals("tensor(x[3]):{{x:0}:0.5,{x:1}:1.5,{x:2}:2.5}", model.smallConstants().get("constant1"));
+ assertEquals("tensor():{3.0}", model.smallConstants().get("constant2"));
+
+ assertEquals(1, model.largeConstants().size());
+ assertEquals("tensor(x[3]):{{x:0}:0.5,{x:1}:1.5,{x:2}:2.5}", model.largeConstants().get("constant1asLarge"));
+
+ assertEquals(2, model.expressions().size());
+ assertEquals("max(reduce(input1 * input2, sum, name) * constant1,x) * constant2",
+ model.expressions().get("foo1").getRoot().toString());
+ assertEquals("max(reduce(input1 * input2, sum, name) * constant1asLarge,x) * constant2",
+ model.expressions().get("foo2").getRoot().toString());
+
+ List<ImportedMlFunction> functions = model.outputExpressions();
+ assertEquals(2, functions.size());
+ ImportedMlFunction foo1Function = functions.get(0);
+ assertEquals(2, foo1Function.arguments().size());
+ assertTrue(foo1Function.arguments().contains("input1"));
+ assertTrue(foo1Function.arguments().contains("input2"));
+ assertEquals(2, foo1Function.argumentTypes().size());
+ assertEquals("tensor(name{},x[3])", foo1Function.argumentTypes().get("input1"));
+ assertEquals("tensor(x[3])", foo1Function.argumentTypes().get("input2"));
+ }
+
+ @Test
+ public void testEmpty() {
+ ImportedModel model = importModel("empty");
+ assertTrue(model.expressions().isEmpty());
+ assertTrue(model.functions().isEmpty());
+ assertTrue(model.inputs().isEmpty());
+ assertTrue(model.largeConstants().isEmpty());
+ assertTrue(model.smallConstants().isEmpty());
+ }
+
+ @Test
+ public void testWrongName() {
+ try {
+ importModel("misnamed");
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("Model 'expectedname' must be saved in a file named 'expectedname.model'", e.getMessage());
+ }
+ }
+
+ private ImportedModel importModel(String name) {
+ String modelPath = "src/test/models/vespa/" + name + ".model";
+
+ VespaImporter importer = new VespaImporter();
+ assertTrue(importer.canImport(modelPath));
+ ImportedModel model = new VespaImporter().importModel(name, modelPath);
+ assertEquals(name, model.name());
+ assertEquals(modelPath, model.source());
+ return model;
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
index 965d5eb8577..6d54b63db4b 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
@@ -18,11 +18,11 @@ public class XGBoostImportTestCase {
ImportedModel model = new XGBoostImporter().importModel("test", "src/test/models/xgboost/xgboost.2.2.json");
assertTrue("All inputs are scalar", model.inputs().isEmpty());
assertEquals(1, model.expressions().size());
- System.out.println(model.expressions().keySet());
RankingExpression expression = model.expressions().get("test");
assertNotNull(expression);
assertEquals("if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)",
expression.getRoot().toString());
+ assertEquals(1, model.outputExpressions().size());
}
}
diff --git a/model-integration/src/test/models/vespa/constant1asLarge.json b/model-integration/src/test/models/vespa/constant1asLarge.json
new file mode 100644
index 00000000000..d2944d255af
--- /dev/null
+++ b/model-integration/src/test/models/vespa/constant1asLarge.json
@@ -0,0 +1,7 @@
+{
+ "cells": [
+ { "address": { "x": "0" }, "value": 0.5 },
+ { "address": { "x": "1" }, "value": 1.5 },
+ { "address": { "x": "2" }, "value": 2.5 }
+ ]
+} \ No newline at end of file
diff --git a/model-integration/src/test/models/vespa/empty.model b/model-integration/src/test/models/vespa/empty.model
new file mode 100644
index 00000000000..f5381b2ba93
--- /dev/null
+++ b/model-integration/src/test/models/vespa/empty.model
@@ -0,0 +1,2 @@
+model empty {
+} \ No newline at end of file
diff --git a/model-integration/src/test/models/vespa/example.model b/model-integration/src/test/models/vespa/example.model
new file mode 100644
index 00000000000..9579be4e44c
--- /dev/null
+++ b/model-integration/src/test/models/vespa/example.model
@@ -0,0 +1,25 @@
+model example {
+
+ # All inputs that are not scalar (aka 0-dimensional tensor) must be declared
+ input1: tensor(name{}, x[3])
+ input2: tensor(x[3])
+
+ constants {
+ constant1: tensor(x[3]):{{x:0}:0.5, {x:1}:1.5, {x:2}:2.5}
+ constant2: 3.0
+ }
+
+ constant constant1asLarge {
+ type: tensor(x[3])
+ file: constant1asLarge.json
+ }
+
+ function foo1() {
+ expression: max(sum(input1 * input2, name) * constant1, x) * constant2
+ }
+
+ function foo2() {
+ expression: max(sum(input1 * input2, name) * constant1asLarge, x) * constant2
+ }
+
+} \ No newline at end of file
diff --git a/model-integration/src/test/models/vespa/misnamed.model b/model-integration/src/test/models/vespa/misnamed.model
new file mode 100644
index 00000000000..44bfa5e380d
--- /dev/null
+++ b/model-integration/src/test/models/vespa/misnamed.model
@@ -0,0 +1,3 @@
+model expectedname {
+
+} \ No newline at end of file
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
index e8e2fdf2454..1181dafad3f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
@@ -11,6 +11,8 @@ import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;
@@ -120,6 +122,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
private boolean isConstant(ExpressionNode node) {
if (node instanceof ConstantNode) return true;
if (node instanceof ReferenceNode) return false;
+ if (node instanceof TensorFunctionNode) return false; // TODO: We could support asking it if it is constant
if ( ! (node instanceof CompositeNode)) return false;
for (ExpressionNode child : ((CompositeNode)node).children()) {
if ( ! isConstant(child)) return false;
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 9264b0a8255..04e68e60178 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1180,7 +1180,8 @@
"public static boolean approxEquals(double, double)",
"public static com.yahoo.tensor.Tensor from(com.yahoo.tensor.TensorType, java.lang.String)",
"public static com.yahoo.tensor.Tensor from(java.lang.String, java.lang.String)",
- "public static com.yahoo.tensor.Tensor from(java.lang.String)"
+ "public static com.yahoo.tensor.Tensor from(java.lang.String)",
+ "public static com.yahoo.tensor.Tensor from(double)"
],
"fields": []
},
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index ebb341147cf..22ff793e6fa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -367,6 +367,13 @@ public interface Tensor {
return TensorParser.tensorFrom(tensorString, Optional.empty());
}
+ /**
+ * Returns a double as a tensor: A dimensionless tensor containing the value as its cell
+ */
+ static Tensor from(double value) {
+ return Tensor.Builder.of(TensorType.empty).cell(value).build();
+ }
+
class Cell implements Map.Entry<TensorAddress, Double> {
private final TensorAddress address;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 3213982355b..6382361f187 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -1,7 +1,11 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.serialization;
+import com.yahoo.slime.ArrayTraverser;
import com.yahoo.slime.Cursor;
+import com.yahoo.slime.Inspector;
+import com.yahoo.slime.JsonDecoder;
+import com.yahoo.slime.ObjectTraverser;
import com.yahoo.slime.Slime;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
@@ -17,9 +21,7 @@ import java.util.Iterator;
// TODO: We should probably move reading of this format from the document module to here
public class JsonFormat {
- /**
- * Serialize the given tensor into JSON format
- */
+ /** Serializes the given tensor into JSON format */
public static byte[] encode(Tensor tensor) {
Slime slime = new Slime();
Cursor root = slime.setObject();
@@ -38,4 +40,19 @@ public class JsonFormat {
addressObject.setString(type.dimensions().get(i).name(), address.label(i));
}
+ /** Deserializes the given tensor from JSON format */
+ // TODO: Add explicit validation (valid() checks) below
+ public static Tensor decode(TensorType type, byte[] jsonTensorValue) {
+ Tensor.Builder tensorBuilder = Tensor.Builder.of(type);
+ Inspector root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get();
+ Inspector cells = root.field("cells");
+ cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, tensorBuilder.cell()));
+ return tensorBuilder.build();
+ }
+
+ private static void decodeCell(Inspector cell, Tensor.Builder.CellBuilder cellBuilder) {
+ cell.field("address").traverse((ObjectTraverser) (dimension, label) -> cellBuilder.label(dimension, label.asString()));
+ cellBuilder.value(cell.field("value").asDouble());
+ }
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 16af413f2f0..5a025b6eb96 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -26,6 +26,8 @@ public class JsonFormatTestCase {
"{\"address\":{\"x\":\"c\",\"y\":\"d\"},\"value\":3.0}" +
"]}",
new String(json, StandardCharsets.UTF_8));
+ Tensor decoded = JsonFormat.decode(tensor.type(), json);
+ assertEquals(tensor, decoded);
}
@Test
@@ -44,6 +46,8 @@ public class JsonFormatTestCase {
"{\"address\":{\"x\":\"1\",\"y\":\"1\"},\"value\":7.0}" +
"]}",
new String(json, StandardCharsets.UTF_8));
+ Tensor decoded = JsonFormat.decode(tensor.type(), json);
+ assertEquals(tensor, decoded);
}
}