summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-10 16:41:51 +0200
committerJon Bratseth <bratseth@oath.com>2018-08-10 16:41:51 +0200
commit7ccae75e7f026df580debe58156aa0ee4585714e (patch)
tree4dc1ed0e09aef3474f402814413972e9bfc43e77 /config-model
parente1eebf1d17ae7257b6d73a737eaebb21b3b6eaaa (diff)
Remove duplication and inheritance
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java32
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java22
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java24
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java2
4 files changed, 41 insertions, 39 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
index 828d8228db1..867740c7912 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
@@ -69,7 +69,7 @@ public class ConvertedModel {
ModelImporter modelImporter,
Map<Path, ImportedModel> importedModels) {
ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
- if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files
+ if ( ! store.hasStoredModel()) // not converted yet - access from models/ directory
convertedExpression = importModel(store, context.rankProfile(), context.queryProfiles(), modelImporter, importedModels);
else
convertedExpression = transformFromStoredModel(store, context.rankProfile());
@@ -88,10 +88,10 @@ public class ConvertedModel {
public ExpressionNode expression() { return convertedExpression; }
- ExpressionNode transformFromImportedModel(ImportedModel model,
- ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles) {
+ private ExpressionNode transformFromImportedModel(ImportedModel model,
+ ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles) {
// Add constants
Set<String> constantsReplacedByMacros = new HashSet<>();
model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
@@ -631,13 +631,25 @@ public class ConvertedModel {
}
/** Encapsulates the arguments to the import feature */
- static abstract class FeatureArguments {
+ static class FeatureArguments {
Path modelPath;
/** Optional arguments */
Optional<String> signature, output;
+ public FeatureArguments(Arguments arguments) {
+ this(Path.fromString(asString(arguments.expressions().get(0))),
+ optionalArgument(1, arguments),
+ optionalArgument(2, arguments));
+ }
+
+ public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) {
+ this.modelPath = modelPath;
+ this.signature = signature;
+ this.output = output;
+ }
+
/** Returns modelPath with slashes replaced by underscores */
public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); }
@@ -676,26 +688,26 @@ public class ConvertedModel {
return fileName.toString();
}
- Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
+ private static Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
if (argumentIndex >= arguments.expressions().size())
return Optional.empty();
return Optional.of(asString(arguments.expressions().get(argumentIndex)));
}
- String asString(ExpressionNode node) {
+ private static String asString(ExpressionNode node) {
if ( ! (node instanceof ConstantNode))
throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
return stripQuotes(((ConstantNode)node).sourceString());
}
- private String stripQuotes(String s) {
+ private static String stripQuotes(String s) {
if ( ! isQuoteSign(s.codePointAt(0))) return s;
if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
throw new IllegalArgumentException("argument [" + s + "] is missing endquote");
return s.substring(1, s.length()-1);
}
- private boolean isQuoteSign(int c) {
+ private static boolean isQuoteSign(int c) {
return c == '\'' || c == '"';
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index ce656b74b54..d31ffefde65 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -47,8 +47,8 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
if ( ! feature.getName().equals("onnx")) return feature;
try {
- ConvertedModel.FeatureArguments arguments = new OnnxFeatureArguments(feature.getArguments());
- ConvertedModel convertedModel = new ConvertedModel(arguments, context, onnxImporter, importedModels);
+ ConvertedModel convertedModel = new ConvertedModel(asFeatureArguments(feature.getArguments()),
+ context, onnxImporter, importedModels);
return convertedModel.expression();
}
catch (IllegalArgumentException | UncheckedIOException e) {
@@ -56,18 +56,14 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
}
}
- static class OnnxFeatureArguments extends ConvertedModel.FeatureArguments {
- public OnnxFeatureArguments(Arguments arguments) {
- if (arguments.isEmpty())
- throw new IllegalArgumentException("An onnx node must take an argument pointing to " +
- "the tensorflow model directory under [application]/models");
- if (arguments.expressions().size() > 3)
- throw new IllegalArgumentException("An onnx feature can have at most 2 arguments");
+ private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) {
+ if (arguments.isEmpty())
+ throw new IllegalArgumentException("An onnx node must take an argument pointing to " +
+ "the onnx model directory under [application]/models");
+ if (arguments.expressions().size() > 3)
+ throw new IllegalArgumentException("An onnx feature can have at most 2 arguments");
- modelPath = Path.fromString(asString(arguments.expressions().get(0)));
- output = optionalArgument(1, arguments);
- signature = Optional.of("default");
- }
+ return new ConvertedModel.FeatureArguments(arguments);
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index b2c096d4e95..d28299b1d30 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -44,8 +44,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
if ( ! feature.getName().equals("tensorflow")) return feature;
try {
- ConvertedModel.FeatureArguments arguments = new TensorFlowFeatureArguments(feature.getArguments());
- ConvertedModel convertedModel = new ConvertedModel(arguments, context, tensorFlowImporter, importedModels);
+ ConvertedModel convertedModel = new ConvertedModel(asFeatureArguments(feature.getArguments()),
+ context, tensorFlowImporter, importedModels);
return convertedModel.expression();
}
catch (IllegalArgumentException | UncheckedIOException e) {
@@ -53,20 +53,14 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
}
- static class TensorFlowFeatureArguments extends ConvertedModel.FeatureArguments {
-
- public TensorFlowFeatureArguments(Arguments arguments) {
- if (arguments.isEmpty())
- throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
- "the tensorflow model directory under [application]/models");
- if (arguments.expressions().size() > 3)
- throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments");
-
- modelPath = Path.fromString(asString(arguments.expressions().get(0)));
- signature = optionalArgument(1, arguments);
- output = optionalArgument(2, arguments);
- }
+ private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) {
+ if (arguments.isEmpty())
+ throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
+ "the tensorflow model directory under [application]/models");
+ if (arguments.expressions().size() > 3)
+ throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments");
+ return new ConvertedModel.FeatureArguments(arguments);
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index d9beab6e2f2..b2ef08dcc36 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -164,7 +164,7 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx','y'): " +
- "Model does not have the specified output 'y'",
+ "Model does not have the specified signature 'y'",
Exceptions.toMessageString(expected));
}
}