diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-01-31 11:13:51 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-01-31 11:13:51 +0100 |
commit | a44edeba9f38c38c431d7b9b6e1ac454e2a0e610 (patch) | |
tree | 21600936cfe396492965764911652b49b4c22731 | |
parent | 9c4ba9bf5b96b8c62a9b8c5a6c20a9175c698b70 (diff) |
Verify macros
65 files changed, 501 insertions, 352 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index bacff94d776..3b68e0199a9 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -2,15 +2,26 @@ package com.yahoo.searchdefinition; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.io.reader.NamedReader; +import com.yahoo.processing.request.CompoundName; +import com.yahoo.search.query.profile.QueryProfile; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.search.query.profile.config.QueryProfileXMLReader; +import com.yahoo.search.query.profile.types.FieldDescription; +import com.yahoo.search.query.profile.types.TensorFieldType; import com.yahoo.search.query.ranking.Diversity; +import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.FeatureList; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TypeMapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.io.File; import java.io.IOException; @@ -276,12 +287,10 @@ public class RankProfile implements Serializable, Cloneable { addConstant(name, value); } - /** - * Returns an unmodifiable view of the constants to use in this. - */ + /** Returns an unmodifiable view of the constants available in this */ public Map<String, Value> getConstants() { if (constants.isEmpty()) - return getInherited() != null ? getInherited().getConstants() : Collections.<String,Value>emptyMap(); + return getInherited() != null ? getInherited().getConstants() : Collections.emptyMap(); if (getInherited() == null || getInherited().getConstants().isEmpty()) return Collections.unmodifiableMap(constants); @@ -433,8 +442,9 @@ public class RankProfile implements Serializable, Cloneable { properties.add(rankProperty); } + @Override public String toString() { - return "rank profile " + getName(); + return "rank profile '" + getName() + "'"; } public int getRerankCount() { @@ -508,6 +518,7 @@ public class RankProfile implements Serializable, Cloneable { /** * Returns the string form of the first phase ranking expression. + * * @return string form of first phase ranking expression */ public String getFirstPhaseRankingString() { @@ -727,6 +738,64 @@ public class RankProfile implements Serializable, Cloneable { } /** + * Creates a context containing the type information of all constants, attributes and query profiles + * referable from this rank profile. + */ + public TypeContext typeContext() { + TypeMapContext context = new TypeMapContext(); + + // Add constants + getConstants().forEach((k, v) -> context.setType(asConstantFeature(k), v.type())); + + // Add attributes + for (SDField field : getSearch().allConcreteFields()) + field.getAttributes().forEach((k, a) -> context.setType(asAttributeFeature(k), a.tensorType().orElse(TensorType.empty))); + + // Add query features from rank profile types reached from the "default" profile + QueryProfile profile = queryProfilesOf(getSearch().sourceApplication()).getComponent("default"); + if (profile != null && profile.getType() != null) { + profile.listTypes(CompoundName.empty, Collections.emptyMap()).forEach((prefix, queryProfileType) -> { + for (FieldDescription field : queryProfileType.declaredFields().values()) { + TensorType type = TensorType.empty; // assume the empty (aka double) type by default + if (field.getType() instanceof TensorFieldType) + type = ((TensorFieldType)field.getType()).type().get(); + + String feature = asQueryFeature(prefix.append(field.getName()).toString()); + context.setType(feature, type); + } + }); + } + + return context; + } + + private QueryProfileRegistry queryProfilesOf(ApplicationPackage applicationPackage) { + List<NamedReader> queryProfileFiles = null; + List<NamedReader> queryProfileTypeFiles = null; + try { + queryProfileFiles = applicationPackage.getQueryProfileFiles(); + queryProfileTypeFiles = applicationPackage.getQueryProfileTypeFiles(); + return new QueryProfileXMLReader().read(queryProfileFiles, queryProfileTypeFiles); + } + finally { + NamedReader.closeAll(queryProfileFiles); + NamedReader.closeAll(queryProfileTypeFiles); + } + } + + private String asConstantFeature(String constantName) { + return "constant(\"" + constantName + "\")"; + } + + private String asAttributeFeature(String constantName) { + return "attribute(\"" + constantName + "\")"; + } + + private String asQueryFeature(String constantName) { + return "query(\"" + constantName + "\")"; + } + + /** * A rank setting. The identity of a rank setting is its field name and type (not value). * A rank setting is immutable. */ diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java index df5697de0d5..f4a0365e36e 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java @@ -247,7 +247,7 @@ public class Search implements Serializable, ImmutableSearch { /** * Returns a list of all the fields of this search definition, that is all fields in all documents, in the documents * they inherit, and all extra fields. The caller receives ownership to the list - subsequent changes to it will not - * impact this Search + * impact this * * @return the list of fields in this searchdefinition */ @@ -546,7 +546,7 @@ public class Search implements Serializable, ImmutableSearch { } /** - * Returns the first occurence of an attribute having this name, or null if none + * Returns the first occurrence of an attribute having this name, or null if none * * @param name Name of attribute * @return The Attribute with given name. diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java index ce3ff7cc447..72ba6de7022 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java @@ -31,24 +31,19 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce private Map<String, Attribute> attributes = new java.util.LinkedHashMap<>(); private Map<String, Attribute> importedAttributes = new java.util.LinkedHashMap<>(); - /** - * Flag indicating if a position-attribute has been found - */ + /** Whether this has any position attribute */ private boolean hasPosition = false; public AttributeFields(Search search) { derive(search); } - /** - * Derives everything from a field - */ + /** Derives everything from a field */ @Override protected void derive(ImmutableSDField field, Search search) { if (field.usesStructOrMap() && !field.getDataType().equals(PositionDataType.INSTANCE) && - !field.getDataType().equals(DataType.getArray(PositionDataType.INSTANCE))) - { + !field.getDataType().equals(DataType.getArray(PositionDataType.INSTANCE))) { return; // Ignore struct fields for indexed search (only implemented for streaming search) } if (field.isImportedField()) { @@ -58,9 +53,7 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce } } - /** - * Return an attribute by name, or null if it doesn't exist - */ + /** Returns an attribute by name, or null if it doesn't exist */ public Attribute getAttribute(String attributeName) { return attributes.get(attributeName); } @@ -69,9 +62,7 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce return getAttribute(attributeName) != null; } - /** - * Derives one attribute. TODO: Support non-default named attributes - */ + /** Derives one attribute. TODO: Support non-default named attributes */ private void deriveAttributes(ImmutableSDField field) { for (Attribute fieldAttribute : field.getAttributes().values()) { deriveAttribute(field, fieldAttribute); @@ -107,9 +98,7 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce } } - /** - * Returns a read only attribute iterator - */ + /** Returns a read only attribute iterator */ public Iterator attributeIterator() { return attributes().iterator(); } @@ -201,4 +190,5 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce } } } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 01d3449573c..d0f705d4c9a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -41,7 +41,6 @@ import java.util.Optional; * * @author bratseth */ -// TODO: Verify types of macros // TODO: Avoid name conflicts across models for constants public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { @@ -84,6 +83,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil Signature signature = chooseSignature(model, store.arguments().signature()); String output = chooseOutput(signature, store.arguments().output()); RankingExpression expression = model.expressions().get(output); + verifyRequiredMacros(expression, model.requiredMacros(), profile); store.writeConverted(expression); model.constants().forEach((k, v) -> transformConstant(store, profile, k, v)); @@ -168,6 +168,44 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } /** + * Verify that the macros referred in the given expression exists in the given rank profile, + * and return tensors of the types specified in requiredMacros. + */ + private void verifyRequiredMacros(RankingExpression expression, Map<String, TensorType> requiredMacros, + RankProfile profile) { + List<String> macroNames = new ArrayList<>(); + addMacroNamesIn(expression.getRoot(), macroNames); + for (String macroName : macroNames) { + TensorType requiredType = requiredMacros.get(macroName); + if (requiredType == null) continue; // Not a required macro + + RankProfile.Macro macro = profile.getMacros().get(macroName); + if (macro == null) + throw new IllegalArgumentException("Model refers Placeholder '" + macroName + + "' of type " + requiredType + " but this macro is not present in " + + profile); + TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext()); + if ( ! actualType.isAssignableTo(requiredType)) + throw new IllegalArgumentException("Model refers Placeholder '" + macroName + + "' of type " + requiredType + + " which must be produced by a macro in the rank profile, but " + + "this macro produces type " + actualType + " in " + profile); + } + } + + private void addMacroNamesIn(ExpressionNode node, List<String> names) { + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode)node; + if (referenceNode.getOutput() == null) // macro references cannot specify outputs + names.add(referenceNode.getName()); + } + else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode)node).children()) + addMacroNamesIn(child, names); + } + } + + /** * Provides read/write access to the correct directories of the application package given by the feature arguments */ private static class ModelStore { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/QueryProfilesBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/QueryProfilesBuilder.java index d08092003f1..b85cb88bf2e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/QueryProfilesBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/QueryProfilesBuilder.java @@ -5,6 +5,7 @@ import com.yahoo.io.reader.NamedReader; import com.yahoo.search.query.profile.config.QueryProfileXMLReader; import com.yahoo.config.application.api.ApplicationPackage; +import java.util.Collections; import java.util.List; /** @@ -13,17 +14,16 @@ import java.util.List; * * @author bratseth */ -// TODO: Move into QueryProfiles public class QueryProfilesBuilder { /** Build the set of query profiles for an application package */ public QueryProfiles build(ApplicationPackage applicationPackage) { - List<NamedReader> queryProfileTypeFiles=null; - List<NamedReader> queryProfileFiles=null; + List<NamedReader> queryProfileTypeFiles = null; + List<NamedReader> queryProfileFiles = null; try { - queryProfileTypeFiles=applicationPackage.getQueryProfileTypeFiles(); - queryProfileFiles=applicationPackage.getQueryProfileFiles(); - return new QueryProfiles(new QueryProfileXMLReader().read(queryProfileTypeFiles,queryProfileFiles)); + queryProfileTypeFiles = applicationPackage.getQueryProfileTypeFiles(); + queryProfileFiles = applicationPackage.getQueryProfileFiles(); + return new QueryProfiles(new QueryProfileXMLReader().read(queryProfileTypeFiles, queryProfileFiles)); } finally { NamedReader.closeAll(queryProfileTypeFiles); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 89f1a9f785c..c24e886c83d 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -49,14 +49,8 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testMinimalTensorFlowReference() throws ParseException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved')" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); assertConstant("Variable_1", search, Optional.of(10L)); assertConstant("Variable", search, Optional.of(7840L)); @@ -64,14 +58,8 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testNestedTensorFlowReference() throws ParseException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: 5 + sum(tensorflow('mnist_softmax/saved'))" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "5 + sum(tensorflow('mnist_softmax/saved'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); assertConstant("Variable_1", search, Optional.of(10L)); assertConstant("Variable", search, Optional.of(7840L)); @@ -79,41 +67,23 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingSignature() throws ParseException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_default')" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved', 'serving_default')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test public void testTensorFlowReferenceSpecifyingSignatureAndOutput() throws ParseException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_default', 'y')" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved', 'serving_default', 'y')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test public void testTensorFlowReferenceSpecifyingNonExistingSignature() throws ParseException { try { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_defaultz')" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved', 'serving_defaultz')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } @@ -128,14 +98,8 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingNonExistingOutput() throws ParseException { try { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_default', 'x')" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved', 'serving_default', 'x')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } @@ -149,14 +113,8 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testImportingFromStoredExpressions() throws ParseException, IOException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_default')" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); assertConstant("Variable_1", search, Optional.of(10L)); assertConstant("Variable", search, Optional.of(7840L)); @@ -168,13 +126,9 @@ public class RankingExpressionWithTensorFlowTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); - RankProfileSearchFixture searchFromStored = new RankProfileSearchFixture( - storedApplication, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_default')" + - " }\n" + - " }"); + RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved')", + storedApplication); searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); // Verify that the constants exists, but don't verify the content as we are not // simulating file distribution in this test @@ -212,6 +166,30 @@ public class RankingExpressionWithTensorFlowTestCase { } } + private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) { + return fixtureWith(placeholderExpression, firstPhaseExpression, new StoringApplicationPackage(applicationDir)); + } + + private RankProfileSearchFixture fixtureWith(String placeholderExpression, + String firstPhaseExpression, + StoringApplicationPackage application) { + try { + return new RankProfileSearchFixture( + application, + " rank-profile my_profile {\n" + + " macro Placeholder() {\n" + + " expression: " + placeholderExpression + + " }\n" + + " first-phase {\n" + + " expression: " + firstPhaseExpression + + " }\n" + + " }"); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + private static class StoringApplicationPackage extends MockApplicationPackage { private final File root; diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java index 04dd3ee9005..a347fbfb3ab 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java @@ -43,28 +43,28 @@ import java.util.regex.Pattern; public class QueryProfile extends FreezableSimpleComponent implements Cloneable { /** Defines the permissible content of this, or null if any content is permissible */ - private QueryProfileType type=null; + private QueryProfileType type = null; /** The value at this query profile - allows non-fields to have values, e.g a=value1, a.b=value2 */ - private Object value=null; + private Object value = null; /** The variants of this, or null if none */ - private QueryProfileVariants variants=null; + private QueryProfileVariants variants = null; /** The resolved variant dimensions of this, or null if none or not resolved yet (is resolved at freeze) */ - private List<String> resolvedDimensions=null; + private List<String> resolvedDimensions = null; /** The query profiles inherited by this, or null if none */ - private List<QueryProfile> inherited=null; + private List<QueryProfile> inherited = null; /** The content of this profile. The values may be primitives, substitutable strings or other query profiles */ - private CopyOnWriteContent content=new CopyOnWriteContent(); + private CopyOnWriteContent content = new CopyOnWriteContent(); /** * Field override settings: fieldName→OverrideValue. These overrides the override * setting in the type (if any) of this field). If there are no query profile level settings, this is null. */ - private Map<String,Boolean> overridable=null; + private Map<String,Boolean> overridable = null; /** * Creates a new query profile from an id. @@ -108,26 +108,26 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable /** Adds a profile to the end of the inherited list of this. Throws an exception if this is frozen. */ public void addInherited(QueryProfile profile) { - addInherited(profile,(DimensionValues)null); + addInherited(profile, (DimensionValues)null); } public final void addInherited(QueryProfile profile,String[] dimensionValues) { - addInherited(profile,DimensionValues.createFrom(dimensionValues)); + addInherited(profile, DimensionValues.createFrom(dimensionValues)); } /** Adds a profile to the end of the inherited list of this for the given variant. Throws an exception if this is frozen. */ public void addInherited(QueryProfile profile, DimensionValues dimensionValues) { ensureNotFrozen(); - DimensionBinding dimensionBinding=DimensionBinding.createFrom(getDimensions(),dimensionValues); + DimensionBinding dimensionBinding=DimensionBinding.createFrom(getDimensions(), dimensionValues); if (dimensionBinding.isNull()) { - if (inherited==null) - inherited=new ArrayList<>(); + if (inherited == null) + inherited = new ArrayList<>(); inherited.add(profile); } else { - if (variants==null) - variants=new QueryProfileVariants(dimensionBinding.getDimensions(), this); + if (variants == null) + variants = new QueryProfileVariants(dimensionBinding.getDimensions(), this); variants.inherit(profile,dimensionBinding.getValues()); } } @@ -152,13 +152,13 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable * @throws IllegalStateException if this is frozen */ public Boolean isDeclaredOverridable(String name, Map<String,String> context) { - return isDeclaredOverridable(new CompoundName(name),DimensionBinding.createFrom(getDimensions(),context)); + return isDeclaredOverridable(new CompoundName(name), DimensionBinding.createFrom(getDimensions(), context)); } /** Sets the dimensions over which this may vary. Note: This will erase any currently defined variants */ public void setDimensions(String[] dimensions) { ensureNotFrozen(); - variants=new QueryProfileVariants(dimensions, this); + variants = new QueryProfileVariants(dimensions, this); } /** Returns the value set at this node, to allow non-leafs to have values. Returns null if none. */ @@ -166,17 +166,17 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable public void setValue(Object value) { ensureNotFrozen(); - this.value=value; + this.value = value; } /** Returns the variant dimensions to be used in this - an unmodifiable list of dimension names */ public List<String> getDimensions() { if (isFrozen()) return resolvedDimensions; - if (variants!=null) return variants.getDimensions(); - if (inherited==null) return null; + if (variants != null) return variants.getDimensions(); + if (inherited == null) return null; for (QueryProfile inheritedProfile : inherited) { - List<String> inheritedDimensions=inheritedProfile.getDimensions(); - if (inheritedDimensions!=null) return inheritedDimensions; + List<String> inheritedDimensions = inheritedProfile.getDimensions(); + if (inheritedDimensions != null) return inheritedDimensions; } return null; } @@ -187,8 +187,8 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable * Sets the overridability of a field in this profile, * this overrides the corresponding setting in the type (if any) */ - public final void setOverridable(String fieldName, boolean overridable, Map<String,String> context) { - setOverridable(new CompoundName(fieldName), overridable,DimensionBinding.createFrom(getDimensions(), context)); + public final void setOverridable(String fieldName, boolean overridable, Map<String, String> context) { + setOverridable(new CompoundName(fieldName), overridable, DimensionBinding.createFrom(getDimensions(), context)); } /** @@ -241,8 +241,8 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable Map<String,Object> values=visitor.getResult(); if (substitution==null) return values; - for (Map.Entry<String,Object> entry : values.entrySet()) { - if (entry.getValue().getClass()==String.class) continue; // Shortcut + for (Map.Entry<String, Object> entry : values.entrySet()) { + if (entry.getValue().getClass() == String.class) continue; // Shortcut if (entry.getValue() instanceof SubstituteString) entry.setValue(((SubstituteString)entry.getValue()).substitute(context,substitution)); } @@ -253,7 +253,7 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable * Lists types reachable from this, indexed by the prefix having that type. * If this is itself typed, this' type will be included with an empty prefix */ - Map<CompoundName, QueryProfileType> listTypes(CompoundName prefix, Map<String, String> context) { + public Map<CompoundName, QueryProfileType> listTypes(CompoundName prefix, Map<String, String> context) { DimensionBinding dimensionBinding = DimensionBinding.createFrom(getDimensions(), context); AllTypesQueryProfileVisitor visitor = new AllTypesQueryProfileVisitor(prefix); accept(visitor, dimensionBinding, null); @@ -264,8 +264,8 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable * Lists references reachable from this. */ Set<CompoundName> listReferences(CompoundName prefix, Map<String, String> context) { - DimensionBinding dimensionBinding=DimensionBinding.createFrom(getDimensions(),context); - AllReferencesQueryProfileVisitor visitor=new AllReferencesQueryProfileVisitor(prefix); + DimensionBinding dimensionBinding = DimensionBinding.createFrom(getDimensions(),context); + AllReferencesQueryProfileVisitor visitor = new AllReferencesQueryProfileVisitor(prefix); accept(visitor,dimensionBinding,null); return visitor.getResult(); } @@ -292,40 +292,40 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable public final Object get(String name) { return get(name,(Map<String,String>)null); } /** Returns a value from this using the given property context for resolution and using this for substitution */ - public final Object get(String name, Map<String,String> context) { - return get(name,context,null); + public final Object get(String name, Map<String, String> context) { + return get(name, context, null); } /** Returns a value from this using the given dimensions for resolution */ public final Object get(String name, String[] dimensionValues) { - return get(name,dimensionValues,null); + return get(name, dimensionValues,null); } public final Object get(String name, String[] dimensionValues, Properties substitution) { - return get(name,DimensionValues.createFrom(dimensionValues),substitution); + return get(name, DimensionValues.createFrom(dimensionValues), substitution); } /** Returns a value from this using the given dimensions for resolution */ public final Object get(String name, DimensionValues dimensionValues, Properties substitution) { - return get(name,DimensionBinding.createFrom(getDimensions(),dimensionValues),substitution); + return get(name, DimensionBinding.createFrom(getDimensions(), dimensionValues), substitution); } public final Object get(String name, Map<String,String> context, Properties substitution) { - return get(name,DimensionBinding.createFrom(getDimensions(),context),substitution); + return get(name, DimensionBinding.createFrom(getDimensions(), context), substitution); } public final Object get(CompoundName name, Map<String,String> context, Properties substitution) { - return get(name,DimensionBinding.createFrom(getDimensions(),context),substitution); + return get(name, DimensionBinding.createFrom(getDimensions(), context), substitution); } final Object get(String name, DimensionBinding binding,Properties substitution) { - return get(new CompoundName(name),binding,substitution); + return get(new CompoundName(name), binding, substitution); } final Object get(CompoundName name, DimensionBinding binding, Properties substitution) { - Object node=get(name,binding); - if (node!=null && node.getClass()==String.class) return node; // Shortcut - if (node instanceof SubstituteString) return ((SubstituteString)node).substitute(binding.getContext(),substitution); + Object node = get(name, binding); + if (node != null && node.getClass() == String.class) return node; // Shortcut + if (node instanceof SubstituteString) return ((SubstituteString)node).substitute(binding.getContext(), substitution); return node; } @@ -431,14 +431,14 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable @Override public QueryProfile clone() { if (isFrozen()) return this; - QueryProfile clone=(QueryProfile)super.clone(); - if (variants !=null) + QueryProfile clone = (QueryProfile)super.clone(); + if (variants != null) clone.variants = variants.clone(); - if (inherited!=null) - clone.inherited=new ArrayList<>(inherited); + if (inherited != null) + clone.inherited = new ArrayList<>(inherited); - if (this.content!=null) - clone.content=content.clone(); + if (this.content != null) + clone.content = content.clone(); return clone; } @@ -454,7 +454,7 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable /** Throws IllegalArgumentException if the given string is not a valid query profile name */ public static void validateName(String name) { - Matcher nameMatcher=namePattern.matcher(name); + Matcher nameMatcher = namePattern.matcher(name); if ( ! nameMatcher.matches()) throw new IllegalArgumentException("Illegal name '" + name + "'"); } @@ -467,7 +467,7 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable setNode(name, value, null, binding, registry); } catch (IllegalArgumentException e) { - throw new IllegalArgumentException("Could not set '" + name + "' to '" + value + "'",e); + throw new IllegalArgumentException("Could not set '" + name + "' to '" + value + "'", e); } } @@ -708,19 +708,19 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable /** Do a variant-aware content lookup in this */ protected Object localLookup(String name, DimensionBinding dimensionBinding) { - Object node=null; - if ( variants!=null && !dimensionBinding.isNull()) - node=variants.get(name,type,true,dimensionBinding); - if (node==null) - node=content==null ? null : content.get(name); + Object node = null; + if ( variants != null && !dimensionBinding.isNull()) + node = variants.get(name,type,true,dimensionBinding); + if (node == null) + node = content == null ? null : content.get(name); return node; } // ----------------- Private ---------------------------------------------------------------------------------- private Boolean isDeclaredOverridable(CompoundName name,DimensionBinding dimensionBinding) { - QueryProfile parent= lookupParentExact(name, true, dimensionBinding); - if (parent.overridable==null) return null; + QueryProfile parent = lookupParentExact(name, true, dimensionBinding); + if (parent.overridable == null) return null; return parent.overridable.get(name.last()); } @@ -729,10 +729,10 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable * this overrides the corresponding setting in the type (if any) */ private void setOverridable(CompoundName fieldName,boolean overridable,DimensionBinding dimensionBinding) { - QueryProfile parent= lookupParentExact(fieldName, true, dimensionBinding); - if (parent.overridable==null) - parent.overridable=new HashMap<>(); - parent.overridable.put(fieldName.last(),overridable); + QueryProfile parent = lookupParentExact(fieldName, true, dimensionBinding); + if (parent.overridable == null) + parent.overridable = new HashMap<>(); + parent.overridable.put(fieldName.last(), overridable); } /** Sets a value to a (possibly non-local) node. The parent query profile holding the value is returned */ @@ -740,7 +740,7 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable DimensionBinding dimensionBinding, QueryProfileRegistry registry) { ensureNotFrozen(); if (name.isCompound()) { - QueryProfile parent= getQueryProfileExact(name.first(), true, dimensionBinding); + QueryProfile parent = getQueryProfileExact(name.first(), true, dimensionBinding); parent.setNode(name.rest(), value,parentType, dimensionBinding.createFor(parent.getDimensions()), registry); } else { @@ -773,8 +773,8 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable * @return the created profile, or null if not present, and create is false */ private QueryProfile getQueryProfileExact(String localName, boolean create, DimensionBinding dimensionBinding) { - Object node=localExactLookup(localName, dimensionBinding); - if (node!=null && node instanceof QueryProfile) { + Object node = localExactLookup(localName, dimensionBinding); + if (node != null && node instanceof QueryProfile) { return (QueryProfile)node; } if (!create) return null; @@ -826,7 +826,7 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable } } - private static final Pattern namePattern=Pattern.compile("[$a-zA-Z_/][-$a-zA-Z0-9_/()]*"); + private static final Pattern namePattern = Pattern.compile("[$a-zA-Z_/][-$a-zA-Z0-9_/()]*"); /** * Returns a compiled version of this which produces faster lookup times diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java b/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java index 46aa33174f7..095d83d2887 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java @@ -39,8 +39,8 @@ public class QueryProfileXMLReader { * @throws RuntimeException if <code>directory</code> is not a readable directory, or if there is some error in the XML */ public QueryProfileRegistry read(String directory) { - List<NamedReader> queryProfileReaders=new ArrayList<>(); - List<NamedReader> queryProfileTypeReaders=new ArrayList<>(); + List<NamedReader> queryProfileReaders = new ArrayList<>(); + List<NamedReader> queryProfileTypeReaders = new ArrayList<>(); try { File dir=new File(directory); if ( !dir.isDirectory() ) throw new IllegalArgumentException("Could not read query profiles: '" + @@ -86,16 +86,16 @@ public class QueryProfileXMLReader { * Read the XML file readers into a registry. This does not close the readers. * This method is used directly from the admin system. */ - public QueryProfileRegistry read(List<NamedReader> queryProfileTypeReaders,List<NamedReader> queryProfileReaders) { - QueryProfileRegistry registry=new QueryProfileRegistry(); + public QueryProfileRegistry read(List<NamedReader> queryProfileTypeReaders, List<NamedReader> queryProfileReaders) { + QueryProfileRegistry registry = new QueryProfileRegistry(); // Phase 1 - List<Element> queryProfileTypeElements=createQueryProfileTypes(queryProfileTypeReaders,registry.getTypeRegistry()); - List<Element> queryProfileElements=createQueryProfiles(queryProfileReaders,registry); + List<Element> queryProfileTypeElements = createQueryProfileTypes(queryProfileTypeReaders, registry.getTypeRegistry()); + List<Element> queryProfileElements = createQueryProfiles(queryProfileReaders, registry); // Phase 2 - fillQueryProfileTypes(queryProfileTypeElements,registry.getTypeRegistry()); - fillQueryProfiles(queryProfileElements,registry); + fillQueryProfileTypes(queryProfileTypeElements, registry.getTypeRegistry()); + fillQueryProfiles(queryProfileElements, registry); return registry; } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/QueryProfileType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/QueryProfileType.java index 305de2a3c70..c826d834d47 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/QueryProfileType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/QueryProfileType.java @@ -102,13 +102,14 @@ public class QueryProfileType extends FreezableSimpleComponent { * * @throws IllegalStateException if this is frozen */ - public Map<String,FieldDescription> declaredFields() { + public Map<String, FieldDescription> declaredFields() { ensureNotFrozen(); return Collections.unmodifiableMap(fields); } /** * Returns true if <i>this</i> is declared strict. + * * @throws IllegalStateException if this is frozen */ public boolean isDeclaredStrict() { @@ -118,6 +119,7 @@ public class QueryProfileType extends FreezableSimpleComponent { /** * Returns true if <i>this</i> is declared as match as path. + * * @throws IllegalStateException if this is frozen */ public boolean getDeclaredMatchAsPath() { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java index a4d3c111356..5f8daa69ecf 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.TensorType; import java.util.Arrays; @@ -81,7 +82,7 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { } @Override - public ValueType getType(String name) { + public TensorType getType(String name) { Integer index = nameToIndex().get(name); if (index == null) return null; return values[index].type(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index a1e79df95e3..861f9565d66 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -4,7 +4,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Set; @@ -25,21 +24,10 @@ public abstract class Context implements EvaluationContext { */ public abstract Value get(String name); - /** Returns the type of the value of the given variable as a tensor type, or null if there is no such variable */ - @Override - public TensorType getTensorType(String name) { - ValueType type = getType(name); - if (type == null) return null; - return type.tensorType(); - } - /** Returns a variable as a tensor */ @Override public Tensor getTensor(String name) { return get(name).asTensor(); } - /** Returns the type of the value of the given variable, or null if there is no such variable */ - public abstract ValueType getType(String name); - /** * <p>Returns the value of a <i>structured variable</i> on the form * <code>name(argument*)(.output)?</code>, where <i>argument</i> is any diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index c85a8f1c7e1..3ac11cff0cb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; /** * A value which acts as a double in numerical context. @@ -13,7 +14,7 @@ import com.yahoo.tensor.Tensor; public abstract class DoubleCompatibleValue extends Value { @Override - public ValueType type() { return ValueType.doubleType(); } + public TensorType type() { return TensorType.empty; } @Override public boolean hasDouble() { return true; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java index 34cd75df9cb..0625e8506cc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.TensorType; /** * A variant of an array context variant which supports faster binding of variables but slower lookup @@ -67,7 +68,7 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext { } @Override - public ValueType getType(String name) { return ValueType.doubleType(); } + public TensorType getType(String name) { return TensorType.empty; } /** Perform a slow lookup by name */ @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java index 2672fe6cd8e..39efe641f26 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.evaluation; +import com.yahoo.tensor.TensorType; + import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -13,7 +15,7 @@ import java.util.Set; */ public class MapContext extends Context { - private Map<String,Value> bindings=new HashMap<>(); + private Map<String, Value> bindings = new HashMap<>(); private boolean frozen = false; @@ -21,16 +23,6 @@ public class MapContext extends Context { } /** - * Freezes this. - * Returns this for convenience. - */ - public MapContext freeze() { - if ( ! frozen) - bindings = Collections.unmodifiableMap(bindings); - return this; - } - - /** * Creates a map context from a map. * The ownership of the map is transferred to this - it cannot be further modified by the caller. * All the Values of the map will be frozen. @@ -41,27 +33,32 @@ public class MapContext extends Context { boundValue.freeze(); } + /** + * Freezes this. + * Returns this for convenience. + */ + public MapContext freeze() { + if ( ! frozen) + bindings = Collections.unmodifiableMap(bindings); + return this; + } + /** Returns the type of the given value key, or null if it is not bound. */ @Override - public ValueType getType(String key) { + public TensorType getType(String key) { Value value = bindings.get(key); if (value == null) return null; return value.type(); } - /** - * Returns the value of a key. 0 is returned if the given key is not bound in this. - */ + /** Returns the value of a key. 0 is returned if the given key is not bound in this. */ @Override public Value get(String key) { return bindings.getOrDefault(key, DoubleValue.zero); } /** - * Sets the value of a key. - * The value is frozen by this. - * - * @since 5.1.5 + * Sets the value of a key.The value is frozen by this. */ @Override public void put(String key,Value value) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index 874b41ec3e1..c60507310f1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -5,6 +5,7 @@ import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; /** * A string value. @@ -29,7 +30,7 @@ public class StringValue extends Value { } @Override - public ValueType type() { return ValueType.doubleType(); } + public TensorType type() { return TensorType.empty; } /** Returns the hashcode of this, to enable strings to be encoded (with reasonable safely) as doubles for optimization */ @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index f1c65dc79d3..c6e456f285d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -6,6 +6,7 @@ import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; /** * A Value containing a tensor. @@ -25,7 +26,7 @@ public class TensorValue extends Value { } @Override - public ValueType type() { return ValueType.doubleType(); } + public TensorType type() { return TensorType.empty; } @Override public double asDouble() { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java new file mode 100644 index 00000000000..f2c4ca58f6d --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java @@ -0,0 +1,27 @@ +package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.HashMap; +import java.util.Map; + +/** + * A context which only contains type information. + * + * @author bratseth + */ +public class TypeMapContext implements TypeContext { + + private final Map<String, TensorType> featureTypes = new HashMap<>(); + + public void setType(String name, TensorType type) { + featureTypes.put(name, type); + } + + @Override + public TensorType getType(String name) { + return featureTypes.get(name); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 856bfb3638d..59d2d95b879 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -13,14 +13,14 @@ import com.yahoo.tensor.TensorType; * Concrete subclasses of this provides implementations of these methods or throws * UnsupportedOperationException if the operation is not supported. * - * @author bratseth + * @author bratseth */ public abstract class Value { private boolean frozen=false; /** Returns the type of this value */ - public abstract ValueType type(); + public abstract TensorType type(); /** Returns this value as a double, or throws UnsupportedOperationException if it cannot be represented as a double */ public abstract double asDouble(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java deleted file mode 100644 index 046ad7861ef..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ValueType.java +++ /dev/null @@ -1,32 +0,0 @@ -package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -import com.yahoo.tensor.TensorType; - -/** - * The type of a ranking expression value - either a double or a tensor. - * - * @author bratseth - */ -public class ValueType { - - private static final ValueType doubleValueType = new ValueType(TensorType.empty); - - private final TensorType tensorType; - - private ValueType(TensorType tensorType) { - this.tensorType = tensorType; - } - - /** Returns true if this is the double type */ - public boolean isDouble() { return tensorType.rank() == 0; } - - /** The type of this as a tensor type. The double type is the empty tensor type (rank 0) */ - public TensorType tensorType() { return tensorType; } - - /** Returns the type representing a double */ - public static ValueType doubleType() { return doubleValueType; } - - /** Returns a type representing the given tensor type */ - public static ValueType of(TensorType type) { return new ValueType(type); } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java index b4e126f69e0..8ee4cdbf297 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java @@ -4,10 +4,11 @@ package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Deque; @@ -25,7 +26,7 @@ public class GBDTForestNode extends ExpressionNode { } @Override - public final ValueType type(Context context) { return ValueType.doubleType(); } + public final TensorType type(TypeContext context) { return TensorType.empty; } @Override public final Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java index f085194a7df..aac635b2545 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java @@ -4,10 +4,11 @@ package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Deque; @@ -50,7 +51,7 @@ public final class GBDTNode extends ExpressionNode { public final double[] values() { return values; } @Override - public final ValueType type(Context context) { return ValueType.doubleType(); } + public final TensorType type(TypeContext context) { return TensorType.empty; } @Override public final Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index 816ef38e128..cdcb4df0360 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -147,14 +147,12 @@ class OperationMapper { return operation.get().map(params); } params.signature().importWarning("TensorFlow operation '" + params.node().getOp() + - "' in node '" + params.node().getName() + "' is not supported."); + "' in node '" + params.node().getName() + "' is not supported."); return Optional.empty(); } - /* - * Operations - */ + // Operations --------------------------------- private static Optional<TypedTensorFunction> constant(TensorFlowImporter.Parameters params) { Tensor value = AttrValueConverter.toVespaTensor(params.node(), "value"); @@ -209,10 +207,11 @@ class OperationMapper { TensorType type = params.result().arguments().get(name); if (type == null) { throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name + - "', but there is no such placeholder"); + "', but there is no such placeholder"); } + params.result().requiredMacro(name, type); // Included literally in the expression and so must be produced by a separate macro in the rank profile - TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(name)); + TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(name, type)); return Optional.of(output); } @@ -227,7 +226,7 @@ class OperationMapper { } private static Optional<TypedTensorFunction> reshape(TensorFlowImporter.Parameters params) { - if (!checkInputs(params, 2)) { + if ( ! checkInputs(params, 2)) { return Optional.empty(); } List<Optional<TypedTensorFunction>> inputs = params.inputs(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 3a6b3f23a1d..6d78b501fdc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -14,7 +14,6 @@ import org.tensorflow.framework.TensorInfo; import java.io.File; import java.io.IOException; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -109,10 +108,10 @@ public class TensorFlowImporter { } Optional<TypedTensorFunction> function = OperationMapper.map(params); - if (!function.isPresent()) { + if ( ! function.isPresent()) { return Optional.empty(); } - if (!controlDependenciesArePresent(params)) { + if ( ! controlDependenciesArePresent(params)) { return Optional.empty(); } params.imported().put(nodeName, function.get()); @@ -185,6 +184,7 @@ public class TensorFlowImporter { /** Parameter object to hold important data while importing */ static final class Parameters { + private final TensorFlowImporter owner; private final GraphDef graph; private final SavedModelBundle model; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java index 60aaf8ddce1..fe725e50a3f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java @@ -27,11 +27,13 @@ public class TensorFlowModel { private final Map<String, Tensor> constants = new HashMap<>(); private final Map<String, RankingExpression> expressions = new HashMap<>(); private final Map<String, RankingExpression> macros = new HashMap<>(); + private final Map<String, TensorType> requiredMacros = new HashMap<>(); void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } void constant(String name, Tensor constant) { constants.put(name, constant); } void expression(String name, RankingExpression expression) { expressions.put(name, expression); } void macro(String name, RankingExpression expression) { macros.put(name, expression); } + void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } /** Returns the given signature. If it does not already exist it is added to this. */ Signature signature(String name) { @@ -51,11 +53,12 @@ public class TensorFlowModel { */ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } - /** - * Returns an immutable map of expressions that can be overridden - such as PlaceholderWithDefault/ - */ + /** Returns an immutable map of macros that are part of this model */ public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); } + /** Returns an immutable map of the macros that must be provided by the environment running this model */ + public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); } + /** Returns an immutable map of the signatures of this */ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java index d45037b6044..fc6428a4c33 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java @@ -4,8 +4,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.functions.Join; import java.util.ArrayDeque; @@ -80,14 +80,14 @@ public final class ArithmeticNode extends CompositeNode { } @Override - public ValueType type(Context context) { + public TensorType type(TypeContext context) { // Compute type using tensor types as arithmetic operators are supported on tensors // and is correct also in the special case of doubles. // As all our functions are type-commutative, we don't need to take operator precedence into account - TensorType type = children.get(0).type(context).tensorType(); + TensorType type = children.get(0).type(context); for (int i = 1; i < children.size(); i++) - type = Join.outputType(type, children.get(i).type(context).tensorType()); - return ValueType.of(type); + type = Join.outputType(type, children.get(i).type(context)); + return type; } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java index fdbb22093ea..7601c0e6180 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java @@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayList; import java.util.Deque; @@ -48,8 +49,8 @@ public class ComparisonNode extends BooleanNode { } @Override - public ValueType type(Context context) { - return ValueType.doubleType(); // by definition + public TensorType type(TypeContext context) { + return TensorType.empty; // by definition } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java index e6074a5f745..1ea8d03f0eb 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java @@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Deque; @@ -48,7 +49,7 @@ public final class ConstantNode extends ExpressionNode { } @Override - public ValueType type(Context context) { return value.type(); } + public TensorType type(TypeContext context) { return value.type(); } @Override public Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java index 8404226c33b..fd9fab99db8 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java @@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; @@ -49,7 +50,7 @@ public final class EmbracedNode extends CompositeNode { } @Override - public ValueType type(Context context) { + public TensorType type(TypeContext context) { return value.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java index 5d06a562b5d..477f4db4981 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java @@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.io.Serializable; import java.util.Deque; @@ -47,7 +48,7 @@ public abstract class ExpressionNode implements Serializable { * @param context the variable type bindings to use for this evaluation * @throws IllegalArgumentException if there are variables which are not bound in the given map */ - public abstract ValueType type(Context context); + public abstract TensorType type(TypeContext context); /** * Returns the value of evaluating this expression over the given context. diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java index b187b8f029c..79515229019 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java @@ -4,7 +4,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.functions.Join; import java.util.ArrayList; @@ -66,16 +67,16 @@ public final class FunctionNode extends CompositeNode { } @Override - public ValueType type(Context context) { + public TensorType type(TypeContext context) { if (arguments.expressions().size() == 0) - return ValueType.doubleType(); + return TensorType.empty; - ValueType argument1Type = arguments.expressions().get(0).type(context); + TensorType argument1Type = arguments.expressions().get(0).type(context); if (arguments.expressions().size() == 1) return argument1Type; - ValueType argument2Type = arguments.expressions().get(1).type(context); - return ValueType.of(Join.outputType(argument1Type.tensorType(), argument2Type.tensorType())); + TensorType argument2Type = arguments.expressions().get(1).type(context); + return Join.outputType(argument1Type, argument2Type); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java index fcd40bed4d0..e42884ecc05 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java @@ -4,8 +4,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; @@ -48,7 +48,7 @@ public class GeneratorLambdaFunctionNode extends CompositeNode { } @Override - public ValueType type(Context context) { return ValueType.of(type); } + public TensorType type(TypeContext context) { return type; } /** Evaluate this in a context which must have the arguments bound */ @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java index b9866bec027..076df327044 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java @@ -3,9 +3,13 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.List; /** * A conditional branch of a ranking expression. @@ -71,9 +75,9 @@ public final class IfNode extends CompositeNode { } @Override - public ValueType type(Context context) { - ValueType trueType = trueExpression.type(context); - ValueType falseType = falseExpression.type(context); + public TensorType type(TypeContext context) { + TensorType trueType = trueExpression.type(context); + TensorType falseType = falseExpression.type(context); if ( ! trueType.equals(falseType)) throw new IllegalArgumentException("An if expression must produce a value of the same type in both " + "alternatives, but the 'true' type is " + trueType + " while the " + diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index b898529c4b9..da946228291 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -5,7 +5,8 @@ import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; @@ -56,8 +57,8 @@ public class LambdaFunctionNode extends CompositeNode { } @Override - public ValueType type(Context context) { - return ValueType.doubleType(); // by definition - no nested lambdas + public TensorType type(TypeContext context) { + return TensorType.empty; // by definition - no nested lambdas } /** Evaluate this in a context which must have the arguments bound */ diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java index cf6475238c4..f55ed59b65c 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java @@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Deque; @@ -31,7 +32,7 @@ public final class NameNode extends ExpressionNode { } @Override - public ValueType type(Context context) { throw new RuntimeException("Named nodes can not have a type"); } + public TensorType type(TypeContext context) { throw new RuntimeException("Named nodes can not have a type"); } @Override public Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java index 2e685a6c8ab..9cbe5f98c72 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java @@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; @@ -37,7 +38,7 @@ public class NegativeNode extends CompositeNode { } @Override - public ValueType type(Context context) { + public TensorType type(TypeContext context) { return value.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java index c4b940f1bd6..e7041600635 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java @@ -3,7 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; @@ -37,7 +38,7 @@ public class NotNode extends BooleanNode { } @Override - public ValueType type(Context context) { + public TensorType type(TypeContext context) { return value.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index e5176f9966d..f79297f7773 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -5,7 +5,8 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayDeque; import java.util.Deque; @@ -45,9 +46,8 @@ public final class ReferenceNode extends CompositeNode { return new ReferenceNode(name, arguments, output); } - public String getOutput() { - return output; - } + /** Returns the specific output this references, or null if none specified */ + public String getOutput() { return output; } /** Returns a copy of this node with a modified output */ public ReferenceNode setOutput(String output) { @@ -106,7 +106,7 @@ public final class ReferenceNode extends CompositeNode { } @Override - public ValueType type(Context context) { + public TensorType type(TypeContext context) { // Don't support outputs of different type, for simplicity return context.getType(name); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java index a8b82c560f7..a7b82f4753f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java @@ -6,8 +6,9 @@ import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayList; import java.util.Deque; @@ -59,8 +60,8 @@ public class SetMembershipNode extends BooleanNode { } @Override - public ValueType type(Context context) { - return ValueType.doubleType(); + public TensorType type(TypeContext context) { + return TensorType.empty; } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 97cfa2a5350..e4c381972e9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -5,10 +5,10 @@ import com.google.common.annotations.Beta; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.evaluation.ValueType; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.functions.PrimitiveTensorFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; @@ -64,7 +64,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public ValueType type(Context context) { return ValueType.of(function.type(context)); } + public TensorType type(TypeContext context) { return function.type(context); } @Override public Value evaluate(Context context) { @@ -111,8 +111,8 @@ public class TensorFunctionNode extends CompositeNode { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(EvaluationContext context) { - return expression.type((Context)context).tensorType(); + public TensorType type(TypeContext context) { + return expression.type(context); } @Override diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java index b59b4750911..445ccf231a7 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java @@ -2,10 +2,12 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.TensorType; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; /** * @author lesters @@ -15,6 +17,18 @@ public class DropoutImportTestCase { @Test public void testDropoutImport() { TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/dropout/saved"); + + // Check (provided) macros + assertEquals(1, model.get().macros().size()); + assertTrue(model.get().macros().containsKey("training/input")); + assertEquals("constant(\"training/input\")", model.get().macros().get("training/input").getRoot().toString()); + + // Check required macros + assertEquals(1, model.get().requiredMacros().size()); + assertTrue(model.get().requiredMacros().containsKey("X")); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), + model.get().requiredMacros().get("X")); + TensorFlowModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index f12b9a2c628..01dd15d5fa0 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -8,6 +8,7 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; /** * @author bratseth @@ -33,6 +34,15 @@ public class MnistSoftmaxImportTestCase { constant1.type()); assertEquals(10, constant1.size()); + // Check (provided) macros + assertEquals(0, model.get().macros().size()); + + // Check required macros + assertEquals(1, model.get().requiredMacros().size()); + assertTrue(model.get().requiredMacros().containsKey("Placeholder")); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), + model.get().requiredMacros().get("Placeholder")); + // Check signatures assertEquals(1, model.get().signatures().size()); TensorFlowModel.Signature signature = model.get().signatures().get("serving_default"); diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java index cfa9ce670f6..6095134b7a2 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java @@ -15,14 +15,14 @@ import static com.yahoo.vespa.http.client.SessionFactory.createTimeoutExecutor; public class FeedClientFactory { /** - * Creates FeedClient. + * Creates a FeedClient. + * * @param sessionParams parameters for connection, hosts, cluster configurations and more. * @param resultCallback on each result, this callback is called. * @return newly created FeedClient API object. */ - public static FeedClient create( - SessionParams sessionParams, - FeedClient.ResultCallback resultCallback) { + public static FeedClient create(SessionParams sessionParams, FeedClient.ResultCallback resultCallback) { return new FeedClientImpl(sessionParams, resultCallback, createTimeoutExecutor()); } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java index 138be61de80..65f56f72a58 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java @@ -20,6 +20,7 @@ import java.util.List; */ // This should be an interface, but in order to be binary compatible during refactoring we made it abstract. public class Result { + public enum ResultType { OPERATION_EXECUTED, TRANSITIVE_ERROR, @@ -106,12 +107,10 @@ public class Result { /** * Information in a Result for a single operation sent to a single endpoint. - * - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> - * @since 5.1.20 */ @Immutable public static final class Detail { + private final ResultType resultType; private final Endpoint endpoint; private final Exception exception; @@ -212,4 +211,5 @@ public class Result { } return b.toString(); } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/EndpointResult.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/EndpointResult.java index 5aec46a8fc7..b04248f98a5 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/EndpointResult.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/EndpointResult.java @@ -5,9 +5,11 @@ import com.yahoo.vespa.http.client.Result; /** * Result from a single endpoint. + * * @author dybis */ public class EndpointResult { + private final String operationId; private final Result.Detail detail; @@ -23,4 +25,5 @@ public class EndpointResult { public Result.Detail getDetail() { return detail; } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ErrorCode.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ErrorCode.java index 96afc537c59..445ad5295c1 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ErrorCode.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ErrorCode.java @@ -6,12 +6,12 @@ import com.google.common.annotations.Beta; /** * Return types for the server. * - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> - * @author <a href="mailto:steinar@yahoo-inc.com">Steinar Knutsen</a> - * @since 5.1.20 + * @author Einar M R Rosenvinge + * @author Steinar Knutsen */ @Beta public enum ErrorCode { + OK(true, true), ERROR(false, false), TRANSIENT_ERROR(false, true), @@ -20,7 +20,7 @@ public enum ErrorCode { private boolean success; private boolean _transient; - private ErrorCode(boolean success, boolean _transient) { + ErrorCode(boolean success, boolean _transient) { this.success = success; this._transient = _transient; } @@ -32,4 +32,5 @@ public enum ErrorCode { public boolean isTransient() { return _transient; } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/OperationStatus.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/OperationStatus.java index 7ea4a214cbd..7aec207e0ab 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/OperationStatus.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/OperationStatus.java @@ -10,11 +10,11 @@ import java.util.Iterator; /** * Serialization/deserialization class for the result of a single document operation against Vespa. * - * @author <a href="mailto:steinar@yahoo-inc.com">Steinar Knutsen</a> - * @since 5.1 + * @author Steinar Knutsen */ @Beta public final class OperationStatus { + public static final String IS_CONDITION_NOT_MET = "IS-CONDITION-NOT-MET"; public final String message; public final String operationId; @@ -81,9 +81,7 @@ public final class OperationStatus { return new OperationStatus(message, operationId, errorCode, isConditionNotMet, traceMessage); } - /** - * @return a string representing the status. - */ + /** Returns a string representing the status. */ public String render() { StringBuilder s = new StringBuilder(); Encoder.encode(operationId, s).append(SEPARATOR); @@ -92,4 +90,5 @@ public final class OperationStatus { Encoder.encode(traceMessage, s).append(EOL); return s.toString(); } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ServerResponseException.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ServerResponseException.java index 4c291935916..1800864cd90 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ServerResponseException.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ServerResponseException.java @@ -5,12 +5,13 @@ import com.google.common.annotations.Beta; /** * The request was not processed properly on the server. - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> - * @since 5.1.20 + * + * @author Einar M R Rosenvinge */ @SuppressWarnings("serial") @Beta public class ServerResponseException extends Exception { + private final int responseCode; private final String responseString; @@ -33,5 +34,6 @@ public class ServerResponseException extends Exception { } return responseString; } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java index a16d992324d..903c1ad4842 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java @@ -17,7 +17,7 @@ import java.util.concurrent.TimeUnit; /** * Implementation of FeedClient. It is a thin layer on top of multiClusterHandler and multiClusterResultAggregator. - * + * * @author dybis */ public class FeedClientImpl implements FeedClient { @@ -92,4 +92,5 @@ public class FeedClientImpl implements FeedClient { } return true; } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java index edbe5542bc4..c3b5d9912de 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java @@ -6,7 +6,6 @@ import com.yahoo.vespa.http.client.Result; import com.yahoo.vespa.http.client.Session; import com.yahoo.vespa.http.client.config.SessionParams; import com.yahoo.vespa.http.client.core.ThrottlePolicy; -import com.yahoo.vespa.http.client.core.api.MultiClusterSessionOutputStream; import com.yahoo.vespa.http.client.core.operationProcessor.IncompleteResultsThrottler; import com.yahoo.vespa.http.client.core.operationProcessor.OperationProcessor; @@ -65,4 +64,5 @@ public class SessionImpl implements Session { public int getIncompleteResultQueueSize() { return operationProcessor.getIncompleteResultQueueSize(); } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java index 420f64d4bf3..dd724e90a42 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java @@ -44,8 +44,6 @@ import java.util.zip.GZIPOutputStream; /** * @author Einar M R Rosenvinge - * - * @since 5.1.20 */ @Beta class ApacheGatewayConnection implements GatewayConnection { diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java index 671e6f07dbe..d963ae79227 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java @@ -2,8 +2,6 @@ package com.yahoo.vespa.http.client.core.communication; import com.yahoo.vespa.http.client.core.Document; -import com.yahoo.vespa.http.client.core.EndpointResult; -import com.yahoo.vespa.http.client.core.operationProcessor.EndPointResultFactory; import java.util.ArrayDeque; import java.util.ArrayList; @@ -13,7 +11,7 @@ import java.util.Optional; import java.util.concurrent.TimeUnit; /** - * Document queue that only gives you document operations on documents for which there are no + * Document queue that only gives you document operations on documents for which there are no * already in flight operations for. * * @author dybis @@ -54,8 +52,6 @@ class DocumentQueue { } } - - Document poll(long timeout, TimeUnit unit) throws InterruptedException { synchronized (queue) { long remainingToWait = unit.toMillis(timeout); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index b396f831de0..5b98a1b4fb5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -19,7 +19,6 @@ import java.util.stream.Collectors; * A tensor type with its dimensions. This is immutable. * <p> * A dimension can be indexed (bound or unbound) or mapped. - * Currently, we only support tensor types where all dimensions have the same type. * * @author geirst * @author bratseth @@ -27,6 +26,7 @@ import java.util.stream.Collectors; @Beta public class TensorType { + /** The empty tensor type - which is the same as a double */ public static final TensorType empty = new TensorType(Collections.emptyList()); /** Sorted list of the dimensions of this */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java index e18b77a0434..3fb94f1251b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java @@ -3,7 +3,6 @@ package com.yahoo.tensor.evaluation; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; /** * An evaluation context which is passed down to all nested functions during evaluation. @@ -11,15 +10,7 @@ import com.yahoo.tensor.TensorType; * @author bratseth */ @Beta -public interface EvaluationContext { - - /** - * Returns tye type of the tensor with this name. - * - * @return returns the type of the tensor which will be returned by calling getTensor(name) - * or null if getTensor will return null. - */ - TensorType getTensorType(String name); +public interface EvaluationContext extends TypeContext { /** Returns the tensor bound to this name, or null if none */ Tensor getTensor(String name); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java index 6bdfe8f19b6..9fe6b7d053f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java @@ -20,7 +20,7 @@ public class MapEvaluationContext implements EvaluationContext { public void put(String name, Tensor tensor) { bindings.put(name, tensor); } @Override - public TensorType getTensorType(String name) { + public TensorType getType(String name) { Tensor tensor = bindings.get(name); if (tensor == null) return null; return tensor.type(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java new file mode 100644 index 00000000000..9b2e81f0b6d --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java @@ -0,0 +1,18 @@ +package com.yahoo.tensor.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import com.yahoo.tensor.TensorType; + +/** + * @author bratseth + */ +public interface TypeContext { + + /** + * Returns tye type of the tensor with this name. + * + * @return returns the type of the tensor which will be returned by calling getTensor(name) + * or null if getTensor will return null. + */ + TensorType getType(String name); + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index 6c149724aca..34beb465d4c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -10,6 +10,7 @@ import com.yahoo.tensor.functions.ToStringContext; import java.util.Collections; import java.util.List; +import java.util.Optional; /** * A tensor variable name which resolves to a tensor in the context at evaluation time @@ -20,9 +21,17 @@ import java.util.List; public class VariableTensor extends PrimitiveTensorFunction { private final String name; + private final Optional<TensorType> requiredType; public VariableTensor(String name) { this.name = name; + this.requiredType = Optional.empty(); + } + + /** A variable tensor which must be compatible with the given type */ + public VariableTensor(String name, TensorType requiredType) { + this.name = name; + this.requiredType = Optional.of(requiredType); } @Override @@ -35,11 +44,19 @@ public class VariableTensor extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(EvaluationContext context) { return context.getTensorType(name); } + public TensorType type(TypeContext context) { + TensorType givenType = context.getType(name); + if (givenType == null) return null; + verifyType(givenType); + return givenType; + } @Override public Tensor evaluate(EvaluationContext context) { - return context.getTensor(name); + Tensor tensor = context.getTensor(name); + if (tensor == null) return null; + verifyType(tensor.type()); + return tensor; } @Override @@ -47,4 +64,9 @@ public class VariableTensor extends PrimitiveTensorFunction { return name; } + private void verifyType(TensorType givenType) { + if (requiredType.isPresent() && ! givenType.isAssignableTo(requiredType.get())) + throw new IllegalArgumentException("Variable '" + name + "' must be compatible with " + + requiredType.get() + " but was " + givenType); + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 0c43caef05c..2109b730e1a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -5,6 +5,7 @@ import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; /** * A composite tensor function is a tensor function which can be expressed (less tersely) @@ -17,7 +18,7 @@ public abstract class CompositeTensorFunction extends TensorFunction { /** Finds the type this produces by first converting it to a primitive function */ @Override - public final TensorType type(EvaluationContext context) { return toPrimitive().type(context); } + public final TensorType type(TypeContext context) { return toPrimitive().type(context); } /** Evaluates this by first converting it to a primitive function */ @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index cc8067224c7..c77ed1c0526 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -9,8 +9,14 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; - -import java.util.*; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; /** @@ -54,7 +60,7 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public TensorType type(EvaluationContext context) { + public TensorType type(TypeContext context) { return type(argumentA.type(context), argumentB.type(context)); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 4a6d656142f..50b479da168 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -5,6 +5,7 @@ import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.List; @@ -41,7 +42,7 @@ public class ConstantTensor extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(EvaluationContext context) { return constant.type(); } + public TensorType type(TypeContext context) { return constant.type(); } @Override public Tensor evaluate(EvaluationContext context) { return constant; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index ff9589bd6ae..e70d1de3db7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -7,6 +7,7 @@ import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.List; @@ -60,7 +61,7 @@ public class Generate extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(EvaluationContext context) { return type; } + public TensorType type(TypeContext context) { return type; } @Override public Tensor evaluate(EvaluationContext context) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 835a2a82a2c..7812c985091 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -11,6 +11,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayList; import java.util.Collections; @@ -94,7 +95,7 @@ public class Join extends PrimitiveTensorFunction { } @Override - public TensorType type(EvaluationContext context) { + public TensorType type(TypeContext context) { return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index e5440b56c54..53504868ff2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -6,6 +6,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Iterator; @@ -52,7 +53,7 @@ public class Map extends PrimitiveTensorFunction { } @Override - public TensorType type(EvaluationContext context) { + public TensorType type(TypeContext context) { return argument.type(context); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 591a6e4649e..76a938b9fe2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -8,6 +8,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.HashMap; @@ -100,7 +101,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public TensorType type(EvaluationContext context) { + public TensorType type(TypeContext context) { return type(argument.type(context)); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 6a9b8d68b38..de3d2be265a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -7,6 +7,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.HashMap; @@ -71,7 +72,7 @@ public class Rename extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(EvaluationContext context) { + public TensorType type(TypeContext context) { return type(argument.type(context)); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 3f6dfae6222..78ab09c7820 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -6,6 +6,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.List; @@ -49,7 +50,7 @@ public abstract class TensorFunction { * * @param context a context which must be passed to all nexted functions when evaluating */ - public abstract TensorType type(EvaluationContext context); + public abstract TensorType type(TypeContext context); /** Evaluate with no context */ public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); } |