summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java26
1 files changed, 5 insertions, 21 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 13e7c136feb..c77ed1c0526 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -3,8 +3,6 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
-import com.yahoo.lang.MutableInteger;
-import com.yahoo.lang.MutableLong;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
@@ -62,35 +60,21 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext context) {
return type(argumentA.type(context), argumentB.type(context));
}
/** Returns the type resulting from concatenating a and b */
private TensorType type(TensorType a, TensorType b) {
- // TODO: Fail if concat dimension is present but not indexed in a or b
TensorType.Builder builder = new TensorType.Builder(a, b);
- if ( ! unboundIn(a, dimension) && ! unboundIn(b, dimension)) {
- builder.set(TensorType.Dimension.indexed(dimension, a.sizeOfDimension(dimension).orElse(1L) +
- b.sizeOfDimension(dimension).orElse(1L)));
- /*
- MutableLong concatSize = new MutableLong(0);
- a.sizeOfDimension(dimension).ifPresent(concatSize::add);
- b.sizeOfDimension(dimension).ifPresent(concatSize::add);
- builder.set(TensorType.Dimension.indexed(dimension, concatSize.get()));
- */
- }
+ if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size
+ builder.set(TensorType.Dimension.indexed(dimension, a.dimension(dimension).get().size().get() +
+ b.dimension(dimension).get().size().get()));
return builder.build();
}
- /** Returns true if this dimension is present and unbound */
- private boolean unboundIn(TensorType type, String dimensionName) {
- Optional<TensorType.Dimension> dimension = type.dimension(dimensionName);
- return dimension.isPresent() && ! dimension.get().size().isPresent();
- }
-
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
a = ensureIndexedDimension(dimension, a);