// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; import com.yahoo.tensor.functions.Argmax; import com.yahoo.tensor.functions.Argmin; import com.yahoo.tensor.functions.Concat; import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Diag; import com.yahoo.tensor.functions.Generate; import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.L1Normalize; import com.yahoo.tensor.functions.L2Normalize; import com.yahoo.tensor.functions.Matmul; import com.yahoo.tensor.functions.Random; import com.yahoo.tensor.functions.Range; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.Softmax; import com.yahoo.tensor.functions.XwPlusB; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; import java.util.function.Function; /** * A multidimensional array which can be used in computations. *

* A tensor consists of a set of dimension names and a set of cells containing scalar values. * Each cell is is identified by its address, which consists of a set of dimension-label pairs which defines * the location of that cell. Both dimensions and labels are string on the form of an identifier or integer. *

* The size of the set of dimensions of a tensor is called its order. *

* In contrast to regular mathematical formulations of tensors, this definition of a tensor allows sparseness * as there is no built-in notion of a contiguous space, and even in cases where a space is implied (such as when * address labels are integers), there is no requirement that every implied cell has a defined value. * Undefined values have no define representation as they are never observed. *

* Tensors can be read and serialized to and from a string form documented in the {@link #toString} method. * * @author bratseth */ public interface Tensor { // ----------------- Accessors TensorType type(); /** Returns whether this have any cells */ default boolean isEmpty() { return size() == 0; } /** Returns the number of cells in this */ long size(); /** Returns the value of a cell, or NaN if this cell does not exist/have no value */ double get(TensorAddress address); /** Returns the cell of this in some undefined order */ Iterator cellIterator(); /** Returns the values of this in some undefined order */ Iterator valueIterator(); /** * Returns an immutable map of the cells of this in no particular order. * This may be expensive for some implementations - avoid when possible */ Map cells(); /** * Returns the value of this as a double if it has no dimensions and one value * * @throws IllegalStateException if this does not have zero dimensions and one value */ default double asDouble() { if (type().dimensions().size() > 0) throw new IllegalStateException("This tensor is not dimensionless. Dimensions: " + type().dimensions().size()); if (size() == 0) return Double.NaN; return valueIterator().next(); } /** * Returns this tensor with the given type if types are compatible * * @throws IllegalArgumentException if types are not compatible */ Tensor withType(TensorType type); /** * Returns a new tensor where existing cells in this tensor have been * modified according to the given operation and cells in the given map. * Cells in the map outside of existing cells are thus ignored. * * @param op the modifying function * @param cells the cells to modify * @return a new tensor with modified cells */ default Tensor modify(DoubleBinaryOperator op, Map cells) { Tensor.Builder builder = Tensor.Builder.of(type()); for (Iterator i = cellIterator(); i.hasNext(); ) { Cell cell = i.next(); TensorAddress address = cell.getKey(); double value = cell.getValue(); builder.cell(address, cells.containsKey(address) ? op.applyAsDouble(value, cells.get(address)) : value); } return builder.build(); } /** * Returns a new tensor where existing cells in this tensor have been * modified according to the given operation and cells in the given map. * In contrast to {@link #modify}, previously non-existing cells are added * to this tensor. Only valid for sparse or mixed tensors. * * @param op how to update overlapping cells * @param cells cells to merge with this tensor * @return a new tensor where this tensor is merged with the other */ Tensor merge(DoubleBinaryOperator op, Map cells); /** * Returns a new tensor where existing cells in this tensor have been * removed according to the given set of addresses. Only valid for sparse * or mixed tensors. For mixed tensors, addresses are assumed to only * contain the sparse dimensions, as the entire dense subspace is removed. * * @param addresses list of addresses to remove * @return a new tensor where cells have been removed */ Tensor remove(Set addresses); // ----------------- Primitive tensor functions default Tensor map(DoubleUnaryOperator mapper) { return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate(); } /** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */ default Tensor reduce(Reduce.Aggregator aggregator, String ... dimensions) { return new Reduce(new ConstantTensor(this), aggregator, Arrays.asList(dimensions)).evaluate(); } /** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */ default Tensor reduce(Reduce.Aggregator aggregator, List dimensions) { return new Reduce(new ConstantTensor(this), aggregator, dimensions).evaluate(); } default Tensor join(Tensor argument, DoubleBinaryOperator combinator) { return new Join(new ConstantTensor(this), new ConstantTensor(argument), combinator).evaluate(); } default Tensor rename(String fromDimension, String toDimension) { return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension), Collections.singletonList(toDimension)).evaluate(); } default Tensor concat(double argument, String dimension) { return concat(Tensor.Builder.of(TensorType.empty).cell(argument).build(), dimension); } default Tensor concat(Tensor argument, String dimension) { return new Concat(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate(); } default Tensor rename(List fromDimensions, List toDimensions) { return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate(); } static Tensor generate(TensorType type, Function, Double> valueSupplier) { return new Generate(type, valueSupplier).evaluate(); } // ----------------- Composite tensor functions which have a defined primitive mapping default Tensor l1Normalize(String dimension) { return new L1Normalize(new ConstantTensor(this), dimension).evaluate(); } default Tensor l2Normalize(String dimension) { return new L2Normalize(new ConstantTensor(this), dimension).evaluate(); } default Tensor matmul(Tensor argument, String dimension) { return new Matmul(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate(); } default Tensor softmax(String dimension) { return new Softmax(new ConstantTensor(this), dimension).evaluate(); } default Tensor xwPlusB(Tensor w, Tensor b, String dimension) { return new XwPlusB(new ConstantTensor(this), new ConstantTensor(w), new ConstantTensor(b), dimension).evaluate(); } default Tensor argmax(String dimension) { return new Argmax(new ConstantTensor(this), dimension).evaluate(); } default Tensor argmin(String dimension) { return new Argmin(new ConstantTensor(this), dimension).evaluate(); } static Tensor diag(TensorType type) { return new Diag(type).evaluate(); } static Tensor random(TensorType type) { return new Random(type).evaluate(); } static Tensor range(TensorType type) { return new Range(type).evaluate(); } // ----------------- Composite tensor functions mapped to primitives here on the fly default Tensor multiply(Tensor argument) { return join(argument, (a, b) -> (a * b )); } default Tensor add(Tensor argument) { return join(argument, (a, b) -> (a + b )); } default Tensor divide(Tensor argument) { return join(argument, (a, b) -> (a / b )); } default Tensor subtract(Tensor argument) { return join(argument, (a, b) -> (a - b )); } default Tensor max(Tensor argument) { return join(argument, (a, b) -> (a > b ? a : b )); } default Tensor min(Tensor argument) { return join(argument, (a, b) -> (a < b ? a : b )); } default Tensor atan2(Tensor argument) { return join(argument, Math::atan2); } default Tensor pow(Tensor argument) { return join(argument, Math::pow); } default Tensor fmod(Tensor argument) { return join(argument, (a, b) -> ( a % b )); } default Tensor ldexp(Tensor argument) { return join(argument, (a, b) -> ( a * Math.pow(2.0, (int)b) )); } default Tensor larger(Tensor argument) { return join(argument, (a, b) -> ( a > b ? 1.0 : 0.0)); } default Tensor largerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a >= b ? 1.0 : 0.0)); } default Tensor smaller(Tensor argument) { return join(argument, (a, b) -> ( a < b ? 1.0 : 0.0)); } default Tensor smallerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a <= b ? 1.0 : 0.0)); } default Tensor equal(Tensor argument) { return join(argument, (a, b) -> ( a == b ? 1.0 : 0.0)); } default Tensor notEqual(Tensor argument) { return join(argument, (a, b) -> ( a != b ? 1.0 : 0.0)); } default Tensor approxEqual(Tensor argument) { return join(argument, (a, b) -> ( approxEquals(a,b) ? 1.0 : 0.0)); } default Tensor avg() { return avg(Collections.emptyList()); } default Tensor avg(String dimension) { return avg(Collections.singletonList(dimension)); } default Tensor avg(List dimensions) { return reduce(Reduce.Aggregator.avg, dimensions); } default Tensor count() { return count(Collections.emptyList()); } default Tensor count(String dimension) { return count(Collections.singletonList(dimension)); } default Tensor count(List dimensions) { return reduce(Reduce.Aggregator.count, dimensions); } default Tensor max() { return max(Collections.emptyList()); } default Tensor max(String dimension) { return max(Collections.singletonList(dimension)); } default Tensor max(List dimensions) { return reduce(Reduce.Aggregator.max, dimensions); } default Tensor min() { return min(Collections.emptyList()); } default Tensor min(String dimension) { return min(Collections.singletonList(dimension)); } default Tensor min(List dimensions) { return reduce(Reduce.Aggregator.min, dimensions); } default Tensor prod() { return prod(Collections.emptyList()); } default Tensor prod(String dimension) { return prod(Collections.singletonList(dimension)); } default Tensor prod(List dimensions) { return reduce(Reduce.Aggregator.prod, dimensions); } default Tensor sum() { return sum(Collections.emptyList()); } default Tensor sum(String dimension) { return sum(Collections.singletonList(dimension)); } default Tensor sum(List dimensions) { return reduce(Reduce.Aggregator.sum, dimensions); } // ----------------- serialization /** * Returns this tensor on the form * {address1:value1,address2:value2,...} * where each address is on the form {dimension1:label1,dimension2:label2,...}, * and values are numbers. *

* Cells are listed in the natural order of tensor addresses: Increasing size primarily * and by element lexical order secondarily. *

* Note that while this is suggestive of JSON, it is not JSON. */ @Override String toString(); /** * Call this from toString in implementations to return the standard string format. * (toString cannot be a default method because default methods cannot override super methods). * * @param tensor the tensor to return the standard string format of * @return the tensor on the standard string format */ static String toStandardString(Tensor tensor) { return tensor.type() + ":" + contentToString(tensor); } static String contentToString(Tensor tensor) { List> cellEntries = new ArrayList<>(tensor.cells().entrySet()); if (tensor.type().dimensions().isEmpty()) { if (cellEntries.isEmpty()) return "{}"; return "{" + cellEntries.get(0).getValue() +"}"; } Collections.sort(cellEntries, java.util.Map.Entry.comparingByKey()); StringBuilder b = new StringBuilder("{"); for (java.util.Map.Entry cell : cellEntries) { b.append(cell.getKey().toString(tensor.type())).append(":").append(cell.getValue()); b.append(","); } if (b.length() > 1) b.setLength(b.length() - 1); b.append("}"); return b.toString(); } // ----------------- equality /** * Returns whether this tensor and the given tensor is mathematically equal: * That they have the same dimension *names* and the same content. */ boolean equals(Object o); /** * Implement here to make this work across implementations. * Implementations must override equals and call this because this is an interface and cannot override equals. */ static boolean equals(Tensor a, Tensor b) { if (a == b) return true; if ( ! a.type().mathematicallyEquals(b.type())) return false; if ( a.size() != b.size()) return false; for (Iterator aIterator = a.cellIterator(); aIterator.hasNext(); ) { Cell aCell = aIterator.next(); double aValue = aCell.getValue(); double bValue = b.get(aCell.getKey()); if (!approxEquals(aValue, bValue, 1e-5)) return false; } return true; } static boolean approxEquals(double x, double y, double tolerance) { return Math.abs(x-y) < tolerance; } static boolean approxEquals(double x, double y) { if (y < -1.0 || y > 1.0) { x = Math.nextAfter(x/y, 1.0); y = 1.0; } else { x = Math.nextAfter(x, y); } return x == y; } // ----------------- Factories /** * Returns a tensor instance containing the given data on the standard string format returned by toString * * @param type the type of the tensor to return * @param tensorString the tensor on the standard tensor string format */ static Tensor from(TensorType type, String tensorString) { return TensorParser.tensorFrom(tensorString, Optional.of(type)); } /** * Returns a tensor instance containing the given data on the standard string format returned by toString * * @param tensorType the type of the tensor to return, as a string on the tensor type format, given in * {@link TensorType#fromSpec} * @param tensorString the tensor on the standard tensor string format */ static Tensor from(String tensorType, String tensorString) { return TensorParser.tensorFrom(tensorString, Optional.of(TensorType.fromSpec(tensorType))); } /** * Returns a tensor instance containing the given data on the standard string format returned by toString. * If a type is not specified it is derived from the first cell of the tensor */ static Tensor from(String tensorString) { return TensorParser.tensorFrom(tensorString, Optional.empty()); } /** Returns a double as a tensor: A dimensionless tensor containing the value as its cell */ static Tensor from(double value) { return Tensor.Builder.of(TensorType.empty).cell(value).build(); } class Cell implements Map.Entry { private final TensorAddress address; private final Number value; Cell(TensorAddress address, Number value) { this.address = address; this.value = value; } @Override public TensorAddress getKey() { return address; } /** * Returns the direct index which can be used to locate this cell, or -1 if not available. * This is for optimizations mapping between tensors where this is possible without creating a * TensorAddress. */ long getDirectIndex() { return -1; } /** Returns the value as a double */ @Override public Double getValue() { return value.doubleValue(); } /** Returns the value as a float */ public float getFloatValue() { return value.floatValue(); } /** Returns the value as a double */ public double getDoubleValue() { return value.doubleValue(); } @Override public Double setValue(Double value) { throw new UnsupportedOperationException("A tensor cannot be modified"); } @Override public boolean equals(Object o) { if (o == this) return true; if ( ! ( o instanceof Map.Entry)) return false; Map.Entry other = (Map.Entry)o; if ( ! this.getValue().equals(other.getValue())) return false; if ( ! this.getKey().equals(other.getKey())) return false; return true; } @Override public int hashCode() { return getKey().hashCode() ^ getValue().hashCode(); // by Map.Entry spec } } interface Builder { /** Creates a suitable builder for the given type */ static Builder of(TensorType type) { boolean containsIndexed = type.dimensions().stream().anyMatch(d -> d.isIndexed()); boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); if (containsIndexed && containsMapped) return MixedTensor.Builder.of(type); if (containsMapped) return MappedTensor.Builder.of(type); else // indexed or empty return IndexedTensor.Builder.of(type); } /** Creates a suitable builder for the given type */ static Builder of(TensorType type, DimensionSizes dimensionSizes) { boolean containsIndexed = type.dimensions().stream().anyMatch(d -> d.isIndexed()); boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); if (containsIndexed && containsMapped) return MixedTensor.Builder.of(type); if (containsMapped) return MappedTensor.Builder.of(type); else // indexed or empty return IndexedTensor.Builder.of(type, dimensionSizes); } /** Returns the type this is building */ TensorType type(); /** Return a cell builder */ CellBuilder cell(); /** Add a cell */ Builder cell(TensorAddress address, double value); Builder cell(TensorAddress address, float value); /** Add a cell */ Builder cell(double value, long ... labels); Builder cell(float value, long ... labels); /** * Add a cell * * @param cell a cell providing the location at which to add this cell * @param value the value to assign to the cell */ default Builder cell(Cell cell, double value) { return cell(cell.getKey(), value); } default Builder cell(Cell cell, float value) { return cell(cell.getKey(), value); } Tensor build(); class CellBuilder { private final TensorAddress.Builder addressBuilder; private final Tensor.Builder tensorBuilder; CellBuilder(TensorType type, Tensor.Builder tensorBuilder) { addressBuilder = new TensorAddress.Builder(type); this.tensorBuilder = tensorBuilder; } public CellBuilder label(String dimension, String label) { addressBuilder.add(dimension, label); return this; } public CellBuilder label(String dimension, long label) { return label(dimension, String.valueOf(label)); } public Builder value(double cellValue) { return tensorBuilder.cell(addressBuilder.build(), cellValue); } public Builder value(float cellValue) { return tensorBuilder.cell(addressBuilder.build(), cellValue); } } } }