diff options
author | Lester Solbakken <lesters@oath.com> | 2019-02-08 11:23:37 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-02-08 11:23:37 +0100 |
commit | 3425c3bbbc522e3da2c3ab221227c2bff36770c3 (patch) | |
tree | 91f6aaf39f21b1ee90982431afa86312eaf74148 /document | |
parent | fb333d2f8d92c2661591bd0a1114a0152708728e (diff) |
Add bound check for dense tensor update modify
Diffstat (limited to 'document')
-rw-r--r-- | document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java | 42 | ||||
-rw-r--r-- | document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java | 20 |
2 files changed, 55 insertions, 7 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java index 41748454ae6..a9bbba519bd 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java @@ -8,8 +8,11 @@ import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.json.TokenBuffer; import com.yahoo.document.update.TensorModifyUpdate; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import java.util.Iterator; + import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectStart; import static com.yahoo.document.json.readers.TensorReader.TENSOR_CELLS; import static com.yahoo.document.json.readers.TensorReader.readTensorCells; @@ -84,8 +87,6 @@ public class TensorModifyUpdateReader { private static ModifyUpdateResult createModifyUpdateResult(TokenBuffer buffer, Field field) { ModifyUpdateResult result = new ModifyUpdateResult(); - TensorDataType tensorDataType = (TensorDataType)field.getDataType(); - TensorType convertedType = TensorModifyUpdate.convertToCompatibleType(tensorDataType.getTensorType()); buffer.next(); int localNesting = buffer.nesting(); while (localNesting <= buffer.nesting()) { @@ -94,7 +95,7 @@ public class TensorModifyUpdateReader { result.operation = createOperation(buffer, field.getName()); break; case TENSOR_CELLS: - result.tensor = createTensor(buffer, convertedType); + result.tensor = createTensor(buffer, field); break; default: throw new IllegalArgumentException("Unknown JSON string '" + buffer.currentName() + "' in modify update for field '" + field.getName() + "'"); @@ -117,12 +118,39 @@ public class TensorModifyUpdateReader { } } - private static TensorFieldValue createTensor(TokenBuffer buffer, TensorType tensorType) { - Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType); + private static TensorFieldValue createTensor(TokenBuffer buffer, Field field) { + TensorDataType tensorDataType = (TensorDataType)field.getDataType(); + TensorType originalType = tensorDataType.getTensorType(); + TensorType convertedType = TensorModifyUpdate.convertToCompatibleType(originalType); + + Tensor.Builder tensorBuilder = Tensor.Builder.of(convertedType); readTensorCells(buffer, tensorBuilder); - TensorFieldValue result = new TensorFieldValue(tensorType); - result.assign(tensorBuilder.build()); + Tensor tensor = tensorBuilder.build(); + + validateBounds(tensor, originalType); + + TensorFieldValue result = new TensorFieldValue(convertedType); + result.assign(tensor); return result; } + /** Only validate if original type is indexed bound */ + private static void validateBounds(Tensor convertedTensor, TensorType originalType) { + if ( ! originalType.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) { + return; + } + for (Iterator<Tensor.Cell> iter = convertedTensor.cellIterator(); iter.hasNext(); ) { + Tensor.Cell cell = iter.next(); + TensorAddress address = cell.getKey(); + for (int i = 0; i < address.size(); ++i) { + long label = address.numericLabel(i); + long bound = originalType.dimensions().get(i).size().get(); // size is non-optional for indexed bound + if (label >= bound) { + throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() + + "' has label '" + label + "' but type is " + originalType.toString()); + } + } + } + } + } diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index ec37ebc8295..376fac3fd84 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -1353,6 +1353,16 @@ public class JsonReaderTestCase { } @Test + public void tensor_modify_update_with_multiply_operation_dense() { + assertTensorModifyUpdate("{{x:a,y:b}:2.0}", TensorModifyUpdate.Operation.MULTIPLY, "sparse_tensor", + inputJson("{", + " 'operation': 'multiply',", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': 'b' }, 'value': 2.0 } ]}")); + } + + + @Test public void tensor_modify_update_treats_the_input_tensor_as_sparse() { // Note that the type of the tensor in the modify update is sparse (it only has mapped dimensions). assertTensorModifyUpdate("tensor(x{},y{}):{{x:0,y:0}:2.0, {x:1,y:2}:3.0}", @@ -1395,6 +1405,16 @@ public class JsonReaderTestCase { } @Test + public void tensor_modify_update_with_out_of_bound_cells_throws() { + exception.expect(IndexOutOfBoundsException.class); + exception.expectMessage("Dimension 'y' has label '3' but type is tensor(x[2],y[3])"); + createTensorModifyUpdate(inputJson("{", + " 'operation': 'replace',", + " 'cells': [", + " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"), "dense_tensor"); + } + + @Test public void tensor_modify_update_with_unknown_operation_throws() { exception.expect(IllegalArgumentException.class); exception.expectMessage("Unknown operation 'unknown' in modify update for field 'sparse_tensor'"); |