diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-06 08:57:09 -0800 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-06 08:57:09 -0800 |
commit | 7ef64a61b4f04a400428fe58ed2475aa37c43d39 (patch) | |
tree | 590627375d361e3d879285abb4210e70b84a29b0 /vespajlib/src/main/java/com/yahoo/tensor/functions | |
parent | e4b328f4ee05b55131420df7f6b5a3685d5dffa5 (diff) |
Generalized Slice tensor function
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java | 266 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java | 189 |
2 files changed, 266 insertions, 189 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java new file mode 100644 index 00000000000..4d3989b8782 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -0,0 +1,266 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.google.common.annotations.Beta; +import com.yahoo.tensor.PartialAddress; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Iterator; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Returns a subspace of a tensor + * + * @author bratseth + */ +@Beta +public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { + + private final TensorFunction<NAMETYPE> argument; + private final List<DimensionValue<NAMETYPE>> subspaceAddress; + + /** + * Creates a value function + * + * @param argument the tensor to return a cell value from + * @param subspaceAddress a description of the address of the cell to return the value of. This is not a TensorAddress + * because those require a type, but a type is not resolved until this is evaluated + */ + public Slice(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> subspaceAddress) { + this.argument = Objects.requireNonNull(argument, "Argument cannot be null"); + if (subspaceAddress.size() > 1 && subspaceAddress.stream().anyMatch(c -> c.dimension().isEmpty())) + throw new IllegalArgumentException("Short form of subspace addresses is only supported with a single dimension: " + + "Specify dimension names explicitly instead"); + this.subspaceAddress = subspaceAddress; + } + + @Override + public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); } + + @Override + public Slice<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { + if (arguments.size() != 1) + throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size()); + return new Slice<>(arguments.get(0), subspaceAddress); + } + + @Override + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } + + @Override + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { + Tensor tensor = argument.evaluate(context); + TensorType resultType = resultType(tensor.type()); + + PartialAddress subspaceAddress = subspaceToAddress(tensor.type(), context); + if (resultType.rank() == 0) // shortcut common case + return Tensor.from(tensor.get(subspaceAddress.asAddress(tensor.type()))); + + Tensor.Builder b = Tensor.Builder.of(resultType); + for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { + Tensor.Cell cell = i.next(); + if (matches(subspaceAddress, cell.getKey(), tensor.type())) + b.cell(remaining(resultType, cell.getKey(), tensor.type()), cell.getValue()); + } + return b.build(); + } + + private PartialAddress subspaceToAddress(TensorType type, EvaluationContext<NAMETYPE> context) { + PartialAddress.Builder b = new PartialAddress.Builder(subspaceAddress.size()); + for (int i = 0; i < subspaceAddress.size(); i++) { + if (subspaceAddress.get(i).label().isPresent()) + b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()), + subspaceAddress.get(i).label().get()); + else + b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()), + subspaceAddress.get(i).index().get().apply(context).intValue()); + } + return b.build(); + } + + private boolean matches(PartialAddress subspaceAddress, + TensorAddress address, TensorType type) { + for (int i = 0; i < subspaceAddress.size(); i++) { + String label = address.label(type.indexOfDimension(subspaceAddress.dimension(i)).get()); + if ( ! label.equals(subspaceAddress.label(i))) + return false; + } + return true; + } + + /** Returns the subset of the given address which is present in the subspace type */ + private TensorAddress remaining(TensorType subspaceType, TensorAddress address, TensorType type) { + TensorAddress.Builder b = new TensorAddress.Builder(subspaceType); + for (int i = 0; i < address.size(); i++) { + String dimension = type.dimensions().get(i).name(); + if (subspaceType.dimension(type.dimensions().get(i).name()).isPresent()) + b.add(dimension, address.label(i)); + } + return b.build(); + } + + @Override + public TensorType type(TypeContext<NAMETYPE> context) { + return resultType(argument.type(context)); + } + + private TensorType resultType(TensorType argumentType) { + TensorType.Builder b = new TensorType.Builder(); + + // Special case where a single indexed or mapped dimension is sliced + if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { + if (subspaceAddress.get(0).index().isPresent()) { + if (argumentType.dimensions().stream().filter(d -> d.isIndexed()).count() > 1) + throw new IllegalArgumentException(this + " slices a single indexed dimension, cannot be applied " + + " to " + argumentType + ", which have multiple"); + for (TensorType.Dimension dimension : argumentType.dimensions()) { + if ( ! dimension.isIndexed()) + b.dimension(dimension); + } + } + else { + if (argumentType.dimensions().stream().filter(d -> ! d.isIndexed()).count() > 1) + throw new IllegalArgumentException(this + " slices a single mapped dimension, cannot be applied " + + " to " + argumentType + ", which have multiple"); + for (TensorType.Dimension dimension : argumentType.dimensions()) { + if (dimension.isIndexed()) + b.dimension(dimension); + } + + } + } + else { // general slicing + Set<String> slicedDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toSet()); + for (TensorType.Dimension dimension : argumentType.dimensions()) { + if (slicedDimensions.contains(dimension.name())) + slicedDimensions.remove(dimension.name()); + else + b.dimension(dimension); + } + if ( ! slicedDimensions.isEmpty()) + throw new IllegalArgumentException(this + " slices " + slicedDimensions + " which are not present in " + + argumentType); + } + return b.build(); + } + + @Override + public String toString(ToStringContext context) { + StringBuilder b = new StringBuilder(argument.toString(context)); + if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { + if (subspaceAddress.get(0).index().isPresent()) + b.append("[").append(subspaceAddress.get(0).index().get().toString(context)).append("]"); + else + b.append("{").append(subspaceAddress.get(0).label().get()).append("}"); + } + else { + b.append("{").append(subspaceAddress.stream().map(i -> i.toString(context)).collect(Collectors.joining(", "))).append("}"); + } + return b.toString(); + } + + public static class DimensionValue<NAMETYPE extends Name> { + + private final Optional<String> dimension; + + /** The label of this, or null if index is set */ + private final String label; + + /** The function returning the index of this, or null if label is set */ + private final ScalarFunction<NAMETYPE> index; + + public DimensionValue(String dimension, String label) { + this(Optional.of(dimension), label, null); + } + + public DimensionValue(String dimension, int index) { + this(Optional.of(dimension), null, new ConstantIntegerFunction<>(index)); + } + + public DimensionValue(int index) { + this(Optional.empty(), null, new ConstantIntegerFunction<>(index)); + } + + public DimensionValue(String label) { + this(Optional.empty(), label, null); + } + + public DimensionValue(ScalarFunction<NAMETYPE> index) { + this(Optional.empty(), null, index); + } + + public DimensionValue(Optional<String> dimension, String label) { + this(dimension, label, null); + } + + public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) { + this(dimension, null, index); + } + + public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) { + this(Optional.of(dimension), null, index); + } + + private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) { + this.dimension = dimension; + this.label = label; + this.index = index; + } + + /** + * Returns the given name of the dimension, or null if dense form is used, such that name + * must be inferred from order + */ + public Optional<String> dimension() { return dimension; } + + /** Returns the label for this dimension or empty if it is provided by an index function */ + public Optional<String> label() { return Optional.ofNullable(label); } + + /** Returns the index expression for this dimension, or empty if it is not a number */ + public Optional<ScalarFunction<NAMETYPE>> index() { return Optional.ofNullable(index); } + + @Override + public String toString() { + return toString(null); + } + + public String toString(ToStringContext context) { + StringBuilder b = new StringBuilder(); + dimension.ifPresent(d -> b.append(d).append(":")); + if (label != null) + b.append(label); + else + b.append(index.toString(context)); + return b.toString(); + } + + } + + private static class ConstantIntegerFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> { + + private final int value; + + public ConstantIntegerFunction(int value) { + this.value = value; + } + + @Override + public Double apply(EvaluationContext<NAMETYPE> context) { + return (double)value; + } + + @Override + public String toString() { return String.valueOf(value); } + + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java deleted file mode 100644 index 37a54807673..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.tensor.functions; - -import com.google.common.annotations.Beta; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.EvaluationContext; -import com.yahoo.tensor.evaluation.Name; -import com.yahoo.tensor.evaluation.TypeContext; - -import java.util.List; -import java.util.Objects; -import java.util.Optional; -import java.util.stream.Collectors; - -/** - * Returns the value of a cell of a tensor (as a rank 0 tensor). - * - * @author bratseth - */ -@Beta -public class Value<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { - - private final TensorFunction<NAMETYPE> argument; - private final List<DimensionValue<NAMETYPE>> cellAddress; - - /** - * Creates a value function - * - * @param argument the tensor to return a cell value from - * @param cellAddress a description of the address of the cell to return the value of. This is not a TensorAddress - * because those require a type, but a type is not resolved until this is evaluated - */ - public Value(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> cellAddress) { - this.argument = Objects.requireNonNull(argument, "Argument cannot be null"); - if (cellAddress.size() > 1 && cellAddress.stream().anyMatch(c -> c.dimension().isEmpty())) - throw new IllegalArgumentException("Short form of cell addresses is only supported with a single dimension: " + - "Specify dimension names explicitly"); - this.cellAddress = cellAddress; - } - - @Override - public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); } - - @Override - public Value<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { - if (arguments.size() != 1) - throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size()); - return new Value<>(arguments.get(0), cellAddress); - } - - @Override - public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } - - @Override - public Tensor evaluate(EvaluationContext<NAMETYPE> context) { - Tensor tensor = argument.evaluate(context); - if (tensor.type().rank() != cellAddress.size()) - throw new IllegalArgumentException("Type/address size mismatch: Cannot address a value with " + toString() + - " to a tensor of type " + tensor.type()); - TensorAddress.Builder b = new TensorAddress.Builder(tensor.type()); - for (int i = 0; i < cellAddress.size(); i++) { - if (cellAddress.get(i).label().isPresent()) - b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()), - cellAddress.get(i).label().get()); - else - b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()), - String.valueOf(cellAddress.get(i).index().get().apply(context).intValue())); - } - return Tensor.from(tensor.get(b.build())); - } - - @Override - public TensorType type(TypeContext<NAMETYPE> context) { - return new TensorType.Builder(argument.type(context).valueType()).build(); - } - - @Override - public String toString(ToStringContext context) { - StringBuilder b = new StringBuilder(argument.toString(context)); - if (cellAddress.size() == 1 && cellAddress.get(0).dimension().isEmpty()) { - if (cellAddress.get(0).index().isPresent()) - b.append("[").append(cellAddress.get(0).index().get().toString(context)).append("]"); - else - b.append("{").append(cellAddress.get(0).label().get()).append("}"); - } - else { - b.append("{").append(cellAddress.stream().map(i -> i.toString(context)).collect(Collectors.joining(", "))).append("}"); - } - return b.toString(); - } - - public static class DimensionValue<NAMETYPE extends Name> { - - private final Optional<String> dimension; - - /** The label of this, or null if index is set */ - private final String label; - - /** The function returning the index of this, or null if label is set */ - private final ScalarFunction<NAMETYPE> index; - - public DimensionValue(String dimension, String label) { - this(Optional.of(dimension), label, null); - } - - public DimensionValue(String dimension, int index) { - this(Optional.of(dimension), null, new ConstantIntegerFunction<>(index)); - } - - public DimensionValue(int index) { - this(Optional.empty(), null, new ConstantIntegerFunction<>(index)); - } - - public DimensionValue(String label) { - this(Optional.empty(), label, null); - } - - public DimensionValue(ScalarFunction<NAMETYPE> index) { - this(Optional.empty(), null, index); - } - - public DimensionValue(Optional<String> dimension, String label) { - this(dimension, label, null); - } - - public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) { - this(dimension, null, index); - } - - public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) { - this(Optional.of(dimension), null, index); - } - - private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) { - this.dimension = dimension; - this.label = label; - this.index = index; - } - - /** - * Returns the given name of the dimension, or null if dense form is used, such that name - * must be inferred from order - */ - public Optional<String> dimension() { return dimension; } - - /** Returns the label for this dimension or empty if it is provided by an index function */ - public Optional<String> label() { return Optional.ofNullable(label); } - - /** Returns the index expression for this dimension, or empty if it is not a number */ - public Optional<ScalarFunction<NAMETYPE>> index() { return Optional.ofNullable(index); } - - @Override - public String toString() { - return toString(null); - } - - public String toString(ToStringContext context) { - StringBuilder b = new StringBuilder(); - dimension.ifPresent(d -> b.append(d).append(":")); - if (label != null) - b.append(label); - else - b.append(index.toString(context)); - return b.toString(); - } - - } - - private static class ConstantIntegerFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> { - - private final int value; - - public ConstantIntegerFunction(int value) { - this.value = value; - } - - @Override - public Double apply(EvaluationContext<NAMETYPE> context) { - return (double)value; - } - - @Override - public String toString() { return String.valueOf(value); } - - } - -} |