diff options
Diffstat (limited to 'config-model/src')
10 files changed, 416 insertions, 138 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index 6de7c985326..65443117c0a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -47,12 +47,15 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement private final SortedSet<Reference> queryFeaturesNotDeclared; private boolean tensorsAreUsed; + private final MapEvaluationTypeContext parent; + MapEvaluationTypeContext(Collection<ExpressionFunction> functions, Map<Reference, TensorType> featureTypes) { super(functions); this.featureTypes.putAll(featureTypes); this.currentResolutionCallStack = new ArrayDeque<>(); this.queryFeaturesNotDeclared = new TreeSet<>(); tensorsAreUsed = false; + parent = null; } private MapEvaluationTypeContext(Map<String, ExpressionFunction> functions, @@ -60,12 +63,14 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement Map<Reference, TensorType> featureTypes, Deque<Reference> currentResolutionCallStack, SortedSet<Reference> queryFeaturesNotDeclared, - boolean tensorsAreUsed) { + boolean tensorsAreUsed, + MapEvaluationTypeContext parent) { super(functions, bindings); this.featureTypes.putAll(featureTypes); this.currentResolutionCallStack = currentResolutionCallStack; this.queryFeaturesNotDeclared = queryFeaturesNotDeclared; this.tensorsAreUsed = tensorsAreUsed; + this.parent = parent; } public void setType(Reference reference, TensorType type) { @@ -82,16 +87,45 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement resolvedTypes.clear(); } + private TensorType resolvedType(Reference reference, int depth) { +// System.out.println(indent + "In resolvedtype - resolving type for " + reference.toString()); + TensorType resolvedType = resolvedTypes.get(reference); + if (resolvedType != null) { +// System.out.println("Found previously resolved type for " + reference + " at depth " + depth + ": (" + resolvedType + ")"); + return resolvedType; + } + if (parent != null) return parent.resolvedType(reference, depth + 1); // what about argument types? Careful with this! +// System.out.println("Could NOT find type for " + reference + " - down to depth " + depth); + return null; + } + + private MapEvaluationTypeContext findOriginalParent() { + if (parent != null) + return parent.findOriginalParent(); + return this; + } + @Override public TensorType getType(Reference reference) { // computeIfAbsent without concurrent modification due to resolve adding more resolved entries: - TensorType resolvedType = resolvedTypes.get(reference); + // TensorType resolvedType = resolvedTypes.get(reference); + TensorType resolvedType = resolvedType(reference, 0); if (resolvedType != null) return resolvedType; resolvedType = resolveType(reference); if (resolvedType == null) return defaultTypeOf(reference); // Don't store fallback to default as we may know more later - resolvedTypes.put(reference, resolvedType); + +// System.out.println("Resolved type of " + reference + ": (" + resolvedType + ")"); + + // Må inn her med et konsept av global eller lokal. + // For globale - legg i lavest parent! + MapEvaluationTypeContext originalParent = findOriginalParent(); + if (originalParent == null) { + originalParent = this; + } + originalParent.resolvedTypes.put(reference, resolvedType); + if (resolvedType.rank() > 0) tensorsAreUsed = true; return resolvedType; @@ -103,6 +137,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) + " -> " + reference); + // Bound to a function argument, and not to a same-named identifier (which would lead to a loop)? Optional<String> binding = boundIdentifier(reference); if (binding.isPresent() && ! binding.get().equals(reference.toString())) { @@ -254,7 +289,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement featureTypes, currentResolutionCallStack, queryFeaturesNotDeclared, - tensorsAreUsed); + tensorsAreUsed, this); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 23eb814de81..ea126123a25 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -680,11 +680,12 @@ public class RankProfile implements Cloneable { Map<String, RankingExpressionFunction> inlineFunctions = compileFunctions(this::getInlineFunctions, queryProfiles, featureTypes, importedModels, Collections.emptyMap(), expressionTransforms); + firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); + secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); + // Function compiling second pass: compile all functions and insert previously compiled inline functions functions = compileFunctions(this::getFunctions, queryProfiles, featureTypes, importedModels, inlineFunctions, expressionTransforms); - firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); - secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); } private void checkNameCollisions(Map<String, RankingExpressionFunction> functions, Map<String, Value> constants) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 1a22b98fd9f..3578cc786ed 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -222,7 +222,9 @@ public class RawRankProfile implements RankProfilesConfig.Producer { List<ExpressionFunction> functionExpressions) { SerializationContext context = new SerializationContext(functionExpressions); for (Map.Entry<String, RankProfile.RankingExpressionFunction> e : functions.entrySet()) { + System.out.println("Deriving: " + e.getKey()); String expressionString = e.getValue().function().getBody().getRoot().toString(new StringBuilder(), context, null, null).toString(); + System.out.println("-> Done deriving: " + e.getKey() + ": " + expressionString); context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expressionString); for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet()) diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java index 89b8889b4ae..8d9098a10f1 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java @@ -79,6 +79,7 @@ public class RankingExpressionTypeResolver extends Processor { } context.forgetResolvedTypes(); + System.out.println("Resolving type for " + function.getKey()); TensorType type = resolveType(expressionFunction.getBody(), "function '" + function.getKey() + "'", context); function.getValue().setReturnType(type); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java index b6f7ab4ff62..d3e029b8de5 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java @@ -102,7 +102,8 @@ public class RankSetupValidator extends Validator { } private void deleteTempDir(File dir) { - IOUtils.recursiveDeleteDir(dir); + System.out.println("Here we were supposed to delete tmpdir: " + dir.getAbsolutePath()); +// IOUtils.recursiveDeleteDir(dir); } private void writeConfigs(String dir, AbstractConfigProducer<?> producer) throws IOException { @@ -133,7 +134,13 @@ public class RankSetupValidator extends Validator { } private static void writeConfig(String dir, String configName, ConfigInstance config) throws IOException { - IOUtils.writeFile(dir + configName, StringUtilities.implodeMultiline(ConfigInstance.serialize(config)), false); + + String output = StringUtilities.implodeMultiline(ConfigInstance.serialize(config)); + System.out.println("Writing config for in " + dir + " for configName '" + configName + "' "); + System.out.println(output); + IOUtils.writeFile(dir + configName, output, false); + +// IOUtils.writeFile(dir + configName, StringUtilities.implodeMultiline(ConfigInstance.serialize(config)), false); } private boolean execValidate(String configId, SearchCluster sc, String sdName, DeployLogger deployLogger) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index c3d6f457ce8..9f649bc820a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -386,138 +386,138 @@ public class ConvertedModel { */ private static void reduceBatchDimensions(RankingExpression expression, ImportedMlModel model, RankProfile profile, QueryProfileRegistry queryProfiles) { - MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles); - - // Add any missing inputs for type resolution - Set<String> functionNames = new HashSet<>(); - addFunctionNamesIn(expression.getRoot(), functionNames, model); - for (String functionName : functionNames) { - Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec); - if (requiredType.isPresent()) { - Reference ref = Reference.fromIdentifier(functionName); - if (typeContext.getType(ref).equals(TensorType.empty)) { - typeContext.setType(ref, requiredType.get()); - } - } - } - typeContext.forgetResolvedTypes(); - - TensorType typeBeforeReducing = expression.getRoot().type(typeContext); - - // Check generated functions for inputs to reduce - for (String functionName : functionNames) { - if ( ! model.functions().containsKey(functionName)) continue; - - RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); - if (rankingExpressionFunction == null) { - throw new IllegalArgumentException("Model refers to generated function '" + functionName + - "but this function is not present in " + profile); - } - RankingExpression functionExpression = rankingExpressionFunction.function().getBody(); - functionExpression.setRoot(reduceBatchDimensionsAtInput(functionExpression.getRoot(), model, typeContext)); - } - - // Check expression for inputs to reduce - ExpressionNode root = expression.getRoot(); - root = reduceBatchDimensionsAtInput(root, model, typeContext); - TensorType typeAfterReducing = root.type(typeContext); - root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); - expression.setRoot(root); - } - - private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model, - MapEvaluationTypeContext typeContext) { - if (node instanceof TensorFunctionNode) { - TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); - if (tensorFunction instanceof Rename) { - List<ExpressionNode> children = ((TensorFunctionNode)node).children(); - if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { - return reduceBatchDimensionExpression(tensorFunction, typeContext); - } - } - // Modify any renames in expression to disregard batch dimension - else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) { - TensorFunction<Reference> childFunction = (((TensorFunctionNode) children.get(0)).function()); - TensorType childType = childFunction.type(typeContext); - Rename rename = (Rename) tensorFunction; - List<String> from = new ArrayList<>(); - List<String> to = new ArrayList<>(); - for (TensorType.Dimension dimension : childType.dimensions()) { - int i = rename.fromDimensions().indexOf(dimension.name()); - if (i < 0) { - throw new IllegalArgumentException("Rename does not contain dimension '" + - dimension + "' in child expression type: " + childType); - } - from.add((String)rename.fromDimensions().get(i)); - to.add((String)rename.toDimensions().get(i)); - } - return new TensorFunctionNode(new Rename<>(childFunction, from, to)); - } - } - } - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) node; - if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { - return reduceBatchDimensionExpression(TensorFunctionNode.wrap(node), typeContext); - } - } - if (node instanceof CompositeNode) { - List<ExpressionNode> children = ((CompositeNode)node).children(); - List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); - for (ExpressionNode child : children) { - transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); - } - return ((CompositeNode)node).setChildren(transformedChildren); - } - return node; - } - - private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, MapEvaluationTypeContext context) { - TensorFunction result = function; - TensorType type = function.type(context); - if (type.dimensions().size() > 1) { - List<String> reduceDimensions = new ArrayList<>(); - for (TensorType.Dimension dimension : type.dimensions()) { - if (dimension.size().orElse(-1L) == 1) { - reduceDimensions.add(dimension.name()); - } - } - if (reduceDimensions.size() > 0) { - result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); - context.forgetResolvedTypes(); // We changed types - } - } - return new TensorFunctionNode(result); +// MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles); +// +// // Add any missing inputs for type resolution +// Set<String> functionNames = new HashSet<>(); +// addFunctionNamesIn(expression.getRoot(), functionNames, model); +// for (String functionName : functionNames) { +// Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec); +// if (requiredType.isPresent()) { +// Reference ref = Reference.fromIdentifier(functionName); +// if (typeContext.getType(ref).equals(TensorType.empty)) { +// typeContext.setType(ref, requiredType.get()); +// } +// } +// } +// typeContext.forgetResolvedTypes(); +// +// TensorType typeBeforeReducing = expression.getRoot().type(typeContext); +// +// // Check generated functions for inputs to reduce +// for (String functionName : functionNames) { +// if ( ! model.functions().containsKey(functionName)) continue; +// +// RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); +// if (rankingExpressionFunction == null) { +// throw new IllegalArgumentException("Model refers to generated function '" + functionName + +// "but this function is not present in " + profile); +// } +// RankingExpression functionExpression = rankingExpressionFunction.function().getBody(); +// functionExpression.setRoot(reduceBatchDimensionsAtInput(functionExpression.getRoot(), model, typeContext)); +// } +// +// // Check expression for inputs to reduce +// ExpressionNode root = expression.getRoot(); +// root = reduceBatchDimensionsAtInput(root, model, typeContext); +// TensorType typeAfterReducing = root.type(typeContext); +// root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); +// expression.setRoot(root); } - /** - * If batch dimensions have been reduced away above, bring them back here - * for any following computation of the tensor. - */ - // TODO: determine when this is not necessary! - private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { - if (after.equals(before)) return node; - - TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType()); - for (TensorType.Dimension dimension : before.dimensions()) { - if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { - typeBuilder.indexed(dimension.name(), 1); - } - } - TensorType expandDimensionsType = typeBuilder.build(); - if (expandDimensionsType.dimensions().size() > 0) { - ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); - Generate generatedFunction = new Generate(expandDimensionsType, - new GeneratorLambdaFunctionNode(expandDimensionsType, - generatedExpression) - .asLongListToDoubleOperator()); - Join expand = new Join(TensorFunctionNode.wrap(node), generatedFunction, ScalarFunctions.multiply()); - return new TensorFunctionNode(expand); - } - return node; - } +// private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model, +// MapEvaluationTypeContext typeContext) { +// if (node instanceof TensorFunctionNode) { +// TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); +// if (tensorFunction instanceof Rename) { +// List<ExpressionNode> children = ((TensorFunctionNode)node).children(); +// if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { +// ReferenceNode referenceNode = (ReferenceNode) children.get(0); +// if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { +// return reduceBatchDimensionExpression(tensorFunction, typeContext); +// } +// } +// // Modify any renames in expression to disregard batch dimension +// else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) { +// TensorFunction<Reference> childFunction = (((TensorFunctionNode) children.get(0)).function()); +// TensorType childType = childFunction.type(typeContext); +// Rename rename = (Rename) tensorFunction; +// List<String> from = new ArrayList<>(); +// List<String> to = new ArrayList<>(); +// for (TensorType.Dimension dimension : childType.dimensions()) { +// int i = rename.fromDimensions().indexOf(dimension.name()); +// if (i < 0) { +// throw new IllegalArgumentException("Rename does not contain dimension '" + +// dimension + "' in child expression type: " + childType); +// } +// from.add((String)rename.fromDimensions().get(i)); +// to.add((String)rename.toDimensions().get(i)); +// } +// return new TensorFunctionNode(new Rename<>(childFunction, from, to)); +// } +// } +// } +// if (node instanceof ReferenceNode) { +// ReferenceNode referenceNode = (ReferenceNode) node; +// if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { +// return reduceBatchDimensionExpression(TensorFunctionNode.wrap(node), typeContext); +// } +// } +// if (node instanceof CompositeNode) { +// List<ExpressionNode> children = ((CompositeNode)node).children(); +// List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); +// for (ExpressionNode child : children) { +// transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); +// } +// return ((CompositeNode)node).setChildren(transformedChildren); +// } +// return node; +// } +// +// private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, MapEvaluationTypeContext context) { +// TensorFunction result = function; +// TensorType type = function.type(context); +// if (type.dimensions().size() > 1) { +// List<String> reduceDimensions = new ArrayList<>(); +// for (TensorType.Dimension dimension : type.dimensions()) { +// if (dimension.size().orElse(-1L) == 1) { +// reduceDimensions.add(dimension.name()); +// } +// } +// if (reduceDimensions.size() > 0) { +// result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); +// context.forgetResolvedTypes(); // We changed types +// } +// } +// return new TensorFunctionNode(result); +// } +// +// /** +// * If batch dimensions have been reduced away above, bring them back here +// * for any following computation of the tensor. +// */ +// // TODO: determine when this is not necessary! +// private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { +// if (after.equals(before)) return node; +// +// TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType()); +// for (TensorType.Dimension dimension : before.dimensions()) { +// if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { +// typeBuilder.indexed(dimension.name(), 1); +// } +// } +// TensorType expandDimensionsType = typeBuilder.build(); +// if (expandDimensionsType.dimensions().size() > 0) { +// ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); +// Generate generatedFunction = new Generate(expandDimensionsType, +// new GeneratorLambdaFunctionNode(expandDimensionsType, +// generatedExpression) +// .asLongListToDoubleOperator()); +// Join expand = new Join(TensorFunctionNode.wrap(node), generatedFunction, ScalarFunctions.multiply()); +// return new TensorFunctionNode(expand); +// } +// return node; +// } /** * If a constant c is overridden by a function, we need to replace instances of "constant(c)" by "c" in expressions. diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java index d84d967a184..8bc9040577b 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java @@ -2,8 +2,11 @@ package com.yahoo.searchdefinition; import com.yahoo.collections.Pair; +import com.yahoo.config.model.application.provider.BaseDeployLogger; +import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.search.query.profile.QueryProfileRegistry; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; +import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.yolean.Exceptions; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; @@ -82,6 +85,18 @@ public class RankingExpressionConstantsTestCase extends SchemaTestCase { new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(foo).rankingScript,14.0)", rankProperties.get(0).toString()); assertEquals("(rankingExpression(firstphase).rankingScript,16.6)", rankProperties.get(2).toString()); + + try { + DerivedConfiguration config = new DerivedConfiguration(s, + new BaseDeployLogger(), + new TestProperties(), + rankProfileRegistry, + queryProfileRegistry, + new ImportedMlModels()); + config.export("/Users/lesters/temp/bert/idea/"); + } catch (Exception e) { + throw new IllegalArgumentException(e); + } } @Test diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index 0cd6674751e..d5638da224c 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -41,6 +41,14 @@ class RankProfileSearchFixture { private Search search; private Map<String, RankProfile> compiledRankProfiles = new HashMap<>(); + // TEMP + public RankProfileRegistry getRankProfileRegistry() { + return rankProfileRegistry; + } + public QueryProfileRegistry getQueryProfileRegistry() { + return queryProfileRegistry; + } + RankProfileSearchFixture(String rankProfiles) throws ParseException { this(MockApplicationPackage.createEmpty(), new QueryProfileRegistry(), rankProfiles); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java new file mode 100644 index 00000000000..2c0620a0c52 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java @@ -0,0 +1,196 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.processing; + +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; +import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter; +import ai.vespa.rankingexpression.importer.lightgbm.LightGBMImporter; +import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; +import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; +import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; +import com.google.common.collect.ImmutableList; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.application.provider.BaseDeployLogger; +import com.yahoo.config.model.deploy.TestProperties; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchdefinition.SearchBuilder; +import com.yahoo.searchdefinition.derived.DerivedConfiguration; +import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.ml.ImportedModelTester; +import com.yahoo.yolean.Exceptions; +import org.junit.After; +import org.junit.Ignore; +import org.junit.Test; + +import java.io.IOException; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; + +public class RankingExpressionWithBertTestCase { + + private final Path applicationDir = Path.fromString("src/test/integration/bert/"); + + /** The model name */ + private final static String name = "bertsquad8"; + + private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_Variable), f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))"; + + @After + public void removeGeneratedModelFiles() { + IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + } + + + @Ignore + @Test + public void testGlobalBertModel() throws IOException { + ImportedModelTester tester = new ImportedModelTester(name, applicationDir); + VespaModel model = tester.createVespaModel(); +// tester.assertLargeConstant(name + "_Variable_1", model, Optional.of(10L)); +// tester.assertLargeConstant(name + "_Variable", model, Optional.of(7840L)); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedAppDir = applicationDir.append("copy"); + try { + storedAppDir.toFile().mkdirs(); + IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString()); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir); + VespaModel storedModel = storedTester.createVespaModel(); +// tester.assertLargeConstant(name + "_Variable_1", storedModel, Optional.of(10L)); +// tester.assertLargeConstant(name + "_Variable", storedModel, Optional.of(7840L)); + } + finally { + IOUtils.recursiveDeleteDir(storedAppDir.toFile()); + } + } + + @Ignore + @Test + public void testBertRankProfile() throws Exception { + StoringApplicationPackage application = new StoringApplicationPackage((applicationDir)); + + ImmutableList<MlModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), + new OnnxImporter(), + new LightGBMImporter(), + new XGBoostImporter()); + + String rankProfiles = " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: onnx('bertsquad8.onnx', 'default', 'unstack')" + + " }\n" + + " }"; + + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + QueryProfileRegistry queryProfileRegistry = application.getQueryProfiles(); + + SearchBuilder builder = new SearchBuilder(application, rankProfileRegistry, queryProfileRegistry); + String sdContent = "search test {\n" + + " document test {\n" + + " field unique_ids type tensor(d0[1]) {\n" + + " indexing: summary | attribute\n" + + " }\n" + + " field input_ids type tensor(d0[1],d1[256]) {\n" + + " indexing: summary | attribute\n" + + " }\n" + + " field input_mask type tensor(d0[1],d1[256]) {\n" + + " indexing: summary | attribute\n" + + " }\n" + + " field segment_ids type tensor(d0[1],d1[256]) {\n" + + " indexing: summary | attribute\n" + + " }" + + " }\n" + + " rank-profile my_profile inherits default {\n" + + " function inline unique_ids_raw_output___9() {\n" + + " expression: attribute(unique_ids)\n" + + " }\n" + + " function inline input_ids() {\n" + + " expression: attribute(input_ids)\n" + + " }\n" + + " function inline input_mask() {\n" + + " expression: attribute(input_mask)\n" + + " }\n" + + " function inline segment_ids() {\n" + + " expression: attribute(segment_ids)\n" + + " }\n" + + " first-phase {\n" + + " expression: onnx(\"bertsquad8.onnx\", \"default\", \"unstack\") \n" + + " }\n" + + " }" + + "}"; + builder.importString(sdContent); + builder.build(); + Search search = builder.getSearch(); + + RankProfile compiled = rankProfileRegistry.get(search, "my_profile") + .compile(queryProfileRegistry, + new ImportedMlModels(applicationDir.toFile(), importers)); + + DerivedConfiguration config = new DerivedConfiguration(search, + new BaseDeployLogger(), + new TestProperties(), + rankProfileRegistry, + queryProfileRegistry, + new ImportedMlModels()); + + config.export("/Users/lesters/temp/bert/idea/"); + +// fixture.assertFirstPhaseExpression(vespaExpression, "my_profile"); + System.out.println("Joda"); + } + + private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression, + String constant, String field) { + return fixtureWith(placeholderExpression, firstPhaseExpression, constant, field, "Placeholder", + new StoringApplicationPackage(applicationDir)); + } + + private RankProfileSearchFixture uncompiledFixtureWith(String rankProfile, StoringApplicationPackage application) { + try { + return new RankProfileSearchFixture(application, application.getQueryProfiles(), + rankProfile, null, null); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + + private RankProfileSearchFixture fixtureWith(String functionExpression, + String firstPhaseExpression, + String constant, + String field, + String functionName, + StoringApplicationPackage application) { + try { + RankProfileSearchFixture fixture = new RankProfileSearchFixture( + application, + application.getQueryProfiles(), + " rank-profile my_profile {\n" + + " function " + functionName + "() {\n" + + " expression: " + functionExpression + + " }\n" + + " first-phase {\n" + + " expression: " + firstPhaseExpression + + " }\n" + + " }", + constant, + field); + fixture.compileRankProfile("my_profile", applicationDir.append("models")); + return fixture; + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index cba931e81f0..c444bf8d7dc 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -1,8 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.application.provider.BaseDeployLogger; +import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; @@ -10,6 +13,7 @@ import com.yahoo.io.reader.NamedReader; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.Tensor; @@ -385,6 +389,15 @@ public class RankingExpressionWithTensorFlowTestCase { finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); } + + DerivedConfiguration config = new DerivedConfiguration(search.search(), + new BaseDeployLogger(), + new TestProperties(), + search.getRankProfileRegistry(), + search.getQueryProfileRegistry(), + new ImportedMlModels()); + config.export("/Users/lesters/temp/bert/idea/"); + } private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) { |