summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-04-23 12:31:21 +0000
committerArne Juul <arnej@verizonmedia.com>2021-04-26 06:45:23 +0000
commitfb98651ec8775d2347a3e42310e14d1b59c38a42 (patch)
tree52549d0bed084386fbbb015cb0278cda5db019ed /vespajlib
parentc0b637dd81754665a014eba3794f31f7fd432d52 (diff)
Reapply "Lesters/resolve cell types for rename and slice"
This reverts commit 49d9016a9e50fd9032b10520c71dc5b05ab8b215.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java40
2 files changed, 17 insertions, 29 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index 275b546c0aa..fc1e7737d83 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -76,10 +77,7 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
private TensorType type(TensorType type) {
- TensorType.Builder builder = new TensorType.Builder(type.valueType());
- for (TensorType.Dimension dimension : type.dimensions())
- builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
- return builder.build();
+ return TypeResolver.rename(type, fromDimensions, toDimensions);
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
index bccd66acd31..607c9a0ab44 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
@@ -6,6 +6,7 @@ import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -14,7 +15,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
-import java.util.Set;
+import java.util.function.Predicate;
import java.util.stream.Collectors;
/**
@@ -113,44 +114,33 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
return resultType(argument.type(context));
}
+ private List<String> findDimensions(List<TensorType.Dimension> dims, Predicate<TensorType.Dimension> pred) {
+ return dims.stream().filter(pred).map(TensorType.Dimension::name).collect(Collectors.toList());
+ }
+
private TensorType resultType(TensorType argumentType) {
- TensorType.Builder b = new TensorType.Builder();
+ List<String> peekDimensions;
// Special case where a single indexed or mapped dimension is sliced
if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) {
if (subspaceAddress.get(0).index().isPresent()) {
- if (argumentType.dimensions().stream().filter(d -> d.isIndexed()).count() > 1)
+ peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isIndexed);
+ if (peekDimensions.size() > 1) {
throw new IllegalArgumentException(this + " slices a single indexed dimension, cannot be applied " +
- " to " + argumentType + ", which have multiple");
- for (TensorType.Dimension dimension : argumentType.dimensions()) {
- if ( ! dimension.isIndexed())
- b.dimension(dimension);
+ "to " + argumentType + ", which has multiple");
}
}
else {
- if (argumentType.dimensions().stream().filter(d -> ! d.isIndexed()).count() > 1)
+ peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isMapped);
+ if (peekDimensions.size() > 1)
throw new IllegalArgumentException(this + " slices a single mapped dimension, cannot be applied " +
- " to " + argumentType + ", which have multiple");
- for (TensorType.Dimension dimension : argumentType.dimensions()) {
- if (dimension.isIndexed())
- b.dimension(dimension);
- }
-
+ "to " + argumentType + ", which has multiple");
}
}
else { // general slicing
- Set<String> slicedDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toSet());
- for (TensorType.Dimension dimension : argumentType.dimensions()) {
- if (slicedDimensions.contains(dimension.name()))
- slicedDimensions.remove(dimension.name());
- else
- b.dimension(dimension);
- }
- if ( ! slicedDimensions.isEmpty())
- throw new IllegalArgumentException(this + " slices " + slicedDimensions + " which are not present in " +
- argumentType);
+ peekDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toList());
}
- return b.build();
+ return TypeResolver.peek(argumentType, peekDimensions);
}
@Override