diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 29 |
1 files changed, 17 insertions, 12 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index d4affe0ef9b..cc8067224c7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -34,10 +34,10 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); } + public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if (arguments.size() != 2) throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size()); return new Concat(arguments.get(0), arguments.get(1), dimension); @@ -54,6 +54,20 @@ public class Concat extends PrimitiveTensorFunction { } @Override + public TensorType type(EvaluationContext context) { + return type(argumentA.type(context), argumentB.type(context)); + } + + /** Returns the type resulting from concatenating a and b */ + private TensorType type(TensorType a, TensorType b) { + TensorType.Builder builder = new TensorType.Builder(a, b); + if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size + builder.set(TensorType.Dimension.indexed(dimension, a.dimension(dimension).get().size().get() + + b.dimension(dimension).get().size().get())); + return builder.build(); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); @@ -63,7 +77,7 @@ public class Concat extends PrimitiveTensorFunction { IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor IndexedTensor bIndexed = (IndexedTensor) b; - TensorType concatType = concatType(a, b); + TensorType concatType = type(a.type(), b.type()); DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); @@ -115,15 +129,6 @@ public class Concat extends PrimitiveTensorFunction { } - /** Returns the type resulting from concatenating a and b */ - private TensorType concatType(Tensor a, Tensor b) { - TensorType.Builder builder = new TensorType.Builder(a.type(), b.type()); - if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size - builder.set(TensorType.Dimension.indexed(dimension, a.type().dimension(dimension).get().size().get() + - b.type().dimension(dimension).get().size().get())); - return builder.build(); - } - /** Returns the concrete (not type) dimension sizes resulting from combining a and b */ private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) { DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size()); |