diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-24 14:18:01 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-24 14:18:01 +0100 |
commit | cb2dc3460fa31dffb51e54847283038e8a0ae93c (patch) | |
tree | e96497fe6b167f8867ad9cb225ea979a6e09dab8 /vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java | |
parent | 437a2dc519cc991302c01acb8cd1df1e96b1283d (diff) |
Implement composite functions
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java new file mode 100644 index 00000000000..4492ab083d4 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -0,0 +1,38 @@ +package com.yahoo.tensor.functions; + +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * @author bratseth + */ +public class Matmul extends CompositeTensorFunction { + + private final TensorFunction argument1, argument2; + private final String dimension; + + public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) { + this.argument1 = argument1; + this.argument2 = argument2; + this.dimension = dimension; + } + + @Override + public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); } + + @Override + public PrimitiveTensorFunction toPrimitive() { + TensorFunction primitiveArgument1 = argument1.toPrimitive(); + TensorFunction primitiveArgument2 = argument2.toPrimitive(); + return new Reduce(new Join(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + dimension); + } + + @Override + public String toString(ToStringContext context) { + return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")"; + } + +} |