From 986c2da2986a2fc0de4895a8107c85e4d0f37fd3 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Fri, 31 May 2019 17:55:21 +0200 Subject: Support native Vespa standalone models --- vespajlib/abi-spec.json | 3 ++- vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) (limited to 'vespajlib') diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 4f81f3baea8..c31eed32830 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1176,7 +1176,8 @@ "public static boolean approxEquals(double, double)", "public static com.yahoo.tensor.Tensor from(com.yahoo.tensor.TensorType, java.lang.String)", "public static com.yahoo.tensor.Tensor from(java.lang.String, java.lang.String)", - "public static com.yahoo.tensor.Tensor from(java.lang.String)" + "public static com.yahoo.tensor.Tensor from(java.lang.String)", + "public static com.yahoo.tensor.Tensor from(double)" ], "fields": [] }, diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index ebb341147cf..22ff793e6fa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -367,6 +367,13 @@ public interface Tensor { return TensorParser.tensorFrom(tensorString, Optional.empty()); } + /** + * Returns a double as a tensor: A dimensionless tensor containing the value as its cell + */ + static Tensor from(double value) { + return Tensor.Builder.of(TensorType.empty).cell(value).build(); + } + class Cell implements Map.Entry { private final TensorAddress address; -- cgit v1.2.3 From c25c8a52e2328bcff2f5a35496e7568ee5a7c752 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Sun, 2 Jun 2019 13:25:45 +0200 Subject: Vespa global model import --- .../integration/vespa/models/constant1asLarge.json | 7 ++ .../test/integration/vespa/models/example.model | 25 +++++++ .../src/test/integration/vespa/services.xml | 6 ++ .../processing/VespaMlModelTestCase.java | 77 ++++++++++++++++++++++ .../yahoo/vespa/model/ml/ImportedModelTester.java | 6 +- .../yahoo/document/json/readers/TensorReader.java | 3 +- .../rankingexpression/importer/ImportedModel.java | 2 +- .../importer/xgboost/XGBoostImporter.java | 24 ++++++- model-integration/src/main/javacc/ModelParser.jj | 59 +++++++---------- .../importer/vespa/VespaImportTestCase.java | 23 ++++++- .../importer/xgboost/XGBoostImportTestCase.java | 1 + .../src/test/models/vespa/constant1asLarge.json | 7 ++ .../src/test/models/vespa/example.model | 13 +++- .../rankingexpression/transform/Simplifier.java | 3 + .../com/yahoo/tensor/serialization/JsonFormat.java | 23 ++++++- .../tensor/serialization/JsonFormatTestCase.java | 4 ++ 16 files changed, 235 insertions(+), 48 deletions(-) create mode 100644 config-model/src/test/integration/vespa/models/constant1asLarge.json create mode 100644 config-model/src/test/integration/vespa/models/example.model create mode 100644 config-model/src/test/integration/vespa/services.xml create mode 100644 config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java create mode 100644 model-integration/src/test/models/vespa/constant1asLarge.json (limited to 'vespajlib') diff --git a/config-model/src/test/integration/vespa/models/constant1asLarge.json b/config-model/src/test/integration/vespa/models/constant1asLarge.json new file mode 100644 index 00000000000..d2944d255af --- /dev/null +++ b/config-model/src/test/integration/vespa/models/constant1asLarge.json @@ -0,0 +1,7 @@ +{ + "cells": [ + { "address": { "x": "0" }, "value": 0.5 }, + { "address": { "x": "1" }, "value": 1.5 }, + { "address": { "x": "2" }, "value": 2.5 } + ] +} \ No newline at end of file diff --git a/config-model/src/test/integration/vespa/models/example.model b/config-model/src/test/integration/vespa/models/example.model new file mode 100644 index 00000000000..9579be4e44c --- /dev/null +++ b/config-model/src/test/integration/vespa/models/example.model @@ -0,0 +1,25 @@ +model example { + + # All inputs that are not scalar (aka 0-dimensional tensor) must be declared + input1: tensor(name{}, x[3]) + input2: tensor(x[3]) + + constants { + constant1: tensor(x[3]):{{x:0}:0.5, {x:1}:1.5, {x:2}:2.5} + constant2: 3.0 + } + + constant constant1asLarge { + type: tensor(x[3]) + file: constant1asLarge.json + } + + function foo1() { + expression: max(sum(input1 * input2, name) * constant1, x) * constant2 + } + + function foo2() { + expression: max(sum(input1 * input2, name) * constant1asLarge, x) * constant2 + } + +} \ No newline at end of file diff --git a/config-model/src/test/integration/vespa/services.xml b/config-model/src/test/integration/vespa/services.xml new file mode 100644 index 00000000000..aa1c0223bdf --- /dev/null +++ b/config-model/src/test/integration/vespa/services.xml @@ -0,0 +1,6 @@ + + + + + + diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java new file mode 100644 index 00000000000..a75699d2a1d --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java @@ -0,0 +1,77 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.processing; + +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.searchdefinition.derived.RawRankProfile; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.ml.ImportedModelTester; +import org.junit.After; +import org.junit.Test; + +import java.io.IOException; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; + +/** + * Tests adding Vespa ranking expression based models in the models/ dir + * + * @author bratseth + */ +public class VespaMlModelTestCase { + + private final Path applicationDir = Path.fromString("src/test/integration/vespa/"); + + private final String expectedRankConfig = + "constant(constant1).type : tensor(x[3])\n" + + "constant(constant1).value : tensor(x[3]):{{x:0}:0.5,{x:1}:1.5,{x:2}:2.5}\n" + + "rankingExpression(foo1).rankingScript : reduce(reduce(input1 * input2, sum, name) * constant(constant1), max, x) * 3.0\n" + + "rankingExpression(foo1).input2.type : tensor(x[3])\n" + + "rankingExpression(foo1).input1.type : tensor(name{},x[3])\n" + + "rankingExpression(foo2).rankingScript : max(reduce(input1 * input2, sum, name) * constant1asLarge,x) * 3.0\n" + + "rankingExpression(foo2).input2.type : tensor(x[3])\n" + + "rankingExpression(foo2).input1.type : tensor(name{},x[3])\n"; + + /** The model name */ + private final String name = "example"; + + @After + public void removeGeneratedModelFiles() { + IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + } + + @Test + public void testGlobalVespaModel() throws IOException { + ImportedModelTester tester = new ImportedModelTester(name, applicationDir); + VespaModel model = tester.createVespaModel(); + tester.assertLargeConstant("constant1asLarge", model, Optional.of(3L)); + assertEquals(expectedRankConfig, rankConfigOf("example", model)); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedAppDir = applicationDir.append("copy"); + try { + storedAppDir.toFile().mkdirs(); + IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString()); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir); + VespaModel storedModel = storedTester.createVespaModel(); + storedTester.assertLargeConstant("constant1asLarge", model, Optional.of(3L)); + assertEquals(expectedRankConfig, rankConfigOf("example", storedModel)); + } + finally { + IOUtils.recursiveDeleteDir(storedAppDir.toFile()); + } + } + + private String rankConfigOf(String rankProfileName, VespaModel model) { + StringBuilder b = new StringBuilder(); + RawRankProfile profile = model.rankProfileList().getRankProfile(rankProfileName); + for (var property : profile.configProperties()) + b.append(property.getFirst()).append(" : ").append(property.getSecond()).append("\n"); + return b.toString(); + } + +} diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java index 563572b4af6..41811738ea4 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; +import ai.vespa.rankingexpression.importer.vespa.VespaImporter; import com.google.common.collect.ImmutableList; import com.yahoo.config.model.ApplicationPackageTester; import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter; @@ -8,10 +9,12 @@ import com.yahoo.config.model.deploy.DeployState; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; +import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; +import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.vespa.model.VespaModel; @@ -34,7 +37,8 @@ public class ImportedModelTester { private final ImmutableList importers = ImmutableList.of(new TensorFlowImporter(), new OnnxImporter(), - new XGBoostImporter()); + new XGBoostImporter(), + new VespaImporter()); private final String modelName; private final Path applicationDir; diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java index 9a1a37caade..a3d2a157073 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java @@ -10,7 +10,7 @@ import static com.yahoo.document.json.readers.JsonParserHelpers.*; /** * Reads the tensor format described at - * http://docs.vespa.ai/documentation/reference/document-json-put-format.html#tensor + * http://docs.vespa.ai/documentation/reference/document-json-format.html#tensor */ public class TensorReader { @@ -20,6 +20,7 @@ public class TensorReader { public static final String TENSOR_VALUE = "value"; public static void fillTensor(TokenBuffer buffer, TensorFieldValue tensorFieldValue) { + // TODO: Switch implementation to om.yahoo.tensor.serialization.JsonFormat.decode Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorFieldValue.getDataType().getTensorType()); expectObjectStart(buffer.currentToken()); int initNesting = buffer.nesting(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index 36cb8c4f1cf..90529ccdca0 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -204,7 +204,7 @@ public class ImportedModel implements ImportedMlModel { ImmutableMap.Builder inputs = new ImmutableMap.Builder<>(); // Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to* // in the model, as these are the names which must actually be bound, if we are to avoid creating an - // "input mapping" to accommodate this complexity in + // "input mapping" to accommodate this complexity for (Map.Entry inputEntry : inputs().entrySet()) inputs.put(inputEntry.getValue(), owner().inputs().get(inputEntry.getValue())); return inputs.build(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java index ac462cc39eb..686cf6cd2df 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java @@ -1,11 +1,13 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer.xgboost; +import com.yahoo.io.IOUtils; import com.yahoo.searchlib.rankingexpression.RankingExpression; import ai.vespa.rankingexpression.importer.ImportedModel; import ai.vespa.rankingexpression.importer.ModelImporter; import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import java.io.BufferedReader; import java.io.File; import java.io.IOException; @@ -22,7 +24,27 @@ public class XGBoostImporter extends ModelImporter { File modelFile = new File(modelPath); if ( ! modelFile.isFile()) return false; - return modelFile.toString().endsWith(".json"); // No other models ends by json yet + return modelFile.toString().endsWith(".json") && probe(modelFile); + } + + /** + * Returns true if the give file looks like an XGBoost json file. + * Currently, we just check if the file has an array on the top level. + */ + private boolean probe(File modelFile) { + try { + BufferedReader reader = IOUtils.createReader(modelFile.getAbsolutePath()); + String line; + while ((line = reader.readLine()) != null) { + line = line.trim(); + if (line.startsWith("[")) return true; + if ( ! line.isEmpty()) return false; + } + return false; + } + catch (IOException e) { + throw new IllegalArgumentException("Could not read '" + modelFile + "'", e); + } } @Override diff --git a/model-integration/src/main/javacc/ModelParser.jj b/model-integration/src/main/javacc/ModelParser.jj index 18dfb4c68ed..a5510dd89f3 100644 --- a/model-integration/src/main/javacc/ModelParser.jj +++ b/model-integration/src/main/javacc/ModelParser.jj @@ -23,13 +23,16 @@ PARSER_BEGIN(ModelParser) package ai.vespa.rankingexpression.importer.vespa.parser; +import java.io.File; import java.io.Reader; import java.io.StringReader; import java.util.List; import java.util.ArrayList; import ai.vespa.rankingexpression.importer.ImportedModel; +import com.yahoo.io.IOUtils; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.serialization.JsonFormat; import com.yahoo.searchlib.rankingexpression.RankingExpression; /** @@ -99,6 +102,7 @@ TOKEN : | < #BRACE_ML_LEVEL_3: "}" > | < #BRACE_ML_CONTENT: (~["{","}"])* > | < #SEARCHLIB_SKIP: ([" ","\f","\n","\r","\t"])+ > +| < CONSTANT: "constant" > | < CONSTANTS: "constants" > | < FILE: "file" > | < URI: "uri" > @@ -147,7 +151,7 @@ void modelContent() : { } { - ( | input() | constants() | function() )* + ( | input() | constants() | largeConstant() | function() )* } /** Declared input variables (aka features). All non-scalar inputs must be declared. */ @@ -233,15 +237,6 @@ String tensorValue() : } } -TensorType tensorTypeWithPrefix(String errorMessage) : -{ - TensorType type; -} -{ - type=tensorType(errorMessage) - { return type; } -} - TensorType tensorType(String errorMessage) : { String tensorTypeString; @@ -259,47 +254,38 @@ TensorType tensorType(String errorMessage) : } } -//---------------------------------------- -/** Consumes a constant block of model. */ -/* +/** Consumes a large constant. */ void largeConstant() : { String name; - RankingConstant constant; + Tensor value; } { - ( name = identifier() - { -// constant = new RankingConstant(name); - } - lbrace() (rankingConstantItem(constant) ()*)+ ) - { - } + ( name = identifier() lbrace() value = largeConstantBody(name) ) + { model.largeConstant(name, value); } } -*/ -/** Consumes a constant block. */ -/* -void rankingConstantItem(RankingConstant constant) : +// TODO: Add support in ImportedModel for passing a large tensor through as a file/Uri pointer instead of reading it here +Tensor largeConstantBody(String name) : { String path = null; TensorType type = null; } { - ( ( path = filePath() { } ()*) { constant.setFileName(path); } - | ( path = uriPath() { } ()*) { constant.setUri(path); } - | type = tensorTypeWithPrefix(rankingConstantErrorMessage(constant.getName())) ()* { constant.setType(type); } - ) + ( path = filePath() +// | ( path = uriPath() TODO + | type = tensorType("Constant '" + name + "'") + | + )+ { - return null; + try { + return JsonFormat.decode(type, IOUtils.readFileBytes(new File(new File(model.source()).getParent(), path))); + } + catch (Exception e) { + throw new IllegalArgumentException("Could not read constant '" + name + "'", e); + } } } -*/ - -String rankingConstantErrorMessage(String name) : {} -{ - { return "For ranking constant ' " + name + "'"; } -} String filePath() : { } { @@ -312,7 +298,6 @@ String uriPath() : { } ( ) { return token.image; } } -//---------------------------------------- /** Consumes an expression token and returns its image. */ String expression() : diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java index 1be2b7a4183..4c8890f6476 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java @@ -2,8 +2,11 @@ package ai.vespa.rankingexpression.importer.vespa; import ai.vespa.rankingexpression.importer.ImportedModel; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; import org.junit.Test; +import java.util.List; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -25,8 +28,24 @@ public class VespaImportTestCase { assertEquals("tensor(x[3]):{{x:0}:0.5,{x:1}:1.5,{x:2}:2.5}", model.smallConstants().get("constant1")); assertEquals("tensor():{3.0}", model.smallConstants().get("constant2")); - assertEquals("max(reduce(input1 * input2, sum, name),x) * constant2", - model.expressions().get("foo").getRoot().toString()); + assertEquals(1, model.largeConstants().size()); + assertEquals("tensor(x[3]):{{x:0}:0.5,{x:1}:1.5,{x:2}:2.5}", model.largeConstants().get("constant1asLarge")); + + assertEquals(2, model.expressions().size()); + assertEquals("max(reduce(input1 * input2, sum, name) * constant1,x) * constant2", + model.expressions().get("foo1").getRoot().toString()); + assertEquals("max(reduce(input1 * input2, sum, name) * constant1asLarge,x) * constant2", + model.expressions().get("foo2").getRoot().toString()); + + List functions = model.outputExpressions(); + assertEquals(2, functions.size()); + ImportedMlFunction foo1Function = functions.get(0); + assertEquals(2, foo1Function.arguments().size()); + assertTrue(foo1Function.arguments().contains("input1")); + assertTrue(foo1Function.arguments().contains("input2")); + assertEquals(2, foo1Function.argumentTypes().size()); + assertEquals("tensor(name{},x[3])", foo1Function.argumentTypes().get("input1")); + assertEquals("tensor(x[3])", foo1Function.argumentTypes().get("input2")); } @Test diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java index 67a3b17255c..6d54b63db4b 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java @@ -22,6 +22,7 @@ public class XGBoostImportTestCase { assertNotNull(expression); assertEquals("if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)", expression.getRoot().toString()); + assertEquals(1, model.outputExpressions().size()); } } diff --git a/model-integration/src/test/models/vespa/constant1asLarge.json b/model-integration/src/test/models/vespa/constant1asLarge.json new file mode 100644 index 00000000000..d2944d255af --- /dev/null +++ b/model-integration/src/test/models/vespa/constant1asLarge.json @@ -0,0 +1,7 @@ +{ + "cells": [ + { "address": { "x": "0" }, "value": 0.5 }, + { "address": { "x": "1" }, "value": 1.5 }, + { "address": { "x": "2" }, "value": 2.5 } + ] +} \ No newline at end of file diff --git a/model-integration/src/test/models/vespa/example.model b/model-integration/src/test/models/vespa/example.model index c0ea461db09..9579be4e44c 100644 --- a/model-integration/src/test/models/vespa/example.model +++ b/model-integration/src/test/models/vespa/example.model @@ -9,8 +9,17 @@ model example { constant2: 3.0 } - function foo() { - expression: max(sum(input1 * input2, name), x) * constant2 + constant constant1asLarge { + type: tensor(x[3]) + file: constant1asLarge.json + } + + function foo1() { + expression: max(sum(input1 * input2, name) * constant1, x) * constant2 + } + + function foo2() { + expression: max(sum(input1 * input2, name) * constant1asLarge, x) * constant2 } } \ No newline at end of file diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java index e8e2fdf2454..1181dafad3f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java @@ -11,6 +11,8 @@ import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.IfNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.functions.TensorFunction; import java.util.ArrayList; import java.util.List; @@ -120,6 +122,7 @@ public class Simplifier extends ExpressionTransformer { private boolean isConstant(ExpressionNode node) { if (node instanceof ConstantNode) return true; if (node instanceof ReferenceNode) return false; + if (node instanceof TensorFunctionNode) return false; // TODO: We could support asking it if it is constant if ( ! (node instanceof CompositeNode)) return false; for (ExpressionNode child : ((CompositeNode)node).children()) { if ( ! isConstant(child)) return false; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 3213982355b..6382361f187 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -1,7 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.serialization; +import com.yahoo.slime.ArrayTraverser; import com.yahoo.slime.Cursor; +import com.yahoo.slime.Inspector; +import com.yahoo.slime.JsonDecoder; +import com.yahoo.slime.ObjectTraverser; import com.yahoo.slime.Slime; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; @@ -17,9 +21,7 @@ import java.util.Iterator; // TODO: We should probably move reading of this format from the document module to here public class JsonFormat { - /** - * Serialize the given tensor into JSON format - */ + /** Serializes the given tensor into JSON format */ public static byte[] encode(Tensor tensor) { Slime slime = new Slime(); Cursor root = slime.setObject(); @@ -38,4 +40,19 @@ public class JsonFormat { addressObject.setString(type.dimensions().get(i).name(), address.label(i)); } + /** Deserializes the given tensor from JSON format */ + // TODO: Add explicit validation (valid() checks) below + public static Tensor decode(TensorType type, byte[] jsonTensorValue) { + Tensor.Builder tensorBuilder = Tensor.Builder.of(type); + Inspector root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get(); + Inspector cells = root.field("cells"); + cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, tensorBuilder.cell())); + return tensorBuilder.build(); + } + + private static void decodeCell(Inspector cell, Tensor.Builder.CellBuilder cellBuilder) { + cell.field("address").traverse((ObjectTraverser) (dimension, label) -> cellBuilder.label(dimension, label.asString())); + cellBuilder.value(cell.field("value").asDouble()); + } + } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 16af413f2f0..5a025b6eb96 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -26,6 +26,8 @@ public class JsonFormatTestCase { "{\"address\":{\"x\":\"c\",\"y\":\"d\"},\"value\":3.0}" + "]}", new String(json, StandardCharsets.UTF_8)); + Tensor decoded = JsonFormat.decode(tensor.type(), json); + assertEquals(tensor, decoded); } @Test @@ -44,6 +46,8 @@ public class JsonFormatTestCase { "{\"address\":{\"x\":\"1\",\"y\":\"1\"},\"value\":7.0}" + "]}", new String(json, StandardCharsets.UTF_8)); + Tensor decoded = JsonFormat.decode(tensor.type(), json); + assertEquals(tensor, decoded); } } -- cgit v1.2.3