summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-05-06 11:18:57 +0200
committerLester Solbakken <lesters@oath.com>2020-05-06 11:18:57 +0200
commite0d728bc344dcc764455fa05aa1b6b67286a1e95 (patch)
treed5be394be0d81d49f5d2dd76a5f2f18a8141dfa5 /model-integration
parentecfb20f08d508f3711645e1497ef4379b0f71e28 (diff)
Avoid double calculation in softmax
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java81
1 files changed, 62 insertions, 19 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
index d03827f4c72..5d9484ff0c8 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
@@ -9,6 +9,7 @@ import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.List;
/**
@@ -23,24 +24,36 @@ public class Softmax extends IntermediateOperation {
public Softmax(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
super(modelName, nodeName, inputs);
this.attributeMap = attributeMap;
+ insert(new SoftmaxPartialOperation(modelName, nodeName, null), 0); // inputs are fixed in insert
}
@Override
protected OrderedTensorType lazyGetType() {
if ( ! allInputTypesPresent(1)) return null;
-
- // input is referenced twice due to overflow avoidance, so make this it's own function.
- inputs.get(0).exportAsRankingFunction = true;
-
return inputs.get(0).type().get();
}
@Override
protected TensorFunction lazyGetFunction() {
if ( ! allInputFunctionsPresent(1)) return null;
+ List<String> reduceDimensions = reduceDimensions();
+ TensorFunction input = inputs.get(0).function().get();
+ TensorFunction sum = new Reduce(input, Reduce.Aggregator.sum, reduceDimensions);
+ TensorFunction div = new Join(input, sum, ScalarFunctions.divide());
+ System.out.println(div);
+ return div;
+ }
- OrderedTensorType inputType = inputs.get(0).type().get();
+ @Override
+ public Softmax withInputs(List<IntermediateOperation> inputs) {
+ return new Softmax(modelName(), name(), inputs, attributeMap);
+ }
+ @Override
+ public String operationName() { return "SoftMax"; }
+
+ private List<String> reduceDimensions() {
+ OrderedTensorType inputType = inputs.get(0).type().get();
int axis = inputType.rank() == 1 ? 0 : 1; // assumption: first dimension is batch dimension
if (attributeMap.get("axis").isPresent()) {
axis = (int)attributeMap.get("axis").get().asDouble();
@@ -52,23 +65,53 @@ public class Softmax extends IntermediateOperation {
for (int i = axis; i < inputType.rank(); ++i) {
reduceDimensions.add(inputType.dimensions().get(i).name()); // Do softmax over all dimensions except batch dimension
}
+ return reduceDimensions;
+ }
- TensorFunction input = inputs.get(0).function().get();
- TensorFunction max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions);
- TensorFunction cap = new Join(input, max, ScalarFunctions.subtract()); // to avoid overflow
- TensorFunction exp = new Map(cap, ScalarFunctions.exp());
- TensorFunction sum = new Reduce(exp, Reduce.Aggregator.sum, reduceDimensions);
- TensorFunction div = new Join(exp, sum, ScalarFunctions.divide());
+ /*
+ * Operation to insert between input and this softmax to avoid double calculation
+ * Note that this partial operation should be removed when we have a specific
+ * softmax optimization in the backend, as this way of splitting the calculation
+ * makes the full softmax expression impossible to recognize.
+ */
+ private class SoftmaxPartialOperation extends IntermediateOperation {
- return div;
- }
+ private SoftmaxPartialOperation(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName + "_partial" , inputs != null ? inputs : Collections.emptyList());
+ }
- @Override
- public Softmax withInputs(List<IntermediateOperation> inputs) {
- return new Softmax(modelName(), name(), inputs, attributeMap);
- }
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! allInputTypesPresent(1)) return null;
- @Override
- public String operationName() { return "SoftMax"; }
+ // input is referenced twice due to overflow avoidance, so make sure it is exported as a ranking function
+ inputs.get(0).exportAsRankingFunction = true;
+
+ // this should also be it's own function since we use it twice
+ exportAsRankingFunction = true;
+
+ return inputs.get(0).type().get();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! allInputFunctionsPresent(1)) return null;
+ List<String> reduceDimensions = reduceDimensions();
+ TensorFunction input = inputs.get(0).function().get();
+ TensorFunction max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions);
+ TensorFunction cap = new Join(input, max, ScalarFunctions.subtract()); // to avoid overflow
+ TensorFunction exp = new Map(cap, ScalarFunctions.exp());
+ return exp;
+ }
+
+ @Override
+ public SoftmaxPartialOperation withInputs(List<IntermediateOperation> inputs) {
+ return new SoftmaxPartialOperation(modelName(), name(), inputs);
+ }
+
+ @Override
+ public String operationName() { return "SoftMaxPartial"; }
+
+ }
}