diff options
author | gjoranv <gjoranv@gmail.com> | 2017-12-17 21:44:49 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-12-17 21:44:49 +0100 |
commit | 03bce1fe1a494f2ac9d4268d4c90b08011b3f600 (patch) | |
tree | 180f294d2ac97d641f0266216ffdc328db9bfef8 /vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java | |
parent | b72e55b87eecae006ed92976151137a80d75be0f (diff) |
Revert "Bratseth/tensorflow models"
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java | 41 |
1 files changed, 16 insertions, 25 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index de9f90a5804..cfc78be7e0c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -19,7 +19,7 @@ import java.util.Objects; import java.util.Set; /** - * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions + * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions * are collapsed to a single value using an aggregator function. * * @author bratseth @@ -45,7 +45,7 @@ public class Reduce extends PrimitiveTensorFunction { /** * Creates a reduce function. - * + * * @param argument the tensor to reduce * @param aggregator the aggregator function to use * @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced, @@ -61,15 +61,6 @@ public class Reduce extends PrimitiveTensorFunction { this.dimensions = ImmutableList.copyOf(dimensions); } - public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) { - TensorType.Builder b = new TensorType.Builder(); - for (TensorType.Dimension dimension : inputType.dimensions()) { - if ( ! reduceDimensions.contains(dimension.name())) - b.dimension(dimension); - } - return b.build(); - } - public TensorFunction argument() { return argument; } @Override @@ -91,7 +82,7 @@ public class Reduce extends PrimitiveTensorFunction { public String toString(ToStringContext context) { return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")"; } - + private String commaSeparated(List<String> list) { StringBuilder b = new StringBuilder(); for (String element : list) @@ -103,7 +94,7 @@ public class Reduce extends PrimitiveTensorFunction { public Tensor evaluate(EvaluationContext context) { Tensor argument = this.argument.evaluate(context); if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) - throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + + throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + dimensions + ": Not all those dimensions are present in this tensor"); // Special case: Reduce all @@ -112,14 +103,14 @@ public class Reduce extends PrimitiveTensorFunction { return reduceIndexedVector((IndexedTensor)argument); else return reduceAllGeneral(argument); - + // Reduce type TensorType.Builder builder = new TensorType.Builder(); for (TensorType.Dimension dimension : argument.type().dimensions()) if ( ! dimensions.contains(dimension.name())) // keep builder.dimension(dimension); TensorType reducedType = builder.build(); - + // Reduce cells Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { @@ -131,10 +122,10 @@ public class Reduce extends PrimitiveTensorFunction { Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType); for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet()) reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue()); - + return reducedBuilder.build(); } - + private TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType) { Set<Integer> indexesToRemove = new HashSet<>(); for (String dimensionToRemove : this.dimensions) @@ -147,7 +138,7 @@ public class Reduce extends PrimitiveTensorFunction { reducedLabels[reducedLabelIndex++] = address.label(i); return TensorAddress.of(reducedLabels); } - + private Tensor reduceAllGeneral(Tensor argument) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); ) @@ -163,7 +154,7 @@ public class Reduce extends PrimitiveTensorFunction { } private static abstract class ValueAggregator { - + private static ValueAggregator ofType(Aggregator aggregator) { switch (aggregator) { case avg : return new AvgAggregator(); @@ -174,22 +165,22 @@ public class Reduce extends PrimitiveTensorFunction { case min : return new MinAggregator(); default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented"); } - + } /** Add a new value to those aggregated by this */ public abstract void aggregate(double value); - + /** Returns the value aggregated by this */ public abstract double aggregatedValue(); - + } - + private static class AvgAggregator extends ValueAggregator { private int valueCount = 0; private double valueSum = 0.0; - + @Override public void aggregate(double value) { valueCount++; @@ -197,7 +188,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public double aggregatedValue() { + public double aggregatedValue() { return valueSum / valueCount; } |