aboutsummaryrefslogtreecommitdiffstats
path: root/document/src/main/java/com
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2020-11-11 17:06:23 +0000
committerGeir Storli <geirst@verizonmedia.com>2020-11-17 12:57:49 +0000
commit4cd2c6a1d4d2ab7337678931271a815b535ce518 (patch)
treeede014280efb581bc841c7ab7702c0a7b9e028a6 /document/src/main/java/com
parente4c14623ad4ecbe6337a49d2176621c528bf7c22 (diff)
Extend tensor remove update to support not fully specified addresses and update JSON parser.
Previously, all the sparse dimensions of the sparse or mixed tensor type (to remove from) had to be specified in the addresses to remove.
Diffstat (limited to 'document/src/main/java/com')
-rw-r--r--document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java44
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java66
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java2
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java8
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java28
5 files changed, 110 insertions, 38 deletions
diff --git a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
index 2c6a556c652..8e7dbd3512a 100644
--- a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
+++ b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
@@ -20,17 +20,27 @@ public class TensorFieldValue extends FieldValue {
private Optional<Tensor> tensor;
- private final TensorDataType dataType;
+ private Optional<TensorDataType> dataType;
+
+ /**
+ * Create an empty tensor field value where the tensor type is not yet known.
+ *
+ * The tensor (and tensor type) can later be assigned with assignTensor().
+ */
+ public TensorFieldValue() {
+ this.dataType = Optional.empty();
+ this.tensor = Optional.empty();
+ }
- /** Create an empty tensor field value */
+ /** Create an empty tensor field value for the given tensor type */
public TensorFieldValue(TensorType type) {
- this.dataType = new TensorDataType(type);
+ this.dataType = Optional.of(new TensorDataType(type));
this.tensor = Optional.empty();
}
/** Create a tensor field value containing the given tensor */
public TensorFieldValue(Tensor tensor) {
- this.dataType = new TensorDataType(tensor.type());
+ this.dataType = Optional.of(new TensorDataType(tensor.type()));
this.tensor = Optional.of(tensor);
}
@@ -38,9 +48,13 @@ public class TensorFieldValue extends FieldValue {
return tensor;
}
+ public Optional<TensorType> getTensorType() {
+ return dataType.isPresent() ? Optional.of(dataType.get().getTensorType()) : Optional.empty();
+ }
+
@Override
public TensorDataType getDataType() {
- return dataType;
+ return dataType.get();
}
@Override
@@ -76,10 +90,22 @@ public class TensorFieldValue extends FieldValue {
}
}
+ /**
+ * Assigns the given tensor to this field value.
+ *
+ * The tensor type is also set from the given tensor if it was not set before.
+ */
public void assignTensor(Optional<Tensor> tensor) {
- if (tensor.isPresent() && ! tensor.get().type().isAssignableTo(dataType.getTensorType()))
- throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() +
- " to field of type " + dataType.getTensorType());
+ if (tensor.isPresent()) {
+ if (getTensorType().isPresent() &&
+ !tensor.get().type().isAssignableTo(getTensorType().get())) {
+ throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() +
+ " to field of type " + getTensorType().get());
+ }
+ if (getTensorType().isEmpty()) {
+ this.dataType = Optional.of(new TensorDataType(tensor.get().type()));
+ }
+ }
this.tensor = tensor;
}
@@ -99,7 +125,7 @@ public class TensorFieldValue extends FieldValue {
if ( ! (o instanceof TensorFieldValue)) return false;
TensorFieldValue other = (TensorFieldValue)o;
- if ( ! dataType.getTensorType().equals(other.dataType.getTensorType())) return false;
+ if ( ! getTensorType().equals(other.getTensorType())) return false;
if ( ! tensor.equals(other.tensor)) return false;
return true;
}
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
index 91c275b6da0..cffc85777dc 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
@@ -1,6 +1,7 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.document.json.readers;
+import com.yahoo.collections.Pair;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
@@ -10,6 +11,8 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import java.util.HashMap;
+
import static com.yahoo.document.json.readers.JsonParserHelpers.expectArrayStart;
import static com.yahoo.document.json.readers.JsonParserHelpers.expectCompositeEnd;
import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectEnd;
@@ -28,8 +31,8 @@ public class TensorRemoveUpdateReader {
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType originalType = tensorDataType.getTensorType();
- TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(originalType);
- Tensor tensor = readRemoveUpdateTensor(buffer, convertedType, originalType);
+ TensorType sparseType = TensorRemoveUpdate.extractSparseDimensions(originalType);
+ Tensor tensor = readRemoveUpdateTensor(buffer, sparseType, originalType);
expectAddressesAreNonEmpty(field, tensor);
return new TensorRemoveUpdate(new TensorFieldValue(tensor));
@@ -54,8 +57,8 @@ public class TensorRemoveUpdateReader {
/**
* Reads all addresses in buffer and returns a tensor where addresses have cell value 1.0
*/
- private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type, TensorType originalType) {
- Tensor.Builder builder = Tensor.Builder.of(type);
+ private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType sparseType, TensorType originalType) {
+ Tensor.Builder builder = null;
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
@@ -63,13 +66,55 @@ public class TensorRemoveUpdateReader {
expectArrayStart(buffer.currentToken());
int nesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= nesting; buffer.next()) {
- builder.cell(readTensorAddress(buffer, type, originalType), 1.0);
+ if (builder == null) {
+ var typeAndAddress = readFirstTensorAddress(buffer, sparseType, originalType);
+ builder = Tensor.Builder.of(typeAndAddress.getFirst());
+ builder.cell(typeAndAddress.getSecond(), 1.0);
+ } else {
+ builder.cell(readTensorAddress(buffer, builder.type(), originalType), 1.0);
+ }
}
expectCompositeEnd(buffer.currentToken());
}
}
expectObjectEnd(buffer.currentToken());
- return builder.build();
+ return (builder != null) ? builder.build() : Tensor.Builder.of(sparseType).build();
+ }
+
+ /**
+ * Reads the first raw tensor address from the given buffer and resolves and returns the tensor type and tensor address based on this.
+ * The resulting tensor type contains a subset or all of the dimensions from the given sparseType.
+ */
+ private static Pair<TensorType, TensorAddress> readFirstTensorAddress(TokenBuffer buffer, TensorType sparseType, TensorType originalType) {
+ var typeBuilder = new TensorType.Builder(sparseType.valueType());
+ var rawAddress = new HashMap<String, String>();
+ expectObjectStart(buffer.currentToken());
+ int initNesting = buffer.nesting();
+ for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
+ var elem = readRawElement(buffer, sparseType, originalType);
+ var dimension = sparseType.dimension(elem.getFirst());
+ if (dimension.isPresent()) {
+ typeBuilder.dimension(dimension.get());
+ rawAddress.put(elem.getFirst(), elem.getSecond());
+ } else {
+ throw new IllegalArgumentException(originalType + " does not contain dimension '" + elem.getFirst() + "'");
+ }
+ }
+ expectObjectEnd(buffer.currentToken());
+ var type = typeBuilder.build();
+ var builder = new TensorAddress.Builder(type);
+ rawAddress.forEach((dimension, label) -> builder.add(dimension, label));
+ return new Pair<>(type, builder.build());
+ }
+
+ private static Pair<String, String> readRawElement(TokenBuffer buffer, TensorType type, TensorType originalType) {
+ String dimension = buffer.currentName();
+ if (type.dimension(dimension).isEmpty() && originalType.dimension(dimension).isPresent()) {
+ throw new IllegalArgumentException("Indexed dimension address '" + dimension +
+ "' should not be specified in remove update");
+ }
+ String label = buffer.currentText();
+ return new Pair<>(dimension, label);
}
private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type, TensorType originalType) {
@@ -77,13 +122,8 @@ public class TensorRemoveUpdateReader {
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
- String dimension = buffer.currentName();
- if ( type.dimension(dimension).isEmpty() && originalType.dimension(dimension).isPresent()) {
- throw new IllegalArgumentException("Indexed dimension address '" + dimension +
- "' should not be specified in remove update");
- }
- String label = buffer.currentText();
- builder.add(dimension, label);
+ var elem = readRawElement(buffer, type, originalType);
+ builder.add(elem.getFirst(), elem.getSecond());
}
expectObjectEnd(buffer.currentToken());
return builder.build();
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java
index cac05fb7879..92b3b566b85 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer6.java
@@ -243,7 +243,7 @@ public class VespaDocumentDeserializer6 extends BufferSerializer implements Docu
int encodedTensorLength = buf.getInt1_4Bytes();
if (encodedTensorLength > 0) {
byte[] encodedTensor = getBytes(null, encodedTensorLength);
- value.assign(TypedBinaryFormat.decode(Optional.of(value.getDataType().getTensorType()),
+ value.assign(TypedBinaryFormat.decode(value.getTensorType(),
GrowableByteBuffer.wrap(encodedTensor)));
} else {
value.clear();
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
index 58c50f047f9..e7f1525ff81 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
@@ -61,10 +61,12 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
}
TensorDataType tensorDataType = (TensorDataType)type;
TensorType tensorType = tensorDataType.getTensorType();
- TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(tensorType);
- TensorFieldValue tensor = new TensorFieldValue(convertedType);
+ TensorFieldValue tensor = new TensorFieldValue();
tensor.deserialize(this);
- return new TensorRemoveUpdate(tensor);
+ var result = new TensorRemoveUpdate(tensor);
+ result.verifyCompatibleType(tensorType);
+ return result;
}
+
}
diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
index 981120af145..a300565391f 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
@@ -12,9 +12,10 @@ import com.yahoo.tensor.TensorType;
import java.util.Objects;
/**
- * An update used to remove cells from a sparse tensor (has only mapped dimensions).
+ * An update used to remove cells from a sparse tensor or dense sub-spaces from a mixed tensor.
*
- * The cells to remove are contained in a sparse tensor where cell values are set to 1.0
+ * The specification of which cells to remove contains addresses using a subset or all of the sparse dimensions of the tensor type.
+ * This is represented as a sparse tensor where cell values are set to 1.0.
*/
public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
@@ -23,17 +24,20 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
public TensorRemoveUpdate(TensorFieldValue value) {
super(ValueUpdateClassID.TENSORREMOVE);
this.tensor = value;
- verifyCompatibleType();
- }
-
- private void verifyCompatibleType() {
- if ( ! tensor.getTensor().isPresent()) {
+ if (!tensor.getTensor().isPresent()) {
throw new IllegalArgumentException("Tensor must be present in remove update");
}
- TensorType tensorType = tensor.getTensor().get().type();
- TensorType expectedType = extractSparseDimensions(tensor.getDataType().getTensorType());
- if ( ! tensorType.equals(expectedType)) {
- throw new IllegalArgumentException("Unexpected type '" + tensorType + "' in remove update. Expected is '" + expectedType + "'");
+ verifyCompatibleType(tensor.getTensorType().get());
+ }
+
+ public void verifyCompatibleType(TensorType originalType) {
+ TensorType sparseType = extractSparseDimensions(originalType);
+ TensorType thisType = tensor.getTensorType().get();
+ for (var dim : thisType.dimensions()) {
+ if (sparseType.dimension(dim.name()).isEmpty()) {
+ throw new IllegalArgumentException("Unexpected type '" + thisType + "' in remove update. "
+ + "Expected dimensions to be a subset of '" + sparseType + "'");
+ }
}
}
@@ -63,6 +67,7 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
Tensor old = ((TensorFieldValue) oldValue).getTensor().get();
Tensor update = tensor.getTensor().get();
+ // TODO: handle the case where this tensor only contains a subset of the sparse dimensions of the input tensor.
Tensor result = old.remove(update.cells().keySet());
return new TensorFieldValue(result);
}
@@ -102,5 +107,4 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
return builder.build();
}
-
}