summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-04-26 13:39:44 +0000
committerArne Juul <arnej@verizonmedia.com>2021-04-26 13:39:44 +0000
commita01159e186f7b37b1ca16be3daa80fa1567969d2 (patch)
tree825f5c75710d1d2e0df7247248dfacf438aa7678 /vespajlib
parentda3d03db65344e927b9682b30013a524cf914522 (diff)
rename some concepts
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java111
1 files changed, 53 insertions, 58 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 959f76d40da..1f33c53dd8e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -31,9 +31,9 @@ import java.util.stream.Collectors;
*/
public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
- static class CellInfo {
+ static class CellVector {
ArrayList<Double> values = new ArrayList<>();
- void valueAt(int ccDimIndex, double value) {
+ void setValue(int ccDimIndex, double value) {
while (values.size() <= ccDimIndex) {
values.add(0.0);
}
@@ -41,28 +41,28 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
}
- static class CellsInfo {
- Map<TensorAddress, CellInfo> map = new HashMap<>();
+ static class CellVectorMap {
+ Map<TensorAddress, CellVector> map = new HashMap<>();
- CellInfo expand(TensorAddress addr) {
+ CellVector lookupCreate(TensorAddress addr) {
if (map.containsKey(addr)) {
return map.get(addr);
} else {
- CellInfo result = new CellInfo();
+ CellVector result = new CellVector();
map.put(addr, result);
return result;
}
}
}
- static class SideInfo {
- Map<TensorAddress, CellsInfo> map = new HashMap<>();
+ static class CellVectorMapMap {
+ Map<TensorAddress, CellVectorMap> map = new HashMap<>();
- CellsInfo match(TensorAddress addr) {
+ CellVectorMap lookupCreate(TensorAddress addr) {
if (map.containsKey(addr)) {
return map.get(addr);
} else {
- CellsInfo result = new CellsInfo();
+ CellVectorMap result = new CellVectorMap();
map.put(addr, result);
return result;
}
@@ -83,12 +83,20 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
}
- static class Helper {
+ enum DimType { common, separate, concat }
+
+ static class SplitHow {
+ List<DimType> handleDims = new ArrayList<>();
+ long numCommon() { return handleDims.stream().filter(t -> (t == DimType.common)).count(); }
+ long numSeparate() { return handleDims.stream().filter(t -> (t == DimType.separate)).count(); }
+ }
+
+ static class ConcatPlan {
final TensorType resultType;
final String concatDimension;
- SideHow aHow = new SideHow();
- SideHow bHow = new SideHow();
+ SplitHow aHow = new SplitHow();
+ SplitHow bHow = new SplitHow();
enum CombineHow { left, right, both, concat }
@@ -114,8 +122,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
case concat:
labels[out++] = String.valueOf(concatDimIdx);
break;
- default:
- throw new IllegalArgumentException("cannot handle: "+how);
+ //default: throw new IllegalArgumentException("cannot handle: "+how);
}
}
return TensorAddress.of(labels);
@@ -123,41 +130,37 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
void aOnly(String dimName) {
if (dimName.equals(concatDimension)) {
- aHow.handleDims.add(DimHow.concat);
+ aHow.handleDims.add(DimType.concat);
combineHow.add(CombineHow.concat);
} else {
- aHow.handleDims.add(DimHow.expand);
- aHow.numExpand++;
+ aHow.handleDims.add(DimType.separate);
combineHow.add(CombineHow.left);
}
}
void bOnly(String dimName) {
if (dimName.equals(concatDimension)) {
- bHow.handleDims.add(DimHow.concat);
+ bHow.handleDims.add(DimType.concat);
combineHow.add(CombineHow.concat);
} else {
- bHow.handleDims.add(DimHow.expand);
- bHow.numExpand++;
+ bHow.handleDims.add(DimType.separate);
combineHow.add(CombineHow.right);
}
}
void bothAandB(String dimName) {
if (dimName.equals(concatDimension)) {
- aHow.handleDims.add(DimHow.concat);
- bHow.handleDims.add(DimHow.concat);
+ aHow.handleDims.add(DimType.concat);
+ bHow.handleDims.add(DimType.concat);
combineHow.add(CombineHow.concat);
} else {
- aHow.handleDims.add(DimHow.match);
- bHow.handleDims.add(DimHow.match);
- aHow.numMatch++;
- bHow.numMatch++;
+ aHow.handleDims.add(DimType.common);
+ bHow.handleDims.add(DimType.common);
combineHow.add(CombineHow.both);
}
}
- Helper(TensorType aType, TensorType bType, String concatDimension) {
+ ConcatPlan(TensorType aType, TensorType bType, String concatDimension) {
this.resultType = TypeResolver.concat(aType, bType, concatDimension);
this.concatDimension = concatDimension;
var aDims = aType.dimensions();
@@ -192,7 +195,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
}
- Tensor mergeSides(SideInfo a, SideInfo b) {
+ Tensor mergeSides(CellVectorMapMap a, CellVectorMapMap b) {
var builder = Tensor.Builder.of(resultType);
int aConcatSize = a.concatDimensionSize();
for (var entry : a.map.entrySet()) {
@@ -200,14 +203,14 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
if (b.map.containsKey(match)) {
var lhs = entry.getValue();
var rhs = b.map.get(match);
- lhs.map.forEach((leftExpand, leftCells) -> {
- rhs.map.forEach((rightExpand, rightCells) -> {
+ lhs.map.forEach((leftOnly, leftCells) -> {
+ rhs.map.forEach((rightOnly, rightCells) -> {
for (int i = 0; i < leftCells.values.size(); i++) {
- TensorAddress addr = combine(match, leftExpand, rightExpand, i);
+ TensorAddress addr = combine(match, leftOnly, rightOnly, i);
builder.cell(addr, leftCells.values.get(i));
}
for (int i = 0; i < rightCells.values.size(); i++) {
- TensorAddress addr = combine(match, leftExpand, rightExpand, i + aConcatSize);
+ TensorAddress addr = combine(match, leftOnly, rightOnly, i + aConcatSize);
builder.cell(addr, rightCells.values.get(i));
}
});
@@ -218,42 +221,34 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
}
- enum DimHow { match, expand, concat }
-
- static class SideHow {
- List<DimHow> handleDims = new ArrayList<>();
- int numMatch = 0;
- int numExpand = 0;
- }
-
- SideInfo analyse(Tensor side, SideHow how) {
+ CellVectorMapMap analyse(Tensor side, SplitHow how) {
var iter = side.cellIterator();
- String[] matchLabels = new String[how.numMatch];
- String[] expandLabels = new String[how.numExpand];
- SideInfo result = new SideInfo();
+ String[] commonLabels = new String[(int)how.numCommon()];
+ String[] separateLabels = new String[(int)how.numSeparate()];
+ CellVectorMapMap result = new CellVectorMapMap();
while (iter.hasNext()) {
var cell = iter.next();
var addr = cell.getKey();
long ccDimIndex = 0;
int matchIdx = 0;
- int expandIdx = 0;
+ int separateIdx = 0;
for (int i = 0; i < how.handleDims.size(); i++) {
switch (how.handleDims.get(i)) {
- case match:
- matchLabels[matchIdx++] = addr.label(i);
+ case common:
+ commonLabels[matchIdx++] = addr.label(i);
break;
- case expand:
- expandLabels[expandIdx++] = addr.label(i);
+ case separate:
+ separateLabels[separateIdx++] = addr.label(i);
break;
case concat:
ccDimIndex = addr.numericLabel(i);
break;
- default: throw new IllegalArgumentException("cannot handle: "+how.handleDims.get(i));
+ // default: throw new IllegalArgumentException("cannot handle: "+how.handleDims.get(i));
}
}
- TensorAddress matchAddr = TensorAddress.of(matchLabels);
- TensorAddress expandAddr = TensorAddress.of(expandLabels);
- result.match(matchAddr).expand(expandAddr).valueAt((int)ccDimIndex, cell.getValue());
+ TensorAddress commonAddr = TensorAddress.of(commonLabels);
+ TensorAddress separateAddr = TensorAddress.of(separateLabels);
+ result.lookupCreate(commonAddr).lookupCreate(separateAddr).setValue((int)ccDimIndex, cell.getValue());
}
return result;
}
@@ -302,10 +297,10 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
// if (a instanceof IndexedTensor && b instanceof IndexedTensor) {
// return oldEvaluate(a, b);
// }
- Helper helper = new Helper(a.type(), b.type(), dimension);
- SideInfo aInfo = analyse(a, helper.aHow);
- SideInfo bInfo = analyse(b, helper.bHow);
- return helper.mergeSides(aInfo, bInfo);
+ ConcatPlan plan = new ConcatPlan(a.type(), b.type(), dimension);
+ CellVectorMapMap aInfo = analyse(a, plan.aHow);
+ CellVectorMapMap bInfo = analyse(b, plan.bHow);
+ return plan.mergeSides(aInfo, bInfo);
}
private Tensor oldEvaluate(Tensor a, Tensor b) {