From 68c83365348fada16a913ba316b3e067f9f4a923 Mon Sep 17 00:00:00 2001 From: Harald Musum Date: Tue, 4 Sep 2018 11:42:35 +0200 Subject: Revert "Bratseth/handle large constants" --- .../derived/DerivedConfiguration.java | 2 +- .../searchdefinition/derived/RankProfileList.java | 26 +-------- .../java/com/yahoo/vespa/model/VespaModel.java | 1 - .../vespa/model/container/ContainerCluster.java | 7 +-- .../yahoo/vespa/model/search/DocumentDatabase.java | 11 +++- .../yahoo/config/model/ModelEvaluationTest.java | 9 +--- model-evaluation/pom.xml | 6 --- .../java/ai/vespa/models/evaluation/Constant.java | 27 ---------- .../vespa/models/evaluation/FunctionEvaluator.java | 2 - .../vespa/models/evaluation/LazyArrayContext.java | 26 ++------- .../java/ai/vespa/models/evaluation/Model.java | 11 ++-- .../vespa/models/evaluation/ModelsEvaluator.java | 5 +- .../evaluation/RankProfilesConfigImporter.java | 63 ++-------------------- .../models/evaluation/ModelsEvaluatorTest.java | 22 ++++---- .../evaluation/RankProfilesImporterTest.java | 51 +++--------------- .../test/resources/config/models/rank-profiles.cfg | 14 ----- .../resources/config/models/ranking-constants.cfg | 30 ----------- .../config/rankexpression/ranking-constants.cfg | 0 18 files changed, 44 insertions(+), 269 deletions(-) delete mode 100644 model-evaluation/src/main/java/ai/vespa/models/evaluation/Constant.java delete mode 100644 model-evaluation/src/test/resources/config/models/rank-profiles.cfg delete mode 100644 model-evaluation/src/test/resources/config/models/ranking-constants.cfg delete mode 100644 model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java index 9a00ee5bbd0..4af26b72817 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java @@ -94,7 +94,7 @@ public class DerivedConfiguration { summaries = new Summaries(search, deployLogger); summaryMap = new SummaryMap(search, summaries); juniperrc = new Juniperrc(search); - rankProfileList = new RankProfileList(search, search.rankingConstants(), attributeFields, rankProfileRegistry, queryProfiles, importedModels); + rankProfileList = new RankProfileList(search, attributeFields, rankProfileRegistry, queryProfiles, importedModels); indexingScript = new IndexingScript(search); indexInfo = new IndexInfo(search); indexSchema = new IndexSchema(search); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index e58b3da4f72..10881ab9ce0 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -3,33 +3,24 @@ package com.yahoo.searchdefinition.derived; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfileRegistry; -import com.yahoo.searchdefinition.RankingConstant; -import com.yahoo.searchdefinition.RankingConstants; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.Search; -import com.yahoo.vespa.config.search.core.RankingConstantsConfig; - import java.util.Map; -import java.util.logging.Logger; /** * The derived rank profiles of a search definition * * @author bratseth */ -public class RankProfileList extends Derived implements RankProfilesConfig.Producer, RankingConstantsConfig.Producer { - - private static final Logger log = Logger.getLogger(RankProfileList.class.getName()); +public class RankProfileList extends Derived implements RankProfilesConfig.Producer { private final Map rankProfiles = new java.util.LinkedHashMap<>(); - private final RankingConstants rankingConstants; public static RankProfileList empty = new RankProfileList(); private RankProfileList() { - this.rankingConstants = new RankingConstants(); } /** @@ -39,13 +30,11 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ * @param attributeFields the attribute fields to create a ranking for */ public RankProfileList(Search search, - RankingConstants rankingConstants, AttributeFields attributeFields, RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles, ImportedModels importedModels) { setName(search == null ? "default" : search.getName()); - this.rankingConstants = rankingConstants; deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields); } @@ -89,17 +78,4 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ } } - @Override - public void getConfig(RankingConstantsConfig.Builder builder) { - for (RankingConstant constant : rankingConstants.asMap().values()) { - if ("".equals(constant.getFileReference())) - log.warning("Illegal file reference " + constant); // Let tests pass ... we should find a better way - else - builder.constant(new RankingConstantsConfig.Constant.Builder() - .name(constant.getName()) - .fileref(constant.getFileReference()) - .type(constant.getType())); - } - } - } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index 1b15233fead..3e9d188670e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -169,7 +169,6 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri deployState.rankProfileRegistry(), deployState.getQueryProfiles().getRegistry()); this.rankProfileList = new RankProfileList(null, // null search -> global - rankingConstants, AttributeFields.empty, deployState.rankProfileRegistry(), deployState.getQueryProfiles().getRegistry(), diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java index 606d5d5afb0..8c6c13d810f 100755 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java @@ -45,7 +45,6 @@ import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RankProfileList; import com.yahoo.vespa.config.search.RankProfilesConfig; -import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.configdefinition.IlscriptsConfig; import com.yahoo.vespa.model.PortsMeta; import com.yahoo.vespa.model.Service; @@ -130,8 +129,7 @@ public final class ContainerCluster RoutingProviderConfig.Producer, ConfigserverConfig.Producer, ThreadpoolConfig.Producer, - RankProfilesConfig.Producer, - RankingConstantsConfig.Producer + RankProfilesConfig.Producer { @@ -734,9 +732,6 @@ public final class ContainerCluster rankProfileList.getConfig(builder); } - @Override - public void getConfig(RankingConstantsConfig.Builder builder) { rankProfileList.getConfig(builder); } - public void setMbusParams(MbusParams mbusParams) { this.mbusParams = mbusParams; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java index b29ed0fc25b..a6bf51a2503 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java @@ -75,7 +75,16 @@ public class DocumentDatabase extends AbstractConfigProducer implements @Override public void getConfig(RankingConstantsConfig.Builder builder) { - derivedCfg.getRankProfileList().getConfig(builder); + for (RankingConstant constant : derivedCfg.getSearch().rankingConstants().asMap().values()) { + if ("".equals(constant.getFileReference())) { + System.err.println("INVALID rank constant "+constant.getName()+" [missing file reference]"); // TODO: Throw or log warning + continue; + } + builder.constant(new RankingConstantsConfig.Constant.Builder() + .name(constant.getName()) + .fileref(constant.getFileReference()) + .type(constant.getType())); + } } @Override diff --git a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java index 57d16f24f03..c5fb4f575cf 100644 --- a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java @@ -6,7 +6,6 @@ import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.vespa.config.search.RankProfilesConfig; -import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.container.ContainerCluster; import org.junit.After; @@ -57,22 +56,16 @@ public class ModelEvaluationTest { private void assertHasMlModels(VespaModel model) { ContainerCluster cluster = model.getContainerClusters().get("container"); - RankProfilesConfig.Builder b = new RankProfilesConfig.Builder(); cluster.getConfig(b); RankProfilesConfig config = new RankProfilesConfig(b); - - RankingConstantsConfig.Builder cb = new RankingConstantsConfig.Builder(); - cluster.getConfig(cb); - RankingConstantsConfig constantsConfig = new RankingConstantsConfig(cb); - assertEquals(4, config.rankprofile().size()); Set modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet()); assertTrue(modelNames.contains("xgboost_2_2")); assertTrue(modelNames.contains("mnist_softmax")); assertTrue(modelNames.contains("mnist_softmax_saved")); - ModelsEvaluator evaluator = new ModelsEvaluator(config, constantsConfig); + ModelsEvaluator evaluator = new ModelsEvaluator(config); assertEquals(4, evaluator.models().size()); Model xgboost = evaluator.models().get("xgboost_2_2"); diff --git a/model-evaluation/pom.xml b/model-evaluation/pom.xml index 0b6a5d08155..edb22c1b529 100644 --- a/model-evaluation/pom.xml +++ b/model-evaluation/pom.xml @@ -38,12 +38,6 @@ ${project.version} provided - - com.yahoo.vespa - searchcore - ${project.version} - provided - com.yahoo.vespa config diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Constant.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Constant.java deleted file mode 100644 index e664693ab38..00000000000 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Constant.java +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.models.evaluation; - -import com.yahoo.tensor.Tensor; - -/** - * A named constant loaded from a file. - * - * This is immutable. - * - * @author bratseth - */ -class Constant { - - private final String name; - private final Tensor value; - - Constant(String name, Tensor value) { - this.name = name; - this.value = value; - } - - public String name() { return name; } - - public Tensor value() { return value; } - -} diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index e08b9f77d15..520986ffb77 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -56,6 +56,4 @@ public class FunctionEvaluator { return function.getBody().evaluate(context).asTensor(); } - LazyArrayContext context() { return context; } - } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index beaa36b898f..2dcfd204077 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -8,7 +8,6 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -17,7 +16,6 @@ import com.yahoo.tensor.TensorType; import java.util.Arrays; import java.util.LinkedHashSet; -import java.util.List; import java.util.Map; import java.util.Set; @@ -39,11 +37,8 @@ final class LazyArrayContext extends Context implements ContextIndex { * * @param expression the expression to create a context for */ - LazyArrayContext(RankingExpression expression, - Map functions, - List constants, - Model model) { - this.indexedBindings = new IndexedBindings(expression, functions, constants, this, model); + LazyArrayContext(RankingExpression expression, Map functions, Model model) { + this.indexedBindings = new IndexedBindings(expression, functions, this, model); } /** @@ -144,10 +139,8 @@ final class LazyArrayContext extends Context implements ContextIndex { */ IndexedBindings(RankingExpression expression, Map functions, - List constants, LazyArrayContext owner, Model model) { - // 1. Determine and prepare bind targets Set bindTargets = new LinkedHashSet<>(); extractBindTargets(expression.getRoot(), functions, bindTargets); @@ -157,18 +150,9 @@ final class LazyArrayContext extends Context implements ContextIndex { int i = 0; ImmutableMap.Builder nameToIndexBuilder = new ImmutableMap.Builder<>(); for (String variable : bindTargets) - nameToIndexBuilder.put(variable, i++); + nameToIndexBuilder.put(variable,i++); nameToIndex = nameToIndexBuilder.build(); - - // 2. Bind the bind targets - for (Constant constant : constants) { - String constantReference = "constant(" + constant.name() + ")"; - Integer index = nameToIndex.get(constantReference); - if (index != null) - values[index] = new TensorValue(constant.value()); - } - for (Map.Entry function : functions.entrySet()) { Integer index = nameToIndex.get(function.getKey().serialForm()); if (index != null) // Referenced in this, so bind it @@ -186,7 +170,7 @@ final class LazyArrayContext extends Context implements ContextIndex { extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets); } else if (isConstant(node)) { - bindTargets.add(node.toString()); + // Ignore } else if (node instanceof ReferenceNode) { bindTargets.add(node.toString()); @@ -209,7 +193,7 @@ final class LazyArrayContext extends Context implements ContextIndex { if ( ! (node instanceof ReferenceNode)) return false; ReferenceNode reference = (ReferenceNode)node; - return reference.getName().equals("constant") && reference.getArguments().size() == 1; + return reference.getName().equals("value") && reference.getArguments().size() == 1; } Value get(int index) { return values[index]; } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 3fb43d73187..95eb923786d 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -36,15 +36,11 @@ public class Model { private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer(); - /** Programmatically create a model containing functions without constant of function references only */ public Model(String name, Collection functions) { - this(name, functions, Collections.emptyMap(), Collections.emptyList()); + this(name, functions, Collections.emptyMap()); } - Model(String name, - Collection functions, - Map referencedFunctions, - List constants) { + Model(String name, Collection functions, Map referencedFunctions) { // TODO: Optimize functions this.name = name; this.functions = ImmutableList.copyOf(functions); @@ -52,8 +48,7 @@ public class Model { ImmutableMap.Builder contextBuilder = new ImmutableMap.Builder<>(); for (ExpressionFunction function : functions) { try { - contextBuilder.put(function.getName(), - new LazyArrayContext(function.getBody(), referencedFunctions, constants, this)); + contextBuilder.put(function.getName(), new LazyArrayContext(function.getBody(), referencedFunctions, this)); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index 48c71b5a04a..dacf20b7ef2 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -5,7 +5,6 @@ import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableMap; import com.yahoo.component.AbstractComponent; import com.yahoo.vespa.config.search.RankProfilesConfig; -import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import java.util.Map; import java.util.stream.Collectors; @@ -22,8 +21,8 @@ public class ModelsEvaluator extends AbstractComponent { private final ImmutableMap models; - public ModelsEvaluator(RankProfilesConfig config, RankingConstantsConfig constantsConfig) { - models = ImmutableMap.copyOf(new RankProfilesConfigImporter().importFrom(config, constantsConfig)); + public ModelsEvaluator(RankProfilesConfig config) { + models = ImmutableMap.copyOf(new RankProfilesConfigImporter().importFrom(config)); } /** Returns the models of this as an immutable map */ diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index b9e7a27c013..bfd6342218a 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -1,57 +1,33 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; -import com.yahoo.filedistribution.fileacquirer.FileAcquirer; -import com.yahoo.io.GrowableByteBuffer; -import com.yahoo.io.IOUtils; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.vespa.config.search.RankProfilesConfig; -import com.yahoo.vespa.config.search.core.RankingConstantsConfig; -import java.io.File; -import java.io.IOException; -import java.io.UncheckedIOException; import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; /** - * Converts RankProfilesConfig instances to RankingExpressions for evaluation. - * This class can be used by a single thread only. + * Converts RankProfilesConfig instances to RankingExpressions for evaluation * * @author bratseth */ class RankProfilesConfigImporter { - /** - * Constants already imported in this while reading some expression. - * This is to avoid re-reading constants referenced - * multiple places, as that is potentially costly. - */ - private Map globalImportedConstants = new HashMap<>(); - /** * Returns a map of the models contained in this config, indexed on name. * The map is modifiable and owned by the caller. */ - Map importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig) { - globalImportedConstants.clear(); + Map importFrom(RankProfilesConfig config) { try { Map models = new HashMap<>(); for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) { - Model model = importProfile(profile, constantsConfig); + Model model = importProfile(profile); models.put(model.name(), model); } return models; @@ -61,14 +37,11 @@ class RankProfilesConfigImporter { } } - private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig) throws ParseException { + private Model importProfile(RankProfilesConfig.Rankprofile profile) throws ParseException { List functions = new ArrayList<>(); Map referencedFunctions = new HashMap<>(); ExpressionFunction firstPhase = null; ExpressionFunction secondPhase = null; - - List constants = readConstants(constantsConfig); - for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) { Optional reference = FunctionReference.fromSerial(property.name()); if ( reference.isPresent()) { @@ -96,7 +69,7 @@ class RankProfilesConfigImporter { functions.add(secondPhase); try { - return new Model(profile.name(), functions, referencedFunctions, constants); + return new Model(profile.name(), functions, referencedFunctions); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e); @@ -110,30 +83,4 @@ class RankProfilesConfigImporter { return null; } - private List readConstants(RankingConstantsConfig constantsConfig) { - List constants = new ArrayList<>(); - for (RankingConstantsConfig.Constant constantConfig : constantsConfig.constant()) { - constants.add(new Constant(constantConfig.name(), - readTensorFromFile(TensorType.fromSpec(constantConfig.type()), - constantConfig.fileref().value()))); - } - return constants; - } - - private Tensor readTensorFromFile(TensorType type, String fileName) { - try { - if (fileName.endsWith(".tbf")) - return TypedBinaryFormat.decode(Optional.of(type), - GrowableByteBuffer.wrap(IOUtils.readFileBytes(new File(fileName)))); - // TODO: Support json and json.lz4 - - if (fileName.isEmpty()) // this is the case in unit tests - return Tensor.from(type, "{}"); - throw new IllegalArgumentException("Unknown tensor file format (determined by file ending): " + fileName); - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index d94e5b2af1b..60cf0d25ded 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -3,10 +3,8 @@ package ai.vespa.models.evaluation; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.config.subscription.FileSource; -import com.yahoo.path.Path; import com.yahoo.tensor.Tensor; import com.yahoo.vespa.config.search.RankProfilesConfig; -import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import org.junit.Test; import java.io.File; @@ -20,9 +18,15 @@ public class ModelsEvaluatorTest { private static final double delta = 0.00000000001; + private ModelsEvaluator createModels() { + String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg"; + RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig(""); + return new ModelsEvaluator(config); + } + @Test public void testTensorEvaluation() { - ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/"); + ModelsEvaluator models = createModels(); FunctionEvaluator function = models.evaluatorOf("macros", "fourtimessum"); function.bind("var1", Tensor.from("{{x:0}:3,{x:1}:5}")); function.bind("var2", Tensor.from("{{x:0}:7,{x:1}:11}")); @@ -31,7 +35,7 @@ public class ModelsEvaluatorTest { @Test public void testEvaluationDependingOnMacroTakingArguments() { - ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/"); + ModelsEvaluator models = createModels(); FunctionEvaluator function = models.evaluatorOf("macros", "secondphase"); function.bind("match", 3); function.bind("rankBoost", 5); @@ -42,14 +46,6 @@ public class ModelsEvaluatorTest { // TODO: Test that binding nonexisting variable doesn't work // TODO: Test that rebinding doesn't work // TODO: Test with nested macros - - private ModelsEvaluator createModels(String path) { - Path configDir = Path.fromString(path); - RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), - RankProfilesConfig.class).getConfig(""); - RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), - RankingConstantsConfig.class).getConfig(""); - return new ModelsEvaluator(config, constantsConfig); - } + // TODO: Test TF/ONNX model } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java index 84e01e58280..d45372fc7da 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java @@ -3,10 +3,8 @@ package ai.vespa.models.evaluation; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.config.subscription.FileSource; -import com.yahoo.path.Path; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.vespa.config.search.RankProfilesConfig; -import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import org.junit.Test; import java.io.File; @@ -23,41 +21,14 @@ import static org.junit.Assert.assertNotNull; public class RankProfilesImporterTest { @Test - public void testImportingModels() { - Map models = createModels("src/test/resources/config/models/"); - - assertEquals(4, models.size()); - - Model xgboost = models.get("xgboost_2_2"); - assertFunction("xgboost_2_2", - "(optimized sum of condition trees of size 192 bytes)", - xgboost); - - Model onnxMnistSoftmax = models.get("mnist_softmax"); - assertFunction("default.add", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", - onnxMnistSoftmax); - assertEquals("tensor(d1[10],d2[784])", - onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); - - Model tfMnistSoftmax = models.get("mnist_softmax_saved"); - assertFunction("serving_default.y", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", - tfMnistSoftmax); - - Model tfMnist = models.get("mnist_saved"); - assertFunction("serving_default.y", - "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", - tfMnist); - } - - @Test - public void testImportingRankExpressions() { - Map models = createModels("src/test/resources/config/rankexpression/"); - + public void testImporting() { + String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg"; + RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig(""); + Map models = new RankProfilesConfigImporter().importFrom(config); assertEquals(18, models.size()); Model macros = models.get("macros"); + assertNotNull(macros); assertEquals("macros", macros.name()); assertEquals(4, macros.functions().size()); assertFunction("fourtimessum", "4 * (var1 + var2)", macros); @@ -73,9 +44,8 @@ public class RankProfilesImporterTest { } private void assertFunction(String name, String expression, Model model) { - assertNotNull("Model is present in config", model); ExpressionFunction function = model.function(name); - assertNotNull("Function '" + name + "' is in " + model, function); + assertNotNull(function); assertEquals(name, function.getName()); assertEquals(expression, function.getBody().getRoot().toString()); } @@ -87,13 +57,4 @@ public class RankProfilesImporterTest { assertEquals(expression, function.getBody().getRoot().toString()); } - private Map createModels(String path) { - Path configDir = Path.fromString(path); - RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), - RankProfilesConfig.class).getConfig(""); - RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), - RankingConstantsConfig.class).getConfig(""); - return new RankProfilesConfigImporter().importFrom(config, constantsConfig); - } - } diff --git a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg deleted file mode 100644 index 1cc36f75158..00000000000 --- a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg +++ /dev/null @@ -1,14 +0,0 @@ -rankprofile[0].name "mnist_saved" -rankprofile[0].fef.property[0].name "rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add).rankingScript" -rankprofile[0].fef.property[0].value "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))" -rankprofile[0].fef.property[1].name "rankingExpression(serving_default.y).rankingScript" -rankprofile[0].fef.property[1].value "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))" -rankprofile[1].name "xgboost_2_2" -rankprofile[1].fef.property[0].name "rankingExpression(xgboost_2_2).rankingScript" -rankprofile[1].fef.property[0].value "if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)" -rankprofile[2].name "mnist_softmax_saved" -rankprofile[2].fef.property[0].name "rankingExpression(serving_default.y).rankingScript" -rankprofile[2].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))" -rankprofile[3].name "mnist_softmax" -rankprofile[3].fef.property[0].name "rankingExpression(default.add).rankingScript" -rankprofile[3].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))" diff --git a/model-evaluation/src/test/resources/config/models/ranking-constants.cfg b/model-evaluation/src/test/resources/config/models/ranking-constants.cfg deleted file mode 100644 index 2b7495ace5e..00000000000 --- a/model-evaluation/src/test/resources/config/models/ranking-constants.cfg +++ /dev/null @@ -1,30 +0,0 @@ -constant[0].name "mnist_saved_dnn_hidden1_weights_read" -constant[0].fileref "" -constant[0].type "tensor(d3[300],d4[784])" -constant[1].name "mnist_saved_dnn_hidden2_weights_read" -constant[1].fileref "" -constant[1].type "tensor(d2[100],d3[300])" -constant[2].name "mnist_softmax_saved_layer_Variable_1_read" -constant[2].fileref "" -constant[2].type "tensor(d1[10])" -constant[3].name "mnist_saved_dnn_hidden1_bias_read" -constant[3].fileref "" -constant[3].type "tensor(d3[300])" -constant[4].name "mnist_saved_dnn_hidden2_bias_read" -constant[4].fileref "" -constant[4].type "tensor(d2[100])" -constant[5].name "mnist_softmax_Variable" -constant[5].fileref "" -constant[5].type "tensor(d1[10],d2[784])" -constant[6].name "mnist_saved_dnn_outputs_weights_read" -constant[6].fileref "" -constant[6].type "tensor(d1[10],d2[100])" -constant[7].name "mnist_softmax_saved_layer_Variable_read" -constant[7].fileref "" -constant[7].type "tensor(d1[10],d2[784])" -constant[8].name "mnist_softmax_Variable_1" -constant[8].fileref "" -constant[8].type "tensor(d1[10])" -constant[9].name "mnist_saved_dnn_outputs_bias_read" -constant[9].fileref "" -constant[9].type "tensor(d1[10])" \ No newline at end of file diff --git a/model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg b/model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg deleted file mode 100644 index e69de29bb2d..00000000000 -- cgit v1.2.3