aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2020-01-13 11:19:24 +0100
committerJon Bratseth <bratseth@verizonmedia.com>2020-01-13 11:19:24 +0100
commita3ec97de26572acdbfe4b1801744061decb84d38 (patch)
treec11e4a00f7595da100ea6a03f5eebc9771164800
parent4976b922193b1071db4711328caf31bc54e1a0d1 (diff)
Support modify of mixed tensors
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java6
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java39
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorReader.java2
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java14
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java5
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java53
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java2
7 files changed, 91 insertions, 30 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
index 6310fa62d15..e98a262b661 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
@@ -13,7 +13,7 @@ import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectStar
import static com.yahoo.document.json.readers.TensorReader.fillTensor;
/**
- * Class used to read an add update for a tensor field.
+ * Reader of an "add" update of a tensor field.
*/
public class TensorAddUpdateReader {
@@ -38,8 +38,8 @@ public class TensorAddUpdateReader {
TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) {
throw new IllegalArgumentException("An add update can only be applied to tensors " +
- "with at least one sparse dimension. Field '" + field.getName() +
- "' has unsupported tensor type '" + tensorType + "'");
+ "with at least one sparse dimension. Field '" + field.getName() +
+ "' has unsupported tensor type '" + tensorType + "'");
}
}
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
index 66588debbca..5fd1c7bbab7 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
@@ -14,11 +14,13 @@ import com.yahoo.tensor.TensorType;
import java.util.Iterator;
import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectStart;
+import static com.yahoo.document.json.readers.TensorReader.TENSOR_BLOCKS;
import static com.yahoo.document.json.readers.TensorReader.TENSOR_CELLS;
+import static com.yahoo.document.json.readers.TensorReader.readTensorBlocks;
import static com.yahoo.document.json.readers.TensorReader.readTensorCells;
/**
- * Class used to read a modify update for a tensor field.
+ * Reader of a "modify" update of a tensor field.
*/
public class TensorModifyUpdateReader {
@@ -30,7 +32,7 @@ public class TensorModifyUpdateReader {
public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) {
expectFieldIsOfTypeTensor(field);
- expectTensorTypeHasNoneIndexedUnboundDimensions(field);
+ expectTensorTypeHasNoIndexedUnboundDimensions(field);
expectObjectStart(buffer.currentToken());
ModifyUpdateResult result = createModifyUpdateResult(buffer, field);
@@ -41,18 +43,19 @@ public class TensorModifyUpdateReader {
}
private static void expectFieldIsOfTypeTensor(Field field) {
- if (!(field.getDataType() instanceof TensorDataType)) {
+ if ( ! (field.getDataType() instanceof TensorDataType)) {
throw new IllegalArgumentException("A modify update can only be applied to tensor fields. " +
- "Field '" + field.getName() + "' is of type '" + field.getDataType().getName() + "'");
+ "Field '" + field.getName() + "' is of type '" +
+ field.getDataType().getName() + "'");
}
}
- private static void expectTensorTypeHasNoneIndexedUnboundDimensions(Field field) {
+ private static void expectTensorTypeHasNoIndexedUnboundDimensions(Field field) {
TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
if (tensorType.dimensions().stream()
.anyMatch(dim -> dim.type().equals(TensorType.Dimension.Type.indexedUnbound))) {
- throw new IllegalArgumentException("A modify update cannot be applied to tensor types with indexed unbound dimensions. "
- + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'");
+ throw new IllegalArgumentException("A modify update cannot be applied to tensor types with indexed unbound dimensions. " +
+ "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'");
}
}
@@ -83,7 +86,10 @@ public class TensorModifyUpdateReader {
result.operation = createOperation(buffer, field.getName());
break;
case TENSOR_CELLS:
- result.tensor = createTensor(buffer, field);
+ result.tensor = createTensorFromCells(buffer, field);
+ break;
+ case TENSOR_BLOCKS:
+ result.tensor = createTensorFromBlocks(buffer, field);
break;
default:
throw new IllegalArgumentException("Unknown JSON string '" + buffer.currentName() + "' in modify update for field '" + field.getName() + "'");
@@ -106,7 +112,7 @@ public class TensorModifyUpdateReader {
}
}
- private static TensorFieldValue createTensor(TokenBuffer buffer, Field field) {
+ private static TensorFieldValue createTensorFromCells(TokenBuffer buffer, Field field) {
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType originalType = tensorDataType.getTensorType();
TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(originalType);
@@ -120,6 +126,19 @@ public class TensorModifyUpdateReader {
return new TensorFieldValue(tensor);
}
+ private static TensorFieldValue createTensorFromBlocks(TokenBuffer buffer, Field field) {
+ TensorDataType tensorDataType = (TensorDataType)field.getDataType();
+ TensorType type = tensorDataType.getTensorType();
+
+ Tensor.Builder tensorBuilder = Tensor.Builder.of(type);
+ readTensorBlocks(buffer, tensorBuilder);
+ Tensor tensor = tensorBuilder.build();
+
+ validateBounds(tensor, type);
+
+ return new TensorFieldValue(tensor);
+ }
+
/** Only validate if original type has indexed bound dimensions */
static void validateBounds(Tensor convertedTensor, TensorType originalType) {
if (originalType.dimensions().stream().noneMatch(d -> d instanceof TensorType.IndexedBoundDimension)) {
@@ -135,7 +154,7 @@ public class TensorModifyUpdateReader {
long bound = dim.size().get(); // size is non-optional for indexed bound
if (label >= bound) {
throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() +
- "' has label '" + label + "' but type is " + originalType.toString());
+ "' has label '" + label + "' but type is " + originalType.toString());
}
}
}
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
index 5516e9523a1..e5699d0e6b1 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
@@ -98,7 +98,7 @@ public class TensorReader {
expectCompositeEnd(buffer.currentToken());
}
- private static void readTensorBlocks(TokenBuffer buffer, Tensor.Builder builder) {
+ static void readTensorBlocks(TokenBuffer buffer, Tensor.Builder builder) {
if ( ! (builder instanceof MixedTensor.BoundBuilder))
throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " +
"Use 'cells' or 'values' instead");
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
index 3bb4b2e262f..91c275b6da0 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
@@ -16,7 +16,7 @@ import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectEnd;
import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectStart;
/**
- * Class used to read a remove update for a tensor field.
+ * Reader of a "remove" update of a tensor field.
*/
public class TensorRemoveUpdateReader {
@@ -39,14 +39,15 @@ public class TensorRemoveUpdateReader {
TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) {
throw new IllegalArgumentException("A remove update can only be applied to tensors " +
- "with at least one sparse dimension. Field '" + field.getName() +
- "' has unsupported tensor type '" + tensorType + "'");
+ "with at least one sparse dimension. Field '" + field.getName() +
+ "' has unsupported tensor type '" + tensorType + "'");
}
}
private static void expectAddressesAreNonEmpty(Field field, Tensor tensor) {
if (tensor.isEmpty()) {
- throw new IllegalArgumentException("Remove update for field '" + field.getName() + "' does not contain tensor addresses");
+ throw new IllegalArgumentException("Remove update for field '" + field.getName() +
+ "' does not contain tensor addresses");
}
}
@@ -77,8 +78,9 @@ public class TensorRemoveUpdateReader {
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
String dimension = buffer.currentName();
- if ( ! type.dimension(dimension).isPresent() && originalType.dimension(dimension).isPresent()) {
- throw new IllegalArgumentException("Indexed dimension address '" + dimension + "' should not be specified in remove update");
+ if ( type.dimension(dimension).isEmpty() && originalType.dimension(dimension).isPresent()) {
+ throw new IllegalArgumentException("Indexed dimension address '" + dimension +
+ "' should not be specified in remove update");
}
String label = buffer.currentText();
builder.add(dimension, label);
diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
index 435c8fcdc65..cc59ff65f1f 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
@@ -18,6 +18,7 @@ import java.util.function.DoubleBinaryOperator;
* The cells to update are contained in a sparse tensor (has only mapped dimensions).
*/
public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
+
protected Operation operation;
protected TensorFieldValue tensor;
@@ -29,8 +30,8 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
}
private void verifyCompatibleType(TensorType type) {
- if (type.dimensions().stream().anyMatch(dim -> dim.isIndexed()) ) {
- throw new IllegalArgumentException("Tensor type '" + type + "' is not compatible as it contains some indexed dimensions");
+ if (type.rank() > 0 && type.dimensions().stream().noneMatch(dim -> dim.isMapped()) ) {
+ throw new IllegalArgumentException("Tensor type '" + type + "' is not compatible as it has no mapped dimensions");
}
}
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index 511ad081c8c..5867ca5596c 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -1378,13 +1378,13 @@ public class JsonReaderTestCase {
@Test
public void testAssignUpdateOfEmptySparseTensor() {
- assertTensorAssignUpdate("tensor(x{},y{}):{}", createAssignUpdateWithSparseTensor("{}"));
+ assertTensorAssignUpdateSparseField("tensor(x{},y{}):{}", createAssignUpdateWithSparseTensor("{}"));
}
@Test
public void testAssignUpdateOfEmptyDenseTensor() {
try {
- assertTensorAssignUpdate("tensor(x{},y{}):{}", createAssignUpdateWithTensor("{}", "dense_unbound_tensor"));
+ assertTensorAssignUpdateSparseField("tensor(x{},y{}):{}", createAssignUpdateWithTensor("{}", "dense_unbound_tensor"));
}
catch (IllegalArgumentException e) {
assertEquals("An indexed tensor must have a value",
@@ -1402,8 +1402,8 @@ public class JsonReaderTestCase {
@Test
public void testAssignUpdateOfTensorWithCells() {
- assertTensorAssignUpdate("{{x:a,y:b}:2.0,{x:c,y:b}:3.0}}",
- createAssignUpdateWithSparseTensor(inputJson("{",
+ assertTensorAssignUpdateSparseField("{{x:a,y:b}:2.0,{x:c,y:b}:3.0}}",
+ createAssignUpdateWithSparseTensor(inputJson("{",
" 'cells': [",
" { 'address': { 'x': 'a', 'y': 'b' },",
" 'value': 2.0 },",
@@ -1414,6 +1414,15 @@ public class JsonReaderTestCase {
}
@Test
+ public void testAssignUpdateOfTensorDenseShortForm() {
+ assertTensorAssignUpdateDenseField("tensor(x[2],y[3]):[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]",
+ createAssignUpdateWithTensor(inputJson("{",
+ " 'values': [1,2,3,4,5,6]",
+ "}"),
+ "dense_tensor"));
+ }
+
+ @Test
public void tensor_modify_update_with_replace_operation() {
assertTensorModifyUpdate("{{x:a,y:b}:2.0}", TensorModifyUpdate.Operation.REPLACE, "sparse_tensor",
inputJson("{",
@@ -1488,6 +1497,24 @@ public class JsonReaderTestCase {
}
@Test
+ public void tensor_modify_update_with_replace_operation_mixed_block_short_form_array() {
+ assertTensorModifyUpdate("tensor(x{},y[3]):{a:[1,2,3]}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor",
+ inputJson("{",
+ " 'operation': 'replace',",
+ " 'blocks': [",
+ " { 'address': { 'x': 'a' }, 'values': [1,2,3] } ]}"));
+ }
+
+ @Test
+ public void tensor_modify_update_with_replace_operation_mixed_block_short_form_map() {
+ assertTensorModifyUpdate("tensor(x{},y[3]):{a:[1,2,3]}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor",
+ inputJson("{",
+ " 'operation': 'replace',",
+ " 'blocks': {",
+ " 'a': [1,2,3] } }"));
+ }
+
+ @Test
public void tensor_modify_update_with_add_operation_mixed() {
assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.ADD, "mixed_tensor",
inputJson("{",
@@ -1830,10 +1857,18 @@ public class JsonReaderTestCase {
assertEquals(1, update.getFieldUpdate(tensorFieldName).size());
}
- private static void assertTensorAssignUpdate(String expectedTensor, DocumentUpdate update) {
+ private static void assertTensorAssignUpdateSparseField(String expectedTensor, DocumentUpdate update) {
assertEquals("testtensor", update.getId().getDocType());
assertEquals(TENSOR_DOC_ID, update.getId().toString());
- AssignValueUpdate assignUpdate = (AssignValueUpdate) getTensorField(update).getValueUpdate(0);
+ AssignValueUpdate assignUpdate = (AssignValueUpdate) getTensorField(update, "sparse_tensor").getValueUpdate(0);
+ TensorFieldValue fieldValue = (TensorFieldValue) assignUpdate.getValue();
+ assertEquals(Tensor.from(expectedTensor), fieldValue.getTensor().get());
+ }
+
+ private static void assertTensorAssignUpdateDenseField(String expectedTensor, DocumentUpdate update) {
+ assertEquals("testtensor", update.getId().getDocType());
+ assertEquals(TENSOR_DOC_ID, update.getId().toString());
+ AssignValueUpdate assignUpdate = (AssignValueUpdate) getTensorField(update, "dense_tensor").getValueUpdate(0);
TensorFieldValue fieldValue = (TensorFieldValue) assignUpdate.getValue();
assertEquals(Tensor.from(expectedTensor), fieldValue.getTensor().get());
}
@@ -1895,7 +1930,11 @@ public class JsonReaderTestCase {
}
private static FieldUpdate getTensorField(DocumentUpdate update) {
- FieldUpdate fieldUpdate = update.getFieldUpdate("sparse_tensor");
+ return getTensorField(update, "sparse_tensor");
+ }
+
+ private static FieldUpdate getTensorField(DocumentUpdate update, String fieldName) {
+ FieldUpdate fieldUpdate = update.getFieldUpdate(fieldName);
assertEquals(1, fieldUpdate.size());
return fieldUpdate;
}
diff --git a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
index b885e6ddca0..4c8d2e69855 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
@@ -32,7 +32,7 @@ public class TensorModifyUpdateTest {
@Test
public void use_of_incompatible_tensor_type_throws() {
exception.expect(IllegalArgumentException.class);
- exception.expectMessage("Tensor type 'tensor(x[3])' is not compatible as it contains some indexed dimensions");
+ exception.expectMessage("Tensor type 'tensor(x[3])' is not compatible as it has no mapped dimensions");
new TensorModifyUpdate(TensorModifyUpdate.Operation.REPLACE,
new TensorFieldValue(Tensor.from("tensor(x[3])", "{{x:1}:3}")));
}