diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java | 46 |
1 files changed, 28 insertions, 18 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index ec9b762a41c..de3d2be265a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -7,6 +7,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.HashMap; @@ -26,6 +27,7 @@ public class Rename extends PrimitiveTensorFunction { private final TensorFunction argument; private final List<String> fromDimensions; private final List<String> toDimensions; + private final Map<String, String> fromToMap; public Rename(TensorFunction argument, String fromDimension, String toDimension) { this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension)); @@ -43,13 +45,24 @@ public class Rename extends PrimitiveTensorFunction { this.argument = argument; this.fromDimensions = ImmutableList.copyOf(fromDimensions); this.toDimensions = ImmutableList.copyOf(toDimensions); + this.fromToMap = fromToMap(fromDimensions, toDimensions); + } + + public List<String> fromDimensions() { return fromDimensions; } + public List<String> toDimensions() { return toDimensions; } + + private static Map<String, String> fromToMap(List<String> fromDimensions, List<String> toDimensions) { + Map<String, String> map = new HashMap<>(); + for (int i = 0; i < fromDimensions.size(); i++) + map.put(fromDimensions.get(i), toDimensions.get(i)); + return map; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Rename must have 1 argument, got " + arguments.size()); return new Rename(arguments.get(0), fromDimensions, toDimensions); @@ -59,11 +72,22 @@ public class Rename extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override + public TensorType type(TypeContext context) { + return type(argument.type(context)); + } + + private TensorType type(TensorType type) { + TensorType.Builder builder = new TensorType.Builder(); + for (TensorType.Dimension dimension : type.dimensions()) + builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); + return builder.build(); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor tensor = argument.evaluate(context); - Map<String, String> fromToMap = fromToMap(); - TensorType renamedType = rename(tensor.type(), fromToMap); + TensorType renamedType = type(tensor.type()); // an array which lists the index of each label in the renamed type int[] toIndexes = new int[tensor.type().dimensions().size()]; @@ -82,13 +106,6 @@ public class Rename extends PrimitiveTensorFunction { return builder.build(); } - private TensorType rename(TensorType type, Map<String, String> fromToMap) { - TensorType.Builder builder = new TensorType.Builder(); - for (TensorType.Dimension dimension : type.dimensions()) - builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); - return builder.build(); - } - private TensorAddress rename(TensorAddress address, int[] toIndexes) { String[] reorderedLabels = new String[toIndexes.length]; for (int i = 0; i < toIndexes.length; i++) @@ -102,13 +119,6 @@ public class Rename extends PrimitiveTensorFunction { toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; } - private Map<String, String> fromToMap() { - Map<String, String> map = new HashMap<>(); - for (int i = 0; i < fromDimensions.size(); i++) - map.put(fromDimensions.get(i), toDimensions.get(i)); - return map; - } - private String toVectorString(List<String> elements) { if (elements.size() == 1) return elements.get(0); |