summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-01-02 15:44:58 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-01-02 15:44:58 +0100
commitded9e870509772e87e7fe42d888d20246e3c7d03 (patch)
tree7b5391edce2301b9018cb9d74d73822040271b00 /vespajlib/src/main/java/com/yahoo/tensor
parent07b29b192fa5e373a90fe0c7e6661f9e8024577e (diff)
Add concat function
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java29
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java21
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java122
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java147
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java24
6 files changed, 278 insertions, 70 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index e99e7da7415..f19097da6bd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -76,6 +76,11 @@ public class IndexedTensor implements Tensor {
return new SuperspaceIterator(dimensions, dimensionSizes);
}
+ /** Returns a subspace iterator having the sizes of the dimensions of this tensor */
+ public Iterator<SubspaceIterator> subspaceIterator(Set<String> dimensions) {
+ return subspaceIterator(dimensions, dimensionSizes);
+ }
+
/**
* Returns the value at the given indexes
*
@@ -526,7 +531,11 @@ public class IndexedTensor implements Tensor {
*/
public final class SubspaceIterator implements Iterator<Map.Entry<TensorAddress, Double>> {
- private final Indexes indexes;
+ private final boolean[] dimensionIndexes;
+ private final int[] address;
+ private final int[] dimensionSizes;
+
+ private Indexes indexes;
private int count = 0;
/**
@@ -537,6 +546,9 @@ public class IndexedTensor implements Tensor {
* @param address the address of the first cell of this subspace.
*/
private SubspaceIterator(boolean[] dimensionIndexes, int[] address, int[] dimensionSizes) {
+ this.dimensionIndexes = dimensionIndexes;
+ this.address = address;
+ this.dimensionSizes = dimensionSizes;
this.indexes = Indexes.of(dimensionSizes, dimensionIndexes, address);
}
@@ -545,9 +557,20 @@ public class IndexedTensor implements Tensor {
return indexes.size();
}
+ /** Returns the address of the cell this currently points to (which may be an invalid position) */
+ public TensorAddress address() { return indexes.toAddress(); }
+
+ /** Rewind this iterator to the first element */
+ public void reset() {
+ this.count = 0;
+ this.indexes = Indexes.of(dimensionSizes, dimensionIndexes, address);
+ }
+
@Override
- public boolean hasNext() { return count < indexes.size(); }
-
+ public boolean hasNext() {
+ return count < indexes.size();
+ }
+
@Override
public Map.Entry<TensorAddress, Double> next() {
if ( ! hasNext()) throw new NoSuchElementException("No cell at " + indexes);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index d81ba5ae128..bbe6cf7d017 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -4,6 +4,7 @@ package com.yahoo.tensor;
import com.google.common.annotations.Beta;
import com.yahoo.tensor.functions.Argmax;
import com.yahoo.tensor.functions.Argmin;
+import com.yahoo.tensor.functions.Concat;
import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Diag;
import com.yahoo.tensor.functions.Generate;
@@ -111,6 +112,10 @@ public interface Tensor {
Collections.singletonList(toDimension)).evaluate();
}
+ default Tensor concat(Tensor argument, String dimension) {
+ return new Concat(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate();
+ }
+
default Tensor rename(List<String> fromDimensions, List<String> toDimensions) {
return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 5224d9632ec..1c5eec01834 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -15,7 +15,8 @@ import java.util.Set;
/**
* An immutable address to a tensor cell. This simply supplies a value to each dimension
- * in a particular tensor type.
+ * in a particular tensor type. As it is just a list of cell labels, it has no independenty meaning without
+ * its accompanying type.
*
* @author bratseth
*/
@@ -50,8 +51,10 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
*/
public abstract int intLabel(int i);
- public final boolean isEmpty() { return size() == 0; }
+ public abstract TensorAddress withLabel(int labelIndex, int label);
+ public final boolean isEmpty() { return size() == 0; }
+
@Override
public int compareTo(TensorAddress other) {
// TODO: Formal issue (only): Ordering with different address sizes
@@ -118,6 +121,13 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
throw new IllegalArgumentException("Expected an int label in " + this + " at position " + i);
}
}
+
+ @Override
+ public TensorAddress withLabel(int index, int label) {
+ String[] labels = Arrays.copyOf(this.labels, this.labels.length);
+ labels[index] = String.valueOf(label);
+ return new StringTensorAddress(labels);
+ }
@Override
public String toString() {
@@ -144,6 +154,13 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
public int intLabel(int i) { return labels[i]; }
@Override
+ public TensorAddress withLabel(int index, int label) {
+ int[] labels = Arrays.copyOf(this.labels, this.labels.length);
+ labels[index] = label;
+ return new IntTensorAddress(labels);
+ }
+
+ @Override
public String toString() {
return Arrays.toString(labels);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 5645ba6eb8e..82f36972a47 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -56,51 +56,6 @@ public class TensorType {
/** Returns true if all dimensions of this are indexed */
public boolean isIndexed() { return dimensions().stream().allMatch(d -> d.isIndexed()); }
- private static final boolean supportsMixedTypes = false;
-
- /**
- * Returns a new tensor type which is the combination of the dimensions of both arguments.
- * If the same dimension is indexed with different size restrictions the largest size will be used.
- * If it is size restricted in one argument but not the other it will not be size restricted.
- * If it is indexed in one and mapped in the other it will become mapped.
- */
- public TensorType combineWith(TensorType other) {
- if ( ! supportsMixedTypes) return combineWithAndDisallowMixedTypes(other); // TODO: Support it
-
- if (this.equals(other)) return this;
-
- TensorType.Builder b = new TensorType.Builder();
- for (Dimension thisDimension : dimensions)
- b.add(thisDimension);
- for (Dimension otherDimension : other.dimensions) {
- Dimension thisDimension = b.dimensions.get(otherDimension.name());
- b.set(otherDimension.combineWith(Optional.ofNullable(thisDimension)));
- }
- return b.build();
- }
-
- private TensorType combineWithAndDisallowMixedTypes(TensorType other) {
- if (this.equals(other)) return this;
-
- boolean containsMapped = dimensions().stream().anyMatch(d -> ! d.isIndexed());
- containsMapped = containsMapped || other.dimensions().stream().anyMatch(d -> ! d.isIndexed());
-
- TensorType.Builder b = new TensorType.Builder();
- for (Dimension thisDimension : dimensions) {
- if (containsMapped)
- thisDimension = new MappedDimension(thisDimension.name());
- b.add(thisDimension);
- }
- for (Dimension otherDimension : other.dimensions) {
- if (containsMapped)
- otherDimension = new MappedDimension(otherDimension.name());
- Dimension thisDimension = b.dimensions.get(otherDimension.name());
- b.set(otherDimension.combineWith(Optional.ofNullable(thisDimension)));
- }
- return b.build();
- }
-
-
/** Returns an immutable list of the dimensions of this */
public List<Dimension> dimensions() { return dimensions; }
@@ -215,6 +170,10 @@ public class TensorType {
return this.name.compareTo(other.name);
}
+ public static Dimension indexed(String name, int size) {
+ return new IndexedBoundDimension(name, size);
+ }
+
}
public static class IndexedBoundDimension extends TensorType.Dimension {
@@ -310,7 +269,51 @@ public class TensorType {
private final Map<String, Dimension> dimensions = new LinkedHashMap<>();
- /** Add a new dimension */
+ /** Creates an empty builder */
+ public Builder() {
+ }
+
+ /**
+ * Creates a builder containing a combination of the dimensions of the given types
+ *
+ * If the same dimension is indexed with different size restrictions the largest size will be used.
+ * If it is size restricted in one argument but not the other it will not be size restricted.
+ * If it is indexed in one and mapped in the other it will become mapped.
+ */
+ public Builder(TensorType ... types) {
+ for (TensorType type : types)
+ addDimensionsOf(type);
+ }
+
+ private static final boolean supportsMixedTypes = false;
+
+ private void addDimensionsOf(TensorType type) {
+ if ( ! supportsMixedTypes) { // TODO: Support it
+ addDimensionsOfAndDisallowMixedDimensions(type);
+ }
+ else {
+ for (Dimension dimension : type.dimensions)
+ set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name()))));
+ }
+ }
+
+ private void addDimensionsOfAndDisallowMixedDimensions(TensorType type) {
+ boolean containsMapped = dimensions.values().stream().anyMatch(d -> ! d.isIndexed());
+ containsMapped = containsMapped || type.dimensions().stream().anyMatch(d -> ! d.isIndexed());
+
+ for (Dimension dimension : type.dimensions) {
+ if (containsMapped)
+ dimension = new MappedDimension(dimension.name());
+ Dimension existing = dimensions.get(dimension.name());
+ set(dimension.combineWith(Optional.ofNullable(existing)));
+ }
+ }
+
+ /**
+ * Adds a new dimension to this
+ *
+ * @throws IllegalArgumentException if the dimension is already present
+ */
private Builder add(Dimension dimension) {
Objects.requireNonNull(dimension, "A dimension cannot be null");
if (dimensions.containsKey(dimension.name()))
@@ -320,28 +323,47 @@ public class TensorType {
return this;
}
- /** Add or replace a dimension */
- private Builder set(Dimension dimension) {
+ /** Adds or replaces a dimension in this */
+ public Builder set(Dimension dimension) {
Objects.requireNonNull(dimension, "A dimension cannot be null");
dimensions.put(dimension.name(), dimension);
return this;
}
- /** Create a bound indexed dimension */
+ /**
+ * Adds a bound indexed dimension to this
+ *
+ * @throws IllegalArgumentException if the dimension is already present
+ */
public Builder indexed(String name, int size) { return add(new IndexedBoundDimension(name, size)); }
- /** Create an unbound indexed dimension */
+ /**
+ * Adds an unbound indexed dimension to this
+ *
+ * @throws IllegalArgumentException if the dimension is already present
+ */
public Builder indexed(String name) {
return add(new IndexedUnboundDimension(name));
}
+ /**
+ * Adds a mapped dimension to this
+ *
+ * @throws IllegalArgumentException if the dimension is already present
+ */
public Builder mapped(String name) {
return add(new MappedDimension(name));
}
+ /** Adds the give dimension */
public Builder dimension(Dimension dimension) {
return add(dimension);
}
+
+ /** Returns the given dimension, or empty if none is present */
+ public Optional<Dimension> getDimension(String dimension) {
+ return Optional.ofNullable(dimensions.get(dimension));
+ }
public Builder dimension(String name, Dimension.Type type) {
switch (type) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index a39f46e5a73..a875b392de7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -2,12 +2,14 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
-import java.util.List;
-import java.util.Optional;
+import java.util.*;
+import java.util.stream.Collectors;
/**
* Concatenation of two tensors along an (indexed) dimension
@@ -21,6 +23,9 @@ public class Concat extends PrimitiveTensorFunction {
private final String dimension;
public Concat(TensorFunction argumentA, TensorFunction argumentB, String dimension) {
+ Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
+ Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
+ Objects.requireNonNull(dimension, "The dimension cannot be null");
this.argumentA = argumentA;
this.argumentB = argumentB;
this.dimension = dimension;
@@ -50,9 +55,141 @@ public class Concat extends PrimitiveTensorFunction {
public Tensor evaluate(EvaluationContext context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
- Optional<TensorType.Dimension> aDimension = a.type().dimension(dimension);
- Optional<TensorType.Dimension> bDimension = a.type().dimension(dimension);
- throw new UnsupportedOperationException("Not implemented"); // TODO
+ a = ensureIndexedDimension(dimension, a);
+ b = ensureIndexedDimension(dimension, b);
+
+ IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
+ IndexedTensor bIndexed = (IndexedTensor) b;
+
+ TensorType concatType = concatType(a, b);
+ int[] concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
+
+ Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
+ int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(aIndexed::size).orElseThrow(RuntimeException::new);
+ int[] aToIndexes = mapIndexes(a.type(), concatType);
+ int[] bToIndexes = mapIndexes(b.type(), concatType);
+ System.out.println("Concatenating " + a + " to " + b);
+ concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
+ System.out.println("Concatenating " + b + " to " + a);
+ concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
+ return builder.build();
+ }
+
+ private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType,
+ int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
+ Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet());
+ for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) {
+ IndexedTensor.SubspaceIterator iaSubspace = ia.next();
+ TensorAddress aAddress = iaSubspace.address();
+ for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) {
+ IndexedTensor.SubspaceIterator ibSubspace = ib.next();
+ System.out.println(" Producing concatenation along '" + dimension + " starting at b address" + ibSubspace.address());
+ while (ibSubspace.hasNext()) {
+ java.util.Map.Entry<TensorAddress, Double> bCell = ibSubspace.next(); // TODO: Create Cell convenience subclass for Map.Entry
+ TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes,
+ concatType, offset, dimension);
+ if (combinedAddress == null) continue; // incompatible
+
+ System.out.println(" Setting " + combinedAddress + " = " + bCell.getValue());
+ builder.cell(combinedAddress, bCell.getValue());
+ }
+ iaSubspace.reset();
+ }
+ }
+ }
+
+ private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor) {
+ Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName);
+ if ( dimension.isPresent() ) {
+ if ( ! dimension.get().isIndexed())
+ throw new IllegalArgumentException("Concat in dimension '" + dimensionName +
+ "' requires that dimension to be indexed or absent, " +
+ "but got a tensor with type " + tensor.type());
+ return tensor;
+ }
+ else { // extend tensor with this dimension
+ if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed()))
+ throw new IllegalArgumentException("Concat requires an indexed tensor, " +
+ "but got a tensor with type " + tensor.type());
+ Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build();
+ return tensor.multiply(unitTensor);
+ }
+
+ }
+
+ /** Returns the type resulting from concatenating a and b */
+ private TensorType concatType(Tensor a, Tensor b) {
+ TensorType.Builder builder = new TensorType.Builder(a.type(), b.type());
+ if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size
+ builder.set(TensorType.Dimension.indexed(dimension, a.type().dimension(dimension).get().size().get() +
+ b.type().dimension(dimension).get().size().get()));
+ return builder.build();
+ }
+
+ /** Returns the concrete (not type) dimension sizes resulting from combining a and b */
+ private int[] concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
+ int[] joinedSizes = new int[concatType.dimensions().size()];
+ for (int i = 0; i < joinedSizes.length; i++) {
+ String currentDimension = concatType.dimensions().get(i).name();
+ int aSize = a.type().indexOfDimension(currentDimension).map(a::size).orElse(0);
+ int bSize = b.type().indexOfDimension(currentDimension).map(b::size).orElse(0);
+ if (currentDimension.equals(concatDimension))
+ joinedSizes[i] = aSize + bSize;
+ else
+ joinedSizes[i] = Math.max(aSize, bSize);
+ }
+ return joinedSizes;
+ }
+
+ /**
+ * Combine two addresses, adding the offset to the concat dimension
+ *
+ * @return the combined address or null if the addresses are incompatible
+ * (in some other dimension than the concat dimension)
+ */
+ private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
+ TensorType concatType, int concatOffset, String concatDimension) {
+ String[] joinedLabels = new String[concatType.dimensions().size()];
+ int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
+ mapContent(a, joinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
+ boolean compatible = mapContent(b, joinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here
+ if ( ! compatible) return null;
+ return TensorAddress.of(joinedLabels);
+ }
+
+ /**
+ * Returns the an array having one entry in order for each dimension of fromType
+ * containing the index at which toType contains the same dimension name.
+ * That is, if the returned array contains n at index i then
+ * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
+ * If some dimension in fromType is not present in toType, the corresponding index will be -1
+ */
+ // TODO: Stolen from join - put on TensorType?
+ private int[] mapIndexes(TensorType fromType, TensorType toType) {
+ int[] toIndexes = new int[fromType.dimensions().size()];
+ for (int i = 0; i < fromType.dimensions().size(); i++)
+ toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
+ return toIndexes;
+ }
+
+ /**
+ * Maps the content in the given list to the given array, using the given index map.
+ *
+ * @return true if the mapping was successful, false if one of the destination positions was
+ * occupied by a different value
+ */
+ private boolean mapContent(TensorAddress from, String[] to, int[] indexMap, int concatDimension, int concatOffset) {
+ for (int i = 0; i < from.size(); i++) {
+ int toIndex = indexMap[i];
+ if (concatDimension == toIndex) {
+ to[toIndex] = String.valueOf(from.intLabel(i) + concatOffset);
+ }
+ else {
+ if (to[toIndex] != null && !to[toIndex].equals(from.label(i))) return false;
+ to[toIndex] = from.label(i);
+ }
+ }
+ return true;
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index ebec5efa436..3e747819b7b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -68,7 +68,7 @@ public class Join extends PrimitiveTensorFunction {
public Tensor evaluate(EvaluationContext context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
- TensorType joinedType = a.type().combineWith(b.type());
+ TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
// Choose join algorithm
if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name()))
@@ -121,15 +121,7 @@ public class Join extends PrimitiveTensorFunction {
if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes
return Tensor.Builder.of(joinedType, new int[joinedType.dimensions().size()]).build();
- // Find size of joined tensor
- int[] joinedSizes = new int[joinedType.dimensions().size()];
- for (int i = 0; i < joinedSizes.length; i++) {
- Optional<Integer> subspaceIndex = subspace.type().indexOfDimension(joinedType.dimensions().get(i).name());
- if (subspaceIndex.isPresent())
- joinedSizes[i] = Math.min(superspace.size(i), subspace.size(subspaceIndex.get()));
- else
- joinedSizes[i] = superspace.size(i);
- }
+ int[] joinedSizes = joinedSize(joinedType, subspace, superspace);
Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSizes);
@@ -146,6 +138,18 @@ public class Join extends PrimitiveTensorFunction {
return builder.build();
}
+
+ private int[] joinedSize(TensorType joinedType, IndexedTensor subspace, IndexedTensor superspace) {
+ int[] joinedSizes = new int[joinedType.dimensions().size()];
+ for (int i = 0; i < joinedSizes.length; i++) {
+ Optional<Integer> subspaceIndex = subspace.type().indexOfDimension(joinedType.dimensions().get(i).name());
+ if (subspaceIndex.isPresent())
+ joinedSizes[i] = Math.min(superspace.size(i), subspace.size(subspaceIndex.get()));
+ else
+ joinedSizes[i] = superspace.size(i);
+ }
+ return joinedSizes;
+ }
private Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type());