diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java | 40 |
1 files changed, 25 insertions, 15 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java index 607c9a0ab44..bccd66acd31 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -6,7 +6,6 @@ import com.yahoo.tensor.PartialAddress; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; @@ -15,7 +14,7 @@ import java.util.Iterator; import java.util.List; import java.util.Objects; import java.util.Optional; -import java.util.function.Predicate; +import java.util.Set; import java.util.stream.Collectors; /** @@ -114,33 +113,44 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY return resultType(argument.type(context)); } - private List<String> findDimensions(List<TensorType.Dimension> dims, Predicate<TensorType.Dimension> pred) { - return dims.stream().filter(pred).map(TensorType.Dimension::name).collect(Collectors.toList()); - } - private TensorType resultType(TensorType argumentType) { - List<String> peekDimensions; + TensorType.Builder b = new TensorType.Builder(); // Special case where a single indexed or mapped dimension is sliced if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { if (subspaceAddress.get(0).index().isPresent()) { - peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isIndexed); - if (peekDimensions.size() > 1) { + if (argumentType.dimensions().stream().filter(d -> d.isIndexed()).count() > 1) throw new IllegalArgumentException(this + " slices a single indexed dimension, cannot be applied " + - "to " + argumentType + ", which has multiple"); + " to " + argumentType + ", which have multiple"); + for (TensorType.Dimension dimension : argumentType.dimensions()) { + if ( ! dimension.isIndexed()) + b.dimension(dimension); } } else { - peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isMapped); - if (peekDimensions.size() > 1) + if (argumentType.dimensions().stream().filter(d -> ! d.isIndexed()).count() > 1) throw new IllegalArgumentException(this + " slices a single mapped dimension, cannot be applied " + - "to " + argumentType + ", which has multiple"); + " to " + argumentType + ", which have multiple"); + for (TensorType.Dimension dimension : argumentType.dimensions()) { + if (dimension.isIndexed()) + b.dimension(dimension); + } + } } else { // general slicing - peekDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toList()); + Set<String> slicedDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toSet()); + for (TensorType.Dimension dimension : argumentType.dimensions()) { + if (slicedDimensions.contains(dimension.name())) + slicedDimensions.remove(dimension.name()); + else + b.dimension(dimension); + } + if ( ! slicedDimensions.isEmpty()) + throw new IllegalArgumentException(this + " slices " + slicedDimensions + " which are not present in " + + argumentType); } - return TypeResolver.peek(argumentType, peekDimensions); + return b.build(); } @Override |