diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 395 |
1 files changed, 199 insertions, 196 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 6d4b15be991..abf0d89c2b7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -31,6 +31,191 @@ import java.util.stream.Collectors; */ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { + enum DimType { common, separate, concat } + + private final TensorFunction<NAMETYPE> argumentA, argumentB; + private final String dimension; + + public Concat(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, String dimension) { + Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); + Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); + Objects.requireNonNull(dimension, "The dimension cannot be null"); + this.argumentA = argumentA; + this.argumentB = argumentB; + this.dimension = dimension; + } + + @Override + public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); } + + @Override + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { + if (arguments.size() != 2) + throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size()); + return new Concat<>(arguments.get(0), arguments.get(1), dimension); + } + + @Override + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Concat<>(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension); + } + + @Override + public String toString(ToStringContext<NAMETYPE> context) { + return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")"; + } + + @Override + public int hashCode() { return Objects.hash("concat", argumentA, argumentB, dimension); } + + @Override + public TensorType type(TypeContext<NAMETYPE> context) { + return TypeResolver.concat(argumentA.type(context), argumentB.type(context), dimension); + } + + @Override + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { + Tensor a = argumentA.evaluate(context); + Tensor b = argumentB.evaluate(context); + if (a instanceof IndexedTensor && b instanceof IndexedTensor) { + return oldEvaluate(a, b); + } + var helper = new Helper(a, b, dimension); + return helper.result; + } + + private Tensor oldEvaluate(Tensor a, Tensor b) { + TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension); + + a = ensureIndexedDimension(dimension, a, concatType.valueType()); + b = ensureIndexedDimension(dimension, b, concatType.valueType()); + + IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor + IndexedTensor bIndexed = (IndexedTensor) b; + DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); + + Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); + long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new); + int[] aToIndexes = mapIndexes(a.type(), concatType); + int[] bToIndexes = mapIndexes(b.type(), concatType); + concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder); + concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder); + return builder.build(); + } + + private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType, + int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) { + Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet()); + for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) { + IndexedTensor.SubspaceIterator iaSubspace = ia.next(); + TensorAddress aAddress = iaSubspace.address(); + for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) { + IndexedTensor.SubspaceIterator ibSubspace = ib.next(); + while (ibSubspace.hasNext()) { + Tensor.Cell bCell = ibSubspace.next(); + TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes, + concatType, offset, dimension); + if (combinedAddress == null) continue; // incompatible + + builder.cell(combinedAddress, bCell.getValue()); + } + iaSubspace.reset(); + } + } + } + + private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor, TensorType.Value combinedValueType) { + Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName); + if ( dimension.isPresent() ) { + if ( ! dimension.get().isIndexed()) + throw new IllegalArgumentException("Concat in dimension '" + dimensionName + + "' requires that dimension to be indexed or absent, " + + "but got a tensor with type " + tensor.type()); + return tensor; + } + else { // extend tensor with this dimension + if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) + throw new IllegalArgumentException("Concat requires an indexed tensor, " + + "but got a tensor with type " + tensor.type()); + Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType) + .indexed(dimensionName, 1) + .build()) + .cell(1,0) + .build(); + return tensor.multiply(unitTensor); + } + + } + + /** Returns the concrete (not type) dimension sizes resulting from combining a and b */ + private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) { + DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size()); + for (int i = 0; i < concatSizes.dimensions(); i++) { + String currentDimension = concatType.dimensions().get(i).name(); + long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L); + long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L); + if (currentDimension.equals(concatDimension)) + concatSizes.set(i, aSize + bSize); + else if (aSize != 0 && bSize != 0 && aSize!=bSize ) + concatSizes.set(i, Math.min(aSize, bSize)); + else + concatSizes.set(i, Math.max(aSize, bSize)); + } + return concatSizes.build(); + } + + /** + * Combine two addresses, adding the offset to the concat dimension + * + * @return the combined address or null if the addresses are incompatible + * (in some other dimension than the concat dimension) + */ + private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, + TensorType concatType, long concatOffset, String concatDimension) { + long[] combinedLabels = new long[concatType.dimensions().size()]; + Arrays.fill(combinedLabels, -1); + int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get(); + mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension + boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here + if ( ! compatible) return null; + return TensorAddress.of(combinedLabels); + } + + /** + * Returns the an array having one entry in order for each dimension of fromType + * containing the index at which toType contains the same dimension name. + * That is, if the returned array contains n at index i then + * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) + * If some dimension in fromType is not present in toType, the corresponding index will be -1 + */ + // TODO: Stolen from join + private int[] mapIndexes(TensorType fromType, TensorType toType) { + int[] toIndexes = new int[fromType.dimensions().size()]; + for (int i = 0; i < fromType.dimensions().size(); i++) + toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1); + return toIndexes; + } + + /** + * Maps the content in the given list to the given array, using the given index map. + * + * @return true if the mapping was successful, false if one of the destination positions was + * occupied by a different value + */ + private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) { + for (int i = 0; i < from.size(); i++) { + int toIndex = indexMap[i]; + if (concatDimension == toIndex) { + to[toIndex] = from.numericLabel(i) + concatOffset; + } + else { + if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false; + to[toIndex] = from.numericLabel(i); + } + } + return true; + } + static class CellVector { ArrayList<Double> values = new ArrayList<>(); void setValue(int ccDimIndex, double value) { @@ -57,8 +242,6 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } - enum DimType { common, separate, concat } - static class SplitHow { List<DimType> handleDims = new ArrayList<>(); long numCommon() { return handleDims.stream().filter(t -> (t == DimType.common)).count(); } @@ -76,7 +259,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET enum CombineHow { left, right, both, concat } List<CombineHow> combineHow = new ArrayList<>(); - + void aOnly(String dimName) { if (dimName.equals(concatDimension)) { splitInfoA.handleDims.add(DimType.concat); @@ -160,8 +343,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET static int concatDimensionSize(CellVectorMapMap data) { Set<Integer> sizes = new HashSet<>(); data.map.forEach((m, cvmap) -> - cvmap.map.forEach((e, vector) -> - sizes.add(vector.values.size()))); + cvmap.map.forEach((e, vector) -> + sizes.add(vector.values.size()))); if (sizes.isEmpty()) { return 1; } @@ -207,17 +390,17 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET var lhs = entry.getValue(); var rhs = b.map.get(common); lhs.map.forEach((leftOnly, leftCells) -> { - rhs.map.forEach((rightOnly, rightCells) -> { - for (int i = 0; i < leftCells.values.size(); i++) { - TensorAddress addr = combine(common, leftOnly, rightOnly, i); - builder.cell(addr, leftCells.values.get(i)); - } - for (int i = 0; i < rightCells.values.size(); i++) { - TensorAddress addr = combine(common, leftOnly, rightOnly, i + aConcatSize); - builder.cell(addr, rightCells.values.get(i)); - } - }); + rhs.map.forEach((rightOnly, rightCells) -> { + for (int i = 0; i < leftCells.values.size(); i++) { + TensorAddress addr = combine(common, leftOnly, rightOnly, i); + builder.cell(addr, leftCells.values.get(i)); + } + for (int i = 0; i < rightCells.values.size(); i++) { + TensorAddress addr = combine(common, leftOnly, rightOnly, i + aConcatSize); + builder.cell(addr, rightCells.values.get(i)); + } }); + }); } } return builder.build(); @@ -240,7 +423,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET commonLabels[commonIdx++] = addr.label(i); break; case separate: - separateLabels[separateIdx++] = addr.label(i); + separateLabels[separateIdx++] = addr.label(i); break; case concat: ccDimIndex = addr.numericLabel(i); @@ -257,184 +440,4 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } } - private final TensorFunction<NAMETYPE> argumentA, argumentB; - private final String dimension; - - public Concat(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, String dimension) { - Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); - Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); - Objects.requireNonNull(dimension, "The dimension cannot be null"); - this.argumentA = argumentA; - this.argumentB = argumentB; - this.dimension = dimension; - } - - @Override - public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); } - - @Override - public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { - if (arguments.size() != 2) - throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size()); - return new Concat<>(arguments.get(0), arguments.get(1), dimension); - } - - @Override - public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { - return new Concat<>(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension); - } - - @Override - public String toString(ToStringContext<NAMETYPE> context) { - return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")"; - } - - @Override - public TensorType type(TypeContext<NAMETYPE> context) { - return TypeResolver.concat(argumentA.type(context), argumentB.type(context), dimension); - } - - @Override - public Tensor evaluate(EvaluationContext<NAMETYPE> context) { - Tensor a = argumentA.evaluate(context); - Tensor b = argumentB.evaluate(context); - if (a instanceof IndexedTensor && b instanceof IndexedTensor) { - return oldEvaluate(a, b); - } - var helper = new Helper(a, b, dimension); - return helper.result; - } - - private Tensor oldEvaluate(Tensor a, Tensor b) { - TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension); - - a = ensureIndexedDimension(dimension, a, concatType.valueType()); - b = ensureIndexedDimension(dimension, b, concatType.valueType()); - - IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor - IndexedTensor bIndexed = (IndexedTensor) b; - DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); - - Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); - long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new); - int[] aToIndexes = mapIndexes(a.type(), concatType); - int[] bToIndexes = mapIndexes(b.type(), concatType); - concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder); - concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder); - return builder.build(); - } - - private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType, - int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) { - Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet()); - for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) { - IndexedTensor.SubspaceIterator iaSubspace = ia.next(); - TensorAddress aAddress = iaSubspace.address(); - for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) { - IndexedTensor.SubspaceIterator ibSubspace = ib.next(); - while (ibSubspace.hasNext()) { - Tensor.Cell bCell = ibSubspace.next(); - TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes, - concatType, offset, dimension); - if (combinedAddress == null) continue; // incompatible - - builder.cell(combinedAddress, bCell.getValue()); - } - iaSubspace.reset(); - } - } - } - - private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor, TensorType.Value combinedValueType) { - Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName); - if ( dimension.isPresent() ) { - if ( ! dimension.get().isIndexed()) - throw new IllegalArgumentException("Concat in dimension '" + dimensionName + - "' requires that dimension to be indexed or absent, " + - "but got a tensor with type " + tensor.type()); - return tensor; - } - else { // extend tensor with this dimension - if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) - throw new IllegalArgumentException("Concat requires an indexed tensor, " + - "but got a tensor with type " + tensor.type()); - Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType) - .indexed(dimensionName, 1) - .build()) - .cell(1,0) - .build(); - return tensor.multiply(unitTensor); - } - - } - - /** Returns the concrete (not type) dimension sizes resulting from combining a and b */ - private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) { - DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size()); - for (int i = 0; i < concatSizes.dimensions(); i++) { - String currentDimension = concatType.dimensions().get(i).name(); - long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L); - long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L); - if (currentDimension.equals(concatDimension)) - concatSizes.set(i, aSize + bSize); - else if (aSize != 0 && bSize != 0 && aSize!=bSize ) - concatSizes.set(i, Math.min(aSize, bSize)); - else - concatSizes.set(i, Math.max(aSize, bSize)); - } - return concatSizes.build(); - } - - /** - * Combine two addresses, adding the offset to the concat dimension - * - * @return the combined address or null if the addresses are incompatible - * (in some other dimension than the concat dimension) - */ - private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, - TensorType concatType, long concatOffset, String concatDimension) { - long[] combinedLabels = new long[concatType.dimensions().size()]; - Arrays.fill(combinedLabels, -1); - int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get(); - mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension - boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here - if ( ! compatible) return null; - return TensorAddress.of(combinedLabels); - } - - /** - * Returns the an array having one entry in order for each dimension of fromType - * containing the index at which toType contains the same dimension name. - * That is, if the returned array contains n at index i then - * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) - * If some dimension in fromType is not present in toType, the corresponding index will be -1 - */ - // TODO: Stolen from join - private int[] mapIndexes(TensorType fromType, TensorType toType) { - int[] toIndexes = new int[fromType.dimensions().size()]; - for (int i = 0; i < fromType.dimensions().size(); i++) - toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1); - return toIndexes; - } - - /** - * Maps the content in the given list to the given array, using the given index map. - * - * @return true if the mapping was successful, false if one of the destination positions was - * occupied by a different value - */ - private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) { - for (int i = 0; i < from.size(); i++) { - int toIndex = indexMap[i]; - if (concatDimension == toIndex) { - to[toIndex] = from.numericLabel(i) + concatOffset; - } - else { - if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false; - to[toIndex] = from.numericLabel(i); - } - } - return true; - } - } |