diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-11-22 14:27:58 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-11-22 14:27:58 +0100 |
commit | b288e61f7af7331656a1850fbdc58cc95fd1bbad (patch) | |
tree | 9d41fa770d2890585a902f41a89c41040ed764be /model-integration/src/test/java/ai | |
parent | 3c4020645b13be560c14e60969e50e3ad41e3d3c (diff) |
Move all importing to model-integration
Diffstat (limited to 'model-integration/src/test/java/ai')
10 files changed, 77 insertions, 8 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java new file mode 100644 index 00000000000..cf8dd6e8e71 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java @@ -0,0 +1,48 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer; + +import org.junit.Test; + +import static org.junit.Assert.assertTrue; + +public class DimensionRenamerTest { + + @Test + public void testMnistRenaming() { + DimensionRenamer renamer = new DimensionRenamer(); + + renamer.addDimension("first_dimension_of_x"); + renamer.addDimension("second_dimension_of_x"); + renamer.addDimension("first_dimension_of_w"); + renamer.addDimension("second_dimension_of_w"); + renamer.addDimension("first_dimension_of_b"); + + // which dimension to join on matmul + renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null); + + // other dimensions in matmul can't be equal + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null); + + // for efficiency, put dimension to join on innermost + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null); + renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null); + + // bias + renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null); + + renamer.solve(); + + String firstDimensionOfXName = renamer.dimensionNameOf("first_dimension_of_x").get(); + String secondDimensionOfXName = renamer.dimensionNameOf("second_dimension_of_x").get(); + String firstDimensionOfWName = renamer.dimensionNameOf("first_dimension_of_w").get(); + String secondDimensionOfWName = renamer.dimensionNameOf("second_dimension_of_w").get(); + String firstDimensionOfBName = renamer.dimensionNameOf("first_dimension_of_b").get(); + + assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfXName) < 0); + assertTrue(firstDimensionOfWName.compareTo(secondDimensionOfWName) > 0); + assertTrue(secondDimensionOfXName.compareTo(firstDimensionOfWName) == 0); + assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfWName) < 0); + assertTrue(secondDimensionOfWName.compareTo(firstDimensionOfBName) == 0); + } + +} diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java new file mode 100644 index 00000000000..afe699d6e05 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java @@ -0,0 +1,21 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class OrderedTensorTypeTestCase { + + @Test + public void testToFromSpec() { + String spec = "tensor(b[],c{},a[3])"; + OrderedTensorType type = OrderedTensorType.fromSpec(spec); + assertEquals(spec, type.toString()); + assertEquals("tensor(a[3],b[],c{})", type.type().toString()); + } + +} diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index d86e7d6dd8e..d3996da9b58 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -6,7 +6,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import ai.vespa.rankingexpression.importer.ImportedModel; import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java index d112a3fa9f2..1a072f54c89 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java @@ -2,7 +2,7 @@ package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import ai.vespa.rankingexpression.importer.ImportedModel; import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java index fa89e060006..37104ab43db 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java @@ -7,7 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import ai.vespa.rankingexpression.importer.ImportedModel; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java index b3559a0a5f6..5e20be051ea 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java @@ -2,7 +2,7 @@ package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import ai.vespa.rankingexpression.importer.ImportedModel; import com.yahoo.tensor.TensorType; import org.junit.Assert; import org.junit.Test; diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java index 7e717c204f8..28b91b3797a 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java @@ -2,7 +2,7 @@ package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import ai.vespa.rankingexpression.importer.ImportedModel; import org.junit.Assert; import org.junit.Test; diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java index f98b37b7e55..6215997d8f9 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java @@ -2,7 +2,7 @@ package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import ai.vespa.rankingexpression.importer.ImportedModel; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Assert; diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java index faa2c7acc18..c3b82cccb46 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java @@ -7,7 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import ai.vespa.rankingexpression.importer.ImportedModel; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java index 30b50c025d0..965d5eb8577 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java @@ -1,7 +1,7 @@ package ai.vespa.rankingexpression.importer.xgboost; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import ai.vespa.rankingexpression.importer.ImportedModel; import org.junit.Test; import static org.junit.Assert.assertEquals; |