diff options
author | Lester Solbakken <lesters@oath.com> | 2019-02-14 14:42:14 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-02-14 14:42:14 +0100 |
commit | aaae66f61dcbb1acb246922903e8f573f7059c88 (patch) | |
tree | 290a4c5b38d79cdcb057e5f05d8accee317f1cb6 | |
parent | e6e410fdba83d7a06eb45b0ff6071dd152caa1af (diff) |
Use a tensor as representation for partial tensor remove updates
3 files changed, 42 insertions, 50 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java index 8fe3730c6e1..6638320699c 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java @@ -3,14 +3,13 @@ package com.yahoo.document.json.readers; import com.yahoo.document.Field; import com.yahoo.document.TensorDataType; +import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.json.TokenBuffer; import com.yahoo.document.update.TensorRemoveUpdate; +import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; -import java.util.HashSet; -import java.util.Set; - import static com.yahoo.document.json.readers.JsonParserHelpers.expectArrayStart; import static com.yahoo.document.json.readers.JsonParserHelpers.expectCompositeEnd; import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectEnd; @@ -26,11 +25,15 @@ public class TensorRemoveUpdateReader { static TensorRemoveUpdate createTensorRemoveUpdate(TokenBuffer buffer, Field field) { expectObjectStart(buffer.currentToken()); expectTensorTypeIsSparse(field); + TensorDataType tensorDataType = (TensorDataType)field.getDataType(); TensorType tensorType = tensorDataType.getTensorType(); - Set<TensorAddress> addresses = readTensorAddresses(buffer, tensorType); - expectAddressesIsNonEmpty(field, addresses); - return new TensorRemoveUpdate(tensorType, addresses); + + // TODO: for mixed case extract a new tensor type based only on mapped dimensions + + Tensor tensor = readRemoveUpdateTensor(buffer, tensorType); + expectAddressesAreNonEmpty(field, tensor); + return new TensorRemoveUpdate(new TensorFieldValue(tensor)); } private static void expectTensorTypeIsSparse(Field field) { @@ -41,14 +44,17 @@ public class TensorRemoveUpdateReader { } } - private static void expectAddressesIsNonEmpty(Field field, Set<TensorAddress> addresses) { - if (addresses.isEmpty()) { + private static void expectAddressesAreNonEmpty(Field field, Tensor tensor) { + if (tensor.isEmpty()) { throw new IllegalArgumentException("Remove update for field '" + field.getName() + "' does not contain tensor addresses"); } } - private static Set<TensorAddress> readTensorAddresses(TokenBuffer buffer, TensorType type) { - Set<TensorAddress> addresses = new HashSet<>(); + /** + * Reads all addresses in buffer and returns a tensor where addresses have cell value 1.0 + */ + private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type) { + Tensor.Builder builder = Tensor.Builder.of(type); expectObjectStart(buffer.currentToken()); int initNesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { @@ -56,13 +62,13 @@ public class TensorRemoveUpdateReader { expectArrayStart(buffer.currentToken()); int nesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= nesting; buffer.next()) { - addresses.add(readTensorAddress(buffer, type)); + builder.cell(readTensorAddress(buffer, type), 1.0); } expectCompositeEnd(buffer.currentToken()); } } expectObjectEnd(buffer.currentToken()); - return addresses; + return builder.build(); } private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type) { diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java index c624b69d522..d9ef84199fa 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java @@ -1,32 +1,26 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.document.update; -import com.google.common.collect.ImmutableSet; import com.yahoo.document.DataType; import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.FieldValue; +import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.serialization.DocumentUpdateWriter; -import com.yahoo.tensor.TensorAddress; -import com.yahoo.tensor.TensorType; import java.util.Objects; -import java.util.Set; -import java.util.StringJoiner; /** * An update used to remove cells from a sparse tensor (has only mapped dimensions). * * The cells to remove are contained in a set of addresses. */ -public class TensorRemoveUpdate extends ValueUpdate { +public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { - private TensorType tensorType; - private ImmutableSet<TensorAddress> addresses; + private TensorFieldValue tensor; - public TensorRemoveUpdate(TensorType tensorType, Set<TensorAddress> addresses) { + public TensorRemoveUpdate(TensorFieldValue value) { super(ValueUpdateClassID.TENSORREMOVE); - this.tensorType = tensorType; - this.addresses = ImmutableSet.copyOf(addresses); + this.tensor = value; } @Override @@ -49,39 +43,32 @@ public class TensorRemoveUpdate extends ValueUpdate { } @Override - public FieldValue getValue() { - return null; + public TensorFieldValue getValue() { + return tensor; } - public TensorType getTensorType() { - return tensorType; - } - - public Set<TensorAddress> getAddresses() { - return addresses; + @Override + public void setValue(TensorFieldValue value) { + tensor = value; } @Override - public void setValue(FieldValue value) { - // Ignore + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (!super.equals(o)) return false; + TensorRemoveUpdate that = (TensorRemoveUpdate) o; + return tensor.equals(that.tensor); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), addresses); + return Objects.hash(super.hashCode(), tensor); } @Override public String toString() { - return super.toString() + " " + toStringWithType(); - } - - public String toStringWithType() { - StringJoiner sj = new StringJoiner(",", "[", "]"); - for (TensorAddress address : addresses) { - sj.add(address.toString(tensorType)); - } - return sj.toString(); + return super.toString() + " " + tensor; } } diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index 9088b5ced6a..e58b26d500d 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -1474,7 +1474,7 @@ public class JsonReaderTestCase { @Test public void tensor_remove_update_on_sparse_tensor() { - assertTensorRemoveUpdate("[{x:a,y:b},{x:c,y:d}]", "sparse_tensor", + assertTensorRemoveUpdate("{{x:a,y:b}:1.0,{x:c,y:d}:1.0}", "sparse_tensor", inputJson("{", " 'addresses': [", " { 'x': 'a', 'y': 'b' },", @@ -1620,8 +1620,7 @@ public class JsonReaderTestCase { assertEquals("testtensor", update.getId().getDocType()); assertEquals(TENSOR_DOC_ID, update.getId().toString()); assertEquals(1, update.fieldUpdates().size()); - FieldUpdate fieldUpdate = update.getFieldUpdate(tensorFieldName); - assertEquals(1, fieldUpdate.size()); + assertEquals(1, update.getFieldUpdate(tensorFieldName).size()); } private static void assertTensorAssignUpdate(String expectedTensor, DocumentUpdate update) { @@ -1678,14 +1677,14 @@ public class JsonReaderTestCase { assertEquals(Tensor.from(expectedTensor), addUpdate.getValue().getTensor().get()); } - private void assertTensorRemoveUpdate(String expected, String tensorFieldName, String tensorJson) { - assertTensorRemoveUpdate(expected, tensorFieldName, createTensorRemoveUpdate(tensorJson, tensorFieldName)); + private void assertTensorRemoveUpdate(String expectedTensor, String tensorFieldName, String tensorJson) { + assertTensorRemoveUpdate(expectedTensor, tensorFieldName, createTensorRemoveUpdate(tensorJson, tensorFieldName)); } - private static void assertTensorRemoveUpdate(String expected, String tensorFieldName, DocumentUpdate update) { + private static void assertTensorRemoveUpdate(String expectedTensor, String tensorFieldName, DocumentUpdate update) { assertTensorFieldUpdate(update, tensorFieldName); TensorRemoveUpdate removeUpdate = (TensorRemoveUpdate) update.getFieldUpdate(tensorFieldName).getValueUpdate(0); - assertEquals(expected, removeUpdate.toStringWithType()); + assertEquals(Tensor.from(expectedTensor), removeUpdate.getValue().getTensor().get()); } private static FieldUpdate getTensorField(DocumentUpdate update) { |