diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java | 27 |
1 files changed, 14 insertions, 13 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index bd732cdc11e..4636871e19c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.functions; import com.google.common.collect.ImmutableList; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.List; @@ -10,12 +11,12 @@ import java.util.List; /** * @author bratseth */ -public class Softmax extends CompositeTensorFunction { +public class Softmax<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final String dimension; - public Softmax(TensorFunction argument, String dimension) { + public Softmax(TensorFunction<NAMETYPE> argument, String dimension) { this.argument = argument; this.dimension = dimension; } @@ -25,23 +26,23 @@ public class Softmax extends CompositeTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Softmax must have 1 argument, got " + arguments.size()); - return new Softmax(arguments.get(0), dimension); + return new Softmax<>(arguments.get(0), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument = argument.toPrimitive(); - return new Join(new Map(primitiveArgument, ScalarFunctions.exp()), - new Reduce(new Map(primitiveArgument, ScalarFunctions.exp()), - Reduce.Aggregator.sum, - dimension), - ScalarFunctions.divide()); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive(); + return new Join<>(new Map<>(primitiveArgument, ScalarFunctions.exp()), + new Reduce<>(new Map<>(primitiveArgument, ScalarFunctions.exp()), + Reduce.Aggregator.sum, + dimension), + ScalarFunctions.divide()); } @Override |