diff options
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java')
-rw-r--r-- | model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java index 4ff0c96d369..9d2f8cf0692 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java @@ -33,14 +33,20 @@ public class TestableTensorFlowModel { private ImportedModel model; // Sizes of the input vector - private final int d0Size = 1; - private final int d1Size = 784; + private int d0Size = 1; + private int d1Size = 784; public TestableTensorFlowModel(String modelName, String modelDir) { tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); model = new TensorFlowImporter().importModel(modelName, modelDir, tensorFlowModel); } + public TestableTensorFlowModel(String modelName, String modelDir, int d0Size, int d1Size) { + this(modelName, modelDir); + this.d0Size = d0Size; + this.d1Size = d1Size; + } + public ImportedModel get() { return model; } /** Compare that summing the tensors produce the same result to within some tolerance delta */ |