diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-10-01 10:42:16 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-10-01 10:42:16 +0200 |
commit | 50bc3b3c198d29374448cc3eac73fbb26e42cab0 (patch) | |
tree | 668c2fdcf18b25fda38e1faa10bd479b76e1ecb6 /model-evaluation/src/main | |
parent | 0ff988ecf9704faac33f6201cb59349e48846457 (diff) |
Fill in missing types
Diffstat (limited to 'model-evaluation/src/main')
3 files changed, 46 insertions, 15 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index c7d0cbd8f30..d144411127e 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -1,7 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; @@ -15,7 +17,9 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -110,6 +114,9 @@ public final class LazyArrayContext extends Context implements ContextIndex { @Override public Set<String> names() { return indexedBindings.names(); } + /** Returns the (immutable) subset of names in this which must be bound when invoking */ + public Set<String> arguments() { return indexedBindings.arguments(); } + private Integer requireIndexOf(String name) { Integer index = indexedBindings.indexOf(name); if (index == null) @@ -130,12 +137,18 @@ public final class LazyArrayContext extends Context implements ContextIndex { /** The mapping from variable name to index */ private final ImmutableMap<String, Integer> nameToIndex; + /** The names which neeeds to be bound externally when envoking this (i.e not constant or invocation */ + private final ImmutableSet<String> arguments; + /** The current values set, pre-converted to doubles */ private final Value[] values; - private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, Value[] values) { + private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, + Value[] values, + ImmutableSet<String> arguments) { this.nameToIndex = nameToIndex; this.values = values; + this.arguments = arguments; } /** @@ -149,8 +162,10 @@ public final class LazyArrayContext extends Context implements ContextIndex { Model model) { // 1. Determine and prepare bind targets Set<String> bindTargets = new LinkedHashSet<>(); - extractBindTargets(expression.getRoot(), functions, bindTargets); + Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation + extractBindTargets(expression.getRoot(), functions, bindTargets, arguments); + this.arguments = ImmutableSet.copyOf(arguments); values = new Value[bindTargets.size()]; Arrays.fill(values, DoubleValue.zero); @@ -178,23 +193,25 @@ public final class LazyArrayContext extends Context implements ContextIndex { private void extractBindTargets(ExpressionNode node, Map<FunctionReference, ExpressionFunction> functions, - Set<String> bindTargets) { + Set<String> bindTargets, + Set<String> arguments) { if (isFunctionReference(node)) { FunctionReference reference = FunctionReference.fromSerial(node.toString()).get(); bindTargets.add(reference.serialForm()); - extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets); + extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets, arguments); } else if (isConstant(node)) { bindTargets.add(node.toString()); } else if (node instanceof ReferenceNode) { bindTargets.add(node.toString()); + arguments.add(node.toString()); } else if (node instanceof CompositeNode) { CompositeNode cNode = (CompositeNode)node; for (ExpressionNode child : cNode.children()) - extractBindTargets(child, functions, bindTargets); + extractBindTargets(child, functions, bindTargets, arguments); } } @@ -215,13 +232,14 @@ public final class LazyArrayContext extends Context implements ContextIndex { Value get(int index) { return values[index]; } void set(int index, Value value) { values[index] = value; } Set<String> names() { return nameToIndex.keySet(); } + Set<String> arguments() { return arguments; } Integer indexOf(String name) { return nameToIndex.get(name); } IndexedBindings copy(Context context) { Value[] valueCopy = new Value[values.length]; for (int i = 0; i < values.length; i++) valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue)values[i]).copyFor(context) : values[i]; - return new IndexedBindings(nameToIndex, valueCopy); + return new IndexedBindings(nameToIndex, valueCopy, arguments); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index ac8f28677a4..fda1ae935ca 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; +import com.yahoo.tensor.TensorType; import java.util.Arrays; import java.util.Collection; @@ -38,28 +39,39 @@ public class Model { /** Programmatically create a model containing functions without constant of function references only */ public Model(String name, Collection<ExpressionFunction> functions) { - this(name, functions, Collections.emptyMap(), Collections.emptyList()); + this(name, + functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)), + Collections.emptyMap(), + Collections.emptyList()); } Model(String name, - Collection<ExpressionFunction> functions, + Map<FunctionReference, ExpressionFunction> functions, Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants) { this.name = name; - this.functions = ImmutableList.copyOf(functions); + // Build context and add missing function arguments (missing because it is legal to omit scalar type arguments) ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>(); - for (ExpressionFunction function : functions) { + for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { try { - contextBuilder.put(function.getName(), - new LazyArrayContext(function.getBody(), referencedFunctions, constants, this)); + LazyArrayContext context = new LazyArrayContext(function.getValue().getBody(), referencedFunctions, constants, this); + contextBuilder.put(function.getValue().getName(), context); + for (String argument : context.arguments()) { + if (function.getValue().argumentTypes().get(argument) == null) + functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty)); + } + if ( ! function.getValue().returnType().isPresent()) + functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty)); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e); } } this.contextPrototypes = contextBuilder.build(); + this.functions = ImmutableList.copyOf(functions.values()); + // Optimize functions ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>(); for (Map.Entry<FunctionReference, ExpressionFunction> function : referencedFunctions.entrySet()) { ExpressionFunction optimizedFunction = optimize(function.getValue(), diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 648c6d931a9..fb424439592 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -78,7 +78,9 @@ public class RankProfilesConfigImporter { Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name()); if ( reference.isPresent()) { RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value()); - ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), Collections.emptyList(), expression); + ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), + Collections.emptyList(), + expression); if (reference.get().isFree()) // make available in model under configured name functions.put(reference.get(), function); @@ -120,14 +122,13 @@ public class RankProfilesConfigImporter { constants.addAll(smallConstantsInfo.asConstants()); try { - return new Model(profile.name(), functions.values(), referencedFunctions, constants); + return new Model(profile.name(), functions, referencedFunctions, constants); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e); } } - // TODO: Replace by lookup in map private ExpressionFunction functionByName(String name, Collection<ExpressionFunction> functions) { for (ExpressionFunction function : functions) if (function.getName().equals(name)) |