aboutsummaryrefslogtreecommitdiffstats
path: root/document
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-02-08 11:23:37 +0100
committerLester Solbakken <lesters@oath.com>2019-02-08 11:23:37 +0100
commit3425c3bbbc522e3da2c3ab221227c2bff36770c3 (patch)
tree91f6aaf39f21b1ee90982431afa86312eaf74148 /document
parentfb333d2f8d92c2661591bd0a1114a0152708728e (diff)
Add bound check for dense tensor update modify
Diffstat (limited to 'document')
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java42
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java20
2 files changed, 55 insertions, 7 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
index 41748454ae6..a9bbba519bd 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
@@ -8,8 +8,11 @@ import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.json.TokenBuffer;
import com.yahoo.document.update.TensorModifyUpdate;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import java.util.Iterator;
+
import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectStart;
import static com.yahoo.document.json.readers.TensorReader.TENSOR_CELLS;
import static com.yahoo.document.json.readers.TensorReader.readTensorCells;
@@ -84,8 +87,6 @@ public class TensorModifyUpdateReader {
private static ModifyUpdateResult createModifyUpdateResult(TokenBuffer buffer, Field field) {
ModifyUpdateResult result = new ModifyUpdateResult();
- TensorDataType tensorDataType = (TensorDataType)field.getDataType();
- TensorType convertedType = TensorModifyUpdate.convertToCompatibleType(tensorDataType.getTensorType());
buffer.next();
int localNesting = buffer.nesting();
while (localNesting <= buffer.nesting()) {
@@ -94,7 +95,7 @@ public class TensorModifyUpdateReader {
result.operation = createOperation(buffer, field.getName());
break;
case TENSOR_CELLS:
- result.tensor = createTensor(buffer, convertedType);
+ result.tensor = createTensor(buffer, field);
break;
default:
throw new IllegalArgumentException("Unknown JSON string '" + buffer.currentName() + "' in modify update for field '" + field.getName() + "'");
@@ -117,12 +118,39 @@ public class TensorModifyUpdateReader {
}
}
- private static TensorFieldValue createTensor(TokenBuffer buffer, TensorType tensorType) {
- Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType);
+ private static TensorFieldValue createTensor(TokenBuffer buffer, Field field) {
+ TensorDataType tensorDataType = (TensorDataType)field.getDataType();
+ TensorType originalType = tensorDataType.getTensorType();
+ TensorType convertedType = TensorModifyUpdate.convertToCompatibleType(originalType);
+
+ Tensor.Builder tensorBuilder = Tensor.Builder.of(convertedType);
readTensorCells(buffer, tensorBuilder);
- TensorFieldValue result = new TensorFieldValue(tensorType);
- result.assign(tensorBuilder.build());
+ Tensor tensor = tensorBuilder.build();
+
+ validateBounds(tensor, originalType);
+
+ TensorFieldValue result = new TensorFieldValue(convertedType);
+ result.assign(tensor);
return result;
}
+ /** Only validate if original type is indexed bound */
+ private static void validateBounds(Tensor convertedTensor, TensorType originalType) {
+ if ( ! originalType.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) {
+ return;
+ }
+ for (Iterator<Tensor.Cell> iter = convertedTensor.cellIterator(); iter.hasNext(); ) {
+ Tensor.Cell cell = iter.next();
+ TensorAddress address = cell.getKey();
+ for (int i = 0; i < address.size(); ++i) {
+ long label = address.numericLabel(i);
+ long bound = originalType.dimensions().get(i).size().get(); // size is non-optional for indexed bound
+ if (label >= bound) {
+ throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() +
+ "' has label '" + label + "' but type is " + originalType.toString());
+ }
+ }
+ }
+ }
+
}
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index ec37ebc8295..376fac3fd84 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -1353,6 +1353,16 @@ public class JsonReaderTestCase {
}
@Test
+ public void tensor_modify_update_with_multiply_operation_dense() {
+ assertTensorModifyUpdate("{{x:a,y:b}:2.0}", TensorModifyUpdate.Operation.MULTIPLY, "sparse_tensor",
+ inputJson("{",
+ " 'operation': 'multiply',",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': 'b' }, 'value': 2.0 } ]}"));
+ }
+
+
+ @Test
public void tensor_modify_update_treats_the_input_tensor_as_sparse() {
// Note that the type of the tensor in the modify update is sparse (it only has mapped dimensions).
assertTensorModifyUpdate("tensor(x{},y{}):{{x:0,y:0}:2.0, {x:1,y:2}:3.0}",
@@ -1395,6 +1405,16 @@ public class JsonReaderTestCase {
}
@Test
+ public void tensor_modify_update_with_out_of_bound_cells_throws() {
+ exception.expect(IndexOutOfBoundsException.class);
+ exception.expectMessage("Dimension 'y' has label '3' but type is tensor(x[2],y[3])");
+ createTensorModifyUpdate(inputJson("{",
+ " 'operation': 'replace',",
+ " 'cells': [",
+ " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"), "dense_tensor");
+ }
+
+ @Test
public void tensor_modify_update_with_unknown_operation_throws() {
exception.expect(IllegalArgumentException.class);
exception.expectMessage("Unknown operation 'unknown' in modify update for field 'sparse_tensor'");