summaryrefslogtreecommitdiffstats
path: root/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
diff options
context:
space:
mode:
Diffstat (limited to 'document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java')
-rw-r--r--document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java44
1 files changed, 35 insertions, 9 deletions
diff --git a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
index 2c6a556c652..8e7dbd3512a 100644
--- a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
+++ b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java
@@ -20,17 +20,27 @@ public class TensorFieldValue extends FieldValue {
private Optional<Tensor> tensor;
- private final TensorDataType dataType;
+ private Optional<TensorDataType> dataType;
+
+ /**
+ * Create an empty tensor field value where the tensor type is not yet known.
+ *
+ * The tensor (and tensor type) can later be assigned with assignTensor().
+ */
+ public TensorFieldValue() {
+ this.dataType = Optional.empty();
+ this.tensor = Optional.empty();
+ }
- /** Create an empty tensor field value */
+ /** Create an empty tensor field value for the given tensor type */
public TensorFieldValue(TensorType type) {
- this.dataType = new TensorDataType(type);
+ this.dataType = Optional.of(new TensorDataType(type));
this.tensor = Optional.empty();
}
/** Create a tensor field value containing the given tensor */
public TensorFieldValue(Tensor tensor) {
- this.dataType = new TensorDataType(tensor.type());
+ this.dataType = Optional.of(new TensorDataType(tensor.type()));
this.tensor = Optional.of(tensor);
}
@@ -38,9 +48,13 @@ public class TensorFieldValue extends FieldValue {
return tensor;
}
+ public Optional<TensorType> getTensorType() {
+ return dataType.isPresent() ? Optional.of(dataType.get().getTensorType()) : Optional.empty();
+ }
+
@Override
public TensorDataType getDataType() {
- return dataType;
+ return dataType.get();
}
@Override
@@ -76,10 +90,22 @@ public class TensorFieldValue extends FieldValue {
}
}
+ /**
+ * Assigns the given tensor to this field value.
+ *
+ * The tensor type is also set from the given tensor if it was not set before.
+ */
public void assignTensor(Optional<Tensor> tensor) {
- if (tensor.isPresent() && ! tensor.get().type().isAssignableTo(dataType.getTensorType()))
- throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() +
- " to field of type " + dataType.getTensorType());
+ if (tensor.isPresent()) {
+ if (getTensorType().isPresent() &&
+ !tensor.get().type().isAssignableTo(getTensorType().get())) {
+ throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() +
+ " to field of type " + getTensorType().get());
+ }
+ if (getTensorType().isEmpty()) {
+ this.dataType = Optional.of(new TensorDataType(tensor.get().type()));
+ }
+ }
this.tensor = tensor;
}
@@ -99,7 +125,7 @@ public class TensorFieldValue extends FieldValue {
if ( ! (o instanceof TensorFieldValue)) return false;
TensorFieldValue other = (TensorFieldValue)o;
- if ( ! dataType.getTensorType().equals(other.dataType.getTensorType())) return false;
+ if ( ! getTensorType().equals(other.getTensorType())) return false;
if ( ! tensor.equals(other.tensor)) return false;
return true;
}