diff options
16 files changed, 332 insertions, 211 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java index a2bdc6834c9..7b7265e02ae 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java @@ -1,14 +1,21 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; +import com.yahoo.config.FileReference; import com.yahoo.tensor.TensorType; +import com.yahoo.vespa.model.AbstractService; +import com.yahoo.vespa.model.utils.FileSender; +import java.util.Collection; import java.util.Objects; /** - * Represents a global ranking constant + * A global ranking constant distributed using file distribution. + * Ranking constants must be sent to some services to be useful - this is done + * by calling the sentTo method during the prepare phase of building models. * * @author arnej + * @author bratseth */ public class RankingConstant { @@ -49,14 +56,16 @@ public class RankingConstant { this.pathType = PathType.URI; } - /** - * Set the internally generated reference to this file used to identify this instance of the file for - * file distribution. - */ - public void setFileReference(String fileReference) { this.fileReference = fileReference; } - public void setType(TensorType tensorType) { this.tensorType = tensorType; } + /** Initiate sending of this constant to some services over file distribution */ + public void sendTo(Collection<? extends AbstractService> services) { + FileReference reference = (pathType == RankingConstant.PathType.FILE) + ? FileSender.sendFileToServices(path, services) + : FileSender.sendUriToServices(path, services); + this.fileReference = reference.value(); + } + public String getName() { return name; } public String getFileName() { return path; } public String getUri() { return path; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java index 5ac1418c0c7..e354c52092f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java @@ -40,12 +40,7 @@ public class RankingConstants { /** Initiate sending of these constants to some services over file distribution */ public void sendTo(Collection<? extends AbstractService> services) { - for (RankingConstant constant : constants.values()) { - FileReference reference = (constant.getPathType() == RankingConstant.PathType.FILE) - ? FileSender.sendFileToServices(constant.getFileName(), services) - : FileSender.sendUriToServices(constant.getUri(), services); - constant.setFileReference(reference.value()); - } + constants.values().forEach(constant -> constant.sendTo(services)); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index 229ae0ebaaf..4cd8c6ac92b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -1,5 +1,4 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.path.Path; @@ -8,12 +7,12 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.vespa.model.ml.ConvertedModel; +import com.yahoo.vespa.model.ml.FeatureArguments; -import java.io.File; import java.io.UncheckedIOException; import java.util.HashMap; import java.util.Map; -import java.util.Optional; /** * Replaces instances of the onnx(model-path, output) @@ -43,7 +42,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans try { // TODO: Put modelPath in FeatureArguments instead - Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); + Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = convertedOnnxModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context)); return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); @@ -53,14 +52,14 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans } } - private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) { + private FeatureArguments asFeatureArguments(Arguments arguments) { if (arguments.isEmpty()) throw new IllegalArgumentException("An onnx node must take an argument pointing to " + "the onnx model directory under [application]/models"); if (arguments.expressions().size() > 3) throw new IllegalArgumentException("An onnx feature can have at most 2 arguments"); - return new ConvertedModel.FeatureArguments(arguments); + return new FeatureArguments(arguments); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index bcb8ef1521d..72cfde0a566 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -7,8 +7,9 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.vespa.model.ml.ConvertedModel; +import com.yahoo.vespa.model.ml.FeatureArguments; -import java.io.File; import java.io.UncheckedIOException; import java.util.HashMap; import java.util.Map; @@ -39,7 +40,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil if ( ! feature.getName().equals("tensorflow")) return feature; try { - Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); + Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context)); return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); @@ -49,14 +50,14 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } - private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) { + private FeatureArguments asFeatureArguments(Arguments arguments) { if (arguments.isEmpty()) throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + "the tensorflow model directory under [application]/models"); if (arguments.expressions().size() > 3) throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments"); - return new ConvertedModel.FeatureArguments(arguments); + return new FeatureArguments(arguments); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java index b4a5069b9d6..8591bf16d07 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java @@ -2,13 +2,13 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.path.Path; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.integration.ml.XGBoostImporter; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.vespa.model.ml.ConvertedModel; +import com.yahoo.vespa.model.ml.FeatureArguments; import java.io.UncheckedIOException; import java.util.HashMap; @@ -41,7 +41,7 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr if ( ! feature.getName().equals("xgboost")) return feature; try { - Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); + Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = convertedXGBoostModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context)); return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); @@ -50,11 +50,11 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr } } - private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) { + private FeatureArguments asFeatureArguments(Arguments arguments) { if (arguments.size() != 1) throw new IllegalArgumentException("An xgboost node must take a single argument pointing to " + "the xgboost model directory under [application]/models"); - return new ConvertedModel.FeatureArguments(arguments); + return new FeatureArguments(arguments); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index 1b15233fead..282e5a29962 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -2,7 +2,6 @@ package com.yahoo.vespa.model; import com.google.common.collect.ImmutableList; -import com.yahoo.collections.Pair; import com.yahoo.config.ConfigBuilder; import com.yahoo.config.ConfigInstance; import com.yahoo.config.ConfigInstance.Builder; @@ -33,7 +32,7 @@ import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.RankingConstants; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RankProfileList; -import com.yahoo.searchdefinition.expressiontransforms.ConvertedModel; +import com.yahoo.vespa.model.ml.ConvertedModel; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; @@ -54,6 +53,7 @@ import com.yahoo.vespa.model.content.cluster.ContentCluster; import com.yahoo.vespa.model.filedistribution.FileDistributionConfigProducer; import com.yahoo.vespa.model.filedistribution.FileDistributor; import com.yahoo.vespa.model.generic.service.ServiceCluster; +import com.yahoo.vespa.model.ml.ModelName; import com.yahoo.vespa.model.routing.Routing; import com.yahoo.vespa.model.search.AbstractSearchCluster; import com.yahoo.vespa.model.utils.internal.ReflectionUtil; @@ -233,7 +233,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri for (ImportedModel model : importedModels.all()) { RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry); rankProfileRegistry.add(profile); - ConvertedModel convertedModel = ConvertedModel.fromSource(model.name(), model.name(), profile, queryProfiles, model); + ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()), + model.name(), profile, queryProfiles, model); for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) { profile.addMacro(entry.getKey(), false).setRankingExpression(entry.getValue()); } @@ -245,7 +246,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri String modelName = generatedModelDir.getPath().last(); RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry); rankProfileRegistry.add(profile); - ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, modelName, profile); + ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile); for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) { profile.addMacro(entry.getKey(), false).setRankingExpression(entry.getValue()); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 2f93bcc2e12..1f27b9843cd 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -1,4 +1,5 @@ -package com.yahoo.searchdefinition.expressiontransforms; +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.ml; import com.google.common.collect.ImmutableMap; import com.yahoo.collections.Pair; @@ -11,6 +12,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.FeatureNames; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; @@ -18,7 +20,6 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -63,14 +64,14 @@ import java.util.stream.Collectors; */ public class ConvertedModel { - private final String modelName; + private final ModelName modelName; private final String modelDescription; private final ImmutableMap<String, RankingExpression> expressions; /** The source importedModel, or empty if this was created from a stored converted model */ private final Optional<ImportedModel> sourceModel; - private ConvertedModel(String modelName, + private ConvertedModel(ModelName modelName, String modelDescription, Map<String, RankingExpression> expressions, Optional<ImportedModel> sourceModel) { @@ -86,7 +87,7 @@ public class ConvertedModel { */ public static ConvertedModel fromSourceOrStore(Path modelPath, RankProfileTransformContext context) { File sourceModel = sourceModelFile(context.rankProfile().applicationPackage(), modelPath); - String modelName = context.rankProfile().getName() + "." + toModelName(modelPath); // must be unique to each profile + ModelName modelName = new ModelName(context.rankProfile().getName(), modelPath); if (sourceModel.exists()) return fromSource(modelName, modelPath.toString(), @@ -99,7 +100,7 @@ public class ConvertedModel { context.rankProfile()); } - public static ConvertedModel fromSource(String modelName, + public static ConvertedModel fromSource(ModelName modelName, String modelDescription, RankProfile rankProfile, QueryProfileRegistry queryProfileRegistry, @@ -111,7 +112,7 @@ public class ConvertedModel { Optional.of(importedModel)); } - public static ConvertedModel fromStore(String modelName, + public static ConvertedModel fromStore(ModelName modelName, String modelDescription, RankProfile rankProfile) { ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName); @@ -240,9 +241,12 @@ public class ConvertedModel { profile.addConstant(constantName, asValue(constantValue)); } - private static void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, - Set<String> constantsReplacedByMacros, - String constantName, Tensor constantValue) { + private static void transformLargeConstant(ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles, + Set<String> constantsReplacedByMacros, + String constantName, + Tensor constantValue) { RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); if (macroOverridingConstant != null) { TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); @@ -255,7 +259,7 @@ public class ConvertedModel { Path constantPath = store.writeLargeConstant(constantName, constantValue); if ( ! profile.rankingConstants().asMap().containsKey(constantName)) { profile.rankingConstants().add(new RankingConstant(constantName, constantValue.type(), - constantPath.toString())); + constantPath.toString())); } } } @@ -491,10 +495,6 @@ public class ConvertedModel { return new TensorValue(tensor); } - private static String toModelName(Path modelPath) { - return modelPath.toString().replace("/", "_"); - } - @Override public String toString() { return "model '" + modelName + "'"; } @@ -513,7 +513,7 @@ public class ConvertedModel { private final ApplicationPackage application; private final ModelFiles modelFiles; - ModelStore(ApplicationPackage application, String modelName) { + ModelStore(ApplicationPackage application, ModelName modelName) { this.application = application; this.modelFiles = new ModelFiles(modelName); } @@ -616,8 +616,12 @@ public class ConvertedModel { .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); // Write content explicitly as a file on the file system as this is distributed using file distribution - createIfNeeded(constantsPath); - IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); + // - but only if this is a global model to avoid writing the same constants for each rank profile + // where they are used + if (modelFiles.modelName.isGlobal()) { + createIfNeeded(constantsPath); + IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); + } return correct(constantPath); } @@ -676,20 +680,24 @@ public class ConvertedModel { static class ModelFiles { - String modelName; + ModelName modelName; - public ModelFiles(String modelName) { + public ModelFiles(ModelName modelName) { this.modelName = modelName; } /** Files stored below this path will be replicated in zookeeper */ public Path storedModelReplicatedPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelName); + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelName.fullName()); } - /** Files stored below this path will not be replicated in zookeeper */ - public Path storedModelPath() { - return ApplicationPackage.MODELS_GENERATED_DIR.append(modelName); + /** + * Files stored below this path will not be replicated in zookeeper. + * Large constants are only stored under the global (not rank-profile-specific) + * path to avoid storing the same large constant multiple times. + */ + public Path storedGlobalModelPath() { + return ApplicationPackage.MODELS_GENERATED_DIR.append(modelName.localName()); } public Path expressionPath(String name) { @@ -706,7 +714,7 @@ public class ConvertedModel { /** Path to the large (ranking) constants directory */ public Path largeConstantsContentPath() { - return storedModelPath().append("constants"); + return storedGlobalModelPath().append("constants"); } /** Path to the large (ranking) constants directory */ @@ -721,53 +729,4 @@ public class ConvertedModel { } - /** Encapsulates the arguments of a specific model output */ - static class FeatureArguments { - - /** Optional arguments */ - private final Optional<String> signature, output; - - public FeatureArguments(Arguments arguments) { - this(optionalArgument(1, arguments), - optionalArgument(2, arguments)); - } - - public FeatureArguments(Optional<String> signature, Optional<String> output) { - this.signature = signature; - this.output = output; - } - - public Optional<String> signature() { return signature; } - public Optional<String> output() { return output; } - - public String toName() { - return (signature.isPresent() ? signature.get() : "") + - (output.isPresent() ? "." + output.get() : ""); - } - - private static Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { - if (argumentIndex >= arguments.expressions().size()) - return Optional.empty(); - return Optional.of(asString(arguments.expressions().get(argumentIndex))); - } - - public static String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); - } - - private static String stripQuotes(String s) { - if ( ! isQuoteSign(s.codePointAt(0))) return s; - if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) - throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); - return s.substring(1, s.length()-1); - } - - private static boolean isQuoteSign(int c) { - return c == '\'' || c == '"'; - } - - } - } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java new file mode 100644 index 00000000000..fda49af6178 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java @@ -0,0 +1,61 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.ml; + +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; + +import java.util.Optional; + +/** + * Encapsulates the arguments of a specific model output + * + * @author bratseth + */ +public class FeatureArguments { + + /** Optional arguments */ + private final Optional<String> signature, output; + + public FeatureArguments(Arguments arguments) { + this(optionalArgument(1, arguments), + optionalArgument(2, arguments)); + } + + public FeatureArguments(Optional<String> signature, Optional<String> output) { + this.signature = signature; + this.output = output; + } + + public Optional<String> signature() { return signature; } + public Optional<String> output() { return output; } + + public String toName() { + return (signature.isPresent() ? signature.get() : "") + + (output.isPresent() ? "." + output.get() : ""); + } + + private static Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { + if (argumentIndex >= arguments.expressions().size()) + return Optional.empty(); + return Optional.of(asString(arguments.expressions().get(argumentIndex))); + } + + public static String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); + } + + private static String stripQuotes(String s) { + if ( ! isQuoteSign(s.codePointAt(0))) return s; + if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) + throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); + return s.substring(1, s.length()-1); + } + + private static boolean isQuoteSign(int c) { + return c == '\'' || c == '"'; + } + +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java new file mode 100644 index 00000000000..5e22fefd093 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java @@ -0,0 +1,54 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.ml; + +import com.yahoo.path.Path; + +/** + * Models used in a rank profile has the rank profile name as name space while gGlobal model names have no namespace + * + * @author bratseth + */ +public class ModelName { + + /** The namespace, or null if none */ + private String namespace; + private String name; + private String fullName; + + public ModelName(String name) { + this(null, name); + } + + public ModelName(String namespace, Path modelPath) { + this(namespace, modelPath.toString().replace("/", "_")); + } + + private ModelName(String namespace, String name) { + this.namespace = namespace; + this.name = name; + this.fullName = (namespace != null ? namespace + "." : "") + name; + } + + /** Returns true if the local name of this is not in a namespace */ + public boolean isGlobal() { return namespace == null; } + + /** Returns the namespace, or null if this is global */ + public String namespace() { return namespace; } + public String localName() { return name; } + public String fullName() { return fullName; } + + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if ( ! (o instanceof ModelName)) return false; + return ((ModelName)o).fullName.equals(this.fullName); + } + + @Override + public int hashCode() { return fullName.hashCode(); } + + @Override + public String toString() { return fullName; } + +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java index 83da5d96418..fbbf029d5f1 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java @@ -1,16 +1,13 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.search; -import com.yahoo.config.FileReference; import com.yahoo.config.model.producer.AbstractConfigProducer; import com.yahoo.config.model.producer.UserConfigRepo; import com.yahoo.prelude.fastsearch.DocumentdbInfoConfig; import com.yahoo.search.config.IndexInfoConfig; -import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.vespa.config.search.AttributesConfig; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.configdefinition.IlscriptsConfig; -import com.yahoo.vespa.model.utils.FileSender; import java.util.ArrayList; import java.util.LinkedList; @@ -36,14 +33,8 @@ public abstract class AbstractSearchCluster extends AbstractConfigProducer protected List<SearchDefinitionSpec> localSDS = new LinkedList<>(); public void prepareToDistributeFiles(List<SearchNode> backends) { - for (SearchDefinitionSpec sds : localSDS) { - for (RankingConstant constant : sds.getSearchDefinition().getSearch().rankingConstants().asMap().values()) { - FileReference reference = (constant.getPathType() == RankingConstant.PathType.FILE) - ? FileSender.sendFileToServices(constant.getFileName(), backends) - : FileSender.sendUriToServices(constant.getUri(), backends); - constant.setFileReference(reference.value()); - } - } + for (SearchDefinitionSpec sds : localSDS) + sds.getSearchDefinition().getSearch().rankingConstants().sendTo(backends); } public static final class IndexingMode { diff --git a/config-model/src/test/integration/onnx/services.xml b/config-model/src/test/integration/onnx/services.xml new file mode 100644 index 00000000000..f623b2464fc --- /dev/null +++ b/config-model/src/test/integration/onnx/services.xml @@ -0,0 +1,5 @@ +<services> + <container version="1.0"> + + </container> +</services>
\ No newline at end of file diff --git a/config-model/src/test/integration/tensorflow/services.xml b/config-model/src/test/integration/tensorflow/services.xml new file mode 100644 index 00000000000..f623b2464fc --- /dev/null +++ b/config-model/src/test/integration/tensorflow/services.xml @@ -0,0 +1,5 @@ +<services> + <container version="1.0"> + + </container> +</services>
\ No newline at end of file diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 414a77e9164..b046d60f948 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -1,27 +1,22 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.searchdefinition.processing; import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.ml.ImportedModelTester; import com.yahoo.yolean.Exceptions; import org.junit.After; import org.junit.Test; import java.io.IOException; -import java.io.UncheckedIOException; import java.util.Optional; import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage; -import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.fail; @@ -41,14 +36,36 @@ public class RankingExpressionWithOnnxTestCase { } @Test + public void testGlobalOnnxModel() throws IOException { + ImportedModelTester tester = new ImportedModelTester(name, applicationDir); + VespaModel model = tester.createVespaModel(); + tester.assertLargeConstant(name + "_Variable_1", model, Optional.of(10L)); + tester.assertLargeConstant(name + "_Variable", model, Optional.of(7840L)); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedAppDir = applicationDir.append("copy"); + try { + storedAppDir.toFile().mkdirs(); + IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString()); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir); + VespaModel storedModel = storedTester.createVespaModel(); + tester.assertLargeConstant(name + "_Variable_1", storedModel, Optional.of(10L)); + tester.assertLargeConstant(name + "_Variable", storedModel, Optional.of(7840L)); + } + finally { + IOUtils.recursiveDeleteDir(storedAppDir.toFile()); + } + } + + @Test public void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "onnx('mnist_softmax.onnx')", "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant(name + "_Variable", search, Optional.of(7840L)); } @Test @@ -68,8 +85,6 @@ public class RankingExpressionWithOnnxTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant(name + "_Variable", search, Optional.of(7840L)); } @Test @@ -82,8 +97,6 @@ public class RankingExpressionWithOnnxTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant( name + "_Variable", search, Optional.of(7840L)); } @@ -104,8 +117,6 @@ public class RankingExpressionWithOnnxTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant( name + "_Variable", search, Optional.of(7840L)); } @@ -114,8 +125,6 @@ public class RankingExpressionWithOnnxTestCase { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "5 + sum(onnx('mnist_softmax.onnx'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant( name + "_Variable", search, Optional.of(7840L)); } @Test @@ -181,9 +190,6 @@ public class RankingExpressionWithOnnxTestCase { "onnx('mnist_softmax.onnx')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); - assertLargeConstant( name + "_Variable", search, Optional.of(7840L)); - // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); try { @@ -200,8 +206,6 @@ public class RankingExpressionWithOnnxTestCase { searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); // Verify that the constants exists, but don't verify the content as we are not // simulating file distribution in this test - assertLargeConstant( name + "_Variable_1", searchFromStored, Optional.empty()); - assertLargeConstant( name + "_Variable", searchFromStored, Optional.empty()); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); @@ -232,7 +236,6 @@ public class RankingExpressionWithOnnxTestCase { assertNull("Constant overridden by macro is not added", search.search().rankingConstants().get( name + "_Variable")); - assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); @@ -245,38 +248,12 @@ public class RankingExpressionWithOnnxTestCase { searchFromStored.compileRankProfile("my_profile", applicationDir.append("models")); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); assertNull("Constant overridden by macro is not added", - searchFromStored.search().rankingConstants().get( name + "_Variable")); - assertLargeConstant( name + "_Variable_1", searchFromStored, Optional.of(10L)); + searchFromStored.search().rankingConstants().get( name + "_Variable")); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); } } - /** - * Verifies that the constant with the given name exists, and - only if an expected size is given - - * that the content of the constant is available and has the expected size. - */ - private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) { - try { - Path constantApplicationPackagePath = Path.fromString("models.generated/my_profile.mnist_softmax.onnx/constants").append(name + ".tbf"); - RankingConstant rankingConstant = search.search().rankingConstants().get(name); - assertEquals(name, rankingConstant.getName()); - assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString())); - - if (expectedSize.isPresent()) { - Path constantPath = applicationDir.append(constantApplicationPackagePath); - assertTrue("Constant file '" + constantPath + "' has been written", - constantPath.toFile().exists()); - Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), - GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile()))); - assertEquals(expectedSize.get().longValue(), deserializedConstant.size()); - } - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) { return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder", new StoringApplicationPackage(applicationDir)); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 450c66e04ef..14632a568ea 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -15,27 +15,22 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.ml.ImportedModelTester; import com.yahoo.yolean.Exceptions; import org.junit.After; import org.junit.Test; -import java.io.BufferedInputStream; import java.io.File; -import java.io.FileInputStream; -import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; -import java.io.InputStream; -import java.io.Reader; import java.io.UncheckedIOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; -import java.util.Iterator; import java.util.List; import java.util.Optional; -import java.util.stream.Collectors; +import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.*; /** @@ -56,12 +51,34 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test + public void testGlobalTensorFlowModel() throws IOException { + ImportedModelTester tester = new ImportedModelTester(name, applicationDir); + VespaModel model = tester.createVespaModel(); + assertLargeConstant(name + "_layer_Variable_1_read", model, Optional.of(10L)); + assertLargeConstant(name + "_layer_Variable_read", model, Optional.of(7840L)); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedAppDir = applicationDir.append("copy"); + try { + storedAppDir.toFile().mkdirs(); + IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString()); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir); + VespaModel storedModel = storedTester.createVespaModel(); + tester.assertLargeConstant(name + "_layer_Variable_1_read", storedModel, Optional.of(10L)); + tester.assertLargeConstant(name + "_layer_Variable_read", storedModel, Optional.of(7840L)); + } + finally { + IOUtils.recursiveDeleteDir(storedAppDir.toFile()); + } + } + + @Test public void testTensorFlowReference() { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); - assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -71,8 +88,6 @@ public class RankingExpressionWithTensorFlowTestCase { "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); - assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -91,8 +106,6 @@ public class RankingExpressionWithTensorFlowTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); - assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -105,8 +118,6 @@ public class RankingExpressionWithTensorFlowTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); - assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -125,8 +136,6 @@ public class RankingExpressionWithTensorFlowTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); - assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -134,8 +143,6 @@ public class RankingExpressionWithTensorFlowTestCase { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "5 + sum(tensorflow('mnist_softmax/saved'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); - assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -233,9 +240,6 @@ public class RankingExpressionWithTensorFlowTestCase { "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); - assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L)); - // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); try { @@ -250,10 +254,6 @@ public class RankingExpressionWithTensorFlowTestCase { "Placeholder", storedApplication); searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); - // Verify that the constants exists, but don't verify the content as we are not - // simulating file distribution in this test - assertLargeConstant(name + "_layer_Variable_1_read", searchFromStored, Optional.empty()); - assertLargeConstant(name + "_layer_Variable_read", searchFromStored, Optional.empty()); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); @@ -287,7 +287,6 @@ public class RankingExpressionWithTensorFlowTestCase { assertNull("Constant overridden by macro is not added", search.search().rankingConstants().get("mnist_softmax_saved_layer_Variable_read")); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); @@ -303,7 +302,6 @@ public class RankingExpressionWithTensorFlowTestCase { searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child"); assertNull("Constant overridden by macro is not added", searchFromStored.search().rankingConstants().get("mnist_softmax_saved_layer_Variable_read")); - assertLargeConstant(name + "_layer_Variable_1_read", searchFromStored, Optional.of(10L)); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); @@ -316,8 +314,6 @@ public class RankingExpressionWithTensorFlowTestCase { RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(expression, "my_profile"); - assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); - assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -401,11 +397,11 @@ public class RankingExpressionWithTensorFlowTestCase { * Verifies that the constant with the given name exists, and - only if an expected size is given - * that the content of the constant is available and has the expected size. */ - private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) { + private void assertLargeConstant(String constantName, VespaModel model, Optional<Long> expectedSize) { try { - Path constantApplicationPackagePath = Path.fromString("models.generated/my_profile.mnist_softmax_saved/constants").append(name + ".tbf"); - RankingConstant rankingConstant = search.search().rankingConstants().get(name); - assertEquals(name, rankingConstant.getName()); + Path constantApplicationPackagePath = Path.fromString("models.generated/" + name + "/constants").append(constantName + ".tbf"); + RankingConstant rankingConstant = model.rankingConstants().get(constantName); + assertEquals(constantName, rankingConstant.getName()); assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString())); if (expectedSize.isPresent()) { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java new file mode 100644 index 00000000000..2ae629562d0 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java @@ -0,0 +1,71 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.ml; + +import com.yahoo.config.model.ApplicationPackageTester; +import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.vespa.model.VespaModel; +import org.xml.sax.SAXException; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Optional; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertEquals; + +/** + * Helper for testing of imported models. + * More duplicated functionality across tests on imported models should be moved here + * + * @author bratseth + */ +public class ImportedModelTester { + + private final String modelName; + private final Path applicationDir; + + public ImportedModelTester(String modelName, Path applicationDir) { + this.modelName = modelName; + this.applicationDir = applicationDir; + } + + public VespaModel createVespaModel() { + try { + return new VespaModel(ApplicationPackageTester.create(applicationDir.toString()).app()); + } + catch (SAXException | IOException e) { + throw new RuntimeException(e); + } + } + + /** + * Verifies that the constant with the given name exists, and - only if an expected size is given - + * that the content of the constant is available and has the expected size. + */ + public void assertLargeConstant(String constantName, VespaModel model, Optional<Long> expectedSize) { + try { + Path constantApplicationPackagePath = Path.fromString("models.generated/" + modelName + "/constants").append(constantName + ".tbf"); + RankingConstant rankingConstant = model.rankingConstants().get(constantName); + assertEquals(constantName, rankingConstant.getName()); + assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString())); + + if (expectedSize.isPresent()) { + Path constantPath = applicationDir.append(constantApplicationPackagePath); + assertTrue("Constant file '" + constantPath + "' has been written", + constantPath.toFile().exists()); + Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), + GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile()))); + assertEquals(expectedSize.get().longValue(), deserializedConstant.size()); + } + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + +} diff --git a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index ad2f62b7dc3..35e6642d7cb 100644 --- a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.config.model; +package com.yahoo.vespa.model.ml; import ai.vespa.models.evaluation.Model; import ai.vespa.models.evaluation.ModelsEvaluator; @@ -18,7 +18,6 @@ import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.container.ContainerCluster; import org.junit.After; import org.junit.Test; -import org.xml.sax.SAXException; import java.io.IOException; import java.util.Set; @@ -41,10 +40,9 @@ public class ModelEvaluationTest { } @Test - public void testMl_ServingApplication() throws SAXException, IOException { - ApplicationPackageTester tester = ApplicationPackageTester.create(appDir.toString()); - VespaModel model = new VespaModel(tester.app()); - assertHasMlModels(model); + public void testMl_ServingApplication() throws IOException { + ImportedModelTester tester = new ImportedModelTester("ml_serving", appDir); + assertHasMlModels(tester.createVespaModel()); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedAppDir = appDir.append("copy"); @@ -53,9 +51,8 @@ public class ModelEvaluationTest { IOUtils.copy(appDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString()); IOUtils.copyDirectory(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); - ApplicationPackageTester storedTester = ApplicationPackageTester.create(storedAppDir.toString()); - VespaModel storedModel = new VespaModel(storedTester.app()); - assertHasMlModels(storedModel); + ImportedModelTester storedTester = new ImportedModelTester("ml_serving", storedAppDir); + assertHasMlModels(storedTester.createVespaModel()); } finally { IOUtils.recursiveDeleteDir(storedAppDir.toFile()); |