diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-11-25 20:07:56 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-11-25 20:07:56 +0100 |
commit | 1d88554bd513783715425120e76fc5f2a86f439f (patch) | |
tree | 166c86107d3620014cc7e26d85118c311e1b8cf0 /model-integration/src/test | |
parent | a01bc21d9bcbc417a9fb2591079561f59f76865e (diff) |
Java type only interface between imported-models and config models
This avoids class incompatibility problems when passing an
imported model across bundle boundaries to a config model.
Tensor string parsing has been sped up as this relies on it more.
Diffstat (limited to 'model-integration/src/test')
3 files changed, 9 insertions, 9 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index d3996da9b58..315456c2613 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -28,13 +28,13 @@ public class OnnxMnistSoftmaxImportTestCase { // Check constants assertEquals(2, model.largeConstants().size()); - Tensor constant0 = model.largeConstants().get("test_Variable"); + Tensor constant0 = Tensor.from(model.largeConstants().get("test_Variable")); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.largeConstants().get("test_Variable_1"); + Tensor constant1 = Tensor.from(model.largeConstants().get("test_Variable_1")); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); assertEquals(10, constant1.size()); @@ -84,8 +84,8 @@ public class OnnxMnistSoftmaxImportTestCase { private Context contextFrom(ImportedModel result) { MapContext context = new MapContext(); - result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); - result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); return context; } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java index 6215997d8f9..be676186017 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java @@ -24,13 +24,13 @@ public class TensorFlowMnistSoftmaxImportTestCase { // Check constants Assert.assertEquals(2, model.get().largeConstants().size()); - Tensor constant0 = model.get().largeConstants().get("test_Variable_read"); + Tensor constant0 = Tensor.from(model.get().largeConstants().get("test_Variable_read")); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.get().largeConstants().get("test_Variable_1_read"); + Tensor constant1 = Tensor.from(model.get().largeConstants().get("test_Variable_1_read")); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java index c3b82cccb46..4ff0c96d369 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java @@ -93,8 +93,8 @@ public class TestableTensorFlowModel { static Context contextFrom(ImportedModel result) { TestableModelContext context = new TestableModelContext(); - result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); - result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); return context; } @@ -108,7 +108,7 @@ public class TestableTensorFlowModel { private void evaluateFunction(Context context, ImportedModel model, String functionName) { if (!context.names().contains(functionName)) { - RankingExpression e = model.functions().get(functionName); + RankingExpression e = RankingExpression.from(model.functions().get(functionName)); evaluateFunctionDependencies(context, model, e.getRoot()); context.put(functionName, new TensorValue(e.evaluate(context).asTensor())); } |