aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-11-22 14:27:58 +0100
committerJon Bratseth <bratseth@oath.com>2018-11-22 14:27:58 +0100
commitb288e61f7af7331656a1850fbdc58cc95fd1bbad (patch)
tree9d41fa770d2890585a902f41a89c41040ed764be /model-integration/src/test/java/ai
parent3c4020645b13be560c14e60969e50e3ad41e3d3c (diff)
Move all importing to model-integration
Diffstat (limited to 'model-integration/src/test/java/ai')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java48
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java21
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java2
10 files changed, 77 insertions, 8 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java
new file mode 100644
index 00000000000..cf8dd6e8e71
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java
@@ -0,0 +1,48 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertTrue;
+
+public class DimensionRenamerTest {
+
+ @Test
+ public void testMnistRenaming() {
+ DimensionRenamer renamer = new DimensionRenamer();
+
+ renamer.addDimension("first_dimension_of_x");
+ renamer.addDimension("second_dimension_of_x");
+ renamer.addDimension("first_dimension_of_w");
+ renamer.addDimension("second_dimension_of_w");
+ renamer.addDimension("first_dimension_of_b");
+
+ // which dimension to join on matmul
+ renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null);
+
+ // other dimensions in matmul can't be equal
+ renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null);
+
+ // for efficiency, put dimension to join on innermost
+ renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null);
+ renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null);
+
+ // bias
+ renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null);
+
+ renamer.solve();
+
+ String firstDimensionOfXName = renamer.dimensionNameOf("first_dimension_of_x").get();
+ String secondDimensionOfXName = renamer.dimensionNameOf("second_dimension_of_x").get();
+ String firstDimensionOfWName = renamer.dimensionNameOf("first_dimension_of_w").get();
+ String secondDimensionOfWName = renamer.dimensionNameOf("second_dimension_of_w").get();
+ String firstDimensionOfBName = renamer.dimensionNameOf("first_dimension_of_b").get();
+
+ assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfXName) < 0);
+ assertTrue(firstDimensionOfWName.compareTo(secondDimensionOfWName) > 0);
+ assertTrue(secondDimensionOfXName.compareTo(firstDimensionOfWName) == 0);
+ assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfWName) < 0);
+ assertTrue(secondDimensionOfWName.compareTo(firstDimensionOfBName) == 0);
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java
new file mode 100644
index 00000000000..afe699d6e05
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java
@@ -0,0 +1,21 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author bratseth
+ */
+public class OrderedTensorTypeTestCase {
+
+ @Test
+ public void testToFromSpec() {
+ String spec = "tensor(b[],c{},a[3])";
+ OrderedTensorType type = OrderedTensorType.fromSpec(spec);
+ assertEquals(spec, type.toString());
+ assertEquals("tensor(a[3],b[],c{})", type.type().toString());
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
index d86e7d6dd8e..d3996da9b58 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -6,7 +6,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java
index d112a3fa9f2..1a072f54c89 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java
@@ -2,7 +2,7 @@
package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import ai.vespa.rankingexpression.importer.ImportedModel;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java
index fa89e060006..37104ab43db 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java
@@ -7,7 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import ai.vespa.rankingexpression.importer.ImportedModel;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
index b3559a0a5f6..5e20be051ea 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
@@ -2,7 +2,7 @@
package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import ai.vespa.rankingexpression.importer.ImportedModel;
import com.yahoo.tensor.TensorType;
import org.junit.Assert;
import org.junit.Test;
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java
index 7e717c204f8..28b91b3797a 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java
@@ -2,7 +2,7 @@
package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import ai.vespa.rankingexpression.importer.ImportedModel;
import org.junit.Assert;
import org.junit.Test;
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
index f98b37b7e55..6215997d8f9 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
@@ -2,7 +2,7 @@
package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import ai.vespa.rankingexpression.importer.ImportedModel;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Assert;
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
index faa2c7acc18..c3b82cccb46 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
@@ -7,7 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import ai.vespa.rankingexpression.importer.ImportedModel;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
index 30b50c025d0..965d5eb8577 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
@@ -1,7 +1,7 @@
package ai.vespa.rankingexpression.importer.xgboost;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import ai.vespa.rankingexpression.importer.ImportedModel;
import org.junit.Test;
import static org.junit.Assert.assertEquals;