aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java20
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java36
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java28
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java20
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java28
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java22
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java25
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java26
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java20
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java27
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java14
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java157
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java31
23 files changed, 324 insertions, 281 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
index 3478061b32c..c52e566ed1e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
@@ -1,38 +1,40 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.evaluation.TypeContext;
+
import java.util.Collections;
import java.util.List;
/**
* @author bratseth
*/
-public class Argmax extends CompositeTensorFunction {
+public class Argmax<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction argument;
+ private final TensorFunction<NAMETYPE> argument;
private final String dimension;
- public Argmax(TensorFunction argument, String dimension) {
+ public Argmax(TensorFunction<NAMETYPE> argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@Override
- public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Argmax must have 1 argument, got " + arguments.size());
- return new Argmax(arguments.get(0), dimension);
+ return new Argmax<>(arguments.get(0), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument = argument.toPrimitive();
- return new Join(primitiveArgument,
- new Reduce(primitiveArgument, Reduce.Aggregator.max, dimension),
- ScalarFunctions.equal());
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive();
+ return new Join<>(primitiveArgument,
+ new Reduce<>(primitiveArgument, Reduce.Aggregator.max, dimension),
+ ScalarFunctions.equal());
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
index ba5b3c3e4b2..aa0333aa421 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
@@ -1,38 +1,40 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.evaluation.TypeContext;
+
import java.util.Collections;
import java.util.List;
/**
* @author bratseth
*/
-public class Argmin extends CompositeTensorFunction {
+public class Argmin<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction argument;
+ private final TensorFunction<NAMETYPE> argument;
private final String dimension;
- public Argmin(TensorFunction argument, String dimension) {
+ public Argmin(TensorFunction<NAMETYPE> argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@Override
- public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Argmin must have 1 argument, got " + arguments.size());
- return new Argmin(arguments.get(0), dimension);
+ return new Argmin<>(arguments.get(0), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument = argument.toPrimitive();
- return new Join(primitiveArgument,
- new Reduce(primitiveArgument, Reduce.Aggregator.min, dimension),
- ScalarFunctions.equal());
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive();
+ return new Join<>(primitiveArgument,
+ new Reduce<>(primitiveArgument, Reduce.Aggregator.min, dimension),
+ ScalarFunctions.equal());
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
index 5dd2cc442aa..59c0fae39b5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
@@ -12,17 +12,17 @@ import com.yahoo.tensor.evaluation.TypeContext;
*
* @author bratseth
*/
-public abstract class CompositeTensorFunction extends TensorFunction {
+public abstract class CompositeTensorFunction<NAMETYPE extends TypeContext.Name> extends TensorFunction<NAMETYPE> {
/** Finds the type this produces by first converting it to a primitive function */
@Override
- public final <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public final TensorType type(TypeContext<NAMETYPE> context) {
return toPrimitive().type(context);
}
/** Evaluates this by first converting it to a primitive function */
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
return toPrimitive().evaluate(context);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 42c6fe2f4aa..a31a7da67e5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -23,12 +23,12 @@ import java.util.stream.Collectors;
*
* @author bratseth
*/
-public class Concat extends PrimitiveTensorFunction {
+public class Concat<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> {
- private final TensorFunction argumentA, argumentB;
+ private final TensorFunction<NAMETYPE> argumentA, argumentB;
private final String dimension;
- public Concat(TensorFunction argumentA, TensorFunction argumentB, String dimension) {
+ public Concat(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, String dimension) {
Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
Objects.requireNonNull(dimension, "The dimension cannot be null");
@@ -38,18 +38,18 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if (arguments.size() != 2)
throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size());
- return new Concat(arguments.get(0), arguments.get(1), dimension);
+ return new Concat<>(arguments.get(0), arguments.get(1), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- return new Concat(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Concat<>(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension);
}
@Override
@@ -58,7 +58,7 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext<NAMETYPE> context) {
return type(argumentA.type(context), argumentB.type(context));
}
@@ -86,7 +86,7 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
TensorType.Value combinedValueType = TensorType.combinedValueType(a.type(), b.type());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index 7c1ce068c90..9d6d488eb60 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -14,7 +14,7 @@ import java.util.List;
*
* @author bratseth
*/
-public class ConstantTensor extends PrimitiveTensorFunction {
+public class ConstantTensor<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> {
private final Tensor constant;
@@ -27,23 +27,23 @@ public class ConstantTensor extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return Collections.emptyList(); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("ConstantTensor must have 0 arguments, got " + arguments.size());
return this;
}
@Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return constant.type(); }
+ public TensorType type(TypeContext<NAMETYPE> context) { return constant.type(); }
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; }
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; }
@Override
public String toString(ToStringContext context) { return constant.toString(); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
index e302f6606e7..638a5246378 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor.functions;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.List;
@@ -14,7 +15,7 @@ import java.util.stream.Stream;
*
* @author bratseth
*/
-public class Diag extends CompositeTensorFunction {
+public class Diag<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> {
private final TensorType type;
private final Function<List<Long>, Double> diagFunction;
@@ -25,18 +26,18 @@ public class Diag extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return Collections.emptyList(); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 0)
throw new IllegalArgumentException("Diag must have 0 arguments, got " + arguments.size());
return this;
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- return new Generate(type, diagFunction);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Generate<>(type, diagFunction);
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
index a75e49c6402..6830ec50c5f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -20,7 +20,7 @@ import java.util.function.Function;
*
* @author bratseth
*/
-public abstract class DynamicTensor extends PrimitiveTensorFunction {
+public abstract class DynamicTensor<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> {
private final TensorType type;
@@ -29,20 +29,20 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; }
+ public TensorType type(TypeContext<NAMETYPE> context) { return type; }
@Override
- public List<TensorFunction> arguments() { return Collections.emptyList(); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if (arguments.size() != 0)
throw new IllegalArgumentException("Dynamic tensors must have 0 arguments, got " + arguments.size());
return this;
}
@Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }
TensorType type() { return type; }
@@ -54,27 +54,26 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
abstract String contentToString(ToStringContext context);
/** Creates a dynamic tensor function. The cell addresses must match the type. */
- public static DynamicTensor from(TensorType type, Map<TensorAddress, ScalarFunction> cells) {
- return new MappedDynamicTensor(type, cells);
+ public static <NAMETYPE extends TypeContext.Name> DynamicTensor<NAMETYPE> from(TensorType type, Map<TensorAddress, ScalarFunction<NAMETYPE>> cells) {
+ return new MappedDynamicTensor<>(type, cells);
}
/** Creates a dynamic tensor function for a bound, indexed tensor */
- public static DynamicTensor from(TensorType type, List<ScalarFunction> cells) {
- return new IndexedDynamicTensor(type, cells);
+ public static <NAMETYPE extends TypeContext.Name> DynamicTensor<NAMETYPE> from(TensorType type, List<ScalarFunction<NAMETYPE>> cells) {
+ return new IndexedDynamicTensor<>(type, cells);
}
- private static class MappedDynamicTensor extends DynamicTensor {
+ private static class MappedDynamicTensor<NAMETYPE extends TypeContext.Name> extends DynamicTensor<NAMETYPE> {
- private final ImmutableMap<TensorAddress, ScalarFunction> cells;
+ private final ImmutableMap<TensorAddress, ScalarFunction<NAMETYPE>> cells;
- MappedDynamicTensor(TensorType type, Map<TensorAddress, ScalarFunction> cells) {
+ MappedDynamicTensor(TensorType type, Map<TensorAddress, ScalarFunction<NAMETYPE>> cells) {
super(type);
this.cells = ImmutableMap.copyOf(cells);
}
@Override
- @SuppressWarnings("unchecked")
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor.Builder builder = Tensor.Builder.of(type());
for (var cell : cells.entrySet())
builder.cell(cell.getKey(), cell.getValue().apply(context));
@@ -102,11 +101,11 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
}
- private static class IndexedDynamicTensor extends DynamicTensor {
+ private static class IndexedDynamicTensor<NAMETYPE extends TypeContext.Name> extends DynamicTensor<NAMETYPE> {
- private final List<ScalarFunction> cells;
+ private final List<ScalarFunction<NAMETYPE>> cells;
- IndexedDynamicTensor(TensorType type, List<ScalarFunction> cells) {
+ IndexedDynamicTensor(TensorType type, List<ScalarFunction<NAMETYPE>> cells) {
super(type);
if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))
throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " +
@@ -115,8 +114,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
}
@Override
- @SuppressWarnings("unchecked")
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type());
for (int i = 0; i < cells.size(); i++)
builder.cellByDirectIndex(i, cells.get(i).apply(context));
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index 52620814ecd..aaed607aaa1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -19,13 +19,13 @@ import java.util.function.Function;
*
* @author bratseth
*/
-public class Generate extends PrimitiveTensorFunction {
+public class Generate<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> {
private final TensorType type;
// One of these are null
private final Function<List<Long>, Double> freeGenerator;
- private final ScalarFunction boundGenerator;
+ private final ScalarFunction<NAMETYPE> boundGenerator;
/** The same as Generate.free */
public Generate(TensorType type, Function<List<Long>, Double> generator) {
@@ -40,8 +40,8 @@ public class Generate extends PrimitiveTensorFunction {
* tensor cell which will receive the value
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
- public static Generate free(TensorType type, Function<List<Long>, Double> generator) {
- return new Generate(type, Objects.requireNonNull(generator), null);
+ public static <NAMETYPE extends TypeContext.Name> Generate<NAMETYPE> free(TensorType type, Function<List<Long>, Double> generator) {
+ return new Generate<>(type, Objects.requireNonNull(generator), null);
}
/**
@@ -52,11 +52,11 @@ public class Generate extends PrimitiveTensorFunction {
* tensor cell which will receive the value
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
- public static Generate bound(TensorType type, ScalarFunction generator) {
- return new Generate(type, null, Objects.requireNonNull(generator));
+ public static <NAMETYPE extends TypeContext.Name> Generate<NAMETYPE> bound(TensorType type, ScalarFunction<NAMETYPE> generator) {
+ return new Generate<>(type, null, Objects.requireNonNull(generator));
}
- private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction boundGenerator) {
+ private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction<NAMETYPE> boundGenerator) {
Objects.requireNonNull(type, "The argument tensor type cannot be null");
validateType(type);
this.type = type;
@@ -71,26 +71,26 @@ public class Generate extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return Collections.emptyList(); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 0)
throw new IllegalArgumentException("Generate must have 0 arguments, got " + arguments.size());
return this;
}
@Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; }
+ public TensorType type(TypeContext<NAMETYPE> context) { return type; }
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor.Builder builder = Tensor.Builder.of(type);
IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type));
- GenerateContext<NAMETYPE> generateContext = new GenerateContext<>(type, context);
+ GenerateContext generateContext = new GenerateContext(type, context);
for (int i = 0; i < indexes.size(); i++) {
indexes.next();
builder.cell(generateContext.apply(indexes), indexes.indexesForReading());
@@ -120,7 +120,7 @@ public class Generate extends PrimitiveTensorFunction {
* This returns all the current index values as variables and falls back to delivering from the given
* evaluation context.
*/
- private class GenerateContext<NAMETYPE extends TypeContext.Name> implements EvaluationContext<NAMETYPE> {
+ private class GenerateContext implements EvaluationContext<NAMETYPE> {
private final TensorType type;
private final EvaluationContext<NAMETYPE> context;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 2939b964f04..29239957260 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -31,12 +31,12 @@ import java.util.function.DoubleBinaryOperator;
*
* @author bratseth
*/
-public class Join extends PrimitiveTensorFunction {
+public class Join<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> {
- private final TensorFunction argumentA, argumentB;
+ private final TensorFunction<NAMETYPE> argumentA, argumentB;
private final DoubleBinaryOperator combinator;
- public Join(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) {
+ public Join(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, DoubleBinaryOperator combinator) {
Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
Objects.requireNonNull(combinator, "The combinator function cannot be null");
@@ -53,18 +53,18 @@ public class Join extends PrimitiveTensorFunction {
public DoubleBinaryOperator combinator() { return combinator; }
@Override
- public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 2)
throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size());
- return new Join(arguments.get(0), arguments.get(1), combinator);
+ return new Join<>(arguments.get(0), arguments.get(1), combinator);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Join<>(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
}
@Override
@@ -73,12 +73,12 @@ public class Join extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext<NAMETYPE> context) {
return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build();
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
index 7939457a101..3ec7ed4ed07 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
@@ -1,39 +1,41 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.evaluation.TypeContext;
+
import java.util.Collections;
import java.util.List;
/**
* @author bratseth
*/
-public class L1Normalize extends CompositeTensorFunction {
+public class L1Normalize<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction argument;
+ private final TensorFunction<NAMETYPE> argument;
private final String dimension;
- public L1Normalize(TensorFunction argument, String dimension) {
+ public L1Normalize(TensorFunction<NAMETYPE> argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@Override
- public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("L1Normalize must have 1 argument, got " + arguments.size());
- return new L1Normalize(arguments.get(0), dimension);
+ return new L1Normalize<>(arguments.get(0), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument = argument.toPrimitive();
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive();
// join(x, reduce(x, "avg", "dimension"), f(x,y) (x / y))
- return new Join(primitiveArgument,
- new Reduce(primitiveArgument, Reduce.Aggregator.sum, dimension),
- ScalarFunctions.divide());
+ return new Join<>(primitiveArgument,
+ new Reduce<>(primitiveArgument, Reduce.Aggregator.sum, dimension),
+ ScalarFunctions.divide());
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
index 40edb8ba23f..a6b30d0b292 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
@@ -1,40 +1,42 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.evaluation.TypeContext;
+
import java.util.Collections;
import java.util.List;
/**
* @author bratseth
*/
-public class L2Normalize extends CompositeTensorFunction {
+public class L2Normalize<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction argument;
+ private final TensorFunction<NAMETYPE> argument;
private final String dimension;
- public L2Normalize(TensorFunction argument, String dimension) {
+ public L2Normalize(TensorFunction<NAMETYPE> argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@Override
- public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("L2Normalize must have 1 argument, got " + arguments.size());
- return new L2Normalize(arguments.get(0), dimension);
+ return new L2Normalize<>(arguments.get(0), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument = argument.toPrimitive();
- return new Join(primitiveArgument,
- new Map(new Reduce(new Map(primitiveArgument, ScalarFunctions.square()),
- Reduce.Aggregator.sum,
- dimension),
- ScalarFunctions.sqrt()),
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive();
+ return new Join<>(primitiveArgument,
+ new Map<>(new Reduce<>(new Map<>(primitiveArgument, ScalarFunctions.square()),
+ Reduce.Aggregator.sum,
+ dimension),
+ ScalarFunctions.sqrt()),
ScalarFunctions.divide());
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index 016c60c6897..d482c70680a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -18,12 +18,12 @@ import java.util.function.DoubleUnaryOperator;
*
* @author bratseth
*/
-public class Map extends PrimitiveTensorFunction {
+public class Map<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> {
- private final TensorFunction argument;
+ private final TensorFunction<NAMETYPE> argument;
private final DoubleUnaryOperator mapper;
- public Map(TensorFunction argument, DoubleUnaryOperator mapper) {
+ public Map(TensorFunction<NAMETYPE> argument, DoubleUnaryOperator mapper) {
Objects.requireNonNull(argument, "The argument tensor cannot be null");
Objects.requireNonNull(mapper, "The argument function cannot be null");
this.argument = argument;
@@ -32,31 +32,31 @@ public class Map extends PrimitiveTensorFunction {
public static TensorType outputType(TensorType inputType) { return inputType; }
- public TensorFunction argument() { return argument; }
+ public TensorFunction<NAMETYPE> argument() { return argument; }
public DoubleUnaryOperator mapper() { return mapper; }
@Override
- public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Map must have 1 argument, got " + arguments.size());
- return new Map(arguments.get(0), mapper);
+ return new Map<>(arguments.get(0), mapper);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- return new Map(argument.toPrimitive(), mapper);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Map<>(argument.toPrimitive(), mapper);
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext<NAMETYPE> context) {
return argument.type(context);
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor argument = argument().evaluate(context);
Tensor.Builder builder = Tensor.Builder.of(argument.type());
for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
index 7c65afc98f9..d32e84f1ca0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -3,18 +3,19 @@ package com.yahoo.tensor.functions;
import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.List;
/**
* @author bratseth
*/
-public class Matmul extends CompositeTensorFunction {
+public class Matmul<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction argument1, argument2;
+ private final TensorFunction<NAMETYPE> argument1, argument2;
private final String dimension;
- public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) {
+ public Matmul(TensorFunction<NAMETYPE> argument1, TensorFunction<NAMETYPE> argument2, String dimension) {
this.argument1 = argument1;
this.argument2 = argument2;
this.dimension = dimension;
@@ -25,22 +26,22 @@ public class Matmul extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return ImmutableList.of(argument1, argument2); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argument1, argument2); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 2)
throw new IllegalArgumentException("Matmul must have 2 arguments, got " + arguments.size());
- return new Matmul(arguments.get(0), arguments.get(1), dimension);
+ return new Matmul<>(arguments.get(0), arguments.get(1), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument1 = argument1.toPrimitive();
- TensorFunction primitiveArgument2 = argument2.toPrimitive();
- return new Reduce(new Join(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()),
- Reduce.Aggregator.sum,
- dimension);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> primitiveArgument1 = argument1.toPrimitive();
+ TensorFunction<NAMETYPE> primitiveArgument2 = argument2.toPrimitive();
+ return new Reduce<>(new Join<>(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ dimension);
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
index e2aae39f11f..1113da39508 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.evaluation.TypeContext;
+
/**
* A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions.
* All tensor implementations must implement all primitive tensor functions.
@@ -8,6 +10,6 @@ package com.yahoo.tensor.functions;
*
* @author bratseth
*/
-public abstract class PrimitiveTensorFunction extends TensorFunction {
+public abstract class PrimitiveTensorFunction<NAMETYPE extends TypeContext.Name> extends TensorFunction<NAMETYPE> {
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
index 7175c91ed33..4ccf41de0fe 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor.functions;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.List;
@@ -13,7 +14,7 @@ import java.util.stream.Stream;
*
* @author bratseth
*/
-public class Random extends CompositeTensorFunction {
+public class Random<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> {
private final TensorType type;
@@ -22,18 +23,18 @@ public class Random extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return Collections.emptyList(); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 0)
throw new IllegalArgumentException("Random must have 0 arguments, got " + arguments.size());
return this;
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- return new Generate(type, ScalarFunctions.random());
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Generate<>(type, ScalarFunctions.random());
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
index d951ec9ccbd..f75d0f2cbef 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor.functions;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.List;
@@ -15,7 +16,7 @@ import java.util.stream.Stream;
*
* @author bratseth
*/
-public class Range extends CompositeTensorFunction {
+public class Range<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> {
private final TensorType type;
private final Function<List<Long>, Double> rangeFunction;
@@ -26,18 +27,18 @@ public class Range extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return Collections.emptyList(); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 0)
throw new IllegalArgumentException("Range must have 0 arguments, got " + arguments.size());
return this;
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- return new Generate(type, rangeFunction);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Generate<>(type, rangeFunction);
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 017dc3920e6..1d24333623b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -24,21 +24,21 @@ import java.util.Set;
*
* @author bratseth
*/
-public class Reduce extends PrimitiveTensorFunction {
+public class Reduce<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> {
public enum Aggregator { avg, count, prod, sum, max, min; }
- private final TensorFunction argument;
+ private final TensorFunction<NAMETYPE> argument;
private final List<String> dimensions;
private final Aggregator aggregator;
/** Creates a reduce function reducing all dimensions */
- public Reduce(TensorFunction argument, Aggregator aggregator) {
+ public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator) {
this(argument, aggregator, Collections.emptyList());
}
/** Creates a reduce function reducing a single dimension */
- public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) {
+ public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator, String dimension) {
this(argument, aggregator, Collections.singletonList(dimension));
}
@@ -51,7 +51,7 @@ public class Reduce extends PrimitiveTensorFunction {
* producing a dimensionless tensor (a scalar).
* @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor
*/
- public Reduce(TensorFunction argument, Aggregator aggregator, List<String> dimensions) {
+ public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator, List<String> dimensions) {
Objects.requireNonNull(argument, "The argument tensor cannot be null");
Objects.requireNonNull(aggregator, "The aggregator cannot be null");
Objects.requireNonNull(dimensions, "The dimensions cannot be null");
@@ -70,25 +70,25 @@ public class Reduce extends PrimitiveTensorFunction {
return b.build();
}
- public TensorFunction argument() { return argument; }
+ public TensorFunction<NAMETYPE> argument() { return argument; }
Aggregator aggregator() { return aggregator; }
List<String> dimensions() { return dimensions; }
@Override
- public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size());
- return new Reduce(arguments.get(0), aggregator, dimensions);
+ return new Reduce<>(arguments.get(0), aggregator, dimensions);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- return new Reduce(argument.toPrimitive(), aggregator, dimensions);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Reduce<>(argument.toPrimitive(), aggregator, dimensions);
}
@Override
@@ -104,7 +104,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext<NAMETYPE> context) {
return type(argument.type(context), dimensions);
}
@@ -118,7 +118,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
return evaluate(this.argument.evaluate(context), dimensions, aggregator);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
index 1134e8177ad..36c20b9e044 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -26,19 +26,19 @@ import java.util.stream.Collectors;
*
* @author lesters
*/
-public class ReduceJoin extends CompositeTensorFunction {
+public class ReduceJoin<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction argumentA, argumentB;
+ private final TensorFunction<NAMETYPE> argumentA, argumentB;
private final DoubleBinaryOperator combinator;
private final Reduce.Aggregator aggregator;
private final List<String> dimensions;
- public ReduceJoin(Reduce reduce, Join join) {
+ public ReduceJoin(Reduce<NAMETYPE> reduce, Join<NAMETYPE> join) {
this(join.arguments().get(0), join.arguments().get(1), join.combinator(), reduce.aggregator(), reduce.dimensions());
}
- public ReduceJoin(TensorFunction argumentA,
- TensorFunction argumentB,
+ public ReduceJoin(TensorFunction<NAMETYPE> argumentA,
+ TensorFunction<NAMETYPE> argumentB,
DoubleBinaryOperator combinator,
Reduce.Aggregator aggregator,
List<String> dimensions) {
@@ -50,25 +50,25 @@ public class ReduceJoin extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> arguments() {
+ public List<TensorFunction<NAMETYPE>> arguments() {
return ImmutableList.of(argumentA, argumentB);
}
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 2)
throw new IllegalArgumentException("ReduceJoin must have 2 arguments, got " + arguments.size());
- return new ReduceJoin(arguments.get(0), arguments.get(1), combinator, aggregator, dimensions);
+ return new ReduceJoin<>(arguments.get(0), arguments.get(1), combinator, aggregator, dimensions);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- Join join = new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
- return new Reduce(join, aggregator, dimensions);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ Join<NAMETYPE> join = new Join<>(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
+ return new Reduce<>(join, aggregator, dimensions);
}
@Override
- public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public final Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index 5694684956e..6731f80cbce 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -20,18 +20,18 @@ import java.util.Objects;
*
* @author bratseth
*/
-public class Rename extends PrimitiveTensorFunction {
+public class Rename<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> {
- private final TensorFunction argument;
+ private final TensorFunction<NAMETYPE> argument;
private final List<String> fromDimensions;
private final List<String> toDimensions;
private final Map<String, String> fromToMap;
- public Rename(TensorFunction argument, String fromDimension, String toDimension) {
+ public Rename(TensorFunction<NAMETYPE> argument, String fromDimension, String toDimension) {
this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension));
}
- public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) {
+ public Rename(TensorFunction<NAMETYPE> argument, List<String> fromDimensions, List<String> toDimensions) {
Objects.requireNonNull(argument, "The argument tensor cannot be null");
Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null");
Objects.requireNonNull(toDimensions, "The 'to' dimensions cannot be null");
@@ -57,20 +57,20 @@ public class Rename extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Rename must have 1 argument, got " + arguments.size());
- return new Rename(arguments.get(0), fromDimensions, toDimensions);
+ return new Rename<>(arguments.get(0), fromDimensions, toDimensions);
}
@Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext<NAMETYPE> context) {
return type(argument.type(context));
}
@@ -82,7 +82,7 @@ public class Rename extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor tensor = argument.evaluate(context);
TensorType renamedType = type(tensor.type());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
index bd732cdc11e..4636871e19c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.functions;
import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.List;
@@ -10,12 +11,12 @@ import java.util.List;
/**
* @author bratseth
*/
-public class Softmax extends CompositeTensorFunction {
+public class Softmax<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction argument;
+ private final TensorFunction<NAMETYPE> argument;
private final String dimension;
- public Softmax(TensorFunction argument, String dimension) {
+ public Softmax(TensorFunction<NAMETYPE> argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@@ -25,23 +26,23 @@ public class Softmax extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Softmax must have 1 argument, got " + arguments.size());
- return new Softmax(arguments.get(0), dimension);
+ return new Softmax<>(arguments.get(0), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument = argument.toPrimitive();
- return new Join(new Map(primitiveArgument, ScalarFunctions.exp()),
- new Reduce(new Map(primitiveArgument, ScalarFunctions.exp()),
- Reduce.Aggregator.sum,
- dimension),
- ScalarFunctions.divide());
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive();
+ return new Join<>(new Map<>(primitiveArgument, ScalarFunctions.exp()),
+ new Reduce<>(new Map<>(primitiveArgument, ScalarFunctions.exp()),
+ Reduce.Aggregator.sum,
+ dimension),
+ ScalarFunctions.divide());
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index 810651bbcfb..1c52046a9be 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -16,17 +16,17 @@ import java.util.List;
*
* @author bratseth
*/
-public abstract class TensorFunction {
+public abstract class TensorFunction<NAMETYPE extends TypeContext.Name> {
/** Returns the function arguments of this node in the order they are applied */
- public abstract List<TensorFunction> arguments();
+ public abstract List<TensorFunction<NAMETYPE>> arguments();
/**
* Returns a copy of this tensor function with the arguments replaced by the given list of arguments.
*
* @throws IllegalArgumentException if the argument list has the wrong size for this function
*/
- public abstract TensorFunction withArguments(List<TensorFunction> arguments);
+ public abstract TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments);
/**
* Translate this function - and all of its arguments recursively -
@@ -34,24 +34,24 @@ public abstract class TensorFunction {
*
* @return a tree of primitive functions implementing this
*/
- public abstract PrimitiveTensorFunction toPrimitive();
+ public abstract PrimitiveTensorFunction<NAMETYPE> toPrimitive();
/**
* Evaluates this tensor.
*
* @param context a context which must be passed to all nested functions when evaluating
*/
- public abstract <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context);
+ public abstract Tensor evaluate(EvaluationContext<NAMETYPE> context);
/**
* Returns the type of the tensor this produces given the input types in the context
*
* @param context a context which must be passed to all nexted functions when evaluating
*/
- public abstract <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context);
+ public abstract TensorType type(TypeContext<NAMETYPE> context);
/** Evaluate with no context */
- public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); }
+ public final Tensor evaluate() { return evaluate(new MapEvaluationContext<>()); }
/**
* Return a string representation of this context.
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java
index 0a881c0a290..0325753d2e0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java
@@ -19,10 +19,10 @@ import java.util.stream.Collectors;
* @author bratseth
*/
@Beta
-public class Value extends PrimitiveTensorFunction {
+public class Value<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> {
- private final TensorFunction argument;
- private final List<DimensionValue> cellAddress;
+ private final TensorFunction<NAMETYPE> argument;
+ private final List<DimensionValue<NAMETYPE>> cellAddress;
/**
* Creates a value function
@@ -31,7 +31,7 @@ public class Value extends PrimitiveTensorFunction {
* @param cellAddress a description of the address of the cell to return the value of. This is not a TensorAddress
* because those require a type, but a type is not resolved until this is evaluated
*/
- public Value(TensorFunction argument, List<DimensionValue> cellAddress) {
+ public Value(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> cellAddress) {
this.argument = Objects.requireNonNull(argument, "Argument cannot be null");
if (cellAddress.size() > 1 && cellAddress.stream().anyMatch(c -> c.dimension().isEmpty()))
throw new IllegalArgumentException("Short form of cell addresses is only supported with a single dimension: " +
@@ -40,34 +40,38 @@ public class Value extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return List.of(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); }
@Override
- public Value withArguments(List<TensorFunction> arguments) {
+ public Value<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if (arguments.size() != 1)
throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size());
- return new Value(arguments.get(0), cellAddress);
+ return new Value<NAMETYPE>(arguments.get(0), cellAddress);
}
@Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor tensor = argument.evaluate(context);
if (tensor.type().rank() != cellAddress.size())
throw new IllegalArgumentException("Type/address size mismatch: Cannot address a value with " + toString() +
" to a tensor of type " + tensor.type());
TensorAddress.Builder b = new TensorAddress.Builder(tensor.type());
for (int i = 0; i < cellAddress.size(); i++) {
- b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()),
- cellAddress.get(i).label());
+ if (cellAddress.get(i).label().isPresent())
+ b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()),
+ cellAddress.get(i).label().get());
+ else
+ b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()),
+ String.valueOf(cellAddress.get(i).index().get().apply(context).intValue()));
}
return Tensor.from(tensor.get(b.build()));
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext<NAMETYPE> context) {
return new TensorType.Builder(argument.type(context).valueType()).build();
}
@@ -87,69 +91,94 @@ public class Value extends PrimitiveTensorFunction {
else {
return "{" + cellAddress.stream().map(i -> i.toString()).collect(Collectors.joining(", ")) + "}";
}
- }
+ }
+
+ public static class DimensionValue<NAMETYPE extends TypeContext.Name> {
+
+ private final Optional<String> dimension;
- public static class DimensionValue {
+ /** The label of this, or null if index is set */
+ private final String label;
- private final Optional<String> dimension;
+ /** The function returning the index of this, or null if label is set */
+ private final ScalarFunction<NAMETYPE> index;
+
+ public DimensionValue(String dimension, String label) {
+ this(Optional.of(dimension), label, null);
+ }
+
+ public DimensionValue(String dimension, int index) {
+ this(Optional.of(dimension), null, new ConstantScalarFunction<>(index));
+ }
- /** The label of this. Always available, whether or not index is */
- private final String label;
+ public DimensionValue(int index) {
+ this(Optional.empty(), null, new ConstantScalarFunction<>(index));
+ }
- /** The index of this, or empty if this is a non-integer label */
- private final Optional<Integer> index;
+ public DimensionValue(String label) {
+ this(Optional.empty(), label, null);
+ }
- public DimensionValue(String dimension, String label) {
- this(Optional.of(dimension), label, indexOrEmpty(label));
- }
+ public DimensionValue(ScalarFunction<NAMETYPE> index) {
+ this(Optional.empty(), null, index);
+ }
- public DimensionValue(String dimension, int index) {
- this(Optional.of(dimension), String.valueOf(index), Optional.of(index));
- }
+ public DimensionValue(Optional<String> dimension, String label) {
+ this(dimension, label, null);
+ }
- public DimensionValue(int index) {
- this(Optional.empty(), String.valueOf(index), Optional.of(index));
- }
+ public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) {
+ this(dimension, null, index);
+ }
- public DimensionValue(String label) {
- this(Optional.empty(), label, indexOrEmpty(label));
- }
+ public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) {
+ this(Optional.of(dimension), null, index);
+ }
- private DimensionValue(Optional<String> dimension, String label, Optional<Integer> index) {
+ private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) {
this.dimension = dimension;
this.label = label;
this.index = index;
- }
-
- /**
- * Returns the given name of the dimension, or null if dense form is used, such that name
- * must be inferred from order
- */
- public Optional<String> dimension() { return dimension; }
-
- /** Returns the label or index for this dimension as a string */
- public String label() { return label; }
-
- /** Returns the index for this dimension, or empty if it is not a number */
- Optional<Integer> index() { return index; }
-
- @Override
- public String toString() {
- if (dimension.isPresent())
- return dimension.get() + ":" + label;
- else
- return label;
- }
-
- private static Optional<Integer> indexOrEmpty(String label) {
- try {
- return Optional.of(Integer.parseInt(label));
- }
- catch (IllegalArgumentException e) {
- return Optional.empty();
- }
- }
-
- }
+ }
+
+ /**
+ * Returns the given name of the dimension, or null if dense form is used, such that name
+ * must be inferred from order
+ */
+ public Optional<String> dimension() { return dimension; }
+
+ /** Returns the label for this dimension or empty if it is provided by an index function */
+ public Optional<String> label() { return Optional.ofNullable(label); }
+
+ /** Returns the index expression for this dimension, or empty if it is not a number */
+ public Optional<ScalarFunction<NAMETYPE>> index() { return Optional.ofNullable(index); }
+
+ @Override
+ public String toString() {
+ StringBuilder b = new StringBuilder();
+ dimension.ifPresent(d -> b.append(d).append(":"));
+ if (label != null)
+ b.append(label);
+ else
+ b.append(index);
+ return b.toString();
+ }
+
+ }
+
+ private static class ConstantScalarFunction<NAMETYPE extends TypeContext.Name> implements ScalarFunction<NAMETYPE> {
+
+ private final Double value;
+
+ public ConstantScalarFunction(int value) {
+ this.value = (double)value;
+ }
+
+ @Override
+ public Double apply(EvaluationContext<NAMETYPE> context) {
+ return value;
+ }
+
+ }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
index 4c0748ee39a..60b4438e909 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
@@ -2,18 +2,19 @@
package com.yahoo.tensor.functions;
import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.List;
/**
* @author bratseth
*/
-public class XwPlusB extends CompositeTensorFunction {
+public class XwPlusB<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction x, w, b;
+ private final TensorFunction<NAMETYPE> x, w, b;
private final String dimension;
- public XwPlusB(TensorFunction x, TensorFunction w, TensorFunction b, String dimension) {
+ public XwPlusB(TensorFunction<NAMETYPE> x, TensorFunction<NAMETYPE> w, TensorFunction<NAMETYPE> b, String dimension) {
this.x = x;
this.w = w;
this.b = b;
@@ -21,25 +22,25 @@ public class XwPlusB extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return ImmutableList.of(x, w, b); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(x, w, b); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 3)
throw new IllegalArgumentException("XwPlusB must have 3 arguments, got " + arguments.size());
- return new XwPlusB(arguments.get(0), arguments.get(1), arguments.get(2), dimension);
+ return new XwPlusB<>(arguments.get(0), arguments.get(1), arguments.get(2), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveX = x.toPrimitive();
- TensorFunction primitiveW = w.toPrimitive();
- TensorFunction primitiveB = b.toPrimitive();
- return new Join(new Reduce(new Join(primitiveX, primitiveW, ScalarFunctions.multiply()),
- Reduce.Aggregator.sum,
- dimension),
- primitiveB,
- ScalarFunctions.add());
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> primitiveX = x.toPrimitive();
+ TensorFunction<NAMETYPE> primitiveW = w.toPrimitive();
+ TensorFunction<NAMETYPE> primitiveB = b.toPrimitive();
+ return new Join<>(new Reduce<>(new Join<>(primitiveX, primitiveW, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ dimension),
+ primitiveB,
+ ScalarFunctions.add());
}
@Override