diff options
Diffstat (limited to 'document/src/main/java/com/yahoo/document')
4 files changed, 63 insertions, 13 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java index 21fa51d5b88..92ede0fbe99 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java @@ -2,6 +2,7 @@ package com.yahoo.document.json.readers; +import com.fasterxml.jackson.core.JsonToken; import com.yahoo.document.Field; import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.TensorFieldValue; @@ -29,6 +30,7 @@ public class TensorModifyUpdateReader { private static final String MODIFY_REPLACE = "replace"; private static final String MODIFY_ADD = "add"; private static final String MODIFY_MULTIPLY = "multiply"; + private static final String MODIFY_CREATE = "create"; public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) { expectFieldIsOfTypeTensor(field); @@ -39,7 +41,7 @@ public class TensorModifyUpdateReader { expectOperationSpecified(result.operation, field.getName()); expectTensorSpecified(result.tensor, field.getName()); - return new TensorModifyUpdate(result.operation, result.tensor); + return new TensorModifyUpdate(result.operation, result.tensor, result.createNonExistingCells); } private static void expectFieldIsOfTypeTensor(Field field) { @@ -73,6 +75,7 @@ public class TensorModifyUpdateReader { private static class ModifyUpdateResult { TensorModifyUpdate.Operation operation = null; + boolean createNonExistingCells = false; TensorFieldValue tensor = null; } @@ -85,6 +88,9 @@ public class TensorModifyUpdateReader { case MODIFY_OPERATION: result.operation = createOperation(buffer, field.getName()); break; + case MODIFY_CREATE: + result.createNonExistingCells = createNonExistingCells(buffer, field.getName()); + break; case TENSOR_CELLS: result.tensor = createTensorFromCells(buffer, field); break; @@ -112,6 +118,16 @@ public class TensorModifyUpdateReader { } } + private static Boolean createNonExistingCells(TokenBuffer buffer, String fieldName) { + if (buffer.current() == JsonToken.VALUE_TRUE) { + return true; + } else if (buffer.current() == JsonToken.VALUE_FALSE) { + return false; + } else { + throw new IllegalArgumentException("Unknown value '" + buffer.currentText() + "' for '" + MODIFY_CREATE + "' in modify update for field '" + fieldName + "'"); + } + } + private static TensorFieldValue createTensorFromCells(TokenBuffer buffer, Field field) { TensorDataType tensorDataType = (TensorDataType)field.getDataType(); TensorType originalType = tensorDataType.getTensorType(); diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java index c6fdc915401..765d999dbc9 100644 --- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java +++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java @@ -26,20 +26,35 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 { @Override protected ValueUpdate readTensorModifyUpdate(DataType type) { byte operationId = getByte(null); - TensorModifyUpdate.Operation operation = TensorModifyUpdate.Operation.getOperation(operationId); + var operation = decodeOperation(operationId); if (operation == null) { throw new DeserializationException("Unknown operation id " + operationId + " for tensor modify update"); } if (!(type instanceof TensorDataType)) { throw new DeserializationException("Expected tensor data type, got " + type); } + var createNonExistingCells = decodeCreateNonExistingCells(operationId); + if (createNonExistingCells) { + // Read the default cell value (but it is not used by TensorModifyUpdate). + getDouble(null); + } TensorDataType tensorDataType = (TensorDataType)type; TensorType tensorType = tensorDataType.getTensorType(); TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(tensorType); TensorFieldValue tensor = new TensorFieldValue(convertedType); tensor.deserialize(this); - return new TensorModifyUpdate(operation, tensor); + return new TensorModifyUpdate(operation, tensor, createNonExistingCells); + } + + private TensorModifyUpdate.Operation decodeOperation(byte operationId) { + byte OP_MASK = 0b01111111; + return TensorModifyUpdate.Operation.getOperation(operationId & OP_MASK); + } + + private boolean decodeCreateNonExistingCells(byte operationId) { + byte CREATE_FLAG = -0b10000000; + return (operationId & CREATE_FLAG) != 0; } @Override diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java index 66bc8cbb4d5..b2c3cdc09de 100644 --- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java +++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java @@ -19,10 +19,22 @@ public class VespaDocumentSerializerHead extends VespaDocumentSerializer6 { @Override public void write(TensorModifyUpdate update) { - putByte(null, (byte) update.getOperation().id); + putByte(null, (byte) encodeOperationId(update)); + if (update.getCreateNonExistingCells()) { + putDouble(null, update.getDefaultCellValue()); + } update.getValue().serialize(this); } + private int encodeOperationId(TensorModifyUpdate update) { + int operationId = update.getOperation().id; + byte CREATE_FLAG = -0b10000000; + if (update.getCreateNonExistingCells()) { + operationId |= CREATE_FLAG; + } + return operationId; + } + @Override public void write(TensorAddUpdate update) { update.getValue().serialize(this); diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java index 835c056868a..8a14bd21443 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java @@ -26,12 +26,17 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { protected Operation operation; protected TensorFieldValue tensor; - protected Optional<Double> defaultCellValue = Optional.empty(); + protected boolean createNonExistingCells; public TensorModifyUpdate(Operation operation, TensorFieldValue tensor) { + this(operation, tensor, false); + } + + public TensorModifyUpdate(Operation operation, TensorFieldValue tensor, boolean createNonExistingCells) { super(ValueUpdateClassID.TENSORMODIFY); this.operation = operation; this.tensor = tensor; + this.createNonExistingCells = createNonExistingCells; verifyCompatibleType(tensor.getDataType().getTensorType()); } @@ -51,10 +56,12 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { } public Operation getOperation() { return operation; } - public TensorFieldValue getValue() { return tensor; } + public boolean getCreateNonExistingCells() { return createNonExistingCells; } + public double getDefaultCellValue() { + return (operation == Operation.MULTIPLY) ? 1.0 : 0.0; + } public void setValue(TensorFieldValue value) { tensor = value; } - public void setDefaultCellValue(double value) { defaultCellValue = Optional.of(value); } @Override public FieldValue applyTo(FieldValue oldValue) { @@ -70,10 +77,10 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { default: throw new UnsupportedOperationException("Unknown operation: " + operation); } - if (defaultCellValue.isPresent() && hasMappedSubtype(oldTensor.type())) { + if (createNonExistingCells && hasMappedSubtype(oldTensor.type())) { var subspaces = findSubspacesNotInInput(oldTensor, tensor.getTensor().get()); if (!subspaces.isEmpty()) { - oldTensor = insertSubspaces(oldTensor, subspaces, defaultCellValue.get()); + oldTensor = insertSubspaces(oldTensor, subspaces, getDefaultCellValue()); } } Tensor modified = oldTensor.modify(modifier, tensor.getTensor().get().cells()); @@ -142,7 +149,6 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { return builder.build(); } - @Override protected void checkCompatibility(DataType fieldType) { if (!(fieldType instanceof TensorDataType)) { @@ -162,17 +168,18 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { if (!super.equals(o)) return false; TensorModifyUpdate that = (TensorModifyUpdate) o; return operation == that.operation && - tensor.equals(that.tensor); + tensor.equals(that.tensor) && + createNonExistingCells == that.createNonExistingCells; } @Override public int hashCode() { - return Objects.hash(super.hashCode(), operation, tensor); + return Objects.hash(super.hashCode(), operation, tensor, createNonExistingCells); } @Override public String toString() { - return super.toString() + " " + operation.name + " " + tensor; + return super.toString() + " " + operation.name + " " + tensor + " " + createNonExistingCells; } /** |