diff options
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java | 40 |
1 files changed, 15 insertions, 25 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java index bccd66acd31..607c9a0ab44 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -6,6 +6,7 @@ import com.yahoo.tensor.PartialAddress; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; @@ -14,7 +15,7 @@ import java.util.Iterator; import java.util.List; import java.util.Objects; import java.util.Optional; -import java.util.Set; +import java.util.function.Predicate; import java.util.stream.Collectors; /** @@ -113,44 +114,33 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY return resultType(argument.type(context)); } + private List<String> findDimensions(List<TensorType.Dimension> dims, Predicate<TensorType.Dimension> pred) { + return dims.stream().filter(pred).map(TensorType.Dimension::name).collect(Collectors.toList()); + } + private TensorType resultType(TensorType argumentType) { - TensorType.Builder b = new TensorType.Builder(); + List<String> peekDimensions; // Special case where a single indexed or mapped dimension is sliced if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { if (subspaceAddress.get(0).index().isPresent()) { - if (argumentType.dimensions().stream().filter(d -> d.isIndexed()).count() > 1) + peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isIndexed); + if (peekDimensions.size() > 1) { throw new IllegalArgumentException(this + " slices a single indexed dimension, cannot be applied " + - " to " + argumentType + ", which have multiple"); - for (TensorType.Dimension dimension : argumentType.dimensions()) { - if ( ! dimension.isIndexed()) - b.dimension(dimension); + "to " + argumentType + ", which has multiple"); } } else { - if (argumentType.dimensions().stream().filter(d -> ! d.isIndexed()).count() > 1) + peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isMapped); + if (peekDimensions.size() > 1) throw new IllegalArgumentException(this + " slices a single mapped dimension, cannot be applied " + - " to " + argumentType + ", which have multiple"); - for (TensorType.Dimension dimension : argumentType.dimensions()) { - if (dimension.isIndexed()) - b.dimension(dimension); - } - + "to " + argumentType + ", which has multiple"); } } else { // general slicing - Set<String> slicedDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toSet()); - for (TensorType.Dimension dimension : argumentType.dimensions()) { - if (slicedDimensions.contains(dimension.name())) - slicedDimensions.remove(dimension.name()); - else - b.dimension(dimension); - } - if ( ! slicedDimensions.isEmpty()) - throw new IllegalArgumentException(this + " slices " + slicedDimensions + " which are not present in " + - argumentType); + peekDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toList()); } - return b.build(); + return TypeResolver.peek(argumentType, peekDimensions); } @Override |