aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-09-27 13:45:10 +0200
committerLester Solbakken <lesters@oath.com>2021-09-27 13:45:10 +0200
commite02be90cd8ea302cb23444a7dd321c9ef774913a (patch)
tree609821fc539839867fb652c55709b424fa1127a5 /vespajlib
parent9377da84086392e118d69b467006e73fe9ae3f70 (diff)
Stateless REST API: short forms for sparse and mixed tensors
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java61
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java30
3 files changed, 88 insertions, 5 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 71ed347219e..33dcd458980 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -91,7 +91,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return b.toString();
}
- /** Returns a label as a string with approriate quoting/escaping when necessary */
+ /** Returns a label as a string with appropriate quoting/escaping when necessary */
public static String labelToString(String label) {
if (TensorType.labelMatcher.matches(label)) return label; // no quoting
if (label.contains("'")) return "\"" + label + "\"";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index cb7539d8565..bebd706f815 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -11,12 +11,19 @@ import com.yahoo.slime.Slime;
import com.yahoo.slime.Type;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.ConstantTensor;
+import com.yahoo.tensor.functions.Slice;
+import java.util.HashSet;
import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
/**
* Writes tensors on the JSON format used in Vespa tensor document fields:
@@ -46,12 +53,33 @@ public class JsonFormat {
}
/** Serializes the given tensor type and value into a short-form JSON format */
- public static byte[] encodeShortForm(IndexedTensor tensor) {
+ public static byte[] encodeShortForm(Tensor tensor) {
Slime slime = new Slime();
Cursor root = slime.setObject();
root.setString("type", tensor.type().toString());
- Cursor value = root.setArray("value");
- encodeList(tensor, value, new long[tensor.dimensionSizes().dimensions()], 0);
+
+ // Encode as nested lists if indexed tensor
+ if (tensor instanceof IndexedTensor) {
+ IndexedTensor denseTensor = (IndexedTensor) tensor;
+ encodeList(denseTensor, root.setArray("value"), new long[denseTensor.dimensionSizes().dimensions()], 0);
+ }
+
+ // Short form for a single mapped dimension
+ else if (tensor instanceof MappedTensor && tensor.type().dimensions().size() == 1) {
+ encodeMap((MappedTensor) tensor, root.setObject("value"));
+ }
+
+ // Short form for a mixed tensor with a single mapped dimension
+ else if (tensor instanceof MixedTensor &&
+ tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() == 1) {
+ encodeMapBlocks((MixedTensor) tensor, root.setObject("value"));
+ }
+
+ // No other short forms exist: default to standard cell address output
+ else {
+ encodeCells(tensor, root.setObject("value"));
+ }
+
return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
}
@@ -81,6 +109,33 @@ public class JsonFormat {
}
}
+ private static void encodeMap(MappedTensor tensor, Cursor cursor) {
+ if (tensor.type().dimensions().size() > 1)
+ throw new IllegalStateException("JSON encode of mapped tensor can only contain a single dimension");
+ tensor.cells().forEach((k,v) -> cursor.setDouble(k.label(0), v));
+ }
+
+ private static void encodeMapBlocks(MixedTensor tensor, Cursor cursor) {
+ var mappedDimensions = tensor.type().dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
+ if (mappedDimensions.size() != 1) {
+ throw new IllegalArgumentException("Should be ensured by caller");
+ }
+ String mappedDimensionName = mappedDimensions.get(0).name();
+ int mappedDimensionIndex = tensor.type().indexOfDimension(mappedDimensionName).
+ orElseThrow(() -> new IllegalStateException("Could not find mapped dimension index"));
+
+ // Find all unique indices for the mapped dimension
+ Set<String> mappedIndices = new HashSet<>();
+ tensor.cellIterator().forEachRemaining((cell) -> mappedIndices.add(cell.getKey().label(mappedDimensionIndex)));
+
+ // Slice out dense subspace of each and encode dense subspace as a list
+ for (String mappedIndex : mappedIndices) {
+ IndexedTensor denseSubspace = (IndexedTensor) new Slice<>(new ConstantTensor<>(tensor),
+ List.of(new Slice.DimensionValue<>(mappedDimensionName, mappedIndex))).evaluate();
+ encodeList(denseSubspace, cursor.setArray(mappedIndex), new long[denseSubspace.dimensionSizes().dimensions()], 0);
+ }
+ }
+
/** Deserializes the given tensor from JSON format */
// NOTE: This must be kept in sync with com.yahoo.document.json.readers.TensorReader in the document module
public static Tensor decode(TensorType type, byte[] jsonTensorValue) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 87796501917..15017dc95ca 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -122,6 +122,34 @@ public class JsonFormatTestCase {
}
@Test
+ public void testSingleDimensionSparseTensorShortForm() {
+ assertEncodeShortForm("tensor(x{}):{a:1, b:2}",
+ "{\"type\":\"tensor(x{})\",\"value\":{\"a\":1.0,\"b\":2.0}}");
+
+ // Multiple mapped dimensions: no short form available
+ assertEncodeShortForm("tensor(x{},y{}):{{x:a,y:b}:1,{x:c,y:d}:2}",
+ "{\"type\":\"tensor(x{},y{})\",\"value\":{\"cells\":[{\"address\":{\"x\":\"a\",\"y\":\"b\"},\"value\":1.0},{\"address\":{\"x\":\"c\",\"y\":\"d\"},\"value\":2.0}]}}");
+ }
+
+ @Test
+ public void testSingleMappedDimensionMixedTensorShortForm() {
+ assertEncodeShortForm("tensor(x{},y[2]):{a:[1,2], b:[3,4] }",
+ "{\"type\":\"tensor(x{},y[2])\",\"value\":{\"a\":[1.0,2.0],\"b\":[3.0,4.0]}}");
+ assertEncodeShortForm("tensor(x[2],y{}):{a:[1,2], b:[3,4] }",
+ "{\"type\":\"tensor(x[2],y{})\",\"value\":{\"a\":[1.0,2.0],\"b\":[3.0,4.0]}}");
+ assertEncodeShortForm("tensor(x{},y[2],z[2]):{a:[[1,2],[3,4]], b:[[5,6],[7,8]] }",
+ "{\"type\":\"tensor(x{},y[2],z[2])\",\"value\":{\"a\":[[1.0,2.0],[3.0,4.0]],\"b\":[[5.0,6.0],[7.0,8.0]]}}");
+ assertEncodeShortForm("tensor(x[1],y{},z[4]):{a:[[1,2,3,4]], b:[[5,6,7,8]] }",
+ "{\"type\":\"tensor(x[1],y{},z[4])\",\"value\":{\"a\":[[1.0,2.0,3.0,4.0]],\"b\":[[5.0,6.0,7.0,8.0]]}}");
+ assertEncodeShortForm("tensor(x[4],y[1],z{}):{a:[[1],[2],[3],[4]], b:[[5],[6],[7],[8]] }",
+ "{\"type\":\"tensor(x[4],y[1],z{})\",\"value\":{\"a\":[[1.0],[2.0],[3.0],[4.0]],\"b\":[[5.0],[6.0],[7.0],[8.0]]}}");
+ assertEncodeShortForm("tensor(a[2],b[2],c{},d[2]):{a:[[[1,2], [3,4]], [[5,6], [7,8]]], b:[[[1,2], [3,4]], [[5,6], [7,8]]] }",
+ "{\"type\":\"tensor(a[2],b[2],c{},d[2])\",\"value\":{" +
+ "\"a\":[[[1.0,2.0],[3.0,4.0]],[[5.0,6.0],[7.0,8.0]]]," +
+ "\"b\":[[[1.0,2.0],[3.0,4.0]],[[5.0,6.0],[7.0,8.0]]]}}");
+ }
+
+ @Test
public void testInt8VectorInHexForm() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[2],y[3])"));
builder.cell().label("x", 0).label("y", 0).value(2.0);
@@ -315,7 +343,7 @@ public class JsonFormatTestCase {
}
private void assertEncodeShortForm(String tensor, String expected) {
- byte[] json = JsonFormat.encodeShortForm((IndexedTensor) Tensor.from(tensor));
+ byte[] json = JsonFormat.encodeShortForm(Tensor.from(tensor));
assertEquals(expected, new String(json, StandardCharsets.UTF_8));
}