aboutsummaryrefslogtreecommitdiffstats
path: root/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
diff options
context:
space:
mode:
Diffstat (limited to 'document/src/main/java/com/yahoo/document/json/readers/TensorReader.java')
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorReader.java53
1 files changed, 21 insertions, 32 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
index 1fd4029b1a5..0b7b1ae9996 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
@@ -4,15 +4,13 @@ package com.yahoo.document.json.readers;
import com.fasterxml.jackson.core.JsonToken;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.json.TokenBuffer;
-import com.yahoo.document.json.TokenBuffer.Token;
+import com.yahoo.slime.Inspector;
+import com.yahoo.slime.Type;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.TensorType.Dimension;
-
-import java.util.function.Supplier;
import static com.yahoo.document.json.readers.JsonParserHelpers.*;
import static com.yahoo.tensor.serialization.JsonFormat.decodeHexString;
@@ -39,43 +37,36 @@ public class TensorReader {
Tensor.Builder builder = Tensor.Builder.of(tensorFieldValue.getDataType().getTensorType());
expectOneOf(buffer.current(), JsonToken.START_OBJECT, JsonToken.START_ARRAY);
int initNesting = buffer.nesting();
- while (true) {
- Supplier<Token> lookahead = buffer.lookahead();
- Token next = lookahead.get();
- if (TENSOR_CELLS.equals(next.name) && ! primitiveContent(next.token, lookahead.get().token)) {
- buffer.next();
+ for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
+ if (TENSOR_CELLS.equals(buffer.currentName()) && ! primitiveContent(buffer)) {
readTensorCells(buffer, builder);
}
- else if (TENSOR_VALUES.equals(next.name) && builder.type().dimensions().stream().allMatch(Dimension::isIndexed)) {
- buffer.next();
+ else if (TENSOR_VALUES.equals(buffer.currentName()) && builder.type().dimensions().stream().allMatch(d -> d.isIndexed())) {
readTensorValues(buffer, builder);
}
- else if (TENSOR_BLOCKS.equals(next.name)) {
- buffer.next();
+ else if (TENSOR_BLOCKS.equals(buffer.currentName())) {
readTensorBlocks(buffer, builder);
}
- else if (TENSOR_TYPE.equals(next.name) && next.token == JsonToken.VALUE_STRING) {
- buffer.next();
+ else if (TENSOR_TYPE.equals(buffer.currentName()) && buffer.current() == JsonToken.VALUE_STRING) {
// Ignore input tensor type
}
- else if (buffer.nesting() == initNesting && JsonToken.END_OBJECT == next.token) {
- buffer.next();
- break;
- }
else {
+ buffer.previous(); // Back up to the start of the enclosing block
readDirectTensorValue(buffer, builder);
- break;
+ buffer.previous(); // ... and back up to the end of the enclosing block
}
}
expectOneOf(buffer.current(), JsonToken.END_OBJECT, JsonToken.END_ARRAY);
tensorFieldValue.assign(builder.build());
}
- static boolean primitiveContent(JsonToken current, JsonToken next) {
- if (current.isScalarValue()) return true;
- if (current == JsonToken.START_ARRAY) {
- if (next == JsonToken.END_ARRAY) return false;
- if (next.isScalarValue()) return true;
+ static boolean primitiveContent(TokenBuffer buffer) {
+ JsonToken cellsValue = buffer.current();
+ if (cellsValue.isScalarValue()) return true;
+ if (cellsValue == JsonToken.START_ARRAY) {
+ JsonToken firstArrayValue = buffer.peek(1);
+ if (firstArrayValue == JsonToken.END_ARRAY) return false;
+ if (firstArrayValue.isScalarValue()) return true;
}
return false;
}
@@ -195,7 +186,7 @@ public class TensorReader {
boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped);
- if (isArrayOfObjects(buffer))
+ if (isArrayOfObjects(buffer, 0))
readTensorCells(buffer, builder);
else if ( ! hasMapped)
readTensorValues(buffer, builder);
@@ -205,12 +196,10 @@ public class TensorReader {
readTensorCells(buffer, builder);
}
- private static boolean isArrayOfObjects(TokenBuffer buffer) {
- if (buffer.current() != JsonToken.START_ARRAY) return false;
- Supplier<Token> lookahead = buffer.lookahead();
- Token next;
- while ((next = lookahead.get()).token == JsonToken.START_ARRAY) { }
- return next.token == JsonToken.START_OBJECT;
+ private static boolean isArrayOfObjects(TokenBuffer buffer, int ahead) {
+ if (buffer.peek(ahead++) != JsonToken.START_ARRAY) return false;
+ if (buffer.peek(ahead) == JsonToken.START_ARRAY) return isArrayOfObjects(buffer, ahead); // nested array
+ return buffer.peek(ahead) == JsonToken.START_OBJECT;
}
private static TensorAddress readAddress(TokenBuffer buffer, TensorType type) {