summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <jonbratseth@yahoo.com>2018-02-05 22:56:23 +0100
committerGitHub <noreply@github.com>2018-02-05 22:56:23 +0100
commit6e6e9c71e11268a7badd2297341a0937cbad2d1f (patch)
tree4fc17d3e36f507efea78adc856228eec5f144019
parent3632387ab3bf56688d54c0714bcefe6f0f6d999f (diff)
parent30a2d3e88529bc5a86ad6c53c8de35e4a71fbac3 (diff)
Merge pull request #4924 from vespa-engine/bratseth/support-small-constants
Support small constants
-rw-r--r--config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationFile.java17
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationFile.java7
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java101
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java11
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationFile.java16
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java21
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java6
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java3
11 files changed, 170 insertions, 33 deletions
diff --git a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationFile.java b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationFile.java
index 60524fbca8d..a8e1256e032 100644
--- a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationFile.java
+++ b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationFile.java
@@ -111,8 +111,8 @@ public class FilesApplicationFile extends ApplicationFile {
file.getParentFile().mkdirs();
}
try {
- String data = com.yahoo.io.IOUtils.readAll(input);
String status = file.exists() ? ApplicationFile.ContentStatusChanged : ApplicationFile.ContentStatusNew;
+ String data = com.yahoo.io.IOUtils.readAll(input);
IOUtils.writeFile(file, data, false);
writeMetaFile(data, status);
} catch (IOException e) {
@@ -122,6 +122,21 @@ public class FilesApplicationFile extends ApplicationFile {
}
@Override
+ public ApplicationFile appendFile(String value) {
+ if (file.getParentFile() != null) {
+ file.getParentFile().mkdirs();
+ }
+ try {
+ String status = file.exists() ? ApplicationFile.ContentStatusChanged : ApplicationFile.ContentStatusNew;
+ IOUtils.writeFile(file, value, true);
+ writeMetaFile(value, status);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ return this;
+ }
+
+ @Override
public List<ApplicationFile> listFiles(final PathFilter filter) {
List<ApplicationFile> files = new ArrayList<>();
if (!file.isDirectory()) {
diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationFile.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationFile.java
index 0384a5c7a1c..33b7807aac5 100644
--- a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationFile.java
+++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationFile.java
@@ -75,6 +75,13 @@ public abstract class ApplicationFile implements Comparable<ApplicationFile> {
public abstract ApplicationFile writeFile(Reader input);
/**
+ * Appends the given string to this text file.
+ *
+ * @return this
+ */
+ public abstract ApplicationFile appendFile(String value);
+
+ /**
* List the files under this directory. If this is file, an empty list is returned.
* Only immediate files/subdirectories are returned.
*
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 2f28d9adb8b..bcbc7cc99e2 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -750,7 +750,9 @@ public class RankProfile implements Serializable, Cloneable {
public TypeContext typeContext(QueryProfileRegistry queryProfiles) {
TypeMapContext context = new TypeMapContext();
- // Add constants
+ // Add small constants
+ getConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.type()));
+ // Add large constants
getSearch().getRankingConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.getTensorType()));
// Add attributes
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 495ca7dd14a..2b997aa25f2 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -2,6 +2,7 @@
package com.yahoo.searchdefinition.expressiontransforms;
import com.google.common.base.Joiner;
+import com.yahoo.collections.Pair;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.application.provider.FilesApplicationPackage;
@@ -11,6 +12,9 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel.Signature;
@@ -25,11 +29,13 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
+import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.StringReader;
import java.io.UncheckedIOException;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -92,16 +98,21 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
verifyRequiredMacros(expression, model.requiredMacros(), profile, queryProfiles);
store.writeConverted(expression);
- model.constants().forEach((k, v) -> transformConstant(store, profile, k, v));
+ model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
+ model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, k, v));
return expression.getRoot();
}
private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
- for (RankingConstant constant : store.readRankingConstants()) {
+ for (Pair<String, Tensor> constant : store.readSmallConstants())
+ profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
+
+ for (RankingConstant constant : store.readLargeConstants()) {
if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName()))
profile.getSearch().addRankingConstant(constant);
}
+
return store.readConverted().getRoot();
}
@@ -158,8 +169,13 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
}
- private void transformConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
- Path constantPath = store.writeConstant(constantName, constantValue);
+ private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
+ store.writeSmallConstant(constantName, constantValue);
+ profile.addConstant(constantName, asValue(constantValue));
+ }
+
+ private void transformLargeConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
+ Path constantPath = store.writeLargeConstant(constantName, constantValue);
if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) {
log.info("Adding constant '" + constantName + "' of type " + constantValue.type());
@@ -218,6 +234,13 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
}
+ private Value asValue(Tensor tensor) {
+ if (tensor.type().rank() == 0)
+ return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
+ else
+ return new TensorValue(tensor);
+ }
+
/**
* Provides read/write access to the correct directories of the application package given by the feature arguments
*/
@@ -272,13 +295,13 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
/**
- * Reads the information about all the constants stored in the application package
+ * Reads the information about all the large (aka ranking) constants stored in the application package
* (the constant value itself is replicated with file distribution).
*/
- public List<RankingConstant> readRankingConstants() {
+ public List<RankingConstant> readLargeConstants() {
try {
List<RankingConstant> constants = new ArrayList<>();
- for (ApplicationFile constantFile : application.getFile(arguments.rankingConstantsPath()).listFiles()) {
+ for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) {
String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
}
@@ -295,25 +318,63 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
*
* @return the path to the stored constant, relative to the application package root
*/
- public Path writeConstant(String name, Tensor constant) {
+ public Path writeLargeConstant(String name, Tensor constant) {
Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
// "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
Path constantPath = constantsPath.append(name + ".tbf");
- Path constantPathCorrected = constantPath;
- if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed)
- && ! constantPath.elements().contains(FilesApplicationPackage.preprocessed)) {
- constantPathCorrected = Path.fromString(FilesApplicationPackage.preprocessed).append(constantPath);
- }
// Remember the constant in a file we replicate in ZooKeeper
- application.getFile(arguments.rankingConstantsPath().append(name + ".constant"))
- .writeFile(new StringReader(name + ":" + constant.type() + ":" + constantPathCorrected));
+ application.getFile(arguments.largeConstantsPath().append(name + ".constant"))
+ .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
// Write content explicitly as a file on the file system as this is distributed using file distribution
createIfNeeded(constantsPath);
IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
- return constantPathCorrected;
+ return correct(constantPath);
+ }
+
+ private List<Pair<String, Tensor>> readSmallConstants() {
+ try {
+ ApplicationFile file = application.getFile(arguments.smallConstantsPath());
+ if (!file.exists()) return Collections.emptyList();
+
+ List<Pair<String, Tensor>> constants = new ArrayList<>();
+ BufferedReader reader = new BufferedReader(file.createReader());
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ TensorType type = TensorType.fromSpec(parts[1]);
+ Tensor tensor = Tensor.from(type, parts[2]);
+ constants.add(new Pair<>(name, tensor));
+ }
+ return constants;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Append this constant to the single file used for small constants distributed as config
+ */
+ public void writeSmallConstant(String name, Tensor constant) {
+ // Secret file format for remembering constants:
+ application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" +
+ constant.type().toString() + "\t" +
+ constant.toString() + "\n");
+ }
+
+ /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
+ private Path correct(Path path) {
+ if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed)
+ && ! path.elements().contains(FilesApplicationPackage.preprocessed)) {
+ return Path.fromString(FilesApplicationPackage.preprocessed).append(path);
+ }
+ else {
+ return path;
+ }
}
private void createIfNeeded(Path path) {
@@ -351,7 +412,13 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
public Optional<String> signature() { return signature; }
public Optional<String> output() { return output; }
- public Path rankingConstantsPath() {
+ /** Path to the small constants file */
+ public Path smallConstantsPath() {
+ return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt");
+ }
+
+ /** Path to the large (ranking) constants directory */
+ public Path largeConstantsPath() {
return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java
index 5255cdaeba1..0334012e8d9 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java
@@ -183,7 +183,7 @@ public class TensorTransformer extends ExpressionTransformer<RankProfileTransfor
}
private void addIfConstant(ReferenceNode node, Context context, RankProfile profile) {
- if (!node.getName().equals(ConstantTensorTransformer.CONSTANT)) {
+ if ( ! node.getName().equals(ConstantTensorTransformer.CONSTANT)) {
return;
}
if (node.children().size() != 1) {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 5203e686681..7246b22b0f8 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -399,6 +399,17 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Override
+ public ApplicationFile appendFile(String value) {
+ try {
+ IOUtils.writeFile(file, value, true);
+ return this;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
public List<ApplicationFile> listFiles(PathFilter filter) {
if ( ! isDirectory()) return Collections.emptyList();
return Arrays.stream(file.listFiles()).filter(f -> filter.accept(Path.fromString(f.toString())))
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationFile.java b/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationFile.java
index 717fb88e5dc..affc2e03e2b 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationFile.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationFile.java
@@ -95,7 +95,6 @@ class ZKApplicationFile extends ApplicationFile {
@Override
public ApplicationFile writeFile(Reader input) {
- // foo/bar/baz.txt
String zkPath = getZKPath(path);
try {
String data = IOUtils.readAll(input);
@@ -112,6 +111,21 @@ class ZKApplicationFile extends ApplicationFile {
}
@Override
+ public ApplicationFile appendFile(String value) {
+ String zkPath = getZKPath(path);
+ String status = ContentStatusNew;
+ if (zkApp.exists(zkPath)) {
+ status = ContentStatusChanged;
+ }
+ String existingData = zkApp.getData(zkPath);
+ if (existingData == null)
+ existingData = "";
+ zkApp.putData(zkPath, existingData + value);
+ writeMetaFile(value, status);
+ return this;
+ }
+
+ @Override
public List<ApplicationFile> listFiles(PathFilter filter) {
String userPath = getZKPath(path);
List<ApplicationFile> ret = new ArrayList<>();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index b8f8e288257..55782c36d18 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -219,7 +219,7 @@ class OperationMapper {
private static Optional<TypedTensorFunction> placeholderWithDefault(TensorFlowImporter.Parameters params) {
String name = toVespaName(params.node().getInput(0));
Tensor defaultValue = getConstantTensor(params, params.node().getInput(0));
- params.result().constant(name, defaultValue);
+ params.result().largeConstant(name, defaultValue);
params.result().macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")")));
// The default value will be provided by the macro. Users can override macro to change value.
TypedTensorFunction output = new TypedTensorFunction(defaultValue.type(), new VariableTensor(name));
@@ -544,7 +544,11 @@ class OperationMapper {
private static Optional<TypedTensorFunction> createConstant(TensorFlowImporter.Parameters params, Tensor constant) {
String name = toVespaName(params.node().getName());
- params.result().constant(name, constant);
+ if (constant.type().rank() == 0 || constant.size() <= 1) {
+ params.result().smallConstant(name, constant);
+ } else {
+ params.result().largeConstant(name, constant);
+ }
TypedTensorFunction output = new TypedTensorFunction(constant.type(),
new TensorFunctionNode.TensorFunctionExpressionNode(
new ReferenceNode("constant(\"" + name + "\")")));
@@ -553,8 +557,11 @@ class OperationMapper {
private static Tensor getConstantTensor(TensorFlowImporter.Parameters params, String name) {
String vespaName = toVespaName(name);
- if (params.result().constants().containsKey(vespaName)) {
- return params.result().constants().get(vespaName);
+ if (params.result().smallConstants().containsKey(vespaName)) {
+ return params.result().smallConstants().get(vespaName);
+ }
+ if (params.result().largeConstants().containsKey(vespaName)) {
+ return params.result().largeConstants().get(vespaName);
}
Session.Runner fetched = params.model().session().runner().fetch(name);
List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
index 530f4793b62..351aa417f9c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
@@ -24,13 +24,15 @@ public class TensorFlowModel {
private final Map<String, Signature> signatures = new HashMap<>();
private final Map<String, TensorType> arguments = new HashMap<>();
- private final Map<String, Tensor> constants = new HashMap<>();
+ private final Map<String, Tensor> smallConstants = new HashMap<>();
+ private final Map<String, Tensor> largeConstants = new HashMap<>();
private final Map<String, RankingExpression> expressions = new HashMap<>();
private final Map<String, RankingExpression> macros = new HashMap<>();
private final Map<String, TensorType> requiredMacros = new HashMap<>();
void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
- void constant(String name, Tensor constant) { constants.put(name, constant); }
+ void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
+ void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
void macro(String name, RankingExpression expression) { macros.put(name, expression); }
void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
@@ -43,8 +45,19 @@ public class TensorFlowModel {
/** Returns an immutable map of the arguments ("Placeholders") of this */
public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
- /** Returns an immutable map of the constants of this */
- public Map<String, Tensor> constants() { return Collections.unmodifiableMap(constants); }
+ /**
+ * Returns an immutable map of the small constants of this.
+ * These should have sizes up to a few kb at most, and correspond to constant
+ * values given in the TensorFlow source.
+ */
+ public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
+
+ /**
+ * Returns an immutable map of the large constants of this.
+ * These can have sizes in gigabytes and must be distributed to nodes separately from configuration,
+ * and correspond to Variable files stored separately in TensorFlow.
+ */
+ public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); }
/**
* Returns an immutable map of the expressions of this - corresponding to TensorFlow nodes
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
index 01dd15d5fa0..ad5abd4c03d 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
@@ -20,15 +20,15 @@ public class MnistSoftmaxImportTestCase {
TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/mnist_softmax/saved");
// Check constants
- assertEquals(2, model.get().constants().size());
+ assertEquals(2, model.get().largeConstants().size());
- Tensor constant0 = model.get().constants().get("Variable");
+ Tensor constant0 = model.get().largeConstants().get("Variable");
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
- Tensor constant1 = model.get().constants().get("Variable_1");
+ Tensor constant1 = model.get().largeConstants().get("Variable_1");
assertNotNull(constant1);
assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
constant1.type());
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
index 2c621fd2e92..ae7714b271a 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
@@ -57,7 +57,8 @@ public class TestableTensorFlowModel {
private Context contextFrom(TensorFlowModel result) {
MapContext context = new MapContext();
- result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor)));
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor)));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor)));
return context;
}