summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-11-24 14:18:01 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-11-24 14:18:01 +0100
commitcb2dc3460fa31dffb51e54847283038e8a0ae93c (patch)
treee96497fe6b167f8867ad9cb225ea979a6e09dab8 /vespajlib/src/main/java/com
parent437a2dc519cc991302c01acb8cd1df1e96b1283d (diff)
Implement composite functions
Diffstat (limited to 'vespajlib/src/main/java/com')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java56
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java14
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/GeneratedTensor.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java (renamed from vespajlib/src/main/java/com/yahoo/tensor/functions/JoinFunction.java)22
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java (renamed from vespajlib/src/main/java/com/yahoo/tensor/functions/MapFunction.java)22
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java38
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java (renamed from vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceFunction.java)21
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java (renamed from vespajlib/src/main/java/com/yahoo/tensor/functions/RenameFunction.java)16
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java34
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java38
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java25
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java14
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java45
17 files changed, 327 insertions, 86 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 06a01fad52e..d11225a5fe2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -2,15 +2,16 @@
package com.yahoo.tensor;
import com.google.common.annotations.Beta;
-import com.google.common.collect.ImmutableMap;
import com.yahoo.tensor.functions.ConstantTensor;
+import com.yahoo.tensor.functions.EvaluationContext;
import com.yahoo.tensor.functions.GeneratedTensor;
-import com.yahoo.tensor.functions.JoinFunction;
+import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.L1Normalize;
import com.yahoo.tensor.functions.L2Normalize;
-import com.yahoo.tensor.functions.MapFunction;
-import com.yahoo.tensor.functions.ReduceFunction;
-import com.yahoo.tensor.functions.RenameFunction;
+import com.yahoo.tensor.functions.Matmul;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.Softmax;
import java.util.ArrayList;
import java.util.Collections;
@@ -21,7 +22,6 @@ import java.util.Set;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;
-import java.util.function.UnaryOperator;
/**
* A multidimensional array which can be used in computations.
@@ -60,34 +60,42 @@ public interface Tensor {
// ----------------- Primitive tensor functions
default Tensor map(DoubleUnaryOperator mapper) {
- return new MapFunction(new ConstantTensor(this), mapper).execute();
+ return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate();
}
/** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */
- default Tensor reduce(ReduceFunction.Aggregator aggregator, List<String> dimensions) {
- return new ReduceFunction(new ConstantTensor(this), aggregator, dimensions).execute();
+ default Tensor reduce(Reduce.Aggregator aggregator, List<String> dimensions) {
+ return new Reduce(new ConstantTensor(this), aggregator, dimensions).evaluate();
}
default Tensor join(Tensor argument, DoubleBinaryOperator combinator) {
- return new JoinFunction(new ConstantTensor(this), new ConstantTensor(argument), combinator).execute();
+ return new Join(new ConstantTensor(this), new ConstantTensor(argument), combinator).evaluate();
}
default Tensor rename(List<String> fromDimensions, List<String> toDimensions) {
- return new RenameFunction(new ConstantTensor(this), fromDimensions, toDimensions).execute();
+ return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate();
}
static Tensor from(TensorType type, Function<List<Integer>, Double> valueSupplier) {
- return new GeneratedTensor(type, valueSupplier).execute();
+ return new GeneratedTensor(type, valueSupplier).evaluate();
}
// ----------------- Composite tensor functions which have a defined primitive mapping
default Tensor l1Normalize(String dimension) {
- return new L1Normalize(new ConstantTensor(this), dimension).execute();
+ return new L1Normalize(new ConstantTensor(this), dimension).evaluate();
}
default Tensor l2Normalize(String dimension) {
- return new L2Normalize(new ConstantTensor(this), dimension).execute();
+ return new L2Normalize(new ConstantTensor(this), dimension).evaluate();
+ }
+
+ default Tensor matmul(Tensor argument, String dimension) {
+ return new Matmul(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate();
+ }
+
+ default Tensor softmax(String dimension) {
+ return new Softmax(new ConstantTensor(this), dimension).evaluate();
}
// ----------------- Composite tensor functions mapped to primitives here on the fly
@@ -98,13 +106,15 @@ public interface Tensor {
default Tensor subtract(Tensor argument) { return join(argument, (a, b) -> (a - b )); }
default Tensor max(Tensor argument) { return join(argument, (a, b) -> (a > b ? a : b )); }
default Tensor min(Tensor argument) { return join(argument, (a, b) -> (a < b ? a : b )); }
+ default Tensor atan2(Tensor argument) { return join(argument, Math::atan2); }
+ default Tensor equal(Tensor argument) { return join(argument, (a, b) -> ( a == b ? 1.0 : 0.0)); }
- default Tensor avg(List<String> dimensions) { return reduce(ReduceFunction.Aggregator.avg, dimensions); }
- default Tensor count(List<String> dimensions) { return reduce(ReduceFunction.Aggregator.count, dimensions); }
- default Tensor max(List<String> dimensions) { return reduce(ReduceFunction.Aggregator.max, dimensions); }
- default Tensor min(List<String> dimensions) { return reduce(ReduceFunction.Aggregator.min, dimensions); }
- default Tensor prod(List<String> dimensions) { return reduce(ReduceFunction.Aggregator.prod, dimensions); }
- default Tensor sum(List<String> dimensions) { return reduce(ReduceFunction.Aggregator.sum, dimensions); }
+ default Tensor avg(List<String> dimensions) { return reduce(Reduce.Aggregator.avg, dimensions); }
+ default Tensor count(List<String> dimensions) { return reduce(Reduce.Aggregator.count, dimensions); }
+ default Tensor max(List<String> dimensions) { return reduce(Reduce.Aggregator.max, dimensions); }
+ default Tensor min(List<String> dimensions) { return reduce(Reduce.Aggregator.min, dimensions); }
+ default Tensor prod(List<String> dimensions) { return reduce(Reduce.Aggregator.prod, dimensions); }
+ default Tensor sum(List<String> dimensions) { return reduce(Reduce.Aggregator.sum, dimensions); }
/**
* Returns true if the given tensor is mathematically equal to this:
@@ -177,11 +187,11 @@ public interface Tensor {
}
static String contentToString(Tensor tensor) {
- List<Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet());
- Collections.sort(cellEntries, Map.Entry.<TensorAddress, Double>comparingByKey());
+ List<java.util.Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet());
+ Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey());
StringBuilder b = new StringBuilder("{");
- for (Map.Entry<TensorAddress, Double> cell : cellEntries) {
+ for (java.util.Map.Entry<TensorAddress, Double> cell : cellEntries) {
b.append(cell.getKey()).append(":").append(cell.getValue());
b.append(",");
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
index e564e1d6f25..31454e28baf 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
@@ -10,8 +10,8 @@ import com.yahoo.tensor.Tensor;
*/
public abstract class CompositeTensorFunction extends TensorFunction {
- /** Executes this by first converting it to a primitive function */
+ /** Evaluates this by first converting it to a primitive function */
@Override
- public final Tensor execute() { return toPrimitive().execute(); }
+ public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index c86ac7b137b..0727579a331 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -3,6 +3,9 @@ package com.yahoo.tensor.functions;
import com.yahoo.tensor.MapTensor;
import com.yahoo.tensor.Tensor;
+import java.util.Collections;
+import java.util.List;
+
/**
* A function which returns a constant tensor.
*
@@ -19,14 +22,17 @@ public class ConstantTensor extends PrimitiveTensorFunction {
public ConstantTensor(Tensor tensor) {
this.constant = tensor;
}
-
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+
@Override
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public Tensor execute() { return constant; }
+ public Tensor evaluate(EvaluationContext context) { return constant; }
@Override
- public String toString() { return constant.toString(); }
+ public String toString(ToStringContext context) { return constant.toString(); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java
new file mode 100644
index 00000000000..24a4c61a58c
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/EvaluationContext.java
@@ -0,0 +1,14 @@
+package com.yahoo.tensor.functions;
+
+/**
+ * An evaluation context which is passed down to all nested functions during evaluation.
+ * The default implementation is empty as this library does not in itself have any need for a
+ * context.
+ *
+ * @author bratseth
+ */
+public interface EvaluationContext {
+
+ static EvaluationContext empty() { return new EvaluationContext() {}; }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/GeneratedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/GeneratedTensor.java
index 998ebbb3f2f..81346b535ff 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/GeneratedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/GeneratedTensor.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.functions;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;
@@ -38,16 +39,19 @@ public class GeneratedTensor extends PrimitiveTensorFunction {
if (dimension.type() != TensorType.Dimension.Type.indexedBound)
throw new IllegalArgumentException("A generated tensor can only have indexed bound dimensions");
}
-
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+
@Override
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public Tensor execute() {
+ public Tensor evaluate(EvaluationContext context) {
throw new UnsupportedOperationException("Not implemented"); // TODO
}
@Override
- public String toString() { return type + "(" + generator + ")"; }
+ public String toString(ToStringContext context) { return type + "(" + generator + ")"; }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/JoinFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 9104307e866..323da5906c3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/JoinFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -1,5 +1,6 @@
package com.yahoo.tensor.functions;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.yahoo.tensor.MapTensor;
@@ -21,12 +22,12 @@ import java.util.function.DoubleBinaryOperator;
*
* @author bratseth
*/
-public class JoinFunction extends PrimitiveTensorFunction {
+public class Join extends PrimitiveTensorFunction {
private final TensorFunction argumentA, argumentB;
private final DoubleBinaryOperator combinator;
- public JoinFunction(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) {
+ public Join(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) {
Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
Objects.requireNonNull(combinator, "The combinator function cannot be null");
@@ -38,23 +39,26 @@ public class JoinFunction extends PrimitiveTensorFunction {
public TensorFunction argumentA() { return argumentA; }
public TensorFunction argumentB() { return argumentB; }
public DoubleBinaryOperator combinator() { return combinator; }
-
+
+ @Override
+ public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); }
+
@Override
public PrimitiveTensorFunction toPrimitive() {
- return new JoinFunction(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
+ return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
}
@Override
- public String toString() {
- return "join(" + argumentA.toString() + ", " + argumentB.toString() + ", f(a, b) (" + combinator + "))";
+ public String toString(ToStringContext context) {
+ return "join(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + combinator + ")";
}
private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>();
@Override
- public Tensor execute() {
- Tensor a = argumentA.execute();
- Tensor b = argumentB.execute();
+ public Tensor evaluate(EvaluationContext context) {
+ Tensor a = argumentA.evaluate(context);
+ Tensor b = argumentB.evaluate(context);
// Dimension product
Set<String> dimensions = combineDimensions(a, b);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
index ec2070d0231..0eeb1762888 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
@@ -1,5 +1,8 @@
package com.yahoo.tensor.functions;
+import java.util.Collections;
+import java.util.List;
+
/**
* @author bratseth
*/
@@ -12,18 +15,21 @@ public class L1Normalize extends CompositeTensorFunction {
this.argument = argument;
this.dimension = dimension;
}
-
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+
@Override
public PrimitiveTensorFunction toPrimitive() {
TensorFunction primitiveArgument = argument.toPrimitive();
- return new JoinFunction(primitiveArgument,
- new ReduceFunction(primitiveArgument, ReduceFunction.Aggregator.avg, dimension),
- ScalarFunctions.multiply());
+ return new Join(primitiveArgument,
+ new Reduce(primitiveArgument, Reduce.Aggregator.avg, dimension),
+ ScalarFunctions.multiply());
}
@Override
- public String toString() {
- return "l1_normalize(" + argument + ")";
+ public String toString(ToStringContext context) {
+ return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")";
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
index 4abac10c1d7..fe041b38e44 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
@@ -1,5 +1,8 @@
package com.yahoo.tensor.functions;
+import java.util.Collections;
+import java.util.List;
+
/**
* @author bratseth
*/
@@ -12,21 +15,24 @@ public class L2Normalize extends CompositeTensorFunction {
this.argument = argument;
this.dimension = dimension;
}
-
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+
@Override
public PrimitiveTensorFunction toPrimitive() {
TensorFunction primitiveArgument = argument.toPrimitive();
- return new JoinFunction(primitiveArgument,
- new MapFunction(new ReduceFunction(new MapFunction(primitiveArgument, ScalarFunctions.square()),
- ReduceFunction.Aggregator.sum,
- dimension),
- ScalarFunctions.square()),
- ScalarFunctions.divide());
+ return new Join(primitiveArgument,
+ new Map(new Reduce(new Map(primitiveArgument, ScalarFunctions.square()),
+ Reduce.Aggregator.sum,
+ dimension),
+ ScalarFunctions.square()),
+ ScalarFunctions.divide());
}
@Override
- public String toString() {
- return "l2_normalize(" + argument + ")";
+ public String toString(ToStringContext context) {
+ return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")";
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index 21878e30fce..5db88953c64 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -5,7 +5,8 @@ import com.yahoo.tensor.MapTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
-import java.util.Map;
+import java.util.Collections;
+import java.util.List;
import java.util.Objects;
import java.util.function.DoubleUnaryOperator;
@@ -14,12 +15,12 @@ import java.util.function.DoubleUnaryOperator;
*
* @author bratseth
*/
-public class MapFunction extends PrimitiveTensorFunction {
+public class Map extends PrimitiveTensorFunction {
private final TensorFunction argument;
private final DoubleUnaryOperator mapper;
- public MapFunction(TensorFunction argument, DoubleUnaryOperator mapper) {
+ public Map(TensorFunction argument, DoubleUnaryOperator mapper) {
Objects.requireNonNull(argument, "The argument tensor cannot be null");
Objects.requireNonNull(mapper, "The argument function cannot be null");
this.argument = argument;
@@ -30,22 +31,25 @@ public class MapFunction extends PrimitiveTensorFunction {
public DoubleUnaryOperator mapper() { return mapper; }
@Override
+ public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+
+ @Override
public PrimitiveTensorFunction toPrimitive() {
- return new MapFunction(argument.toPrimitive(), mapper);
+ return new Map(argument.toPrimitive(), mapper);
}
@Override
- public Tensor execute() {
- Tensor argument = argument().execute();
+ public Tensor evaluate(EvaluationContext context) {
+ Tensor argument = argument().evaluate(context);
ImmutableMap.Builder<TensorAddress, Double> mappedCells = new ImmutableMap.Builder<>();
- for (Map.Entry<TensorAddress, Double> cell : argument.cells().entrySet())
+ for (java.util.Map.Entry<TensorAddress, Double> cell : argument.cells().entrySet())
mappedCells.put(cell.getKey(), mapper.applyAsDouble(cell.getValue()));
return new MapTensor(argument.dimensions(), mappedCells.build());
}
@Override
- public String toString() {
- return "map(" + argument.toString() + ", f(a) (" + mapper + "))";
+ public String toString(ToStringContext context) {
+ return "map(" + argument.toString(context) + ", " + mapper + ")";
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
new file mode 100644
index 00000000000..4492ab083d4
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -0,0 +1,38 @@
+package com.yahoo.tensor.functions;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * @author bratseth
+ */
+public class Matmul extends CompositeTensorFunction {
+
+ private final TensorFunction argument1, argument2;
+ private final String dimension;
+
+ public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) {
+ this.argument1 = argument1;
+ this.argument2 = argument2;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ TensorFunction primitiveArgument1 = argument1.toPrimitive();
+ TensorFunction primitiveArgument2 = argument2.toPrimitive();
+ return new Reduce(new Join(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ dimension);
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")";
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 6e339f91497..ef18cb61b17 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -21,7 +21,7 @@ import java.util.stream.Collectors;
*
* @author bratseth
*/
-public class ReduceFunction extends PrimitiveTensorFunction {
+public class Reduce extends PrimitiveTensorFunction {
public enum Aggregator { avg, count, prod, sum, max, min; }
@@ -30,12 +30,12 @@ public class ReduceFunction extends PrimitiveTensorFunction {
private final Aggregator aggregator;
/** Creates a reduce function reducing aLL dimensions */
- public ReduceFunction(TensorFunction argument, Aggregator aggregator) {
+ public Reduce(TensorFunction argument, Aggregator aggregator) {
this(argument, aggregator, Collections.emptyList());
}
/** Creates a reduce function reducing a single dimension */
- public ReduceFunction(TensorFunction argument, Aggregator aggregator, String dimension) {
+ public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) {
this(argument, aggregator, Collections.singletonList(dimension));
}
@@ -48,7 +48,7 @@ public class ReduceFunction extends PrimitiveTensorFunction {
* producing a dimensionless tensor (a scalar).
* @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor
*/
- public ReduceFunction(TensorFunction argument, Aggregator aggregator, List<String> dimensions) {
+ public Reduce(TensorFunction argument, Aggregator aggregator, List<String> dimensions) {
Objects.requireNonNull(argument, "The argument tensor cannot be null");
Objects.requireNonNull(aggregator, "The aggregator cannot be null");
Objects.requireNonNull(dimensions, "The dimensions cannot be null");
@@ -60,13 +60,16 @@ public class ReduceFunction extends PrimitiveTensorFunction {
public TensorFunction argument() { return argument; }
@Override
+ public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+
+ @Override
public PrimitiveTensorFunction toPrimitive() {
- return new ReduceFunction(argument.toPrimitive(), aggregator, dimensions);
+ return new Reduce(argument.toPrimitive(), aggregator, dimensions);
}
@Override
- public String toString() {
- return "reduce(" + argument.toString() + ", " + aggregator + commaSeparated(dimensions) + ")";
+ public String toString(ToStringContext context) {
+ return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")";
}
private String commaSeparated(List<String> list) {
@@ -77,8 +80,8 @@ public class ReduceFunction extends PrimitiveTensorFunction {
}
@Override
- public Tensor execute() {
- Tensor argument = this.argument.execute();
+ public Tensor evaluate(EvaluationContext context) {
+ Tensor argument = this.argument.evaluate(context);
if ( ! dimensions.isEmpty() && ! argument.dimensions().containsAll(dimensions))
throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/RenameFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index 9098243c259..05af86c33e8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/RenameFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -7,6 +7,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -19,13 +20,13 @@ import java.util.stream.Collectors;
*
* @author bratseth
*/
-public class RenameFunction extends PrimitiveTensorFunction {
+public class Rename extends PrimitiveTensorFunction {
private final TensorFunction argument;
private final List<String> fromDimensions;
private final List<String> toDimensions;
- public RenameFunction(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) {
+ public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) {
Objects.requireNonNull(argument, "The argument tensor cannot be null");
Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null");
Objects.requireNonNull(toDimensions, "The 'to' dimensions cannot be null");
@@ -38,13 +39,16 @@ public class RenameFunction extends PrimitiveTensorFunction {
this.fromDimensions = ImmutableList.copyOf(fromDimensions);
this.toDimensions = ImmutableList.copyOf(toDimensions);
}
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
@Override
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public Tensor execute() {
- Tensor tensor = argument.execute();
+ public Tensor evaluate(EvaluationContext context) {
+ Tensor tensor = argument.evaluate(context);
Map<String, String> fromToMap = fromToMap();
Set<String> renamedDimensions = tensor.dimensions().stream()
.map((d) -> fromToMap.getOrDefault(d, d))
@@ -71,8 +75,8 @@ public class RenameFunction extends PrimitiveTensorFunction {
}
@Override
- public String toString() {
- return "rename(" + argument + ", " +
+ public String toString(ToStringContext context) {
+ return "rename(" + argument.toString(context) + ", " +
toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index 05f196da06d..f1ca4e0525c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -6,23 +6,35 @@ import java.util.function.DoubleUnaryOperator;
/**
* Factory of scalar Java functions.
* The purpose of this is to embellish anonymous functions with a runtime type
- * such that they can be inspected and return a usable toString.
+ * such that they can be inspected and will return a parseable toString.
*
* @author bratseth
*/
public class ScalarFunctions {
-
+
+ public static DoubleBinaryOperator add() { return new Addition(); }
public static DoubleBinaryOperator multiply() { return new Multiplication(); }
public static DoubleBinaryOperator divide() { return new Division(); }
public static DoubleUnaryOperator square() { return new Square(); }
-
+ public static DoubleUnaryOperator exp() { return new Exponent(); }
+
+ public static class Addition implements DoubleBinaryOperator {
+
+ @Override
+ public double applyAsDouble(double left, double right) { return left + right; }
+
+ @Override
+ public String toString() { return "f(a,b)(a+b)"; }
+
+ }
+
public static class Multiplication implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left * right; }
@Override
- public String toString() { return "a * b"; }
+ public String toString() { return "f(a,b)(a*b)"; }
}
@@ -32,7 +44,7 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return left / right; }
@Override
- public String toString() { return "a / b"; }
+ public String toString() { return "f(a,b)(a/b)"; }
}
public static class Square implements DoubleUnaryOperator {
@@ -41,7 +53,17 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return operand * operand; }
@Override
- public String toString() { return "a * a"; }
+ public String toString() { return "f(a)(a*a)"; }
+
+ }
+
+ public static class Exponent implements DoubleUnaryOperator {
+
+ @Override
+ public double applyAsDouble(double operand) { return Math.exp(operand); }
+
+ @Override
+ public String toString() { return "f(a)(exp(a))"; }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
new file mode 100644
index 00000000000..aee8cedee17
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -0,0 +1,38 @@
+package com.yahoo.tensor.functions;
+
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * @author bratseth
+ */
+public class Softmax extends CompositeTensorFunction {
+
+ private final TensorFunction argument;
+ private final String dimension;
+
+ public Softmax(TensorFunction argument, String dimension) {
+ this.argument = argument;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ TensorFunction primitiveArgument = argument.toPrimitive();
+ // join(map(t, f(x)(exp(x))), reduce(map(t, f(x)(exp(x))), "sum", "dimension"), f(x,y)(x / y))
+ return new Join(new Map(primitiveArgument, ScalarFunctions.exp()),
+ new Reduce(new Map(primitiveArgument, ScalarFunctions.exp()),
+ Reduce.Aggregator.sum,
+ dimension),
+ ScalarFunctions.divide());
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "softmax(" + argument.toString(context) + ", " + dimension + ")";
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index d8e22a2088e..a717292632e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -2,6 +2,8 @@ package com.yahoo.tensor.functions;
import com.yahoo.tensor.Tensor;
+import java.util.List;
+
/**
* A representation of a tensor function which is able to be translated to a set of primitive
* tensor functions if necessary.
@@ -11,6 +13,9 @@ import com.yahoo.tensor.Tensor;
*/
public abstract class TensorFunction {
+ /** Returns the function arguments of this node in the order they are applied */
+ public abstract List<TensorFunction> functionArguments();
+
/**
* Translate this function - and all of its arguments recursively -
* to a tree of primitive functions only.
@@ -19,6 +24,24 @@ public abstract class TensorFunction {
*/
public abstract PrimitiveTensorFunction toPrimitive();
- public abstract Tensor execute();
+ /**
+ * Evaluates this tensor.
+ *
+ * @param context a context which must be passed to all nexted functions when evaluating
+ */
+ public abstract Tensor evaluate(EvaluationContext context);
+
+ /** Evaluate with no context */
+ public final Tensor evaluate() { return evaluate(EvaluationContext.empty()); }
+
+ /**
+ * Return a string representation of this context.
+ *
+ * @param context a context which must be passed to all nexted functions when requesting the string value
+ */
+ public abstract String toString(ToStringContext context);
+
+ @Override
+ public final String toString() { return toString(ToStringContext.empty()); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java
new file mode 100644
index 00000000000..b71229703d2
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java
@@ -0,0 +1,14 @@
+package com.yahoo.tensor.functions;
+
+/**
+ * A context which is passed down to all nested functions when returning a string representation.
+ * The default implementation is empty as this library does not in itself have any need for a
+ * context.
+ *
+ * @author bratseth
+ */
+public interface ToStringContext {
+
+ static ToStringContext empty() { return new ToStringContext() {}; }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
new file mode 100644
index 00000000000..1988c1d2390
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
@@ -0,0 +1,45 @@
+package com.yahoo.tensor.functions;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * @author bratseth
+ */
+public class XwPlusB extends CompositeTensorFunction {
+
+ private final TensorFunction x, w, b;
+ private final String dimension;
+
+ public XwPlusB(TensorFunction x, TensorFunction w, TensorFunction b, String dimension) {
+ this.x = x;
+ this.w = w;
+ this.b = b;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return ImmutableList.of(x, w, b); }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ TensorFunction primitiveX = x.toPrimitive();
+ TensorFunction primitiveW = w.toPrimitive();
+ TensorFunction primitiveB = b.toPrimitive();
+ return new Join(new Reduce(new Join(primitiveX, primitiveW, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ dimension),
+ primitiveB,
+ ScalarFunctions.add());
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "xw_plus_b(" + x.toString(context) + ", " +
+ w.toString(context) + ", " +
+ b.toString(context) + ", " +
+ dimension + ")";
+ }
+
+}