summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-09-25 15:49:22 -0700
committerJon Bratseth <bratseth@oath.com>2018-09-25 15:49:22 -0700
commit11884899e39c54abeb79bacbe723df0ff34ce869 (patch)
tree674025004f825c9cc12a075f992c0b2d1d45509e
parent0246064bbfb9657515f516e2fea12d593cd13016 (diff)
Revert "Merge pull request #7094 from vespa-engine/revert-7070-bratseth/rank-type-information-2"
This reverts commit 0246064bbfb9657515f516e2fea12d593cd13016, reversing changes made to f627463a8100090ec109d27c3aeb439a3395a34f.
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java17
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/Search.java14
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java7
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java10
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java35
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java28
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java (renamed from config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java)57
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/multifieldresolver/RankProfileTypeSettingsProcessor.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/Service.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java31
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java116
-rw-r--r--config-model/src/test/derived/gemini2/gemini.sd6
-rw-r--r--config-model/src/test/derived/gemini2/rank-profiles.cfg4
-rw-r--r--config-model/src/test/derived/rankexpression/rank-profiles.cfg60
-rw-r--r--config-model/src/test/derived/tensor/rank-profiles.cfg6
-rw-r--r--config-model/src/test/derived/tensor/tensor.sd6
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java6
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java10
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolverTestCase.java (renamed from config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java)2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java17
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java4
-rw-r--r--container-search/src/main/java/com/yahoo/search/searchchain/AsyncExecution.java2
-rw-r--r--processing/src/main/java/com/yahoo/processing/execution/AsyncExecution.java29
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java48
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java79
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java1
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java13
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java8
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java16
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java8
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java26
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java14
35 files changed, 452 insertions, 241 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 16e494c2db1..937151c0d3a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -449,7 +449,7 @@ public class RankProfile implements Serializable, Cloneable {
addRankProperty(new RankProperty(name, parameter));
}
- public void addRankProperty(RankProperty rankProperty) {
+ private void addRankProperty(RankProperty rankProperty) {
// Just the usual multimap semantics here
List<RankProperty> properties = rankProperties.get(rankProperty.getName());
if (properties == null) {
@@ -541,15 +541,14 @@ public class RankProfile implements Serializable, Cloneable {
/** Returns an unmodifiable view of the functions in this */
public Map<String, RankingExpressionFunction> getFunctions() {
- if (functions.size() == 0 && getInherited() == null) return Collections.emptyMap();
- if (functions.size() == 0) return getInherited().getFunctions();
+ if (functions.isEmpty() && getInherited() == null) return Collections.emptyMap();
+ if (functions.isEmpty()) return getInherited().getFunctions();
if (getInherited() == null) return Collections.unmodifiableMap(functions);
// Neither is null
Map<String, RankingExpressionFunction> allFunctions = new LinkedHashMap<>(getInherited().getFunctions());
allFunctions.putAll(functions);
return Collections.unmodifiableMap(allFunctions);
-
}
public int getKeepRankCount() {
@@ -695,7 +694,7 @@ public class RankProfile implements Serializable, Cloneable {
for (Map.Entry<String, RankingExpressionFunction> entry : functions.entrySet()) {
RankingExpressionFunction rankingExpressionFunction = entry.getValue();
RankingExpression compiled = compile(rankingExpressionFunction.function().getBody(), queryProfiles, importedModels, getConstants(), inlineFunctions, expressionTransforms);
- compiledFunctions.put(entry.getKey(), rankingExpressionFunction.withBody(compiled));
+ compiledFunctions.put(entry.getKey(), rankingExpressionFunction.withExpression(compiled));
}
return compiledFunctions;
}
@@ -898,7 +897,7 @@ public class RankProfile implements Serializable, Cloneable {
/** A function in a rank profile */
public static class RankingExpressionFunction {
- private final ExpressionFunction function;
+ private ExpressionFunction function;
/** True if this should be inlined into calling expressions. Useful for very cheap functions. */
private final boolean inline;
@@ -908,13 +907,17 @@ public class RankProfile implements Serializable, Cloneable {
this.inline = inline;
}
+ public void setReturnType(TensorType type) {
+ this.function = function.withReturnType(type);
+ }
+
public ExpressionFunction function() { return function; }
public boolean inline() {
return inline && function.arguments().isEmpty(); // only inline no-arg functions;
}
- public RankingExpressionFunction withBody(RankingExpression expression) {
+ public RankingExpressionFunction withExpression(RankingExpression expression) {
return new RankingExpressionFunction(function.withBody(expression), inline);
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
index f42d5de21e8..a988da9664e 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
@@ -59,9 +59,6 @@ public class Search implements Serializable, ImmutableSearch {
// Field sets
private FieldSets fieldSets = new FieldSets();
- // Whether or not this object has been processed.
- private boolean processed;
-
// The unique name of this search definition.
private String name;
@@ -585,17 +582,6 @@ public class Search implements Serializable, ImmutableSearch {
return false;
}
- public void process() {
- if (processed) {
- throw new IllegalStateException("Search '" + getName() + "' already processed.");
- }
- processed = true;
- }
-
- public boolean isProcessed() {
- return processed;
- }
-
/**
* The field set settings for this search
*
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java
index 3c2ebc058ac..151ad02a3fa 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java
@@ -188,10 +188,6 @@ public class SearchBuilder {
throw new IllegalArgumentException("Search has no name.");
}
String rawName = rawSearch.getName();
- if (rawSearch.isProcessed()) {
- throw new IllegalArgumentException("A search definition with a search section called '" + rawName +
- "' has already been processed.");
- }
for (Search search : searchList) {
if (rawName.equals(search.getName())) {
throw new IllegalArgumentException("A search definition with a search section called '" + rawName +
@@ -247,8 +243,7 @@ public class SearchBuilder {
DocumentModelBuilder builder = new DocumentModelBuilder(model);
for (Search search : new SearchOrderer().order(searchList)) {
- new FieldOperationApplierForSearch().process(search);
- // These two needed for a couple of old unit tests, ideally these are just read from app
+ new FieldOperationApplierForSearch().process(search); // TODO: Why is this not in the regular list?
process(search, deployLogger, new QueryProfiles(queryProfileRegistry), validate);
built.add(search);
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
index 9a00ee5bbd0..7c2d9a3b0ad 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
@@ -74,21 +74,11 @@ public class DerivedConfiguration {
QueryProfileRegistry queryProfiles,
ImportedModels importedModels) {
Validator.ensureNotNull("Search definition", search);
- if ( ! search.isProcessed()) {
- throw new IllegalArgumentException("Search '" + search.getName() + "' not processed.");
- }
this.search = search;
if ( ! search.isDocumentsOnly()) {
streamingFields = new VsmFields(search);
streamingSummary = new VsmSummary(search);
}
- if (abstractSearchList != null) {
- for (Search abstractSearch : abstractSearchList) {
- if (!abstractSearch.isProcessed()) {
- throw new IllegalArgumentException("Search '" + search.getName() + "' not processed.");
- }
- }
- }
if ( ! search.isDocumentsOnly()) {
attributeFields = new AttributeFields(search);
summaries = new Summaries(search, deployLogger);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index c041d5c6a89..279b5334187 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -13,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
+import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import java.nio.charset.Charset;
@@ -23,6 +24,7 @@ import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.stream.Collectors;
/**
* A rank profile derived from a search definition, containing exactly the features available natively in the server
@@ -180,32 +182,39 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
private void derivePropertiesAndSummaryFeaturesFromFunctions(Map<String, RankProfile.RankingExpressionFunction> functions) {
if (functions.isEmpty()) return;
- Map<String, ExpressionFunction> expressionFunctions = new LinkedHashMap<>();
- for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : functions.entrySet()) {
- expressionFunctions.put(function.getKey(), function.getValue().function());
- }
+
+ List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList());
Map<String, String> functionProperties = new LinkedHashMap<>();
- functionProperties.putAll(deriveFunctionProperties(expressionFunctions));
+ functionProperties.putAll(deriveFunctionProperties(functions, functionExpressions));
+
if (firstPhaseRanking != null) {
- functionProperties.putAll(firstPhaseRanking.getRankProperties(new ArrayList<>(expressionFunctions.values())));
+ functionProperties.putAll(firstPhaseRanking.getRankProperties(functionExpressions));
}
if (secondPhaseRanking != null) {
- functionProperties.putAll(secondPhaseRanking.getRankProperties(new ArrayList<>(expressionFunctions.values())));
+ functionProperties.putAll(secondPhaseRanking.getRankProperties(functionExpressions));
}
for (Map.Entry<String, String> e : functionProperties.entrySet()) {
rankProperties.add(new RankProfile.RankProperty(e.getKey(), e.getValue()));
}
- SerializationContext context = new SerializationContext(expressionFunctions.values(), null, functionProperties);
+ SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties);
replaceFunctionSummaryFeatures(context);
}
- private Map<String, String> deriveFunctionProperties(Map<String, ExpressionFunction> functions) {
- SerializationContext context = new SerializationContext(functions);
- for (Map.Entry<String, ExpressionFunction> e : functions.entrySet()) {
- String expression = e.getValue().getBody().getRoot().toString(new StringBuilder(), context, null, null).toString();
- context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expression);
+ private Map<String, String> deriveFunctionProperties(Map<String, RankProfile.RankingExpressionFunction> functions,
+ List<ExpressionFunction> functionExpressions) {
+ SerializationContext context = new SerializationContext(functionExpressions);
+ for (Map.Entry<String, RankProfile.RankingExpressionFunction> e : functions.entrySet()) {
+ String expressionString = e.getValue().function().getBody().getRoot().toString(new StringBuilder(), context, null, null).toString();
+ context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expressionString);
+
+ for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet())
+ context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue());
+ if (e.getValue().function().returnType().isPresent())
+ context.addFunctionTypeSerialization(e.getKey(), e.getValue().function().returnType().get());
+ else if (e.getValue().function().arguments().isEmpty())
+ throw new IllegalStateException("Type of function '" + e.getKey() + "' is not resolved");
}
return context.serializedFunctions();
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
index 8c8c32389e2..15d295736c1 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
@@ -7,10 +7,8 @@ import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.processing.multifieldresolver.RankProfileTypeSettingsProcessor;
import com.yahoo.vespa.model.container.search.QueryProfiles;
-import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
-import java.util.List;
/**
* Executor of processors. This defines the right order of processor execution.
@@ -75,12 +73,20 @@ public class Processing {
ReferenceFieldsProcessor::new,
FastAccessValidator::new,
ReservedFunctionNames::new,
- RankingExpressionTypeValidator::new,
+ RankingExpressionTypeResolver::new,
// These should be last:
IndexingValidation::new,
IndexingValues::new);
}
+ /** Processors of rank profiles only (those who tolerate and so something useful when the search field is null) */
+ private Collection<ProcessorFactory> rankProfileProcessors() {
+ return Arrays.asList(
+ RankProfileTypeSettingsProcessor::new,
+ ReservedFunctionNames::new,
+ RankingExpressionTypeResolver::new);
+ }
+
/**
* Runs all search processors on the given {@link Search} object. These will modify the search object, <b>possibly
* exchanging it with another</b>, as well as its document types.
@@ -93,12 +99,26 @@ public class Processing {
public void process(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry,
QueryProfiles queryProfiles, boolean validate, boolean documentsOnly) {
Collection<ProcessorFactory> factories = processors();
- search.process();
factories.stream()
.map(factory -> factory.create(search, deployLogger, rankProfileRegistry, queryProfiles))
.forEach(processor -> processor.process(validate, documentsOnly));
}
+ /**
+ * Runs rank profiles processors only.
+ *
+ * @param deployLogger The log to log messages and warnings for application deployment to
+ * @param rankProfileRegistry a {@link com.yahoo.searchdefinition.RankProfileRegistry}
+ * @param queryProfiles The query profiles contained in the application this search is part of.
+ */
+ public void processRankProfiles(DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry,
+ QueryProfiles queryProfiles, boolean validate, boolean documentsOnly) {
+ Collection<ProcessorFactory> factories = rankProfileProcessors();
+ factories.stream()
+ .map(factory -> factory.create(null, deployLogger, rankProfileRegistry, queryProfiles))
+ .forEach(processor -> processor.process(validate, documentsOnly));
+ }
+
@FunctionalInterface
public interface ProcessorFactory {
Processor create(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java
index 102d1910360..4c8b5910b78 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java
@@ -7,39 +7,44 @@ import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.Search;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.vespa.model.container.search.QueryProfiles;
+import java.util.Map;
+
/**
- * Validates the types of all ranking expressions under a search instance:
+ * Resolves and assigns types to all functions in a ranking expression, and
+ * validates the types of all ranking expressions under a search instance:
* Some operators constrain the types of inputs, and first-and second-phase expressions
- * must return scalar values. In addition, the existence of all referred attribute, query and constant
+ * must return scalar values.
+ *
+ * In addition, the existence of all referred attribute, query and constant
* features is ensured.
*
* @author bratseth
*/
-public class RankingExpressionTypeValidator extends Processor {
+public class RankingExpressionTypeResolver extends Processor {
private final QueryProfileRegistry queryProfiles;
- public RankingExpressionTypeValidator(Search search,
- DeployLogger deployLogger,
- RankProfileRegistry rankProfileRegistry,
- QueryProfiles queryProfiles) {
+ public RankingExpressionTypeResolver(Search search,
+ DeployLogger deployLogger,
+ RankProfileRegistry rankProfileRegistry,
+ QueryProfiles queryProfiles) {
super(search, deployLogger, rankProfileRegistry, queryProfiles);
this.queryProfiles = queryProfiles.getRegistry();
}
@Override
public void process(boolean validate, boolean documentsOnly) {
- if ( ! validate) return;
if (documentsOnly) return;
for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) {
try {
- validate(profile);
+ resolveTypesIn(profile, validate);
}
catch (IllegalArgumentException e) {
throw new IllegalArgumentException("In " + search + ", " + profile, e);
@@ -47,20 +52,34 @@ public class RankingExpressionTypeValidator extends Processor {
}
}
- /** Throws an IllegalArgumentException if the given rank profile does not produce valid type */
- private void validate(RankProfile profile) {
- TypeContext context = profile.typeContext(queryProfiles);
- profile.getSummaryFeatures().forEach(f -> ensureValid(f, "summary feature " + f, context));
- ensureValidDouble(profile.getFirstPhaseRanking(), "first-phase expression", context);
- ensureValidDouble(profile.getSecondPhaseRanking(), "second-phase expression", context);
+ /**
+ * Resolves the types of all functions in the given profile
+ *
+ * @throws IllegalArgumentException if validate is true and the given rank profile does not produce valid types
+ */
+ private void resolveTypesIn(RankProfile profile, boolean validate) {
+ TypeContext<Reference> context = profile.typeContext(queryProfiles);
+ for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) {
+ if ( ! function.getValue().function().arguments().isEmpty()) continue;
+ TensorType type = resolveType(function.getValue().function().getBody(),
+ "function '" + function.getKey() + "'",
+ context);
+ function.getValue().setReturnType(type);
+ }
+
+ if (validate) {
+ profile.getSummaryFeatures().forEach(f -> resolveType(f, "summary feature " + f, context));
+ ensureValidDouble(profile.getFirstPhaseRanking(), "first-phase expression", context);
+ ensureValidDouble(profile.getSecondPhaseRanking(), "second-phase expression", context);
+ }
}
- private TensorType ensureValid(RankingExpression expression, String expressionDescription, TypeContext context) {
+ private TensorType resolveType(RankingExpression expression, String expressionDescription, TypeContext context) {
if (expression == null) return null;
- return ensureValid(expression.getRoot(), expressionDescription, context);
+ return resolveType(expression.getRoot(), expressionDescription, context);
}
- private TensorType ensureValid(ExpressionNode expression, String expressionDescription, TypeContext context) {
+ private TensorType resolveType(ExpressionNode expression, String expressionDescription, TypeContext context) {
TensorType type;
try {
type = expression.type(context);
@@ -75,7 +94,7 @@ public class RankingExpressionTypeValidator extends Processor {
private void ensureValidDouble(RankingExpression expression, String expressionDescription, TypeContext context) {
if (expression == null) return;
- TensorType type = ensureValid(expression, expressionDescription, context);
+ TensorType type = resolveType(expression, expressionDescription, context);
if ( ! type.equals(TensorType.empty))
throw new IllegalArgumentException("The " + expressionDescription + " must produce a double " +
"(a tensor with no dimensions), but produces " + type);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/multifieldresolver/RankProfileTypeSettingsProcessor.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/multifieldresolver/RankProfileTypeSettingsProcessor.java
index ec4cbdfe58b..3bde76c1c79 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/multifieldresolver/RankProfileTypeSettingsProcessor.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/multifieldresolver/RankProfileTypeSettingsProcessor.java
@@ -44,6 +44,7 @@ public class RankProfileTypeSettingsProcessor extends Processor {
}
private void processAttributeFields() {
+ if (search == null) return; // we're processing global profiles
for (SDField field : search.allConcreteFields()) {
Attribute attribute = field.getAttributes().get(field.getName());
if (attribute != null && attribute.tensorType().isPresent()) {
@@ -53,6 +54,7 @@ public class RankProfileTypeSettingsProcessor extends Processor {
}
private void processImportedFields() {
+ if (search == null) return; // we're processing global profiles
Optional<ImportedFields> importedFields = search.importedFields();
if (importedFields.isPresent()) {
importedFields.get().fields().forEach((fieldName, field) -> processImportedField(field));
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/Service.java b/config-model/src/main/java/com/yahoo/vespa/model/Service.java
index 620e44bc11a..29ec26b06d2 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/Service.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/Service.java
@@ -7,7 +7,7 @@ import java.util.HashMap;
import java.util.Optional;
/**
- * Representation of a process which runs a service
+ * Representation of a markProcessed which runs a service
*
* @author gjoranv
*/
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index 4b70b1b5ae2..13304ea10ee 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -32,7 +32,9 @@ import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.RankingConstants;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RankProfileList;
+import com.yahoo.searchdefinition.processing.Processing;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.vespa.model.container.search.QueryProfiles;
import com.yahoo.vespa.model.ml.ConvertedModel;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
@@ -168,7 +170,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
createGlobalRankProfiles(deployState.getImportedModels(),
deployState.rankProfileRegistry(),
- deployState.getQueryProfiles().getRegistry());
+ deployState.getQueryProfiles());
this.rankProfileList = new RankProfileList(null, // null search -> global
rankingConstants,
AttributeFields.empty,
@@ -219,26 +221,23 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
/** Adds generic application specific clusters of services */
private void addServiceClusters(ApplicationPackage app, VespaModelBuilder builder) {
- for (ServiceCluster sc : builder.getClusters(app, this))
- serviceClusters.add(sc);
+ serviceClusters.addAll(builder.getClusters(app, this));
}
/**
- * Creates a rank profile not attached to any search definition, for each imported model in the application package
+ * Creates a rank profile not attached to any search definition, for each imported model in the application package,
+ * and adds it to the given rank profile registry.
*/
- private ImmutableList<RankProfile> createGlobalRankProfiles(ImportedModels importedModels,
- RankProfileRegistry rankProfileRegistry,
- QueryProfileRegistry queryProfiles) {
- List<RankProfile> profiles = new ArrayList<>();
+ private void createGlobalRankProfiles(ImportedModels importedModels,
+ RankProfileRegistry rankProfileRegistry,
+ QueryProfiles queryProfiles) {
if ( ! importedModels.all().isEmpty()) { // models/ directory is available
for (ImportedModel model : importedModels.all()) {
RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry);
rankProfileRegistry.add(profile);
ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()),
- model.name(), profile, queryProfiles, model);
- for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) {
- profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue()), false);
- }
+ model.name(), profile, queryProfiles.getRegistry(), model);
+ convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false));
}
}
else { // generated and stored model information may be available instead
@@ -248,12 +247,12 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry);
rankProfileRegistry.add(profile);
ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile);
- for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) {
- profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue()), false);
- }
+ convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false));
}
}
- return ImmutableList.copyOf(profiles);
+ new Processing().processRankProfiles(deployState.getDeployLogger(),
+ rankProfileRegistry,
+ queryProfiles, true, false);
}
/** Returns the global rank profiles as a rank profile list */
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index adf5c81283e..fb0109ed32e 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -48,6 +48,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -67,14 +68,14 @@ public class ConvertedModel {
private final ModelName modelName;
private final String modelDescription;
- private final ImmutableMap<String, RankingExpression> expressions;
+ private final ImmutableMap<String, ExpressionFunction> expressions;
/** The source importedModel, or empty if this was created from a stored converted model */
private final Optional<ImportedModel> sourceModel;
private ConvertedModel(ModelName modelName,
String modelDescription,
- Map<String, RankingExpression> expressions,
+ Map<String, ExpressionFunction> expressions,
Optional<ImportedModel> sourceModel) {
this.modelName = modelName;
this.modelDescription = modelDescription;
@@ -132,23 +133,23 @@ public class ConvertedModel {
* if signatures are used, or the expression name if signatures are not used and there are multiple
* expressions, and the second is the output name if signature names are used.
*/
- public Map<String, RankingExpression> expressions() { return expressions; }
+ public Map<String, ExpressionFunction> expressions() { return expressions; }
/**
* Returns the expression matching the given arguments.
*/
public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) {
- RankingExpression expression = selectExpression(arguments);
- if (sourceModel.isPresent()) // we can verify
- verifyRequiredFunctions(expression, sourceModel.get(), context.rankProfile(), context.queryProfiles());
- return expression.getRoot();
+ ExpressionFunction expression = selectExpression(arguments);
+ if (sourceModel.isPresent()) // we should verify
+ verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles());
+ return expression.getBody().getRoot();
}
- private RankingExpression selectExpression(FeatureArguments arguments) {
+ private ExpressionFunction selectExpression(FeatureArguments arguments) {
if (expressions.isEmpty())
throw new IllegalArgumentException("No expressions available in " + this);
- RankingExpression expression = expressions.get(arguments.toName());
+ ExpressionFunction expression = expressions.get(arguments.toName());
if (expression != null) return expression;
if ( ! arguments.signature().isPresent()) {
@@ -158,7 +159,7 @@ public class ConvertedModel {
}
if ( ! arguments.output().isPresent()) {
- List<Map.Entry<String, RankingExpression>> entriesWithTheRightPrefix =
+ List<Map.Entry<String, ExpressionFunction>> entriesWithTheRightPrefix =
expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(arguments.signature().get() + ".")).collect(Collectors.toList());
if (entriesWithTheRightPrefix.size() < 1)
throw new IllegalArgumentException("No expressions named '" + arguments.signature().get() +
@@ -179,10 +180,10 @@ public class ConvertedModel {
// ----------------------- Static model conversion/storage below here
- private static Map<String, RankingExpression> convertAndStore(ImportedModel model,
- RankProfile profile,
- QueryProfileRegistry queryProfiles,
- ModelStore store) {
+ private static Map<String, ExpressionFunction> convertAndStore(ImportedModel model,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ ModelStore store) {
// Add constants
Set<String> constantsReplacedByFunctions = new HashSet<>();
model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
@@ -193,8 +194,8 @@ public class ConvertedModel {
addGeneratedFunctions(model, profile);
// Add expressions
- Map<String, RankingExpression> expressions = new HashMap<>();
- for (Pair<String, RankingExpression> output : model.outputExpressions()) {
+ Map<String, ExpressionFunction> expressions = new HashMap<>();
+ for (Pair<String, ExpressionFunction> output : model.outputExpressions()) {
addExpression(output.getSecond(), output.getFirst(),
constantsReplacedByFunctions,
model, store, profile, queryProfiles,
@@ -210,21 +211,21 @@ public class ConvertedModel {
return expressions;
}
- private static void addExpression(RankingExpression expression,
+ private static void addExpression(ExpressionFunction expression,
String expressionName,
Set<String> constantsReplacedByFunctions,
ImportedModel model,
ModelStore store,
RankProfile profile,
QueryProfileRegistry queryProfiles,
- Map<String, RankingExpression> expressions) {
- expression = replaceConstantsByFunctions(expression, constantsReplacedByFunctions);
- reduceBatchDimensions(expression, model, profile, queryProfiles);
+ Map<String, ExpressionFunction> expressions) {
+ expression = expression.withBody(replaceConstantsByFunctions(expression.getBody(), constantsReplacedByFunctions));
+ reduceBatchDimensions(expression.getBody(), model, profile, queryProfiles);
store.writeExpression(expressionName, expression);
expressions.put(expressionName, expression);
}
- private static Map<String, RankingExpression> convertStored(ModelStore store, RankProfile profile) {
+ private static Map<String, ExpressionFunction> convertStored(ModelStore store, RankProfile profile) {
for (Pair<String, Tensor> constant : store.readSmallConstants())
profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
@@ -290,15 +291,15 @@ public class ConvertedModel {
}
/**
- * Verify that the functions referred in the given expression exists in the given rank profile,
- * and return tensors of the types specified in requiredFunctions.
+ * Verify that the inputs declared in the given expression exists in the given rank profile as functions,
+ * and return tensors of the correct types.
*/
- private static void verifyRequiredFunctions(RankingExpression expression, ImportedModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
+ private static void verifyInputs(RankingExpression expression, ImportedModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
Set<String> functionNames = new HashSet<>();
addFunctionNamesIn(expression.getRoot(), functionNames, model);
for (String functionName : functionNames) {
- TensorType requiredType = model.requiredFunctions().get(functionName);
+ TensorType requiredType = model.inputs().get(functionName);
if (requiredType == null) continue; // Not a required function
RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName);
@@ -375,7 +376,7 @@ public class ConvertedModel {
List<ExpressionNode> children = ((TensorFunctionNode)node).children();
if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) children.get(0);
- if (model.requiredFunctions().containsKey(referenceNode.getName())) {
+ if (model.inputs().containsKey(referenceNode.getName())) {
return reduceBatchDimensionExpression(tensorFunction, typeContext);
}
}
@@ -383,7 +384,7 @@ public class ConvertedModel {
}
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) node;
- if (model.requiredFunctions().containsKey(referenceNode.getName())) {
+ if (model.inputs().containsKey(referenceNode.getName())) {
return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
}
}
@@ -451,7 +452,8 @@ public class ConvertedModel {
Set<String> constantsReplacedByFunctions) {
if (constantsReplacedByFunctions.isEmpty()) return expression;
return new RankingExpression(expression.getName(),
- replaceConstantsByFunctions(expression.getRoot(), constantsReplacedByFunctions));
+ replaceConstantsByFunctions(expression.getRoot(),
+ constantsReplacedByFunctions));
}
private static ExpressionNode replaceConstantsByFunctions(ExpressionNode node, Set<String> constantsReplacedByFunctions) {
@@ -524,19 +526,21 @@ public class ConvertedModel {
* @param name the name of this ranking expression - may have 1-3 parts separated by dot where the first part
* is always the model name
*/
- void writeExpression(String name, RankingExpression expression) {
- application.getFile(modelFiles.expressionPath(name))
- .writeFile(new StringReader(expression.getRoot().toString()));
+ void writeExpression(String name, ExpressionFunction expression) {
+ StringBuilder b = new StringBuilder(expression.getBody().getRoot().toString());
+ for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet())
+ b.append('\n').append(input.getKey()).append('\t').append(input.getValue());
+ application.getFile(modelFiles.expressionPath(name)).writeFile(new StringReader(b.toString()));
}
- Map<String, RankingExpression> readExpressions() {
- Map<String, RankingExpression> expressions = new HashMap<>();
+ Map<String, ExpressionFunction> readExpressions() {
+ Map<String, ExpressionFunction> expressions = new HashMap<>();
ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath());
if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap();
for (ApplicationFile expressionFile : expressionPath.listFiles()) {
- try (Reader reader = new BufferedReader(expressionFile.createReader())){
+ try (BufferedReader reader = new BufferedReader(expressionFile.createReader())){
String name = expressionFile.getPath().getName();
- expressions.put(name, new RankingExpression(name, reader));
+ expressions.put(name, readExpression(name, reader));
}
catch (IOException e) {
throw new UncheckedIOException("Failed reading " + expressionFile.getPath(), e);
@@ -548,8 +552,22 @@ public class ConvertedModel {
return expressions;
}
+ private ExpressionFunction readExpression(String name, BufferedReader reader)
+ throws IOException, ParseException {
+ // First line is expression
+ RankingExpression expression = new RankingExpression(name, reader.readLine());
+ // Next lines are inputs on the format name\ttensorTypeSpec
+ Map<String, TensorType> inputs = new LinkedHashMap<>();
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ inputs.put(parts[0], TensorType.fromSpec(parts[1]));
+ }
+ return new ExpressionFunction(name, new ArrayList<>(inputs.keySet()), expression, inputs, Optional.empty());
+ }
+
/** Adds this function expression to the application package so it can be read later. */
- void writeFunction(String name, RankingExpression expression) {
+ public void writeFunction(String name, RankingExpression expression) {
application.getFile(modelFiles.functionsPath()).appendFile(name + "\t" +
expression.getRoot().toString() + "\n");
}
@@ -561,20 +579,20 @@ public class ConvertedModel {
if ( ! file.exists()) return Collections.emptyList();
List<Pair<String, RankingExpression>> functions = new ArrayList<>();
- BufferedReader reader = new BufferedReader(file.createReader());
- String line;
- while (null != (line = reader.readLine())) {
- String[] parts = line.split("\t");
- String name = parts[0];
- try {
- RankingExpression expression = new RankingExpression(parts[0], parts[1]);
- functions.add(new Pair<>(name, expression));
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + name, e);
+ try (BufferedReader reader = new BufferedReader(file.createReader())) {
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ try {
+ RankingExpression expression = new RankingExpression(parts[0], parts[1]);
+ functions.add(new Pair<>(name, expression));
+ } catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + name, e);
+ }
}
+ return functions;
}
- return functions;
}
catch (IOException e) {
throw new UncheckedIOException(e);
diff --git a/config-model/src/test/derived/gemini2/gemini.sd b/config-model/src/test/derived/gemini2/gemini.sd
index 01e20c1b30a..8a570e58fa8 100644
--- a/config-model/src/test/derived/gemini2/gemini.sd
+++ b/config-model/src/test/derived/gemini2/gemini.sd
@@ -2,6 +2,12 @@
search gemini {
document gemini {
+ field right type string {
+ indexing: attribute
+ }
+ field wrong type string {
+ indexing: attribute
+ }
}
rank-profile test {
diff --git a/config-model/src/test/derived/gemini2/rank-profiles.cfg b/config-model/src/test/derived/gemini2/rank-profiles.cfg
index aa4f963320d..2b73e923c88 100644
--- a/config-model/src/test/derived/gemini2/rank-profiles.cfg
+++ b/config-model/src/test/derived/gemini2/rank-profiles.cfg
@@ -21,9 +21,13 @@ rankprofile[].fef.property[].name "rankingExpression(wrapper1@2d437c13405e61d6).
rankprofile[].fef.property[].value "rankingExpression(wrapper2@2d437c13405e61d6)"
rankprofile[].fef.property[].name "rankingExpression(toplevel).rankingScript"
rankprofile[].fef.property[].value "rankingExpression(wrapper1@2d437c13405e61d6)"
+rankprofile[].fef.property[].name "rankingExpression(toplevel).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(wrapper2@8fc8470e911f253f).rankingScript"
rankprofile[].fef.property[].value "attribute(wrong)"
rankprofile[].fef.property[].name "rankingExpression(wrapper1@8fc8470e911f253f).rankingScript"
rankprofile[].fef.property[].value "rankingExpression(wrapper2@8fc8470e911f253f)"
rankprofile[].fef.property[].name "rankingExpression(interfering).rankingScript"
rankprofile[].fef.property[].value "rankingExpression(wrapper1@8fc8470e911f253f)"
+rankprofile[].fef.property[].name "rankingExpression(interfering).type"
+rankprofile[].fef.property[].value "tensor()"
diff --git a/config-model/src/test/derived/rankexpression/rank-profiles.cfg b/config-model/src/test/derived/rankexpression/rank-profiles.cfg
index 9629ad863d4..d109ca4f0ec 100644
--- a/config-model/src/test/derived/rankexpression/rank-profiles.cfg
+++ b/config-model/src/test/derived/rankexpression/rank-profiles.cfg
@@ -128,6 +128,8 @@ rankprofile[].fef.property[].name "rankingExpression(fourtimessum).rankingScript
rankprofile[].fef.property[].value "4 * (var1 + var2)"
rankprofile[].fef.property[].name "rankingExpression(myfeature).rankingScript"
rankprofile[].fef.property[].value "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + 30 * pow(0 - fieldMatch(description).earliness,2)"
+rankprofile[].fef.property[].name "rankingExpression(myfeature).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86).rankingScript"
rankprofile[].fef.property[].value "4 * (match + rankBoost)"
rankprofile[].fef.property[].name "vespa.rank.firstphase"
@@ -145,10 +147,16 @@ rankprofile[].fef.property[].name "rankingExpression(fourtimessum).rankingScript
rankprofile[].fef.property[].value "4 * (var1 + var2)"
rankprofile[].fef.property[].name "rankingExpression(myfeature).rankingScript"
rankprofile[].fef.property[].value "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + 30 * pow(0 - fieldMatch(description).earliness,2)"
+rankprofile[].fef.property[].name "rankingExpression(myfeature).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature).rankingScript"
rankprofile[].fef.property[].value "70 * fieldMatch(title).completeness"
+rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature2).rankingScript"
rankprofile[].fef.property[].value "71 * fieldMatch(title).completeness"
+rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature2).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(fourtimessum@2b1138e8965e7ff5.67f1e87166cfef86).rankingScript"
rankprofile[].fef.property[].value "4 * (match + match)"
rankprofile[].fef.property[].name "vespa.rank.firstphase"
@@ -164,11 +172,15 @@ rankprofile[].fef.property[].value "rankingExpression(mysummaryfeature)"
rankprofile[].name "macros3"
rankprofile[].fef.property[].name "rankingExpression(onlyusedinsummaryfeature).rankingScript"
rankprofile[].fef.property[].value "5"
+rankprofile[].fef.property[].name "rankingExpression(onlyusedinsummaryfeature).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "vespa.summary.feature"
rankprofile[].fef.property[].value "rankingExpression(matches(title,rankingExpression(onlyusedinsummaryfeature)))"
rankprofile[].name "macros3-inherited"
rankprofile[].fef.property[].name "rankingExpression(onlyusedinsummaryfeature).rankingScript"
rankprofile[].fef.property[].value "5"
+rankprofile[].fef.property[].name "rankingExpression(onlyusedinsummaryfeature).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "vespa.summary.feature"
rankprofile[].fef.property[].value "rankingExpression(matches(title,rankingExpression(onlyusedinsummaryfeature)))"
rankprofile[].name "macros-inherited"
@@ -178,10 +190,16 @@ rankprofile[].fef.property[].name "rankingExpression(fourtimessum).rankingScript
rankprofile[].fef.property[].value "4 * (var1 + var2)"
rankprofile[].fef.property[].name "rankingExpression(myfeature).rankingScript"
rankprofile[].fef.property[].value "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + 30 * pow(0 - fieldMatch(description).earliness,2)"
+rankprofile[].fef.property[].name "rankingExpression(myfeature).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature).rankingScript"
rankprofile[].fef.property[].value "80 * fieldMatch(title).completeness"
+rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature2).rankingScript"
rankprofile[].fef.property[].value "71 * fieldMatch(title).completeness"
+rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature2).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(fourtimessum@2b1138e8965e7ff5.67f1e87166cfef86).rankingScript"
rankprofile[].fef.property[].value "4 * (match + match)"
rankprofile[].fef.property[].name "vespa.rank.firstphase"
@@ -203,10 +221,16 @@ rankprofile[].fef.property[].name "rankingExpression(fourtimessum).rankingScript
rankprofile[].fef.property[].value "4 * (var1 + var2)"
rankprofile[].fef.property[].name "rankingExpression(myfeature).rankingScript"
rankprofile[].fef.property[].value "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + 30 * pow(0 - fieldMatch(description).earliness,2)"
+rankprofile[].fef.property[].name "rankingExpression(myfeature).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature).rankingScript"
rankprofile[].fef.property[].value "80 * fieldMatch(title).completeness"
+rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature2).rankingScript"
rankprofile[].fef.property[].value "71 * fieldMatch(title).completeness"
+rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature2).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(fourtimessum@2b1138e8965e7ff5.67f1e87166cfef86).rankingScript"
rankprofile[].fef.property[].value "4 * (match + match)"
rankprofile[].fef.property[].name "vespa.rank.firstphase"
@@ -228,10 +252,16 @@ rankprofile[].fef.property[].name "rankingExpression(fourtimessum).rankingScript
rankprofile[].fef.property[].value "4 * (var1 + var2)"
rankprofile[].fef.property[].name "rankingExpression(myfeature).rankingScript"
rankprofile[].fef.property[].value "700 * fieldMatch(title).completeness"
+rankprofile[].fef.property[].name "rankingExpression(myfeature).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature).rankingScript"
rankprofile[].fef.property[].value "80 * fieldMatch(title).completeness"
+rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature2).rankingScript"
rankprofile[].fef.property[].value "71 * fieldMatch(title).completeness"
+rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature2).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "vespa.rank.firstphase"
rankprofile[].fef.property[].value "rankingExpression(firstphase)"
rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript"
@@ -249,8 +279,14 @@ rankprofile[].fef.property[].name "rankingExpression(m1).rankingScript"
rankprofile[].fef.property[].value "700 * fieldMatch(title).completeness"
rankprofile[].fef.property[].name "rankingExpression(m2).rankingScript"
rankprofile[].fef.property[].value "rankingExpression(m1) * 67"
+rankprofile[].fef.property[].name "rankingExpression(m2).type"
+rankprofile[].fef.property[].value "tensor()"
+rankprofile[].fef.property[].name "rankingExpression(m1).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(m4).rankingScript"
rankprofile[].fef.property[].value "703 * fieldMatch(fromfile).completeness"
+rankprofile[].fef.property[].name "rankingExpression(m4).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "vespa.rank.secondphase"
rankprofile[].fef.property[].value "rankingExpression(secondphase)"
rankprofile[].fef.property[].name "rankingExpression(secondphase).rankingScript"
@@ -260,10 +296,18 @@ rankprofile[].fef.property[].name "rankingExpression(m1).rankingScript"
rankprofile[].fef.property[].value "700 * fieldMatch(title).completeness"
rankprofile[].fef.property[].name "rankingExpression(m2).rankingScript"
rankprofile[].fef.property[].value "rankingExpression(m1) * 67"
+rankprofile[].fef.property[].name "rankingExpression(m2).type"
+rankprofile[].fef.property[].value "tensor()"
+rankprofile[].fef.property[].name "rankingExpression(m1).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(m4).rankingScript"
rankprofile[].fef.property[].value "701 * fieldMatch(title).completeness"
+rankprofile[].fef.property[].name "rankingExpression(m4).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(m3).rankingScript"
rankprofile[].fef.property[].value "if (isNan(attribute(nrtgmp)) == 1, 0.0, rankingExpression(m2))"
+rankprofile[].fef.property[].name "rankingExpression(m3).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "vespa.rank.secondphase"
rankprofile[].fef.property[].value "rankingExpression(secondphase)"
rankprofile[].fef.property[].name "rankingExpression(secondphase).rankingScript"
@@ -273,8 +317,14 @@ rankprofile[].fef.property[].name "rankingExpression(m1).rankingScript"
rankprofile[].fef.property[].value "700 * fieldMatch(title).completeness"
rankprofile[].fef.property[].name "rankingExpression(m2).rankingScript"
rankprofile[].fef.property[].value "rankingExpression(m1) * 67"
+rankprofile[].fef.property[].name "rankingExpression(m2).type"
+rankprofile[].fef.property[].value "tensor()"
+rankprofile[].fef.property[].name "rankingExpression(m1).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(m4).rankingScript"
rankprofile[].fef.property[].value "703 * fieldMatch(fromfile).completeness"
+rankprofile[].fef.property[].name "rankingExpression(m4).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "vespa.rank.secondphase"
rankprofile[].fef.property[].value "rankingExpression(secondphase)"
rankprofile[].fef.property[].name "rankingExpression(secondphase).rankingScript"
@@ -284,12 +334,22 @@ rankprofile[].fef.property[].name "rankingExpression(m1).rankingScript"
rankprofile[].fef.property[].value "700 * fieldMatch(title).completeness"
rankprofile[].fef.property[].name "rankingExpression(m2).rankingScript"
rankprofile[].fef.property[].value "rankingExpression(m1) * 67"
+rankprofile[].fef.property[].name "rankingExpression(m2).type"
+rankprofile[].fef.property[].value "tensor()"
+rankprofile[].fef.property[].name "rankingExpression(m1).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(m4).rankingScript"
rankprofile[].fef.property[].value "701 * fieldMatch(title).completeness"
+rankprofile[].fef.property[].name "rankingExpression(m4).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(m3).rankingScript"
rankprofile[].fef.property[].value "if (isNan(attribute(nrtgmp)) == 1, 0.0, rankingExpression(m2))"
+rankprofile[].fef.property[].name "rankingExpression(m3).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "rankingExpression(m5).rankingScript"
rankprofile[].fef.property[].value "if (isNan(attribute(glmpfw)) == 1, rankingExpression(m1), rankingExpression(m4))"
+rankprofile[].fef.property[].name "rankingExpression(m5).type"
+rankprofile[].fef.property[].value "tensor()"
rankprofile[].fef.property[].name "vespa.rank.secondphase"
rankprofile[].fef.property[].value "rankingExpression(secondphase)"
rankprofile[].fef.property[].name "rankingExpression(secondphase).rankingScript"
diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg
index cb496c06367..471343da63c 100644
--- a/config-model/src/test/derived/tensor/rank-profiles.cfg
+++ b/config-model/src/test/derived/tensor/rank-profiles.cfg
@@ -43,10 +43,14 @@ rankprofile[].fef.property[].value "tensor(x{})"
rankprofile[].fef.property[].name "vespa.type.attribute.f4"
rankprofile[].fef.property[].value "tensor(x[10],y[20])"
rankprofile[].name "profile3"
+rankprofile[].fef.property[].name "rankingExpression(joinedtensors).rankingScript"
+rankprofile[].fef.property[].value "tensor(i[10])(i) * attribute(f4)"
+rankprofile[].fef.property[].name "rankingExpression(joinedtensors).type"
+rankprofile[].fef.property[].value "tensor(i[10],x[10],y[20])"
rankprofile[].fef.property[].name "vespa.rank.firstphase"
rankprofile[].fef.property[].value "rankingExpression(firstphase)"
rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript"
-rankprofile[].fef.property[].value "reduce(tensor(i[10])(i) * attribute(f4), sum)"
+rankprofile[].fef.property[].value "reduce(rankingExpression(joinedtensors), sum)"
rankprofile[].fef.property[].name "vespa.type.attribute.f2"
rankprofile[].fef.property[].value "tensor(x[2],y[])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd
index 5792a7997f8..54463448e28 100644
--- a/config-model/src/test/derived/tensor/tensor.sd
+++ b/config-model/src/test/derived/tensor/tensor.sd
@@ -36,7 +36,11 @@ search tensor {
rank-profile profile3 {
first-phase {
- expression: sum(tensor(i[10])(i) * attribute(f4))
+ expression: sum(joinedtensors)
+ }
+
+ function joinedtensors() {
+ expression: tensor(i[10])(i) * attribute(f4)
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
index 150469cc928..2c1f4c8ecb6 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
@@ -82,7 +82,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase
new ImportedModels(),
new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(foo).rankingScript,14.0)", rankProperties.get(0).toString());
- assertEquals("(rankingExpression(firstphase).rankingScript,16.6)", rankProperties.get(2).toString());
+ assertEquals("(rankingExpression(firstphase).rankingScript,16.6)", rankProperties.get(3).toString());
}
@Test
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java
index 17bebcba70e..0ff8a5cc7ca 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java
@@ -40,7 +40,7 @@ public class RankingExpressionLoopDetectionTestCase {
fail("Excepted exception");
}
catch (IllegalArgumentException e) {
- assertEquals("In search definition 'test', rank profile 'test': The first-phase expression is invalid: Invocation loop: foo -> foo",
+ assertEquals("In search definition 'test', rank profile 'test': The function 'foo' is invalid: Invocation loop: foo -> foo",
Exceptions.toMessageString(e));
}
}
@@ -75,7 +75,7 @@ public class RankingExpressionLoopDetectionTestCase {
fail("Excepted exception");
}
catch (IllegalArgumentException e) {
- assertEquals("In search definition 'test', rank profile 'test': The first-phase expression is invalid: Invocation loop: foo -> arg(5) -> foo",
+ assertEquals("In search definition 'test', rank profile 'test': The function 'foo' is invalid: Invocation loop: arg(5) -> foo -> arg(5)",
Exceptions.toMessageString(e));
}
}
@@ -110,7 +110,7 @@ public class RankingExpressionLoopDetectionTestCase {
fail("Excepted exception");
}
catch (IllegalArgumentException e) {
- assertEquals("In search definition 'test', rank profile 'test': The first-phase expression is invalid: Invocation loop: foo -> arg(foo) -> foo",
+ assertEquals("In search definition 'test', rank profile 'test': The function 'foo' is invalid: Invocation loop: arg(foo) -> foo -> arg(foo)",
Exceptions.toMessageString(e));
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
index e15d4075b19..4f99922a422 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
@@ -211,12 +211,16 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
censorBindingHash(testRankProperties.get(1).toString()));
assertEquals("(rankingExpression(hidden_layer).rankingScript,rankingExpression(relu@))",
censorBindingHash(testRankProperties.get(2).toString()));
+ assertEquals("(rankingExpression(hidden_layer).type,tensor(x[]))",
+ censorBindingHash(testRankProperties.get(3).toString()));
assertEquals("(rankingExpression(final_layer).rankingScript,sigmoid(reduce(rankingExpression(hidden_layer) * constant(W_final), sum, hidden) + constant(b_final)))",
- testRankProperties.get(3).toString());
- assertEquals("(vespa.rank.secondphase,rankingExpression(secondphase))",
testRankProperties.get(4).toString());
- assertEquals("(rankingExpression(secondphase).rankingScript,reduce(rankingExpression(final_layer), sum))",
+ assertEquals("(rankingExpression(final_layer).type,tensor(x[]))",
testRankProperties.get(5).toString());
+ assertEquals("(vespa.rank.secondphase,rankingExpression(secondphase))",
+ testRankProperties.get(6).toString());
+ assertEquals("(rankingExpression(secondphase).rankingScript,reduce(rankingExpression(final_layer), sum))",
+ testRankProperties.get(7).toString());
}
private QueryProfileRegistry queryProfileWith(String field, String type) {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolverTestCase.java
index 0d8cbbf2e6a..1b917b6f3a3 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolverTestCase.java
@@ -19,7 +19,7 @@ import static org.junit.Assert.fail;
/**
* @author bratseth
*/
-public class RankingExpressionTypeValidatorTestCase {
+public class RankingExpressionTypeResolverTestCase {
@Test
public void tensorFirstPhaseMustProduceDouble() throws Exception {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
index fd048737b43..7d62bc5089d 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
@@ -41,7 +41,7 @@ public class RankingExpressionsTestCase extends SearchDefinitionTestCase {
new QueryProfileRegistry(),
new ImportedModels(),
new AttributeFields(search)).configProperties();
- assertEquals(6, rankProperties.size());
+ assertEquals(7, rankProperties.size());
assertEquals("rankingExpression(titlematch$).rankingScript", rankProperties.get(0).getFirst());
assertEquals("var1 * var2 + 890", rankProperties.get(0).getSecond());
@@ -49,14 +49,17 @@ public class RankingExpressionsTestCase extends SearchDefinitionTestCase {
assertEquals("rankingExpression(artistmatch).rankingScript", rankProperties.get(1).getFirst());
assertEquals("78 + closeness(distance)", rankProperties.get(1).getSecond());
- assertEquals("rankingExpression(firstphase).rankingScript", rankProperties.get(5).getFirst());
- assertEquals("0.8 + 0.2 * rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c) + 0.8 * rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6) * closeness(distance)", rankProperties.get(5).getSecond());
+ assertEquals("rankingExpression(firstphase).rankingScript", rankProperties.get(6).getFirst());
+ assertEquals("0.8 + 0.2 * rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c) + 0.8 * rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6) * closeness(distance)", rankProperties.get(6).getSecond());
- assertEquals("rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6).rankingScript", rankProperties.get(3).getFirst());
- assertEquals("7 * 8 + 890", rankProperties.get(3).getSecond());
+ assertEquals("rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6).rankingScript", rankProperties.get(4).getFirst());
+ assertEquals("7 * 8 + 890", rankProperties.get(4).getSecond());
- assertEquals("rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c).rankingScript", rankProperties.get(2).getFirst());
- assertEquals("4 * 5 + 890", rankProperties.get(2).getSecond());
+ assertEquals("rankingExpression(artistmatch).type", rankProperties.get(2).getFirst());
+ assertEquals("tensor()", rankProperties.get(2).getSecond());
+
+ assertEquals("rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c).rankingScript", rankProperties.get(3).getFirst());
+ assertEquals("4 * 5 + 890", rankProperties.get(3).getSecond());
}
@Test(expected = IllegalArgumentException.class)
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
index 8e721dbe503..6e3a227e2a9 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
@@ -58,8 +58,8 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase {
"max(attribute(tensor_field_1),x)");
assertTransformedExpression("1+reduce(attribute(tensor_field_1),max,x)",
"1 + max(attribute(tensor_field_1),x)");
- assertTransformedExpression("if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)",
- "if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)");
+ assertTransformedExpression("if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),attribute(tensor_field_1))",
+ "if(attribute(double_field),1 + max(attribute(tensor_field_1),x),attribute(tensor_field_1))");
assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)",
"max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)");
assertTransformedExpression("reduce(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),max,x)",
diff --git a/container-search/src/main/java/com/yahoo/search/searchchain/AsyncExecution.java b/container-search/src/main/java/com/yahoo/search/searchchain/AsyncExecution.java
index d74f29a3b77..4dbec304bde 100644
--- a/container-search/src/main/java/com/yahoo/search/searchchain/AsyncExecution.java
+++ b/container-search/src/main/java/com/yahoo/search/searchchain/AsyncExecution.java
@@ -147,7 +147,7 @@ public class AsyncExecution {
}
private static <T> Future<T> getFuture(Callable<T> callable) {
- final FutureTask<T> future = new FutureTask<>(callable);
+ FutureTask<T> future = new FutureTask<>(callable);
getExecutor().execute(future);
return future;
}
diff --git a/processing/src/main/java/com/yahoo/processing/execution/AsyncExecution.java b/processing/src/main/java/com/yahoo/processing/execution/AsyncExecution.java
index eac96e9b408..2c40165f8e5 100644
--- a/processing/src/main/java/com/yahoo/processing/execution/AsyncExecution.java
+++ b/processing/src/main/java/com/yahoo/processing/execution/AsyncExecution.java
@@ -59,14 +59,14 @@ public class AsyncExecution {
/**
* Create an async execution of a single processor
*/
- public AsyncExecution(final Processor processor, Execution parent) {
+ public AsyncExecution(Processor processor, Execution parent) {
this(new Execution(processor, parent));
}
/**
* Create an async execution of a chain
*/
- public AsyncExecution(final Chain<? extends Processor> chain, Execution parent) {
+ public AsyncExecution(Chain<? extends Processor> chain, Execution parent) {
this(new Execution(chain, parent));
}
@@ -81,7 +81,7 @@ public class AsyncExecution {
*
* @param execution the execution from which the state of this is created
*/
- public AsyncExecution(final Execution execution) {
+ public AsyncExecution(Execution execution) {
this.execution = new Execution(execution);
}
@@ -89,7 +89,7 @@ public class AsyncExecution {
* Performs an async processing. Note that the given request cannot be simultaneously
* used in multiple such processings - a clone must be created for each.
*/
- public FutureResponse process(final Request request) {
+ public FutureResponse process(Request request) {
return getFutureResponse(new Callable<Response>() {
@Override
public Response call() {
@@ -99,13 +99,13 @@ public class AsyncExecution {
}
private static <T> Future<T> getFuture(final Callable<T> callable) {
- final FutureTask<T> future = new FutureTask<>(callable);
+ FutureTask<T> future = new FutureTask<>(callable);
executorMain.execute(future);
return future;
}
- private FutureResponse getFutureResponse(final Callable<Response> callable, final Request request) {
- final FutureResponse future = new FutureResponse(callable, execution, request);
+ private FutureResponse getFutureResponse(Callable<Response> callable, Request request) {
+ FutureResponse future = new FutureResponse(callable, execution, request);
executorMain.execute(future.delegate());
return future;
}
@@ -118,15 +118,15 @@ public class AsyncExecution {
* @return the list of responses in the same order as returned from the task collection
*/
// Note that this may also be achieved using guava Futures. Not sure if this should be deprecated because of it.
- public static List<Response> waitForAll(final Collection<FutureResponse> tasks, final long timeout) {
+ public static List<Response> waitForAll(Collection<FutureResponse> tasks, long timeout) {
// Copy the list in case it is modified while we are waiting
- final List<FutureResponse> workingTasks = new ArrayList<>(tasks);
+ List<FutureResponse> workingTasks = new ArrayList<>(tasks);
@SuppressWarnings({"rawtypes", "unchecked"})
- final Future task = getFuture(new Callable() {
+ Future task = getFuture(new Callable() {
@Override
public List<Future> call() {
- for (final FutureResponse task : workingTasks) {
+ for (FutureResponse task : workingTasks) {
task.get();
}
return null;
@@ -135,12 +135,12 @@ public class AsyncExecution {
try {
task.get(timeout, TimeUnit.MILLISECONDS);
- } catch (final TimeoutException | InterruptedException | ExecutionException e) {
+ } catch (TimeoutException | InterruptedException | ExecutionException e) {
// Handle timeouts below
}
- final List<Response> responses = new ArrayList<>(tasks.size());
- for (final FutureResponse future : workingTasks)
+ List<Response> responses = new ArrayList<>(tasks.size());
+ for (FutureResponse future : workingTasks)
responses.add(getTaskResponse(future));
return responses;
}
@@ -153,5 +153,4 @@ public class AsyncExecution {
}
}
-
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
index da34ab8822d..f6502a9801d 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
@@ -2,8 +2,11 @@
package com.yahoo.searchlib.rankingexpression;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.yahoo.log.event.Collection;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
+import com.yahoo.tensor.TensorType;
import com.yahoo.text.Utf8;
import java.security.MessageDigest;
@@ -13,9 +16,14 @@ import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
/**
- * A function defined by a ranking expression
+ * A function defined by a ranking expression, optionally containing type information
+ * for inputs and outputs.
+ *
+ * Immutable, but note that ranking expressions are *not* immutable.
*
* @author Simon Thoresen Hult
* @author bratseth
@@ -24,8 +32,13 @@ public class ExpressionFunction {
private final String name;
private final ImmutableList<String> arguments;
+
+ /** Types of the inputs, if known. The keys here is any subset (including empty and identity) of the argument list */
+ private final ImmutableMap<String, TensorType> argumentTypes;
private final RankingExpression body;
+ private final Optional<TensorType> returnType;
+
/**
* Constructs a new function with no arguments
*
@@ -44,9 +57,18 @@ public class ExpressionFunction {
* @param body the ranking expression that defines this function
*/
public ExpressionFunction(String name, List<String> arguments, RankingExpression body) {
- this.name = name;
+ this(name, arguments, body, ImmutableMap.of(), Optional.empty());
+ }
+
+ public ExpressionFunction(String name, List<String> arguments, RankingExpression body,
+ Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) {
+ this.name = Objects.requireNonNull(name, "name cannot be null");
this.arguments = arguments==null ? ImmutableList.of() : ImmutableList.copyOf(arguments);
- this.body = body;
+ this.body = Objects.requireNonNull(body, "body cannot be null");
+ if ( ! this.arguments.containsAll(argumentTypes.keySet()))
+ throw new IllegalArgumentException("Argument type keys must be a subset of the argument keys");
+ this.argumentTypes = ImmutableMap.copyOf(argumentTypes);
+ this.returnType = Objects.requireNonNull(returnType, "returnType cannot be null");
}
public String getName() { return name; }
@@ -56,9 +78,27 @@ public class ExpressionFunction {
public RankingExpression getBody() { return body; }
+ /** Returns the types of the arguments of this, if specified. The keys of this may be any subset of the arguments */
+ public Map<String, TensorType> argumentTypes() { return argumentTypes; }
+
+ /** Returns the return type of this, or empty if not specified */
+ public Optional<TensorType> returnType() { return returnType; }
+
+ public ExpressionFunction withName(String name) {
+ return new ExpressionFunction(name, arguments, body, argumentTypes, returnType);
+ }
+
/** Returns a copy of this with the body changed to the given value */
public ExpressionFunction withBody(RankingExpression body) {
- return new ExpressionFunction(name, arguments, body);
+ return new ExpressionFunction(name, arguments, body, argumentTypes, returnType);
+ }
+
+ public ExpressionFunction withReturnType(TensorType returnType) {
+ return new ExpressionFunction(name, arguments, body, argumentTypes, Optional.of(returnType));
+ }
+
+ public ExpressionFunction withArgumentTypes(Map<String, TensorType> argumentTypes) {
+ return new ExpressionFunction(name, arguments, body, argumentTypes, returnType);
}
/**
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
index 282a4c5e0a9..9ff391a5cfe 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
@@ -1,15 +1,22 @@
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
import com.yahoo.collections.Pair;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
+import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
import java.util.regex.Pattern;
/**
@@ -26,12 +33,11 @@ public class ImportedModel {
private final String source;
private final Map<String, Signature> signatures = new HashMap<>();
- private final Map<String, TensorType> arguments = new HashMap<>();
+ private final Map<String, TensorType> inputs = new HashMap<>();
private final Map<String, Tensor> smallConstants = new HashMap<>();
private final Map<String, Tensor> largeConstants = new HashMap<>();
private final Map<String, RankingExpression> expressions = new HashMap<>();
private final Map<String, RankingExpression> functions = new HashMap<>();
- private final Map<String, TensorType> requiredFunctions = new HashMap<>();
/**
* Creates a new imported model.
@@ -49,11 +55,11 @@ public class ImportedModel {
/** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
public String name() { return name; }
- /** Returns the source path (directiry or file) of this model */
+ /** Returns the source path (directory or file) of this model */
public String source() { return source; }
- /** Returns an immutable map of the arguments ("Placeholders") of this */
- public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
+ /** Returns an immutable map of the inputs of this */
+ public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); }
/**
* Returns an immutable map of the small constants of this.
@@ -71,7 +77,7 @@ public class ImportedModel {
/**
* Returns an immutable map of the expressions of this - corresponding to graph nodes
- * which are not Inputs/Placeholders or Variables (which instead become respectively arguments and constants).
+ * which are not Inputs/Placeholders or Variables (which instead become respectively inputs and constants).
* Note that only nodes recursively referenced by a placeholder/input are added.
*/
public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
@@ -82,9 +88,6 @@ public class ImportedModel {
*/
public Map<String, RankingExpression> functions() { return Collections.unmodifiableMap(functions); }
- /** Returns an immutable map of the functions that must be provided by the environment running this model */
- public Map<String, TensorType> requiredFunctions() { return Collections.unmodifiableMap(requiredFunctions); }
-
/** Returns an immutable map of the signatures of this */
public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
@@ -96,12 +99,11 @@ public class ImportedModel {
/** Convenience method for returning a default signature */
Signature defaultSignature() { return signature(defaultSignatureName); }
- void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
+ void input(String name, TensorType argumentType) { inputs.put(name, argumentType); }
void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
void function(String name, RankingExpression expression) { functions.put(name, expression); }
- void requiredFunction(String name, TensorType type) { requiredFunctions.put(name, type); }
/**
* Returns all the output expressions of this indexed by name. The names consist of one or two parts
@@ -109,24 +111,39 @@ public class ImportedModel {
* if signatures are used, or the expression name if signatures are not used and there are multiple
* expressions, and the second is the output name if signature names are used.
*/
- public List<Pair<String, RankingExpression>> outputExpressions() {
- List<Pair<String, RankingExpression>> expressions = new ArrayList<>();
+ public List<Pair<String, ExpressionFunction>> outputExpressions() {
+ List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>();
for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) {
for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet())
expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(),
- expressions().get(outputEntry.getValue())));
+ signatureEntry.getValue().outputExpression(outputEntry.getKey())
+ .withName(signatureEntry.getKey() + "." + outputEntry.getKey())));
if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
expressions.add(new Pair<>(signatureEntry.getKey(),
- expressions().get(signatureEntry.getKey())));
+ new ExpressionFunction(signatureEntry.getKey(),
+ new ArrayList<>(signatureEntry.getValue().inputs().keySet()),
+ expressions().get(signatureEntry.getKey()),
+ signatureEntry.getValue().inputMap(),
+ Optional.empty())));
}
if (signatures().isEmpty()) { // fallback for models without signatures
if (expressions().size() == 1) {
Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next();
- expressions.add(new Pair<>(singleEntry.getKey(), singleEntry.getValue()));
+ expressions.add(new Pair<>(singleEntry.getKey(),
+ new ExpressionFunction(singleEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ singleEntry.getValue(),
+ inputs,
+ Optional.empty())));
}
else {
for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) {
- expressions.add(new Pair<>(expressionEntry.getKey(), expressionEntry.getValue()));
+ expressions.add(new Pair<>(expressionEntry.getKey(),
+ new ExpressionFunction(expressionEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ expressionEntry.getValue(),
+ inputs,
+ Optional.empty())));
}
}
}
@@ -134,7 +151,7 @@ public class ImportedModel {
}
/**
- * A signature is a set of named inputs and outputs, where the inputs maps to argument
+ * A signature is a set of named inputs and outputs, where the inputs maps to input
* ("placeholder") names+types, and outputs maps to expressions nodes.
* Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit
* concept of signatures. For now, we handle ONNX models as having a single signature.
@@ -142,8 +159,8 @@ public class ImportedModel {
public class Signature {
private final String name;
- private final Map<String, String> inputs = new HashMap<>();
- private final Map<String, String> outputs = new HashMap<>();
+ private final Map<String, String> inputs = new LinkedHashMap<>();
+ private final Map<String, String> outputs = new LinkedHashMap<>();
private final Map<String, String> skippedOutputs = new HashMap<>();
private final List<String> importWarnings = new ArrayList<>();
@@ -158,12 +175,20 @@ public class ImportedModel {
/**
* Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
- * to argument (Placeholder) name in the owner of this
+ * in this signature to input name in the owning model
*/
public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
- /** Returns the type of the argument this input references */
- public TensorType inputArgument(String inputName) { return owner().arguments().get(inputs.get(inputName)); }
+ /** Returns the name and type of all inputs in this signature as an immutable map */
+ public Map<String, TensorType> inputMap() {
+ ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>();
+ for (Map.Entry<String, String> inputEntry : inputs().entrySet())
+ inputs.put(inputEntry.getKey(), owner().inputs().get(inputEntry.getValue()));
+ return inputs.build();
+ }
+
+ /** Returns the type of the input this input references */
+ public TensorType inputArgument(String inputName) { return owner().inputs().get(inputs.get(inputName)); }
/** Returns an immutable list of the expression names of this */
public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); }
@@ -180,7 +205,13 @@ public class ImportedModel {
public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
/** Returns the expression this output references */
- public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); }
+ public ExpressionFunction outputExpression(String outputName) {
+ return new ExpressionFunction(outputName,
+ new ArrayList<>(inputs.keySet()),
+ owner().expressions().get(outputs.get(outputName)),
+ inputMap(),
+ Optional.empty());
+ }
@Override
public String toString() { return "signature '" + name + "'"; }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
index d25502fd149..b7138ad87e3 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
@@ -187,8 +187,7 @@ public abstract class ModelImporter {
if (operation.isInput()) {
// All inputs must have dimensions with standard naming convention: d0, d1, ...
OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
- model.argument(operation.vespaName(), standardNamingConvention.type());
- model.requiredFunction(operation.vespaName(), standardNamingConvention.type());
+ model.input(operation.vespaName(), standardNamingConvention.type());
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
index 917b0d6a389..e6bb5f40b3f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
@@ -2,7 +2,6 @@
package com.yahoo.searchlib.rankingexpression.integration.ml;
-import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter;
import onnx.Onnx;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
index 796c13a8669..94d663b4954 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
@@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.tensor.TensorType;
import java.util.Collection;
import java.util.Collections;
@@ -80,9 +82,14 @@ public class SerializationContext extends FunctionReferenceContext {
serializedFunctions.put(name, expressionString);
}
- /** Returns the existing serialization of a function, or null if none */
- public String getFunctionSerialization(String name) {
- return serializedFunctions.get(name);
+ /** Adds the serialization of the an argument type to a function */
+ public void addArgumentTypeSerialization(String functionName, String argumentName, TensorType type) {
+ serializedFunctions.put("rankingExpression(" + functionName + ")." + argumentName + ".type", type.toString());
+ }
+
+ /** Adds the serialization of the return type of a function */
+ public void addFunctionTypeSerialization(String functionName, TensorType type) {
+ serializedFunctions.put("rankingExpression(" + functionName + ").type", type.toString());
}
@Override
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java b/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java
index 118eba2cd96..969bc318391 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java
@@ -15,6 +15,7 @@ import org.junit.Test;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import static org.junit.Assert.fail;
@@ -40,7 +41,8 @@ public class GroupingSerializationTest {
t.assertMatch(new FloatResultNode(7.3));
t.assertMatch(new StringResultNode("7.3"));
t.assertMatch(new StringResultNode(
- new String(new byte[]{(byte)0xe5, (byte)0xa6, (byte)0x82, (byte)0xe6, (byte)0x9e, (byte)0x9c})));
+ new String(new byte[]{(byte)0xe5, (byte)0xa6, (byte)0x82, (byte)0xe6, (byte)0x9e, (byte)0x9c},
+ StandardCharsets.UTF_8)));
t.assertMatch(new RawResultNode(new byte[]{'7', '.', '4'}));
t.assertMatch(new IntegerBucketResultNode());
t.assertMatch(new FloatBucketResultNode());
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
index bf9684082f4..593e7b54c10 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import org.junit.Test;
@@ -20,10 +21,11 @@ public class BatchNormImportTestCase {
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- RankingExpression output = signature.outputExpression("y");
+ ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName());
- model.assertEqualResult("X", output.getName());
+ assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getBody().getName());
+ model.assertEqualResult("X", output.getBody().getName());
+ assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
index a8f7542f3a4..59712c0152f 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -19,22 +20,23 @@ public class DropoutImportTestCase {
TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/dropout/saved");
// Check required functions
- assertEquals(1, model.get().requiredFunctions().size());
- assertTrue(model.get().requiredFunctions().containsKey("X"));
+ assertEquals(1, model.get().inputs().size());
+ assertTrue(model.get().inputs().containsKey("X"));
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
- model.get().requiredFunctions().get("X"));
+ model.get().inputs().get("X"));
ImportedModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- RankingExpression output = signature.outputExpression("y");
+ ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("outputs/Maximum", output.getName());
+ assertEquals("outputs/Maximum", output.getBody().getName());
assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
- output.getRoot().toString());
- model.assertEqualResult("X", output.getName());
+ output.getBody().getRoot().toString());
+ model.assertEqualResult("X", output.getBody().getName());
+ assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java
index add66eece1a..3d8d5d5a570 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import org.junit.Test;
@@ -20,11 +21,10 @@ public class MnistImportTestCase {
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- RankingExpression output = signature.outputExpression("y");
+ ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("dnn/outputs/add", output.getName());
- model.assertEqualResultSum("input", output.getName(), 0.00001);
+ assertEquals("dnn/outputs/add", output.getBody().getName());
+ model.assertEqualResultSum("input", output.getBody().getName(), 0.00001);
}
-
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
index e20ac16a691..b6e83404ab1 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
@@ -1,5 +1,6 @@
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
@@ -27,27 +28,28 @@ public class OnnxMnistSoftmaxImportTestCase {
Tensor constant0 = model.largeConstants().get("test_Variable");
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
- constant0.type());
+ constant0.type());
assertEquals(7840, constant0.size());
Tensor constant1 = model.largeConstants().get("test_Variable_1");
assertNotNull(constant1);
- assertEquals(new TensorType.Builder().indexed("d1", 10).build(),
- constant1.type());
+ assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type());
assertEquals(10, constant1.size());
- // Check required functions (inputs)
- assertEquals(1, model.requiredFunctions().size());
- assertTrue(model.requiredFunctions().containsKey("Placeholder"));
- assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
- model.requiredFunctions().get("Placeholder"));
+ // Check inputs
+ assertEquals(1, model.inputs().size());
+ assertTrue(model.inputs().containsKey("Placeholder"));
+ assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder"));
- // Check outputs
- RankingExpression output = model.defaultSignature().outputExpression("add");
+ // Check signature
+ ExpressionFunction output = model.defaultSignature().outputExpression("add");
assertNotNull(output);
- assertEquals("add", output.getName());
+ assertEquals("add", output.getBody().getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
- output.getRoot().toString());
+ output.getBody().getRoot().toString());
+ assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"),
+ model.inputs().get(model.defaultSignature().inputs().get("Placeholder")));
+ assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString());
}
@Test
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
index ef28eb4678f..0a48ecfce21 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -38,10 +39,10 @@ public class TensorFlowMnistSoftmaxImportTestCase {
assertEquals(0, model.get().functions().size());
// Check required functions
- assertEquals(1, model.get().requiredFunctions().size());
- assertTrue(model.get().requiredFunctions().containsKey("Placeholder"));
+ assertEquals(1, model.get().inputs().size());
+ assertTrue(model.get().inputs().containsKey("Placeholder"));
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
- model.get().requiredFunctions().get("Placeholder"));
+ model.get().inputs().get("Placeholder"));
// Check signatures
assertEquals(1, model.get().signatures().size());
@@ -56,11 +57,12 @@ public class TensorFlowMnistSoftmaxImportTestCase {
// ... signature outputs
assertEquals(1, signature.outputs().size());
- RankingExpression output = signature.outputExpression("y");
+ ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("add", output.getName());
+ assertEquals("add", output.getBody().getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))",
- output.getRoot().toString());
+ output.getBody().getRoot().toString());
+ assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString());
// Test execution
model.assertEqualResult("Placeholder", "MatMul");