aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java27
1 files changed, 14 insertions, 13 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
index bd732cdc11e..4636871e19c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.functions;
import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.List;
@@ -10,12 +11,12 @@ import java.util.List;
/**
* @author bratseth
*/
-public class Softmax extends CompositeTensorFunction {
+public class Softmax<NAMETYPE extends TypeContext.Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction argument;
+ private final TensorFunction<NAMETYPE> argument;
private final String dimension;
- public Softmax(TensorFunction argument, String dimension) {
+ public Softmax(TensorFunction<NAMETYPE> argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@@ -25,23 +26,23 @@ public class Softmax extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Softmax must have 1 argument, got " + arguments.size());
- return new Softmax(arguments.get(0), dimension);
+ return new Softmax<>(arguments.get(0), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument = argument.toPrimitive();
- return new Join(new Map(primitiveArgument, ScalarFunctions.exp()),
- new Reduce(new Map(primitiveArgument, ScalarFunctions.exp()),
- Reduce.Aggregator.sum,
- dimension),
- ScalarFunctions.divide());
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive();
+ return new Join<>(new Map<>(primitiveArgument, ScalarFunctions.exp()),
+ new Reduce<>(new Map<>(primitiveArgument, ScalarFunctions.exp()),
+ Reduce.Aggregator.sum,
+ dimension),
+ ScalarFunctions.divide());
}
@Override