aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-12-06 08:57:09 -0800
committerJon Bratseth <bratseth@verizonmedia.com>2019-12-06 08:57:09 -0800
commit7ef64a61b4f04a400428fe58ed2475aa37c43d39 (patch)
tree590627375d361e3d879285abb4210e70b84a29b0 /vespajlib/src/main/java/com/yahoo/tensor/functions
parente4b328f4ee05b55131420df7f6b5a3685d5dffa5 (diff)
Generalized Slice tensor function
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java266
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java189
2 files changed, 266 insertions, 189 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
new file mode 100644
index 00000000000..4d3989b8782
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
@@ -0,0 +1,266 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.google.common.annotations.Beta;
+import com.yahoo.tensor.PartialAddress;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Returns a subspace of a tensor
+ *
+ * @author bratseth
+ */
+@Beta
+public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
+
+ private final TensorFunction<NAMETYPE> argument;
+ private final List<DimensionValue<NAMETYPE>> subspaceAddress;
+
+ /**
+ * Creates a value function
+ *
+ * @param argument the tensor to return a cell value from
+ * @param subspaceAddress a description of the address of the cell to return the value of. This is not a TensorAddress
+ * because those require a type, but a type is not resolved until this is evaluated
+ */
+ public Slice(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> subspaceAddress) {
+ this.argument = Objects.requireNonNull(argument, "Argument cannot be null");
+ if (subspaceAddress.size() > 1 && subspaceAddress.stream().anyMatch(c -> c.dimension().isEmpty()))
+ throw new IllegalArgumentException("Short form of subspace addresses is only supported with a single dimension: " +
+ "Specify dimension names explicitly instead");
+ this.subspaceAddress = subspaceAddress;
+ }
+
+ @Override
+ public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); }
+
+ @Override
+ public Slice<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
+ if (arguments.size() != 1)
+ throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size());
+ return new Slice<>(arguments.get(0), subspaceAddress);
+ }
+
+ @Override
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }
+
+ @Override
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ Tensor tensor = argument.evaluate(context);
+ TensorType resultType = resultType(tensor.type());
+
+ PartialAddress subspaceAddress = subspaceToAddress(tensor.type(), context);
+ if (resultType.rank() == 0) // shortcut common case
+ return Tensor.from(tensor.get(subspaceAddress.asAddress(tensor.type())));
+
+ Tensor.Builder b = Tensor.Builder.of(resultType);
+ for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
+ Tensor.Cell cell = i.next();
+ if (matches(subspaceAddress, cell.getKey(), tensor.type()))
+ b.cell(remaining(resultType, cell.getKey(), tensor.type()), cell.getValue());
+ }
+ return b.build();
+ }
+
+ private PartialAddress subspaceToAddress(TensorType type, EvaluationContext<NAMETYPE> context) {
+ PartialAddress.Builder b = new PartialAddress.Builder(subspaceAddress.size());
+ for (int i = 0; i < subspaceAddress.size(); i++) {
+ if (subspaceAddress.get(i).label().isPresent())
+ b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()),
+ subspaceAddress.get(i).label().get());
+ else
+ b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()),
+ subspaceAddress.get(i).index().get().apply(context).intValue());
+ }
+ return b.build();
+ }
+
+ private boolean matches(PartialAddress subspaceAddress,
+ TensorAddress address, TensorType type) {
+ for (int i = 0; i < subspaceAddress.size(); i++) {
+ String label = address.label(type.indexOfDimension(subspaceAddress.dimension(i)).get());
+ if ( ! label.equals(subspaceAddress.label(i)))
+ return false;
+ }
+ return true;
+ }
+
+ /** Returns the subset of the given address which is present in the subspace type */
+ private TensorAddress remaining(TensorType subspaceType, TensorAddress address, TensorType type) {
+ TensorAddress.Builder b = new TensorAddress.Builder(subspaceType);
+ for (int i = 0; i < address.size(); i++) {
+ String dimension = type.dimensions().get(i).name();
+ if (subspaceType.dimension(type.dimensions().get(i).name()).isPresent())
+ b.add(dimension, address.label(i));
+ }
+ return b.build();
+ }
+
+ @Override
+ public TensorType type(TypeContext<NAMETYPE> context) {
+ return resultType(argument.type(context));
+ }
+
+ private TensorType resultType(TensorType argumentType) {
+ TensorType.Builder b = new TensorType.Builder();
+
+ // Special case where a single indexed or mapped dimension is sliced
+ if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) {
+ if (subspaceAddress.get(0).index().isPresent()) {
+ if (argumentType.dimensions().stream().filter(d -> d.isIndexed()).count() > 1)
+ throw new IllegalArgumentException(this + " slices a single indexed dimension, cannot be applied " +
+ " to " + argumentType + ", which have multiple");
+ for (TensorType.Dimension dimension : argumentType.dimensions()) {
+ if ( ! dimension.isIndexed())
+ b.dimension(dimension);
+ }
+ }
+ else {
+ if (argumentType.dimensions().stream().filter(d -> ! d.isIndexed()).count() > 1)
+ throw new IllegalArgumentException(this + " slices a single mapped dimension, cannot be applied " +
+ " to " + argumentType + ", which have multiple");
+ for (TensorType.Dimension dimension : argumentType.dimensions()) {
+ if (dimension.isIndexed())
+ b.dimension(dimension);
+ }
+
+ }
+ }
+ else { // general slicing
+ Set<String> slicedDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toSet());
+ for (TensorType.Dimension dimension : argumentType.dimensions()) {
+ if (slicedDimensions.contains(dimension.name()))
+ slicedDimensions.remove(dimension.name());
+ else
+ b.dimension(dimension);
+ }
+ if ( ! slicedDimensions.isEmpty())
+ throw new IllegalArgumentException(this + " slices " + slicedDimensions + " which are not present in " +
+ argumentType);
+ }
+ return b.build();
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ StringBuilder b = new StringBuilder(argument.toString(context));
+ if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) {
+ if (subspaceAddress.get(0).index().isPresent())
+ b.append("[").append(subspaceAddress.get(0).index().get().toString(context)).append("]");
+ else
+ b.append("{").append(subspaceAddress.get(0).label().get()).append("}");
+ }
+ else {
+ b.append("{").append(subspaceAddress.stream().map(i -> i.toString(context)).collect(Collectors.joining(", "))).append("}");
+ }
+ return b.toString();
+ }
+
+ public static class DimensionValue<NAMETYPE extends Name> {
+
+ private final Optional<String> dimension;
+
+ /** The label of this, or null if index is set */
+ private final String label;
+
+ /** The function returning the index of this, or null if label is set */
+ private final ScalarFunction<NAMETYPE> index;
+
+ public DimensionValue(String dimension, String label) {
+ this(Optional.of(dimension), label, null);
+ }
+
+ public DimensionValue(String dimension, int index) {
+ this(Optional.of(dimension), null, new ConstantIntegerFunction<>(index));
+ }
+
+ public DimensionValue(int index) {
+ this(Optional.empty(), null, new ConstantIntegerFunction<>(index));
+ }
+
+ public DimensionValue(String label) {
+ this(Optional.empty(), label, null);
+ }
+
+ public DimensionValue(ScalarFunction<NAMETYPE> index) {
+ this(Optional.empty(), null, index);
+ }
+
+ public DimensionValue(Optional<String> dimension, String label) {
+ this(dimension, label, null);
+ }
+
+ public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) {
+ this(dimension, null, index);
+ }
+
+ public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) {
+ this(Optional.of(dimension), null, index);
+ }
+
+ private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) {
+ this.dimension = dimension;
+ this.label = label;
+ this.index = index;
+ }
+
+ /**
+ * Returns the given name of the dimension, or null if dense form is used, such that name
+ * must be inferred from order
+ */
+ public Optional<String> dimension() { return dimension; }
+
+ /** Returns the label for this dimension or empty if it is provided by an index function */
+ public Optional<String> label() { return Optional.ofNullable(label); }
+
+ /** Returns the index expression for this dimension, or empty if it is not a number */
+ public Optional<ScalarFunction<NAMETYPE>> index() { return Optional.ofNullable(index); }
+
+ @Override
+ public String toString() {
+ return toString(null);
+ }
+
+ public String toString(ToStringContext context) {
+ StringBuilder b = new StringBuilder();
+ dimension.ifPresent(d -> b.append(d).append(":"));
+ if (label != null)
+ b.append(label);
+ else
+ b.append(index.toString(context));
+ return b.toString();
+ }
+
+ }
+
+ private static class ConstantIntegerFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> {
+
+ private final int value;
+
+ public ConstantIntegerFunction(int value) {
+ this.value = value;
+ }
+
+ @Override
+ public Double apply(EvaluationContext<NAMETYPE> context) {
+ return (double)value;
+ }
+
+ @Override
+ public String toString() { return String.valueOf(value); }
+
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java
deleted file mode 100644
index 37a54807673..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java
+++ /dev/null
@@ -1,189 +0,0 @@
-// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor.functions;
-
-import com.google.common.annotations.Beta;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.EvaluationContext;
-import com.yahoo.tensor.evaluation.Name;
-import com.yahoo.tensor.evaluation.TypeContext;
-
-import java.util.List;
-import java.util.Objects;
-import java.util.Optional;
-import java.util.stream.Collectors;
-
-/**
- * Returns the value of a cell of a tensor (as a rank 0 tensor).
- *
- * @author bratseth
- */
-@Beta
-public class Value<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
-
- private final TensorFunction<NAMETYPE> argument;
- private final List<DimensionValue<NAMETYPE>> cellAddress;
-
- /**
- * Creates a value function
- *
- * @param argument the tensor to return a cell value from
- * @param cellAddress a description of the address of the cell to return the value of. This is not a TensorAddress
- * because those require a type, but a type is not resolved until this is evaluated
- */
- public Value(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> cellAddress) {
- this.argument = Objects.requireNonNull(argument, "Argument cannot be null");
- if (cellAddress.size() > 1 && cellAddress.stream().anyMatch(c -> c.dimension().isEmpty()))
- throw new IllegalArgumentException("Short form of cell addresses is only supported with a single dimension: " +
- "Specify dimension names explicitly");
- this.cellAddress = cellAddress;
- }
-
- @Override
- public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); }
-
- @Override
- public Value<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
- if (arguments.size() != 1)
- throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size());
- return new Value<>(arguments.get(0), cellAddress);
- }
-
- @Override
- public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }
-
- @Override
- public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
- Tensor tensor = argument.evaluate(context);
- if (tensor.type().rank() != cellAddress.size())
- throw new IllegalArgumentException("Type/address size mismatch: Cannot address a value with " + toString() +
- " to a tensor of type " + tensor.type());
- TensorAddress.Builder b = new TensorAddress.Builder(tensor.type());
- for (int i = 0; i < cellAddress.size(); i++) {
- if (cellAddress.get(i).label().isPresent())
- b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()),
- cellAddress.get(i).label().get());
- else
- b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()),
- String.valueOf(cellAddress.get(i).index().get().apply(context).intValue()));
- }
- return Tensor.from(tensor.get(b.build()));
- }
-
- @Override
- public TensorType type(TypeContext<NAMETYPE> context) {
- return new TensorType.Builder(argument.type(context).valueType()).build();
- }
-
- @Override
- public String toString(ToStringContext context) {
- StringBuilder b = new StringBuilder(argument.toString(context));
- if (cellAddress.size() == 1 && cellAddress.get(0).dimension().isEmpty()) {
- if (cellAddress.get(0).index().isPresent())
- b.append("[").append(cellAddress.get(0).index().get().toString(context)).append("]");
- else
- b.append("{").append(cellAddress.get(0).label().get()).append("}");
- }
- else {
- b.append("{").append(cellAddress.stream().map(i -> i.toString(context)).collect(Collectors.joining(", "))).append("}");
- }
- return b.toString();
- }
-
- public static class DimensionValue<NAMETYPE extends Name> {
-
- private final Optional<String> dimension;
-
- /** The label of this, or null if index is set */
- private final String label;
-
- /** The function returning the index of this, or null if label is set */
- private final ScalarFunction<NAMETYPE> index;
-
- public DimensionValue(String dimension, String label) {
- this(Optional.of(dimension), label, null);
- }
-
- public DimensionValue(String dimension, int index) {
- this(Optional.of(dimension), null, new ConstantIntegerFunction<>(index));
- }
-
- public DimensionValue(int index) {
- this(Optional.empty(), null, new ConstantIntegerFunction<>(index));
- }
-
- public DimensionValue(String label) {
- this(Optional.empty(), label, null);
- }
-
- public DimensionValue(ScalarFunction<NAMETYPE> index) {
- this(Optional.empty(), null, index);
- }
-
- public DimensionValue(Optional<String> dimension, String label) {
- this(dimension, label, null);
- }
-
- public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) {
- this(dimension, null, index);
- }
-
- public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) {
- this(Optional.of(dimension), null, index);
- }
-
- private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) {
- this.dimension = dimension;
- this.label = label;
- this.index = index;
- }
-
- /**
- * Returns the given name of the dimension, or null if dense form is used, such that name
- * must be inferred from order
- */
- public Optional<String> dimension() { return dimension; }
-
- /** Returns the label for this dimension or empty if it is provided by an index function */
- public Optional<String> label() { return Optional.ofNullable(label); }
-
- /** Returns the index expression for this dimension, or empty if it is not a number */
- public Optional<ScalarFunction<NAMETYPE>> index() { return Optional.ofNullable(index); }
-
- @Override
- public String toString() {
- return toString(null);
- }
-
- public String toString(ToStringContext context) {
- StringBuilder b = new StringBuilder();
- dimension.ifPresent(d -> b.append(d).append(":"));
- if (label != null)
- b.append(label);
- else
- b.append(index.toString(context));
- return b.toString();
- }
-
- }
-
- private static class ConstantIntegerFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> {
-
- private final int value;
-
- public ConstantIntegerFunction(int value) {
- this.value = value;
- }
-
- @Override
- public Double apply(EvaluationContext<NAMETYPE> context) {
- return (double)value;
- }
-
- @Override
- public String toString() { return String.valueOf(value); }
-
- }
-
-}