aboutsummaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-01-31 19:45:03 +0100
committerJon Bratseth <bratseth@oath.com>2018-01-31 19:45:03 +0100
commitc99fe08dd762542b774f8d59e82da5f6ab076aa7 (patch)
treec119ff04aa8295b6b94040b94d07500145e04e4c /config-model
parent6aa637558ae3dafc6112f2ac8fb192ede83744de (diff)
Add constant feature test
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java6
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java25
3 files changed, 26 insertions, 7 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 492f4e56465..3268522517d 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -747,7 +747,7 @@ public class RankProfile implements Serializable, Cloneable {
TypeMapContext context = new TypeMapContext();
// Add constants
- getConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.type()));
+ getSearch().getRankingConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.getTensorType()));
// Add attributes
for (SDField field : getSearch().allConcreteFields()) {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
index 79bdbddbdd6..e54d348f904 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -29,12 +29,14 @@ class RankProfileSearchFixture {
}
RankProfileSearchFixture(ApplicationPackage applicationpackage, String rankProfiles) throws ParseException {
- this(applicationpackage, rankProfiles, null);
+ this(applicationpackage, rankProfiles, null, null);
}
- RankProfileSearchFixture(ApplicationPackage applicationpackage, String rankProfiles, String field) throws ParseException {
+ RankProfileSearchFixture(ApplicationPackage applicationpackage, String rankProfiles, String field, String constant)
+ throws ParseException {
SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry);
String sdContent = "search test {\n" +
+ " " + (constant != null ? constant : "") +
" document test {\n" +
(field != null ? field : "") +
" }\n" +
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 60a2c5326a0..5637806f908 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -59,6 +59,17 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
+ public void testTensorFlowReferenceWithConstantFeature() throws ParseException {
+ RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
+ "tensorflow('mnist_softmax/saved')",
+ "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }",
+ null);
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ assertConstant("Variable_1", search, Optional.of(10L));
+ assertConstant("Variable", search, Optional.of(7840L));
+ }
+
+ @Test
public void testTensorFlowReferenceWithQueryFeature() throws ParseException {
String queryProfile = "<query-profile id='default' type='root'/>";
String queryProfileType = "<query-profile-type id='root'>" +
@@ -70,6 +81,7 @@ public class RankingExpressionWithTensorFlowTestCase {
RankProfileSearchFixture search = fixtureWith("query(mytensor)",
"tensorflow('mnist_softmax/saved')",
null,
+ null,
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
assertConstant("Variable_1", search, Optional.of(10L));
@@ -82,6 +94,7 @@ public class RankingExpressionWithTensorFlowTestCase {
RankProfileSearchFixture search = fixtureWith("attribute(mytensor)",
"tensorflow('mnist_softmax/saved')",
"field mytensor type tensor(d0[],d1[784]) { indexing: attribute }",
+ null,
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
assertConstant("Variable_1", search, Optional.of(10L));
@@ -200,6 +213,7 @@ public class RankingExpressionWithTensorFlowTestCase {
RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"tensorflow('mnist_softmax/saved')",
null,
+ null,
storedApplication);
searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile");
// Verify that the constants exists, but don't verify the content as we are not
@@ -239,17 +253,19 @@ public class RankingExpressionWithTensorFlowTestCase {
}
private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) {
- return fixtureWith(placeholderExpression, firstPhaseExpression, null,
+ return fixtureWith(placeholderExpression, firstPhaseExpression, null, null,
new StoringApplicationPackage(applicationDir));
}
- private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression, String field) {
- return fixtureWith(placeholderExpression, firstPhaseExpression, field,
+ private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression,
+ String constant, String field) {
+ return fixtureWith(placeholderExpression, firstPhaseExpression, field, constant,
new StoringApplicationPackage(applicationDir));
}
private RankProfileSearchFixture fixtureWith(String placeholderExpression,
String firstPhaseExpression,
+ String constant,
String field,
StoringApplicationPackage application) {
try {
@@ -257,12 +273,13 @@ public class RankingExpressionWithTensorFlowTestCase {
application,
" rank-profile my_profile {\n" +
" macro Placeholder() {\n" +
- " expression: " + placeholderExpression +
+ " expression: " + placeholderExpression +
" }\n" +
" first-phase {\n" +
" expression: " + firstPhaseExpression +
" }\n" +
" }",
+ constant,
field);
}
catch (ParseException e) {