aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-10-08 11:25:43 +0200
committerLester Solbakken <lesters@oath.com>2018-10-08 11:25:43 +0200
commit6007a6fb18699de1bc1ffa7005825d487482b83e (patch)
tree0bcb89dcd8f5da0ea48f581374f10f30ef8e1691 /vespajlib
parentde3a914eb138ae8b6892e5aa7e0008c10cf667e7 (diff)
Add faster tensor rename if dimension after rename are in the same order
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java17
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java17
6 files changed, 64 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 2d127eb86cf..fb55b2d5014 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -162,6 +162,15 @@ public class IndexedTensor implements Tensor {
@Override
public TensorType type() { return type; }
+ @Override
+ public IndexedTensor withType(TensorType type) {
+ if (!this.type.isRenamableTo(type)) {
+ throw new IllegalArgumentException("IndexedTensor.withType: types are not compatible. Current type: '" +
+ this.type.toString() + "', requested type: '" + type.toString() + "'");
+ }
+ return new IndexedTensor(type, dimensionSizes, values);
+ }
+
public DimensionSizes dimensionSizes() {
return dimensionSizes;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index ef19ef2e96c..ec3020a1a4e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -42,6 +42,15 @@ public class MappedTensor implements Tensor {
public Map<TensorAddress, Double> cells() { return cells; }
@Override
+ public Tensor withType(TensorType other) {
+ if (!this.type.isRenamableTo(type)) {
+ throw new IllegalArgumentException("MappedTensor.withType: types are not compatible. Current type: '" +
+ this.type.toString() + "', requested type: '" + type.toString() + "'");
+ }
+ return new MappedTensor(other, cells);
+ }
+
+ @Override
public int hashCode() { return cells.hashCode(); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 5ff33aa340b..17e33c58a13 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -99,6 +99,15 @@ public class MixedTensor implements Tensor {
}
@Override
+ public Tensor withType(TensorType other) {
+ if (!this.type.isRenamableTo(type)) {
+ throw new IllegalArgumentException("MixedTensor.withType: types are not compatible. Current type: '" +
+ this.type.toString() + "', requested type: '" + type.toString() + "'");
+ }
+ return new MixedTensor(other, cells, index);
+ }
+
+ @Override
public int hashCode() { return cells.hashCode(); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 59d5ee72372..483ccd330e0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -86,6 +86,13 @@ public interface Tensor {
return valueIterator().next();
}
+ /**
+ * Returns this tensor with the given type if types are compatible
+ *
+ * @throws IllegalArgumentException if types are not compatible
+ */
+ Tensor withType(TensorType type);
+
// ----------------- Primitive tensor functions
default Tensor map(DoubleUnaryOperator mapper) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 1d447ed3eed..acba9eafd71 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -87,7 +87,7 @@ public class TensorType {
* i.e if the given type is a generalization of this type.
*/
public boolean isAssignableTo(TensorType generalization) {
- return isConvertibleOrAssignableTo(generalization, false);
+ return isConvertibleOrAssignableTo(generalization, false, true);
}
/**
@@ -98,16 +98,25 @@ public class TensorType {
* converted to the given type by zero padding.
*/
public boolean isConvertibleTo(TensorType generalization) {
- return isConvertibleOrAssignableTo(generalization, true);
+ return isConvertibleOrAssignableTo(generalization, true, true);
}
- private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible) {
+ /**
+ * Returns whether or not this type can simply be renamed to
+ * the given type. This is the same as being assignable, but disregarding
+ * dimension names.
+ */
+ public boolean isRenamableTo(TensorType other) {
+ return isConvertibleOrAssignableTo(other, false, false);
+ }
+
+ private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible, boolean considerName) {
if (generalization.dimensions().size() != this.dimensions().size()) return false;
for (int i = 0; i < generalization.dimensions().size(); i++) {
Dimension thisDimension = this.dimensions().get(i);
Dimension generalizationDimension = generalization.dimensions().get(i);
if (thisDimension.isIndexed() != generalizationDimension.isIndexed()) return false;
- if ( ! thisDimension.name().equals(generalizationDimension.name())) return false;
+ if (considerName && ! thisDimension.name().equals(generalizationDimension.name())) return false;
if (generalizationDimension.size().isPresent()) {
if ( ! thisDimension.size().isPresent()) return false;
if (convertible) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index 53d774de329..e18af235d59 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -95,6 +95,11 @@ public class Rename extends PrimitiveTensorFunction {
toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get();
}
+ // avoid building a new tensor if dimensions can simply be renamed
+ if (simpleRenameIsPossible(toIndexes)) {
+ return tensor.withType(renamedType);
+ }
+
Tensor.Builder builder = Tensor.Builder.of(renamedType);
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
@@ -104,6 +109,18 @@ public class Rename extends PrimitiveTensorFunction {
return builder.build();
}
+ /**
+ * If none of the dimensions change order after rename we can do a simple rename.
+ */
+ private boolean simpleRenameIsPossible(int[] toIndexes) {
+ for (int i = 0; i < toIndexes.length; ++i) {
+ if (toIndexes[i] != i) {
+ return false;
+ }
+ }
+ return true;
+ }
+
private TensorAddress rename(TensorAddress address, int[] toIndexes) {
String[] reorderedLabels = new String[toIndexes.length];
for (int i = 0; i < toIndexes.length; i++)