diff options
Diffstat (limited to 'vespajlib/src/main')
3 files changed, 47 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 95f64cec0c1..1f3c373c1e8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -212,7 +212,6 @@ public class MixedTensor implements Tensor { } - /** * Builder for mixed tensors with bound indexed dimensions. */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index bafec70be59..95cc70804e2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -80,11 +80,20 @@ public class TensorType { /** Sorted list of the dimensions of this */ private final ImmutableList<Dimension> dimensions; + private final TensorType mappedSubtype; + private TensorType(Value valueType, Collection<Dimension> dimensions) { this.valueType = valueType; List<Dimension> dimensionList = new ArrayList<>(dimensions); Collections.sort(dimensionList); this.dimensions = ImmutableList.copyOf(dimensionList); + + if (dimensionList.stream().allMatch(d -> d.isIndexed())) + mappedSubtype = empty; + else if (dimensionList.stream().noneMatch(d -> d.isIndexed())) + mappedSubtype = this; + else + mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> ! d.isIndexed()).collect(Collectors.toList())); } static public Value combinedValueType(TensorType ... types) { @@ -116,6 +125,9 @@ public class TensorType { /** Returns the numeric type of the cell values of this */ public Value valueType() { return valueType; } + /** The type representing the mapped subset of dimensions of this. */ + public TensorType mappedSubtype() { return mappedSubtype; } + /** Returns the number of dimensions of this: dimensions().size() */ public int rank() { return dimensions.size(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index fa022e2bdd1..2233622db3e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -10,6 +10,7 @@ import com.yahoo.slime.ObjectTraverser; import com.yahoo.slime.Slime; import com.yahoo.slime.Type; import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -75,21 +76,48 @@ public class JsonFormat { private static void decodeCells(Inspector cells, Tensor.Builder builder) { if ( cells.type() != Type.ARRAY) throw new IllegalArgumentException("Excepted 'cells' to contain an array, not " + cells.type()); - cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder.cell())); + cells.traverse((ArrayTraverser) (__, cell) -> decodeCellOrCells(cell, builder)); } - private static void decodeCell(Inspector cell, Tensor.Builder.CellBuilder cellBuilder) { - Inspector address = cell.field("address"); + private static void decodeCellOrCells(Inspector cell, Tensor.Builder builder) { + Inspector value = cell.field("value"); + if (value.type() == Type.LONG || value.type() == Type.DOUBLE) { + decodeCell(cell.field("address"), value, builder.cell()); + } + else { + Inspector values = cell.field("values"); + if (values.type() == Type.ARRAY) + decodeValueBlock(cell.field("address"), values, builder); + else + throw new IllegalArgumentException("Expected a cell to contain a numeric 'value' or an array 'values'"); + } + } + + private static void decodeCell(Inspector address, Inspector value, Tensor.Builder.CellBuilder cellBuilder) { if ( address.type() != Type.OBJECT) throw new IllegalArgumentException("Excepted a cell to contain an object called 'address'"); address.traverse((ObjectTraverser) (dimension, label) -> cellBuilder.label(dimension, label.asString())); - - Inspector value = cell.field("value"); - if (value.type() != Type.LONG && value.type() != Type.DOUBLE) - throw new IllegalArgumentException("Excepted a cell to contain a numeric value called 'value'"); cellBuilder.value(value.asDouble()); } + private static void decodeValueBlock(Inspector address, Inspector valuesBlock, Tensor.Builder builder) { + if ( ! (builder instanceof MixedTensor.BoundBuilder)) + throw new IllegalArgumentException("Sending 'values' in 'cells' is only permissible with a mixed tensor " + + "type with bound indexed dimensions, but the type is " + + builder.type()); + MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder)builder; + + if (address.type() != Type.OBJECT) + throw new IllegalArgumentException("Expected a cell to contain an object called 'address'"); + TensorAddress.Builder sparseAddress = new TensorAddress.Builder(mixedBuilder.type().mappedSubtype()); + address.traverse((ObjectTraverser) (dimension, label) -> sparseAddress.add(dimension, label.asString())); + + double[] values = new double[(int)mixedBuilder.denseSubspaceSize()]; + valuesBlock.traverse((ArrayTraverser) (index, value) -> values[index] = value.asDouble()); + + mixedBuilder.block(sparseAddress.build(), values); + } + private static void decodeValues(Inspector values, Tensor.Builder builder) { if ( ! (builder instanceof IndexedTensor.BoundBuilder)) throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + |