summaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java')
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java189
1 files changed, 161 insertions, 28 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 1fe1ebf2bb3..132cf936054 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -1,23 +1,33 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.processing;
+import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.io.IOUtils;
+import com.yahoo.io.reader.NamedReader;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.ml.ImportedModelTester;
import com.yahoo.yolean.Exceptions;
import org.junit.After;
import org.junit.Test;
+import java.io.File;
+import java.io.FileReader;
import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
import java.util.Optional;
-import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage;
-
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.fail;
@@ -28,8 +38,7 @@ public class RankingExpressionWithOnnxTestCase {
/** The model name */
private final static String name = "mnist_softmax";
- private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_Variable), f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))";
- private final static String vespaExpressionWithBatchReduce = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))";
+ private final static String vespaExpression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(" + name + "_layer_Variable), f(a,b)(a * b)), sum, d2) * 1.0, constant(" + name + "_layer_Variable_1) * 1.0, f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))";
@After
public void removeGeneratedModelFiles() {
@@ -40,8 +49,8 @@ public class RankingExpressionWithOnnxTestCase {
public void testGlobalOnnxModel() throws IOException {
ImportedModelTester tester = new ImportedModelTester(name, applicationDir);
VespaModel model = tester.createVespaModel();
- tester.assertLargeConstant(name + "_Variable_1", model, Optional.of(10L));
- tester.assertLargeConstant(name + "_Variable", model, Optional.of(7840L));
+ tester.assertLargeConstant(name + "_layer_Variable_1", model, Optional.of(10L));
+ tester.assertLargeConstant(name + "_layer_Variable", model, Optional.of(7840L));
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedAppDir = applicationDir.append("copy");
@@ -52,8 +61,8 @@ public class RankingExpressionWithOnnxTestCase {
storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir);
VespaModel storedModel = storedTester.createVespaModel();
- tester.assertLargeConstant(name + "_Variable_1", storedModel, Optional.of(10L));
- tester.assertLargeConstant(name + "_Variable", storedModel, Optional.of(7840L));
+ tester.assertLargeConstant(name + "_layer_Variable_1", storedModel, Optional.of(10L));
+ tester.assertLargeConstant(name + "_layer_Variable", storedModel, Optional.of(7840L));
}
finally {
IOUtils.recursiveDeleteDir(storedAppDir.toFile());
@@ -64,7 +73,7 @@ public class RankingExpressionWithOnnxTestCase {
public void testOnnxReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
"onnx('mnist_softmax.onnx')",
- "constant mytensor { file: ignored\ntype: tensor<float>(d0[7],d1[784]) }",
+ "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }",
null);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@@ -74,7 +83,7 @@ public class RankingExpressionWithOnnxTestCase {
String queryProfile = "<query-profile id='default' type='root'/>";
String queryProfileType =
"<query-profile-type id='root'>" +
- " <field name='query(mytensor)' type='tensor&lt;float&gt;(d0[3],d1[784])'/>" +
+ " <field name='query(mytensor)' type='tensor&lt;float&gt;(d0[1],d1[784])'/>" +
"</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
queryProfile,
@@ -97,7 +106,7 @@ public class RankingExpressionWithOnnxTestCase {
"field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }",
"Placeholder",
application);
- search.assertFirstPhaseExpression(vespaExpressionWithBatchReduce, "my_profile");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@@ -115,28 +124,28 @@ public class RankingExpressionWithOnnxTestCase {
"field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }",
"Placeholder",
application);
- search.assertFirstPhaseExpression(vespaExpressionWithBatchReduce, "my_profile");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@Test
public void testNestedOnnxReference() {
- RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
+ RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
"5 + sum(onnx('mnist_softmax.onnx'))");
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
}
@Test
public void testOnnxReferenceWithSpecifiedOutput() {
- RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx', 'add')");
+ RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
+ "onnx('mnist_softmax.onnx', 'layer_add')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@Test
public void testOnnxReferenceWithSpecifiedOutputAndSignature() {
- RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx', 'default.add')");
+ RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
+ "onnx('mnist_softmax.onnx', 'default.layer_add')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@@ -158,7 +167,7 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx'): " +
- "Model refers input 'Placeholder' of type tensor<float>(d0[],d1[784]) but this function is " +
+ "Model refers input 'Placeholder' of type tensor<float>(d0[1],d1[784]) but this function is " +
"not present in rank profile 'my_profile'",
Exceptions.toMessageString(expected));
}
@@ -167,7 +176,7 @@ public class RankingExpressionWithOnnxTestCase {
@Test
public void testOnnxReferenceWithWrongFunctionType() {
try {
- RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d5[10])(0.0)",
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d5[10])(0.0)",
"onnx('mnist_softmax.onnx')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
fail("Expecting exception");
@@ -175,8 +184,8 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx'): " +
- "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[],d1[784]), " +
- "but this function returns tensor(d0[2],d5[10])",
+ "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[1],d1[784]), " +
+ "but this function returns tensor(d0[1],d5[10])",
Exceptions.toMessageString(expected));
}
}
@@ -192,14 +201,14 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx','y'): " +
- "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.add",
+ "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.layer_add",
Exceptions.toMessageString(expected));
}
}
@Test
public void testImportingFromStoredExpressions() throws IOException {
- RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
+ RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
"onnx('mnist_softmax.onnx')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
@@ -226,26 +235,29 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
- public void testImportingFromStoredExpressionsWithFunctionOverridingConstant() throws IOException {
+ public void testImportingFromStoredExpressionsWithFunctionOverridingConstantAndInheritance() throws IOException {
String rankProfile =
" rank-profile my_profile {\n" +
" function Placeholder() {\n" +
- " expression: tensor<float>(d0[2],d1[784])(0.0)\n" +
+ " expression: tensor<float>(d0[1],d1[784])(0.0)\n" +
" }\n" +
- " function " + name + "_Variable() {\n" +
+ " function " + name + "_layer_Variable() {\n" +
" expression: tensor<float>(d1[10],d2[784])(0.0)\n" +
" }\n" +
" first-phase {\n" +
" expression: onnx('mnist_softmax.onnx')" +
" }\n" +
+ " }" +
+ " rank-profile my_profile_child inherits my_profile {\n" +
" }";
-
String vespaExpressionWithoutConstant =
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), " + name + "_Variable, f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))";
+ "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), " + name + "_layer_Variable, f(a,b)(a * b)), sum, d2) * 1.0, constant(" + name + "_layer_Variable_1) * 1.0, f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))";
RankProfileSearchFixture search = uncompiledFixtureWith(rankProfile, new StoringApplicationPackage(applicationDir));
search.compileRankProfile("my_profile", applicationDir.append("models"));
+ search.compileRankProfile("my_profile_child", applicationDir.append("models"));
search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
+ search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child");
assertNull("Constant overridden by function is not added",
search.search().rankingConstants().get( name + "_Variable"));
@@ -259,7 +271,9 @@ public class RankingExpressionWithOnnxTestCase {
StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfile, storedApplication);
searchFromStored.compileRankProfile("my_profile", applicationDir.append("models"));
+ searchFromStored.compileRankProfile("my_profile_child", applicationDir.append("models"));
searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
+ searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child");
assertNull("Constant overridden by function is not added",
searchFromStored.search().rankingConstants().get( name + "_Variable"));
} finally {
@@ -267,6 +281,90 @@ public class RankingExpressionWithOnnxTestCase {
}
}
+ @Test
+ public void testReduceBatchDimension() {
+ final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(" + name + "_layer_Variable), f(a,b)(a * b)), sum, d2) * 1.0, constant(" + name + "_layer_Variable_1) * 1.0, f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))";
+ RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
+ "onnx('mnist_softmax.onnx')");
+ search.assertFirstPhaseExpression(expression, "my_profile");
+ }
+
+ @Test
+ public void testFunctionGeneration() {
+ final String name = "small_constants_and_functions";
+ final String rankProfiles =
+ " rank-profile my_profile {\n" +
+ " function input() {\n" +
+ " expression: tensor<float>(d0[3])(0.0)\n" +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: onnx('" + name + ".onnx')" +
+ " }\n" +
+ " }";
+ final String functionName = "imported_ml_function_" + name + "_exp_output";
+ final String expression = "join(" + functionName + ", reduce(join(join(reduce(" + functionName + ", sum, d0), tensor<float>(d0[1])(1.0), f(a,b)(a * b)), constant(" + name + "_epsilon), f(a,b)(a + b)), sum, d0), f(a,b)(a / b))";
+ final String functionExpression = "map(input, f(a)(exp(a)))";
+
+ RankProfileSearchFixture search = uncompiledFixtureWith(rankProfiles, new StoringApplicationPackage(applicationDir));
+ search.compileRankProfile("my_profile", applicationDir.append("models"));
+ search.assertFirstPhaseExpression(expression, "my_profile");
+ search.assertFunction(functionExpression, functionName, "my_profile");
+ }
+
+ @Test
+ public void testImportingFromStoredExpressionsWithSmallConstantsAndInheritance() throws IOException {
+ final String name = "small_constants_and_functions";
+ final String rankProfiles =
+ " rank-profile my_profile {\n" +
+ " function input() {\n" +
+ " expression: tensor<float>(d0[3])(0.0)\n" +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: onnx('" + name + ".onnx')" +
+ " }\n" +
+ " }" +
+ " rank-profile my_profile_child inherits my_profile {\n" +
+ " }";
+ final String functionName = "imported_ml_function_" + name + "_exp_output";
+ final String expression = "join(" + functionName + ", reduce(join(join(reduce(" + functionName + ", sum, d0), tensor<float>(d0[1])(1.0), f(a,b)(a * b)), constant(" + name + "_epsilon), f(a,b)(a + b)), sum, d0), f(a,b)(a / b))";
+ final String functionExpression = "map(input, f(a)(exp(a)))";
+
+ RankProfileSearchFixture search = uncompiledFixtureWith(rankProfiles, new StoringApplicationPackage(applicationDir));
+ search.compileRankProfile("my_profile", applicationDir.append("models"));
+ search.compileRankProfile("my_profile_child", applicationDir.append("models"));
+ search.assertFirstPhaseExpression(expression, "my_profile");
+ search.assertFirstPhaseExpression(expression, "my_profile_child");
+ assertSmallConstant(name + "_epsilon", TensorType.fromSpec("tensor()"), search);
+ search.assertFunction(functionExpression, functionName, "my_profile");
+ search.assertFunction(functionExpression, functionName, "my_profile_child");
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
+ try {
+ storedApplicationDirectory.toFile().mkdirs();
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
+ RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfiles, storedApplication);
+ searchFromStored.compileRankProfile("my_profile", applicationDir.append("models"));
+ searchFromStored.compileRankProfile("my_profile_child", applicationDir.append("models"));
+ searchFromStored.assertFirstPhaseExpression(expression, "my_profile");
+ searchFromStored.assertFirstPhaseExpression(expression, "my_profile_child");
+ assertSmallConstant(name + "_epsilon", TensorType.fromSpec("tensor()"), search);
+ searchFromStored.assertFunction(functionExpression, functionName, "my_profile");
+ searchFromStored.assertFunction(functionExpression, functionName, "my_profile_child");
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
+ }
+ }
+
+ private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) {
+ Value value = search.compiledRankProfile("my_profile").getConstants().get(name);
+ assertNotNull(value);
+ assertEquals(type, value.type());
+ }
+
private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) {
return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder",
new StoringApplicationPackage(applicationDir));
@@ -316,4 +414,39 @@ public class RankingExpressionWithOnnxTestCase {
}
}
+ static class StoringApplicationPackage extends MockApplicationPackage {
+
+ StoringApplicationPackage(Path applicationPackageWritableRoot) {
+ this(applicationPackageWritableRoot, null, null);
+ }
+
+ StoringApplicationPackage(Path applicationPackageWritableRoot, String queryProfile, String queryProfileType) {
+ super(new File(applicationPackageWritableRoot.toString()),
+ null, null, Collections.emptyList(), null,
+ null, null, false, queryProfile, queryProfileType);
+ }
+
+ @Override
+ public ApplicationFile getFile(Path file) {
+ return new MockApplicationFile(file, Path.fromString(root().toString()));
+ }
+
+ @Override
+ public List<NamedReader> getFiles(Path path, String suffix) {
+ List<NamedReader> readers = new ArrayList<>();
+ for (File file : getFileReference(path).listFiles()) {
+ if ( ! file.getName().endsWith(suffix)) continue;
+ try {
+ readers.add(new NamedReader(file.getName(), new FileReader(file)));
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+ return readers;
+ }
+
+ }
+
+
}