diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-10-08 15:09:40 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-10-08 15:09:40 +0200 |
commit | f9fff4feb28350dafc400daaf6049ea7d1527f47 (patch) | |
tree | 433f11dcad7861ed4470018c07c252bd6b60203a /document/src | |
parent | 8fb3b2c9d19c29909a84eb9cef883031a6b13000 (diff) |
Single sparse dimension short form
Diffstat (limited to 'document/src')
-rw-r--r-- | document/src/main/java/com/yahoo/document/json/readers/TensorReader.java | 49 | ||||
-rw-r--r-- | document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java | 31 |
2 files changed, 68 insertions, 12 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java index 7b5fcfed0db..497c717a6ad 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.document.json.readers; +import com.fasterxml.jackson.core.JsonToken; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.json.TokenBuffer; import com.yahoo.tensor.IndexedTensor; @@ -47,10 +48,19 @@ public class TensorReader { } static void readTensorCells(TokenBuffer buffer, Tensor.Builder builder) { - expectArrayStart(buffer.currentToken()); - int initNesting = buffer.nesting(); - for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) - readTensorCell(buffer, builder); + if (buffer.currentToken() == JsonToken.START_ARRAY) { + int initNesting = buffer.nesting(); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + readTensorCell(buffer, builder); + } + else if (buffer.currentToken() == JsonToken.START_OBJECT) { // single dimension short form + int initNesting = buffer.nesting(); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + builder.cell(asAddress(buffer.currentName(), builder.type()), readDouble(buffer)); + } + else { + throw new IllegalArgumentException("Expected 'cells' to contain an array or an object, but got " + buffer.currentToken()); + } expectCompositeEnd(buffer.currentToken()); } @@ -80,8 +90,6 @@ public class TensorReader { if ( ! (builder instanceof IndexedTensor.BoundBuilder)) throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + "Use 'cells' or 'blocks' instead"); - expectArrayStart(buffer.currentToken()); - IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; int index = 0; int initNesting = buffer.nesting(); @@ -94,12 +102,23 @@ public class TensorReader { if ( ! (builder instanceof MixedTensor.BoundBuilder)) throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " + "Use 'cells' or 'values' instead"); - expectArrayStart(buffer.currentToken()); MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder) builder; - int initNesting = buffer.nesting(); - for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) - readTensorBlock(buffer, mixedBuilder); + if (buffer.currentToken() == JsonToken.START_ARRAY) { + int initNesting = buffer.nesting(); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + readTensorBlock(buffer, mixedBuilder); + } + else if (buffer.currentToken() == JsonToken.START_OBJECT) { + int initNesting = buffer.nesting(); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + mixedBuilder.block(asAddress(buffer.currentName(), builder.type().mappedSubtype()), + readValues(buffer, (int)mixedBuilder.denseSubspaceSize())); + } + else { + throw new IllegalArgumentException("Expected 'blocks' to contain an array or an object, but got " + buffer.currentToken()); + } + expectCompositeEnd(buffer.currentToken()); } @@ -127,8 +146,8 @@ public class TensorReader { private static TensorAddress readAddress(TokenBuffer buffer, TensorType type) { expectObjectStart(buffer.currentToken()); - int initNesting = buffer.nesting(); TensorAddress.Builder builder = new TensorAddress.Builder(type); + int initNesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) builder.add(buffer.currentName(), buffer.currentText()); expectObjectEnd(buffer.currentToken()); @@ -149,11 +168,17 @@ public class TensorReader { private static double readDouble(TokenBuffer buffer) { try { - return Double.valueOf(buffer.currentText()); + return Double.parseDouble(buffer.currentText()); } catch (NumberFormatException e) { throw new IllegalArgumentException("Expected a number but got '" + buffer.currentText()); } } + private static TensorAddress asAddress(String label, TensorType type) { + if (type.dimensions().size() != 1) + throw new IllegalArgumentException("Expected a tensor with a single dimension but got " + type); + return new TensorAddress.Builder(type).add(type.dimensions().get(0).name(), label).build(); + } + } diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index 91998dedbb8..2af740147ed 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -153,6 +153,8 @@ public class JsonReaderTestCase { } { DocumentType x = new DocumentType("testtensor"); + x.addField(new Field("sparse_single_dimension_tensor", + new TensorDataType(new TensorType.Builder().mapped("x").build()))); x.addField(new Field("sparse_tensor", new TensorDataType(new TensorType.Builder().mapped("x").mapped("y").build()))); x.addField(new Field("dense_tensor", @@ -1335,6 +1337,26 @@ public class JsonReaderTestCase { } @Test + public void testMixedTensorInMixedFormWithSingleSparseDimensionShortForm() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[3])")); + builder.cell().label("x", 0).label("y", 0).value(2.0); + builder.cell().label("x", 0).label("y", 1).value(3.0); + builder.cell().label("x", 0).label("y", 2).value(4.0); + builder.cell().label("x", 1).label("y", 0).value(5.0); + builder.cell().label("x", 1).label("y", 1).value(6.0); + builder.cell().label("x", 1).label("y", 2).value(7.0); + Tensor expected = builder.build(); + + String mixedJson = "{\"blocks\":{" + + "\"0\":[2.0,3.0,4.0]," + + "\"1\":[5.0,6.0,7.0]" + + "}}"; + Tensor tensor = assertTensorField(expected, + createPutWithTensor(inputJson(mixedJson), "mixed_tensor"), "mixed_tensor"); + assertTrue(tensor instanceof MixedTensor); // this matters for performance + } + + @Test public void testParsingOfTensorWithSingleCellInDifferentJsonOrder() { assertSparseTensorField("{{x:a,y:b}:2.0}", createPutWithSparseTensor(inputJson("{", @@ -1539,6 +1561,15 @@ public class JsonReaderTestCase { } @Test + public void tensor_add_update_on_sparse_tensor_with_single_dimension_short_form() { + assertTensorAddUpdate("{{x:a}:2.0, {x:c}: 3.0}", "sparse_single_dimension_tensor", + inputJson("{", + " 'cells': {", + " 'a': 2.0,", + " 'c': 3.0 }}")); + } + + @Test public void tensor_add_update_on_mixed_tensor() { assertTensorAddUpdate("{{x:a,y:0}:2.0, {x:a,y:1}:3.0, {x:a,y:2}:0.0}", "mixed_tensor", inputJson("{", |