aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java43
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java5
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java1
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java11
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java260
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java15
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java8
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java196
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java13
10 files changed, 416 insertions, 138 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index 6de7c985326..65443117c0a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -47,12 +47,15 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
private final SortedSet<Reference> queryFeaturesNotDeclared;
private boolean tensorsAreUsed;
+ private final MapEvaluationTypeContext parent;
+
MapEvaluationTypeContext(Collection<ExpressionFunction> functions, Map<Reference, TensorType> featureTypes) {
super(functions);
this.featureTypes.putAll(featureTypes);
this.currentResolutionCallStack = new ArrayDeque<>();
this.queryFeaturesNotDeclared = new TreeSet<>();
tensorsAreUsed = false;
+ parent = null;
}
private MapEvaluationTypeContext(Map<String, ExpressionFunction> functions,
@@ -60,12 +63,14 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
Map<Reference, TensorType> featureTypes,
Deque<Reference> currentResolutionCallStack,
SortedSet<Reference> queryFeaturesNotDeclared,
- boolean tensorsAreUsed) {
+ boolean tensorsAreUsed,
+ MapEvaluationTypeContext parent) {
super(functions, bindings);
this.featureTypes.putAll(featureTypes);
this.currentResolutionCallStack = currentResolutionCallStack;
this.queryFeaturesNotDeclared = queryFeaturesNotDeclared;
this.tensorsAreUsed = tensorsAreUsed;
+ this.parent = parent;
}
public void setType(Reference reference, TensorType type) {
@@ -82,16 +87,45 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
resolvedTypes.clear();
}
+ private TensorType resolvedType(Reference reference, int depth) {
+// System.out.println(indent + "In resolvedtype - resolving type for " + reference.toString());
+ TensorType resolvedType = resolvedTypes.get(reference);
+ if (resolvedType != null) {
+// System.out.println("Found previously resolved type for " + reference + " at depth " + depth + ": (" + resolvedType + ")");
+ return resolvedType;
+ }
+ if (parent != null) return parent.resolvedType(reference, depth + 1); // what about argument types? Careful with this!
+// System.out.println("Could NOT find type for " + reference + " - down to depth " + depth);
+ return null;
+ }
+
+ private MapEvaluationTypeContext findOriginalParent() {
+ if (parent != null)
+ return parent.findOriginalParent();
+ return this;
+ }
+
@Override
public TensorType getType(Reference reference) {
// computeIfAbsent without concurrent modification due to resolve adding more resolved entries:
- TensorType resolvedType = resolvedTypes.get(reference);
+ // TensorType resolvedType = resolvedTypes.get(reference);
+ TensorType resolvedType = resolvedType(reference, 0);
if (resolvedType != null) return resolvedType;
resolvedType = resolveType(reference);
if (resolvedType == null)
return defaultTypeOf(reference); // Don't store fallback to default as we may know more later
- resolvedTypes.put(reference, resolvedType);
+
+// System.out.println("Resolved type of " + reference + ": (" + resolvedType + ")");
+
+ // Må inn her med et konsept av global eller lokal.
+ // For globale - legg i lavest parent!
+ MapEvaluationTypeContext originalParent = findOriginalParent();
+ if (originalParent == null) {
+ originalParent = this;
+ }
+ originalParent.resolvedTypes.put(reference, resolvedType);
+
if (resolvedType.rank() > 0)
tensorsAreUsed = true;
return resolvedType;
@@ -103,6 +137,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) +
" -> " + reference);
+
// Bound to a function argument, and not to a same-named identifier (which would lead to a loop)?
Optional<String> binding = boundIdentifier(reference);
if (binding.isPresent() && ! binding.get().equals(reference.toString())) {
@@ -254,7 +289,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
featureTypes,
currentResolutionCallStack,
queryFeaturesNotDeclared,
- tensorsAreUsed);
+ tensorsAreUsed, this);
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 23eb814de81..ea126123a25 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -680,11 +680,12 @@ public class RankProfile implements Cloneable {
Map<String, RankingExpressionFunction> inlineFunctions =
compileFunctions(this::getInlineFunctions, queryProfiles, featureTypes, importedModels, Collections.emptyMap(), expressionTransforms);
+ firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms);
+ secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms);
+
// Function compiling second pass: compile all functions and insert previously compiled inline functions
functions = compileFunctions(this::getFunctions, queryProfiles, featureTypes, importedModels, inlineFunctions, expressionTransforms);
- firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms);
- secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms);
}
private void checkNameCollisions(Map<String, RankingExpressionFunction> functions, Map<String, Value> constants) {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index 1a22b98fd9f..3578cc786ed 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -222,7 +222,9 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
List<ExpressionFunction> functionExpressions) {
SerializationContext context = new SerializationContext(functionExpressions);
for (Map.Entry<String, RankProfile.RankingExpressionFunction> e : functions.entrySet()) {
+ System.out.println("Deriving: " + e.getKey());
String expressionString = e.getValue().function().getBody().getRoot().toString(new StringBuilder(), context, null, null).toString();
+ System.out.println("-> Done deriving: " + e.getKey() + ": " + expressionString);
context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expressionString);
for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet())
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java
index 89b8889b4ae..8d9098a10f1 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java
@@ -79,6 +79,7 @@ public class RankingExpressionTypeResolver extends Processor {
}
context.forgetResolvedTypes();
+ System.out.println("Resolving type for " + function.getKey());
TensorType type = resolveType(expressionFunction.getBody(), "function '" + function.getKey() + "'", context);
function.getValue().setReturnType(type);
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
index b6f7ab4ff62..d3e029b8de5 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
@@ -102,7 +102,8 @@ public class RankSetupValidator extends Validator {
}
private void deleteTempDir(File dir) {
- IOUtils.recursiveDeleteDir(dir);
+ System.out.println("Here we were supposed to delete tmpdir: " + dir.getAbsolutePath());
+// IOUtils.recursiveDeleteDir(dir);
}
private void writeConfigs(String dir, AbstractConfigProducer<?> producer) throws IOException {
@@ -133,7 +134,13 @@ public class RankSetupValidator extends Validator {
}
private static void writeConfig(String dir, String configName, ConfigInstance config) throws IOException {
- IOUtils.writeFile(dir + configName, StringUtilities.implodeMultiline(ConfigInstance.serialize(config)), false);
+
+ String output = StringUtilities.implodeMultiline(ConfigInstance.serialize(config));
+ System.out.println("Writing config for in " + dir + " for configName '" + configName + "' ");
+ System.out.println(output);
+ IOUtils.writeFile(dir + configName, output, false);
+
+// IOUtils.writeFile(dir + configName, StringUtilities.implodeMultiline(ConfigInstance.serialize(config)), false);
}
private boolean execValidate(String configId, SearchCluster sc, String sdName, DeployLogger deployLogger) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index c3d6f457ce8..9f649bc820a 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -386,138 +386,138 @@ public class ConvertedModel {
*/
private static void reduceBatchDimensions(RankingExpression expression, ImportedMlModel model,
RankProfile profile, QueryProfileRegistry queryProfiles) {
- MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles);
-
- // Add any missing inputs for type resolution
- Set<String> functionNames = new HashSet<>();
- addFunctionNamesIn(expression.getRoot(), functionNames, model);
- for (String functionName : functionNames) {
- Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec);
- if (requiredType.isPresent()) {
- Reference ref = Reference.fromIdentifier(functionName);
- if (typeContext.getType(ref).equals(TensorType.empty)) {
- typeContext.setType(ref, requiredType.get());
- }
- }
- }
- typeContext.forgetResolvedTypes();
-
- TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
-
- // Check generated functions for inputs to reduce
- for (String functionName : functionNames) {
- if ( ! model.functions().containsKey(functionName)) continue;
-
- RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName);
- if (rankingExpressionFunction == null) {
- throw new IllegalArgumentException("Model refers to generated function '" + functionName +
- "but this function is not present in " + profile);
- }
- RankingExpression functionExpression = rankingExpressionFunction.function().getBody();
- functionExpression.setRoot(reduceBatchDimensionsAtInput(functionExpression.getRoot(), model, typeContext));
- }
-
- // Check expression for inputs to reduce
- ExpressionNode root = expression.getRoot();
- root = reduceBatchDimensionsAtInput(root, model, typeContext);
- TensorType typeAfterReducing = root.type(typeContext);
- root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
- expression.setRoot(root);
- }
-
- private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model,
- MapEvaluationTypeContext typeContext) {
- if (node instanceof TensorFunctionNode) {
- TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
- if (tensorFunction instanceof Rename) {
- List<ExpressionNode> children = ((TensorFunctionNode)node).children();
- if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) children.get(0);
- if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
- return reduceBatchDimensionExpression(tensorFunction, typeContext);
- }
- }
- // Modify any renames in expression to disregard batch dimension
- else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) {
- TensorFunction<Reference> childFunction = (((TensorFunctionNode) children.get(0)).function());
- TensorType childType = childFunction.type(typeContext);
- Rename rename = (Rename) tensorFunction;
- List<String> from = new ArrayList<>();
- List<String> to = new ArrayList<>();
- for (TensorType.Dimension dimension : childType.dimensions()) {
- int i = rename.fromDimensions().indexOf(dimension.name());
- if (i < 0) {
- throw new IllegalArgumentException("Rename does not contain dimension '" +
- dimension + "' in child expression type: " + childType);
- }
- from.add((String)rename.fromDimensions().get(i));
- to.add((String)rename.toDimensions().get(i));
- }
- return new TensorFunctionNode(new Rename<>(childFunction, from, to));
- }
- }
- }
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) node;
- if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
- return reduceBatchDimensionExpression(TensorFunctionNode.wrap(node), typeContext);
- }
- }
- if (node instanceof CompositeNode) {
- List<ExpressionNode> children = ((CompositeNode)node).children();
- List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
- for (ExpressionNode child : children) {
- transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
- }
- return ((CompositeNode)node).setChildren(transformedChildren);
- }
- return node;
- }
-
- private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, MapEvaluationTypeContext context) {
- TensorFunction result = function;
- TensorType type = function.type(context);
- if (type.dimensions().size() > 1) {
- List<String> reduceDimensions = new ArrayList<>();
- for (TensorType.Dimension dimension : type.dimensions()) {
- if (dimension.size().orElse(-1L) == 1) {
- reduceDimensions.add(dimension.name());
- }
- }
- if (reduceDimensions.size() > 0) {
- result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
- context.forgetResolvedTypes(); // We changed types
- }
- }
- return new TensorFunctionNode(result);
+// MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles);
+//
+// // Add any missing inputs for type resolution
+// Set<String> functionNames = new HashSet<>();
+// addFunctionNamesIn(expression.getRoot(), functionNames, model);
+// for (String functionName : functionNames) {
+// Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec);
+// if (requiredType.isPresent()) {
+// Reference ref = Reference.fromIdentifier(functionName);
+// if (typeContext.getType(ref).equals(TensorType.empty)) {
+// typeContext.setType(ref, requiredType.get());
+// }
+// }
+// }
+// typeContext.forgetResolvedTypes();
+//
+// TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
+//
+// // Check generated functions for inputs to reduce
+// for (String functionName : functionNames) {
+// if ( ! model.functions().containsKey(functionName)) continue;
+//
+// RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName);
+// if (rankingExpressionFunction == null) {
+// throw new IllegalArgumentException("Model refers to generated function '" + functionName +
+// "but this function is not present in " + profile);
+// }
+// RankingExpression functionExpression = rankingExpressionFunction.function().getBody();
+// functionExpression.setRoot(reduceBatchDimensionsAtInput(functionExpression.getRoot(), model, typeContext));
+// }
+//
+// // Check expression for inputs to reduce
+// ExpressionNode root = expression.getRoot();
+// root = reduceBatchDimensionsAtInput(root, model, typeContext);
+// TensorType typeAfterReducing = root.type(typeContext);
+// root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
+// expression.setRoot(root);
}
- /**
- * If batch dimensions have been reduced away above, bring them back here
- * for any following computation of the tensor.
- */
- // TODO: determine when this is not necessary!
- private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
- if (after.equals(before)) return node;
-
- TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType());
- for (TensorType.Dimension dimension : before.dimensions()) {
- if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
- typeBuilder.indexed(dimension.name(), 1);
- }
- }
- TensorType expandDimensionsType = typeBuilder.build();
- if (expandDimensionsType.dimensions().size() > 0) {
- ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
- Generate generatedFunction = new Generate(expandDimensionsType,
- new GeneratorLambdaFunctionNode(expandDimensionsType,
- generatedExpression)
- .asLongListToDoubleOperator());
- Join expand = new Join(TensorFunctionNode.wrap(node), generatedFunction, ScalarFunctions.multiply());
- return new TensorFunctionNode(expand);
- }
- return node;
- }
+// private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model,
+// MapEvaluationTypeContext typeContext) {
+// if (node instanceof TensorFunctionNode) {
+// TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
+// if (tensorFunction instanceof Rename) {
+// List<ExpressionNode> children = ((TensorFunctionNode)node).children();
+// if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
+// ReferenceNode referenceNode = (ReferenceNode) children.get(0);
+// if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
+// return reduceBatchDimensionExpression(tensorFunction, typeContext);
+// }
+// }
+// // Modify any renames in expression to disregard batch dimension
+// else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) {
+// TensorFunction<Reference> childFunction = (((TensorFunctionNode) children.get(0)).function());
+// TensorType childType = childFunction.type(typeContext);
+// Rename rename = (Rename) tensorFunction;
+// List<String> from = new ArrayList<>();
+// List<String> to = new ArrayList<>();
+// for (TensorType.Dimension dimension : childType.dimensions()) {
+// int i = rename.fromDimensions().indexOf(dimension.name());
+// if (i < 0) {
+// throw new IllegalArgumentException("Rename does not contain dimension '" +
+// dimension + "' in child expression type: " + childType);
+// }
+// from.add((String)rename.fromDimensions().get(i));
+// to.add((String)rename.toDimensions().get(i));
+// }
+// return new TensorFunctionNode(new Rename<>(childFunction, from, to));
+// }
+// }
+// }
+// if (node instanceof ReferenceNode) {
+// ReferenceNode referenceNode = (ReferenceNode) node;
+// if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
+// return reduceBatchDimensionExpression(TensorFunctionNode.wrap(node), typeContext);
+// }
+// }
+// if (node instanceof CompositeNode) {
+// List<ExpressionNode> children = ((CompositeNode)node).children();
+// List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
+// for (ExpressionNode child : children) {
+// transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
+// }
+// return ((CompositeNode)node).setChildren(transformedChildren);
+// }
+// return node;
+// }
+//
+// private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, MapEvaluationTypeContext context) {
+// TensorFunction result = function;
+// TensorType type = function.type(context);
+// if (type.dimensions().size() > 1) {
+// List<String> reduceDimensions = new ArrayList<>();
+// for (TensorType.Dimension dimension : type.dimensions()) {
+// if (dimension.size().orElse(-1L) == 1) {
+// reduceDimensions.add(dimension.name());
+// }
+// }
+// if (reduceDimensions.size() > 0) {
+// result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
+// context.forgetResolvedTypes(); // We changed types
+// }
+// }
+// return new TensorFunctionNode(result);
+// }
+//
+// /**
+// * If batch dimensions have been reduced away above, bring them back here
+// * for any following computation of the tensor.
+// */
+// // TODO: determine when this is not necessary!
+// private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
+// if (after.equals(before)) return node;
+//
+// TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType());
+// for (TensorType.Dimension dimension : before.dimensions()) {
+// if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
+// typeBuilder.indexed(dimension.name(), 1);
+// }
+// }
+// TensorType expandDimensionsType = typeBuilder.build();
+// if (expandDimensionsType.dimensions().size() > 0) {
+// ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
+// Generate generatedFunction = new Generate(expandDimensionsType,
+// new GeneratorLambdaFunctionNode(expandDimensionsType,
+// generatedExpression)
+// .asLongListToDoubleOperator());
+// Join expand = new Join(TensorFunctionNode.wrap(node), generatedFunction, ScalarFunctions.multiply());
+// return new TensorFunctionNode(expand);
+// }
+// return node;
+// }
/**
* If a constant c is overridden by a function, we need to replace instances of "constant(c)" by "c" in expressions.
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
index d84d967a184..8bc9040577b 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
@@ -2,8 +2,11 @@
package com.yahoo.searchdefinition;
import com.yahoo.collections.Pair;
+import com.yahoo.config.model.application.provider.BaseDeployLogger;
+import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
+import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.yolean.Exceptions;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
@@ -82,6 +85,18 @@ public class RankingExpressionConstantsTestCase extends SchemaTestCase {
new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(foo).rankingScript,14.0)", rankProperties.get(0).toString());
assertEquals("(rankingExpression(firstphase).rankingScript,16.6)", rankProperties.get(2).toString());
+
+ try {
+ DerivedConfiguration config = new DerivedConfiguration(s,
+ new BaseDeployLogger(),
+ new TestProperties(),
+ rankProfileRegistry,
+ queryProfileRegistry,
+ new ImportedMlModels());
+ config.export("/Users/lesters/temp/bert/idea/");
+ } catch (Exception e) {
+ throw new IllegalArgumentException(e);
+ }
}
@Test
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
index 0cd6674751e..d5638da224c 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -41,6 +41,14 @@ class RankProfileSearchFixture {
private Search search;
private Map<String, RankProfile> compiledRankProfiles = new HashMap<>();
+ // TEMP
+ public RankProfileRegistry getRankProfileRegistry() {
+ return rankProfileRegistry;
+ }
+ public QueryProfileRegistry getQueryProfileRegistry() {
+ return queryProfileRegistry;
+ }
+
RankProfileSearchFixture(String rankProfiles) throws ParseException {
this(MockApplicationPackage.createEmpty(), new QueryProfileRegistry(), rankProfiles);
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java
new file mode 100644
index 00000000000..2c0620a0c52
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java
@@ -0,0 +1,196 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.processing;
+
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
+import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
+import ai.vespa.rankingexpression.importer.lightgbm.LightGBMImporter;
+import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
+import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
+import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
+import com.google.common.collect.ImmutableList;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.application.provider.BaseDeployLogger;
+import com.yahoo.config.model.deploy.TestProperties;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchdefinition.SearchBuilder;
+import com.yahoo.searchdefinition.derived.DerivedConfiguration;
+import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage;
+import com.yahoo.vespa.model.VespaModel;
+import com.yahoo.vespa.model.ml.ImportedModelTester;
+import com.yahoo.yolean.Exceptions;
+import org.junit.After;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.Optional;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+
+public class RankingExpressionWithBertTestCase {
+
+ private final Path applicationDir = Path.fromString("src/test/integration/bert/");
+
+ /** The model name */
+ private final static String name = "bertsquad8";
+
+ private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_Variable), f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))";
+
+ @After
+ public void removeGeneratedModelFiles() {
+ IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ }
+
+
+ @Ignore
+ @Test
+ public void testGlobalBertModel() throws IOException {
+ ImportedModelTester tester = new ImportedModelTester(name, applicationDir);
+ VespaModel model = tester.createVespaModel();
+// tester.assertLargeConstant(name + "_Variable_1", model, Optional.of(10L));
+// tester.assertLargeConstant(name + "_Variable", model, Optional.of(7840L));
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedAppDir = applicationDir.append("copy");
+ try {
+ storedAppDir.toFile().mkdirs();
+ IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString());
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir);
+ VespaModel storedModel = storedTester.createVespaModel();
+// tester.assertLargeConstant(name + "_Variable_1", storedModel, Optional.of(10L));
+// tester.assertLargeConstant(name + "_Variable", storedModel, Optional.of(7840L));
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(storedAppDir.toFile());
+ }
+ }
+
+ @Ignore
+ @Test
+ public void testBertRankProfile() throws Exception {
+ StoringApplicationPackage application = new StoringApplicationPackage((applicationDir));
+
+ ImmutableList<MlModelImporter> importers = ImmutableList.of(new TensorFlowImporter(),
+ new OnnxImporter(),
+ new LightGBMImporter(),
+ new XGBoostImporter());
+
+ String rankProfiles = " rank-profile my_profile {\n" +
+ " first-phase {\n" +
+ " expression: onnx('bertsquad8.onnx', 'default', 'unstack')" +
+ " }\n" +
+ " }";
+
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ QueryProfileRegistry queryProfileRegistry = application.getQueryProfiles();
+
+ SearchBuilder builder = new SearchBuilder(application, rankProfileRegistry, queryProfileRegistry);
+ String sdContent = "search test {\n" +
+ " document test {\n" +
+ " field unique_ids type tensor(d0[1]) {\n" +
+ " indexing: summary | attribute\n" +
+ " }\n" +
+ " field input_ids type tensor(d0[1],d1[256]) {\n" +
+ " indexing: summary | attribute\n" +
+ " }\n" +
+ " field input_mask type tensor(d0[1],d1[256]) {\n" +
+ " indexing: summary | attribute\n" +
+ " }\n" +
+ " field segment_ids type tensor(d0[1],d1[256]) {\n" +
+ " indexing: summary | attribute\n" +
+ " }" +
+ " }\n" +
+ " rank-profile my_profile inherits default {\n" +
+ " function inline unique_ids_raw_output___9() {\n" +
+ " expression: attribute(unique_ids)\n" +
+ " }\n" +
+ " function inline input_ids() {\n" +
+ " expression: attribute(input_ids)\n" +
+ " }\n" +
+ " function inline input_mask() {\n" +
+ " expression: attribute(input_mask)\n" +
+ " }\n" +
+ " function inline segment_ids() {\n" +
+ " expression: attribute(segment_ids)\n" +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: onnx(\"bertsquad8.onnx\", \"default\", \"unstack\") \n" +
+ " }\n" +
+ " }" +
+ "}";
+ builder.importString(sdContent);
+ builder.build();
+ Search search = builder.getSearch();
+
+ RankProfile compiled = rankProfileRegistry.get(search, "my_profile")
+ .compile(queryProfileRegistry,
+ new ImportedMlModels(applicationDir.toFile(), importers));
+
+ DerivedConfiguration config = new DerivedConfiguration(search,
+ new BaseDeployLogger(),
+ new TestProperties(),
+ rankProfileRegistry,
+ queryProfileRegistry,
+ new ImportedMlModels());
+
+ config.export("/Users/lesters/temp/bert/idea/");
+
+// fixture.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ System.out.println("Joda");
+ }
+
+ private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression,
+ String constant, String field) {
+ return fixtureWith(placeholderExpression, firstPhaseExpression, constant, field, "Placeholder",
+ new StoringApplicationPackage(applicationDir));
+ }
+
+ private RankProfileSearchFixture uncompiledFixtureWith(String rankProfile, StoringApplicationPackage application) {
+ try {
+ return new RankProfileSearchFixture(application, application.getQueryProfiles(),
+ rankProfile, null, null);
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+ private RankProfileSearchFixture fixtureWith(String functionExpression,
+ String firstPhaseExpression,
+ String constant,
+ String field,
+ String functionName,
+ StoringApplicationPackage application) {
+ try {
+ RankProfileSearchFixture fixture = new RankProfileSearchFixture(
+ application,
+ application.getQueryProfiles(),
+ " rank-profile my_profile {\n" +
+ " function " + functionName + "() {\n" +
+ " expression: " + functionExpression +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: " + firstPhaseExpression +
+ " }\n" +
+ " }",
+ constant,
+ field);
+ fixture.compileRankProfile("my_profile", applicationDir.append("models"));
+ return fixture;
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index cba931e81f0..c444bf8d7dc 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -1,8 +1,11 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.processing;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.application.provider.BaseDeployLogger;
+import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
@@ -10,6 +13,7 @@ import com.yahoo.io.reader.NamedReader;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankingConstant;
+import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.searchdefinition.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.Tensor;
@@ -385,6 +389,15 @@ public class RankingExpressionWithTensorFlowTestCase {
finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
}
+
+ DerivedConfiguration config = new DerivedConfiguration(search.search(),
+ new BaseDeployLogger(),
+ new TestProperties(),
+ search.getRankProfileRegistry(),
+ search.getQueryProfileRegistry(),
+ new ImportedMlModels());
+ config.export("/Users/lesters/temp/bert/idea/");
+
}
private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) {