aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java42
3 files changed, 47 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 95f64cec0c1..1f3c373c1e8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -212,7 +212,6 @@ public class MixedTensor implements Tensor {
}
-
/**
* Builder for mixed tensors with bound indexed dimensions.
*/
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index bafec70be59..95cc70804e2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -80,11 +80,20 @@ public class TensorType {
/** Sorted list of the dimensions of this */
private final ImmutableList<Dimension> dimensions;
+ private final TensorType mappedSubtype;
+
private TensorType(Value valueType, Collection<Dimension> dimensions) {
this.valueType = valueType;
List<Dimension> dimensionList = new ArrayList<>(dimensions);
Collections.sort(dimensionList);
this.dimensions = ImmutableList.copyOf(dimensionList);
+
+ if (dimensionList.stream().allMatch(d -> d.isIndexed()))
+ mappedSubtype = empty;
+ else if (dimensionList.stream().noneMatch(d -> d.isIndexed()))
+ mappedSubtype = this;
+ else
+ mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> ! d.isIndexed()).collect(Collectors.toList()));
}
static public Value combinedValueType(TensorType ... types) {
@@ -116,6 +125,9 @@ public class TensorType {
/** Returns the numeric type of the cell values of this */
public Value valueType() { return valueType; }
+ /** The type representing the mapped subset of dimensions of this. */
+ public TensorType mappedSubtype() { return mappedSubtype; }
+
/** Returns the number of dimensions of this: dimensions().size() */
public int rank() { return dimensions.size(); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index fa022e2bdd1..2233622db3e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -10,6 +10,7 @@ import com.yahoo.slime.ObjectTraverser;
import com.yahoo.slime.Slime;
import com.yahoo.slime.Type;
import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -75,21 +76,48 @@ public class JsonFormat {
private static void decodeCells(Inspector cells, Tensor.Builder builder) {
if ( cells.type() != Type.ARRAY)
throw new IllegalArgumentException("Excepted 'cells' to contain an array, not " + cells.type());
- cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder.cell()));
+ cells.traverse((ArrayTraverser) (__, cell) -> decodeCellOrCells(cell, builder));
}
- private static void decodeCell(Inspector cell, Tensor.Builder.CellBuilder cellBuilder) {
- Inspector address = cell.field("address");
+ private static void decodeCellOrCells(Inspector cell, Tensor.Builder builder) {
+ Inspector value = cell.field("value");
+ if (value.type() == Type.LONG || value.type() == Type.DOUBLE) {
+ decodeCell(cell.field("address"), value, builder.cell());
+ }
+ else {
+ Inspector values = cell.field("values");
+ if (values.type() == Type.ARRAY)
+ decodeValueBlock(cell.field("address"), values, builder);
+ else
+ throw new IllegalArgumentException("Expected a cell to contain a numeric 'value' or an array 'values'");
+ }
+ }
+
+ private static void decodeCell(Inspector address, Inspector value, Tensor.Builder.CellBuilder cellBuilder) {
if ( address.type() != Type.OBJECT)
throw new IllegalArgumentException("Excepted a cell to contain an object called 'address'");
address.traverse((ObjectTraverser) (dimension, label) -> cellBuilder.label(dimension, label.asString()));
-
- Inspector value = cell.field("value");
- if (value.type() != Type.LONG && value.type() != Type.DOUBLE)
- throw new IllegalArgumentException("Excepted a cell to contain a numeric value called 'value'");
cellBuilder.value(value.asDouble());
}
+ private static void decodeValueBlock(Inspector address, Inspector valuesBlock, Tensor.Builder builder) {
+ if ( ! (builder instanceof MixedTensor.BoundBuilder))
+ throw new IllegalArgumentException("Sending 'values' in 'cells' is only permissible with a mixed tensor " +
+ "type with bound indexed dimensions, but the type is " +
+ builder.type());
+ MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder)builder;
+
+ if (address.type() != Type.OBJECT)
+ throw new IllegalArgumentException("Expected a cell to contain an object called 'address'");
+ TensorAddress.Builder sparseAddress = new TensorAddress.Builder(mixedBuilder.type().mappedSubtype());
+ address.traverse((ObjectTraverser) (dimension, label) -> sparseAddress.add(dimension, label.asString()));
+
+ double[] values = new double[(int)mixedBuilder.denseSubspaceSize()];
+ valuesBlock.traverse((ArrayTraverser) (index, value) -> values[index] = value.asDouble());
+
+ mixedBuilder.block(sparseAddress.build(), values);
+ }
+
private static void decodeValues(Inspector values, Tensor.Builder builder) {
if ( ! (builder instanceof IndexedTensor.BoundBuilder))
throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " +