aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-10-01 10:42:16 +0200
committerJon Bratseth <bratseth@oath.com>2018-10-01 10:42:16 +0200
commit50bc3b3c198d29374448cc3eac73fbb26e42cab0 (patch)
tree668c2fdcf18b25fda38e1faa10bd479b76e1ecb6 /model-evaluation/src/main
parent0ff988ecf9704faac33f6201cb59349e48846457 (diff)
Fill in missing types
Diffstat (limited to 'model-evaluation/src/main')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java30
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java24
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java7
3 files changed, 46 insertions, 15 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index c7d0cbd8f30..d144411127e 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -1,7 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
@@ -15,7 +17,9 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
@@ -110,6 +114,9 @@ public final class LazyArrayContext extends Context implements ContextIndex {
@Override
public Set<String> names() { return indexedBindings.names(); }
+ /** Returns the (immutable) subset of names in this which must be bound when invoking */
+ public Set<String> arguments() { return indexedBindings.arguments(); }
+
private Integer requireIndexOf(String name) {
Integer index = indexedBindings.indexOf(name);
if (index == null)
@@ -130,12 +137,18 @@ public final class LazyArrayContext extends Context implements ContextIndex {
/** The mapping from variable name to index */
private final ImmutableMap<String, Integer> nameToIndex;
+ /** The names which neeeds to be bound externally when envoking this (i.e not constant or invocation */
+ private final ImmutableSet<String> arguments;
+
/** The current values set, pre-converted to doubles */
private final Value[] values;
- private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, Value[] values) {
+ private IndexedBindings(ImmutableMap<String, Integer> nameToIndex,
+ Value[] values,
+ ImmutableSet<String> arguments) {
this.nameToIndex = nameToIndex;
this.values = values;
+ this.arguments = arguments;
}
/**
@@ -149,8 +162,10 @@ public final class LazyArrayContext extends Context implements ContextIndex {
Model model) {
// 1. Determine and prepare bind targets
Set<String> bindTargets = new LinkedHashSet<>();
- extractBindTargets(expression.getRoot(), functions, bindTargets);
+ Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation
+ extractBindTargets(expression.getRoot(), functions, bindTargets, arguments);
+ this.arguments = ImmutableSet.copyOf(arguments);
values = new Value[bindTargets.size()];
Arrays.fill(values, DoubleValue.zero);
@@ -178,23 +193,25 @@ public final class LazyArrayContext extends Context implements ContextIndex {
private void extractBindTargets(ExpressionNode node,
Map<FunctionReference, ExpressionFunction> functions,
- Set<String> bindTargets) {
+ Set<String> bindTargets,
+ Set<String> arguments) {
if (isFunctionReference(node)) {
FunctionReference reference = FunctionReference.fromSerial(node.toString()).get();
bindTargets.add(reference.serialForm());
- extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets);
+ extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets, arguments);
}
else if (isConstant(node)) {
bindTargets.add(node.toString());
}
else if (node instanceof ReferenceNode) {
bindTargets.add(node.toString());
+ arguments.add(node.toString());
}
else if (node instanceof CompositeNode) {
CompositeNode cNode = (CompositeNode)node;
for (ExpressionNode child : cNode.children())
- extractBindTargets(child, functions, bindTargets);
+ extractBindTargets(child, functions, bindTargets, arguments);
}
}
@@ -215,13 +232,14 @@ public final class LazyArrayContext extends Context implements ContextIndex {
Value get(int index) { return values[index]; }
void set(int index, Value value) { values[index] = value; }
Set<String> names() { return nameToIndex.keySet(); }
+ Set<String> arguments() { return arguments; }
Integer indexOf(String name) { return nameToIndex.get(name); }
IndexedBindings copy(Context context) {
Value[] valueCopy = new Value[values.length];
for (int i = 0; i < values.length; i++)
valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue)values[i]).copyFor(context) : values[i];
- return new IndexedBindings(nameToIndex, valueCopy);
+ return new IndexedBindings(nameToIndex, valueCopy, arguments);
}
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index ac8f28677a4..fda1ae935ca 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.tensor.TensorType;
import java.util.Arrays;
import java.util.Collection;
@@ -38,28 +39,39 @@ public class Model {
/** Programmatically create a model containing functions without constant of function references only */
public Model(String name, Collection<ExpressionFunction> functions) {
- this(name, functions, Collections.emptyMap(), Collections.emptyList());
+ this(name,
+ functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)),
+ Collections.emptyMap(),
+ Collections.emptyList());
}
Model(String name,
- Collection<ExpressionFunction> functions,
+ Map<FunctionReference, ExpressionFunction> functions,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
List<Constant> constants) {
this.name = name;
- this.functions = ImmutableList.copyOf(functions);
+ // Build context and add missing function arguments (missing because it is legal to omit scalar type arguments)
ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>();
- for (ExpressionFunction function : functions) {
+ for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
try {
- contextBuilder.put(function.getName(),
- new LazyArrayContext(function.getBody(), referencedFunctions, constants, this));
+ LazyArrayContext context = new LazyArrayContext(function.getValue().getBody(), referencedFunctions, constants, this);
+ contextBuilder.put(function.getValue().getName(), context);
+ for (String argument : context.arguments()) {
+ if (function.getValue().argumentTypes().get(argument) == null)
+ functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty));
+ }
+ if ( ! function.getValue().returnType().isPresent())
+ functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty));
}
catch (RuntimeException e) {
throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e);
}
}
this.contextPrototypes = contextBuilder.build();
+ this.functions = ImmutableList.copyOf(functions.values());
+ // Optimize functions
ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>();
for (Map.Entry<FunctionReference, ExpressionFunction> function : referencedFunctions.entrySet()) {
ExpressionFunction optimizedFunction = optimize(function.getValue(),
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 648c6d931a9..fb424439592 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -78,7 +78,9 @@ public class RankProfilesConfigImporter {
Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name());
if ( reference.isPresent()) {
RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value());
- ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), Collections.emptyList(), expression);
+ ExpressionFunction function = new ExpressionFunction(reference.get().functionName(),
+ Collections.emptyList(),
+ expression);
if (reference.get().isFree()) // make available in model under configured name
functions.put(reference.get(), function);
@@ -120,14 +122,13 @@ public class RankProfilesConfigImporter {
constants.addAll(smallConstantsInfo.asConstants());
try {
- return new Model(profile.name(), functions.values(), referencedFunctions, constants);
+ return new Model(profile.name(), functions, referencedFunctions, constants);
}
catch (RuntimeException e) {
throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e);
}
}
- // TODO: Replace by lookup in map
private ExpressionFunction functionByName(String name, Collection<ExpressionFunction> functions) {
for (ExpressionFunction function : functions)
if (function.getName().equals(name))