diff options
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java')
-rw-r--r-- | config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java | 189 |
1 files changed, 161 insertions, 28 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 1fe1ebf2bb3..132cf936054 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -1,23 +1,33 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; +import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.io.IOUtils; +import com.yahoo.io.reader.NamedReader; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.ml.ImportedModelTester; import com.yahoo.yolean.Exceptions; import org.junit.After; import org.junit.Test; +import java.io.File; +import java.io.FileReader; import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import java.util.Optional; -import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage; - import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.fail; @@ -28,8 +38,7 @@ public class RankingExpressionWithOnnxTestCase { /** The model name */ private final static String name = "mnist_softmax"; - private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_Variable), f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))"; - private final static String vespaExpressionWithBatchReduce = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))"; + private final static String vespaExpression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(" + name + "_layer_Variable), f(a,b)(a * b)), sum, d2) * 1.0, constant(" + name + "_layer_Variable_1) * 1.0, f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))"; @After public void removeGeneratedModelFiles() { @@ -40,8 +49,8 @@ public class RankingExpressionWithOnnxTestCase { public void testGlobalOnnxModel() throws IOException { ImportedModelTester tester = new ImportedModelTester(name, applicationDir); VespaModel model = tester.createVespaModel(); - tester.assertLargeConstant(name + "_Variable_1", model, Optional.of(10L)); - tester.assertLargeConstant(name + "_Variable", model, Optional.of(7840L)); + tester.assertLargeConstant(name + "_layer_Variable_1", model, Optional.of(10L)); + tester.assertLargeConstant(name + "_layer_Variable", model, Optional.of(7840L)); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedAppDir = applicationDir.append("copy"); @@ -52,8 +61,8 @@ public class RankingExpressionWithOnnxTestCase { storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir); VespaModel storedModel = storedTester.createVespaModel(); - tester.assertLargeConstant(name + "_Variable_1", storedModel, Optional.of(10L)); - tester.assertLargeConstant(name + "_Variable", storedModel, Optional.of(7840L)); + tester.assertLargeConstant(name + "_layer_Variable_1", storedModel, Optional.of(10L)); + tester.assertLargeConstant(name + "_layer_Variable", storedModel, Optional.of(7840L)); } finally { IOUtils.recursiveDeleteDir(storedAppDir.toFile()); @@ -64,7 +73,7 @@ public class RankingExpressionWithOnnxTestCase { public void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "onnx('mnist_softmax.onnx')", - "constant mytensor { file: ignored\ntype: tensor<float>(d0[7],d1[784]) }", + "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @@ -74,7 +83,7 @@ public class RankingExpressionWithOnnxTestCase { String queryProfile = "<query-profile id='default' type='root'/>"; String queryProfileType = "<query-profile-type id='root'>" + - " <field name='query(mytensor)' type='tensor<float>(d0[3],d1[784])'/>" + + " <field name='query(mytensor)' type='tensor<float>(d0[1],d1[784])'/>" + "</query-profile-type>"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, @@ -97,7 +106,7 @@ public class RankingExpressionWithOnnxTestCase { "field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }", "Placeholder", application); - search.assertFirstPhaseExpression(vespaExpressionWithBatchReduce, "my_profile"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @@ -115,28 +124,28 @@ public class RankingExpressionWithOnnxTestCase { "field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }", "Placeholder", application); - search.assertFirstPhaseExpression(vespaExpressionWithBatchReduce, "my_profile"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test public void testNestedOnnxReference() { - RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", + RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", "5 + sum(onnx('mnist_softmax.onnx'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); } @Test public void testOnnxReferenceWithSpecifiedOutput() { - RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", - "onnx('mnist_softmax.onnx', 'add')"); + RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", + "onnx('mnist_softmax.onnx', 'layer_add')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test public void testOnnxReferenceWithSpecifiedOutputAndSignature() { - RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", - "onnx('mnist_softmax.onnx', 'default.add')"); + RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", + "onnx('mnist_softmax.onnx', 'default.layer_add')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @@ -158,7 +167,7 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx'): " + - "Model refers input 'Placeholder' of type tensor<float>(d0[],d1[784]) but this function is " + + "Model refers input 'Placeholder' of type tensor<float>(d0[1],d1[784]) but this function is " + "not present in rank profile 'my_profile'", Exceptions.toMessageString(expected)); } @@ -167,7 +176,7 @@ public class RankingExpressionWithOnnxTestCase { @Test public void testOnnxReferenceWithWrongFunctionType() { try { - RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d5[10])(0.0)", + RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d5[10])(0.0)", "onnx('mnist_softmax.onnx')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); @@ -175,8 +184,8 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx'): " + - "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[],d1[784]), " + - "but this function returns tensor(d0[2],d5[10])", + "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[1],d1[784]), " + + "but this function returns tensor(d0[1],d5[10])", Exceptions.toMessageString(expected)); } } @@ -192,14 +201,14 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx','y'): " + - "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.add", + "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.layer_add", Exceptions.toMessageString(expected)); } } @Test public void testImportingFromStoredExpressions() throws IOException { - RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", + RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", "onnx('mnist_softmax.onnx')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -226,26 +235,29 @@ public class RankingExpressionWithOnnxTestCase { } @Test - public void testImportingFromStoredExpressionsWithFunctionOverridingConstant() throws IOException { + public void testImportingFromStoredExpressionsWithFunctionOverridingConstantAndInheritance() throws IOException { String rankProfile = " rank-profile my_profile {\n" + " function Placeholder() {\n" + - " expression: tensor<float>(d0[2],d1[784])(0.0)\n" + + " expression: tensor<float>(d0[1],d1[784])(0.0)\n" + " }\n" + - " function " + name + "_Variable() {\n" + + " function " + name + "_layer_Variable() {\n" + " expression: tensor<float>(d1[10],d2[784])(0.0)\n" + " }\n" + " first-phase {\n" + " expression: onnx('mnist_softmax.onnx')" + " }\n" + + " }" + + " rank-profile my_profile_child inherits my_profile {\n" + " }"; - String vespaExpressionWithoutConstant = - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), " + name + "_Variable, f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))"; + "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), " + name + "_layer_Variable, f(a,b)(a * b)), sum, d2) * 1.0, constant(" + name + "_layer_Variable_1) * 1.0, f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))"; RankProfileSearchFixture search = uncompiledFixtureWith(rankProfile, new StoringApplicationPackage(applicationDir)); search.compileRankProfile("my_profile", applicationDir.append("models")); + search.compileRankProfile("my_profile_child", applicationDir.append("models")); search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); + search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child"); assertNull("Constant overridden by function is not added", search.search().rankingConstants().get( name + "_Variable")); @@ -259,7 +271,9 @@ public class RankingExpressionWithOnnxTestCase { StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfile, storedApplication); searchFromStored.compileRankProfile("my_profile", applicationDir.append("models")); + searchFromStored.compileRankProfile("my_profile_child", applicationDir.append("models")); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); + searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child"); assertNull("Constant overridden by function is not added", searchFromStored.search().rankingConstants().get( name + "_Variable")); } finally { @@ -267,6 +281,90 @@ public class RankingExpressionWithOnnxTestCase { } } + @Test + public void testReduceBatchDimension() { + final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(" + name + "_layer_Variable), f(a,b)(a * b)), sum, d2) * 1.0, constant(" + name + "_layer_Variable_1) * 1.0, f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))"; + RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", + "onnx('mnist_softmax.onnx')"); + search.assertFirstPhaseExpression(expression, "my_profile"); + } + + @Test + public void testFunctionGeneration() { + final String name = "small_constants_and_functions"; + final String rankProfiles = + " rank-profile my_profile {\n" + + " function input() {\n" + + " expression: tensor<float>(d0[3])(0.0)\n" + + " }\n" + + " first-phase {\n" + + " expression: onnx('" + name + ".onnx')" + + " }\n" + + " }"; + final String functionName = "imported_ml_function_" + name + "_exp_output"; + final String expression = "join(" + functionName + ", reduce(join(join(reduce(" + functionName + ", sum, d0), tensor<float>(d0[1])(1.0), f(a,b)(a * b)), constant(" + name + "_epsilon), f(a,b)(a + b)), sum, d0), f(a,b)(a / b))"; + final String functionExpression = "map(input, f(a)(exp(a)))"; + + RankProfileSearchFixture search = uncompiledFixtureWith(rankProfiles, new StoringApplicationPackage(applicationDir)); + search.compileRankProfile("my_profile", applicationDir.append("models")); + search.assertFirstPhaseExpression(expression, "my_profile"); + search.assertFunction(functionExpression, functionName, "my_profile"); + } + + @Test + public void testImportingFromStoredExpressionsWithSmallConstantsAndInheritance() throws IOException { + final String name = "small_constants_and_functions"; + final String rankProfiles = + " rank-profile my_profile {\n" + + " function input() {\n" + + " expression: tensor<float>(d0[3])(0.0)\n" + + " }\n" + + " first-phase {\n" + + " expression: onnx('" + name + ".onnx')" + + " }\n" + + " }" + + " rank-profile my_profile_child inherits my_profile {\n" + + " }"; + final String functionName = "imported_ml_function_" + name + "_exp_output"; + final String expression = "join(" + functionName + ", reduce(join(join(reduce(" + functionName + ", sum, d0), tensor<float>(d0[1])(1.0), f(a,b)(a * b)), constant(" + name + "_epsilon), f(a,b)(a + b)), sum, d0), f(a,b)(a / b))"; + final String functionExpression = "map(input, f(a)(exp(a)))"; + + RankProfileSearchFixture search = uncompiledFixtureWith(rankProfiles, new StoringApplicationPackage(applicationDir)); + search.compileRankProfile("my_profile", applicationDir.append("models")); + search.compileRankProfile("my_profile_child", applicationDir.append("models")); + search.assertFirstPhaseExpression(expression, "my_profile"); + search.assertFirstPhaseExpression(expression, "my_profile_child"); + assertSmallConstant(name + "_epsilon", TensorType.fromSpec("tensor()"), search); + search.assertFunction(functionExpression, functionName, "my_profile"); + search.assertFunction(functionExpression, functionName, "my_profile_child"); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); + try { + storedApplicationDirectory.toFile().mkdirs(); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); + RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfiles, storedApplication); + searchFromStored.compileRankProfile("my_profile", applicationDir.append("models")); + searchFromStored.compileRankProfile("my_profile_child", applicationDir.append("models")); + searchFromStored.assertFirstPhaseExpression(expression, "my_profile"); + searchFromStored.assertFirstPhaseExpression(expression, "my_profile_child"); + assertSmallConstant(name + "_epsilon", TensorType.fromSpec("tensor()"), search); + searchFromStored.assertFunction(functionExpression, functionName, "my_profile"); + searchFromStored.assertFunction(functionExpression, functionName, "my_profile_child"); + } + finally { + IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); + } + } + + private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) { + Value value = search.compiledRankProfile("my_profile").getConstants().get(name); + assertNotNull(value); + assertEquals(type, value.type()); + } + private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) { return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder", new StoringApplicationPackage(applicationDir)); @@ -316,4 +414,39 @@ public class RankingExpressionWithOnnxTestCase { } } + static class StoringApplicationPackage extends MockApplicationPackage { + + StoringApplicationPackage(Path applicationPackageWritableRoot) { + this(applicationPackageWritableRoot, null, null); + } + + StoringApplicationPackage(Path applicationPackageWritableRoot, String queryProfile, String queryProfileType) { + super(new File(applicationPackageWritableRoot.toString()), + null, null, Collections.emptyList(), null, + null, null, false, queryProfile, queryProfileType); + } + + @Override + public ApplicationFile getFile(Path file) { + return new MockApplicationFile(file, Path.fromString(root().toString())); + } + + @Override + public List<NamedReader> getFiles(Path path, String suffix) { + List<NamedReader> readers = new ArrayList<>(); + for (File file : getFileReference(path).listFiles()) { + if ( ! file.getName().endsWith(suffix)) continue; + try { + readers.add(new NamedReader(file.getName(), new FileReader(file))); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + return readers; + } + + } + + } |