diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-09-11 14:26:16 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-09-11 14:26:16 +0200 |
commit | 4ca726bd303ff8ddd59e823bb6ecc2016615791c (patch) | |
tree | cd0a2f02f5b53c7ae29ab41d971e380e57f73966 | |
parent | bd90229a30ee1c8206d758ac38b7dbea6215032b (diff) |
Refactor
10 files changed, 140 insertions, 124 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index 229ae0ebaaf..4cd8c6ac92b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -1,5 +1,4 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.path.Path; @@ -8,12 +7,12 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.vespa.model.ml.ConvertedModel; +import com.yahoo.vespa.model.ml.FeatureArguments; -import java.io.File; import java.io.UncheckedIOException; import java.util.HashMap; import java.util.Map; -import java.util.Optional; /** * Replaces instances of the onnx(model-path, output) @@ -43,7 +42,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans try { // TODO: Put modelPath in FeatureArguments instead - Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); + Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = convertedOnnxModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context)); return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); @@ -53,14 +52,14 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans } } - private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) { + private FeatureArguments asFeatureArguments(Arguments arguments) { if (arguments.isEmpty()) throw new IllegalArgumentException("An onnx node must take an argument pointing to " + "the onnx model directory under [application]/models"); if (arguments.expressions().size() > 3) throw new IllegalArgumentException("An onnx feature can have at most 2 arguments"); - return new ConvertedModel.FeatureArguments(arguments); + return new FeatureArguments(arguments); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index bcb8ef1521d..72cfde0a566 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -7,8 +7,9 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.vespa.model.ml.ConvertedModel; +import com.yahoo.vespa.model.ml.FeatureArguments; -import java.io.File; import java.io.UncheckedIOException; import java.util.HashMap; import java.util.Map; @@ -39,7 +40,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil if ( ! feature.getName().equals("tensorflow")) return feature; try { - Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); + Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context)); return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); @@ -49,14 +50,14 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } - private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) { + private FeatureArguments asFeatureArguments(Arguments arguments) { if (arguments.isEmpty()) throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + "the tensorflow model directory under [application]/models"); if (arguments.expressions().size() > 3) throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments"); - return new ConvertedModel.FeatureArguments(arguments); + return new FeatureArguments(arguments); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java index b4a5069b9d6..8591bf16d07 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java @@ -2,13 +2,13 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.path.Path; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.integration.ml.XGBoostImporter; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.vespa.model.ml.ConvertedModel; +import com.yahoo.vespa.model.ml.FeatureArguments; import java.io.UncheckedIOException; import java.util.HashMap; @@ -41,7 +41,7 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr if ( ! feature.getName().equals("xgboost")) return feature; try { - Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); + Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = convertedXGBoostModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context)); return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); @@ -50,11 +50,11 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr } } - private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) { + private FeatureArguments asFeatureArguments(Arguments arguments) { if (arguments.size() != 1) throw new IllegalArgumentException("An xgboost node must take a single argument pointing to " + "the xgboost model directory under [application]/models"); - return new ConvertedModel.FeatureArguments(arguments); + return new FeatureArguments(arguments); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index 1779fa57e8d..282e5a29962 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -2,7 +2,6 @@ package com.yahoo.vespa.model; import com.google.common.collect.ImmutableList; -import com.yahoo.collections.Pair; import com.yahoo.config.ConfigBuilder; import com.yahoo.config.ConfigInstance; import com.yahoo.config.ConfigInstance.Builder; @@ -33,7 +32,7 @@ import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.RankingConstants; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RankProfileList; -import com.yahoo.searchdefinition.expressiontransforms.ConvertedModel; +import com.yahoo.vespa.model.ml.ConvertedModel; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; @@ -54,6 +53,7 @@ import com.yahoo.vespa.model.content.cluster.ContentCluster; import com.yahoo.vespa.model.filedistribution.FileDistributionConfigProducer; import com.yahoo.vespa.model.filedistribution.FileDistributor; import com.yahoo.vespa.model.generic.service.ServiceCluster; +import com.yahoo.vespa.model.ml.ModelName; import com.yahoo.vespa.model.routing.Routing; import com.yahoo.vespa.model.search.AbstractSearchCluster; import com.yahoo.vespa.model.utils.internal.ReflectionUtil; @@ -233,7 +233,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri for (ImportedModel model : importedModels.all()) { RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry); rankProfileRegistry.add(profile); - ConvertedModel convertedModel = ConvertedModel.fromSource(new ConvertedModel.ModelName(model.name()), + ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()), model.name(), profile, queryProfiles, model); for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) { profile.addMacro(entry.getKey(), false).setRankingExpression(entry.getValue()); @@ -246,7 +246,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri String modelName = generatedModelDir.getPath().last(); RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry); rankProfileRegistry.add(profile); - ConvertedModel convertedModel = ConvertedModel.fromStore(new ConvertedModel.ModelName(modelName), modelName, profile); + ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile); for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) { profile.addMacro(entry.getKey(), false).setRankingExpression(entry.getValue()); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 050a5226324..1f27b9843cd 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -1,4 +1,5 @@ -package com.yahoo.searchdefinition.expressiontransforms; +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.ml; import com.google.common.collect.ImmutableMap; import com.yahoo.collections.Pair; @@ -11,6 +12,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.FeatureNames; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; @@ -18,7 +20,6 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -728,101 +729,4 @@ public class ConvertedModel { } - /** Encapsulates the arguments of a specific model output */ - static class FeatureArguments { - - /** Optional arguments */ - private final Optional<String> signature, output; - - public FeatureArguments(Arguments arguments) { - this(optionalArgument(1, arguments), - optionalArgument(2, arguments)); - } - - public FeatureArguments(Optional<String> signature, Optional<String> output) { - this.signature = signature; - this.output = output; - } - - public Optional<String> signature() { return signature; } - public Optional<String> output() { return output; } - - public String toName() { - return (signature.isPresent() ? signature.get() : "") + - (output.isPresent() ? "." + output.get() : ""); - } - - private static Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { - if (argumentIndex >= arguments.expressions().size()) - return Optional.empty(); - return Optional.of(asString(arguments.expressions().get(argumentIndex))); - } - - public static String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); - } - - private static String stripQuotes(String s) { - if ( ! isQuoteSign(s.codePointAt(0))) return s; - if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) - throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); - return s.substring(1, s.length()-1); - } - - private static boolean isQuoteSign(int c) { - return c == '\'' || c == '"'; - } - - } - - /** - * Models used in a rank profile has the rank profile name as name space while gGlobal model names have no namespace - */ - public static class ModelName { - - /** The namespace, or null if none */ - private String namespace; - private String name; - private String fullName; - - public ModelName(String name) { - this(null, name); - } - - public ModelName(String namespace, Path modelPath) { - this(namespace, modelPath.toString().replace("/", "_")); - } - - private ModelName(String namespace, String name) { - this.namespace = namespace; - this.name = name; - this.fullName = (namespace != null ? namespace + "." : "") + name; - } - - /** Returns true if the local name of this is not in a namespace */ - public boolean isGlobal() { return namespace == null; } - - /** Returns the namespace, or null if this is global */ - public String namespace() { return namespace; } - public String localName() { return name; } - public String fullName() { return fullName; } - - - @Override - public boolean equals(Object o) { - if (o == this) return true; - if ( ! (o instanceof ModelName)) return false; - return ((ModelName)o).fullName.equals(this.fullName); - } - - @Override - public int hashCode() { return fullName.hashCode(); } - - @Override - public String toString() { return fullName; } - - } - } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java new file mode 100644 index 00000000000..fda49af6178 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java @@ -0,0 +1,61 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.ml; + +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; + +import java.util.Optional; + +/** + * Encapsulates the arguments of a specific model output + * + * @author bratseth + */ +public class FeatureArguments { + + /** Optional arguments */ + private final Optional<String> signature, output; + + public FeatureArguments(Arguments arguments) { + this(optionalArgument(1, arguments), + optionalArgument(2, arguments)); + } + + public FeatureArguments(Optional<String> signature, Optional<String> output) { + this.signature = signature; + this.output = output; + } + + public Optional<String> signature() { return signature; } + public Optional<String> output() { return output; } + + public String toName() { + return (signature.isPresent() ? signature.get() : "") + + (output.isPresent() ? "." + output.get() : ""); + } + + private static Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { + if (argumentIndex >= arguments.expressions().size()) + return Optional.empty(); + return Optional.of(asString(arguments.expressions().get(argumentIndex))); + } + + public static String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); + } + + private static String stripQuotes(String s) { + if ( ! isQuoteSign(s.codePointAt(0))) return s; + if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) + throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); + return s.substring(1, s.length()-1); + } + + private static boolean isQuoteSign(int c) { + return c == '\'' || c == '"'; + } + +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java new file mode 100644 index 00000000000..8de9a08aa88 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java @@ -0,0 +1,53 @@ +package com.yahoo.vespa.model.ml; + +import com.yahoo.path.Path; + +/** + * Models used in a rank profile has the rank profile name as name space while gGlobal model names have no namespace + * + * @author bratseth + */ +public class ModelName { + + /** The namespace, or null if none */ + private String namespace; + private String name; + private String fullName; + + public ModelName(String name) { + this(null, name); + } + + public ModelName(String namespace, Path modelPath) { + this(namespace, modelPath.toString().replace("/", "_")); + } + + private ModelName(String namespace, String name) { + this.namespace = namespace; + this.name = name; + this.fullName = (namespace != null ? namespace + "." : "") + name; + } + + /** Returns true if the local name of this is not in a namespace */ + public boolean isGlobal() { return namespace == null; } + + /** Returns the namespace, or null if this is global */ + public String namespace() { return namespace; } + public String localName() { return name; } + public String fullName() { return fullName; } + + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if ( ! (o instanceof ModelName)) return false; + return ((ModelName)o).fullName.equals(this.fullName); + } + + @Override + public int hashCode() { return fullName.hashCode(); } + + @Override + public String toString() { return fullName; } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 42010d4d523..b046d60f948 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -1,5 +1,4 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.searchdefinition.processing; import com.yahoo.config.application.api.ApplicationPackage; @@ -8,6 +7,7 @@ import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.ml.ImportedModelTester; import com.yahoo.yolean.Exceptions; import org.junit.After; import org.junit.Test; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 4d6bfc98266..14632a568ea 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -16,6 +16,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.ml.ImportedModelTester; import com.yahoo.yolean.Exceptions; import org.junit.After; import org.junit.Test; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java index 85c77475ad3..2ae629562d0 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImportedModelTester.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition.processing; +package com.yahoo.vespa.model.ml; import com.yahoo.config.model.ApplicationPackageTester; import com.yahoo.io.GrowableByteBuffer; @@ -38,10 +38,7 @@ public class ImportedModelTester { try { return new VespaModel(ApplicationPackageTester.create(applicationDir.toString()).app()); } - catch (SAXException e) { - throw new RuntimeException(e); - } - catch (IOException e) { + catch (SAXException | IOException e) { throw new RuntimeException(e); } } |