aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-06-02 13:25:45 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-06-02 13:25:45 +0200
commitc25c8a52e2328bcff2f5a35496e7568ee5a7c752 (patch)
treecd624363ad22b7a2b6a76e41bd27c0cd7f5169d7
parente9e5a422c0aa6364c3c5f7b9da53e9fcf9a5f0f8 (diff)
Vespa global model import
-rw-r--r--config-model/src/test/integration/vespa/models/constant1asLarge.json7
-rw-r--r--config-model/src/test/integration/vespa/models/example.model25
-rw-r--r--config-model/src/test/integration/vespa/services.xml6
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java77
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java6
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorReader.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java24
-rw-r--r--model-integration/src/main/javacc/ModelParser.jj59
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java23
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java1
-rw-r--r--model-integration/src/test/models/vespa/constant1asLarge.json7
-rw-r--r--model-integration/src/test/models/vespa/example.model13
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java23
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java4
16 files changed, 235 insertions, 48 deletions
diff --git a/config-model/src/test/integration/vespa/models/constant1asLarge.json b/config-model/src/test/integration/vespa/models/constant1asLarge.json
new file mode 100644
index 00000000000..d2944d255af
--- /dev/null
+++ b/config-model/src/test/integration/vespa/models/constant1asLarge.json
@@ -0,0 +1,7 @@
+{
+ "cells": [
+ { "address": { "x": "0" }, "value": 0.5 },
+ { "address": { "x": "1" }, "value": 1.5 },
+ { "address": { "x": "2" }, "value": 2.5 }
+ ]
+} \ No newline at end of file
diff --git a/config-model/src/test/integration/vespa/models/example.model b/config-model/src/test/integration/vespa/models/example.model
new file mode 100644
index 00000000000..9579be4e44c
--- /dev/null
+++ b/config-model/src/test/integration/vespa/models/example.model
@@ -0,0 +1,25 @@
+model example {
+
+ # All inputs that are not scalar (aka 0-dimensional tensor) must be declared
+ input1: tensor(name{}, x[3])
+ input2: tensor(x[3])
+
+ constants {
+ constant1: tensor(x[3]):{{x:0}:0.5, {x:1}:1.5, {x:2}:2.5}
+ constant2: 3.0
+ }
+
+ constant constant1asLarge {
+ type: tensor(x[3])
+ file: constant1asLarge.json
+ }
+
+ function foo1() {
+ expression: max(sum(input1 * input2, name) * constant1, x) * constant2
+ }
+
+ function foo2() {
+ expression: max(sum(input1 * input2, name) * constant1asLarge, x) * constant2
+ }
+
+} \ No newline at end of file
diff --git a/config-model/src/test/integration/vespa/services.xml b/config-model/src/test/integration/vespa/services.xml
new file mode 100644
index 00000000000..aa1c0223bdf
--- /dev/null
+++ b/config-model/src/test/integration/vespa/services.xml
@@ -0,0 +1,6 @@
+<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
+<services>
+ <container version="1.0">
+
+ </container>
+</services>
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java
new file mode 100644
index 00000000000..a75699d2a1d
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java
@@ -0,0 +1,77 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.searchdefinition.derived.RawRankProfile;
+import com.yahoo.vespa.model.VespaModel;
+import com.yahoo.vespa.model.ml.ImportedModelTester;
+import org.junit.After;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.Optional;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests adding Vespa ranking expression based models in the models/ dir
+ *
+ * @author bratseth
+ */
+public class VespaMlModelTestCase {
+
+ private final Path applicationDir = Path.fromString("src/test/integration/vespa/");
+
+ private final String expectedRankConfig =
+ "constant(constant1).type : tensor(x[3])\n" +
+ "constant(constant1).value : tensor(x[3]):{{x:0}:0.5,{x:1}:1.5,{x:2}:2.5}\n" +
+ "rankingExpression(foo1).rankingScript : reduce(reduce(input1 * input2, sum, name) * constant(constant1), max, x) * 3.0\n" +
+ "rankingExpression(foo1).input2.type : tensor(x[3])\n" +
+ "rankingExpression(foo1).input1.type : tensor(name{},x[3])\n" +
+ "rankingExpression(foo2).rankingScript : max(reduce(input1 * input2, sum, name) * constant1asLarge,x) * 3.0\n" +
+ "rankingExpression(foo2).input2.type : tensor(x[3])\n" +
+ "rankingExpression(foo2).input1.type : tensor(name{},x[3])\n";
+
+ /** The model name */
+ private final String name = "example";
+
+ @After
+ public void removeGeneratedModelFiles() {
+ IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ }
+
+ @Test
+ public void testGlobalVespaModel() throws IOException {
+ ImportedModelTester tester = new ImportedModelTester(name, applicationDir);
+ VespaModel model = tester.createVespaModel();
+ tester.assertLargeConstant("constant1asLarge", model, Optional.of(3L));
+ assertEquals(expectedRankConfig, rankConfigOf("example", model));
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedAppDir = applicationDir.append("copy");
+ try {
+ storedAppDir.toFile().mkdirs();
+ IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString());
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir);
+ VespaModel storedModel = storedTester.createVespaModel();
+ storedTester.assertLargeConstant("constant1asLarge", model, Optional.of(3L));
+ assertEquals(expectedRankConfig, rankConfigOf("example", storedModel));
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(storedAppDir.toFile());
+ }
+ }
+
+ private String rankConfigOf(String rankProfileName, VespaModel model) {
+ StringBuilder b = new StringBuilder();
+ RawRankProfile profile = model.rankProfileList().getRankProfile(rankProfileName);
+ for (var property : profile.configProperties())
+ b.append(property.getFirst()).append(" : ").append(property.getSecond()).append("\n");
+ return b.toString();
+ }
+
+}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
index 563572b4af6..41811738ea4 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.ml;
+import ai.vespa.rankingexpression.importer.vespa.VespaImporter;
import com.google.common.collect.ImmutableList;
import com.yahoo.config.model.ApplicationPackageTester;
import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
@@ -8,10 +9,12 @@ import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
+import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
+import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.model.VespaModel;
@@ -34,7 +37,8 @@ public class ImportedModelTester {
private final ImmutableList<MlModelImporter> importers = ImmutableList.of(new TensorFlowImporter(),
new OnnxImporter(),
- new XGBoostImporter());
+ new XGBoostImporter(),
+ new VespaImporter());
private final String modelName;
private final Path applicationDir;
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
index 9a1a37caade..a3d2a157073 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
@@ -10,7 +10,7 @@ import static com.yahoo.document.json.readers.JsonParserHelpers.*;
/**
* Reads the tensor format described at
- * http://docs.vespa.ai/documentation/reference/document-json-put-format.html#tensor
+ * http://docs.vespa.ai/documentation/reference/document-json-format.html#tensor
*/
public class TensorReader {
@@ -20,6 +20,7 @@ public class TensorReader {
public static final String TENSOR_VALUE = "value";
public static void fillTensor(TokenBuffer buffer, TensorFieldValue tensorFieldValue) {
+ // TODO: Switch implementation to om.yahoo.tensor.serialization.JsonFormat.decode
Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorFieldValue.getDataType().getTensorType());
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index 36cb8c4f1cf..90529ccdca0 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -204,7 +204,7 @@ public class ImportedModel implements ImportedMlModel {
ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>();
// Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to*
// in the model, as these are the names which must actually be bound, if we are to avoid creating an
- // "input mapping" to accommodate this complexity in
+ // "input mapping" to accommodate this complexity
for (Map.Entry<String, String> inputEntry : inputs().entrySet())
inputs.put(inputEntry.getValue(), owner().inputs().get(inputEntry.getValue()));
return inputs.build();
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
index ac462cc39eb..686cf6cd2df 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
@@ -1,11 +1,13 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.xgboost;
+import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.ModelImporter;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
@@ -22,7 +24,27 @@ public class XGBoostImporter extends ModelImporter {
File modelFile = new File(modelPath);
if ( ! modelFile.isFile()) return false;
- return modelFile.toString().endsWith(".json"); // No other models ends by json yet
+ return modelFile.toString().endsWith(".json") && probe(modelFile);
+ }
+
+ /**
+ * Returns true if the give file looks like an XGBoost json file.
+ * Currently, we just check if the file has an array on the top level.
+ */
+ private boolean probe(File modelFile) {
+ try {
+ BufferedReader reader = IOUtils.createReader(modelFile.getAbsolutePath());
+ String line;
+ while ((line = reader.readLine()) != null) {
+ line = line.trim();
+ if (line.startsWith("[")) return true;
+ if ( ! line.isEmpty()) return false;
+ }
+ return false;
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not read '" + modelFile + "'", e);
+ }
}
@Override
diff --git a/model-integration/src/main/javacc/ModelParser.jj b/model-integration/src/main/javacc/ModelParser.jj
index 18dfb4c68ed..a5510dd89f3 100644
--- a/model-integration/src/main/javacc/ModelParser.jj
+++ b/model-integration/src/main/javacc/ModelParser.jj
@@ -23,13 +23,16 @@ PARSER_BEGIN(ModelParser)
package ai.vespa.rankingexpression.importer.vespa.parser;
+import java.io.File;
import java.io.Reader;
import java.io.StringReader;
import java.util.List;
import java.util.ArrayList;
import ai.vespa.rankingexpression.importer.ImportedModel;
+import com.yahoo.io.IOUtils;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.serialization.JsonFormat;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
/**
@@ -99,6 +102,7 @@ TOKEN :
| < #BRACE_ML_LEVEL_3: <BRACE_ML_CONTENT> "}" >
| < #BRACE_ML_CONTENT: (~["{","}"])* >
| < #SEARCHLIB_SKIP: ([" ","\f","\n","\r","\t"])+ >
+| < CONSTANT: "constant" >
| < CONSTANTS: "constants" >
| < FILE: "file" >
| < URI: "uri" >
@@ -147,7 +151,7 @@ void modelContent() :
{
}
{
- ( <NL> | input() | constants() | function() )*
+ ( <NL> | input() | constants() | largeConstant() | function() )*
}
/** Declared input variables (aka features). All non-scalar inputs must be declared. */
@@ -233,15 +237,6 @@ String tensorValue() :
}
}
-TensorType tensorTypeWithPrefix(String errorMessage) :
-{
- TensorType type;
-}
-{
- <TYPE> <COLON> type=tensorType(errorMessage)
- { return type; }
-}
-
TensorType tensorType(String errorMessage) :
{
String tensorTypeString;
@@ -259,47 +254,38 @@ TensorType tensorType(String errorMessage) :
}
}
-//----------------------------------------
-/** Consumes a constant block of model. */
-/*
+/** Consumes a large constant. */
void largeConstant() :
{
String name;
- RankingConstant constant;
+ Tensor value;
}
{
- ( <CONSTANT> name = identifier()
- {
-// constant = new RankingConstant(name);
- }
- lbrace() (rankingConstantItem(constant) (<NL>)*)+ <RBRACE> )
- {
- }
+ ( <CONSTANT> name = identifier() lbrace() value = largeConstantBody(name) <RBRACE> )
+ { model.largeConstant(name, value); }
}
-*/
-/** Consumes a constant block. */
-/*
-void rankingConstantItem(RankingConstant constant) :
+// TODO: Add support in ImportedModel for passing a large tensor through as a file/Uri pointer instead of reading it here
+Tensor largeConstantBody(String name) :
{
String path = null;
TensorType type = null;
}
{
- ( (<FILE> <COLON> path = filePath() { } (<NL>)*) { constant.setFileName(path); }
- | (<URI> <COLON> path = uriPath() { } (<NL>)*) { constant.setUri(path); }
- | type = tensorTypeWithPrefix(rankingConstantErrorMessage(constant.getName())) (<NL>)* { constant.setType(type); }
- )
+ ( <FILE> <COLON> path = filePath()
+// | (<URI> <COLON> path = uriPath() TODO
+ | <TYPE> <COLON> type = tensorType("Constant '" + name + "'")
+ | <NL>
+ )+
{
- return null;
+ try {
+ return JsonFormat.decode(type, IOUtils.readFileBytes(new File(new File(model.source()).getParent(), path)));
+ }
+ catch (Exception e) {
+ throw new IllegalArgumentException("Could not read constant '" + name + "'", e);
+ }
}
}
-*/
-
-String rankingConstantErrorMessage(String name) : {}
-{
- { return "For ranking constant ' " + name + "'"; }
-}
String filePath() : { }
{
@@ -312,7 +298,6 @@ String uriPath() : { }
( <URI_PATH> )
{ return token.image; }
}
-//----------------------------------------
/** Consumes an expression token and returns its image. */
String expression() :
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
index 1be2b7a4183..4c8890f6476 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
@@ -2,8 +2,11 @@
package ai.vespa.rankingexpression.importer.vespa;
import ai.vespa.rankingexpression.importer.ImportedModel;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
import org.junit.Test;
+import java.util.List;
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -25,8 +28,24 @@ public class VespaImportTestCase {
assertEquals("tensor(x[3]):{{x:0}:0.5,{x:1}:1.5,{x:2}:2.5}", model.smallConstants().get("constant1"));
assertEquals("tensor():{3.0}", model.smallConstants().get("constant2"));
- assertEquals("max(reduce(input1 * input2, sum, name),x) * constant2",
- model.expressions().get("foo").getRoot().toString());
+ assertEquals(1, model.largeConstants().size());
+ assertEquals("tensor(x[3]):{{x:0}:0.5,{x:1}:1.5,{x:2}:2.5}", model.largeConstants().get("constant1asLarge"));
+
+ assertEquals(2, model.expressions().size());
+ assertEquals("max(reduce(input1 * input2, sum, name) * constant1,x) * constant2",
+ model.expressions().get("foo1").getRoot().toString());
+ assertEquals("max(reduce(input1 * input2, sum, name) * constant1asLarge,x) * constant2",
+ model.expressions().get("foo2").getRoot().toString());
+
+ List<ImportedMlFunction> functions = model.outputExpressions();
+ assertEquals(2, functions.size());
+ ImportedMlFunction foo1Function = functions.get(0);
+ assertEquals(2, foo1Function.arguments().size());
+ assertTrue(foo1Function.arguments().contains("input1"));
+ assertTrue(foo1Function.arguments().contains("input2"));
+ assertEquals(2, foo1Function.argumentTypes().size());
+ assertEquals("tensor(name{},x[3])", foo1Function.argumentTypes().get("input1"));
+ assertEquals("tensor(x[3])", foo1Function.argumentTypes().get("input2"));
}
@Test
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
index 67a3b17255c..6d54b63db4b 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
@@ -22,6 +22,7 @@ public class XGBoostImportTestCase {
assertNotNull(expression);
assertEquals("if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)",
expression.getRoot().toString());
+ assertEquals(1, model.outputExpressions().size());
}
}
diff --git a/model-integration/src/test/models/vespa/constant1asLarge.json b/model-integration/src/test/models/vespa/constant1asLarge.json
new file mode 100644
index 00000000000..d2944d255af
--- /dev/null
+++ b/model-integration/src/test/models/vespa/constant1asLarge.json
@@ -0,0 +1,7 @@
+{
+ "cells": [
+ { "address": { "x": "0" }, "value": 0.5 },
+ { "address": { "x": "1" }, "value": 1.5 },
+ { "address": { "x": "2" }, "value": 2.5 }
+ ]
+} \ No newline at end of file
diff --git a/model-integration/src/test/models/vespa/example.model b/model-integration/src/test/models/vespa/example.model
index c0ea461db09..9579be4e44c 100644
--- a/model-integration/src/test/models/vespa/example.model
+++ b/model-integration/src/test/models/vespa/example.model
@@ -9,8 +9,17 @@ model example {
constant2: 3.0
}
- function foo() {
- expression: max(sum(input1 * input2, name), x) * constant2
+ constant constant1asLarge {
+ type: tensor(x[3])
+ file: constant1asLarge.json
+ }
+
+ function foo1() {
+ expression: max(sum(input1 * input2, name) * constant1, x) * constant2
+ }
+
+ function foo2() {
+ expression: max(sum(input1 * input2, name) * constant1asLarge, x) * constant2
}
} \ No newline at end of file
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
index e8e2fdf2454..1181dafad3f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
@@ -11,6 +11,8 @@ import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;
@@ -120,6 +122,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
private boolean isConstant(ExpressionNode node) {
if (node instanceof ConstantNode) return true;
if (node instanceof ReferenceNode) return false;
+ if (node instanceof TensorFunctionNode) return false; // TODO: We could support asking it if it is constant
if ( ! (node instanceof CompositeNode)) return false;
for (ExpressionNode child : ((CompositeNode)node).children()) {
if ( ! isConstant(child)) return false;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 3213982355b..6382361f187 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -1,7 +1,11 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.serialization;
+import com.yahoo.slime.ArrayTraverser;
import com.yahoo.slime.Cursor;
+import com.yahoo.slime.Inspector;
+import com.yahoo.slime.JsonDecoder;
+import com.yahoo.slime.ObjectTraverser;
import com.yahoo.slime.Slime;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
@@ -17,9 +21,7 @@ import java.util.Iterator;
// TODO: We should probably move reading of this format from the document module to here
public class JsonFormat {
- /**
- * Serialize the given tensor into JSON format
- */
+ /** Serializes the given tensor into JSON format */
public static byte[] encode(Tensor tensor) {
Slime slime = new Slime();
Cursor root = slime.setObject();
@@ -38,4 +40,19 @@ public class JsonFormat {
addressObject.setString(type.dimensions().get(i).name(), address.label(i));
}
+ /** Deserializes the given tensor from JSON format */
+ // TODO: Add explicit validation (valid() checks) below
+ public static Tensor decode(TensorType type, byte[] jsonTensorValue) {
+ Tensor.Builder tensorBuilder = Tensor.Builder.of(type);
+ Inspector root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get();
+ Inspector cells = root.field("cells");
+ cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, tensorBuilder.cell()));
+ return tensorBuilder.build();
+ }
+
+ private static void decodeCell(Inspector cell, Tensor.Builder.CellBuilder cellBuilder) {
+ cell.field("address").traverse((ObjectTraverser) (dimension, label) -> cellBuilder.label(dimension, label.asString()));
+ cellBuilder.value(cell.field("value").asDouble());
+ }
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 16af413f2f0..5a025b6eb96 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -26,6 +26,8 @@ public class JsonFormatTestCase {
"{\"address\":{\"x\":\"c\",\"y\":\"d\"},\"value\":3.0}" +
"]}",
new String(json, StandardCharsets.UTF_8));
+ Tensor decoded = JsonFormat.decode(tensor.type(), json);
+ assertEquals(tensor, decoded);
}
@Test
@@ -44,6 +46,8 @@ public class JsonFormatTestCase {
"{\"address\":{\"x\":\"1\",\"y\":\"1\"},\"value\":7.0}" +
"]}",
new String(json, StandardCharsets.UTF_8));
+ Tensor decoded = JsonFormat.decode(tensor.type(), json);
+ assertEquals(tensor, decoded);
}
}