diff options
79 files changed, 898 insertions, 460 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index bacff94d776..cd65c6ef761 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -2,15 +2,26 @@ package com.yahoo.searchdefinition; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.io.reader.NamedReader; +import com.yahoo.processing.request.CompoundName; +import com.yahoo.search.query.profile.QueryProfile; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.search.query.profile.config.QueryProfileXMLReader; +import com.yahoo.search.query.profile.types.FieldDescription; +import com.yahoo.search.query.profile.types.TensorFieldType; import com.yahoo.search.query.ranking.Diversity; +import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.FeatureList; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TypeMapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.io.File; import java.io.IOException; @@ -276,12 +287,10 @@ public class RankProfile implements Serializable, Cloneable { addConstant(name, value); } - /** - * Returns an unmodifiable view of the constants to use in this. - */ + /** Returns an unmodifiable view of the constants available in this */ public Map<String, Value> getConstants() { if (constants.isEmpty()) - return getInherited() != null ? getInherited().getConstants() : Collections.<String,Value>emptyMap(); + return getInherited() != null ? getInherited().getConstants() : Collections.emptyMap(); if (getInherited() == null || getInherited().getConstants().isEmpty()) return Collections.unmodifiableMap(constants); @@ -433,8 +442,9 @@ public class RankProfile implements Serializable, Cloneable { properties.add(rankProperty); } + @Override public String toString() { - return "rank profile " + getName(); + return "rank profile '" + getName() + "'"; } public int getRerankCount() { @@ -508,6 +518,7 @@ public class RankProfile implements Serializable, Cloneable { /** * Returns the string form of the first phase ranking expression. + * * @return string form of first phase ranking expression */ public String getFirstPhaseRankingString() { @@ -697,7 +708,8 @@ public class RankProfile implements Serializable, Cloneable { private void checkNameCollisions(Map<String, Macro> macros, Map<String, Value> constants) { for (Map.Entry<String, Macro> macroEntry : macros.entrySet()) { if (constants.get(macroEntry.getKey()) != null) - throw new IllegalArgumentException("Cannot have both a constant and macro named '" + macroEntry.getKey() + "'"); + throw new IllegalArgumentException("Cannot have both a constant and macro named '" + + macroEntry.getKey() + "'"); } } @@ -727,6 +739,64 @@ public class RankProfile implements Serializable, Cloneable { } /** + * Creates a context containing the type information of all constants, attributes and query profiles + * referable from this rank profile. + */ + public TypeContext typeContext() { + TypeMapContext context = new TypeMapContext(); + + // Add constants + getConstants().forEach((k, v) -> context.setType(asConstantFeature(k), v.type())); + + // Add attributes + for (SDField field : getSearch().allConcreteFields()) + field.getAttributes().forEach((k, a) -> context.setType(asAttributeFeature(k), a.tensorType().orElse(TensorType.empty))); + + // Add query features from rank profile types reached from the "default" profile + QueryProfile profile = queryProfilesOf(getSearch().sourceApplication()).getComponent("default"); + if (profile != null && profile.getType() != null) { + profile.listTypes(CompoundName.empty, Collections.emptyMap()).forEach((prefix, queryProfileType) -> { + for (FieldDescription field : queryProfileType.declaredFields().values()) { + TensorType type = TensorType.empty; // assume the empty (aka double) type by default + if (field.getType() instanceof TensorFieldType) + type = ((TensorFieldType)field.getType()).type().get(); + + String feature = asQueryFeature(prefix.append(field.getName()).toString()); + context.setType(feature, type); + } + }); + } + + return context; + } + + private QueryProfileRegistry queryProfilesOf(ApplicationPackage applicationPackage) { + List<NamedReader> queryProfileFiles = null; + List<NamedReader> queryProfileTypeFiles = null; + try { + queryProfileFiles = applicationPackage.getQueryProfileFiles(); + queryProfileTypeFiles = applicationPackage.getQueryProfileTypeFiles(); + return new QueryProfileXMLReader().read(queryProfileFiles, queryProfileTypeFiles); + } + finally { + NamedReader.closeAll(queryProfileFiles); + NamedReader.closeAll(queryProfileTypeFiles); + } + } + + private String asConstantFeature(String constantName) { + return "constant(\"" + constantName + "\")"; + } + + private String asAttributeFeature(String constantName) { + return "attribute(\"" + constantName + "\")"; + } + + private String asQueryFeature(String constantName) { + return "query(\"" + constantName + "\")"; + } + + /** * A rank setting. The identity of a rank setting is its field name and type (not value). * A rank setting is immutable. */ diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java index df5697de0d5..f4a0365e36e 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java @@ -247,7 +247,7 @@ public class Search implements Serializable, ImmutableSearch { /** * Returns a list of all the fields of this search definition, that is all fields in all documents, in the documents * they inherit, and all extra fields. The caller receives ownership to the list - subsequent changes to it will not - * impact this Search + * impact this * * @return the list of fields in this searchdefinition */ @@ -546,7 +546,7 @@ public class Search implements Serializable, ImmutableSearch { } /** - * Returns the first occurence of an attribute having this name, or null if none + * Returns the first occurrence of an attribute having this name, or null if none * * @param name Name of attribute * @return The Attribute with given name. diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java index ce3ff7cc447..72ba6de7022 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java @@ -31,24 +31,19 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce private Map<String, Attribute> attributes = new java.util.LinkedHashMap<>(); private Map<String, Attribute> importedAttributes = new java.util.LinkedHashMap<>(); - /** - * Flag indicating if a position-attribute has been found - */ + /** Whether this has any position attribute */ private boolean hasPosition = false; public AttributeFields(Search search) { derive(search); } - /** - * Derives everything from a field - */ + /** Derives everything from a field */ @Override protected void derive(ImmutableSDField field, Search search) { if (field.usesStructOrMap() && !field.getDataType().equals(PositionDataType.INSTANCE) && - !field.getDataType().equals(DataType.getArray(PositionDataType.INSTANCE))) - { + !field.getDataType().equals(DataType.getArray(PositionDataType.INSTANCE))) { return; // Ignore struct fields for indexed search (only implemented for streaming search) } if (field.isImportedField()) { @@ -58,9 +53,7 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce } } - /** - * Return an attribute by name, or null if it doesn't exist - */ + /** Returns an attribute by name, or null if it doesn't exist */ public Attribute getAttribute(String attributeName) { return attributes.get(attributeName); } @@ -69,9 +62,7 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce return getAttribute(attributeName) != null; } - /** - * Derives one attribute. TODO: Support non-default named attributes - */ + /** Derives one attribute. TODO: Support non-default named attributes */ private void deriveAttributes(ImmutableSDField field) { for (Attribute fieldAttribute : field.getAttributes().values()) { deriveAttribute(field, fieldAttribute); @@ -107,9 +98,7 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce } } - /** - * Returns a read only attribute iterator - */ + /** Returns a read only attribute iterator */ public Iterator attributeIterator() { return attributes().iterator(); } @@ -201,4 +190,5 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce } } } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 01d3449573c..9c7fd7d9f0a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -41,7 +41,6 @@ import java.util.Optional; * * @author bratseth */ -// TODO: Verify types of macros // TODO: Avoid name conflicts across models for constants public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { @@ -84,6 +83,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil Signature signature = chooseSignature(model, store.arguments().signature()); String output = chooseOutput(signature, store.arguments().output()); RankingExpression expression = model.expressions().get(output); + verifyRequiredMacros(expression, model.requiredMacros(), profile); store.writeConverted(expression); model.constants().forEach((k, v) -> transformConstant(store, profile, k, v)); @@ -168,6 +168,44 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } /** + * Verify that the macros referred in the given expression exists in the given rank profile, + * and return tensors of the types specified in requiredMacros. + */ + private void verifyRequiredMacros(RankingExpression expression, Map<String, TensorType> requiredMacros, + RankProfile profile) { + List<String> macroNames = new ArrayList<>(); + addMacroNamesIn(expression.getRoot(), macroNames); + for (String macroName : macroNames) { + TensorType requiredType = requiredMacros.get(macroName); + if (requiredType == null) continue; // Not a required macro + + RankProfile.Macro macro = profile.getMacros().get(macroName); + if (macro == null) + throw new IllegalArgumentException("Model refers Placeholder '" + macroName + + "' of type " + requiredType + " but this macro is not present in " + + profile); + TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext()); + if ( ! actualType.isAssignableTo(requiredType)) + throw new IllegalArgumentException("Model refers Placeholder '" + macroName + + "' of type " + requiredType + + " which must be produced by a macro in the rank profile, but " + + "this macro produces type " + actualType); + } + } + + private void addMacroNamesIn(ExpressionNode node, List<String> names) { + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode)node; + if (referenceNode.getOutput() == null) // macro references cannot specify outputs + names.add(referenceNode.getName()); + } + else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode)node).children()) + addMacroNamesIn(child, names); + } + } + + /** * Provides read/write access to the correct directories of the application package given by the feature arguments */ private static class ModelStore { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/QueryProfilesBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/QueryProfilesBuilder.java index d08092003f1..b85cb88bf2e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/QueryProfilesBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/QueryProfilesBuilder.java @@ -5,6 +5,7 @@ import com.yahoo.io.reader.NamedReader; import com.yahoo.search.query.profile.config.QueryProfileXMLReader; import com.yahoo.config.application.api.ApplicationPackage; +import java.util.Collections; import java.util.List; /** @@ -13,17 +14,16 @@ import java.util.List; * * @author bratseth */ -// TODO: Move into QueryProfiles public class QueryProfilesBuilder { /** Build the set of query profiles for an application package */ public QueryProfiles build(ApplicationPackage applicationPackage) { - List<NamedReader> queryProfileTypeFiles=null; - List<NamedReader> queryProfileFiles=null; + List<NamedReader> queryProfileTypeFiles = null; + List<NamedReader> queryProfileFiles = null; try { - queryProfileTypeFiles=applicationPackage.getQueryProfileTypeFiles(); - queryProfileFiles=applicationPackage.getQueryProfileFiles(); - return new QueryProfiles(new QueryProfileXMLReader().read(queryProfileTypeFiles,queryProfileFiles)); + queryProfileTypeFiles = applicationPackage.getQueryProfileTypeFiles(); + queryProfileFiles = applicationPackage.getQueryProfileFiles(); + return new QueryProfiles(new QueryProfileXMLReader().read(queryProfileTypeFiles, queryProfileFiles)); } finally { NamedReader.closeAll(queryProfileTypeFiles); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 89f1a9f785c..82d0d66a82a 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -49,14 +49,8 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testMinimalTensorFlowReference() throws ParseException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved')" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); assertConstant("Variable_1", search, Optional.of(10L)); assertConstant("Variable", search, Optional.of(7840L)); @@ -64,14 +58,8 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testNestedTensorFlowReference() throws ParseException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: 5 + sum(tensorflow('mnist_softmax/saved'))" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "5 + sum(tensorflow('mnist_softmax/saved'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); assertConstant("Variable_1", search, Optional.of(10L)); assertConstant("Variable", search, Optional.of(7840L)); @@ -79,39 +67,26 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingSignature() throws ParseException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_default')" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved', 'serving_default')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test public void testTensorFlowReferenceSpecifyingSignatureAndOutput() throws ParseException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_default', 'y')" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved', 'serving_default', 'y')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test - public void testTensorFlowReferenceSpecifyingNonExistingSignature() throws ParseException { + public void testTensorFlowReferenceMissingMacro() throws ParseException { try { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( - application, + new StoringApplicationPackage(applicationDir), " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_defaultz')" + + " expression: tensorflow('mnist_softmax/saved')" + " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -119,6 +94,40 @@ public class RankingExpressionWithTensorFlowTestCase { } catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + + "tensorflow('mnist_softmax/saved'): " + + "Model refers Placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + + "not present in rank profile 'my_profile'", + Exceptions.toMessageString(expected)); + } + } + + @Test + public void testTensorFlowReferenceWithWrongMacroType() throws ParseException { + try { + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d5[10])(0.0)", + "tensorflow('mnist_softmax/saved')"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + fail("Expecting exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + + "tensorflow('mnist_softmax/saved'): " + + "Model refers Placeholder 'Placeholder' of type tensor(d0[],d1[784]) which must be produced " + + "by a macro in the rank profile, but this macro produces type tensor(d0[2],d5[10])", + Exceptions.toMessageString(expected)); + } + } + + @Test + public void testTensorFlowReferenceSpecifyingNonExistingSignature() throws ParseException { + try { + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved', 'serving_defaultz')"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + fail("Expecting exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved','serving_defaultz'): " + "Model does not have the specified signature 'serving_defaultz'", Exceptions.toMessageString(expected)); @@ -128,14 +137,8 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingNonExistingOutput() throws ParseException { try { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_default', 'x')" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved', 'serving_default', 'x')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } @@ -149,14 +152,8 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testImportingFromStoredExpressions() throws ParseException, IOException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_default')" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); assertConstant("Variable_1", search, Optional.of(10L)); assertConstant("Variable", search, Optional.of(7840L)); @@ -168,13 +165,9 @@ public class RankingExpressionWithTensorFlowTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); - RankProfileSearchFixture searchFromStored = new RankProfileSearchFixture( - storedApplication, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_default')" + - " }\n" + - " }"); + RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved')", + storedApplication); searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); // Verify that the constants exists, but don't verify the content as we are not // simulating file distribution in this test @@ -212,6 +205,30 @@ public class RankingExpressionWithTensorFlowTestCase { } } + private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) { + return fixtureWith(placeholderExpression, firstPhaseExpression, new StoringApplicationPackage(applicationDir)); + } + + private RankProfileSearchFixture fixtureWith(String placeholderExpression, + String firstPhaseExpression, + StoringApplicationPackage application) { + try { + return new RankProfileSearchFixture( + application, + " rank-profile my_profile {\n" + + " macro Placeholder() {\n" + + " expression: " + placeholderExpression + + " }\n" + + " first-phase {\n" + + " expression: " + firstPhaseExpression + + " }\n" + + " }"); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + private static class StoringApplicationPackage extends MockApplicationPackage { private final File root; diff --git a/container-core/src/main/java/com/yahoo/container/jdisc/ExtendedResponse.java b/container-core/src/main/java/com/yahoo/container/jdisc/ExtendedResponse.java index 9b6837a9dcb..b2d32c8e745 100644 --- a/container-core/src/main/java/com/yahoo/container/jdisc/ExtendedResponse.java +++ b/container-core/src/main/java/com/yahoo/container/jdisc/ExtendedResponse.java @@ -1,18 +1,17 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.container.jdisc; -import java.io.IOException; -import java.io.OutputStream; - import com.yahoo.container.handler.Coverage; import com.yahoo.container.handler.Timing; import com.yahoo.container.logging.HitCounts; import com.yahoo.jdisc.handler.CompletionHandler; import com.yahoo.jdisc.handler.ContentChannel; +import java.io.IOException; +import java.io.OutputStream; + /** - * An HTTP response supporting async rendering and extended information for - * logging. + * An HTTP response supporting async rendering and extended information for logging. * * @author Steinar Knutsen */ diff --git a/container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java b/container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java index 933751fd9ad..1c2d28c754f 100644 --- a/container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java +++ b/container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java @@ -84,16 +84,18 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler { } @Override - protected LoggingCompletionHandler createLoggingCompletionHandler( - long startTime, long renderStartTime, HttpResponse response, - HttpRequest httpRequest, ContentChannelOutputStream rendererWiring) { + protected LoggingCompletionHandler createLoggingCompletionHandler(long startTime, + long renderStartTime, + HttpResponse response, + HttpRequest httpRequest, + ContentChannelOutputStream rendererWiring) { return new LoggingHandler(startTime, renderStartTime, httpRequest, response, rendererWiring); } private static String getClientIP(com.yahoo.jdisc.http.HttpRequest httpRequest) { SocketAddress clientAddress = httpRequest.getRemoteAddress(); - if (clientAddress == null) - return "0.0.0.0"; + if (clientAddress == null) return "0.0.0.0"; + return clientAddress.toString(); } @@ -105,24 +107,21 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler { } } - private static String remoteHostAddress( - com.yahoo.jdisc.http.HttpRequest httpRequest) { + private static String remoteHostAddress(com.yahoo.jdisc.http.HttpRequest httpRequest) { SocketAddress remoteAddress = httpRequest.getRemoteAddress(); - if (remoteAddress == null) - return "0.0.0.0"; + if (remoteAddress == null) return "0.0.0.0"; + if (remoteAddress instanceof InetSocketAddress) { - return ((InetSocketAddress) remoteAddress).getAddress() - .getHostAddress(); + return ((InetSocketAddress) remoteAddress).getAddress().getHostAddress(); } else { - throw new RuntimeException( - "Expected remote address of type InetSocketAddress, got " - + remoteAddress.getClass().getName()); + throw new RuntimeException("Expected remote address of type InetSocketAddress, got " + + remoteAddress.getClass().getName()); } } private void logTimes(long startTime, String sourceIP, - long renderStartTime, long commitStartTime, long endTime, - String req, String normalizedQuery, Timing t) { + long renderStartTime, long commitStartTime, long endTime, + String req, String normalizedQuery, Timing t) { // note: intentionally only taking time since request was received long totalTime = endTime - startTime; @@ -140,33 +139,33 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler { return; } - StringBuilder msgbuf = new StringBuilder(); - msgbuf.append(normalizedQuery); - msgbuf.append(" from ").append(sourceIP).append(". "); + StringBuilder b = new StringBuilder(); + b.append(normalizedQuery); + b.append(" from ").append(sourceIP).append(". "); if (requestOverhead > 0) { - msgbuf.append("Time from HTTP connection open to request reception "); - msgbuf.append(requestOverhead).append(" ms. "); + b.append("Time from HTTP connection open to request reception "); + b.append(requestOverhead).append(" ms. "); } if (summaryStartTime != 0) { - msgbuf.append("Request time: "); - msgbuf.append(summaryStartTime - startTime).append(" ms. "); - msgbuf.append("Summary fetch time: "); - msgbuf.append(renderStartTime - summaryStartTime).append(" ms. "); + b.append("Request time: "); + b.append(summaryStartTime - startTime).append(" ms. "); + b.append("Summary fetch time: "); + b.append(renderStartTime - summaryStartTime).append(" ms. "); } else { long spentSearching = renderStartTime - startTime; - msgbuf.append("Processing time: ").append(spentSearching).append(" ms. "); + b.append("Processing time: ").append(spentSearching).append(" ms. "); } - msgbuf.append("Result rendering/transfer: "); - msgbuf.append(commitStartTime - renderStartTime).append(" ms. "); - msgbuf.append("End transaction: "); - msgbuf.append(endTime - commitStartTime).append(" ms. "); - msgbuf.append("Total: ").append(totalTime).append(" ms. "); - msgbuf.append("Timeout: ").append(timeoutInterval).append(" ms. "); - msgbuf.append("Request string: ").append(req); + b.append("Result rendering/transfer: "); + b.append(commitStartTime - renderStartTime).append(" ms. "); + b.append("End transaction: "); + b.append(endTime - commitStartTime).append(" ms. "); + b.append("Total: ").append(totalTime).append(" ms. "); + b.append("Timeout: ").append(timeoutInterval).append(" ms. "); + b.append("Request string: ").append(req); - log.log(LogLevel.WARNING, "Slow execution. " + msgbuf); + log.log(LogLevel.WARNING, "Slow execution. " + b); } private static class NullResponse extends ExtendedResponse { @@ -175,10 +174,11 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler { } @Override - public void render(OutputStream output, ContentChannel networkChannel, - CompletionHandler handler) throws IOException { + public void render(OutputStream output, ContentChannel networkChannel, CompletionHandler handler) + throws IOException { // NOP } + } private class LoggingHandler implements LoggingCompletionHandler { @@ -191,9 +191,8 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler { private final ContentChannelOutputStream rendererWiring; private final ExtendedResponse extendedResponse; - LoggingHandler(long startTime, long renderStartTime, - HttpRequest httpRequest, HttpResponse httpResponse, - ContentChannelOutputStream rendererWiring) { + LoggingHandler(long startTime, long renderStartTime, HttpRequest httpRequest, HttpResponse httpResponse, + ContentChannelOutputStream rendererWiring) { this.startTime = startTime; this.renderStartTime = renderStartTime; this.commitStartTime = renderStartTime; @@ -233,20 +232,19 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler { } private void writeToLogs(long endTime) { - final com.yahoo.jdisc.http.HttpRequest jdiscRequest = httpRequest.getJDiscRequest(); + com.yahoo.jdisc.http.HttpRequest jdiscRequest = httpRequest.getJDiscRequest(); - logTimes( - startTime, - getClientIP(jdiscRequest), - renderStartTime, - commitStartTime, - endTime, - jdiscRequest.getUri().toString(), - extendedResponse.getParsedQuery(), - extendedResponse.getTiming()); + logTimes(startTime, + getClientIP(jdiscRequest), + renderStartTime, + commitStartTime, + endTime, + jdiscRequest.getUri().toString(), + extendedResponse.getParsedQuery(), + extendedResponse.getTiming()); - final Optional<AccessLogEntry> jdiscRequestAccessLogEntry - = AccessLoggingRequestHandler.getAccessLogEntry(jdiscRequest); + Optional<AccessLogEntry> jdiscRequestAccessLogEntry = + AccessLoggingRequestHandler.getAccessLogEntry(jdiscRequest); if (jdiscRequestAccessLogEntry.isPresent()) { // This means we are running with Jetty, not Netty. @@ -275,18 +273,17 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler { } } - private void populateAccessLogEntryNotCreatedByHttpServer( - final AccessLogEntry logEntry, - final com.yahoo.jdisc.http.HttpRequest httpRequest, - final Timing timing, - final String fullRequest, - final long commitStartTime, - final long startTime, - final long written, - final int status) { + private void populateAccessLogEntryNotCreatedByHttpServer(AccessLogEntry logEntry, + com.yahoo.jdisc.http.HttpRequest httpRequest, + Timing timing, + String fullRequest, + long commitStartTime, + long startTime, + long written, + int status) { try { - final InetSocketAddress remoteAddress = AccessLogUtil.getRemoteAddress(httpRequest); - final long evalStartTime = getEvalStart(timing, startTime); + InetSocketAddress remoteAddress = AccessLogUtil.getRemoteAddress(httpRequest); + long evalStartTime = getEvalStart(timing, startTime); logEntry.setIpV4Address(remoteHostAddress(httpRequest)); logEntry.setTimeStamp(evalStartTime); logEntry.setDurationBetweenRequestResponse(commitStartTime - evalStartTime); @@ -302,8 +299,8 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler { logEntry.setHttpMethod(AccessLogUtil.getHttpMethod(httpRequest)); logEntry.setHttpVersion(AccessLogUtil.getHttpVersion(httpRequest)); } catch (Exception e) { - log.log(LogLevel.WARNING, "Could not populate the access log [" - + fullRequest + "]", e); + log.log(LogLevel.WARNING, "Could not populate the access log [" + fullRequest + "]", e); } } + } diff --git a/container-search/src/main/java/com/yahoo/search/Result.java b/container-search/src/main/java/com/yahoo/search/Result.java index ded8992fa65..7978798f53c 100644 --- a/container-search/src/main/java/com/yahoo/search/Result.java +++ b/container-search/src/main/java/com/yahoo/search/Result.java @@ -46,7 +46,7 @@ public final class Result extends com.yahoo.processing.Response implements Clone * Headers containing "envelope" meta information to be returned with this result. * Used for HTTP getHeaders when the return protocol is HTTP. */ - private ListMap<String,String> headers=null; + private ListMap<String,String> headers = null; /** * Result rendering infrastructure. diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java index 04dd3ee9005..a347fbfb3ab 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java @@ -43,28 +43,28 @@ import java.util.regex.Pattern; public class QueryProfile extends FreezableSimpleComponent implements Cloneable { /** Defines the permissible content of this, or null if any content is permissible */ - private QueryProfileType type=null; + private QueryProfileType type = null; /** The value at this query profile - allows non-fields to have values, e.g a=value1, a.b=value2 */ - private Object value=null; + private Object value = null; /** The variants of this, or null if none */ - private QueryProfileVariants variants=null; + private QueryProfileVariants variants = null; /** The resolved variant dimensions of this, or null if none or not resolved yet (is resolved at freeze) */ - private List<String> resolvedDimensions=null; + private List<String> resolvedDimensions = null; /** The query profiles inherited by this, or null if none */ - private List<QueryProfile> inherited=null; + private List<QueryProfile> inherited = null; /** The content of this profile. The values may be primitives, substitutable strings or other query profiles */ - private CopyOnWriteContent content=new CopyOnWriteContent(); + private CopyOnWriteContent content = new CopyOnWriteContent(); /** * Field override settings: fieldName→OverrideValue. These overrides the override * setting in the type (if any) of this field). If there are no query profile level settings, this is null. */ - private Map<String,Boolean> overridable=null; + private Map<String,Boolean> overridable = null; /** * Creates a new query profile from an id. @@ -108,26 +108,26 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable /** Adds a profile to the end of the inherited list of this. Throws an exception if this is frozen. */ public void addInherited(QueryProfile profile) { - addInherited(profile,(DimensionValues)null); + addInherited(profile, (DimensionValues)null); } public final void addInherited(QueryProfile profile,String[] dimensionValues) { - addInherited(profile,DimensionValues.createFrom(dimensionValues)); + addInherited(profile, DimensionValues.createFrom(dimensionValues)); } /** Adds a profile to the end of the inherited list of this for the given variant. Throws an exception if this is frozen. */ public void addInherited(QueryProfile profile, DimensionValues dimensionValues) { ensureNotFrozen(); - DimensionBinding dimensionBinding=DimensionBinding.createFrom(getDimensions(),dimensionValues); + DimensionBinding dimensionBinding=DimensionBinding.createFrom(getDimensions(), dimensionValues); if (dimensionBinding.isNull()) { - if (inherited==null) - inherited=new ArrayList<>(); + if (inherited == null) + inherited = new ArrayList<>(); inherited.add(profile); } else { - if (variants==null) - variants=new QueryProfileVariants(dimensionBinding.getDimensions(), this); + if (variants == null) + variants = new QueryProfileVariants(dimensionBinding.getDimensions(), this); variants.inherit(profile,dimensionBinding.getValues()); } } @@ -152,13 +152,13 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable * @throws IllegalStateException if this is frozen */ public Boolean isDeclaredOverridable(String name, Map<String,String> context) { - return isDeclaredOverridable(new CompoundName(name),DimensionBinding.createFrom(getDimensions(),context)); + return isDeclaredOverridable(new CompoundName(name), DimensionBinding.createFrom(getDimensions(), context)); } /** Sets the dimensions over which this may vary. Note: This will erase any currently defined variants */ public void setDimensions(String[] dimensions) { ensureNotFrozen(); - variants=new QueryProfileVariants(dimensions, this); + variants = new QueryProfileVariants(dimensions, this); } /** Returns the value set at this node, to allow non-leafs to have values. Returns null if none. */ @@ -166,17 +166,17 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable public void setValue(Object value) { ensureNotFrozen(); - this.value=value; + this.value = value; } /** Returns the variant dimensions to be used in this - an unmodifiable list of dimension names */ public List<String> getDimensions() { if (isFrozen()) return resolvedDimensions; - if (variants!=null) return variants.getDimensions(); - if (inherited==null) return null; + if (variants != null) return variants.getDimensions(); + if (inherited == null) return null; for (QueryProfile inheritedProfile : inherited) { - List<String> inheritedDimensions=inheritedProfile.getDimensions(); - if (inheritedDimensions!=null) return inheritedDimensions; + List<String> inheritedDimensions = inheritedProfile.getDimensions(); + if (inheritedDimensions != null) return inheritedDimensions; } return null; } @@ -187,8 +187,8 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable * Sets the overridability of a field in this profile, * this overrides the corresponding setting in the type (if any) */ - public final void setOverridable(String fieldName, boolean overridable, Map<String,String> context) { - setOverridable(new CompoundName(fieldName), overridable,DimensionBinding.createFrom(getDimensions(), context)); + public final void setOverridable(String fieldName, boolean overridable, Map<String, String> context) { + setOverridable(new CompoundName(fieldName), overridable, DimensionBinding.createFrom(getDimensions(), context)); } /** @@ -241,8 +241,8 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable Map<String,Object> values=visitor.getResult(); if (substitution==null) return values; - for (Map.Entry<String,Object> entry : values.entrySet()) { - if (entry.getValue().getClass()==String.class) continue; // Shortcut + for (Map.Entry<String, Object> entry : values.entrySet()) { + if (entry.getValue().getClass() == String.class) continue; // Shortcut if (entry.getValue() instanceof SubstituteString) entry.setValue(((SubstituteString)entry.getValue()).substitute(context,substitution)); } @@ -253,7 +253,7 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable * Lists types reachable from this, indexed by the prefix having that type. * If this is itself typed, this' type will be included with an empty prefix */ - Map<CompoundName, QueryProfileType> listTypes(CompoundName prefix, Map<String, String> context) { + public Map<CompoundName, QueryProfileType> listTypes(CompoundName prefix, Map<String, String> context) { DimensionBinding dimensionBinding = DimensionBinding.createFrom(getDimensions(), context); AllTypesQueryProfileVisitor visitor = new AllTypesQueryProfileVisitor(prefix); accept(visitor, dimensionBinding, null); @@ -264,8 +264,8 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable * Lists references reachable from this. */ Set<CompoundName> listReferences(CompoundName prefix, Map<String, String> context) { - DimensionBinding dimensionBinding=DimensionBinding.createFrom(getDimensions(),context); - AllReferencesQueryProfileVisitor visitor=new AllReferencesQueryProfileVisitor(prefix); + DimensionBinding dimensionBinding = DimensionBinding.createFrom(getDimensions(),context); + AllReferencesQueryProfileVisitor visitor = new AllReferencesQueryProfileVisitor(prefix); accept(visitor,dimensionBinding,null); return visitor.getResult(); } @@ -292,40 +292,40 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable public final Object get(String name) { return get(name,(Map<String,String>)null); } /** Returns a value from this using the given property context for resolution and using this for substitution */ - public final Object get(String name, Map<String,String> context) { - return get(name,context,null); + public final Object get(String name, Map<String, String> context) { + return get(name, context, null); } /** Returns a value from this using the given dimensions for resolution */ public final Object get(String name, String[] dimensionValues) { - return get(name,dimensionValues,null); + return get(name, dimensionValues,null); } public final Object get(String name, String[] dimensionValues, Properties substitution) { - return get(name,DimensionValues.createFrom(dimensionValues),substitution); + return get(name, DimensionValues.createFrom(dimensionValues), substitution); } /** Returns a value from this using the given dimensions for resolution */ public final Object get(String name, DimensionValues dimensionValues, Properties substitution) { - return get(name,DimensionBinding.createFrom(getDimensions(),dimensionValues),substitution); + return get(name, DimensionBinding.createFrom(getDimensions(), dimensionValues), substitution); } public final Object get(String name, Map<String,String> context, Properties substitution) { - return get(name,DimensionBinding.createFrom(getDimensions(),context),substitution); + return get(name, DimensionBinding.createFrom(getDimensions(), context), substitution); } public final Object get(CompoundName name, Map<String,String> context, Properties substitution) { - return get(name,DimensionBinding.createFrom(getDimensions(),context),substitution); + return get(name, DimensionBinding.createFrom(getDimensions(), context), substitution); } final Object get(String name, DimensionBinding binding,Properties substitution) { - return get(new CompoundName(name),binding,substitution); + return get(new CompoundName(name), binding, substitution); } final Object get(CompoundName name, DimensionBinding binding, Properties substitution) { - Object node=get(name,binding); - if (node!=null && node.getClass()==String.class) return node; // Shortcut - if (node instanceof SubstituteString) return ((SubstituteString)node).substitute(binding.getContext(),substitution); + Object node = get(name, binding); + if (node != null && node.getClass() == String.class) return node; // Shortcut + if (node instanceof SubstituteString) return ((SubstituteString)node).substitute(binding.getContext(), substitution); return node; } @@ -431,14 +431,14 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable @Override public QueryProfile clone() { if (isFrozen()) return this; - QueryProfile clone=(QueryProfile)super.clone(); - if (variants !=null) + QueryProfile clone = (QueryProfile)super.clone(); + if (variants != null) clone.variants = variants.clone(); - if (inherited!=null) - clone.inherited=new ArrayList<>(inherited); + if (inherited != null) + clone.inherited = new ArrayList<>(inherited); - if (this.content!=null) - clone.content=content.clone(); + if (this.content != null) + clone.content = content.clone(); return clone; } @@ -454,7 +454,7 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable /** Throws IllegalArgumentException if the given string is not a valid query profile name */ public static void validateName(String name) { - Matcher nameMatcher=namePattern.matcher(name); + Matcher nameMatcher = namePattern.matcher(name); if ( ! nameMatcher.matches()) throw new IllegalArgumentException("Illegal name '" + name + "'"); } @@ -467,7 +467,7 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable setNode(name, value, null, binding, registry); } catch (IllegalArgumentException e) { - throw new IllegalArgumentException("Could not set '" + name + "' to '" + value + "'",e); + throw new IllegalArgumentException("Could not set '" + name + "' to '" + value + "'", e); } } @@ -708,19 +708,19 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable /** Do a variant-aware content lookup in this */ protected Object localLookup(String name, DimensionBinding dimensionBinding) { - Object node=null; - if ( variants!=null && !dimensionBinding.isNull()) - node=variants.get(name,type,true,dimensionBinding); - if (node==null) - node=content==null ? null : content.get(name); + Object node = null; + if ( variants != null && !dimensionBinding.isNull()) + node = variants.get(name,type,true,dimensionBinding); + if (node == null) + node = content == null ? null : content.get(name); return node; } // ----------------- Private ---------------------------------------------------------------------------------- private Boolean isDeclaredOverridable(CompoundName name,DimensionBinding dimensionBinding) { - QueryProfile parent= lookupParentExact(name, true, dimensionBinding); - if (parent.overridable==null) return null; + QueryProfile parent = lookupParentExact(name, true, dimensionBinding); + if (parent.overridable == null) return null; return parent.overridable.get(name.last()); } @@ -729,10 +729,10 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable * this overrides the corresponding setting in the type (if any) */ private void setOverridable(CompoundName fieldName,boolean overridable,DimensionBinding dimensionBinding) { - QueryProfile parent= lookupParentExact(fieldName, true, dimensionBinding); - if (parent.overridable==null) - parent.overridable=new HashMap<>(); - parent.overridable.put(fieldName.last(),overridable); + QueryProfile parent = lookupParentExact(fieldName, true, dimensionBinding); + if (parent.overridable == null) + parent.overridable = new HashMap<>(); + parent.overridable.put(fieldName.last(), overridable); } /** Sets a value to a (possibly non-local) node. The parent query profile holding the value is returned */ @@ -740,7 +740,7 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable DimensionBinding dimensionBinding, QueryProfileRegistry registry) { ensureNotFrozen(); if (name.isCompound()) { - QueryProfile parent= getQueryProfileExact(name.first(), true, dimensionBinding); + QueryProfile parent = getQueryProfileExact(name.first(), true, dimensionBinding); parent.setNode(name.rest(), value,parentType, dimensionBinding.createFor(parent.getDimensions()), registry); } else { @@ -773,8 +773,8 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable * @return the created profile, or null if not present, and create is false */ private QueryProfile getQueryProfileExact(String localName, boolean create, DimensionBinding dimensionBinding) { - Object node=localExactLookup(localName, dimensionBinding); - if (node!=null && node instanceof QueryProfile) { + Object node = localExactLookup(localName, dimensionBinding); + if (node != null && node instanceof QueryProfile) { return (QueryProfile)node; } if (!create) return null; @@ -826,7 +826,7 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable } } - private static final Pattern namePattern=Pattern.compile("[$a-zA-Z_/][-$a-zA-Z0-9_/()]*"); + private static final Pattern namePattern = Pattern.compile("[$a-zA-Z_/][-$a-zA-Z0-9_/()]*"); /** * Returns a compiled version of this which produces faster lookup times diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java b/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java index 46aa33174f7..095d83d2887 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java @@ -39,8 +39,8 @@ public class QueryProfileXMLReader { * @throws RuntimeException if <code>directory</code> is not a readable directory, or if there is some error in the XML */ public QueryProfileRegistry read(String directory) { - List<NamedReader> queryProfileReaders=new ArrayList<>(); - List<NamedReader> queryProfileTypeReaders=new ArrayList<>(); + List<NamedReader> queryProfileReaders = new ArrayList<>(); + List<NamedReader> queryProfileTypeReaders = new ArrayList<>(); try { File dir=new File(directory); if ( !dir.isDirectory() ) throw new IllegalArgumentException("Could not read query profiles: '" + @@ -86,16 +86,16 @@ public class QueryProfileXMLReader { * Read the XML file readers into a registry. This does not close the readers. * This method is used directly from the admin system. */ - public QueryProfileRegistry read(List<NamedReader> queryProfileTypeReaders,List<NamedReader> queryProfileReaders) { - QueryProfileRegistry registry=new QueryProfileRegistry(); + public QueryProfileRegistry read(List<NamedReader> queryProfileTypeReaders, List<NamedReader> queryProfileReaders) { + QueryProfileRegistry registry = new QueryProfileRegistry(); // Phase 1 - List<Element> queryProfileTypeElements=createQueryProfileTypes(queryProfileTypeReaders,registry.getTypeRegistry()); - List<Element> queryProfileElements=createQueryProfiles(queryProfileReaders,registry); + List<Element> queryProfileTypeElements = createQueryProfileTypes(queryProfileTypeReaders, registry.getTypeRegistry()); + List<Element> queryProfileElements = createQueryProfiles(queryProfileReaders, registry); // Phase 2 - fillQueryProfileTypes(queryProfileTypeElements,registry.getTypeRegistry()); - fillQueryProfiles(queryProfileElements,registry); + fillQueryProfileTypes(queryProfileTypeElements, registry.getTypeRegistry()); + fillQueryProfiles(queryProfileElements, registry); return registry; } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/QueryProfileType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/QueryProfileType.java index 305de2a3c70..c826d834d47 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/QueryProfileType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/QueryProfileType.java @@ -102,13 +102,14 @@ public class QueryProfileType extends FreezableSimpleComponent { * * @throws IllegalStateException if this is frozen */ - public Map<String,FieldDescription> declaredFields() { + public Map<String, FieldDescription> declaredFields() { ensureNotFrozen(); return Collections.unmodifiableMap(fields); } /** * Returns true if <i>this</i> is declared strict. + * * @throws IllegalStateException if this is frozen */ public boolean isDeclaredStrict() { @@ -118,6 +119,7 @@ public class QueryProfileType extends FreezableSimpleComponent { /** * Returns true if <i>this</i> is declared as match as path. + * * @throws IllegalStateException if this is frozen */ public boolean getDeclaredMatchAsPath() { diff --git a/container-search/src/main/java/com/yahoo/search/statistics/ElapsedTime.java b/container-search/src/main/java/com/yahoo/search/statistics/ElapsedTime.java index 82438bda34e..804ce1b496b 100644 --- a/container-search/src/main/java/com/yahoo/search/statistics/ElapsedTime.java +++ b/container-search/src/main/java/com/yahoo/search/statistics/ElapsedTime.java @@ -12,13 +12,11 @@ import java.util.Set; import static com.yahoo.search.statistics.TimeTracker.Activity.*; /** - * Basically a collection of TimeTracker instances. + * A collection of TimeTracker instances. * - * <p>This class may need a lot of restructuring as actual - * needs are mapped out. - * - * @author <a href="steinar@yahoo-inc.com">Steinar Knutsen</a> + * @author Steinar Knutsen */ +// This class may need a lot of restructuring as actual needs are mapped out. public class ElapsedTime { // An identity set is used to make it safe to do multiple merges. This may happen if @@ -35,16 +33,11 @@ public class ElapsedTime { private long fetcher(Activity toFetch, TimeTracker fetchFrom) { switch (toFetch) { - case SEARCH: - return fetchFrom.searchTime(); - case FILL: - return fetchFrom.fillTime(); - case PING: - return fetchFrom.pingTime(); - default: - return 0L; + case SEARCH: return fetchFrom.searchTime(); + case FILL: return fetchFrom.fillTime(); + case PING: return fetchFrom.pingTime(); + default: return 0L; } - } /** @@ -232,4 +225,5 @@ public class ElapsedTime { report.append("."); return report.toString(); } + } diff --git a/container-search/src/main/java/com/yahoo/search/statistics/TimeTracker.java b/container-search/src/main/java/com/yahoo/search/statistics/TimeTracker.java index 6112bc504d3..d2461dffc7a 100644 --- a/container-search/src/main/java/com/yahoo/search/statistics/TimeTracker.java +++ b/container-search/src/main/java/com/yahoo/search/statistics/TimeTracker.java @@ -1,25 +1,24 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.statistics; -import java.util.ArrayList; -import java.util.EnumMap; -import java.util.List; -import java.util.Map; - import com.yahoo.component.chain.Chain; import com.yahoo.prelude.Pong; import com.yahoo.processing.Processor; import com.yahoo.search.Result; import com.yahoo.search.Searcher; +import java.util.ArrayList; +import java.util.EnumMap; +import java.util.List; +import java.util.Map; + /** - * A container for storing time stamps throughout the - * lifetime of an Execution instance. + * A container for storing time stamps throughout the lifetime of an Execution instance. * * <p>Check state both when entering and exiting, to allow for arbitrary * new queries anywhere inside a search chain. * - * @author <a href="mailto:steinar@yahoo-inc.com">Steinar Knutsen</a> + * @author Steinar Knutsen */ public final class TimeTracker { @@ -214,8 +213,6 @@ public final class TimeTracker { } concludeState(now); initNewState(now, activity); - } else { - return; } } @@ -314,7 +311,7 @@ public final class TimeTracker { return typedSum(Activity.PING); } - private long returnfromState(int searcherIndex, boolean detailed) { + private long returnFromState(int searcherIndex, boolean detailed) { if (detailed) { return detailedMeasurements(searcherIndex, false); } else { @@ -350,7 +347,7 @@ public final class TimeTracker { } private void sampleReturn(int searcherIndex, boolean detailed, ElapsedTime elapsed) { - long now = returnfromState(searcherIndex, detailed); + long now = returnFromState(searcherIndex, detailed); if (searcherIndex == entryIndex) { concludeStateOnExit(now); if (elapsed != null) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java index 23dd841b0ef..5f8daa69ecf 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.TensorType; import java.util.Arrays; @@ -48,12 +49,11 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { * * @throws IllegalArgumentException if the name is not present in the ranking expression this was created with, and * ignoredUnknownValues is false - * @since 5.1.5 */ @Override public final void put(String name, Value value) { Integer index = nameToIndex().get(name); - if (index==null) { + if (index == null) { if (ignoreUnknownValues()) return; else @@ -70,24 +70,29 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { /** * Puts a value by index. * The value will be frozen if it isn't already. - * - * @since 5.1.5 */ public final void put(int index, Value value) { - values[index]=value.freeze(); + values[index] = value.freeze(); try { - doubleValues()[index]=value.asDouble(); + doubleValues()[index] = value.asDouble(); } catch (UnsupportedOperationException e) { - doubleValues()[index]=Double.NaN; // see getDouble below + doubleValues()[index] = Double.NaN; // see getDouble below } } + @Override + public TensorType getType(String name) { + Integer index = nameToIndex().get(name); + if (index == null) return null; + return values[index].type(); + } + /** Perform a slow lookup by name */ @Override public Value get(String name) { - Integer index=nameToIndex().get(name); - if (index==null) return DoubleValue.zero; + Integer index = nameToIndex().get(name); + if (index == null) return DoubleValue.zero; return values[index]; } @@ -100,8 +105,8 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { /** Perform a fast lookup directly of the value as a double. This is faster than get(index).asDouble() */ @Override public final double getDouble(int index) { - double value=doubleValues()[index]; - if (value==Double.NaN) + double value = doubleValues()[index]; + if (value == Double.NaN) throw new UnsupportedOperationException("Value at " + index + " has no double representation"); return value; } @@ -111,7 +116,7 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { * in a different thread (i.e, name name to index map, different value set. */ public ArrayContext clone() { - ArrayContext clone=(ArrayContext)super.clone(); + ArrayContext clone = (ArrayContext)super.clone(); clone.values = new Value[nameToIndex().size()]; Arrays.fill(values,constantZero); return clone; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index 0eeb0a9e630..861f9565d66 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -17,7 +17,7 @@ import java.util.stream.Collectors; public abstract class Context implements EvaluationContext { /** - * <p>Returns the value of a simple variable name.</p> + * Returns the value of a simple variable name. * * @param name the name of the variable whose value to return. * @return the value of the named variable. diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index 2ef4a2ede2f..3ac11cff0cb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -4,18 +4,19 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; /** * A value which acts as a double in numerical context. * * @author bratseth - * @since 5.1.21 */ public abstract class DoubleCompatibleValue extends Value { @Override + public TensorType type() { return TensorType.empty; } + + @Override public boolean hasDouble() { return true; } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java index ceec9358b3c..0625e8506cc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.TensorType; /** * A variant of an array context variant which supports faster binding of variables but slower lookup @@ -38,7 +39,6 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext { * * @throws IllegalArgumentException if the name is not present in the ranking expression this was created with, and * ignoredUnknownValues is false - * @since 5.1.5 */ @Override public final void put(String name, Value value) { @@ -57,11 +57,7 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext { doubleValues()[index] = value; } - /** - * Puts a value by index. - * - * @since 5.1.5 - */ + /** Puts a value by index. */ public final void put(int index, Value value) { try { put(index, value.asDouble()); @@ -71,6 +67,9 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext { } } + @Override + public TensorType getType(String name) { return TensorType.empty; } + /** Perform a slow lookup by name */ @Override public Value get(String name) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java index 09895a0c2f6..39efe641f26 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.evaluation; +import com.yahoo.tensor.TensorType; + import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -13,7 +15,7 @@ import java.util.Set; */ public class MapContext extends Context { - private Map<String,Value> bindings=new HashMap<>(); + private Map<String, Value> bindings = new HashMap<>(); private boolean frozen = false; @@ -21,16 +23,6 @@ public class MapContext extends Context { } /** - * Freezes this. - * Returns this for convenience. - */ - public MapContext freeze() { - if ( ! frozen) - bindings = Collections.unmodifiableMap(bindings); - return this; - } - - /** * Creates a map context from a map. * The ownership of the map is transferred to this - it cannot be further modified by the caller. * All the Values of the map will be frozen. @@ -42,20 +34,31 @@ public class MapContext extends Context { } /** - * Returns the value of a key. 0 is returned if the given key is not bound in this. + * Freezes this. + * Returns this for convenience. */ + public MapContext freeze() { + if ( ! frozen) + bindings = Collections.unmodifiableMap(bindings); + return this; + } + + /** Returns the type of the given value key, or null if it is not bound. */ @Override - public Value get(String key) { + public TensorType getType(String key) { Value value = bindings.get(key); - if (value == null) return DoubleValue.zero; - return value; + if (value == null) return null; + return value.type(); + } + + /** Returns the value of a key. 0 is returned if the given key is not bound in this. */ + @Override + public Value get(String key) { + return bindings.getOrDefault(key, DoubleValue.zero); } /** - * Sets the value of a key. - * The value is frozen by this. - * - * @since 5.1.5 + * Sets the value of a key.The value is frozen by this. */ @Override public void put(String key,Value value) { @@ -67,7 +70,7 @@ public class MapContext extends Context { if (frozen) return bindings; return Collections.unmodifiableMap(bindings); } - + /** Returns a new, modifiable context containing all the bindings of this */ public MapContext thawedCopy() { return new MapContext(new HashMap<>(bindings)); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index dad69b31181..c60507310f1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -5,7 +5,6 @@ import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; /** @@ -30,6 +29,9 @@ public class StringValue extends Value { this.value = value; } + @Override + public TensorType type() { return TensorType.empty; } + /** Returns the hashcode of this, to enable strings to be encoded (with reasonable safely) as doubles for optimization */ @Override public double asDouble() { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 26c30fe5ed2..c6e456f285d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -6,6 +6,7 @@ import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; /** * A Value containing a tensor. @@ -25,6 +26,9 @@ public class TensorValue extends Value { } @Override + public TensorType type() { return TensorType.empty; } + + @Override public double asDouble() { if (hasDouble()) return value.get(TensorAddress.of()); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java new file mode 100644 index 00000000000..a018aae0c3e --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java @@ -0,0 +1,28 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.HashMap; +import java.util.Map; + +/** + * A context which only contains type information. + * + * @author bratseth + */ +public class TypeMapContext implements TypeContext { + + private final Map<String, TensorType> featureTypes = new HashMap<>(); + + public void setType(String name, TensorType type) { + featureTypes.put(name, type); + } + + @Override + public TensorType getType(String name) { + return featureTypes.get(name); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 40d70e0022c..59d2d95b879 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -13,12 +13,15 @@ import com.yahoo.tensor.TensorType; * Concrete subclasses of this provides implementations of these methods or throws * UnsupportedOperationException if the operation is not supported. * - * @author bratseth + * @author bratseth */ public abstract class Value { private boolean frozen=false; + /** Returns the type of this value */ + public abstract TensorType type(); + /** Returns this value as a double, or throws UnsupportedOperationException if it cannot be represented as a double */ public abstract double asDouble(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java index 372fb00431b..8ee4cdbf297 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java @@ -7,6 +7,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Deque; @@ -24,6 +26,9 @@ public class GBDTForestNode extends ExpressionNode { } @Override + public final TensorType type(TypeContext context) { return TensorType.empty; } + + @Override public final Value evaluate(Context context) { int pc = 0; double treeSum = 0; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java index 4d7b4835892..aac635b2545 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java @@ -7,6 +7,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Deque; @@ -49,6 +51,9 @@ public final class GBDTNode extends ExpressionNode { public final double[] values() { return values; } @Override + public final TensorType type(TypeContext context) { return TensorType.empty; } + + @Override public final Value evaluate(Context context) { return new DoubleValue(evaluate(values,0,context)); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index 816ef38e128..cdcb4df0360 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -147,14 +147,12 @@ class OperationMapper { return operation.get().map(params); } params.signature().importWarning("TensorFlow operation '" + params.node().getOp() + - "' in node '" + params.node().getName() + "' is not supported."); + "' in node '" + params.node().getName() + "' is not supported."); return Optional.empty(); } - /* - * Operations - */ + // Operations --------------------------------- private static Optional<TypedTensorFunction> constant(TensorFlowImporter.Parameters params) { Tensor value = AttrValueConverter.toVespaTensor(params.node(), "value"); @@ -209,10 +207,11 @@ class OperationMapper { TensorType type = params.result().arguments().get(name); if (type == null) { throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name + - "', but there is no such placeholder"); + "', but there is no such placeholder"); } + params.result().requiredMacro(name, type); // Included literally in the expression and so must be produced by a separate macro in the rank profile - TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(name)); + TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(name, type)); return Optional.of(output); } @@ -227,7 +226,7 @@ class OperationMapper { } private static Optional<TypedTensorFunction> reshape(TensorFlowImporter.Parameters params) { - if (!checkInputs(params, 2)) { + if ( ! checkInputs(params, 2)) { return Optional.empty(); } List<Optional<TypedTensorFunction>> inputs = params.inputs(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 3a6b3f23a1d..6d78b501fdc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -14,7 +14,6 @@ import org.tensorflow.framework.TensorInfo; import java.io.File; import java.io.IOException; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -109,10 +108,10 @@ public class TensorFlowImporter { } Optional<TypedTensorFunction> function = OperationMapper.map(params); - if (!function.isPresent()) { + if ( ! function.isPresent()) { return Optional.empty(); } - if (!controlDependenciesArePresent(params)) { + if ( ! controlDependenciesArePresent(params)) { return Optional.empty(); } params.imported().put(nodeName, function.get()); @@ -185,6 +184,7 @@ public class TensorFlowImporter { /** Parameter object to hold important data while importing */ static final class Parameters { + private final TensorFlowImporter owner; private final GraphDef graph; private final SavedModelBundle model; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java index 60aaf8ddce1..fe725e50a3f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java @@ -27,11 +27,13 @@ public class TensorFlowModel { private final Map<String, Tensor> constants = new HashMap<>(); private final Map<String, RankingExpression> expressions = new HashMap<>(); private final Map<String, RankingExpression> macros = new HashMap<>(); + private final Map<String, TensorType> requiredMacros = new HashMap<>(); void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } void constant(String name, Tensor constant) { constants.put(name, constant); } void expression(String name, RankingExpression expression) { expressions.put(name, expression); } void macro(String name, RankingExpression expression) { macros.put(name, expression); } + void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } /** Returns the given signature. If it does not already exist it is added to this. */ Signature signature(String name) { @@ -51,11 +53,12 @@ public class TensorFlowModel { */ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } - /** - * Returns an immutable map of expressions that can be overridden - such as PlaceholderWithDefault/ - */ + /** Returns an immutable map of macros that are part of this model */ public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); } + /** Returns an immutable map of the macros that must be provided by the environment running this model */ + public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); } + /** Returns an immutable map of the signatures of this */ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java index 518a15bcc87..fc6428a4c33 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java @@ -4,8 +4,15 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.functions.Join; -import java.util.*; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; /** * A binary mathematical operation @@ -73,14 +80,26 @@ public final class ArithmeticNode extends CompositeNode { } @Override + public TensorType type(TypeContext context) { + // Compute type using tensor types as arithmetic operators are supported on tensors + // and is correct also in the special case of doubles. + // As all our functions are type-commutative, we don't need to take operator precedence into account + TensorType type = children.get(0).type(context); + for (int i = 1; i < children.size(); i++) + type = Join.outputType(type, children.get(i).type(context)); + return type; + } + + @Override public Value evaluate(Context context) { Iterator<ExpressionNode> child = children.iterator(); + // Apply in precedence order: Deque<ValueItem> stack = new ArrayDeque<>(); stack.push(new ValueItem(ArithmeticOperator.OR, child.next().evaluate(context))); for (Iterator<ArithmeticOperator> it = operators.iterator(); it.hasNext() && child.hasNext();) { ArithmeticOperator op = it.next(); - if (!stack.isEmpty()) { + if ( ! stack.isEmpty()) { while (stack.peek().op.hasPrecedenceOver(op)) { popStack(stack); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java index 9484f789169..7601c0e6180 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java @@ -1,11 +1,14 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; -import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; -import java.util.*; +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; /** * A node which returns the outcome of a comparison. @@ -46,6 +49,11 @@ public class ComparisonNode extends BooleanNode { } @Override + public TensorType type(TypeContext context) { + return TensorType.empty; // by definition + } + + @Override public Value evaluate(Context context) { Value leftValue = leftCondition.evaluate(context); Value rightValue = rightCondition.evaluate(context); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java index cd473ae6a6f..1ea8d03f0eb 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java @@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Deque; @@ -47,6 +49,9 @@ public final class ConstantNode extends ExpressionNode { } @Override + public TensorType type(TypeContext context) { return value.type(); } + + @Override public Value evaluate(Context context) { return value; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java index b5d7c41d698..fd9fab99db8 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java @@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; @@ -48,6 +50,11 @@ public final class EmbracedNode extends CompositeNode { } @Override + public TensorType type(TypeContext context) { + return value.type(context); + } + + @Override public CompositeNode setChildren(List<ExpressionNode> newChildren) { if (newChildren.size() != 1) throw new IllegalArgumentException("Expected 1 child but got " + newChildren.size()); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java index eb303fc6446..477f4db4981 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java @@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.io.Serializable; import java.util.Deque; @@ -41,6 +43,14 @@ public abstract class ExpressionNode implements Serializable { public abstract String toString(SerializationContext context, Deque<String> path, CompositeNode parent); /** + * Returns the type this will return if evaluated with the given context. + * + * @param context the variable type bindings to use for this evaluation + * @throws IllegalArgumentException if there are variables which are not bound in the given map + */ + public abstract TensorType type(TypeContext context); + + /** * Returns the value of evaluating this expression over the given context. * * @param context the variable bindings to use for this evaluation diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java index 142e282e5c6..79515229019 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java @@ -4,6 +4,9 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.functions.Join; import java.util.ArrayList; import java.util.Collections; @@ -64,16 +67,29 @@ public final class FunctionNode extends CompositeNode { } @Override + public TensorType type(TypeContext context) { + if (arguments.expressions().size() == 0) + return TensorType.empty; + + TensorType argument1Type = arguments.expressions().get(0).type(context); + if (arguments.expressions().size() == 1) + return argument1Type; + + TensorType argument2Type = arguments.expressions().get(1).type(context); + return Join.outputType(argument1Type, argument2Type); + } + + @Override public Value evaluate(Context context) { if (arguments.expressions().size() == 0) - return DoubleValue.zero.function(function,DoubleValue.zero); + return DoubleValue.zero.function(function ,DoubleValue.zero); Value argument1 = arguments.expressions().get(0).evaluate(context); if (arguments.expressions().size() == 1) return argument1.function(function, DoubleValue.zero); Value argument2 = arguments.expressions().get(1).evaluate(context); - return argument1.function(function,argument2); + return argument1.function(function, argument2); } /** Returns a new function node with the children replaced by the given children */ diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java index 9da1ba40144..e42884ecc05 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java @@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; @@ -46,6 +47,9 @@ public class GeneratorLambdaFunctionNode extends CompositeNode { return generator.toString(context, path, this); } + @Override + public TensorType type(TypeContext context) { return type; } + /** Evaluate this in a context which must have the arguments bound */ @Override public Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java index 1b429de0be5..076df327044 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java @@ -3,13 +3,18 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.List; /** * A conditional branch of a ranking expression. * - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen * @author bratseth */ public final class IfNode extends CompositeNode { @@ -70,6 +75,17 @@ public final class IfNode extends CompositeNode { } @Override + public TensorType type(TypeContext context) { + TensorType trueType = trueExpression.type(context); + TensorType falseType = falseExpression.type(context); + if ( ! trueType.equals(falseType)) + throw new IllegalArgumentException("An if expression must produce a value of the same type in both " + + "alternatives, but the 'true' type is " + trueType + " while the " + + "'false' type is " + falseType); + return trueType; + } + + @Override public Value evaluate(Context context) { if (condition.evaluate(context).asBoolean()) return trueExpression.evaluate(context); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index 78206d75d0d..da946228291 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -5,6 +5,8 @@ import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; @@ -14,20 +16,20 @@ import java.util.function.DoubleUnaryOperator; /** * A free, parametrized function - * + * * @author bratseth */ public class LambdaFunctionNode extends CompositeNode { private final ImmutableList<String> arguments; private final ExpressionNode functionExpression; - + public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) { // TODO: Verify that the function only accesses the given arguments this.arguments = ImmutableList.copyOf(arguments); this.functionExpression = functionExpression; } - + @Override public List<ExpressionNode> children() { return Collections.singletonList(functionExpression); @@ -54,19 +56,24 @@ public class LambdaFunctionNode extends CompositeNode { return b.toString(); } + @Override + public TensorType type(TypeContext context) { + return TensorType.empty; // by definition - no nested lambdas + } + /** Evaluate this in a context which must have the arguments bound */ @Override public Value evaluate(Context context) { return functionExpression.evaluate(context); } - - /** + + /** * Returns this as a double unary operator - * - * @throws IllegalStateException if this has more than one argument + * + * @throws IllegalStateException if this has more than one argument */ public DoubleUnaryOperator asDoubleUnaryOperator() { - if (arguments.size() > 1) + if (arguments.size() > 1) throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " + "Must have at most one argument " + " but has " + arguments); return new DoubleUnaryLambda(); @@ -93,7 +100,7 @@ public class LambdaFunctionNode extends CompositeNode { context.put(arguments.get(0), operand); return evaluate(context).asDouble(); } - + @Override public String toString() { return LambdaFunctionNode.this.toString(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java index 69df572272a..f55ed59b65c 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java @@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Deque; @@ -30,6 +32,9 @@ public final class NameNode extends ExpressionNode { } @Override + public TensorType type(TypeContext context) { throw new RuntimeException("Named nodes can not have a type"); } + + @Override public Value evaluate(Context context) { throw new RuntimeException("Name nodes should never be evaluated"); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java index 61c20a97b64..9cbe5f98c72 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java @@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; @@ -36,6 +38,11 @@ public class NegativeNode extends CompositeNode { } @Override + public TensorType type(TypeContext context) { + return value.type(context); + } + + @Override public Value evaluate(Context context) { return value.evaluate(context).negate(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java index 8c459a032bd..e7041600635 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java @@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; @@ -36,6 +38,11 @@ public class NotNode extends BooleanNode { } @Override + public TensorType type(TypeContext context) { + return value.type(context); + } + + @Override public Value evaluate(Context context) { return value.evaluate(context).not(); } @@ -45,6 +52,6 @@ public class NotNode extends BooleanNode { if (children.size() != 1) throw new IllegalArgumentException("Expected 1 children but got " + children.size()); return new NotNode(children.get(0)); } - + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index 139709998b4..f79297f7773 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -5,6 +5,8 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayDeque; import java.util.Deque; @@ -44,9 +46,8 @@ public final class ReferenceNode extends CompositeNode { return new ReferenceNode(name, arguments, output); } - public String getOutput() { - return output; - } + /** Returns the specific output this references, or null if none specified */ + public String getOutput() { return output; } /** Returns a copy of this node with a modified output */ public ReferenceNode setOutput(String output) { @@ -105,8 +106,14 @@ public final class ReferenceNode extends CompositeNode { } @Override + public TensorType type(TypeContext context) { + // Don't support outputs of different type, for simplicity + return context.getType(name); + } + + @Override public Value evaluate(Context context) { - if (arguments.expressions().size()==0 && output==null) + if (arguments.expressions().isEmpty() && output == null) return context.get(name); return context.get(name, arguments, output); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java index f6b1a1a8979..a7b82f4753f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java @@ -7,6 +7,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayList; import java.util.Deque; @@ -58,6 +60,11 @@ public class SetMembershipNode extends BooleanNode { } @Override + public TensorType type(TypeContext context) { + return TensorType.empty; + } + + @Override public Value evaluate(Context context) { Value value = testValue.evaluate(context); if (value instanceof TensorValue) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index b42570d3aea..e4c381972e9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -6,7 +6,9 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.functions.PrimitiveTensorFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; @@ -35,7 +37,7 @@ public class TensorFunctionNode extends CompositeNode { @Override public List<ExpressionNode> children() { - return function.functionArguments().stream() + return function.arguments().stream() .map(this::toExpressionNode) .collect(Collectors.toList()); } @@ -52,7 +54,7 @@ public class TensorFunctionNode extends CompositeNode { List<TensorFunction> wrappedChildren = children.stream() .map(TensorFunctionExpressionNode::new) .collect(Collectors.toList()); - return new TensorFunctionNode(function.replaceArguments(wrappedChildren)); + return new TensorFunctionNode(function.withArguments(wrappedChildren)); } @Override @@ -62,6 +64,9 @@ public class TensorFunctionNode extends CompositeNode { } @Override + public TensorType type(TypeContext context) { return function.type(context); } + + @Override public Value evaluate(Context context) { return new TensorValue(function.evaluate(context)); } @@ -84,7 +89,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public List<TensorFunction> functionArguments() { + public List<TensorFunction> arguments() { if (expression instanceof CompositeNode) return ((CompositeNode)expression).children().stream() .map(TensorFunctionExpressionNode::new) @@ -94,7 +99,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if (arguments.size() == 0) return this; List<ExpressionNode> unwrappedChildren = arguments.stream() .map(arg -> ((TensorFunctionExpressionNode)arg).expression) @@ -106,12 +111,17 @@ public class TensorFunctionNode extends CompositeNode { public PrimitiveTensorFunction toPrimitive() { return this; } @Override + public TensorType type(TypeContext context) { + return expression.type(context); + } + + @Override public Tensor evaluate(EvaluationContext context) { Value result = expression.evaluate((Context)context); if ( ! ( result instanceof TensorValue)) throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " + "but this returns " + result + ", not a tensor"); - return ((TensorValue)result).asTensor(); + return result.asTensor(); } @Override diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java index b59b4750911..445ccf231a7 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java @@ -2,10 +2,12 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.TensorType; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; /** * @author lesters @@ -15,6 +17,18 @@ public class DropoutImportTestCase { @Test public void testDropoutImport() { TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/dropout/saved"); + + // Check (provided) macros + assertEquals(1, model.get().macros().size()); + assertTrue(model.get().macros().containsKey("training/input")); + assertEquals("constant(\"training/input\")", model.get().macros().get("training/input").getRoot().toString()); + + // Check required macros + assertEquals(1, model.get().requiredMacros().size()); + assertTrue(model.get().requiredMacros().containsKey("X")); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), + model.get().requiredMacros().get("X")); + TensorFlowModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index f12b9a2c628..01dd15d5fa0 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -8,6 +8,7 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; /** * @author bratseth @@ -33,6 +34,15 @@ public class MnistSoftmaxImportTestCase { constant1.type()); assertEquals(10, constant1.size()); + // Check (provided) macros + assertEquals(0, model.get().macros().size()); + + // Check required macros + assertEquals(1, model.get().requiredMacros().size()); + assertTrue(model.get().requiredMacros().containsKey("Placeholder")); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), + model.get().requiredMacros().get("Placeholder")); + // Check signatures assertEquals(1, model.get().signatures().size()); TensorFlowModel.Signature signature = model.get().signatures().get("serving_default"); diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java index cfa9ce670f6..6095134b7a2 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java @@ -15,14 +15,14 @@ import static com.yahoo.vespa.http.client.SessionFactory.createTimeoutExecutor; public class FeedClientFactory { /** - * Creates FeedClient. + * Creates a FeedClient. + * * @param sessionParams parameters for connection, hosts, cluster configurations and more. * @param resultCallback on each result, this callback is called. * @return newly created FeedClient API object. */ - public static FeedClient create( - SessionParams sessionParams, - FeedClient.ResultCallback resultCallback) { + public static FeedClient create(SessionParams sessionParams, FeedClient.ResultCallback resultCallback) { return new FeedClientImpl(sessionParams, resultCallback, createTimeoutExecutor()); } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java index 138be61de80..65f56f72a58 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java @@ -20,6 +20,7 @@ import java.util.List; */ // This should be an interface, but in order to be binary compatible during refactoring we made it abstract. public class Result { + public enum ResultType { OPERATION_EXECUTED, TRANSITIVE_ERROR, @@ -106,12 +107,10 @@ public class Result { /** * Information in a Result for a single operation sent to a single endpoint. - * - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> - * @since 5.1.20 */ @Immutable public static final class Detail { + private final ResultType resultType; private final Endpoint endpoint; private final Exception exception; @@ -212,4 +211,5 @@ public class Result { } return b.toString(); } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/EndpointResult.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/EndpointResult.java index 5aec46a8fc7..b04248f98a5 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/EndpointResult.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/EndpointResult.java @@ -5,9 +5,11 @@ import com.yahoo.vespa.http.client.Result; /** * Result from a single endpoint. + * * @author dybis */ public class EndpointResult { + private final String operationId; private final Result.Detail detail; @@ -23,4 +25,5 @@ public class EndpointResult { public Result.Detail getDetail() { return detail; } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ErrorCode.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ErrorCode.java index 96afc537c59..445ad5295c1 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ErrorCode.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ErrorCode.java @@ -6,12 +6,12 @@ import com.google.common.annotations.Beta; /** * Return types for the server. * - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> - * @author <a href="mailto:steinar@yahoo-inc.com">Steinar Knutsen</a> - * @since 5.1.20 + * @author Einar M R Rosenvinge + * @author Steinar Knutsen */ @Beta public enum ErrorCode { + OK(true, true), ERROR(false, false), TRANSIENT_ERROR(false, true), @@ -20,7 +20,7 @@ public enum ErrorCode { private boolean success; private boolean _transient; - private ErrorCode(boolean success, boolean _transient) { + ErrorCode(boolean success, boolean _transient) { this.success = success; this._transient = _transient; } @@ -32,4 +32,5 @@ public enum ErrorCode { public boolean isTransient() { return _transient; } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/OperationStatus.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/OperationStatus.java index 7ea4a214cbd..7aec207e0ab 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/OperationStatus.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/OperationStatus.java @@ -10,11 +10,11 @@ import java.util.Iterator; /** * Serialization/deserialization class for the result of a single document operation against Vespa. * - * @author <a href="mailto:steinar@yahoo-inc.com">Steinar Knutsen</a> - * @since 5.1 + * @author Steinar Knutsen */ @Beta public final class OperationStatus { + public static final String IS_CONDITION_NOT_MET = "IS-CONDITION-NOT-MET"; public final String message; public final String operationId; @@ -81,9 +81,7 @@ public final class OperationStatus { return new OperationStatus(message, operationId, errorCode, isConditionNotMet, traceMessage); } - /** - * @return a string representing the status. - */ + /** Returns a string representing the status. */ public String render() { StringBuilder s = new StringBuilder(); Encoder.encode(operationId, s).append(SEPARATOR); @@ -92,4 +90,5 @@ public final class OperationStatus { Encoder.encode(traceMessage, s).append(EOL); return s.toString(); } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ServerResponseException.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ServerResponseException.java index 4c291935916..1800864cd90 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ServerResponseException.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/ServerResponseException.java @@ -5,12 +5,13 @@ import com.google.common.annotations.Beta; /** * The request was not processed properly on the server. - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> - * @since 5.1.20 + * + * @author Einar M R Rosenvinge */ @SuppressWarnings("serial") @Beta public class ServerResponseException extends Exception { + private final int responseCode; private final String responseString; @@ -33,5 +34,6 @@ public class ServerResponseException extends Exception { } return responseString; } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java index a16d992324d..903c1ad4842 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java @@ -17,7 +17,7 @@ import java.util.concurrent.TimeUnit; /** * Implementation of FeedClient. It is a thin layer on top of multiClusterHandler and multiClusterResultAggregator. - * + * * @author dybis */ public class FeedClientImpl implements FeedClient { @@ -92,4 +92,5 @@ public class FeedClientImpl implements FeedClient { } return true; } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java index edbe5542bc4..c3b5d9912de 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java @@ -6,7 +6,6 @@ import com.yahoo.vespa.http.client.Result; import com.yahoo.vespa.http.client.Session; import com.yahoo.vespa.http.client.config.SessionParams; import com.yahoo.vespa.http.client.core.ThrottlePolicy; -import com.yahoo.vespa.http.client.core.api.MultiClusterSessionOutputStream; import com.yahoo.vespa.http.client.core.operationProcessor.IncompleteResultsThrottler; import com.yahoo.vespa.http.client.core.operationProcessor.OperationProcessor; @@ -65,4 +64,5 @@ public class SessionImpl implements Session { public int getIncompleteResultQueueSize() { return operationProcessor.getIncompleteResultQueueSize(); } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java index 420f64d4bf3..dd724e90a42 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java @@ -44,8 +44,6 @@ import java.util.zip.GZIPOutputStream; /** * @author Einar M R Rosenvinge - * - * @since 5.1.20 */ @Beta class ApacheGatewayConnection implements GatewayConnection { diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java index 671e6f07dbe..d963ae79227 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java @@ -2,8 +2,6 @@ package com.yahoo.vespa.http.client.core.communication; import com.yahoo.vespa.http.client.core.Document; -import com.yahoo.vespa.http.client.core.EndpointResult; -import com.yahoo.vespa.http.client.core.operationProcessor.EndPointResultFactory; import java.util.ArrayDeque; import java.util.ArrayList; @@ -13,7 +11,7 @@ import java.util.Optional; import java.util.concurrent.TimeUnit; /** - * Document queue that only gives you document operations on documents for which there are no + * Document queue that only gives you document operations on documents for which there are no * already in flight operations for. * * @author dybis @@ -54,8 +52,6 @@ class DocumentQueue { } } - - Document poll(long timeout, TimeUnit unit) throws InterruptedException { synchronized (queue) { long remainingToWait = unit.toMillis(timeout); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index b396f831de0..5b98a1b4fb5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -19,7 +19,6 @@ import java.util.stream.Collectors; * A tensor type with its dimensions. This is immutable. * <p> * A dimension can be indexed (bound or unbound) or mapped. - * Currently, we only support tensor types where all dimensions have the same type. * * @author geirst * @author bratseth @@ -27,6 +26,7 @@ import java.util.stream.Collectors; @Beta public class TensorType { + /** The empty tensor type - which is the same as a double */ public static final TensorType empty = new TensorType(Collections.emptyList()); /** Sorted list of the dimensions of this */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java index 3db661f8a23..3fb94f1251b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java @@ -10,7 +10,7 @@ import com.yahoo.tensor.Tensor; * @author bratseth */ @Beta -public interface EvaluationContext { +public interface EvaluationContext extends TypeContext { /** Returns the tensor bound to this name, or null if none */ Tensor getTensor(String name); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java index db8a66a5fa2..9fe6b7d053f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.evaluation; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import java.util.HashMap; @@ -19,6 +20,13 @@ public class MapEvaluationContext implements EvaluationContext { public void put(String name, Tensor tensor) { bindings.put(name, tensor); } @Override + public TensorType getType(String name) { + Tensor tensor = bindings.get(name); + if (tensor == null) return null; + return tensor.type(); + } + + @Override public Tensor getTensor(String name) { return bindings.get(name); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java new file mode 100644 index 00000000000..4d3bd04c3d4 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java @@ -0,0 +1,21 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.evaluation; + +import com.yahoo.tensor.TensorType; + +/** + * Provides type information about a context (set of variable bindings). + * + * @author bratseth + */ +public interface TypeContext { + + /** + * Returns tye type of the tensor with this name. + * + * @return returns the type of the tensor which will be returned by calling getTensor(name) + * or null if getTensor will return null. + */ + TensorType getType(String name); + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index 1f6ad050368..34beb465d4c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -3,12 +3,14 @@ package com.yahoo.tensor.evaluation; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.PrimitiveTensorFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; import java.util.Collections; import java.util.List; +import java.util.Optional; /** * A tensor variable name which resolves to a tensor in the context at evaluation time @@ -19,23 +21,42 @@ import java.util.List; public class VariableTensor extends PrimitiveTensorFunction { private final String name; + private final Optional<TensorType> requiredType; public VariableTensor(String name) { this.name = name; + this.requiredType = Optional.empty(); + } + + /** A variable tensor which must be compatible with the given type */ + public VariableTensor(String name, TensorType requiredType) { + this.name = name; + this.requiredType = Optional.of(requiredType); } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { return this; } + public TensorFunction withArguments(List<TensorFunction> arguments) { return this; } @Override public PrimitiveTensorFunction toPrimitive() { return this; } @Override + public TensorType type(TypeContext context) { + TensorType givenType = context.getType(name); + if (givenType == null) return null; + verifyType(givenType); + return givenType; + } + + @Override public Tensor evaluate(EvaluationContext context) { - return context.getTensor(name); + Tensor tensor = context.getTensor(name); + if (tensor == null) return null; + verifyType(tensor.type()); + return tensor; } @Override @@ -43,4 +64,9 @@ public class VariableTensor extends PrimitiveTensorFunction { return name; } + private void verifyType(TensorType givenType) { + if (requiredType.isPresent() && ! givenType.isAssignableTo(requiredType.get())) + throw new IllegalArgumentException("Variable '" + name + "' must be compatible with " + + requiredType.get() + " but was " + givenType); + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java index 10f53670826..93365d20966 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java @@ -14,17 +14,17 @@ public class Argmax extends CompositeTensorFunction { private final TensorFunction argument; private final String dimension; - + public Argmax(TensorFunction argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Argmax must have 1 argument, got " + arguments.size()); return new Argmax(arguments.get(0), dimension); @@ -37,7 +37,7 @@ public class Argmax extends CompositeTensorFunction { new Reduce(primitiveArgument, Reduce.Aggregator.max, dimension), ScalarFunctions.equal()); } - + @Override public String toString(ToStringContext context) { return "argmax(" + argument.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java index d324aec53e9..e598cdf8a98 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java @@ -14,17 +14,17 @@ public class Argmin extends CompositeTensorFunction { private final TensorFunction argument; private final String dimension; - + public Argmin(TensorFunction argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Argmin must have 1 argument, got " + arguments.size()); return new Argmin(arguments.get(0), dimension); @@ -37,7 +37,7 @@ public class Argmin extends CompositeTensorFunction { new Reduce(primitiveArgument, Reduce.Aggregator.min, dimension), ScalarFunctions.equal()); } - + @Override public String toString(ToStringContext context) { return "argmin(" + argument.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 191c7988443..2109b730e1a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -3,7 +3,9 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; /** * A composite tensor function is a tensor function which can be expressed (less tersely) @@ -14,6 +16,10 @@ import com.yahoo.tensor.evaluation.EvaluationContext; @Beta public abstract class CompositeTensorFunction extends TensorFunction { + /** Finds the type this produces by first converting it to a primitive function */ + @Override + public final TensorType type(TypeContext context) { return toPrimitive().type(context); } + /** Evaluates this by first converting it to a primitive function */ @Override public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index d4affe0ef9b..c77ed1c0526 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -9,8 +9,14 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; - -import java.util.*; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; /** @@ -34,10 +40,10 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); } + public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if (arguments.size() != 2) throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size()); return new Concat(arguments.get(0), arguments.get(1), dimension); @@ -54,6 +60,20 @@ public class Concat extends PrimitiveTensorFunction { } @Override + public TensorType type(TypeContext context) { + return type(argumentA.type(context), argumentB.type(context)); + } + + /** Returns the type resulting from concatenating a and b */ + private TensorType type(TensorType a, TensorType b) { + TensorType.Builder builder = new TensorType.Builder(a, b); + if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size + builder.set(TensorType.Dimension.indexed(dimension, a.dimension(dimension).get().size().get() + + b.dimension(dimension).get().size().get())); + return builder.build(); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); @@ -63,7 +83,7 @@ public class Concat extends PrimitiveTensorFunction { IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor IndexedTensor bIndexed = (IndexedTensor) b; - TensorType concatType = concatType(a, b); + TensorType concatType = type(a.type(), b.type()); DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); @@ -115,15 +135,6 @@ public class Concat extends PrimitiveTensorFunction { } - /** Returns the type resulting from concatenating a and b */ - private TensorType concatType(Tensor a, Tensor b) { - TensorType.Builder builder = new TensorType.Builder(a.type(), b.type()); - if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size - builder.set(TensorType.Dimension.indexed(dimension, a.type().dimension(dimension).get().size().get() + - b.type().dimension(dimension).get().size().get())); - return builder.build(); - } - /** Returns the concrete (not type) dimension sizes resulting from combining a and b */ private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) { DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 14ed38718ce..50b479da168 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -3,7 +3,9 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.List; @@ -27,10 +29,10 @@ public class ConstantTensor extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("ConstantTensor must have 0 arguments, got " + arguments.size()); return this; @@ -40,6 +42,9 @@ public class ConstantTensor extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override + public TensorType type(TypeContext context) { return constant.type(); } + + @Override public Tensor evaluate(EvaluationContext context) { return constant; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java index 653be8dacf0..e302f6606e7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java @@ -25,10 +25,10 @@ public class Diag extends CompositeTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Diag must have 0 arguments, got " + arguments.size()); return this; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index ef2770c04f5..e70d1de3db7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -7,6 +7,7 @@ import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.List; @@ -47,10 +48,10 @@ public class Generate extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Generate must have 0 arguments, got " + arguments.size()); return this; @@ -60,6 +61,9 @@ public class Generate extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override + public TensorType type(TypeContext context) { return type; } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor.Builder builder = Tensor.Builder.of(type); IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type)); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 174a8e4c435..7812c985091 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -11,6 +11,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayList; import java.util.Collections; @@ -47,6 +48,7 @@ public class Join extends PrimitiveTensorFunction { } /** Returns the type resulting from applying Join to the two given types */ + // TODO: Replace implementation by new TensorType.Builder(a.type(), b.type()).build(); public static TensorType outputType(TensorType a, TensorType b) { TensorType.Builder typeBuilder = new TensorType.Builder(); for (int i = 0; i < a.dimensions().size(); ++i) { @@ -70,15 +72,13 @@ public class Join extends PrimitiveTensorFunction { return typeBuilder.build(); } - public TensorFunction argumentA() { return argumentA; } - public TensorFunction argumentB() { return argumentB; } public DoubleBinaryOperator combinator() { return combinator; } @Override - public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); } + public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 2) throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size()); return new Join(arguments.get(0), arguments.get(1), combinator); @@ -95,6 +95,11 @@ public class Join extends PrimitiveTensorFunction { } @Override + public TensorType type(TypeContext context) { + return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build(); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java index 91a9c6d1b27..d7f7ae59d62 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java @@ -14,17 +14,17 @@ public class L1Normalize extends CompositeTensorFunction { private final TensorFunction argument; private final String dimension; - + public L1Normalize(TensorFunction argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("L1Normalize must have 1 argument, got " + arguments.size()); return new L1Normalize(arguments.get(0), dimension); @@ -38,7 +38,7 @@ public class L1Normalize extends CompositeTensorFunction { new Reduce(primitiveArgument, Reduce.Aggregator.sum, dimension), ScalarFunctions.divide()); } - + @Override public String toString(ToStringContext context) { return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java index bdf8921f81d..e2c526760bd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java @@ -14,17 +14,17 @@ public class L2Normalize extends CompositeTensorFunction { private final TensorFunction argument; private final String dimension; - + public L2Normalize(TensorFunction argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("L2Normalize must have 1 argument, got " + arguments.size()); return new L2Normalize(arguments.get(0), dimension); @@ -40,7 +40,7 @@ public class L2Normalize extends CompositeTensorFunction { ScalarFunctions.sqrt()), ScalarFunctions.divide()); } - + @Override public String toString(ToStringContext context) { return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index a5e1a016a41..53504868ff2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -2,12 +2,11 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; -import com.google.common.collect.ImmutableMap; -import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Iterator; @@ -39,10 +38,10 @@ public class Map extends PrimitiveTensorFunction { public DoubleUnaryOperator mapper() { return mapper; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Map must have 1 argument, got " + arguments.size()); return new Map(arguments.get(0), mapper); @@ -54,6 +53,11 @@ public class Map extends PrimitiveTensorFunction { } @Override + public TensorType type(TypeContext context) { + return argument.type(context); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor argument = argument().evaluate(context); Tensor.Builder builder = Tensor.Builder.of(argument.type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index 4071917c2b5..935e4761cfe 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -27,10 +27,10 @@ public class Matmul extends CompositeTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); } + public List<TensorFunction> arguments() { return ImmutableList.of(argument1, argument2); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 2) throw new IllegalArgumentException("Matmul must have 2 arguments, got " + arguments.size()); return new Matmul(arguments.get(0), arguments.get(1), dimension); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java index 958ef85d1dc..1475f7f4ac1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java @@ -28,10 +28,10 @@ public class Random extends CompositeTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Random must have 0 arguments, got " + arguments.size()); return this; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java index 8e7f4e4c773..d951ec9ccbd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java @@ -26,10 +26,10 @@ public class Range extends CompositeTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Range must have 0 arguments, got " + arguments.size()); return this; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index de9f90a5804..76a938b9fe2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -8,6 +8,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.HashMap; @@ -73,10 +74,10 @@ public class Reduce extends PrimitiveTensorFunction { public TensorFunction argument() { return argument; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size()); return new Reduce(arguments.get(0), aggregator, dimensions); @@ -100,6 +101,19 @@ public class Reduce extends PrimitiveTensorFunction { } @Override + public TensorType type(TypeContext context) { + return type(argument.type(context)); + } + + private TensorType type(TensorType argumentType) { + TensorType.Builder builder = new TensorType.Builder(); + for (TensorType.Dimension dimension : argumentType.dimensions()) + if ( ! dimensions.contains(dimension.name())) // keep + builder.dimension(dimension); + return builder.build(); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor argument = this.argument.evaluate(context); if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) @@ -113,12 +127,7 @@ public class Reduce extends PrimitiveTensorFunction { else return reduceAllGeneral(argument); - // Reduce type - TensorType.Builder builder = new TensorType.Builder(); - for (TensorType.Dimension dimension : argument.type().dimensions()) - if ( ! dimensions.contains(dimension.name())) // keep - builder.dimension(dimension); - TensorType reducedType = builder.build(); + TensorType reducedType = type(argument.type()); // Reduce cells Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index ec9b762a41c..de3d2be265a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -7,6 +7,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.HashMap; @@ -26,6 +27,7 @@ public class Rename extends PrimitiveTensorFunction { private final TensorFunction argument; private final List<String> fromDimensions; private final List<String> toDimensions; + private final Map<String, String> fromToMap; public Rename(TensorFunction argument, String fromDimension, String toDimension) { this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension)); @@ -43,13 +45,24 @@ public class Rename extends PrimitiveTensorFunction { this.argument = argument; this.fromDimensions = ImmutableList.copyOf(fromDimensions); this.toDimensions = ImmutableList.copyOf(toDimensions); + this.fromToMap = fromToMap(fromDimensions, toDimensions); + } + + public List<String> fromDimensions() { return fromDimensions; } + public List<String> toDimensions() { return toDimensions; } + + private static Map<String, String> fromToMap(List<String> fromDimensions, List<String> toDimensions) { + Map<String, String> map = new HashMap<>(); + for (int i = 0; i < fromDimensions.size(); i++) + map.put(fromDimensions.get(i), toDimensions.get(i)); + return map; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Rename must have 1 argument, got " + arguments.size()); return new Rename(arguments.get(0), fromDimensions, toDimensions); @@ -59,11 +72,22 @@ public class Rename extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override + public TensorType type(TypeContext context) { + return type(argument.type(context)); + } + + private TensorType type(TensorType type) { + TensorType.Builder builder = new TensorType.Builder(); + for (TensorType.Dimension dimension : type.dimensions()) + builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); + return builder.build(); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor tensor = argument.evaluate(context); - Map<String, String> fromToMap = fromToMap(); - TensorType renamedType = rename(tensor.type(), fromToMap); + TensorType renamedType = type(tensor.type()); // an array which lists the index of each label in the renamed type int[] toIndexes = new int[tensor.type().dimensions().size()]; @@ -82,13 +106,6 @@ public class Rename extends PrimitiveTensorFunction { return builder.build(); } - private TensorType rename(TensorType type, Map<String, String> fromToMap) { - TensorType.Builder builder = new TensorType.Builder(); - for (TensorType.Dimension dimension : type.dimensions()) - builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); - return builder.build(); - } - private TensorAddress rename(TensorAddress address, int[] toIndexes) { String[] reorderedLabels = new String[toIndexes.length]; for (int i = 0; i < toIndexes.length; i++) @@ -102,13 +119,6 @@ public class Rename extends PrimitiveTensorFunction { toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; } - private Map<String, String> fromToMap() { - Map<String, String> map = new HashMap<>(); - for (int i = 0; i < fromDimensions.size(); i++) - map.put(fromDimensions.get(i), toDimensions.get(i)); - return map; - } - private String toVectorString(List<String> elements) { if (elements.size() == 1) return elements.get(0); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index c856b548180..32cff5ac84a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -16,21 +16,21 @@ public class Softmax extends CompositeTensorFunction { private final TensorFunction argument; private final String dimension; - + public Softmax(TensorFunction argument, String dimension) { this.argument = argument; this.dimension = dimension; } - + public static TensorType outputType(TensorType inputType, String dimension) { return Reduce.outputType(inputType, ImmutableList.of(dimension)); } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Softmax must have 1 argument, got " + arguments.size()); return new Softmax(arguments.get(0), dimension); @@ -45,7 +45,7 @@ public class Softmax extends CompositeTensorFunction { dimension), ScalarFunctions.divide()); } - + @Override public String toString(ToStringContext context) { return "softmax(" + argument.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 533a46f87fe..78ab09c7820 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -3,8 +3,10 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.List; @@ -19,14 +21,14 @@ import java.util.List; public abstract class TensorFunction { /** Returns the function arguments of this node in the order they are applied */ - public abstract List<TensorFunction> functionArguments(); + public abstract List<TensorFunction> arguments(); /** * Returns a copy of this tensor function with the arguments replaced by the given list of arguments. * * @throws IllegalArgumentException if the argument list has the wrong size for this function */ - public abstract TensorFunction replaceArguments(List<TensorFunction> arguments); + public abstract TensorFunction withArguments(List<TensorFunction> arguments); /** * Translate this function - and all of its arguments recursively - @@ -43,6 +45,13 @@ public abstract class TensorFunction { */ public abstract Tensor evaluate(EvaluationContext context); + /** + * Returns the type of the tensor this produces given the input types in the context + * + * @param context a context which must be passed to all nexted functions when evaluating + */ + public abstract TensorType type(TypeContext context); + /** Evaluate with no context */ public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java index 2464be981f5..78ff0731566 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java @@ -14,7 +14,7 @@ public class XwPlusB extends CompositeTensorFunction { private final TensorFunction x, w, b; private final String dimension; - + public XwPlusB(TensorFunction x, TensorFunction w, TensorFunction b, String dimension) { this.x = x; this.w = w; @@ -23,10 +23,10 @@ public class XwPlusB extends CompositeTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return ImmutableList.of(x, w, b); } + public List<TensorFunction> arguments() { return ImmutableList.of(x, w, b); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 3) throw new IllegalArgumentException("XwPlusB must have 3 arguments, got " + arguments.size()); return new XwPlusB(arguments.get(0), arguments.get(1), arguments.get(2), dimension); @@ -43,7 +43,7 @@ public class XwPlusB extends CompositeTensorFunction { primitiveB, ScalarFunctions.add()); } - + @Override public String toString(ToStringContext context) { return "xw_plus_b(" + x.toString(context) + ", " + |