summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-04-22 14:05:50 +0200
committerLester Solbakken <lesters@oath.com>2021-04-22 14:05:50 +0200
commit20af8c0439bcd3a98cd17144b330e71997be09a4 (patch)
tree9edf78e162be450a7e7d01c7e84e99a9d19ebb6c /vespajlib
parenta242591a9328fa21959ca76c08e616b1f1c682d7 (diff)
Wire in tensor cell type resolving for slice/peek in Java
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java40
1 files changed, 15 insertions, 25 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
index bccd66acd31..607c9a0ab44 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
@@ -6,6 +6,7 @@ import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -14,7 +15,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
-import java.util.Set;
+import java.util.function.Predicate;
import java.util.stream.Collectors;
/**
@@ -113,44 +114,33 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
return resultType(argument.type(context));
}
+ private List<String> findDimensions(List<TensorType.Dimension> dims, Predicate<TensorType.Dimension> pred) {
+ return dims.stream().filter(pred).map(TensorType.Dimension::name).collect(Collectors.toList());
+ }
+
private TensorType resultType(TensorType argumentType) {
- TensorType.Builder b = new TensorType.Builder();
+ List<String> peekDimensions;
// Special case where a single indexed or mapped dimension is sliced
if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) {
if (subspaceAddress.get(0).index().isPresent()) {
- if (argumentType.dimensions().stream().filter(d -> d.isIndexed()).count() > 1)
+ peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isIndexed);
+ if (peekDimensions.size() > 1) {
throw new IllegalArgumentException(this + " slices a single indexed dimension, cannot be applied " +
- " to " + argumentType + ", which have multiple");
- for (TensorType.Dimension dimension : argumentType.dimensions()) {
- if ( ! dimension.isIndexed())
- b.dimension(dimension);
+ "to " + argumentType + ", which has multiple");
}
}
else {
- if (argumentType.dimensions().stream().filter(d -> ! d.isIndexed()).count() > 1)
+ peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isMapped);
+ if (peekDimensions.size() > 1)
throw new IllegalArgumentException(this + " slices a single mapped dimension, cannot be applied " +
- " to " + argumentType + ", which have multiple");
- for (TensorType.Dimension dimension : argumentType.dimensions()) {
- if (dimension.isIndexed())
- b.dimension(dimension);
- }
-
+ "to " + argumentType + ", which has multiple");
}
}
else { // general slicing
- Set<String> slicedDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toSet());
- for (TensorType.Dimension dimension : argumentType.dimensions()) {
- if (slicedDimensions.contains(dimension.name()))
- slicedDimensions.remove(dimension.name());
- else
- b.dimension(dimension);
- }
- if ( ! slicedDimensions.isEmpty())
- throw new IllegalArgumentException(this + " slices " + slicedDimensions + " which are not present in " +
- argumentType);
+ peekDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toList());
}
- return b.build();
+ return TypeResolver.peek(argumentType, peekDimensions);
}
@Override