summaryrefslogtreecommitdiffstats
path: root/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
diff options
context:
space:
mode:
Diffstat (limited to 'document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java')
-rw-r--r--document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java55
1 files changed, 49 insertions, 6 deletions
diff --git a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
index 8e7dbd3512a..3177ec54465 100644
--- a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
+++ b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
@@ -6,8 +6,10 @@ import com.yahoo.document.TensorDataType;
import com.yahoo.document.serialization.FieldReader;
import com.yahoo.document.serialization.FieldWriter;
import com.yahoo.document.serialization.XmlStream;
+import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.serialization.TypedBinaryFormat;
import java.util.Optional;
@@ -20,6 +22,8 @@ public class TensorFieldValue extends FieldValue {
private Optional<Tensor> tensor;
+ private Optional<byte[]> serializedTensor;
+
private Optional<TensorDataType> dataType;
/**
@@ -28,27 +32,49 @@ public class TensorFieldValue extends FieldValue {
* The tensor (and tensor type) can later be assigned with assignTensor().
*/
public TensorFieldValue() {
- this.dataType = Optional.empty();
- this.tensor = Optional.empty();
+ this.dataType = Optional.empty();
+ this.serializedTensor = Optional.empty();
+ this.tensor = Optional.empty();
}
/** Create an empty tensor field value for the given tensor type */
public TensorFieldValue(TensorType type) {
this.dataType = Optional.of(new TensorDataType(type));
+ this.serializedTensor = Optional.empty();
this.tensor = Optional.empty();
}
/** Create a tensor field value containing the given tensor */
public TensorFieldValue(Tensor tensor) {
this.dataType = Optional.of(new TensorDataType(tensor.type()));
+ this.serializedTensor = Optional.empty();
this.tensor = Optional.of(tensor);
}
+ private void lazyDeserialize() {
+ if (tensor.isEmpty() && serializedTensor.isPresent()) {
+ Tensor t = TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(serializedTensor.get()));
+ if (dataType.isEmpty()) {
+ this.dataType = Optional.of(new TensorDataType(t.type()));
+ this.tensor = Optional.of(t);
+ } else {
+ if (t.type().isAssignableTo(dataType.get().getTensorType())) {
+ this.tensor = Optional.of(t);
+ } else {
+ throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + t.type() +
+ " to field of type " + dataType.get());
+ }
+ }
+ }
+ }
+
public Optional<Tensor> getTensor() {
+ lazyDeserialize();
return tensor;
}
public Optional<TensorType> getTensorType() {
+ lazyDeserialize();
return dataType.isPresent() ? Optional.of(dataType.get().getTensorType()) : Optional.empty();
}
@@ -59,8 +85,9 @@ public class TensorFieldValue extends FieldValue {
@Override
public String toString() {
- if (tensor.isPresent()) {
- return tensor.get().toString();
+ var t = getTensor();
+ if (t.isPresent()) {
+ return t.get().toString();
} else {
return "null";
}
@@ -74,6 +101,7 @@ public class TensorFieldValue extends FieldValue {
@Override
public void clear() {
tensor = Optional.empty();
+ serializedTensor = Optional.empty();
}
@Override
@@ -90,12 +118,28 @@ public class TensorFieldValue extends FieldValue {
}
}
+ public void assignSerializedTensor(byte[] data) {
+ serializedTensor = Optional.of(data);
+ tensor = Optional.empty();
+ }
+
+ public Optional<byte[]> getSerializedTensor() {
+ if (serializedTensor.isPresent()) {
+ return serializedTensor;
+ } else if (tensor.isPresent()) {
+ serializedTensor = Optional.of(TypedBinaryFormat.encode(tensor.get()));
+ assert(serializedTensor.isPresent());
+ }
+ return serializedTensor;
+ }
+
/**
* Assigns the given tensor to this field value.
*
* The tensor type is also set from the given tensor if it was not set before.
*/
public void assignTensor(Optional<Tensor> tensor) {
+ this.serializedTensor = Optional.empty();
if (tensor.isPresent()) {
if (getTensorType().isPresent() &&
!tensor.get().type().isAssignableTo(getTensorType().get())) {
@@ -126,7 +170,7 @@ public class TensorFieldValue extends FieldValue {
TensorFieldValue other = (TensorFieldValue)o;
if ( ! getTensorType().equals(other.getTensorType())) return false;
- if ( ! tensor.equals(other.tensor)) return false;
+ if ( ! getTensor().equals(other.getTensor())) return false;
return true;
}
@@ -136,4 +180,3 @@ public class TensorFieldValue extends FieldValue {
}
}
-