diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-11 13:32:09 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-11 13:32:09 +0100 |
commit | 7e2e38d17a51d0ca93dc74b8e7e0d34c5eeb19af (patch) | |
tree | 4815e7f8df0e9620045289385236ea0106ae2ea2 /config-model/src/main/java/com/yahoo/searchdefinition | |
parent | 58c0e6c1115950aa479217b9c97c74f0ebd0ec01 (diff) |
Use constant tensor files WIP
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition')
4 files changed, 49 insertions, 17 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index ec3100bc6b9..bacff94d776 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -39,7 +39,7 @@ public class RankProfile implements Serializable, Cloneable { private final String name; /** The search definition owning this profile, or null if none */ - private Search search=null; + private Search search = null; /** The name of the rank profile inherited by this */ private String inheritedName = null; @@ -51,7 +51,7 @@ public class RankProfile implements Serializable, Cloneable { protected Set<RankSetting> rankSettings = new java.util.LinkedHashSet<>(); /** The ranking expression to be used for first phase */ - private RankingExpression firstPhaseRanking= null; + private RankingExpression firstPhaseRanking = null; /** The ranking expression to be used for second phase */ private RankingExpression secondPhaseRanking = null; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java index 16e57ee913d..c65e0fad1c7 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java @@ -6,7 +6,7 @@ import com.yahoo.tensor.TensorType; import java.util.Objects; /** - * Represents a global ranking constant (declared in a .sd file) + * Represents a global ranking constant * * @author arnej */ @@ -16,23 +16,35 @@ public class RankingConstant { private final String name; private TensorType tensorType = null; private String fileName = null; - private String fileRef = ""; + private String fileReference = ""; public RankingConstant(String name) { this.name = name; } - public void setFileName(String fileName) { + public RankingConstant(String name, TensorType type, String fileName) { + this(name); + this.tensorType = type; + this.fileName = fileName; + validate(); + } + + public void setFileName(String fileName) { Objects.requireNonNull(fileName, "Filename cannot be null"); - this.fileName = fileName; + this.fileName = fileName; } - public void setFileReference(String fileRef) { this.fileRef = fileRef; } + /** + * Set the internally generated reference to this file used to identify this instance of the file for + * file distribution. + */ + public void setFileReference(String fileReference) { this.fileReference = fileReference; } + public void setType(TensorType tensorType) { this.tensorType = tensorType; } public String getName() { return name; } public String getFileName() { return fileName; } - public String getFileReference() { return fileRef; } + public String getFileReference() { return fileReference; } public TensorType getTensorType() { return tensorType; } public String getType() { return tensorType.toString(); } @@ -47,7 +59,7 @@ public class RankingConstant { StringBuilder b = new StringBuilder(); b.append("constant '").append(name) .append("' from file '").append(fileName) - .append("' with ref '").append(fileRef) + .append("' with ref '").append(fileReference) .append("' of type '").append(tensorType) .append("'"); return b.toString(); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java index 7baa3eb170f..bd7b8ce6e15 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java @@ -58,7 +58,7 @@ public class Search implements Serializable, ImmutableSearch { // Field sets private FieldSets fieldSets = new FieldSets(); - + // Whether or not this object has been processed. private boolean processed; @@ -162,13 +162,13 @@ public class Search implements Serializable, ImmutableSearch { docType = document; } - public void addRankingConstant(RankingConstant rConstant) { - rConstant.validate(); - String name = rConstant.getName(); + public void addRankingConstant(RankingConstant constant) { + constant.validate(); + String name = constant.getName(); if (rankingConstants.get(name) != null) { throw new IllegalArgumentException("Ranking constant '"+name+"' defined twice"); } - rankingConstants.put(name, rConstant); + rankingConstants.put(name, constant); } public Iterable<RankingConstant> getRankingConstants() { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index d05027dda39..a36384ce6f2 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -1,17 +1,22 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.google.common.base.Joiner; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.Tensor; +import java.io.File; import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -48,7 +53,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + "the tensorflow model directory under [application]/models"); - String modelPath = asString(feature.getArguments().expressions().get(0)); + String modelPath = ApplicationPackage.MODELS_DIR + "/" + asString(feature.getArguments().expressions().get(0)); TensorFlowModel result = importedModels.computeIfAbsent(modelPath, k -> tensorFlowImporter.importModel(modelPath)); // Find the specified expression @@ -58,7 +63,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil optionalArgument(2, feature.getArguments())); // Add all constants (after finding outputs to fail faster when the output is not found) - result.constants().forEach((k, v) -> context.rankProfile().addConstantTensor(k, new TensorValue(v))); + if (1==1) + result.constants().forEach((k, v) -> context.rankProfile().addConstantTensor(k, new TensorValue(v))); + else // correct way, disabled for now + result.constants().forEach((k, v) -> transformConstant(modelPath, context.rankProfile(), k, v)); return expression.getRoot(); } @@ -120,6 +128,18 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } + private void transformConstant(String modelPath, RankProfile profile, String constantName, Tensor constantValue) { + File constantFilePath = new File(modelPath, "converted_variables"); + if ( ! constantFilePath.exists() ) { + if ( ! constantFilePath.mkdir() ) + throw new IllegalStateException("Could not create directory " + constantFilePath); + } + + File constantFile = new File(constantFilePath, constantName + ".json"); + // writeAsVespaTensor(constantValue, constantFile); + profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), constantFilePath.getPath())); + } + private String skippedOutputsDescription(TensorFlowModel.Signature signature) { if (signature.skippedOutputs().isEmpty()) return ""; StringBuilder b = new StringBuilder(": "); |