diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2020-01-02 11:35:37 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2020-01-02 11:35:37 +0100 |
commit | fe102598a18b21a859d5b802883ccb2f462962f9 (patch) | |
tree | 70a8d6d239797c18a8634665e2a65bfaabebabba /vespajlib/src/main/java/com/yahoo/tensor | |
parent | 6d7909e022817be11b5f088cbd1e537d9b71919d (diff) |
Add merge
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor')
7 files changed, 173 insertions, 64 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index 202817ece42..632501c7d08 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -38,7 +38,7 @@ public final class DimensionSizes { * @throws IllegalArgumentException if the index is larger than the number of dimensions in this tensor minus one */ public long size(int dimensionIndex) { - if (dimensionIndex <0 || dimensionIndex >= sizes.length) + if (dimensionIndex < 0 || dimensionIndex >= sizes.length) throw new IllegalArgumentException("Illegal dimension index " + dimensionIndex + ": This has " + sizes.length + " dimensions"); return sizes[dimensionIndex]; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index ba3a35e8eda..b255f18cdd4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -197,11 +197,6 @@ public abstract class IndexedTensor implements Tensor { } @Override - public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells) { - throw new IllegalArgumentException("Merge is not supported for indexed tensors"); - } - - @Override public Tensor remove(Set<TensorAddress> addresses) { throw new IllegalArgumentException("Remove is not supported for indexed tensors"); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index 693c4b5f2b0..33f904efd42 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -53,25 +53,6 @@ public class MappedTensor implements Tensor { } @Override - public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) { - - // currently, underlying implementation disallows multiple entries with the same key - - Tensor.Builder builder = Tensor.Builder.of(type()); - for (Map.Entry<TensorAddress, Double> cell : cells.entrySet()) { - TensorAddress address = cell.getKey(); - double value = cell.getValue(); - builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value); - } - for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { - if ( ! cells.containsKey(addCell.getKey())) { - builder.cell(addCell.getKey(), addCell.getValue()); - } - } - return builder.build(); - } - - @Override public Tensor remove(Set<TensorAddress> addresses) { Tensor.Builder builder = Tensor.Builder.of(type()); for (Iterator<Tensor.Cell> i = cellIterator(); i.hasNext(); ) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 0c4efe78113..ad4f0fd0dfb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -53,9 +53,11 @@ public class MixedTensor implements Tensor { @Override public double get(TensorAddress address) { long cellIndex = index.indexOf(address); + if (cellIndex < 0) + return Double.NaN; Cell cell = cells.get((int)cellIndex); if ( ! address.equals(cell.getKey())) - throw new IllegalStateException("Unable to find correct cell in " + this + " by direct index " + address); + return Double.NaN; return cell.getValue(); } @@ -71,10 +73,6 @@ public class MixedTensor implements Tensor { return cells.iterator(); } - private Iterable<Cell> cellIterable() { - return this::cellIterator; - } - /** * Returns an iterator over the values of this tensor. * The iteration order is the same as for cellIterator. @@ -113,20 +111,6 @@ public class MixedTensor implements Tensor { } @Override - public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) { - Tensor.Builder builder = Tensor.Builder.of(type()); - for (Cell cell : cellIterable()) { - TensorAddress address = cell.getKey(); - double value = cell.getValue(); - builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value); - } - for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { - builder.cell(addCell.getKey(), addCell.getValue()); - } - return builder.build(); - } - - @Override public Tensor remove(Set<TensorAddress> addresses) { Tensor.Builder builder = Tensor.Builder.of(type()); @@ -380,10 +364,11 @@ public class MixedTensor implements Tensor { this.denseType = createPartialType(type.valueType(), indexedDimensions); } + /** Returns the index of the given address, or -1 if it is not present */ public long indexOf(TensorAddress address) { TensorAddress sparsePart = sparsePartialAddress(address); if ( ! sparseMap.containsKey(sparsePart)) - throw new IllegalArgumentException("Address subspace " + sparsePart + " not found in " + this); + return -1; long base = sparseMap.get(sparsePart); long offset = denseOffset(address); return base + offset; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index cffd41905a1..6245c26b9f4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -12,6 +12,7 @@ import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.L1Normalize; import com.yahoo.tensor.functions.L2Normalize; import com.yahoo.tensor.functions.Matmul; +import com.yahoo.tensor.functions.Merge; import com.yahoo.tensor.functions.Random; import com.yahoo.tensor.functions.Range; import com.yahoo.tensor.functions.Reduce; @@ -124,18 +125,6 @@ public interface Tensor { /** * Returns a new tensor where existing cells in this tensor have been - * modified according to the given operation and cells in the given map. - * In contrast to {@link #modify}, previously non-existing cells are added - * to this tensor. Only valid for sparse or mixed tensors. - * - * @param op how to update overlapping cells - * @param cells cells to merge with this tensor - * @return a new tensor where this tensor is merged with the other - */ - Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells); - - /** - * Returns a new tensor where existing cells in this tensor have been * removed according to the given set of addresses. Only valid for sparse * or mixed tensors. For mixed tensors, addresses are assumed to only * contain the sparse dimensions, as the entire dense subspace is removed. @@ -164,6 +153,10 @@ public interface Tensor { return new Join<>(new ConstantTensor<>(this), new ConstantTensor<>(argument), combinator).evaluate(); } + default Tensor merge(Tensor argument, DoubleBinaryOperator combinator) { + return new Merge<>(new ConstantTensor<>(this), new ConstantTensor<>(argument), combinator).evaluate(); + } + default Tensor rename(String fromDimension, String toDimension) { return new Rename<>(new ConstantTensor<>(this), Collections.singletonList(fromDimension), Collections.singletonList(toDimension)).evaluate(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 58cb151875e..32398c5a1e9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -314,12 +314,13 @@ public class TensorType { /** * Returns the dimension resulting from combining two dimensions having the same name but possibly different - * types. This works by degrading to the type making the fewer promises. - * [N] + [M] = [min(N, M)] + * types: + * + * [N] + [M] = [ minimal ? min(N, M) : max(N, M) ] * [N] + [] = [] * [] + {} = {} */ - Dimension combineWith(Optional<Dimension> other) { + public Dimension combineWith(Optional<Dimension> other, boolean minimal) { if ( ! other.isPresent()) return this; if (this instanceof MappedDimension) return this; if (other.get() instanceof MappedDimension) return other.get(); @@ -329,7 +330,10 @@ public class TensorType { // both are indexed bound IndexedBoundDimension thisIb = (IndexedBoundDimension)this; IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get(); - return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb; + if (minimal) + return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb; + else + return thisIb.size().get() < otherIb.size().get() ? otherIb : thisIb; } @Override @@ -483,7 +487,7 @@ public class TensorType { /** * Creates a builder containing a combination of the dimensions of the given types * - * If the same dimension is indexed with different size restrictions the largest size will be used. + * If the same dimension is indexed with different size restrictions the smallest size will be used. * If it is size restricted in one argument but not the other it will not be size restricted. * If it is indexed in one and mapped in the other it will become mapped. * @@ -516,7 +520,7 @@ public class TensorType { } else { for (Dimension dimension : type.dimensions) - set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name())))); + set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name())), true)); } } @@ -528,7 +532,7 @@ public class TensorType { if (containsMapped) dimension = new MappedDimension(dimension.name()); Dimension existing = dimensions.get(dimension.name()); - set(dimension.combineWith(Optional.ofNullable(existing))); + set(dimension.combineWith(Optional.ofNullable(existing), true)); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java new file mode 100644 index 00000000000..350eaaa16f6 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java @@ -0,0 +1,151 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; +import com.yahoo.tensor.DimensionSizes; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.PartialAddress; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.function.DoubleBinaryOperator; + +/** + * The <i>merge</i> tensor operation produces from two argument tensors having equal dimension names + * a tensor having the same dimension names, with each dimension the smallest (see below) which can encompass all the + * values of both tensors, and where the values are the union of the values of both tensors. In the cases where both + * tensors contain a value for a given cell, and only then, the lambda scalar expression is evaluated to produce + * the resulting cell value. If none of the argument tensors have a value (but the cell exists due to merging + * indexed dimensions of uneven size in multidimensional tensors) the resulting cell is 0. + * <p> + * The type of each dimension of the result tensor is determined as follows: + * <ul> + * <li>If at least one of the two argument dimensions are mapped, the resulting dimension is mapped. + * <li>Otherwise, the size of the resulting (indexed) dimension is the max size of the argument dimensions. + * </ul> + * + * @author bratseth + */ +public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { + + private final TensorFunction<NAMETYPE> argumentA, argumentB; + private final DoubleBinaryOperator merger; + + public Merge(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, DoubleBinaryOperator merger) { + Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); + Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); + Objects.requireNonNull(merger, "The merger function cannot be null"); + this.argumentA = argumentA; + this.argumentB = argumentB; + this.merger = merger; + } + + /** Returns the type resulting from applying Merge to the two given types */ + public static TensorType outputType(TensorType a, TensorType b) { + if ( ! a.dimensionNames().equals(b.dimensionNames())) + throw new IllegalArgumentException("Cannot merge " + a + " and " + b + + ": Both arguments must have the same dimension names"); + + TensorType.Builder builder = new TensorType.Builder(TensorType.combinedValueType(a, b)); + for (TensorType.Dimension dimension : a.dimensions()) + builder.set(dimension.combineWith(b.dimension(dimension.name()), false)); + return builder.build(); + } + + public DoubleBinaryOperator merger() { return merger; } + + @Override + public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); } + + @Override + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { + if ( arguments.size() != 2) + throw new IllegalArgumentException("Merge must have 2 arguments, got " + arguments.size()); + return new Merge<>(arguments.get(0), arguments.get(1), merger); + } + + @Override + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Merge<>(argumentA.toPrimitive(), argumentB.toPrimitive(), merger); + } + + @Override + public String toString(ToStringContext context) { + return "merge(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + merger + ")"; + } + + @Override + public TensorType type(TypeContext<NAMETYPE> context) { + return outputType(argumentA.type(context), argumentB.type(context)); + } + + @Override + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { + Tensor a = argumentA.evaluate(context); + Tensor b = argumentB.evaluate(context); + TensorType mergedType = outputType(a.type(), b.type()); + return evaluate(a, b, mergedType, merger); + } + + static Tensor evaluate(Tensor a, Tensor b, TensorType mergedType, DoubleBinaryOperator combinator) { + // Choose merge algorithm + if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name())) + return indexedVectorMerge((IndexedTensor)a, (IndexedTensor)b, mergedType, combinator); + else + return generalMerge(a, b, mergedType, combinator); + } + + private static boolean hasSingleIndexedDimension(Tensor tensor) { + return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed(); + } + + private static Tensor indexedVectorMerge(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) { + long aSize = a.dimensionSizes().size(0); + long bSize = b.dimensionSizes().size(0); + long mergedSize = Math.max(aSize, bSize); + long sharedSize = Math.min(aSize, bSize); + Iterator<Double> aIterator = a.valueIterator(); + Iterator<Double> bIterator = b.valueIterator(); + IndexedTensor.Builder builder = IndexedTensor.Builder.of(type); + for (long i = 0; i < sharedSize; i++) + builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i); + Iterator<Double> largestIterator = aSize > bSize ? aIterator : bIterator; + for (long i = sharedSize; i < mergedSize; i++) + builder.cell(largestIterator.next(), i); + return builder.build(); + } + + private static Tensor generalMerge(Tensor a, Tensor b, TensorType mergedType, DoubleBinaryOperator combinator) { + Tensor.Builder builder = Tensor.Builder.of(mergedType); + addCellsOf(a, b, builder, combinator); + addCellsOf(b, a, builder, null); + return builder.build(); + } + + private static void addCellsOf(Tensor a, Tensor b, Tensor.Builder builder, DoubleBinaryOperator combinator) { + for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) { + Map.Entry<TensorAddress, Double> aCell = i.next(); + double bCellValue = b.get(aCell.getKey()); + if (Double.isNaN(bCellValue)) + builder.cell(aCell.getKey(), aCell.getValue()); + else if (combinator != null) + builder.cell(aCell.getKey(), combinator.applyAsDouble(aCell.getValue(), bCellValue)); + } + } + +} + |