diff options
26 files changed, 272 insertions, 52 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java index 4af26b72817..9a00ee5bbd0 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java @@ -94,7 +94,7 @@ public class DerivedConfiguration { summaries = new Summaries(search, deployLogger); summaryMap = new SummaryMap(search, summaries); juniperrc = new Juniperrc(search); - rankProfileList = new RankProfileList(search, attributeFields, rankProfileRegistry, queryProfiles, importedModels); + rankProfileList = new RankProfileList(search, search.rankingConstants(), attributeFields, rankProfileRegistry, queryProfiles, importedModels); indexingScript = new IndexingScript(search); indexInfo = new IndexInfo(search); indexSchema = new IndexSchema(search); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 10881ab9ce0..e58b3da4f72 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -3,24 +3,33 @@ package com.yahoo.searchdefinition.derived; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchdefinition.RankingConstants; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.Search; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; + import java.util.Map; +import java.util.logging.Logger; /** * The derived rank profiles of a search definition * * @author bratseth */ -public class RankProfileList extends Derived implements RankProfilesConfig.Producer { +public class RankProfileList extends Derived implements RankProfilesConfig.Producer, RankingConstantsConfig.Producer { + + private static final Logger log = Logger.getLogger(RankProfileList.class.getName()); private final Map<String, RawRankProfile> rankProfiles = new java.util.LinkedHashMap<>(); + private final RankingConstants rankingConstants; public static RankProfileList empty = new RankProfileList(); private RankProfileList() { + this.rankingConstants = new RankingConstants(); } /** @@ -30,11 +39,13 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ * @param attributeFields the attribute fields to create a ranking for */ public RankProfileList(Search search, + RankingConstants rankingConstants, AttributeFields attributeFields, RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles, ImportedModels importedModels) { setName(search == null ? "default" : search.getName()); + this.rankingConstants = rankingConstants; deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields); } @@ -78,4 +89,17 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ } } + @Override + public void getConfig(RankingConstantsConfig.Builder builder) { + for (RankingConstant constant : rankingConstants.asMap().values()) { + if ("".equals(constant.getFileReference())) + log.warning("Illegal file reference " + constant); // Let tests pass ... we should find a better way + else + builder.constant(new RankingConstantsConfig.Constant.Builder() + .name(constant.getName()) + .fileref(constant.getFileReference()) + .type(constant.getType())); + } + } + } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index 3e9d188670e..1b15233fead 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -169,6 +169,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri deployState.rankProfileRegistry(), deployState.getQueryProfiles().getRegistry()); this.rankProfileList = new RankProfileList(null, // null search -> global + rankingConstants, AttributeFields.empty, deployState.rankProfileRegistry(), deployState.getQueryProfiles().getRegistry(), diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java index 5fbf1efd0ea..0b434fd0c49 100755 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java @@ -43,6 +43,7 @@ import com.yahoo.search.pagetemplates.PageTemplatesConfig; import com.yahoo.search.query.profile.config.QueryProfilesConfig; import com.yahoo.searchdefinition.derived.RankProfileList; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.configdefinition.IlscriptsConfig; import com.yahoo.vespa.model.PortsMeta; import com.yahoo.vespa.model.Service; @@ -124,7 +125,8 @@ public final class ContainerCluster RoutingProviderConfig.Producer, ConfigserverConfig.Producer, ThreadpoolConfig.Producer, - RankProfilesConfig.Producer + RankProfilesConfig.Producer, + RankingConstantsConfig.Producer { @@ -712,6 +714,9 @@ public final class ContainerCluster rankProfileList.getConfig(builder); } + @Override + public void getConfig(RankingConstantsConfig.Builder builder) { rankProfileList.getConfig(builder); } + public void setMbusParams(MbusParams mbusParams) { this.mbusParams = mbusParams; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java index a6bf51a2503..b29ed0fc25b 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java @@ -75,16 +75,7 @@ public class DocumentDatabase extends AbstractConfigProducer implements @Override public void getConfig(RankingConstantsConfig.Builder builder) { - for (RankingConstant constant : derivedCfg.getSearch().rankingConstants().asMap().values()) { - if ("".equals(constant.getFileReference())) { - System.err.println("INVALID rank constant "+constant.getName()+" [missing file reference]"); // TODO: Throw or log warning - continue; - } - builder.constant(new RankingConstantsConfig.Constant.Builder() - .name(constant.getName()) - .fileref(constant.getFileReference()) - .type(constant.getType())); - } + derivedCfg.getRankProfileList().getConfig(builder); } @Override diff --git a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java index a6022f32528..91d7fd436f3 100644 --- a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java @@ -7,6 +7,7 @@ import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.container.ContainerCluster; import org.junit.After; @@ -57,16 +58,22 @@ public class ModelEvaluationTest { private void assertHasMlModels(VespaModel model) { ContainerCluster cluster = model.getContainerClusters().get("container"); + RankProfilesConfig.Builder b = new RankProfilesConfig.Builder(); cluster.getConfig(b); RankProfilesConfig config = new RankProfilesConfig(b); + + RankingConstantsConfig.Builder cb = new RankingConstantsConfig.Builder(); + cluster.getConfig(cb); + RankingConstantsConfig constantsConfig = new RankingConstantsConfig(cb); + assertEquals(4, config.rankprofile().size()); Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet()); assertTrue(modelNames.contains("xgboost_2_2")); assertTrue(modelNames.contains("mnist_softmax")); assertTrue(modelNames.contains("mnist_softmax_saved")); - ModelsEvaluator evaluator = new ModelsEvaluator(config); + ModelsEvaluator evaluator = new ModelsEvaluator(config, constantsConfig); assertEquals(4, evaluator.models().size()); Model xgboost = evaluator.models().get("xgboost_2_2"); diff --git a/configdefinitions/src/vespa/CMakeLists.txt b/configdefinitions/src/vespa/CMakeLists.txt index 0a7d4ef4381..94239f81bbe 100644 --- a/configdefinitions/src/vespa/CMakeLists.txt +++ b/configdefinitions/src/vespa/CMakeLists.txt @@ -44,6 +44,8 @@ vespa_generate_config(configdefinitions persistence.def) install_config_definition(persistence.def vespa.config.content.persistence.def) vespa_generate_config(configdefinitions rank-profiles.def) install_config_definition(rank-profiles.def vespa.config.search.rank-profiles.def) +vespa_generate_config(configdefinitions ranking-constants.def) +install_config_definition(ranking-constants.def vespa.config.search.core.ranking-constants.def) vespa_generate_config(configdefinitions routing.def) install_config_definition(routing.def cloud.config.routing.def) vespa_generate_config(configdefinitions routing-provider.def) diff --git a/searchcore/src/vespa/searchcore/config/ranking-constants.def b/configdefinitions/src/vespa/ranking-constants.def index 3b55eda3308..3b55eda3308 100644 --- a/searchcore/src/vespa/searchcore/config/ranking-constants.def +++ b/configdefinitions/src/vespa/ranking-constants.def diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Constant.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Constant.java new file mode 100644 index 00000000000..e664693ab38 --- /dev/null +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Constant.java @@ -0,0 +1,27 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import com.yahoo.tensor.Tensor; + +/** + * A named constant loaded from a file. + * + * This is immutable. + * + * @author bratseth + */ +class Constant { + + private final String name; + private final Tensor value; + + Constant(String name, Tensor value) { + this.name = name; + this.value = value; + } + + public String name() { return name; } + + public Tensor value() { return value; } + +} diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 520986ffb77..e08b9f77d15 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -56,4 +56,6 @@ public class FunctionEvaluator { return function.getBody().evaluate(context).asTensor(); } + LazyArrayContext context() { return context; } + } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index 2dcfd204077..beaa36b898f 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -16,6 +17,7 @@ import com.yahoo.tensor.TensorType; import java.util.Arrays; import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import java.util.Set; @@ -37,8 +39,11 @@ final class LazyArrayContext extends Context implements ContextIndex { * * @param expression the expression to create a context for */ - LazyArrayContext(RankingExpression expression, Map<FunctionReference, ExpressionFunction> functions, Model model) { - this.indexedBindings = new IndexedBindings(expression, functions, this, model); + LazyArrayContext(RankingExpression expression, + Map<FunctionReference, ExpressionFunction> functions, + List<Constant> constants, + Model model) { + this.indexedBindings = new IndexedBindings(expression, functions, constants, this, model); } /** @@ -139,8 +144,10 @@ final class LazyArrayContext extends Context implements ContextIndex { */ IndexedBindings(RankingExpression expression, Map<FunctionReference, ExpressionFunction> functions, + List<Constant> constants, LazyArrayContext owner, Model model) { + // 1. Determine and prepare bind targets Set<String> bindTargets = new LinkedHashSet<>(); extractBindTargets(expression.getRoot(), functions, bindTargets); @@ -150,9 +157,18 @@ final class LazyArrayContext extends Context implements ContextIndex { int i = 0; ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>(); for (String variable : bindTargets) - nameToIndexBuilder.put(variable,i++); + nameToIndexBuilder.put(variable, i++); nameToIndex = nameToIndexBuilder.build(); + + // 2. Bind the bind targets + for (Constant constant : constants) { + String constantReference = "constant(" + constant.name() + ")"; + Integer index = nameToIndex.get(constantReference); + if (index != null) + values[index] = new TensorValue(constant.value()); + } + for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { Integer index = nameToIndex.get(function.getKey().serialForm()); if (index != null) // Referenced in this, so bind it @@ -170,7 +186,7 @@ final class LazyArrayContext extends Context implements ContextIndex { extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets); } else if (isConstant(node)) { - // Ignore + bindTargets.add(node.toString()); } else if (node instanceof ReferenceNode) { bindTargets.add(node.toString()); @@ -193,7 +209,7 @@ final class LazyArrayContext extends Context implements ContextIndex { if ( ! (node instanceof ReferenceNode)) return false; ReferenceNode reference = (ReferenceNode)node; - return reference.getName().equals("value") && reference.getArguments().size() == 1; + return reference.getName().equals("constant") && reference.getArguments().size() == 1; } Value get(int index) { return values[index]; } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 95eb923786d..3fb43d73187 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -36,11 +36,15 @@ public class Model { private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer(); + /** Programmatically create a model containing functions without constant of function references only */ public Model(String name, Collection<ExpressionFunction> functions) { - this(name, functions, Collections.emptyMap()); + this(name, functions, Collections.emptyMap(), Collections.emptyList()); } - Model(String name, Collection<ExpressionFunction> functions, Map<FunctionReference, ExpressionFunction> referencedFunctions) { + Model(String name, + Collection<ExpressionFunction> functions, + Map<FunctionReference, ExpressionFunction> referencedFunctions, + List<Constant> constants) { // TODO: Optimize functions this.name = name; this.functions = ImmutableList.copyOf(functions); @@ -48,7 +52,8 @@ public class Model { ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>(); for (ExpressionFunction function : functions) { try { - contextBuilder.put(function.getName(), new LazyArrayContext(function.getBody(), referencedFunctions, this)); + contextBuilder.put(function.getName(), + new LazyArrayContext(function.getBody(), referencedFunctions, constants, this)); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index dacf20b7ef2..48c71b5a04a 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -5,6 +5,7 @@ import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableMap; import com.yahoo.component.AbstractComponent; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import java.util.Map; import java.util.stream.Collectors; @@ -21,8 +22,8 @@ public class ModelsEvaluator extends AbstractComponent { private final ImmutableMap<String, Model> models; - public ModelsEvaluator(RankProfilesConfig config) { - models = ImmutableMap.copyOf(new RankProfilesConfigImporter().importFrom(config)); + public ModelsEvaluator(RankProfilesConfig config, RankingConstantsConfig constantsConfig) { + models = ImmutableMap.copyOf(new RankProfilesConfigImporter().importFrom(config, constantsConfig)); } /** Returns the models of this as an immutable map */ diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index bfd6342218a..b9e7a27c013 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -1,33 +1,57 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import com.yahoo.filedistribution.fileacquirer.FileAcquirer; +import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.io.IOUtils; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; /** - * Converts RankProfilesConfig instances to RankingExpressions for evaluation + * Converts RankProfilesConfig instances to RankingExpressions for evaluation. + * This class can be used by a single thread only. * * @author bratseth */ class RankProfilesConfigImporter { /** + * Constants already imported in this while reading some expression. + * This is to avoid re-reading constants referenced + * multiple places, as that is potentially costly. + */ + private Map<String, Constant> globalImportedConstants = new HashMap<>(); + + /** * Returns a map of the models contained in this config, indexed on name. * The map is modifiable and owned by the caller. */ - Map<String, Model> importFrom(RankProfilesConfig config) { + Map<String, Model> importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig) { + globalImportedConstants.clear(); try { Map<String, Model> models = new HashMap<>(); for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) { - Model model = importProfile(profile); + Model model = importProfile(profile, constantsConfig); models.put(model.name(), model); } return models; @@ -37,11 +61,14 @@ class RankProfilesConfigImporter { } } - private Model importProfile(RankProfilesConfig.Rankprofile profile) throws ParseException { + private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig) throws ParseException { List<ExpressionFunction> functions = new ArrayList<>(); Map<FunctionReference, ExpressionFunction> referencedFunctions = new HashMap<>(); ExpressionFunction firstPhase = null; ExpressionFunction secondPhase = null; + + List<Constant> constants = readConstants(constantsConfig); + for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) { Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name()); if ( reference.isPresent()) { @@ -69,7 +96,7 @@ class RankProfilesConfigImporter { functions.add(secondPhase); try { - return new Model(profile.name(), functions, referencedFunctions); + return new Model(profile.name(), functions, referencedFunctions, constants); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e); @@ -83,4 +110,30 @@ class RankProfilesConfigImporter { return null; } + private List<Constant> readConstants(RankingConstantsConfig constantsConfig) { + List<Constant> constants = new ArrayList<>(); + for (RankingConstantsConfig.Constant constantConfig : constantsConfig.constant()) { + constants.add(new Constant(constantConfig.name(), + readTensorFromFile(TensorType.fromSpec(constantConfig.type()), + constantConfig.fileref().value()))); + } + return constants; + } + + private Tensor readTensorFromFile(TensorType type, String fileName) { + try { + if (fileName.endsWith(".tbf")) + return TypedBinaryFormat.decode(Optional.of(type), + GrowableByteBuffer.wrap(IOUtils.readFileBytes(new File(fileName)))); + // TODO: Support json and json.lz4 + + if (fileName.isEmpty()) // this is the case in unit tests + return Tensor.from(type, "{}"); + throw new IllegalArgumentException("Unknown tensor file format (determined by file ending): " + fileName); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index 60cf0d25ded..d94e5b2af1b 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -3,8 +3,10 @@ package ai.vespa.models.evaluation; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.config.subscription.FileSource; +import com.yahoo.path.Path; import com.yahoo.tensor.Tensor; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import org.junit.Test; import java.io.File; @@ -18,15 +20,9 @@ public class ModelsEvaluatorTest { private static final double delta = 0.00000000001; - private ModelsEvaluator createModels() { - String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg"; - RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig(""); - return new ModelsEvaluator(config); - } - @Test public void testTensorEvaluation() { - ModelsEvaluator models = createModels(); + ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/"); FunctionEvaluator function = models.evaluatorOf("macros", "fourtimessum"); function.bind("var1", Tensor.from("{{x:0}:3,{x:1}:5}")); function.bind("var2", Tensor.from("{{x:0}:7,{x:1}:11}")); @@ -35,7 +31,7 @@ public class ModelsEvaluatorTest { @Test public void testEvaluationDependingOnMacroTakingArguments() { - ModelsEvaluator models = createModels(); + ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/"); FunctionEvaluator function = models.evaluatorOf("macros", "secondphase"); function.bind("match", 3); function.bind("rankBoost", 5); @@ -46,6 +42,14 @@ public class ModelsEvaluatorTest { // TODO: Test that binding nonexisting variable doesn't work // TODO: Test that rebinding doesn't work // TODO: Test with nested macros - // TODO: Test TF/ONNX model + + private ModelsEvaluator createModels(String path) { + Path configDir = Path.fromString(path); + RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), + RankProfilesConfig.class).getConfig(""); + RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), + RankingConstantsConfig.class).getConfig(""); + return new ModelsEvaluator(config, constantsConfig); + } } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java index d45372fc7da..84e01e58280 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java @@ -3,8 +3,10 @@ package ai.vespa.models.evaluation; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.config.subscription.FileSource; +import com.yahoo.path.Path; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import org.junit.Test; import java.io.File; @@ -21,14 +23,41 @@ import static org.junit.Assert.assertNotNull; public class RankProfilesImporterTest { @Test - public void testImporting() { - String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg"; - RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig(""); - Map<String, Model> models = new RankProfilesConfigImporter().importFrom(config); + public void testImportingModels() { + Map<String, Model> models = createModels("src/test/resources/config/models/"); + + assertEquals(4, models.size()); + + Model xgboost = models.get("xgboost_2_2"); + assertFunction("xgboost_2_2", + "(optimized sum of condition trees of size 192 bytes)", + xgboost); + + Model onnxMnistSoftmax = models.get("mnist_softmax"); + assertFunction("default.add", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", + onnxMnistSoftmax); + assertEquals("tensor(d1[10],d2[784])", + onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); + + Model tfMnistSoftmax = models.get("mnist_softmax_saved"); + assertFunction("serving_default.y", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", + tfMnistSoftmax); + + Model tfMnist = models.get("mnist_saved"); + assertFunction("serving_default.y", + "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", + tfMnist); + } + + @Test + public void testImportingRankExpressions() { + Map<String, Model> models = createModels("src/test/resources/config/rankexpression/"); + assertEquals(18, models.size()); Model macros = models.get("macros"); - assertNotNull(macros); assertEquals("macros", macros.name()); assertEquals(4, macros.functions().size()); assertFunction("fourtimessum", "4 * (var1 + var2)", macros); @@ -44,8 +73,9 @@ public class RankProfilesImporterTest { } private void assertFunction(String name, String expression, Model model) { + assertNotNull("Model is present in config", model); ExpressionFunction function = model.function(name); - assertNotNull(function); + assertNotNull("Function '" + name + "' is in " + model, function); assertEquals(name, function.getName()); assertEquals(expression, function.getBody().getRoot().toString()); } @@ -57,4 +87,13 @@ public class RankProfilesImporterTest { assertEquals(expression, function.getBody().getRoot().toString()); } + private Map<String, Model> createModels(String path) { + Path configDir = Path.fromString(path); + RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), + RankProfilesConfig.class).getConfig(""); + RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), + RankingConstantsConfig.class).getConfig(""); + return new RankProfilesConfigImporter().importFrom(config, constantsConfig); + } + } diff --git a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg new file mode 100644 index 00000000000..1cc36f75158 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg @@ -0,0 +1,14 @@ +rankprofile[0].name "mnist_saved" +rankprofile[0].fef.property[0].name "rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add).rankingScript" +rankprofile[0].fef.property[0].value "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))" +rankprofile[0].fef.property[1].name "rankingExpression(serving_default.y).rankingScript" +rankprofile[0].fef.property[1].value "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))" +rankprofile[1].name "xgboost_2_2" +rankprofile[1].fef.property[0].name "rankingExpression(xgboost_2_2).rankingScript" +rankprofile[1].fef.property[0].value "if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)" +rankprofile[2].name "mnist_softmax_saved" +rankprofile[2].fef.property[0].name "rankingExpression(serving_default.y).rankingScript" +rankprofile[2].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))" +rankprofile[3].name "mnist_softmax" +rankprofile[3].fef.property[0].name "rankingExpression(default.add).rankingScript" +rankprofile[3].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))" diff --git a/model-evaluation/src/test/resources/config/models/ranking-constants.cfg b/model-evaluation/src/test/resources/config/models/ranking-constants.cfg new file mode 100644 index 00000000000..2b7495ace5e --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/ranking-constants.cfg @@ -0,0 +1,30 @@ +constant[0].name "mnist_saved_dnn_hidden1_weights_read" +constant[0].fileref "" +constant[0].type "tensor(d3[300],d4[784])" +constant[1].name "mnist_saved_dnn_hidden2_weights_read" +constant[1].fileref "" +constant[1].type "tensor(d2[100],d3[300])" +constant[2].name "mnist_softmax_saved_layer_Variable_1_read" +constant[2].fileref "" +constant[2].type "tensor(d1[10])" +constant[3].name "mnist_saved_dnn_hidden1_bias_read" +constant[3].fileref "" +constant[3].type "tensor(d3[300])" +constant[4].name "mnist_saved_dnn_hidden2_bias_read" +constant[4].fileref "" +constant[4].type "tensor(d2[100])" +constant[5].name "mnist_softmax_Variable" +constant[5].fileref "" +constant[5].type "tensor(d1[10],d2[784])" +constant[6].name "mnist_saved_dnn_outputs_weights_read" +constant[6].fileref "" +constant[6].type "tensor(d1[10],d2[100])" +constant[7].name "mnist_softmax_saved_layer_Variable_read" +constant[7].fileref "" +constant[7].type "tensor(d1[10],d2[784])" +constant[8].name "mnist_softmax_Variable_1" +constant[8].fileref "" +constant[8].type "tensor(d1[10])" +constant[9].name "mnist_saved_dnn_outputs_bias_read" +constant[9].fileref "" +constant[9].type "tensor(d1[10])"
\ No newline at end of file diff --git a/model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg b/model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg diff --git a/searchcore/pom.xml b/searchcore/pom.xml index 3b43bf1205e..002ba1f508f 100644 --- a/searchcore/pom.xml +++ b/searchcore/pom.xml @@ -1,3 +1,4 @@ +<?xml version="1.0"?> <!-- Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" @@ -11,8 +12,8 @@ <relativePath>../parent/pom.xml</relativePath> </parent> <artifactId>searchcore</artifactId> - <version>6-SNAPSHOT</version> <packaging>jar</packaging> + <version>6-SNAPSHOT</version> <name>${project.artifactId}</name> <dependencies> <dependency> diff --git a/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp b/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp index f60863ef0b0..4f2720718ff 100644 --- a/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp +++ b/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp @@ -4,6 +4,7 @@ #include <vespa/config-imported-fields.h> #include <vespa/config-indexschema.h> #include <vespa/config-rank-profiles.h> +#include <vespa/config-ranking-constants.h> #include <vespa/config/config.h> #include <vespa/config/helper/legacy.h> #include <vespa/config/common/exceptions.h> @@ -11,7 +12,6 @@ #include <vespa/eval/eval/value_cache/constant_value.h> #include <vespa/eval/tensor/default_tensor_engine.h> #include <vespa/searchcommon/common/schemaconfigurer.h> -#include <vespa/searchcore/config/config-ranking-constants.h> #include <vespa/searchcore/proton/matching/error_constant_value.h> #include <vespa/searchcore/proton/matching/indexenvironment.h> #include <vespa/searchlib/features/setup.h> diff --git a/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp b/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp index 2817ddb1b85..f2e5fa7c152 100644 --- a/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp +++ b/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp @@ -5,6 +5,7 @@ #include <vespa/config-imported-fields.h> #include <vespa/config-indexschema.h> #include <vespa/config-rank-profiles.h> +#include <vespa/config-ranking-constants.h> #include <vespa/config-summary.h> #include <vespa/config-summarymap.h> #include <vespa/document/repo/documenttyperepo.h> @@ -17,7 +18,6 @@ #include <vespa/searchcore/proton/server/i_proton_configurer.h> #include <vespa/searchcore/proton/common/hw_info.h> #include <vespa/searchsummary/config/config-juniperrc.h> -#include <vespa/searchcore/config/config-ranking-constants.h> #include <vespa/vespalib/testkit/testapp.h> #include <vespa/vespalib/util/varholder.h> #include <vespa/config-bucketspaces.h> diff --git a/searchcore/src/tests/proton/proton_configurer/proton_configurer_test.cpp b/searchcore/src/tests/proton/proton_configurer/proton_configurer_test.cpp index dfb1268aaa6..e42016c9577 100644 --- a/searchcore/src/tests/proton/proton_configurer/proton_configurer_test.cpp +++ b/searchcore/src/tests/proton/proton_configurer/proton_configurer_test.cpp @@ -5,6 +5,7 @@ #include <vespa/config-imported-fields.h> #include <vespa/config-indexschema.h> #include <vespa/config-rank-profiles.h> +#include <vespa/config-ranking-constants.h> #include <vespa/config-summary.h> #include <vespa/config-summarymap.h> #include <vespa/document/repo/documenttyperepo.h> @@ -18,7 +19,6 @@ #include <vespa/searchcore/proton/server/i_proton_configurer_owner.h> #include <vespa/searchcore/proton/server/i_proton_disk_layout.h> #include <vespa/searchsummary/config/config-juniperrc.h> -#include <vespa/searchcore/config/config-ranking-constants.h> #include <vespa/vespalib/testkit/testapp.h> #include <vespa/searchcommon/common/schemaconfigurer.h> #include <vespa/vespalib/util/threadstackexecutor.h> diff --git a/searchcore/src/vespa/searchcore/config/CMakeLists.txt b/searchcore/src/vespa/searchcore/config/CMakeLists.txt index 3d62309161c..186c9b68d34 100644 --- a/searchcore/src/vespa/searchcore/config/CMakeLists.txt +++ b/searchcore/src/vespa/searchcore/config/CMakeLists.txt @@ -9,6 +9,4 @@ vespa_generate_config(searchcore_fconfig fdispatchrc.def) install_config_definition(fdispatchrc.def vespa.config.search.core.fdispatchrc.def) vespa_generate_config(searchcore_fconfig proton.def) install_config_definition(proton.def vespa.config.search.core.proton.def) -vespa_generate_config(searchcore_fconfig ranking-constants.def) -install_config_definition(ranking-constants.def vespa.config.search.core.ranking-constants.def) vespa_generate_config(searchcore_fconfig hwinfo.def) diff --git a/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.cpp b/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.cpp index fd1f9f1155d..b9a3d055cd1 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.cpp @@ -5,12 +5,12 @@ #include <vespa/config-imported-fields.h> #include <vespa/config-indexschema.h> #include <vespa/config-rank-profiles.h> +#include <vespa/config-ranking-constants.h> #include <vespa/config-summary.h> #include <vespa/config-summarymap.h> #include <vespa/searchsummary/config/config-juniperrc.h> #include <vespa/document/config/config-documenttypes.h> #include <vespa/document/repo/documenttyperepo.h> -#include <vespa/searchcore/config/config-ranking-constants.h> #include <vespa/searchcore/proton/attribute/attribute_aspect_delayer.h> #include <vespa/searchcore/proton/common/document_type_inspector.h> #include <vespa/searchcore/proton/common/indexschema_inspector.h> diff --git a/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp b/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp index f2230215c3d..431f7554416 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp @@ -3,9 +3,9 @@ #include "documentdbconfigmanager.h" #include "bootstrapconfig.h" #include <vespa/searchcore/proton/common/hw_info.h> -#include <vespa/searchcore/config/config-ranking-constants.h> #include <vespa/config-imported-fields.h> #include <vespa/config-rank-profiles.h> +#include <vespa/config-ranking-constants.h> #include <vespa/config-summarymap.h> #include <vespa/config/file_acquirer/file_acquirer.h> #include <vespa/config/helper/legacy.h> |