diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java | 25 |
1 files changed, 13 insertions, 12 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index 7c65afc98f9..d32e84f1ca0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -3,18 +3,19 @@ package com.yahoo.tensor.functions; import com.google.common.collect.ImmutableList; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.List; /** * @author bratseth */ -public class Matmul extends CompositeTensorFunction { +public class Matmul<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction argument1, argument2; + private final TensorFunction<NAMETYPE> argument1, argument2; private final String dimension; - public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) { + public Matmul(TensorFunction<NAMETYPE> argument1, TensorFunction<NAMETYPE> argument2, String dimension) { this.argument1 = argument1; this.argument2 = argument2; this.dimension = dimension; @@ -25,22 +26,22 @@ public class Matmul extends CompositeTensorFunction { } @Override - public List<TensorFunction> arguments() { return ImmutableList.of(argument1, argument2); } + public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argument1, argument2); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 2) throw new IllegalArgumentException("Matmul must have 2 arguments, got " + arguments.size()); - return new Matmul(arguments.get(0), arguments.get(1), dimension); + return new Matmul<>(arguments.get(0), arguments.get(1), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument1 = argument1.toPrimitive(); - TensorFunction primitiveArgument2 = argument2.toPrimitive(); - return new Reduce(new Join(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()), - Reduce.Aggregator.sum, - dimension); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitiveArgument1 = argument1.toPrimitive(); + TensorFunction<NAMETYPE> primitiveArgument2 = argument2.toPrimitive(); + return new Reduce<>(new Join<>(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + dimension); } @Override |