summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java46
1 files changed, 28 insertions, 18 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index ec9b762a41c..de3d2be265a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -7,6 +7,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.HashMap;
@@ -26,6 +27,7 @@ public class Rename extends PrimitiveTensorFunction {
private final TensorFunction argument;
private final List<String> fromDimensions;
private final List<String> toDimensions;
+ private final Map<String, String> fromToMap;
public Rename(TensorFunction argument, String fromDimension, String toDimension) {
this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension));
@@ -43,13 +45,24 @@ public class Rename extends PrimitiveTensorFunction {
this.argument = argument;
this.fromDimensions = ImmutableList.copyOf(fromDimensions);
this.toDimensions = ImmutableList.copyOf(toDimensions);
+ this.fromToMap = fromToMap(fromDimensions, toDimensions);
+ }
+
+ public List<String> fromDimensions() { return fromDimensions; }
+ public List<String> toDimensions() { return toDimensions; }
+
+ private static Map<String, String> fromToMap(List<String> fromDimensions, List<String> toDimensions) {
+ Map<String, String> map = new HashMap<>();
+ for (int i = 0; i < fromDimensions.size(); i++)
+ map.put(fromDimensions.get(i), toDimensions.get(i));
+ return map;
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Rename must have 1 argument, got " + arguments.size());
return new Rename(arguments.get(0), fromDimensions, toDimensions);
@@ -59,11 +72,22 @@ public class Rename extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
+ public TensorType type(TypeContext context) {
+ return type(argument.type(context));
+ }
+
+ private TensorType type(TensorType type) {
+ TensorType.Builder builder = new TensorType.Builder();
+ for (TensorType.Dimension dimension : type.dimensions())
+ builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
+ return builder.build();
+ }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
Tensor tensor = argument.evaluate(context);
- Map<String, String> fromToMap = fromToMap();
- TensorType renamedType = rename(tensor.type(), fromToMap);
+ TensorType renamedType = type(tensor.type());
// an array which lists the index of each label in the renamed type
int[] toIndexes = new int[tensor.type().dimensions().size()];
@@ -82,13 +106,6 @@ public class Rename extends PrimitiveTensorFunction {
return builder.build();
}
- private TensorType rename(TensorType type, Map<String, String> fromToMap) {
- TensorType.Builder builder = new TensorType.Builder();
- for (TensorType.Dimension dimension : type.dimensions())
- builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
- return builder.build();
- }
-
private TensorAddress rename(TensorAddress address, int[] toIndexes) {
String[] reorderedLabels = new String[toIndexes.length];
for (int i = 0; i < toIndexes.length; i++)
@@ -102,13 +119,6 @@ public class Rename extends PrimitiveTensorFunction {
toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
}
- private Map<String, String> fromToMap() {
- Map<String, String> map = new HashMap<>();
- for (int i = 0; i < fromDimensions.size(); i++)
- map.put(fromDimensions.get(i), toDimensions.get(i));
- return map;
- }
-
private String toVectorString(List<String> elements) {
if (elements.size() == 1)
return elements.get(0);