summaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java')
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java93
1 files changed, 68 insertions, 25 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 7228af2b0de..9a96555bb78 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -6,6 +6,7 @@ import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
+import com.yahoo.io.reader.NamedReader;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankingConstant;
@@ -22,10 +23,12 @@ import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
+import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.Reader;
import java.io.UncheckedIOException;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
@@ -156,6 +159,7 @@ public class RankingExpressionWithTensorFlowTestCase {
" expression: tensorflow('mnist_softmax/saved')" +
" }\n" +
" }");
+ search.compileRankProfile("my_profile");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
fail("Expecting exception");
}
@@ -196,7 +200,9 @@ public class RankingExpressionWithTensorFlowTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved','serving_defaultz'): " +
- "Model does not have the specified signature 'serving_defaultz'",
+ "No expressions available in model 'mnist_softmax_saved'",
+// "No expressions named 'serving_defaultz' in model 'mnist_softmax/saved'. "+
+// "Available expressions: mnist_softmax_saved.serving_default.y",
Exceptions.toMessageString(expected));
}
}
@@ -212,7 +218,9 @@ public class RankingExpressionWithTensorFlowTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved','serving_default','x'): " +
- "Model does not have the specified output 'x'",
+ "No expressions available in model 'mnist_softmax_saved'",
+// "No expression 'mnist_softmax_saved.serving_default.x' in model 'mnist_softmax/saved'. " +
+// "Available expressions: mnist_softmax_saved.serving_default.y",
Exceptions.toMessageString(expected));
}
}
@@ -251,8 +259,8 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
- public void testImportingFromStoredExpressionsWithMacroOverridingConstant() throws IOException {
- String rankProfile =
+ public void testImportingFromStoredExpressionsWithMacroOverridingConstantAndInheritance() throws IOException {
+ String rankProfiles =
" rank-profile my_profile {\n" +
" macro Placeholder() {\n" +
" expression: tensor(d0[2],d1[784])(0.0)\n" +
@@ -263,13 +271,17 @@ public class RankingExpressionWithTensorFlowTestCase {
" first-phase {\n" +
" expression: tensorflow('mnist_softmax/saved')" +
" }\n" +
+ " }" +
+ " rank-profile my_profile_child inherits my_profile {\n" +
" }";
-
String vespaExpressionWithoutConstant =
"join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), mnist_softmax_saved_layer_Variable_read, f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))";
- RankProfileSearchFixture search = fixtureWith(rankProfile, new StoringApplicationPackage(applicationDir));
+ RankProfileSearchFixture search = fixtureWithUncompiled(rankProfiles, new StoringApplicationPackage(applicationDir));
+ search.compileRankProfile("my_profile");
+ search.compileRankProfile("my_profile_child");
search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
+ search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child");
assertNull("Constant overridden by macro is not added",
search.search().getRankingConstants().get("mnist_softmax_saved_layer_Variable_read"));
@@ -282,8 +294,11 @@ public class RankingExpressionWithTensorFlowTestCase {
IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
- RankProfileSearchFixture searchFromStored = fixtureWith(rankProfile, storedApplication);
+ RankProfileSearchFixture searchFromStored = fixtureWithUncompiled(rankProfiles, storedApplication);
+ searchFromStored.compileRankProfile("my_profile");
+ searchFromStored.compileRankProfile("my_profile_child");
searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
+ searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child");
assertNull("Constant overridden by macro is not added",
searchFromStored.search().getRankingConstants().get("mnist_softmax_saved_layer_Variable_read"));
assertLargeConstant("mnist_softmax_saved_layer_Variable_1_read", searchFromStored, Optional.of(10L));
@@ -297,7 +312,7 @@ public class RankingExpressionWithTensorFlowTestCase {
public void testTensorFlowReduceBatchDimension() {
final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
- "tensorflow('mnist_softmax/saved')");
+ "tensorflow('mnist_softmax/saved')");
search.assertFirstPhaseExpression(expression, "my_profile");
assertLargeConstant("mnist_softmax_saved_layer_Variable_1_read", search, Optional.of(10L));
assertLargeConstant("mnist_softmax_saved_layer_Variable_read", search, Optional.of(7840L));
@@ -321,22 +336,33 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
- public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException {
+ public void testImportingFromStoredExpressionsWithSmallConstantsAndInheritance() throws IOException {
+ final String rankProfiles =
+ " rank-profile my_profile {\n" +
+ " macro input() {\n" +
+ " expression: tensor(d0[1],d1[784])(0.0)\n" +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: tensorflow('mnist/saved')" +
+ " }\n" +
+ " }" +
+ " rank-profile my_profile_child inherits my_profile {\n" +
+ " }";
+
final String expression = "join(join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))";
final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))";
- StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
- RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
- "tensorflow('mnist/saved')",
- null,
- null,
- "input",
- application);
+ RankProfileSearchFixture search = fixtureWithUncompiled(rankProfiles, new StoringApplicationPackage(applicationDir));
+ search.compileRankProfile("my_profile");
+ search.compileRankProfile("my_profile_child");
search.assertFirstPhaseExpression(expression, "my_profile");
+ search.assertFirstPhaseExpression(expression, "my_profile_child");
assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile_child");
search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile");
+ search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile_child");
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
@@ -345,16 +371,16 @@ public class RankingExpressionWithTensorFlowTestCase {
IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
- RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[1],d1[784])(0.0)",
- "tensorflow('mnist/saved')",
- null,
- null,
- "input",
- storedApplication);
+ RankProfileSearchFixture searchFromStored = fixtureWithUncompiled(rankProfiles, storedApplication);
+ searchFromStored.compileRankProfile("my_profile");
+ searchFromStored.compileRankProfile("my_profile_child");
searchFromStored.assertFirstPhaseExpression(expression, "my_profile");
+ searchFromStored.assertFirstPhaseExpression(expression, "my_profile_child");
assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
searchFromStored.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ searchFromStored.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile_child");
searchFromStored.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile");
+ searchFromStored.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile_child");
}
finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
@@ -362,7 +388,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) {
- Value value = search.rankProfile("my_profile").getConstants().get(name);
+ Value value = search.compiledRankProfile("my_profile").getConstants().get(name);
assertNotNull(value);
assertEquals(type, value.type());
}
@@ -410,7 +436,7 @@ public class RankingExpressionWithTensorFlowTestCase {
String macroName,
StoringApplicationPackage application) {
try {
- return new RankProfileSearchFixture(
+ RankProfileSearchFixture fixture = new RankProfileSearchFixture(
application,
application.getQueryProfiles(),
" rank-profile my_profile {\n" +
@@ -423,13 +449,15 @@ public class RankingExpressionWithTensorFlowTestCase {
" }",
constant,
field);
+ fixture.compileRankProfile("my_profile");
+ return fixture;
}
catch (ParseException e) {
throw new IllegalArgumentException(e);
}
}
- private RankProfileSearchFixture fixtureWith(String rankProfile, StoringApplicationPackage application) {
+ private RankProfileSearchFixture fixtureWithUncompiled(String rankProfile, StoringApplicationPackage application) {
try {
return new RankProfileSearchFixture(application, application.getQueryProfiles(),
rankProfile, null, null);
@@ -463,6 +491,21 @@ public class RankingExpressionWithTensorFlowTestCase {
return new StoringApplicationPackageFile(file, Path.fromString(root.toString()));
}
+ @Override
+ public List<NamedReader> getFiles(Path path, String suffix) {
+ List<NamedReader> readers = new ArrayList<>();
+ for (File file : getFileReference(path).listFiles()) {
+ if ( ! file.getName().endsWith(suffix)) continue;
+ try {
+ readers.add(new NamedReader(file.getName(), new FileReader(file)));
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+ return readers;
+ }
+
}
static class StoringApplicationPackageFile extends ApplicationFile {