summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-06-01 11:38:07 +0200
committerLester Solbakken <lesters@oath.com>2018-06-01 11:38:07 +0200
commitb10e7ef3eabff36c751b1518d895c0a7595f7630 (patch)
tree3b59abe02c0b1b733e1089d5729b3a306dfdde89 /config-model
parent07b3d8babae871ec17c18c83c98109a6e98e9f53 (diff)
Fix ONNX ranking feature signature
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java36
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java19
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java18
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java18
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java10
5 files changed, 53 insertions, 48 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java
index f4d944313ac..8c976a5bb0f 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java
@@ -67,7 +67,7 @@ abstract class MLImportFeatureConverter extends ExpressionTransformer<RankProfil
ImportedModel.Signature signature = chooseSignature(model, store.arguments().signature());
String output = chooseOutput(signature, store.arguments().output());
if (signature.skippedOutputs().containsKey(output)) {
- String message = "Could not import TensorFlow model output '" + output + "'";
+ String message = "Could not import model output '" + output + "'";
if (!signature.skippedOutputs().get(output).isEmpty()) {
message += ": " + signature.skippedOutputs().get(output);
}
@@ -193,7 +193,7 @@ abstract class MLImportFeatureConverter extends ExpressionTransformer<RankProfil
private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
if (profile.getMacros().containsKey(macroName)) {
- throw new IllegalArgumentException("Generated TensorFlow macro '" + macroName + "' already exists.");
+ throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists.");
}
profile.addMacro(macroName, false); // todo: inline if only used once
RankProfile.Macro macro = profile.getMacros().get(macroName);
@@ -425,9 +425,9 @@ abstract class MLImportFeatureConverter extends ExpressionTransformer<RankProfil
private final ApplicationPackage application;
private final FeatureArguments arguments;
- public ModelStore(ApplicationPackage application, Arguments arguments) {
+ public ModelStore(ApplicationPackage application, FeatureArguments arguments) {
this.application = application;
- this.arguments = new FeatureArguments(arguments);
+ this.arguments = arguments;
}
public FeatureArguments arguments() { return arguments; }
@@ -595,25 +595,13 @@ abstract class MLImportFeatureConverter extends ExpressionTransformer<RankProfil
}
- /** Encapsulates the 1, 2 or 3 arguments to a tensorflow feature */
- static class FeatureArguments {
+ /** Encapsulates the arguments to the import feature */
+ static abstract class FeatureArguments {
- private final Path modelPath;
+ Path modelPath;
/** Optional arguments */
- private final Optional<String> signature, output;
-
- public FeatureArguments(Arguments arguments) {
- if (arguments.isEmpty())
- throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
- "the tensorflow model directory under [application]/models");
- if (arguments.expressions().size() > 3)
- throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments");
-
- modelPath = Path.fromString(asString(arguments.expressions().get(0)));
- signature = optionalArgument(1, arguments);
- output = optionalArgument(2, arguments);
- }
+ Optional<String> signature, output;
/** Returns modelPath with slashes replaced by underscores */
public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); }
@@ -653,22 +641,22 @@ abstract class MLImportFeatureConverter extends ExpressionTransformer<RankProfil
return fileName.toString();
}
- private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
+ Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
if (argumentIndex >= arguments.expressions().size())
return Optional.empty();
return Optional.of(asString(arguments.expressions().get(argumentIndex)));
}
- private String asString(ExpressionNode node) {
+ String asString(ExpressionNode node) {
if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node);
+ throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
return stripQuotes(((ConstantNode)node).sourceString());
}
private String stripQuotes(String s) {
if ( ! isQuoteSign(s.codePointAt(0))) return s;
if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
- throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote");
+ throw new IllegalArgumentException("argument [" + s + "] is missing endquote");
return s.substring(1, s.length()-1);
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index f0cb0516908..44eeb364603 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -7,6 +7,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
@@ -14,6 +15,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import java.io.UncheckedIOException;
import java.util.HashMap;
import java.util.Map;
+import java.util.Optional;
/**
* Replaces instances of the onnx(model-path, output)
@@ -44,7 +46,8 @@ public class OnnxFeatureConverter extends MLImportFeatureConverter {
if ( ! feature.getName().equals("onnx")) return feature;
try {
- ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments());
+ FeatureArguments arguments = new OnnxFeatureArguments(feature.getArguments());
+ ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
if ( ! store.hasStoredModel()) // not converted yet - access Onnx model files
return transformFromOnnxModel(store, context.rankProfile(), context.queryProfiles());
else
@@ -64,4 +67,18 @@ public class OnnxFeatureConverter extends MLImportFeatureConverter {
return transformFromImportedModel(model, store, profile, queryProfiles);
}
+ static class OnnxFeatureArguments extends FeatureArguments {
+ public OnnxFeatureArguments(Arguments arguments) {
+ if (arguments.isEmpty())
+ throw new IllegalArgumentException("An onnx node must take an argument pointing to " +
+ "the tensorflow model directory under [application]/models");
+ if (arguments.expressions().size() > 3)
+ throw new IllegalArgumentException("An onnx feature can have at most 2 arguments");
+
+ modelPath = Path.fromString(asString(arguments.expressions().get(0)));
+ output = optionalArgument(1, arguments);
+ signature = Optional.of("default");
+ }
+ }
+
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 75d72111e9a..27e1ad51b33 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -6,6 +6,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
@@ -42,7 +43,8 @@ public class TensorFlowFeatureConverter extends MLImportFeatureConverter {
if ( ! feature.getName().equals("tensorflow")) return feature;
try {
- ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments());
+ FeatureArguments arguments = new TensorFlowFeatureArguments(feature.getArguments());
+ ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files
return transformFromTensorFlowModel(store, context.rankProfile(), context.queryProfiles());
else
@@ -62,4 +64,18 @@ public class TensorFlowFeatureConverter extends MLImportFeatureConverter {
return transformFromImportedModel(model, store, profile, queryProfiles);
}
+ static class TensorFlowFeatureArguments extends FeatureArguments {
+ public TensorFlowFeatureArguments(Arguments arguments) {
+ if (arguments.isEmpty())
+ throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
+ "the tensorflow model directory under [application]/models");
+ if (arguments.expressions().size() > 3)
+ throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments");
+
+ modelPath = Path.fromString(asString(arguments.expressions().get(0)));
+ signature = optionalArgument(1, arguments);
+ output = optionalArgument(2, arguments);
+ }
+ }
+
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index d6d6b952909..d9beab6e2f2 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -37,15 +37,6 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
- public void testOnnxReference() throws ParseException {
- RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx')");
- search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
- assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
- }
-
- @Test
public void testOnnxReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
"onnx('mnist_softmax.onnx')",
@@ -122,13 +113,6 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
- public void testOnnxReferenceSpecifyingOutput() {
- RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx', 'add')");
- search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- }
-
- @Test
public void testOnnxReferenceMissingMacro() throws ParseException {
try {
RankProfileSearchFixture search = new RankProfileSearchFixture(
@@ -180,7 +164,7 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx','y'): " +
- "Model does not have the output 'y'",
+ "Model does not have the specified output 'y'",
Exceptions.toMessageString(expected));
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index c9115342965..594f869cd3f 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -162,7 +162,7 @@ public class RankingExpressionWithTensorFlowTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved'): " +
- "Model refers placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
+ "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
"not present in rank profile 'my_profile'",
Exceptions.toMessageString(expected));
}
@@ -179,7 +179,7 @@ public class RankingExpressionWithTensorFlowTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved'): " +
- "Model refers placeholder 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " +
+ "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " +
"but this macro returns tensor(d0[2],d5[10])",
Exceptions.toMessageString(expected));
}
@@ -334,9 +334,9 @@ public class RankingExpressionWithTensorFlowTestCase {
"input",
application);
search.assertFirstPhaseExpression(expression, "my_profile");
- assertSmallConstant("dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
- search.assertMacro(macroExpression1, "imported_macro__dnn_hidden1_add", "my_profile");
- search.assertMacro(macroExpression2, "imported_macro__dnn_hidden2_add", "my_profile");
+ assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
+ search.assertMacro(macroExpression1, "imported_ml_macro__dnn_hidden1_add", "my_profile");
+ search.assertMacro(macroExpression2, "imported_ml_macro__dnn_hidden2_add", "my_profile");
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");