aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java29
1 files changed, 17 insertions, 12 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index d4affe0ef9b..cc8067224c7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -34,10 +34,10 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); }
+ public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if (arguments.size() != 2)
throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size());
return new Concat(arguments.get(0), arguments.get(1), dimension);
@@ -54,6 +54,20 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
+ public TensorType type(EvaluationContext context) {
+ return type(argumentA.type(context), argumentB.type(context));
+ }
+
+ /** Returns the type resulting from concatenating a and b */
+ private TensorType type(TensorType a, TensorType b) {
+ TensorType.Builder builder = new TensorType.Builder(a, b);
+ if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size
+ builder.set(TensorType.Dimension.indexed(dimension, a.dimension(dimension).get().size().get() +
+ b.dimension(dimension).get().size().get()));
+ return builder.build();
+ }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
@@ -63,7 +77,7 @@ public class Concat extends PrimitiveTensorFunction {
IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
IndexedTensor bIndexed = (IndexedTensor) b;
- TensorType concatType = concatType(a, b);
+ TensorType concatType = type(a.type(), b.type());
DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
@@ -115,15 +129,6 @@ public class Concat extends PrimitiveTensorFunction {
}
- /** Returns the type resulting from concatenating a and b */
- private TensorType concatType(Tensor a, Tensor b) {
- TensorType.Builder builder = new TensorType.Builder(a.type(), b.type());
- if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size
- builder.set(TensorType.Dimension.indexed(dimension, a.type().dimension(dimension).get().size().get() +
- b.type().dimension(dimension).get().size().get()));
- return builder.build();
- }
-
/** Returns the concrete (not type) dimension sizes resulting from combining a and b */
private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size());