diff options
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 241 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java | 122 |
2 files changed, 363 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 59a452588ca..959f76d40da 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -13,8 +13,12 @@ import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; +import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; @@ -27,6 +31,233 @@ import java.util.stream.Collectors; */ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { + static class CellInfo { + ArrayList<Double> values = new ArrayList<>(); + void valueAt(int ccDimIndex, double value) { + while (values.size() <= ccDimIndex) { + values.add(0.0); + } + values.set(ccDimIndex, value); + } + } + + static class CellsInfo { + Map<TensorAddress, CellInfo> map = new HashMap<>(); + + CellInfo expand(TensorAddress addr) { + if (map.containsKey(addr)) { + return map.get(addr); + } else { + CellInfo result = new CellInfo(); + map.put(addr, result); + return result; + } + } + } + + static class SideInfo { + Map<TensorAddress, CellsInfo> map = new HashMap<>(); + + CellsInfo match(TensorAddress addr) { + if (map.containsKey(addr)) { + return map.get(addr); + } else { + CellsInfo result = new CellsInfo(); + map.put(addr, result); + return result; + } + } + + int concatDimensionSize() { + Set<Integer> sizes = new HashSet<>(); + map.forEach((m, cells) -> + cells.map.forEach((e, cell) -> + sizes.add(cell.values.size()))); + if (sizes.isEmpty()) { + return 1; + } + if (sizes.size() == 1) { + return sizes.iterator().next(); + } + throw new IllegalArgumentException("inconsistent size of concat dimension, had "+sizes.size()+" different values"); + } + } + + static class Helper { + final TensorType resultType; + final String concatDimension; + + SideHow aHow = new SideHow(); + SideHow bHow = new SideHow(); + + enum CombineHow { left, right, both, concat } + + List<CombineHow> combineHow = new ArrayList<>(); + + TensorAddress combine(TensorAddress match, TensorAddress leftOnly, TensorAddress rightOnly, int concatDimIdx) { + String[] labels = new String[resultType.rank()]; + int out = 0; + int m = 0; + int a = 0; + int b = 0; + for (var how : combineHow) { + switch (how) { + case left: + labels[out++] = leftOnly.label(a++); + break; + case right: + labels[out++] = rightOnly.label(b++); + break; + case both: + labels[out++] = match.label(m++); + break; + case concat: + labels[out++] = String.valueOf(concatDimIdx); + break; + default: + throw new IllegalArgumentException("cannot handle: "+how); + } + } + return TensorAddress.of(labels); + } + + void aOnly(String dimName) { + if (dimName.equals(concatDimension)) { + aHow.handleDims.add(DimHow.concat); + combineHow.add(CombineHow.concat); + } else { + aHow.handleDims.add(DimHow.expand); + aHow.numExpand++; + combineHow.add(CombineHow.left); + } + } + + void bOnly(String dimName) { + if (dimName.equals(concatDimension)) { + bHow.handleDims.add(DimHow.concat); + combineHow.add(CombineHow.concat); + } else { + bHow.handleDims.add(DimHow.expand); + bHow.numExpand++; + combineHow.add(CombineHow.right); + } + } + + void bothAandB(String dimName) { + if (dimName.equals(concatDimension)) { + aHow.handleDims.add(DimHow.concat); + bHow.handleDims.add(DimHow.concat); + combineHow.add(CombineHow.concat); + } else { + aHow.handleDims.add(DimHow.match); + bHow.handleDims.add(DimHow.match); + aHow.numMatch++; + bHow.numMatch++; + combineHow.add(CombineHow.both); + } + } + + Helper(TensorType aType, TensorType bType, String concatDimension) { + this.resultType = TypeResolver.concat(aType, bType, concatDimension); + this.concatDimension = concatDimension; + var aDims = aType.dimensions(); + var bDims = bType.dimensions(); + int i = 0; + int j = 0; + while (i < aDims.size() && j < bDims.size()) { + String aName = aDims.get(i).name(); + String bName = bDims.get(j).name(); + int cmp = aName.compareTo(bName); + if (cmp == 0) { + bothAandB(aName); + ++i; + ++j; + } else if (cmp < 0) { + aOnly(aName); + ++i; + } else { + bOnly(bName); + ++j; + } + } + while (i < aDims.size()) { + aOnly(aDims.get(i++).name()); + } + while (j < bDims.size()) { + bOnly(bDims.get(j++).name()); + } + if (combineHow.size() < resultType.rank()) { + var idx = resultType.indexOfDimension(concatDimension); + combineHow.add(idx.get(), CombineHow.concat); + } + } + + Tensor mergeSides(SideInfo a, SideInfo b) { + var builder = Tensor.Builder.of(resultType); + int aConcatSize = a.concatDimensionSize(); + for (var entry : a.map.entrySet()) { + TensorAddress match = entry.getKey(); + if (b.map.containsKey(match)) { + var lhs = entry.getValue(); + var rhs = b.map.get(match); + lhs.map.forEach((leftExpand, leftCells) -> { + rhs.map.forEach((rightExpand, rightCells) -> { + for (int i = 0; i < leftCells.values.size(); i++) { + TensorAddress addr = combine(match, leftExpand, rightExpand, i); + builder.cell(addr, leftCells.values.get(i)); + } + for (int i = 0; i < rightCells.values.size(); i++) { + TensorAddress addr = combine(match, leftExpand, rightExpand, i + aConcatSize); + builder.cell(addr, rightCells.values.get(i)); + } + }); + }); + } + } + return builder.build(); + } + } + + enum DimHow { match, expand, concat } + + static class SideHow { + List<DimHow> handleDims = new ArrayList<>(); + int numMatch = 0; + int numExpand = 0; + } + + SideInfo analyse(Tensor side, SideHow how) { + var iter = side.cellIterator(); + String[] matchLabels = new String[how.numMatch]; + String[] expandLabels = new String[how.numExpand]; + SideInfo result = new SideInfo(); + while (iter.hasNext()) { + var cell = iter.next(); + var addr = cell.getKey(); + long ccDimIndex = 0; + int matchIdx = 0; + int expandIdx = 0; + for (int i = 0; i < how.handleDims.size(); i++) { + switch (how.handleDims.get(i)) { + case match: + matchLabels[matchIdx++] = addr.label(i); + break; + case expand: + expandLabels[expandIdx++] = addr.label(i); + break; + case concat: + ccDimIndex = addr.numericLabel(i); + break; + default: throw new IllegalArgumentException("cannot handle: "+how.handleDims.get(i)); + } + } + TensorAddress matchAddr = TensorAddress.of(matchLabels); + TensorAddress expandAddr = TensorAddress.of(expandLabels); + result.match(matchAddr).expand(expandAddr).valueAt((int)ccDimIndex, cell.getValue()); + } + return result; + } + private final TensorFunction<NAMETYPE> argumentA, argumentB; private final String dimension; @@ -68,6 +299,16 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); + // if (a instanceof IndexedTensor && b instanceof IndexedTensor) { + // return oldEvaluate(a, b); + // } + Helper helper = new Helper(a.type(), b.type(), dimension); + SideInfo aInfo = analyse(a, helper.aHow); + SideInfo bInfo = analyse(b, helper.bHow); + return helper.mergeSides(aInfo, bInfo); + } + + private Tensor oldEvaluate(Tensor a, Tensor b) { TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension); a = ensureIndexedDimension(dimension, a, concatType.valueType()); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java index 0476fe1c757..fe7d3872d23 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java @@ -95,6 +95,128 @@ public class ConcatTestCase { assertConcat("tensor(x[],y[])", "tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y"); } + @Test + public void testAdvancedMixed() { + Tensor a = Tensor.from("tensor(a[2],b[2],c{},d[2],e{}):{"+ + "{a:0,b:0,c:17,d:0,e:42}:1.0,"+ + "{a:0,b:0,c:17,d:1,e:42}:2.0,"+ + "{a:0,b:1,c:17,d:0,e:42}:3.0,"+ + "{a:0,b:1,c:17,d:1,e:42}:4.0,"+ + "{a:1,b:0,c:17,d:0,e:42}:5.0,"+ + "{a:1,b:0,c:17,d:1,e:42}:6.0,"+ + "{a:1,b:1,c:17,d:0,e:42}:7.0,"+ + "{a:1,b:1,c:17,d:1,e:42}:8.0}"); + Tensor b = Tensor.from("tensor(a[2],b[2],c{},f[2],g{}):{"+ + "{a:0,b:0,c:17,f:0,g:666}:51.0,"+ + "{a:0,b:0,c:17,f:1,g:666}:52.0,"+ + "{a:0,b:1,c:17,f:0,g:666}:53.0,"+ + "{a:0,b:1,c:17,f:1,g:666}:54.0,"+ + "{a:1,b:0,c:17,f:0,g:666}:55.0,"+ + "{a:1,b:0,c:17,f:1,g:666}:56.0,"+ + "{a:1,b:1,c:17,f:0,g:666}:57.0,"+ + "{a:1,b:1,c:17,f:1,g:666}:58.0}"); + + assertConcat("tensor(a[4],b[2],c{},d[2],e{},f[2],g{})", + "tensor(a[4],b[2],c{},d[2],e{},f[2],g{}):{"+ + "{a:0,b:0,c:17,d:0,e:42,f:0,g:666}:1.0,"+ + "{a:0,b:0,c:17,d:0,e:42,f:1,g:666}:1.0,"+ + "{a:0,b:0,c:17,d:1,e:42,f:0,g:666}:2.0,"+ + "{a:0,b:0,c:17,d:1,e:42,f:1,g:666}:2.0,"+ + "{a:0,b:1,c:17,d:0,e:42,f:0,g:666}:3.0,"+ + "{a:0,b:1,c:17,d:0,e:42,f:1,g:666}:3.0,"+ + "{a:0,b:1,c:17,d:1,e:42,f:0,g:666}:4.0,"+ + "{a:0,b:1,c:17,d:1,e:42,f:1,g:666}:4.0,"+ + "{a:1,b:0,c:17,d:0,e:42,f:0,g:666}:5.0,"+ + "{a:1,b:0,c:17,d:0,e:42,f:1,g:666}:5.0,"+ + "{a:1,b:0,c:17,d:1,e:42,f:0,g:666}:6.0,"+ + "{a:1,b:0,c:17,d:1,e:42,f:1,g:666}:6.0,"+ + "{a:1,b:1,c:17,d:0,e:42,f:0,g:666}:7.0,"+ + "{a:1,b:1,c:17,d:0,e:42,f:1,g:666}:7.0,"+ + "{a:1,b:1,c:17,d:1,e:42,f:0,g:666}:8.0,"+ + "{a:1,b:1,c:17,d:1,e:42,f:1,g:666}:8.0,"+ + "{a:2,b:0,c:17,d:0,e:42,f:0,g:666}:51.0,"+ + "{a:2,b:0,c:17,d:0,e:42,f:1,g:666}:52.0,"+ + "{a:2,b:0,c:17,d:1,e:42,f:0,g:666}:51.0,"+ + "{a:2,b:0,c:17,d:1,e:42,f:1,g:666}:52.0,"+ + "{a:2,b:1,c:17,d:0,e:42,f:0,g:666}:53.0,"+ + "{a:2,b:1,c:17,d:0,e:42,f:1,g:666}:54.0,"+ + "{a:2,b:1,c:17,d:1,e:42,f:0,g:666}:53.0,"+ + "{a:2,b:1,c:17,d:1,e:42,f:1,g:666}:54.0,"+ + "{a:3,b:0,c:17,d:0,e:42,f:0,g:666}:55.0,"+ + "{a:3,b:0,c:17,d:0,e:42,f:1,g:666}:56.0,"+ + "{a:3,b:0,c:17,d:1,e:42,f:0,g:666}:55.0,"+ + "{a:3,b:0,c:17,d:1,e:42,f:1,g:666}:56.0,"+ + "{a:3,b:1,c:17,d:0,e:42,f:0,g:666}:57.0,"+ + "{a:3,b:1,c:17,d:0,e:42,f:1,g:666}:58.0,"+ + "{a:3,b:1,c:17,d:1,e:42,f:0,g:666}:57.0,"+ + "{a:3,b:1,c:17,d:1,e:42,f:1,g:666}:58.0}", + a, b, "a"); + + assertConcat("tensor(a[2],b[2],c{},d[2],e{},f[2],g{},x[2])", + "tensor(a[2],b[2],c{},d[2],e{},f[2],g{},x[2]):{"+ + "{a:0,b:0,c:17,d:0,e:42,f:0,g:666,x:0}:1.0,"+ + "{a:0,b:0,c:17,d:0,e:42,f:1,g:666,x:0}:1.0,"+ + "{a:0,b:0,c:17,d:1,e:42,f:0,g:666,x:0}:2.0,"+ + "{a:0,b:0,c:17,d:1,e:42,f:1,g:666,x:0}:2.0,"+ + "{a:0,b:1,c:17,d:0,e:42,f:0,g:666,x:0}:3.0,"+ + "{a:0,b:1,c:17,d:0,e:42,f:1,g:666,x:0}:3.0,"+ + "{a:0,b:1,c:17,d:1,e:42,f:0,g:666,x:0}:4.0,"+ + "{a:0,b:1,c:17,d:1,e:42,f:1,g:666,x:0}:4.0,"+ + "{a:1,b:0,c:17,d:0,e:42,f:0,g:666,x:0}:5.0,"+ + "{a:1,b:0,c:17,d:0,e:42,f:1,g:666,x:0}:5.0,"+ + "{a:1,b:0,c:17,d:1,e:42,f:0,g:666,x:0}:6.0,"+ + "{a:1,b:0,c:17,d:1,e:42,f:1,g:666,x:0}:6.0,"+ + "{a:1,b:1,c:17,d:0,e:42,f:0,g:666,x:0}:7.0,"+ + "{a:1,b:1,c:17,d:0,e:42,f:1,g:666,x:0}:7.0,"+ + "{a:1,b:1,c:17,d:1,e:42,f:0,g:666,x:0}:8.0,"+ + "{a:1,b:1,c:17,d:1,e:42,f:1,g:666,x:0}:8.0,"+ + "{a:0,b:0,c:17,d:0,e:42,f:0,g:666,x:1}:51.0,"+ + "{a:0,b:0,c:17,d:0,e:42,f:1,g:666,x:1}:52.0,"+ + "{a:0,b:0,c:17,d:1,e:42,f:0,g:666,x:1}:51.0,"+ + "{a:0,b:0,c:17,d:1,e:42,f:1,g:666,x:1}:52.0,"+ + "{a:0,b:1,c:17,d:0,e:42,f:0,g:666,x:1}:53.0,"+ + "{a:0,b:1,c:17,d:0,e:42,f:1,g:666,x:1}:54.0,"+ + "{a:0,b:1,c:17,d:1,e:42,f:0,g:666,x:1}:53.0,"+ + "{a:0,b:1,c:17,d:1,e:42,f:1,g:666,x:1}:54.0,"+ + "{a:1,b:0,c:17,d:0,e:42,f:0,g:666,x:1}:55.0,"+ + "{a:1,b:0,c:17,d:0,e:42,f:1,g:666,x:1}:56.0,"+ + "{a:1,b:0,c:17,d:1,e:42,f:0,g:666,x:1}:55.0,"+ + "{a:1,b:0,c:17,d:1,e:42,f:1,g:666,x:1}:56.0,"+ + "{a:1,b:1,c:17,d:0,e:42,f:0,g:666,x:1}:57.0,"+ + "{a:1,b:1,c:17,d:0,e:42,f:1,g:666,x:1}:58.0,"+ + "{a:1,b:1,c:17,d:1,e:42,f:0,g:666,x:1}:57.0,"+ + "{a:1,b:1,c:17,d:1,e:42,f:1,g:666,x:1}:58.0}", + a, b, "x"); + + assertConcat("tensor(a[2],b[2],c{},d[3],e{},f[2],g{})", + "tensor(a[2],b[2],c{},d[3],e{},f[2],g{}):{"+ + "{a:0,b:0,c:17,d:0,e:42,f:0,g:666}:1.0,"+ + "{a:0,b:0,c:17,d:0,e:42,f:1,g:666}:1.0,"+ + "{a:0,b:0,c:17,d:1,e:42,f:0,g:666}:2.0,"+ + "{a:0,b:0,c:17,d:1,e:42,f:1,g:666}:2.0,"+ + "{a:0,b:1,c:17,d:0,e:42,f:0,g:666}:3.0,"+ + "{a:0,b:1,c:17,d:0,e:42,f:1,g:666}:3.0,"+ + "{a:0,b:1,c:17,d:1,e:42,f:0,g:666}:4.0,"+ + "{a:0,b:1,c:17,d:1,e:42,f:1,g:666}:4.0,"+ + "{a:1,b:0,c:17,d:0,e:42,f:0,g:666}:5.0,"+ + "{a:1,b:0,c:17,d:0,e:42,f:1,g:666}:5.0,"+ + "{a:1,b:0,c:17,d:1,e:42,f:0,g:666}:6.0,"+ + "{a:1,b:0,c:17,d:1,e:42,f:1,g:666}:6.0,"+ + "{a:1,b:1,c:17,d:0,e:42,f:0,g:666}:7.0,"+ + "{a:1,b:1,c:17,d:0,e:42,f:1,g:666}:7.0,"+ + "{a:1,b:1,c:17,d:1,e:42,f:0,g:666}:8.0,"+ + "{a:1,b:1,c:17,d:1,e:42,f:1,g:666}:8.0,"+ + "{a:0,b:0,c:17,d:2,e:42,f:0,g:666}:51.0,"+ + "{a:0,b:0,c:17,d:2,e:42,f:1,g:666}:52.0,"+ + "{a:0,b:1,c:17,d:2,e:42,f:0,g:666}:53.0,"+ + "{a:0,b:1,c:17,d:2,e:42,f:1,g:666}:54.0,"+ + "{a:1,b:0,c:17,d:2,e:42,f:0,g:666}:55.0,"+ + "{a:1,b:0,c:17,d:2,e:42,f:1,g:666}:56.0,"+ + "{a:1,b:1,c:17,d:2,e:42,f:0,g:666}:57.0,"+ + "{a:1,b:1,c:17,d:2,e:42,f:1,g:666}:58.0}", + a, b, "d"); + } + private void assertConcat(String expected, Tensor a, Tensor b, String dimension) { assertConcat(null, expected, a, b, dimension); } |