diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-12-13 15:21:44 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-12-13 15:21:44 +0100 |
commit | 3783a9b21f8ab7ca3700903d9780a9f7374cf0c5 (patch) | |
tree | ec003528946a37b9f0aeb49e1b314fdc6601c26e /vespajlib/src/main/java/com/yahoo/tensor/functions | |
parent | 5b67e6f8f641141f848ad3989156151f9f182441 (diff) |
Check agreement between TF and Vespa execution
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions')
14 files changed, 79 insertions, 81 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 8f4dbf014a7..191c7988443 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -8,7 +8,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext; /** * A composite tensor function is a tensor function which can be expressed (less tersely) * as a tree of primitive tensor functions. - * + * * @author bratseth */ @Beta diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 1dbb94fdb20..faa0ca36cb6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -15,7 +15,7 @@ import java.util.stream.Collectors; /** * Concatenation of two tensors along an (indexed) dimension - * + * * @author bratseth */ @Beta @@ -74,7 +74,7 @@ public class Concat extends PrimitiveTensorFunction { concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder); return builder.build(); } - + private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType, int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) { Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet()); @@ -112,7 +112,7 @@ public class Concat extends PrimitiveTensorFunction { Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build(); return tensor.multiply(unitTensor); } - + } /** Returns the type resulting from concatenating a and b */ @@ -144,7 +144,7 @@ public class Concat extends PrimitiveTensorFunction { /** * Combine two addresses, adding the offset to the concat dimension * - * @return the combined address or null if the addresses are incompatible + * @return the combined address or null if the addresses are incompatible * (in some other dimension than the concat dimension) */ private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, @@ -161,7 +161,7 @@ public class Concat extends PrimitiveTensorFunction { /** * Returns the an array having one entry in order for each dimension of fromType * containing the index at which toType contains the same dimension name. - * That is, if the returned array contains n at index i then + * That is, if the returned array contains n at index i then * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) * If some dimension in fromType is not present in toType, the corresponding index will be -1 */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 4ac7b21ba90..14ed38718ce 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -10,18 +10,18 @@ import java.util.List; /** * A function which returns a constant tensor. - * + * * @author bratseth */ @Beta public class ConstantTensor extends PrimitiveTensorFunction { private final Tensor constant; - + public ConstantTensor(String tensorString) { this.constant = Tensor.from(tensorString); } - + public ConstantTensor(Tensor tensor) { this.constant = tensor; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java index bbdbd5c3df1..c75d8ee4753 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java @@ -11,19 +11,19 @@ import java.util.stream.Stream; /** * A tensor generator which returns a tensor of any dimension filled with 1 in the diagonal and 0 elsewhere. - * + * * @author bratseth */ public class Diag extends CompositeTensorFunction { private final TensorType type; private final Function<List<Integer>, Double> diagFunction; - + public Diag(TensorType type) { this.type = type; this.diagFunction = ScalarFunctions.equal(dimensionNames().collect(Collectors.toList())); } - + @Override public List<TensorFunction> functionArguments() { return Collections.emptyList(); } @@ -43,7 +43,7 @@ public class Diag extends CompositeTensorFunction { public String toString(ToStringContext context) { return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction; } - + private Stream<String> dimensionNames() { return type.dimensions().stream().map(TensorType.Dimension::name); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 6ea73b7f310..e42d25197e2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -15,7 +15,7 @@ import java.util.function.Function; /** * An indexed tensor whose values are generated by a function - * + * * @author bratseth */ @Beta @@ -26,7 +26,7 @@ public class Generate extends PrimitiveTensorFunction { /** * Creates a generated tensor - * + * * @param type the type of the tensor * @param generator the function generating values from a list of ints specifying the indexes of the * tensor cell which will receive the value @@ -39,7 +39,7 @@ public class Generate extends PrimitiveTensorFunction { this.type = type; this.generator = generator; } - + private void validateType(TensorType type) { for (TensorType.Dimension dimension : type.dimensions()) if (dimension.type() != TensorType.Dimension.Type.indexedBound) @@ -58,7 +58,7 @@ public class Generate extends PrimitiveTensorFunction { @Override public PrimitiveTensorFunction toPrimitive() { return this; } - + @Override public Tensor evaluate(EvaluationContext context) { Tensor.Builder builder = Tensor.Builder.of(type); @@ -69,14 +69,14 @@ public class Generate extends PrimitiveTensorFunction { } return builder.build(); } - + private DimensionSizes dimensionSizes(TensorType type) { DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); for (int i = 0; i < b.dimensions(); i++) b.set(i, type.dimensions().get(i).size().get()); return b.build(); } - + @Override public String toString(ToStringContext context) { return type + "(" + generator + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 9a37127e1f0..ff887e3e9a6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -28,12 +28,12 @@ import java.util.function.DoubleBinaryOperator; * The <i>join</i> tensor operation produces a tensor from the argument tensors containing the set of cells * given by the cross product of the cells of the given tensors, having as values the value produced by * applying the given combinator function on the values from the two source cells. - * + * * @author bratseth */ @Beta public class Join extends PrimitiveTensorFunction { - + private final TensorFunction argumentA, argumentB; private final DoubleBinaryOperator combinator; @@ -56,7 +56,7 @@ public class Join extends PrimitiveTensorFunction { if (aDim.name().equals(bDim.name())) { // include if (aDim.isIndexed() && bDim.isIndexed()) { if (aDim.size().isPresent() || bDim.size().isPresent()) - typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE), + typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE), bDim.size().orElse(Integer.MAX_VALUE))); else typeBuilder.indexed(aDim.name()); @@ -112,11 +112,11 @@ public class Join extends PrimitiveTensorFunction { else return generalJoin(a, b, joinedType); } - + private boolean hasSingleIndexedDimension(Tensor tensor) { return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed(); } - + private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) { int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); Iterator<Double> aIterator = a.valueIterator(); @@ -138,7 +138,7 @@ public class Join extends PrimitiveTensorFunction { } return builder.build(); } - + /** Join a tensor into a superspace */ private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor) @@ -150,7 +150,7 @@ public class Join extends PrimitiveTensorFunction { private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build(); - + DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace); IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes); @@ -158,14 +158,14 @@ public class Join extends PrimitiveTensorFunction { // Find dimensions which are only in the supertype Set<String> superDimensionNames = new HashSet<>(superspace.type().dimensionNames()); superDimensionNames.removeAll(subspace.type().dimensionNames()); - + for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) { IndexedTensor.SubspaceIterator subspaceInSuper = i.next(); joinSubspaces(subspace.valueIterator(), subspace.size(), subspaceInSuper, subspaceInSuper.size(), reversedArgumentOrder, builder); } - + return builder.build(); } @@ -224,7 +224,7 @@ public class Join extends PrimitiveTensorFunction { subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get(); return subspaceIndexes; } - + private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) { String[] subspaceLabels = new String[subspaceIndexes.length]; for (int i = 0; i < subspaceIndexes.length; i++) @@ -259,7 +259,7 @@ public class Join extends PrimitiveTensorFunction { DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize); // for each combination of dimensions only in a - for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) { + for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) { IndexedTensor.SubspaceIterator aSubspace = ia.next(); // for each combination of dimensions in a which is also in b while (aSubspace.hasNext()) { @@ -276,7 +276,7 @@ public class Join extends PrimitiveTensorFunction { } } } - + private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) { PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size()); for (int i = 0; i < addressType.dimensions().size(); i++) @@ -284,7 +284,7 @@ public class Join extends PrimitiveTensorFunction { builder.add(addressType.dimensions().get(i).name(), address.intLabel(i)); return builder.build(); } - + /** Returns the sizes from the joined sizes which are present in the type argument */ private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) { DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); @@ -295,7 +295,7 @@ public class Join extends PrimitiveTensorFunction { } return builder.build(); } - + private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) { int[] aToIndexes = mapIndexes(a.type(), joinedType); int[] bToIndexes = mapIndexes(b.type(), joinedType); @@ -364,7 +364,7 @@ public class Join extends PrimitiveTensorFunction { /** * Returns the an array having one entry in order for each dimension of fromType * containing the index at which toType contains the same dimension name. - * That is, if the returned array contains n at index i then + * That is, if the returned array contains n at index i then * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) * If some dimension in fromType is not present in toType, the corresponding index will be -1 */ @@ -384,7 +384,7 @@ public class Join extends PrimitiveTensorFunction { return TensorAddress.of(joinedLabels); } - /** + /** * Maps the content in the given list to the given array, using the given index map. * * @return true if the mapping was successful, false if one of the destination positions was diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index d322a6ab497..a5e1a016a41 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -32,7 +32,7 @@ public class Map extends PrimitiveTensorFunction { this.argument = argument; this.mapper = mapper; } - + public static TensorType outputType(TensorType inputType) { return inputType; } public TensorFunction argument() { return argument; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index 5e102454487..4071917c2b5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -15,15 +15,15 @@ public class Matmul extends CompositeTensorFunction { private final TensorFunction argument1, argument2; private final String dimension; - + public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) { this.argument1 = argument1; this.argument2 = argument2; this.dimension = dimension; } - + public static TensorType outputType(TensorType a, TensorType b, String dimension) { - return Reduce.outputType(Join.outputType(a, b), ImmutableList.of(dimension)); + return Join.outputType(a, b); } @Override @@ -44,7 +44,7 @@ public class Matmul extends CompositeTensorFunction { Reduce.Aggregator.sum, dimension); } - + @Override public String toString(ToStringContext context) { return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java index efb7b9e500c..b7c9a5d2342 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java @@ -8,10 +8,10 @@ import com.yahoo.tensor.Tensor; * A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions. * All tensor implementations must implement all primitive tensor functions. * Primitive tensor functions are fully inspectable. - * + * * @author bratseth */ @Beta public abstract class PrimitiveTensorFunction extends TensorFunction { - + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java index 457763e97ba..958ef85d1dc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java @@ -22,11 +22,11 @@ import java.util.stream.Stream; public class Random extends CompositeTensorFunction { private final TensorType type; - + public Random(TensorType type) { this.type = type; } - + @Override public List<TensorFunction> functionArguments() { return Collections.emptyList(); } @@ -46,7 +46,7 @@ public class Random extends CompositeTensorFunction { public String toString(ToStringContext context) { return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")"; } - + private Stream<String> dimensionNames() { return type.dimensions().stream().map(TensorType.Dimension::toString); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java index e2b39a2048d..a56f82b026a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java @@ -12,19 +12,19 @@ import java.util.stream.Stream; /** * A tensor generator which returns a tensor of any dimension filled with the sum of the tensor * indexes of each position. - * + * * @author bratseth */ public class Range extends CompositeTensorFunction { private final TensorType type; private final Function<List<Integer>, Double> rangeFunction; - + public Range(TensorType type) { this.type = type; this.rangeFunction = ScalarFunctions.sum(dimensionNames().collect(Collectors.toList())); } - + @Override public List<TensorFunction> functionArguments() { return Collections.emptyList(); } @@ -44,7 +44,7 @@ public class Range extends CompositeTensorFunction { public String toString(ToStringContext context) { return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction; } - + private Stream<String> dimensionNames() { return type.dimensions().stream().map(TensorType.Dimension::toString); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index a51df12e522..de9f90a5804 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -19,7 +19,7 @@ import java.util.Objects; import java.util.Set; /** - * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions + * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions * are collapsed to a single value using an aggregator function. * * @author bratseth @@ -45,7 +45,7 @@ public class Reduce extends PrimitiveTensorFunction { /** * Creates a reduce function. - * + * * @param argument the tensor to reduce * @param aggregator the aggregator function to use * @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced, @@ -69,7 +69,7 @@ public class Reduce extends PrimitiveTensorFunction { } return b.build(); } - + public TensorFunction argument() { return argument; } @Override @@ -91,7 +91,7 @@ public class Reduce extends PrimitiveTensorFunction { public String toString(ToStringContext context) { return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")"; } - + private String commaSeparated(List<String> list) { StringBuilder b = new StringBuilder(); for (String element : list) @@ -103,7 +103,7 @@ public class Reduce extends PrimitiveTensorFunction { public Tensor evaluate(EvaluationContext context) { Tensor argument = this.argument.evaluate(context); if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) - throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + + throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + dimensions + ": Not all those dimensions are present in this tensor"); // Special case: Reduce all @@ -112,14 +112,14 @@ public class Reduce extends PrimitiveTensorFunction { return reduceIndexedVector((IndexedTensor)argument); else return reduceAllGeneral(argument); - + // Reduce type TensorType.Builder builder = new TensorType.Builder(); for (TensorType.Dimension dimension : argument.type().dimensions()) if ( ! dimensions.contains(dimension.name())) // keep builder.dimension(dimension); TensorType reducedType = builder.build(); - + // Reduce cells Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { @@ -131,10 +131,10 @@ public class Reduce extends PrimitiveTensorFunction { Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType); for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet()) reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue()); - + return reducedBuilder.build(); } - + private TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType) { Set<Integer> indexesToRemove = new HashSet<>(); for (String dimensionToRemove : this.dimensions) @@ -147,7 +147,7 @@ public class Reduce extends PrimitiveTensorFunction { reducedLabels[reducedLabelIndex++] = address.label(i); return TensorAddress.of(reducedLabels); } - + private Tensor reduceAllGeneral(Tensor argument) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); ) @@ -163,7 +163,7 @@ public class Reduce extends PrimitiveTensorFunction { } private static abstract class ValueAggregator { - + private static ValueAggregator ofType(Aggregator aggregator) { switch (aggregator) { case avg : return new AvgAggregator(); @@ -174,22 +174,22 @@ public class Reduce extends PrimitiveTensorFunction { case min : return new MinAggregator(); default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented"); } - + } /** Add a new value to those aggregated by this */ public abstract void aggregate(double value); - + /** Returns the value aggregated by this */ public abstract double aggregatedValue(); - + } - + private static class AvgAggregator extends ValueAggregator { private int valueCount = 0; private double valueSum = 0.0; - + @Override public void aggregate(double value) { valueCount++; @@ -197,7 +197,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public double aggregatedValue() { + public double aggregatedValue() { return valueSum / valueCount; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 6e52760424e..ec9b762a41c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -3,8 +3,6 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -19,7 +17,7 @@ import java.util.Objects; /** * The <i>rename</i> tensor function returns a tensor where some dimensions are assigned new names. - * + * * @author bratseth */ @Beta @@ -28,7 +26,7 @@ public class Rename extends PrimitiveTensorFunction { private final TensorFunction argument; private final List<String> fromDimensions; private final List<String> toDimensions; - + public Rename(TensorFunction argument, String fromDimension, String toDimension) { this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension)); } @@ -46,7 +44,7 @@ public class Rename extends PrimitiveTensorFunction { this.fromDimensions = ImmutableList.copyOf(fromDimensions); this.toDimensions = ImmutableList.copyOf(toDimensions); } - + @Override public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } @@ -66,7 +64,7 @@ public class Rename extends PrimitiveTensorFunction { Map<String, String> fromToMap = fromToMap(); TensorType renamedType = rename(tensor.type(), fromToMap); - + // an array which lists the index of each label in the renamed type int[] toIndexes = new int[tensor.type().dimensions().size()]; for (int i = 0; i < tensor.type().dimensions().size(); i++) { @@ -74,7 +72,7 @@ public class Rename extends PrimitiveTensorFunction { String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName); toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get(); } - + Tensor.Builder builder = Tensor.Builder.of(renamedType); for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); @@ -90,7 +88,7 @@ public class Rename extends PrimitiveTensorFunction { builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); return builder.build(); } - + private TensorAddress rename(TensorAddress address, int[] toIndexes) { String[] reorderedLabels = new String[toIndexes.length]; for (int i = 0; i < toIndexes.length; i++) @@ -99,18 +97,18 @@ public class Rename extends PrimitiveTensorFunction { } @Override - public String toString(ToStringContext context) { - return "rename(" + argument.toString(context) + ", " + + public String toString(ToStringContext context) { + return "rename(" + argument.toString(context) + ", " + toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; } - + private Map<String, String> fromToMap() { Map<String, String> map = new HashMap<>(); for (int i = 0; i < fromDimensions.size(); i++) map.put(fromDimensions.get(i), toDimensions.get(i)); return map; } - + private String toVectorString(List<String> elements) { if (elements.size() == 1) return elements.get(0); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index cabcce198d1..533a46f87fe 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -12,7 +12,7 @@ import java.util.List; * A representation of a tensor function which is able to be translated to a set of primitive * tensor functions if necessary. * All tensor functions are immutable. - * + * * @author bratseth */ @Beta @@ -48,11 +48,11 @@ public abstract class TensorFunction { /** * Return a string representation of this context. - * + * * @param context a context which must be passed to all nexted functions when requesting the string value */ public abstract String toString(ToStringContext context); - + @Override public String toString() { return toString(ToStringContext.empty()); } |