diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-04-26 13:39:44 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-04-26 13:39:44 +0000 |
commit | a01159e186f7b37b1ca16be3daa80fa1567969d2 (patch) | |
tree | 825f5c75710d1d2e0df7247248dfacf438aa7678 /vespajlib | |
parent | da3d03db65344e927b9682b30013a524cf914522 (diff) |
rename some concepts
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 111 |
1 files changed, 53 insertions, 58 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 959f76d40da..1f33c53dd8e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -31,9 +31,9 @@ import java.util.stream.Collectors; */ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { - static class CellInfo { + static class CellVector { ArrayList<Double> values = new ArrayList<>(); - void valueAt(int ccDimIndex, double value) { + void setValue(int ccDimIndex, double value) { while (values.size() <= ccDimIndex) { values.add(0.0); } @@ -41,28 +41,28 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } } - static class CellsInfo { - Map<TensorAddress, CellInfo> map = new HashMap<>(); + static class CellVectorMap { + Map<TensorAddress, CellVector> map = new HashMap<>(); - CellInfo expand(TensorAddress addr) { + CellVector lookupCreate(TensorAddress addr) { if (map.containsKey(addr)) { return map.get(addr); } else { - CellInfo result = new CellInfo(); + CellVector result = new CellVector(); map.put(addr, result); return result; } } } - static class SideInfo { - Map<TensorAddress, CellsInfo> map = new HashMap<>(); + static class CellVectorMapMap { + Map<TensorAddress, CellVectorMap> map = new HashMap<>(); - CellsInfo match(TensorAddress addr) { + CellVectorMap lookupCreate(TensorAddress addr) { if (map.containsKey(addr)) { return map.get(addr); } else { - CellsInfo result = new CellsInfo(); + CellVectorMap result = new CellVectorMap(); map.put(addr, result); return result; } @@ -83,12 +83,20 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } } - static class Helper { + enum DimType { common, separate, concat } + + static class SplitHow { + List<DimType> handleDims = new ArrayList<>(); + long numCommon() { return handleDims.stream().filter(t -> (t == DimType.common)).count(); } + long numSeparate() { return handleDims.stream().filter(t -> (t == DimType.separate)).count(); } + } + + static class ConcatPlan { final TensorType resultType; final String concatDimension; - SideHow aHow = new SideHow(); - SideHow bHow = new SideHow(); + SplitHow aHow = new SplitHow(); + SplitHow bHow = new SplitHow(); enum CombineHow { left, right, both, concat } @@ -114,8 +122,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET case concat: labels[out++] = String.valueOf(concatDimIdx); break; - default: - throw new IllegalArgumentException("cannot handle: "+how); + //default: throw new IllegalArgumentException("cannot handle: "+how); } } return TensorAddress.of(labels); @@ -123,41 +130,37 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET void aOnly(String dimName) { if (dimName.equals(concatDimension)) { - aHow.handleDims.add(DimHow.concat); + aHow.handleDims.add(DimType.concat); combineHow.add(CombineHow.concat); } else { - aHow.handleDims.add(DimHow.expand); - aHow.numExpand++; + aHow.handleDims.add(DimType.separate); combineHow.add(CombineHow.left); } } void bOnly(String dimName) { if (dimName.equals(concatDimension)) { - bHow.handleDims.add(DimHow.concat); + bHow.handleDims.add(DimType.concat); combineHow.add(CombineHow.concat); } else { - bHow.handleDims.add(DimHow.expand); - bHow.numExpand++; + bHow.handleDims.add(DimType.separate); combineHow.add(CombineHow.right); } } void bothAandB(String dimName) { if (dimName.equals(concatDimension)) { - aHow.handleDims.add(DimHow.concat); - bHow.handleDims.add(DimHow.concat); + aHow.handleDims.add(DimType.concat); + bHow.handleDims.add(DimType.concat); combineHow.add(CombineHow.concat); } else { - aHow.handleDims.add(DimHow.match); - bHow.handleDims.add(DimHow.match); - aHow.numMatch++; - bHow.numMatch++; + aHow.handleDims.add(DimType.common); + bHow.handleDims.add(DimType.common); combineHow.add(CombineHow.both); } } - Helper(TensorType aType, TensorType bType, String concatDimension) { + ConcatPlan(TensorType aType, TensorType bType, String concatDimension) { this.resultType = TypeResolver.concat(aType, bType, concatDimension); this.concatDimension = concatDimension; var aDims = aType.dimensions(); @@ -192,7 +195,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } } - Tensor mergeSides(SideInfo a, SideInfo b) { + Tensor mergeSides(CellVectorMapMap a, CellVectorMapMap b) { var builder = Tensor.Builder.of(resultType); int aConcatSize = a.concatDimensionSize(); for (var entry : a.map.entrySet()) { @@ -200,14 +203,14 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET if (b.map.containsKey(match)) { var lhs = entry.getValue(); var rhs = b.map.get(match); - lhs.map.forEach((leftExpand, leftCells) -> { - rhs.map.forEach((rightExpand, rightCells) -> { + lhs.map.forEach((leftOnly, leftCells) -> { + rhs.map.forEach((rightOnly, rightCells) -> { for (int i = 0; i < leftCells.values.size(); i++) { - TensorAddress addr = combine(match, leftExpand, rightExpand, i); + TensorAddress addr = combine(match, leftOnly, rightOnly, i); builder.cell(addr, leftCells.values.get(i)); } for (int i = 0; i < rightCells.values.size(); i++) { - TensorAddress addr = combine(match, leftExpand, rightExpand, i + aConcatSize); + TensorAddress addr = combine(match, leftOnly, rightOnly, i + aConcatSize); builder.cell(addr, rightCells.values.get(i)); } }); @@ -218,42 +221,34 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } } - enum DimHow { match, expand, concat } - - static class SideHow { - List<DimHow> handleDims = new ArrayList<>(); - int numMatch = 0; - int numExpand = 0; - } - - SideInfo analyse(Tensor side, SideHow how) { + CellVectorMapMap analyse(Tensor side, SplitHow how) { var iter = side.cellIterator(); - String[] matchLabels = new String[how.numMatch]; - String[] expandLabels = new String[how.numExpand]; - SideInfo result = new SideInfo(); + String[] commonLabels = new String[(int)how.numCommon()]; + String[] separateLabels = new String[(int)how.numSeparate()]; + CellVectorMapMap result = new CellVectorMapMap(); while (iter.hasNext()) { var cell = iter.next(); var addr = cell.getKey(); long ccDimIndex = 0; int matchIdx = 0; - int expandIdx = 0; + int separateIdx = 0; for (int i = 0; i < how.handleDims.size(); i++) { switch (how.handleDims.get(i)) { - case match: - matchLabels[matchIdx++] = addr.label(i); + case common: + commonLabels[matchIdx++] = addr.label(i); break; - case expand: - expandLabels[expandIdx++] = addr.label(i); + case separate: + separateLabels[separateIdx++] = addr.label(i); break; case concat: ccDimIndex = addr.numericLabel(i); break; - default: throw new IllegalArgumentException("cannot handle: "+how.handleDims.get(i)); + // default: throw new IllegalArgumentException("cannot handle: "+how.handleDims.get(i)); } } - TensorAddress matchAddr = TensorAddress.of(matchLabels); - TensorAddress expandAddr = TensorAddress.of(expandLabels); - result.match(matchAddr).expand(expandAddr).valueAt((int)ccDimIndex, cell.getValue()); + TensorAddress commonAddr = TensorAddress.of(commonLabels); + TensorAddress separateAddr = TensorAddress.of(separateLabels); + result.lookupCreate(commonAddr).lookupCreate(separateAddr).setValue((int)ccDimIndex, cell.getValue()); } return result; } @@ -302,10 +297,10 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET // if (a instanceof IndexedTensor && b instanceof IndexedTensor) { // return oldEvaluate(a, b); // } - Helper helper = new Helper(a.type(), b.type(), dimension); - SideInfo aInfo = analyse(a, helper.aHow); - SideInfo bInfo = analyse(b, helper.bHow); - return helper.mergeSides(aInfo, bInfo); + ConcatPlan plan = new ConcatPlan(a.type(), b.type(), dimension); + CellVectorMapMap aInfo = analyse(a, plan.aHow); + CellVectorMapMap bInfo = analyse(b, plan.bHow); + return plan.mergeSides(aInfo, bInfo); } private Tensor oldEvaluate(Tensor a, Tensor b) { |