diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-02 15:44:58 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-02 15:44:58 +0100 |
commit | ded9e870509772e87e7fe42d888d20246e3c7d03 (patch) | |
tree | 7b5391edce2301b9018cb9d74d73822040271b00 | |
parent | 07b29b192fa5e373a90fe0c7e6661f9e8024577e (diff) |
Add concat function
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 29 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 5 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java | 21 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorType.java | 122 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 147 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java | 24 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java | 39 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java (renamed from vespajlib/src/test/java/com/yahoo/tensor/JoinTestCase.java) | 3 |
8 files changed, 319 insertions, 71 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index e99e7da7415..f19097da6bd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -76,6 +76,11 @@ public class IndexedTensor implements Tensor { return new SuperspaceIterator(dimensions, dimensionSizes); } + /** Returns a subspace iterator having the sizes of the dimensions of this tensor */ + public Iterator<SubspaceIterator> subspaceIterator(Set<String> dimensions) { + return subspaceIterator(dimensions, dimensionSizes); + } + /** * Returns the value at the given indexes * @@ -526,7 +531,11 @@ public class IndexedTensor implements Tensor { */ public final class SubspaceIterator implements Iterator<Map.Entry<TensorAddress, Double>> { - private final Indexes indexes; + private final boolean[] dimensionIndexes; + private final int[] address; + private final int[] dimensionSizes; + + private Indexes indexes; private int count = 0; /** @@ -537,6 +546,9 @@ public class IndexedTensor implements Tensor { * @param address the address of the first cell of this subspace. */ private SubspaceIterator(boolean[] dimensionIndexes, int[] address, int[] dimensionSizes) { + this.dimensionIndexes = dimensionIndexes; + this.address = address; + this.dimensionSizes = dimensionSizes; this.indexes = Indexes.of(dimensionSizes, dimensionIndexes, address); } @@ -545,9 +557,20 @@ public class IndexedTensor implements Tensor { return indexes.size(); } + /** Returns the address of the cell this currently points to (which may be an invalid position) */ + public TensorAddress address() { return indexes.toAddress(); } + + /** Rewind this iterator to the first element */ + public void reset() { + this.count = 0; + this.indexes = Indexes.of(dimensionSizes, dimensionIndexes, address); + } + @Override - public boolean hasNext() { return count < indexes.size(); } - + public boolean hasNext() { + return count < indexes.size(); + } + @Override public Map.Entry<TensorAddress, Double> next() { if ( ! hasNext()) throw new NoSuchElementException("No cell at " + indexes); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index d81ba5ae128..bbe6cf7d017 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -4,6 +4,7 @@ package com.yahoo.tensor; import com.google.common.annotations.Beta; import com.yahoo.tensor.functions.Argmax; import com.yahoo.tensor.functions.Argmin; +import com.yahoo.tensor.functions.Concat; import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Diag; import com.yahoo.tensor.functions.Generate; @@ -111,6 +112,10 @@ public interface Tensor { Collections.singletonList(toDimension)).evaluate(); } + default Tensor concat(Tensor argument, String dimension) { + return new Concat(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate(); + } + default Tensor rename(List<String> fromDimensions, List<String> toDimensions) { return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 5224d9632ec..1c5eec01834 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -15,7 +15,8 @@ import java.util.Set; /** * An immutable address to a tensor cell. This simply supplies a value to each dimension - * in a particular tensor type. + * in a particular tensor type. As it is just a list of cell labels, it has no independenty meaning without + * its accompanying type. * * @author bratseth */ @@ -50,8 +51,10 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { */ public abstract int intLabel(int i); - public final boolean isEmpty() { return size() == 0; } + public abstract TensorAddress withLabel(int labelIndex, int label); + public final boolean isEmpty() { return size() == 0; } + @Override public int compareTo(TensorAddress other) { // TODO: Formal issue (only): Ordering with different address sizes @@ -118,6 +121,13 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { throw new IllegalArgumentException("Expected an int label in " + this + " at position " + i); } } + + @Override + public TensorAddress withLabel(int index, int label) { + String[] labels = Arrays.copyOf(this.labels, this.labels.length); + labels[index] = String.valueOf(label); + return new StringTensorAddress(labels); + } @Override public String toString() { @@ -144,6 +154,13 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { public int intLabel(int i) { return labels[i]; } @Override + public TensorAddress withLabel(int index, int label) { + int[] labels = Arrays.copyOf(this.labels, this.labels.length); + labels[index] = label; + return new IntTensorAddress(labels); + } + + @Override public String toString() { return Arrays.toString(labels); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 5645ba6eb8e..82f36972a47 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -56,51 +56,6 @@ public class TensorType { /** Returns true if all dimensions of this are indexed */ public boolean isIndexed() { return dimensions().stream().allMatch(d -> d.isIndexed()); } - private static final boolean supportsMixedTypes = false; - - /** - * Returns a new tensor type which is the combination of the dimensions of both arguments. - * If the same dimension is indexed with different size restrictions the largest size will be used. - * If it is size restricted in one argument but not the other it will not be size restricted. - * If it is indexed in one and mapped in the other it will become mapped. - */ - public TensorType combineWith(TensorType other) { - if ( ! supportsMixedTypes) return combineWithAndDisallowMixedTypes(other); // TODO: Support it - - if (this.equals(other)) return this; - - TensorType.Builder b = new TensorType.Builder(); - for (Dimension thisDimension : dimensions) - b.add(thisDimension); - for (Dimension otherDimension : other.dimensions) { - Dimension thisDimension = b.dimensions.get(otherDimension.name()); - b.set(otherDimension.combineWith(Optional.ofNullable(thisDimension))); - } - return b.build(); - } - - private TensorType combineWithAndDisallowMixedTypes(TensorType other) { - if (this.equals(other)) return this; - - boolean containsMapped = dimensions().stream().anyMatch(d -> ! d.isIndexed()); - containsMapped = containsMapped || other.dimensions().stream().anyMatch(d -> ! d.isIndexed()); - - TensorType.Builder b = new TensorType.Builder(); - for (Dimension thisDimension : dimensions) { - if (containsMapped) - thisDimension = new MappedDimension(thisDimension.name()); - b.add(thisDimension); - } - for (Dimension otherDimension : other.dimensions) { - if (containsMapped) - otherDimension = new MappedDimension(otherDimension.name()); - Dimension thisDimension = b.dimensions.get(otherDimension.name()); - b.set(otherDimension.combineWith(Optional.ofNullable(thisDimension))); - } - return b.build(); - } - - /** Returns an immutable list of the dimensions of this */ public List<Dimension> dimensions() { return dimensions; } @@ -215,6 +170,10 @@ public class TensorType { return this.name.compareTo(other.name); } + public static Dimension indexed(String name, int size) { + return new IndexedBoundDimension(name, size); + } + } public static class IndexedBoundDimension extends TensorType.Dimension { @@ -310,7 +269,51 @@ public class TensorType { private final Map<String, Dimension> dimensions = new LinkedHashMap<>(); - /** Add a new dimension */ + /** Creates an empty builder */ + public Builder() { + } + + /** + * Creates a builder containing a combination of the dimensions of the given types + * + * If the same dimension is indexed with different size restrictions the largest size will be used. + * If it is size restricted in one argument but not the other it will not be size restricted. + * If it is indexed in one and mapped in the other it will become mapped. + */ + public Builder(TensorType ... types) { + for (TensorType type : types) + addDimensionsOf(type); + } + + private static final boolean supportsMixedTypes = false; + + private void addDimensionsOf(TensorType type) { + if ( ! supportsMixedTypes) { // TODO: Support it + addDimensionsOfAndDisallowMixedDimensions(type); + } + else { + for (Dimension dimension : type.dimensions) + set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name())))); + } + } + + private void addDimensionsOfAndDisallowMixedDimensions(TensorType type) { + boolean containsMapped = dimensions.values().stream().anyMatch(d -> ! d.isIndexed()); + containsMapped = containsMapped || type.dimensions().stream().anyMatch(d -> ! d.isIndexed()); + + for (Dimension dimension : type.dimensions) { + if (containsMapped) + dimension = new MappedDimension(dimension.name()); + Dimension existing = dimensions.get(dimension.name()); + set(dimension.combineWith(Optional.ofNullable(existing))); + } + } + + /** + * Adds a new dimension to this + * + * @throws IllegalArgumentException if the dimension is already present + */ private Builder add(Dimension dimension) { Objects.requireNonNull(dimension, "A dimension cannot be null"); if (dimensions.containsKey(dimension.name())) @@ -320,28 +323,47 @@ public class TensorType { return this; } - /** Add or replace a dimension */ - private Builder set(Dimension dimension) { + /** Adds or replaces a dimension in this */ + public Builder set(Dimension dimension) { Objects.requireNonNull(dimension, "A dimension cannot be null"); dimensions.put(dimension.name(), dimension); return this; } - /** Create a bound indexed dimension */ + /** + * Adds a bound indexed dimension to this + * + * @throws IllegalArgumentException if the dimension is already present + */ public Builder indexed(String name, int size) { return add(new IndexedBoundDimension(name, size)); } - /** Create an unbound indexed dimension */ + /** + * Adds an unbound indexed dimension to this + * + * @throws IllegalArgumentException if the dimension is already present + */ public Builder indexed(String name) { return add(new IndexedUnboundDimension(name)); } + /** + * Adds a mapped dimension to this + * + * @throws IllegalArgumentException if the dimension is already present + */ public Builder mapped(String name) { return add(new MappedDimension(name)); } + /** Adds the give dimension */ public Builder dimension(Dimension dimension) { return add(dimension); } + + /** Returns the given dimension, or empty if none is present */ + public Optional<Dimension> getDimension(String dimension) { + return Optional.ofNullable(dimensions.get(dimension)); + } public Builder dimension(String name, Dimension.Type type) { switch (type) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index a39f46e5a73..a875b392de7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -2,12 +2,14 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; -import java.util.List; -import java.util.Optional; +import java.util.*; +import java.util.stream.Collectors; /** * Concatenation of two tensors along an (indexed) dimension @@ -21,6 +23,9 @@ public class Concat extends PrimitiveTensorFunction { private final String dimension; public Concat(TensorFunction argumentA, TensorFunction argumentB, String dimension) { + Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); + Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); + Objects.requireNonNull(dimension, "The dimension cannot be null"); this.argumentA = argumentA; this.argumentB = argumentB; this.dimension = dimension; @@ -50,9 +55,141 @@ public class Concat extends PrimitiveTensorFunction { public Tensor evaluate(EvaluationContext context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); - Optional<TensorType.Dimension> aDimension = a.type().dimension(dimension); - Optional<TensorType.Dimension> bDimension = a.type().dimension(dimension); - throw new UnsupportedOperationException("Not implemented"); // TODO + a = ensureIndexedDimension(dimension, a); + b = ensureIndexedDimension(dimension, b); + + IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor + IndexedTensor bIndexed = (IndexedTensor) b; + + TensorType concatType = concatType(a, b); + int[] concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); + + Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); + int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(aIndexed::size).orElseThrow(RuntimeException::new); + int[] aToIndexes = mapIndexes(a.type(), concatType); + int[] bToIndexes = mapIndexes(b.type(), concatType); + System.out.println("Concatenating " + a + " to " + b); + concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder); + System.out.println("Concatenating " + b + " to " + a); + concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder); + return builder.build(); + } + + private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType, + int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) { + Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet()); + for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) { + IndexedTensor.SubspaceIterator iaSubspace = ia.next(); + TensorAddress aAddress = iaSubspace.address(); + for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) { + IndexedTensor.SubspaceIterator ibSubspace = ib.next(); + System.out.println(" Producing concatenation along '" + dimension + " starting at b address" + ibSubspace.address()); + while (ibSubspace.hasNext()) { + java.util.Map.Entry<TensorAddress, Double> bCell = ibSubspace.next(); // TODO: Create Cell convenience subclass for Map.Entry + TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes, + concatType, offset, dimension); + if (combinedAddress == null) continue; // incompatible + + System.out.println(" Setting " + combinedAddress + " = " + bCell.getValue()); + builder.cell(combinedAddress, bCell.getValue()); + } + iaSubspace.reset(); + } + } + } + + private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor) { + Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName); + if ( dimension.isPresent() ) { + if ( ! dimension.get().isIndexed()) + throw new IllegalArgumentException("Concat in dimension '" + dimensionName + + "' requires that dimension to be indexed or absent, " + + "but got a tensor with type " + tensor.type()); + return tensor; + } + else { // extend tensor with this dimension + if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) + throw new IllegalArgumentException("Concat requires an indexed tensor, " + + "but got a tensor with type " + tensor.type()); + Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build(); + return tensor.multiply(unitTensor); + } + + } + + /** Returns the type resulting from concatenating a and b */ + private TensorType concatType(Tensor a, Tensor b) { + TensorType.Builder builder = new TensorType.Builder(a.type(), b.type()); + if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size + builder.set(TensorType.Dimension.indexed(dimension, a.type().dimension(dimension).get().size().get() + + b.type().dimension(dimension).get().size().get())); + return builder.build(); + } + + /** Returns the concrete (not type) dimension sizes resulting from combining a and b */ + private int[] concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) { + int[] joinedSizes = new int[concatType.dimensions().size()]; + for (int i = 0; i < joinedSizes.length; i++) { + String currentDimension = concatType.dimensions().get(i).name(); + int aSize = a.type().indexOfDimension(currentDimension).map(a::size).orElse(0); + int bSize = b.type().indexOfDimension(currentDimension).map(b::size).orElse(0); + if (currentDimension.equals(concatDimension)) + joinedSizes[i] = aSize + bSize; + else + joinedSizes[i] = Math.max(aSize, bSize); + } + return joinedSizes; + } + + /** + * Combine two addresses, adding the offset to the concat dimension + * + * @return the combined address or null if the addresses are incompatible + * (in some other dimension than the concat dimension) + */ + private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, + TensorType concatType, int concatOffset, String concatDimension) { + String[] joinedLabels = new String[concatType.dimensions().size()]; + int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get(); + mapContent(a, joinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension + boolean compatible = mapContent(b, joinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here + if ( ! compatible) return null; + return TensorAddress.of(joinedLabels); + } + + /** + * Returns the an array having one entry in order for each dimension of fromType + * containing the index at which toType contains the same dimension name. + * That is, if the returned array contains n at index i then + * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) + * If some dimension in fromType is not present in toType, the corresponding index will be -1 + */ + // TODO: Stolen from join - put on TensorType? + private int[] mapIndexes(TensorType fromType, TensorType toType) { + int[] toIndexes = new int[fromType.dimensions().size()]; + for (int i = 0; i < fromType.dimensions().size(); i++) + toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1); + return toIndexes; + } + + /** + * Maps the content in the given list to the given array, using the given index map. + * + * @return true if the mapping was successful, false if one of the destination positions was + * occupied by a different value + */ + private boolean mapContent(TensorAddress from, String[] to, int[] indexMap, int concatDimension, int concatOffset) { + for (int i = 0; i < from.size(); i++) { + int toIndex = indexMap[i]; + if (concatDimension == toIndex) { + to[toIndex] = String.valueOf(from.intLabel(i) + concatOffset); + } + else { + if (to[toIndex] != null && !to[toIndex].equals(from.label(i))) return false; + to[toIndex] = from.label(i); + } + } + return true; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index ebec5efa436..3e747819b7b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -68,7 +68,7 @@ public class Join extends PrimitiveTensorFunction { public Tensor evaluate(EvaluationContext context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); - TensorType joinedType = a.type().combineWith(b.type()); + TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build(); // Choose join algorithm if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name())) @@ -121,15 +121,7 @@ public class Join extends PrimitiveTensorFunction { if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes return Tensor.Builder.of(joinedType, new int[joinedType.dimensions().size()]).build(); - // Find size of joined tensor - int[] joinedSizes = new int[joinedType.dimensions().size()]; - for (int i = 0; i < joinedSizes.length; i++) { - Optional<Integer> subspaceIndex = subspace.type().indexOfDimension(joinedType.dimensions().get(i).name()); - if (subspaceIndex.isPresent()) - joinedSizes[i] = Math.min(superspace.size(i), subspace.size(subspaceIndex.get())); - else - joinedSizes[i] = superspace.size(i); - } + int[] joinedSizes = joinedSize(joinedType, subspace, superspace); Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSizes); @@ -146,6 +138,18 @@ public class Join extends PrimitiveTensorFunction { return builder.build(); } + + private int[] joinedSize(TensorType joinedType, IndexedTensor subspace, IndexedTensor superspace) { + int[] joinedSizes = new int[joinedType.dimensions().size()]; + for (int i = 0; i < joinedSizes.length; i++) { + Optional<Integer> subspaceIndex = subspace.type().indexOfDimension(joinedType.dimensions().get(i).name()); + if (subspaceIndex.isPresent()) + joinedSizes[i] = Math.min(superspace.size(i), subspace.size(subspaceIndex.get())); + else + joinedSizes[i] = superspace.size(i); + } + return joinedSizes; + } private Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type()); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java new file mode 100644 index 00000000000..69f2c710d7a --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java @@ -0,0 +1,39 @@ +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class ConcatTestCase { + + @Test + public void testConcat() { + { + Tensor a = Tensor.from("{1}"); + Tensor b = Tensor.from("{2}"); + assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }"), a.concat(b, "x")); + assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:2, {x:1}:1 }"), b.concat(a, "x")); + } + + { + Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2, {x:2}:3 }"); + Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); + assertEquals(Tensor.from("tensor(x[6]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4, {x:4}:5, {x:5}:6 }"), a.concat(b, "x")); + assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " + + "{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }"), a.concat(b, "y")); + } + + { + Tensor a = Tensor.from("{1}"); + Tensor b = Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:3, {x:2}:4 }"); + assertEquals(Tensor.from("tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }"), a.concat(b, "x")); + assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " + + "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }"), a.concat(b, "y")); + } + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/JoinTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java index 63dd4a4a644..f2b55c74066 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/JoinTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java @@ -1,5 +1,6 @@ -package com.yahoo.tensor; +package com.yahoo.tensor.functions; +import com.yahoo.tensor.Tensor; import org.junit.Test; import static org.junit.Assert.assertEquals; |