summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-11-16 13:43:01 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-11-16 13:43:01 +0100
commit9d8296953e573fc23fe4e346219d4155e6f4e81c (patch)
tree62a770a165002b5e096ce75c03aad0c48358cd54 /vespajlib/src/main
parent4ad513c134bf980431d14f1c2c1d4775086047ec (diff)
More functions
Diffstat (limited to 'vespajlib/src/main')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java53
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java32
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java40
8 files changed, 134 insertions, 16 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 41882738e89..0f67b4ce5fb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -2,6 +2,11 @@
package com.yahoo.tensor;
import com.google.common.annotations.Beta;
+import com.yahoo.tensor.functions.ConstantTensor;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.L1Normalize;
+import com.yahoo.tensor.functions.L2Normalize;
+import com.yahoo.tensor.functions.Reduce;
import java.util.ArrayList;
import java.util.Collections;
@@ -13,6 +18,7 @@ import java.util.Set;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleFunction;
import java.util.function.DoubleUnaryOperator;
+import java.util.function.Function;
import java.util.function.UnaryOperator;
/**
@@ -49,20 +55,53 @@ public interface Tensor {
/** Returns the value of a cell, or NaN if this cell does not exist/have no value */
double get(TensorAddress address);
- // ----------------- Level 0 functions
+ // ----------------- Primitive tensor functions
- default Tensor map(Tensor tensor, DoubleUnaryOperator mapper) {
+ default Tensor map(DoubleUnaryOperator mapper) {
throw new UnsupportedOperationException("Not implemented");
}
- default Tensor reduce(Tensor tensor, String dimension,
- DoubleBinaryOperator reductor, Optional<DoubleBinaryOperator> postTransformation) {
+ default Tensor reduce(Reduce.Aggregator aggregator, List<String> dimensions) {
throw new UnsupportedOperationException("Not implemented");
}
- default Tensor join(Tensor tensorA, Tensor tensorB, DoubleBinaryOperator combinator) {
+ default Tensor join(Tensor tensor, DoubleBinaryOperator combinator) {
throw new UnsupportedOperationException("Not implemented");
}
+
+ default Tensor rename(List<String> fromDimensions, List<String> toDimensions) {
+ throw new UnsupportedOperationException("Not implemented");
+ }
+
+ static Tensor from(TensorType type, Function<List<Integer>, Double> valueSupplier) {
+ throw new UnsupportedOperationException("Not implemented");
+ }
+
+ // ----------------- Composite tensor functions
+
+ default Tensor l1Normalize(String dimension) {
+ return new L1Normalize(new ConstantTensor(this), dimension).toPrimitive().execute();
+ }
+
+ default Tensor l2Normalize(String dimension) {
+ return new L2Normalize(new ConstantTensor(this), dimension).toPrimitive().execute();
+ }
+
+ default Tensor multiply(Tensor argument) {
+ return new Join(new ConstantTensor(this), new ConstantTensor(argument), (a, b) -> ( a * b )).execute();
+ }
+
+ default Tensor sum(Tensor argument) {
+ return new Join(new ConstantTensor(this), new ConstantTensor(argument), (a, b) -> ( a + b )).execute();
+ }
+
+ default Tensor divide(Tensor argument) {
+ return new Join(new ConstantTensor(this), new ConstantTensor(argument), (a, b) -> ( a / b )).execute();
+ }
+
+ default Tensor subtract(Tensor argument) {
+ return new Join(new ConstantTensor(this), new ConstantTensor(argument), (a, b) -> ( a - b )).execute();
+ }
// ----------------- Old stuff
/**
@@ -80,7 +119,7 @@ public interface Tensor {
* @param argument the tensor to multiply by this
* @return the resulting tensor.
*/
- default Tensor multiply(Tensor argument) {
+ default Tensor oldMultiply(Tensor argument) {
return new TensorProduct(this, argument).result();
}
@@ -140,7 +179,7 @@ public interface Tensor {
* Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
* and have the value undefined for any non-shared dimension.
*/
- default Tensor subtract(Tensor argument) {
+ default Tensor oldSubtract(Tensor argument) {
return new TensorDifference(this, argument).result();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index dae3f43fc7f..f8db0f2b601 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -1,6 +1,7 @@
package com.yahoo.tensor.functions;
import com.yahoo.tensor.MapTensor;
+import com.yahoo.tensor.Tensor;
/**
* A function which returns a constant tensor.
@@ -9,12 +10,16 @@ import com.yahoo.tensor.MapTensor;
*/
public class ConstantTensor extends PrimitiveTensorFunction {
- private final MapTensor constant;
+ private final Tensor constant;
public ConstantTensor(String tensorString) {
this.constant = MapTensor.from(tensorString);
}
+ public ConstantTensor(Tensor tensor) {
+ this.constant = tensor;
+ }
+
@Override
public PrimitiveTensorFunction toPrimitive() { return this; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 5a4c0c8b2a8..2f7b6802e89 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -35,7 +35,7 @@ public class Join extends PrimitiveTensorFunction {
@Override
public String toString() {
- return "join(" + argumentA.toString() + ", " + argumentB.toString() + ", lambda(a, b) (" + combinator + "))";
+ return "join(" + argumentA.toString() + ", " + argumentB.toString() + ", f(a, b) (" + combinator + "))";
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
index d571875796b..d3fc707b65d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
@@ -6,14 +6,19 @@ package com.yahoo.tensor.functions;
public class L1Normalize extends CompositeTensorFunction {
private final TensorFunction argument;
+ private final String dimension;
- public L1Normalize(TensorFunction argument) {
+ public L1Normalize(TensorFunction argument, String dimension) {
this.argument = argument;
+ this.dimension = dimension;
}
@Override
public PrimitiveTensorFunction toPrimitive() {
- return new Join(argument.toPrimitive(), new Reduce(argument.toPrimitive(), Reduce.Aggregator.avg, "dimension"), ScalarFunctions.multiply());
+ TensorFunction primitiveArgument = argument.toPrimitive();
+ return new Join(primitiveArgument,
+ new Reduce(primitiveArgument, Reduce.Aggregator.avg, dimension),
+ ScalarFunctions.multiply());
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
new file mode 100644
index 00000000000..eb632ee679a
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
@@ -0,0 +1,32 @@
+package com.yahoo.tensor.functions;
+
+/**
+ * @author bratseth
+ */
+public class L2Normalize extends CompositeTensorFunction {
+
+ private final TensorFunction argument;
+ private final String dimension;
+
+ public L2Normalize(TensorFunction argument, String dimension) {
+ this.argument = argument;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ TensorFunction primitiveArgument = argument.toPrimitive();
+ return new Join(primitiveArgument,
+ new Map(new Reduce(new Map(primitiveArgument, ScalarFunctions.square()),
+ Reduce.Aggregator.sum,
+ dimension),
+ ScalarFunctions.square()),
+ ScalarFunctions.divide());
+ }
+
+ @Override
+ public String toString() {
+ return "l2_normalize(" + argument + ")";
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index 0a7dd16d6ec..901cdd7d16a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -30,7 +30,7 @@ public class Map extends PrimitiveTensorFunction {
@Override
public String toString() {
- return "map(" + argument.toString() + ", lambda(a) (" + mapper + "))";
+ return "map(" + argument.toString() + ", f(a) (" + mapper + "))";
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
index 9c0c9abaeb7..215e52f1809 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
@@ -1,5 +1,7 @@
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.Tensor;
+
/**
* A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions.
* All tensor implementations must implement all primitive tensor functions.
@@ -8,4 +10,7 @@ package com.yahoo.tensor.functions;
* @author bratseth
*/
public abstract class PrimitiveTensorFunction extends TensorFunction {
+
+ public abstract Tensor execute();
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index 517a55339dd..05f196da06d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -1,16 +1,48 @@
package com.yahoo.tensor.functions;
import java.util.function.DoubleBinaryOperator;
+import java.util.function.DoubleUnaryOperator;
/**
- * Factory of scalar Java functions which have a type which can be inspected.
+ * Factory of scalar Java functions.
+ * The purpose of this is to embellish anonymous functions with a runtime type
+ * such that they can be inspected and return a usable toString.
*
* @author bratseth
*/
public class ScalarFunctions {
- public static DoubleBinaryOperator multiply() {
- return
- }
+ public static DoubleBinaryOperator multiply() { return new Multiplication(); }
+ public static DoubleBinaryOperator divide() { return new Division(); }
+ public static DoubleUnaryOperator square() { return new Square(); }
+ public static class Multiplication implements DoubleBinaryOperator {
+
+ @Override
+ public double applyAsDouble(double left, double right) { return left * right; }
+
+ @Override
+ public String toString() { return "a * b"; }
+
+ }
+
+ public static class Division implements DoubleBinaryOperator {
+
+ @Override
+ public double applyAsDouble(double left, double right) { return left / right; }
+
+ @Override
+ public String toString() { return "a / b"; }
+ }
+
+ public static class Square implements DoubleUnaryOperator {
+
+ @Override
+ public double applyAsDouble(double operand) { return operand * operand; }
+
+ @Override
+ public String toString() { return "a * a"; }
+
+ }
+
}