diff options
author | Lester Solbakken <lesters@oath.com> | 2020-05-06 11:18:57 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-05-06 11:18:57 +0200 |
commit | e0d728bc344dcc764455fa05aa1b6b67286a1e95 (patch) | |
tree | d5be394be0d81d49f5d2dd76a5f2f18a8141dfa5 /model-integration/src/main | |
parent | ecfb20f08d508f3711645e1497ef4379b0f71e28 (diff) |
Avoid double calculation in softmax
Diffstat (limited to 'model-integration/src/main')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java | 81 |
1 files changed, 62 insertions, 19 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java index d03827f4c72..5d9484ff0c8 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; import java.util.ArrayList; +import java.util.Collections; import java.util.List; /** @@ -23,24 +24,36 @@ public class Softmax extends IntermediateOperation { public Softmax(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { super(modelName, nodeName, inputs); this.attributeMap = attributeMap; + insert(new SoftmaxPartialOperation(modelName, nodeName, null), 0); // inputs are fixed in insert } @Override protected OrderedTensorType lazyGetType() { if ( ! allInputTypesPresent(1)) return null; - - // input is referenced twice due to overflow avoidance, so make this it's own function. - inputs.get(0).exportAsRankingFunction = true; - return inputs.get(0).type().get(); } @Override protected TensorFunction lazyGetFunction() { if ( ! allInputFunctionsPresent(1)) return null; + List<String> reduceDimensions = reduceDimensions(); + TensorFunction input = inputs.get(0).function().get(); + TensorFunction sum = new Reduce(input, Reduce.Aggregator.sum, reduceDimensions); + TensorFunction div = new Join(input, sum, ScalarFunctions.divide()); + System.out.println(div); + return div; + } - OrderedTensorType inputType = inputs.get(0).type().get(); + @Override + public Softmax withInputs(List<IntermediateOperation> inputs) { + return new Softmax(modelName(), name(), inputs, attributeMap); + } + @Override + public String operationName() { return "SoftMax"; } + + private List<String> reduceDimensions() { + OrderedTensorType inputType = inputs.get(0).type().get(); int axis = inputType.rank() == 1 ? 0 : 1; // assumption: first dimension is batch dimension if (attributeMap.get("axis").isPresent()) { axis = (int)attributeMap.get("axis").get().asDouble(); @@ -52,23 +65,53 @@ public class Softmax extends IntermediateOperation { for (int i = axis; i < inputType.rank(); ++i) { reduceDimensions.add(inputType.dimensions().get(i).name()); // Do softmax over all dimensions except batch dimension } + return reduceDimensions; + } - TensorFunction input = inputs.get(0).function().get(); - TensorFunction max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions); - TensorFunction cap = new Join(input, max, ScalarFunctions.subtract()); // to avoid overflow - TensorFunction exp = new Map(cap, ScalarFunctions.exp()); - TensorFunction sum = new Reduce(exp, Reduce.Aggregator.sum, reduceDimensions); - TensorFunction div = new Join(exp, sum, ScalarFunctions.divide()); + /* + * Operation to insert between input and this softmax to avoid double calculation + * Note that this partial operation should be removed when we have a specific + * softmax optimization in the backend, as this way of splitting the calculation + * makes the full softmax expression impossible to recognize. + */ + private class SoftmaxPartialOperation extends IntermediateOperation { - return div; - } + private SoftmaxPartialOperation(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName + "_partial" , inputs != null ? inputs : Collections.emptyList()); + } - @Override - public Softmax withInputs(List<IntermediateOperation> inputs) { - return new Softmax(modelName(), name(), inputs, attributeMap); - } + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(1)) return null; - @Override - public String operationName() { return "SoftMax"; } + // input is referenced twice due to overflow avoidance, so make sure it is exported as a ranking function + inputs.get(0).exportAsRankingFunction = true; + + // this should also be it's own function since we use it twice + exportAsRankingFunction = true; + + return inputs.get(0).type().get(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputFunctionsPresent(1)) return null; + List<String> reduceDimensions = reduceDimensions(); + TensorFunction input = inputs.get(0).function().get(); + TensorFunction max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions); + TensorFunction cap = new Join(input, max, ScalarFunctions.subtract()); // to avoid overflow + TensorFunction exp = new Map(cap, ScalarFunctions.exp()); + return exp; + } + + @Override + public SoftmaxPartialOperation withInputs(List<IntermediateOperation> inputs) { + return new SoftmaxPartialOperation(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "SoftMaxPartial"; } + + } } |