aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java395
1 files changed, 199 insertions, 196 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 6d4b15be991..abf0d89c2b7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -31,6 +31,191 @@ import java.util.stream.Collectors;
*/
public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
+ enum DimType { common, separate, concat }
+
+ private final TensorFunction<NAMETYPE> argumentA, argumentB;
+ private final String dimension;
+
+ public Concat(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, String dimension) {
+ Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
+ Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
+ Objects.requireNonNull(dimension, "The dimension cannot be null");
+ this.argumentA = argumentA;
+ this.argumentB = argumentB;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); }
+
+ @Override
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
+ if (arguments.size() != 2)
+ throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size());
+ return new Concat<>(arguments.get(0), arguments.get(1), dimension);
+ }
+
+ @Override
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Concat<>(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension);
+ }
+
+ @Override
+ public String toString(ToStringContext<NAMETYPE> context) {
+ return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")";
+ }
+
+ @Override
+ public int hashCode() { return Objects.hash("concat", argumentA, argumentB, dimension); }
+
+ @Override
+ public TensorType type(TypeContext<NAMETYPE> context) {
+ return TypeResolver.concat(argumentA.type(context), argumentB.type(context), dimension);
+ }
+
+ @Override
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ Tensor a = argumentA.evaluate(context);
+ Tensor b = argumentB.evaluate(context);
+ if (a instanceof IndexedTensor && b instanceof IndexedTensor) {
+ return oldEvaluate(a, b);
+ }
+ var helper = new Helper(a, b, dimension);
+ return helper.result;
+ }
+
+ private Tensor oldEvaluate(Tensor a, Tensor b) {
+ TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension);
+
+ a = ensureIndexedDimension(dimension, a, concatType.valueType());
+ b = ensureIndexedDimension(dimension, b, concatType.valueType());
+
+ IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
+ IndexedTensor bIndexed = (IndexedTensor) b;
+ DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
+
+ Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
+ long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
+ int[] aToIndexes = mapIndexes(a.type(), concatType);
+ int[] bToIndexes = mapIndexes(b.type(), concatType);
+ concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
+ concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
+ return builder.build();
+ }
+
+ private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType,
+ int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
+ Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet());
+ for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) {
+ IndexedTensor.SubspaceIterator iaSubspace = ia.next();
+ TensorAddress aAddress = iaSubspace.address();
+ for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) {
+ IndexedTensor.SubspaceIterator ibSubspace = ib.next();
+ while (ibSubspace.hasNext()) {
+ Tensor.Cell bCell = ibSubspace.next();
+ TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes,
+ concatType, offset, dimension);
+ if (combinedAddress == null) continue; // incompatible
+
+ builder.cell(combinedAddress, bCell.getValue());
+ }
+ iaSubspace.reset();
+ }
+ }
+ }
+
+ private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor, TensorType.Value combinedValueType) {
+ Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName);
+ if ( dimension.isPresent() ) {
+ if ( ! dimension.get().isIndexed())
+ throw new IllegalArgumentException("Concat in dimension '" + dimensionName +
+ "' requires that dimension to be indexed or absent, " +
+ "but got a tensor with type " + tensor.type());
+ return tensor;
+ }
+ else { // extend tensor with this dimension
+ if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed()))
+ throw new IllegalArgumentException("Concat requires an indexed tensor, " +
+ "but got a tensor with type " + tensor.type());
+ Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType)
+ .indexed(dimensionName, 1)
+ .build())
+ .cell(1,0)
+ .build();
+ return tensor.multiply(unitTensor);
+ }
+
+ }
+
+ /** Returns the concrete (not type) dimension sizes resulting from combining a and b */
+ private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
+ DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size());
+ for (int i = 0; i < concatSizes.dimensions(); i++) {
+ String currentDimension = concatType.dimensions().get(i).name();
+ long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L);
+ long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L);
+ if (currentDimension.equals(concatDimension))
+ concatSizes.set(i, aSize + bSize);
+ else if (aSize != 0 && bSize != 0 && aSize!=bSize )
+ concatSizes.set(i, Math.min(aSize, bSize));
+ else
+ concatSizes.set(i, Math.max(aSize, bSize));
+ }
+ return concatSizes.build();
+ }
+
+ /**
+ * Combine two addresses, adding the offset to the concat dimension
+ *
+ * @return the combined address or null if the addresses are incompatible
+ * (in some other dimension than the concat dimension)
+ */
+ private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
+ TensorType concatType, long concatOffset, String concatDimension) {
+ long[] combinedLabels = new long[concatType.dimensions().size()];
+ Arrays.fill(combinedLabels, -1);
+ int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
+ mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
+ boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here
+ if ( ! compatible) return null;
+ return TensorAddress.of(combinedLabels);
+ }
+
+ /**
+ * Returns the an array having one entry in order for each dimension of fromType
+ * containing the index at which toType contains the same dimension name.
+ * That is, if the returned array contains n at index i then
+ * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
+ * If some dimension in fromType is not present in toType, the corresponding index will be -1
+ */
+ // TODO: Stolen from join
+ private int[] mapIndexes(TensorType fromType, TensorType toType) {
+ int[] toIndexes = new int[fromType.dimensions().size()];
+ for (int i = 0; i < fromType.dimensions().size(); i++)
+ toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
+ return toIndexes;
+ }
+
+ /**
+ * Maps the content in the given list to the given array, using the given index map.
+ *
+ * @return true if the mapping was successful, false if one of the destination positions was
+ * occupied by a different value
+ */
+ private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) {
+ for (int i = 0; i < from.size(); i++) {
+ int toIndex = indexMap[i];
+ if (concatDimension == toIndex) {
+ to[toIndex] = from.numericLabel(i) + concatOffset;
+ }
+ else {
+ if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false;
+ to[toIndex] = from.numericLabel(i);
+ }
+ }
+ return true;
+ }
+
static class CellVector {
ArrayList<Double> values = new ArrayList<>();
void setValue(int ccDimIndex, double value) {
@@ -57,8 +242,6 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
- enum DimType { common, separate, concat }
-
static class SplitHow {
List<DimType> handleDims = new ArrayList<>();
long numCommon() { return handleDims.stream().filter(t -> (t == DimType.common)).count(); }
@@ -76,7 +259,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
enum CombineHow { left, right, both, concat }
List<CombineHow> combineHow = new ArrayList<>();
-
+
void aOnly(String dimName) {
if (dimName.equals(concatDimension)) {
splitInfoA.handleDims.add(DimType.concat);
@@ -160,8 +343,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
static int concatDimensionSize(CellVectorMapMap data) {
Set<Integer> sizes = new HashSet<>();
data.map.forEach((m, cvmap) ->
- cvmap.map.forEach((e, vector) ->
- sizes.add(vector.values.size())));
+ cvmap.map.forEach((e, vector) ->
+ sizes.add(vector.values.size())));
if (sizes.isEmpty()) {
return 1;
}
@@ -207,17 +390,17 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
var lhs = entry.getValue();
var rhs = b.map.get(common);
lhs.map.forEach((leftOnly, leftCells) -> {
- rhs.map.forEach((rightOnly, rightCells) -> {
- for (int i = 0; i < leftCells.values.size(); i++) {
- TensorAddress addr = combine(common, leftOnly, rightOnly, i);
- builder.cell(addr, leftCells.values.get(i));
- }
- for (int i = 0; i < rightCells.values.size(); i++) {
- TensorAddress addr = combine(common, leftOnly, rightOnly, i + aConcatSize);
- builder.cell(addr, rightCells.values.get(i));
- }
- });
+ rhs.map.forEach((rightOnly, rightCells) -> {
+ for (int i = 0; i < leftCells.values.size(); i++) {
+ TensorAddress addr = combine(common, leftOnly, rightOnly, i);
+ builder.cell(addr, leftCells.values.get(i));
+ }
+ for (int i = 0; i < rightCells.values.size(); i++) {
+ TensorAddress addr = combine(common, leftOnly, rightOnly, i + aConcatSize);
+ builder.cell(addr, rightCells.values.get(i));
+ }
});
+ });
}
}
return builder.build();
@@ -240,7 +423,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
commonLabels[commonIdx++] = addr.label(i);
break;
case separate:
- separateLabels[separateIdx++] = addr.label(i);
+ separateLabels[separateIdx++] = addr.label(i);
break;
case concat:
ccDimIndex = addr.numericLabel(i);
@@ -257,184 +440,4 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
}
- private final TensorFunction<NAMETYPE> argumentA, argumentB;
- private final String dimension;
-
- public Concat(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, String dimension) {
- Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
- Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
- Objects.requireNonNull(dimension, "The dimension cannot be null");
- this.argumentA = argumentA;
- this.argumentB = argumentB;
- this.dimension = dimension;
- }
-
- @Override
- public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); }
-
- @Override
- public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
- if (arguments.size() != 2)
- throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size());
- return new Concat<>(arguments.get(0), arguments.get(1), dimension);
- }
-
- @Override
- public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
- return new Concat<>(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension);
- }
-
- @Override
- public String toString(ToStringContext<NAMETYPE> context) {
- return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")";
- }
-
- @Override
- public TensorType type(TypeContext<NAMETYPE> context) {
- return TypeResolver.concat(argumentA.type(context), argumentB.type(context), dimension);
- }
-
- @Override
- public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
- Tensor a = argumentA.evaluate(context);
- Tensor b = argumentB.evaluate(context);
- if (a instanceof IndexedTensor && b instanceof IndexedTensor) {
- return oldEvaluate(a, b);
- }
- var helper = new Helper(a, b, dimension);
- return helper.result;
- }
-
- private Tensor oldEvaluate(Tensor a, Tensor b) {
- TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension);
-
- a = ensureIndexedDimension(dimension, a, concatType.valueType());
- b = ensureIndexedDimension(dimension, b, concatType.valueType());
-
- IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
- IndexedTensor bIndexed = (IndexedTensor) b;
- DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
-
- Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
- long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
- int[] aToIndexes = mapIndexes(a.type(), concatType);
- int[] bToIndexes = mapIndexes(b.type(), concatType);
- concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
- concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
- return builder.build();
- }
-
- private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType,
- int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
- Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet());
- for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) {
- IndexedTensor.SubspaceIterator iaSubspace = ia.next();
- TensorAddress aAddress = iaSubspace.address();
- for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) {
- IndexedTensor.SubspaceIterator ibSubspace = ib.next();
- while (ibSubspace.hasNext()) {
- Tensor.Cell bCell = ibSubspace.next();
- TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes,
- concatType, offset, dimension);
- if (combinedAddress == null) continue; // incompatible
-
- builder.cell(combinedAddress, bCell.getValue());
- }
- iaSubspace.reset();
- }
- }
- }
-
- private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor, TensorType.Value combinedValueType) {
- Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName);
- if ( dimension.isPresent() ) {
- if ( ! dimension.get().isIndexed())
- throw new IllegalArgumentException("Concat in dimension '" + dimensionName +
- "' requires that dimension to be indexed or absent, " +
- "but got a tensor with type " + tensor.type());
- return tensor;
- }
- else { // extend tensor with this dimension
- if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed()))
- throw new IllegalArgumentException("Concat requires an indexed tensor, " +
- "but got a tensor with type " + tensor.type());
- Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType)
- .indexed(dimensionName, 1)
- .build())
- .cell(1,0)
- .build();
- return tensor.multiply(unitTensor);
- }
-
- }
-
- /** Returns the concrete (not type) dimension sizes resulting from combining a and b */
- private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
- DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size());
- for (int i = 0; i < concatSizes.dimensions(); i++) {
- String currentDimension = concatType.dimensions().get(i).name();
- long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L);
- long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L);
- if (currentDimension.equals(concatDimension))
- concatSizes.set(i, aSize + bSize);
- else if (aSize != 0 && bSize != 0 && aSize!=bSize )
- concatSizes.set(i, Math.min(aSize, bSize));
- else
- concatSizes.set(i, Math.max(aSize, bSize));
- }
- return concatSizes.build();
- }
-
- /**
- * Combine two addresses, adding the offset to the concat dimension
- *
- * @return the combined address or null if the addresses are incompatible
- * (in some other dimension than the concat dimension)
- */
- private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
- TensorType concatType, long concatOffset, String concatDimension) {
- long[] combinedLabels = new long[concatType.dimensions().size()];
- Arrays.fill(combinedLabels, -1);
- int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
- mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
- boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here
- if ( ! compatible) return null;
- return TensorAddress.of(combinedLabels);
- }
-
- /**
- * Returns the an array having one entry in order for each dimension of fromType
- * containing the index at which toType contains the same dimension name.
- * That is, if the returned array contains n at index i then
- * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
- * If some dimension in fromType is not present in toType, the corresponding index will be -1
- */
- // TODO: Stolen from join
- private int[] mapIndexes(TensorType fromType, TensorType toType) {
- int[] toIndexes = new int[fromType.dimensions().size()];
- for (int i = 0; i < fromType.dimensions().size(); i++)
- toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
- return toIndexes;
- }
-
- /**
- * Maps the content in the given list to the given array, using the given index map.
- *
- * @return true if the mapping was successful, false if one of the destination positions was
- * occupied by a different value
- */
- private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) {
- for (int i = 0; i < from.size(); i++) {
- int toIndex = indexMap[i];
- if (concatDimension == toIndex) {
- to[toIndex] = from.numericLabel(i) + concatOffset;
- }
- else {
- if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false;
- to[toIndex] = from.numericLabel(i);
- }
- }
- return true;
- }
-
}