aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-02-14 14:42:14 +0100
committerLester Solbakken <lesters@oath.com>2019-02-14 14:42:14 +0100
commitaaae66f61dcbb1acb246922903e8f573f7059c88 (patch)
tree290a4c5b38d79cdcb057e5f05d8accee317f1cb6
parente6e410fdba83d7a06eb45b0ff6071dd152caa1af (diff)
Use a tensor as representation for partial tensor remove updates
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java30
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java49
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java13
3 files changed, 42 insertions, 50 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
index 8fe3730c6e1..6638320699c 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
@@ -3,14 +3,13 @@ package com.yahoo.document.json.readers;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
+import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.json.TokenBuffer;
import com.yahoo.document.update.TensorRemoveUpdate;
+import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
-import java.util.HashSet;
-import java.util.Set;
-
import static com.yahoo.document.json.readers.JsonParserHelpers.expectArrayStart;
import static com.yahoo.document.json.readers.JsonParserHelpers.expectCompositeEnd;
import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectEnd;
@@ -26,11 +25,15 @@ public class TensorRemoveUpdateReader {
static TensorRemoveUpdate createTensorRemoveUpdate(TokenBuffer buffer, Field field) {
expectObjectStart(buffer.currentToken());
expectTensorTypeIsSparse(field);
+
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType tensorType = tensorDataType.getTensorType();
- Set<TensorAddress> addresses = readTensorAddresses(buffer, tensorType);
- expectAddressesIsNonEmpty(field, addresses);
- return new TensorRemoveUpdate(tensorType, addresses);
+
+ // TODO: for mixed case extract a new tensor type based only on mapped dimensions
+
+ Tensor tensor = readRemoveUpdateTensor(buffer, tensorType);
+ expectAddressesAreNonEmpty(field, tensor);
+ return new TensorRemoveUpdate(new TensorFieldValue(tensor));
}
private static void expectTensorTypeIsSparse(Field field) {
@@ -41,14 +44,17 @@ public class TensorRemoveUpdateReader {
}
}
- private static void expectAddressesIsNonEmpty(Field field, Set<TensorAddress> addresses) {
- if (addresses.isEmpty()) {
+ private static void expectAddressesAreNonEmpty(Field field, Tensor tensor) {
+ if (tensor.isEmpty()) {
throw new IllegalArgumentException("Remove update for field '" + field.getName() + "' does not contain tensor addresses");
}
}
- private static Set<TensorAddress> readTensorAddresses(TokenBuffer buffer, TensorType type) {
- Set<TensorAddress> addresses = new HashSet<>();
+ /**
+ * Reads all addresses in buffer and returns a tensor where addresses have cell value 1.0
+ */
+ private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type) {
+ Tensor.Builder builder = Tensor.Builder.of(type);
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
@@ -56,13 +62,13 @@ public class TensorRemoveUpdateReader {
expectArrayStart(buffer.currentToken());
int nesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= nesting; buffer.next()) {
- addresses.add(readTensorAddress(buffer, type));
+ builder.cell(readTensorAddress(buffer, type), 1.0);
}
expectCompositeEnd(buffer.currentToken());
}
}
expectObjectEnd(buffer.currentToken());
- return addresses;
+ return builder.build();
}
private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type) {
diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
index c624b69d522..d9ef84199fa 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
@@ -1,32 +1,26 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.document.update;
-import com.google.common.collect.ImmutableSet;
import com.yahoo.document.DataType;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.FieldValue;
+import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.serialization.DocumentUpdateWriter;
-import com.yahoo.tensor.TensorAddress;
-import com.yahoo.tensor.TensorType;
import java.util.Objects;
-import java.util.Set;
-import java.util.StringJoiner;
/**
* An update used to remove cells from a sparse tensor (has only mapped dimensions).
*
* The cells to remove are contained in a set of addresses.
*/
-public class TensorRemoveUpdate extends ValueUpdate {
+public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
- private TensorType tensorType;
- private ImmutableSet<TensorAddress> addresses;
+ private TensorFieldValue tensor;
- public TensorRemoveUpdate(TensorType tensorType, Set<TensorAddress> addresses) {
+ public TensorRemoveUpdate(TensorFieldValue value) {
super(ValueUpdateClassID.TENSORREMOVE);
- this.tensorType = tensorType;
- this.addresses = ImmutableSet.copyOf(addresses);
+ this.tensor = value;
}
@Override
@@ -49,39 +43,32 @@ public class TensorRemoveUpdate extends ValueUpdate {
}
@Override
- public FieldValue getValue() {
- return null;
+ public TensorFieldValue getValue() {
+ return tensor;
}
- public TensorType getTensorType() {
- return tensorType;
- }
-
- public Set<TensorAddress> getAddresses() {
- return addresses;
+ @Override
+ public void setValue(TensorFieldValue value) {
+ tensor = value;
}
@Override
- public void setValue(FieldValue value) {
- // Ignore
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ if (!super.equals(o)) return false;
+ TensorRemoveUpdate that = (TensorRemoveUpdate) o;
+ return tensor.equals(that.tensor);
}
@Override
public int hashCode() {
- return Objects.hash(super.hashCode(), addresses);
+ return Objects.hash(super.hashCode(), tensor);
}
@Override
public String toString() {
- return super.toString() + " " + toStringWithType();
- }
-
- public String toStringWithType() {
- StringJoiner sj = new StringJoiner(",", "[", "]");
- for (TensorAddress address : addresses) {
- sj.add(address.toString(tensorType));
- }
- return sj.toString();
+ return super.toString() + " " + tensor;
}
}
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index 9088b5ced6a..e58b26d500d 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -1474,7 +1474,7 @@ public class JsonReaderTestCase {
@Test
public void tensor_remove_update_on_sparse_tensor() {
- assertTensorRemoveUpdate("[{x:a,y:b},{x:c,y:d}]", "sparse_tensor",
+ assertTensorRemoveUpdate("{{x:a,y:b}:1.0,{x:c,y:d}:1.0}", "sparse_tensor",
inputJson("{",
" 'addresses': [",
" { 'x': 'a', 'y': 'b' },",
@@ -1620,8 +1620,7 @@ public class JsonReaderTestCase {
assertEquals("testtensor", update.getId().getDocType());
assertEquals(TENSOR_DOC_ID, update.getId().toString());
assertEquals(1, update.fieldUpdates().size());
- FieldUpdate fieldUpdate = update.getFieldUpdate(tensorFieldName);
- assertEquals(1, fieldUpdate.size());
+ assertEquals(1, update.getFieldUpdate(tensorFieldName).size());
}
private static void assertTensorAssignUpdate(String expectedTensor, DocumentUpdate update) {
@@ -1678,14 +1677,14 @@ public class JsonReaderTestCase {
assertEquals(Tensor.from(expectedTensor), addUpdate.getValue().getTensor().get());
}
- private void assertTensorRemoveUpdate(String expected, String tensorFieldName, String tensorJson) {
- assertTensorRemoveUpdate(expected, tensorFieldName, createTensorRemoveUpdate(tensorJson, tensorFieldName));
+ private void assertTensorRemoveUpdate(String expectedTensor, String tensorFieldName, String tensorJson) {
+ assertTensorRemoveUpdate(expectedTensor, tensorFieldName, createTensorRemoveUpdate(tensorJson, tensorFieldName));
}
- private static void assertTensorRemoveUpdate(String expected, String tensorFieldName, DocumentUpdate update) {
+ private static void assertTensorRemoveUpdate(String expectedTensor, String tensorFieldName, DocumentUpdate update) {
assertTensorFieldUpdate(update, tensorFieldName);
TensorRemoveUpdate removeUpdate = (TensorRemoveUpdate) update.getFieldUpdate(tensorFieldName).getValueUpdate(0);
- assertEquals(expected, removeUpdate.toStringWithType());
+ assertEquals(Tensor.from(expectedTensor), removeUpdate.getValue().getTensor().get());
}
private static FieldUpdate getTensorField(DocumentUpdate update) {