summaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main/java
diff options
context:
space:
mode:
Diffstat (limited to 'model-evaluation/src/main/java')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Constant.java27
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java26
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java11
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java5
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java63
6 files changed, 119 insertions, 15 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Constant.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Constant.java
new file mode 100644
index 00000000000..e664693ab38
--- /dev/null
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Constant.java
@@ -0,0 +1,27 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.yahoo.tensor.Tensor;
+
+/**
+ * A named constant loaded from a file.
+ *
+ * This is immutable.
+ *
+ * @author bratseth
+ */
+class Constant {
+
+ private final String name;
+ private final Tensor value;
+
+ Constant(String name, Tensor value) {
+ this.name = name;
+ this.value = value;
+ }
+
+ public String name() { return name; }
+
+ public Tensor value() { return value; }
+
+}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index 520986ffb77..e08b9f77d15 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -56,4 +56,6 @@ public class FunctionEvaluator {
return function.getBody().evaluate(context).asTensor();
}
+ LazyArrayContext context() { return context; }
+
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index 2dcfd204077..beaa36b898f 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -16,6 +17,7 @@ import com.yahoo.tensor.TensorType;
import java.util.Arrays;
import java.util.LinkedHashSet;
+import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -37,8 +39,11 @@ final class LazyArrayContext extends Context implements ContextIndex {
*
* @param expression the expression to create a context for
*/
- LazyArrayContext(RankingExpression expression, Map<FunctionReference, ExpressionFunction> functions, Model model) {
- this.indexedBindings = new IndexedBindings(expression, functions, this, model);
+ LazyArrayContext(RankingExpression expression,
+ Map<FunctionReference, ExpressionFunction> functions,
+ List<Constant> constants,
+ Model model) {
+ this.indexedBindings = new IndexedBindings(expression, functions, constants, this, model);
}
/**
@@ -139,8 +144,10 @@ final class LazyArrayContext extends Context implements ContextIndex {
*/
IndexedBindings(RankingExpression expression,
Map<FunctionReference, ExpressionFunction> functions,
+ List<Constant> constants,
LazyArrayContext owner,
Model model) {
+ // 1. Determine and prepare bind targets
Set<String> bindTargets = new LinkedHashSet<>();
extractBindTargets(expression.getRoot(), functions, bindTargets);
@@ -150,9 +157,18 @@ final class LazyArrayContext extends Context implements ContextIndex {
int i = 0;
ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>();
for (String variable : bindTargets)
- nameToIndexBuilder.put(variable,i++);
+ nameToIndexBuilder.put(variable, i++);
nameToIndex = nameToIndexBuilder.build();
+
+ // 2. Bind the bind targets
+ for (Constant constant : constants) {
+ String constantReference = "constant(" + constant.name() + ")";
+ Integer index = nameToIndex.get(constantReference);
+ if (index != null)
+ values[index] = new TensorValue(constant.value());
+ }
+
for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
Integer index = nameToIndex.get(function.getKey().serialForm());
if (index != null) // Referenced in this, so bind it
@@ -170,7 +186,7 @@ final class LazyArrayContext extends Context implements ContextIndex {
extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets);
}
else if (isConstant(node)) {
- // Ignore
+ bindTargets.add(node.toString());
}
else if (node instanceof ReferenceNode) {
bindTargets.add(node.toString());
@@ -193,7 +209,7 @@ final class LazyArrayContext extends Context implements ContextIndex {
if ( ! (node instanceof ReferenceNode)) return false;
ReferenceNode reference = (ReferenceNode)node;
- return reference.getName().equals("value") && reference.getArguments().size() == 1;
+ return reference.getName().equals("constant") && reference.getArguments().size() == 1;
}
Value get(int index) { return values[index]; }
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 95eb923786d..3fb43d73187 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -36,11 +36,15 @@ public class Model {
private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer();
+ /** Programmatically create a model containing functions without constant of function references only */
public Model(String name, Collection<ExpressionFunction> functions) {
- this(name, functions, Collections.emptyMap());
+ this(name, functions, Collections.emptyMap(), Collections.emptyList());
}
- Model(String name, Collection<ExpressionFunction> functions, Map<FunctionReference, ExpressionFunction> referencedFunctions) {
+ Model(String name,
+ Collection<ExpressionFunction> functions,
+ Map<FunctionReference, ExpressionFunction> referencedFunctions,
+ List<Constant> constants) {
// TODO: Optimize functions
this.name = name;
this.functions = ImmutableList.copyOf(functions);
@@ -48,7 +52,8 @@ public class Model {
ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>();
for (ExpressionFunction function : functions) {
try {
- contextBuilder.put(function.getName(), new LazyArrayContext(function.getBody(), referencedFunctions, this));
+ contextBuilder.put(function.getName(),
+ new LazyArrayContext(function.getBody(), referencedFunctions, constants, this));
}
catch (RuntimeException e) {
throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e);
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index dacf20b7ef2..48c71b5a04a 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -5,6 +5,7 @@ import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableMap;
import com.yahoo.component.AbstractComponent;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import java.util.Map;
import java.util.stream.Collectors;
@@ -21,8 +22,8 @@ public class ModelsEvaluator extends AbstractComponent {
private final ImmutableMap<String, Model> models;
- public ModelsEvaluator(RankProfilesConfig config) {
- models = ImmutableMap.copyOf(new RankProfilesConfigImporter().importFrom(config));
+ public ModelsEvaluator(RankProfilesConfig config, RankingConstantsConfig constantsConfig) {
+ models = ImmutableMap.copyOf(new RankProfilesConfigImporter().importFrom(config, constantsConfig));
}
/** Returns the models of this as an immutable map */
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index bfd6342218a..b9e7a27c013 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -1,33 +1,57 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
+import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
+import com.yahoo.io.GrowableByteBuffer;
+import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import java.io.File;
+import java.io.IOException;
+import java.io.UncheckedIOException;
import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
/**
- * Converts RankProfilesConfig instances to RankingExpressions for evaluation
+ * Converts RankProfilesConfig instances to RankingExpressions for evaluation.
+ * This class can be used by a single thread only.
*
* @author bratseth
*/
class RankProfilesConfigImporter {
/**
+ * Constants already imported in this while reading some expression.
+ * This is to avoid re-reading constants referenced
+ * multiple places, as that is potentially costly.
+ */
+ private Map<String, Constant> globalImportedConstants = new HashMap<>();
+
+ /**
* Returns a map of the models contained in this config, indexed on name.
* The map is modifiable and owned by the caller.
*/
- Map<String, Model> importFrom(RankProfilesConfig config) {
+ Map<String, Model> importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig) {
+ globalImportedConstants.clear();
try {
Map<String, Model> models = new HashMap<>();
for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) {
- Model model = importProfile(profile);
+ Model model = importProfile(profile, constantsConfig);
models.put(model.name(), model);
}
return models;
@@ -37,11 +61,14 @@ class RankProfilesConfigImporter {
}
}
- private Model importProfile(RankProfilesConfig.Rankprofile profile) throws ParseException {
+ private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig) throws ParseException {
List<ExpressionFunction> functions = new ArrayList<>();
Map<FunctionReference, ExpressionFunction> referencedFunctions = new HashMap<>();
ExpressionFunction firstPhase = null;
ExpressionFunction secondPhase = null;
+
+ List<Constant> constants = readConstants(constantsConfig);
+
for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) {
Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name());
if ( reference.isPresent()) {
@@ -69,7 +96,7 @@ class RankProfilesConfigImporter {
functions.add(secondPhase);
try {
- return new Model(profile.name(), functions, referencedFunctions);
+ return new Model(profile.name(), functions, referencedFunctions, constants);
}
catch (RuntimeException e) {
throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e);
@@ -83,4 +110,30 @@ class RankProfilesConfigImporter {
return null;
}
+ private List<Constant> readConstants(RankingConstantsConfig constantsConfig) {
+ List<Constant> constants = new ArrayList<>();
+ for (RankingConstantsConfig.Constant constantConfig : constantsConfig.constant()) {
+ constants.add(new Constant(constantConfig.name(),
+ readTensorFromFile(TensorType.fromSpec(constantConfig.type()),
+ constantConfig.fileref().value())));
+ }
+ return constants;
+ }
+
+ private Tensor readTensorFromFile(TensorType type, String fileName) {
+ try {
+ if (fileName.endsWith(".tbf"))
+ return TypedBinaryFormat.decode(Optional.of(type),
+ GrowableByteBuffer.wrap(IOUtils.readFileBytes(new File(fileName))));
+ // TODO: Support json and json.lz4
+
+ if (fileName.isEmpty()) // this is the case in unit tests
+ return Tensor.from(type, "{}");
+ throw new IllegalArgumentException("Unknown tensor file format (determined by file ending): " + fileName);
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
}