summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-11-24 14:18:01 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-11-24 14:18:01 +0100
commitcb2dc3460fa31dffb51e54847283038e8a0ae93c (patch)
treee96497fe6b167f8867ad9cb225ea979a6e09dab8 /vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
parent437a2dc519cc991302c01acb8cd1df1e96b1283d (diff)
Implement composite functions
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java38
1 files changed, 38 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
new file mode 100644
index 00000000000..4492ab083d4
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -0,0 +1,38 @@
+package com.yahoo.tensor.functions;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * @author bratseth
+ */
+public class Matmul extends CompositeTensorFunction {
+
+ private final TensorFunction argument1, argument2;
+ private final String dimension;
+
+ public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) {
+ this.argument1 = argument1;
+ this.argument2 = argument2;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ TensorFunction primitiveArgument1 = argument1.toPrimitive();
+ TensorFunction primitiveArgument2 = argument2.toPrimitive();
+ return new Reduce(new Join(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ dimension);
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")";
+ }
+
+}