diff options
author | Lester Solbakken <lesters@oath.com> | 2018-10-08 11:25:43 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-10-08 11:25:43 +0200 |
commit | 6007a6fb18699de1bc1ffa7005825d487482b83e (patch) | |
tree | 0bcb89dcd8f5da0ea48f581374f10f30ef8e1691 /vespajlib | |
parent | de3a914eb138ae8b6892e5aa7e0008c10cf667e7 (diff) |
Add faster tensor rename if dimension after rename are in the same order
Diffstat (limited to 'vespajlib')
6 files changed, 64 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 2d127eb86cf..fb55b2d5014 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -162,6 +162,15 @@ public class IndexedTensor implements Tensor { @Override public TensorType type() { return type; } + @Override + public IndexedTensor withType(TensorType type) { + if (!this.type.isRenamableTo(type)) { + throw new IllegalArgumentException("IndexedTensor.withType: types are not compatible. Current type: '" + + this.type.toString() + "', requested type: '" + type.toString() + "'"); + } + return new IndexedTensor(type, dimensionSizes, values); + } + public DimensionSizes dimensionSizes() { return dimensionSizes; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index ef19ef2e96c..ec3020a1a4e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -42,6 +42,15 @@ public class MappedTensor implements Tensor { public Map<TensorAddress, Double> cells() { return cells; } @Override + public Tensor withType(TensorType other) { + if (!this.type.isRenamableTo(type)) { + throw new IllegalArgumentException("MappedTensor.withType: types are not compatible. Current type: '" + + this.type.toString() + "', requested type: '" + type.toString() + "'"); + } + return new MappedTensor(other, cells); + } + + @Override public int hashCode() { return cells.hashCode(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 5ff33aa340b..17e33c58a13 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -99,6 +99,15 @@ public class MixedTensor implements Tensor { } @Override + public Tensor withType(TensorType other) { + if (!this.type.isRenamableTo(type)) { + throw new IllegalArgumentException("MixedTensor.withType: types are not compatible. Current type: '" + + this.type.toString() + "', requested type: '" + type.toString() + "'"); + } + return new MixedTensor(other, cells, index); + } + + @Override public int hashCode() { return cells.hashCode(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 59d5ee72372..483ccd330e0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -86,6 +86,13 @@ public interface Tensor { return valueIterator().next(); } + /** + * Returns this tensor with the given type if types are compatible + * + * @throws IllegalArgumentException if types are not compatible + */ + Tensor withType(TensorType type); + // ----------------- Primitive tensor functions default Tensor map(DoubleUnaryOperator mapper) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 1d447ed3eed..acba9eafd71 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -87,7 +87,7 @@ public class TensorType { * i.e if the given type is a generalization of this type. */ public boolean isAssignableTo(TensorType generalization) { - return isConvertibleOrAssignableTo(generalization, false); + return isConvertibleOrAssignableTo(generalization, false, true); } /** @@ -98,16 +98,25 @@ public class TensorType { * converted to the given type by zero padding. */ public boolean isConvertibleTo(TensorType generalization) { - return isConvertibleOrAssignableTo(generalization, true); + return isConvertibleOrAssignableTo(generalization, true, true); } - private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible) { + /** + * Returns whether or not this type can simply be renamed to + * the given type. This is the same as being assignable, but disregarding + * dimension names. + */ + public boolean isRenamableTo(TensorType other) { + return isConvertibleOrAssignableTo(other, false, false); + } + + private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible, boolean considerName) { if (generalization.dimensions().size() != this.dimensions().size()) return false; for (int i = 0; i < generalization.dimensions().size(); i++) { Dimension thisDimension = this.dimensions().get(i); Dimension generalizationDimension = generalization.dimensions().get(i); if (thisDimension.isIndexed() != generalizationDimension.isIndexed()) return false; - if ( ! thisDimension.name().equals(generalizationDimension.name())) return false; + if (considerName && ! thisDimension.name().equals(generalizationDimension.name())) return false; if (generalizationDimension.size().isPresent()) { if ( ! thisDimension.size().isPresent()) return false; if (convertible) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 53d774de329..e18af235d59 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -95,6 +95,11 @@ public class Rename extends PrimitiveTensorFunction { toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get(); } + // avoid building a new tensor if dimensions can simply be renamed + if (simpleRenameIsPossible(toIndexes)) { + return tensor.withType(renamedType); + } + Tensor.Builder builder = Tensor.Builder.of(renamedType); for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); @@ -104,6 +109,18 @@ public class Rename extends PrimitiveTensorFunction { return builder.build(); } + /** + * If none of the dimensions change order after rename we can do a simple rename. + */ + private boolean simpleRenameIsPossible(int[] toIndexes) { + for (int i = 0; i < toIndexes.length; ++i) { + if (toIndexes[i] != i) { + return false; + } + } + return true; + } + private TensorAddress rename(TensorAddress address, int[] toIndexes) { String[] reorderedLabels = new String[toIndexes.length]; for (int i = 0; i < toIndexes.length; i++) |