diff options
Diffstat (limited to 'vespajlib/src/main/java/com')
4 files changed, 370 insertions, 208 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 1f3c373c1e8..1cde1fcdbb7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -54,9 +54,8 @@ public class MixedTensor implements Tensor { public double get(TensorAddress address) { long cellIndex = index.indexOf(address); Cell cell = cells.get((int)cellIndex); - if (!address.equals(cell.getKey())) { - throw new IllegalStateException("Unable to find correct cell by direct index."); - } + if ( ! address.equals(cell.getKey())) + throw new IllegalStateException("Unable to find correct cell in " + this + " by direct index " + address); return cell.getValue(); } @@ -375,9 +374,8 @@ public class MixedTensor implements Tensor { public long indexOf(TensorAddress address) { TensorAddress sparsePart = sparsePartialAddress(address); - if ( ! sparseMap.containsKey(sparsePart)) { - throw new IllegalArgumentException("Address not found"); - } + if ( ! sparseMap.containsKey(sparsePart)) + throw new IllegalArgumentException("Address subspace " + sparsePart + " not found in " + this); long base = sparseMap.get(sparsePart); long offset = denseOffset(address); return base + offset; @@ -414,7 +412,7 @@ public class MixedTensor implements Tensor { TensorType.Dimension dimension = type.dimensions().get(i); if (dimension.isIndexed()) { denseSubspaceSize *= dimension.size().orElseThrow(() -> - new IllegalArgumentException("Unknown size of indexed dimension.")); + new IllegalArgumentException("Unknown size of indexed dimension")); } } } @@ -422,15 +420,13 @@ public class MixedTensor implements Tensor { } private TensorAddress sparsePartialAddress(TensorAddress address) { - if (type.dimensions().size() != address.size()) { - throw new IllegalArgumentException("Tensor type and address are not of same size."); - } + if (type.dimensions().size() != address.size()) + throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + address); TensorAddress.Builder builder = new TensorAddress.Builder(sparseType); for (int i = 0; i < type.dimensions().size(); ++i) { TensorType.Dimension dimension = type.dimensions().get(i); - if (!dimension.isIndexed()) { + if ( ! dimension.isIndexed()) builder.add(dimension.name(), address.label(i)); - } } return builder.build(); } @@ -488,6 +484,11 @@ public class MixedTensor implements Tensor { return TensorAddress.of(labels); } + @Override + public String toString() { + return "indexes into " + type; + } + } public static TensorType createPartialType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java index 9c41d5aad68..4eca9c47402 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import java.util.Arrays; + /** * An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors * dimensions. @@ -13,10 +15,10 @@ package com.yahoo.tensor; // - We can add support for string labels later without breaking the API public class PartialAddress { - // Two arrays which contains corresponding dimension=label pairs. + // Two arrays which contains corresponding dimension:label pairs. // The sizes of these are always equal. private final String[] dimensionNames; - private final long[] labels; + private final Object[] labels; private PartialAddress(Builder builder) { this.dimensionNames = builder.dimensionNames; @@ -25,23 +27,99 @@ public class PartialAddress { builder.labels = null; } - /** Returns the int label of this dimension, or -1 if no label is specified for it */ - long numericLabel(String dimensionName) { + public String dimension(int i) { + return dimensionNames[i]; + } + + /** Returns the numeric label of this dimension, or -1 if no label is specified for it */ + public long numericLabel(String dimensionName) { for (int i = 0; i < dimensionNames.length; i++) if (dimensionNames[i].equals(dimensionName)) - return labels[i]; + return asLong(labels[i]); return -1; } + /** Returns the label of this dimension, or null if no label is specified for it */ + public String label(String dimensionName) { + for (int i = 0; i < dimensionNames.length; i++) + if (dimensionNames[i].equals(dimensionName)) + return labels[i].toString(); + return null; + } + + /** + * Returns the label at position i + * + * @throws IllegalArgumentException if i is out of bounds + */ + public String label(int i) { + if (i >= size()) + throw new IllegalArgumentException("No label at position " + i + " in " + this); + return labels[i].toString(); + } + + public int size() { return dimensionNames.length; } + + /** Returns this as an address in the given tensor type */ + // We need the type here not just for validation but because this must map to the dimension order given by the type + public TensorAddress asAddress(TensorType type) { + if (type.rank() != size()) + throw new IllegalArgumentException(type + " has a different rank than " + this); + if (Arrays.stream(labels).allMatch(l -> l instanceof Long)) { + long[] numericLabels = new long[labels.length]; + for (int i = 0; i < type.dimensions().size(); i++) { + long label = numericLabel(type.dimensions().get(i).name()); + if (label < 0) + throw new IllegalArgumentException(type + " dimension names does not match " + this); + numericLabels[i] = label; + } + return TensorAddress.of(numericLabels); + } + else { + String[] stringLabels = new String[labels.length]; + for (int i = 0; i < type.dimensions().size(); i++) { + String label = label(type.dimensions().get(i).name()); + if (label == null) + throw new IllegalArgumentException(type + " dimension names does not match " + this); + stringLabels[i] = label; + } + return TensorAddress.of(stringLabels); + } + } + + private long asLong(Object label) { + if (label instanceof Long) { + return (Long) label; + } + else { + try { + return Long.parseLong(label.toString()); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Label '" + label + "' is not numeric"); + } + } + } + + @Override + public String toString() { + StringBuilder b = new StringBuilder("Partial address {"); + for (int i = 0; i < dimensionNames.length; i++) + b.append(dimensionNames[i]).append(":").append(label(i)).append(", "); + if (size() > 0) + b.setLength(b.length() - 2); + return b.toString(); + } + public static class Builder { private String[] dimensionNames; - private long[] labels; + private Object[] labels; private int index = 0; public Builder(int size) { dimensionNames = new String[size]; - labels = new long[size]; + labels = new Object[size]; } public void add(String dimensionName, long label) { @@ -50,6 +128,12 @@ public class PartialAddress { index++; } + public void add(String dimensionName, String label) { + dimensionNames[index] = dimensionName; + labels[index] = label; + index++; + } + public PartialAddress build() { return new PartialAddress(this); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java new file mode 100644 index 00000000000..4d3989b8782 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -0,0 +1,266 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.google.common.annotations.Beta; +import com.yahoo.tensor.PartialAddress; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Iterator; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Returns a subspace of a tensor + * + * @author bratseth + */ +@Beta +public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { + + private final TensorFunction<NAMETYPE> argument; + private final List<DimensionValue<NAMETYPE>> subspaceAddress; + + /** + * Creates a value function + * + * @param argument the tensor to return a cell value from + * @param subspaceAddress a description of the address of the cell to return the value of. This is not a TensorAddress + * because those require a type, but a type is not resolved until this is evaluated + */ + public Slice(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> subspaceAddress) { + this.argument = Objects.requireNonNull(argument, "Argument cannot be null"); + if (subspaceAddress.size() > 1 && subspaceAddress.stream().anyMatch(c -> c.dimension().isEmpty())) + throw new IllegalArgumentException("Short form of subspace addresses is only supported with a single dimension: " + + "Specify dimension names explicitly instead"); + this.subspaceAddress = subspaceAddress; + } + + @Override + public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); } + + @Override + public Slice<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { + if (arguments.size() != 1) + throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size()); + return new Slice<>(arguments.get(0), subspaceAddress); + } + + @Override + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } + + @Override + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { + Tensor tensor = argument.evaluate(context); + TensorType resultType = resultType(tensor.type()); + + PartialAddress subspaceAddress = subspaceToAddress(tensor.type(), context); + if (resultType.rank() == 0) // shortcut common case + return Tensor.from(tensor.get(subspaceAddress.asAddress(tensor.type()))); + + Tensor.Builder b = Tensor.Builder.of(resultType); + for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { + Tensor.Cell cell = i.next(); + if (matches(subspaceAddress, cell.getKey(), tensor.type())) + b.cell(remaining(resultType, cell.getKey(), tensor.type()), cell.getValue()); + } + return b.build(); + } + + private PartialAddress subspaceToAddress(TensorType type, EvaluationContext<NAMETYPE> context) { + PartialAddress.Builder b = new PartialAddress.Builder(subspaceAddress.size()); + for (int i = 0; i < subspaceAddress.size(); i++) { + if (subspaceAddress.get(i).label().isPresent()) + b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()), + subspaceAddress.get(i).label().get()); + else + b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()), + subspaceAddress.get(i).index().get().apply(context).intValue()); + } + return b.build(); + } + + private boolean matches(PartialAddress subspaceAddress, + TensorAddress address, TensorType type) { + for (int i = 0; i < subspaceAddress.size(); i++) { + String label = address.label(type.indexOfDimension(subspaceAddress.dimension(i)).get()); + if ( ! label.equals(subspaceAddress.label(i))) + return false; + } + return true; + } + + /** Returns the subset of the given address which is present in the subspace type */ + private TensorAddress remaining(TensorType subspaceType, TensorAddress address, TensorType type) { + TensorAddress.Builder b = new TensorAddress.Builder(subspaceType); + for (int i = 0; i < address.size(); i++) { + String dimension = type.dimensions().get(i).name(); + if (subspaceType.dimension(type.dimensions().get(i).name()).isPresent()) + b.add(dimension, address.label(i)); + } + return b.build(); + } + + @Override + public TensorType type(TypeContext<NAMETYPE> context) { + return resultType(argument.type(context)); + } + + private TensorType resultType(TensorType argumentType) { + TensorType.Builder b = new TensorType.Builder(); + + // Special case where a single indexed or mapped dimension is sliced + if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { + if (subspaceAddress.get(0).index().isPresent()) { + if (argumentType.dimensions().stream().filter(d -> d.isIndexed()).count() > 1) + throw new IllegalArgumentException(this + " slices a single indexed dimension, cannot be applied " + + " to " + argumentType + ", which have multiple"); + for (TensorType.Dimension dimension : argumentType.dimensions()) { + if ( ! dimension.isIndexed()) + b.dimension(dimension); + } + } + else { + if (argumentType.dimensions().stream().filter(d -> ! d.isIndexed()).count() > 1) + throw new IllegalArgumentException(this + " slices a single mapped dimension, cannot be applied " + + " to " + argumentType + ", which have multiple"); + for (TensorType.Dimension dimension : argumentType.dimensions()) { + if (dimension.isIndexed()) + b.dimension(dimension); + } + + } + } + else { // general slicing + Set<String> slicedDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toSet()); + for (TensorType.Dimension dimension : argumentType.dimensions()) { + if (slicedDimensions.contains(dimension.name())) + slicedDimensions.remove(dimension.name()); + else + b.dimension(dimension); + } + if ( ! slicedDimensions.isEmpty()) + throw new IllegalArgumentException(this + " slices " + slicedDimensions + " which are not present in " + + argumentType); + } + return b.build(); + } + + @Override + public String toString(ToStringContext context) { + StringBuilder b = new StringBuilder(argument.toString(context)); + if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { + if (subspaceAddress.get(0).index().isPresent()) + b.append("[").append(subspaceAddress.get(0).index().get().toString(context)).append("]"); + else + b.append("{").append(subspaceAddress.get(0).label().get()).append("}"); + } + else { + b.append("{").append(subspaceAddress.stream().map(i -> i.toString(context)).collect(Collectors.joining(", "))).append("}"); + } + return b.toString(); + } + + public static class DimensionValue<NAMETYPE extends Name> { + + private final Optional<String> dimension; + + /** The label of this, or null if index is set */ + private final String label; + + /** The function returning the index of this, or null if label is set */ + private final ScalarFunction<NAMETYPE> index; + + public DimensionValue(String dimension, String label) { + this(Optional.of(dimension), label, null); + } + + public DimensionValue(String dimension, int index) { + this(Optional.of(dimension), null, new ConstantIntegerFunction<>(index)); + } + + public DimensionValue(int index) { + this(Optional.empty(), null, new ConstantIntegerFunction<>(index)); + } + + public DimensionValue(String label) { + this(Optional.empty(), label, null); + } + + public DimensionValue(ScalarFunction<NAMETYPE> index) { + this(Optional.empty(), null, index); + } + + public DimensionValue(Optional<String> dimension, String label) { + this(dimension, label, null); + } + + public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) { + this(dimension, null, index); + } + + public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) { + this(Optional.of(dimension), null, index); + } + + private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) { + this.dimension = dimension; + this.label = label; + this.index = index; + } + + /** + * Returns the given name of the dimension, or null if dense form is used, such that name + * must be inferred from order + */ + public Optional<String> dimension() { return dimension; } + + /** Returns the label for this dimension or empty if it is provided by an index function */ + public Optional<String> label() { return Optional.ofNullable(label); } + + /** Returns the index expression for this dimension, or empty if it is not a number */ + public Optional<ScalarFunction<NAMETYPE>> index() { return Optional.ofNullable(index); } + + @Override + public String toString() { + return toString(null); + } + + public String toString(ToStringContext context) { + StringBuilder b = new StringBuilder(); + dimension.ifPresent(d -> b.append(d).append(":")); + if (label != null) + b.append(label); + else + b.append(index.toString(context)); + return b.toString(); + } + + } + + private static class ConstantIntegerFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> { + + private final int value; + + public ConstantIntegerFunction(int value) { + this.value = value; + } + + @Override + public Double apply(EvaluationContext<NAMETYPE> context) { + return (double)value; + } + + @Override + public String toString() { return String.valueOf(value); } + + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java deleted file mode 100644 index 37a54807673..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.tensor.functions; - -import com.google.common.annotations.Beta; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.EvaluationContext; -import com.yahoo.tensor.evaluation.Name; -import com.yahoo.tensor.evaluation.TypeContext; - -import java.util.List; -import java.util.Objects; -import java.util.Optional; -import java.util.stream.Collectors; - -/** - * Returns the value of a cell of a tensor (as a rank 0 tensor). - * - * @author bratseth - */ -@Beta -public class Value<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { - - private final TensorFunction<NAMETYPE> argument; - private final List<DimensionValue<NAMETYPE>> cellAddress; - - /** - * Creates a value function - * - * @param argument the tensor to return a cell value from - * @param cellAddress a description of the address of the cell to return the value of. This is not a TensorAddress - * because those require a type, but a type is not resolved until this is evaluated - */ - public Value(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> cellAddress) { - this.argument = Objects.requireNonNull(argument, "Argument cannot be null"); - if (cellAddress.size() > 1 && cellAddress.stream().anyMatch(c -> c.dimension().isEmpty())) - throw new IllegalArgumentException("Short form of cell addresses is only supported with a single dimension: " + - "Specify dimension names explicitly"); - this.cellAddress = cellAddress; - } - - @Override - public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); } - - @Override - public Value<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { - if (arguments.size() != 1) - throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size()); - return new Value<>(arguments.get(0), cellAddress); - } - - @Override - public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } - - @Override - public Tensor evaluate(EvaluationContext<NAMETYPE> context) { - Tensor tensor = argument.evaluate(context); - if (tensor.type().rank() != cellAddress.size()) - throw new IllegalArgumentException("Type/address size mismatch: Cannot address a value with " + toString() + - " to a tensor of type " + tensor.type()); - TensorAddress.Builder b = new TensorAddress.Builder(tensor.type()); - for (int i = 0; i < cellAddress.size(); i++) { - if (cellAddress.get(i).label().isPresent()) - b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()), - cellAddress.get(i).label().get()); - else - b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()), - String.valueOf(cellAddress.get(i).index().get().apply(context).intValue())); - } - return Tensor.from(tensor.get(b.build())); - } - - @Override - public TensorType type(TypeContext<NAMETYPE> context) { - return new TensorType.Builder(argument.type(context).valueType()).build(); - } - - @Override - public String toString(ToStringContext context) { - StringBuilder b = new StringBuilder(argument.toString(context)); - if (cellAddress.size() == 1 && cellAddress.get(0).dimension().isEmpty()) { - if (cellAddress.get(0).index().isPresent()) - b.append("[").append(cellAddress.get(0).index().get().toString(context)).append("]"); - else - b.append("{").append(cellAddress.get(0).label().get()).append("}"); - } - else { - b.append("{").append(cellAddress.stream().map(i -> i.toString(context)).collect(Collectors.joining(", "))).append("}"); - } - return b.toString(); - } - - public static class DimensionValue<NAMETYPE extends Name> { - - private final Optional<String> dimension; - - /** The label of this, or null if index is set */ - private final String label; - - /** The function returning the index of this, or null if label is set */ - private final ScalarFunction<NAMETYPE> index; - - public DimensionValue(String dimension, String label) { - this(Optional.of(dimension), label, null); - } - - public DimensionValue(String dimension, int index) { - this(Optional.of(dimension), null, new ConstantIntegerFunction<>(index)); - } - - public DimensionValue(int index) { - this(Optional.empty(), null, new ConstantIntegerFunction<>(index)); - } - - public DimensionValue(String label) { - this(Optional.empty(), label, null); - } - - public DimensionValue(ScalarFunction<NAMETYPE> index) { - this(Optional.empty(), null, index); - } - - public DimensionValue(Optional<String> dimension, String label) { - this(dimension, label, null); - } - - public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) { - this(dimension, null, index); - } - - public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) { - this(Optional.of(dimension), null, index); - } - - private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) { - this.dimension = dimension; - this.label = label; - this.index = index; - } - - /** - * Returns the given name of the dimension, or null if dense form is used, such that name - * must be inferred from order - */ - public Optional<String> dimension() { return dimension; } - - /** Returns the label for this dimension or empty if it is provided by an index function */ - public Optional<String> label() { return Optional.ofNullable(label); } - - /** Returns the index expression for this dimension, or empty if it is not a number */ - public Optional<ScalarFunction<NAMETYPE>> index() { return Optional.ofNullable(index); } - - @Override - public String toString() { - return toString(null); - } - - public String toString(ToStringContext context) { - StringBuilder b = new StringBuilder(); - dimension.ifPresent(d -> b.append(d).append(":")); - if (label != null) - b.append(label); - else - b.append(index.toString(context)); - return b.toString(); - } - - } - - private static class ConstantIntegerFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> { - - private final int value; - - public ConstantIntegerFunction(int value) { - this.value = value; - } - - @Override - public Double apply(EvaluationContext<NAMETYPE> context) { - return (double)value; - } - - @Override - public String toString() { return String.valueOf(value); } - - } - -} |