summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-04-26 10:05:29 +0000
committerArne Juul <arnej@verizonmedia.com>2021-04-26 11:06:53 +0000
commitda3d03db65344e927b9682b30013a524cf914522 (patch)
treee6b25ab344dca5d2c70d3fa172da1c1f78bae10a /vespajlib
parentef6f81bb8db68b97972770fede264a97a5d5140d (diff)
add very generic concat algorithm
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java241
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java122
2 files changed, 363 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 59a452588ca..959f76d40da 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -13,8 +13,12 @@ import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
import java.util.Iterator;
+import java.util.ArrayList;
import java.util.List;
+import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
@@ -27,6 +31,233 @@ import java.util.stream.Collectors;
*/
public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
+ static class CellInfo {
+ ArrayList<Double> values = new ArrayList<>();
+ void valueAt(int ccDimIndex, double value) {
+ while (values.size() <= ccDimIndex) {
+ values.add(0.0);
+ }
+ values.set(ccDimIndex, value);
+ }
+ }
+
+ static class CellsInfo {
+ Map<TensorAddress, CellInfo> map = new HashMap<>();
+
+ CellInfo expand(TensorAddress addr) {
+ if (map.containsKey(addr)) {
+ return map.get(addr);
+ } else {
+ CellInfo result = new CellInfo();
+ map.put(addr, result);
+ return result;
+ }
+ }
+ }
+
+ static class SideInfo {
+ Map<TensorAddress, CellsInfo> map = new HashMap<>();
+
+ CellsInfo match(TensorAddress addr) {
+ if (map.containsKey(addr)) {
+ return map.get(addr);
+ } else {
+ CellsInfo result = new CellsInfo();
+ map.put(addr, result);
+ return result;
+ }
+ }
+
+ int concatDimensionSize() {
+ Set<Integer> sizes = new HashSet<>();
+ map.forEach((m, cells) ->
+ cells.map.forEach((e, cell) ->
+ sizes.add(cell.values.size())));
+ if (sizes.isEmpty()) {
+ return 1;
+ }
+ if (sizes.size() == 1) {
+ return sizes.iterator().next();
+ }
+ throw new IllegalArgumentException("inconsistent size of concat dimension, had "+sizes.size()+" different values");
+ }
+ }
+
+ static class Helper {
+ final TensorType resultType;
+ final String concatDimension;
+
+ SideHow aHow = new SideHow();
+ SideHow bHow = new SideHow();
+
+ enum CombineHow { left, right, both, concat }
+
+ List<CombineHow> combineHow = new ArrayList<>();
+
+ TensorAddress combine(TensorAddress match, TensorAddress leftOnly, TensorAddress rightOnly, int concatDimIdx) {
+ String[] labels = new String[resultType.rank()];
+ int out = 0;
+ int m = 0;
+ int a = 0;
+ int b = 0;
+ for (var how : combineHow) {
+ switch (how) {
+ case left:
+ labels[out++] = leftOnly.label(a++);
+ break;
+ case right:
+ labels[out++] = rightOnly.label(b++);
+ break;
+ case both:
+ labels[out++] = match.label(m++);
+ break;
+ case concat:
+ labels[out++] = String.valueOf(concatDimIdx);
+ break;
+ default:
+ throw new IllegalArgumentException("cannot handle: "+how);
+ }
+ }
+ return TensorAddress.of(labels);
+ }
+
+ void aOnly(String dimName) {
+ if (dimName.equals(concatDimension)) {
+ aHow.handleDims.add(DimHow.concat);
+ combineHow.add(CombineHow.concat);
+ } else {
+ aHow.handleDims.add(DimHow.expand);
+ aHow.numExpand++;
+ combineHow.add(CombineHow.left);
+ }
+ }
+
+ void bOnly(String dimName) {
+ if (dimName.equals(concatDimension)) {
+ bHow.handleDims.add(DimHow.concat);
+ combineHow.add(CombineHow.concat);
+ } else {
+ bHow.handleDims.add(DimHow.expand);
+ bHow.numExpand++;
+ combineHow.add(CombineHow.right);
+ }
+ }
+
+ void bothAandB(String dimName) {
+ if (dimName.equals(concatDimension)) {
+ aHow.handleDims.add(DimHow.concat);
+ bHow.handleDims.add(DimHow.concat);
+ combineHow.add(CombineHow.concat);
+ } else {
+ aHow.handleDims.add(DimHow.match);
+ bHow.handleDims.add(DimHow.match);
+ aHow.numMatch++;
+ bHow.numMatch++;
+ combineHow.add(CombineHow.both);
+ }
+ }
+
+ Helper(TensorType aType, TensorType bType, String concatDimension) {
+ this.resultType = TypeResolver.concat(aType, bType, concatDimension);
+ this.concatDimension = concatDimension;
+ var aDims = aType.dimensions();
+ var bDims = bType.dimensions();
+ int i = 0;
+ int j = 0;
+ while (i < aDims.size() && j < bDims.size()) {
+ String aName = aDims.get(i).name();
+ String bName = bDims.get(j).name();
+ int cmp = aName.compareTo(bName);
+ if (cmp == 0) {
+ bothAandB(aName);
+ ++i;
+ ++j;
+ } else if (cmp < 0) {
+ aOnly(aName);
+ ++i;
+ } else {
+ bOnly(bName);
+ ++j;
+ }
+ }
+ while (i < aDims.size()) {
+ aOnly(aDims.get(i++).name());
+ }
+ while (j < bDims.size()) {
+ bOnly(bDims.get(j++).name());
+ }
+ if (combineHow.size() < resultType.rank()) {
+ var idx = resultType.indexOfDimension(concatDimension);
+ combineHow.add(idx.get(), CombineHow.concat);
+ }
+ }
+
+ Tensor mergeSides(SideInfo a, SideInfo b) {
+ var builder = Tensor.Builder.of(resultType);
+ int aConcatSize = a.concatDimensionSize();
+ for (var entry : a.map.entrySet()) {
+ TensorAddress match = entry.getKey();
+ if (b.map.containsKey(match)) {
+ var lhs = entry.getValue();
+ var rhs = b.map.get(match);
+ lhs.map.forEach((leftExpand, leftCells) -> {
+ rhs.map.forEach((rightExpand, rightCells) -> {
+ for (int i = 0; i < leftCells.values.size(); i++) {
+ TensorAddress addr = combine(match, leftExpand, rightExpand, i);
+ builder.cell(addr, leftCells.values.get(i));
+ }
+ for (int i = 0; i < rightCells.values.size(); i++) {
+ TensorAddress addr = combine(match, leftExpand, rightExpand, i + aConcatSize);
+ builder.cell(addr, rightCells.values.get(i));
+ }
+ });
+ });
+ }
+ }
+ return builder.build();
+ }
+ }
+
+ enum DimHow { match, expand, concat }
+
+ static class SideHow {
+ List<DimHow> handleDims = new ArrayList<>();
+ int numMatch = 0;
+ int numExpand = 0;
+ }
+
+ SideInfo analyse(Tensor side, SideHow how) {
+ var iter = side.cellIterator();
+ String[] matchLabels = new String[how.numMatch];
+ String[] expandLabels = new String[how.numExpand];
+ SideInfo result = new SideInfo();
+ while (iter.hasNext()) {
+ var cell = iter.next();
+ var addr = cell.getKey();
+ long ccDimIndex = 0;
+ int matchIdx = 0;
+ int expandIdx = 0;
+ for (int i = 0; i < how.handleDims.size(); i++) {
+ switch (how.handleDims.get(i)) {
+ case match:
+ matchLabels[matchIdx++] = addr.label(i);
+ break;
+ case expand:
+ expandLabels[expandIdx++] = addr.label(i);
+ break;
+ case concat:
+ ccDimIndex = addr.numericLabel(i);
+ break;
+ default: throw new IllegalArgumentException("cannot handle: "+how.handleDims.get(i));
+ }
+ }
+ TensorAddress matchAddr = TensorAddress.of(matchLabels);
+ TensorAddress expandAddr = TensorAddress.of(expandLabels);
+ result.match(matchAddr).expand(expandAddr).valueAt((int)ccDimIndex, cell.getValue());
+ }
+ return result;
+ }
+
private final TensorFunction<NAMETYPE> argumentA, argumentB;
private final String dimension;
@@ -68,6 +299,16 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
+ // if (a instanceof IndexedTensor && b instanceof IndexedTensor) {
+ // return oldEvaluate(a, b);
+ // }
+ Helper helper = new Helper(a.type(), b.type(), dimension);
+ SideInfo aInfo = analyse(a, helper.aHow);
+ SideInfo bInfo = analyse(b, helper.bHow);
+ return helper.mergeSides(aInfo, bInfo);
+ }
+
+ private Tensor oldEvaluate(Tensor a, Tensor b) {
TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension);
a = ensureIndexedDimension(dimension, a, concatType.valueType());
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
index 0476fe1c757..fe7d3872d23 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
@@ -95,6 +95,128 @@ public class ConcatTestCase {
assertConcat("tensor(x[],y[])", "tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y");
}
+ @Test
+ public void testAdvancedMixed() {
+ Tensor a = Tensor.from("tensor(a[2],b[2],c{},d[2],e{}):{"+
+ "{a:0,b:0,c:17,d:0,e:42}:1.0,"+
+ "{a:0,b:0,c:17,d:1,e:42}:2.0,"+
+ "{a:0,b:1,c:17,d:0,e:42}:3.0,"+
+ "{a:0,b:1,c:17,d:1,e:42}:4.0,"+
+ "{a:1,b:0,c:17,d:0,e:42}:5.0,"+
+ "{a:1,b:0,c:17,d:1,e:42}:6.0,"+
+ "{a:1,b:1,c:17,d:0,e:42}:7.0,"+
+ "{a:1,b:1,c:17,d:1,e:42}:8.0}");
+ Tensor b = Tensor.from("tensor(a[2],b[2],c{},f[2],g{}):{"+
+ "{a:0,b:0,c:17,f:0,g:666}:51.0,"+
+ "{a:0,b:0,c:17,f:1,g:666}:52.0,"+
+ "{a:0,b:1,c:17,f:0,g:666}:53.0,"+
+ "{a:0,b:1,c:17,f:1,g:666}:54.0,"+
+ "{a:1,b:0,c:17,f:0,g:666}:55.0,"+
+ "{a:1,b:0,c:17,f:1,g:666}:56.0,"+
+ "{a:1,b:1,c:17,f:0,g:666}:57.0,"+
+ "{a:1,b:1,c:17,f:1,g:666}:58.0}");
+
+ assertConcat("tensor(a[4],b[2],c{},d[2],e{},f[2],g{})",
+ "tensor(a[4],b[2],c{},d[2],e{},f[2],g{}):{"+
+ "{a:0,b:0,c:17,d:0,e:42,f:0,g:666}:1.0,"+
+ "{a:0,b:0,c:17,d:0,e:42,f:1,g:666}:1.0,"+
+ "{a:0,b:0,c:17,d:1,e:42,f:0,g:666}:2.0,"+
+ "{a:0,b:0,c:17,d:1,e:42,f:1,g:666}:2.0,"+
+ "{a:0,b:1,c:17,d:0,e:42,f:0,g:666}:3.0,"+
+ "{a:0,b:1,c:17,d:0,e:42,f:1,g:666}:3.0,"+
+ "{a:0,b:1,c:17,d:1,e:42,f:0,g:666}:4.0,"+
+ "{a:0,b:1,c:17,d:1,e:42,f:1,g:666}:4.0,"+
+ "{a:1,b:0,c:17,d:0,e:42,f:0,g:666}:5.0,"+
+ "{a:1,b:0,c:17,d:0,e:42,f:1,g:666}:5.0,"+
+ "{a:1,b:0,c:17,d:1,e:42,f:0,g:666}:6.0,"+
+ "{a:1,b:0,c:17,d:1,e:42,f:1,g:666}:6.0,"+
+ "{a:1,b:1,c:17,d:0,e:42,f:0,g:666}:7.0,"+
+ "{a:1,b:1,c:17,d:0,e:42,f:1,g:666}:7.0,"+
+ "{a:1,b:1,c:17,d:1,e:42,f:0,g:666}:8.0,"+
+ "{a:1,b:1,c:17,d:1,e:42,f:1,g:666}:8.0,"+
+ "{a:2,b:0,c:17,d:0,e:42,f:0,g:666}:51.0,"+
+ "{a:2,b:0,c:17,d:0,e:42,f:1,g:666}:52.0,"+
+ "{a:2,b:0,c:17,d:1,e:42,f:0,g:666}:51.0,"+
+ "{a:2,b:0,c:17,d:1,e:42,f:1,g:666}:52.0,"+
+ "{a:2,b:1,c:17,d:0,e:42,f:0,g:666}:53.0,"+
+ "{a:2,b:1,c:17,d:0,e:42,f:1,g:666}:54.0,"+
+ "{a:2,b:1,c:17,d:1,e:42,f:0,g:666}:53.0,"+
+ "{a:2,b:1,c:17,d:1,e:42,f:1,g:666}:54.0,"+
+ "{a:3,b:0,c:17,d:0,e:42,f:0,g:666}:55.0,"+
+ "{a:3,b:0,c:17,d:0,e:42,f:1,g:666}:56.0,"+
+ "{a:3,b:0,c:17,d:1,e:42,f:0,g:666}:55.0,"+
+ "{a:3,b:0,c:17,d:1,e:42,f:1,g:666}:56.0,"+
+ "{a:3,b:1,c:17,d:0,e:42,f:0,g:666}:57.0,"+
+ "{a:3,b:1,c:17,d:0,e:42,f:1,g:666}:58.0,"+
+ "{a:3,b:1,c:17,d:1,e:42,f:0,g:666}:57.0,"+
+ "{a:3,b:1,c:17,d:1,e:42,f:1,g:666}:58.0}",
+ a, b, "a");
+
+ assertConcat("tensor(a[2],b[2],c{},d[2],e{},f[2],g{},x[2])",
+ "tensor(a[2],b[2],c{},d[2],e{},f[2],g{},x[2]):{"+
+ "{a:0,b:0,c:17,d:0,e:42,f:0,g:666,x:0}:1.0,"+
+ "{a:0,b:0,c:17,d:0,e:42,f:1,g:666,x:0}:1.0,"+
+ "{a:0,b:0,c:17,d:1,e:42,f:0,g:666,x:0}:2.0,"+
+ "{a:0,b:0,c:17,d:1,e:42,f:1,g:666,x:0}:2.0,"+
+ "{a:0,b:1,c:17,d:0,e:42,f:0,g:666,x:0}:3.0,"+
+ "{a:0,b:1,c:17,d:0,e:42,f:1,g:666,x:0}:3.0,"+
+ "{a:0,b:1,c:17,d:1,e:42,f:0,g:666,x:0}:4.0,"+
+ "{a:0,b:1,c:17,d:1,e:42,f:1,g:666,x:0}:4.0,"+
+ "{a:1,b:0,c:17,d:0,e:42,f:0,g:666,x:0}:5.0,"+
+ "{a:1,b:0,c:17,d:0,e:42,f:1,g:666,x:0}:5.0,"+
+ "{a:1,b:0,c:17,d:1,e:42,f:0,g:666,x:0}:6.0,"+
+ "{a:1,b:0,c:17,d:1,e:42,f:1,g:666,x:0}:6.0,"+
+ "{a:1,b:1,c:17,d:0,e:42,f:0,g:666,x:0}:7.0,"+
+ "{a:1,b:1,c:17,d:0,e:42,f:1,g:666,x:0}:7.0,"+
+ "{a:1,b:1,c:17,d:1,e:42,f:0,g:666,x:0}:8.0,"+
+ "{a:1,b:1,c:17,d:1,e:42,f:1,g:666,x:0}:8.0,"+
+ "{a:0,b:0,c:17,d:0,e:42,f:0,g:666,x:1}:51.0,"+
+ "{a:0,b:0,c:17,d:0,e:42,f:1,g:666,x:1}:52.0,"+
+ "{a:0,b:0,c:17,d:1,e:42,f:0,g:666,x:1}:51.0,"+
+ "{a:0,b:0,c:17,d:1,e:42,f:1,g:666,x:1}:52.0,"+
+ "{a:0,b:1,c:17,d:0,e:42,f:0,g:666,x:1}:53.0,"+
+ "{a:0,b:1,c:17,d:0,e:42,f:1,g:666,x:1}:54.0,"+
+ "{a:0,b:1,c:17,d:1,e:42,f:0,g:666,x:1}:53.0,"+
+ "{a:0,b:1,c:17,d:1,e:42,f:1,g:666,x:1}:54.0,"+
+ "{a:1,b:0,c:17,d:0,e:42,f:0,g:666,x:1}:55.0,"+
+ "{a:1,b:0,c:17,d:0,e:42,f:1,g:666,x:1}:56.0,"+
+ "{a:1,b:0,c:17,d:1,e:42,f:0,g:666,x:1}:55.0,"+
+ "{a:1,b:0,c:17,d:1,e:42,f:1,g:666,x:1}:56.0,"+
+ "{a:1,b:1,c:17,d:0,e:42,f:0,g:666,x:1}:57.0,"+
+ "{a:1,b:1,c:17,d:0,e:42,f:1,g:666,x:1}:58.0,"+
+ "{a:1,b:1,c:17,d:1,e:42,f:0,g:666,x:1}:57.0,"+
+ "{a:1,b:1,c:17,d:1,e:42,f:1,g:666,x:1}:58.0}",
+ a, b, "x");
+
+ assertConcat("tensor(a[2],b[2],c{},d[3],e{},f[2],g{})",
+ "tensor(a[2],b[2],c{},d[3],e{},f[2],g{}):{"+
+ "{a:0,b:0,c:17,d:0,e:42,f:0,g:666}:1.0,"+
+ "{a:0,b:0,c:17,d:0,e:42,f:1,g:666}:1.0,"+
+ "{a:0,b:0,c:17,d:1,e:42,f:0,g:666}:2.0,"+
+ "{a:0,b:0,c:17,d:1,e:42,f:1,g:666}:2.0,"+
+ "{a:0,b:1,c:17,d:0,e:42,f:0,g:666}:3.0,"+
+ "{a:0,b:1,c:17,d:0,e:42,f:1,g:666}:3.0,"+
+ "{a:0,b:1,c:17,d:1,e:42,f:0,g:666}:4.0,"+
+ "{a:0,b:1,c:17,d:1,e:42,f:1,g:666}:4.0,"+
+ "{a:1,b:0,c:17,d:0,e:42,f:0,g:666}:5.0,"+
+ "{a:1,b:0,c:17,d:0,e:42,f:1,g:666}:5.0,"+
+ "{a:1,b:0,c:17,d:1,e:42,f:0,g:666}:6.0,"+
+ "{a:1,b:0,c:17,d:1,e:42,f:1,g:666}:6.0,"+
+ "{a:1,b:1,c:17,d:0,e:42,f:0,g:666}:7.0,"+
+ "{a:1,b:1,c:17,d:0,e:42,f:1,g:666}:7.0,"+
+ "{a:1,b:1,c:17,d:1,e:42,f:0,g:666}:8.0,"+
+ "{a:1,b:1,c:17,d:1,e:42,f:1,g:666}:8.0,"+
+ "{a:0,b:0,c:17,d:2,e:42,f:0,g:666}:51.0,"+
+ "{a:0,b:0,c:17,d:2,e:42,f:1,g:666}:52.0,"+
+ "{a:0,b:1,c:17,d:2,e:42,f:0,g:666}:53.0,"+
+ "{a:0,b:1,c:17,d:2,e:42,f:1,g:666}:54.0,"+
+ "{a:1,b:0,c:17,d:2,e:42,f:0,g:666}:55.0,"+
+ "{a:1,b:0,c:17,d:2,e:42,f:1,g:666}:56.0,"+
+ "{a:1,b:1,c:17,d:2,e:42,f:0,g:666}:57.0,"+
+ "{a:1,b:1,c:17,d:2,e:42,f:1,g:666}:58.0}",
+ a, b, "d");
+ }
+
private void assertConcat(String expected, Tensor a, Tensor b, String dimension) {
assertConcat(null, expected, a, b, dimension);
}