diff options
Diffstat (limited to 'vespajlib/src/main/java')
26 files changed, 357 insertions, 312 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 0523624ea9f..b8ef84cabb7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.functions.Argmax; import com.yahoo.tensor.functions.Argmin; import com.yahoo.tensor.functions.Concat; @@ -144,25 +145,25 @@ public interface Tensor { // ----------------- Primitive tensor functions default Tensor map(DoubleUnaryOperator mapper) { - return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate(); + return new com.yahoo.tensor.functions.Map<>(new ConstantTensor<>(this), mapper).evaluate(); } /** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */ default Tensor reduce(Reduce.Aggregator aggregator, String ... dimensions) { - return new Reduce(new ConstantTensor(this), aggregator, Arrays.asList(dimensions)).evaluate(); + return new Reduce<>(new ConstantTensor<>(this), aggregator, Arrays.asList(dimensions)).evaluate(); } /** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */ default Tensor reduce(Reduce.Aggregator aggregator, List<String> dimensions) { - return new Reduce(new ConstantTensor(this), aggregator, dimensions).evaluate(); + return new Reduce<>(new ConstantTensor<>(this), aggregator, dimensions).evaluate(); } default Tensor join(Tensor argument, DoubleBinaryOperator combinator) { - return new Join(new ConstantTensor(this), new ConstantTensor(argument), combinator).evaluate(); + return new Join<>(new ConstantTensor<>(this), new ConstantTensor<>(argument), combinator).evaluate(); } default Tensor rename(String fromDimension, String toDimension) { - return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension), - Collections.singletonList(toDimension)).evaluate(); + return new Rename<>(new ConstantTensor<>(this), Collections.singletonList(fromDimension), + Collections.singletonList(toDimension)).evaluate(); } default Tensor concat(double argument, String dimension) { @@ -170,50 +171,50 @@ public interface Tensor { } default Tensor concat(Tensor argument, String dimension) { - return new Concat(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate(); + return new Concat<>(new ConstantTensor<>(this), new ConstantTensor<>(argument), dimension).evaluate(); } default Tensor rename(List<String> fromDimensions, List<String> toDimensions) { - return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate(); + return new Rename<>(new ConstantTensor<>(this), fromDimensions, toDimensions).evaluate(); } static Tensor generate(TensorType type, Function<List<Long>, Double> valueSupplier) { - return new Generate(type, valueSupplier).evaluate(); + return new Generate<>(type, valueSupplier).evaluate(); } // ----------------- Composite tensor functions which have a defined primitive mapping default Tensor l1Normalize(String dimension) { - return new L1Normalize(new ConstantTensor(this), dimension).evaluate(); + return new L1Normalize<>(new ConstantTensor<>(this), dimension).evaluate(); } default Tensor l2Normalize(String dimension) { - return new L2Normalize(new ConstantTensor(this), dimension).evaluate(); + return new L2Normalize<>(new ConstantTensor<>(this), dimension).evaluate(); } default Tensor matmul(Tensor argument, String dimension) { - return new Matmul(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate(); + return new Matmul<>(new ConstantTensor<>(this), new ConstantTensor<>(argument), dimension).evaluate(); } default Tensor softmax(String dimension) { - return new Softmax(new ConstantTensor(this), dimension).evaluate(); + return new Softmax<>(new ConstantTensor<>(this), dimension).evaluate(); } default Tensor xwPlusB(Tensor w, Tensor b, String dimension) { - return new XwPlusB(new ConstantTensor(this), new ConstantTensor(w), new ConstantTensor(b), dimension).evaluate(); + return new XwPlusB<>(new ConstantTensor<>(this), new ConstantTensor<>(w), new ConstantTensor<>(b), dimension).evaluate(); } default Tensor argmax(String dimension) { - return new Argmax(new ConstantTensor(this), dimension).evaluate(); + return new Argmax<>(new ConstantTensor<>(this), dimension).evaluate(); } - default Tensor argmin(String dimension) { return new Argmin(new ConstantTensor(this), dimension).evaluate(); } + default Tensor argmin(String dimension) { return new Argmin<>(new ConstantTensor<>(this), dimension).evaluate(); } - static Tensor diag(TensorType type) { return new Diag(type).evaluate(); } + static Tensor diag(TensorType type) { return new Diag<>(type).evaluate(); } - static Tensor random(TensorType type) { return new Random(type).evaluate(); } + static Tensor random(TensorType type) { return new Random<>(type).evaluate(); } - static Tensor range(TensorType type) { return new Range(type).evaluate(); } + static Tensor range(TensorType type) { return new Range<>(type).evaluate(); } // ----------------- Composite tensor functions mapped to primitives here on the fly diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java index e302e317418..076c73212d1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java @@ -9,7 +9,7 @@ import java.util.HashMap; /** * @author bratseth */ -public class MapEvaluationContext implements EvaluationContext<TypeContext.Name> { +public class MapEvaluationContext<NAMETYPE extends TypeContext.Name> implements EvaluationContext<NAMETYPE> { private final java.util.Map<String, Tensor> bindings = new HashMap<>(); @@ -17,14 +17,14 @@ public class MapEvaluationContext implements EvaluationContext<TypeContext.Name> @Override public TensorType getType(String name) { - return getType(new Name(name)); + Tensor tensor = bindings.get(name); + if (tensor == null) return null; + return tensor.type(); } @Override - public TensorType getType(Name name) { - Tensor tensor = bindings.get(name.toString()); - if (tensor == null) return null; - return tensor.type(); + public TensorType getType(NAMETYPE name) { + return getType(name.name()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index c1cfa319664..f4e025b3843 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -16,7 +16,7 @@ import java.util.Optional; * * @author bratseth */ -public class VariableTensor extends PrimitiveTensorFunction { +public class VariableTensor<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> { private final String name; private final Optional<TensorType> requiredType; @@ -33,16 +33,16 @@ public class VariableTensor extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { return this; } + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { return this; } @Override - public PrimitiveTensorFunction toPrimitive() { return this; } + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext<NAMETYPE> context) { TensorType givenType = context.getType(name); if (givenType == null) return null; verifyType(givenType); @@ -50,7 +50,7 @@ public class VariableTensor extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor tensor = context.getTensor(name); if (tensor == null) return null; verifyType(tensor.type()); @@ -67,4 +67,5 @@ public class VariableTensor extends PrimitiveTensorFunction { throw new IllegalArgumentException("Variable '" + name + "' must be compatible with " + requiredType.get() + " but was " + givenType); } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java index 3478061b32c..c52e566ed1e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java @@ -1,38 +1,40 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.evaluation.TypeContext; + import java.util.Collections; import java.util.List; /** * @author bratseth */ -public class Argmax extends CompositeTensorFunction { +public class Argmax<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final String dimension; - public Argmax(TensorFunction argument, String dimension) { + public Argmax(TensorFunction<NAMETYPE> argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Argmax must have 1 argument, got " + arguments.size()); - return new Argmax(arguments.get(0), dimension); + return new Argmax<>(arguments.get(0), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument = argument.toPrimitive(); - return new Join(primitiveArgument, - new Reduce(primitiveArgument, Reduce.Aggregator.max, dimension), - ScalarFunctions.equal()); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive(); + return new Join<>(primitiveArgument, + new Reduce<>(primitiveArgument, Reduce.Aggregator.max, dimension), + ScalarFunctions.equal()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java index ba5b3c3e4b2..aa0333aa421 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java @@ -1,38 +1,40 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.evaluation.TypeContext; + import java.util.Collections; import java.util.List; /** * @author bratseth */ -public class Argmin extends CompositeTensorFunction { +public class Argmin<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final String dimension; - public Argmin(TensorFunction argument, String dimension) { + public Argmin(TensorFunction<NAMETYPE> argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Argmin must have 1 argument, got " + arguments.size()); - return new Argmin(arguments.get(0), dimension); + return new Argmin<>(arguments.get(0), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument = argument.toPrimitive(); - return new Join(primitiveArgument, - new Reduce(primitiveArgument, Reduce.Aggregator.min, dimension), - ScalarFunctions.equal()); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive(); + return new Join<>(primitiveArgument, + new Reduce<>(primitiveArgument, Reduce.Aggregator.min, dimension), + ScalarFunctions.equal()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 5dd2cc442aa..59c0fae39b5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -12,17 +12,17 @@ import com.yahoo.tensor.evaluation.TypeContext; * * @author bratseth */ -public abstract class CompositeTensorFunction extends TensorFunction { +public abstract class CompositeTensorFunction<NAMETYPE extends TypeContext.Name> extends TensorFunction<NAMETYPE> { /** Finds the type this produces by first converting it to a primitive function */ @Override - public final <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public final TensorType type(TypeContext<NAMETYPE> context) { return toPrimitive().type(context); } /** Evaluates this by first converting it to a primitive function */ @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { return toPrimitive().evaluate(context); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 42c6fe2f4aa..a31a7da67e5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -23,12 +23,12 @@ import java.util.stream.Collectors; * * @author bratseth */ -public class Concat extends PrimitiveTensorFunction { +public class Concat<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> { - private final TensorFunction argumentA, argumentB; + private final TensorFunction<NAMETYPE> argumentA, argumentB; private final String dimension; - public Concat(TensorFunction argumentA, TensorFunction argumentB, String dimension) { + public Concat(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, String dimension) { Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); Objects.requireNonNull(dimension, "The dimension cannot be null"); @@ -38,18 +38,18 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); } + public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if (arguments.size() != 2) throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size()); - return new Concat(arguments.get(0), arguments.get(1), dimension); + return new Concat<>(arguments.get(0), arguments.get(1), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - return new Concat(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Concat<>(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension); } @Override @@ -58,7 +58,7 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext<NAMETYPE> context) { return type(argumentA.type(context), argumentB.type(context)); } @@ -86,7 +86,7 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); TensorType.Value combinedValueType = TensorType.combinedValueType(a.type(), b.type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 7c1ce068c90..9d6d488eb60 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -14,7 +14,7 @@ import java.util.List; * * @author bratseth */ -public class ConstantTensor extends PrimitiveTensorFunction { +public class ConstantTensor<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> { private final Tensor constant; @@ -27,23 +27,23 @@ public class ConstantTensor extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("ConstantTensor must have 0 arguments, got " + arguments.size()); return this; } @Override - public PrimitiveTensorFunction toPrimitive() { return this; } + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return constant.type(); } + public TensorType type(TypeContext<NAMETYPE> context) { return constant.type(); } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; } + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; } @Override public String toString(ToStringContext context) { return constant.toString(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java index e302f6606e7..638a5246378 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.List; @@ -14,7 +15,7 @@ import java.util.stream.Stream; * * @author bratseth */ -public class Diag extends CompositeTensorFunction { +public class Diag<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> { private final TensorType type; private final Function<List<Long>, Double> diagFunction; @@ -25,18 +26,18 @@ public class Diag extends CompositeTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Diag must have 0 arguments, got " + arguments.size()); return this; } @Override - public PrimitiveTensorFunction toPrimitive() { - return new Generate(type, diagFunction); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Generate<>(type, diagFunction); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index a75e49c6402..6830ec50c5f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -20,7 +20,7 @@ import java.util.function.Function; * * @author bratseth */ -public abstract class DynamicTensor extends PrimitiveTensorFunction { +public abstract class DynamicTensor<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> { private final TensorType type; @@ -29,20 +29,20 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; } + public TensorType type(TypeContext<NAMETYPE> context) { return type; } @Override - public List<TensorFunction> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if (arguments.size() != 0) throw new IllegalArgumentException("Dynamic tensors must have 0 arguments, got " + arguments.size()); return this; } @Override - public PrimitiveTensorFunction toPrimitive() { return this; } + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } TensorType type() { return type; } @@ -54,27 +54,26 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { abstract String contentToString(ToStringContext context); /** Creates a dynamic tensor function. The cell addresses must match the type. */ - public static DynamicTensor from(TensorType type, Map<TensorAddress, ScalarFunction> cells) { - return new MappedDynamicTensor(type, cells); + public static <NAMETYPE extends TypeContext.Name> DynamicTensor<NAMETYPE> from(TensorType type, Map<TensorAddress, ScalarFunction<NAMETYPE>> cells) { + return new MappedDynamicTensor<>(type, cells); } /** Creates a dynamic tensor function for a bound, indexed tensor */ - public static DynamicTensor from(TensorType type, List<ScalarFunction> cells) { - return new IndexedDynamicTensor(type, cells); + public static <NAMETYPE extends TypeContext.Name> DynamicTensor<NAMETYPE> from(TensorType type, List<ScalarFunction<NAMETYPE>> cells) { + return new IndexedDynamicTensor<>(type, cells); } - private static class MappedDynamicTensor extends DynamicTensor { + private static class MappedDynamicTensor<NAMETYPE extends TypeContext.Name> extends DynamicTensor<NAMETYPE> { - private final ImmutableMap<TensorAddress, ScalarFunction> cells; + private final ImmutableMap<TensorAddress, ScalarFunction<NAMETYPE>> cells; - MappedDynamicTensor(TensorType type, Map<TensorAddress, ScalarFunction> cells) { + MappedDynamicTensor(TensorType type, Map<TensorAddress, ScalarFunction<NAMETYPE>> cells) { super(type); this.cells = ImmutableMap.copyOf(cells); } @Override - @SuppressWarnings("unchecked") - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type()); for (var cell : cells.entrySet()) builder.cell(cell.getKey(), cell.getValue().apply(context)); @@ -102,11 +101,11 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } - private static class IndexedDynamicTensor extends DynamicTensor { + private static class IndexedDynamicTensor<NAMETYPE extends TypeContext.Name> extends DynamicTensor<NAMETYPE> { - private final List<ScalarFunction> cells; + private final List<ScalarFunction<NAMETYPE>> cells; - IndexedDynamicTensor(TensorType type, List<ScalarFunction> cells) { + IndexedDynamicTensor(TensorType type, List<ScalarFunction<NAMETYPE>> cells) { super(type); if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)) throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " + @@ -115,8 +114,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override - @SuppressWarnings("unchecked") - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type()); for (int i = 0; i < cells.size(); i++) builder.cellByDirectIndex(i, cells.get(i).apply(context)); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 52620814ecd..aaed607aaa1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -19,13 +19,13 @@ import java.util.function.Function; * * @author bratseth */ -public class Generate extends PrimitiveTensorFunction { +public class Generate<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> { private final TensorType type; // One of these are null private final Function<List<Long>, Double> freeGenerator; - private final ScalarFunction boundGenerator; + private final ScalarFunction<NAMETYPE> boundGenerator; /** The same as Generate.free */ public Generate(TensorType type, Function<List<Long>, Double> generator) { @@ -40,8 +40,8 @@ public class Generate extends PrimitiveTensorFunction { * tensor cell which will receive the value * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound */ - public static Generate free(TensorType type, Function<List<Long>, Double> generator) { - return new Generate(type, Objects.requireNonNull(generator), null); + public static <NAMETYPE extends TypeContext.Name> Generate<NAMETYPE> free(TensorType type, Function<List<Long>, Double> generator) { + return new Generate<>(type, Objects.requireNonNull(generator), null); } /** @@ -52,11 +52,11 @@ public class Generate extends PrimitiveTensorFunction { * tensor cell which will receive the value * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound */ - public static Generate bound(TensorType type, ScalarFunction generator) { - return new Generate(type, null, Objects.requireNonNull(generator)); + public static <NAMETYPE extends TypeContext.Name> Generate<NAMETYPE> bound(TensorType type, ScalarFunction<NAMETYPE> generator) { + return new Generate<>(type, null, Objects.requireNonNull(generator)); } - private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction boundGenerator) { + private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction<NAMETYPE> boundGenerator) { Objects.requireNonNull(type, "The argument tensor type cannot be null"); validateType(type); this.type = type; @@ -71,26 +71,26 @@ public class Generate extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Generate must have 0 arguments, got " + arguments.size()); return this; } @Override - public PrimitiveTensorFunction toPrimitive() { return this; } + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; } + public TensorType type(TypeContext<NAMETYPE> context) { return type; } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type); IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type)); - GenerateContext<NAMETYPE> generateContext = new GenerateContext<>(type, context); + GenerateContext generateContext = new GenerateContext(type, context); for (int i = 0; i < indexes.size(); i++) { indexes.next(); builder.cell(generateContext.apply(indexes), indexes.indexesForReading()); @@ -120,7 +120,7 @@ public class Generate extends PrimitiveTensorFunction { * This returns all the current index values as variables and falls back to delivering from the given * evaluation context. */ - private class GenerateContext<NAMETYPE extends TypeContext.Name> implements EvaluationContext<NAMETYPE> { + private class GenerateContext implements EvaluationContext<NAMETYPE> { private final TensorType type; private final EvaluationContext<NAMETYPE> context; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 2939b964f04..29239957260 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -31,12 +31,12 @@ import java.util.function.DoubleBinaryOperator; * * @author bratseth */ -public class Join extends PrimitiveTensorFunction { +public class Join<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> { - private final TensorFunction argumentA, argumentB; + private final TensorFunction<NAMETYPE> argumentA, argumentB; private final DoubleBinaryOperator combinator; - public Join(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) { + public Join(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, DoubleBinaryOperator combinator) { Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); Objects.requireNonNull(combinator, "The combinator function cannot be null"); @@ -53,18 +53,18 @@ public class Join extends PrimitiveTensorFunction { public DoubleBinaryOperator combinator() { return combinator; } @Override - public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); } + public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 2) throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size()); - return new Join(arguments.get(0), arguments.get(1), combinator); + return new Join<>(arguments.get(0), arguments.get(1), combinator); } @Override - public PrimitiveTensorFunction toPrimitive() { - return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Join<>(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); } @Override @@ -73,12 +73,12 @@ public class Join extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext<NAMETYPE> context) { return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build(); } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java index 7939457a101..3ec7ed4ed07 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java @@ -1,39 +1,41 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.evaluation.TypeContext; + import java.util.Collections; import java.util.List; /** * @author bratseth */ -public class L1Normalize extends CompositeTensorFunction { +public class L1Normalize<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final String dimension; - public L1Normalize(TensorFunction argument, String dimension) { + public L1Normalize(TensorFunction<NAMETYPE> argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("L1Normalize must have 1 argument, got " + arguments.size()); - return new L1Normalize(arguments.get(0), dimension); + return new L1Normalize<>(arguments.get(0), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument = argument.toPrimitive(); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive(); // join(x, reduce(x, "avg", "dimension"), f(x,y) (x / y)) - return new Join(primitiveArgument, - new Reduce(primitiveArgument, Reduce.Aggregator.sum, dimension), - ScalarFunctions.divide()); + return new Join<>(primitiveArgument, + new Reduce<>(primitiveArgument, Reduce.Aggregator.sum, dimension), + ScalarFunctions.divide()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java index 40edb8ba23f..a6b30d0b292 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java @@ -1,40 +1,42 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.evaluation.TypeContext; + import java.util.Collections; import java.util.List; /** * @author bratseth */ -public class L2Normalize extends CompositeTensorFunction { +public class L2Normalize<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final String dimension; - public L2Normalize(TensorFunction argument, String dimension) { + public L2Normalize(TensorFunction<NAMETYPE> argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("L2Normalize must have 1 argument, got " + arguments.size()); - return new L2Normalize(arguments.get(0), dimension); + return new L2Normalize<>(arguments.get(0), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument = argument.toPrimitive(); - return new Join(primitiveArgument, - new Map(new Reduce(new Map(primitiveArgument, ScalarFunctions.square()), - Reduce.Aggregator.sum, - dimension), - ScalarFunctions.sqrt()), + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive(); + return new Join<>(primitiveArgument, + new Map<>(new Reduce<>(new Map<>(primitiveArgument, ScalarFunctions.square()), + Reduce.Aggregator.sum, + dimension), + ScalarFunctions.sqrt()), ScalarFunctions.divide()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index 016c60c6897..d482c70680a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -18,12 +18,12 @@ import java.util.function.DoubleUnaryOperator; * * @author bratseth */ -public class Map extends PrimitiveTensorFunction { +public class Map<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> { - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final DoubleUnaryOperator mapper; - public Map(TensorFunction argument, DoubleUnaryOperator mapper) { + public Map(TensorFunction<NAMETYPE> argument, DoubleUnaryOperator mapper) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); Objects.requireNonNull(mapper, "The argument function cannot be null"); this.argument = argument; @@ -32,31 +32,31 @@ public class Map extends PrimitiveTensorFunction { public static TensorType outputType(TensorType inputType) { return inputType; } - public TensorFunction argument() { return argument; } + public TensorFunction<NAMETYPE> argument() { return argument; } public DoubleUnaryOperator mapper() { return mapper; } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Map must have 1 argument, got " + arguments.size()); - return new Map(arguments.get(0), mapper); + return new Map<>(arguments.get(0), mapper); } @Override - public PrimitiveTensorFunction toPrimitive() { - return new Map(argument.toPrimitive(), mapper); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Map<>(argument.toPrimitive(), mapper); } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext<NAMETYPE> context) { return argument.type(context); } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor argument = argument().evaluate(context); Tensor.Builder builder = Tensor.Builder.of(argument.type()); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index 7c65afc98f9..d32e84f1ca0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -3,18 +3,19 @@ package com.yahoo.tensor.functions; import com.google.common.collect.ImmutableList; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.List; /** * @author bratseth */ -public class Matmul extends CompositeTensorFunction { +public class Matmul<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction argument1, argument2; + private final TensorFunction<NAMETYPE> argument1, argument2; private final String dimension; - public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) { + public Matmul(TensorFunction<NAMETYPE> argument1, TensorFunction<NAMETYPE> argument2, String dimension) { this.argument1 = argument1; this.argument2 = argument2; this.dimension = dimension; @@ -25,22 +26,22 @@ public class Matmul extends CompositeTensorFunction { } @Override - public List<TensorFunction> arguments() { return ImmutableList.of(argument1, argument2); } + public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argument1, argument2); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 2) throw new IllegalArgumentException("Matmul must have 2 arguments, got " + arguments.size()); - return new Matmul(arguments.get(0), arguments.get(1), dimension); + return new Matmul<>(arguments.get(0), arguments.get(1), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument1 = argument1.toPrimitive(); - TensorFunction primitiveArgument2 = argument2.toPrimitive(); - return new Reduce(new Join(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()), - Reduce.Aggregator.sum, - dimension); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitiveArgument1 = argument1.toPrimitive(); + TensorFunction<NAMETYPE> primitiveArgument2 = argument2.toPrimitive(); + return new Reduce<>(new Join<>(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + dimension); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java index e2aae39f11f..1113da39508 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.evaluation.TypeContext; + /** * A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions. * All tensor implementations must implement all primitive tensor functions. @@ -8,6 +10,6 @@ package com.yahoo.tensor.functions; * * @author bratseth */ -public abstract class PrimitiveTensorFunction extends TensorFunction { +public abstract class PrimitiveTensorFunction<NAMETYPE extends TypeContext.Name> extends TensorFunction<NAMETYPE> { } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java index 7175c91ed33..4ccf41de0fe 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.List; @@ -13,7 +14,7 @@ import java.util.stream.Stream; * * @author bratseth */ -public class Random extends CompositeTensorFunction { +public class Random<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> { private final TensorType type; @@ -22,18 +23,18 @@ public class Random extends CompositeTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Random must have 0 arguments, got " + arguments.size()); return this; } @Override - public PrimitiveTensorFunction toPrimitive() { - return new Generate(type, ScalarFunctions.random()); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Generate<>(type, ScalarFunctions.random()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java index d951ec9ccbd..f75d0f2cbef 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.List; @@ -15,7 +16,7 @@ import java.util.stream.Stream; * * @author bratseth */ -public class Range extends CompositeTensorFunction { +public class Range<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> { private final TensorType type; private final Function<List<Long>, Double> rangeFunction; @@ -26,18 +27,18 @@ public class Range extends CompositeTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Range must have 0 arguments, got " + arguments.size()); return this; } @Override - public PrimitiveTensorFunction toPrimitive() { - return new Generate(type, rangeFunction); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Generate<>(type, rangeFunction); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 017dc3920e6..1d24333623b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -24,21 +24,21 @@ import java.util.Set; * * @author bratseth */ -public class Reduce extends PrimitiveTensorFunction { +public class Reduce<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> { public enum Aggregator { avg, count, prod, sum, max, min; } - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final List<String> dimensions; private final Aggregator aggregator; /** Creates a reduce function reducing all dimensions */ - public Reduce(TensorFunction argument, Aggregator aggregator) { + public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator) { this(argument, aggregator, Collections.emptyList()); } /** Creates a reduce function reducing a single dimension */ - public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) { + public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator, String dimension) { this(argument, aggregator, Collections.singletonList(dimension)); } @@ -51,7 +51,7 @@ public class Reduce extends PrimitiveTensorFunction { * producing a dimensionless tensor (a scalar). * @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor */ - public Reduce(TensorFunction argument, Aggregator aggregator, List<String> dimensions) { + public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator, List<String> dimensions) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); Objects.requireNonNull(aggregator, "The aggregator cannot be null"); Objects.requireNonNull(dimensions, "The dimensions cannot be null"); @@ -70,25 +70,25 @@ public class Reduce extends PrimitiveTensorFunction { return b.build(); } - public TensorFunction argument() { return argument; } + public TensorFunction<NAMETYPE> argument() { return argument; } Aggregator aggregator() { return aggregator; } List<String> dimensions() { return dimensions; } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size()); - return new Reduce(arguments.get(0), aggregator, dimensions); + return new Reduce<>(arguments.get(0), aggregator, dimensions); } @Override - public PrimitiveTensorFunction toPrimitive() { - return new Reduce(argument.toPrimitive(), aggregator, dimensions); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Reduce<>(argument.toPrimitive(), aggregator, dimensions); } @Override @@ -104,7 +104,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext<NAMETYPE> context) { return type(argument.type(context), dimensions); } @@ -118,7 +118,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { return evaluate(this.argument.evaluate(context), dimensions, aggregator); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index 1134e8177ad..36c20b9e044 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -26,19 +26,19 @@ import java.util.stream.Collectors; * * @author lesters */ -public class ReduceJoin extends CompositeTensorFunction { +public class ReduceJoin<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction argumentA, argumentB; + private final TensorFunction<NAMETYPE> argumentA, argumentB; private final DoubleBinaryOperator combinator; private final Reduce.Aggregator aggregator; private final List<String> dimensions; - public ReduceJoin(Reduce reduce, Join join) { + public ReduceJoin(Reduce<NAMETYPE> reduce, Join<NAMETYPE> join) { this(join.arguments().get(0), join.arguments().get(1), join.combinator(), reduce.aggregator(), reduce.dimensions()); } - public ReduceJoin(TensorFunction argumentA, - TensorFunction argumentB, + public ReduceJoin(TensorFunction<NAMETYPE> argumentA, + TensorFunction<NAMETYPE> argumentB, DoubleBinaryOperator combinator, Reduce.Aggregator aggregator, List<String> dimensions) { @@ -50,25 +50,25 @@ public class ReduceJoin extends CompositeTensorFunction { } @Override - public List<TensorFunction> arguments() { + public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 2) throw new IllegalArgumentException("ReduceJoin must have 2 arguments, got " + arguments.size()); - return new ReduceJoin(arguments.get(0), arguments.get(1), combinator, aggregator, dimensions); + return new ReduceJoin<>(arguments.get(0), arguments.get(1), combinator, aggregator, dimensions); } @Override - public PrimitiveTensorFunction toPrimitive() { - Join join = new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); - return new Reduce(join, aggregator, dimensions); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + Join<NAMETYPE> join = new Join<>(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); + return new Reduce<>(join, aggregator, dimensions); } @Override - public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public final Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 5694684956e..6731f80cbce 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -20,18 +20,18 @@ import java.util.Objects; * * @author bratseth */ -public class Rename extends PrimitiveTensorFunction { +public class Rename<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> { - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final List<String> fromDimensions; private final List<String> toDimensions; private final Map<String, String> fromToMap; - public Rename(TensorFunction argument, String fromDimension, String toDimension) { + public Rename(TensorFunction<NAMETYPE> argument, String fromDimension, String toDimension) { this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension)); } - public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) { + public Rename(TensorFunction<NAMETYPE> argument, List<String> fromDimensions, List<String> toDimensions) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null"); Objects.requireNonNull(toDimensions, "The 'to' dimensions cannot be null"); @@ -57,20 +57,20 @@ public class Rename extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Rename must have 1 argument, got " + arguments.size()); - return new Rename(arguments.get(0), fromDimensions, toDimensions); + return new Rename<>(arguments.get(0), fromDimensions, toDimensions); } @Override - public PrimitiveTensorFunction toPrimitive() { return this; } + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext<NAMETYPE> context) { return type(argument.type(context)); } @@ -82,7 +82,7 @@ public class Rename extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor tensor = argument.evaluate(context); TensorType renamedType = type(tensor.type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index bd732cdc11e..4636871e19c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.functions; import com.google.common.collect.ImmutableList; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.List; @@ -10,12 +11,12 @@ import java.util.List; /** * @author bratseth */ -public class Softmax extends CompositeTensorFunction { +public class Softmax<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final String dimension; - public Softmax(TensorFunction argument, String dimension) { + public Softmax(TensorFunction<NAMETYPE> argument, String dimension) { this.argument = argument; this.dimension = dimension; } @@ -25,23 +26,23 @@ public class Softmax extends CompositeTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Softmax must have 1 argument, got " + arguments.size()); - return new Softmax(arguments.get(0), dimension); + return new Softmax<>(arguments.get(0), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument = argument.toPrimitive(); - return new Join(new Map(primitiveArgument, ScalarFunctions.exp()), - new Reduce(new Map(primitiveArgument, ScalarFunctions.exp()), - Reduce.Aggregator.sum, - dimension), - ScalarFunctions.divide()); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive(); + return new Join<>(new Map<>(primitiveArgument, ScalarFunctions.exp()), + new Reduce<>(new Map<>(primitiveArgument, ScalarFunctions.exp()), + Reduce.Aggregator.sum, + dimension), + ScalarFunctions.divide()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 810651bbcfb..1c52046a9be 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -16,17 +16,17 @@ import java.util.List; * * @author bratseth */ -public abstract class TensorFunction { +public abstract class TensorFunction<NAMETYPE extends TypeContext.Name> { /** Returns the function arguments of this node in the order they are applied */ - public abstract List<TensorFunction> arguments(); + public abstract List<TensorFunction<NAMETYPE>> arguments(); /** * Returns a copy of this tensor function with the arguments replaced by the given list of arguments. * * @throws IllegalArgumentException if the argument list has the wrong size for this function */ - public abstract TensorFunction withArguments(List<TensorFunction> arguments); + public abstract TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments); /** * Translate this function - and all of its arguments recursively - @@ -34,24 +34,24 @@ public abstract class TensorFunction { * * @return a tree of primitive functions implementing this */ - public abstract PrimitiveTensorFunction toPrimitive(); + public abstract PrimitiveTensorFunction<NAMETYPE> toPrimitive(); /** * Evaluates this tensor. * * @param context a context which must be passed to all nested functions when evaluating */ - public abstract <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context); + public abstract Tensor evaluate(EvaluationContext<NAMETYPE> context); /** * Returns the type of the tensor this produces given the input types in the context * * @param context a context which must be passed to all nexted functions when evaluating */ - public abstract <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context); + public abstract TensorType type(TypeContext<NAMETYPE> context); /** Evaluate with no context */ - public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); } + public final Tensor evaluate() { return evaluate(new MapEvaluationContext<>()); } /** * Return a string representation of this context. diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java index 0a881c0a290..0325753d2e0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java @@ -19,10 +19,10 @@ import java.util.stream.Collectors; * @author bratseth */ @Beta -public class Value extends PrimitiveTensorFunction { +public class Value<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> { - private final TensorFunction argument; - private final List<DimensionValue> cellAddress; + private final TensorFunction<NAMETYPE> argument; + private final List<DimensionValue<NAMETYPE>> cellAddress; /** * Creates a value function @@ -31,7 +31,7 @@ public class Value extends PrimitiveTensorFunction { * @param cellAddress a description of the address of the cell to return the value of. This is not a TensorAddress * because those require a type, but a type is not resolved until this is evaluated */ - public Value(TensorFunction argument, List<DimensionValue> cellAddress) { + public Value(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> cellAddress) { this.argument = Objects.requireNonNull(argument, "Argument cannot be null"); if (cellAddress.size() > 1 && cellAddress.stream().anyMatch(c -> c.dimension().isEmpty())) throw new IllegalArgumentException("Short form of cell addresses is only supported with a single dimension: " + @@ -40,34 +40,38 @@ public class Value extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> arguments() { return List.of(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); } @Override - public Value withArguments(List<TensorFunction> arguments) { + public Value<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if (arguments.size() != 1) throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size()); - return new Value(arguments.get(0), cellAddress); + return new Value<NAMETYPE>(arguments.get(0), cellAddress); } @Override - public PrimitiveTensorFunction toPrimitive() { return this; } + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor tensor = argument.evaluate(context); if (tensor.type().rank() != cellAddress.size()) throw new IllegalArgumentException("Type/address size mismatch: Cannot address a value with " + toString() + " to a tensor of type " + tensor.type()); TensorAddress.Builder b = new TensorAddress.Builder(tensor.type()); for (int i = 0; i < cellAddress.size(); i++) { - b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()), - cellAddress.get(i).label()); + if (cellAddress.get(i).label().isPresent()) + b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()), + cellAddress.get(i).label().get()); + else + b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()), + String.valueOf(cellAddress.get(i).index().get().apply(context).intValue())); } return Tensor.from(tensor.get(b.build())); } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext<NAMETYPE> context) { return new TensorType.Builder(argument.type(context).valueType()).build(); } @@ -87,69 +91,94 @@ public class Value extends PrimitiveTensorFunction { else { return "{" + cellAddress.stream().map(i -> i.toString()).collect(Collectors.joining(", ")) + "}"; } - } + } + + public static class DimensionValue<NAMETYPE extends TypeContext.Name> { + + private final Optional<String> dimension; - public static class DimensionValue { + /** The label of this, or null if index is set */ + private final String label; - private final Optional<String> dimension; + /** The function returning the index of this, or null if label is set */ + private final ScalarFunction<NAMETYPE> index; + + public DimensionValue(String dimension, String label) { + this(Optional.of(dimension), label, null); + } + + public DimensionValue(String dimension, int index) { + this(Optional.of(dimension), null, new ConstantScalarFunction<>(index)); + } - /** The label of this. Always available, whether or not index is */ - private final String label; + public DimensionValue(int index) { + this(Optional.empty(), null, new ConstantScalarFunction<>(index)); + } - /** The index of this, or empty if this is a non-integer label */ - private final Optional<Integer> index; + public DimensionValue(String label) { + this(Optional.empty(), label, null); + } - public DimensionValue(String dimension, String label) { - this(Optional.of(dimension), label, indexOrEmpty(label)); - } + public DimensionValue(ScalarFunction<NAMETYPE> index) { + this(Optional.empty(), null, index); + } - public DimensionValue(String dimension, int index) { - this(Optional.of(dimension), String.valueOf(index), Optional.of(index)); - } + public DimensionValue(Optional<String> dimension, String label) { + this(dimension, label, null); + } - public DimensionValue(int index) { - this(Optional.empty(), String.valueOf(index), Optional.of(index)); - } + public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) { + this(dimension, null, index); + } - public DimensionValue(String label) { - this(Optional.empty(), label, indexOrEmpty(label)); - } + public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) { + this(Optional.of(dimension), null, index); + } - private DimensionValue(Optional<String> dimension, String label, Optional<Integer> index) { + private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) { this.dimension = dimension; this.label = label; this.index = index; - } - - /** - * Returns the given name of the dimension, or null if dense form is used, such that name - * must be inferred from order - */ - public Optional<String> dimension() { return dimension; } - - /** Returns the label or index for this dimension as a string */ - public String label() { return label; } - - /** Returns the index for this dimension, or empty if it is not a number */ - Optional<Integer> index() { return index; } - - @Override - public String toString() { - if (dimension.isPresent()) - return dimension.get() + ":" + label; - else - return label; - } - - private static Optional<Integer> indexOrEmpty(String label) { - try { - return Optional.of(Integer.parseInt(label)); - } - catch (IllegalArgumentException e) { - return Optional.empty(); - } - } - - } + } + + /** + * Returns the given name of the dimension, or null if dense form is used, such that name + * must be inferred from order + */ + public Optional<String> dimension() { return dimension; } + + /** Returns the label for this dimension or empty if it is provided by an index function */ + public Optional<String> label() { return Optional.ofNullable(label); } + + /** Returns the index expression for this dimension, or empty if it is not a number */ + public Optional<ScalarFunction<NAMETYPE>> index() { return Optional.ofNullable(index); } + + @Override + public String toString() { + StringBuilder b = new StringBuilder(); + dimension.ifPresent(d -> b.append(d).append(":")); + if (label != null) + b.append(label); + else + b.append(index); + return b.toString(); + } + + } + + private static class ConstantScalarFunction<NAMETYPE extends TypeContext.Name> implements ScalarFunction<NAMETYPE> { + + private final Double value; + + public ConstantScalarFunction(int value) { + this.value = (double)value; + } + + @Override + public Double apply(EvaluationContext<NAMETYPE> context) { + return value; + } + + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java index 4c0748ee39a..60b4438e909 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java @@ -2,18 +2,19 @@ package com.yahoo.tensor.functions; import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.List; /** * @author bratseth */ -public class XwPlusB extends CompositeTensorFunction { +public class XwPlusB<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction x, w, b; + private final TensorFunction<NAMETYPE> x, w, b; private final String dimension; - public XwPlusB(TensorFunction x, TensorFunction w, TensorFunction b, String dimension) { + public XwPlusB(TensorFunction<NAMETYPE> x, TensorFunction<NAMETYPE> w, TensorFunction<NAMETYPE> b, String dimension) { this.x = x; this.w = w; this.b = b; @@ -21,25 +22,25 @@ public class XwPlusB extends CompositeTensorFunction { } @Override - public List<TensorFunction> arguments() { return ImmutableList.of(x, w, b); } + public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(x, w, b); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 3) throw new IllegalArgumentException("XwPlusB must have 3 arguments, got " + arguments.size()); - return new XwPlusB(arguments.get(0), arguments.get(1), arguments.get(2), dimension); + return new XwPlusB<>(arguments.get(0), arguments.get(1), arguments.get(2), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveX = x.toPrimitive(); - TensorFunction primitiveW = w.toPrimitive(); - TensorFunction primitiveB = b.toPrimitive(); - return new Join(new Reduce(new Join(primitiveX, primitiveW, ScalarFunctions.multiply()), - Reduce.Aggregator.sum, - dimension), - primitiveB, - ScalarFunctions.add()); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitiveX = x.toPrimitive(); + TensorFunction<NAMETYPE> primitiveW = w.toPrimitive(); + TensorFunction<NAMETYPE> primitiveB = b.toPrimitive(); + return new Join<>(new Reduce<>(new Join<>(primitiveX, primitiveW, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + dimension), + primitiveB, + ScalarFunctions.add()); } @Override |