diff options
author | Lester Solbakken <lesters@oath.com> | 2019-02-20 12:46:24 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-02-20 12:46:24 +0100 |
commit | 085b6922c07f4626c61e2ed2e6dde6beec0855de (patch) | |
tree | 597fc14c08199339c9ab9286c365af6e8d4cdcdb | |
parent | 85e394563c8b711a1a0307c8ac5953c1817f5629 (diff) |
TensorAddUpdate support for mixed tensors
10 files changed, 158 insertions, 55 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java index ffbfe49347c..da8bcc13397 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java @@ -6,10 +6,15 @@ import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.json.TokenBuffer; import com.yahoo.document.update.TensorAddUpdate; +import com.yahoo.document.update.TensorModifyUpdate; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import java.util.Iterator; + import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectStart; +import static com.yahoo.document.json.readers.TensorModifyUpdateReader.validateBounds; import static com.yahoo.document.json.readers.TensorReader.fillTensor; /** @@ -23,22 +28,27 @@ public class TensorAddUpdateReader { public static TensorAddUpdate createTensorAddUpdate(TokenBuffer buffer, Field field) { expectObjectStart(buffer.currentToken()); - expectTensorTypeIsSparse(field); + expectTensorTypeHasSparseDimensions(field); + // Convert update type to sparse TensorDataType tensorDataType = (TensorDataType)field.getDataType(); - TensorType tensorType = tensorDataType.getTensorType(); - TensorFieldValue tensorFieldValue = new TensorFieldValue(tensorType); + TensorType originalType = tensorDataType.getTensorType(); + TensorType convertedType = TensorModifyUpdate.convertToCompatibleType(originalType); + + TensorFieldValue tensorFieldValue = new TensorFieldValue(convertedType); fillTensor(buffer, tensorFieldValue); expectTensorIsNonEmpty(field, tensorFieldValue.getTensor().get()); + validateBounds(tensorFieldValue.getTensor().get(), originalType); + return new TensorAddUpdate(tensorFieldValue); } - private static void expectTensorTypeIsSparse(Field field) { + private static void expectTensorTypeHasSparseDimensions(Field field) { TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType(); - if (tensorType.dimensions().stream() - .anyMatch(dim -> dim.isIndexed())) { - throw new IllegalArgumentException("An add update can only be applied to sparse tensors. " - + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'"); + if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) { + throw new IllegalArgumentException("An add update can only be applied to tensors " + + "with at least one sparse dimension. Field '" + field.getName() + + "' has unsupported tensor type '" + tensorType + "'"); } } @@ -48,5 +58,4 @@ public class TensorAddUpdateReader { } } - } diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java index a9bbba519bd..5022185e03f 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java @@ -129,25 +129,26 @@ public class TensorModifyUpdateReader { validateBounds(tensor, originalType); - TensorFieldValue result = new TensorFieldValue(convertedType); - result.assign(tensor); - return result; + return new TensorFieldValue(tensor); } - /** Only validate if original type is indexed bound */ - private static void validateBounds(Tensor convertedTensor, TensorType originalType) { - if ( ! originalType.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) { + /** Only validate if original type has indexed bound dimensions */ + static void validateBounds(Tensor convertedTensor, TensorType originalType) { + if (originalType.dimensions().stream().noneMatch(d -> d instanceof TensorType.IndexedBoundDimension)) { return; } for (Iterator<Tensor.Cell> iter = convertedTensor.cellIterator(); iter.hasNext(); ) { Tensor.Cell cell = iter.next(); TensorAddress address = cell.getKey(); for (int i = 0; i < address.size(); ++i) { - long label = address.numericLabel(i); - long bound = originalType.dimensions().get(i).size().get(); // size is non-optional for indexed bound - if (label >= bound) { - throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() + - "' has label '" + label + "' but type is " + originalType.toString()); + TensorType.Dimension dim = originalType.dimensions().get(i); + if (dim instanceof TensorType.IndexedBoundDimension) { + long label = address.numericLabel(i); + long bound = dim.size().get(); // size is non-optional for indexed bound + if (label >= bound) { + throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() + + "' has label '" + label + "' but type is " + originalType.toString()); + } } } } diff --git a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java index cfc3ee0c742..7059edbca7f 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java @@ -13,9 +13,9 @@ import java.util.Map; import java.util.Objects; /** - * An update used to add cells to a sparse tensor (has only mapped dimensions). + * An update used to add cells to a sparse or mixed tensor (has at least one mapped dimension). * - * The cells to add are contained in a sparse tensor as well. + * The cells to add are contained in a sparse tensor. */ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> { @@ -50,22 +50,10 @@ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> { return oldValue; } - Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get(); - Map<TensorAddress, Double> oldCells = oldTensor.cells(); - Map<TensorAddress, Double> addCells = tensor.getTensor().get().cells(); - - // currently, underlying implementation disallows multiple entries with the same key - - Tensor.Builder builder = Tensor.Builder.of(oldTensor.type()); - for (Map.Entry<TensorAddress, Double> oldCell : oldCells.entrySet()) { - builder.cell(oldCell.getKey(), addCells.getOrDefault(oldCell.getKey(), oldCell.getValue())); - } - for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { - if ( ! oldCells.containsKey(addCell.getKey())) { - builder.cell(addCell.getKey(), addCell.getValue()); - } - } - return new TensorFieldValue(builder.build()); + Tensor old = ((TensorFieldValue) oldValue).getTensor().get(); + Tensor update = tensor.getTensor().get(); + Tensor result = old.merge((left, right) -> right, update.cells()); + return new TensorFieldValue(result); } @Override diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index e58b26d500d..a20276e5c65 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -56,6 +56,7 @@ import com.yahoo.text.Utf8; import org.apache.commons.codec.binary.Base64; import org.junit.After; import org.junit.Before; +import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -1449,11 +1450,29 @@ public class JsonReaderTestCase { } @Test - public void tensor_add_update_on_non_sparse_tensor_throws() { + public void tensor_add_update_on_mixed_tensor() { + assertTensorAddUpdate("{{x:a,y:0}:2.0, {x:a,y:1}:3.0}", "mixed_tensor", + inputJson("{", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 },", + " { 'address': { 'x': 'a', 'y': '1' }, 'value': 3.0 } ]}")); + } + + @Test + public void tensor_add_update_with_out_of_bound_dense_cells_throws() { + exception.expect(IndexOutOfBoundsException.class); + exception.expectMessage("Dimension 'y' has label '3' but type is tensor(x{},y[3])"); + createTensorAddUpdate(inputJson("{", + " 'cells': [", + " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"), "mixed_tensor"); + } + + @Test + public void tensor_add_update_on_dense_tensor_throws() { exception.expect(IllegalArgumentException.class); - exception.expectMessage("An add update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'"); + exception.expectMessage("An add update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'"); createTensorAddUpdate(inputJson("{", - " 'cells': [] }"), "mixed_tensor"); + " 'cells': [] }"), "dense_tensor"); } @Test @@ -1481,12 +1500,22 @@ public class JsonReaderTestCase { " { 'x': 'c', 'y': 'd' } ]}")); } + @Ignore + @Test + public void tensor_remove_update_on_mixed_tensor() { + assertTensorRemoveUpdate("{{x:1}:1.0,{x:2}:1.0}", "mixed_tensor", + inputJson("{", + " 'addresses': [", + " { 'x': '1' },", + " { 'x': '2' } ]}")); + } + @Test - public void tensor_remove_update_on_non_sparse_tensor_throws() { + public void tensor_remove_update_on_dense_tensor_throws() { exception.expect(IllegalArgumentException.class); - exception.expectMessage("A remove update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'"); + exception.expectMessage("A remove update can only be applied to sparse tensors. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'"); createTensorRemoveUpdate(inputJson("{", - " 'addresses': [] }"), "mixed_tensor"); + " 'addresses': [] }"), "dense_tensor"); } @Test diff --git a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java index eb4001e6415..c6b21380e4b 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java @@ -3,27 +3,40 @@ package com.yahoo.document.update; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.tensor.Tensor; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import static org.junit.Assert.assertEquals; public class TensorAddUpdateTest { + @Rule + public ExpectedException exception = ExpectedException.none(); + @Test public void apply_add_update_operations() { - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}"); - assertApplyTo("{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1, {x:0,y:1}:2}"); + assertApplyTo("tensor(x{},y{})", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"); + assertApplyTo("tensor(x{},y{})", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3}"); + assertApplyTo("tensor(x{},y{})", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}"); + assertApplyTo("tensor(x{},y{})", "{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5}"); + assertApplyTo("tensor(x{},y{})", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1, {x:0,y:1}:2}"); + + assertApplyTo("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"); + assertApplyTo("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:0}"); + assertApplyTo("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}"); + assertApplyTo("tensor(x{},y[3])", "{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5,{x:0,y:1}:0,{x:0,y:2}:0}"); + assertApplyTo("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:0}"); } - private void assertApplyTo(String init, String update, String expected) { - String spec = "tensor(x{},y{})"; + private Tensor updateField(String spec, String init, String update) { TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init)); - TensorAddUpdate addUpdate = new TensorAddUpdate(new TensorFieldValue(Tensor.from(spec, update))); - TensorFieldValue updatedFieldValue = (TensorFieldValue) addUpdate.applyTo(initialFieldValue); - assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get()); + TensorAddUpdate addUpdate = new TensorAddUpdate(new TensorFieldValue(Tensor.from("tensor(x{},y{})", update))); + return ((TensorFieldValue) addUpdate.applyTo(initialFieldValue)).getTensor().get(); + } + + private void assertApplyTo(String spec, String init, String update, String expected) { + assertEquals(Tensor.from(spec, expected), updateField(spec, init, update)); } } diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 932513f8a57..480523982fa 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -808,6 +808,7 @@ "public com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.DimensionSizes dimensionSizes()", "public java.util.Map cells()", + "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)", @@ -852,6 +853,7 @@ "public java.util.Iterator valueIterator()", "public java.util.Map cells()", "public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", + "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)" @@ -937,6 +939,7 @@ "public java.util.Iterator valueIterator()", "public java.util.Map cells()", "public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", + "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)", @@ -1039,6 +1042,7 @@ "public double asDouble()", "public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public abstract com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", "public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)", "public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])", "public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index fb55b2d5014..704cead7c01 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -13,6 +13,7 @@ import java.util.Map; import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; +import java.util.function.DoubleBinaryOperator; /** * An indexed (dense) tensor backed by a double array. @@ -190,6 +191,11 @@ public class IndexedTensor implements Tensor { } @Override + public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells) { + throw new IllegalArgumentException("Merge is not supported for indexed tensors"); + } + + @Override public int hashCode() { return Arrays.hashCode(values); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index ec3020a1a4e..f44b3ce13b7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableMap; import java.util.Iterator; import java.util.Map; +import java.util.function.DoubleBinaryOperator; /** * A sparse implementation of a tensor backed by a Map of cells to values. @@ -51,6 +52,25 @@ public class MappedTensor implements Tensor { } @Override + public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) { + + // currently, underlying implementation disallows multiple entries with the same key + + Tensor.Builder builder = Tensor.Builder.of(type()); + for (Map.Entry<TensorAddress, Double> cell : cells.entrySet()) { + TensorAddress address = cell.getKey(); + double value = cell.getValue(); + builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value); + } + for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { + if ( ! cells.containsKey(addCell.getKey())) { + builder.cell(addCell.getKey(), addCell.getValue()); + } + } + return builder.build(); + } + + @Override public int hashCode() { return cells.hashCode(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 17e33c58a13..3630a016691 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -9,6 +9,7 @@ import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.function.DoubleBinaryOperator; import java.util.stream.Collectors; /** @@ -70,13 +71,17 @@ public class MixedTensor implements Tensor { return cells.iterator(); } + private Iterable<Cell> cellIterable() { + return this::cellIterator; + } + /** * Returns an iterator over the values of this tensor. * The iteration order is the same as for cellIterator. */ @Override public Iterator<Double> valueIterator() { - return new Iterator<Double>() { + return new Iterator<>() { Iterator<Cell> cellIterator = cellIterator(); @Override public boolean hasNext() { @@ -108,6 +113,20 @@ public class MixedTensor implements Tensor { } @Override + public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) { + Tensor.Builder builder = Tensor.Builder.of(type()); + for (Cell cell : cellIterable()) { + TensorAddress address = cell.getKey(); + double value = cell.getValue(); + builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value); + } + for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { + builder.cell(addCell.getKey(), addCell.getValue()); + } + return builder.build(); + } + + @Override public int hashCode() { return cells.hashCode(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 8002990e5c6..175e6b41daa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -113,6 +113,20 @@ public interface Tensor { return builder.build(); } + /** + * Returns a new tensor where existing cells in this tensor have been + * modified according to the given operation and cells in the given map. + * In contrast to {@link #modify}, previously non-existing cells are added + * to this tensor. Only valid for sparse or mixed tensors. + * + * @param op how to update overlapping cells + * @param cells cells to merge with this tensor + * @return a new tensor where this tensor is merged with the other + */ + Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells); + +// Tensor remove(Tensor other); + // ----------------- Primitive tensor functions default Tensor map(DoubleUnaryOperator mapper) { |