summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-02-06 15:05:39 +0100
committerJon Bratseth <bratseth@oath.com>2018-02-06 15:05:39 +0100
commit77b6442c29dee93cf449c3b4e4178d8cf1c99617 (patch)
tree31df387a32b4efe2f159967b8b04389f3e701988 /searchlib
parenta23fc5e8d4e9ef0f737041f6d4f2ebc50b38c40b (diff)
parent384475dbec8d3a525a7ea7c0d14d65b75a529689 (diff)
Merge branch 'master' into bratseth/typecheck-all
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java21
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java6
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java3
4 files changed, 33 insertions, 12 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index b8f8e288257..55782c36d18 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -219,7 +219,7 @@ class OperationMapper {
private static Optional<TypedTensorFunction> placeholderWithDefault(TensorFlowImporter.Parameters params) {
String name = toVespaName(params.node().getInput(0));
Tensor defaultValue = getConstantTensor(params, params.node().getInput(0));
- params.result().constant(name, defaultValue);
+ params.result().largeConstant(name, defaultValue);
params.result().macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")")));
// The default value will be provided by the macro. Users can override macro to change value.
TypedTensorFunction output = new TypedTensorFunction(defaultValue.type(), new VariableTensor(name));
@@ -544,7 +544,11 @@ class OperationMapper {
private static Optional<TypedTensorFunction> createConstant(TensorFlowImporter.Parameters params, Tensor constant) {
String name = toVespaName(params.node().getName());
- params.result().constant(name, constant);
+ if (constant.type().rank() == 0 || constant.size() <= 1) {
+ params.result().smallConstant(name, constant);
+ } else {
+ params.result().largeConstant(name, constant);
+ }
TypedTensorFunction output = new TypedTensorFunction(constant.type(),
new TensorFunctionNode.TensorFunctionExpressionNode(
new ReferenceNode("constant(\"" + name + "\")")));
@@ -553,8 +557,11 @@ class OperationMapper {
private static Tensor getConstantTensor(TensorFlowImporter.Parameters params, String name) {
String vespaName = toVespaName(name);
- if (params.result().constants().containsKey(vespaName)) {
- return params.result().constants().get(vespaName);
+ if (params.result().smallConstants().containsKey(vespaName)) {
+ return params.result().smallConstants().get(vespaName);
+ }
+ if (params.result().largeConstants().containsKey(vespaName)) {
+ return params.result().largeConstants().get(vespaName);
}
Session.Runner fetched = params.model().session().runner().fetch(name);
List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
index 530f4793b62..351aa417f9c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
@@ -24,13 +24,15 @@ public class TensorFlowModel {
private final Map<String, Signature> signatures = new HashMap<>();
private final Map<String, TensorType> arguments = new HashMap<>();
- private final Map<String, Tensor> constants = new HashMap<>();
+ private final Map<String, Tensor> smallConstants = new HashMap<>();
+ private final Map<String, Tensor> largeConstants = new HashMap<>();
private final Map<String, RankingExpression> expressions = new HashMap<>();
private final Map<String, RankingExpression> macros = new HashMap<>();
private final Map<String, TensorType> requiredMacros = new HashMap<>();
void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
- void constant(String name, Tensor constant) { constants.put(name, constant); }
+ void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
+ void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
void macro(String name, RankingExpression expression) { macros.put(name, expression); }
void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
@@ -43,8 +45,19 @@ public class TensorFlowModel {
/** Returns an immutable map of the arguments ("Placeholders") of this */
public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
- /** Returns an immutable map of the constants of this */
- public Map<String, Tensor> constants() { return Collections.unmodifiableMap(constants); }
+ /**
+ * Returns an immutable map of the small constants of this.
+ * These should have sizes up to a few kb at most, and correspond to constant
+ * values given in the TensorFlow source.
+ */
+ public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
+
+ /**
+ * Returns an immutable map of the large constants of this.
+ * These can have sizes in gigabytes and must be distributed to nodes separately from configuration,
+ * and correspond to Variable files stored separately in TensorFlow.
+ */
+ public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); }
/**
* Returns an immutable map of the expressions of this - corresponding to TensorFlow nodes
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
index 01dd15d5fa0..ad5abd4c03d 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
@@ -20,15 +20,15 @@ public class MnistSoftmaxImportTestCase {
TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/mnist_softmax/saved");
// Check constants
- assertEquals(2, model.get().constants().size());
+ assertEquals(2, model.get().largeConstants().size());
- Tensor constant0 = model.get().constants().get("Variable");
+ Tensor constant0 = model.get().largeConstants().get("Variable");
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
- Tensor constant1 = model.get().constants().get("Variable_1");
+ Tensor constant1 = model.get().largeConstants().get("Variable_1");
assertNotNull(constant1);
assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
constant1.type());
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
index 2c621fd2e92..ae7714b271a 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
@@ -57,7 +57,8 @@ public class TestableTensorFlowModel {
private Context contextFrom(TensorFlowModel result) {
MapContext context = new MapContext();
- result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor)));
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor)));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor)));
return context;
}