diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java | 24 |
1 files changed, 11 insertions, 13 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index ec9b762a41c..6b0daf1b49a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -3,6 +3,8 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -17,7 +19,7 @@ import java.util.Objects; /** * The <i>rename</i> tensor function returns a tensor where some dimensions are assigned new names. - * + * * @author bratseth */ @Beta @@ -27,10 +29,6 @@ public class Rename extends PrimitiveTensorFunction { private final List<String> fromDimensions; private final List<String> toDimensions; - public Rename(TensorFunction argument, String fromDimension, String toDimension) { - this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension)); - } - public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null"); @@ -44,7 +42,7 @@ public class Rename extends PrimitiveTensorFunction { this.fromDimensions = ImmutableList.copyOf(fromDimensions); this.toDimensions = ImmutableList.copyOf(toDimensions); } - + @Override public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } @@ -64,7 +62,7 @@ public class Rename extends PrimitiveTensorFunction { Map<String, String> fromToMap = fromToMap(); TensorType renamedType = rename(tensor.type(), fromToMap); - + // an array which lists the index of each label in the renamed type int[] toIndexes = new int[tensor.type().dimensions().size()]; for (int i = 0; i < tensor.type().dimensions().size(); i++) { @@ -72,7 +70,7 @@ public class Rename extends PrimitiveTensorFunction { String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName); toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get(); } - + Tensor.Builder builder = Tensor.Builder.of(renamedType); for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); @@ -88,7 +86,7 @@ public class Rename extends PrimitiveTensorFunction { builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); return builder.build(); } - + private TensorAddress rename(TensorAddress address, int[] toIndexes) { String[] reorderedLabels = new String[toIndexes.length]; for (int i = 0; i < toIndexes.length; i++) @@ -97,18 +95,18 @@ public class Rename extends PrimitiveTensorFunction { } @Override - public String toString(ToStringContext context) { - return "rename(" + argument.toString(context) + ", " + + public String toString(ToStringContext context) { + return "rename(" + argument.toString(context) + ", " + toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; } - + private Map<String, String> fromToMap() { Map<String, String> map = new HashMap<>(); for (int i = 0; i < fromDimensions.size(); i++) map.put(fromDimensions.get(i), toDimensions.get(i)); return map; } - + private String toVectorString(List<String> elements) { if (elements.size() == 1) return elements.get(0); |