summaryrefslogtreecommitdiffstats
path: root/document
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-02-22 09:10:34 +0100
committerGitHub <noreply@github.com>2019-02-22 09:10:34 +0100
commit21a0951fb8906d60a6c2f565f6aac40087e986fe (patch)
tree30a5a80fa5a5d8046c12117be222ad91cbe10b10 /document
parentf3e121c715d2cb60102b88494a4daccf1ec2ebc4 (diff)
parent21651a8420530f069d42f37ca4dd0381f043501a (diff)
Merge pull request #8558 from vespa-engine/lesters/tensor-partial-update-mixed-tensors-java
Tensor partial update for mixed tensors - Java
Diffstat (limited to 'document')
-rw-r--r--document/abi-spec.json3
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java14
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java37
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java27
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java14
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java26
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java2
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java38
-rw-r--r--document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java63
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java92
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java8
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java17
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java3
13 files changed, 224 insertions, 120 deletions
diff --git a/document/abi-spec.json b/document/abi-spec.json
index 61390af3523..d4db3026b27 100644
--- a/document/abi-spec.json
+++ b/document/abi-spec.json
@@ -5244,7 +5244,7 @@
],
"methods": [
"public void <init>(com.yahoo.document.update.TensorModifyUpdate$Operation, com.yahoo.document.datatypes.TensorFieldValue)",
- "public static com.yahoo.tensor.TensorType convertToCompatibleType(com.yahoo.tensor.TensorType)",
+ "public static com.yahoo.tensor.TensorType convertDimensionsToMapped(com.yahoo.tensor.TensorType)",
"public com.yahoo.document.update.TensorModifyUpdate$Operation getOperation()",
"public com.yahoo.document.datatypes.TensorFieldValue getValue()",
"public void setValue(com.yahoo.document.datatypes.TensorFieldValue)",
@@ -5278,6 +5278,7 @@
"public boolean equals(java.lang.Object)",
"public int hashCode()",
"public java.lang.String toString()",
+ "public static com.yahoo.tensor.TensorType extractSparseDimensions(com.yahoo.tensor.TensorType)",
"public bridge synthetic void setValue(com.yahoo.document.datatypes.FieldValue)",
"public bridge synthetic com.yahoo.document.datatypes.FieldValue getValue()"
],
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
index ffbfe49347c..6310fa62d15 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
@@ -23,22 +23,23 @@ public class TensorAddUpdateReader {
public static TensorAddUpdate createTensorAddUpdate(TokenBuffer buffer, Field field) {
expectObjectStart(buffer.currentToken());
- expectTensorTypeIsSparse(field);
+ expectTensorTypeHasSparseDimensions(field);
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType tensorType = tensorDataType.getTensorType();
TensorFieldValue tensorFieldValue = new TensorFieldValue(tensorType);
fillTensor(buffer, tensorFieldValue);
+
expectTensorIsNonEmpty(field, tensorFieldValue.getTensor().get());
return new TensorAddUpdate(tensorFieldValue);
}
- private static void expectTensorTypeIsSparse(Field field) {
+ private static void expectTensorTypeHasSparseDimensions(Field field) {
TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
- if (tensorType.dimensions().stream()
- .anyMatch(dim -> dim.isIndexed())) {
- throw new IllegalArgumentException("An add update can only be applied to sparse tensors. "
- + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'");
+ if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) {
+ throw new IllegalArgumentException("An add update can only be applied to tensors " +
+ "with at least one sparse dimension. Field '" + field.getName() +
+ "' has unsupported tensor type '" + tensorType + "'");
}
}
@@ -48,5 +49,4 @@ public class TensorAddUpdateReader {
}
}
-
}
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
index a9bbba519bd..66588debbca 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
@@ -29,10 +29,8 @@ public class TensorModifyUpdateReader {
private static final String MODIFY_MULTIPLY = "multiply";
public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) {
-
expectFieldIsOfTypeTensor(field);
expectTensorTypeHasNoneIndexedUnboundDimensions(field);
- expectTensorTypeIsNotMixed(field);
expectObjectStart(buffer.currentToken());
ModifyUpdateResult result = createModifyUpdateResult(buffer, field);
@@ -58,16 +56,6 @@ public class TensorModifyUpdateReader {
}
}
- private static void expectTensorTypeIsNotMixed(Field field) {
- TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
- long numMappedDimensions = tensorType.dimensions().stream().filter(dim -> dim.type().equals(TensorType.Dimension.Type.mapped)).count();
- long numIndexedDimensions = tensorType.dimensions().stream().filter(dim -> dim.isIndexed()).count();
- if (numMappedDimensions > 0 && numIndexedDimensions > 0) {
- throw new IllegalArgumentException("A modify update cannot be applied to tensor types with mixed dimensions. "
- + "Field '" + field.getName() + "' has mixed tensor type '" + tensorType + "'");
- }
- }
-
private static void expectOperationSpecified(TensorModifyUpdate.Operation operation, String fieldName) {
if (operation == null) {
throw new IllegalArgumentException("Modify update for field '" + fieldName + "' does not contain an operation");
@@ -121,7 +109,7 @@ public class TensorModifyUpdateReader {
private static TensorFieldValue createTensor(TokenBuffer buffer, Field field) {
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType originalType = tensorDataType.getTensorType();
- TensorType convertedType = TensorModifyUpdate.convertToCompatibleType(originalType);
+ TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(originalType);
Tensor.Builder tensorBuilder = Tensor.Builder.of(convertedType);
readTensorCells(buffer, tensorBuilder);
@@ -129,25 +117,26 @@ public class TensorModifyUpdateReader {
validateBounds(tensor, originalType);
- TensorFieldValue result = new TensorFieldValue(convertedType);
- result.assign(tensor);
- return result;
+ return new TensorFieldValue(tensor);
}
- /** Only validate if original type is indexed bound */
- private static void validateBounds(Tensor convertedTensor, TensorType originalType) {
- if ( ! originalType.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) {
+ /** Only validate if original type has indexed bound dimensions */
+ static void validateBounds(Tensor convertedTensor, TensorType originalType) {
+ if (originalType.dimensions().stream().noneMatch(d -> d instanceof TensorType.IndexedBoundDimension)) {
return;
}
for (Iterator<Tensor.Cell> iter = convertedTensor.cellIterator(); iter.hasNext(); ) {
Tensor.Cell cell = iter.next();
TensorAddress address = cell.getKey();
for (int i = 0; i < address.size(); ++i) {
- long label = address.numericLabel(i);
- long bound = originalType.dimensions().get(i).size().get(); // size is non-optional for indexed bound
- if (label >= bound) {
- throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() +
- "' has label '" + label + "' but type is " + originalType.toString());
+ TensorType.Dimension dim = originalType.dimensions().get(i);
+ if (dim instanceof TensorType.IndexedBoundDimension) {
+ long label = address.numericLabel(i);
+ long bound = dim.size().get(); // size is non-optional for indexed bound
+ if (label >= bound) {
+ throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() +
+ "' has label '" + label + "' but type is " + originalType.toString());
+ }
}
}
}
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
index 210a6a80ee5..3bb4b2e262f 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
@@ -24,23 +24,23 @@ public class TensorRemoveUpdateReader {
static TensorRemoveUpdate createTensorRemoveUpdate(TokenBuffer buffer, Field field) {
expectObjectStart(buffer.currentToken());
- expectTensorTypeIsSparse(field);
+ expectTensorTypeHasSparseDimensions(field);
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
- TensorType tensorType = tensorDataType.getTensorType();
+ TensorType originalType = tensorDataType.getTensorType();
+ TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(originalType);
+ Tensor tensor = readRemoveUpdateTensor(buffer, convertedType, originalType);
- // TODO: for mixed case extract a new tensor type based only on mapped dimensions
-
- Tensor tensor = readRemoveUpdateTensor(buffer, tensorType);
expectAddressesAreNonEmpty(field, tensor);
return new TensorRemoveUpdate(new TensorFieldValue(tensor));
}
- private static void expectTensorTypeIsSparse(Field field) {
+ private static void expectTensorTypeHasSparseDimensions(Field field) {
TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
- if (tensorType.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed)) {
- throw new IllegalArgumentException("A remove update can only be applied to sparse tensors. "
- + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'");
+ if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) {
+ throw new IllegalArgumentException("A remove update can only be applied to tensors " +
+ "with at least one sparse dimension. Field '" + field.getName() +
+ "' has unsupported tensor type '" + tensorType + "'");
}
}
@@ -53,7 +53,7 @@ public class TensorRemoveUpdateReader {
/**
* Reads all addresses in buffer and returns a tensor where addresses have cell value 1.0
*/
- private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type) {
+ private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type, TensorType originalType) {
Tensor.Builder builder = Tensor.Builder.of(type);
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
@@ -62,7 +62,7 @@ public class TensorRemoveUpdateReader {
expectArrayStart(buffer.currentToken());
int nesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= nesting; buffer.next()) {
- builder.cell(readTensorAddress(buffer, type), 1.0);
+ builder.cell(readTensorAddress(buffer, type, originalType), 1.0);
}
expectCompositeEnd(buffer.currentToken());
}
@@ -71,12 +71,15 @@ public class TensorRemoveUpdateReader {
return builder.build();
}
- private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type) {
+ private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type, TensorType originalType) {
TensorAddress.Builder builder = new TensorAddress.Builder(type);
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
String dimension = buffer.currentName();
+ if ( ! type.dimension(dimension).isPresent() && originalType.dimension(dimension).isPresent()) {
+ throw new IllegalArgumentException("Indexed dimension address '" + dimension + "' should not be specified in remove update");
+ }
String label = buffer.currentText();
builder.add(dimension, label);
}
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
index 2f22def9aa1..a763db33e7a 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
@@ -5,6 +5,7 @@ import com.yahoo.document.DataType;
import com.yahoo.document.DocumentTypeManager;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
+import com.yahoo.document.json.readers.TensorRemoveUpdateReader;
import com.yahoo.document.update.TensorAddUpdate;
import com.yahoo.document.update.TensorModifyUpdate;
import com.yahoo.document.update.TensorRemoveUpdate;
@@ -35,7 +36,10 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
throw new DeserializationException("Expected tensor data type, got " + type);
}
TensorDataType tensorDataType = (TensorDataType)type;
- TensorFieldValue tensor = new TensorFieldValue(TensorModifyUpdate.convertToCompatibleType(tensorDataType.getTensorType()));
+ TensorType tensorType = tensorDataType.getTensorType();
+ TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(tensorType);
+
+ TensorFieldValue tensor = new TensorFieldValue(convertedType);
tensor.deserialize(this);
return new TensorModifyUpdate(operation, tensor);
}
@@ -46,7 +50,8 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
throw new DeserializationException("Expected tensor data type, got " + type);
}
TensorDataType tensorDataType = (TensorDataType)type;
- TensorFieldValue tensor = new TensorFieldValue(tensorDataType.getTensorType());
+ TensorType tensorType = tensorDataType.getTensorType();
+ TensorFieldValue tensor = new TensorFieldValue(tensorType);
tensor.deserialize(this);
return new TensorAddUpdate(tensor);
}
@@ -58,10 +63,9 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
}
TensorDataType tensorDataType = (TensorDataType)type;
TensorType tensorType = tensorDataType.getTensorType();
+ TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(tensorType);
- // TODO: for mixed case extract a new tensor type based only on mapped dimensions
-
- TensorFieldValue tensor = new TensorFieldValue(tensorType);
+ TensorFieldValue tensor = new TensorFieldValue(convertedType);
tensor.deserialize(this);
return new TensorRemoveUpdate(tensor);
}
diff --git a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
index cfc3ee0c742..f8d2464deb7 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
@@ -7,15 +7,11 @@ import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.serialization.DocumentUpdateWriter;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
-import java.util.Map;
import java.util.Objects;
/**
- * An update used to add cells to a sparse tensor (has only mapped dimensions).
- *
- * The cells to add are contained in a sparse tensor as well.
+ * An update used to add cells to a sparse or mixed tensor (has at least one mapped dimension).
*/
public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> {
@@ -50,22 +46,10 @@ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> {
return oldValue;
}
- Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get();
- Map<TensorAddress, Double> oldCells = oldTensor.cells();
- Map<TensorAddress, Double> addCells = tensor.getTensor().get().cells();
-
- // currently, underlying implementation disallows multiple entries with the same key
-
- Tensor.Builder builder = Tensor.Builder.of(oldTensor.type());
- for (Map.Entry<TensorAddress, Double> oldCell : oldCells.entrySet()) {
- builder.cell(oldCell.getKey(), addCells.getOrDefault(oldCell.getKey(), oldCell.getValue()));
- }
- for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) {
- if ( ! oldCells.containsKey(addCell.getKey())) {
- builder.cell(addCell.getKey(), addCell.getValue());
- }
- }
- return new TensorFieldValue(builder.build());
+ Tensor old = ((TensorFieldValue) oldValue).getTensor().get();
+ Tensor update = tensor.getTensor().get();
+ Tensor result = old.merge((left, right) -> right, update.cells()); // note this might be slow for large mixed tensor updates
+ return new TensorFieldValue(result);
}
@Override
diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
index 6111b51ca4e..2773f9d31da 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
@@ -37,7 +37,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
/**
* Converts the given tensor type to a type that is compatible for being used in this update (has only mapped dimensions).
*/
- public static TensorType convertToCompatibleType(TensorType type) {
+ public static TensorType convertDimensionsToMapped(TensorType type) {
TensorType.Builder builder = new TensorType.Builder();
type.dimensions().stream().forEach(dim -> builder.mapped(dim.name()));
return builder.build();
diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
index e9fb1e3efd5..335cda8e133 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
@@ -7,10 +7,8 @@ import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.serialization.DocumentUpdateWriter;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
-import java.util.Iterator;
-import java.util.Map;
import java.util.Objects;
/**
@@ -25,6 +23,18 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
public TensorRemoveUpdate(TensorFieldValue value) {
super(ValueUpdateClassID.TENSORREMOVE);
this.tensor = value;
+ verifyCompatibleType();
+ }
+
+ private void verifyCompatibleType() {
+ if ( ! tensor.getTensor().isPresent()) {
+ throw new IllegalArgumentException("Tensor must be present in remove update");
+ }
+ TensorType tensorType = tensor.getTensor().get().type();
+ TensorType expectedType = extractSparseDimensions(tensor.getDataType().getTensorType());
+ if ( ! tensorType.equals(expectedType)) {
+ throw new IllegalArgumentException("Unexpected type '" + tensorType + "' in remove update. Expected is '" + expectedType + "'");
+ }
}
@Override
@@ -51,17 +61,10 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
return oldValue;
}
- Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get();
- Map<TensorAddress, Double> cellsToRemove = tensor.getTensor().get().cells();
- Tensor.Builder builder = Tensor.Builder.of(oldTensor.type());
- for (Iterator<Tensor.Cell> i = oldTensor.cellIterator(); i.hasNext(); ) {
- Tensor.Cell cell = i.next();
- TensorAddress address = cell.getKey();
- if ( ! cellsToRemove.containsKey(address)) {
- builder.cell(address, cell.getValue());
- }
- }
- return new TensorFieldValue(builder.build());
+ Tensor old = ((TensorFieldValue) oldValue).getTensor().get();
+ Tensor update = tensor.getTensor().get();
+ Tensor result = old.remove(update.cells().keySet());
+ return new TensorFieldValue(result);
}
@Override
@@ -93,4 +96,11 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
return super.toString() + " " + tensor;
}
+ public static TensorType extractSparseDimensions(TensorType type) {
+ TensorType.Builder builder = new TensorType.Builder();
+ type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name()));
+ return builder.build();
+ }
+
+
}
diff --git a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
index e2736dabd2b..454ad72f344 100644
--- a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
+++ b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
@@ -40,6 +40,7 @@ public class DocumentUpdateJsonSerializerTest {
final static TensorType sparseTensorType = new TensorType.Builder().mapped("x").mapped("y").build();
final static TensorType denseTensorType = new TensorType.Builder().indexed("x", 2).indexed("y", 3).build();
+ final static TensorType mixedTensorType = new TensorType.Builder().mapped("x").indexed("y", 3).build();
final static DocumentTypeManager types = new DocumentTypeManager();
final static JsonFactory parserFactory = new JsonFactory();
final static DocumentType docType = new DocumentType("doctype");
@@ -60,6 +61,7 @@ public class DocumentUpdateJsonSerializerTest {
docType.addField(new Field("byte_field", DataType.BYTE));
docType.addField(new Field("sparse_tensor", new TensorDataType(sparseTensorType)));
docType.addField(new Field("dense_tensor", new TensorDataType(denseTensorType)));
+ docType.addField(new Field("mixed_tensor", new TensorDataType(mixedTensorType)));
docType.addField(new Field("reference_field", new ReferenceDataType(refTargetDocType, 777)));
docType.addField(new Field("predicate_field", DataType.PREDICATE));
docType.addField(new Field("raw_field", DataType.RAW));
@@ -336,6 +338,26 @@ public class DocumentUpdateJsonSerializerTest {
}
@Test
+ public void test_tensor_modify_update_on_mixed_tensor() {
+ roundtripSerializeJsonAndMatch(inputJson(
+ "{",
+ " 'update': 'DOCUMENT_ID',",
+ " 'fields': {",
+ " 'mixed_tensor': {",
+ " 'modify': {",
+ " 'operation': 'multiply',",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 },",
+ " { 'address': { 'x': 'c', 'y': '1' }, 'value': 3.0 }",
+ " ]",
+ " }",
+ " }",
+ " }",
+ "}"
+ ));
+ }
+
+ @Test
public void test_tensor_add_update() {
roundtripSerializeJsonAndMatch(inputJson(
"{",
@@ -355,6 +377,29 @@ public class DocumentUpdateJsonSerializerTest {
}
@Test
+ public void test_tensor_add_update_mixed() {
+ roundtripSerializeJsonAndMatch(inputJson(
+ "{",
+ " 'update': 'DOCUMENT_ID',",
+ " 'fields': {",
+ " 'mixed_tensor': {",
+ " 'add': {",
+ " 'cells': [",
+ " { 'address': { 'x': '1', 'y': '0' }, 'value': 2.0 },",
+ " { 'address': { 'x': '1', 'y': '1' }, 'value': 0.0 },",
+ " { 'address': { 'x': '1', 'y': '2' }, 'value': 0.0 },",
+ " { 'address': { 'x': '0', 'y': '0' }, 'value': 0.0 },",
+ " { 'address': { 'x': '0', 'y': '1' }, 'value': 0.0 },",
+ " { 'address': { 'x': '0', 'y': '2' }, 'value': 3.0 }",
+ " ]",
+ " }",
+ " }",
+ " }",
+ "}"
+ ));
+ }
+
+ @Test
public void test_tensor_remove_update() {
roundtripSerializeJsonAndMatch(inputJson(
"{",
@@ -374,6 +419,24 @@ public class DocumentUpdateJsonSerializerTest {
}
@Test
+ public void test_tensor_remove_update_mixed() {
+ roundtripSerializeJsonAndMatch(inputJson(
+ "{",
+ " 'update': 'DOCUMENT_ID',",
+ " 'fields': {",
+ " 'mixed_tensor': {",
+ " 'remove': {",
+ " 'addresses': [",
+ " {'x':'0' }",
+ " ]",
+ " }",
+ " }",
+ " }",
+ "}"
+ ));
+ }
+
+ @Test
public void reference_field_id_can_be_update_assigned_non_empty_id() {
roundtripSerializeJsonAndMatch(inputJson(
"{",
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index e58b26d500d..15d1e859f73 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -1387,12 +1387,30 @@ public class JsonReaderTestCase {
}
@Test
- public void tensor_modify_update_on_mixed_tensor_throws() {
- exception.expect(IllegalArgumentException.class);
- exception.expectMessage("A modify update cannot be applied to tensor types with mixed dimensions. Field 'mixed_tensor' has mixed tensor type 'tensor(x{},y[3])'");
- createTensorModifyUpdate(inputJson("{",
- " 'operation': 'replace',",
- " 'cells': [] }"), "mixed_tensor");
+ public void tensor_modify_update_with_replace_operation_mixed() {
+ assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor",
+ inputJson("{",
+ " 'operation': 'replace',",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}"));
+ }
+
+ @Test
+ public void tensor_modify_update_with_add_operation_mixed() {
+ assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.ADD, "mixed_tensor",
+ inputJson("{",
+ " 'operation': 'add',",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}"));
+ }
+
+ @Test
+ public void tensor_modify_update_with_multiply_operation_mixed() {
+ assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.MULTIPLY, "mixed_tensor",
+ inputJson("{",
+ " 'operation': 'multiply',",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}"));
}
@Test
@@ -1406,6 +1424,17 @@ public class JsonReaderTestCase {
}
@Test
+ public void tensor_modify_update_with_out_of_bound_cells_throws_mixed() {
+ exception.expect(IndexOutOfBoundsException.class);
+ exception.expectMessage("Dimension 'y' has label '3' but type is tensor(x{},y[3])");
+ createTensorModifyUpdate(inputJson("{",
+ " 'operation': 'replace',",
+ " 'cells': [",
+ " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"), "mixed_tensor");
+ }
+
+
+ @Test
public void tensor_modify_update_with_unknown_operation_throws() {
exception.expect(IllegalArgumentException.class);
exception.expectMessage("Unknown operation 'unknown' in modify update for field 'sparse_tensor'");
@@ -1449,11 +1478,29 @@ public class JsonReaderTestCase {
}
@Test
- public void tensor_add_update_on_non_sparse_tensor_throws() {
+ public void tensor_add_update_on_mixed_tensor() {
+ assertTensorAddUpdate("{{x:a,y:0}:2.0, {x:a,y:1}:3.0, {x:a,y:2}:0.0}", "mixed_tensor",
+ inputJson("{",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 },",
+ " { 'address': { 'x': 'a', 'y': '1' }, 'value': 3.0 } ]}"));
+ }
+
+ @Test
+ public void tensor_add_update_on_mixed_with_out_of_bound_dense_cells_throws() {
+ exception.expect(IndexOutOfBoundsException.class);
+ exception.expectMessage("Index 3 out of bounds for length 3");
+ createTensorAddUpdate(inputJson("{",
+ " 'cells': [",
+ " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"), "mixed_tensor");
+ }
+
+ @Test
+ public void tensor_add_update_on_dense_tensor_throws() {
exception.expect(IllegalArgumentException.class);
- exception.expectMessage("An add update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'");
+ exception.expectMessage("An add update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'");
createTensorAddUpdate(inputJson("{",
- " 'cells': [] }"), "mixed_tensor");
+ " 'cells': [] }"), "dense_tensor");
}
@Test
@@ -1470,6 +1517,7 @@ public class JsonReaderTestCase {
exception.expect(IllegalArgumentException.class);
exception.expectMessage("Add update for field 'sparse_tensor' does not contain tensor cells");
createTensorAddUpdate(inputJson("{}"), "sparse_tensor");
+ createTensorAddUpdate(inputJson("{}"), "mixed_tensor");
}
@Test
@@ -1482,11 +1530,30 @@ public class JsonReaderTestCase {
}
@Test
- public void tensor_remove_update_on_non_sparse_tensor_throws() {
+ public void tensor_remove_update_on_mixed_tensor() {
+ assertTensorRemoveUpdate("{{x:1}:1.0,{x:2}:1.0}", "mixed_tensor",
+ inputJson("{",
+ " 'addresses': [",
+ " { 'x': '1' },",
+ " { 'x': '2' } ]}"));
+ }
+
+ @Test
+ public void tensor_remove_update_on_mixed_tensor_with_dense_addresses_throws() {
+ exception.expect(IllegalArgumentException.class);
+ exception.expectMessage("Indexed dimension address 'y' should not be specified in remove update");
+ createTensorRemoveUpdate(inputJson("{",
+ " 'addresses': [",
+ " { 'x': '1', 'y': '0' },",
+ " { 'x': '2', 'y': '0' } ]}"), "mixed_tensor");
+ }
+
+ @Test
+ public void tensor_remove_update_on_dense_tensor_throws() {
exception.expect(IllegalArgumentException.class);
- exception.expectMessage("A remove update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'");
+ exception.expectMessage("A remove update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'");
createTensorRemoveUpdate(inputJson("{",
- " 'addresses': [] }"), "mixed_tensor");
+ " 'addresses': [] }"), "dense_tensor");
}
@Test
@@ -1503,6 +1570,7 @@ public class JsonReaderTestCase {
exception.expect(IllegalArgumentException.class);
exception.expectMessage("Remove update for field 'sparse_tensor' does not contain tensor addresses");
createTensorRemoveUpdate(inputJson("{'addresses': [] }"), "sparse_tensor");
+ createTensorRemoveUpdate(inputJson("{'addresses': [] }"), "mixed_tensor");
}
@Test
diff --git a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java
index eb4001e6415..6935c54ba2a 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java
@@ -12,18 +12,14 @@ public class TensorAddUpdateTest {
@Test
public void apply_add_update_operations() {
assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}");
- assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3}");
- assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}");
- assertApplyTo("{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5}");
- assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1, {x:0,y:1}:2}");
}
private void assertApplyTo(String init, String update, String expected) {
String spec = "tensor(x{},y{})";
TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init));
TensorAddUpdate addUpdate = new TensorAddUpdate(new TensorFieldValue(Tensor.from(spec, update)));
- TensorFieldValue updatedFieldValue = (TensorFieldValue) addUpdate.applyTo(initialFieldValue);
- assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get());
+ Tensor updated = ((TensorFieldValue) addUpdate.applyTo(initialFieldValue)).getTensor().get();
+ assertEquals(Tensor.from(spec, expected), updated);
}
}
diff --git a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
index 6e9444de2be..b885e6ddca0 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
@@ -1,12 +1,6 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.document.update;
-import com.yahoo.document.Document;
-import com.yahoo.document.DocumentId;
-import com.yahoo.document.DocumentType;
-import com.yahoo.document.DocumentTypeManager;
-import com.yahoo.document.Field;
-import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.update.TensorModifyUpdate.Operation;
import com.yahoo.tensor.Tensor;
@@ -28,10 +22,11 @@ public class TensorModifyUpdateTest {
assertConvertToCompatible("tensor(x{})", "tensor(x[10])");
assertConvertToCompatible("tensor(x{})", "tensor(x{})");
assertConvertToCompatible("tensor(x{},y{},z{})", "tensor(x[],y[10],z{})");
+ assertConvertToCompatible("tensor(x{},y{})", "tensor(x{},y[3])");
}
private static void assertConvertToCompatible(String expectedType, String inputType) {
- assertEquals(expectedType, TensorModifyUpdate.convertToCompatibleType(TensorType.fromSpec(inputType)).toString());
+ assertEquals(expectedType, TensorModifyUpdate.convertDimensionsToMapped(TensorType.fromSpec(inputType)).toString());
}
@Test
@@ -46,15 +41,9 @@ public class TensorModifyUpdateTest {
public void apply_modify_update_operations() {
assertApplyTo("tensor(x{},y{})", Operation.REPLACE,
"{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}");
- assertApplyTo("tensor(x{},y{})", Operation.ADD,
- "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}");
- assertApplyTo("tensor(x{},y{})", Operation.MULTIPLY,
- "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}");
- assertApplyTo("tensor(x[1],y[2])", Operation.REPLACE,
- "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}");
assertApplyTo("tensor(x[1],y[2])", Operation.ADD,
"{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}");
- assertApplyTo("tensor(x[1],y[2])", Operation.MULTIPLY,
+ assertApplyTo("tensor(x{},y[2])", Operation.MULTIPLY,
"{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}");
}
diff --git a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
index 40ab00facdb..3a005e858c8 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
@@ -12,9 +12,6 @@ public class TensorRemoveUpdateTest {
@Test
public void apply_remove_update_operations() {
assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0,y:1}:1}", "{{x:0,y:0}:2}");
- assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:0}:1,{x:0,y:1}:1}", "{}");
- assertApplyTo("{}", "{{x:0,y:0}:1}", "{}");
- assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}");
}
private void assertApplyTo(String init, String update, String expected) {