diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-16 16:15:30 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-16 16:15:30 +0100 |
commit | 53966483ec9541d02e272572261ffd04fc6ed570 (patch) | |
tree | d5d3528af993255334c89a82e82242661f2ca31b /vespajlib/src/main/java/com/yahoo | |
parent | 7eec9171277f9e153cc2e0dc9be3e79ac8ab0512 (diff) |
Move to iterator access where possible
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo')
8 files changed, 130 insertions, 35 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 78e4b93b32e..263c41a6e13 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -8,8 +8,10 @@ import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; /** @@ -33,7 +35,17 @@ public class IndexedTensor implements Tensor { this.dimensionSizes = dimensionSizes; this.values = values; } - + + @Override + public int size() { + return values.length; + } + + @Override + public Iterator<Map.Entry<TensorAddress, Double>> cellIterator() { + return new ValueIterator(); + } + /** * Returns the value at the given indexes * @@ -58,7 +70,8 @@ public class IndexedTensor implements Tensor { } private static int toValueIndex(int[] indexes, int[] dimensionSizes) { - if (indexes.length == 0) return 0; + if (indexes.length == 1) return indexes[0]; // for speed + if (indexes.length == 0) return 0; // for speed int valueIndex = 0; for (int i = 0; i < indexes.length; i++) @@ -95,9 +108,8 @@ public class IndexedTensor implements Tensor { } @Override - // TODO: Replace this with iterator public Map<TensorAddress, Double> cells() { - if (dimensionSizes.length == 0) + if (dimensionSizes.length == 0) return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]); ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>(); @@ -105,15 +117,20 @@ public class IndexedTensor implements Tensor { for (int i = 0; i < values.length; i++) { builder.put(new TensorAddress(tensorIndexes), values[i]); if (i < values.length -1) - next(tensorIndexes.length - 1, tensorIndexes, dimensionSizes); + next(tensorIndexes); } return builder.build(); } - private void next(int dimension, int[] tensorIndexes, int[] dimensionSizes) { + private void next(int[] tensorIndexes) { + nextRecursively(tensorIndexes.length - 1, tensorIndexes); + } + + // TODO: Tail recursion -> loop + private void nextRecursively(int dimension, int[] tensorIndexes) { if (tensorIndexes[dimension] + 1 == dimensionSizes[dimension]) { tensorIndexes[dimension] = 0; - next(dimension - 1, tensorIndexes, dimensionSizes); + nextRecursively(dimension - 1, tensorIndexes); } else { tensorIndexes[dimension]++; @@ -365,5 +382,53 @@ public class IndexedTensor implements Tensor { } } + + private class ValueIterator implements Iterator<Map.Entry<TensorAddress, Double>> { + + private int cursor = 0; + private final int[] tensorIndexes = new int[dimensionSizes.length]; + + @Override + public boolean hasNext() { + return cursor < values.length; + } + + @Override + public Map.Entry<TensorAddress, Double> next() { + if ( ! hasNext()) throw new NoSuchElementException(); + + Map.Entry<TensorAddress, Double> current = new Cell(new TensorAddress(tensorIndexes), values[cursor]); + + cursor++; + if (hasNext()) + IndexedTensor.this.next(tensorIndexes); + + return current; + } + + private class Cell implements Map.Entry<TensorAddress, Double> { + + private final TensorAddress address; + private final Double value; + + private Cell(TensorAddress address, Double value) { + this.address = address; + this.value = value; + } + + @Override + public TensorAddress getKey() { return address; } + + @Override + public Double getValue() { return value; } + + @Override + public Double setValue(Double value) { + throw new UnsupportedOperationException("A tensor cannot be modified"); + } + + } + + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index 44451dc3f51..c3dce27c651 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -4,6 +4,7 @@ package com.yahoo.tensor; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableMap; +import java.util.Iterator; import java.util.Map; /** @@ -26,14 +27,20 @@ public class MappedTensor implements Tensor { @Override public TensorType type() { return type; } - + @Override - public Map<TensorAddress, Double> cells() { return cells; } + public int size() { return cells.size(); } @Override public double get(TensorAddress address) { return cells.getOrDefault(address, Double.NaN); } @Override + public Iterator<Map.Entry<TensorAddress, Double>> cellIterator() { return cells.entrySet().iterator(); } + + @Override + public Map<TensorAddress, Double> cells() { return cells; } + + @Override public int hashCode() { return cells.hashCode(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index eea9f61c9df..cc40f84ccd3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -15,6 +15,7 @@ import com.yahoo.tensor.functions.Softmax; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -43,14 +44,24 @@ import java.util.function.Function; @Beta public interface Tensor { + // ----------------- Accessors + TensorType type(); - /** Returns an immutable map of the cells of this */ - Map<TensorAddress, Double> cells(); + /** Returns whether this have any cells */ + default boolean isEmpty() { return size() == 0; } + + /** Returns the number of cells in this */ + int size(); /** Returns the value of a cell, or NaN if this cell does not exist/have no value */ double get(TensorAddress address); + Iterator<Map.Entry<TensorAddress, Double>> cellIterator(); + + /** Returns an immutable map of the cells of this. This may be expensive for some implementations - avoid when possible */ + Map<TensorAddress, Double> cells(); + /** * Returns the value of this as a double if it has no dimensions and one value * @@ -59,11 +70,10 @@ public interface Tensor { default double asDouble() { if (type().dimensions().size() > 0) throw new IllegalStateException("This tensor is not dimensionless. Dimensions: " + type().dimensions().size()); - Map<TensorAddress, Double> cells = cells(); - if (cells.size() == 0) return Double.NaN; - if (cells.size() > 1) - throw new IllegalStateException("This tensor does not have a single value, it has " + cells().size()); - return cells.values().iterator().next(); + if (size() == 0) return Double.NaN; + if (size() > 1) + throw new IllegalStateException("This tensor does not have a single value, it has " + size()); + return cellIterator().next().getValue(); } // ----------------- Primitive tensor functions @@ -163,7 +173,7 @@ public interface Tensor { * @return the tensor on the standard string format */ static String toStandardString(Tensor tensor) { - if (tensor.cells().isEmpty() && ! tensor.type().dimensions().isEmpty()) // explicitly output type TODO: Never do that? + if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty()) // explicitly output type TODO: Never do that? return tensor.type() + ":" + contentToString(tensor); else return contentToString(tensor); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index dbbcde0ad7b..a48b2b5ae4f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -10,6 +10,7 @@ import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -81,7 +82,6 @@ public class Join extends PrimitiveTensorFunction { } private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) { - // TODO: Consider special-case builder for 1 dimension int joinedLength = Math.min(a.length(0), b.length(0)); IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new int[] { joinedLength}); for (int i = 0; i < joinedLength; i++) @@ -92,9 +92,10 @@ public class Join extends PrimitiveTensorFunction { /** When both tensors have the same dimensions, at most one cell matches a cell in the other tensor */ private Tensor singleSpaceJoin(Tensor a, Tensor b, TensorType joinedType) { Tensor.Builder builder = Tensor.Builder.of(joinedType); - for (Map.Entry<TensorAddress, Double> aCell : a.cells().entrySet()) { - Double bCellValue = b.cells().get(aCell.getKey()); - if (bCellValue == null) continue; // no match + for (Iterator<Map.Entry<TensorAddress, Double>> i = a.cellIterator(); i.hasNext(); ) { + Map.Entry<TensorAddress, Double> aCell = i.next(); + double bCellValue = b.get(aCell.getKey()); + if (Double.isNaN(bCellValue)) continue; // no match builder.cell(aCell.getKey(), combinator.applyAsDouble(aCell.getValue(), bCellValue)); } return builder.build(); @@ -104,7 +105,8 @@ public class Join extends PrimitiveTensorFunction { private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type()); Tensor.Builder builder = Tensor.Builder.of(joinedType); - for (Map.Entry<TensorAddress, Double> supercell : superspace.cells().entrySet()) { + for (Iterator<Map.Entry<TensorAddress, Double>> i = superspace.cellIterator(); i.hasNext(); ) { + Map.Entry<TensorAddress, Double> supercell = i.next(); TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes); double subspaceValue = subspace.get(subaddress); if ( ! Double.isNaN(subspaceValue)) @@ -136,8 +138,10 @@ public class Join extends PrimitiveTensorFunction { int[] aToIndexes = mapIndexes(a.type(), joinedType); int[] bToIndexes = mapIndexes(b.type(), joinedType); Tensor.Builder builder = Tensor.Builder.of(joinedType); - for (Map.Entry<TensorAddress, Double> aCell : a.cells().entrySet()) { - for (Map.Entry<TensorAddress, Double> bCell : b.cells().entrySet()) { + for (Iterator<Map.Entry<TensorAddress, Double>> aIterator = a.cellIterator(); aIterator.hasNext(); ) { + Map.Entry<TensorAddress, Double> aCell = aIterator.next(); + for (Iterator<Map.Entry<TensorAddress, Double>> bIterator = b.cellIterator(); bIterator.hasNext(); ) { + Map.Entry<TensorAddress, Double> bCell = bIterator.next(); TensorAddress combinedAddress = combineAddresses(aCell.getKey(), aToIndexes, bCell.getKey(), bToIndexes, joinedType); if (combinedAddress == null) continue; // not combinable diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index 1f28b6a14c2..b5ca6d3ccfb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -8,6 +8,7 @@ import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.Objects; import java.util.function.DoubleUnaryOperator; @@ -52,8 +53,10 @@ public class Map extends PrimitiveTensorFunction { public Tensor evaluate(EvaluationContext context) { Tensor argument = argument().evaluate(context); Tensor.Builder builder = Tensor.Builder.of(argument.type()); - for (java.util.Map.Entry<TensorAddress, Double> cell : argument.cells().entrySet()) + for (Iterator<java.util.Map.Entry<TensorAddress, Double>> i = argument.cellIterator(); i.hasNext(); ) { + java.util.Map.Entry<TensorAddress, Double> cell = i.next(); builder.cell(cell.getKey(), mapper.applyAsDouble(cell.getValue())); + } return builder.build(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 57b862534a1..96ab8bc6f60 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -13,6 +13,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -97,6 +98,7 @@ public class Reduce extends PrimitiveTensorFunction { throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + dimensions + ": Not all those dimensions are present in this tensor"); + // Special case: Reduce all if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor) return reduceIndexedVector((IndexedTensor)argument); @@ -112,7 +114,8 @@ public class Reduce extends PrimitiveTensorFunction { // Reduce cells Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(); - for (Map.Entry<TensorAddress, Double> cell : argument.cells().entrySet()) { + for (Iterator<Map.Entry<TensorAddress, Double>> i = argument.cellIterator(); i.hasNext(); ) { + Map.Entry<TensorAddress, Double> cell = i.next(); TensorAddress reducedAddress = reduceDimensions(cell.getKey(), argument.type(), reducedType); aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); aggregatingCells.get(reducedAddress).aggregate(cell.getValue()); @@ -139,8 +142,8 @@ public class Reduce extends PrimitiveTensorFunction { private Tensor reduceAllGeneral(Tensor argument) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); - for (Double cellValue : argument.cells().values()) - valueAggregator.aggregate(cellValue); + for (Iterator<Map.Entry<TensorAddress, Double>> i = argument.cellIterator(); i.hasNext(); ) + valueAggregator.aggregate(i.next().getValue()); return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 39d555d1632..aabe38d1824 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -11,6 +11,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Collections; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -70,7 +71,8 @@ public class Rename extends PrimitiveTensorFunction { } Tensor.Builder builder = Tensor.Builder.of(renamedType); - for (Map.Entry<TensorAddress, Double> cell : tensor.cells().entrySet()) { + for (Iterator<Map.Entry<TensorAddress, Double>> i = tensor.cellIterator(); i.hasNext(); ) { + Map.Entry<TensorAddress, Double> cell = i.next(); TensorAddress renamedAddress = rename(cell.getKey(), toIndexes); builder.cell(renamedAddress, cell.getValue()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index 1dc35f20057..cb9a93fd233 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -28,7 +28,7 @@ class SparseBinaryFormat implements BinaryFormat { @Override public void encode(GrowableByteBuffer buffer, Tensor tensor) { encodeDimensions(buffer, tensor.type().dimensions()); - encodeCells(buffer, tensor.cells()); + encodeCells(buffer, tensor); } private static void encodeDimensions(GrowableByteBuffer buffer, List<TensorType.Dimension> sortedDimensions) { @@ -38,11 +38,12 @@ class SparseBinaryFormat implements BinaryFormat { } } - private static void encodeCells(GrowableByteBuffer buffer, Map<TensorAddress, Double> cells) { - buffer.putInt1_4Bytes(cells.size()); - for (Map.Entry<TensorAddress, Double> cellEntry : cells.entrySet()) { - encodeAddress(buffer, cellEntry.getKey()); - buffer.putDouble(cellEntry.getValue()); + private static void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { + buffer.putInt1_4Bytes(tensor.size()); + for (Iterator<Map.Entry<TensorAddress, Double>> i = tensor.cellIterator(); i.hasNext(); ) { + Map.Entry<TensorAddress, Double> cell = i.next(); + encodeAddress(buffer, cell.getKey()); + buffer.putDouble(cell.getValue()); } } |