summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-11-16 13:43:01 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-11-16 13:43:01 +0100
commit9d8296953e573fc23fe4e346219d4155e6f4e81c (patch)
tree62a770a165002b5e096ce75c03aad0c48358cd54 /vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
parent4ad513c134bf980431d14f1c2c1d4775086047ec (diff)
More functions
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java32
1 files changed, 32 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
new file mode 100644
index 00000000000..eb632ee679a
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
@@ -0,0 +1,32 @@
+package com.yahoo.tensor.functions;
+
+/**
+ * @author bratseth
+ */
+public class L2Normalize extends CompositeTensorFunction {
+
+ private final TensorFunction argument;
+ private final String dimension;
+
+ public L2Normalize(TensorFunction argument, String dimension) {
+ this.argument = argument;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ TensorFunction primitiveArgument = argument.toPrimitive();
+ return new Join(primitiveArgument,
+ new Map(new Reduce(new Map(primitiveArgument, ScalarFunctions.square()),
+ Reduce.Aggregator.sum,
+ dimension),
+ ScalarFunctions.square()),
+ ScalarFunctions.divide());
+ }
+
+ @Override
+ public String toString() {
+ return "l2_normalize(" + argument + ")";
+ }
+
+}