diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2020-01-02 11:35:37 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2020-01-02 11:35:37 +0100 |
commit | fe102598a18b21a859d5b802883ccb2f462962f9 (patch) | |
tree | 70a8d6d239797c18a8634665e2a65bfaabebabba /vespajlib | |
parent | 6d7909e022817be11b5f088cbd1e537d9b71919d (diff) |
Add merge
Diffstat (limited to 'vespajlib')
9 files changed, 211 insertions, 69 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index a4a9a1e1b24..623c965e603 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -896,7 +896,6 @@ "public abstract com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.DimensionSizes dimensionSizes()", "public java.util.Map cells()", - "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public java.lang.String toString()", "public boolean equals(java.lang.Object)", @@ -945,7 +944,6 @@ "public java.util.Iterator valueIterator()", "public java.util.Map cells()", "public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", - "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", @@ -1036,7 +1034,6 @@ "public java.util.Iterator valueIterator()", "public java.util.Map cells()", "public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", - "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", @@ -1157,12 +1154,12 @@ "public double asDouble()", "public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)", - "public abstract com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", "public abstract com.yahoo.tensor.Tensor remove(java.util.Set)", "public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)", "public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])", "public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)", "public com.yahoo.tensor.Tensor join(com.yahoo.tensor.Tensor, java.util.function.DoubleBinaryOperator)", + "public com.yahoo.tensor.Tensor merge(com.yahoo.tensor.Tensor, java.util.function.DoubleBinaryOperator)", "public com.yahoo.tensor.Tensor rename(java.lang.String, java.lang.String)", "public com.yahoo.tensor.Tensor concat(double, java.lang.String)", "public com.yahoo.tensor.Tensor concat(com.yahoo.tensor.Tensor, java.lang.String)", @@ -1327,6 +1324,7 @@ "public abstract com.yahoo.tensor.TensorType$Dimension$Type type()", "public abstract com.yahoo.tensor.TensorType$Dimension withName(java.lang.String)", "public boolean isIndexed()", + "public com.yahoo.tensor.TensorType$Dimension combineWith(java.util.Optional, boolean)", "public abstract java.lang.String toString()", "public boolean equals(java.lang.Object)", "public int hashCode()", @@ -1746,6 +1744,25 @@ ], "fields": [] }, + "com.yahoo.tensor.functions.Merge": { + "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>(com.yahoo.tensor.functions.TensorFunction, com.yahoo.tensor.functions.TensorFunction, java.util.function.DoubleBinaryOperator)", + "public static com.yahoo.tensor.TensorType outputType(com.yahoo.tensor.TensorType, com.yahoo.tensor.TensorType)", + "public java.util.function.DoubleBinaryOperator merger()", + "public java.util.List arguments()", + "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", + "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", + "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", + "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)" + ], + "fields": [] + }, "com.yahoo.tensor.functions.PrimitiveTensorFunction": { "superClass": "com.yahoo.tensor.functions.TensorFunction", "interfaces": [], diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index 202817ece42..632501c7d08 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -38,7 +38,7 @@ public final class DimensionSizes { * @throws IllegalArgumentException if the index is larger than the number of dimensions in this tensor minus one */ public long size(int dimensionIndex) { - if (dimensionIndex <0 || dimensionIndex >= sizes.length) + if (dimensionIndex < 0 || dimensionIndex >= sizes.length) throw new IllegalArgumentException("Illegal dimension index " + dimensionIndex + ": This has " + sizes.length + " dimensions"); return sizes[dimensionIndex]; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index ba3a35e8eda..b255f18cdd4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -197,11 +197,6 @@ public abstract class IndexedTensor implements Tensor { } @Override - public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells) { - throw new IllegalArgumentException("Merge is not supported for indexed tensors"); - } - - @Override public Tensor remove(Set<TensorAddress> addresses) { throw new IllegalArgumentException("Remove is not supported for indexed tensors"); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index 693c4b5f2b0..33f904efd42 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -53,25 +53,6 @@ public class MappedTensor implements Tensor { } @Override - public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) { - - // currently, underlying implementation disallows multiple entries with the same key - - Tensor.Builder builder = Tensor.Builder.of(type()); - for (Map.Entry<TensorAddress, Double> cell : cells.entrySet()) { - TensorAddress address = cell.getKey(); - double value = cell.getValue(); - builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value); - } - for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { - if ( ! cells.containsKey(addCell.getKey())) { - builder.cell(addCell.getKey(), addCell.getValue()); - } - } - return builder.build(); - } - - @Override public Tensor remove(Set<TensorAddress> addresses) { Tensor.Builder builder = Tensor.Builder.of(type()); for (Iterator<Tensor.Cell> i = cellIterator(); i.hasNext(); ) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 0c4efe78113..ad4f0fd0dfb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -53,9 +53,11 @@ public class MixedTensor implements Tensor { @Override public double get(TensorAddress address) { long cellIndex = index.indexOf(address); + if (cellIndex < 0) + return Double.NaN; Cell cell = cells.get((int)cellIndex); if ( ! address.equals(cell.getKey())) - throw new IllegalStateException("Unable to find correct cell in " + this + " by direct index " + address); + return Double.NaN; return cell.getValue(); } @@ -71,10 +73,6 @@ public class MixedTensor implements Tensor { return cells.iterator(); } - private Iterable<Cell> cellIterable() { - return this::cellIterator; - } - /** * Returns an iterator over the values of this tensor. * The iteration order is the same as for cellIterator. @@ -113,20 +111,6 @@ public class MixedTensor implements Tensor { } @Override - public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) { - Tensor.Builder builder = Tensor.Builder.of(type()); - for (Cell cell : cellIterable()) { - TensorAddress address = cell.getKey(); - double value = cell.getValue(); - builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value); - } - for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { - builder.cell(addCell.getKey(), addCell.getValue()); - } - return builder.build(); - } - - @Override public Tensor remove(Set<TensorAddress> addresses) { Tensor.Builder builder = Tensor.Builder.of(type()); @@ -380,10 +364,11 @@ public class MixedTensor implements Tensor { this.denseType = createPartialType(type.valueType(), indexedDimensions); } + /** Returns the index of the given address, or -1 if it is not present */ public long indexOf(TensorAddress address) { TensorAddress sparsePart = sparsePartialAddress(address); if ( ! sparseMap.containsKey(sparsePart)) - throw new IllegalArgumentException("Address subspace " + sparsePart + " not found in " + this); + return -1; long base = sparseMap.get(sparsePart); long offset = denseOffset(address); return base + offset; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index cffd41905a1..6245c26b9f4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -12,6 +12,7 @@ import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.L1Normalize; import com.yahoo.tensor.functions.L2Normalize; import com.yahoo.tensor.functions.Matmul; +import com.yahoo.tensor.functions.Merge; import com.yahoo.tensor.functions.Random; import com.yahoo.tensor.functions.Range; import com.yahoo.tensor.functions.Reduce; @@ -124,18 +125,6 @@ public interface Tensor { /** * Returns a new tensor where existing cells in this tensor have been - * modified according to the given operation and cells in the given map. - * In contrast to {@link #modify}, previously non-existing cells are added - * to this tensor. Only valid for sparse or mixed tensors. - * - * @param op how to update overlapping cells - * @param cells cells to merge with this tensor - * @return a new tensor where this tensor is merged with the other - */ - Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells); - - /** - * Returns a new tensor where existing cells in this tensor have been * removed according to the given set of addresses. Only valid for sparse * or mixed tensors. For mixed tensors, addresses are assumed to only * contain the sparse dimensions, as the entire dense subspace is removed. @@ -164,6 +153,10 @@ public interface Tensor { return new Join<>(new ConstantTensor<>(this), new ConstantTensor<>(argument), combinator).evaluate(); } + default Tensor merge(Tensor argument, DoubleBinaryOperator combinator) { + return new Merge<>(new ConstantTensor<>(this), new ConstantTensor<>(argument), combinator).evaluate(); + } + default Tensor rename(String fromDimension, String toDimension) { return new Rename<>(new ConstantTensor<>(this), Collections.singletonList(fromDimension), Collections.singletonList(toDimension)).evaluate(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 58cb151875e..32398c5a1e9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -314,12 +314,13 @@ public class TensorType { /** * Returns the dimension resulting from combining two dimensions having the same name but possibly different - * types. This works by degrading to the type making the fewer promises. - * [N] + [M] = [min(N, M)] + * types: + * + * [N] + [M] = [ minimal ? min(N, M) : max(N, M) ] * [N] + [] = [] * [] + {} = {} */ - Dimension combineWith(Optional<Dimension> other) { + public Dimension combineWith(Optional<Dimension> other, boolean minimal) { if ( ! other.isPresent()) return this; if (this instanceof MappedDimension) return this; if (other.get() instanceof MappedDimension) return other.get(); @@ -329,7 +330,10 @@ public class TensorType { // both are indexed bound IndexedBoundDimension thisIb = (IndexedBoundDimension)this; IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get(); - return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb; + if (minimal) + return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb; + else + return thisIb.size().get() < otherIb.size().get() ? otherIb : thisIb; } @Override @@ -483,7 +487,7 @@ public class TensorType { /** * Creates a builder containing a combination of the dimensions of the given types * - * If the same dimension is indexed with different size restrictions the largest size will be used. + * If the same dimension is indexed with different size restrictions the smallest size will be used. * If it is size restricted in one argument but not the other it will not be size restricted. * If it is indexed in one and mapped in the other it will become mapped. * @@ -516,7 +520,7 @@ public class TensorType { } else { for (Dimension dimension : type.dimensions) - set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name())))); + set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name())), true)); } } @@ -528,7 +532,7 @@ public class TensorType { if (containsMapped) dimension = new MappedDimension(dimension.name()); Dimension existing = dimensions.get(dimension.name()); - set(dimension.combineWith(Optional.ofNullable(existing))); + set(dimension.combineWith(Optional.ofNullable(existing), true)); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java new file mode 100644 index 00000000000..350eaaa16f6 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java @@ -0,0 +1,151 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; +import com.yahoo.tensor.DimensionSizes; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.PartialAddress; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.function.DoubleBinaryOperator; + +/** + * The <i>merge</i> tensor operation produces from two argument tensors having equal dimension names + * a tensor having the same dimension names, with each dimension the smallest (see below) which can encompass all the + * values of both tensors, and where the values are the union of the values of both tensors. In the cases where both + * tensors contain a value for a given cell, and only then, the lambda scalar expression is evaluated to produce + * the resulting cell value. If none of the argument tensors have a value (but the cell exists due to merging + * indexed dimensions of uneven size in multidimensional tensors) the resulting cell is 0. + * <p> + * The type of each dimension of the result tensor is determined as follows: + * <ul> + * <li>If at least one of the two argument dimensions are mapped, the resulting dimension is mapped. + * <li>Otherwise, the size of the resulting (indexed) dimension is the max size of the argument dimensions. + * </ul> + * + * @author bratseth + */ +public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { + + private final TensorFunction<NAMETYPE> argumentA, argumentB; + private final DoubleBinaryOperator merger; + + public Merge(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, DoubleBinaryOperator merger) { + Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); + Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); + Objects.requireNonNull(merger, "The merger function cannot be null"); + this.argumentA = argumentA; + this.argumentB = argumentB; + this.merger = merger; + } + + /** Returns the type resulting from applying Merge to the two given types */ + public static TensorType outputType(TensorType a, TensorType b) { + if ( ! a.dimensionNames().equals(b.dimensionNames())) + throw new IllegalArgumentException("Cannot merge " + a + " and " + b + + ": Both arguments must have the same dimension names"); + + TensorType.Builder builder = new TensorType.Builder(TensorType.combinedValueType(a, b)); + for (TensorType.Dimension dimension : a.dimensions()) + builder.set(dimension.combineWith(b.dimension(dimension.name()), false)); + return builder.build(); + } + + public DoubleBinaryOperator merger() { return merger; } + + @Override + public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); } + + @Override + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { + if ( arguments.size() != 2) + throw new IllegalArgumentException("Merge must have 2 arguments, got " + arguments.size()); + return new Merge<>(arguments.get(0), arguments.get(1), merger); + } + + @Override + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Merge<>(argumentA.toPrimitive(), argumentB.toPrimitive(), merger); + } + + @Override + public String toString(ToStringContext context) { + return "merge(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + merger + ")"; + } + + @Override + public TensorType type(TypeContext<NAMETYPE> context) { + return outputType(argumentA.type(context), argumentB.type(context)); + } + + @Override + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { + Tensor a = argumentA.evaluate(context); + Tensor b = argumentB.evaluate(context); + TensorType mergedType = outputType(a.type(), b.type()); + return evaluate(a, b, mergedType, merger); + } + + static Tensor evaluate(Tensor a, Tensor b, TensorType mergedType, DoubleBinaryOperator combinator) { + // Choose merge algorithm + if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name())) + return indexedVectorMerge((IndexedTensor)a, (IndexedTensor)b, mergedType, combinator); + else + return generalMerge(a, b, mergedType, combinator); + } + + private static boolean hasSingleIndexedDimension(Tensor tensor) { + return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed(); + } + + private static Tensor indexedVectorMerge(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) { + long aSize = a.dimensionSizes().size(0); + long bSize = b.dimensionSizes().size(0); + long mergedSize = Math.max(aSize, bSize); + long sharedSize = Math.min(aSize, bSize); + Iterator<Double> aIterator = a.valueIterator(); + Iterator<Double> bIterator = b.valueIterator(); + IndexedTensor.Builder builder = IndexedTensor.Builder.of(type); + for (long i = 0; i < sharedSize; i++) + builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i); + Iterator<Double> largestIterator = aSize > bSize ? aIterator : bIterator; + for (long i = sharedSize; i < mergedSize; i++) + builder.cell(largestIterator.next(), i); + return builder.build(); + } + + private static Tensor generalMerge(Tensor a, Tensor b, TensorType mergedType, DoubleBinaryOperator combinator) { + Tensor.Builder builder = Tensor.Builder.of(mergedType); + addCellsOf(a, b, builder, combinator); + addCellsOf(b, a, builder, null); + return builder.build(); + } + + private static void addCellsOf(Tensor a, Tensor b, Tensor.Builder builder, DoubleBinaryOperator combinator) { + for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) { + Map.Entry<TensorAddress, Double> aCell = i.next(); + double bCellValue = b.get(aCell.getKey()); + if (Double.isNaN(bCellValue)) + builder.cell(aCell.getKey(), aCell.getValue()); + else if (combinator != null) + builder.cell(aCell.getKey(), combinator.applyAsDouble(aCell.getValue(), bCellValue)); + } + } + +} + diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 7932f90d797..43f9b976978 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -216,6 +216,22 @@ public class TensorTestCase { Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"), Tensor.from("tensor(x{},y[3])", "{}"), Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}")); + assertTensorMerge( + Tensor.from("tensor(x[2]):[5,6]"), + Tensor.from("tensor(x[4]):[1,2,3,4]"), + Tensor.from("tensor(x[4]):[1,2,3,4]")); + assertTensorMerge( + Tensor.from("tensor(x[4]):[1,2,3,4]"), + Tensor.from("tensor(x[2]):[5,6]"), + Tensor.from("tensor(x[4]):[5,6,3,4]")); + assertTensorMerge( + Tensor.from("tensor(x[4],y[2]):[[1,2],[3,4],[5,6],[7,8]]"), + Tensor.from("tensor(x[2],y[3]):[[9,10,11],[12,13,14]]"), + Tensor.from("tensor(x[4],y[3]):[[9,10,11],[12,13,14],[5,6,0],[7,8,0]]")); + assertTensorMerge( + Tensor.from("tensor(key{},x[4]):{a:[1,2,3,4],c:[5,6,7,8]}"), + Tensor.from("tensor(key{},x[2]):{a:[9,10],b:[11,12]}"), + Tensor.from("tensor(key{},x[4]):{a:[9,10,3,4],b:[11,12,0,0],c:[5,6,7,8]}")); } @Test @@ -302,7 +318,7 @@ public class TensorTestCase { private void assertTensorMerge(Tensor init, Tensor update, Tensor expected) { DoubleBinaryOperator op = (left, right) -> right; - assertEquals(expected, init.merge(op, update.cells())); + assertEquals(expected, init.merge(update, op)); } private void assertTensorRemove(Tensor init, Tensor update, Tensor expected) { |