summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2020-01-13 14:38:24 +0100
committerJon Bratseth <bratseth@verizonmedia.com>2020-01-13 14:38:24 +0100
commitfdcf0682eb4ed0471431adaf4a6be70628b9c84d (patch)
tree929006dbc7398704f1ee496c3e9df020ef23c21d
parent7fad0f3d7b5dcd171655d101c05cf51f758bfc83 (diff)
Convert tensor update to sparse
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java15
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorReader.java27
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java2
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java14
-rw-r--r--vespajlib/abi-spec.json1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java7
7 files changed, 55 insertions, 17 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
index 5fd1c7bbab7..b8937d8b739 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
@@ -132,20 +132,27 @@ public class TensorModifyUpdateReader {
Tensor.Builder tensorBuilder = Tensor.Builder.of(type);
readTensorBlocks(buffer, tensorBuilder);
- Tensor tensor = tensorBuilder.build();
-
+ Tensor tensor = convertToSparse(tensorBuilder.build());
validateBounds(tensor, type);
return new TensorFieldValue(tensor);
}
+ private static Tensor convertToSparse(Tensor tensor) {
+ if (tensor.type().dimensions().stream().noneMatch(dimension -> dimension.isIndexed())) return tensor;
+ Tensor.Builder b = Tensor.Builder.of(TensorModifyUpdate.convertDimensionsToMapped(tensor.type()));
+ for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); )
+ b.cell(i.next());
+ return b.build();
+ }
+
/** Only validate if original type has indexed bound dimensions */
static void validateBounds(Tensor convertedTensor, TensorType originalType) {
if (originalType.dimensions().stream().noneMatch(d -> d instanceof TensorType.IndexedBoundDimension)) {
return;
}
- for (Iterator<Tensor.Cell> iter = convertedTensor.cellIterator(); iter.hasNext(); ) {
- Tensor.Cell cell = iter.next();
+ for (Iterator<Tensor.Cell> cellIterator = convertedTensor.cellIterator(); cellIterator.hasNext(); ) {
+ Tensor.Cell cell = cellIterator.next();
TensorAddress address = cell.getKey();
for (int i = 0; i < address.size(); ++i) {
TensorType.Dimension dim = originalType.dimensions().get(i);
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
index e5699d0e6b1..769e31818e6 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
@@ -111,12 +111,15 @@ public class TensorReader {
}
else if (buffer.currentToken() == JsonToken.START_OBJECT) {
int initNesting = buffer.nesting();
- for (buffer.next(); buffer.nesting() >= initNesting; buffer.next())
- mixedBuilder.block(asAddress(buffer.currentName(), builder.type().mappedSubtype()),
- readValues(buffer, (int)mixedBuilder.denseSubspaceSize()));
+ for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
+ TensorAddress mappedAddress = asAddress(buffer.currentName(), builder.type().mappedSubtype());
+ mixedBuilder.block(mappedAddress,
+ readValues(buffer, (int) mixedBuilder.denseSubspaceSize(), mappedAddress, mixedBuilder.type()));
+ }
}
else {
- throw new IllegalArgumentException("Expected 'blocks' to contain an array or an object, but got " + buffer.currentToken());
+ throw new IllegalArgumentException("Expected 'blocks' to contain an array or an object, but got " +
+ buffer.currentToken());
}
expectCompositeEnd(buffer.currentToken());
@@ -134,7 +137,7 @@ public class TensorReader {
if (TensorReader.TENSOR_ADDRESS.equals(currentName))
address = readAddress(buffer, mixedBuilder.type().mappedSubtype());
else if (TensorReader.TENSOR_VALUES.equals(currentName))
- values = readValues(buffer, (int)mixedBuilder.denseSubspaceSize());
+ values = readValues(buffer, (int)mixedBuilder.denseSubspaceSize(), address, mixedBuilder.type());
}
expectObjectEnd(buffer.currentToken());
if (address == null)
@@ -154,7 +157,16 @@ public class TensorReader {
return builder.build();
}
- private static double[] readValues(TokenBuffer buffer, int size) {
+ /**
+ * Reads values for a tensor subspace block
+ *
+ * @param buffer the buffer containing the values
+ * @param size the expected number of values
+ * @param address the address for the block for error reporting, or null if not known
+ * @param type the type of the tensor we are reading
+ * @return the values read
+ */
+ private static double[] readValues(TokenBuffer buffer, int size, TensorAddress address, TensorType type) {
expectArrayStart(buffer.currentToken());
int index = 0;
@@ -162,6 +174,9 @@ public class TensorReader {
double[] values = new double[size];
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next())
values[index++] = readDouble(buffer);
+ if (index != size)
+ throw new IllegalArgumentException((address != null ? "At " + address.toString(type) + ": " : "") +
+ "Expected " + size + " values, but got " + index);
expectCompositeEnd(buffer.currentToken());
return values;
}
diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
index 0015b59e9a9..b6664464e0b 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
@@ -30,7 +30,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
}
private void verifyCompatibleType(TensorType type) {
- if (type.rank() > 0 && type.dimensions().stream().noneMatch(dim -> dim.isMapped()) ) {
+ if (type.dimensions().stream().anyMatch(dim -> dim.isIndexed()) ) {
throw new IllegalArgumentException("Tensor type '" + type + "' is not compatible as it has no mapped dimensions");
}
}
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index 5867ca5596c..54ae3d6d373 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -1498,7 +1498,7 @@ public class JsonReaderTestCase {
@Test
public void tensor_modify_update_with_replace_operation_mixed_block_short_form_array() {
- assertTensorModifyUpdate("tensor(x{},y[3]):{a:[1,2,3]}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor",
+ assertTensorModifyUpdate("{{x:a,y:0}:1,{x:a,y:1}:2,{x:a,y:2}:3}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor",
inputJson("{",
" 'operation': 'replace',",
" 'blocks': [",
@@ -1506,8 +1506,18 @@ public class JsonReaderTestCase {
}
@Test
+ public void tensor_modify_update_with_replace_operation_mixed_block_short_form_must_specify_full_subspace() {
+ illegalTensorModifyUpdate("Error in 'mixed_tensor': At {x:a}: Expected 3 values, but got 2",
+ "mixed_tensor",
+ inputJson("{",
+ " 'operation': 'replace',",
+ " 'blocks': {",
+ " 'a': [2,3] } }"));
+ }
+
+ @Test
public void tensor_modify_update_with_replace_operation_mixed_block_short_form_map() {
- assertTensorModifyUpdate("tensor(x{},y[3]):{a:[1,2,3]}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor",
+ assertTensorModifyUpdate("{{x:a,y:0}:1,{x:a,y:1}:2,{x:a,y:2}:3}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor",
inputJson("{",
" 'operation': 'replace',",
" 'blocks': {",
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index f631b3e1c58..66eb4b1f4e6 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1107,6 +1107,7 @@
"public varargs abstract com.yahoo.tensor.Tensor$Builder cell(float, long[])",
"public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell, double)",
"public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell, float)",
+ "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell)",
"public abstract com.yahoo.tensor.Tensor build()"
],
"fields": []
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 67c6930ce35..2b393d8a637 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -253,8 +253,12 @@ public class MixedTensor implements Tensor {
}
public Tensor.Builder block(TensorAddress sparsePart, double[] values) {
+ int denseSubspaceSize = (int)denseSubspaceSize();
+ if (values.length < denseSubspaceSize)
+ throw new IllegalArgumentException("Block should have " + denseSubspaceSize +
+ " values, but has only " + values.length);
double[] denseSubspace = denseSubspace(sparsePart);
- System.arraycopy(values, 0, denseSubspace, 0, (int)denseSubspaceSize());
+ System.arraycopy(values, 0, denseSubspace, 0, denseSubspaceSize);
return this;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 08d4f1c08b7..71bdee36c27 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -516,9 +516,10 @@ public interface Tensor {
default Builder cell(Cell cell, double value) {
return cell(cell.getKey(), value);
}
- default Builder cell(Cell cell, float value) {
- return cell(cell.getKey(), value);
- }
+ default Builder cell(Cell cell, float value) { return cell(cell.getKey(), value); }
+
+ /** Adds the given cell to this tensor */
+ default Builder cell(Cell cell) { return cell(cell.getKey(), cell.getValue()); }
Tensor build();