aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java92
1 files changed, 82 insertions, 10 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 2aa8b2a0d48..83e15a4081a 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -3,6 +3,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
@@ -13,6 +14,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.TensorFunction;
@@ -47,6 +49,8 @@ public abstract class IntermediateOperation {
protected TensorFunction rankingExpressionFunction = null;
protected boolean exportAsRankingFunction = false;
+ private boolean hasRenamedDimensions = false;
+
private final List<String> importWarnings = new ArrayList<>();
private Value constantValue = null;
private List<IntermediateOperation> controlInputs = Collections.emptyList();
@@ -121,7 +125,10 @@ public abstract class IntermediateOperation {
}
/** Performs dimension rename for this operation */
- public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); }
+ public void renameDimensions(DimensionRenamer renamer) {
+ type = type.rename(renamer);
+ hasRenamedDimensions = true;
+ }
/** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */
public boolean isInput() { return false; }
@@ -144,7 +151,11 @@ public abstract class IntermediateOperation {
}
/** Set the constant value function */
- public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { this.constantValueFunction = func; }
+ public void setConstantValueFunction(Function<OrderedTensorType, Value> func) {
+ this.constantValueFunction = func;
+ }
+
+ public boolean hasConstantValueFunction() { return constantValueFunction != null; }
/** Sets the external control inputs */
public void setControlInputs(List<IntermediateOperation> inputs) { this.controlInputs = inputs; }
@@ -153,12 +164,23 @@ public abstract class IntermediateOperation {
public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); }
/** Retrieve the valid Vespa name of this node */
- public String vespaName() { return vespaName(name); }
- public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null; }
+ public String vespaName() {
+ if (isConstant())
+ return modelName + "_" + vespaName(name);
+ return vespaName(name);
+ }
+
+ public String vespaName(String name) {
+ return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null;
+ }
/** Retrieve the valid Vespa name of this node if it is a ranking expression function */
public String rankingExpressionFunctionName() {
- return vespaName() != null ? FUNCTION_PREFIX + modelName + "_" + vespaName() : null;
+ String vespaName = vespaName();
+ if (vespaName == null) {
+ return null;
+ }
+ return isConstant() ? "constant(" + vespaName + ")" : FUNCTION_PREFIX + modelName + "_" + vespaName;
}
/** Retrieve the list of warnings produced during its lifetime */
@@ -185,30 +207,80 @@ public abstract class IntermediateOperation {
/** Recursively evaluates this operation's constant value to avoid doing it run-time. */
public Value evaluateAsConstant(OrderedTensorType type) {
+// System.out.println("Starting constant evaluation for " + name);
if ( ! isConstant() ) {
throw new IllegalArgumentException("Attempted to evaluate non-constant operation as a constant.");
}
- Value val = evaluateAsConstant(new MapContext(DoubleValue.NaN));
- if (type != null && ! val.asTensor().type().equals(type.type()) ) {
+ if (type == null) {
+ System.out.println("Evaluating as constant for " + name + " with type null! Probably an error.");
+ }
+
+ IntermediateOperation evaluateOn = this;
+ if ( ! hasRenamedDimensions) {
+ // make a copy of the tree, perform renaming and evaluate
+ IntermediateOperation copy = copyTree(0);
+ optimizeAndRename(copy);
+ evaluateOn = copy;
+ }
+ Value val = evaluateOn.evaluateAsConstant(new MapContext(DoubleValue.NaN), 0);
+
+ if (type == null) {
+ return val;
+ }
+ Tensor tensor = val.asTensor(); //.withType(type.type());
+ if ( ! tensor.type().isRenamableTo(type.type()) ) {
throw new IllegalArgumentException("Constant evaluation in " + name + " resulted in wrong type. " +
"Expected: " + type.type() + " Got: " + val.asTensor().type());
}
- return val;
+ // set constant value so we don't have to re-evaluate
+ setConstantValueFunction(t -> new TensorValue(tensor.withType(t.type())));
+// System.out.println("Returning constant evaluation for " + name);
+ return new TensorValue(tensor.withType(type.type()));
+ }
+
+ private IntermediateOperation copyTree(int indent) {
+ String indentString = ""; for (int i = 0; i < indent; ++i) indentString += " ";
+// System.out.println(indentString + "Copying " + name);
+ List<IntermediateOperation> in = new ArrayList<>();
+ if (constantValue != null) {
+// System.out.println(indentString + name + " has a constant value");
+ IntermediateOperation constant = new Constant(modelName, name, type);
+ constant.setConstantValueFunction(t -> new TensorValue(constantValue.asTensor().withType(t.type())));
+ return constant;
+ }
+ inputs.forEach(i -> in.add(i.copyTree(indent + 1)));
+ IntermediateOperation copy = withInputs(in);
+ if (constantValueFunction != null) {
+ copy.constantValueFunction = constantValueFunction; // works?
+ }
+ return copy;
+ }
+
+ private TensorFunction optimizeAndRename(IntermediateOperation op) {
+ IntermediateGraph graph = new IntermediateGraph(modelName);
+ graph.put(name, op);
+ graph.outputs(graph.defaultSignature()).put(name, name);
+ graph.optimize();
+ return op.function().get();
}
- private Value evaluateAsConstant(Context context) {
+ private Value evaluateAsConstant(Context context, int indent) {
+ String in = ""; for (int i = 0; i < indent; ++i) in += " ";
+// System.out.println(in + "Constant evaluating for " + name);
String constantName = "constant(" + vespaName() + ")";
Value result = context.get(constantName);
if (result == DoubleValue.NaN) {
if (constantValue != null) {
+// System.out.println(in + name + " has constant value.");
result = constantValue;
} else if (inputs.size() == 0) {
+// System.out.println(in + name + " has no inputs.");
if (getConstantValue().isEmpty()) {
throw new IllegalArgumentException("Error in evaluating constant for " + name);
}
result = getConstantValue().get();
} else {
- inputs.forEach(i -> i.evaluateAsConstant(context));
+ inputs.forEach(i -> i.evaluateAsConstant(context, indent+1));
result = new TensorValue(lazyGetFunction().evaluate(context));
}
context.put(constantName, result);