summaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java')
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java189
1 files changed, 28 insertions, 161 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 132cf936054..1fe1ebf2bb3 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -1,33 +1,23 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.processing;
-import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.io.IOUtils;
-import com.yahoo.io.reader.NamedReader;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.ml.ImportedModelTester;
import com.yahoo.yolean.Exceptions;
import org.junit.After;
import org.junit.Test;
-import java.io.File;
-import java.io.FileReader;
import java.io.IOException;
-import java.io.UncheckedIOException;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
import java.util.Optional;
+import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage;
+
import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.fail;
@@ -38,7 +28,8 @@ public class RankingExpressionWithOnnxTestCase {
/** The model name */
private final static String name = "mnist_softmax";
- private final static String vespaExpression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(" + name + "_layer_Variable), f(a,b)(a * b)), sum, d2) * 1.0, constant(" + name + "_layer_Variable_1) * 1.0, f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))";
+ private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_Variable), f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))";
+ private final static String vespaExpressionWithBatchReduce = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))";
@After
public void removeGeneratedModelFiles() {
@@ -49,8 +40,8 @@ public class RankingExpressionWithOnnxTestCase {
public void testGlobalOnnxModel() throws IOException {
ImportedModelTester tester = new ImportedModelTester(name, applicationDir);
VespaModel model = tester.createVespaModel();
- tester.assertLargeConstant(name + "_layer_Variable_1", model, Optional.of(10L));
- tester.assertLargeConstant(name + "_layer_Variable", model, Optional.of(7840L));
+ tester.assertLargeConstant(name + "_Variable_1", model, Optional.of(10L));
+ tester.assertLargeConstant(name + "_Variable", model, Optional.of(7840L));
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedAppDir = applicationDir.append("copy");
@@ -61,8 +52,8 @@ public class RankingExpressionWithOnnxTestCase {
storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir);
VespaModel storedModel = storedTester.createVespaModel();
- tester.assertLargeConstant(name + "_layer_Variable_1", storedModel, Optional.of(10L));
- tester.assertLargeConstant(name + "_layer_Variable", storedModel, Optional.of(7840L));
+ tester.assertLargeConstant(name + "_Variable_1", storedModel, Optional.of(10L));
+ tester.assertLargeConstant(name + "_Variable", storedModel, Optional.of(7840L));
}
finally {
IOUtils.recursiveDeleteDir(storedAppDir.toFile());
@@ -73,7 +64,7 @@ public class RankingExpressionWithOnnxTestCase {
public void testOnnxReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
"onnx('mnist_softmax.onnx')",
- "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }",
+ "constant mytensor { file: ignored\ntype: tensor<float>(d0[7],d1[784]) }",
null);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@@ -83,7 +74,7 @@ public class RankingExpressionWithOnnxTestCase {
String queryProfile = "<query-profile id='default' type='root'/>";
String queryProfileType =
"<query-profile-type id='root'>" +
- " <field name='query(mytensor)' type='tensor&lt;float&gt;(d0[1],d1[784])'/>" +
+ " <field name='query(mytensor)' type='tensor&lt;float&gt;(d0[3],d1[784])'/>" +
"</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
queryProfile,
@@ -106,7 +97,7 @@ public class RankingExpressionWithOnnxTestCase {
"field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }",
"Placeholder",
application);
- search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ search.assertFirstPhaseExpression(vespaExpressionWithBatchReduce, "my_profile");
}
@@ -124,28 +115,28 @@ public class RankingExpressionWithOnnxTestCase {
"field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }",
"Placeholder",
application);
- search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ search.assertFirstPhaseExpression(vespaExpressionWithBatchReduce, "my_profile");
}
@Test
public void testNestedOnnxReference() {
- RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
+ RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
"5 + sum(onnx('mnist_softmax.onnx'))");
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
}
@Test
public void testOnnxReferenceWithSpecifiedOutput() {
- RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx', 'layer_add')");
+ RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
+ "onnx('mnist_softmax.onnx', 'add')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@Test
public void testOnnxReferenceWithSpecifiedOutputAndSignature() {
- RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx', 'default.layer_add')");
+ RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
+ "onnx('mnist_softmax.onnx', 'default.add')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@@ -167,7 +158,7 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx'): " +
- "Model refers input 'Placeholder' of type tensor<float>(d0[1],d1[784]) but this function is " +
+ "Model refers input 'Placeholder' of type tensor<float>(d0[],d1[784]) but this function is " +
"not present in rank profile 'my_profile'",
Exceptions.toMessageString(expected));
}
@@ -176,7 +167,7 @@ public class RankingExpressionWithOnnxTestCase {
@Test
public void testOnnxReferenceWithWrongFunctionType() {
try {
- RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d5[10])(0.0)",
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d5[10])(0.0)",
"onnx('mnist_softmax.onnx')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
fail("Expecting exception");
@@ -184,8 +175,8 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx'): " +
- "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[1],d1[784]), " +
- "but this function returns tensor(d0[1],d5[10])",
+ "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[],d1[784]), " +
+ "but this function returns tensor(d0[2],d5[10])",
Exceptions.toMessageString(expected));
}
}
@@ -201,14 +192,14 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx','y'): " +
- "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.layer_add",
+ "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.add",
Exceptions.toMessageString(expected));
}
}
@Test
public void testImportingFromStoredExpressions() throws IOException {
- RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
+ RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
"onnx('mnist_softmax.onnx')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
@@ -235,29 +226,26 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
- public void testImportingFromStoredExpressionsWithFunctionOverridingConstantAndInheritance() throws IOException {
+ public void testImportingFromStoredExpressionsWithFunctionOverridingConstant() throws IOException {
String rankProfile =
" rank-profile my_profile {\n" +
" function Placeholder() {\n" +
- " expression: tensor<float>(d0[1],d1[784])(0.0)\n" +
+ " expression: tensor<float>(d0[2],d1[784])(0.0)\n" +
" }\n" +
- " function " + name + "_layer_Variable() {\n" +
+ " function " + name + "_Variable() {\n" +
" expression: tensor<float>(d1[10],d2[784])(0.0)\n" +
" }\n" +
" first-phase {\n" +
" expression: onnx('mnist_softmax.onnx')" +
" }\n" +
- " }" +
- " rank-profile my_profile_child inherits my_profile {\n" +
" }";
+
String vespaExpressionWithoutConstant =
- "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), " + name + "_layer_Variable, f(a,b)(a * b)), sum, d2) * 1.0, constant(" + name + "_layer_Variable_1) * 1.0, f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))";
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), " + name + "_Variable, f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))";
RankProfileSearchFixture search = uncompiledFixtureWith(rankProfile, new StoringApplicationPackage(applicationDir));
search.compileRankProfile("my_profile", applicationDir.append("models"));
- search.compileRankProfile("my_profile_child", applicationDir.append("models"));
search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
- search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child");
assertNull("Constant overridden by function is not added",
search.search().rankingConstants().get( name + "_Variable"));
@@ -271,9 +259,7 @@ public class RankingExpressionWithOnnxTestCase {
StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfile, storedApplication);
searchFromStored.compileRankProfile("my_profile", applicationDir.append("models"));
- searchFromStored.compileRankProfile("my_profile_child", applicationDir.append("models"));
searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
- searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child");
assertNull("Constant overridden by function is not added",
searchFromStored.search().rankingConstants().get( name + "_Variable"));
} finally {
@@ -281,90 +267,6 @@ public class RankingExpressionWithOnnxTestCase {
}
}
- @Test
- public void testReduceBatchDimension() {
- final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(" + name + "_layer_Variable), f(a,b)(a * b)), sum, d2) * 1.0, constant(" + name + "_layer_Variable_1) * 1.0, f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))";
- RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx')");
- search.assertFirstPhaseExpression(expression, "my_profile");
- }
-
- @Test
- public void testFunctionGeneration() {
- final String name = "small_constants_and_functions";
- final String rankProfiles =
- " rank-profile my_profile {\n" +
- " function input() {\n" +
- " expression: tensor<float>(d0[3])(0.0)\n" +
- " }\n" +
- " first-phase {\n" +
- " expression: onnx('" + name + ".onnx')" +
- " }\n" +
- " }";
- final String functionName = "imported_ml_function_" + name + "_exp_output";
- final String expression = "join(" + functionName + ", reduce(join(join(reduce(" + functionName + ", sum, d0), tensor<float>(d0[1])(1.0), f(a,b)(a * b)), constant(" + name + "_epsilon), f(a,b)(a + b)), sum, d0), f(a,b)(a / b))";
- final String functionExpression = "map(input, f(a)(exp(a)))";
-
- RankProfileSearchFixture search = uncompiledFixtureWith(rankProfiles, new StoringApplicationPackage(applicationDir));
- search.compileRankProfile("my_profile", applicationDir.append("models"));
- search.assertFirstPhaseExpression(expression, "my_profile");
- search.assertFunction(functionExpression, functionName, "my_profile");
- }
-
- @Test
- public void testImportingFromStoredExpressionsWithSmallConstantsAndInheritance() throws IOException {
- final String name = "small_constants_and_functions";
- final String rankProfiles =
- " rank-profile my_profile {\n" +
- " function input() {\n" +
- " expression: tensor<float>(d0[3])(0.0)\n" +
- " }\n" +
- " first-phase {\n" +
- " expression: onnx('" + name + ".onnx')" +
- " }\n" +
- " }" +
- " rank-profile my_profile_child inherits my_profile {\n" +
- " }";
- final String functionName = "imported_ml_function_" + name + "_exp_output";
- final String expression = "join(" + functionName + ", reduce(join(join(reduce(" + functionName + ", sum, d0), tensor<float>(d0[1])(1.0), f(a,b)(a * b)), constant(" + name + "_epsilon), f(a,b)(a + b)), sum, d0), f(a,b)(a / b))";
- final String functionExpression = "map(input, f(a)(exp(a)))";
-
- RankProfileSearchFixture search = uncompiledFixtureWith(rankProfiles, new StoringApplicationPackage(applicationDir));
- search.compileRankProfile("my_profile", applicationDir.append("models"));
- search.compileRankProfile("my_profile_child", applicationDir.append("models"));
- search.assertFirstPhaseExpression(expression, "my_profile");
- search.assertFirstPhaseExpression(expression, "my_profile_child");
- assertSmallConstant(name + "_epsilon", TensorType.fromSpec("tensor()"), search);
- search.assertFunction(functionExpression, functionName, "my_profile");
- search.assertFunction(functionExpression, functionName, "my_profile_child");
-
- // At this point the expression is stored - copy application to another location which do not have a models dir
- Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
- try {
- storedApplicationDirectory.toFile().mkdirs();
- IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
- storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
- StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
- RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfiles, storedApplication);
- searchFromStored.compileRankProfile("my_profile", applicationDir.append("models"));
- searchFromStored.compileRankProfile("my_profile_child", applicationDir.append("models"));
- searchFromStored.assertFirstPhaseExpression(expression, "my_profile");
- searchFromStored.assertFirstPhaseExpression(expression, "my_profile_child");
- assertSmallConstant(name + "_epsilon", TensorType.fromSpec("tensor()"), search);
- searchFromStored.assertFunction(functionExpression, functionName, "my_profile");
- searchFromStored.assertFunction(functionExpression, functionName, "my_profile_child");
- }
- finally {
- IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
- }
- }
-
- private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) {
- Value value = search.compiledRankProfile("my_profile").getConstants().get(name);
- assertNotNull(value);
- assertEquals(type, value.type());
- }
-
private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) {
return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder",
new StoringApplicationPackage(applicationDir));
@@ -414,39 +316,4 @@ public class RankingExpressionWithOnnxTestCase {
}
}
- static class StoringApplicationPackage extends MockApplicationPackage {
-
- StoringApplicationPackage(Path applicationPackageWritableRoot) {
- this(applicationPackageWritableRoot, null, null);
- }
-
- StoringApplicationPackage(Path applicationPackageWritableRoot, String queryProfile, String queryProfileType) {
- super(new File(applicationPackageWritableRoot.toString()),
- null, null, Collections.emptyList(), null,
- null, null, false, queryProfile, queryProfileType);
- }
-
- @Override
- public ApplicationFile getFile(Path file) {
- return new MockApplicationFile(file, Path.fromString(root().toString()));
- }
-
- @Override
- public List<NamedReader> getFiles(Path path, String suffix) {
- List<NamedReader> readers = new ArrayList<>();
- for (File file : getFileReference(path).listFiles()) {
- if ( ! file.getName().endsWith(suffix)) continue;
- try {
- readers.add(new NamedReader(file.getName(), new FileReader(file)));
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
- return readers;
- }
-
- }
-
-
}