aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java25
1 files changed, 13 insertions, 12 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
index 7c65afc98f9..d32e84f1ca0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -3,18 +3,19 @@ package com.yahoo.tensor.functions;
import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.List;
/**
* @author bratseth
*/
-public class Matmul extends CompositeTensorFunction {
+public class Matmul<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction argument1, argument2;
+ private final TensorFunction<NAMETYPE> argument1, argument2;
private final String dimension;
- public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) {
+ public Matmul(TensorFunction<NAMETYPE> argument1, TensorFunction<NAMETYPE> argument2, String dimension) {
this.argument1 = argument1;
this.argument2 = argument2;
this.dimension = dimension;
@@ -25,22 +26,22 @@ public class Matmul extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return ImmutableList.of(argument1, argument2); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argument1, argument2); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 2)
throw new IllegalArgumentException("Matmul must have 2 arguments, got " + arguments.size());
- return new Matmul(arguments.get(0), arguments.get(1), dimension);
+ return new Matmul<>(arguments.get(0), arguments.get(1), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument1 = argument1.toPrimitive();
- TensorFunction primitiveArgument2 = argument2.toPrimitive();
- return new Reduce(new Join(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()),
- Reduce.Aggregator.sum,
- dimension);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> primitiveArgument1 = argument1.toPrimitive();
+ TensorFunction<NAMETYPE> primitiveArgument2 = argument2.toPrimitive();
+ return new Reduce<>(new Join<>(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ dimension);
}
@Override