summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-11-22 09:30:37 +0100
committerLester Solbakken <lesters@oath.com>2019-11-22 09:30:37 +0100
commit8b3c453b66f891a59ca80bfc47abe63be1b9bace (patch)
treeaca088e543f040f836dd2cc67c4c7f5ded16a65b /model-integration
parentbf04cdc3471570c4cfd1ffa57a66eaad1f4263ae (diff)
Propagate constant values for ONNX import
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java35
2 files changed, 40 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index 714953fbd45..c60a9b85d10 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -125,6 +125,11 @@ class GraphImporter {
List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph);
operation = mapOperation(node, inputs, intermediateGraph);
+ // propagate constant values if all inputs are constant
+ if (operation.isConstant()) {
+ operation.setConstantValueFunction(operation::evaluateAsConstant);
+ }
+
if (isOutputNode(name, onnxGraph)) {
intermediateGraph.outputs(intermediateGraph.defaultSignature())
.put(IntermediateOperation.namePartOf(name), operation.vespaName());
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 87a3f1a8e66..6d0cdfc5021 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -5,6 +5,10 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
@@ -178,6 +182,37 @@ public abstract class IntermediateOperation {
return verifyInputs(expected, IntermediateOperation::function);
}
+ /** Recursively evaluates this operation's constant value to avoid doing it run-time. */
+ public Value evaluateAsConstant(OrderedTensorType type) {
+ if ( ! isConstant() ) {
+ throw new IllegalArgumentException("Attempted to evaluate non-constant operation as a constant.");
+ }
+ Value val = evaluateAsConstant(new MapContext(DoubleValue.NaN));
+ if ( ! val.asTensor().type().equals(type.type()) ) {
+ throw new IllegalArgumentException("Constant evaluation in " + name + " resulted in wrong type. " +
+ "Expected: " + type.type() + " Got: " + val.asTensor().type());
+ }
+ return val;
+ }
+
+ private Value evaluateAsConstant(Context context) {
+ String constantName = "constant(" + vespaName() + ")";
+ Value result = context.get(constantName);
+ if (result == DoubleValue.NaN) {
+ if (inputs.size() == 0) {
+ if (getConstantValue().isEmpty()) {
+ throw new IllegalArgumentException("Error in evaluating constant for " + name);
+ }
+ result = getConstantValue().get();
+ } else {
+ inputs.forEach(i -> i.evaluateAsConstant(context));
+ result = new TensorValue(lazyGetFunction().evaluate(context));
+ }
+ context.put(constantName, result);
+ }
+ return result;
+ }
+
/**
* Returns the largest value type among the input value types.
* This should only be called after it has been verified that input types are available.