diff options
Diffstat (limited to 'document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java')
-rw-r--r-- | document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java | 44 |
1 files changed, 35 insertions, 9 deletions
diff --git a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java index 2c6a556c652..8e7dbd3512a 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java +++ b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java @@ -20,17 +20,27 @@ public class TensorFieldValue extends FieldValue { private Optional<Tensor> tensor; - private final TensorDataType dataType; + private Optional<TensorDataType> dataType; + + /** + * Create an empty tensor field value where the tensor type is not yet known. + * + * The tensor (and tensor type) can later be assigned with assignTensor(). + */ + public TensorFieldValue() { + this.dataType = Optional.empty(); + this.tensor = Optional.empty(); + } - /** Create an empty tensor field value */ + /** Create an empty tensor field value for the given tensor type */ public TensorFieldValue(TensorType type) { - this.dataType = new TensorDataType(type); + this.dataType = Optional.of(new TensorDataType(type)); this.tensor = Optional.empty(); } /** Create a tensor field value containing the given tensor */ public TensorFieldValue(Tensor tensor) { - this.dataType = new TensorDataType(tensor.type()); + this.dataType = Optional.of(new TensorDataType(tensor.type())); this.tensor = Optional.of(tensor); } @@ -38,9 +48,13 @@ public class TensorFieldValue extends FieldValue { return tensor; } + public Optional<TensorType> getTensorType() { + return dataType.isPresent() ? Optional.of(dataType.get().getTensorType()) : Optional.empty(); + } + @Override public TensorDataType getDataType() { - return dataType; + return dataType.get(); } @Override @@ -76,10 +90,22 @@ public class TensorFieldValue extends FieldValue { } } + /** + * Assigns the given tensor to this field value. + * + * The tensor type is also set from the given tensor if it was not set before. + */ public void assignTensor(Optional<Tensor> tensor) { - if (tensor.isPresent() && ! tensor.get().type().isAssignableTo(dataType.getTensorType())) - throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() + - " to field of type " + dataType.getTensorType()); + if (tensor.isPresent()) { + if (getTensorType().isPresent() && + !tensor.get().type().isAssignableTo(getTensorType().get())) { + throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() + + " to field of type " + getTensorType().get()); + } + if (getTensorType().isEmpty()) { + this.dataType = Optional.of(new TensorDataType(tensor.get().type())); + } + } this.tensor = tensor; } @@ -99,7 +125,7 @@ public class TensorFieldValue extends FieldValue { if ( ! (o instanceof TensorFieldValue)) return false; TensorFieldValue other = (TensorFieldValue)o; - if ( ! dataType.getTensorType().equals(other.dataType.getTensorType())) return false; + if ( ! getTensorType().equals(other.getTensorType())) return false; if ( ! tensor.equals(other.tensor)) return false; return true; } |