summaryrefslogtreecommitdiffstats
path: root/document/src/main/java/com/yahoo/document
diff options
context:
space:
mode:
Diffstat (limited to 'document/src/main/java/com/yahoo/document')
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java18
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java19
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java14
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java25
4 files changed, 63 insertions, 13 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
index 21fa51d5b88..92ede0fbe99 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
@@ -2,6 +2,7 @@
package com.yahoo.document.json.readers;
+import com.fasterxml.jackson.core.JsonToken;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
@@ -29,6 +30,7 @@ public class TensorModifyUpdateReader {
private static final String MODIFY_REPLACE = "replace";
private static final String MODIFY_ADD = "add";
private static final String MODIFY_MULTIPLY = "multiply";
+ private static final String MODIFY_CREATE = "create";
public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) {
expectFieldIsOfTypeTensor(field);
@@ -39,7 +41,7 @@ public class TensorModifyUpdateReader {
expectOperationSpecified(result.operation, field.getName());
expectTensorSpecified(result.tensor, field.getName());
- return new TensorModifyUpdate(result.operation, result.tensor);
+ return new TensorModifyUpdate(result.operation, result.tensor, result.createNonExistingCells);
}
private static void expectFieldIsOfTypeTensor(Field field) {
@@ -73,6 +75,7 @@ public class TensorModifyUpdateReader {
private static class ModifyUpdateResult {
TensorModifyUpdate.Operation operation = null;
+ boolean createNonExistingCells = false;
TensorFieldValue tensor = null;
}
@@ -85,6 +88,9 @@ public class TensorModifyUpdateReader {
case MODIFY_OPERATION:
result.operation = createOperation(buffer, field.getName());
break;
+ case MODIFY_CREATE:
+ result.createNonExistingCells = createNonExistingCells(buffer, field.getName());
+ break;
case TENSOR_CELLS:
result.tensor = createTensorFromCells(buffer, field);
break;
@@ -112,6 +118,16 @@ public class TensorModifyUpdateReader {
}
}
+ private static Boolean createNonExistingCells(TokenBuffer buffer, String fieldName) {
+ if (buffer.current() == JsonToken.VALUE_TRUE) {
+ return true;
+ } else if (buffer.current() == JsonToken.VALUE_FALSE) {
+ return false;
+ } else {
+ throw new IllegalArgumentException("Unknown value '" + buffer.currentText() + "' for '" + MODIFY_CREATE + "' in modify update for field '" + fieldName + "'");
+ }
+ }
+
private static TensorFieldValue createTensorFromCells(TokenBuffer buffer, Field field) {
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType originalType = tensorDataType.getTensorType();
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
index c6fdc915401..765d999dbc9 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
@@ -26,20 +26,35 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
@Override
protected ValueUpdate readTensorModifyUpdate(DataType type) {
byte operationId = getByte(null);
- TensorModifyUpdate.Operation operation = TensorModifyUpdate.Operation.getOperation(operationId);
+ var operation = decodeOperation(operationId);
if (operation == null) {
throw new DeserializationException("Unknown operation id " + operationId + " for tensor modify update");
}
if (!(type instanceof TensorDataType)) {
throw new DeserializationException("Expected tensor data type, got " + type);
}
+ var createNonExistingCells = decodeCreateNonExistingCells(operationId);
+ if (createNonExistingCells) {
+ // Read the default cell value (but it is not used by TensorModifyUpdate).
+ getDouble(null);
+ }
TensorDataType tensorDataType = (TensorDataType)type;
TensorType tensorType = tensorDataType.getTensorType();
TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(tensorType);
TensorFieldValue tensor = new TensorFieldValue(convertedType);
tensor.deserialize(this);
- return new TensorModifyUpdate(operation, tensor);
+ return new TensorModifyUpdate(operation, tensor, createNonExistingCells);
+ }
+
+ private TensorModifyUpdate.Operation decodeOperation(byte operationId) {
+ byte OP_MASK = 0b01111111;
+ return TensorModifyUpdate.Operation.getOperation(operationId & OP_MASK);
+ }
+
+ private boolean decodeCreateNonExistingCells(byte operationId) {
+ byte CREATE_FLAG = -0b10000000;
+ return (operationId & CREATE_FLAG) != 0;
}
@Override
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java
index 66bc8cbb4d5..b2c3cdc09de 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java
@@ -19,10 +19,22 @@ public class VespaDocumentSerializerHead extends VespaDocumentSerializer6 {
@Override
public void write(TensorModifyUpdate update) {
- putByte(null, (byte) update.getOperation().id);
+ putByte(null, (byte) encodeOperationId(update));
+ if (update.getCreateNonExistingCells()) {
+ putDouble(null, update.getDefaultCellValue());
+ }
update.getValue().serialize(this);
}
+ private int encodeOperationId(TensorModifyUpdate update) {
+ int operationId = update.getOperation().id;
+ byte CREATE_FLAG = -0b10000000;
+ if (update.getCreateNonExistingCells()) {
+ operationId |= CREATE_FLAG;
+ }
+ return operationId;
+ }
+
@Override
public void write(TensorAddUpdate update) {
update.getValue().serialize(this);
diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
index 835c056868a..8a14bd21443 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
@@ -26,12 +26,17 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
protected Operation operation;
protected TensorFieldValue tensor;
- protected Optional<Double> defaultCellValue = Optional.empty();
+ protected boolean createNonExistingCells;
public TensorModifyUpdate(Operation operation, TensorFieldValue tensor) {
+ this(operation, tensor, false);
+ }
+
+ public TensorModifyUpdate(Operation operation, TensorFieldValue tensor, boolean createNonExistingCells) {
super(ValueUpdateClassID.TENSORMODIFY);
this.operation = operation;
this.tensor = tensor;
+ this.createNonExistingCells = createNonExistingCells;
verifyCompatibleType(tensor.getDataType().getTensorType());
}
@@ -51,10 +56,12 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
}
public Operation getOperation() { return operation; }
-
public TensorFieldValue getValue() { return tensor; }
+ public boolean getCreateNonExistingCells() { return createNonExistingCells; }
+ public double getDefaultCellValue() {
+ return (operation == Operation.MULTIPLY) ? 1.0 : 0.0;
+ }
public void setValue(TensorFieldValue value) { tensor = value; }
- public void setDefaultCellValue(double value) { defaultCellValue = Optional.of(value); }
@Override
public FieldValue applyTo(FieldValue oldValue) {
@@ -70,10 +77,10 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
default:
throw new UnsupportedOperationException("Unknown operation: " + operation);
}
- if (defaultCellValue.isPresent() && hasMappedSubtype(oldTensor.type())) {
+ if (createNonExistingCells && hasMappedSubtype(oldTensor.type())) {
var subspaces = findSubspacesNotInInput(oldTensor, tensor.getTensor().get());
if (!subspaces.isEmpty()) {
- oldTensor = insertSubspaces(oldTensor, subspaces, defaultCellValue.get());
+ oldTensor = insertSubspaces(oldTensor, subspaces, getDefaultCellValue());
}
}
Tensor modified = oldTensor.modify(modifier, tensor.getTensor().get().cells());
@@ -142,7 +149,6 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
return builder.build();
}
-
@Override
protected void checkCompatibility(DataType fieldType) {
if (!(fieldType instanceof TensorDataType)) {
@@ -162,17 +168,18 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
if (!super.equals(o)) return false;
TensorModifyUpdate that = (TensorModifyUpdate) o;
return operation == that.operation &&
- tensor.equals(that.tensor);
+ tensor.equals(that.tensor) &&
+ createNonExistingCells == that.createNonExistingCells;
}
@Override
public int hashCode() {
- return Objects.hash(super.hashCode(), operation, tensor);
+ return Objects.hash(super.hashCode(), operation, tensor, createNonExistingCells);
}
@Override
public String toString() {
- return super.toString() + " " + operation.name + " " + tensor;
+ return super.toString() + " " + operation.name + " " + tensor + " " + createNonExistingCells;
}
/**