diff options
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java')
-rw-r--r-- | config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java | 189 |
1 files changed, 28 insertions, 161 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 132cf936054..1fe1ebf2bb3 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -1,33 +1,23 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; -import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.io.IOUtils; -import com.yahoo.io.reader.NamedReader; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.ml.ImportedModelTester; import com.yahoo.yolean.Exceptions; import org.junit.After; import org.junit.Test; -import java.io.File; -import java.io.FileReader; import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; import java.util.Optional; +import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage; + import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.fail; @@ -38,7 +28,8 @@ public class RankingExpressionWithOnnxTestCase { /** The model name */ private final static String name = "mnist_softmax"; - private final static String vespaExpression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(" + name + "_layer_Variable), f(a,b)(a * b)), sum, d2) * 1.0, constant(" + name + "_layer_Variable_1) * 1.0, f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))"; + private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_Variable), f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))"; + private final static String vespaExpressionWithBatchReduce = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))"; @After public void removeGeneratedModelFiles() { @@ -49,8 +40,8 @@ public class RankingExpressionWithOnnxTestCase { public void testGlobalOnnxModel() throws IOException { ImportedModelTester tester = new ImportedModelTester(name, applicationDir); VespaModel model = tester.createVespaModel(); - tester.assertLargeConstant(name + "_layer_Variable_1", model, Optional.of(10L)); - tester.assertLargeConstant(name + "_layer_Variable", model, Optional.of(7840L)); + tester.assertLargeConstant(name + "_Variable_1", model, Optional.of(10L)); + tester.assertLargeConstant(name + "_Variable", model, Optional.of(7840L)); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedAppDir = applicationDir.append("copy"); @@ -61,8 +52,8 @@ public class RankingExpressionWithOnnxTestCase { storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir); VespaModel storedModel = storedTester.createVespaModel(); - tester.assertLargeConstant(name + "_layer_Variable_1", storedModel, Optional.of(10L)); - tester.assertLargeConstant(name + "_layer_Variable", storedModel, Optional.of(7840L)); + tester.assertLargeConstant(name + "_Variable_1", storedModel, Optional.of(10L)); + tester.assertLargeConstant(name + "_Variable", storedModel, Optional.of(7840L)); } finally { IOUtils.recursiveDeleteDir(storedAppDir.toFile()); @@ -73,7 +64,7 @@ public class RankingExpressionWithOnnxTestCase { public void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "onnx('mnist_softmax.onnx')", - "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }", + "constant mytensor { file: ignored\ntype: tensor<float>(d0[7],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @@ -83,7 +74,7 @@ public class RankingExpressionWithOnnxTestCase { String queryProfile = "<query-profile id='default' type='root'/>"; String queryProfileType = "<query-profile-type id='root'>" + - " <field name='query(mytensor)' type='tensor<float>(d0[1],d1[784])'/>" + + " <field name='query(mytensor)' type='tensor<float>(d0[3],d1[784])'/>" + "</query-profile-type>"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, @@ -106,7 +97,7 @@ public class RankingExpressionWithOnnxTestCase { "field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }", "Placeholder", application); - search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + search.assertFirstPhaseExpression(vespaExpressionWithBatchReduce, "my_profile"); } @@ -124,28 +115,28 @@ public class RankingExpressionWithOnnxTestCase { "field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }", "Placeholder", application); - search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + search.assertFirstPhaseExpression(vespaExpressionWithBatchReduce, "my_profile"); } @Test public void testNestedOnnxReference() { - RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", + RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", "5 + sum(onnx('mnist_softmax.onnx'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); } @Test public void testOnnxReferenceWithSpecifiedOutput() { - RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "onnx('mnist_softmax.onnx', 'layer_add')"); + RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", + "onnx('mnist_softmax.onnx', 'add')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test public void testOnnxReferenceWithSpecifiedOutputAndSignature() { - RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "onnx('mnist_softmax.onnx', 'default.layer_add')"); + RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", + "onnx('mnist_softmax.onnx', 'default.add')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @@ -167,7 +158,7 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx'): " + - "Model refers input 'Placeholder' of type tensor<float>(d0[1],d1[784]) but this function is " + + "Model refers input 'Placeholder' of type tensor<float>(d0[],d1[784]) but this function is " + "not present in rank profile 'my_profile'", Exceptions.toMessageString(expected)); } @@ -176,7 +167,7 @@ public class RankingExpressionWithOnnxTestCase { @Test public void testOnnxReferenceWithWrongFunctionType() { try { - RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d5[10])(0.0)", + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d5[10])(0.0)", "onnx('mnist_softmax.onnx')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); @@ -184,8 +175,8 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx'): " + - "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[1],d1[784]), " + - "but this function returns tensor(d0[1],d5[10])", + "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[],d1[784]), " + + "but this function returns tensor(d0[2],d5[10])", Exceptions.toMessageString(expected)); } } @@ -201,14 +192,14 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx','y'): " + - "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.layer_add", + "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.add", Exceptions.toMessageString(expected)); } } @Test public void testImportingFromStoredExpressions() throws IOException { - RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", + RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", "onnx('mnist_softmax.onnx')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -235,29 +226,26 @@ public class RankingExpressionWithOnnxTestCase { } @Test - public void testImportingFromStoredExpressionsWithFunctionOverridingConstantAndInheritance() throws IOException { + public void testImportingFromStoredExpressionsWithFunctionOverridingConstant() throws IOException { String rankProfile = " rank-profile my_profile {\n" + " function Placeholder() {\n" + - " expression: tensor<float>(d0[1],d1[784])(0.0)\n" + + " expression: tensor<float>(d0[2],d1[784])(0.0)\n" + " }\n" + - " function " + name + "_layer_Variable() {\n" + + " function " + name + "_Variable() {\n" + " expression: tensor<float>(d1[10],d2[784])(0.0)\n" + " }\n" + " first-phase {\n" + " expression: onnx('mnist_softmax.onnx')" + " }\n" + - " }" + - " rank-profile my_profile_child inherits my_profile {\n" + " }"; + String vespaExpressionWithoutConstant = - "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), " + name + "_layer_Variable, f(a,b)(a * b)), sum, d2) * 1.0, constant(" + name + "_layer_Variable_1) * 1.0, f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))"; + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), " + name + "_Variable, f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))"; RankProfileSearchFixture search = uncompiledFixtureWith(rankProfile, new StoringApplicationPackage(applicationDir)); search.compileRankProfile("my_profile", applicationDir.append("models")); - search.compileRankProfile("my_profile_child", applicationDir.append("models")); search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); - search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child"); assertNull("Constant overridden by function is not added", search.search().rankingConstants().get( name + "_Variable")); @@ -271,9 +259,7 @@ public class RankingExpressionWithOnnxTestCase { StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfile, storedApplication); searchFromStored.compileRankProfile("my_profile", applicationDir.append("models")); - searchFromStored.compileRankProfile("my_profile_child", applicationDir.append("models")); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); - searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child"); assertNull("Constant overridden by function is not added", searchFromStored.search().rankingConstants().get( name + "_Variable")); } finally { @@ -281,90 +267,6 @@ public class RankingExpressionWithOnnxTestCase { } } - @Test - public void testReduceBatchDimension() { - final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(" + name + "_layer_Variable), f(a,b)(a * b)), sum, d2) * 1.0, constant(" + name + "_layer_Variable_1) * 1.0, f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))"; - RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "onnx('mnist_softmax.onnx')"); - search.assertFirstPhaseExpression(expression, "my_profile"); - } - - @Test - public void testFunctionGeneration() { - final String name = "small_constants_and_functions"; - final String rankProfiles = - " rank-profile my_profile {\n" + - " function input() {\n" + - " expression: tensor<float>(d0[3])(0.0)\n" + - " }\n" + - " first-phase {\n" + - " expression: onnx('" + name + ".onnx')" + - " }\n" + - " }"; - final String functionName = "imported_ml_function_" + name + "_exp_output"; - final String expression = "join(" + functionName + ", reduce(join(join(reduce(" + functionName + ", sum, d0), tensor<float>(d0[1])(1.0), f(a,b)(a * b)), constant(" + name + "_epsilon), f(a,b)(a + b)), sum, d0), f(a,b)(a / b))"; - final String functionExpression = "map(input, f(a)(exp(a)))"; - - RankProfileSearchFixture search = uncompiledFixtureWith(rankProfiles, new StoringApplicationPackage(applicationDir)); - search.compileRankProfile("my_profile", applicationDir.append("models")); - search.assertFirstPhaseExpression(expression, "my_profile"); - search.assertFunction(functionExpression, functionName, "my_profile"); - } - - @Test - public void testImportingFromStoredExpressionsWithSmallConstantsAndInheritance() throws IOException { - final String name = "small_constants_and_functions"; - final String rankProfiles = - " rank-profile my_profile {\n" + - " function input() {\n" + - " expression: tensor<float>(d0[3])(0.0)\n" + - " }\n" + - " first-phase {\n" + - " expression: onnx('" + name + ".onnx')" + - " }\n" + - " }" + - " rank-profile my_profile_child inherits my_profile {\n" + - " }"; - final String functionName = "imported_ml_function_" + name + "_exp_output"; - final String expression = "join(" + functionName + ", reduce(join(join(reduce(" + functionName + ", sum, d0), tensor<float>(d0[1])(1.0), f(a,b)(a * b)), constant(" + name + "_epsilon), f(a,b)(a + b)), sum, d0), f(a,b)(a / b))"; - final String functionExpression = "map(input, f(a)(exp(a)))"; - - RankProfileSearchFixture search = uncompiledFixtureWith(rankProfiles, new StoringApplicationPackage(applicationDir)); - search.compileRankProfile("my_profile", applicationDir.append("models")); - search.compileRankProfile("my_profile_child", applicationDir.append("models")); - search.assertFirstPhaseExpression(expression, "my_profile"); - search.assertFirstPhaseExpression(expression, "my_profile_child"); - assertSmallConstant(name + "_epsilon", TensorType.fromSpec("tensor()"), search); - search.assertFunction(functionExpression, functionName, "my_profile"); - search.assertFunction(functionExpression, functionName, "my_profile_child"); - - // At this point the expression is stored - copy application to another location which do not have a models dir - Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); - try { - storedApplicationDirectory.toFile().mkdirs(); - IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), - storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); - StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); - RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfiles, storedApplication); - searchFromStored.compileRankProfile("my_profile", applicationDir.append("models")); - searchFromStored.compileRankProfile("my_profile_child", applicationDir.append("models")); - searchFromStored.assertFirstPhaseExpression(expression, "my_profile"); - searchFromStored.assertFirstPhaseExpression(expression, "my_profile_child"); - assertSmallConstant(name + "_epsilon", TensorType.fromSpec("tensor()"), search); - searchFromStored.assertFunction(functionExpression, functionName, "my_profile"); - searchFromStored.assertFunction(functionExpression, functionName, "my_profile_child"); - } - finally { - IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); - } - } - - private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) { - Value value = search.compiledRankProfile("my_profile").getConstants().get(name); - assertNotNull(value); - assertEquals(type, value.type()); - } - private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) { return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder", new StoringApplicationPackage(applicationDir)); @@ -414,39 +316,4 @@ public class RankingExpressionWithOnnxTestCase { } } - static class StoringApplicationPackage extends MockApplicationPackage { - - StoringApplicationPackage(Path applicationPackageWritableRoot) { - this(applicationPackageWritableRoot, null, null); - } - - StoringApplicationPackage(Path applicationPackageWritableRoot, String queryProfile, String queryProfileType) { - super(new File(applicationPackageWritableRoot.toString()), - null, null, Collections.emptyList(), null, - null, null, false, queryProfile, queryProfileType); - } - - @Override - public ApplicationFile getFile(Path file) { - return new MockApplicationFile(file, Path.fromString(root().toString())); - } - - @Override - public List<NamedReader> getFiles(Path path, String suffix) { - List<NamedReader> readers = new ArrayList<>(); - for (File file : getFileReference(path).listFiles()) { - if ( ! file.getName().endsWith(suffix)) continue; - try { - readers.add(new NamedReader(file.getName(), new FileReader(file))); - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - return readers; - } - - } - - } |