summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorArnstein Ressem <aressem@gmail.com>2021-04-22 23:57:46 +0200
committerGitHub <noreply@github.com>2021-04-22 23:57:46 +0200
commit49d9016a9e50fd9032b10520c71dc5b05ab8b215 (patch)
tree02e8d98dadd3a6cfad1f1e3b10d854ae02cccc03 /vespajlib
parentf40237d5bdfe32a7b33df13e426b67b88fd0288a (diff)
Revert "Lesters/resolve cell types for rename and slice"
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java40
2 files changed, 29 insertions, 17 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index fc1e7737d83..275b546c0aa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -5,7 +5,6 @@ import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -77,7 +76,10 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
private TensorType type(TensorType type) {
- return TypeResolver.rename(type, fromDimensions, toDimensions);
+ TensorType.Builder builder = new TensorType.Builder(type.valueType());
+ for (TensorType.Dimension dimension : type.dimensions())
+ builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
+ return builder.build();
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
index 607c9a0ab44..bccd66acd31 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
@@ -6,7 +6,6 @@ import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -15,7 +14,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
-import java.util.function.Predicate;
+import java.util.Set;
import java.util.stream.Collectors;
/**
@@ -114,33 +113,44 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
return resultType(argument.type(context));
}
- private List<String> findDimensions(List<TensorType.Dimension> dims, Predicate<TensorType.Dimension> pred) {
- return dims.stream().filter(pred).map(TensorType.Dimension::name).collect(Collectors.toList());
- }
-
private TensorType resultType(TensorType argumentType) {
- List<String> peekDimensions;
+ TensorType.Builder b = new TensorType.Builder();
// Special case where a single indexed or mapped dimension is sliced
if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) {
if (subspaceAddress.get(0).index().isPresent()) {
- peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isIndexed);
- if (peekDimensions.size() > 1) {
+ if (argumentType.dimensions().stream().filter(d -> d.isIndexed()).count() > 1)
throw new IllegalArgumentException(this + " slices a single indexed dimension, cannot be applied " +
- "to " + argumentType + ", which has multiple");
+ " to " + argumentType + ", which have multiple");
+ for (TensorType.Dimension dimension : argumentType.dimensions()) {
+ if ( ! dimension.isIndexed())
+ b.dimension(dimension);
}
}
else {
- peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isMapped);
- if (peekDimensions.size() > 1)
+ if (argumentType.dimensions().stream().filter(d -> ! d.isIndexed()).count() > 1)
throw new IllegalArgumentException(this + " slices a single mapped dimension, cannot be applied " +
- "to " + argumentType + ", which has multiple");
+ " to " + argumentType + ", which have multiple");
+ for (TensorType.Dimension dimension : argumentType.dimensions()) {
+ if (dimension.isIndexed())
+ b.dimension(dimension);
+ }
+
}
}
else { // general slicing
- peekDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toList());
+ Set<String> slicedDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toSet());
+ for (TensorType.Dimension dimension : argumentType.dimensions()) {
+ if (slicedDimensions.contains(dimension.name()))
+ slicedDimensions.remove(dimension.name());
+ else
+ b.dimension(dimension);
+ }
+ if ( ! slicedDimensions.isEmpty())
+ throw new IllegalArgumentException(this + " slices " + slicedDimensions + " which are not present in " +
+ argumentType);
}
- return TypeResolver.peek(argumentType, peekDimensions);
+ return b.build();
}
@Override