summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2020-01-02 11:35:37 +0100
committerJon Bratseth <bratseth@verizonmedia.com>2020-01-02 11:35:37 +0100
commitfe102598a18b21a859d5b802883ccb2f462962f9 (patch)
tree70a8d6d239797c18a8634665e2a65bfaabebabba /vespajlib
parent6d7909e022817be11b5f088cbd1e537d9b71919d (diff)
Add merge
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json25
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java19
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java25
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java17
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java151
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java18
9 files changed, 211 insertions, 69 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index a4a9a1e1b24..623c965e603 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -896,7 +896,6 @@
"public abstract com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.DimensionSizes dimensionSizes()",
"public java.util.Map cells()",
- "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
"public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
@@ -945,7 +944,6 @@
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
- "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
"public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
@@ -1036,7 +1034,6 @@
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
- "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
"public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
@@ -1157,12 +1154,12 @@
"public double asDouble()",
"public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)",
- "public abstract com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
"public abstract com.yahoo.tensor.Tensor remove(java.util.Set)",
"public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)",
"public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])",
"public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)",
"public com.yahoo.tensor.Tensor join(com.yahoo.tensor.Tensor, java.util.function.DoubleBinaryOperator)",
+ "public com.yahoo.tensor.Tensor merge(com.yahoo.tensor.Tensor, java.util.function.DoubleBinaryOperator)",
"public com.yahoo.tensor.Tensor rename(java.lang.String, java.lang.String)",
"public com.yahoo.tensor.Tensor concat(double, java.lang.String)",
"public com.yahoo.tensor.Tensor concat(com.yahoo.tensor.Tensor, java.lang.String)",
@@ -1327,6 +1324,7 @@
"public abstract com.yahoo.tensor.TensorType$Dimension$Type type()",
"public abstract com.yahoo.tensor.TensorType$Dimension withName(java.lang.String)",
"public boolean isIndexed()",
+ "public com.yahoo.tensor.TensorType$Dimension combineWith(java.util.Optional, boolean)",
"public abstract java.lang.String toString()",
"public boolean equals(java.lang.Object)",
"public int hashCode()",
@@ -1746,6 +1744,25 @@
],
"fields": []
},
+ "com.yahoo.tensor.functions.Merge": {
+ "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>(com.yahoo.tensor.functions.TensorFunction, com.yahoo.tensor.functions.TensorFunction, java.util.function.DoubleBinaryOperator)",
+ "public static com.yahoo.tensor.TensorType outputType(com.yahoo.tensor.TensorType, com.yahoo.tensor.TensorType)",
+ "public java.util.function.DoubleBinaryOperator merger()",
+ "public java.util.List arguments()",
+ "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
+ "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
+ "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)"
+ ],
+ "fields": []
+ },
"com.yahoo.tensor.functions.PrimitiveTensorFunction": {
"superClass": "com.yahoo.tensor.functions.TensorFunction",
"interfaces": [],
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
index 202817ece42..632501c7d08 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
@@ -38,7 +38,7 @@ public final class DimensionSizes {
* @throws IllegalArgumentException if the index is larger than the number of dimensions in this tensor minus one
*/
public long size(int dimensionIndex) {
- if (dimensionIndex <0 || dimensionIndex >= sizes.length)
+ if (dimensionIndex < 0 || dimensionIndex >= sizes.length)
throw new IllegalArgumentException("Illegal dimension index " + dimensionIndex +
": This has " + sizes.length + " dimensions");
return sizes[dimensionIndex];
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index ba3a35e8eda..b255f18cdd4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -197,11 +197,6 @@ public abstract class IndexedTensor implements Tensor {
}
@Override
- public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells) {
- throw new IllegalArgumentException("Merge is not supported for indexed tensors");
- }
-
- @Override
public Tensor remove(Set<TensorAddress> addresses) {
throw new IllegalArgumentException("Remove is not supported for indexed tensors");
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index 693c4b5f2b0..33f904efd42 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -53,25 +53,6 @@ public class MappedTensor implements Tensor {
}
@Override
- public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) {
-
- // currently, underlying implementation disallows multiple entries with the same key
-
- Tensor.Builder builder = Tensor.Builder.of(type());
- for (Map.Entry<TensorAddress, Double> cell : cells.entrySet()) {
- TensorAddress address = cell.getKey();
- double value = cell.getValue();
- builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value);
- }
- for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) {
- if ( ! cells.containsKey(addCell.getKey())) {
- builder.cell(addCell.getKey(), addCell.getValue());
- }
- }
- return builder.build();
- }
-
- @Override
public Tensor remove(Set<TensorAddress> addresses) {
Tensor.Builder builder = Tensor.Builder.of(type());
for (Iterator<Tensor.Cell> i = cellIterator(); i.hasNext(); ) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 0c4efe78113..ad4f0fd0dfb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -53,9 +53,11 @@ public class MixedTensor implements Tensor {
@Override
public double get(TensorAddress address) {
long cellIndex = index.indexOf(address);
+ if (cellIndex < 0)
+ return Double.NaN;
Cell cell = cells.get((int)cellIndex);
if ( ! address.equals(cell.getKey()))
- throw new IllegalStateException("Unable to find correct cell in " + this + " by direct index " + address);
+ return Double.NaN;
return cell.getValue();
}
@@ -71,10 +73,6 @@ public class MixedTensor implements Tensor {
return cells.iterator();
}
- private Iterable<Cell> cellIterable() {
- return this::cellIterator;
- }
-
/**
* Returns an iterator over the values of this tensor.
* The iteration order is the same as for cellIterator.
@@ -113,20 +111,6 @@ public class MixedTensor implements Tensor {
}
@Override
- public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) {
- Tensor.Builder builder = Tensor.Builder.of(type());
- for (Cell cell : cellIterable()) {
- TensorAddress address = cell.getKey();
- double value = cell.getValue();
- builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value);
- }
- for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) {
- builder.cell(addCell.getKey(), addCell.getValue());
- }
- return builder.build();
- }
-
- @Override
public Tensor remove(Set<TensorAddress> addresses) {
Tensor.Builder builder = Tensor.Builder.of(type());
@@ -380,10 +364,11 @@ public class MixedTensor implements Tensor {
this.denseType = createPartialType(type.valueType(), indexedDimensions);
}
+ /** Returns the index of the given address, or -1 if it is not present */
public long indexOf(TensorAddress address) {
TensorAddress sparsePart = sparsePartialAddress(address);
if ( ! sparseMap.containsKey(sparsePart))
- throw new IllegalArgumentException("Address subspace " + sparsePart + " not found in " + this);
+ return -1;
long base = sparseMap.get(sparsePart);
long offset = denseOffset(address);
return base + offset;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index cffd41905a1..6245c26b9f4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -12,6 +12,7 @@ import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.L1Normalize;
import com.yahoo.tensor.functions.L2Normalize;
import com.yahoo.tensor.functions.Matmul;
+import com.yahoo.tensor.functions.Merge;
import com.yahoo.tensor.functions.Random;
import com.yahoo.tensor.functions.Range;
import com.yahoo.tensor.functions.Reduce;
@@ -124,18 +125,6 @@ public interface Tensor {
/**
* Returns a new tensor where existing cells in this tensor have been
- * modified according to the given operation and cells in the given map.
- * In contrast to {@link #modify}, previously non-existing cells are added
- * to this tensor. Only valid for sparse or mixed tensors.
- *
- * @param op how to update overlapping cells
- * @param cells cells to merge with this tensor
- * @return a new tensor where this tensor is merged with the other
- */
- Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells);
-
- /**
- * Returns a new tensor where existing cells in this tensor have been
* removed according to the given set of addresses. Only valid for sparse
* or mixed tensors. For mixed tensors, addresses are assumed to only
* contain the sparse dimensions, as the entire dense subspace is removed.
@@ -164,6 +153,10 @@ public interface Tensor {
return new Join<>(new ConstantTensor<>(this), new ConstantTensor<>(argument), combinator).evaluate();
}
+ default Tensor merge(Tensor argument, DoubleBinaryOperator combinator) {
+ return new Merge<>(new ConstantTensor<>(this), new ConstantTensor<>(argument), combinator).evaluate();
+ }
+
default Tensor rename(String fromDimension, String toDimension) {
return new Rename<>(new ConstantTensor<>(this), Collections.singletonList(fromDimension),
Collections.singletonList(toDimension)).evaluate();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 58cb151875e..32398c5a1e9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -314,12 +314,13 @@ public class TensorType {
/**
* Returns the dimension resulting from combining two dimensions having the same name but possibly different
- * types. This works by degrading to the type making the fewer promises.
- * [N] + [M] = [min(N, M)]
+ * types:
+ *
+ * [N] + [M] = [ minimal ? min(N, M) : max(N, M) ]
* [N] + [] = []
* [] + {} = {}
*/
- Dimension combineWith(Optional<Dimension> other) {
+ public Dimension combineWith(Optional<Dimension> other, boolean minimal) {
if ( ! other.isPresent()) return this;
if (this instanceof MappedDimension) return this;
if (other.get() instanceof MappedDimension) return other.get();
@@ -329,7 +330,10 @@ public class TensorType {
// both are indexed bound
IndexedBoundDimension thisIb = (IndexedBoundDimension)this;
IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get();
- return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb;
+ if (minimal)
+ return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb;
+ else
+ return thisIb.size().get() < otherIb.size().get() ? otherIb : thisIb;
}
@Override
@@ -483,7 +487,7 @@ public class TensorType {
/**
* Creates a builder containing a combination of the dimensions of the given types
*
- * If the same dimension is indexed with different size restrictions the largest size will be used.
+ * If the same dimension is indexed with different size restrictions the smallest size will be used.
* If it is size restricted in one argument but not the other it will not be size restricted.
* If it is indexed in one and mapped in the other it will become mapped.
*
@@ -516,7 +520,7 @@ public class TensorType {
}
else {
for (Dimension dimension : type.dimensions)
- set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name()))));
+ set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name())), true));
}
}
@@ -528,7 +532,7 @@ public class TensorType {
if (containsMapped)
dimension = new MappedDimension(dimension.name());
Dimension existing = dimensions.get(dimension.name());
- set(dimension.combineWith(Optional.ofNullable(existing)));
+ set(dimension.combineWith(Optional.ofNullable(existing), true));
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
new file mode 100644
index 00000000000..350eaaa16f6
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
@@ -0,0 +1,151 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Sets;
+import com.yahoo.tensor.DimensionSizes;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.PartialAddress;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+import java.util.function.DoubleBinaryOperator;
+
+/**
+ * The <i>merge</i> tensor operation produces from two argument tensors having equal dimension names
+ * a tensor having the same dimension names, with each dimension the smallest (see below) which can encompass all the
+ * values of both tensors, and where the values are the union of the values of both tensors. In the cases where both
+ * tensors contain a value for a given cell, and only then, the lambda scalar expression is evaluated to produce
+ * the resulting cell value. If none of the argument tensors have a value (but the cell exists due to merging
+ * indexed dimensions of uneven size in multidimensional tensors) the resulting cell is 0.
+ * <p>
+ * The type of each dimension of the result tensor is determined as follows:
+ * <ul>
+ * <li>If at least one of the two argument dimensions are mapped, the resulting dimension is mapped.
+ * <li>Otherwise, the size of the resulting (indexed) dimension is the max size of the argument dimensions.
+ * </ul>
+ *
+ * @author bratseth
+ */
+public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
+
+ private final TensorFunction<NAMETYPE> argumentA, argumentB;
+ private final DoubleBinaryOperator merger;
+
+ public Merge(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, DoubleBinaryOperator merger) {
+ Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
+ Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
+ Objects.requireNonNull(merger, "The merger function cannot be null");
+ this.argumentA = argumentA;
+ this.argumentB = argumentB;
+ this.merger = merger;
+ }
+
+ /** Returns the type resulting from applying Merge to the two given types */
+ public static TensorType outputType(TensorType a, TensorType b) {
+ if ( ! a.dimensionNames().equals(b.dimensionNames()))
+ throw new IllegalArgumentException("Cannot merge " + a + " and " + b +
+ ": Both arguments must have the same dimension names");
+
+ TensorType.Builder builder = new TensorType.Builder(TensorType.combinedValueType(a, b));
+ for (TensorType.Dimension dimension : a.dimensions())
+ builder.set(dimension.combineWith(b.dimension(dimension.name()), false));
+ return builder.build();
+ }
+
+ public DoubleBinaryOperator merger() { return merger; }
+
+ @Override
+ public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); }
+
+ @Override
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
+ if ( arguments.size() != 2)
+ throw new IllegalArgumentException("Merge must have 2 arguments, got " + arguments.size());
+ return new Merge<>(arguments.get(0), arguments.get(1), merger);
+ }
+
+ @Override
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Merge<>(argumentA.toPrimitive(), argumentB.toPrimitive(), merger);
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "merge(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + merger + ")";
+ }
+
+ @Override
+ public TensorType type(TypeContext<NAMETYPE> context) {
+ return outputType(argumentA.type(context), argumentB.type(context));
+ }
+
+ @Override
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ Tensor a = argumentA.evaluate(context);
+ Tensor b = argumentB.evaluate(context);
+ TensorType mergedType = outputType(a.type(), b.type());
+ return evaluate(a, b, mergedType, merger);
+ }
+
+ static Tensor evaluate(Tensor a, Tensor b, TensorType mergedType, DoubleBinaryOperator combinator) {
+ // Choose merge algorithm
+ if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name()))
+ return indexedVectorMerge((IndexedTensor)a, (IndexedTensor)b, mergedType, combinator);
+ else
+ return generalMerge(a, b, mergedType, combinator);
+ }
+
+ private static boolean hasSingleIndexedDimension(Tensor tensor) {
+ return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
+ }
+
+ private static Tensor indexedVectorMerge(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) {
+ long aSize = a.dimensionSizes().size(0);
+ long bSize = b.dimensionSizes().size(0);
+ long mergedSize = Math.max(aSize, bSize);
+ long sharedSize = Math.min(aSize, bSize);
+ Iterator<Double> aIterator = a.valueIterator();
+ Iterator<Double> bIterator = b.valueIterator();
+ IndexedTensor.Builder builder = IndexedTensor.Builder.of(type);
+ for (long i = 0; i < sharedSize; i++)
+ builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i);
+ Iterator<Double> largestIterator = aSize > bSize ? aIterator : bIterator;
+ for (long i = sharedSize; i < mergedSize; i++)
+ builder.cell(largestIterator.next(), i);
+ return builder.build();
+ }
+
+ private static Tensor generalMerge(Tensor a, Tensor b, TensorType mergedType, DoubleBinaryOperator combinator) {
+ Tensor.Builder builder = Tensor.Builder.of(mergedType);
+ addCellsOf(a, b, builder, combinator);
+ addCellsOf(b, a, builder, null);
+ return builder.build();
+ }
+
+ private static void addCellsOf(Tensor a, Tensor b, Tensor.Builder builder, DoubleBinaryOperator combinator) {
+ for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) {
+ Map.Entry<TensorAddress, Double> aCell = i.next();
+ double bCellValue = b.get(aCell.getKey());
+ if (Double.isNaN(bCellValue))
+ builder.cell(aCell.getKey(), aCell.getValue());
+ else if (combinator != null)
+ builder.cell(aCell.getKey(), combinator.applyAsDouble(aCell.getValue(), bCellValue));
+ }
+ }
+
+}
+
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 7932f90d797..43f9b976978 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -216,6 +216,22 @@ public class TensorTestCase {
Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
Tensor.from("tensor(x{},y[3])", "{}"),
Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x[2]):[5,6]"),
+ Tensor.from("tensor(x[4]):[1,2,3,4]"),
+ Tensor.from("tensor(x[4]):[1,2,3,4]"));
+ assertTensorMerge(
+ Tensor.from("tensor(x[4]):[1,2,3,4]"),
+ Tensor.from("tensor(x[2]):[5,6]"),
+ Tensor.from("tensor(x[4]):[5,6,3,4]"));
+ assertTensorMerge(
+ Tensor.from("tensor(x[4],y[2]):[[1,2],[3,4],[5,6],[7,8]]"),
+ Tensor.from("tensor(x[2],y[3]):[[9,10,11],[12,13,14]]"),
+ Tensor.from("tensor(x[4],y[3]):[[9,10,11],[12,13,14],[5,6,0],[7,8,0]]"));
+ assertTensorMerge(
+ Tensor.from("tensor(key{},x[4]):{a:[1,2,3,4],c:[5,6,7,8]}"),
+ Tensor.from("tensor(key{},x[2]):{a:[9,10],b:[11,12]}"),
+ Tensor.from("tensor(key{},x[4]):{a:[9,10,3,4],b:[11,12,0,0],c:[5,6,7,8]}"));
}
@Test
@@ -302,7 +318,7 @@ public class TensorTestCase {
private void assertTensorMerge(Tensor init, Tensor update, Tensor expected) {
DoubleBinaryOperator op = (left, right) -> right;
- assertEquals(expected, init.merge(op, update.cells()));
+ assertEquals(expected, init.merge(update, op));
}
private void assertTensorRemove(Tensor init, Tensor update, Tensor expected) {