aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java26
1 files changed, 23 insertions, 3 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index d05027dda39..a36384ce6f2 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -1,17 +1,22 @@
package com.yahoo.searchdefinition.expressiontransforms;
import com.google.common.base.Joiner;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.tensor.Tensor;
+import java.io.File;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
@@ -48,7 +53,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
"the tensorflow model directory under [application]/models");
- String modelPath = asString(feature.getArguments().expressions().get(0));
+ String modelPath = ApplicationPackage.MODELS_DIR + "/" + asString(feature.getArguments().expressions().get(0));
TensorFlowModel result = importedModels.computeIfAbsent(modelPath, k -> tensorFlowImporter.importModel(modelPath));
// Find the specified expression
@@ -58,7 +63,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
optionalArgument(2, feature.getArguments()));
// Add all constants (after finding outputs to fail faster when the output is not found)
- result.constants().forEach((k, v) -> context.rankProfile().addConstantTensor(k, new TensorValue(v)));
+ if (1==1)
+ result.constants().forEach((k, v) -> context.rankProfile().addConstantTensor(k, new TensorValue(v)));
+ else // correct way, disabled for now
+ result.constants().forEach((k, v) -> transformConstant(modelPath, context.rankProfile(), k, v));
return expression.getRoot();
}
@@ -120,6 +128,18 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
}
+ private void transformConstant(String modelPath, RankProfile profile, String constantName, Tensor constantValue) {
+ File constantFilePath = new File(modelPath, "converted_variables");
+ if ( ! constantFilePath.exists() ) {
+ if ( ! constantFilePath.mkdir() )
+ throw new IllegalStateException("Could not create directory " + constantFilePath);
+ }
+
+ File constantFile = new File(constantFilePath, constantName + ".json");
+ // writeAsVespaTensor(constantValue, constantFile);
+ profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), constantFilePath.getPath()));
+ }
+
private String skippedOutputsDescription(TensorFlowModel.Signature signature) {
if (signature.skippedOutputs().isEmpty()) return "";
StringBuilder b = new StringBuilder(": ");