summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java147
1 files changed, 142 insertions, 5 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index a39f46e5a73..a875b392de7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -2,12 +2,14 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
-import java.util.List;
-import java.util.Optional;
+import java.util.*;
+import java.util.stream.Collectors;
/**
* Concatenation of two tensors along an (indexed) dimension
@@ -21,6 +23,9 @@ public class Concat extends PrimitiveTensorFunction {
private final String dimension;
public Concat(TensorFunction argumentA, TensorFunction argumentB, String dimension) {
+ Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
+ Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
+ Objects.requireNonNull(dimension, "The dimension cannot be null");
this.argumentA = argumentA;
this.argumentB = argumentB;
this.dimension = dimension;
@@ -50,9 +55,141 @@ public class Concat extends PrimitiveTensorFunction {
public Tensor evaluate(EvaluationContext context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
- Optional<TensorType.Dimension> aDimension = a.type().dimension(dimension);
- Optional<TensorType.Dimension> bDimension = a.type().dimension(dimension);
- throw new UnsupportedOperationException("Not implemented"); // TODO
+ a = ensureIndexedDimension(dimension, a);
+ b = ensureIndexedDimension(dimension, b);
+
+ IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
+ IndexedTensor bIndexed = (IndexedTensor) b;
+
+ TensorType concatType = concatType(a, b);
+ int[] concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
+
+ Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
+ int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(aIndexed::size).orElseThrow(RuntimeException::new);
+ int[] aToIndexes = mapIndexes(a.type(), concatType);
+ int[] bToIndexes = mapIndexes(b.type(), concatType);
+ System.out.println("Concatenating " + a + " to " + b);
+ concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
+ System.out.println("Concatenating " + b + " to " + a);
+ concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
+ return builder.build();
+ }
+
+ private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType,
+ int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
+ Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet());
+ for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) {
+ IndexedTensor.SubspaceIterator iaSubspace = ia.next();
+ TensorAddress aAddress = iaSubspace.address();
+ for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) {
+ IndexedTensor.SubspaceIterator ibSubspace = ib.next();
+ System.out.println(" Producing concatenation along '" + dimension + " starting at b address" + ibSubspace.address());
+ while (ibSubspace.hasNext()) {
+ java.util.Map.Entry<TensorAddress, Double> bCell = ibSubspace.next(); // TODO: Create Cell convenience subclass for Map.Entry
+ TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes,
+ concatType, offset, dimension);
+ if (combinedAddress == null) continue; // incompatible
+
+ System.out.println(" Setting " + combinedAddress + " = " + bCell.getValue());
+ builder.cell(combinedAddress, bCell.getValue());
+ }
+ iaSubspace.reset();
+ }
+ }
+ }
+
+ private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor) {
+ Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName);
+ if ( dimension.isPresent() ) {
+ if ( ! dimension.get().isIndexed())
+ throw new IllegalArgumentException("Concat in dimension '" + dimensionName +
+ "' requires that dimension to be indexed or absent, " +
+ "but got a tensor with type " + tensor.type());
+ return tensor;
+ }
+ else { // extend tensor with this dimension
+ if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed()))
+ throw new IllegalArgumentException("Concat requires an indexed tensor, " +
+ "but got a tensor with type " + tensor.type());
+ Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build();
+ return tensor.multiply(unitTensor);
+ }
+
+ }
+
+ /** Returns the type resulting from concatenating a and b */
+ private TensorType concatType(Tensor a, Tensor b) {
+ TensorType.Builder builder = new TensorType.Builder(a.type(), b.type());
+ if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size
+ builder.set(TensorType.Dimension.indexed(dimension, a.type().dimension(dimension).get().size().get() +
+ b.type().dimension(dimension).get().size().get()));
+ return builder.build();
+ }
+
+ /** Returns the concrete (not type) dimension sizes resulting from combining a and b */
+ private int[] concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
+ int[] joinedSizes = new int[concatType.dimensions().size()];
+ for (int i = 0; i < joinedSizes.length; i++) {
+ String currentDimension = concatType.dimensions().get(i).name();
+ int aSize = a.type().indexOfDimension(currentDimension).map(a::size).orElse(0);
+ int bSize = b.type().indexOfDimension(currentDimension).map(b::size).orElse(0);
+ if (currentDimension.equals(concatDimension))
+ joinedSizes[i] = aSize + bSize;
+ else
+ joinedSizes[i] = Math.max(aSize, bSize);
+ }
+ return joinedSizes;
+ }
+
+ /**
+ * Combine two addresses, adding the offset to the concat dimension
+ *
+ * @return the combined address or null if the addresses are incompatible
+ * (in some other dimension than the concat dimension)
+ */
+ private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
+ TensorType concatType, int concatOffset, String concatDimension) {
+ String[] joinedLabels = new String[concatType.dimensions().size()];
+ int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
+ mapContent(a, joinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
+ boolean compatible = mapContent(b, joinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here
+ if ( ! compatible) return null;
+ return TensorAddress.of(joinedLabels);
+ }
+
+ /**
+ * Returns the an array having one entry in order for each dimension of fromType
+ * containing the index at which toType contains the same dimension name.
+ * That is, if the returned array contains n at index i then
+ * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
+ * If some dimension in fromType is not present in toType, the corresponding index will be -1
+ */
+ // TODO: Stolen from join - put on TensorType?
+ private int[] mapIndexes(TensorType fromType, TensorType toType) {
+ int[] toIndexes = new int[fromType.dimensions().size()];
+ for (int i = 0; i < fromType.dimensions().size(); i++)
+ toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
+ return toIndexes;
+ }
+
+ /**
+ * Maps the content in the given list to the given array, using the given index map.
+ *
+ * @return true if the mapping was successful, false if one of the destination positions was
+ * occupied by a different value
+ */
+ private boolean mapContent(TensorAddress from, String[] to, int[] indexMap, int concatDimension, int concatOffset) {
+ for (int i = 0; i < from.size(); i++) {
+ int toIndex = indexMap[i];
+ if (concatDimension == toIndex) {
+ to[toIndex] = String.valueOf(from.intLabel(i) + concatOffset);
+ }
+ else {
+ if (to[toIndex] != null && !to[toIndex].equals(from.label(i))) return false;
+ to[toIndex] = from.label(i);
+ }
+ }
+ return true;
}
}