diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 26 |
1 files changed, 5 insertions, 21 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 13e7c136feb..c77ed1c0526 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -3,8 +3,6 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; -import com.yahoo.lang.MutableInteger; -import com.yahoo.lang.MutableLong; import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; @@ -62,35 +60,21 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext context) { return type(argumentA.type(context), argumentB.type(context)); } /** Returns the type resulting from concatenating a and b */ private TensorType type(TensorType a, TensorType b) { - // TODO: Fail if concat dimension is present but not indexed in a or b TensorType.Builder builder = new TensorType.Builder(a, b); - if ( ! unboundIn(a, dimension) && ! unboundIn(b, dimension)) { - builder.set(TensorType.Dimension.indexed(dimension, a.sizeOfDimension(dimension).orElse(1L) + - b.sizeOfDimension(dimension).orElse(1L))); - /* - MutableLong concatSize = new MutableLong(0); - a.sizeOfDimension(dimension).ifPresent(concatSize::add); - b.sizeOfDimension(dimension).ifPresent(concatSize::add); - builder.set(TensorType.Dimension.indexed(dimension, concatSize.get())); - */ - } + if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size + builder.set(TensorType.Dimension.indexed(dimension, a.dimension(dimension).get().size().get() + + b.dimension(dimension).get().size().get())); return builder.build(); } - /** Returns true if this dimension is present and unbound */ - private boolean unboundIn(TensorType type, String dimensionName) { - Optional<TensorType.Dimension> dimension = type.dimension(dimensionName); - return dimension.isPresent() && ! dimension.get().size().isPresent(); - } - @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); a = ensureIndexedDimension(dimension, a); |