// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; import com.yahoo.api.annotations.Beta; import com.yahoo.tensor.PartialAddress; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; /** * Returns a subspace of a tensor * * @author bratseth */ @Beta public class Slice extends PrimitiveTensorFunction { private final TensorFunction argument; private final List> subspaceAddress; /** * Creates a value function * * @param argument the tensor to return a cell value from * @param subspaceAddress a description of the address of the cell to return the value of. This is not a TensorAddress * because those require a type, but a type is not resolved until this is evaluated */ public Slice(TensorFunction argument, List> subspaceAddress) { this.argument = Objects.requireNonNull(argument, "Argument cannot be null"); if (subspaceAddress.size() > 1 && subspaceAddress.stream().anyMatch(c -> c.dimension().isEmpty())) throw new IllegalArgumentException("Short form of subspace addresses is only supported with a single dimension: " + "Specify dimension names explicitly instead"); this.subspaceAddress = subspaceAddress; } @Override public List> arguments() { return List.of(argument); } public List> selectorFunctions() { var result = new ArrayList>(); for (var dimVal : subspaceAddress) { dimVal.index().ifPresent(fun -> fun.asTensorFunction().ifPresent(tf -> result.add(tf))); } return result; } public TensorFunction withTransformedFunctions( Function, ScalarFunction> transformer) { List> transformedAddress = new ArrayList<>(); for (var orig : subspaceAddress) { var idxFun = orig.index(); if (idxFun.isPresent()) { var transformed = transformer.apply(idxFun.get()); transformedAddress.add(new DimensionValue(orig.dimension(), transformed)); } else { transformedAddress.add(orig); } } return new Slice<>(argument, transformedAddress); } @Override public Slice withArguments(List> arguments) { if (arguments.size() != 1) throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size()); return new Slice<>(arguments.get(0), subspaceAddress); } @Override public PrimitiveTensorFunction toPrimitive() { return this; } @Override public Tensor evaluate(EvaluationContext context) { Tensor tensor = argument.evaluate(context); TensorType resultType = resultType(tensor.type()); PartialAddress subspaceAddress = subspaceToAddress(tensor.type(), context); if (resultType.rank() == 0) { // shortcut common case return Tensor.from(tensor.get(subspaceAddress.asAddress(tensor.type()))); } Tensor.Builder b = Tensor.Builder.of(resultType); for (Iterator i = tensor.cellIterator(); i.hasNext(); ) { Tensor.Cell cell = i.next(); if (matches(subspaceAddress, cell.getKey(), tensor.type())) b.cell(remaining(resultType, cell.getKey(), tensor.type()), cell.getValue()); } return b.build(); } private PartialAddress subspaceToAddress(TensorType type, EvaluationContext context) { PartialAddress.Builder b = new PartialAddress.Builder(subspaceAddress.size()); for (int i = 0; i < subspaceAddress.size(); i++) { if (subspaceAddress.get(i).label().isPresent()) b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()), subspaceAddress.get(i).label().get()); else b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()), subspaceAddress.get(i).index().get().apply(context).intValue()); } return b.build(); } private boolean matches(PartialAddress subspaceAddress, TensorAddress address, TensorType type) { for (int i = 0; i < subspaceAddress.size(); i++) { String label = address.label(type.indexOfDimension(subspaceAddress.dimension(i)).get()); if ( ! label.equals(subspaceAddress.label(i))) return false; } return true; } /** Returns the subset of the given address which is present in the subspace type */ private TensorAddress remaining(TensorType subspaceType, TensorAddress address, TensorType type) { TensorAddress.Builder b = new TensorAddress.Builder(subspaceType); for (int i = 0; i < address.size(); i++) { String dimension = type.dimensions().get(i).name(); if (subspaceType.dimension(type.dimensions().get(i).name()).isPresent()) b.add(dimension, address.label(i)); } return b.build(); } @Override public TensorType type(TypeContext context) { return resultType(argument.type(context)); } private List findDimensions(List dims, Predicate pred) { return dims.stream().filter(pred).map(TensorType.Dimension::name).toList(); } private TensorType resultType(TensorType argumentType) { List peekDimensions; if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { // Special case where a single indexed or mapped dimension is sliced if (subspaceAddress.get(0).index().isPresent()) { peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isIndexed); if (peekDimensions.size() > 1) { throw new IllegalArgumentException(this + " slices a single indexed dimension, cannot be applied " + "to " + argumentType + ", which has multiple"); } } else { peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isMapped); if (peekDimensions.size() > 1) throw new IllegalArgumentException(this + " slices a single mapped dimension, cannot be applied " + "to " + argumentType + ", which has multiple"); } } else { // general slicing peekDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).toList(); } try { return TypeResolver.peek(argumentType, peekDimensions); } catch (IllegalArgumentException e) { throw new IllegalArgumentException(this + " cannot slice type " + argumentType, e); } } @Override public String toString(ToStringContext context) { StringBuilder b = new StringBuilder(argument.toString(context)); if (context.typeContext().isEmpty() && subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { // use short forms if (subspaceAddress.get(0).index().isPresent()) b.append("[").append(subspaceAddress.get(0).index().get().toString(context)).append("]"); else b.append("{").append(subspaceAddress.get(0).label().get()).append("}"); } else { // general form b.append("{").append(subspaceAddress.stream() .map(i -> i.toString(context, this)) .collect(Collectors.joining(", "))).append("}"); } return b.toString(); } @Override public int hashCode() { return Objects.hash("slice", argument, subspaceAddress); } public static class DimensionValue { private final Optional dimension; /** The label of this, or null if index is set */ private final String label; /** The function returning the index of this, or null if label is set */ private final ScalarFunction index; public DimensionValue(String dimension, String label) { this(Optional.of(dimension), label, null); } public DimensionValue(String dimension, int index) { this(Optional.of(dimension), null, new ConstantIntegerFunction<>(index)); } public DimensionValue(int index) { this(Optional.empty(), null, new ConstantIntegerFunction<>(index)); } public DimensionValue(String label) { this(Optional.empty(), label, null); } public DimensionValue(ScalarFunction index) { this(Optional.empty(), null, index); } public DimensionValue(Optional dimension, String label) { this(dimension, label, null); } public DimensionValue(Optional dimension, ScalarFunction index) { this(dimension, null, index); } public DimensionValue(String dimension, ScalarFunction index) { this(Optional.of(dimension), null, index); } private DimensionValue(Optional dimension, String label, ScalarFunction index) { this.dimension = dimension; this.label = label; this.index = index; } /** * Returns the given name of the dimension, or null if dense form is used, such that name * must be inferred from order */ public Optional dimension() { return dimension; } /** Returns the label for this dimension or empty if it is provided by an index function */ public Optional label() { return Optional.ofNullable(label); } /** Returns the index expression for this dimension, or empty if it is not a number */ public Optional> index() { return Optional.ofNullable(index); } @Override public String toString() { return toString(null, null); } String toString(ToStringContext context, Slice owner) { StringBuilder b = new StringBuilder(); Optional dimensionName = dimension; if (context != null && dimensionName.isEmpty()) { // This isn't just toString(): Output canonical form or fail TensorType type = context.typeContext().isPresent() ? owner.argument.type(context.typeContext().get()) : null; if (type == null || type.dimensions().size() != 1) throw new IllegalArgumentException("The tensor dimension name being sliced by " + owner + " cannot be uniquely resolved. Use the full form: " + "'slice{myDimensionName:" + valueToString(context) + "}'"); else dimensionName = Optional.of(type.dimensions().get(0).name()); } dimensionName.ifPresent(d -> b.append(d).append(":")); b.append(valueToString(context)); return b.toString(); } private String valueToString(ToStringContext context) { if (label != null) { return TensorAddress.labelToString(label); } else { return index.toString(context); } } @Override public int hashCode() { return Objects.hash(dimension, label, index); } } private static class ConstantIntegerFunction implements ScalarFunction { private final int value; public ConstantIntegerFunction(int value) { this.value = value; } @Override public Double apply(EvaluationContext context) { return (double)value; } @Override public String toString() { return String.valueOf(value); } @Override public int hashCode() { return Objects.hash("constantIntegerFunction", value); } } }