summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java82
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/Search.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java24
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java40
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/search/QueryProfilesBuilder.java12
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java135
-rw-r--r--container-core/src/main/java/com/yahoo/container/jdisc/ExtendedResponse.java9
-rw-r--r--container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java125
-rw-r--r--container-search/src/main/java/com/yahoo/search/Result.java2
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java124
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java16
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/QueryProfileType.java4
-rw-r--r--container-search/src/main/java/com/yahoo/search/statistics/ElapsedTime.java22
-rw-r--r--container-search/src/main/java/com/yahoo/search/statistics/TimeTracker.java21
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java29
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java11
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java43
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java28
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java13
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java9
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java23
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java12
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java5
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java7
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java10
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java20
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java4
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java20
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java25
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java9
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java20
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java14
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java10
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java8
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java6
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/EndpointResult.java3
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ErrorCode.java9
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/OperationStatus.java9
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ServerResponseException.java6
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java3
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java2
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java2
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java21
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java32
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java39
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java13
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java25
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java46
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java13
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java8
79 files changed, 898 insertions, 460 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index bacff94d776..cd65c6ef761 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -2,15 +2,26 @@
package com.yahoo.searchdefinition;
import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.io.reader.NamedReader;
+import com.yahoo.processing.request.CompoundName;
+import com.yahoo.search.query.profile.QueryProfile;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.search.query.profile.config.QueryProfileXMLReader;
+import com.yahoo.search.query.profile.types.FieldDescription;
+import com.yahoo.search.query.profile.types.TensorFieldType;
import com.yahoo.search.query.ranking.Diversity;
+import com.yahoo.searchdefinition.document.SDField;
import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext;
import com.yahoo.searchdefinition.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.FeatureList;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TypeMapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.io.File;
import java.io.IOException;
@@ -276,12 +287,10 @@ public class RankProfile implements Serializable, Cloneable {
addConstant(name, value);
}
- /**
- * Returns an unmodifiable view of the constants to use in this.
- */
+ /** Returns an unmodifiable view of the constants available in this */
public Map<String, Value> getConstants() {
if (constants.isEmpty())
- return getInherited() != null ? getInherited().getConstants() : Collections.<String,Value>emptyMap();
+ return getInherited() != null ? getInherited().getConstants() : Collections.emptyMap();
if (getInherited() == null || getInherited().getConstants().isEmpty())
return Collections.unmodifiableMap(constants);
@@ -433,8 +442,9 @@ public class RankProfile implements Serializable, Cloneable {
properties.add(rankProperty);
}
+ @Override
public String toString() {
- return "rank profile " + getName();
+ return "rank profile '" + getName() + "'";
}
public int getRerankCount() {
@@ -508,6 +518,7 @@ public class RankProfile implements Serializable, Cloneable {
/**
* Returns the string form of the first phase ranking expression.
+ *
* @return string form of first phase ranking expression
*/
public String getFirstPhaseRankingString() {
@@ -697,7 +708,8 @@ public class RankProfile implements Serializable, Cloneable {
private void checkNameCollisions(Map<String, Macro> macros, Map<String, Value> constants) {
for (Map.Entry<String, Macro> macroEntry : macros.entrySet()) {
if (constants.get(macroEntry.getKey()) != null)
- throw new IllegalArgumentException("Cannot have both a constant and macro named '" + macroEntry.getKey() + "'");
+ throw new IllegalArgumentException("Cannot have both a constant and macro named '" +
+ macroEntry.getKey() + "'");
}
}
@@ -727,6 +739,64 @@ public class RankProfile implements Serializable, Cloneable {
}
/**
+ * Creates a context containing the type information of all constants, attributes and query profiles
+ * referable from this rank profile.
+ */
+ public TypeContext typeContext() {
+ TypeMapContext context = new TypeMapContext();
+
+ // Add constants
+ getConstants().forEach((k, v) -> context.setType(asConstantFeature(k), v.type()));
+
+ // Add attributes
+ for (SDField field : getSearch().allConcreteFields())
+ field.getAttributes().forEach((k, a) -> context.setType(asAttributeFeature(k), a.tensorType().orElse(TensorType.empty)));
+
+ // Add query features from rank profile types reached from the "default" profile
+ QueryProfile profile = queryProfilesOf(getSearch().sourceApplication()).getComponent("default");
+ if (profile != null && profile.getType() != null) {
+ profile.listTypes(CompoundName.empty, Collections.emptyMap()).forEach((prefix, queryProfileType) -> {
+ for (FieldDescription field : queryProfileType.declaredFields().values()) {
+ TensorType type = TensorType.empty; // assume the empty (aka double) type by default
+ if (field.getType() instanceof TensorFieldType)
+ type = ((TensorFieldType)field.getType()).type().get();
+
+ String feature = asQueryFeature(prefix.append(field.getName()).toString());
+ context.setType(feature, type);
+ }
+ });
+ }
+
+ return context;
+ }
+
+ private QueryProfileRegistry queryProfilesOf(ApplicationPackage applicationPackage) {
+ List<NamedReader> queryProfileFiles = null;
+ List<NamedReader> queryProfileTypeFiles = null;
+ try {
+ queryProfileFiles = applicationPackage.getQueryProfileFiles();
+ queryProfileTypeFiles = applicationPackage.getQueryProfileTypeFiles();
+ return new QueryProfileXMLReader().read(queryProfileFiles, queryProfileTypeFiles);
+ }
+ finally {
+ NamedReader.closeAll(queryProfileFiles);
+ NamedReader.closeAll(queryProfileTypeFiles);
+ }
+ }
+
+ private String asConstantFeature(String constantName) {
+ return "constant(\"" + constantName + "\")";
+ }
+
+ private String asAttributeFeature(String constantName) {
+ return "attribute(\"" + constantName + "\")";
+ }
+
+ private String asQueryFeature(String constantName) {
+ return "query(\"" + constantName + "\")";
+ }
+
+ /**
* A rank setting. The identity of a rank setting is its field name and type (not value).
* A rank setting is immutable.
*/
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
index df5697de0d5..f4a0365e36e 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
@@ -247,7 +247,7 @@ public class Search implements Serializable, ImmutableSearch {
/**
* Returns a list of all the fields of this search definition, that is all fields in all documents, in the documents
* they inherit, and all extra fields. The caller receives ownership to the list - subsequent changes to it will not
- * impact this Search
+ * impact this
*
* @return the list of fields in this searchdefinition
*/
@@ -546,7 +546,7 @@ public class Search implements Serializable, ImmutableSearch {
}
/**
- * Returns the first occurence of an attribute having this name, or null if none
+ * Returns the first occurrence of an attribute having this name, or null if none
*
* @param name Name of attribute
* @return The Attribute with given name.
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java
index ce3ff7cc447..72ba6de7022 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java
@@ -31,24 +31,19 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce
private Map<String, Attribute> attributes = new java.util.LinkedHashMap<>();
private Map<String, Attribute> importedAttributes = new java.util.LinkedHashMap<>();
- /**
- * Flag indicating if a position-attribute has been found
- */
+ /** Whether this has any position attribute */
private boolean hasPosition = false;
public AttributeFields(Search search) {
derive(search);
}
- /**
- * Derives everything from a field
- */
+ /** Derives everything from a field */
@Override
protected void derive(ImmutableSDField field, Search search) {
if (field.usesStructOrMap() &&
!field.getDataType().equals(PositionDataType.INSTANCE) &&
- !field.getDataType().equals(DataType.getArray(PositionDataType.INSTANCE)))
- {
+ !field.getDataType().equals(DataType.getArray(PositionDataType.INSTANCE))) {
return; // Ignore struct fields for indexed search (only implemented for streaming search)
}
if (field.isImportedField()) {
@@ -58,9 +53,7 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce
}
}
- /**
- * Return an attribute by name, or null if it doesn't exist
- */
+ /** Returns an attribute by name, or null if it doesn't exist */
public Attribute getAttribute(String attributeName) {
return attributes.get(attributeName);
}
@@ -69,9 +62,7 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce
return getAttribute(attributeName) != null;
}
- /**
- * Derives one attribute. TODO: Support non-default named attributes
- */
+ /** Derives one attribute. TODO: Support non-default named attributes */
private void deriveAttributes(ImmutableSDField field) {
for (Attribute fieldAttribute : field.getAttributes().values()) {
deriveAttribute(field, fieldAttribute);
@@ -107,9 +98,7 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce
}
}
- /**
- * Returns a read only attribute iterator
- */
+ /** Returns a read only attribute iterator */
public Iterator attributeIterator() {
return attributes().iterator();
}
@@ -201,4 +190,5 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce
}
}
}
+
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 01d3449573c..9c7fd7d9f0a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -41,7 +41,6 @@ import java.util.Optional;
*
* @author bratseth
*/
-// TODO: Verify types of macros
// TODO: Avoid name conflicts across models for constants
public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
@@ -84,6 +83,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
Signature signature = chooseSignature(model, store.arguments().signature());
String output = chooseOutput(signature, store.arguments().output());
RankingExpression expression = model.expressions().get(output);
+ verifyRequiredMacros(expression, model.requiredMacros(), profile);
store.writeConverted(expression);
model.constants().forEach((k, v) -> transformConstant(store, profile, k, v));
@@ -168,6 +168,44 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
/**
+ * Verify that the macros referred in the given expression exists in the given rank profile,
+ * and return tensors of the types specified in requiredMacros.
+ */
+ private void verifyRequiredMacros(RankingExpression expression, Map<String, TensorType> requiredMacros,
+ RankProfile profile) {
+ List<String> macroNames = new ArrayList<>();
+ addMacroNamesIn(expression.getRoot(), macroNames);
+ for (String macroName : macroNames) {
+ TensorType requiredType = requiredMacros.get(macroName);
+ if (requiredType == null) continue; // Not a required macro
+
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ if (macro == null)
+ throw new IllegalArgumentException("Model refers Placeholder '" + macroName +
+ "' of type " + requiredType + " but this macro is not present in " +
+ profile);
+ TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext());
+ if ( ! actualType.isAssignableTo(requiredType))
+ throw new IllegalArgumentException("Model refers Placeholder '" + macroName +
+ "' of type " + requiredType +
+ " which must be produced by a macro in the rank profile, but " +
+ "this macro produces type " + actualType);
+ }
+ }
+
+ private void addMacroNamesIn(ExpressionNode node, List<String> names) {
+ if (node instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode)node;
+ if (referenceNode.getOutput() == null) // macro references cannot specify outputs
+ names.add(referenceNode.getName());
+ }
+ else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode)node).children())
+ addMacroNamesIn(child, names);
+ }
+ }
+
+ /**
* Provides read/write access to the correct directories of the application package given by the feature arguments
*/
private static class ModelStore {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/QueryProfilesBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/QueryProfilesBuilder.java
index d08092003f1..b85cb88bf2e 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/QueryProfilesBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/QueryProfilesBuilder.java
@@ -5,6 +5,7 @@ import com.yahoo.io.reader.NamedReader;
import com.yahoo.search.query.profile.config.QueryProfileXMLReader;
import com.yahoo.config.application.api.ApplicationPackage;
+import java.util.Collections;
import java.util.List;
/**
@@ -13,17 +14,16 @@ import java.util.List;
*
* @author bratseth
*/
-// TODO: Move into QueryProfiles
public class QueryProfilesBuilder {
/** Build the set of query profiles for an application package */
public QueryProfiles build(ApplicationPackage applicationPackage) {
- List<NamedReader> queryProfileTypeFiles=null;
- List<NamedReader> queryProfileFiles=null;
+ List<NamedReader> queryProfileTypeFiles = null;
+ List<NamedReader> queryProfileFiles = null;
try {
- queryProfileTypeFiles=applicationPackage.getQueryProfileTypeFiles();
- queryProfileFiles=applicationPackage.getQueryProfileFiles();
- return new QueryProfiles(new QueryProfileXMLReader().read(queryProfileTypeFiles,queryProfileFiles));
+ queryProfileTypeFiles = applicationPackage.getQueryProfileTypeFiles();
+ queryProfileFiles = applicationPackage.getQueryProfileFiles();
+ return new QueryProfiles(new QueryProfileXMLReader().read(queryProfileTypeFiles, queryProfileFiles));
}
finally {
NamedReader.closeAll(queryProfileTypeFiles);
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 89f1a9f785c..82d0d66a82a 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -49,14 +49,8 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testMinimalTensorFlowReference() throws ParseException {
- StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
- RankProfileSearchFixture search = new RankProfileSearchFixture(
- application,
- " rank-profile my_profile {\n" +
- " first-phase {\n" +
- " expression: tensorflow('mnist_softmax/saved')" +
- " }\n" +
- " }");
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "tensorflow('mnist_softmax/saved')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
assertConstant("Variable_1", search, Optional.of(10L));
assertConstant("Variable", search, Optional.of(7840L));
@@ -64,14 +58,8 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testNestedTensorFlowReference() throws ParseException {
- StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
- RankProfileSearchFixture search = new RankProfileSearchFixture(
- application,
- " rank-profile my_profile {\n" +
- " first-phase {\n" +
- " expression: 5 + sum(tensorflow('mnist_softmax/saved'))" +
- " }\n" +
- " }");
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "5 + sum(tensorflow('mnist_softmax/saved'))");
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
assertConstant("Variable_1", search, Optional.of(10L));
assertConstant("Variable", search, Optional.of(7840L));
@@ -79,39 +67,26 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testTensorFlowReferenceSpecifyingSignature() throws ParseException {
- StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
- RankProfileSearchFixture search = new RankProfileSearchFixture(
- application,
- " rank-profile my_profile {\n" +
- " first-phase {\n" +
- " expression: tensorflow('mnist_softmax/saved', 'serving_default')" +
- " }\n" +
- " }");
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "tensorflow('mnist_softmax/saved', 'serving_default')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@Test
public void testTensorFlowReferenceSpecifyingSignatureAndOutput() throws ParseException {
- StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
- RankProfileSearchFixture search = new RankProfileSearchFixture(
- application,
- " rank-profile my_profile {\n" +
- " first-phase {\n" +
- " expression: tensorflow('mnist_softmax/saved', 'serving_default', 'y')" +
- " }\n" +
- " }");
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "tensorflow('mnist_softmax/saved', 'serving_default', 'y')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@Test
- public void testTensorFlowReferenceSpecifyingNonExistingSignature() throws ParseException {
+ public void testTensorFlowReferenceMissingMacro() throws ParseException {
try {
- StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = new RankProfileSearchFixture(
- application,
+ new StoringApplicationPackage(applicationDir),
" rank-profile my_profile {\n" +
" first-phase {\n" +
- " expression: tensorflow('mnist_softmax/saved', 'serving_defaultz')" +
+ " expression: tensorflow('mnist_softmax/saved')" +
" }\n" +
" }");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
@@ -119,6 +94,40 @@ public class RankingExpressionWithTensorFlowTestCase {
}
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
+ "tensorflow('mnist_softmax/saved'): " +
+ "Model refers Placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
+ "not present in rank profile 'my_profile'",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+ @Test
+ public void testTensorFlowReferenceWithWrongMacroType() throws ParseException {
+ try {
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d5[10])(0.0)",
+ "tensorflow('mnist_softmax/saved')");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ fail("Expecting exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
+ "tensorflow('mnist_softmax/saved'): " +
+ "Model refers Placeholder 'Placeholder' of type tensor(d0[],d1[784]) which must be produced " +
+ "by a macro in the rank profile, but this macro produces type tensor(d0[2],d5[10])",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+ @Test
+ public void testTensorFlowReferenceSpecifyingNonExistingSignature() throws ParseException {
+ try {
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "tensorflow('mnist_softmax/saved', 'serving_defaultz')");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ fail("Expecting exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved','serving_defaultz'): " +
"Model does not have the specified signature 'serving_defaultz'",
Exceptions.toMessageString(expected));
@@ -128,14 +137,8 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testTensorFlowReferenceSpecifyingNonExistingOutput() throws ParseException {
try {
- StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
- RankProfileSearchFixture search = new RankProfileSearchFixture(
- application,
- " rank-profile my_profile {\n" +
- " first-phase {\n" +
- " expression: tensorflow('mnist_softmax/saved', 'serving_default', 'x')" +
- " }\n" +
- " }");
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "tensorflow('mnist_softmax/saved', 'serving_default', 'x')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
fail("Expecting exception");
}
@@ -149,14 +152,8 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testImportingFromStoredExpressions() throws ParseException, IOException {
- StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
- RankProfileSearchFixture search = new RankProfileSearchFixture(
- application,
- " rank-profile my_profile {\n" +
- " first-phase {\n" +
- " expression: tensorflow('mnist_softmax/saved', 'serving_default')" +
- " }\n" +
- " }");
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "tensorflow('mnist_softmax/saved')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
assertConstant("Variable_1", search, Optional.of(10L));
assertConstant("Variable", search, Optional.of(7840L));
@@ -168,13 +165,9 @@ public class RankingExpressionWithTensorFlowTestCase {
IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
- RankProfileSearchFixture searchFromStored = new RankProfileSearchFixture(
- storedApplication,
- " rank-profile my_profile {\n" +
- " first-phase {\n" +
- " expression: tensorflow('mnist_softmax/saved', 'serving_default')" +
- " }\n" +
- " }");
+ RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "tensorflow('mnist_softmax/saved')",
+ storedApplication);
searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile");
// Verify that the constants exists, but don't verify the content as we are not
// simulating file distribution in this test
@@ -212,6 +205,30 @@ public class RankingExpressionWithTensorFlowTestCase {
}
}
+ private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) {
+ return fixtureWith(placeholderExpression, firstPhaseExpression, new StoringApplicationPackage(applicationDir));
+ }
+
+ private RankProfileSearchFixture fixtureWith(String placeholderExpression,
+ String firstPhaseExpression,
+ StoringApplicationPackage application) {
+ try {
+ return new RankProfileSearchFixture(
+ application,
+ " rank-profile my_profile {\n" +
+ " macro Placeholder() {\n" +
+ " expression: " + placeholderExpression +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: " + firstPhaseExpression +
+ " }\n" +
+ " }");
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
private static class StoringApplicationPackage extends MockApplicationPackage {
private final File root;
diff --git a/container-core/src/main/java/com/yahoo/container/jdisc/ExtendedResponse.java b/container-core/src/main/java/com/yahoo/container/jdisc/ExtendedResponse.java
index 9b6837a9dcb..b2d32c8e745 100644
--- a/container-core/src/main/java/com/yahoo/container/jdisc/ExtendedResponse.java
+++ b/container-core/src/main/java/com/yahoo/container/jdisc/ExtendedResponse.java
@@ -1,18 +1,17 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.container.jdisc;
-import java.io.IOException;
-import java.io.OutputStream;
-
import com.yahoo.container.handler.Coverage;
import com.yahoo.container.handler.Timing;
import com.yahoo.container.logging.HitCounts;
import com.yahoo.jdisc.handler.CompletionHandler;
import com.yahoo.jdisc.handler.ContentChannel;
+import java.io.IOException;
+import java.io.OutputStream;
+
/**
- * An HTTP response supporting async rendering and extended information for
- * logging.
+ * An HTTP response supporting async rendering and extended information for logging.
*
* @author Steinar Knutsen
*/
diff --git a/container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java b/container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java
index 933751fd9ad..1c2d28c754f 100644
--- a/container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java
+++ b/container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java
@@ -84,16 +84,18 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler {
}
@Override
- protected LoggingCompletionHandler createLoggingCompletionHandler(
- long startTime, long renderStartTime, HttpResponse response,
- HttpRequest httpRequest, ContentChannelOutputStream rendererWiring) {
+ protected LoggingCompletionHandler createLoggingCompletionHandler(long startTime,
+ long renderStartTime,
+ HttpResponse response,
+ HttpRequest httpRequest,
+ ContentChannelOutputStream rendererWiring) {
return new LoggingHandler(startTime, renderStartTime, httpRequest, response, rendererWiring);
}
private static String getClientIP(com.yahoo.jdisc.http.HttpRequest httpRequest) {
SocketAddress clientAddress = httpRequest.getRemoteAddress();
- if (clientAddress == null)
- return "0.0.0.0";
+ if (clientAddress == null) return "0.0.0.0";
+
return clientAddress.toString();
}
@@ -105,24 +107,21 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler {
}
}
- private static String remoteHostAddress(
- com.yahoo.jdisc.http.HttpRequest httpRequest) {
+ private static String remoteHostAddress(com.yahoo.jdisc.http.HttpRequest httpRequest) {
SocketAddress remoteAddress = httpRequest.getRemoteAddress();
- if (remoteAddress == null)
- return "0.0.0.0";
+ if (remoteAddress == null) return "0.0.0.0";
+
if (remoteAddress instanceof InetSocketAddress) {
- return ((InetSocketAddress) remoteAddress).getAddress()
- .getHostAddress();
+ return ((InetSocketAddress) remoteAddress).getAddress().getHostAddress();
} else {
- throw new RuntimeException(
- "Expected remote address of type InetSocketAddress, got "
- + remoteAddress.getClass().getName());
+ throw new RuntimeException("Expected remote address of type InetSocketAddress, got " +
+ remoteAddress.getClass().getName());
}
}
private void logTimes(long startTime, String sourceIP,
- long renderStartTime, long commitStartTime, long endTime,
- String req, String normalizedQuery, Timing t) {
+ long renderStartTime, long commitStartTime, long endTime,
+ String req, String normalizedQuery, Timing t) {
// note: intentionally only taking time since request was received
long totalTime = endTime - startTime;
@@ -140,33 +139,33 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler {
return;
}
- StringBuilder msgbuf = new StringBuilder();
- msgbuf.append(normalizedQuery);
- msgbuf.append(" from ").append(sourceIP).append(". ");
+ StringBuilder b = new StringBuilder();
+ b.append(normalizedQuery);
+ b.append(" from ").append(sourceIP).append(". ");
if (requestOverhead > 0) {
- msgbuf.append("Time from HTTP connection open to request reception ");
- msgbuf.append(requestOverhead).append(" ms. ");
+ b.append("Time from HTTP connection open to request reception ");
+ b.append(requestOverhead).append(" ms. ");
}
if (summaryStartTime != 0) {
- msgbuf.append("Request time: ");
- msgbuf.append(summaryStartTime - startTime).append(" ms. ");
- msgbuf.append("Summary fetch time: ");
- msgbuf.append(renderStartTime - summaryStartTime).append(" ms. ");
+ b.append("Request time: ");
+ b.append(summaryStartTime - startTime).append(" ms. ");
+ b.append("Summary fetch time: ");
+ b.append(renderStartTime - summaryStartTime).append(" ms. ");
} else {
long spentSearching = renderStartTime - startTime;
- msgbuf.append("Processing time: ").append(spentSearching).append(" ms. ");
+ b.append("Processing time: ").append(spentSearching).append(" ms. ");
}
- msgbuf.append("Result rendering/transfer: ");
- msgbuf.append(commitStartTime - renderStartTime).append(" ms. ");
- msgbuf.append("End transaction: ");
- msgbuf.append(endTime - commitStartTime).append(" ms. ");
- msgbuf.append("Total: ").append(totalTime).append(" ms. ");
- msgbuf.append("Timeout: ").append(timeoutInterval).append(" ms. ");
- msgbuf.append("Request string: ").append(req);
+ b.append("Result rendering/transfer: ");
+ b.append(commitStartTime - renderStartTime).append(" ms. ");
+ b.append("End transaction: ");
+ b.append(endTime - commitStartTime).append(" ms. ");
+ b.append("Total: ").append(totalTime).append(" ms. ");
+ b.append("Timeout: ").append(timeoutInterval).append(" ms. ");
+ b.append("Request string: ").append(req);
- log.log(LogLevel.WARNING, "Slow execution. " + msgbuf);
+ log.log(LogLevel.WARNING, "Slow execution. " + b);
}
private static class NullResponse extends ExtendedResponse {
@@ -175,10 +174,11 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler {
}
@Override
- public void render(OutputStream output, ContentChannel networkChannel,
- CompletionHandler handler) throws IOException {
+ public void render(OutputStream output, ContentChannel networkChannel, CompletionHandler handler)
+ throws IOException {
// NOP
}
+
}
private class LoggingHandler implements LoggingCompletionHandler {
@@ -191,9 +191,8 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler {
private final ContentChannelOutputStream rendererWiring;
private final ExtendedResponse extendedResponse;
- LoggingHandler(long startTime, long renderStartTime,
- HttpRequest httpRequest, HttpResponse httpResponse,
- ContentChannelOutputStream rendererWiring) {
+ LoggingHandler(long startTime, long renderStartTime, HttpRequest httpRequest, HttpResponse httpResponse,
+ ContentChannelOutputStream rendererWiring) {
this.startTime = startTime;
this.renderStartTime = renderStartTime;
this.commitStartTime = renderStartTime;
@@ -233,20 +232,19 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler {
}
private void writeToLogs(long endTime) {
- final com.yahoo.jdisc.http.HttpRequest jdiscRequest = httpRequest.getJDiscRequest();
+ com.yahoo.jdisc.http.HttpRequest jdiscRequest = httpRequest.getJDiscRequest();
- logTimes(
- startTime,
- getClientIP(jdiscRequest),
- renderStartTime,
- commitStartTime,
- endTime,
- jdiscRequest.getUri().toString(),
- extendedResponse.getParsedQuery(),
- extendedResponse.getTiming());
+ logTimes(startTime,
+ getClientIP(jdiscRequest),
+ renderStartTime,
+ commitStartTime,
+ endTime,
+ jdiscRequest.getUri().toString(),
+ extendedResponse.getParsedQuery(),
+ extendedResponse.getTiming());
- final Optional<AccessLogEntry> jdiscRequestAccessLogEntry
- = AccessLoggingRequestHandler.getAccessLogEntry(jdiscRequest);
+ Optional<AccessLogEntry> jdiscRequestAccessLogEntry =
+ AccessLoggingRequestHandler.getAccessLogEntry(jdiscRequest);
if (jdiscRequestAccessLogEntry.isPresent()) {
// This means we are running with Jetty, not Netty.
@@ -275,18 +273,17 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler {
}
}
- private void populateAccessLogEntryNotCreatedByHttpServer(
- final AccessLogEntry logEntry,
- final com.yahoo.jdisc.http.HttpRequest httpRequest,
- final Timing timing,
- final String fullRequest,
- final long commitStartTime,
- final long startTime,
- final long written,
- final int status) {
+ private void populateAccessLogEntryNotCreatedByHttpServer(AccessLogEntry logEntry,
+ com.yahoo.jdisc.http.HttpRequest httpRequest,
+ Timing timing,
+ String fullRequest,
+ long commitStartTime,
+ long startTime,
+ long written,
+ int status) {
try {
- final InetSocketAddress remoteAddress = AccessLogUtil.getRemoteAddress(httpRequest);
- final long evalStartTime = getEvalStart(timing, startTime);
+ InetSocketAddress remoteAddress = AccessLogUtil.getRemoteAddress(httpRequest);
+ long evalStartTime = getEvalStart(timing, startTime);
logEntry.setIpV4Address(remoteHostAddress(httpRequest));
logEntry.setTimeStamp(evalStartTime);
logEntry.setDurationBetweenRequestResponse(commitStartTime - evalStartTime);
@@ -302,8 +299,8 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler {
logEntry.setHttpMethod(AccessLogUtil.getHttpMethod(httpRequest));
logEntry.setHttpVersion(AccessLogUtil.getHttpVersion(httpRequest));
} catch (Exception e) {
- log.log(LogLevel.WARNING, "Could not populate the access log ["
- + fullRequest + "]", e);
+ log.log(LogLevel.WARNING, "Could not populate the access log [" + fullRequest + "]", e);
}
}
+
}
diff --git a/container-search/src/main/java/com/yahoo/search/Result.java b/container-search/src/main/java/com/yahoo/search/Result.java
index ded8992fa65..7978798f53c 100644
--- a/container-search/src/main/java/com/yahoo/search/Result.java
+++ b/container-search/src/main/java/com/yahoo/search/Result.java
@@ -46,7 +46,7 @@ public final class Result extends com.yahoo.processing.Response implements Clone
* Headers containing "envelope" meta information to be returned with this result.
* Used for HTTP getHeaders when the return protocol is HTTP.
*/
- private ListMap<String,String> headers=null;
+ private ListMap<String,String> headers = null;
/**
* Result rendering infrastructure.
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java
index 04dd3ee9005..a347fbfb3ab 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java
@@ -43,28 +43,28 @@ import java.util.regex.Pattern;
public class QueryProfile extends FreezableSimpleComponent implements Cloneable {
/** Defines the permissible content of this, or null if any content is permissible */
- private QueryProfileType type=null;
+ private QueryProfileType type = null;
/** The value at this query profile - allows non-fields to have values, e.g a=value1, a.b=value2 */
- private Object value=null;
+ private Object value = null;
/** The variants of this, or null if none */
- private QueryProfileVariants variants=null;
+ private QueryProfileVariants variants = null;
/** The resolved variant dimensions of this, or null if none or not resolved yet (is resolved at freeze) */
- private List<String> resolvedDimensions=null;
+ private List<String> resolvedDimensions = null;
/** The query profiles inherited by this, or null if none */
- private List<QueryProfile> inherited=null;
+ private List<QueryProfile> inherited = null;
/** The content of this profile. The values may be primitives, substitutable strings or other query profiles */
- private CopyOnWriteContent content=new CopyOnWriteContent();
+ private CopyOnWriteContent content = new CopyOnWriteContent();
/**
* Field override settings: fieldName→OverrideValue. These overrides the override
* setting in the type (if any) of this field). If there are no query profile level settings, this is null.
*/
- private Map<String,Boolean> overridable=null;
+ private Map<String,Boolean> overridable = null;
/**
* Creates a new query profile from an id.
@@ -108,26 +108,26 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
/** Adds a profile to the end of the inherited list of this. Throws an exception if this is frozen. */
public void addInherited(QueryProfile profile) {
- addInherited(profile,(DimensionValues)null);
+ addInherited(profile, (DimensionValues)null);
}
public final void addInherited(QueryProfile profile,String[] dimensionValues) {
- addInherited(profile,DimensionValues.createFrom(dimensionValues));
+ addInherited(profile, DimensionValues.createFrom(dimensionValues));
}
/** Adds a profile to the end of the inherited list of this for the given variant. Throws an exception if this is frozen. */
public void addInherited(QueryProfile profile, DimensionValues dimensionValues) {
ensureNotFrozen();
- DimensionBinding dimensionBinding=DimensionBinding.createFrom(getDimensions(),dimensionValues);
+ DimensionBinding dimensionBinding=DimensionBinding.createFrom(getDimensions(), dimensionValues);
if (dimensionBinding.isNull()) {
- if (inherited==null)
- inherited=new ArrayList<>();
+ if (inherited == null)
+ inherited = new ArrayList<>();
inherited.add(profile);
}
else {
- if (variants==null)
- variants=new QueryProfileVariants(dimensionBinding.getDimensions(), this);
+ if (variants == null)
+ variants = new QueryProfileVariants(dimensionBinding.getDimensions(), this);
variants.inherit(profile,dimensionBinding.getValues());
}
}
@@ -152,13 +152,13 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
* @throws IllegalStateException if this is frozen
*/
public Boolean isDeclaredOverridable(String name, Map<String,String> context) {
- return isDeclaredOverridable(new CompoundName(name),DimensionBinding.createFrom(getDimensions(),context));
+ return isDeclaredOverridable(new CompoundName(name), DimensionBinding.createFrom(getDimensions(), context));
}
/** Sets the dimensions over which this may vary. Note: This will erase any currently defined variants */
public void setDimensions(String[] dimensions) {
ensureNotFrozen();
- variants=new QueryProfileVariants(dimensions, this);
+ variants = new QueryProfileVariants(dimensions, this);
}
/** Returns the value set at this node, to allow non-leafs to have values. Returns null if none. */
@@ -166,17 +166,17 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
public void setValue(Object value) {
ensureNotFrozen();
- this.value=value;
+ this.value = value;
}
/** Returns the variant dimensions to be used in this - an unmodifiable list of dimension names */
public List<String> getDimensions() {
if (isFrozen()) return resolvedDimensions;
- if (variants!=null) return variants.getDimensions();
- if (inherited==null) return null;
+ if (variants != null) return variants.getDimensions();
+ if (inherited == null) return null;
for (QueryProfile inheritedProfile : inherited) {
- List<String> inheritedDimensions=inheritedProfile.getDimensions();
- if (inheritedDimensions!=null) return inheritedDimensions;
+ List<String> inheritedDimensions = inheritedProfile.getDimensions();
+ if (inheritedDimensions != null) return inheritedDimensions;
}
return null;
}
@@ -187,8 +187,8 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
* Sets the overridability of a field in this profile,
* this overrides the corresponding setting in the type (if any)
*/
- public final void setOverridable(String fieldName, boolean overridable, Map<String,String> context) {
- setOverridable(new CompoundName(fieldName), overridable,DimensionBinding.createFrom(getDimensions(), context));
+ public final void setOverridable(String fieldName, boolean overridable, Map<String, String> context) {
+ setOverridable(new CompoundName(fieldName), overridable, DimensionBinding.createFrom(getDimensions(), context));
}
/**
@@ -241,8 +241,8 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
Map<String,Object> values=visitor.getResult();
if (substitution==null) return values;
- for (Map.Entry<String,Object> entry : values.entrySet()) {
- if (entry.getValue().getClass()==String.class) continue; // Shortcut
+ for (Map.Entry<String, Object> entry : values.entrySet()) {
+ if (entry.getValue().getClass() == String.class) continue; // Shortcut
if (entry.getValue() instanceof SubstituteString)
entry.setValue(((SubstituteString)entry.getValue()).substitute(context,substitution));
}
@@ -253,7 +253,7 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
* Lists types reachable from this, indexed by the prefix having that type.
* If this is itself typed, this' type will be included with an empty prefix
*/
- Map<CompoundName, QueryProfileType> listTypes(CompoundName prefix, Map<String, String> context) {
+ public Map<CompoundName, QueryProfileType> listTypes(CompoundName prefix, Map<String, String> context) {
DimensionBinding dimensionBinding = DimensionBinding.createFrom(getDimensions(), context);
AllTypesQueryProfileVisitor visitor = new AllTypesQueryProfileVisitor(prefix);
accept(visitor, dimensionBinding, null);
@@ -264,8 +264,8 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
* Lists references reachable from this.
*/
Set<CompoundName> listReferences(CompoundName prefix, Map<String, String> context) {
- DimensionBinding dimensionBinding=DimensionBinding.createFrom(getDimensions(),context);
- AllReferencesQueryProfileVisitor visitor=new AllReferencesQueryProfileVisitor(prefix);
+ DimensionBinding dimensionBinding = DimensionBinding.createFrom(getDimensions(),context);
+ AllReferencesQueryProfileVisitor visitor = new AllReferencesQueryProfileVisitor(prefix);
accept(visitor,dimensionBinding,null);
return visitor.getResult();
}
@@ -292,40 +292,40 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
public final Object get(String name) { return get(name,(Map<String,String>)null); }
/** Returns a value from this using the given property context for resolution and using this for substitution */
- public final Object get(String name, Map<String,String> context) {
- return get(name,context,null);
+ public final Object get(String name, Map<String, String> context) {
+ return get(name, context, null);
}
/** Returns a value from this using the given dimensions for resolution */
public final Object get(String name, String[] dimensionValues) {
- return get(name,dimensionValues,null);
+ return get(name, dimensionValues,null);
}
public final Object get(String name, String[] dimensionValues, Properties substitution) {
- return get(name,DimensionValues.createFrom(dimensionValues),substitution);
+ return get(name, DimensionValues.createFrom(dimensionValues), substitution);
}
/** Returns a value from this using the given dimensions for resolution */
public final Object get(String name, DimensionValues dimensionValues, Properties substitution) {
- return get(name,DimensionBinding.createFrom(getDimensions(),dimensionValues),substitution);
+ return get(name, DimensionBinding.createFrom(getDimensions(), dimensionValues), substitution);
}
public final Object get(String name, Map<String,String> context, Properties substitution) {
- return get(name,DimensionBinding.createFrom(getDimensions(),context),substitution);
+ return get(name, DimensionBinding.createFrom(getDimensions(), context), substitution);
}
public final Object get(CompoundName name, Map<String,String> context, Properties substitution) {
- return get(name,DimensionBinding.createFrom(getDimensions(),context),substitution);
+ return get(name, DimensionBinding.createFrom(getDimensions(), context), substitution);
}
final Object get(String name, DimensionBinding binding,Properties substitution) {
- return get(new CompoundName(name),binding,substitution);
+ return get(new CompoundName(name), binding, substitution);
}
final Object get(CompoundName name, DimensionBinding binding, Properties substitution) {
- Object node=get(name,binding);
- if (node!=null && node.getClass()==String.class) return node; // Shortcut
- if (node instanceof SubstituteString) return ((SubstituteString)node).substitute(binding.getContext(),substitution);
+ Object node = get(name, binding);
+ if (node != null && node.getClass() == String.class) return node; // Shortcut
+ if (node instanceof SubstituteString) return ((SubstituteString)node).substitute(binding.getContext(), substitution);
return node;
}
@@ -431,14 +431,14 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
@Override
public QueryProfile clone() {
if (isFrozen()) return this;
- QueryProfile clone=(QueryProfile)super.clone();
- if (variants !=null)
+ QueryProfile clone = (QueryProfile)super.clone();
+ if (variants != null)
clone.variants = variants.clone();
- if (inherited!=null)
- clone.inherited=new ArrayList<>(inherited);
+ if (inherited != null)
+ clone.inherited = new ArrayList<>(inherited);
- if (this.content!=null)
- clone.content=content.clone();
+ if (this.content != null)
+ clone.content = content.clone();
return clone;
}
@@ -454,7 +454,7 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
/** Throws IllegalArgumentException if the given string is not a valid query profile name */
public static void validateName(String name) {
- Matcher nameMatcher=namePattern.matcher(name);
+ Matcher nameMatcher = namePattern.matcher(name);
if ( ! nameMatcher.matches())
throw new IllegalArgumentException("Illegal name '" + name + "'");
}
@@ -467,7 +467,7 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
setNode(name, value, null, binding, registry);
}
catch (IllegalArgumentException e) {
- throw new IllegalArgumentException("Could not set '" + name + "' to '" + value + "'",e);
+ throw new IllegalArgumentException("Could not set '" + name + "' to '" + value + "'", e);
}
}
@@ -708,19 +708,19 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
/** Do a variant-aware content lookup in this */
protected Object localLookup(String name, DimensionBinding dimensionBinding) {
- Object node=null;
- if ( variants!=null && !dimensionBinding.isNull())
- node=variants.get(name,type,true,dimensionBinding);
- if (node==null)
- node=content==null ? null : content.get(name);
+ Object node = null;
+ if ( variants != null && !dimensionBinding.isNull())
+ node = variants.get(name,type,true,dimensionBinding);
+ if (node == null)
+ node = content == null ? null : content.get(name);
return node;
}
// ----------------- Private ----------------------------------------------------------------------------------
private Boolean isDeclaredOverridable(CompoundName name,DimensionBinding dimensionBinding) {
- QueryProfile parent= lookupParentExact(name, true, dimensionBinding);
- if (parent.overridable==null) return null;
+ QueryProfile parent = lookupParentExact(name, true, dimensionBinding);
+ if (parent.overridable == null) return null;
return parent.overridable.get(name.last());
}
@@ -729,10 +729,10 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
* this overrides the corresponding setting in the type (if any)
*/
private void setOverridable(CompoundName fieldName,boolean overridable,DimensionBinding dimensionBinding) {
- QueryProfile parent= lookupParentExact(fieldName, true, dimensionBinding);
- if (parent.overridable==null)
- parent.overridable=new HashMap<>();
- parent.overridable.put(fieldName.last(),overridable);
+ QueryProfile parent = lookupParentExact(fieldName, true, dimensionBinding);
+ if (parent.overridable == null)
+ parent.overridable = new HashMap<>();
+ parent.overridable.put(fieldName.last(), overridable);
}
/** Sets a value to a (possibly non-local) node. The parent query profile holding the value is returned */
@@ -740,7 +740,7 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
DimensionBinding dimensionBinding, QueryProfileRegistry registry) {
ensureNotFrozen();
if (name.isCompound()) {
- QueryProfile parent= getQueryProfileExact(name.first(), true, dimensionBinding);
+ QueryProfile parent = getQueryProfileExact(name.first(), true, dimensionBinding);
parent.setNode(name.rest(), value,parentType, dimensionBinding.createFor(parent.getDimensions()), registry);
}
else {
@@ -773,8 +773,8 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
* @return the created profile, or null if not present, and create is false
*/
private QueryProfile getQueryProfileExact(String localName, boolean create, DimensionBinding dimensionBinding) {
- Object node=localExactLookup(localName, dimensionBinding);
- if (node!=null && node instanceof QueryProfile) {
+ Object node = localExactLookup(localName, dimensionBinding);
+ if (node != null && node instanceof QueryProfile) {
return (QueryProfile)node;
}
if (!create) return null;
@@ -826,7 +826,7 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
}
}
- private static final Pattern namePattern=Pattern.compile("[$a-zA-Z_/][-$a-zA-Z0-9_/()]*");
+ private static final Pattern namePattern = Pattern.compile("[$a-zA-Z_/][-$a-zA-Z0-9_/()]*");
/**
* Returns a compiled version of this which produces faster lookup times
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java b/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java
index 46aa33174f7..095d83d2887 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java
@@ -39,8 +39,8 @@ public class QueryProfileXMLReader {
* @throws RuntimeException if <code>directory</code> is not a readable directory, or if there is some error in the XML
*/
public QueryProfileRegistry read(String directory) {
- List<NamedReader> queryProfileReaders=new ArrayList<>();
- List<NamedReader> queryProfileTypeReaders=new ArrayList<>();
+ List<NamedReader> queryProfileReaders = new ArrayList<>();
+ List<NamedReader> queryProfileTypeReaders = new ArrayList<>();
try {
File dir=new File(directory);
if ( !dir.isDirectory() ) throw new IllegalArgumentException("Could not read query profiles: '" +
@@ -86,16 +86,16 @@ public class QueryProfileXMLReader {
* Read the XML file readers into a registry. This does not close the readers.
* This method is used directly from the admin system.
*/
- public QueryProfileRegistry read(List<NamedReader> queryProfileTypeReaders,List<NamedReader> queryProfileReaders) {
- QueryProfileRegistry registry=new QueryProfileRegistry();
+ public QueryProfileRegistry read(List<NamedReader> queryProfileTypeReaders, List<NamedReader> queryProfileReaders) {
+ QueryProfileRegistry registry = new QueryProfileRegistry();
// Phase 1
- List<Element> queryProfileTypeElements=createQueryProfileTypes(queryProfileTypeReaders,registry.getTypeRegistry());
- List<Element> queryProfileElements=createQueryProfiles(queryProfileReaders,registry);
+ List<Element> queryProfileTypeElements = createQueryProfileTypes(queryProfileTypeReaders, registry.getTypeRegistry());
+ List<Element> queryProfileElements = createQueryProfiles(queryProfileReaders, registry);
// Phase 2
- fillQueryProfileTypes(queryProfileTypeElements,registry.getTypeRegistry());
- fillQueryProfiles(queryProfileElements,registry);
+ fillQueryProfileTypes(queryProfileTypeElements, registry.getTypeRegistry());
+ fillQueryProfiles(queryProfileElements, registry);
return registry;
}
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/QueryProfileType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/QueryProfileType.java
index 305de2a3c70..c826d834d47 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/QueryProfileType.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/QueryProfileType.java
@@ -102,13 +102,14 @@ public class QueryProfileType extends FreezableSimpleComponent {
*
* @throws IllegalStateException if this is frozen
*/
- public Map<String,FieldDescription> declaredFields() {
+ public Map<String, FieldDescription> declaredFields() {
ensureNotFrozen();
return Collections.unmodifiableMap(fields);
}
/**
* Returns true if <i>this</i> is declared strict.
+ *
* @throws IllegalStateException if this is frozen
*/
public boolean isDeclaredStrict() {
@@ -118,6 +119,7 @@ public class QueryProfileType extends FreezableSimpleComponent {
/**
* Returns true if <i>this</i> is declared as match as path.
+ *
* @throws IllegalStateException if this is frozen
*/
public boolean getDeclaredMatchAsPath() {
diff --git a/container-search/src/main/java/com/yahoo/search/statistics/ElapsedTime.java b/container-search/src/main/java/com/yahoo/search/statistics/ElapsedTime.java
index 82438bda34e..804ce1b496b 100644
--- a/container-search/src/main/java/com/yahoo/search/statistics/ElapsedTime.java
+++ b/container-search/src/main/java/com/yahoo/search/statistics/ElapsedTime.java
@@ -12,13 +12,11 @@ import java.util.Set;
import static com.yahoo.search.statistics.TimeTracker.Activity.*;
/**
- * Basically a collection of TimeTracker instances.
+ * A collection of TimeTracker instances.
*
- * <p>This class may need a lot of restructuring as actual
- * needs are mapped out.
- *
- * @author <a href="steinar@yahoo-inc.com">Steinar Knutsen</a>
+ * @author Steinar Knutsen
*/
+// This class may need a lot of restructuring as actual needs are mapped out.
public class ElapsedTime {
// An identity set is used to make it safe to do multiple merges. This may happen if
@@ -35,16 +33,11 @@ public class ElapsedTime {
private long fetcher(Activity toFetch, TimeTracker fetchFrom) {
switch (toFetch) {
- case SEARCH:
- return fetchFrom.searchTime();
- case FILL:
- return fetchFrom.fillTime();
- case PING:
- return fetchFrom.pingTime();
- default:
- return 0L;
+ case SEARCH: return fetchFrom.searchTime();
+ case FILL: return fetchFrom.fillTime();
+ case PING: return fetchFrom.pingTime();
+ default: return 0L;
}
-
}
/**
@@ -232,4 +225,5 @@ public class ElapsedTime {
report.append(".");
return report.toString();
}
+
}
diff --git a/container-search/src/main/java/com/yahoo/search/statistics/TimeTracker.java b/container-search/src/main/java/com/yahoo/search/statistics/TimeTracker.java
index 6112bc504d3..d2461dffc7a 100644
--- a/container-search/src/main/java/com/yahoo/search/statistics/TimeTracker.java
+++ b/container-search/src/main/java/com/yahoo/search/statistics/TimeTracker.java
@@ -1,25 +1,24 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.statistics;
-import java.util.ArrayList;
-import java.util.EnumMap;
-import java.util.List;
-import java.util.Map;
-
import com.yahoo.component.chain.Chain;
import com.yahoo.prelude.Pong;
import com.yahoo.processing.Processor;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
+import java.util.ArrayList;
+import java.util.EnumMap;
+import java.util.List;
+import java.util.Map;
+
/**
- * A container for storing time stamps throughout the
- * lifetime of an Execution instance.
+ * A container for storing time stamps throughout the lifetime of an Execution instance.
*
* <p>Check state both when entering and exiting, to allow for arbitrary
* new queries anywhere inside a search chain.
*
- * @author <a href="mailto:steinar@yahoo-inc.com">Steinar Knutsen</a>
+ * @author Steinar Knutsen
*/
public final class TimeTracker {
@@ -214,8 +213,6 @@ public final class TimeTracker {
}
concludeState(now);
initNewState(now, activity);
- } else {
- return;
}
}
@@ -314,7 +311,7 @@ public final class TimeTracker {
return typedSum(Activity.PING);
}
- private long returnfromState(int searcherIndex, boolean detailed) {
+ private long returnFromState(int searcherIndex, boolean detailed) {
if (detailed) {
return detailedMeasurements(searcherIndex, false);
} else {
@@ -350,7 +347,7 @@ public final class TimeTracker {
}
private void sampleReturn(int searcherIndex, boolean detailed, ElapsedTime elapsed) {
- long now = returnfromState(searcherIndex, detailed);
+ long now = returnFromState(searcherIndex, detailed);
if (searcherIndex == entryIndex) {
concludeStateOnExit(now);
if (elapsed != null) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
index 23dd841b0ef..5f8daa69ecf 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.tensor.TensorType;
import java.util.Arrays;
@@ -48,12 +49,11 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
*
* @throws IllegalArgumentException if the name is not present in the ranking expression this was created with, and
* ignoredUnknownValues is false
- * @since 5.1.5
*/
@Override
public final void put(String name, Value value) {
Integer index = nameToIndex().get(name);
- if (index==null) {
+ if (index == null) {
if (ignoreUnknownValues())
return;
else
@@ -70,24 +70,29 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
/**
* Puts a value by index.
* The value will be frozen if it isn't already.
- *
- * @since 5.1.5
*/
public final void put(int index, Value value) {
- values[index]=value.freeze();
+ values[index] = value.freeze();
try {
- doubleValues()[index]=value.asDouble();
+ doubleValues()[index] = value.asDouble();
}
catch (UnsupportedOperationException e) {
- doubleValues()[index]=Double.NaN; // see getDouble below
+ doubleValues()[index] = Double.NaN; // see getDouble below
}
}
+ @Override
+ public TensorType getType(String name) {
+ Integer index = nameToIndex().get(name);
+ if (index == null) return null;
+ return values[index].type();
+ }
+
/** Perform a slow lookup by name */
@Override
public Value get(String name) {
- Integer index=nameToIndex().get(name);
- if (index==null) return DoubleValue.zero;
+ Integer index = nameToIndex().get(name);
+ if (index == null) return DoubleValue.zero;
return values[index];
}
@@ -100,8 +105,8 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
/** Perform a fast lookup directly of the value as a double. This is faster than get(index).asDouble() */
@Override
public final double getDouble(int index) {
- double value=doubleValues()[index];
- if (value==Double.NaN)
+ double value = doubleValues()[index];
+ if (value == Double.NaN)
throw new UnsupportedOperationException("Value at " + index + " has no double representation");
return value;
}
@@ -111,7 +116,7 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
* in a different thread (i.e, name name to index map, different value set.
*/
public ArrayContext clone() {
- ArrayContext clone=(ArrayContext)super.clone();
+ ArrayContext clone = (ArrayContext)super.clone();
clone.values = new Value[nameToIndex().size()];
Arrays.fill(values,constantZero);
return clone;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index 0eeb0a9e630..861f9565d66 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -17,7 +17,7 @@ import java.util.stream.Collectors;
public abstract class Context implements EvaluationContext {
/**
- * <p>Returns the value of a simple variable name.</p>
+ * Returns the value of a simple variable name.
*
* @param name the name of the variable whose value to return.
* @return the value of the named variable.
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index 2ef4a2ede2f..3ac11cff0cb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -4,18 +4,19 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
/**
* A value which acts as a double in numerical context.
*
* @author bratseth
- * @since 5.1.21
*/
public abstract class DoubleCompatibleValue extends Value {
@Override
+ public TensorType type() { return TensorType.empty; }
+
+ @Override
public boolean hasDouble() { return true; }
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
index ceec9358b3c..0625e8506cc 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.tensor.TensorType;
/**
* A variant of an array context variant which supports faster binding of variables but slower lookup
@@ -38,7 +39,6 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext {
*
* @throws IllegalArgumentException if the name is not present in the ranking expression this was created with, and
* ignoredUnknownValues is false
- * @since 5.1.5
*/
@Override
public final void put(String name, Value value) {
@@ -57,11 +57,7 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext {
doubleValues()[index] = value;
}
- /**
- * Puts a value by index.
- *
- * @since 5.1.5
- */
+ /** Puts a value by index. */
public final void put(int index, Value value) {
try {
put(index, value.asDouble());
@@ -71,6 +67,9 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext {
}
}
+ @Override
+ public TensorType getType(String name) { return TensorType.empty; }
+
/** Perform a slow lookup by name */
@Override
public Value get(String name) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
index 09895a0c2f6..39efe641f26 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.evaluation;
+import com.yahoo.tensor.TensorType;
+
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@@ -13,7 +15,7 @@ import java.util.Set;
*/
public class MapContext extends Context {
- private Map<String,Value> bindings=new HashMap<>();
+ private Map<String, Value> bindings = new HashMap<>();
private boolean frozen = false;
@@ -21,16 +23,6 @@ public class MapContext extends Context {
}
/**
- * Freezes this.
- * Returns this for convenience.
- */
- public MapContext freeze() {
- if ( ! frozen)
- bindings = Collections.unmodifiableMap(bindings);
- return this;
- }
-
- /**
* Creates a map context from a map.
* The ownership of the map is transferred to this - it cannot be further modified by the caller.
* All the Values of the map will be frozen.
@@ -42,20 +34,31 @@ public class MapContext extends Context {
}
/**
- * Returns the value of a key. 0 is returned if the given key is not bound in this.
+ * Freezes this.
+ * Returns this for convenience.
*/
+ public MapContext freeze() {
+ if ( ! frozen)
+ bindings = Collections.unmodifiableMap(bindings);
+ return this;
+ }
+
+ /** Returns the type of the given value key, or null if it is not bound. */
@Override
- public Value get(String key) {
+ public TensorType getType(String key) {
Value value = bindings.get(key);
- if (value == null) return DoubleValue.zero;
- return value;
+ if (value == null) return null;
+ return value.type();
+ }
+
+ /** Returns the value of a key. 0 is returned if the given key is not bound in this. */
+ @Override
+ public Value get(String key) {
+ return bindings.getOrDefault(key, DoubleValue.zero);
}
/**
- * Sets the value of a key.
- * The value is frozen by this.
- *
- * @since 5.1.5
+ * Sets the value of a key.The value is frozen by this.
*/
@Override
public void put(String key,Value value) {
@@ -67,7 +70,7 @@ public class MapContext extends Context {
if (frozen) return bindings;
return Collections.unmodifiableMap(bindings);
}
-
+
/** Returns a new, modifiable context containing all the bindings of this */
public MapContext thawedCopy() { return new MapContext(new HashMap<>(bindings)); }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index dad69b31181..c60507310f1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -5,7 +5,6 @@ import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
/**
@@ -30,6 +29,9 @@ public class StringValue extends Value {
this.value = value;
}
+ @Override
+ public TensorType type() { return TensorType.empty; }
+
/** Returns the hashcode of this, to enable strings to be encoded (with reasonable safely) as doubles for optimization */
@Override
public double asDouble() {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index 26c30fe5ed2..c6e456f285d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -6,6 +6,7 @@ import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
/**
* A Value containing a tensor.
@@ -25,6 +26,9 @@ public class TensorValue extends Value {
}
@Override
+ public TensorType type() { return TensorType.empty; }
+
+ @Override
public double asDouble() {
if (hasDouble())
return value.get(TensorAddress.of());
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java
new file mode 100644
index 00000000000..a018aae0c3e
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java
@@ -0,0 +1,28 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A context which only contains type information.
+ *
+ * @author bratseth
+ */
+public class TypeMapContext implements TypeContext {
+
+ private final Map<String, TensorType> featureTypes = new HashMap<>();
+
+ public void setType(String name, TensorType type) {
+ featureTypes.put(name, type);
+ }
+
+ @Override
+ public TensorType getType(String name) {
+ return featureTypes.get(name);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 40d70e0022c..59d2d95b879 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -13,12 +13,15 @@ import com.yahoo.tensor.TensorType;
* Concrete subclasses of this provides implementations of these methods or throws
* UnsupportedOperationException if the operation is not supported.
*
- * @author bratseth
+ * @author bratseth
*/
public abstract class Value {
private boolean frozen=false;
+ /** Returns the type of this value */
+ public abstract TensorType type();
+
/** Returns this value as a double, or throws UnsupportedOperationException if it cannot be represented as a double */
public abstract double asDouble();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
index 372fb00431b..8ee4cdbf297 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
@@ -7,6 +7,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Deque;
@@ -24,6 +26,9 @@ public class GBDTForestNode extends ExpressionNode {
}
@Override
+ public final TensorType type(TypeContext context) { return TensorType.empty; }
+
+ @Override
public final Value evaluate(Context context) {
int pc = 0;
double treeSum = 0;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
index 4d7b4835892..aac635b2545 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
@@ -7,6 +7,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Deque;
@@ -49,6 +51,9 @@ public final class GBDTNode extends ExpressionNode {
public final double[] values() { return values; }
@Override
+ public final TensorType type(TypeContext context) { return TensorType.empty; }
+
+ @Override
public final Value evaluate(Context context) {
return new DoubleValue(evaluate(values,0,context));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index 816ef38e128..cdcb4df0360 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -147,14 +147,12 @@ class OperationMapper {
return operation.get().map(params);
}
params.signature().importWarning("TensorFlow operation '" + params.node().getOp() +
- "' in node '" + params.node().getName() + "' is not supported.");
+ "' in node '" + params.node().getName() + "' is not supported.");
return Optional.empty();
}
- /*
- * Operations
- */
+ // Operations ---------------------------------
private static Optional<TypedTensorFunction> constant(TensorFlowImporter.Parameters params) {
Tensor value = AttrValueConverter.toVespaTensor(params.node(), "value");
@@ -209,10 +207,11 @@ class OperationMapper {
TensorType type = params.result().arguments().get(name);
if (type == null) {
throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name +
- "', but there is no such placeholder");
+ "', but there is no such placeholder");
}
+ params.result().requiredMacro(name, type);
// Included literally in the expression and so must be produced by a separate macro in the rank profile
- TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(name));
+ TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(name, type));
return Optional.of(output);
}
@@ -227,7 +226,7 @@ class OperationMapper {
}
private static Optional<TypedTensorFunction> reshape(TensorFlowImporter.Parameters params) {
- if (!checkInputs(params, 2)) {
+ if ( ! checkInputs(params, 2)) {
return Optional.empty();
}
List<Optional<TypedTensorFunction>> inputs = params.inputs();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 3a6b3f23a1d..6d78b501fdc 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -14,7 +14,6 @@ import org.tensorflow.framework.TensorInfo;
import java.io.File;
import java.io.IOException;
-import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -109,10 +108,10 @@ public class TensorFlowImporter {
}
Optional<TypedTensorFunction> function = OperationMapper.map(params);
- if (!function.isPresent()) {
+ if ( ! function.isPresent()) {
return Optional.empty();
}
- if (!controlDependenciesArePresent(params)) {
+ if ( ! controlDependenciesArePresent(params)) {
return Optional.empty();
}
params.imported().put(nodeName, function.get());
@@ -185,6 +184,7 @@ public class TensorFlowImporter {
/** Parameter object to hold important data while importing */
static final class Parameters {
+
private final TensorFlowImporter owner;
private final GraphDef graph;
private final SavedModelBundle model;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
index 60aaf8ddce1..fe725e50a3f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
@@ -27,11 +27,13 @@ public class TensorFlowModel {
private final Map<String, Tensor> constants = new HashMap<>();
private final Map<String, RankingExpression> expressions = new HashMap<>();
private final Map<String, RankingExpression> macros = new HashMap<>();
+ private final Map<String, TensorType> requiredMacros = new HashMap<>();
void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
void constant(String name, Tensor constant) { constants.put(name, constant); }
void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
void macro(String name, RankingExpression expression) { macros.put(name, expression); }
+ void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
/** Returns the given signature. If it does not already exist it is added to this. */
Signature signature(String name) {
@@ -51,11 +53,12 @@ public class TensorFlowModel {
*/
public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
- /**
- * Returns an immutable map of expressions that can be overridden - such as PlaceholderWithDefault/
- */
+ /** Returns an immutable map of macros that are part of this model */
public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); }
+ /** Returns an immutable map of the macros that must be provided by the environment running this model */
+ public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); }
+
/** Returns an immutable map of the signatures of this */
public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
index 518a15bcc87..fc6428a4c33 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
@@ -4,8 +4,15 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.functions.Join;
-import java.util.*;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.Iterator;
+import java.util.List;
/**
* A binary mathematical operation
@@ -73,14 +80,26 @@ public final class ArithmeticNode extends CompositeNode {
}
@Override
+ public TensorType type(TypeContext context) {
+ // Compute type using tensor types as arithmetic operators are supported on tensors
+ // and is correct also in the special case of doubles.
+ // As all our functions are type-commutative, we don't need to take operator precedence into account
+ TensorType type = children.get(0).type(context);
+ for (int i = 1; i < children.size(); i++)
+ type = Join.outputType(type, children.get(i).type(context));
+ return type;
+ }
+
+ @Override
public Value evaluate(Context context) {
Iterator<ExpressionNode> child = children.iterator();
+ // Apply in precedence order:
Deque<ValueItem> stack = new ArrayDeque<>();
stack.push(new ValueItem(ArithmeticOperator.OR, child.next().evaluate(context)));
for (Iterator<ArithmeticOperator> it = operators.iterator(); it.hasNext() && child.hasNext();) {
ArithmeticOperator op = it.next();
- if (!stack.isEmpty()) {
+ if ( ! stack.isEmpty()) {
while (stack.peek().op.hasPrecedenceOver(op)) {
popStack(stack);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
index 9484f789169..7601c0e6180 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
@@ -1,11 +1,14 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
-import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.List;
/**
* A node which returns the outcome of a comparison.
@@ -46,6 +49,11 @@ public class ComparisonNode extends BooleanNode {
}
@Override
+ public TensorType type(TypeContext context) {
+ return TensorType.empty; // by definition
+ }
+
+ @Override
public Value evaluate(Context context) {
Value leftValue = leftCondition.evaluate(context);
Value rightValue = rightCondition.evaluate(context);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
index cd473ae6a6f..1ea8d03f0eb 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
@@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Deque;
@@ -47,6 +49,9 @@ public final class ConstantNode extends ExpressionNode {
}
@Override
+ public TensorType type(TypeContext context) { return value.type(); }
+
+ @Override
public Value evaluate(Context context) {
return value;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
index b5d7c41d698..fd9fab99db8 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
@@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
@@ -48,6 +50,11 @@ public final class EmbracedNode extends CompositeNode {
}
@Override
+ public TensorType type(TypeContext context) {
+ return value.type(context);
+ }
+
+ @Override
public CompositeNode setChildren(List<ExpressionNode> newChildren) {
if (newChildren.size() != 1)
throw new IllegalArgumentException("Expected 1 child but got " + newChildren.size());
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
index eb303fc6446..477f4db4981 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
@@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.io.Serializable;
import java.util.Deque;
@@ -41,6 +43,14 @@ public abstract class ExpressionNode implements Serializable {
public abstract String toString(SerializationContext context, Deque<String> path, CompositeNode parent);
/**
+ * Returns the type this will return if evaluated with the given context.
+ *
+ * @param context the variable type bindings to use for this evaluation
+ * @throws IllegalArgumentException if there are variables which are not bound in the given map
+ */
+ public abstract TensorType type(TypeContext context);
+
+ /**
* Returns the value of evaluating this expression over the given context.
*
* @param context the variable bindings to use for this evaluation
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
index 142e282e5c6..79515229019 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
@@ -4,6 +4,9 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.functions.Join;
import java.util.ArrayList;
import java.util.Collections;
@@ -64,16 +67,29 @@ public final class FunctionNode extends CompositeNode {
}
@Override
+ public TensorType type(TypeContext context) {
+ if (arguments.expressions().size() == 0)
+ return TensorType.empty;
+
+ TensorType argument1Type = arguments.expressions().get(0).type(context);
+ if (arguments.expressions().size() == 1)
+ return argument1Type;
+
+ TensorType argument2Type = arguments.expressions().get(1).type(context);
+ return Join.outputType(argument1Type, argument2Type);
+ }
+
+ @Override
public Value evaluate(Context context) {
if (arguments.expressions().size() == 0)
- return DoubleValue.zero.function(function,DoubleValue.zero);
+ return DoubleValue.zero.function(function ,DoubleValue.zero);
Value argument1 = arguments.expressions().get(0).evaluate(context);
if (arguments.expressions().size() == 1)
return argument1.function(function, DoubleValue.zero);
Value argument2 = arguments.expressions().get(1).evaluate(context);
- return argument1.function(function,argument2);
+ return argument1.function(function, argument2);
}
/** Returns a new function node with the children replaced by the given children */
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
index 9da1ba40144..e42884ecc05 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
@@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
@@ -46,6 +47,9 @@ public class GeneratorLambdaFunctionNode extends CompositeNode {
return generator.toString(context, path, this);
}
+ @Override
+ public TensorType type(TypeContext context) { return type; }
+
/** Evaluate this in a context which must have the arguments bound */
@Override
public Value evaluate(Context context) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
index 1b429de0be5..076df327044 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
@@ -3,13 +3,18 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
/**
* A conditional branch of a ranking expression.
*
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
* @author bratseth
*/
public final class IfNode extends CompositeNode {
@@ -70,6 +75,17 @@ public final class IfNode extends CompositeNode {
}
@Override
+ public TensorType type(TypeContext context) {
+ TensorType trueType = trueExpression.type(context);
+ TensorType falseType = falseExpression.type(context);
+ if ( ! trueType.equals(falseType))
+ throw new IllegalArgumentException("An if expression must produce a value of the same type in both " +
+ "alternatives, but the 'true' type is " + trueType + " while the " +
+ "'false' type is " + falseType);
+ return trueType;
+ }
+
+ @Override
public Value evaluate(Context context) {
if (condition.evaluate(context).asBoolean())
return trueExpression.evaluate(context);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index 78206d75d0d..da946228291 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -5,6 +5,8 @@ import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
@@ -14,20 +16,20 @@ import java.util.function.DoubleUnaryOperator;
/**
* A free, parametrized function
- *
+ *
* @author bratseth
*/
public class LambdaFunctionNode extends CompositeNode {
private final ImmutableList<String> arguments;
private final ExpressionNode functionExpression;
-
+
public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) {
// TODO: Verify that the function only accesses the given arguments
this.arguments = ImmutableList.copyOf(arguments);
this.functionExpression = functionExpression;
}
-
+
@Override
public List<ExpressionNode> children() {
return Collections.singletonList(functionExpression);
@@ -54,19 +56,24 @@ public class LambdaFunctionNode extends CompositeNode {
return b.toString();
}
+ @Override
+ public TensorType type(TypeContext context) {
+ return TensorType.empty; // by definition - no nested lambdas
+ }
+
/** Evaluate this in a context which must have the arguments bound */
@Override
public Value evaluate(Context context) {
return functionExpression.evaluate(context);
}
-
- /**
+
+ /**
* Returns this as a double unary operator
- *
- * @throws IllegalStateException if this has more than one argument
+ *
+ * @throws IllegalStateException if this has more than one argument
*/
public DoubleUnaryOperator asDoubleUnaryOperator() {
- if (arguments.size() > 1)
+ if (arguments.size() > 1)
throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " +
"Must have at most one argument " + " but has " + arguments);
return new DoubleUnaryLambda();
@@ -93,7 +100,7 @@ public class LambdaFunctionNode extends CompositeNode {
context.put(arguments.get(0), operand);
return evaluate(context).asDouble();
}
-
+
@Override
public String toString() {
return LambdaFunctionNode.this.toString();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
index 69df572272a..f55ed59b65c 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
@@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Deque;
@@ -30,6 +32,9 @@ public final class NameNode extends ExpressionNode {
}
@Override
+ public TensorType type(TypeContext context) { throw new RuntimeException("Named nodes can not have a type"); }
+
+ @Override
public Value evaluate(Context context) {
throw new RuntimeException("Name nodes should never be evaluated");
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
index 61c20a97b64..9cbe5f98c72 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
@@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
@@ -36,6 +38,11 @@ public class NegativeNode extends CompositeNode {
}
@Override
+ public TensorType type(TypeContext context) {
+ return value.type(context);
+ }
+
+ @Override
public Value evaluate(Context context) {
return value.evaluate(context).negate();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
index 8c459a032bd..e7041600635 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
@@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
@@ -36,6 +38,11 @@ public class NotNode extends BooleanNode {
}
@Override
+ public TensorType type(TypeContext context) {
+ return value.type(context);
+ }
+
+ @Override
public Value evaluate(Context context) {
return value.evaluate(context).not();
}
@@ -45,6 +52,6 @@ public class NotNode extends BooleanNode {
if (children.size() != 1) throw new IllegalArgumentException("Expected 1 children but got " + children.size());
return new NotNode(children.get(0));
}
-
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
index 139709998b4..f79297f7773 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
@@ -5,6 +5,8 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayDeque;
import java.util.Deque;
@@ -44,9 +46,8 @@ public final class ReferenceNode extends CompositeNode {
return new ReferenceNode(name, arguments, output);
}
- public String getOutput() {
- return output;
- }
+ /** Returns the specific output this references, or null if none specified */
+ public String getOutput() { return output; }
/** Returns a copy of this node with a modified output */
public ReferenceNode setOutput(String output) {
@@ -105,8 +106,14 @@ public final class ReferenceNode extends CompositeNode {
}
@Override
+ public TensorType type(TypeContext context) {
+ // Don't support outputs of different type, for simplicity
+ return context.getType(name);
+ }
+
+ @Override
public Value evaluate(Context context) {
- if (arguments.expressions().size()==0 && output==null)
+ if (arguments.expressions().isEmpty() && output == null)
return context.get(name);
return context.get(name, arguments, output);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
index f6b1a1a8979..a7b82f4753f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
@@ -7,6 +7,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayList;
import java.util.Deque;
@@ -58,6 +60,11 @@ public class SetMembershipNode extends BooleanNode {
}
@Override
+ public TensorType type(TypeContext context) {
+ return TensorType.empty;
+ }
+
+ @Override
public Value evaluate(Context context) {
Value value = testValue.evaluate(context);
if (value instanceof TensorValue) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index b42570d3aea..e4c381972e9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -6,7 +6,9 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
@@ -35,7 +37,7 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public List<ExpressionNode> children() {
- return function.functionArguments().stream()
+ return function.arguments().stream()
.map(this::toExpressionNode)
.collect(Collectors.toList());
}
@@ -52,7 +54,7 @@ public class TensorFunctionNode extends CompositeNode {
List<TensorFunction> wrappedChildren = children.stream()
.map(TensorFunctionExpressionNode::new)
.collect(Collectors.toList());
- return new TensorFunctionNode(function.replaceArguments(wrappedChildren));
+ return new TensorFunctionNode(function.withArguments(wrappedChildren));
}
@Override
@@ -62,6 +64,9 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
+ public TensorType type(TypeContext context) { return function.type(context); }
+
+ @Override
public Value evaluate(Context context) {
return new TensorValue(function.evaluate(context));
}
@@ -84,7 +89,7 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public List<TensorFunction> functionArguments() {
+ public List<TensorFunction> arguments() {
if (expression instanceof CompositeNode)
return ((CompositeNode)expression).children().stream()
.map(TensorFunctionExpressionNode::new)
@@ -94,7 +99,7 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if (arguments.size() == 0) return this;
List<ExpressionNode> unwrappedChildren = arguments.stream()
.map(arg -> ((TensorFunctionExpressionNode)arg).expression)
@@ -106,12 +111,17 @@ public class TensorFunctionNode extends CompositeNode {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
+ public TensorType type(TypeContext context) {
+ return expression.type(context);
+ }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
Value result = expression.evaluate((Context)context);
if ( ! ( result instanceof TensorValue))
throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " +
"but this returns " + result + ", not a tensor");
- return ((TensorValue)result).asTensor();
+ return result.asTensor();
}
@Override
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
index b59b4750911..445ccf231a7 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
@@ -2,10 +2,12 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.tensor.TensorType;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
/**
* @author lesters
@@ -15,6 +17,18 @@ public class DropoutImportTestCase {
@Test
public void testDropoutImport() {
TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/dropout/saved");
+
+ // Check (provided) macros
+ assertEquals(1, model.get().macros().size());
+ assertTrue(model.get().macros().containsKey("training/input"));
+ assertEquals("constant(\"training/input\")", model.get().macros().get("training/input").getRoot().toString());
+
+ // Check required macros
+ assertEquals(1, model.get().requiredMacros().size());
+ assertTrue(model.get().requiredMacros().containsKey("X"));
+ assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
+ model.get().requiredMacros().get("X"));
+
TensorFlowModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
index f12b9a2c628..01dd15d5fa0 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
@@ -8,6 +8,7 @@ import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
/**
* @author bratseth
@@ -33,6 +34,15 @@ public class MnistSoftmaxImportTestCase {
constant1.type());
assertEquals(10, constant1.size());
+ // Check (provided) macros
+ assertEquals(0, model.get().macros().size());
+
+ // Check required macros
+ assertEquals(1, model.get().requiredMacros().size());
+ assertTrue(model.get().requiredMacros().containsKey("Placeholder"));
+ assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
+ model.get().requiredMacros().get("Placeholder"));
+
// Check signatures
assertEquals(1, model.get().signatures().size());
TensorFlowModel.Signature signature = model.get().signatures().get("serving_default");
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java
index cfa9ce670f6..6095134b7a2 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java
@@ -15,14 +15,14 @@ import static com.yahoo.vespa.http.client.SessionFactory.createTimeoutExecutor;
public class FeedClientFactory {
/**
- * Creates FeedClient.
+ * Creates a FeedClient.
+ *
* @param sessionParams parameters for connection, hosts, cluster configurations and more.
* @param resultCallback on each result, this callback is called.
* @return newly created FeedClient API object.
*/
- public static FeedClient create(
- SessionParams sessionParams,
- FeedClient.ResultCallback resultCallback) {
+ public static FeedClient create(SessionParams sessionParams, FeedClient.ResultCallback resultCallback) {
return new FeedClientImpl(sessionParams, resultCallback, createTimeoutExecutor());
}
+
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java
index 138be61de80..65f56f72a58 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java
@@ -20,6 +20,7 @@ import java.util.List;
*/
// This should be an interface, but in order to be binary compatible during refactoring we made it abstract.
public class Result {
+
public enum ResultType {
OPERATION_EXECUTED,
TRANSITIVE_ERROR,
@@ -106,12 +107,10 @@ public class Result {
/**
* Information in a Result for a single operation sent to a single endpoint.
- *
- * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a>
- * @since 5.1.20
*/
@Immutable
public static final class Detail {
+
private final ResultType resultType;
private final Endpoint endpoint;
private final Exception exception;
@@ -212,4 +211,5 @@ public class Result {
}
return b.toString();
}
+
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/EndpointResult.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/EndpointResult.java
index 5aec46a8fc7..b04248f98a5 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/EndpointResult.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/EndpointResult.java
@@ -5,9 +5,11 @@ import com.yahoo.vespa.http.client.Result;
/**
* Result from a single endpoint.
+ *
* @author dybis
*/
public class EndpointResult {
+
private final String operationId;
private final Result.Detail detail;
@@ -23,4 +25,5 @@ public class EndpointResult {
public Result.Detail getDetail() {
return detail;
}
+
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ErrorCode.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ErrorCode.java
index 96afc537c59..445ad5295c1 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ErrorCode.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ErrorCode.java
@@ -6,12 +6,12 @@ import com.google.common.annotations.Beta;
/**
* Return types for the server.
*
- * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a>
- * @author <a href="mailto:steinar@yahoo-inc.com">Steinar Knutsen</a>
- * @since 5.1.20
+ * @author Einar M R Rosenvinge
+ * @author Steinar Knutsen
*/
@Beta
public enum ErrorCode {
+
OK(true, true),
ERROR(false, false),
TRANSIENT_ERROR(false, true),
@@ -20,7 +20,7 @@ public enum ErrorCode {
private boolean success;
private boolean _transient;
- private ErrorCode(boolean success, boolean _transient) {
+ ErrorCode(boolean success, boolean _transient) {
this.success = success;
this._transient = _transient;
}
@@ -32,4 +32,5 @@ public enum ErrorCode {
public boolean isTransient() {
return _transient;
}
+
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/OperationStatus.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/OperationStatus.java
index 7ea4a214cbd..7aec207e0ab 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/OperationStatus.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/OperationStatus.java
@@ -10,11 +10,11 @@ import java.util.Iterator;
/**
* Serialization/deserialization class for the result of a single document operation against Vespa.
*
- * @author <a href="mailto:steinar@yahoo-inc.com">Steinar Knutsen</a>
- * @since 5.1
+ * @author Steinar Knutsen
*/
@Beta
public final class OperationStatus {
+
public static final String IS_CONDITION_NOT_MET = "IS-CONDITION-NOT-MET";
public final String message;
public final String operationId;
@@ -81,9 +81,7 @@ public final class OperationStatus {
return new OperationStatus(message, operationId, errorCode, isConditionNotMet, traceMessage);
}
- /**
- * @return a string representing the status.
- */
+ /** Returns a string representing the status. */
public String render() {
StringBuilder s = new StringBuilder();
Encoder.encode(operationId, s).append(SEPARATOR);
@@ -92,4 +90,5 @@ public final class OperationStatus {
Encoder.encode(traceMessage, s).append(EOL);
return s.toString();
}
+
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ServerResponseException.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ServerResponseException.java
index 4c291935916..1800864cd90 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ServerResponseException.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ServerResponseException.java
@@ -5,12 +5,13 @@ import com.google.common.annotations.Beta;
/**
* The request was not processed properly on the server.
- * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a>
- * @since 5.1.20
+ *
+ * @author Einar M R Rosenvinge
*/
@SuppressWarnings("serial")
@Beta
public class ServerResponseException extends Exception {
+
private final int responseCode;
private final String responseString;
@@ -33,5 +34,6 @@ public class ServerResponseException extends Exception {
}
return responseString;
}
+
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java
index a16d992324d..903c1ad4842 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java
@@ -17,7 +17,7 @@ import java.util.concurrent.TimeUnit;
/**
* Implementation of FeedClient. It is a thin layer on top of multiClusterHandler and multiClusterResultAggregator.
- *
+ *
* @author dybis
*/
public class FeedClientImpl implements FeedClient {
@@ -92,4 +92,5 @@ public class FeedClientImpl implements FeedClient {
}
return true;
}
+
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java
index edbe5542bc4..c3b5d9912de 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java
@@ -6,7 +6,6 @@ import com.yahoo.vespa.http.client.Result;
import com.yahoo.vespa.http.client.Session;
import com.yahoo.vespa.http.client.config.SessionParams;
import com.yahoo.vespa.http.client.core.ThrottlePolicy;
-import com.yahoo.vespa.http.client.core.api.MultiClusterSessionOutputStream;
import com.yahoo.vespa.http.client.core.operationProcessor.IncompleteResultsThrottler;
import com.yahoo.vespa.http.client.core.operationProcessor.OperationProcessor;
@@ -65,4 +64,5 @@ public class SessionImpl implements Session {
public int getIncompleteResultQueueSize() {
return operationProcessor.getIncompleteResultQueueSize();
}
+
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java
index 420f64d4bf3..dd724e90a42 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java
@@ -44,8 +44,6 @@ import java.util.zip.GZIPOutputStream;
/**
* @author Einar M R Rosenvinge
- *
- * @since 5.1.20
*/
@Beta
class ApacheGatewayConnection implements GatewayConnection {
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java
index 671e6f07dbe..d963ae79227 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java
@@ -2,8 +2,6 @@
package com.yahoo.vespa.http.client.core.communication;
import com.yahoo.vespa.http.client.core.Document;
-import com.yahoo.vespa.http.client.core.EndpointResult;
-import com.yahoo.vespa.http.client.core.operationProcessor.EndPointResultFactory;
import java.util.ArrayDeque;
import java.util.ArrayList;
@@ -13,7 +11,7 @@ import java.util.Optional;
import java.util.concurrent.TimeUnit;
/**
- * Document queue that only gives you document operations on documents for which there are no
+ * Document queue that only gives you document operations on documents for which there are no
* already in flight operations for.
*
* @author dybis
@@ -54,8 +52,6 @@ class DocumentQueue {
}
}
-
-
Document poll(long timeout, TimeUnit unit) throws InterruptedException {
synchronized (queue) {
long remainingToWait = unit.toMillis(timeout);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index b396f831de0..5b98a1b4fb5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -19,7 +19,6 @@ import java.util.stream.Collectors;
* A tensor type with its dimensions. This is immutable.
* <p>
* A dimension can be indexed (bound or unbound) or mapped.
- * Currently, we only support tensor types where all dimensions have the same type.
*
* @author geirst
* @author bratseth
@@ -27,6 +26,7 @@ import java.util.stream.Collectors;
@Beta
public class TensorType {
+ /** The empty tensor type - which is the same as a double */
public static final TensorType empty = new TensorType(Collections.emptyList());
/** Sorted list of the dimensions of this */
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
index 3db661f8a23..3fb94f1251b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
@@ -10,7 +10,7 @@ import com.yahoo.tensor.Tensor;
* @author bratseth
*/
@Beta
-public interface EvaluationContext {
+public interface EvaluationContext extends TypeContext {
/** Returns the tensor bound to this name, or null if none */
Tensor getTensor(String name);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
index db8a66a5fa2..9fe6b7d053f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.evaluation;
import com.google.common.annotations.Beta;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import java.util.HashMap;
@@ -19,6 +20,13 @@ public class MapEvaluationContext implements EvaluationContext {
public void put(String name, Tensor tensor) { bindings.put(name, tensor); }
@Override
+ public TensorType getType(String name) {
+ Tensor tensor = bindings.get(name);
+ if (tensor == null) return null;
+ return tensor.type();
+ }
+
+ @Override
public Tensor getTensor(String name) { return bindings.get(name); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java
new file mode 100644
index 00000000000..4d3bd04c3d4
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java
@@ -0,0 +1,21 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.evaluation;
+
+import com.yahoo.tensor.TensorType;
+
+/**
+ * Provides type information about a context (set of variable bindings).
+ *
+ * @author bratseth
+ */
+public interface TypeContext {
+
+ /**
+ * Returns tye type of the tensor with this name.
+ *
+ * @return returns the type of the tensor which will be returned by calling getTensor(name)
+ * or null if getTensor will return null.
+ */
+ TensorType getType(String name);
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
index 1f6ad050368..34beb465d4c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
@@ -3,12 +3,14 @@ package com.yahoo.tensor.evaluation;
import com.google.common.annotations.Beta;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.Collections;
import java.util.List;
+import java.util.Optional;
/**
* A tensor variable name which resolves to a tensor in the context at evaluation time
@@ -19,23 +21,42 @@ import java.util.List;
public class VariableTensor extends PrimitiveTensorFunction {
private final String name;
+ private final Optional<TensorType> requiredType;
public VariableTensor(String name) {
this.name = name;
+ this.requiredType = Optional.empty();
+ }
+
+ /** A variable tensor which must be compatible with the given type */
+ public VariableTensor(String name, TensorType requiredType) {
+ this.name = name;
+ this.requiredType = Optional.of(requiredType);
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+ public List<TensorFunction> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) { return this; }
+ public TensorFunction withArguments(List<TensorFunction> arguments) { return this; }
@Override
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
+ public TensorType type(TypeContext context) {
+ TensorType givenType = context.getType(name);
+ if (givenType == null) return null;
+ verifyType(givenType);
+ return givenType;
+ }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
- return context.getTensor(name);
+ Tensor tensor = context.getTensor(name);
+ if (tensor == null) return null;
+ verifyType(tensor.type());
+ return tensor;
}
@Override
@@ -43,4 +64,9 @@ public class VariableTensor extends PrimitiveTensorFunction {
return name;
}
+ private void verifyType(TensorType givenType) {
+ if (requiredType.isPresent() && ! givenType.isAssignableTo(requiredType.get()))
+ throw new IllegalArgumentException("Variable '" + name + "' must be compatible with " +
+ requiredType.get() + " but was " + givenType);
+ }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
index 10f53670826..93365d20966 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
@@ -14,17 +14,17 @@ public class Argmax extends CompositeTensorFunction {
private final TensorFunction argument;
private final String dimension;
-
+
public Argmax(TensorFunction argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Argmax must have 1 argument, got " + arguments.size());
return new Argmax(arguments.get(0), dimension);
@@ -37,7 +37,7 @@ public class Argmax extends CompositeTensorFunction {
new Reduce(primitiveArgument, Reduce.Aggregator.max, dimension),
ScalarFunctions.equal());
}
-
+
@Override
public String toString(ToStringContext context) {
return "argmax(" + argument.toString(context) + ", " + dimension + ")";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
index d324aec53e9..e598cdf8a98 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
@@ -14,17 +14,17 @@ public class Argmin extends CompositeTensorFunction {
private final TensorFunction argument;
private final String dimension;
-
+
public Argmin(TensorFunction argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Argmin must have 1 argument, got " + arguments.size());
return new Argmin(arguments.get(0), dimension);
@@ -37,7 +37,7 @@ public class Argmin extends CompositeTensorFunction {
new Reduce(primitiveArgument, Reduce.Aggregator.min, dimension),
ScalarFunctions.equal());
}
-
+
@Override
public String toString(ToStringContext context) {
return "argmin(" + argument.toString(context) + ", " + dimension + ")";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
index 191c7988443..2109b730e1a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
@@ -3,7 +3,9 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
/**
* A composite tensor function is a tensor function which can be expressed (less tersely)
@@ -14,6 +16,10 @@ import com.yahoo.tensor.evaluation.EvaluationContext;
@Beta
public abstract class CompositeTensorFunction extends TensorFunction {
+ /** Finds the type this produces by first converting it to a primitive function */
+ @Override
+ public final TensorType type(TypeContext context) { return toPrimitive().type(context); }
+
/** Evaluates this by first converting it to a primitive function */
@Override
public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index d4affe0ef9b..c77ed1c0526 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -9,8 +9,14 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
-
-import java.util.*;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
import java.util.stream.Collectors;
/**
@@ -34,10 +40,10 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); }
+ public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if (arguments.size() != 2)
throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size());
return new Concat(arguments.get(0), arguments.get(1), dimension);
@@ -54,6 +60,20 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
+ public TensorType type(TypeContext context) {
+ return type(argumentA.type(context), argumentB.type(context));
+ }
+
+ /** Returns the type resulting from concatenating a and b */
+ private TensorType type(TensorType a, TensorType b) {
+ TensorType.Builder builder = new TensorType.Builder(a, b);
+ if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size
+ builder.set(TensorType.Dimension.indexed(dimension, a.dimension(dimension).get().size().get() +
+ b.dimension(dimension).get().size().get()));
+ return builder.build();
+ }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
@@ -63,7 +83,7 @@ public class Concat extends PrimitiveTensorFunction {
IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
IndexedTensor bIndexed = (IndexedTensor) b;
- TensorType concatType = concatType(a, b);
+ TensorType concatType = type(a.type(), b.type());
DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
@@ -115,15 +135,6 @@ public class Concat extends PrimitiveTensorFunction {
}
- /** Returns the type resulting from concatenating a and b */
- private TensorType concatType(Tensor a, Tensor b) {
- TensorType.Builder builder = new TensorType.Builder(a.type(), b.type());
- if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size
- builder.set(TensorType.Dimension.indexed(dimension, a.type().dimension(dimension).get().size().get() +
- b.type().dimension(dimension).get().size().get()));
- return builder.build();
- }
-
/** Returns the concrete (not type) dimension sizes resulting from combining a and b */
private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index 14ed38718ce..50b479da168 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -3,7 +3,9 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.List;
@@ -27,10 +29,10 @@ public class ConstantTensor extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+ public List<TensorFunction> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("ConstantTensor must have 0 arguments, got " + arguments.size());
return this;
@@ -40,6 +42,9 @@ public class ConstantTensor extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
+ public TensorType type(TypeContext context) { return constant.type(); }
+
+ @Override
public Tensor evaluate(EvaluationContext context) { return constant; }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
index 653be8dacf0..e302f6606e7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -25,10 +25,10 @@ public class Diag extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+ public List<TensorFunction> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 0)
throw new IllegalArgumentException("Diag must have 0 arguments, got " + arguments.size());
return this;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index ef2770c04f5..e70d1de3db7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -7,6 +7,7 @@ import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.List;
@@ -47,10 +48,10 @@ public class Generate extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+ public List<TensorFunction> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 0)
throw new IllegalArgumentException("Generate must have 0 arguments, got " + arguments.size());
return this;
@@ -60,6 +61,9 @@ public class Generate extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
+ public TensorType type(TypeContext context) { return type; }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
Tensor.Builder builder = Tensor.Builder.of(type);
IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type));
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 174a8e4c435..7812c985091 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -11,6 +11,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayList;
import java.util.Collections;
@@ -47,6 +48,7 @@ public class Join extends PrimitiveTensorFunction {
}
/** Returns the type resulting from applying Join to the two given types */
+ // TODO: Replace implementation by new TensorType.Builder(a.type(), b.type()).build();
public static TensorType outputType(TensorType a, TensorType b) {
TensorType.Builder typeBuilder = new TensorType.Builder();
for (int i = 0; i < a.dimensions().size(); ++i) {
@@ -70,15 +72,13 @@ public class Join extends PrimitiveTensorFunction {
return typeBuilder.build();
}
- public TensorFunction argumentA() { return argumentA; }
- public TensorFunction argumentB() { return argumentB; }
public DoubleBinaryOperator combinator() { return combinator; }
@Override
- public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); }
+ public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 2)
throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size());
return new Join(arguments.get(0), arguments.get(1), combinator);
@@ -95,6 +95,11 @@ public class Join extends PrimitiveTensorFunction {
}
@Override
+ public TensorType type(TypeContext context) {
+ return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build();
+ }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
index 91a9c6d1b27..d7f7ae59d62 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
@@ -14,17 +14,17 @@ public class L1Normalize extends CompositeTensorFunction {
private final TensorFunction argument;
private final String dimension;
-
+
public L1Normalize(TensorFunction argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("L1Normalize must have 1 argument, got " + arguments.size());
return new L1Normalize(arguments.get(0), dimension);
@@ -38,7 +38,7 @@ public class L1Normalize extends CompositeTensorFunction {
new Reduce(primitiveArgument, Reduce.Aggregator.sum, dimension),
ScalarFunctions.divide());
}
-
+
@Override
public String toString(ToStringContext context) {
return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
index bdf8921f81d..e2c526760bd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
@@ -14,17 +14,17 @@ public class L2Normalize extends CompositeTensorFunction {
private final TensorFunction argument;
private final String dimension;
-
+
public L2Normalize(TensorFunction argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("L2Normalize must have 1 argument, got " + arguments.size());
return new L2Normalize(arguments.get(0), dimension);
@@ -40,7 +40,7 @@ public class L2Normalize extends CompositeTensorFunction {
ScalarFunctions.sqrt()),
ScalarFunctions.divide());
}
-
+
@Override
public String toString(ToStringContext context) {
return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index a5e1a016a41..53504868ff2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -2,12 +2,11 @@
package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
-import com.google.common.collect.ImmutableMap;
-import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Iterator;
@@ -39,10 +38,10 @@ public class Map extends PrimitiveTensorFunction {
public DoubleUnaryOperator mapper() { return mapper; }
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Map must have 1 argument, got " + arguments.size());
return new Map(arguments.get(0), mapper);
@@ -54,6 +53,11 @@ public class Map extends PrimitiveTensorFunction {
}
@Override
+ public TensorType type(TypeContext context) {
+ return argument.type(context);
+ }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
Tensor argument = argument().evaluate(context);
Tensor.Builder builder = Tensor.Builder.of(argument.type());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
index 4071917c2b5..935e4761cfe 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -27,10 +27,10 @@ public class Matmul extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); }
+ public List<TensorFunction> arguments() { return ImmutableList.of(argument1, argument2); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 2)
throw new IllegalArgumentException("Matmul must have 2 arguments, got " + arguments.size());
return new Matmul(arguments.get(0), arguments.get(1), dimension);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
index 958ef85d1dc..1475f7f4ac1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
@@ -28,10 +28,10 @@ public class Random extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+ public List<TensorFunction> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 0)
throw new IllegalArgumentException("Random must have 0 arguments, got " + arguments.size());
return this;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
index 8e7f4e4c773..d951ec9ccbd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -26,10 +26,10 @@ public class Range extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+ public List<TensorFunction> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 0)
throw new IllegalArgumentException("Range must have 0 arguments, got " + arguments.size());
return this;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index de9f90a5804..76a938b9fe2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -8,6 +8,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.HashMap;
@@ -73,10 +74,10 @@ public class Reduce extends PrimitiveTensorFunction {
public TensorFunction argument() { return argument; }
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size());
return new Reduce(arguments.get(0), aggregator, dimensions);
@@ -100,6 +101,19 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
+ public TensorType type(TypeContext context) {
+ return type(argument.type(context));
+ }
+
+ private TensorType type(TensorType argumentType) {
+ TensorType.Builder builder = new TensorType.Builder();
+ for (TensorType.Dimension dimension : argumentType.dimensions())
+ if ( ! dimensions.contains(dimension.name())) // keep
+ builder.dimension(dimension);
+ return builder.build();
+ }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
Tensor argument = this.argument.evaluate(context);
if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
@@ -113,12 +127,7 @@ public class Reduce extends PrimitiveTensorFunction {
else
return reduceAllGeneral(argument);
- // Reduce type
- TensorType.Builder builder = new TensorType.Builder();
- for (TensorType.Dimension dimension : argument.type().dimensions())
- if ( ! dimensions.contains(dimension.name())) // keep
- builder.dimension(dimension);
- TensorType reducedType = builder.build();
+ TensorType reducedType = type(argument.type());
// Reduce cells
Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index ec9b762a41c..de3d2be265a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -7,6 +7,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.HashMap;
@@ -26,6 +27,7 @@ public class Rename extends PrimitiveTensorFunction {
private final TensorFunction argument;
private final List<String> fromDimensions;
private final List<String> toDimensions;
+ private final Map<String, String> fromToMap;
public Rename(TensorFunction argument, String fromDimension, String toDimension) {
this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension));
@@ -43,13 +45,24 @@ public class Rename extends PrimitiveTensorFunction {
this.argument = argument;
this.fromDimensions = ImmutableList.copyOf(fromDimensions);
this.toDimensions = ImmutableList.copyOf(toDimensions);
+ this.fromToMap = fromToMap(fromDimensions, toDimensions);
+ }
+
+ public List<String> fromDimensions() { return fromDimensions; }
+ public List<String> toDimensions() { return toDimensions; }
+
+ private static Map<String, String> fromToMap(List<String> fromDimensions, List<String> toDimensions) {
+ Map<String, String> map = new HashMap<>();
+ for (int i = 0; i < fromDimensions.size(); i++)
+ map.put(fromDimensions.get(i), toDimensions.get(i));
+ return map;
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Rename must have 1 argument, got " + arguments.size());
return new Rename(arguments.get(0), fromDimensions, toDimensions);
@@ -59,11 +72,22 @@ public class Rename extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
+ public TensorType type(TypeContext context) {
+ return type(argument.type(context));
+ }
+
+ private TensorType type(TensorType type) {
+ TensorType.Builder builder = new TensorType.Builder();
+ for (TensorType.Dimension dimension : type.dimensions())
+ builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
+ return builder.build();
+ }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
Tensor tensor = argument.evaluate(context);
- Map<String, String> fromToMap = fromToMap();
- TensorType renamedType = rename(tensor.type(), fromToMap);
+ TensorType renamedType = type(tensor.type());
// an array which lists the index of each label in the renamed type
int[] toIndexes = new int[tensor.type().dimensions().size()];
@@ -82,13 +106,6 @@ public class Rename extends PrimitiveTensorFunction {
return builder.build();
}
- private TensorType rename(TensorType type, Map<String, String> fromToMap) {
- TensorType.Builder builder = new TensorType.Builder();
- for (TensorType.Dimension dimension : type.dimensions())
- builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
- return builder.build();
- }
-
private TensorAddress rename(TensorAddress address, int[] toIndexes) {
String[] reorderedLabels = new String[toIndexes.length];
for (int i = 0; i < toIndexes.length; i++)
@@ -102,13 +119,6 @@ public class Rename extends PrimitiveTensorFunction {
toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
}
- private Map<String, String> fromToMap() {
- Map<String, String> map = new HashMap<>();
- for (int i = 0; i < fromDimensions.size(); i++)
- map.put(fromDimensions.get(i), toDimensions.get(i));
- return map;
- }
-
private String toVectorString(List<String> elements) {
if (elements.size() == 1)
return elements.get(0);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
index c856b548180..32cff5ac84a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -16,21 +16,21 @@ public class Softmax extends CompositeTensorFunction {
private final TensorFunction argument;
private final String dimension;
-
+
public Softmax(TensorFunction argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
-
+
public static TensorType outputType(TensorType inputType, String dimension) {
return Reduce.outputType(inputType, ImmutableList.of(dimension));
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Softmax must have 1 argument, got " + arguments.size());
return new Softmax(arguments.get(0), dimension);
@@ -45,7 +45,7 @@ public class Softmax extends CompositeTensorFunction {
dimension),
ScalarFunctions.divide());
}
-
+
@Override
public String toString(ToStringContext context) {
return "softmax(" + argument.toString(context) + ", " + dimension + ")";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index 533a46f87fe..78ab09c7820 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -3,8 +3,10 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.MapEvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.List;
@@ -19,14 +21,14 @@ import java.util.List;
public abstract class TensorFunction {
/** Returns the function arguments of this node in the order they are applied */
- public abstract List<TensorFunction> functionArguments();
+ public abstract List<TensorFunction> arguments();
/**
* Returns a copy of this tensor function with the arguments replaced by the given list of arguments.
*
* @throws IllegalArgumentException if the argument list has the wrong size for this function
*/
- public abstract TensorFunction replaceArguments(List<TensorFunction> arguments);
+ public abstract TensorFunction withArguments(List<TensorFunction> arguments);
/**
* Translate this function - and all of its arguments recursively -
@@ -43,6 +45,13 @@ public abstract class TensorFunction {
*/
public abstract Tensor evaluate(EvaluationContext context);
+ /**
+ * Returns the type of the tensor this produces given the input types in the context
+ *
+ * @param context a context which must be passed to all nexted functions when evaluating
+ */
+ public abstract TensorType type(TypeContext context);
+
/** Evaluate with no context */
public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
index 2464be981f5..78ff0731566 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
@@ -14,7 +14,7 @@ public class XwPlusB extends CompositeTensorFunction {
private final TensorFunction x, w, b;
private final String dimension;
-
+
public XwPlusB(TensorFunction x, TensorFunction w, TensorFunction b, String dimension) {
this.x = x;
this.w = w;
@@ -23,10 +23,10 @@ public class XwPlusB extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> functionArguments() { return ImmutableList.of(x, w, b); }
+ public List<TensorFunction> arguments() { return ImmutableList.of(x, w, b); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 3)
throw new IllegalArgumentException("XwPlusB must have 3 arguments, got " + arguments.size());
return new XwPlusB(arguments.get(0), arguments.get(1), arguments.get(2), dimension);
@@ -43,7 +43,7 @@ public class XwPlusB extends CompositeTensorFunction {
primitiveB,
ScalarFunctions.add());
}
-
+
@Override
public String toString(ToStringContext context) {
return "xw_plus_b(" + x.toString(context) + ", " +