diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java | 92 |
1 files changed, 82 insertions, 10 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 2aa8b2a0d48..83e15a4081a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; @@ -13,6 +14,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.TensorFunction; @@ -47,6 +49,8 @@ public abstract class IntermediateOperation { protected TensorFunction rankingExpressionFunction = null; protected boolean exportAsRankingFunction = false; + private boolean hasRenamedDimensions = false; + private final List<String> importWarnings = new ArrayList<>(); private Value constantValue = null; private List<IntermediateOperation> controlInputs = Collections.emptyList(); @@ -121,7 +125,10 @@ public abstract class IntermediateOperation { } /** Performs dimension rename for this operation */ - public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); } + public void renameDimensions(DimensionRenamer renamer) { + type = type.rename(renamer); + hasRenamedDimensions = true; + } /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */ public boolean isInput() { return false; } @@ -144,7 +151,11 @@ public abstract class IntermediateOperation { } /** Set the constant value function */ - public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { this.constantValueFunction = func; } + public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { + this.constantValueFunction = func; + } + + public boolean hasConstantValueFunction() { return constantValueFunction != null; } /** Sets the external control inputs */ public void setControlInputs(List<IntermediateOperation> inputs) { this.controlInputs = inputs; } @@ -153,12 +164,23 @@ public abstract class IntermediateOperation { public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); } /** Retrieve the valid Vespa name of this node */ - public String vespaName() { return vespaName(name); } - public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null; } + public String vespaName() { + if (isConstant()) + return modelName + "_" + vespaName(name); + return vespaName(name); + } + + public String vespaName(String name) { + return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null; + } /** Retrieve the valid Vespa name of this node if it is a ranking expression function */ public String rankingExpressionFunctionName() { - return vespaName() != null ? FUNCTION_PREFIX + modelName + "_" + vespaName() : null; + String vespaName = vespaName(); + if (vespaName == null) { + return null; + } + return isConstant() ? "constant(" + vespaName + ")" : FUNCTION_PREFIX + modelName + "_" + vespaName; } /** Retrieve the list of warnings produced during its lifetime */ @@ -185,30 +207,80 @@ public abstract class IntermediateOperation { /** Recursively evaluates this operation's constant value to avoid doing it run-time. */ public Value evaluateAsConstant(OrderedTensorType type) { +// System.out.println("Starting constant evaluation for " + name); if ( ! isConstant() ) { throw new IllegalArgumentException("Attempted to evaluate non-constant operation as a constant."); } - Value val = evaluateAsConstant(new MapContext(DoubleValue.NaN)); - if (type != null && ! val.asTensor().type().equals(type.type()) ) { + if (type == null) { + System.out.println("Evaluating as constant for " + name + " with type null! Probably an error."); + } + + IntermediateOperation evaluateOn = this; + if ( ! hasRenamedDimensions) { + // make a copy of the tree, perform renaming and evaluate + IntermediateOperation copy = copyTree(0); + optimizeAndRename(copy); + evaluateOn = copy; + } + Value val = evaluateOn.evaluateAsConstant(new MapContext(DoubleValue.NaN), 0); + + if (type == null) { + return val; + } + Tensor tensor = val.asTensor(); //.withType(type.type()); + if ( ! tensor.type().isRenamableTo(type.type()) ) { throw new IllegalArgumentException("Constant evaluation in " + name + " resulted in wrong type. " + "Expected: " + type.type() + " Got: " + val.asTensor().type()); } - return val; + // set constant value so we don't have to re-evaluate + setConstantValueFunction(t -> new TensorValue(tensor.withType(t.type()))); +// System.out.println("Returning constant evaluation for " + name); + return new TensorValue(tensor.withType(type.type())); + } + + private IntermediateOperation copyTree(int indent) { + String indentString = ""; for (int i = 0; i < indent; ++i) indentString += " "; +// System.out.println(indentString + "Copying " + name); + List<IntermediateOperation> in = new ArrayList<>(); + if (constantValue != null) { +// System.out.println(indentString + name + " has a constant value"); + IntermediateOperation constant = new Constant(modelName, name, type); + constant.setConstantValueFunction(t -> new TensorValue(constantValue.asTensor().withType(t.type()))); + return constant; + } + inputs.forEach(i -> in.add(i.copyTree(indent + 1))); + IntermediateOperation copy = withInputs(in); + if (constantValueFunction != null) { + copy.constantValueFunction = constantValueFunction; // works? + } + return copy; + } + + private TensorFunction optimizeAndRename(IntermediateOperation op) { + IntermediateGraph graph = new IntermediateGraph(modelName); + graph.put(name, op); + graph.outputs(graph.defaultSignature()).put(name, name); + graph.optimize(); + return op.function().get(); } - private Value evaluateAsConstant(Context context) { + private Value evaluateAsConstant(Context context, int indent) { + String in = ""; for (int i = 0; i < indent; ++i) in += " "; +// System.out.println(in + "Constant evaluating for " + name); String constantName = "constant(" + vespaName() + ")"; Value result = context.get(constantName); if (result == DoubleValue.NaN) { if (constantValue != null) { +// System.out.println(in + name + " has constant value."); result = constantValue; } else if (inputs.size() == 0) { +// System.out.println(in + name + " has no inputs."); if (getConstantValue().isEmpty()) { throw new IllegalArgumentException("Error in evaluating constant for " + name); } result = getConstantValue().get(); } else { - inputs.forEach(i -> i.evaluateAsConstant(context)); + inputs.forEach(i -> i.evaluateAsConstant(context, indent+1)); result = new TensorValue(lazyGetFunction().evaluate(context)); } context.put(constantName, result); |