diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-04-26 13:59:27 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-04-26 13:59:27 +0000 |
commit | 3fbb4234ed3397188869d1446b1e6d6aaf93eb33 (patch) | |
tree | f92ecb5ed26fa4ea6b9805ab432666453278d4e4 /vespajlib/src | |
parent | a01159e186f7b37b1ca16be3daa80fa1567969d2 (diff) |
move code around and rename more concepts
Diffstat (limited to 'vespajlib/src')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 208 |
1 files changed, 104 insertions, 104 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 1f33c53dd8e..ad798b5c675 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -43,15 +43,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET static class CellVectorMap { Map<TensorAddress, CellVector> map = new HashMap<>(); - CellVector lookupCreate(TensorAddress addr) { - if (map.containsKey(addr)) { - return map.get(addr); - } else { - CellVector result = new CellVector(); - map.put(addr, result); - return result; - } + return map.computeIfAbsent(addr, k -> new CellVector()); } } @@ -59,28 +52,9 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET Map<TensorAddress, CellVectorMap> map = new HashMap<>(); CellVectorMap lookupCreate(TensorAddress addr) { - if (map.containsKey(addr)) { - return map.get(addr); - } else { - CellVectorMap result = new CellVectorMap(); - map.put(addr, result); - return result; - } + return map.computeIfAbsent(addr, k -> new CellVectorMap()); } - int concatDimensionSize() { - Set<Integer> sizes = new HashSet<>(); - map.forEach((m, cells) -> - cells.map.forEach((e, cell) -> - sizes.add(cell.values.size()))); - if (sizes.isEmpty()) { - return 1; - } - if (sizes.size() == 1) { - return sizes.iterator().next(); - } - throw new IllegalArgumentException("inconsistent size of concat dimension, had "+sizes.size()+" different values"); - } } enum DimType { common, separate, concat } @@ -92,70 +66,45 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } static class ConcatPlan { + final TensorType resultType; final String concatDimension; - SplitHow aHow = new SplitHow(); - SplitHow bHow = new SplitHow(); + SplitHow splitInfoA = new SplitHow(); + SplitHow splitInfoB = new SplitHow(); enum CombineHow { left, right, both, concat } List<CombineHow> combineHow = new ArrayList<>(); - TensorAddress combine(TensorAddress match, TensorAddress leftOnly, TensorAddress rightOnly, int concatDimIdx) { - String[] labels = new String[resultType.rank()]; - int out = 0; - int m = 0; - int a = 0; - int b = 0; - for (var how : combineHow) { - switch (how) { - case left: - labels[out++] = leftOnly.label(a++); - break; - case right: - labels[out++] = rightOnly.label(b++); - break; - case both: - labels[out++] = match.label(m++); - break; - case concat: - labels[out++] = String.valueOf(concatDimIdx); - break; - //default: throw new IllegalArgumentException("cannot handle: "+how); - } - } - return TensorAddress.of(labels); - } - void aOnly(String dimName) { if (dimName.equals(concatDimension)) { - aHow.handleDims.add(DimType.concat); + splitInfoA.handleDims.add(DimType.concat); combineHow.add(CombineHow.concat); } else { - aHow.handleDims.add(DimType.separate); + splitInfoA.handleDims.add(DimType.separate); combineHow.add(CombineHow.left); } } void bOnly(String dimName) { if (dimName.equals(concatDimension)) { - bHow.handleDims.add(DimType.concat); + splitInfoB.handleDims.add(DimType.concat); combineHow.add(CombineHow.concat); } else { - bHow.handleDims.add(DimType.separate); + splitInfoB.handleDims.add(DimType.separate); combineHow.add(CombineHow.right); } } void bothAandB(String dimName) { if (dimName.equals(concatDimension)) { - aHow.handleDims.add(DimType.concat); - bHow.handleDims.add(DimType.concat); + splitInfoA.handleDims.add(DimType.concat); + splitInfoB.handleDims.add(DimType.concat); combineHow.add(CombineHow.concat); } else { - aHow.handleDims.add(DimType.common); - bHow.handleDims.add(DimType.common); + splitInfoA.handleDims.add(DimType.common); + splitInfoB.handleDims.add(DimType.common); combineHow.add(CombineHow.both); } } @@ -195,22 +144,75 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } } - Tensor mergeSides(CellVectorMapMap a, CellVectorMapMap b) { - var builder = Tensor.Builder.of(resultType); - int aConcatSize = a.concatDimensionSize(); + } + + static class Helper { + ConcatPlan plan; + Tensor result; + + Helper(Tensor a, Tensor b, String dimension) { + this.plan = new ConcatPlan(a.type(), b.type(), dimension); + CellVectorMapMap aData = decompose(a, plan.splitInfoA); + CellVectorMapMap bData = decompose(b, plan.splitInfoB); + this.result = merge(aData, bData); + } + + static int concatDimensionSize(CellVectorMapMap data) { + Set<Integer> sizes = new HashSet<>(); + data.map.forEach((m, cells) -> + cells.map.forEach((e, vector) -> + sizes.add(vector.values.size()))); + if (sizes.isEmpty()) { + return 1; + } + if (sizes.size() == 1) { + return sizes.iterator().next(); + } + throw new IllegalArgumentException("inconsistent size of concat dimension, had "+sizes.size()+" different values"); + } + + TensorAddress combine(TensorAddress match, TensorAddress leftOnly, TensorAddress rightOnly, int concatDimIdx) { + String[] labels = new String[plan.resultType.rank()]; + int out = 0; + int m = 0; + int a = 0; + int b = 0; + for (var how : plan.combineHow) { + switch (how) { + case left: + labels[out++] = leftOnly.label(a++); + break; + case right: + labels[out++] = rightOnly.label(b++); + break; + case both: + labels[out++] = match.label(m++); + break; + case concat: + labels[out++] = String.valueOf(concatDimIdx); + break; + //default: throw new IllegalArgumentException("cannot handle: "+how); + } + } + return TensorAddress.of(labels); + } + + Tensor merge(CellVectorMapMap a, CellVectorMapMap b) { + var builder = Tensor.Builder.of(plan.resultType); + int aConcatSize = concatDimensionSize(a); for (var entry : a.map.entrySet()) { - TensorAddress match = entry.getKey(); - if (b.map.containsKey(match)) { + TensorAddress common = entry.getKey(); + if (b.map.containsKey(common)) { var lhs = entry.getValue(); - var rhs = b.map.get(match); + var rhs = b.map.get(common); lhs.map.forEach((leftOnly, leftCells) -> { rhs.map.forEach((rightOnly, rightCells) -> { for (int i = 0; i < leftCells.values.size(); i++) { - TensorAddress addr = combine(match, leftOnly, rightOnly, i); + TensorAddress addr = combine(common, leftOnly, rightOnly, i); builder.cell(addr, leftCells.values.get(i)); } for (int i = 0; i < rightCells.values.size(); i++) { - TensorAddress addr = combine(match, leftOnly, rightOnly, i + aConcatSize); + TensorAddress addr = combine(common, leftOnly, rightOnly, i + aConcatSize); builder.cell(addr, rightCells.values.get(i)); } }); @@ -219,38 +221,38 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } return builder.build(); } - } - CellVectorMapMap analyse(Tensor side, SplitHow how) { - var iter = side.cellIterator(); - String[] commonLabels = new String[(int)how.numCommon()]; - String[] separateLabels = new String[(int)how.numSeparate()]; - CellVectorMapMap result = new CellVectorMapMap(); - while (iter.hasNext()) { - var cell = iter.next(); - var addr = cell.getKey(); - long ccDimIndex = 0; - int matchIdx = 0; - int separateIdx = 0; - for (int i = 0; i < how.handleDims.size(); i++) { - switch (how.handleDims.get(i)) { - case common: - commonLabels[matchIdx++] = addr.label(i); - break; - case separate: - separateLabels[separateIdx++] = addr.label(i); - break; - case concat: - ccDimIndex = addr.numericLabel(i); - break; - // default: throw new IllegalArgumentException("cannot handle: "+how.handleDims.get(i)); + CellVectorMapMap decompose(Tensor input, SplitHow how) { + var iter = input.cellIterator(); + String[] commonLabels = new String[(int)how.numCommon()]; + String[] separateLabels = new String[(int)how.numSeparate()]; + CellVectorMapMap result = new CellVectorMapMap(); + while (iter.hasNext()) { + var cell = iter.next(); + var addr = cell.getKey(); + long ccDimIndex = 0; + int commonIdx = 0; + int separateIdx = 0; + for (int i = 0; i < how.handleDims.size(); i++) { + switch (how.handleDims.get(i)) { + case common: + commonLabels[commonIdx++] = addr.label(i); + break; + case separate: + separateLabels[separateIdx++] = addr.label(i); + break; + case concat: + ccDimIndex = addr.numericLabel(i); + break; + // default: throw new IllegalArgumentException("cannot handle: "+how.handleDims.get(i)); + } } + TensorAddress commonAddr = TensorAddress.of(commonLabels); + TensorAddress separateAddr = TensorAddress.of(separateLabels); + result.lookupCreate(commonAddr).lookupCreate(separateAddr).setValue((int)ccDimIndex, cell.getValue()); } - TensorAddress commonAddr = TensorAddress.of(commonLabels); - TensorAddress separateAddr = TensorAddress.of(separateLabels); - result.lookupCreate(commonAddr).lookupCreate(separateAddr).setValue((int)ccDimIndex, cell.getValue()); + return result; } - return result; } private final TensorFunction<NAMETYPE> argumentA, argumentB; @@ -297,10 +299,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET // if (a instanceof IndexedTensor && b instanceof IndexedTensor) { // return oldEvaluate(a, b); // } - ConcatPlan plan = new ConcatPlan(a.type(), b.type(), dimension); - CellVectorMapMap aInfo = analyse(a, plan.aHow); - CellVectorMapMap bInfo = analyse(b, plan.bHow); - return plan.mergeSides(aInfo, bInfo); + var helper = new Helper(a, b, dimension); + return helper.result; } private Tensor oldEvaluate(Tensor a, Tensor b) { |