summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2018-01-11 13:32:09 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2018-01-11 13:32:09 +0100
commit7e2e38d17a51d0ca93dc74b8e7e0d34c5eeb19af (patch)
tree4815e7f8df0e9620045289385236ea0106ae2ea2 /config-model/src/main/java/com/yahoo/searchdefinition
parent58c0e6c1115950aa479217b9c97c74f0ebd0ec01 (diff)
Use constant tensor files WIP
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java26
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/Search.java10
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java26
4 files changed, 49 insertions, 17 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index ec3100bc6b9..bacff94d776 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -39,7 +39,7 @@ public class RankProfile implements Serializable, Cloneable {
private final String name;
/** The search definition owning this profile, or null if none */
- private Search search=null;
+ private Search search = null;
/** The name of the rank profile inherited by this */
private String inheritedName = null;
@@ -51,7 +51,7 @@ public class RankProfile implements Serializable, Cloneable {
protected Set<RankSetting> rankSettings = new java.util.LinkedHashSet<>();
/** The ranking expression to be used for first phase */
- private RankingExpression firstPhaseRanking= null;
+ private RankingExpression firstPhaseRanking = null;
/** The ranking expression to be used for second phase */
private RankingExpression secondPhaseRanking = null;
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java
index 16e57ee913d..c65e0fad1c7 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java
@@ -6,7 +6,7 @@ import com.yahoo.tensor.TensorType;
import java.util.Objects;
/**
- * Represents a global ranking constant (declared in a .sd file)
+ * Represents a global ranking constant
*
* @author arnej
*/
@@ -16,23 +16,35 @@ public class RankingConstant {
private final String name;
private TensorType tensorType = null;
private String fileName = null;
- private String fileRef = "";
+ private String fileReference = "";
public RankingConstant(String name) {
this.name = name;
}
- public void setFileName(String fileName) {
+ public RankingConstant(String name, TensorType type, String fileName) {
+ this(name);
+ this.tensorType = type;
+ this.fileName = fileName;
+ validate();
+ }
+
+ public void setFileName(String fileName) {
Objects.requireNonNull(fileName, "Filename cannot be null");
- this.fileName = fileName;
+ this.fileName = fileName;
}
- public void setFileReference(String fileRef) { this.fileRef = fileRef; }
+ /**
+ * Set the internally generated reference to this file used to identify this instance of the file for
+ * file distribution.
+ */
+ public void setFileReference(String fileReference) { this.fileReference = fileReference; }
+
public void setType(TensorType tensorType) { this.tensorType = tensorType; }
public String getName() { return name; }
public String getFileName() { return fileName; }
- public String getFileReference() { return fileRef; }
+ public String getFileReference() { return fileReference; }
public TensorType getTensorType() { return tensorType; }
public String getType() { return tensorType.toString(); }
@@ -47,7 +59,7 @@ public class RankingConstant {
StringBuilder b = new StringBuilder();
b.append("constant '").append(name)
.append("' from file '").append(fileName)
- .append("' with ref '").append(fileRef)
+ .append("' with ref '").append(fileReference)
.append("' of type '").append(tensorType)
.append("'");
return b.toString();
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
index 7baa3eb170f..bd7b8ce6e15 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
@@ -58,7 +58,7 @@ public class Search implements Serializable, ImmutableSearch {
// Field sets
private FieldSets fieldSets = new FieldSets();
-
+
// Whether or not this object has been processed.
private boolean processed;
@@ -162,13 +162,13 @@ public class Search implements Serializable, ImmutableSearch {
docType = document;
}
- public void addRankingConstant(RankingConstant rConstant) {
- rConstant.validate();
- String name = rConstant.getName();
+ public void addRankingConstant(RankingConstant constant) {
+ constant.validate();
+ String name = constant.getName();
if (rankingConstants.get(name) != null) {
throw new IllegalArgumentException("Ranking constant '"+name+"' defined twice");
}
- rankingConstants.put(name, rConstant);
+ rankingConstants.put(name, constant);
}
public Iterable<RankingConstant> getRankingConstants() {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index d05027dda39..a36384ce6f2 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -1,17 +1,22 @@
package com.yahoo.searchdefinition.expressiontransforms;
import com.google.common.base.Joiner;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.tensor.Tensor;
+import java.io.File;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
@@ -48,7 +53,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
"the tensorflow model directory under [application]/models");
- String modelPath = asString(feature.getArguments().expressions().get(0));
+ String modelPath = ApplicationPackage.MODELS_DIR + "/" + asString(feature.getArguments().expressions().get(0));
TensorFlowModel result = importedModels.computeIfAbsent(modelPath, k -> tensorFlowImporter.importModel(modelPath));
// Find the specified expression
@@ -58,7 +63,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
optionalArgument(2, feature.getArguments()));
// Add all constants (after finding outputs to fail faster when the output is not found)
- result.constants().forEach((k, v) -> context.rankProfile().addConstantTensor(k, new TensorValue(v)));
+ if (1==1)
+ result.constants().forEach((k, v) -> context.rankProfile().addConstantTensor(k, new TensorValue(v)));
+ else // correct way, disabled for now
+ result.constants().forEach((k, v) -> transformConstant(modelPath, context.rankProfile(), k, v));
return expression.getRoot();
}
@@ -120,6 +128,18 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
}
+ private void transformConstant(String modelPath, RankProfile profile, String constantName, Tensor constantValue) {
+ File constantFilePath = new File(modelPath, "converted_variables");
+ if ( ! constantFilePath.exists() ) {
+ if ( ! constantFilePath.mkdir() )
+ throw new IllegalStateException("Could not create directory " + constantFilePath);
+ }
+
+ File constantFile = new File(constantFilePath, constantName + ".json");
+ // writeAsVespaTensor(constantValue, constantFile);
+ profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), constantFilePath.getPath()));
+ }
+
private String skippedOutputsDescription(TensorFlowModel.Signature signature) {
if (signature.skippedOutputs().isEmpty()) return "";
StringBuilder b = new StringBuilder(": ");