diff options
author | Lester Solbakken <lesters@oath.com> | 2018-06-01 11:38:07 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-06-01 11:38:07 +0200 |
commit | b10e7ef3eabff36c751b1518d895c0a7595f7630 (patch) | |
tree | 3b59abe02c0b1b733e1089d5729b3a306dfdde89 /config-model | |
parent | 07b3d8babae871ec17c18c83c98109a6e98e9f53 (diff) |
Fix ONNX ranking feature signature
Diffstat (limited to 'config-model')
5 files changed, 53 insertions, 48 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java index f4d944313ac..8c976a5bb0f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java @@ -67,7 +67,7 @@ abstract class MLImportFeatureConverter extends ExpressionTransformer<RankProfil ImportedModel.Signature signature = chooseSignature(model, store.arguments().signature()); String output = chooseOutput(signature, store.arguments().output()); if (signature.skippedOutputs().containsKey(output)) { - String message = "Could not import TensorFlow model output '" + output + "'"; + String message = "Could not import model output '" + output + "'"; if (!signature.skippedOutputs().get(output).isEmpty()) { message += ": " + signature.skippedOutputs().get(output); } @@ -193,7 +193,7 @@ abstract class MLImportFeatureConverter extends ExpressionTransformer<RankProfil private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { if (profile.getMacros().containsKey(macroName)) { - throw new IllegalArgumentException("Generated TensorFlow macro '" + macroName + "' already exists."); + throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists."); } profile.addMacro(macroName, false); // todo: inline if only used once RankProfile.Macro macro = profile.getMacros().get(macroName); @@ -425,9 +425,9 @@ abstract class MLImportFeatureConverter extends ExpressionTransformer<RankProfil private final ApplicationPackage application; private final FeatureArguments arguments; - public ModelStore(ApplicationPackage application, Arguments arguments) { + public ModelStore(ApplicationPackage application, FeatureArguments arguments) { this.application = application; - this.arguments = new FeatureArguments(arguments); + this.arguments = arguments; } public FeatureArguments arguments() { return arguments; } @@ -595,25 +595,13 @@ abstract class MLImportFeatureConverter extends ExpressionTransformer<RankProfil } - /** Encapsulates the 1, 2 or 3 arguments to a tensorflow feature */ - static class FeatureArguments { + /** Encapsulates the arguments to the import feature */ + static abstract class FeatureArguments { - private final Path modelPath; + Path modelPath; /** Optional arguments */ - private final Optional<String> signature, output; - - public FeatureArguments(Arguments arguments) { - if (arguments.isEmpty()) - throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + - "the tensorflow model directory under [application]/models"); - if (arguments.expressions().size() > 3) - throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments"); - - modelPath = Path.fromString(asString(arguments.expressions().get(0))); - signature = optionalArgument(1, arguments); - output = optionalArgument(2, arguments); - } + Optional<String> signature, output; /** Returns modelPath with slashes replaced by underscores */ public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); } @@ -653,22 +641,22 @@ abstract class MLImportFeatureConverter extends ExpressionTransformer<RankProfil return fileName.toString(); } - private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { + Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { if (argumentIndex >= arguments.expressions().size()) return Optional.empty(); return Optional.of(asString(arguments.expressions().get(argumentIndex))); } - private String asString(ExpressionNode node) { + String asString(ExpressionNode node) { if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node); + throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); return stripQuotes(((ConstantNode)node).sourceString()); } private String stripQuotes(String s) { if ( ! isQuoteSign(s.codePointAt(0))) return s; if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) - throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote"); + throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); return s.substring(1, s.length()-1); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index f0cb0516908..44eeb364603 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -7,6 +7,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; @@ -14,6 +15,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import java.io.UncheckedIOException; import java.util.HashMap; import java.util.Map; +import java.util.Optional; /** * Replaces instances of the onnx(model-path, output) @@ -44,7 +46,8 @@ public class OnnxFeatureConverter extends MLImportFeatureConverter { if ( ! feature.getName().equals("onnx")) return feature; try { - ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments()); + FeatureArguments arguments = new OnnxFeatureArguments(feature.getArguments()); + ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments); if ( ! store.hasStoredModel()) // not converted yet - access Onnx model files return transformFromOnnxModel(store, context.rankProfile(), context.queryProfiles()); else @@ -64,4 +67,18 @@ public class OnnxFeatureConverter extends MLImportFeatureConverter { return transformFromImportedModel(model, store, profile, queryProfiles); } + static class OnnxFeatureArguments extends FeatureArguments { + public OnnxFeatureArguments(Arguments arguments) { + if (arguments.isEmpty()) + throw new IllegalArgumentException("An onnx node must take an argument pointing to " + + "the tensorflow model directory under [application]/models"); + if (arguments.expressions().size() > 3) + throw new IllegalArgumentException("An onnx feature can have at most 2 arguments"); + + modelPath = Path.fromString(asString(arguments.expressions().get(0))); + output = optionalArgument(1, arguments); + signature = Optional.of("default"); + } + } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 75d72111e9a..27e1ad51b33 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -6,6 +6,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; @@ -42,7 +43,8 @@ public class TensorFlowFeatureConverter extends MLImportFeatureConverter { if ( ! feature.getName().equals("tensorflow")) return feature; try { - ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments()); + FeatureArguments arguments = new TensorFlowFeatureArguments(feature.getArguments()); + ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments); if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files return transformFromTensorFlowModel(store, context.rankProfile(), context.queryProfiles()); else @@ -62,4 +64,18 @@ public class TensorFlowFeatureConverter extends MLImportFeatureConverter { return transformFromImportedModel(model, store, profile, queryProfiles); } + static class TensorFlowFeatureArguments extends FeatureArguments { + public TensorFlowFeatureArguments(Arguments arguments) { + if (arguments.isEmpty()) + throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + + "the tensorflow model directory under [application]/models"); + if (arguments.expressions().size() > 3) + throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments"); + + modelPath = Path.fromString(asString(arguments.expressions().get(0))); + signature = optionalArgument(1, arguments); + output = optionalArgument(2, arguments); + } + } + } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index d6d6b952909..d9beab6e2f2 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -37,15 +37,6 @@ public class RankingExpressionWithOnnxTestCase { } @Test - public void testOnnxReference() throws ParseException { - RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", - "onnx('mnist_softmax.onnx')"); - search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); - assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); - } - - @Test public void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "onnx('mnist_softmax.onnx')", @@ -122,13 +113,6 @@ public class RankingExpressionWithOnnxTestCase { } @Test - public void testOnnxReferenceSpecifyingOutput() { - RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", - "onnx('mnist_softmax.onnx', 'add')"); - search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - } - - @Test public void testOnnxReferenceMissingMacro() throws ParseException { try { RankProfileSearchFixture search = new RankProfileSearchFixture( @@ -180,7 +164,7 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx','y'): " + - "Model does not have the output 'y'", + "Model does not have the specified output 'y'", Exceptions.toMessageString(expected)); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index c9115342965..594f869cd3f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -162,7 +162,7 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved'): " + - "Model refers placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + + "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + "not present in rank profile 'my_profile'", Exceptions.toMessageString(expected)); } @@ -179,7 +179,7 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved'): " + - "Model refers placeholder 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " + + "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " + "but this macro returns tensor(d0[2],d5[10])", Exceptions.toMessageString(expected)); } @@ -334,9 +334,9 @@ public class RankingExpressionWithTensorFlowTestCase { "input", application); search.assertFirstPhaseExpression(expression, "my_profile"); - assertSmallConstant("dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); - search.assertMacro(macroExpression1, "imported_macro__dnn_hidden1_add", "my_profile"); - search.assertMacro(macroExpression2, "imported_macro__dnn_hidden2_add", "my_profile"); + assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); + search.assertMacro(macroExpression1, "imported_ml_macro__dnn_hidden1_add", "my_profile"); + search.assertMacro(macroExpression2, "imported_ml_macro__dnn_hidden2_add", "my_profile"); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); |