summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java10
1 files changed, 8 insertions, 2 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
index 4ff0c96d369..9d2f8cf0692 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
@@ -33,14 +33,20 @@ public class TestableTensorFlowModel {
private ImportedModel model;
// Sizes of the input vector
- private final int d0Size = 1;
- private final int d1Size = 784;
+ private int d0Size = 1;
+ private int d1Size = 784;
public TestableTensorFlowModel(String modelName, String modelDir) {
tensorFlowModel = SavedModelBundle.load(modelDir, "serve");
model = new TensorFlowImporter().importModel(modelName, modelDir, tensorFlowModel);
}
+ public TestableTensorFlowModel(String modelName, String modelDir, int d0Size, int d1Size) {
+ this(modelName, modelDir);
+ this.d0Size = d0Size;
+ this.d1Size = d1Size;
+ }
+
public ImportedModel get() { return model; }
/** Compare that summing the tensors produce the same result to within some tolerance delta */