From a86919f38df73174c204f4705e3cd9d316a42553 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Mon, 20 Aug 2018 09:17:01 +0200 Subject: Add deactivated capability to import all models --- .../expressiontransforms/ConvertedModel.java | 23 +++++++++++++++++----- .../expressiontransforms/OnnxFeatureConverter.java | 2 +- .../TensorFlowFeatureConverter.java | 2 +- .../RankingExpressionWithOnnxTestCase.java | 3 ++- .../RankingExpressionWithTensorFlowTestCase.java | 10 ++++++---- 5 files changed, 28 insertions(+), 12 deletions(-) (limited to 'config-model') diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java index 3bd96c9db26..d85d0983509 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java @@ -77,12 +77,13 @@ public class ConvertedModel { public ConvertedModel(Path modelPath, RankProfileTransformContext context, - ModelImporter modelImporter) { + ModelImporter modelImporter, + FeatureArguments arguments) { // TODO: Remove this.modelPath = modelPath; this.modelName = toModelName(modelPath); ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), modelPath); if ( ! store.hasStoredModel()) // not converted yet - access from models/ directory - expressions = importModel(store, context.rankProfile(), context.queryProfiles(), modelImporter); + expressions = importModel(store, context.rankProfile(), context.queryProfiles(), modelImporter, arguments); else expressions = transformFromStoredModel(store, context.rankProfile()); } @@ -90,9 +91,10 @@ public class ConvertedModel { private Map importModel(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, - ModelImporter modelImporter) { + ModelImporter modelImporter, + FeatureArguments arguments) { ImportedModel model = modelImporter.importModel(store.modelFiles.modelName(), store.modelDir()); - return transformFromImportedModel(model, store, profile, queryProfiles); + return transformFromImportedModel(model, store, profile, queryProfiles, arguments); } /** Returns the expression matching the given arguments */ @@ -132,7 +134,8 @@ public class ConvertedModel { private Map transformFromImportedModel(ImportedModel model, ModelStore store, RankProfile profile, - QueryProfileRegistry queryProfiles) { + QueryProfileRegistry queryProfiles, + FeatureArguments arguments) { // Add constants Set constantsReplacedByMacros = new HashSet<>(); model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); @@ -145,7 +148,10 @@ public class ConvertedModel { // Add expressions Map expressions = new HashMap<>(); for (Map.Entry signatureEntry : model.signatures().entrySet()) { + if ( ! matches(signatureEntry.getValue(), arguments, Optional.empty())) continue; + for (Map.Entry outputEntry : signatureEntry.getValue().outputs().entrySet()) { + if ( ! matches(signatureEntry.getValue(), arguments, Optional.of(outputEntry.getKey()))) continue; addExpression(model.expressions().get(outputEntry.getValue()), modelName + "." + signatureEntry.getKey() + "." + outputEntry.getKey(), constantsReplacedByMacros, @@ -185,6 +191,13 @@ public class ConvertedModel { return expressions; } + private boolean matches(ImportedModel.Signature signature, FeatureArguments arguments, Optional output) { + if ( ! modelName.equals(arguments.modelName)) return false; + if ( arguments.signature.isPresent() && ! signature.name().equals(arguments.signature().get())) return false; + if (output.isPresent() && arguments.output().isPresent() && ! output.get().matches(arguments.output().get())) return false; + return true; + } + private void addExpression(RankingExpression expression, String expressionName, Set constantsReplacedByMacros, diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index 0dec12c4749..97395c1aad3 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -48,7 +48,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer new ConvertedModel(path, context, onnxImporter)); + ConvertedModel convertedModel = convertedModels.computeIfAbsent(modelPath, path -> new ConvertedModel(path, context, onnxImporter, new ConvertedModel.FeatureArguments(feature.getArguments()))); return convertedModel.expression(asFeatureArguments(feature.getArguments())); } catch (IllegalArgumentException | UncheckedIOException e) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 585adc0c0d4..b3778e2af84 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -45,7 +45,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer new ConvertedModel(path, context, tensorFlowImporter)); + ConvertedModel convertedModel = convertedModels.computeIfAbsent(modelPath, path -> new ConvertedModel(path, context, tensorFlowImporter, new ConvertedModel.FeatureArguments(feature.getArguments()))); return convertedModel.expression(asFeatureArguments(feature.getArguments())); } catch (IllegalArgumentException | UncheckedIOException e) { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 80629172fe3..a7465fa9695 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -165,7 +165,8 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx','y'): " + - "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: mnist_softmax.onnx.default.add", + "No expressions available in model 'mnist_softmax.onnx'", +// "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: mnist_softmax.onnx.default.add", Exceptions.toMessageString(expected)); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 2d33cee5820..29859817736 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -200,8 +200,9 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved','serving_defaultz'): " + - "No expressions named 'serving_defaultz' in model 'mnist_softmax/saved'. "+ - "Available expressions: mnist_softmax_saved.serving_default.y", + "No expressions available in model 'mnist_softmax_saved'", +// "No expressions named 'serving_defaultz' in model 'mnist_softmax/saved'. "+ +// "Available expressions: mnist_softmax_saved.serving_default.y", Exceptions.toMessageString(expected)); } } @@ -217,8 +218,9 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved','serving_default','x'): " + - "No expression 'mnist_softmax_saved.serving_default.x' in model 'mnist_softmax/saved'. " + - "Available expressions: mnist_softmax_saved.serving_default.y", + "No expressions available in model 'mnist_softmax_saved'", +// "No expression 'mnist_softmax_saved.serving_default.x' in model 'mnist_softmax/saved'. " + +// "Available expressions: mnist_softmax_saved.serving_default.y", Exceptions.toMessageString(expected)); } } -- cgit v1.2.3