summaryrefslogtreecommitdiffstats
path: root/document/src/main
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2023-08-25 10:07:35 +0000
committerGeir Storli <geirst@yahooinc.com>2023-08-25 10:07:35 +0000
commit6fd175bcb3e1b61323214380d6f23324a3056043 (patch)
tree6644e278e75daff81491a265ced5ccdc3cd4413a /document/src/main
parent39420e6f2331825568605cfeb2975844de99de3a (diff)
Add "create non-existing cells" flag to TensorModifyUpdate.
When this is true, non-existing cells in the input tensor is created before applying the modify update. The default cell value is 0.0 for REPLACE and ADD operations, and 1.0 for MULTIPLY operation.
Diffstat (limited to 'document/src/main')
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java18
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java19
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java14
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java25
4 files changed, 63 insertions, 13 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
index 21fa51d5b88..92ede0fbe99 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
@@ -2,6 +2,7 @@
package com.yahoo.document.json.readers;
+import com.fasterxml.jackson.core.JsonToken;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
@@ -29,6 +30,7 @@ public class TensorModifyUpdateReader {
private static final String MODIFY_REPLACE = "replace";
private static final String MODIFY_ADD = "add";
private static final String MODIFY_MULTIPLY = "multiply";
+ private static final String MODIFY_CREATE = "create";
public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) {
expectFieldIsOfTypeTensor(field);
@@ -39,7 +41,7 @@ public class TensorModifyUpdateReader {
expectOperationSpecified(result.operation, field.getName());
expectTensorSpecified(result.tensor, field.getName());
- return new TensorModifyUpdate(result.operation, result.tensor);
+ return new TensorModifyUpdate(result.operation, result.tensor, result.createNonExistingCells);
}
private static void expectFieldIsOfTypeTensor(Field field) {
@@ -73,6 +75,7 @@ public class TensorModifyUpdateReader {
private static class ModifyUpdateResult {
TensorModifyUpdate.Operation operation = null;
+ boolean createNonExistingCells = false;
TensorFieldValue tensor = null;
}
@@ -85,6 +88,9 @@ public class TensorModifyUpdateReader {
case MODIFY_OPERATION:
result.operation = createOperation(buffer, field.getName());
break;
+ case MODIFY_CREATE:
+ result.createNonExistingCells = createNonExistingCells(buffer, field.getName());
+ break;
case TENSOR_CELLS:
result.tensor = createTensorFromCells(buffer, field);
break;
@@ -112,6 +118,16 @@ public class TensorModifyUpdateReader {
}
}
+ private static Boolean createNonExistingCells(TokenBuffer buffer, String fieldName) {
+ if (buffer.current() == JsonToken.VALUE_TRUE) {
+ return true;
+ } else if (buffer.current() == JsonToken.VALUE_FALSE) {
+ return false;
+ } else {
+ throw new IllegalArgumentException("Unknown value '" + buffer.currentText() + "' for '" + MODIFY_CREATE + "' in modify update for field '" + fieldName + "'");
+ }
+ }
+
private static TensorFieldValue createTensorFromCells(TokenBuffer buffer, Field field) {
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType originalType = tensorDataType.getTensorType();
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
index c6fdc915401..765d999dbc9 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
@@ -26,20 +26,35 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
@Override
protected ValueUpdate readTensorModifyUpdate(DataType type) {
byte operationId = getByte(null);
- TensorModifyUpdate.Operation operation = TensorModifyUpdate.Operation.getOperation(operationId);
+ var operation = decodeOperation(operationId);
if (operation == null) {
throw new DeserializationException("Unknown operation id " + operationId + " for tensor modify update");
}
if (!(type instanceof TensorDataType)) {
throw new DeserializationException("Expected tensor data type, got " + type);
}
+ var createNonExistingCells = decodeCreateNonExistingCells(operationId);
+ if (createNonExistingCells) {
+ // Read the default cell value (but it is not used by TensorModifyUpdate).
+ getDouble(null);
+ }
TensorDataType tensorDataType = (TensorDataType)type;
TensorType tensorType = tensorDataType.getTensorType();
TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(tensorType);
TensorFieldValue tensor = new TensorFieldValue(convertedType);
tensor.deserialize(this);
- return new TensorModifyUpdate(operation, tensor);
+ return new TensorModifyUpdate(operation, tensor, createNonExistingCells);
+ }
+
+ private TensorModifyUpdate.Operation decodeOperation(byte operationId) {
+ byte OP_MASK = 0b01111111;
+ return TensorModifyUpdate.Operation.getOperation(operationId & OP_MASK);
+ }
+
+ private boolean decodeCreateNonExistingCells(byte operationId) {
+ byte CREATE_FLAG = -0b10000000;
+ return (operationId & CREATE_FLAG) != 0;
}
@Override
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java
index 66bc8cbb4d5..b2c3cdc09de 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java
@@ -19,10 +19,22 @@ public class VespaDocumentSerializerHead extends VespaDocumentSerializer6 {
@Override
public void write(TensorModifyUpdate update) {
- putByte(null, (byte) update.getOperation().id);
+ putByte(null, (byte) encodeOperationId(update));
+ if (update.getCreateNonExistingCells()) {
+ putDouble(null, update.getDefaultCellValue());
+ }
update.getValue().serialize(this);
}
+ private int encodeOperationId(TensorModifyUpdate update) {
+ int operationId = update.getOperation().id;
+ byte CREATE_FLAG = -0b10000000;
+ if (update.getCreateNonExistingCells()) {
+ operationId |= CREATE_FLAG;
+ }
+ return operationId;
+ }
+
@Override
public void write(TensorAddUpdate update) {
update.getValue().serialize(this);
diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
index 835c056868a..8a14bd21443 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
@@ -26,12 +26,17 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
protected Operation operation;
protected TensorFieldValue tensor;
- protected Optional<Double> defaultCellValue = Optional.empty();
+ protected boolean createNonExistingCells;
public TensorModifyUpdate(Operation operation, TensorFieldValue tensor) {
+ this(operation, tensor, false);
+ }
+
+ public TensorModifyUpdate(Operation operation, TensorFieldValue tensor, boolean createNonExistingCells) {
super(ValueUpdateClassID.TENSORMODIFY);
this.operation = operation;
this.tensor = tensor;
+ this.createNonExistingCells = createNonExistingCells;
verifyCompatibleType(tensor.getDataType().getTensorType());
}
@@ -51,10 +56,12 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
}
public Operation getOperation() { return operation; }
-
public TensorFieldValue getValue() { return tensor; }
+ public boolean getCreateNonExistingCells() { return createNonExistingCells; }
+ public double getDefaultCellValue() {
+ return (operation == Operation.MULTIPLY) ? 1.0 : 0.0;
+ }
public void setValue(TensorFieldValue value) { tensor = value; }
- public void setDefaultCellValue(double value) { defaultCellValue = Optional.of(value); }
@Override
public FieldValue applyTo(FieldValue oldValue) {
@@ -70,10 +77,10 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
default:
throw new UnsupportedOperationException("Unknown operation: " + operation);
}
- if (defaultCellValue.isPresent() && hasMappedSubtype(oldTensor.type())) {
+ if (createNonExistingCells && hasMappedSubtype(oldTensor.type())) {
var subspaces = findSubspacesNotInInput(oldTensor, tensor.getTensor().get());
if (!subspaces.isEmpty()) {
- oldTensor = insertSubspaces(oldTensor, subspaces, defaultCellValue.get());
+ oldTensor = insertSubspaces(oldTensor, subspaces, getDefaultCellValue());
}
}
Tensor modified = oldTensor.modify(modifier, tensor.getTensor().get().cells());
@@ -142,7 +149,6 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
return builder.build();
}
-
@Override
protected void checkCompatibility(DataType fieldType) {
if (!(fieldType instanceof TensorDataType)) {
@@ -162,17 +168,18 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
if (!super.equals(o)) return false;
TensorModifyUpdate that = (TensorModifyUpdate) o;
return operation == that.operation &&
- tensor.equals(that.tensor);
+ tensor.equals(that.tensor) &&
+ createNonExistingCells == that.createNonExistingCells;
}
@Override
public int hashCode() {
- return Objects.hash(super.hashCode(), operation, tensor);
+ return Objects.hash(super.hashCode(), operation, tensor, createNonExistingCells);
}
@Override
public String toString() {
- return super.toString() + " " + operation.name + " " + tensor;
+ return super.toString() + " " + operation.name + " " + tensor + " " + createNonExistingCells;
}
/**