diff options
author | Lester Solbakken <lesters@oath.com> | 2019-02-20 12:46:24 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-02-20 12:46:24 +0100 |
commit | 085b6922c07f4626c61e2ed2e6dde6beec0855de (patch) | |
tree | 597fc14c08199339c9ab9286c365af6e8d4cdcdb /document | |
parent | 85e394563c8b711a1a0307c8ac5953c1817f5629 (diff) |
TensorAddUpdate support for mixed tensors
Diffstat (limited to 'document')
5 files changed, 94 insertions, 54 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java index ffbfe49347c..da8bcc13397 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java @@ -6,10 +6,15 @@ import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.json.TokenBuffer; import com.yahoo.document.update.TensorAddUpdate; +import com.yahoo.document.update.TensorModifyUpdate; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import java.util.Iterator; + import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectStart; +import static com.yahoo.document.json.readers.TensorModifyUpdateReader.validateBounds; import static com.yahoo.document.json.readers.TensorReader.fillTensor; /** @@ -23,22 +28,27 @@ public class TensorAddUpdateReader { public static TensorAddUpdate createTensorAddUpdate(TokenBuffer buffer, Field field) { expectObjectStart(buffer.currentToken()); - expectTensorTypeIsSparse(field); + expectTensorTypeHasSparseDimensions(field); + // Convert update type to sparse TensorDataType tensorDataType = (TensorDataType)field.getDataType(); - TensorType tensorType = tensorDataType.getTensorType(); - TensorFieldValue tensorFieldValue = new TensorFieldValue(tensorType); + TensorType originalType = tensorDataType.getTensorType(); + TensorType convertedType = TensorModifyUpdate.convertToCompatibleType(originalType); + + TensorFieldValue tensorFieldValue = new TensorFieldValue(convertedType); fillTensor(buffer, tensorFieldValue); expectTensorIsNonEmpty(field, tensorFieldValue.getTensor().get()); + validateBounds(tensorFieldValue.getTensor().get(), originalType); + return new TensorAddUpdate(tensorFieldValue); } - private static void expectTensorTypeIsSparse(Field field) { + private static void expectTensorTypeHasSparseDimensions(Field field) { TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType(); - if (tensorType.dimensions().stream() - .anyMatch(dim -> dim.isIndexed())) { - throw new IllegalArgumentException("An add update can only be applied to sparse tensors. " - + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'"); + if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) { + throw new IllegalArgumentException("An add update can only be applied to tensors " + + "with at least one sparse dimension. Field '" + field.getName() + + "' has unsupported tensor type '" + tensorType + "'"); } } @@ -48,5 +58,4 @@ public class TensorAddUpdateReader { } } - } diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java index a9bbba519bd..5022185e03f 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java @@ -129,25 +129,26 @@ public class TensorModifyUpdateReader { validateBounds(tensor, originalType); - TensorFieldValue result = new TensorFieldValue(convertedType); - result.assign(tensor); - return result; + return new TensorFieldValue(tensor); } - /** Only validate if original type is indexed bound */ - private static void validateBounds(Tensor convertedTensor, TensorType originalType) { - if ( ! originalType.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) { + /** Only validate if original type has indexed bound dimensions */ + static void validateBounds(Tensor convertedTensor, TensorType originalType) { + if (originalType.dimensions().stream().noneMatch(d -> d instanceof TensorType.IndexedBoundDimension)) { return; } for (Iterator<Tensor.Cell> iter = convertedTensor.cellIterator(); iter.hasNext(); ) { Tensor.Cell cell = iter.next(); TensorAddress address = cell.getKey(); for (int i = 0; i < address.size(); ++i) { - long label = address.numericLabel(i); - long bound = originalType.dimensions().get(i).size().get(); // size is non-optional for indexed bound - if (label >= bound) { - throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() + - "' has label '" + label + "' but type is " + originalType.toString()); + TensorType.Dimension dim = originalType.dimensions().get(i); + if (dim instanceof TensorType.IndexedBoundDimension) { + long label = address.numericLabel(i); + long bound = dim.size().get(); // size is non-optional for indexed bound + if (label >= bound) { + throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() + + "' has label '" + label + "' but type is " + originalType.toString()); + } } } } diff --git a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java index cfc3ee0c742..7059edbca7f 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java @@ -13,9 +13,9 @@ import java.util.Map; import java.util.Objects; /** - * An update used to add cells to a sparse tensor (has only mapped dimensions). + * An update used to add cells to a sparse or mixed tensor (has at least one mapped dimension). * - * The cells to add are contained in a sparse tensor as well. + * The cells to add are contained in a sparse tensor. */ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> { @@ -50,22 +50,10 @@ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> { return oldValue; } - Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get(); - Map<TensorAddress, Double> oldCells = oldTensor.cells(); - Map<TensorAddress, Double> addCells = tensor.getTensor().get().cells(); - - // currently, underlying implementation disallows multiple entries with the same key - - Tensor.Builder builder = Tensor.Builder.of(oldTensor.type()); - for (Map.Entry<TensorAddress, Double> oldCell : oldCells.entrySet()) { - builder.cell(oldCell.getKey(), addCells.getOrDefault(oldCell.getKey(), oldCell.getValue())); - } - for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { - if ( ! oldCells.containsKey(addCell.getKey())) { - builder.cell(addCell.getKey(), addCell.getValue()); - } - } - return new TensorFieldValue(builder.build()); + Tensor old = ((TensorFieldValue) oldValue).getTensor().get(); + Tensor update = tensor.getTensor().get(); + Tensor result = old.merge((left, right) -> right, update.cells()); + return new TensorFieldValue(result); } @Override diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index e58b26d500d..a20276e5c65 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -56,6 +56,7 @@ import com.yahoo.text.Utf8; import org.apache.commons.codec.binary.Base64; import org.junit.After; import org.junit.Before; +import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -1449,11 +1450,29 @@ public class JsonReaderTestCase { } @Test - public void tensor_add_update_on_non_sparse_tensor_throws() { + public void tensor_add_update_on_mixed_tensor() { + assertTensorAddUpdate("{{x:a,y:0}:2.0, {x:a,y:1}:3.0}", "mixed_tensor", + inputJson("{", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 },", + " { 'address': { 'x': 'a', 'y': '1' }, 'value': 3.0 } ]}")); + } + + @Test + public void tensor_add_update_with_out_of_bound_dense_cells_throws() { + exception.expect(IndexOutOfBoundsException.class); + exception.expectMessage("Dimension 'y' has label '3' but type is tensor(x{},y[3])"); + createTensorAddUpdate(inputJson("{", + " 'cells': [", + " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"), "mixed_tensor"); + } + + @Test + public void tensor_add_update_on_dense_tensor_throws() { exception.expect(IllegalArgumentException.class); - exception.expectMessage("An add update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'"); + exception.expectMessage("An add update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'"); createTensorAddUpdate(inputJson("{", - " 'cells': [] }"), "mixed_tensor"); + " 'cells': [] }"), "dense_tensor"); } @Test @@ -1481,12 +1500,22 @@ public class JsonReaderTestCase { " { 'x': 'c', 'y': 'd' } ]}")); } + @Ignore + @Test + public void tensor_remove_update_on_mixed_tensor() { + assertTensorRemoveUpdate("{{x:1}:1.0,{x:2}:1.0}", "mixed_tensor", + inputJson("{", + " 'addresses': [", + " { 'x': '1' },", + " { 'x': '2' } ]}")); + } + @Test - public void tensor_remove_update_on_non_sparse_tensor_throws() { + public void tensor_remove_update_on_dense_tensor_throws() { exception.expect(IllegalArgumentException.class); - exception.expectMessage("A remove update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'"); + exception.expectMessage("A remove update can only be applied to sparse tensors. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'"); createTensorRemoveUpdate(inputJson("{", - " 'addresses': [] }"), "mixed_tensor"); + " 'addresses': [] }"), "dense_tensor"); } @Test diff --git a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java index eb4001e6415..c6b21380e4b 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java @@ -3,27 +3,40 @@ package com.yahoo.document.update; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.tensor.Tensor; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import static org.junit.Assert.assertEquals; public class TensorAddUpdateTest { + @Rule + public ExpectedException exception = ExpectedException.none(); + @Test public void apply_add_update_operations() { - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}"); - assertApplyTo("{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1, {x:0,y:1}:2}"); + assertApplyTo("tensor(x{},y{})", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"); + assertApplyTo("tensor(x{},y{})", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3}"); + assertApplyTo("tensor(x{},y{})", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}"); + assertApplyTo("tensor(x{},y{})", "{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5}"); + assertApplyTo("tensor(x{},y{})", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1, {x:0,y:1}:2}"); + + assertApplyTo("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"); + assertApplyTo("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:0}"); + assertApplyTo("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}"); + assertApplyTo("tensor(x{},y[3])", "{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5,{x:0,y:1}:0,{x:0,y:2}:0}"); + assertApplyTo("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:0}"); } - private void assertApplyTo(String init, String update, String expected) { - String spec = "tensor(x{},y{})"; + private Tensor updateField(String spec, String init, String update) { TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init)); - TensorAddUpdate addUpdate = new TensorAddUpdate(new TensorFieldValue(Tensor.from(spec, update))); - TensorFieldValue updatedFieldValue = (TensorFieldValue) addUpdate.applyTo(initialFieldValue); - assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get()); + TensorAddUpdate addUpdate = new TensorAddUpdate(new TensorFieldValue(Tensor.from("tensor(x{},y{})", update))); + return ((TensorFieldValue) addUpdate.applyTo(initialFieldValue)).getTensor().get(); + } + + private void assertApplyTo(String spec, String init, String update, String expected) { + assertEquals(Tensor.from(spec, expected), updateField(spec, init, update)); } } |