diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-10-08 15:09:40 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-10-08 15:09:40 +0200 |
commit | f9fff4feb28350dafc400daaf6049ea7d1527f47 (patch) | |
tree | 433f11dcad7861ed4470018c07c252bd6b60203a /document/src/main | |
parent | 8fb3b2c9d19c29909a84eb9cef883031a6b13000 (diff) |
Single sparse dimension short form
Diffstat (limited to 'document/src/main')
-rw-r--r-- | document/src/main/java/com/yahoo/document/json/readers/TensorReader.java | 49 |
1 files changed, 37 insertions, 12 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java index 7b5fcfed0db..497c717a6ad 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.document.json.readers; +import com.fasterxml.jackson.core.JsonToken; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.json.TokenBuffer; import com.yahoo.tensor.IndexedTensor; @@ -47,10 +48,19 @@ public class TensorReader { } static void readTensorCells(TokenBuffer buffer, Tensor.Builder builder) { - expectArrayStart(buffer.currentToken()); - int initNesting = buffer.nesting(); - for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) - readTensorCell(buffer, builder); + if (buffer.currentToken() == JsonToken.START_ARRAY) { + int initNesting = buffer.nesting(); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + readTensorCell(buffer, builder); + } + else if (buffer.currentToken() == JsonToken.START_OBJECT) { // single dimension short form + int initNesting = buffer.nesting(); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + builder.cell(asAddress(buffer.currentName(), builder.type()), readDouble(buffer)); + } + else { + throw new IllegalArgumentException("Expected 'cells' to contain an array or an object, but got " + buffer.currentToken()); + } expectCompositeEnd(buffer.currentToken()); } @@ -80,8 +90,6 @@ public class TensorReader { if ( ! (builder instanceof IndexedTensor.BoundBuilder)) throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + "Use 'cells' or 'blocks' instead"); - expectArrayStart(buffer.currentToken()); - IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; int index = 0; int initNesting = buffer.nesting(); @@ -94,12 +102,23 @@ public class TensorReader { if ( ! (builder instanceof MixedTensor.BoundBuilder)) throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " + "Use 'cells' or 'values' instead"); - expectArrayStart(buffer.currentToken()); MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder) builder; - int initNesting = buffer.nesting(); - for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) - readTensorBlock(buffer, mixedBuilder); + if (buffer.currentToken() == JsonToken.START_ARRAY) { + int initNesting = buffer.nesting(); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + readTensorBlock(buffer, mixedBuilder); + } + else if (buffer.currentToken() == JsonToken.START_OBJECT) { + int initNesting = buffer.nesting(); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + mixedBuilder.block(asAddress(buffer.currentName(), builder.type().mappedSubtype()), + readValues(buffer, (int)mixedBuilder.denseSubspaceSize())); + } + else { + throw new IllegalArgumentException("Expected 'blocks' to contain an array or an object, but got " + buffer.currentToken()); + } + expectCompositeEnd(buffer.currentToken()); } @@ -127,8 +146,8 @@ public class TensorReader { private static TensorAddress readAddress(TokenBuffer buffer, TensorType type) { expectObjectStart(buffer.currentToken()); - int initNesting = buffer.nesting(); TensorAddress.Builder builder = new TensorAddress.Builder(type); + int initNesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) builder.add(buffer.currentName(), buffer.currentText()); expectObjectEnd(buffer.currentToken()); @@ -149,11 +168,17 @@ public class TensorReader { private static double readDouble(TokenBuffer buffer) { try { - return Double.valueOf(buffer.currentText()); + return Double.parseDouble(buffer.currentText()); } catch (NumberFormatException e) { throw new IllegalArgumentException("Expected a number but got '" + buffer.currentText()); } } + private static TensorAddress asAddress(String label, TensorType type) { + if (type.dimensions().size() != 1) + throw new IllegalArgumentException("Expected a tensor with a single dimension but got " + type); + return new TensorAddress.Builder(type).add(type.dimensions().get(0).name(), label).build(); + } + } |