summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java23
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java7
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java11
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java9
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java10
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java9
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java (renamed from config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java)107
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java61
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java54
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java13
-rw-r--r--config-model/src/test/integration/onnx/services.xml5
-rw-r--r--config-model/src/test/integration/tensorflow/services.xml5
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java77
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java66
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java71
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java (renamed from config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java)15
16 files changed, 332 insertions, 211 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java
index a2bdc6834c9..7b7265e02ae 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java
@@ -1,14 +1,21 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition;
+import com.yahoo.config.FileReference;
import com.yahoo.tensor.TensorType;
+import com.yahoo.vespa.model.AbstractService;
+import com.yahoo.vespa.model.utils.FileSender;
+import java.util.Collection;
import java.util.Objects;
/**
- * Represents a global ranking constant
+ * A global ranking constant distributed using file distribution.
+ * Ranking constants must be sent to some services to be useful - this is done
+ * by calling the sentTo method during the prepare phase of building models.
*
* @author arnej
+ * @author bratseth
*/
public class RankingConstant {
@@ -49,14 +56,16 @@ public class RankingConstant {
this.pathType = PathType.URI;
}
- /**
- * Set the internally generated reference to this file used to identify this instance of the file for
- * file distribution.
- */
- public void setFileReference(String fileReference) { this.fileReference = fileReference; }
-
public void setType(TensorType tensorType) { this.tensorType = tensorType; }
+ /** Initiate sending of this constant to some services over file distribution */
+ public void sendTo(Collection<? extends AbstractService> services) {
+ FileReference reference = (pathType == RankingConstant.PathType.FILE)
+ ? FileSender.sendFileToServices(path, services)
+ : FileSender.sendUriToServices(path, services);
+ this.fileReference = reference.value();
+ }
+
public String getName() { return name; }
public String getFileName() { return path; }
public String getUri() { return path; }
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java
index 5ac1418c0c7..e354c52092f 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java
@@ -40,12 +40,7 @@ public class RankingConstants {
/** Initiate sending of these constants to some services over file distribution */
public void sendTo(Collection<? extends AbstractService> services) {
- for (RankingConstant constant : constants.values()) {
- FileReference reference = (constant.getPathType() == RankingConstant.PathType.FILE)
- ? FileSender.sendFileToServices(constant.getFileName(), services)
- : FileSender.sendUriToServices(constant.getUri(), services);
- constant.setFileReference(reference.value());
- }
+ constants.values().forEach(constant -> constant.sendTo(services));
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index 229ae0ebaaf..4cd8c6ac92b 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -1,5 +1,4 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
package com.yahoo.searchdefinition.expressiontransforms;
import com.yahoo.path.Path;
@@ -8,12 +7,12 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.vespa.model.ml.ConvertedModel;
+import com.yahoo.vespa.model.ml.FeatureArguments;
-import java.io.File;
import java.io.UncheckedIOException;
import java.util.HashMap;
import java.util.Map;
-import java.util.Optional;
/**
* Replaces instances of the onnx(model-path, output)
@@ -43,7 +42,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
try {
// TODO: Put modelPath in FeatureArguments instead
- Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
+ Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0)));
ConvertedModel convertedModel =
convertedOnnxModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context));
return convertedModel.expression(asFeatureArguments(feature.getArguments()), context);
@@ -53,14 +52,14 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
}
}
- private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) {
+ private FeatureArguments asFeatureArguments(Arguments arguments) {
if (arguments.isEmpty())
throw new IllegalArgumentException("An onnx node must take an argument pointing to " +
"the onnx model directory under [application]/models");
if (arguments.expressions().size() > 3)
throw new IllegalArgumentException("An onnx feature can have at most 2 arguments");
- return new ConvertedModel.FeatureArguments(arguments);
+ return new FeatureArguments(arguments);
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index bcb8ef1521d..72cfde0a566 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -7,8 +7,9 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.vespa.model.ml.ConvertedModel;
+import com.yahoo.vespa.model.ml.FeatureArguments;
-import java.io.File;
import java.io.UncheckedIOException;
import java.util.HashMap;
import java.util.Map;
@@ -39,7 +40,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
if ( ! feature.getName().equals("tensorflow")) return feature;
try {
- Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
+ Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0)));
ConvertedModel convertedModel =
convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context));
return convertedModel.expression(asFeatureArguments(feature.getArguments()), context);
@@ -49,14 +50,14 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
}
- private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) {
+ private FeatureArguments asFeatureArguments(Arguments arguments) {
if (arguments.isEmpty())
throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
"the tensorflow model directory under [application]/models");
if (arguments.expressions().size() > 3)
throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments");
- return new ConvertedModel.FeatureArguments(arguments);
+ return new FeatureArguments(arguments);
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
index b4a5069b9d6..8591bf16d07 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
@@ -2,13 +2,13 @@
package com.yahoo.searchdefinition.expressiontransforms;
import com.yahoo.path.Path;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.integration.ml.XGBoostImporter;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.vespa.model.ml.ConvertedModel;
+import com.yahoo.vespa.model.ml.FeatureArguments;
import java.io.UncheckedIOException;
import java.util.HashMap;
@@ -41,7 +41,7 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr
if ( ! feature.getName().equals("xgboost")) return feature;
try {
- Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
+ Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0)));
ConvertedModel convertedModel =
convertedXGBoostModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context));
return convertedModel.expression(asFeatureArguments(feature.getArguments()), context);
@@ -50,11 +50,11 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr
}
}
- private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) {
+ private FeatureArguments asFeatureArguments(Arguments arguments) {
if (arguments.size() != 1)
throw new IllegalArgumentException("An xgboost node must take a single argument pointing to " +
"the xgboost model directory under [application]/models");
- return new ConvertedModel.FeatureArguments(arguments);
+ return new FeatureArguments(arguments);
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index 1b15233fead..282e5a29962 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -2,7 +2,6 @@
package com.yahoo.vespa.model;
import com.google.common.collect.ImmutableList;
-import com.yahoo.collections.Pair;
import com.yahoo.config.ConfigBuilder;
import com.yahoo.config.ConfigInstance;
import com.yahoo.config.ConfigInstance.Builder;
@@ -33,7 +32,7 @@ import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.RankingConstants;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RankProfileList;
-import com.yahoo.searchdefinition.expressiontransforms.ConvertedModel;
+import com.yahoo.vespa.model.ml.ConvertedModel;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
@@ -54,6 +53,7 @@ import com.yahoo.vespa.model.content.cluster.ContentCluster;
import com.yahoo.vespa.model.filedistribution.FileDistributionConfigProducer;
import com.yahoo.vespa.model.filedistribution.FileDistributor;
import com.yahoo.vespa.model.generic.service.ServiceCluster;
+import com.yahoo.vespa.model.ml.ModelName;
import com.yahoo.vespa.model.routing.Routing;
import com.yahoo.vespa.model.search.AbstractSearchCluster;
import com.yahoo.vespa.model.utils.internal.ReflectionUtil;
@@ -233,7 +233,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
for (ImportedModel model : importedModels.all()) {
RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry);
rankProfileRegistry.add(profile);
- ConvertedModel convertedModel = ConvertedModel.fromSource(model.name(), model.name(), profile, queryProfiles, model);
+ ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()),
+ model.name(), profile, queryProfiles, model);
for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) {
profile.addMacro(entry.getKey(), false).setRankingExpression(entry.getValue());
}
@@ -245,7 +246,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
String modelName = generatedModelDir.getPath().last();
RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry);
rankProfileRegistry.add(profile);
- ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, modelName, profile);
+ ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile);
for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) {
profile.addMacro(entry.getKey(), false).setRankingExpression(entry.getValue());
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 2f93bcc2e12..1f27b9843cd 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -1,4 +1,5 @@
-package com.yahoo.searchdefinition.expressiontransforms;
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.ml;
import com.google.common.collect.ImmutableMap;
import com.yahoo.collections.Pair;
@@ -11,6 +12,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.FeatureNames;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
+import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
@@ -18,7 +20,6 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -63,14 +64,14 @@ import java.util.stream.Collectors;
*/
public class ConvertedModel {
- private final String modelName;
+ private final ModelName modelName;
private final String modelDescription;
private final ImmutableMap<String, RankingExpression> expressions;
/** The source importedModel, or empty if this was created from a stored converted model */
private final Optional<ImportedModel> sourceModel;
- private ConvertedModel(String modelName,
+ private ConvertedModel(ModelName modelName,
String modelDescription,
Map<String, RankingExpression> expressions,
Optional<ImportedModel> sourceModel) {
@@ -86,7 +87,7 @@ public class ConvertedModel {
*/
public static ConvertedModel fromSourceOrStore(Path modelPath, RankProfileTransformContext context) {
File sourceModel = sourceModelFile(context.rankProfile().applicationPackage(), modelPath);
- String modelName = context.rankProfile().getName() + "." + toModelName(modelPath); // must be unique to each profile
+ ModelName modelName = new ModelName(context.rankProfile().getName(), modelPath);
if (sourceModel.exists())
return fromSource(modelName,
modelPath.toString(),
@@ -99,7 +100,7 @@ public class ConvertedModel {
context.rankProfile());
}
- public static ConvertedModel fromSource(String modelName,
+ public static ConvertedModel fromSource(ModelName modelName,
String modelDescription,
RankProfile rankProfile,
QueryProfileRegistry queryProfileRegistry,
@@ -111,7 +112,7 @@ public class ConvertedModel {
Optional.of(importedModel));
}
- public static ConvertedModel fromStore(String modelName,
+ public static ConvertedModel fromStore(ModelName modelName,
String modelDescription,
RankProfile rankProfile) {
ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName);
@@ -240,9 +241,12 @@ public class ConvertedModel {
profile.addConstant(constantName, asValue(constantValue));
}
- private static void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
- Set<String> constantsReplacedByMacros,
- String constantName, Tensor constantValue) {
+ private static void transformLargeConstant(ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ Set<String> constantsReplacedByMacros,
+ String constantName,
+ Tensor constantValue) {
RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
if (macroOverridingConstant != null) {
TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
@@ -255,7 +259,7 @@ public class ConvertedModel {
Path constantPath = store.writeLargeConstant(constantName, constantValue);
if ( ! profile.rankingConstants().asMap().containsKey(constantName)) {
profile.rankingConstants().add(new RankingConstant(constantName, constantValue.type(),
- constantPath.toString()));
+ constantPath.toString()));
}
}
}
@@ -491,10 +495,6 @@ public class ConvertedModel {
return new TensorValue(tensor);
}
- private static String toModelName(Path modelPath) {
- return modelPath.toString().replace("/", "_");
- }
-
@Override
public String toString() { return "model '" + modelName + "'"; }
@@ -513,7 +513,7 @@ public class ConvertedModel {
private final ApplicationPackage application;
private final ModelFiles modelFiles;
- ModelStore(ApplicationPackage application, String modelName) {
+ ModelStore(ApplicationPackage application, ModelName modelName) {
this.application = application;
this.modelFiles = new ModelFiles(modelName);
}
@@ -616,8 +616,12 @@ public class ConvertedModel {
.writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
// Write content explicitly as a file on the file system as this is distributed using file distribution
- createIfNeeded(constantsPath);
- IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
+ // - but only if this is a global model to avoid writing the same constants for each rank profile
+ // where they are used
+ if (modelFiles.modelName.isGlobal()) {
+ createIfNeeded(constantsPath);
+ IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
+ }
return correct(constantPath);
}
@@ -676,20 +680,24 @@ public class ConvertedModel {
static class ModelFiles {
- String modelName;
+ ModelName modelName;
- public ModelFiles(String modelName) {
+ public ModelFiles(ModelName modelName) {
this.modelName = modelName;
}
/** Files stored below this path will be replicated in zookeeper */
public Path storedModelReplicatedPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelName);
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelName.fullName());
}
- /** Files stored below this path will not be replicated in zookeeper */
- public Path storedModelPath() {
- return ApplicationPackage.MODELS_GENERATED_DIR.append(modelName);
+ /**
+ * Files stored below this path will not be replicated in zookeeper.
+ * Large constants are only stored under the global (not rank-profile-specific)
+ * path to avoid storing the same large constant multiple times.
+ */
+ public Path storedGlobalModelPath() {
+ return ApplicationPackage.MODELS_GENERATED_DIR.append(modelName.localName());
}
public Path expressionPath(String name) {
@@ -706,7 +714,7 @@ public class ConvertedModel {
/** Path to the large (ranking) constants directory */
public Path largeConstantsContentPath() {
- return storedModelPath().append("constants");
+ return storedGlobalModelPath().append("constants");
}
/** Path to the large (ranking) constants directory */
@@ -721,53 +729,4 @@ public class ConvertedModel {
}
- /** Encapsulates the arguments of a specific model output */
- static class FeatureArguments {
-
- /** Optional arguments */
- private final Optional<String> signature, output;
-
- public FeatureArguments(Arguments arguments) {
- this(optionalArgument(1, arguments),
- optionalArgument(2, arguments));
- }
-
- public FeatureArguments(Optional<String> signature, Optional<String> output) {
- this.signature = signature;
- this.output = output;
- }
-
- public Optional<String> signature() { return signature; }
- public Optional<String> output() { return output; }
-
- public String toName() {
- return (signature.isPresent() ? signature.get() : "") +
- (output.isPresent() ? "." + output.get() : "");
- }
-
- private static Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
- if (argumentIndex >= arguments.expressions().size())
- return Optional.empty();
- return Optional.of(asString(arguments.expressions().get(argumentIndex)));
- }
-
- public static String asString(ExpressionNode node) {
- if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
- }
-
- private static String stripQuotes(String s) {
- if ( ! isQuoteSign(s.codePointAt(0))) return s;
- if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
- throw new IllegalArgumentException("argument [" + s + "] is missing endquote");
- return s.substring(1, s.length()-1);
- }
-
- private static boolean isQuoteSign(int c) {
- return c == '\'' || c == '"';
- }
-
- }
-
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java
new file mode 100644
index 00000000000..fda49af6178
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java
@@ -0,0 +1,61 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.ml;
+
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+
+import java.util.Optional;
+
+/**
+ * Encapsulates the arguments of a specific model output
+ *
+ * @author bratseth
+ */
+public class FeatureArguments {
+
+ /** Optional arguments */
+ private final Optional<String> signature, output;
+
+ public FeatureArguments(Arguments arguments) {
+ this(optionalArgument(1, arguments),
+ optionalArgument(2, arguments));
+ }
+
+ public FeatureArguments(Optional<String> signature, Optional<String> output) {
+ this.signature = signature;
+ this.output = output;
+ }
+
+ public Optional<String> signature() { return signature; }
+ public Optional<String> output() { return output; }
+
+ public String toName() {
+ return (signature.isPresent() ? signature.get() : "") +
+ (output.isPresent() ? "." + output.get() : "");
+ }
+
+ private static Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
+ if (argumentIndex >= arguments.expressions().size())
+ return Optional.empty();
+ return Optional.of(asString(arguments.expressions().get(argumentIndex)));
+ }
+
+ public static String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
+ }
+
+ private static String stripQuotes(String s) {
+ if ( ! isQuoteSign(s.codePointAt(0))) return s;
+ if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
+ throw new IllegalArgumentException("argument [" + s + "] is missing endquote");
+ return s.substring(1, s.length()-1);
+ }
+
+ private static boolean isQuoteSign(int c) {
+ return c == '\'' || c == '"';
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java
new file mode 100644
index 00000000000..5e22fefd093
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java
@@ -0,0 +1,54 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.ml;
+
+import com.yahoo.path.Path;
+
+/**
+ * Models used in a rank profile has the rank profile name as name space while gGlobal model names have no namespace
+ *
+ * @author bratseth
+ */
+public class ModelName {
+
+ /** The namespace, or null if none */
+ private String namespace;
+ private String name;
+ private String fullName;
+
+ public ModelName(String name) {
+ this(null, name);
+ }
+
+ public ModelName(String namespace, Path modelPath) {
+ this(namespace, modelPath.toString().replace("/", "_"));
+ }
+
+ private ModelName(String namespace, String name) {
+ this.namespace = namespace;
+ this.name = name;
+ this.fullName = (namespace != null ? namespace + "." : "") + name;
+ }
+
+ /** Returns true if the local name of this is not in a namespace */
+ public boolean isGlobal() { return namespace == null; }
+
+ /** Returns the namespace, or null if this is global */
+ public String namespace() { return namespace; }
+ public String localName() { return name; }
+ public String fullName() { return fullName; }
+
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == this) return true;
+ if ( ! (o instanceof ModelName)) return false;
+ return ((ModelName)o).fullName.equals(this.fullName);
+ }
+
+ @Override
+ public int hashCode() { return fullName.hashCode(); }
+
+ @Override
+ public String toString() { return fullName; }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java
index 83da5d96418..fbbf029d5f1 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java
@@ -1,16 +1,13 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.search;
-import com.yahoo.config.FileReference;
import com.yahoo.config.model.producer.AbstractConfigProducer;
import com.yahoo.config.model.producer.UserConfigRepo;
import com.yahoo.prelude.fastsearch.DocumentdbInfoConfig;
import com.yahoo.search.config.IndexInfoConfig;
-import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.vespa.config.search.AttributesConfig;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.configdefinition.IlscriptsConfig;
-import com.yahoo.vespa.model.utils.FileSender;
import java.util.ArrayList;
import java.util.LinkedList;
@@ -36,14 +33,8 @@ public abstract class AbstractSearchCluster extends AbstractConfigProducer
protected List<SearchDefinitionSpec> localSDS = new LinkedList<>();
public void prepareToDistributeFiles(List<SearchNode> backends) {
- for (SearchDefinitionSpec sds : localSDS) {
- for (RankingConstant constant : sds.getSearchDefinition().getSearch().rankingConstants().asMap().values()) {
- FileReference reference = (constant.getPathType() == RankingConstant.PathType.FILE)
- ? FileSender.sendFileToServices(constant.getFileName(), backends)
- : FileSender.sendUriToServices(constant.getUri(), backends);
- constant.setFileReference(reference.value());
- }
- }
+ for (SearchDefinitionSpec sds : localSDS)
+ sds.getSearchDefinition().getSearch().rankingConstants().sendTo(backends);
}
public static final class IndexingMode {
diff --git a/config-model/src/test/integration/onnx/services.xml b/config-model/src/test/integration/onnx/services.xml
new file mode 100644
index 00000000000..f623b2464fc
--- /dev/null
+++ b/config-model/src/test/integration/onnx/services.xml
@@ -0,0 +1,5 @@
+<services>
+ <container version="1.0">
+
+ </container>
+</services> \ No newline at end of file
diff --git a/config-model/src/test/integration/tensorflow/services.xml b/config-model/src/test/integration/tensorflow/services.xml
new file mode 100644
index 00000000000..f623b2464fc
--- /dev/null
+++ b/config-model/src/test/integration/tensorflow/services.xml
@@ -0,0 +1,5 @@
+<services>
+ <container version="1.0">
+
+ </container>
+</services> \ No newline at end of file
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 414a77e9164..b046d60f948 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -1,27 +1,22 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
package com.yahoo.searchdefinition.processing;
import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.serialization.TypedBinaryFormat;
+import com.yahoo.vespa.model.VespaModel;
+import com.yahoo.vespa.model.ml.ImportedModelTester;
import com.yahoo.yolean.Exceptions;
import org.junit.After;
import org.junit.Test;
import java.io.IOException;
-import java.io.UncheckedIOException;
import java.util.Optional;
import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage;
-import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.fail;
@@ -41,14 +36,36 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
+ public void testGlobalOnnxModel() throws IOException {
+ ImportedModelTester tester = new ImportedModelTester(name, applicationDir);
+ VespaModel model = tester.createVespaModel();
+ tester.assertLargeConstant(name + "_Variable_1", model, Optional.of(10L));
+ tester.assertLargeConstant(name + "_Variable", model, Optional.of(7840L));
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedAppDir = applicationDir.append("copy");
+ try {
+ storedAppDir.toFile().mkdirs();
+ IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString());
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir);
+ VespaModel storedModel = storedTester.createVespaModel();
+ tester.assertLargeConstant(name + "_Variable_1", storedModel, Optional.of(10L));
+ tester.assertLargeConstant(name + "_Variable", storedModel, Optional.of(7840L));
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(storedAppDir.toFile());
+ }
+ }
+
+ @Test
public void testOnnxReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
"onnx('mnist_softmax.onnx')",
"constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }",
null);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant(name + "_Variable_1", search, Optional.of(10L));
- assertLargeConstant(name + "_Variable", search, Optional.of(7840L));
}
@Test
@@ -68,8 +85,6 @@ public class RankingExpressionWithOnnxTestCase {
"Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant(name + "_Variable_1", search, Optional.of(10L));
- assertLargeConstant(name + "_Variable", search, Optional.of(7840L));
}
@Test
@@ -82,8 +97,6 @@ public class RankingExpressionWithOnnxTestCase {
"Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant( name + "_Variable_1", search, Optional.of(10L));
- assertLargeConstant( name + "_Variable", search, Optional.of(7840L));
}
@@ -104,8 +117,6 @@ public class RankingExpressionWithOnnxTestCase {
"Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant( name + "_Variable_1", search, Optional.of(10L));
- assertLargeConstant( name + "_Variable", search, Optional.of(7840L));
}
@@ -114,8 +125,6 @@ public class RankingExpressionWithOnnxTestCase {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"5 + sum(onnx('mnist_softmax.onnx'))");
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
- assertLargeConstant( name + "_Variable_1", search, Optional.of(10L));
- assertLargeConstant( name + "_Variable", search, Optional.of(7840L));
}
@Test
@@ -181,9 +190,6 @@ public class RankingExpressionWithOnnxTestCase {
"onnx('mnist_softmax.onnx')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant( name + "_Variable_1", search, Optional.of(10L));
- assertLargeConstant( name + "_Variable", search, Optional.of(7840L));
-
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
try {
@@ -200,8 +206,6 @@ public class RankingExpressionWithOnnxTestCase {
searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile");
// Verify that the constants exists, but don't verify the content as we are not
// simulating file distribution in this test
- assertLargeConstant( name + "_Variable_1", searchFromStored, Optional.empty());
- assertLargeConstant( name + "_Variable", searchFromStored, Optional.empty());
}
finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
@@ -232,7 +236,6 @@ public class RankingExpressionWithOnnxTestCase {
assertNull("Constant overridden by macro is not added",
search.search().rankingConstants().get( name + "_Variable"));
- assertLargeConstant( name + "_Variable_1", search, Optional.of(10L));
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
@@ -245,38 +248,12 @@ public class RankingExpressionWithOnnxTestCase {
searchFromStored.compileRankProfile("my_profile", applicationDir.append("models"));
searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
assertNull("Constant overridden by macro is not added",
- searchFromStored.search().rankingConstants().get( name + "_Variable"));
- assertLargeConstant( name + "_Variable_1", searchFromStored, Optional.of(10L));
+ searchFromStored.search().rankingConstants().get( name + "_Variable"));
} finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
}
}
- /**
- * Verifies that the constant with the given name exists, and - only if an expected size is given -
- * that the content of the constant is available and has the expected size.
- */
- private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) {
- try {
- Path constantApplicationPackagePath = Path.fromString("models.generated/my_profile.mnist_softmax.onnx/constants").append(name + ".tbf");
- RankingConstant rankingConstant = search.search().rankingConstants().get(name);
- assertEquals(name, rankingConstant.getName());
- assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString()));
-
- if (expectedSize.isPresent()) {
- Path constantPath = applicationDir.append(constantApplicationPackagePath);
- assertTrue("Constant file '" + constantPath + "' has been written",
- constantPath.toFile().exists());
- Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(),
- GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile())));
- assertEquals(expectedSize.get().longValue(), deserializedConstant.size());
- }
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) {
return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder",
new StoringApplicationPackage(applicationDir));
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 450c66e04ef..14632a568ea 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -15,27 +15,22 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
+import com.yahoo.vespa.model.VespaModel;
+import com.yahoo.vespa.model.ml.ImportedModelTester;
import com.yahoo.yolean.Exceptions;
import org.junit.After;
import org.junit.Test;
-import java.io.BufferedInputStream;
import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
-import java.io.InputStream;
-import java.io.Reader;
import java.io.UncheckedIOException;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.Collections;
-import java.util.Iterator;
import java.util.List;
import java.util.Optional;
-import java.util.stream.Collectors;
+import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.*;
/**
@@ -56,12 +51,34 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
+ public void testGlobalTensorFlowModel() throws IOException {
+ ImportedModelTester tester = new ImportedModelTester(name, applicationDir);
+ VespaModel model = tester.createVespaModel();
+ assertLargeConstant(name + "_layer_Variable_1_read", model, Optional.of(10L));
+ assertLargeConstant(name + "_layer_Variable_read", model, Optional.of(7840L));
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedAppDir = applicationDir.append("copy");
+ try {
+ storedAppDir.toFile().mkdirs();
+ IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString());
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir);
+ VespaModel storedModel = storedTester.createVespaModel();
+ tester.assertLargeConstant(name + "_layer_Variable_1_read", storedModel, Optional.of(10L));
+ tester.assertLargeConstant(name + "_layer_Variable_read", storedModel, Optional.of(7840L));
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(storedAppDir.toFile());
+ }
+ }
+
+ @Test
public void testTensorFlowReference() {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"tensorflow('mnist_softmax/saved')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L));
- assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L));
}
@Test
@@ -71,8 +88,6 @@ public class RankingExpressionWithTensorFlowTestCase {
"constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }",
null);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L));
- assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L));
}
@Test
@@ -91,8 +106,6 @@ public class RankingExpressionWithTensorFlowTestCase {
"Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L));
- assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L));
}
@Test
@@ -105,8 +118,6 @@ public class RankingExpressionWithTensorFlowTestCase {
"Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L));
- assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L));
}
@Test
@@ -125,8 +136,6 @@ public class RankingExpressionWithTensorFlowTestCase {
"Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L));
- assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L));
}
@Test
@@ -134,8 +143,6 @@ public class RankingExpressionWithTensorFlowTestCase {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"5 + sum(tensorflow('mnist_softmax/saved'))");
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
- assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L));
- assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L));
}
@Test
@@ -233,9 +240,6 @@ public class RankingExpressionWithTensorFlowTestCase {
"tensorflow('mnist_softmax/saved')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L));
- assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L));
-
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
try {
@@ -250,10 +254,6 @@ public class RankingExpressionWithTensorFlowTestCase {
"Placeholder",
storedApplication);
searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile");
- // Verify that the constants exists, but don't verify the content as we are not
- // simulating file distribution in this test
- assertLargeConstant(name + "_layer_Variable_1_read", searchFromStored, Optional.empty());
- assertLargeConstant(name + "_layer_Variable_read", searchFromStored, Optional.empty());
}
finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
@@ -287,7 +287,6 @@ public class RankingExpressionWithTensorFlowTestCase {
assertNull("Constant overridden by macro is not added",
search.search().rankingConstants().get("mnist_softmax_saved_layer_Variable_read"));
- assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L));
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
@@ -303,7 +302,6 @@ public class RankingExpressionWithTensorFlowTestCase {
searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child");
assertNull("Constant overridden by macro is not added",
searchFromStored.search().rankingConstants().get("mnist_softmax_saved_layer_Variable_read"));
- assertLargeConstant(name + "_layer_Variable_1_read", searchFromStored, Optional.of(10L));
}
finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
@@ -316,8 +314,6 @@ public class RankingExpressionWithTensorFlowTestCase {
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
"tensorflow('mnist_softmax/saved')");
search.assertFirstPhaseExpression(expression, "my_profile");
- assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L));
- assertLargeConstant(name + "_layer_Variable_read", search, Optional.of(7840L));
}
@Test
@@ -401,11 +397,11 @@ public class RankingExpressionWithTensorFlowTestCase {
* Verifies that the constant with the given name exists, and - only if an expected size is given -
* that the content of the constant is available and has the expected size.
*/
- private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) {
+ private void assertLargeConstant(String constantName, VespaModel model, Optional<Long> expectedSize) {
try {
- Path constantApplicationPackagePath = Path.fromString("models.generated/my_profile.mnist_softmax_saved/constants").append(name + ".tbf");
- RankingConstant rankingConstant = search.search().rankingConstants().get(name);
- assertEquals(name, rankingConstant.getName());
+ Path constantApplicationPackagePath = Path.fromString("models.generated/" + name + "/constants").append(constantName + ".tbf");
+ RankingConstant rankingConstant = model.rankingConstants().get(constantName);
+ assertEquals(constantName, rankingConstant.getName());
assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString()));
if (expectedSize.isPresent()) {
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
new file mode 100644
index 00000000000..2ae629562d0
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
@@ -0,0 +1,71 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.ml;
+
+import com.yahoo.config.model.ApplicationPackageTester;
+import com.yahoo.io.GrowableByteBuffer;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.searchdefinition.RankingConstant;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.serialization.TypedBinaryFormat;
+import com.yahoo.vespa.model.VespaModel;
+import org.xml.sax.SAXException;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.util.Optional;
+
+import static junit.framework.TestCase.assertTrue;
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Helper for testing of imported models.
+ * More duplicated functionality across tests on imported models should be moved here
+ *
+ * @author bratseth
+ */
+public class ImportedModelTester {
+
+ private final String modelName;
+ private final Path applicationDir;
+
+ public ImportedModelTester(String modelName, Path applicationDir) {
+ this.modelName = modelName;
+ this.applicationDir = applicationDir;
+ }
+
+ public VespaModel createVespaModel() {
+ try {
+ return new VespaModel(ApplicationPackageTester.create(applicationDir.toString()).app());
+ }
+ catch (SAXException | IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /**
+ * Verifies that the constant with the given name exists, and - only if an expected size is given -
+ * that the content of the constant is available and has the expected size.
+ */
+ public void assertLargeConstant(String constantName, VespaModel model, Optional<Long> expectedSize) {
+ try {
+ Path constantApplicationPackagePath = Path.fromString("models.generated/" + modelName + "/constants").append(constantName + ".tbf");
+ RankingConstant rankingConstant = model.rankingConstants().get(constantName);
+ assertEquals(constantName, rankingConstant.getName());
+ assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString()));
+
+ if (expectedSize.isPresent()) {
+ Path constantPath = applicationDir.append(constantApplicationPackagePath);
+ assertTrue("Constant file '" + constantPath + "' has been written",
+ constantPath.toFile().exists());
+ Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(),
+ GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile())));
+ assertEquals(expectedSize.get().longValue(), deserializedConstant.size());
+ }
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+}
diff --git a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index ad2f62b7dc3..35e6642d7cb 100644
--- a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.config.model;
+package com.yahoo.vespa.model.ml;
import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.ModelsEvaluator;
@@ -18,7 +18,6 @@ import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.container.ContainerCluster;
import org.junit.After;
import org.junit.Test;
-import org.xml.sax.SAXException;
import java.io.IOException;
import java.util.Set;
@@ -41,10 +40,9 @@ public class ModelEvaluationTest {
}
@Test
- public void testMl_ServingApplication() throws SAXException, IOException {
- ApplicationPackageTester tester = ApplicationPackageTester.create(appDir.toString());
- VespaModel model = new VespaModel(tester.app());
- assertHasMlModels(model);
+ public void testMl_ServingApplication() throws IOException {
+ ImportedModelTester tester = new ImportedModelTester("ml_serving", appDir);
+ assertHasMlModels(tester.createVespaModel());
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedAppDir = appDir.append("copy");
@@ -53,9 +51,8 @@ public class ModelEvaluationTest {
IOUtils.copy(appDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString());
IOUtils.copyDirectory(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
- ApplicationPackageTester storedTester = ApplicationPackageTester.create(storedAppDir.toString());
- VespaModel storedModel = new VespaModel(storedTester.app());
- assertHasMlModels(storedModel);
+ ImportedModelTester storedTester = new ImportedModelTester("ml_serving", storedAppDir);
+ assertHasMlModels(storedTester.createVespaModel());
}
finally {
IOUtils.recursiveDeleteDir(storedAppDir.toFile());