summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java24
1 files changed, 13 insertions, 11 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index 6b0daf1b49a..ec9b762a41c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -3,8 +3,6 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -19,7 +17,7 @@ import java.util.Objects;
/**
* The <i>rename</i> tensor function returns a tensor where some dimensions are assigned new names.
- *
+ *
* @author bratseth
*/
@Beta
@@ -29,6 +27,10 @@ public class Rename extends PrimitiveTensorFunction {
private final List<String> fromDimensions;
private final List<String> toDimensions;
+ public Rename(TensorFunction argument, String fromDimension, String toDimension) {
+ this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension));
+ }
+
public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) {
Objects.requireNonNull(argument, "The argument tensor cannot be null");
Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null");
@@ -42,7 +44,7 @@ public class Rename extends PrimitiveTensorFunction {
this.fromDimensions = ImmutableList.copyOf(fromDimensions);
this.toDimensions = ImmutableList.copyOf(toDimensions);
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
@@ -62,7 +64,7 @@ public class Rename extends PrimitiveTensorFunction {
Map<String, String> fromToMap = fromToMap();
TensorType renamedType = rename(tensor.type(), fromToMap);
-
+
// an array which lists the index of each label in the renamed type
int[] toIndexes = new int[tensor.type().dimensions().size()];
for (int i = 0; i < tensor.type().dimensions().size(); i++) {
@@ -70,7 +72,7 @@ public class Rename extends PrimitiveTensorFunction {
String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName);
toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get();
}
-
+
Tensor.Builder builder = Tensor.Builder.of(renamedType);
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
@@ -86,7 +88,7 @@ public class Rename extends PrimitiveTensorFunction {
builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
return builder.build();
}
-
+
private TensorAddress rename(TensorAddress address, int[] toIndexes) {
String[] reorderedLabels = new String[toIndexes.length];
for (int i = 0; i < toIndexes.length; i++)
@@ -95,18 +97,18 @@ public class Rename extends PrimitiveTensorFunction {
}
@Override
- public String toString(ToStringContext context) {
- return "rename(" + argument.toString(context) + ", " +
+ public String toString(ToStringContext context) {
+ return "rename(" + argument.toString(context) + ", " +
toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
}
-
+
private Map<String, String> fromToMap() {
Map<String, String> map = new HashMap<>();
for (int i = 0; i < fromDimensions.size(); i++)
map.put(fromDimensions.get(i), toDimensions.get(i));
return map;
}
-
+
private String toVectorString(List<String> elements) {
if (elements.size() == 1)
return elements.get(0);