diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-24 14:18:01 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-24 14:18:01 +0100 |
commit | cb2dc3460fa31dffb51e54847283038e8a0ae93c (patch) | |
tree | e96497fe6b167f8867ad9cb225ea979a6e09dab8 /vespajlib/src/main/java/com | |
parent | 437a2dc519cc991302c01acb8cd1df1e96b1283d (diff) |
Implement composite functions
Diffstat (limited to 'vespajlib/src/main/java/com')
17 files changed, 327 insertions, 86 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 06a01fad52e..d11225a5fe2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -2,15 +2,16 @@ package com.yahoo.tensor; import com.google.common.annotations.Beta; -import com.google.common.collect.ImmutableMap; import com.yahoo.tensor.functions.ConstantTensor; +import com.yahoo.tensor.functions.EvaluationContext; import com.yahoo.tensor.functions.GeneratedTensor; -import com.yahoo.tensor.functions.JoinFunction; +import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.L1Normalize; import com.yahoo.tensor.functions.L2Normalize; -import com.yahoo.tensor.functions.MapFunction; -import com.yahoo.tensor.functions.ReduceFunction; -import com.yahoo.tensor.functions.RenameFunction; +import com.yahoo.tensor.functions.Matmul; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.Softmax; import java.util.ArrayList; import java.util.Collections; @@ -21,7 +22,6 @@ import java.util.Set; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; import java.util.function.Function; -import java.util.function.UnaryOperator; /** * A multidimensional array which can be used in computations. @@ -60,34 +60,42 @@ public interface Tensor { // ----------------- Primitive tensor functions default Tensor map(DoubleUnaryOperator mapper) { - return new MapFunction(new ConstantTensor(this), mapper).execute(); + return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate(); } /** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */ - default Tensor reduce(ReduceFunction.Aggregator aggregator, List<String> dimensions) { - return new ReduceFunction(new ConstantTensor(this), aggregator, dimensions).execute(); + default Tensor reduce(Reduce.Aggregator aggregator, List<String> dimensions) { + return new Reduce(new ConstantTensor(this), aggregator, dimensions).evaluate(); } default Tensor join(Tensor argument, DoubleBinaryOperator combinator) { - return new JoinFunction(new ConstantTensor(this), new ConstantTensor(argument), combinator).execute(); + return new Join(new ConstantTensor(this), new ConstantTensor(argument), combinator).evaluate(); } default Tensor rename(List<String> fromDimensions, List<String> toDimensions) { - return new RenameFunction(new ConstantTensor(this), fromDimensions, toDimensions).execute(); + return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate(); } static Tensor from(TensorType type, Function<List<Integer>, Double> valueSupplier) { - return new GeneratedTensor(type, valueSupplier).execute(); + return new GeneratedTensor(type, valueSupplier).evaluate(); } // ----------------- Composite tensor functions which have a defined primitive mapping default Tensor l1Normalize(String dimension) { - return new L1Normalize(new ConstantTensor(this), dimension).execute(); + return new L1Normalize(new ConstantTensor(this), dimension).evaluate(); } default Tensor l2Normalize(String dimension) { - return new L2Normalize(new ConstantTensor(this), dimension).execute(); + return new L2Normalize(new ConstantTensor(this), dimension).evaluate(); + } + + default Tensor matmul(Tensor argument, String dimension) { + return new Matmul(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate(); + } + + default Tensor softmax(String dimension) { + return new Softmax(new ConstantTensor(this), dimension).evaluate(); } // ----------------- Composite tensor functions mapped to primitives here on the fly @@ -98,13 +106,15 @@ public interface Tensor { default Tensor subtract(Tensor argument) { return join(argument, (a, b) -> (a - b )); } default Tensor max(Tensor argument) { return join(argument, (a, b) -> (a > b ? a : b )); } default Tensor min(Tensor argument) { return join(argument, (a, b) -> (a < b ? a : b )); } + default Tensor atan2(Tensor argument) { return join(argument, Math::atan2); } + default Tensor equal(Tensor argument) { return join(argument, (a, b) -> ( a == b ? 1.0 : 0.0)); } - default Tensor avg(List<String> dimensions) { return reduce(ReduceFunction.Aggregator.avg, dimensions); } - default Tensor count(List<String> dimensions) { return reduce(ReduceFunction.Aggregator.count, dimensions); } - default Tensor max(List<String> dimensions) { return reduce(ReduceFunction.Aggregator.max, dimensions); } - default Tensor min(List<String> dimensions) { return reduce(ReduceFunction.Aggregator.min, dimensions); } - default Tensor prod(List<String> dimensions) { return reduce(ReduceFunction.Aggregator.prod, dimensions); } - default Tensor sum(List<String> dimensions) { return reduce(ReduceFunction.Aggregator.sum, dimensions); } + default Tensor avg(List<String> dimensions) { return reduce(Reduce.Aggregator.avg, dimensions); } + default Tensor count(List<String> dimensions) { return reduce(Reduce.Aggregator.count, dimensions); } + default Tensor max(List<String> dimensions) { return reduce(Reduce.Aggregator.max, dimensions); } + default Tensor min(List<String> dimensions) { return reduce(Reduce.Aggregator.min, dimensions); } + default Tensor prod(List<String> dimensions) { return reduce(Reduce.Aggregator.prod, dimensions); } + default Tensor sum(List<String> dimensions) { return reduce(Reduce.Aggregator.sum, dimensions); } /** * Returns true if the given tensor is mathematically equal to this: @@ -177,11 +187,11 @@ public interface Tensor { } static String contentToString(Tensor tensor) { - List<Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet()); - Collections.sort(cellEntries, Map.Entry.<TensorAddress, Double>comparingByKey()); + List<java.util.Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet()); + Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey()); StringBuilder b = new StringBuilder("{"); - for (Map.Entry<TensorAddress, Double> cell : cellEntries) { + for (java.util.Map.Entry<TensorAddress, Double> cell : cellEntries) { b.append(cell.getKey()).append(":").append(cell.getValue()); b.append(","); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index e564e1d6f25..31454e28baf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -10,8 +10,8 @@ import com.yahoo.tensor.Tensor; */ public abstract class CompositeTensorFunction extends TensorFunction { - /** Executes this by first converting it to a primitive function */ + /** Evaluates this by first converting it to a primitive function */ @Override - public final Tensor execute() { return toPrimitive().execute(); } + public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index c86ac7b137b..0727579a331 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -3,6 +3,9 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.MapTensor; import com.yahoo.tensor.Tensor; +import java.util.Collections; +import java.util.List; + /** * A function which returns a constant tensor. * @@ -19,14 +22,17 @@ public class ConstantTensor extends PrimitiveTensorFunction { public ConstantTensor(Tensor tensor) { this.constant = tensor; } - + + @Override + public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + @Override public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public Tensor execute() { return constant; } + public Tensor evaluate(EvaluationContext context) { return constant; } @Override - public String toString() { return constant.toString(); } + public String toString(ToStringContext context) { return constant.toString(); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java new file mode 100644 index 00000000000..24a4c61a58c --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java @@ -0,0 +1,14 @@ +package com.yahoo.tensor.functions; + +/** + * An evaluation context which is passed down to all nested functions during evaluation. + * The default implementation is empty as this library does not in itself have any need for a + * context. + * + * @author bratseth + */ +public interface EvaluationContext { + + static EvaluationContext empty() { return new EvaluationContext() {}; } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/GeneratedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/GeneratedTensor.java index 998ebbb3f2f..81346b535ff 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/GeneratedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/GeneratedTensor.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.function.Function; @@ -38,16 +39,19 @@ public class GeneratedTensor extends PrimitiveTensorFunction { if (dimension.type() != TensorType.Dimension.Type.indexedBound) throw new IllegalArgumentException("A generated tensor can only have indexed bound dimensions"); } - + + @Override + public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + @Override public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public Tensor execute() { + public Tensor evaluate(EvaluationContext context) { throw new UnsupportedOperationException("Not implemented"); // TODO } @Override - public String toString() { return type + "(" + generator + ")"; } + public String toString(ToStringContext context) { return type + "(" + generator + ")"; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/JoinFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 9104307e866..323da5906c3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/JoinFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -1,5 +1,6 @@ package com.yahoo.tensor.functions; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.yahoo.tensor.MapTensor; @@ -21,12 +22,12 @@ import java.util.function.DoubleBinaryOperator; * * @author bratseth */ -public class JoinFunction extends PrimitiveTensorFunction { +public class Join extends PrimitiveTensorFunction { private final TensorFunction argumentA, argumentB; private final DoubleBinaryOperator combinator; - public JoinFunction(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) { + public Join(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) { Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); Objects.requireNonNull(combinator, "The combinator function cannot be null"); @@ -38,23 +39,26 @@ public class JoinFunction extends PrimitiveTensorFunction { public TensorFunction argumentA() { return argumentA; } public TensorFunction argumentB() { return argumentB; } public DoubleBinaryOperator combinator() { return combinator; } - + + @Override + public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); } + @Override public PrimitiveTensorFunction toPrimitive() { - return new JoinFunction(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); + return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); } @Override - public String toString() { - return "join(" + argumentA.toString() + ", " + argumentB.toString() + ", f(a, b) (" + combinator + "))"; + public String toString(ToStringContext context) { + return "join(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + combinator + ")"; } private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>(); @Override - public Tensor execute() { - Tensor a = argumentA.execute(); - Tensor b = argumentB.execute(); + public Tensor evaluate(EvaluationContext context) { + Tensor a = argumentA.evaluate(context); + Tensor b = argumentB.evaluate(context); // Dimension product Set<String> dimensions = combineDimensions(a, b); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java index ec2070d0231..0eeb1762888 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java @@ -1,5 +1,8 @@ package com.yahoo.tensor.functions; +import java.util.Collections; +import java.util.List; + /** * @author bratseth */ @@ -12,18 +15,21 @@ public class L1Normalize extends CompositeTensorFunction { this.argument = argument; this.dimension = dimension; } - + + @Override + public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + @Override public PrimitiveTensorFunction toPrimitive() { TensorFunction primitiveArgument = argument.toPrimitive(); - return new JoinFunction(primitiveArgument, - new ReduceFunction(primitiveArgument, ReduceFunction.Aggregator.avg, dimension), - ScalarFunctions.multiply()); + return new Join(primitiveArgument, + new Reduce(primitiveArgument, Reduce.Aggregator.avg, dimension), + ScalarFunctions.multiply()); } @Override - public String toString() { - return "l1_normalize(" + argument + ")"; + public String toString(ToStringContext context) { + return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")"; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java index 4abac10c1d7..fe041b38e44 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java @@ -1,5 +1,8 @@ package com.yahoo.tensor.functions; +import java.util.Collections; +import java.util.List; + /** * @author bratseth */ @@ -12,21 +15,24 @@ public class L2Normalize extends CompositeTensorFunction { this.argument = argument; this.dimension = dimension; } - + + @Override + public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + @Override public PrimitiveTensorFunction toPrimitive() { TensorFunction primitiveArgument = argument.toPrimitive(); - return new JoinFunction(primitiveArgument, - new MapFunction(new ReduceFunction(new MapFunction(primitiveArgument, ScalarFunctions.square()), - ReduceFunction.Aggregator.sum, - dimension), - ScalarFunctions.square()), - ScalarFunctions.divide()); + return new Join(primitiveArgument, + new Map(new Reduce(new Map(primitiveArgument, ScalarFunctions.square()), + Reduce.Aggregator.sum, + dimension), + ScalarFunctions.square()), + ScalarFunctions.divide()); } @Override - public String toString() { - return "l2_normalize(" + argument + ")"; + public String toString(ToStringContext context) { + return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")"; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index 21878e30fce..5db88953c64 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -5,7 +5,8 @@ import com.yahoo.tensor.MapTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; -import java.util.Map; +import java.util.Collections; +import java.util.List; import java.util.Objects; import java.util.function.DoubleUnaryOperator; @@ -14,12 +15,12 @@ import java.util.function.DoubleUnaryOperator; * * @author bratseth */ -public class MapFunction extends PrimitiveTensorFunction { +public class Map extends PrimitiveTensorFunction { private final TensorFunction argument; private final DoubleUnaryOperator mapper; - public MapFunction(TensorFunction argument, DoubleUnaryOperator mapper) { + public Map(TensorFunction argument, DoubleUnaryOperator mapper) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); Objects.requireNonNull(mapper, "The argument function cannot be null"); this.argument = argument; @@ -30,22 +31,25 @@ public class MapFunction extends PrimitiveTensorFunction { public DoubleUnaryOperator mapper() { return mapper; } @Override + public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + + @Override public PrimitiveTensorFunction toPrimitive() { - return new MapFunction(argument.toPrimitive(), mapper); + return new Map(argument.toPrimitive(), mapper); } @Override - public Tensor execute() { - Tensor argument = argument().execute(); + public Tensor evaluate(EvaluationContext context) { + Tensor argument = argument().evaluate(context); ImmutableMap.Builder<TensorAddress, Double> mappedCells = new ImmutableMap.Builder<>(); - for (Map.Entry<TensorAddress, Double> cell : argument.cells().entrySet()) + for (java.util.Map.Entry<TensorAddress, Double> cell : argument.cells().entrySet()) mappedCells.put(cell.getKey(), mapper.applyAsDouble(cell.getValue())); return new MapTensor(argument.dimensions(), mappedCells.build()); } @Override - public String toString() { - return "map(" + argument.toString() + ", f(a) (" + mapper + "))"; + public String toString(ToStringContext context) { + return "map(" + argument.toString(context) + ", " + mapper + ")"; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java new file mode 100644 index 00000000000..4492ab083d4 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -0,0 +1,38 @@ +package com.yahoo.tensor.functions; + +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * @author bratseth + */ +public class Matmul extends CompositeTensorFunction { + + private final TensorFunction argument1, argument2; + private final String dimension; + + public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) { + this.argument1 = argument1; + this.argument2 = argument2; + this.dimension = dimension; + } + + @Override + public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); } + + @Override + public PrimitiveTensorFunction toPrimitive() { + TensorFunction primitiveArgument1 = argument1.toPrimitive(); + TensorFunction primitiveArgument2 = argument2.toPrimitive(); + return new Reduce(new Join(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + dimension); + } + + @Override + public String toString(ToStringContext context) { + return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")"; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 6e339f91497..ef18cb61b17 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -21,7 +21,7 @@ import java.util.stream.Collectors; * * @author bratseth */ -public class ReduceFunction extends PrimitiveTensorFunction { +public class Reduce extends PrimitiveTensorFunction { public enum Aggregator { avg, count, prod, sum, max, min; } @@ -30,12 +30,12 @@ public class ReduceFunction extends PrimitiveTensorFunction { private final Aggregator aggregator; /** Creates a reduce function reducing aLL dimensions */ - public ReduceFunction(TensorFunction argument, Aggregator aggregator) { + public Reduce(TensorFunction argument, Aggregator aggregator) { this(argument, aggregator, Collections.emptyList()); } /** Creates a reduce function reducing a single dimension */ - public ReduceFunction(TensorFunction argument, Aggregator aggregator, String dimension) { + public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) { this(argument, aggregator, Collections.singletonList(dimension)); } @@ -48,7 +48,7 @@ public class ReduceFunction extends PrimitiveTensorFunction { * producing a dimensionless tensor (a scalar). * @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor */ - public ReduceFunction(TensorFunction argument, Aggregator aggregator, List<String> dimensions) { + public Reduce(TensorFunction argument, Aggregator aggregator, List<String> dimensions) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); Objects.requireNonNull(aggregator, "The aggregator cannot be null"); Objects.requireNonNull(dimensions, "The dimensions cannot be null"); @@ -60,13 +60,16 @@ public class ReduceFunction extends PrimitiveTensorFunction { public TensorFunction argument() { return argument; } @Override + public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + + @Override public PrimitiveTensorFunction toPrimitive() { - return new ReduceFunction(argument.toPrimitive(), aggregator, dimensions); + return new Reduce(argument.toPrimitive(), aggregator, dimensions); } @Override - public String toString() { - return "reduce(" + argument.toString() + ", " + aggregator + commaSeparated(dimensions) + ")"; + public String toString(ToStringContext context) { + return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")"; } private String commaSeparated(List<String> list) { @@ -77,8 +80,8 @@ public class ReduceFunction extends PrimitiveTensorFunction { } @Override - public Tensor execute() { - Tensor argument = this.argument.execute(); + public Tensor evaluate(EvaluationContext context) { + Tensor argument = this.argument.evaluate(context); if ( ! dimensions.isEmpty() && ! argument.dimensions().containsAll(dimensions)) throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/RenameFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 9098243c259..05af86c33e8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/RenameFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -7,6 +7,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -19,13 +20,13 @@ import java.util.stream.Collectors; * * @author bratseth */ -public class RenameFunction extends PrimitiveTensorFunction { +public class Rename extends PrimitiveTensorFunction { private final TensorFunction argument; private final List<String> fromDimensions; private final List<String> toDimensions; - public RenameFunction(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) { + public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null"); Objects.requireNonNull(toDimensions, "The 'to' dimensions cannot be null"); @@ -38,13 +39,16 @@ public class RenameFunction extends PrimitiveTensorFunction { this.fromDimensions = ImmutableList.copyOf(fromDimensions); this.toDimensions = ImmutableList.copyOf(toDimensions); } + + @Override + public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } @Override public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public Tensor execute() { - Tensor tensor = argument.execute(); + public Tensor evaluate(EvaluationContext context) { + Tensor tensor = argument.evaluate(context); Map<String, String> fromToMap = fromToMap(); Set<String> renamedDimensions = tensor.dimensions().stream() .map((d) -> fromToMap.getOrDefault(d, d)) @@ -71,8 +75,8 @@ public class RenameFunction extends PrimitiveTensorFunction { } @Override - public String toString() { - return "rename(" + argument + ", " + + public String toString(ToStringContext context) { + return "rename(" + argument.toString(context) + ", " + toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index 05f196da06d..f1ca4e0525c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -6,23 +6,35 @@ import java.util.function.DoubleUnaryOperator; /** * Factory of scalar Java functions. * The purpose of this is to embellish anonymous functions with a runtime type - * such that they can be inspected and return a usable toString. + * such that they can be inspected and will return a parseable toString. * * @author bratseth */ public class ScalarFunctions { - + + public static DoubleBinaryOperator add() { return new Addition(); } public static DoubleBinaryOperator multiply() { return new Multiplication(); } public static DoubleBinaryOperator divide() { return new Division(); } public static DoubleUnaryOperator square() { return new Square(); } - + public static DoubleUnaryOperator exp() { return new Exponent(); } + + public static class Addition implements DoubleBinaryOperator { + + @Override + public double applyAsDouble(double left, double right) { return left + right; } + + @Override + public String toString() { return "f(a,b)(a+b)"; } + + } + public static class Multiplication implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return left * right; } @Override - public String toString() { return "a * b"; } + public String toString() { return "f(a,b)(a*b)"; } } @@ -32,7 +44,7 @@ public class ScalarFunctions { public double applyAsDouble(double left, double right) { return left / right; } @Override - public String toString() { return "a / b"; } + public String toString() { return "f(a,b)(a/b)"; } } public static class Square implements DoubleUnaryOperator { @@ -41,7 +53,17 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return operand * operand; } @Override - public String toString() { return "a * a"; } + public String toString() { return "f(a)(a*a)"; } + + } + + public static class Exponent implements DoubleUnaryOperator { + + @Override + public double applyAsDouble(double operand) { return Math.exp(operand); } + + @Override + public String toString() { return "f(a)(exp(a))"; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java new file mode 100644 index 00000000000..aee8cedee17 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -0,0 +1,38 @@ +package com.yahoo.tensor.functions; + +import java.util.Collections; +import java.util.List; + +/** + * @author bratseth + */ +public class Softmax extends CompositeTensorFunction { + + private final TensorFunction argument; + private final String dimension; + + public Softmax(TensorFunction argument, String dimension) { + this.argument = argument; + this.dimension = dimension; + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + + @Override + public PrimitiveTensorFunction toPrimitive() { + TensorFunction primitiveArgument = argument.toPrimitive(); + // join(map(t, f(x)(exp(x))), reduce(map(t, f(x)(exp(x))), "sum", "dimension"), f(x,y)(x / y)) + return new Join(new Map(primitiveArgument, ScalarFunctions.exp()), + new Reduce(new Map(primitiveArgument, ScalarFunctions.exp()), + Reduce.Aggregator.sum, + dimension), + ScalarFunctions.divide()); + } + + @Override + public String toString(ToStringContext context) { + return "softmax(" + argument.toString(context) + ", " + dimension + ")"; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index d8e22a2088e..a717292632e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -2,6 +2,8 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.Tensor; +import java.util.List; + /** * A representation of a tensor function which is able to be translated to a set of primitive * tensor functions if necessary. @@ -11,6 +13,9 @@ import com.yahoo.tensor.Tensor; */ public abstract class TensorFunction { + /** Returns the function arguments of this node in the order they are applied */ + public abstract List<TensorFunction> functionArguments(); + /** * Translate this function - and all of its arguments recursively - * to a tree of primitive functions only. @@ -19,6 +24,24 @@ public abstract class TensorFunction { */ public abstract PrimitiveTensorFunction toPrimitive(); - public abstract Tensor execute(); + /** + * Evaluates this tensor. + * + * @param context a context which must be passed to all nexted functions when evaluating + */ + public abstract Tensor evaluate(EvaluationContext context); + + /** Evaluate with no context */ + public final Tensor evaluate() { return evaluate(EvaluationContext.empty()); } + + /** + * Return a string representation of this context. + * + * @param context a context which must be passed to all nexted functions when requesting the string value + */ + public abstract String toString(ToStringContext context); + + @Override + public final String toString() { return toString(ToStringContext.empty()); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java new file mode 100644 index 00000000000..b71229703d2 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java @@ -0,0 +1,14 @@ +package com.yahoo.tensor.functions; + +/** + * A context which is passed down to all nested functions when returning a string representation. + * The default implementation is empty as this library does not in itself have any need for a + * context. + * + * @author bratseth + */ +public interface ToStringContext { + + static ToStringContext empty() { return new ToStringContext() {}; } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java new file mode 100644 index 00000000000..1988c1d2390 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java @@ -0,0 +1,45 @@ +package com.yahoo.tensor.functions; + +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * @author bratseth + */ +public class XwPlusB extends CompositeTensorFunction { + + private final TensorFunction x, w, b; + private final String dimension; + + public XwPlusB(TensorFunction x, TensorFunction w, TensorFunction b, String dimension) { + this.x = x; + this.w = w; + this.b = b; + this.dimension = dimension; + } + + @Override + public List<TensorFunction> functionArguments() { return ImmutableList.of(x, w, b); } + + @Override + public PrimitiveTensorFunction toPrimitive() { + TensorFunction primitiveX = x.toPrimitive(); + TensorFunction primitiveW = w.toPrimitive(); + TensorFunction primitiveB = b.toPrimitive(); + return new Join(new Reduce(new Join(primitiveX, primitiveW, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + dimension), + primitiveB, + ScalarFunctions.add()); + } + + @Override + public String toString(ToStringContext context) { + return "xw_plus_b(" + x.toString(context) + ", " + + w.toString(context) + ", " + + b.toString(context) + ", " + + dimension + ")"; + } + +} |