aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-10-08 11:25:43 +0200
committerLester Solbakken <lesters@oath.com>2018-10-08 11:25:43 +0200
commit6007a6fb18699de1bc1ffa7005825d487482b83e (patch)
tree0bcb89dcd8f5da0ea48f581374f10f30ef8e1691 /vespajlib/src/main/java/com/yahoo/tensor/functions
parentde3a914eb138ae8b6892e5aa7e0008c10cf667e7 (diff)
Add faster tensor rename if dimension after rename are in the same order
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java17
1 files changed, 17 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index 53d774de329..e18af235d59 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -95,6 +95,11 @@ public class Rename extends PrimitiveTensorFunction {
toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get();
}
+ // avoid building a new tensor if dimensions can simply be renamed
+ if (simpleRenameIsPossible(toIndexes)) {
+ return tensor.withType(renamedType);
+ }
+
Tensor.Builder builder = Tensor.Builder.of(renamedType);
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
@@ -104,6 +109,18 @@ public class Rename extends PrimitiveTensorFunction {
return builder.build();
}
+ /**
+ * If none of the dimensions change order after rename we can do a simple rename.
+ */
+ private boolean simpleRenameIsPossible(int[] toIndexes) {
+ for (int i = 0; i < toIndexes.length; ++i) {
+ if (toIndexes[i] != i) {
+ return false;
+ }
+ }
+ return true;
+ }
+
private TensorAddress rename(TensorAddress address, int[] toIndexes) {
String[] reorderedLabels = new String[toIndexes.length];
for (int i = 0; i < toIndexes.length; i++)