diff options
-rw-r--r-- | document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java | 48 |
1 files changed, 24 insertions, 24 deletions
diff --git a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java index 6c560d9207d..6dbe1c05646 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java +++ b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java @@ -53,18 +53,10 @@ public class TensorFieldValue extends FieldValue { private void lazyDeserialize() { if (tensor.isEmpty() && serializedTensor.isPresent()) { - Tensor t = TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(serializedTensor.get())); - if (dataType.isEmpty()) { - this.dataType = Optional.of(new TensorDataType(t.type())); - this.tensor = Optional.of(t); - } else { - if (t.type().isAssignableTo(dataType.get().getTensorType())) { - this.tensor = Optional.of(t); - } else { - throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + t.type() + - " to field of type " + dataType.get()); - } - } + var t = TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(serializedTensor.get())); + Optional<Tensor> newTensor = Optional.of(t); + assignTypeFrom(newTensor); + this.tensor = newTensor; } } @@ -74,7 +66,9 @@ public class TensorFieldValue extends FieldValue { } public Optional<TensorType> getTensorType() { - lazyDeserialize(); + if (! dataType.isPresent()) { + lazyDeserialize(); + } return dataType.isPresent() ? Optional.of(dataType.get().getTensorType()) : Optional.empty(); } @@ -104,6 +98,18 @@ public class TensorFieldValue extends FieldValue { serializedTensor = Optional.empty(); } + private void assignTypeFrom(Optional<Tensor> newTensor) { + if (newTensor.isEmpty()) return; + TensorType newType = newTensor.get().type(); + if (dataType.isEmpty()) { + this.dataType = Optional.of(new TensorDataType(newType)); + } + TensorType curType = dataType.get().getTensorType(); + if (! newType.isAssignableTo(curType)) { + throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + newType + " to field of type " + curType); + } + } + @Override public void assign(Object o) { if (o == null) { @@ -111,7 +117,10 @@ public class TensorFieldValue extends FieldValue { } else if (o instanceof Tensor) { assignTensor(Optional.of((Tensor)o)); } else if (o instanceof TensorFieldValue) { - assignTensor(((TensorFieldValue)o).getTensor()); + var tfv = (TensorFieldValue)o; + assignTypeFrom(tfv.tensor); + this.serializedTensor = tfv.serializedTensor; + this.tensor = tfv.tensor; } else { throw new IllegalArgumentException("Expected class '" + getClass().getName() + "', got '" + o.getClass().getName() + "'."); @@ -139,17 +148,8 @@ public class TensorFieldValue extends FieldValue { * The tensor type is also set from the given tensor if it was not set before. */ public void assignTensor(Optional<Tensor> tensor) { + assignTypeFrom(tensor); this.serializedTensor = Optional.empty(); - if (tensor.isPresent()) { - if (getTensorType().isPresent() && - !tensor.get().type().isAssignableTo(getTensorType().get())) { - throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() + - " to field of type " + getTensorType().get()); - } - if (getTensorType().isEmpty()) { - this.dataType = Optional.of(new TensorDataType(tensor.get().type())); - } - } this.tensor = tensor; } |