aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-12-13 15:21:44 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-12-13 15:21:44 +0100
commit3783a9b21f8ab7ca3700903d9780a9f7374cf0c5 (patch)
treeec003528946a37b9f0aeb49e1b314fdc6601c26e /vespajlib/src/main/java/com/yahoo/tensor/functions
parent5b67e6f8f641141f848ad3989156151f9f182441 (diff)
Check agreement between TF and Vespa execution
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java32
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java34
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java22
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java6
14 files changed, 79 insertions, 81 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
index 8f4dbf014a7..191c7988443 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
@@ -8,7 +8,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext;
/**
* A composite tensor function is a tensor function which can be expressed (less tersely)
* as a tree of primitive tensor functions.
- *
+ *
* @author bratseth
*/
@Beta
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 1dbb94fdb20..faa0ca36cb6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -15,7 +15,7 @@ import java.util.stream.Collectors;
/**
* Concatenation of two tensors along an (indexed) dimension
- *
+ *
* @author bratseth
*/
@Beta
@@ -74,7 +74,7 @@ public class Concat extends PrimitiveTensorFunction {
concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
return builder.build();
}
-
+
private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType,
int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet());
@@ -112,7 +112,7 @@ public class Concat extends PrimitiveTensorFunction {
Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build();
return tensor.multiply(unitTensor);
}
-
+
}
/** Returns the type resulting from concatenating a and b */
@@ -144,7 +144,7 @@ public class Concat extends PrimitiveTensorFunction {
/**
* Combine two addresses, adding the offset to the concat dimension
*
- * @return the combined address or null if the addresses are incompatible
+ * @return the combined address or null if the addresses are incompatible
* (in some other dimension than the concat dimension)
*/
private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
@@ -161,7 +161,7 @@ public class Concat extends PrimitiveTensorFunction {
/**
* Returns the an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
- * That is, if the returned array contains n at index i then
+ * That is, if the returned array contains n at index i then
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
* If some dimension in fromType is not present in toType, the corresponding index will be -1
*/
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index 4ac7b21ba90..14ed38718ce 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -10,18 +10,18 @@ import java.util.List;
/**
* A function which returns a constant tensor.
- *
+ *
* @author bratseth
*/
@Beta
public class ConstantTensor extends PrimitiveTensorFunction {
private final Tensor constant;
-
+
public ConstantTensor(String tensorString) {
this.constant = Tensor.from(tensorString);
}
-
+
public ConstantTensor(Tensor tensor) {
this.constant = tensor;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
index bbdbd5c3df1..c75d8ee4753 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -11,19 +11,19 @@ import java.util.stream.Stream;
/**
* A tensor generator which returns a tensor of any dimension filled with 1 in the diagonal and 0 elsewhere.
- *
+ *
* @author bratseth
*/
public class Diag extends CompositeTensorFunction {
private final TensorType type;
private final Function<List<Integer>, Double> diagFunction;
-
+
public Diag(TensorType type) {
this.type = type;
this.diagFunction = ScalarFunctions.equal(dimensionNames().collect(Collectors.toList()));
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -43,7 +43,7 @@ public class Diag extends CompositeTensorFunction {
public String toString(ToStringContext context) {
return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction;
}
-
+
private Stream<String> dimensionNames() {
return type.dimensions().stream().map(TensorType.Dimension::name);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index 6ea73b7f310..e42d25197e2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -15,7 +15,7 @@ import java.util.function.Function;
/**
* An indexed tensor whose values are generated by a function
- *
+ *
* @author bratseth
*/
@Beta
@@ -26,7 +26,7 @@ public class Generate extends PrimitiveTensorFunction {
/**
* Creates a generated tensor
- *
+ *
* @param type the type of the tensor
* @param generator the function generating values from a list of ints specifying the indexes of the
* tensor cell which will receive the value
@@ -39,7 +39,7 @@ public class Generate extends PrimitiveTensorFunction {
this.type = type;
this.generator = generator;
}
-
+
private void validateType(TensorType type) {
for (TensorType.Dimension dimension : type.dimensions())
if (dimension.type() != TensorType.Dimension.Type.indexedBound)
@@ -58,7 +58,7 @@ public class Generate extends PrimitiveTensorFunction {
@Override
public PrimitiveTensorFunction toPrimitive() { return this; }
-
+
@Override
public Tensor evaluate(EvaluationContext context) {
Tensor.Builder builder = Tensor.Builder.of(type);
@@ -69,14 +69,14 @@ public class Generate extends PrimitiveTensorFunction {
}
return builder.build();
}
-
+
private DimensionSizes dimensionSizes(TensorType type) {
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
for (int i = 0; i < b.dimensions(); i++)
b.set(i, type.dimensions().get(i).size().get());
return b.build();
}
-
+
@Override
public String toString(ToStringContext context) { return type + "(" + generator + ")"; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 9a37127e1f0..ff887e3e9a6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -28,12 +28,12 @@ import java.util.function.DoubleBinaryOperator;
* The <i>join</i> tensor operation produces a tensor from the argument tensors containing the set of cells
* given by the cross product of the cells of the given tensors, having as values the value produced by
* applying the given combinator function on the values from the two source cells.
- *
+ *
* @author bratseth
*/
@Beta
public class Join extends PrimitiveTensorFunction {
-
+
private final TensorFunction argumentA, argumentB;
private final DoubleBinaryOperator combinator;
@@ -56,7 +56,7 @@ public class Join extends PrimitiveTensorFunction {
if (aDim.name().equals(bDim.name())) { // include
if (aDim.isIndexed() && bDim.isIndexed()) {
if (aDim.size().isPresent() || bDim.size().isPresent())
- typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE),
+ typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE),
bDim.size().orElse(Integer.MAX_VALUE)));
else
typeBuilder.indexed(aDim.name());
@@ -112,11 +112,11 @@ public class Join extends PrimitiveTensorFunction {
else
return generalJoin(a, b, joinedType);
}
-
+
private boolean hasSingleIndexedDimension(Tensor tensor) {
return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
}
-
+
private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) {
int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Iterator<Double> aIterator = a.valueIterator();
@@ -138,7 +138,7 @@ public class Join extends PrimitiveTensorFunction {
}
return builder.build();
}
-
+
/** Join a tensor into a superspace */
private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor)
@@ -150,7 +150,7 @@ public class Join extends PrimitiveTensorFunction {
private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes
return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build();
-
+
DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace);
IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes);
@@ -158,14 +158,14 @@ public class Join extends PrimitiveTensorFunction {
// Find dimensions which are only in the supertype
Set<String> superDimensionNames = new HashSet<>(superspace.type().dimensionNames());
superDimensionNames.removeAll(subspace.type().dimensionNames());
-
+
for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) {
IndexedTensor.SubspaceIterator subspaceInSuper = i.next();
joinSubspaces(subspace.valueIterator(), subspace.size(),
subspaceInSuper, subspaceInSuper.size(),
reversedArgumentOrder, builder);
}
-
+
return builder.build();
}
@@ -224,7 +224,7 @@ public class Join extends PrimitiveTensorFunction {
subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get();
return subspaceIndexes;
}
-
+
private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
String[] subspaceLabels = new String[subspaceIndexes.length];
for (int i = 0; i < subspaceIndexes.length; i++)
@@ -259,7 +259,7 @@ public class Join extends PrimitiveTensorFunction {
DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize);
// for each combination of dimensions only in a
- for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) {
+ for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) {
IndexedTensor.SubspaceIterator aSubspace = ia.next();
// for each combination of dimensions in a which is also in b
while (aSubspace.hasNext()) {
@@ -276,7 +276,7 @@ public class Join extends PrimitiveTensorFunction {
}
}
}
-
+
private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) {
PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size());
for (int i = 0; i < addressType.dimensions().size(); i++)
@@ -284,7 +284,7 @@ public class Join extends PrimitiveTensorFunction {
builder.add(addressType.dimensions().get(i).name(), address.intLabel(i));
return builder.build();
}
-
+
/** Returns the sizes from the joined sizes which are present in the type argument */
private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) {
DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
@@ -295,7 +295,7 @@ public class Join extends PrimitiveTensorFunction {
}
return builder.build();
}
-
+
private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) {
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
@@ -364,7 +364,7 @@ public class Join extends PrimitiveTensorFunction {
/**
* Returns the an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
- * That is, if the returned array contains n at index i then
+ * That is, if the returned array contains n at index i then
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
* If some dimension in fromType is not present in toType, the corresponding index will be -1
*/
@@ -384,7 +384,7 @@ public class Join extends PrimitiveTensorFunction {
return TensorAddress.of(joinedLabels);
}
- /**
+ /**
* Maps the content in the given list to the given array, using the given index map.
*
* @return true if the mapping was successful, false if one of the destination positions was
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index d322a6ab497..a5e1a016a41 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -32,7 +32,7 @@ public class Map extends PrimitiveTensorFunction {
this.argument = argument;
this.mapper = mapper;
}
-
+
public static TensorType outputType(TensorType inputType) { return inputType; }
public TensorFunction argument() { return argument; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
index 5e102454487..4071917c2b5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -15,15 +15,15 @@ public class Matmul extends CompositeTensorFunction {
private final TensorFunction argument1, argument2;
private final String dimension;
-
+
public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) {
this.argument1 = argument1;
this.argument2 = argument2;
this.dimension = dimension;
}
-
+
public static TensorType outputType(TensorType a, TensorType b, String dimension) {
- return Reduce.outputType(Join.outputType(a, b), ImmutableList.of(dimension));
+ return Join.outputType(a, b);
}
@Override
@@ -44,7 +44,7 @@ public class Matmul extends CompositeTensorFunction {
Reduce.Aggregator.sum,
dimension);
}
-
+
@Override
public String toString(ToStringContext context) {
return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
index efb7b9e500c..b7c9a5d2342 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
@@ -8,10 +8,10 @@ import com.yahoo.tensor.Tensor;
* A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions.
* All tensor implementations must implement all primitive tensor functions.
* Primitive tensor functions are fully inspectable.
- *
+ *
* @author bratseth
*/
@Beta
public abstract class PrimitiveTensorFunction extends TensorFunction {
-
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
index 457763e97ba..958ef85d1dc 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
@@ -22,11 +22,11 @@ import java.util.stream.Stream;
public class Random extends CompositeTensorFunction {
private final TensorType type;
-
+
public Random(TensorType type) {
this.type = type;
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -46,7 +46,7 @@ public class Random extends CompositeTensorFunction {
public String toString(ToStringContext context) {
return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")";
}
-
+
private Stream<String> dimensionNames() {
return type.dimensions().stream().map(TensorType.Dimension::toString);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
index e2b39a2048d..a56f82b026a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -12,19 +12,19 @@ import java.util.stream.Stream;
/**
* A tensor generator which returns a tensor of any dimension filled with the sum of the tensor
* indexes of each position.
- *
+ *
* @author bratseth
*/
public class Range extends CompositeTensorFunction {
private final TensorType type;
private final Function<List<Integer>, Double> rangeFunction;
-
+
public Range(TensorType type) {
this.type = type;
this.rangeFunction = ScalarFunctions.sum(dimensionNames().collect(Collectors.toList()));
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -44,7 +44,7 @@ public class Range extends CompositeTensorFunction {
public String toString(ToStringContext context) {
return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction;
}
-
+
private Stream<String> dimensionNames() {
return type.dimensions().stream().map(TensorType.Dimension::toString);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index a51df12e522..de9f90a5804 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -19,7 +19,7 @@ import java.util.Objects;
import java.util.Set;
/**
- * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions
+ * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions
* are collapsed to a single value using an aggregator function.
*
* @author bratseth
@@ -45,7 +45,7 @@ public class Reduce extends PrimitiveTensorFunction {
/**
* Creates a reduce function.
- *
+ *
* @param argument the tensor to reduce
* @param aggregator the aggregator function to use
* @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced,
@@ -69,7 +69,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
return b.build();
}
-
+
public TensorFunction argument() { return argument; }
@Override
@@ -91,7 +91,7 @@ public class Reduce extends PrimitiveTensorFunction {
public String toString(ToStringContext context) {
return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")";
}
-
+
private String commaSeparated(List<String> list) {
StringBuilder b = new StringBuilder();
for (String element : list)
@@ -103,7 +103,7 @@ public class Reduce extends PrimitiveTensorFunction {
public Tensor evaluate(EvaluationContext context) {
Tensor argument = this.argument.evaluate(context);
if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
- throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
+ throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
dimensions + ": Not all those dimensions are present in this tensor");
// Special case: Reduce all
@@ -112,14 +112,14 @@ public class Reduce extends PrimitiveTensorFunction {
return reduceIndexedVector((IndexedTensor)argument);
else
return reduceAllGeneral(argument);
-
+
// Reduce type
TensorType.Builder builder = new TensorType.Builder();
for (TensorType.Dimension dimension : argument.type().dimensions())
if ( ! dimensions.contains(dimension.name())) // keep
builder.dimension(dimension);
TensorType reducedType = builder.build();
-
+
// Reduce cells
Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>();
for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) {
@@ -131,10 +131,10 @@ public class Reduce extends PrimitiveTensorFunction {
Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType);
for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet())
reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue());
-
+
return reducedBuilder.build();
}
-
+
private TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType) {
Set<Integer> indexesToRemove = new HashSet<>();
for (String dimensionToRemove : this.dimensions)
@@ -147,7 +147,7 @@ public class Reduce extends PrimitiveTensorFunction {
reducedLabels[reducedLabelIndex++] = address.label(i);
return TensorAddress.of(reducedLabels);
}
-
+
private Tensor reduceAllGeneral(Tensor argument) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); )
@@ -163,7 +163,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
private static abstract class ValueAggregator {
-
+
private static ValueAggregator ofType(Aggregator aggregator) {
switch (aggregator) {
case avg : return new AvgAggregator();
@@ -174,22 +174,22 @@ public class Reduce extends PrimitiveTensorFunction {
case min : return new MinAggregator();
default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented");
}
-
+
}
/** Add a new value to those aggregated by this */
public abstract void aggregate(double value);
-
+
/** Returns the value aggregated by this */
public abstract double aggregatedValue();
-
+
}
-
+
private static class AvgAggregator extends ValueAggregator {
private int valueCount = 0;
private double valueSum = 0.0;
-
+
@Override
public void aggregate(double value) {
valueCount++;
@@ -197,7 +197,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
- public double aggregatedValue() {
+ public double aggregatedValue() {
return valueSum / valueCount;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index 6e52760424e..ec9b762a41c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -3,8 +3,6 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -19,7 +17,7 @@ import java.util.Objects;
/**
* The <i>rename</i> tensor function returns a tensor where some dimensions are assigned new names.
- *
+ *
* @author bratseth
*/
@Beta
@@ -28,7 +26,7 @@ public class Rename extends PrimitiveTensorFunction {
private final TensorFunction argument;
private final List<String> fromDimensions;
private final List<String> toDimensions;
-
+
public Rename(TensorFunction argument, String fromDimension, String toDimension) {
this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension));
}
@@ -46,7 +44,7 @@ public class Rename extends PrimitiveTensorFunction {
this.fromDimensions = ImmutableList.copyOf(fromDimensions);
this.toDimensions = ImmutableList.copyOf(toDimensions);
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
@@ -66,7 +64,7 @@ public class Rename extends PrimitiveTensorFunction {
Map<String, String> fromToMap = fromToMap();
TensorType renamedType = rename(tensor.type(), fromToMap);
-
+
// an array which lists the index of each label in the renamed type
int[] toIndexes = new int[tensor.type().dimensions().size()];
for (int i = 0; i < tensor.type().dimensions().size(); i++) {
@@ -74,7 +72,7 @@ public class Rename extends PrimitiveTensorFunction {
String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName);
toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get();
}
-
+
Tensor.Builder builder = Tensor.Builder.of(renamedType);
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
@@ -90,7 +88,7 @@ public class Rename extends PrimitiveTensorFunction {
builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
return builder.build();
}
-
+
private TensorAddress rename(TensorAddress address, int[] toIndexes) {
String[] reorderedLabels = new String[toIndexes.length];
for (int i = 0; i < toIndexes.length; i++)
@@ -99,18 +97,18 @@ public class Rename extends PrimitiveTensorFunction {
}
@Override
- public String toString(ToStringContext context) {
- return "rename(" + argument.toString(context) + ", " +
+ public String toString(ToStringContext context) {
+ return "rename(" + argument.toString(context) + ", " +
toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
}
-
+
private Map<String, String> fromToMap() {
Map<String, String> map = new HashMap<>();
for (int i = 0; i < fromDimensions.size(); i++)
map.put(fromDimensions.get(i), toDimensions.get(i));
return map;
}
-
+
private String toVectorString(List<String> elements) {
if (elements.size() == 1)
return elements.get(0);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index cabcce198d1..533a46f87fe 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -12,7 +12,7 @@ import java.util.List;
* A representation of a tensor function which is able to be translated to a set of primitive
* tensor functions if necessary.
* All tensor functions are immutable.
- *
+ *
* @author bratseth
*/
@Beta
@@ -48,11 +48,11 @@ public abstract class TensorFunction {
/**
* Return a string representation of this context.
- *
+ *
* @param context a context which must be passed to all nexted functions when requesting the string value
*/
public abstract String toString(ToStringContext context);
-
+
@Override
public String toString() { return toString(ToStringContext.empty()); }