summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-04-26 13:59:27 +0000
committerArne Juul <arnej@verizonmedia.com>2021-04-26 13:59:27 +0000
commit3fbb4234ed3397188869d1446b1e6d6aaf93eb33 (patch)
treef92ecb5ed26fa4ea6b9805ab432666453278d4e4 /vespajlib
parenta01159e186f7b37b1ca16be3daa80fa1567969d2 (diff)
move code around and rename more concepts
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java208
1 files changed, 104 insertions, 104 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 1f33c53dd8e..ad798b5c675 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -43,15 +43,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
static class CellVectorMap {
Map<TensorAddress, CellVector> map = new HashMap<>();
-
CellVector lookupCreate(TensorAddress addr) {
- if (map.containsKey(addr)) {
- return map.get(addr);
- } else {
- CellVector result = new CellVector();
- map.put(addr, result);
- return result;
- }
+ return map.computeIfAbsent(addr, k -> new CellVector());
}
}
@@ -59,28 +52,9 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
Map<TensorAddress, CellVectorMap> map = new HashMap<>();
CellVectorMap lookupCreate(TensorAddress addr) {
- if (map.containsKey(addr)) {
- return map.get(addr);
- } else {
- CellVectorMap result = new CellVectorMap();
- map.put(addr, result);
- return result;
- }
+ return map.computeIfAbsent(addr, k -> new CellVectorMap());
}
- int concatDimensionSize() {
- Set<Integer> sizes = new HashSet<>();
- map.forEach((m, cells) ->
- cells.map.forEach((e, cell) ->
- sizes.add(cell.values.size())));
- if (sizes.isEmpty()) {
- return 1;
- }
- if (sizes.size() == 1) {
- return sizes.iterator().next();
- }
- throw new IllegalArgumentException("inconsistent size of concat dimension, had "+sizes.size()+" different values");
- }
}
enum DimType { common, separate, concat }
@@ -92,70 +66,45 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
static class ConcatPlan {
+
final TensorType resultType;
final String concatDimension;
- SplitHow aHow = new SplitHow();
- SplitHow bHow = new SplitHow();
+ SplitHow splitInfoA = new SplitHow();
+ SplitHow splitInfoB = new SplitHow();
enum CombineHow { left, right, both, concat }
List<CombineHow> combineHow = new ArrayList<>();
- TensorAddress combine(TensorAddress match, TensorAddress leftOnly, TensorAddress rightOnly, int concatDimIdx) {
- String[] labels = new String[resultType.rank()];
- int out = 0;
- int m = 0;
- int a = 0;
- int b = 0;
- for (var how : combineHow) {
- switch (how) {
- case left:
- labels[out++] = leftOnly.label(a++);
- break;
- case right:
- labels[out++] = rightOnly.label(b++);
- break;
- case both:
- labels[out++] = match.label(m++);
- break;
- case concat:
- labels[out++] = String.valueOf(concatDimIdx);
- break;
- //default: throw new IllegalArgumentException("cannot handle: "+how);
- }
- }
- return TensorAddress.of(labels);
- }
-
void aOnly(String dimName) {
if (dimName.equals(concatDimension)) {
- aHow.handleDims.add(DimType.concat);
+ splitInfoA.handleDims.add(DimType.concat);
combineHow.add(CombineHow.concat);
} else {
- aHow.handleDims.add(DimType.separate);
+ splitInfoA.handleDims.add(DimType.separate);
combineHow.add(CombineHow.left);
}
}
void bOnly(String dimName) {
if (dimName.equals(concatDimension)) {
- bHow.handleDims.add(DimType.concat);
+ splitInfoB.handleDims.add(DimType.concat);
combineHow.add(CombineHow.concat);
} else {
- bHow.handleDims.add(DimType.separate);
+ splitInfoB.handleDims.add(DimType.separate);
combineHow.add(CombineHow.right);
}
}
void bothAandB(String dimName) {
if (dimName.equals(concatDimension)) {
- aHow.handleDims.add(DimType.concat);
- bHow.handleDims.add(DimType.concat);
+ splitInfoA.handleDims.add(DimType.concat);
+ splitInfoB.handleDims.add(DimType.concat);
combineHow.add(CombineHow.concat);
} else {
- aHow.handleDims.add(DimType.common);
- bHow.handleDims.add(DimType.common);
+ splitInfoA.handleDims.add(DimType.common);
+ splitInfoB.handleDims.add(DimType.common);
combineHow.add(CombineHow.both);
}
}
@@ -195,22 +144,75 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
}
- Tensor mergeSides(CellVectorMapMap a, CellVectorMapMap b) {
- var builder = Tensor.Builder.of(resultType);
- int aConcatSize = a.concatDimensionSize();
+ }
+
+ static class Helper {
+ ConcatPlan plan;
+ Tensor result;
+
+ Helper(Tensor a, Tensor b, String dimension) {
+ this.plan = new ConcatPlan(a.type(), b.type(), dimension);
+ CellVectorMapMap aData = decompose(a, plan.splitInfoA);
+ CellVectorMapMap bData = decompose(b, plan.splitInfoB);
+ this.result = merge(aData, bData);
+ }
+
+ static int concatDimensionSize(CellVectorMapMap data) {
+ Set<Integer> sizes = new HashSet<>();
+ data.map.forEach((m, cells) ->
+ cells.map.forEach((e, vector) ->
+ sizes.add(vector.values.size())));
+ if (sizes.isEmpty()) {
+ return 1;
+ }
+ if (sizes.size() == 1) {
+ return sizes.iterator().next();
+ }
+ throw new IllegalArgumentException("inconsistent size of concat dimension, had "+sizes.size()+" different values");
+ }
+
+ TensorAddress combine(TensorAddress match, TensorAddress leftOnly, TensorAddress rightOnly, int concatDimIdx) {
+ String[] labels = new String[plan.resultType.rank()];
+ int out = 0;
+ int m = 0;
+ int a = 0;
+ int b = 0;
+ for (var how : plan.combineHow) {
+ switch (how) {
+ case left:
+ labels[out++] = leftOnly.label(a++);
+ break;
+ case right:
+ labels[out++] = rightOnly.label(b++);
+ break;
+ case both:
+ labels[out++] = match.label(m++);
+ break;
+ case concat:
+ labels[out++] = String.valueOf(concatDimIdx);
+ break;
+ //default: throw new IllegalArgumentException("cannot handle: "+how);
+ }
+ }
+ return TensorAddress.of(labels);
+ }
+
+ Tensor merge(CellVectorMapMap a, CellVectorMapMap b) {
+ var builder = Tensor.Builder.of(plan.resultType);
+ int aConcatSize = concatDimensionSize(a);
for (var entry : a.map.entrySet()) {
- TensorAddress match = entry.getKey();
- if (b.map.containsKey(match)) {
+ TensorAddress common = entry.getKey();
+ if (b.map.containsKey(common)) {
var lhs = entry.getValue();
- var rhs = b.map.get(match);
+ var rhs = b.map.get(common);
lhs.map.forEach((leftOnly, leftCells) -> {
rhs.map.forEach((rightOnly, rightCells) -> {
for (int i = 0; i < leftCells.values.size(); i++) {
- TensorAddress addr = combine(match, leftOnly, rightOnly, i);
+ TensorAddress addr = combine(common, leftOnly, rightOnly, i);
builder.cell(addr, leftCells.values.get(i));
}
for (int i = 0; i < rightCells.values.size(); i++) {
- TensorAddress addr = combine(match, leftOnly, rightOnly, i + aConcatSize);
+ TensorAddress addr = combine(common, leftOnly, rightOnly, i + aConcatSize);
builder.cell(addr, rightCells.values.get(i));
}
});
@@ -219,38 +221,38 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
return builder.build();
}
- }
- CellVectorMapMap analyse(Tensor side, SplitHow how) {
- var iter = side.cellIterator();
- String[] commonLabels = new String[(int)how.numCommon()];
- String[] separateLabels = new String[(int)how.numSeparate()];
- CellVectorMapMap result = new CellVectorMapMap();
- while (iter.hasNext()) {
- var cell = iter.next();
- var addr = cell.getKey();
- long ccDimIndex = 0;
- int matchIdx = 0;
- int separateIdx = 0;
- for (int i = 0; i < how.handleDims.size(); i++) {
- switch (how.handleDims.get(i)) {
- case common:
- commonLabels[matchIdx++] = addr.label(i);
- break;
- case separate:
- separateLabels[separateIdx++] = addr.label(i);
- break;
- case concat:
- ccDimIndex = addr.numericLabel(i);
- break;
- // default: throw new IllegalArgumentException("cannot handle: "+how.handleDims.get(i));
+ CellVectorMapMap decompose(Tensor input, SplitHow how) {
+ var iter = input.cellIterator();
+ String[] commonLabels = new String[(int)how.numCommon()];
+ String[] separateLabels = new String[(int)how.numSeparate()];
+ CellVectorMapMap result = new CellVectorMapMap();
+ while (iter.hasNext()) {
+ var cell = iter.next();
+ var addr = cell.getKey();
+ long ccDimIndex = 0;
+ int commonIdx = 0;
+ int separateIdx = 0;
+ for (int i = 0; i < how.handleDims.size(); i++) {
+ switch (how.handleDims.get(i)) {
+ case common:
+ commonLabels[commonIdx++] = addr.label(i);
+ break;
+ case separate:
+ separateLabels[separateIdx++] = addr.label(i);
+ break;
+ case concat:
+ ccDimIndex = addr.numericLabel(i);
+ break;
+ // default: throw new IllegalArgumentException("cannot handle: "+how.handleDims.get(i));
+ }
}
+ TensorAddress commonAddr = TensorAddress.of(commonLabels);
+ TensorAddress separateAddr = TensorAddress.of(separateLabels);
+ result.lookupCreate(commonAddr).lookupCreate(separateAddr).setValue((int)ccDimIndex, cell.getValue());
}
- TensorAddress commonAddr = TensorAddress.of(commonLabels);
- TensorAddress separateAddr = TensorAddress.of(separateLabels);
- result.lookupCreate(commonAddr).lookupCreate(separateAddr).setValue((int)ccDimIndex, cell.getValue());
+ return result;
}
- return result;
}
private final TensorFunction<NAMETYPE> argumentA, argumentB;
@@ -297,10 +299,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
// if (a instanceof IndexedTensor && b instanceof IndexedTensor) {
// return oldEvaluate(a, b);
// }
- ConcatPlan plan = new ConcatPlan(a.type(), b.type(), dimension);
- CellVectorMapMap aInfo = analyse(a, plan.aHow);
- CellVectorMapMap bInfo = analyse(b, plan.bHow);
- return plan.mergeSides(aInfo, bInfo);
+ var helper = new Helper(a, b, dimension);
+ return helper.result;
}
private Tensor oldEvaluate(Tensor a, Tensor b) {