summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-10-03 06:27:03 +0200
committerJon Bratseth <bratseth@gmail.com>2022-10-03 06:27:03 +0200
commit0c855ec9883c8f49cca892bed80358647c7cd9c0 (patch)
tree6b06c7c9d4c7ed536b2830024ad1bbea6a9bb837
parent539f2871e4812a463aa108639aac66c4903f3c33 (diff)
Revert "Merge pull request #24279 from vespa-engine/jonmv/revert-GC-heaven-commits"
This reverts commit 539f2871e4812a463aa108639aac66c4903f3c33, reversing changes made to aeaa3c2da60259a8ba80345591657922c90c1993.
-rw-r--r--config-model/src/main/java/com/yahoo/schema/derived/DerivedConfiguration.java41
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java32
-rw-r--r--config-model/src/test/derived/neuralnet/neuralnet.sd4
-rw-r--r--config-model/src/test/derived/neuralnet_noqueryprofile/neuralnet.sd4
-rw-r--r--config-model/src/test/derived/neuralnet_noqueryprofile/schema-info.cfg3
-rw-r--r--config-model/src/test/derived/rankingexpression/rankexpression.sd8
-rw-r--r--config-model/src/test/derived/rankingexpression/summary.cfg8
-rw-r--r--config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java68
-rw-r--r--config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java51
9 files changed, 149 insertions, 70 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/schema/derived/DerivedConfiguration.java
index 11bd14cbe46..8b07aa48a24 100644
--- a/config-model/src/main/java/com/yahoo/schema/derived/DerivedConfiguration.java
+++ b/config-model/src/main/java/com/yahoo/schema/derived/DerivedConfiguration.java
@@ -66,26 +66,31 @@ public class DerivedConfiguration implements AttributesConfig.Producer {
* schema is later modified.
*/
public DerivedConfiguration(Schema schema, DeployState deployState) {
- Validator.ensureNotNull("Schema", schema);
- this.schema = schema;
- this.queryProfiles = deployState.getQueryProfiles().getRegistry();
- this.maxUncommittedMemory = deployState.getProperties().featureFlags().maxUnCommittedMemory();
- if ( ! schema.isDocumentsOnly()) {
- streamingFields = new VsmFields(schema);
- streamingSummary = new VsmSummary(schema);
+ try {
+ Validator.ensureNotNull("Schema", schema);
+ this.schema = schema;
+ this.queryProfiles = deployState.getQueryProfiles().getRegistry();
+ this.maxUncommittedMemory = deployState.getProperties().featureFlags().maxUnCommittedMemory();
+ if (!schema.isDocumentsOnly()) {
+ streamingFields = new VsmFields(schema);
+ streamingSummary = new VsmSummary(schema);
+ }
+ if (!schema.isDocumentsOnly()) {
+ attributeFields = new AttributeFields(schema);
+ summaries = new Summaries(schema, deployState.getDeployLogger(), deployState.getProperties().featureFlags());
+ juniperrc = new Juniperrc(schema);
+ rankProfileList = new RankProfileList(schema, schema.rankExpressionFiles(), attributeFields, deployState);
+ indexingScript = new IndexingScript(schema);
+ indexInfo = new IndexInfo(schema);
+ schemaInfo = new SchemaInfo(schema, deployState.rankProfileRegistry(), summaries);
+ indexSchema = new IndexSchema(schema);
+ importedFields = new ImportedFields(schema);
+ }
+ Validation.validate(this, schema);
}
- if ( ! schema.isDocumentsOnly()) {
- attributeFields = new AttributeFields(schema);
- summaries = new Summaries(schema, deployState.getDeployLogger(), deployState.getProperties().featureFlags());
- juniperrc = new Juniperrc(schema);
- rankProfileList = new RankProfileList(schema, schema.rankExpressionFiles(), attributeFields, deployState);
- indexingScript = new IndexingScript(schema);
- indexInfo = new IndexInfo(schema);
- schemaInfo = new SchemaInfo(schema, deployState.rankProfileRegistry(), summaries);
- indexSchema = new IndexSchema(schema);
- importedFields = new ImportedFields(schema);
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("Invalid " + schema, e);
}
- Validation.validate(this, schema);
}
/**
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java
index ad050d4ca63..49fb48225e7 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java
@@ -36,12 +36,12 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor
node = transformChildren(composite, context);
if (node instanceof OperationNode arithmetic)
- node = transformBooleanArithmetics(arithmetic);
+ node = transformBooleanArithmetics(arithmetic, context);
return node;
}
- private ExpressionNode transformBooleanArithmetics(OperationNode node) {
+ private ExpressionNode transformBooleanArithmetics(OperationNode node, TransformContext context) {
Iterator<ExpressionNode> child = node.children().iterator();
// Transform in precedence order:
@@ -51,24 +51,25 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor
Operator op = it.next();
if ( ! stack.isEmpty()) {
while (stack.size() > 1 && ! op.hasPrecedenceOver(stack.peek().op)) {
- popStack(stack);
+ popStack(stack, context);
}
}
stack.push(new ChildNode(op, child.next()));
}
while (stack.size() > 1)
- popStack(stack);
+ popStack(stack, context);
return stack.getFirst().child;
}
- private void popStack(Deque<ChildNode> stack) {
+ private void popStack(Deque<ChildNode> stack, TransformContext context) {
ChildNode rhs = stack.pop();
ChildNode lhs = stack.peek();
+ boolean primitives = isDefinitelyPrimitive(lhs.child, context) && isDefinitelyPrimitive(rhs.child, context);
ExpressionNode combination;
- if (rhs.op == Operator.and)
+ if (primitives && rhs.op == Operator.and)
combination = andByIfNode(lhs.child, rhs.child);
- else if (rhs.op == Operator.or)
+ else if (primitives && rhs.op == Operator.or)
combination = orByIfNode(lhs.child, rhs.child);
else {
combination = resolve(lhs, rhs);
@@ -77,6 +78,22 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor
lhs.child = combination;
}
+ private boolean isDefinitelyPrimitive(ExpressionNode node, TransformContext context) {
+ try {
+ return node.type(context.types()).rank() == 0;
+ }
+ catch (IllegalArgumentException e) {
+ // Types can only be reliably resolved top down, which has not done here.
+ // E.g
+ // function(nameArg) {
+ // attribute(nameArg)
+ // }
+ // is supported.
+ // So, we return false when something cannot be resolved.
+ return false;
+ }
+ }
+
private static OperationNode resolve(ChildNode left, ChildNode right) {
if (! (left.child instanceof OperationNode) && ! (right.child instanceof OperationNode))
return new OperationNode(left.child, right.op, right.child);
@@ -103,7 +120,6 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor
joinedChildren.add(node.child);
}
-
private IfNode andByIfNode(ExpressionNode a, ExpressionNode b) {
return new IfNode(a, b, new ConstantNode(new BooleanValue(false)));
}
diff --git a/config-model/src/test/derived/neuralnet/neuralnet.sd b/config-model/src/test/derived/neuralnet/neuralnet.sd
index 54f6cefc6f4..95b7341a42f 100644
--- a/config-model/src/test/derived/neuralnet/neuralnet.sd
+++ b/config-model/src/test/derived/neuralnet/neuralnet.sd
@@ -3,6 +3,10 @@ schema neuralnet {
document neuralnet {
+ field uniqueRCount type double {
+ indexing: attribute
+ }
+
field pinned type int {
indexing: attribute
}
diff --git a/config-model/src/test/derived/neuralnet_noqueryprofile/neuralnet.sd b/config-model/src/test/derived/neuralnet_noqueryprofile/neuralnet.sd
index 073813d2198..e083b152aba 100644
--- a/config-model/src/test/derived/neuralnet_noqueryprofile/neuralnet.sd
+++ b/config-model/src/test/derived/neuralnet_noqueryprofile/neuralnet.sd
@@ -3,6 +3,10 @@ schema neuralnet {
document neuralnet {
+ field uniqueRCount type double {
+ indexing: attribute
+ }
+
field pinned type int {
indexing: attribute
}
diff --git a/config-model/src/test/derived/neuralnet_noqueryprofile/schema-info.cfg b/config-model/src/test/derived/neuralnet_noqueryprofile/schema-info.cfg
index 524a1253480..82bba81f0d5 100644
--- a/config-model/src/test/derived/neuralnet_noqueryprofile/schema-info.cfg
+++ b/config-model/src/test/derived/neuralnet_noqueryprofile/schema-info.cfg
@@ -10,6 +10,9 @@ schema[].summaryclass[].fields[].name "documentid"
schema[].summaryclass[].fields[].type "longstring"
schema[].summaryclass[].fields[].dynamic false
schema[].summaryclass[].name "attributeprefetch"
+schema[].summaryclass[].fields[].name "uniqueRCount"
+schema[].summaryclass[].fields[].type "double"
+schema[].summaryclass[].fields[].dynamic false
schema[].summaryclass[].fields[].name "pinned"
schema[].summaryclass[].fields[].type "integer"
schema[].summaryclass[].fields[].dynamic false
diff --git a/config-model/src/test/derived/rankingexpression/rankexpression.sd b/config-model/src/test/derived/rankingexpression/rankexpression.sd
index a5e7f07f6ac..7d8c79da5fb 100644
--- a/config-model/src/test/derived/rankingexpression/rankexpression.sd
+++ b/config-model/src/test/derived/rankingexpression/rankexpression.sd
@@ -3,6 +3,14 @@ schema rankexpression {
document rankexpression {
+ field nrtgmp type double {
+ indexing: attribute
+ }
+
+ field glmpfw type double {
+ indexing: attribute
+ }
+
field artist type string {
indexing: summary | index
}
diff --git a/config-model/src/test/derived/rankingexpression/summary.cfg b/config-model/src/test/derived/rankingexpression/summary.cfg
index 1c1453a8a89..b52cb055164 100644
--- a/config-model/src/test/derived/rankingexpression/summary.cfg
+++ b/config-model/src/test/derived/rankingexpression/summary.cfg
@@ -24,9 +24,15 @@ classes[].fields[].source ""
classes[].fields[].name "documentid"
classes[].fields[].command "documentid"
classes[].fields[].source ""
-classes[].id 1736696699
+classes[].id 399614584
classes[].name "attributeprefetch"
classes[].omitsummaryfeatures false
+classes[].fields[].name "nrtgmp"
+classes[].fields[].command "attribute"
+classes[].fields[].source "nrtgmp"
+classes[].fields[].name "glmpfw"
+classes[].fields[].command "attribute"
+classes[].fields[].source "glmpfw"
classes[].fields[].name "year"
classes[].fields[].command "attribute"
classes[].fields[].source "year"
diff --git a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java
index d692b69d3c8..d06573f7bae 100644
--- a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java
@@ -2,10 +2,13 @@
package com.yahoo.schema.expressiontransforms;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.MapTypeContext;
import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import org.junit.jupiter.api.Test;
import java.util.Map;
@@ -20,7 +23,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class BooleanExpressionTransformerTestCase {
@Test
- public void testTransformer() throws Exception {
+ public void booleanTransformation() throws Exception {
assertTransformed("if (a, b, false)", "a && b");
assertTransformed("if (a, true, b)", "a || b");
assertTransformed("if (a, true, b + c)", "a || b + c");
@@ -33,16 +36,17 @@ public class BooleanExpressionTransformerTestCase {
}
@Test
- public void testIt() throws Exception {
- assertTransformed("if(1 - 1, true, 1 - 1)", "1 - 1 || 1 - 1");
+ public void noTransformationOnTensorTypes() throws Exception {
+ var typeContext = new MapTypeContext();
+ typeContext.setType(Reference.fromIdentifier("tensorA"), TensorType.fromSpec("tensor(x{})"));
+ typeContext.setType(Reference.fromIdentifier("tensorB"), TensorType.fromSpec("tensor(x{})"));
+ assertUntransformed("tensorA && tensorB", typeContext);
+ assertTransformed("a && (tensorA * tensorB)","a && ( tensorA * tensorB)", typeContext);
}
@Test
public void testNotSkewingNonBoolean() throws Exception {
- assertTransformed("a + b + c * d + e + f", "a + b + c * d + e + f");
- var expr = new BooleanExpressionTransformer()
- .transform(new RankingExpression("a + b + c * d + e + f"),
- new TransformContext(Map.of(), new MapTypeContext()));
+ var expr = assertTransformed("a + b + c * d + e + f", "a + b + c * d + e + f");
assertTrue(expr.getRoot() instanceof OperationNode);
OperationNode root = (OperationNode) expr.getRoot();
assertEquals(5, root.operators().size());
@@ -51,41 +55,53 @@ public class BooleanExpressionTransformerTestCase {
@Test
public void testTransformPreservesPrecedence() throws Exception {
- assertUnTransformed("a");
- assertUnTransformed("a + b");
- assertUnTransformed("a + b + c");
- assertUnTransformed("a * b");
- assertUnTransformed("a + b * c + d");
- assertUnTransformed("a + b + c * d + e + f");
- assertUnTransformed("a * b + c + d + e * f");
- assertUnTransformed("(a * b) + c + d + e * f");
- assertUnTransformed("(a * b + c) + d + e * f");
- assertUnTransformed("a * (b + c) + d + e * f");
- assertUnTransformed("(a * b) + (c + (d + e)) * f");
+ assertUntransformed("a");
+ assertUntransformed("a + b");
+ assertUntransformed("a + b + c");
+ assertUntransformed("a * b");
+ assertUntransformed("a + b * c + d");
+ assertUntransformed("a + b + c * d + e + f");
+ assertUntransformed("a * b + c + d + e * f");
+ assertUntransformed("(a * b) + c + d + e * f");
+ assertUntransformed("(a * b + c) + d + e * f");
+ assertUntransformed("a * (b + c) + d + e * f");
+ assertUntransformed("(a * b) + (c + (d + e)) * f");
+ }
+
+ private void assertUntransformed(String input) throws Exception {
+ assertUntransformed(input, new MapTypeContext());
+ }
+
+ private void assertUntransformed(String input, MapTypeContext typeContext) throws Exception {
+ assertTransformed(input, input, typeContext);
}
- private void assertUnTransformed(String input) throws Exception {
- assertTransformed(input, input);
+ private RankingExpression assertTransformed(String expected, String input) throws Exception {
+ return assertTransformed(expected, input, new MapTypeContext());
}
- private void assertTransformed(String expected, String input) throws Exception {
+ private RankingExpression assertTransformed(String expected, String input, MapTypeContext typeContext) throws Exception {
+ MapContext context = contextWithSingleLetterVariables(typeContext);
var transformedExpression = new BooleanExpressionTransformer()
.transform(new RankingExpression(input),
- new TransformContext(Map.of(), new MapTypeContext()));
+ new TransformContext(Map.of(), typeContext));
assertEquals(new RankingExpression(expected), transformedExpression, "Transformed as expected");
- MapContext context = contextWithSingleLetterVariables();
var inputExpression = new RankingExpression(input);
assertEquals(inputExpression.evaluate(context).asBoolean(),
transformedExpression.evaluate(context).asBoolean(),
"Transform and original input are equivalent");
+ return transformedExpression;
}
- private MapContext contextWithSingleLetterVariables() {
+ private MapContext contextWithSingleLetterVariables(MapTypeContext typeContext) {
var context = new MapContext();
- for (int i = 0; i < 26; i++)
- context.put(Character.toString(i + 97), Math.floorMod(i, 2));
+ for (int i = 0; i < 26; i++) {
+ String name = Character.toString(i + 97);
+ typeContext.setType(Reference.fromIdentifier(name), TensorType.empty);
+ context.put(name, Math.floorMod(i, 2));
+ }
return context;
}
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java
index 83d19b010bb..2f53dba7bb4 100644
--- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java
@@ -31,6 +31,8 @@ public class RankingExpressionWithOnnxTestCase {
private final static String name = "mnist_softmax";
private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_layer_Variable), f(a,b)(a * b)), sum, d2) * 1.0, constant(mnist_softmax_layer_Variable_1) * 1.0, f(a,b)(a + b))";
+ private final static String vespaExpressionConstants = "constant mnist_softmax_layer_Variable { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }\n" +
+ "constant mnist_softmax_layer_Variable_1 { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }\n";
@AfterEach
public void removeGeneratedModelFiles() {
@@ -41,7 +43,7 @@ public class RankingExpressionWithOnnxTestCase {
void testOnnxReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
"onnx_vespa('mnist_softmax.onnx')",
- "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }",
+ vespaExpressionConstants + "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }",
null);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@@ -58,7 +60,7 @@ public class RankingExpressionWithOnnxTestCase {
queryProfileType);
RankProfileSearchFixture search = fixtureWith("query(mytensor)",
"onnx_vespa('mnist_softmax.onnx')",
- null,
+ vespaExpressionConstants,
null,
"Placeholder",
application);
@@ -70,7 +72,7 @@ public class RankingExpressionWithOnnxTestCase {
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = fixtureWith("attribute(mytensor)",
"onnx_vespa('mnist_softmax.onnx')",
- null,
+ vespaExpressionConstants,
"field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }",
"Placeholder",
application);
@@ -88,7 +90,7 @@ public class RankingExpressionWithOnnxTestCase {
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, queryProfileType);
RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)",
"onnx_vespa('mnist_softmax.onnx')",
- "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }",
+ vespaExpressionConstants + "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }",
"field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }",
"Placeholder",
application);
@@ -99,21 +101,24 @@ public class RankingExpressionWithOnnxTestCase {
@Test
void testNestedOnnxReference() {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "5 + sum(onnx_vespa('mnist_softmax.onnx'))");
+ "5 + sum(onnx_vespa('mnist_softmax.onnx'))",
+ vespaExpressionConstants);
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
}
@Test
void testOnnxReferenceWithSpecifiedOutput() {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "onnx_vespa('mnist_softmax.onnx', 'layer_add')");
+ "onnx_vespa('mnist_softmax.onnx', 'layer_add')",
+ vespaExpressionConstants);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@Test
void testOnnxReferenceWithSpecifiedOutputAndSignature() {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "onnx_vespa('mnist_softmax.onnx', 'default.layer_add')");
+ "onnx_vespa('mnist_softmax.onnx', 'default.layer_add')",
+ vespaExpressionConstants);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@@ -177,7 +182,8 @@ public class RankingExpressionWithOnnxTestCase {
@Test
void testImportingFromStoredExpressions() throws IOException {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "onnx_vespa(\"mnist_softmax.onnx\")");
+ "onnx_vespa(\"mnist_softmax.onnx\")",
+ vespaExpressionConstants);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
// At this point the expression is stored - copy application to another location which do not have a models dir
@@ -187,12 +193,14 @@ public class RankingExpressionWithOnnxTestCase {
IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
+ String constants = "constant mnist_softmax_layer_Variable { file: ignored\ntype: tensor<float>(d0[2],d1[784]) }\n" +
+ "constant mnist_softmax_layer_Variable_1 { file: ignored\ntype: tensor<float>(d0[2],d1[784]) }\n";
RankProfileSearchFixture searchFromStored = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
- "onnx_vespa('mnist_softmax.onnx')",
- null,
- null,
- "Placeholder",
- storedApplication);
+ "onnx_vespa('mnist_softmax.onnx')",
+ constants,
+ null,
+ "Placeholder",
+ storedApplication);
searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile");
// Verify that the constants exists, but don't verify the content as we are not
// simulating file distribution in this test
@@ -221,7 +229,8 @@ public class RankingExpressionWithOnnxTestCase {
String vespaExpressionWithoutConstant =
"join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), " + name + "_layer_Variable, f(a,b)(a * b)), sum, d2) * 1.0, constant(" + name + "_layer_Variable_1) * 1.0, f(a,b)(a + b))";
- RankProfileSearchFixture search = uncompiledFixtureWith(rankProfile, new StoringApplicationPackage(applicationDir));
+ String constant = "constant mnist_softmax_layer_Variable_1 { file: ignored\ntype: tensor<float>(d0[1],d1[10]) }\n";
+ RankProfileSearchFixture search = uncompiledFixtureWith(rankProfile, new StoringApplicationPackage(applicationDir), constant);
search.compileRankProfile("my_profile", applicationDir.append("models"));
search.compileRankProfile("my_profile_child", applicationDir.append("models"));
search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
@@ -237,7 +246,7 @@ public class RankingExpressionWithOnnxTestCase {
IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
- RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfile, storedApplication);
+ RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfile, storedApplication, constant);
searchFromStored.compileRankProfile("my_profile", applicationDir.append("models"));
searchFromStored.compileRankProfile("my_profile_child", applicationDir.append("models"));
searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
@@ -326,7 +335,11 @@ public class RankingExpressionWithOnnxTestCase {
}
private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) {
- return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder",
+ return fixtureWith(placeholderExpression, firstPhaseExpression, null);
+ }
+
+ private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression, String constant) {
+ return fixtureWith(placeholderExpression, firstPhaseExpression, constant, null, "Placeholder",
new StoringApplicationPackage(applicationDir));
}
@@ -337,9 +350,13 @@ public class RankingExpressionWithOnnxTestCase {
}
private RankProfileSearchFixture uncompiledFixtureWith(String rankProfile, StoringApplicationPackage application) {
+ return uncompiledFixtureWith(rankProfile, application, null);
+ }
+
+ private RankProfileSearchFixture uncompiledFixtureWith(String rankProfile, StoringApplicationPackage application, String constant) {
try {
return new RankProfileSearchFixture(application, application.getQueryProfiles(),
- rankProfile, null, null);
+ rankProfile, constant, null);
}
catch (ParseException e) {
throw new IllegalArgumentException(e);