summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-20 09:17:01 +0200
committerJon Bratseth <bratseth@oath.com>2018-08-20 09:17:01 +0200
commita86919f38df73174c204f4705e3cd9d316a42553 (patch)
treeaffd290ab13c37f5ea1173c421c771c317700e3b /config-model
parent03e88f27bc2278fb711a8c4a8a85763b23067348 (diff)
Add deactivated capability to import all models
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java23
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java10
5 files changed, 28 insertions, 12 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
index 3bd96c9db26..d85d0983509 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
@@ -77,12 +77,13 @@ public class ConvertedModel {
public ConvertedModel(Path modelPath,
RankProfileTransformContext context,
- ModelImporter modelImporter) {
+ ModelImporter modelImporter,
+ FeatureArguments arguments) { // TODO: Remove
this.modelPath = modelPath;
this.modelName = toModelName(modelPath);
ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), modelPath);
if ( ! store.hasStoredModel()) // not converted yet - access from models/ directory
- expressions = importModel(store, context.rankProfile(), context.queryProfiles(), modelImporter);
+ expressions = importModel(store, context.rankProfile(), context.queryProfiles(), modelImporter, arguments);
else
expressions = transformFromStoredModel(store, context.rankProfile());
}
@@ -90,9 +91,10 @@ public class ConvertedModel {
private Map<String, RankingExpression> importModel(ModelStore store,
RankProfile profile,
QueryProfileRegistry queryProfiles,
- ModelImporter modelImporter) {
+ ModelImporter modelImporter,
+ FeatureArguments arguments) {
ImportedModel model = modelImporter.importModel(store.modelFiles.modelName(), store.modelDir());
- return transformFromImportedModel(model, store, profile, queryProfiles);
+ return transformFromImportedModel(model, store, profile, queryProfiles, arguments);
}
/** Returns the expression matching the given arguments */
@@ -132,7 +134,8 @@ public class ConvertedModel {
private Map<String, RankingExpression> transformFromImportedModel(ImportedModel model,
ModelStore store,
RankProfile profile,
- QueryProfileRegistry queryProfiles) {
+ QueryProfileRegistry queryProfiles,
+ FeatureArguments arguments) {
// Add constants
Set<String> constantsReplacedByMacros = new HashSet<>();
model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
@@ -145,7 +148,10 @@ public class ConvertedModel {
// Add expressions
Map<String, RankingExpression> expressions = new HashMap<>();
for (Map.Entry<String, ImportedModel.Signature> signatureEntry : model.signatures().entrySet()) {
+ if ( ! matches(signatureEntry.getValue(), arguments, Optional.empty())) continue;
+
for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) {
+ if ( ! matches(signatureEntry.getValue(), arguments, Optional.of(outputEntry.getKey()))) continue;
addExpression(model.expressions().get(outputEntry.getValue()),
modelName + "." + signatureEntry.getKey() + "." + outputEntry.getKey(),
constantsReplacedByMacros,
@@ -185,6 +191,13 @@ public class ConvertedModel {
return expressions;
}
+ private boolean matches(ImportedModel.Signature signature, FeatureArguments arguments, Optional<String> output) {
+ if ( ! modelName.equals(arguments.modelName)) return false;
+ if ( arguments.signature.isPresent() && ! signature.name().equals(arguments.signature().get())) return false;
+ if (output.isPresent() && arguments.output().isPresent() && ! output.get().matches(arguments.output().get())) return false;
+ return true;
+ }
+
private void addExpression(RankingExpression expression,
String expressionName,
Set<String> constantsReplacedByMacros,
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index 0dec12c4749..97395c1aad3 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -48,7 +48,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
try {
Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
- ConvertedModel convertedModel = convertedModels.computeIfAbsent(modelPath, path -> new ConvertedModel(path, context, onnxImporter));
+ ConvertedModel convertedModel = convertedModels.computeIfAbsent(modelPath, path -> new ConvertedModel(path, context, onnxImporter, new ConvertedModel.FeatureArguments(feature.getArguments())));
return convertedModel.expression(asFeatureArguments(feature.getArguments()));
}
catch (IllegalArgumentException | UncheckedIOException e) {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 585adc0c0d4..b3778e2af84 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -45,7 +45,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
try {
Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
- ConvertedModel convertedModel = convertedModels.computeIfAbsent(modelPath, path -> new ConvertedModel(path, context, tensorFlowImporter));
+ ConvertedModel convertedModel = convertedModels.computeIfAbsent(modelPath, path -> new ConvertedModel(path, context, tensorFlowImporter, new ConvertedModel.FeatureArguments(feature.getArguments())));
return convertedModel.expression(asFeatureArguments(feature.getArguments()));
}
catch (IllegalArgumentException | UncheckedIOException e) {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 80629172fe3..a7465fa9695 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -165,7 +165,8 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx','y'): " +
- "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: mnist_softmax.onnx.default.add",
+ "No expressions available in model 'mnist_softmax.onnx'",
+// "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: mnist_softmax.onnx.default.add",
Exceptions.toMessageString(expected));
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 2d33cee5820..29859817736 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -200,8 +200,9 @@ public class RankingExpressionWithTensorFlowTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved','serving_defaultz'): " +
- "No expressions named 'serving_defaultz' in model 'mnist_softmax/saved'. "+
- "Available expressions: mnist_softmax_saved.serving_default.y",
+ "No expressions available in model 'mnist_softmax_saved'",
+// "No expressions named 'serving_defaultz' in model 'mnist_softmax/saved'. "+
+// "Available expressions: mnist_softmax_saved.serving_default.y",
Exceptions.toMessageString(expected));
}
}
@@ -217,8 +218,9 @@ public class RankingExpressionWithTensorFlowTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved','serving_default','x'): " +
- "No expression 'mnist_softmax_saved.serving_default.x' in model 'mnist_softmax/saved'. " +
- "Available expressions: mnist_softmax_saved.serving_default.y",
+ "No expressions available in model 'mnist_softmax_saved'",
+// "No expression 'mnist_softmax_saved.serving_default.x' in model 'mnist_softmax/saved'. " +
+// "Available expressions: mnist_softmax_saved.serving_default.y",
Exceptions.toMessageString(expected));
}
}