diff options
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java | 34 |
1 files changed, 16 insertions, 18 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index 78b30f0c873..4d1b5a97583 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -4,7 +4,6 @@ package ai.vespa.models.evaluation; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; @@ -31,22 +30,22 @@ public final class LazyArrayContext extends Context implements ContextIndex { public final static Value defaultContextValue = DoubleValue.zero; + private final ExpressionFunction function; + private final IndexedBindings indexedBindings; - private LazyArrayContext(IndexedBindings indexedBindings) { + private LazyArrayContext(ExpressionFunction function, IndexedBindings indexedBindings) { + this.function = function; this.indexedBindings = indexedBindings.copy(this); } - /** - * Create a fast lookup, lazy context for an expression. - * - * @param expression the expression to create a context for - */ - LazyArrayContext(RankingExpression expression, - Map<FunctionReference, ExpressionFunction> functions, + /** Create a fast lookup, lazy context for a function */ + LazyArrayContext(ExpressionFunction function, + Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants, Model model) { - this.indexedBindings = new IndexedBindings(expression, functions, constants, this, model); + this.function = function; + this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, this, model); } /** @@ -76,7 +75,6 @@ public final class LazyArrayContext extends Context implements ContextIndex { @Override public TensorType getType(Reference reference) { - // TODO: Add type information so we do not need to evaluate to get this return get(requireIndexOf(reference.toString())).type(); } @@ -128,7 +126,7 @@ public final class LazyArrayContext extends Context implements ContextIndex { * in a different thread or for re-binding free variables. */ LazyArrayContext copy() { - return new LazyArrayContext(indexedBindings); + return new LazyArrayContext(function, indexedBindings); } private static class IndexedBindings { @@ -154,15 +152,15 @@ public final class LazyArrayContext extends Context implements ContextIndex { * Creates indexed bindings for the given expressions. * The given expression and functions may be inspected but cannot be stored. */ - IndexedBindings(RankingExpression expression, - Map<FunctionReference, ExpressionFunction> functions, + IndexedBindings(ExpressionFunction function, + Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants, LazyArrayContext owner, Model model) { // 1. Determine and prepare bind targets Set<String> bindTargets = new LinkedHashSet<>(); Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation - extractBindTargets(expression.getRoot(), functions, bindTargets, arguments); + extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments); this.arguments = ImmutableSet.copyOf(arguments); values = new Value[bindTargets.size()]; @@ -183,10 +181,10 @@ public final class LazyArrayContext extends Context implements ContextIndex { values[index] = new TensorValue(constant.value()); } - for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { - Integer index = nameToIndex.get(function.getKey().serialForm()); + for (Map.Entry<FunctionReference, ExpressionFunction> referencedFunction : referencedFunctions.entrySet()) { + Integer index = nameToIndex.get(referencedFunction.getKey().serialForm()); if (index != null) // Referenced in this, so bind it - values[index] = new LazyValue(function.getKey(), owner, model); + values[index] = new LazyValue(referencedFunction.getKey(), owner, model); } } |