diff options
author | Lester Solbakken <lesters@oath.com> | 2019-11-22 09:30:37 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-11-22 09:30:37 +0100 |
commit | 8b3c453b66f891a59ca80bfc47abe63be1b9bace (patch) | |
tree | aca088e543f040f836dd2cc67c4c7f5ded16a65b /model-integration | |
parent | bf04cdc3471570c4cfd1ffa57a66eaad1f4263ae (diff) |
Propagate constant values for ONNX import
Diffstat (limited to 'model-integration')
2 files changed, 40 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index 714953fbd45..c60a9b85d10 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -125,6 +125,11 @@ class GraphImporter { List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph); operation = mapOperation(node, inputs, intermediateGraph); + // propagate constant values if all inputs are constant + if (operation.isConstant()) { + operation.setConstantValueFunction(operation::evaluateAsConstant); + } + if (isOutputNode(name, onnxGraph)) { intermediateGraph.outputs(intermediateGraph.defaultSignature()) .put(IntermediateOperation.namePartOf(name), operation.vespaName()); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 87a3f1a8e66..6d0cdfc5021 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -5,6 +5,10 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; @@ -178,6 +182,37 @@ public abstract class IntermediateOperation { return verifyInputs(expected, IntermediateOperation::function); } + /** Recursively evaluates this operation's constant value to avoid doing it run-time. */ + public Value evaluateAsConstant(OrderedTensorType type) { + if ( ! isConstant() ) { + throw new IllegalArgumentException("Attempted to evaluate non-constant operation as a constant."); + } + Value val = evaluateAsConstant(new MapContext(DoubleValue.NaN)); + if ( ! val.asTensor().type().equals(type.type()) ) { + throw new IllegalArgumentException("Constant evaluation in " + name + " resulted in wrong type. " + + "Expected: " + type.type() + " Got: " + val.asTensor().type()); + } + return val; + } + + private Value evaluateAsConstant(Context context) { + String constantName = "constant(" + vespaName() + ")"; + Value result = context.get(constantName); + if (result == DoubleValue.NaN) { + if (inputs.size() == 0) { + if (getConstantValue().isEmpty()) { + throw new IllegalArgumentException("Error in evaluating constant for " + name); + } + result = getConstantValue().get(); + } else { + inputs.forEach(i -> i.evaluateAsConstant(context)); + result = new TensorValue(lazyGetFunction().evaluate(context)); + } + context.put(constantName, result); + } + return result; + } + /** * Returns the largest value type among the input value types. * This should only be called after it has been verified that input types are available. |