diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-16 13:43:01 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-16 13:43:01 +0100 |
commit | 9d8296953e573fc23fe4e346219d4155e6f4e81c (patch) | |
tree | 62a770a165002b5e096ce75c03aad0c48358cd54 /vespajlib/src/main | |
parent | 4ad513c134bf980431d14f1c2c1d4775086047ec (diff) |
More functions
Diffstat (limited to 'vespajlib/src/main')
8 files changed, 134 insertions, 16 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 41882738e89..0f67b4ce5fb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -2,6 +2,11 @@ package com.yahoo.tensor; import com.google.common.annotations.Beta; +import com.yahoo.tensor.functions.ConstantTensor; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.L1Normalize; +import com.yahoo.tensor.functions.L2Normalize; +import com.yahoo.tensor.functions.Reduce; import java.util.ArrayList; import java.util.Collections; @@ -13,6 +18,7 @@ import java.util.Set; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleFunction; import java.util.function.DoubleUnaryOperator; +import java.util.function.Function; import java.util.function.UnaryOperator; /** @@ -49,20 +55,53 @@ public interface Tensor { /** Returns the value of a cell, or NaN if this cell does not exist/have no value */ double get(TensorAddress address); - // ----------------- Level 0 functions + // ----------------- Primitive tensor functions - default Tensor map(Tensor tensor, DoubleUnaryOperator mapper) { + default Tensor map(DoubleUnaryOperator mapper) { throw new UnsupportedOperationException("Not implemented"); } - default Tensor reduce(Tensor tensor, String dimension, - DoubleBinaryOperator reductor, Optional<DoubleBinaryOperator> postTransformation) { + default Tensor reduce(Reduce.Aggregator aggregator, List<String> dimensions) { throw new UnsupportedOperationException("Not implemented"); } - default Tensor join(Tensor tensorA, Tensor tensorB, DoubleBinaryOperator combinator) { + default Tensor join(Tensor tensor, DoubleBinaryOperator combinator) { throw new UnsupportedOperationException("Not implemented"); } + + default Tensor rename(List<String> fromDimensions, List<String> toDimensions) { + throw new UnsupportedOperationException("Not implemented"); + } + + static Tensor from(TensorType type, Function<List<Integer>, Double> valueSupplier) { + throw new UnsupportedOperationException("Not implemented"); + } + + // ----------------- Composite tensor functions + + default Tensor l1Normalize(String dimension) { + return new L1Normalize(new ConstantTensor(this), dimension).toPrimitive().execute(); + } + + default Tensor l2Normalize(String dimension) { + return new L2Normalize(new ConstantTensor(this), dimension).toPrimitive().execute(); + } + + default Tensor multiply(Tensor argument) { + return new Join(new ConstantTensor(this), new ConstantTensor(argument), (a, b) -> ( a * b )).execute(); + } + + default Tensor sum(Tensor argument) { + return new Join(new ConstantTensor(this), new ConstantTensor(argument), (a, b) -> ( a + b )).execute(); + } + + default Tensor divide(Tensor argument) { + return new Join(new ConstantTensor(this), new ConstantTensor(argument), (a, b) -> ( a / b )).execute(); + } + + default Tensor subtract(Tensor argument) { + return new Join(new ConstantTensor(this), new ConstantTensor(argument), (a, b) -> ( a - b )).execute(); + } // ----------------- Old stuff /** @@ -80,7 +119,7 @@ public interface Tensor { * @param argument the tensor to multiply by this * @return the resulting tensor. */ - default Tensor multiply(Tensor argument) { + default Tensor oldMultiply(Tensor argument) { return new TensorProduct(this, argument).result(); } @@ -140,7 +179,7 @@ public interface Tensor { * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors, * and have the value undefined for any non-shared dimension. */ - default Tensor subtract(Tensor argument) { + default Tensor oldSubtract(Tensor argument) { return new TensorDifference(this, argument).result(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index dae3f43fc7f..f8db0f2b601 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -1,6 +1,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.MapTensor; +import com.yahoo.tensor.Tensor; /** * A function which returns a constant tensor. @@ -9,12 +10,16 @@ import com.yahoo.tensor.MapTensor; */ public class ConstantTensor extends PrimitiveTensorFunction { - private final MapTensor constant; + private final Tensor constant; public ConstantTensor(String tensorString) { this.constant = MapTensor.from(tensorString); } + public ConstantTensor(Tensor tensor) { + this.constant = tensor; + } + @Override public PrimitiveTensorFunction toPrimitive() { return this; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 5a4c0c8b2a8..2f7b6802e89 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -35,7 +35,7 @@ public class Join extends PrimitiveTensorFunction { @Override public String toString() { - return "join(" + argumentA.toString() + ", " + argumentB.toString() + ", lambda(a, b) (" + combinator + "))"; + return "join(" + argumentA.toString() + ", " + argumentB.toString() + ", f(a, b) (" + combinator + "))"; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java index d571875796b..d3fc707b65d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java @@ -6,14 +6,19 @@ package com.yahoo.tensor.functions; public class L1Normalize extends CompositeTensorFunction { private final TensorFunction argument; + private final String dimension; - public L1Normalize(TensorFunction argument) { + public L1Normalize(TensorFunction argument, String dimension) { this.argument = argument; + this.dimension = dimension; } @Override public PrimitiveTensorFunction toPrimitive() { - return new Join(argument.toPrimitive(), new Reduce(argument.toPrimitive(), Reduce.Aggregator.avg, "dimension"), ScalarFunctions.multiply()); + TensorFunction primitiveArgument = argument.toPrimitive(); + return new Join(primitiveArgument, + new Reduce(primitiveArgument, Reduce.Aggregator.avg, dimension), + ScalarFunctions.multiply()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java new file mode 100644 index 00000000000..eb632ee679a --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java @@ -0,0 +1,32 @@ +package com.yahoo.tensor.functions; + +/** + * @author bratseth + */ +public class L2Normalize extends CompositeTensorFunction { + + private final TensorFunction argument; + private final String dimension; + + public L2Normalize(TensorFunction argument, String dimension) { + this.argument = argument; + this.dimension = dimension; + } + + @Override + public PrimitiveTensorFunction toPrimitive() { + TensorFunction primitiveArgument = argument.toPrimitive(); + return new Join(primitiveArgument, + new Map(new Reduce(new Map(primitiveArgument, ScalarFunctions.square()), + Reduce.Aggregator.sum, + dimension), + ScalarFunctions.square()), + ScalarFunctions.divide()); + } + + @Override + public String toString() { + return "l2_normalize(" + argument + ")"; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index 0a7dd16d6ec..901cdd7d16a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -30,7 +30,7 @@ public class Map extends PrimitiveTensorFunction { @Override public String toString() { - return "map(" + argument.toString() + ", lambda(a) (" + mapper + "))"; + return "map(" + argument.toString() + ", f(a) (" + mapper + "))"; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java index 9c0c9abaeb7..215e52f1809 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java @@ -1,5 +1,7 @@ package com.yahoo.tensor.functions; +import com.yahoo.tensor.Tensor; + /** * A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions. * All tensor implementations must implement all primitive tensor functions. @@ -8,4 +10,7 @@ package com.yahoo.tensor.functions; * @author bratseth */ public abstract class PrimitiveTensorFunction extends TensorFunction { + + public abstract Tensor execute(); + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index 517a55339dd..05f196da06d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -1,16 +1,48 @@ package com.yahoo.tensor.functions; import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleUnaryOperator; /** - * Factory of scalar Java functions which have a type which can be inspected. + * Factory of scalar Java functions. + * The purpose of this is to embellish anonymous functions with a runtime type + * such that they can be inspected and return a usable toString. * * @author bratseth */ public class ScalarFunctions { - public static DoubleBinaryOperator multiply() { - return - } + public static DoubleBinaryOperator multiply() { return new Multiplication(); } + public static DoubleBinaryOperator divide() { return new Division(); } + public static DoubleUnaryOperator square() { return new Square(); } + public static class Multiplication implements DoubleBinaryOperator { + + @Override + public double applyAsDouble(double left, double right) { return left * right; } + + @Override + public String toString() { return "a * b"; } + + } + + public static class Division implements DoubleBinaryOperator { + + @Override + public double applyAsDouble(double left, double right) { return left / right; } + + @Override + public String toString() { return "a / b"; } + } + + public static class Square implements DoubleUnaryOperator { + + @Override + public double applyAsDouble(double operand) { return operand * operand; } + + @Override + public String toString() { return "a * a"; } + + } + } |