diff options
author | Jon Marius Venstad <jonmv@users.noreply.github.com> | 2018-01-02 12:12:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-01-02 12:12:08 +0100 |
commit | 95ae8c562a8826f03bd2faad82b0ffb754133342 (patch) | |
tree | 35545093828437f6d7704b4b6c3646a39ff50a00 | |
parent | 21dd7e03056a77e1de75e6e95413c3b00e6615ec (diff) | |
parent | b73a4d2c8b5e7ae83743b10b8f21836811e5dff4 (diff) |
Merge branch 'master' into jvenstad/zone-cleanup-4
316 files changed, 9650 insertions, 2023 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java index c1a786194a2..7506c884715 100644 --- a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java +++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java @@ -7,6 +7,7 @@ import com.yahoo.config.provision.Zone; import com.yahoo.path.Path; import com.yahoo.io.IOUtils; import com.yahoo.io.reader.NamedReader; +import com.yahoo.path.Path; import com.yahoo.text.XML; import com.yahoo.vespa.config.ConfigDefinitionKey; import org.w3c.dom.Element; @@ -14,8 +15,17 @@ import org.xml.sax.SAXException; import javax.xml.parsers.ParserConfigurationException; import javax.xml.transform.TransformerException; -import java.io.*; -import java.util.*; +import java.io.File; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.Reader; +import java.util.Collection; +import java.util.Collections; +import java.util.Enumeration; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.jar.JarEntry; import java.util.jar.JarFile; @@ -228,9 +238,9 @@ public interface ApplicationPackage { throw new UnsupportedOperationException("This application package cannot write its metadata"); } - /** - * Returns the single host allocation info of this, or an empty map if no allocation is available - * + /** + * Returns the single host allocation info of this, or an empty map if no allocation is available + * * @deprecated please use #getAllocatedHosts */ // TODO: Remove on Vespa 7 @@ -261,7 +271,8 @@ public interface ApplicationPackage { * * @return A new application package instance pointing to a new location */ - default ApplicationPackage preprocess(Zone zone, RuleConfigDeriver ruleConfigDeriver, DeployLogger logger) throws IOException, TransformerException, ParserConfigurationException, SAXException { + default ApplicationPackage preprocess(Zone zone, RuleConfigDeriver ruleConfigDeriver, DeployLogger logger) + throws IOException, TransformerException, ParserConfigurationException, SAXException { throw new UnsupportedOperationException("This application package does not support preprocessing"); } diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java b/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java index cee501841b4..61cab2f6ce7 100644 --- a/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java +++ b/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java @@ -4,10 +4,9 @@ package com.yahoo.config.application.api; import java.util.logging.Level; /** - * Used during application deployment to persist and propagate messages to end user + * Used during application deployment to propagate messages to the end user * - * @author lulf - * @since 5.1 + * @author Ulf Lillengen */ public interface DeployLogger { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java index 69e353ceb35..65176006a2a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java @@ -118,7 +118,7 @@ public class TensorTransformer extends ExpressionTransformer { private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) { ExpressionNode arg1 = node.children().get(0); ExpressionNode arg2 = node.children().get(1); - + TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument(arg1); Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name()); String dimension = ((ReferenceNode) arg2).getName(); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java index c52a5dc465d..f932265cb93 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java @@ -259,7 +259,7 @@ public final class Attribute implements Cloneable, Serializable { throw new IllegalArgumentException("Field " + fieldType + " not supported in convertCollectionType"); } } - + private static Optional<TensorType> convertTensorType(DataType fieldType) { if ( ! ( fieldType instanceof TensorDataType)) return Optional.empty(); return Optional.of(((TensorDataType)fieldType).getTensorType()); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java index c8918f39834..8b6df1a87db 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java @@ -29,7 +29,7 @@ public class ImmutableImportedSDField implements ImmutableSDField { @Override public <T extends Expression> boolean containsExpression(Class<T> searchFor) { - throw createUnsupportedException(); + throw createUnsupportedException(searchFor.getSimpleName()); } @Override @@ -79,9 +79,9 @@ public class ImmutableImportedSDField implements ImmutableSDField { @Override public Index getIndex(String name) { - if (!importedField.fieldName().equals(name)) { + if ( ! importedField.fieldName().equals(name)) { throw new IllegalArgumentException("Getting an index (" + name + ") with different name than the imported field (" - + importedField.fieldName() + ") is not supported"); + + importedField.fieldName() + ") is not supported"); } String targetIndexName = importedField.targetField().getName(); return importedField.targetField().getIndex(targetIndexName); @@ -104,7 +104,7 @@ public class ImmutableImportedSDField implements ImmutableSDField { @Override public ScriptExpression getIndexingScript() { - throw createUnsupportedException(); + throw createUnsupportedException("indexing"); } @Override @@ -119,12 +119,12 @@ public class ImmutableImportedSDField implements ImmutableSDField { @Override public ImmutableSDField getStructField(String name) { - throw createUnsupportedException(); + throw createUnsupportedException("struct"); } @Override public Collection<? extends ImmutableSDField> getStructFields() { - throw createUnsupportedException(); + throw createUnsupportedException("struct"); } @Override @@ -134,12 +134,12 @@ public class ImmutableImportedSDField implements ImmutableSDField { @Override public Stemming getStemming(Search search) { - throw createUnsupportedException(); + throw createUnsupportedException("stemming"); } @Override public Ranking getRanking() { - throw createUnsupportedException(); + throw createUnsupportedException("ranking"); } @Override @@ -158,8 +158,8 @@ public class ImmutableImportedSDField implements ImmutableSDField { importedField.targetField().getDataType()); } - private static UnsupportedOperationException createUnsupportedException() { - return new UnsupportedOperationException("This aspect is not meaningful or relevant for an imported field."); + private static UnsupportedOperationException createUnsupportedException(String aspect) { + return new UnsupportedOperationException("'" + aspect + "' is not meaningful or relevant for an imported field."); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java b/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java index 96a9448739a..9368d6aaa39 100644 --- a/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java +++ b/config-model/src/main/java/com/yahoo/vespa/configmodel/producers/DocumentManager.java @@ -20,7 +20,7 @@ import java.util.Set; */ public class DocumentManager { - public DocumentmanagerConfig.Builder produce(DocumentModel model, + public DocumentmanagerConfig.Builder produce(DocumentModel model, DocumentmanagerConfig.Builder documentConfigBuilder) { documentConfigBuilder.enablecompression(false); Set<DataType> handled = new HashSet<>(); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java index 6eeb12ffdd9..ce3c04f41f7 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java @@ -45,7 +45,7 @@ public class ConstantTensorJsonValidator { throw new IllegalArgumentException("Ranking constant file names must end with either '.json' or '.json.lz4'"); } } - + private void validateTensor(TensorType type, Reader tensorData) { wrapIOException(() -> { this.parser = jsonFactory.createParser(tensorData); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java index 4a9310799aa..c686f023d5b 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java @@ -64,7 +64,7 @@ public class RankingConstantsValidator extends Validator { private void validateRankingConstant(RankingConstant rankingConstant, ApplicationPackage applicationPackage) throws FileNotFoundException { ApplicationFile tensorApplicationFile = applicationPackage.getFile(Path.fromString(rankingConstant.getFileName())); - new ConstantTensorJsonValidator().validate(rankingConstant.getFileName(), + new ConstantTensorJsonValidator().validate(rankingConstant.getFileName(), rankingConstant.getTensorType(), tensorApplicationFile.createReader()); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java index 4383e55e45d..28a54771c21 100755 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java @@ -220,6 +220,7 @@ public final class ContainerCluster addSimpleComponent("com.yahoo.container.jdisc.metric.MetricConsumerProviderProvider"); addSimpleComponent("com.yahoo.container.jdisc.metric.MetricProvider"); addSimpleComponent("com.yahoo.container.jdisc.metric.MetricUpdater"); + addSimpleComponent(com.yahoo.container.jdisc.LoggingRequestHandler.Context.class); addSimpleComponent(com.yahoo.metrics.simple.MetricManager.class.getName(), null, MetricProperties.BUNDLE_SYMBOLIC_NAME); addSimpleComponent(com.yahoo.metrics.simple.jdisc.JdiscMetricsFactory.class.getName(), null, MetricProperties.BUNDLE_SYMBOLIC_NAME); addSimpleComponent("com.yahoo.container.jdisc.state.StateMonitor"); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java index a5b7d67e377..e1675007bbc 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java @@ -8,6 +8,7 @@ import com.yahoo.vespa.model.content.Redundancy; * Builds redundancy config for a content cluster. */ public class RedundancyBuilder { + Redundancy build(ModelElement clusterXml) { Integer initialRedundancy = 2; Integer finalRedundancy = 3; @@ -37,4 +38,5 @@ public class RedundancyBuilder { return new Redundancy(initialRedundancy, finalRedundancy, readyCopies); } + } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java index 9407c21fee8..960a3b7d6db 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java @@ -173,7 +173,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { assertFalse(findProperty(rawProfile.configProperties(), "vespa.type.query.tensor3").isPresent()); assertFalse(findProperty(rawProfile.configProperties(), "vespa.type.query.numeric").isPresent()); } - + private static Optional<String> findProperty(List<Pair<String, String>> properties, String key) { for (Pair<String, String> property : properties) if (property.getFirst().equals(key)) diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java index 7cd00e155bb..4600f6ae4c6 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java @@ -123,7 +123,7 @@ public class ExportingTestCase extends AbstractExportingTestCase { public void testIndexinfoFieldsets() throws IOException, ParseException { assertCorrectDeriving("indexinfo_fieldsets"); } - + @Test public void testStreamingJuniper() throws IOException, ParseException { assertCorrectDeriving("streamingjuniper"); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index 12bdd8d2b5c..e5693d24f0f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -202,5 +202,5 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { } return b.toString(); } - + } diff --git a/configd/src/apps/sentinel/service.cpp b/configd/src/apps/sentinel/service.cpp index 3c762a957ec..5633c356bc7 100644 --- a/configd/src/apps/sentinel/service.cpp +++ b/configd/src/apps/sentinel/service.cpp @@ -113,8 +113,7 @@ Service::terminate(bool catchable, bool dumpState) ret == 0 ? "OK" : strerror(errno)); return ret; } else { - setState(KILLING); - if (dumpState) { + if (dumpState && _state != KILLING) { vespalib::string pstackCmd = make_string("pstack %d > %s/%s.pstack.%d", _pid, getVespaTempDir().c_str(), name().c_str(), _pid); LOG(info, "%s:%d failed to stop. Stack dumped at %s", name().c_str(), _pid, pstackCmd.c_str()); @@ -123,6 +122,7 @@ Service::terminate(bool catchable, bool dumpState) LOG(warning, "'%s' failed with return value %d", pstackCmd.c_str(), pstackRet); } } + setState(KILLING); kill(_pid, SIGCONT); // if it was stopped for some reason int ret = kill(_pid, SIGKILL); LOG(debug, "%s: kill -SIGKILL %d: %s", name().c_str(), (int)_pid, diff --git a/configdefinitions/src/vespa/configserver.def b/configdefinitions/src/vespa/configserver.def index cbc2317da2d..9072a20c006 100644 --- a/configdefinitions/src/vespa/configserver.def +++ b/configdefinitions/src/vespa/configserver.def @@ -27,6 +27,7 @@ payloadCompressionType enum { UNCOMPRESSED, LZ4 } default=LZ4 serverId string default="localhost" hostedVespa bool default=false numParallelTenantLoaders int default=4 +zookeeperLocalhostAffinity bool default=false # Zone information environment string default="prod" diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceChecker.java b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceChecker.java index 39cd4629ff0..925f8324b30 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceChecker.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceChecker.java @@ -35,7 +35,6 @@ public class ApplicationConvergenceChecker extends AbstractComponent { private final static Set<String> serviceTypesToCheck = new HashSet<>(Arrays.asList( "container", - "container-clustercontroller", "qrserver", "docprocservice", "searchnode", diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/CombinedLegacyDistribution.java b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/CombinedLegacyDistribution.java index 1046ed93491..819f9a9d5d6 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/CombinedLegacyDistribution.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/CombinedLegacyDistribution.java @@ -21,11 +21,12 @@ import java.util.logging.Logger; public class CombinedLegacyDistribution implements FileDistribution { private final static Logger log = Logger.getLogger(CombinedLegacyDistribution.class.getName()); - private final Supervisor supervisor = new Supervisor(new Transport()); + private final Supervisor supervisor; private final FileDistribution legacy; private final boolean disableFileDistributor; - CombinedLegacyDistribution(FileDBHandler legacy, boolean disableFileDistributor) { + CombinedLegacyDistribution(Supervisor supervisor, FileDBHandler legacy, boolean disableFileDistributor) { + this.supervisor = supervisor; this.legacy = legacy; this.disableFileDistributor = disableFileDistributor; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionProvider.java b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionProvider.java index 38fa3087f88..cd3f0f7f167 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionProvider.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionProvider.java @@ -4,6 +4,7 @@ package com.yahoo.vespa.config.server.filedistribution; import com.yahoo.config.FileReference; import com.yahoo.config.model.api.FileDistribution; import com.yahoo.config.application.api.FileRegistry; +import com.yahoo.jrt.Supervisor; import com.yahoo.vespa.filedistribution.FileDistributionManager; import java.io.File; @@ -35,16 +36,17 @@ public class FileDistributionProvider { } } - public FileDistributionProvider(File applicationDir, String zooKeepersSpec, + public FileDistributionProvider(Supervisor supervisor, File applicationDir, String zooKeepersSpec, String applicationId, Lock fileDistributionLock, boolean disableFileDistributor) { ensureDirExists(FileDistribution.getDefaultFileDBPath()); final FileDistributionManager manager = new FileDistributionManager( FileDistribution.getDefaultFileDBPath(), applicationDir, zooKeepersSpec, applicationId, fileDistributionLock); - this.fileDistribution = new CombinedLegacyDistribution(new FileDBHandler(manager), disableFileDistributor); + this.fileDistribution = new CombinedLegacyDistribution(supervisor, new FileDBHandler(manager), disableFileDistributor); this.fileRegistry = new CombinedLegacyRegistry(new FileDBRegistry(new ManagerWrapper(manager)), new FileDBRegistry(new ApplicationFileManager(applicationDir, new FileDirectory()))); + } // For testing only diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandler.java index 3ec4d1b6e46..94707635950 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandler.java @@ -23,14 +23,14 @@ import java.util.concurrent.Executor; public class HttpGetConfigHandler extends HttpHandler { private final RequestHandler requestHandler; - public HttpGetConfigHandler(Executor executor, RequestHandler requestHandler, AccessLog accessLog) { - super(executor, accessLog); + public HttpGetConfigHandler(HttpHandler.Context ctx, RequestHandler requestHandler) { + super(ctx); this.requestHandler = requestHandler; } @Inject - public HttpGetConfigHandler(Executor executor, Tenants tenants, AccessLog accesslog) { - this(executor, tenants.defaultTenant().getRequestHandler(), accesslog); + public HttpGetConfigHandler(HttpHandler.Context ctx, Tenants tenants) { + this(ctx, tenants.defaultTenant().getRequestHandler()); } @Override diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpHandler.java index cc78c2715e2..e8db448b245 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpHandler.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server.http; +import com.google.inject.Inject; + import com.yahoo.config.provision.ApplicationLockException; import com.yahoo.container.jdisc.HttpRequest; import com.yahoo.container.jdisc.HttpResponse; @@ -25,8 +27,8 @@ import java.util.concurrent.Executor; */ public class HttpHandler extends LoggingRequestHandler { - public HttpHandler(Executor executor, AccessLog accessLog) { - super(executor, accessLog); + public HttpHandler(HttpHandler.Context ctx) { + super(ctx); } @Override diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandler.java index 5ea0b38c110..64361c0771c 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandler.java @@ -32,12 +32,12 @@ public class HttpListConfigsHandler extends HttpHandler { private final RequestHandler requestHandler; @Inject - public HttpListConfigsHandler(Executor executor, AccessLog accessLog, Tenants tenants) { - this(executor, accessLog, tenants.defaultTenant().getRequestHandler()); + public HttpListConfigsHandler(HttpHandler.Context ctx, Tenants tenants) { + this(ctx, tenants.defaultTenant().getRequestHandler()); } - public HttpListConfigsHandler(Executor executor, AccessLog accessLog, RequestHandler requestHandler) { - super(executor, accessLog); + public HttpListConfigsHandler(HttpHandler.Context ctx, RequestHandler requestHandler) { + super(ctx); this.requestHandler = requestHandler; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListNamedConfigsHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListNamedConfigsHandler.java index 7c51fd131ff..81163d79341 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListNamedConfigsHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/HttpListNamedConfigsHandler.java @@ -25,14 +25,16 @@ import java.util.concurrent.Executor; public class HttpListNamedConfigsHandler extends HttpHandler { private final RequestHandler requestHandler; - public HttpListNamedConfigsHandler(Executor executor, RequestHandler requestHandler, AccessLog accessLog) { - super(executor, accessLog); + public HttpListNamedConfigsHandler(HttpHandler.Context ctx, + RequestHandler requestHandler) { + super(ctx); this.requestHandler = requestHandler; } @Inject - public HttpListNamedConfigsHandler(Executor executor, Tenants tenants, AccessLog accessLog) { - this(executor, tenants.defaultTenant().getRequestHandler(), accessLog); + public HttpListNamedConfigsHandler(HttpHandler.Context ctx, + Tenants tenants) { + this(ctx, tenants.defaultTenant().getRequestHandler()); } @Override diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/SessionHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/SessionHandler.java index 40ffc8e9da3..5acb6e81a83 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/SessionHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/SessionHandler.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server.http; +import com.google.inject.Inject; + import com.yahoo.config.provision.ApplicationId; import com.yahoo.container.jdisc.HttpRequest; import com.yahoo.container.logging.AccessLog; @@ -27,8 +29,9 @@ public class SessionHandler extends HttpHandler { protected final ApplicationRepository applicationRepository; - public SessionHandler(Executor executor, AccessLog accessLog, ApplicationRepository applicationRepository) { - super(executor, accessLog); + public SessionHandler(HttpHandler.Context ctx, ApplicationRepository applicationRepository) + { + super(ctx); this.applicationRepository = applicationRepository; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java index ef122147d79..819f1a35cf3 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server.http.v2; +import com.google.inject.Inject; + import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ApplicationName; @@ -37,11 +39,11 @@ public class ApplicationHandler extends HttpHandler { private final Zone zone; private final ApplicationRepository applicationRepository; - public ApplicationHandler(Executor executor, - AccessLog accessLog, + @Inject + public ApplicationHandler(HttpHandler.Context ctx, Zone zone, ApplicationRepository applicationRepository) { - super(executor, accessLog); + super(ctx); this.zone = zone; this.applicationRepository = applicationRepository; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HostHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HostHandler.java index 2acaa67baef..13933544ad1 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HostHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HostHandler.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server.http.v2; +import com.google.inject.Inject; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.TenantName; import com.yahoo.config.provision.Zone; @@ -28,8 +29,10 @@ public class HostHandler extends HttpHandler { final HostRegistries hostRegistries; private final Zone zone; - public HostHandler(Executor executor, AccessLog accessLog, GlobalComponentRegistry globalComponentRegistry) { - super(executor, accessLog); + @Inject + public HostHandler(HttpHandler.Context ctx, + GlobalComponentRegistry globalComponentRegistry) { + super(ctx); this.hostRegistries = globalComponentRegistry.getHostRegistries(); this.zone = globalComponentRegistry.getZone(); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandler.java index 1b566fbb9c5..0ca720c9710 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandler.java @@ -27,8 +27,10 @@ public class HttpGetConfigHandler extends HttpHandler { private final Tenants tenants; @Inject - public HttpGetConfigHandler(Executor executor, AccessLog accesslog, Tenants tenants) { - super(executor, accesslog); + public HttpGetConfigHandler(HttpHandler.Context ctx, + Tenants tenants) + { + super(ctx); this.tenants = tenants; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandler.java index ea3a1a2c9f4..2a9e2b1ecf4 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandler.java @@ -34,8 +34,10 @@ public class HttpListConfigsHandler extends HttpHandler { private final Zone zone; @Inject - public HttpListConfigsHandler(Executor executor, AccessLog accesslog, Tenants tenants, Zone zone) { - super(executor, accesslog); + public HttpListConfigsHandler(HttpHandler.Context ctx, + Tenants tenants, Zone zone) + { + super(ctx); this.tenants = tenants; this.zone = zone; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListNamedConfigsHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListNamedConfigsHandler.java index 2262b8bc722..0a55d3585e0 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListNamedConfigsHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/HttpListNamedConfigsHandler.java @@ -29,8 +29,10 @@ public class HttpListNamedConfigsHandler extends HttpHandler { private final Zone zone; @Inject - public HttpListNamedConfigsHandler(Executor executor, AccessLog accesslog, Tenants tenants, Zone zone) { - super(executor, accesslog); + public HttpListNamedConfigsHandler(HttpHandler.Context ctx, + Tenants tenants, Zone zone) + { + super(ctx); this.tenants = tenants; this.zone = zone; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandler.java index 79f551c270b..42872881088 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandler.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server.http.v2; +import com.google.inject.Inject; import com.google.common.base.Function; import com.google.common.collect.Collections2; import com.yahoo.config.provision.TenantName; @@ -29,8 +30,11 @@ import java.util.concurrent.Executor; public class ListApplicationsHandler extends HttpHandler { private final Tenants tenants; private final Zone zone; - public ListApplicationsHandler(Executor executor, AccessLog accessLog, Tenants tenants, Zone zone) { - super(executor, accessLog); + + @Inject + public ListApplicationsHandler(HttpHandler.Context ctx, + Tenants tenants, Zone zone) { + super(ctx); this.tenants = tenants; this.zone = zone; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandler.java index f1c75ff0a01..b2330ebd97f 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandler.java @@ -33,12 +33,11 @@ public class SessionActiveHandler extends SessionHandler { private final Zone zone; @Inject - public SessionActiveHandler(Executor executor, - AccessLog accessLog, + public SessionActiveHandler(SessionHandler.Context ctx, + ApplicationRepository applicationRepository, Tenants tenants, - Zone zone, - ApplicationRepository applicationRepository) { - super(executor, accessLog, applicationRepository); + Zone zone) { + super(ctx, applicationRepository); this.tenants = tenants; this.zone = zone; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandler.java index c9d5407e0e3..524eb01e625 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandler.java @@ -28,11 +28,11 @@ public class SessionContentHandler extends SessionHandler { private final ContentHandler contentHandler = new ContentHandler(); @Inject - public SessionContentHandler(Executor executor, - AccessLog accessLog, - Tenants tenants, - ApplicationRepository applicationRepository) { - super(executor, accessLog, applicationRepository); + public SessionContentHandler(SessionHandler.Context ctx, + ApplicationRepository applicationRepository, + Tenants tenants) + { + super(ctx, applicationRepository); this.tenants = tenants; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandler.java index 5908851e399..b0c251f477c 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandler.java @@ -49,12 +49,11 @@ public class SessionCreateHandler extends SessionHandler { private final Duration zookeeperBarrierTimeout; @Inject - public SessionCreateHandler(Executor executor, - AccessLog accessLog, + public SessionCreateHandler(SessionHandler.Context ctx, + ApplicationRepository applicationRepository, Tenants tenants, - ConfigserverConfig configserverConfig, - ApplicationRepository applicationRepository) { - super(executor, accessLog, applicationRepository); + ConfigserverConfig configserverConfig) { + super(ctx, applicationRepository); this.tenants = tenants; this.zookeeperBarrierTimeout = Duration.ofSeconds(configserverConfig.zookeeper().barrierTimeout()); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandler.java index 03a3f3556e4..2b432a50ee1 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandler.java @@ -41,12 +41,11 @@ public class SessionPrepareHandler extends SessionHandler { private final Duration zookeeperBarrierTimeout; @Inject - public SessionPrepareHandler(Executor executor, - AccessLog accessLog, + public SessionPrepareHandler(SessionHandler.Context ctx, + ApplicationRepository applicationRepository, Tenants tenants, - ConfigserverConfig configserverConfig, - ApplicationRepository applicationRepository) { - super(executor, accessLog, applicationRepository); + ConfigserverConfig configserverConfig) { + super(ctx, applicationRepository); this.tenants = tenants; this.zookeeperBarrierTimeout = Duration.ofSeconds(configserverConfig.zookeeper().barrierTimeout()); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/TenantHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/TenantHandler.java index 5c1d8a36f6a..955bba5f5b4 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/TenantHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/TenantHandler.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.config.server.http.v2; import java.util.List; import java.util.concurrent.Executor; +import com.google.inject.Inject; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.TenantName; @@ -29,8 +30,10 @@ public class TenantHandler extends HttpHandler { private static final String TENANT_NAME_REGEXP = "[\\w-]+"; private final Tenants tenants; - public TenantHandler(Executor executor, AccessLog accessLog, Tenants tenants) { - super(executor, accessLog); + @Inject + public TenantHandler(HttpHandler.Context ctx, + Tenants tenants) { + super(ctx); this.tenants = tenants; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java index 99a34a45a2f..243c47ba3d7 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java @@ -3,6 +3,8 @@ package com.yahoo.vespa.config.server.session; import com.google.inject.Inject; import com.yahoo.config.provision.ApplicationId; +import com.yahoo.jrt.Supervisor; +import com.yahoo.jrt.Transport; import com.yahoo.vespa.config.server.filedistribution.FileDistributionLock; import com.yahoo.vespa.config.server.filedistribution.FileDistributionProvider; import com.yahoo.vespa.curator.Curator; @@ -21,6 +23,7 @@ public class FileDistributionFactory { private static final String lockPath = "/vespa/filedistribution/lock"; private final String zkSpec; private final Lock lock; + private final Supervisor supervisor = new Supervisor(new Transport()); @Inject public FileDistributionFactory(Curator curator) { @@ -33,7 +36,12 @@ public class FileDistributionFactory { } public FileDistributionProvider createProvider(File applicationPackage, ApplicationId applicationId, boolean disableFileDistributor) { - return new FileDistributionProvider(applicationPackage, zkSpec, applicationId.serializedForm(), lock, disableFileDistributor); + return new FileDistributionProvider(supervisor, applicationPackage, zkSpec, applicationId.serializedForm(), lock, disableFileDistributor); } + @Override + protected void finalize() throws Throwable { + super.finalize(); + supervisor.transport().shutdown().join(); + } } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandlerTest.java index 71f4e4add50..b19d6e2e257 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpGetConfigHandlerTest.java @@ -44,13 +44,9 @@ public class HttpGetConfigHandlerTest { mockRequestHandler.setAllConfigs(new HashSet<ConfigKey<?>>() {{ add(new ConfigKey<>("bar", "myid", "foo")); }} ); - handler = new HttpGetConfigHandler(new Executor() { - @SuppressWarnings("NullableProblems") - @Override - public void execute(Runnable command) { - command.run(); - } - }, mockRequestHandler, AccessLog.voidAccessLog()); + handler = new HttpGetConfigHandler( + HttpGetConfigHandler.testOnlyContext(), + mockRequestHandler); } @Test diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpHandlerTest.java index 76844bb7c21..bf881e7a546 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpHandlerTest.java @@ -25,7 +25,7 @@ public class HttpHandlerTest { @Test public void testResponse() throws IOException { final String message = "failed"; - HttpHandler httpHandler = new HttpTestHandler(Executors.newSingleThreadExecutor(), AccessLog.voidAccessLog(), new InvalidApplicationException(message)); + HttpHandler httpHandler = new HttpTestHandler(new InvalidApplicationException(message)); HttpResponse response = httpHandler.handle(HttpRequest.createTestRequest("foo", com.yahoo.jdisc.http.HttpRequest.Method.GET)); assertThat(response.getStatus(), is(Response.Status.BAD_REQUEST)); ByteArrayOutputStream baos = new ByteArrayOutputStream(); @@ -38,8 +38,8 @@ public class HttpHandlerTest { private static class HttpTestHandler extends HttpHandler { private RuntimeException exception; - public HttpTestHandler(Executor executor, AccessLog accessLog, RuntimeException exception) { - super(executor, accessLog); + public HttpTestHandler(RuntimeException exception) { + super(HttpHandler.testOnlyContext()); this.exception = exception; } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandlerTest.java index db8526150bf..01618e5a85f 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/HttpListConfigsHandlerTest.java @@ -37,18 +37,9 @@ public class HttpListConfigsHandlerTest { mockRequestHandler.setAllConfigs(new HashSet<ConfigKey<?>>() {{ add(new ConfigKey<>("bar", "conf/id/", "foo")); }} ); - handler = new HttpListConfigsHandler(new Executor() { - @Override - public void execute(Runnable command) { - command.run(); - } - }, AccessLog.voidAccessLog(), mockRequestHandler); - namedHandler = new HttpListNamedConfigsHandler(new Executor() { - @Override - public void execute(Runnable command) { - command.run(); - } - }, mockRequestHandler, AccessLog.voidAccessLog()); + HttpListConfigsHandler.Context ctx = HttpListConfigsHandler.testOnlyContext(); + handler = new HttpListConfigsHandler(ctx, mockRequestHandler); + namedHandler = new HttpListNamedConfigsHandler(ctx, mockRequestHandler); } @Test diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/SessionExampleHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/SessionExampleHandlerTest.java index 7aff8f9410b..b6d9ab5d618 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/SessionExampleHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/SessionExampleHandlerTest.java @@ -54,7 +54,7 @@ public class SessionExampleHandlerTest { public static class SessionExampleHandler extends ThreadedHttpRequestHandler { public SessionExampleHandler(Executor executor) { - super(executor); + super(executor, null); } @Override diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java index a17d485a425..c34dbe76a43 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java @@ -52,8 +52,7 @@ public class ApplicationContentHandlerTest extends ContentHandlerTestBase { testTenantBuilder.tenants().get(tenant2).getLocalSessionRepo().addSession(new MockSession(3l, FilesApplicationPackage.fromFile(new File("src/test/apps/content2")))); testTenantBuilder.tenants().get(tenant1).getApplicationRepo().createPutApplicationTransaction(idTenant1, 2l).commit(); testTenantBuilder.tenants().get(tenant2).getApplicationRepo().createPutApplicationTransaction(idTenant2, 3l).commit(); - handler = new ApplicationHandler(Runnable::run, - AccessLog.voidAccessLog(), + handler = new ApplicationHandler(ApplicationHandler.testOnlyContext(), Zone.defaultZone(), new ApplicationRepository(testTenantBuilder.createTenants(), new MockProvisioner(), diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java index 5552758a0a6..8ac64e5b28a 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java @@ -96,7 +96,8 @@ public class ApplicationHandlerTest { mockHttpProxy, new MockLogServerLogGrabber()); listApplicationsHandler = new ListApplicationsHandler( - Runnable::run, AccessLog.voidAccessLog(), tenants, Zone.defaultZone()); + ListApplicationsHandler.testOnlyContext(), + tenants, Zone.defaultZone()); } private ApplicationHandler createMockApplicationHandler( @@ -105,8 +106,7 @@ public class ApplicationHandlerTest { HttpProxy httpProxy, LogServerLogGrabber logServerLogGrabber) { return new ApplicationHandler( - Runnable::run, - AccessLog.voidAccessLog(), + ApplicationHandler.testOnlyContext(), Zone.defaultZone(), new ApplicationRepository(tenants, HostProvisionerProvider.withProvisioner(provisioner), @@ -118,8 +118,7 @@ public class ApplicationHandlerTest { private ApplicationHandler createApplicationHandler(Tenants tenants) { return new ApplicationHandler( - Runnable::run, - AccessLog.voidAccessLog(), + ApplicationHandler.testOnlyContext(), Zone.defaultZone(), new ApplicationRepository(tenants, HostProvisionerProvider.withProvisioner(provisioner), diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HostHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HostHandlerTest.java index e439f424c45..fe25170d8ba 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HostHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HostHandlerTest.java @@ -52,9 +52,9 @@ public class HostHandlerTest { hostRegistries = testComponentRegistry.getHostRegistries(); hostRegistries.createApplicationHostRegistry(mytenant).update(ApplicationId.from(mytenant, ApplicationName.defaultName(), InstanceName.defaultName()), Collections.singletonList(hostname)); hostRegistries.getTenantHostRegistry().update(mytenant, Collections.singletonList(hostname)); - hostHandler = new HostHandler(command -> { - command.run(); - }, AccessLog.voidAccessLog(), testComponentRegistry); + hostHandler = new HostHandler( + HostHandler.testOnlyContext(), + testComponentRegistry); return hostHandler; } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandlerTest.java index cc18e279002..11bacc30b27 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpGetConfigHandlerTest.java @@ -49,9 +49,9 @@ public class HttpGetConfigHandlerTest { TestTenantBuilder tb = new TestTenantBuilder(); tb.createTenant(tenant).withRequestHandler(mockRequestHandler).build(); Tenants tenants = tb.createTenants(); - handler = new HttpGetConfigHandler(command -> { - command.run(); - }, AccessLog.voidAccessLog(), tenants); + handler = new HttpGetConfigHandler( + HttpGetConfigHandler.testOnlyContext(), + tenants); } @Test diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandlerTest.java index a66e9542a5f..e7ccd9f957e 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/HttpListConfigsHandlerTest.java @@ -45,12 +45,12 @@ public class HttpListConfigsHandlerTest { TestTenantBuilder tb = new TestTenantBuilder(); tb.createTenant(TenantName.from("mytenant")).withRequestHandler(mockRequestHandler).build(); Tenants tenants = tb.createTenants(); - handler = new HttpListConfigsHandler(command -> { - command.run(); - }, AccessLog.voidAccessLog(), tenants, Zone.defaultZone()); - namedHandler = new HttpListNamedConfigsHandler(command -> { - command.run(); - }, AccessLog.voidAccessLog(), tenants, Zone.defaultZone()); + handler = new HttpListConfigsHandler( + HttpListConfigsHandler.testOnlyContext(), + tenants, Zone.defaultZone()); + namedHandler = new HttpListNamedConfigsHandler( + HttpListConfigsHandler.testOnlyContext(), + tenants, Zone.defaultZone()); } @Test diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java index 9e7853a8fdf..3233d9598d1 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java @@ -39,10 +39,10 @@ public class ListApplicationsHandlerTest { applicationRepo = testBuilder.tenants().get(mytenant).getApplicationRepo(); applicationRepo2 = testBuilder.tenants().get(foobar).getApplicationRepo(); Tenants tenants = testBuilder.createTenants(); - handler = new ListApplicationsHandler(Runnable::run, - AccessLog.voidAccessLog(), - tenants, - new Zone(Environment.dev, RegionName.from("us-east"))); + handler = new ListApplicationsHandler( + ListApplicationsHandler.testOnlyContext(), + tenants, + new Zone(Environment.dev, RegionName.from("us-east"))); } @Test diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java index 6542c865d56..04bc8d7b49a 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java @@ -373,13 +373,12 @@ public class SessionActiveHandlerTest extends SessionHandlerTest { .withApplicationRepo(applicationRepo) .build(); return new SessionActiveHandler( - Runnable::run, - AccessLog.voidAccessLog(), - testTenantBuilder.createTenants(), - Zone.defaultZone(), + SessionActiveHandler.testOnlyContext(), new ApplicationRepository(testTenantBuilder.createTenants(), hostProvisioner, - Clock.systemUTC())); + Clock.systemUTC()), + testTenantBuilder.createTenants(), + Zone.defaultZone()); } } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java index 1d831032416..e4841930cc8 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java @@ -161,15 +161,11 @@ public class SessionContentHandlerTest extends ContentHandlerTestBase { private SessionContentHandler createHandler() throws Exception { TestTenantBuilder testTenantBuilder = new TestTenantBuilder(); testTenantBuilder.createTenant(tenant).getLocalSessionRepo().addSession(new MockSession(1l, FilesApplicationPackage.fromFile(createTestApp()))); - return new SessionContentHandler(new Executor() { - @SuppressWarnings("NullableProblems") - @Override - public void execute(Runnable command) { - command.run(); - } - }, AccessLog.voidAccessLog(), testTenantBuilder.createTenants(), - new ApplicationRepository(testTenantBuilder.createTenants(), - new SessionHandlerTest.MockProvisioner(), - Clock.systemUTC())); + return new SessionContentHandler( + SessionContentHandler.testOnlyContext(), + new ApplicationRepository(testTenantBuilder.createTenants(), + new SessionHandlerTest.MockProvisioner(), + Clock.systemUTC()), + testTenantBuilder.createTenants()); } } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java index 65b12490b17..fc9264a6ef5 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java @@ -243,10 +243,13 @@ public class SessionCreateHandlerTest extends SessionHandlerTest { private SessionCreateHandler createHandler(Tenants tenants) throws Exception { TestTenantBuilder testTenantBuilder = new TestTenantBuilder(); final ConfigserverConfig configserverConfig = new ConfigserverConfig(new ConfigserverConfig.Builder()); - return new SessionCreateHandler(Runnable::run, AccessLog.voidAccessLog(), tenants, configserverConfig, - new ApplicationRepository(testTenantBuilder.createTenants(), - new SessionHandlerTest.MockProvisioner(), - Clock.systemUTC())); + return new SessionCreateHandler( + SessionCreateHandler.testOnlyContext(), + new ApplicationRepository(testTenantBuilder.createTenants(), + new SessionHandlerTest.MockProvisioner(), + Clock.systemUTC()), + tenants, configserverConfig); + } private HttpRequest post() throws FileNotFoundException { diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java index 74a2dcf8054..1759cd68062 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java @@ -383,10 +383,13 @@ public class SessionPrepareHandlerTest extends SessionHandlerTest { private SessionHandler createHandler(TestTenantBuilder builder) { final ConfigserverConfig configserverConfig = new ConfigserverConfig(new ConfigserverConfig.Builder()); - return new SessionPrepareHandler(Runnable::run, AccessLog.voidAccessLog(), builder.createTenants(), configserverConfig, - new ApplicationRepository(builder.createTenants(), - new MockProvisioner(), - Clock.systemUTC())); + return new SessionPrepareHandler( + SessionPrepareHandler.testOnlyContext(), + new ApplicationRepository(builder.createTenants(), + new MockProvisioner(), + Clock.systemUTC()), + builder.createTenants(), configserverConfig); + } private TestTenantBuilder addTenant(TenantName tenantName, diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/TenantHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/TenantHandlerTest.java index ce4b25fe529..e948bf68970 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/TenantHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/TenantHandlerTest.java @@ -27,7 +27,9 @@ public class TenantHandlerTest extends TenantTest { @Before public void setup() throws Exception { - handler = new TenantHandler(testExecutor(), null, tenants); + handler = new TenantHandler( + TenantHandler.testOnlyContext(), + tenants); } @Test diff --git a/container-accesslogging/src/main/java/com/yahoo/container/logging/JSONFormatter.java b/container-accesslogging/src/main/java/com/yahoo/container/logging/JSONFormatter.java index 85026296363..1bbc08aa0a7 100644 --- a/container-accesslogging/src/main/java/com/yahoo/container/logging/JSONFormatter.java +++ b/container-accesslogging/src/main/java/com/yahoo/container/logging/JSONFormatter.java @@ -12,7 +12,6 @@ import java.math.BigDecimal; import java.math.RoundingMode; import java.net.URI; import java.security.Principal; -import java.time.Instant; import java.util.List; import java.util.Map; import java.util.Objects; @@ -49,8 +48,6 @@ public class JSONFormatter { generator.writeStartObject(); generator.writeStringField("ip", accessLogEntry.getIpV4Address()); generator.writeNumberField("time", toTimestampInSeconds(accessLogEntry.getTimeStampMillis())); - generator.writeStringField("time-iso8601", - Instant.ofEpochMilli(accessLogEntry.getTimeStampMillis()).toString()); generator.writeNumberField("duration", durationAsSeconds(accessLogEntry.getDurationBetweenRequestResponseMillis())); generator.writeNumberField("responsesize", accessLogEntry.getReturnedContentSize()); diff --git a/container-accesslogging/src/test/java/com/yahoo/container/logging/JSONLogTestCase.java b/container-accesslogging/src/test/java/com/yahoo/container/logging/JSONLogTestCase.java index 7f81a3568dd..ae27d7b1814 100644 --- a/container-accesslogging/src/test/java/com/yahoo/container/logging/JSONLogTestCase.java +++ b/container-accesslogging/src/test/java/com/yahoo/container/logging/JSONLogTestCase.java @@ -40,7 +40,6 @@ public class JSONLogTestCase extends junit.framework.TestCase { String expectedOutput = "{\"ip\":\"152.200.54.243\"," + "\"time\":920880005.023," + - "\"time-iso8601\":\"1999-03-08T08:00:05.023Z\"," + "\"duration\":0.122," + "\"responsesize\":9875," + "\"code\":200," + @@ -69,7 +68,6 @@ public class JSONLogTestCase extends junit.framework.TestCase { String expectedOutput = "{\"ip\":\"152.200.54.243\"," + "\"time\":920880005.023," + - "\"time-iso8601\":\"1999-03-08T08:00:05.023Z\"," + "\"duration\":0.122," + "\"responsesize\":9875," + "\"code\":200," + @@ -102,7 +100,6 @@ public class JSONLogTestCase extends junit.framework.TestCase { String expectedOutput = "{\"ip\":\"152.200.54.243\"," + "\"time\":920880005.023," + - "\"time-iso8601\":\"1999-03-08T08:00:05.023Z\"," + "\"duration\":0.122," + "\"responsesize\":9875," + "\"code\":200," + @@ -128,7 +125,6 @@ public class JSONLogTestCase extends junit.framework.TestCase { expectedOutput = "{\"ip\":\"152.200.54.243\"," + "\"time\":920880005.023," + - "\"time-iso8601\":\"1999-03-08T08:00:05.023Z\"," + "\"duration\":0.122," + "\"responsesize\":9875," + "\"code\":200," + @@ -175,7 +171,6 @@ public class JSONLogTestCase extends junit.framework.TestCase { String expectedOutput = "{\"ip\":\"152.200.54.243\"," + "\"time\":920880005.023," + - "\"time-iso8601\":\"1999-03-08T08:00:05.023Z\"," + "\"duration\":0.122," + "\"responsesize\":9875," + "\"code\":200," + diff --git a/container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java b/container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java index 0095fcece4f..4f365ebbab3 100644 --- a/container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java +++ b/container-core/src/main/java/com/yahoo/container/jdisc/LoggingRequestHandler.java @@ -35,7 +35,42 @@ public abstract class LoggingRequestHandler extends ThreadedHttpRequestHandler { this(executor, accessLog, null); } + public static class Context { + final Executor executor; + final AccessLog accessLog; + final Metric metric; + @Inject + public Context(Executor executor, AccessLog accessLog, Metric metric) { + this.executor = executor; + this.accessLog = accessLog; + this.metric = metric; + } + public Context(Context other) { + this.executor = other.executor; + this.accessLog = other.accessLog; + this.metric = other.metric; + } + } + public static Context testOnlyContext() { + return new Context(new Executor() { + @Override + public void execute(Runnable command) { + command.run(); + } + }, + AccessLog.voidAccessLog(), + null); + } + @Inject + public LoggingRequestHandler(Context ctx) { + this(ctx.executor, ctx.accessLog, ctx.metric); + } + + public LoggingRequestHandler(Context ctx, boolean allowAsyncResponse) { + this(ctx.executor, ctx.accessLog, ctx.metric, allowAsyncResponse); + } + public LoggingRequestHandler(Executor executor, AccessLog accessLog, Metric metric) { this(executor, accessLog, metric, false); } diff --git a/container-dev/pom.xml b/container-dev/pom.xml index f62bbd22690..16006452e61 100644 --- a/container-dev/pom.xml +++ b/container-dev/pom.xml @@ -121,6 +121,18 @@ <groupId>org.bouncycastle</groupId> <artifactId>bcprov-jdk15on</artifactId> </exclusion> + <exclusion> + <groupId>org.tensorflow</groupId> + <artifactId>proto</artifactId> + </exclusion> + <exclusion> + <groupId>org.tensorflow</groupId> + <artifactId>tensorflow</artifactId> + </exclusion> + <exclusion> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + </exclusion> </exclusions> </dependency> <dependency> @@ -189,6 +201,18 @@ <groupId>xerces</groupId> <artifactId>xercesImpl</artifactId> </exclusion> + <exclusion> + <groupId>org.tensorflow</groupId> + <artifactId>proto</artifactId> + </exclusion> + <exclusion> + <groupId>org.tensorflow</groupId> + <artifactId>tensorflow</artifactId> + </exclusion> + <exclusion> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + </exclusion> </exclusions> </dependency> <dependency> diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/ConfiguredApplication.java b/container-disc/src/main/java/com/yahoo/container/jdisc/ConfiguredApplication.java index fa2ee8e89a9..bf696771b20 100644 --- a/container-disc/src/main/java/com/yahoo/container/jdisc/ConfiguredApplication.java +++ b/container-disc/src/main/java/com/yahoo/container/jdisc/ConfiguredApplication.java @@ -29,6 +29,7 @@ import com.yahoo.jdisc.handler.RequestHandler; import com.yahoo.jdisc.service.ClientProvider; import com.yahoo.jdisc.service.ServerProvider; import com.yahoo.jrt.ListenFailedException; +import com.yahoo.log.LogLevel; import com.yahoo.log.LogSetup; import com.yahoo.osgi.OsgiImpl; import com.yahoo.vespa.config.ConfigKey; @@ -88,6 +89,7 @@ public final class ConfiguredApplication implements Application { static { LogSetup.initVespaLogging("Container"); + log.log(LogLevel.INFO, "Starting container"); } /** diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java index 033b396bc9b..c4c57f4bc47 100644 --- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java +++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.container.jdisc.athenz; +import javax.net.ssl.SSLContext; + /** * @author mortent */ @@ -8,4 +10,5 @@ public interface AthenzIdentityProvider { String getNToken() throws AthenzIdentityProviderException; String getDomain(); String getService(); + SSLContext getSslContext(); } diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java index 356780a0900..3d6b32744c6 100644 --- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java +++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java @@ -8,6 +8,20 @@ import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider; import com.yahoo.container.jdisc.athenz.AthenzIdentityProviderException; import com.yahoo.log.LogLevel; +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import java.io.FileInputStream; +import java.io.IOException; +import java.security.KeyManagementException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.UnrecoverableKeyException; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; import java.time.Clock; import java.time.Duration; import java.time.Instant; @@ -106,6 +120,52 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen } @Override + public SSLContext getSslContext() { + try { + SSLContext sslContext = SSLContext.getInstance("TLSv1.2"); + sslContext.init(createKeyManagersWithServiceCertificate(), + createTrustManagersWithAthenzCa(), + null); + return sslContext; + } catch (NoSuchAlgorithmException | KeyManagementException e) { + throw new RuntimeException(e); + } + } + + private KeyManager[] createKeyManagersWithServiceCertificate() { + try { + credentialsRetrievedSignal.await(); + KeyStore keyStore = KeyStore.getInstance("JKS"); + keyStore.load(null); + keyStore.setKeyEntry("instance-key", + credentials.get().getKeyPair().getPrivate(), + new char[0], + new Certificate[]{credentials.get().getCertificate()}); + KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(keyStore, new char[0]); + return keyManagerFactory.getKeyManagers(); + } catch (KeyStoreException | NoSuchAlgorithmException | UnrecoverableKeyException | CertificateException | IOException e) { + throw new RuntimeException(e); + } catch (InterruptedException e) { + throw new AthenzIdentityProviderException("Failed to register instance credentials", lastThrowable.get()); + } + } + + private static TrustManager[] createTrustManagersWithAthenzCa() { + try { + KeyStore trustStore = KeyStore.getInstance("JKS"); + try (FileInputStream in = new FileInputStream("/home/y/share/ssl/certs/yahoo_certificate_bundle.jks")) { + trustStore.load(in, null); + } + TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(trustStore); + return trustManagerFactory.getTrustManagers(); + } catch (CertificateException | IOException | KeyStoreException | NoSuchAlgorithmException e) { + throw new RuntimeException(e); + } + } + + @Override public void deconstruct() { scheduler.shutdown(AWAIT_TERMINTATION_TIMEOUT); } diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/component/Deconstructor.java b/container-disc/src/main/java/com/yahoo/container/jdisc/component/Deconstructor.java index a4eb2449064..b83dd6175e1 100644 --- a/container-disc/src/main/java/com/yahoo/container/jdisc/component/Deconstructor.java +++ b/container-disc/src/main/java/com/yahoo/container/jdisc/component/Deconstructor.java @@ -37,7 +37,6 @@ public class Deconstructor implements ComponentDeconstructor { if (component instanceof AbstractComponent) { AbstractComponent abstractComponent = (AbstractComponent) component; if (abstractComponent.isDeconstructable()) { - log.info("Scheduling deconstruction of " + abstractComponent); executor.schedule(new DestructComponentTask(abstractComponent), delay, TimeUnit.SECONDS); } } else if (component instanceof Provider) { diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java index fc1bbace092..1e44a8fa64d 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumField.java @@ -1,16 +1,16 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.prelude.fastsearch; +import com.yahoo.container.search.LegacyEmulationConfig; +import com.yahoo.data.access.Inspector; +import com.yahoo.log.LogLevel; + import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.nio.ByteBuffer; import java.util.HashMap; import java.util.Map; import java.util.logging.Logger; -import com.yahoo.data.access.Inspector; -import com.yahoo.container.search.LegacyEmulationConfig; - -import com.yahoo.log.LogLevel; /** * @author Bjørn Borud @@ -25,7 +25,7 @@ public abstract class DocsumField { Map<String, Constructor<? extends DocsumField>> constructors = new HashMap<>(); - void put(String typename, Class<? extends DocsumField> fieldClass) + void put(String typename, Class<? extends DocsumField> fieldClass) throws NoSuchMethodException, SecurityException { Constructor<? extends DocsumField> constructor = fieldClass.getConstructor(String.class); constructors.put(typename, constructor); @@ -106,7 +106,7 @@ public abstract class DocsumField { public abstract Object decode(ByteBuffer b); /** - * Get the number of bytes this field occupies in the given buffer + * Get the number of bytes this field occupies in the given buffer * AND SET(!) the position to the first byte after this field. */ public abstract int getLength(ByteBuffer b); diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java index 1524a4da426..692e93bed7e 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastHit.java @@ -109,7 +109,7 @@ public class FastHit extends Hit { /** * Returns the explicitly set uri if available, returns "index:[source]/[partid]/[id]" otherwise - * + * * @return uri of hit */ public URI getUri() { @@ -128,9 +128,9 @@ public class FastHit extends Hit { } /** - * The uri of the index location of this hit ("index:[source]/[partid]/[id]"). + * The uri of the index location of this hit ("index:[source]/[partid]/[id]"). * This is the uri if no other uri is assigned - * + * * @return uri to the index. */ public URI getIndexUri() { @@ -215,7 +215,7 @@ public class FastHit extends Hit { * The empty string ("") if no value is assigned in the document. * * <li><b>Dynamic summary string fields</b>: A Java String before JuniperSearcher and a HitField after.</li> - * + * * <li><b>Numerics</b>: The corresponding numeric Java type.<br> * If the field has <i>no value</i> assigned in the document, * the special numeric {@link com.yahoo.search.result.NanNumber#NaN} is returned. diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java index e0ca7fbe6e1..d8b38667224 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/TensorField.java @@ -13,7 +13,7 @@ import java.util.Optional; /** * A tensor field. Tensors are encoded as a data field where the data (following the length) * is encoded in a tensor binary format defined by com.yahoo.tensor.serialization.TypedBinaryFormat - * + * * @author bratseth */ public class TensorField extends DocsumField implements VariableLengthField { diff --git a/container-search/src/main/java/com/yahoo/prelude/searcher/ValidateSortingSearcher.java b/container-search/src/main/java/com/yahoo/prelude/searcher/ValidateSortingSearcher.java index 15a8a670a2e..8091397237d 100644 --- a/container-search/src/main/java/com/yahoo/prelude/searcher/ValidateSortingSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/searcher/ValidateSortingSearcher.java @@ -25,7 +25,7 @@ import static com.yahoo.prelude.querytransform.NormalizingSearcher.ACCENT_REMOVA * Check sorting specification makes sense to the search cluster before * passing it on to the backend. * - * @author <a href="mailto:steinar@yahoo-inc.com">Steinar Knutsen</a> + * @author Steinar Knutsen */ @Before(PhaseNames.BACKEND) @After(ACCENT_REMOVAL) @@ -118,6 +118,7 @@ public class ValidateSortingSearcher extends Searcher { for (Sorting.FieldOrder f : l) { String name = f.getFieldName(); if ("[rank]".equals(name) || "[docid]".equals(name)) { + // built-in constants - ok } else if (names.containsKey(name)) { AttributesConfig.Attribute attrConfig = names.get(name); if (attrConfig != null) { @@ -166,18 +167,13 @@ public class ValidateSortingSearcher extends Searcher { locale = "en_US"; } - // getLogger().info("locale = " + locale + " attrConfig.sortlocale.value() = " + attrConfig.sortlocale.value() + " query.getLanguage() = " + query.getModel().getLanguage()); - // getLogger().info("locale = " + locale); - Sorting.UcaSorter.Strength strength = sorter.getStrength(); if (sorter.getStrength() == Sorting.UcaSorter.Strength.UNDEFINED) { strength = config2Strength(attrConfig.sortstrength()); } if ((sorter.getStrength() == Sorting.UcaSorter.Strength.UNDEFINED) || (sorter.getLocale() == null) || sorter.getLocale().isEmpty()) { - // getLogger().info("locale = " + locale + " strength = " + strength.toString()); sorter.setLocale(locale, strength); } - //getLogger().info("locale = " + locale + " strength = " + strength.toString() + "decompose = " + sorter.getDecomposition()); } } else { return ErrorMessage.createInvalidQueryParameter("The cluster " + getClusterName() + " has attribute config for field: " + name); diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index 0ec15b95b0d..0fd529bf262 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -16,7 +16,7 @@ import java.util.Optional; public class TensorFieldType extends FieldType { // TODO: Require tensor type - + private final Optional<TensorType> type; /** Creates a tensor field type with optional information about the kind of tensor this will hold */ diff --git a/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java b/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java index 15a0fd60511..5494d1965f8 100644 --- a/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java +++ b/container-search/src/test/java/com/yahoo/prelude/fastsearch/SlimeSummaryTestCase.java @@ -102,7 +102,7 @@ public class SlimeSummaryTestCase { public void testDecoding() { Tensor tensor1 = Tensor.from("tensor(x{},y{}):{{x:foo,y:bar}:0.1}"); Tensor tensor2 = Tensor.from("tensor(x[],y[1]):{{x:0,y:0}:-0.3}"); - + String summary_cf = "file:src/test/java/com/yahoo/prelude/fastsearch/summary.cfg"; DocsumDefinitionSet set = createDocsumDefinitionSet(summary_cf); byte[] docsum = makeDocsum(tensor1, tensor2); diff --git a/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java b/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java index e59c03b33c3..62eacaa0afe 100644 --- a/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/test/QueryTestCase.java @@ -2,8 +2,6 @@ package com.yahoo.search.test; import com.yahoo.component.chain.Chain; -import com.yahoo.language.Language; -import com.yahoo.language.Linguistics; import com.yahoo.language.detect.Detection; import com.yahoo.language.detect.Detector; import com.yahoo.language.detect.Hint; @@ -28,7 +26,6 @@ import com.yahoo.search.query.profile.QueryProfile; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.search.result.Hit; import com.yahoo.search.searchchain.Execution; - import com.yahoo.yolean.Exceptions; import org.junit.Ignore; import org.junit.Test; @@ -45,14 +42,14 @@ import java.util.stream.Collectors; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.is; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertThat; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** @@ -69,7 +66,7 @@ public class QueryTestCase { assertEquals("", q.properties().get("aParameter")); assertNull(q.properties().get("notSetParameter")); } - + // TODO: YQL work in progress (jon) @Ignore @Test @@ -693,7 +690,7 @@ public class QueryTestCase { List<IndexedItem> l = QueryTree.getPositiveTerms(i); assertEquals(3, l.size()); } - + @Test public void testHeuristicLanguageDetectionTextExtraction() { assertDetectionText("b ", "a:b", "text:a", "text:default"); @@ -720,27 +717,27 @@ public class QueryTestCase { q.getModel().getQueryTree(); // cause parsing assertEquals(expectedDetectionText, mockLinguistics.detector.lastDetectionText); } - + /** A linguistics instance which records the last language detection text passed to it */ private static class MockLinguistics extends SimpleLinguistics { final MockDetector detector = new MockDetector(); - + @Override public Detector getDetector() { return detector; } - + } - + private static class MockDetector extends SimpleDetector { String lastDetectionText = null; - + @Override public Detection detect(String input, Hint hint) { lastDetectionText = input; return super.detect(input, hint); } - + } protected boolean contains(String lineSubstring,String[] lines) { diff --git a/container/pom.xml b/container/pom.xml index 3793a3508a4..d252a5eee4a 100644 --- a/container/pom.xml +++ b/container/pom.xml @@ -47,6 +47,18 @@ <groupId>org.apache.commons</groupId> <artifactId>commons-lang3</artifactId> </exclusion> + <exclusion> + <groupId>org.tensorflow</groupId> + <artifactId>proto</artifactId> + </exclusion> + <exclusion> + <groupId>org.tensorflow</groupId> + <artifactId>tensorflow</artifactId> + </exclusion> + <exclusion> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + </exclusion> </exclusions> </dependency> </dependencies> diff --git a/controller-api/pom.xml b/controller-api/pom.xml index 5ef130a22ba..543ab24999d 100644 --- a/controller-api/pom.xml +++ b/controller-api/pom.xml @@ -18,24 +18,9 @@ <dependencies> <!-- provided --> - - <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>component</artifactId> - <scope>provided</scope> - <version>${project.version}</version> - </dependency> - - <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>annotations</artifactId> - <scope>provided</scope> - <version>${project.version}</version> - </dependency> - <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>vespajlib</artifactId> + <artifactId>container-dev</artifactId> <scope>provided</scope> <version>${project.version}</version> </dependency> @@ -54,56 +39,6 @@ <version>${project.version}</version> </dependency> - <dependency> - <groupId>com.fasterxml.jackson.core</groupId> - <artifactId>jackson-annotations</artifactId> - <scope>provided</scope> - </dependency> - - <dependency> - <groupId>com.fasterxml.jackson.core</groupId> - <artifactId>jackson-databind</artifactId> - <scope>provided</scope> - </dependency> - - <dependency> - <groupId>com.fasterxml.jackson.datatype</groupId> - <artifactId>jackson-datatype-jdk8</artifactId> - <scope>provided</scope> - </dependency> - - <dependency> - <groupId>org.glassfish.jersey.media</groupId> - <artifactId>jersey-media-multipart</artifactId> - <scope>provided</scope> - </dependency> - - <dependency> - <groupId>javax.servlet</groupId> - <artifactId>javax.servlet-api</artifactId> - <scope>provided</scope> - </dependency> - - <dependency> - <groupId>javax.ws.rs</groupId> - <artifactId>javax.ws.rs-api</artifactId> - <scope>provided</scope> - </dependency> - - <dependency> - <groupId>org.glassfish.jersey.core</groupId> - <artifactId>jersey-server</artifactId> - <version>${jersey2.version}</version> - <scope>provided</scope> - </dependency> - - <dependency> - <groupId>com.google.inject</groupId> - <artifactId>guice</artifactId> - <classifier>no_aop</classifier> - <scope>provided</scope> - </dependency> - <!-- compile --> <dependency> @@ -128,6 +63,19 @@ <scope>test</scope> </dependency> + <!-- Required for AthenzIdentityVerifierTest --> + <dependency> + <groupId>org.bouncycastle</groupId> + <artifactId>bcpkix-jdk15on</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-core</artifactId> + <scope>test</scope> + </dependency> + + </dependencies> <build> diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzIdentityVerifier.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzIdentityVerifier.java new file mode 100644 index 00000000000..bfaa6c2acda --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzIdentityVerifier.java @@ -0,0 +1,41 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.api.integration.athenz; + +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import java.security.cert.X509Certificate; +import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * A {@link HostnameVerifier} that validates Athenz x509 certificates using the identity in the Common Name attribute. + * + * @author bjorncs + */ +// TODO Move to dedicated Athenz bundle +public class AthenzIdentityVerifier implements HostnameVerifier { + + private static final Logger log = Logger.getLogger(AthenzIdentityVerifier.class.getName()); + + private final Set<AthenzIdentity> allowedIdentities; + + public AthenzIdentityVerifier(Set<AthenzIdentity> allowedIdentities) { + this.allowedIdentities = allowedIdentities; + } + + @Override + public boolean verify(String hostname, SSLSession session) { + try { + X509Certificate cert = (X509Certificate) session.getPeerCertificates()[0]; + AthenzIdentity certificateIdentity = AthenzUtils.createAthenzIdentity(cert); + return allowedIdentities.contains(certificateIdentity); + } catch (SSLPeerUnverifiedException e) { + log.log(Level.WARNING, "Unverified client: " + hostname); + return false; + } + } + +} + diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzPrincipal.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzPrincipal.java index 8279edcd8e6..b31cb4a26bb 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzPrincipal.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzPrincipal.java @@ -5,6 +5,7 @@ import com.yahoo.vespa.hosted.controller.api.identifiers.AthenzDomain; import java.security.Principal; import java.util.Objects; +import java.util.Optional; /** * @author bjorncs @@ -14,6 +15,10 @@ public class AthenzPrincipal implements Principal { private final AthenzIdentity athenzIdentity; private final NToken nToken; + public AthenzPrincipal(AthenzIdentity athenzIdentity) { + this(athenzIdentity, null); + } + public AthenzPrincipal(AthenzIdentity athenzIdentity, NToken nToken) { this.athenzIdentity = athenzIdentity; @@ -33,8 +38,8 @@ public class AthenzPrincipal implements Principal { return athenzIdentity.getDomain(); } - public NToken getNToken() { - return nToken; + public Optional<NToken> getNToken() { + return Optional.ofNullable(nToken); } @Override diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzUtils.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzUtils.java index 0ed5d86dd7e..04ec0b61614 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzUtils.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzUtils.java @@ -4,6 +4,10 @@ package com.yahoo.vespa.hosted.controller.api.integration.athenz; import com.yahoo.vespa.hosted.controller.api.identifiers.AthenzDomain; import com.yahoo.vespa.hosted.controller.api.identifiers.UserId; +import javax.naming.NamingException; +import javax.naming.ldap.LdapName; +import java.security.cert.X509Certificate; + /** * @author bjorncs */ @@ -23,4 +27,40 @@ public class AthenzUtils { } } + public static AthenzIdentity createAthenzIdentity(String fullName) { + int domainIdentityNameSeparatorIndex = fullName.lastIndexOf('.'); + if (domainIdentityNameSeparatorIndex == -1 + || domainIdentityNameSeparatorIndex == 0 + || domainIdentityNameSeparatorIndex == fullName.length() - 1) { + throw new IllegalArgumentException("Invalid Athenz identity: " + fullName); + } + AthenzDomain domain = new AthenzDomain(fullName.substring(0, domainIdentityNameSeparatorIndex)); + String identityName = fullName.substring(domainIdentityNameSeparatorIndex + 1, fullName.length()); + return createAthenzIdentity(domain, identityName); + } + + public static AthenzIdentity createAthenzIdentity(X509Certificate certificate) { + String commonName = getCommonName(certificate); + if (isAthenzRoleIdentity(commonName)) { + throw new IllegalArgumentException("Athenz role certificate not supported"); + } + return createAthenzIdentity(commonName); + } + + private static boolean isAthenzRoleIdentity(String commonName) { + return commonName.contains(":role."); + } + + private static String getCommonName(X509Certificate certificate) { + try { + String subjectPrincipal = certificate.getSubjectX500Principal().getName(); + return new LdapName(subjectPrincipal).getRdns().stream() + .filter(rdn -> rdn.getType().equalsIgnoreCase("cn")) + .map(rdn -> rdn.getValue().toString()) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("Could not find CN in certificate: " + subjectPrincipal)); + } catch (NamingException e) { + throw new IllegalArgumentException("Invalid CN: " + e, e); + } + } } diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/InvalidTokenException.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/InvalidTokenException.java index 1df1746b02e..967af1c553f 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/InvalidTokenException.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/InvalidTokenException.java @@ -4,7 +4,7 @@ package com.yahoo.vespa.hosted.controller.api.integration.athenz; /** * @author bjorncs */ -public class InvalidTokenException extends Exception { +public class InvalidTokenException extends RuntimeException { public InvalidTokenException(String message) { super(message); } diff --git a/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzIdentityVerifierTest.java b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzIdentityVerifierTest.java new file mode 100644 index 00000000000..88da28fb273 --- /dev/null +++ b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzIdentityVerifierTest.java @@ -0,0 +1,82 @@ +package com.yahoo.vespa.hosted.controller.api.integration.athenz; + +import org.bouncycastle.asn1.x500.X500Name; +import org.bouncycastle.asn1.x509.BasicConstraints; +import org.bouncycastle.asn1.x509.Extension; +import org.bouncycastle.cert.CertIOException; +import org.bouncycastle.cert.X509v3CertificateBuilder; +import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; +import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.bouncycastle.operator.ContentSigner; +import org.bouncycastle.operator.OperatorCreationException; +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; +import org.junit.Test; + +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import java.math.BigInteger; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.time.Duration; +import java.time.Instant; +import java.util.Date; + +import static java.util.Collections.singleton; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * @author bjorncs + */ +public class AthenzIdentityVerifierTest { + + @Test + public void verifies_certificate_with_athenz_service_as_common_name() throws Exception { + AthenzIdentity trustedIdentity = new AthenzService("mydomain", "alice"); + AthenzIdentity unknownIdentity = new AthenzService("mydomain", "mallory"); + KeyPair keyPair = createKeyPair(); + AthenzIdentityVerifier verifier = new AthenzIdentityVerifier(singleton(trustedIdentity)); + assertTrue(verifier.verify("hostname", createSslSessionMock(createSelfSignedCertificate(keyPair, trustedIdentity)))); + assertFalse(verifier.verify("hostname", createSslSessionMock(createSelfSignedCertificate(keyPair, unknownIdentity)))); + } + + private static KeyPair createKeyPair() throws NoSuchAlgorithmException { + KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA"); + keyGen.initialize(512); + return keyGen.generateKeyPair(); + } + + private static X509Certificate createSelfSignedCertificate(KeyPair keyPair, AthenzIdentity identity) + throws OperatorCreationException, CertIOException, CertificateException { + ContentSigner contentSigner = new JcaContentSignerBuilder("SHA256WithRSA").build(keyPair.getPrivate()); + X500Name x500Name = new X500Name("CN="+ identity.getFullName()); + Instant now = Instant.now(); + Date notBefore = Date.from(now); + Date notAfter = Date.from(now.plus(Duration.ofDays(30))); + + X509v3CertificateBuilder certificateBuilder = + new JcaX509v3CertificateBuilder( + x500Name, BigInteger.valueOf(now.toEpochMilli()), notBefore, notAfter, x500Name, keyPair.getPublic() + ) + .addExtension(Extension.basicConstraints, true, new BasicConstraints(true)); + + return new JcaX509CertificateConverter() + .setProvider(new BouncyCastleProvider()) + .getCertificate(certificateBuilder.build(contentSigner)); + + } + + private static SSLSession createSslSessionMock(X509Certificate certificate) throws SSLPeerUnverifiedException { + SSLSession sslSession = mock(SSLSession.class); + when(sslSession.getPeerCertificates()).thenReturn(new Certificate[]{certificate}); + return sslSession; + } + +}
\ No newline at end of file diff --git a/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzUtilsTest.java b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzUtilsTest.java new file mode 100644 index 00000000000..f2db74a4c3d --- /dev/null +++ b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/athenz/AthenzUtilsTest.java @@ -0,0 +1,21 @@ +package com.yahoo.vespa.hosted.controller.api.integration.athenz; + +import com.yahoo.vespa.hosted.controller.api.identifiers.AthenzDomain; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bjorncs + */ +public class AthenzUtilsTest { + + @Test + public void athenz_identity_is_parsed_from_dot_separated_string() { + AthenzIdentity expectedIdentity = new AthenzService(new AthenzDomain("my.subdomain"), "myservicename"); + String fullName = expectedIdentity.getFullName(); + AthenzIdentity actualIdentity = AthenzUtils.createAthenzIdentity(fullName); + assertEquals(expectedIdentity, actualIdentity); + } + +}
\ No newline at end of file diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java index 1b2ad9f938a..fb675862320 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ClusterCost.java @@ -24,6 +24,7 @@ import java.util.Objects; * @author smorgrav */ public class ClusterCost { + private final double tco; private final double waste; private final ClusterInfo clusterInfo; @@ -32,8 +33,8 @@ public class ClusterCost { private final ClusterUtilization resultUtilization; /** - * @param clusterInfo Value object with cluster info e.g. the TCO for the hardware used - * @param systemUtilization Utilization of system resources (as ratios) + * @param clusterInfo value object with cluster info e.g. the TCO for the hardware used + * @param systemUtilization utilization of system resources (as ratios) */ public ClusterCost(ClusterInfo clusterInfo, ClusterUtilization systemUtilization) { @@ -79,10 +80,10 @@ public class ClusterCost { } static ClusterUtilization calculateResultUtilization(ClusterUtilization system, ClusterUtilization target) { - double cpu = ratio(system.getCpu(),target.getCpu()); - double mem = ratio(system.getMemory(),target.getMemory()); - double disk = ratio(system.getDisk(),target.getDisk()); - double diskbusy = ratio(system.getDiskBusy(),target.getDiskBusy()); + double cpu = ratio(system.getCpu(), target.getCpu()); + double mem = ratio(system.getMemory(), target.getMemory()); + double disk = ratio(system.getDisk(), target.getDisk()); + double diskbusy = ratio(system.getDiskBusy(), target.getDiskBusy()); return new ClusterUtilization(mem, cpu, disk, diskbusy); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java index 585690793bb..371e1c41e32 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentCost.java @@ -44,17 +44,17 @@ public class DeploymentCost { return clusters; } - /** @return Total cost of ownership for the deployment (sum of all clusters) */ + /** Returns the total monthly cost of ownership for the deployment (sum of all clusters) */ public double getTco() { return tco; } - /** @return The utilization of clusters that wastes most money in this deployment */ + /** Returns the utilization of clusters that wastes most money in this deployment */ public double getUtilization() { return utilization; } - /** @return The amount of dollars spent and not utilized */ + /** Returns the amount of dollars spent and not utilized */ public double getWaste() { return waste; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/JobStatus.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/JobStatus.java index ceb04d88026..a7940076277 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/JobStatus.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/JobStatus.java @@ -11,14 +11,14 @@ import java.util.Optional; /** * The last known build status of a particular deployment job for a particular application. * This is immutable. - * + * * @author bratseth * @author mpolden */ public class JobStatus { - + private final DeploymentJobs.JobType type; - + private final Optional<JobRun> lastTriggered; private final Optional<JobRun> lastCompleted; private final Optional<JobRun> firstFailing; @@ -42,7 +42,7 @@ public class JobStatus { this.type = type; this.jobError = jobError; - + // Never say we triggered component because we don't: this.lastTriggered = type == DeploymentJobs.JobType.component ? Optional.empty() : lastTriggered; this.lastCompleted = lastCompleted; @@ -52,7 +52,7 @@ public class JobStatus { /** Returns an empty job status */ public static JobStatus initial(DeploymentJobs.JobType type) { - return new JobStatus(type, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + return new JobStatus(type, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); } public JobStatus withTriggering(Version version, Optional<ApplicationRevision> revision, @@ -89,13 +89,13 @@ public class JobStatus { Optional<JobRun> firstFailing = this.firstFailing; if (jobError.isPresent() && ! this.firstFailing.isPresent()) firstFailing = Optional.of(thisCompletion); - + Optional<JobRun> lastSuccess = this.lastSuccess; if ( ! jobError.isPresent()) { lastSuccess = Optional.of(thisCompletion); firstFailing = Optional.empty(); } - + return new JobStatus(type, jobError, lastTriggered, Optional.of(thisCompletion), firstFailing, lastSuccess); } @@ -105,7 +105,7 @@ public class JobStatus { public boolean isSuccess() { return lastCompleted().isPresent() && ! jobError.isPresent(); } - + /** Returns true if last triggered is newer than last completed and was started after timeoutLimit */ public boolean isRunning(Instant timeoutLimit) { if ( ! lastTriggered.isPresent()) return false; @@ -114,6 +114,11 @@ public class JobStatus { return ! lastTriggered.get().at().isBefore(lastCompleted.get().at()); } + /** Returns true if this is running and has been so since before the given limit */ + public boolean isHanging(Instant timeoutLimit) { + return isRunning(Instant.MIN) && lastTriggered.get().at().isBefore(timeoutLimit.plusMillis(1)); + } + /** The error of the last completion, or empty if the last run succeeded */ public Optional<DeploymentJobs.JobError> jobError() { return jobError; } @@ -140,10 +145,10 @@ public class JobStatus { ", first failing: " + firstFailing.map(JobRun::toString).orElse("(not failing)") + ", lastSuccess: " + lastSuccess.map(JobRun::toString).orElse("(never)") + "]"; } - + @Override public int hashCode() { return Objects.hash(type, jobError, lastTriggered, lastCompleted, firstFailing, lastSuccess); } - + @Override public boolean equals(Object o) { if (o == this) return true; @@ -159,15 +164,15 @@ public class JobStatus { /** Information about a particular triggering or completion of a run of a job. This is immutable. */ public static class JobRun { - + private final long id; private final Version version; private final Optional<ApplicationRevision> revision; private final boolean upgrade; private final String reason; private final Instant at; - - public JobRun(long id, Version version, Optional<ApplicationRevision> revision, + + public JobRun(long id, Version version, Optional<ApplicationRevision> revision, boolean upgrade, String reason, Instant at) { Objects.requireNonNull(version, "version cannot be null"); Objects.requireNonNull(revision, "revision cannot be null"); @@ -188,16 +193,16 @@ public class JobStatus { // TODO: Fix how this is set, and add an applicationChange() method as well, in the same vein. /** Returns whether this job run was a Vespa upgrade */ public boolean upgrade() { return upgrade; } - + /** Returns the Vespa version used on this run */ public Version version() { return version; } - + /** Returns the application revision used for this run, or empty when not known */ public Optional<ApplicationRevision> revision() { return revision; } - + /** Returns a human-readable reason for this particular job run */ public String reason() { return reason; } - + /** Returns the time if this triggering or completion */ public Instant at() { return at; } @@ -218,7 +223,7 @@ public class JobStatus { public int hashCode() { return Objects.hash(version, revision, upgrade, at); } - + @Override public boolean equals(Object o) { if (this == o) return true; @@ -234,7 +239,7 @@ public class JobStatus { @Override public String toString() { return "job run " + id + " of version " + (upgrade() ? "upgrade " : "") + version + " " + revision + " at " + at; } - + } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilter.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilter.java index 328461355db..7aaaad534db 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilter.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilter.java @@ -7,17 +7,24 @@ import com.yahoo.jdisc.handler.ResponseHandler; import com.yahoo.jdisc.http.filter.DiscFilterRequest; import com.yahoo.jdisc.http.filter.SecurityRequestFilter; import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzPrincipal; -import com.yahoo.vespa.hosted.controller.api.integration.athenz.InvalidTokenException; +import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzUtils; import com.yahoo.vespa.hosted.controller.api.integration.athenz.NToken; import com.yahoo.vespa.hosted.controller.api.integration.athenz.ZmsKeystore; import com.yahoo.vespa.hosted.controller.athenz.config.AthenzConfig; +import java.security.cert.X509Certificate; +import java.util.Optional; import java.util.concurrent.Executor; import static com.yahoo.vespa.hosted.controller.athenz.filter.SecurityFilterUtils.sendErrorResponse; /** - * Performs authentication by validating the principal token (NToken) header. + * Authenticates Athenz principal, either through: + * 1. TLS client authentication (based on Athenz x509 identity certficiate). + * 2. The principal token (NToken) header. + * The TLS authentication is based on the following assumptions: + * - The underlying connector is configured with 'clientAuth' set to either WANT_AUTH or NEED_AUTH. + * - The trust store is configured with the Athenz CA certificates only. * * @author bjorncs */ @@ -43,18 +50,45 @@ public class AthenzPrincipalFilter implements SecurityRequestFilter { @Override public void filter(DiscFilterRequest request, ResponseHandler responseHandler) { - String rawToken = request.getHeader(principalTokenHeader); - if (rawToken == null || rawToken.isEmpty()) { - sendErrorResponse(responseHandler, Response.Status.UNAUTHORIZED, "NToken is missing"); - return; - } try { - AthenzPrincipal principal = validator.validate(new NToken(rawToken)); + Optional<AthenzPrincipal> certificatePrincipal = getClientCertificate(request) + .map(AthenzUtils::createAthenzIdentity) + .map(AthenzPrincipal::new); + Optional<AthenzPrincipal> nTokenPrincipal = getPrincipalToken(request, principalTokenHeader) + .map(validator::validate); + + if (!certificatePrincipal.isPresent() && !nTokenPrincipal.isPresent()) { + String errorMessage = "Unable to authenticate Athenz identity. " + + "Both client certificate missing and principal token header are missing."; + sendErrorResponse(responseHandler, Response.Status.UNAUTHORIZED, errorMessage); + return; + } + if (certificatePrincipal.isPresent() && nTokenPrincipal.isPresent() + && !certificatePrincipal.get().getIdentity().equals(nTokenPrincipal.get().getIdentity())) { + String errorMessage = String.format( + "Identity in principal token does not match x509 CN: token-identity=%s, cert-identity=%s", + nTokenPrincipal.get().getIdentity().getFullName(), + certificatePrincipal.get().getIdentity().getFullName()); + sendErrorResponse(responseHandler, Response.Status.UNAUTHORIZED, errorMessage); + return; + } + AthenzPrincipal principal = nTokenPrincipal.orElseGet(certificatePrincipal::get); request.setUserPrincipal(principal); request.setRemoteUser(principal.getName()); - } catch (InvalidTokenException e) { + } catch (Exception e) { sendErrorResponse(responseHandler,Response.Status.UNAUTHORIZED, e.getMessage()); } } + private static Optional<X509Certificate> getClientCertificate(DiscFilterRequest request) { + return Optional.ofNullable((X509Certificate[]) request.getAttribute("jdisc.request.X509Certificate")) + .map(chain -> chain[0]); + } + + private static Optional<NToken> getPrincipalToken(DiscFilterRequest request, String principalTokenHeaderName) { + return Optional.ofNullable(request.getHeader(principalTokenHeaderName)) + .filter(token -> !token.isEmpty()) + .map(NToken::new); + } + } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java index 487cbc02acc..6a9db3ae917 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java @@ -32,9 +32,9 @@ import java.util.logging.Logger; /** * Responsible for scheduling deployment jobs in a build system and keeping * Application.deploying() in sync with what is scheduled. - * + * * This class is multithread safe. - * + * * @author bratseth * @author mpolden */ @@ -60,7 +60,7 @@ public class DeploymentTrigger { this.order = new DeploymentOrder(controller); this.jobTimeout = controller.system().equals(SystemName.main) ? Duration.ofHours(12) : Duration.ofHours(1); } - + /** Returns the time in the past before which jobs are at this moment considered unresponsive */ public Instant jobTimeoutLimit() { return clock.instant().minus(jobTimeout); } @@ -70,10 +70,10 @@ public class DeploymentTrigger { //--- Start of methods which triggers deployment jobs ------------------------- - /** + /** * Called each time a job completes (successfully or not) to cause triggering of one or more follow-up jobs * (which may possibly the same job once over). - * + * * @param report information about the job that just completed */ public void triggerFromCompletion(JobReport report) { @@ -143,10 +143,11 @@ public class DeploymentTrigger { JobStatus systemTestStatus = application.deploymentJobs().jobStatus().get(JobType.systemTest); if (application.deploying().get() instanceof Change.VersionChange) { Version target = ((Change.VersionChange) application.deploying().get()).version(); - if (systemTestStatus == null + if (systemTestStatus == null || ! systemTestStatus.lastTriggered().isPresent() || ! systemTestStatus.isSuccess() - || ! systemTestStatus.lastTriggered().get().version().equals(target)) { + || ! systemTestStatus.lastTriggered().get().version().equals(target) + || systemTestStatus.isHanging(jobTimeoutLimit())) { application = trigger(JobType.systemTest, application, false, "Upgrade to " + target); controller.applications().store(application); } @@ -170,7 +171,7 @@ public class DeploymentTrigger { List<JobType> nextToTrigger = new ArrayList<>(); for (JobType nextJobType : order.nextAfter(jobType, application)) { JobStatus nextStatus = application.deploymentJobs().jobStatus().get(nextJobType); - if (changesAvailable(application, jobStatus, nextStatus)) + if (changesAvailable(application, jobStatus, nextStatus) || nextStatus.isHanging(jobTimeoutLimit())) nextToTrigger.add(nextJobType); } // Trigger them in parallel @@ -209,10 +210,10 @@ public class DeploymentTrigger { return true; return false; } - + /** * Triggers a change of this application - * + * * @param applicationId the application to trigger * @throws IllegalArgumentException if this application already have an ongoing change */ @@ -267,7 +268,7 @@ public class DeploymentTrigger { } /** - * Trigger a job for an application + * Trigger a job for an application * * @param jobType the type of the job to trigger, or null to trigger nothing * @param application the application to trigger the job for @@ -289,7 +290,7 @@ public class DeploymentTrigger { /** * Trigger a job for an application, if allowed - * + * * @param jobType the type of the job to trigger, or null to trigger nothing * @param application the application to trigger the job for * @param first whether to trigger the job before other jobs @@ -323,7 +324,7 @@ public class DeploymentTrigger { /** Returns true if the given proposed job triggering should be effected */ private boolean allowedTriggering(JobType jobType, LockedApplication application) { - // Note: We could make a more fine-grained and more correct determination about whether to block + // Note: We could make a more fine-grained and more correct determination about whether to block // by instead basing the decision on what is currently deployed in the zone. However, // this leads to some additional corner cases, and the possibility of blocking an application // fix to a version upgrade, so not doing it now @@ -341,7 +342,7 @@ public class DeploymentTrigger { return true; } - + private boolean isRunningProductionJob(Application application) { return JobList.from(application) .production() @@ -364,7 +365,7 @@ public class DeploymentTrigger { if (existingDeployment == null) return false; return existingDeployment.version().isAfter(version); } - + private boolean acceptNewRevisionNow(LockedApplication application) { if ( ! application.deploying().isPresent()) return true; @@ -377,5 +378,5 @@ public class DeploymentTrigger { // Otherwise, the application is currently upgrading, without failures, and we should wait with the revision. return false; } - + } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/Authorizer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/Authorizer.java index b7080a763f0..77ce49eaf47 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/Authorizer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/Authorizer.java @@ -9,14 +9,13 @@ import com.yahoo.vespa.hosted.controller.api.identifiers.AthenzDomain; import com.yahoo.vespa.hosted.controller.api.identifiers.TenantId; import com.yahoo.vespa.hosted.controller.api.identifiers.UserGroup; import com.yahoo.vespa.hosted.controller.api.identifiers.UserId; -import com.yahoo.vespa.hosted.controller.api.integration.entity.EntityService; import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzClientFactory; import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzIdentity; import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzPrincipal; import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzUser; import com.yahoo.vespa.hosted.controller.api.integration.athenz.NToken; +import com.yahoo.vespa.hosted.controller.api.integration.entity.EntityService; import com.yahoo.vespa.hosted.controller.common.ContextAttributes; -import com.yahoo.vespa.hosted.controller.restapi.filter.NTokenRequestFilter; import javax.ws.rs.ForbiddenException; import javax.ws.rs.HttpMethod; @@ -78,8 +77,7 @@ public class Authorizer { } public Optional<NToken> getNToken(HttpRequest request) { - String nTokenHeader = (String)request.getJDiscRequest().context().get(NTokenRequestFilter.NTOKEN_HEADER); - return Optional.ofNullable(nTokenHeader).map(NToken::new); + return getPrincipalIfAny(request).flatMap(AthenzPrincipal::getNToken); } public boolean isSuperUser(HttpRequest request) { diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java index ffb78b7342a..c887fbfc1a8 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java @@ -7,10 +7,19 @@ import com.yahoo.jdisc.handler.ReadableContentChannel; import com.yahoo.jdisc.handler.ResponseHandler; import com.yahoo.jdisc.http.filter.DiscFilterRequest; import com.yahoo.vespa.hosted.controller.api.identifiers.UserId; +import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzIdentity; import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzPrincipal; import com.yahoo.vespa.hosted.controller.api.integration.athenz.AthenzUser; import com.yahoo.vespa.hosted.controller.api.integration.athenz.InvalidTokenException; import com.yahoo.vespa.hosted.controller.api.integration.athenz.NToken; +import org.bouncycastle.asn1.x500.X500Name; +import org.bouncycastle.cert.X509v3CertificateBuilder; +import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; +import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.bouncycastle.operator.ContentSigner; +import org.bouncycastle.operator.OperatorCreationException; +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; import org.junit.Before; import org.junit.Test; @@ -18,6 +27,15 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.io.UncheckedIOException; +import java.math.BigInteger; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.time.Duration; +import java.time.Instant; +import java.util.Date; import java.util.Objects; import static com.yahoo.jdisc.Response.Status.UNAUTHORIZED; @@ -37,21 +55,21 @@ public class AthenzPrincipalFilterTest { private static final NToken NTOKEN = new NToken("dummy"); private static final String ATHENZ_PRINCIPAL_HEADER = "Athenz-Principal-Auth"; + private static final AthenzIdentity IDENTITY = AthenzUser.fromUserId(new UserId("bob")); + private static final X509Certificate CERTIFICATE = createSelfSignedCertificate(IDENTITY); private NTokenValidator validator; - private AthenzPrincipal principal; @Before public void before() { validator = mock(NTokenValidator.class); - principal = new AthenzPrincipal(AthenzUser.fromUserId(new UserId("bob")), NTOKEN); } @Test - public void valid_ntoken_is_accepted() throws Exception { + public void valid_ntoken_is_accepted() { DiscFilterRequest request = mock(DiscFilterRequest.class); + AthenzPrincipal principal = new AthenzPrincipal(IDENTITY, NTOKEN); when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken()); - when(validator.validate(NTOKEN)).thenReturn(principal); AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER); @@ -61,7 +79,7 @@ public class AthenzPrincipalFilterTest { } @Test - public void missing_token_is_unauthorized() throws Exception { + public void missing_token_and_certificate_is_unauthorized() { DiscFilterRequest request = mock(DiscFilterRequest.class); when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(null); @@ -70,26 +88,76 @@ public class AthenzPrincipalFilterTest { AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER); filter.filter(request, responseHandler); - assertThat(responseHandler.response, notNullValue()); - assertThat(responseHandler.response.getStatus(), equalTo(UNAUTHORIZED)); - assertThat(responseHandler.getResponseContent(), containsString("NToken is missing")); + assertUnauthorized(responseHandler, "Unable to authenticate Athenz identity"); + } + + @Test + public void invalid_token_is_unauthorized() { + DiscFilterRequest request = mock(DiscFilterRequest.class); + String errorMessage = "Invalid token"; + when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken()); + when(validator.validate(NTOKEN)).thenThrow(new InvalidTokenException(errorMessage)); + + ResponseHandlerMock responseHandler = new ResponseHandlerMock(); + + AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER); + filter.filter(request, responseHandler); + + assertUnauthorized(responseHandler, errorMessage); + } + + @Test + public void certificate_is_accepted() { + DiscFilterRequest request = mock(DiscFilterRequest.class); + when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(null); + when(request.getAttribute("jdisc.request.X509Certificate")).thenReturn(new X509Certificate[]{CERTIFICATE}); + + ResponseHandlerMock responseHandler = new ResponseHandlerMock(); + + AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER); + filter.filter(request, responseHandler); + + AthenzPrincipal expectedPrincipal = new AthenzPrincipal(IDENTITY); + verify(request).setUserPrincipal(expectedPrincipal); } @Test - public void invalid_token_is_unauthorized() throws Exception { + public void both_ntoken_and_certificate_is_accepted() { DiscFilterRequest request = mock(DiscFilterRequest.class); + AthenzPrincipal principalWithToken = new AthenzPrincipal(IDENTITY, NTOKEN); when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken()); + when(request.getAttribute("jdisc.request.X509Certificate")).thenReturn(new X509Certificate[]{CERTIFICATE}); + when(validator.validate(NTOKEN)).thenReturn(principalWithToken); + + ResponseHandlerMock responseHandler = new ResponseHandlerMock(); + + AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER); + filter.filter(request, responseHandler); - when(validator.validate(NTOKEN)).thenThrow(new InvalidTokenException("Invalid token")); + verify(request).setUserPrincipal(principalWithToken); + } + + @Test + public void conflicting_ntoken_and_certificate_is_unauthorized() { + DiscFilterRequest request = mock(DiscFilterRequest.class); + AthenzUser conflictingIdentity = AthenzUser.fromUserId(new UserId("mallory")); + when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken()); + when(request.getAttribute("jdisc.request.X509Certificate")) + .thenReturn(new X509Certificate[]{createSelfSignedCertificate(conflictingIdentity)}); + when(validator.validate(NTOKEN)).thenReturn(new AthenzPrincipal(IDENTITY)); ResponseHandlerMock responseHandler = new ResponseHandlerMock(); AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER); filter.filter(request, responseHandler); + assertUnauthorized(responseHandler, "Identity in principal token does not match x509 CN"); + } + + private static void assertUnauthorized(ResponseHandlerMock responseHandler, String expectedMessageSubstring) { assertThat(responseHandler.response, notNullValue()); assertThat(responseHandler.response.getStatus(), equalTo(UNAUTHORIZED)); - assertThat(responseHandler.getResponseContent(), containsString("Invalid token")); + assertThat(responseHandler.getResponseContent(), containsString(expectedMessageSubstring)); } private static class ResponseHandlerMock implements ResponseHandler { @@ -114,4 +182,24 @@ public class AthenzPrincipalFilterTest { } + // TODO Move this to separate athenz module/bundle + private static X509Certificate createSelfSignedCertificate(AthenzIdentity identity) { + try { + KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA"); + keyGen.initialize(512); + KeyPair keyPair = keyGen.genKeyPair(); + ContentSigner contentSigner = new JcaContentSignerBuilder("SHA256WithRSA").build(keyPair.getPrivate()); + X500Name x500Name = new X500Name("CN="+ identity.getFullName()); + X509v3CertificateBuilder certificateBuilder = + new JcaX509v3CertificateBuilder( + x500Name, BigInteger.ONE, new Date(), Date.from(Instant.now().plus(Duration.ofDays(30))), + x500Name, keyPair.getPublic()); + return new JcaX509CertificateConverter() + .setProvider(new BouncyCastleProvider()) + .getCertificate(certificateBuilder.build(contentSigner)); + } catch (CertificateException | NoSuchAlgorithmException | OperatorCreationException e) { + throw new RuntimeException(e); + } + } + } diff --git a/document/src/main/java/com/yahoo/document/DataType.java b/document/src/main/java/com/yahoo/document/DataType.java index c8a04866aa9..abdbf394591 100644 --- a/document/src/main/java/com/yahoo/document/DataType.java +++ b/document/src/main/java/com/yahoo/document/DataType.java @@ -51,7 +51,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com public final static PrimitiveDataType URI = new PrimitiveDataType("uri", 10, UriFieldValue.class, new UriFieldValue.Factory()); public final static NumericDataType BYTE = new NumericDataType("byte", 16, ByteFieldValue.class, ByteFieldValue.getFactory()); public final static PrimitiveDataType PREDICATE = new PrimitiveDataType("predicate", 20, PredicateFieldValue.class, PredicateFieldValue.getFactory()); - public final static int tensorDataTypeCode = 21; // All TensorDataType instances have id=21 but carries additional type information serialized separately + public final static int tensorDataTypeCode = 21; // All TensorDataType instances have id=21 but carries additional type information serialized separately // ADDITIONAL parametrized types added at runtime: map, struct, array, weighted set, annotation reference, tensor // Tags are converted to weightedset<string> when reading the search definition TODO: Remove it @@ -99,7 +99,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com /** * Creates a field value by reflection - * + * * @param arg the value of the newly created field value * @return a fully constructed value */ @@ -201,7 +201,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com public static TensorDataType getTensor(TensorType type) { return new TensorDataType(type); } - + public String getName() { return name; } @@ -267,7 +267,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com */ public FieldPath buildFieldPath(String fieldPathString) { if (fieldPathString.length() > 0) { - throw new IllegalArgumentException("Datatype " + toString() + + throw new IllegalArgumentException("Datatype " + toString() + " does not support further recursive structure: " + fieldPathString); } return new FieldPath(); diff --git a/document/src/main/java/com/yahoo/document/DocumentTypeManager.java b/document/src/main/java/com/yahoo/document/DocumentTypeManager.java index 8c9318199d8..5fad35a2287 100644 --- a/document/src/main/java/com/yahoo/document/DocumentTypeManager.java +++ b/document/src/main/java/com/yahoo/document/DocumentTypeManager.java @@ -38,7 +38,7 @@ public class DocumentTypeManager { // *Configured data types* (not built-in/primitive) indexed by their id // // *Primitive* data types are always available and have a single id. - // + // // *Built-in dynamic* types: The tensor type. // Any tensor type has the same id and is always available just like primitive types. // However, unlike primitive types, each tensor type is a separate DataType instance @@ -112,7 +112,7 @@ public class DocumentTypeManager { public DataType getDataType(String name) { if (name.startsWith("tensor(")) // built-in dynamic return new TensorDataType(TensorType.fromSpec(name)); - + List<DataType> foundTypes = new ArrayList<>(); for (DataType type : dataTypes.values()) { if (type.getName().equalsIgnoreCase(name)) { @@ -141,10 +141,10 @@ public class DocumentTypeManager { } public DataType getDataType(int code) { return getDataType(code, ""); } - + /** * Return a data type instance - * + * * @param code the code of the data type to return, which must be either built in or present in this manager * @param detailedType detailed type information, or the empty string if none * @return the appropriate DataType instance @@ -183,7 +183,7 @@ public class DocumentTypeManager { /** * Register a single datatype. Re-registering an existing, but equal, datatype is ok. - * + * * @param type The datatype to register */ void registerSingleType(DataType type) { @@ -280,7 +280,7 @@ public class DocumentTypeManager { /** * Returns a read only view of the registered data types - * + * * @return collection of types */ public Collection<DataType> getDataTypes() { diff --git a/document/src/main/java/com/yahoo/document/TensorDataType.java b/document/src/main/java/com/yahoo/document/TensorDataType.java index aefdc030a12..50e9cf0f60f 100644 --- a/document/src/main/java/com/yahoo/document/TensorDataType.java +++ b/document/src/main/java/com/yahoo/document/TensorDataType.java @@ -8,13 +8,13 @@ import com.yahoo.vespa.objects.Ids; /** * A DataType containing a tensor type - * + * * @author bratseth */ public class TensorDataType extends DataType { private final TensorType tensorType; - + // The global class identifier shared with C++. public static int classId = registerClass(Ids.document + 59, TensorDataType.class); @@ -47,5 +47,5 @@ public class TensorDataType extends DataType { /** Returns the type of the tensor this field can hold */ public TensorType getTensorType() { return tensorType; } - + } diff --git a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java index ae8d5cf596a..1808396986e 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java +++ b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java @@ -19,7 +19,7 @@ import java.util.Optional; public class TensorFieldValue extends FieldValue { private Optional<Tensor> tensor; - + private final TensorDataType dataType; /** Create an empty tensor field value */ @@ -66,7 +66,7 @@ public class TensorFieldValue extends FieldValue { o.getClass().getName() + "'."); } } - + public void assignTensor(Optional<Tensor> tensor) { if (tensor.isPresent() && ! tensor.get().type().isAssignableTo(dataType.getTensorType())) throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() + diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index f37fa5ea675..29ba244a9f1 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -146,9 +146,9 @@ public class JsonReaderTestCase { } { DocumentType x = new DocumentType("testtensor"); - x.addField(new Field("mappedtensorfield", + x.addField(new Field("mappedtensorfield", new TensorDataType(new TensorType.Builder().mapped("x").mapped("y").build()))); - x.addField(new Field("indexedtensorfield", + x.addField(new Field("indexedtensorfield", new TensorDataType(new TensorType.Builder().indexed("x").indexed("y").build()))); types.registerDocumentType(x); } @@ -1280,8 +1280,8 @@ public class JsonReaderTestCase { return (DocumentPut) reader.next(); } - private DocumentPut createPutWithMappedTensor(String inputTensor) { - return createPutWithTensor(inputTensor, "mappedtensorfield"); + private DocumentPut createPutWithMappedTensor(String inputTensor) { + return createPutWithTensor(inputTensor, "mappedtensorfield"); } private DocumentPut createPutWithTensor(String inputTensor, String tensorFieldName) { InputStream rawDoc = new ByteArrayInputStream( diff --git a/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java b/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java index 7104c1686f8..5c65b11a0c4 100644 --- a/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java +++ b/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java @@ -24,7 +24,7 @@ public class TensorFieldValueSerializationTestCase { private final static TensorType tensorType = new TensorType.Builder().mapped("dimX").mapped("dimY").build(); private final static String TENSOR_FIELD = "my_tensor"; private final static String TENSOR_FILES = "src/test/resources/tensor/"; - private final static TestDocumentFactory docFactory = new TestDocumentFactory(createDocType(), + private final static TestDocumentFactory docFactory = new TestDocumentFactory(createDocType(), "id:test:my_type::foo"); private static DocumentType createDocType() { diff --git a/document/src/vespa/document/bucket/bucketspace.h b/document/src/vespa/document/bucket/bucketspace.h index 1198b173a4b..99b510f7aff 100644 --- a/document/src/vespa/document/bucket/bucketspace.h +++ b/document/src/vespa/document/bucket/bucketspace.h @@ -16,15 +16,16 @@ class BucketSpace { public: using Type = uint64_t; - BucketSpace(const BucketSpace&) noexcept = default; - BucketSpace& operator=(const BucketSpace&) noexcept = default; - explicit BucketSpace(Type id) noexcept : _id(id) {} + constexpr BucketSpace(const BucketSpace&) noexcept = default; + constexpr BucketSpace& operator=(const BucketSpace&) noexcept = default; + constexpr explicit BucketSpace(Type id) noexcept : _id(id) {} - bool operator <(const BucketSpace& bucket) const noexcept { return _id < bucket._id; } - bool operator==(const BucketSpace& bucket) const noexcept { return _id == bucket._id; } - bool operator!=(const BucketSpace& bucket) const noexcept { return _id != bucket._id; } + constexpr bool operator <(const BucketSpace& bucket) const noexcept { return _id < bucket._id; } + constexpr bool operator==(const BucketSpace& bucket) const noexcept { return _id == bucket._id; } + constexpr bool operator!=(const BucketSpace& bucket) const noexcept { return _id != bucket._id; } - Type getId() const noexcept { return _id; } + constexpr Type getId() const noexcept { return _id; } + constexpr bool valid() const noexcept { return (_id != 0); } vespalib::string toString() const; struct hash { @@ -36,7 +37,8 @@ public: /* * Temporary placeholder value while wiring in use of BucketSpace in APIs. */ - static BucketSpace placeHolder() { return BucketSpace(0); } + static constexpr BucketSpace placeHolder() noexcept { return BucketSpace(1); } + static constexpr BucketSpace invalid() noexcept { return BucketSpace(0); } private: Type _id; }; diff --git a/document/src/vespa/document/select/CMakeLists.txt b/document/src/vespa/document/select/CMakeLists.txt index 6dadd35e98a..bc73498622d 100644 --- a/document/src/vespa/document/select/CMakeLists.txt +++ b/document/src/vespa/document/select/CMakeLists.txt @@ -1,10 +1,14 @@ # Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -find_package(BISON REQUIRED) -find_package(FLEX REQUIRED) +find_package(BISON REQUIRED 3.0) +find_package(FLEX REQUIRED 2.5) -BISON_TARGET(DocSelParser grammar/parser.yy ${CMAKE_CURRENT_BINARY_DIR}/parser.cxx) -FLEX_TARGET(DocSelLexer grammar/lexer.ll ${CMAKE_CURRENT_BINARY_DIR}/lexer.cxx) +BISON_TARGET(DocSelParser grammar/parser.yy + ${CMAKE_CURRENT_BINARY_DIR}/parser.cxx + DEFINES_FILE ${CMAKE_CURRENT_BINARY_DIR}/parser.hxx) +FLEX_TARGET(DocSelLexer grammar/lexer.ll + ${CMAKE_CURRENT_BINARY_DIR}/lexer.cxx + DEFINES_FILE ${CMAKE_CURRENT_BINARY_DIR}/lexer.hxx) ADD_FLEX_BISON_DEPENDENCY(DocSelLexer DocSelParser) include_directories(${CMAKE_CURRENT_BINARY_DIR}) diff --git a/document/src/vespa/document/select/context.cpp b/document/src/vespa/document/select/context.cpp index 6d9e0df157b..3a728db33f8 100644 --- a/document/src/vespa/document/select/context.cpp +++ b/document/src/vespa/document/select/context.cpp @@ -38,10 +38,14 @@ Context::~Context() { } std::unique_ptr<Value> Context::getValue(const vespalib::string & value) const { - VariableMap::const_iterator iter = _variables->find(value); - - if (iter != _variables->end()) { - return std::make_unique<FloatValue>(iter->second); + if (_variables) { + VariableMap::const_iterator iter = _variables->find(value); + + if (iter != _variables->end()) { + return std::make_unique<FloatValue>(iter->second); + } else { + return std::make_unique<FloatValue>(0.0); + } } else { return std::make_unique<FloatValue>(0.0); } diff --git a/document/src/vespa/document/select/grammar/lexer.ll b/document/src/vespa/document/select/grammar/lexer.ll index 8cd5638c122..6483b5e8534 100644 --- a/document/src/vespa/document/select/grammar/lexer.ll +++ b/document/src/vespa/document/select/grammar/lexer.ll @@ -1,9 +1,5 @@ /* Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. */ - /* We use the .*xx-suffix to denote a build-time generated file */ -%option outfile="lexer.cxx" -%option header-file="lexer.hxx" - %option c++ /* Uncomment to enable debug tracing of parsing */ /* %option debug */ diff --git a/document/src/vespa/document/select/grammar/parser.yy b/document/src/vespa/document/select/grammar/parser.yy index baf987355c9..f96bd50378f 100644 --- a/document/src/vespa/document/select/grammar/parser.yy +++ b/document/src/vespa/document/select/grammar/parser.yy @@ -1,8 +1,5 @@ /* Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. */ -%output "parser.cxx" -%defines "parser.hxx" - /* Skeleton implementation included as part of the generated source. Note: _not_ covered by the GPL. */ %skeleton "lalr1.cc" diff --git a/document/src/vespa/document/test/make_bucket_space.cpp b/document/src/vespa/document/test/make_bucket_space.cpp index be8292fcf71..dae0399e75d 100644 --- a/document/src/vespa/document/test/make_bucket_space.cpp +++ b/document/src/vespa/document/test/make_bucket_space.cpp @@ -11,12 +11,12 @@ BucketSpace makeBucketSpace() BucketSpace makeBucketSpace(const vespalib::string &docTypeName) { - // Used by persistence conformance test to map fron document type name + // Used by persistence conformance test to map from document type name // to bucket space. See document::TestDocRepo for known document types. if (docTypeName == "no") { - return BucketSpace(2); + return BucketSpace(3); } else if (docTypeName == "testdoctype2") { - return BucketSpace(1); + return BucketSpace(2); } else { return makeBucketSpace(); } diff --git a/documentapi/src/main/java/com/yahoo/documentapi/SyncParameters.java b/documentapi/src/main/java/com/yahoo/documentapi/SyncParameters.java index 66af8061f7c..cbe322aef71 100755 --- a/documentapi/src/main/java/com/yahoo/documentapi/SyncParameters.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/SyncParameters.java @@ -1,17 +1,17 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.documentapi; -import java.time.temporal.TemporalAmount; +import java.time.Duration; import java.util.Optional; /** * Parameters for creating a synchronous session * * @author bjorncs - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen */ public class SyncParameters extends Parameters { - private final TemporalAmount defaultTimeout; + private final Duration defaultTimeout; /** * @deprecated Use {@link Builder} instead. @@ -22,21 +22,21 @@ public class SyncParameters extends Parameters { this(null); } - private SyncParameters(TemporalAmount defaultTimeout) { + private SyncParameters(Duration defaultTimeout) { this.defaultTimeout = defaultTimeout; } - public Optional<TemporalAmount> defaultTimeout() { + public Optional<Duration> defaultTimeout() { return Optional.ofNullable(defaultTimeout); } public static class Builder { - private TemporalAmount defaultTimeout; + private Duration defaultTimeout; /** * Set default timeout for all messagebus operations. */ - public void setDefaultTimeout(TemporalAmount defaultTimeout) { + public void setDefaultTimeout(Duration defaultTimeout) { this.defaultTimeout = defaultTimeout; } diff --git a/documentapi/src/main/java/com/yahoo/documentapi/SyncSession.java b/documentapi/src/main/java/com/yahoo/documentapi/SyncSession.java index ee9b1760012..ca55933e302 100755 --- a/documentapi/src/main/java/com/yahoo/documentapi/SyncSession.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/SyncSession.java @@ -8,13 +8,13 @@ import com.yahoo.document.DocumentRemove; import com.yahoo.document.DocumentUpdate; import com.yahoo.documentapi.messagebus.protocol.DocumentProtocol; -import java.time.temporal.TemporalAmount; +import java.time.Duration; /** * <p>A session for synchronous access to a document repository. This class * provides simple document access where throughput is not a concern.</p> * - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen * @author bjorncs */ public interface SyncSession extends Session { @@ -71,7 +71,7 @@ public interface SyncSession extends Session { * @throws DocumentAccessException on any messagebus error, including timeout ({@link com.yahoo.messagebus.ErrorCode#TIMEOUT}). */ // TODO Vespa 7: Remove default implementation. Consider removing get() overloads without timeout. - default Document get(DocumentId id, TemporalAmount timeout) { + default Document get(DocumentId id, Duration timeout) { return get(id); } @@ -88,8 +88,7 @@ public interface SyncSession extends Session { * @throws DocumentAccessException on any messagebus error, including timeout ({@link com.yahoo.messagebus.ErrorCode#TIMEOUT}). */ // TODO Vespa 7: Remove default implementation. Consider removing get() overloads without timeout. - default Document get(DocumentId id, String fieldSet, DocumentProtocol.Priority priority, - TemporalAmount timeout) { + default Document get(DocumentId id, String fieldSet, DocumentProtocol.Priority priority, Duration timeout) { return get(id, fieldSet, priority); } diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusSyncSession.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusSyncSession.java index f2b1816a410..e02b6029dcf 100755 --- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusSyncSession.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusSyncSession.java @@ -25,19 +25,18 @@ import com.yahoo.messagebus.MessageBus; import com.yahoo.messagebus.Reply; import com.yahoo.messagebus.ReplyHandler; -import java.time.temporal.ChronoUnit; -import java.time.temporal.TemporalAmount; +import java.time.Duration; /** * An implementation of the SyncSession interface running over message bus. * - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen * @author bjorncs */ public class MessageBusSyncSession implements MessageBusSession, SyncSession, ReplyHandler { private final MessageBusAsyncSession session; - private final TemporalAmount defaultTimeout; + private final Duration defaultTimeout; /** * Creates a new sync session running on message bus logic. @@ -87,9 +86,9 @@ public class MessageBusSyncSession implements MessageBusSession, SyncSession, Re return syncSend(msg, defaultTimeout); } - private Reply syncSend(Message msg, TemporalAmount timeout) { + private Reply syncSend(Message msg, Duration timeout) { if (timeout != null) { - msg.setTimeRemaining(timeout.get(ChronoUnit.MILLIS)); + msg.setTimeRemaining(timeout.toMillis()); } try { RequestMonitor monitor = new RequestMonitor(); @@ -135,13 +134,12 @@ public class MessageBusSyncSession implements MessageBusSession, SyncSession, Re } @Override - public Document get(DocumentId id, TemporalAmount timeout) { + public Document get(DocumentId id, Duration timeout) { return get(id, "[all]", DocumentProtocol.Priority.NORMAL_1, timeout); } @Override - public Document get(DocumentId id, String fieldSet, DocumentProtocol.Priority pri, - TemporalAmount timeout) { + public Document get(DocumentId id, String fieldSet, DocumentProtocol.Priority pri, Duration timeout) { GetDocumentMessage msg = new GetDocumentMessage(id, fieldSet); msg.setPriority(pri); diff --git a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp index ca77997bac7..0b8b98fc617 100644 --- a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp +++ b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp @@ -1,8 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <vespa/log/log.h> -LOG_SETUP("dense_dot_product_function_test"); - #include <vespa/vespalib/testkit/test_kit.h> #include <vespa/eval/eval/tensor_function.h> #include <vespa/eval/tensor/dense/dense_dot_product_function.h> @@ -12,16 +9,13 @@ LOG_SETUP("dense_dot_product_function_test"); #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/util/stash.h> +#include <vespa/log/log.h> +LOG_SETUP("dense_dot_product_function_test"); + using namespace vespalib; using namespace vespalib::eval; using namespace vespalib::tensor; -ValueType -makeType(size_t numCells) -{ - return ValueType::tensor_type({{"x", numCells}}); -} - tensor::Tensor::UP makeTensor(size_t numCells, double cellBias) { diff --git a/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp b/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp index 6f3cdd5f93f..61efdbe6d22 100644 --- a/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp +++ b/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp @@ -147,7 +147,7 @@ TEST_F("require that builder can be re-used", Fixture) } void -assertTensorCell(const std::vector<size_t> &expAddress, +assertTensorCell(const DenseTensor::Address &expAddress, double expCell, const DenseTensor::CellsIterator &itr) { diff --git a/eval/src/tests/tensor/sparse_tensor_builder/sparse_tensor_builder_test.cpp b/eval/src/tests/tensor/sparse_tensor_builder/sparse_tensor_builder_test.cpp index bd3d1ada017..708c2f761f7 100644 --- a/eval/src/tests/tensor/sparse_tensor_builder/sparse_tensor_builder_test.cpp +++ b/eval/src/tests/tensor/sparse_tensor_builder/sparse_tensor_builder_test.cpp @@ -2,9 +2,11 @@ #include <vespa/vespalib/testkit/test_kit.h> #include <vespa/eval/tensor/sparse/sparse_tensor_builder.h> +#include <vespa/eval/tensor/sparse/sparse_tensor_address_combiner.h> #include <vespa/vespalib/test/insertion_operators.h> using namespace vespalib::tensor; +using namespace vespalib::tensor::sparse; using vespalib::eval::TensorSpec; using vespalib::eval::ValueType; @@ -57,10 +59,8 @@ TEST("require that tensor can be constructed") const ValueType &type = sparseTensor.type(); const SparseTensor::Cells &cells = sparseTensor.cells(); EXPECT_EQUAL(2u, cells.size()); - assertCellValue(10, TensorAddress({{"a","1"},{"b","2"}}), - type, cells); - assertCellValue(20, TensorAddress({{"c","3"},{"d","4"}}), - type, cells); + assertCellValue(10, TensorAddress({{"a","1"},{"b","2"}}), type, cells); + assertCellValue(20, TensorAddress({{"c","3"},{"d","4"}}), type, cells); } TEST("require that tensor can be converted to tensor spec") @@ -94,4 +94,22 @@ TEST("require that dimensions are extracted") EXPECT_EQUAL("tensor(a{},b{},c{})", sparseTensor.type().to_spec()); } +void verifyAddressCombiner(const ValueType & a, const ValueType & b, size_t numDim, size_t numOverlapping) { + TensorAddressCombiner combiner(a, b); + EXPECT_EQUAL(numDim, combiner.numDimensions()); + EXPECT_EQUAL(numOverlapping, combiner.numOverlappingDimensions()); +} +TEST("Test sparse tensor address combiner") { + verifyAddressCombiner(ValueType::tensor_type({{"a"}}), ValueType::tensor_type({{"b"}}), 2, 0); + verifyAddressCombiner(ValueType::tensor_type({{"a"}, {"b"}}), ValueType::tensor_type({{"b"}}), 2, 1); + verifyAddressCombiner(ValueType::tensor_type({{"a"}, {"b"}}), ValueType::tensor_type({{"b"}, {"c"}}), 3, 1); + +} + +TEST("Test essential object sizes") { + EXPECT_EQUAL(16u, sizeof(SparseTensorAddressRef)); + EXPECT_EQUAL(24u, sizeof(std::pair<SparseTensorAddressRef, double>)); + EXPECT_EQUAL(32u, sizeof(vespalib::hash_node<std::pair<SparseTensorAddressRef, double>>)); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/vespa/eval/eval/operation.h b/eval/src/vespa/eval/eval/operation.h index 52a0fbabd22..05c974bd3ff 100644 --- a/eval/src/vespa/eval/eval/operation.h +++ b/eval/src/vespa/eval/eval/operation.h @@ -7,10 +7,8 @@ #include <vespa/vespalib/util/approx.h> #include <vespa/vespalib/util/stash.h> -namespace vespalib { -namespace eval { +namespace vespalib::eval::operation { -namespace operation { struct Neg { static double f(double a); }; struct Not { static double f(double a); }; struct Add { static double f(double a, double b); }; @@ -52,7 +50,5 @@ struct IsNan { static double f(double a); }; struct Relu { static double f(double a); }; struct Sigmoid { static double f(double a); }; struct Elu { static double f(double a); }; -} // namespace vespalib::eval::operation -} // namespace vespalib::eval -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/eval/simple_tensor.cpp b/eval/src/vespa/eval/eval/simple_tensor.cpp index 0e58d292334..1836f2088f3 100644 --- a/eval/src/vespa/eval/eval/simple_tensor.cpp +++ b/eval/src/vespa/eval/eval/simple_tensor.cpp @@ -57,14 +57,14 @@ Address select(const Address &a, const Address &b, const IndexList &selector) { return result; } -size_t get_dimension_size(const ValueType &type, size_t dim_idx) { +size_t get_dimension_size(const ValueType &type, ValueType::Dimension::size_type dim_idx) { if (dim_idx == ValueType::Dimension::npos) { return 1; } return type.dimensions()[dim_idx].size; } -size_t get_dimension_index(const Address &addr, size_t dim_idx) { +size_t get_dimension_index(const Address &addr, ValueType::Dimension::size_type dim_idx) { if (dim_idx == ValueType::Dimension::npos) { return 0; } diff --git a/eval/src/vespa/eval/eval/value_type.cpp b/eval/src/vespa/eval/eval/value_type.cpp index 03e6d2bbcdf..1c4973a78ca 100644 --- a/eval/src/vespa/eval/eval/value_type.cpp +++ b/eval/src/vespa/eval/eval/value_type.cpp @@ -101,9 +101,9 @@ struct Renamer { } // namespace vespalib::tensor::<unnamed> -constexpr size_t ValueType::Dimension::npos; +constexpr ValueType::Dimension::size_type ValueType::Dimension::npos; -ValueType::~ValueType() { } +ValueType::~ValueType() = default; bool ValueType::is_sparse() const { diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h index 2988cc5204e..a4762acd4c0 100644 --- a/eval/src/vespa/eval/eval/value_type.h +++ b/eval/src/vespa/eval/eval/value_type.h @@ -6,8 +6,7 @@ #include <vector> #include <memory> -namespace vespalib { -namespace eval { +namespace vespalib::eval { /** * The type of a Value. This is used for type-resolution during @@ -19,12 +18,13 @@ class ValueType public: enum class Type { ANY, ERROR, DOUBLE, TENSOR }; struct Dimension { - static constexpr size_t npos = -1; + using size_type = uint32_t; + static constexpr size_type npos = -1; vespalib::string name; - size_t size; + size_type size; Dimension(const vespalib::string &name_in) : name(name_in), size(npos) {} - Dimension(const vespalib::string &name_in, size_t size_in) + Dimension(const vespalib::string &name_in, size_type size_in) : name(name_in), size(size_in) {} bool operator==(const Dimension &rhs) const { return ((name == rhs.name) && (size == rhs.size)); @@ -91,5 +91,4 @@ public: std::ostream &operator<<(std::ostream &os, const ValueType &type); -} // namespace vespalib::eval -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/cell_function.h b/eval/src/vespa/eval/tensor/cell_function.h index d758cf60634..a268c9a34b1 100644 --- a/eval/src/vespa/eval/tensor/cell_function.h +++ b/eval/src/vespa/eval/tensor/cell_function.h @@ -4,8 +4,7 @@ #include <functional> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** * Interface for a function to be applied on cells in a tensor. @@ -17,5 +16,4 @@ struct CellFunction virtual double apply(double value) const = 0; }; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp index 992f2eae750..fdd0cd6638f 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp @@ -6,8 +6,7 @@ #include <vespa/eval/eval/value.h> #include <vespa/eval/tensor/tensor.h> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { using CellsRef = DenseTensorView::CellsRef; @@ -39,5 +38,5 @@ DenseDotProductFunction::eval(ConstArrayRef<eval::Value::CREF> params, Stash &st return stash.create<eval::DoubleValue>(result); } -} // namespace tensor -} // namespace vespalib +} + diff --git a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h index 8ad57d69524..288f2afd084 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h +++ b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h @@ -5,8 +5,7 @@ #include <vespa/eval/eval/tensor_function.h> #include <vespa/vespalib/hwaccelrated/iaccelrated.h> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** * Tensor function for a dot product between two 1-dimensional dense tensors. @@ -27,5 +26,5 @@ public: const eval::Value &eval(ConstArrayRef<eval::Value::CREF> params, Stash &stash) const override; }; -} // namespace tensor -} // namespace vespalib +} + diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor.cpp index 5d7e0c83267..9693e89bb75 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor.cpp @@ -4,12 +4,10 @@ #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/util/exceptions.h> #include <vespa/eval/eval/operation.h> -#include <sstream> using vespalib::eval::TensorSpec; -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { namespace { @@ -84,5 +82,5 @@ DenseTensor::operator==(const DenseTensor &rhs) const (_cells == rhs._cells); } -} // namespace vespalib::tensor -} // namespace vespalib +} + diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor.h b/eval/src/vespa/eval/tensor/dense/dense_tensor.h index 1b97438272e..c45d3c7ccb6 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor.h @@ -8,8 +8,7 @@ #include "dense_tensor_cells_iterator.h" #include "dense_tensor_view.h" -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** * A dense tensor where all dimensions are indexed. @@ -29,16 +28,13 @@ private: public: DenseTensor(); ~DenseTensor() {} - DenseTensor(const eval::ValueType &type_in, - const Cells &cells_in); - DenseTensor(const eval::ValueType &type_in, - Cells &&cells_in); - DenseTensor(eval::ValueType &&type_in, - Cells &&cells_in); + DenseTensor(const eval::ValueType &type_in, const Cells &cells_in); + DenseTensor(const eval::ValueType &type_in, Cells &&cells_in); + DenseTensor(eval::ValueType &&type_in, Cells &&cells_in); bool operator==(const DenseTensor &rhs) const; const Cells &cells() const { return _cells; } }; -} // namespace vespalib::tensor -} // namespace vespalib +} + diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp index 3e9f4f619f0..ef2a56d4582 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp @@ -4,33 +4,7 @@ #include <vespa/vespalib/util/exceptions.h> #include <cassert> -namespace vespalib { -namespace tensor { - -using Address = DenseTensorAddressCombiner::Address; - -namespace { - -class AddressReader -{ -private: - const Address &_address; - size_t _idx; - -public: - AddressReader(const Address &address) - : _address(address), - _idx(0) - {} - size_t nextLabel() { - return _address[_idx++]; - } - bool valid() { - return _idx < _address.size(); - } -}; - -} +namespace vespalib::tensor { DenseTensorAddressCombiner::~DenseTensorAddressCombiner() { } @@ -57,35 +31,7 @@ DenseTensorAddressCombiner::DenseTensorAddressCombiner(const eval::ValueType &lh _ops.push_back(AddressOp::RHS); ++rhsItr; } -} - -bool -DenseTensorAddressCombiner::combine(const CellsIterator &lhsItr, - const CellsIterator &rhsItr) -{ - _combinedAddress.clear(); - AddressReader lhsReader(lhsItr.address()); - AddressReader rhsReader(rhsItr.address()); - for (const auto &op : _ops) { - switch (op) { - case AddressOp::LHS: - _combinedAddress.emplace_back(lhsReader.nextLabel()); - break; - case AddressOp::RHS: - _combinedAddress.emplace_back(rhsReader.nextLabel()); - break; - case AddressOp::BOTH: - size_t lhsLabel = lhsReader.nextLabel(); - size_t rhsLabel = rhsReader.nextLabel(); - if (lhsLabel != rhsLabel) { - return false; - } - _combinedAddress.emplace_back(lhsLabel); - } - } - assert(!lhsReader.valid()); - assert(!rhsReader.valid()); - return true; + _combinedAddress.resize(_ops.size()); } eval::ValueType @@ -120,5 +66,4 @@ DenseTensorAddressCombiner::combineDimensions(const eval::ValueType &lhs, eval::ValueType::tensor_type(std::move(result))); } -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h index 30bfd740fdd..37fad083dc1 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h @@ -7,8 +7,7 @@ #include <vespa/eval/tensor/types.h> #include <vespa/eval/eval/value_type.h> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** @@ -19,32 +18,57 @@ namespace tensor { class DenseTensorAddressCombiner { public: - using Address = std::vector<size_t>; + using Address = DenseTensorCellsIterator::Address; private: - enum class AddressOp { - LHS, - RHS, - BOTH - }; + enum class AddressOp { LHS, RHS, BOTH }; using CellsIterator = DenseTensorCellsIterator; std::vector<AddressOp> _ops; Address _combinedAddress; + class AddressReader + { + private: + const Address &_address; + uint32_t _idx; + public: + AddressReader(const Address &address) : _address(address), _idx(0) {} + Address::value_type nextLabel() { return _address[_idx++]; } + }; public: - DenseTensorAddressCombiner(const eval::ValueType &lhs, - const eval::ValueType &rhs); + DenseTensorAddressCombiner(const eval::ValueType &lhs, const eval::ValueType &rhs); ~DenseTensorAddressCombiner(); - bool combine(const CellsIterator &lhsItr, - const CellsIterator &rhsItr); const Address &address() const { return _combinedAddress; } - static eval::ValueType combineDimensions(const eval::ValueType &lhs, const eval::ValueType &rhs); + bool combine(const CellsIterator &lhsItr, const CellsIterator &rhsItr) { + uint32_t index(0); + AddressReader lhsReader(lhsItr.address()); + AddressReader rhsReader(rhsItr.address()); + for (const auto &op : _ops) { + switch (op) { + case AddressOp::LHS: + _combinedAddress[index] = lhsReader.nextLabel(); + break; + case AddressOp::RHS: + _combinedAddress[index] = rhsReader.nextLabel(); + break; + case AddressOp::BOTH: + Address::value_type lhsLabel = lhsReader.nextLabel(); + Address::value_type rhsLabel = rhsReader.nextLabel(); + if (lhsLabel != rhsLabel) { + return false; + } + _combinedAddress[index] = lhsLabel; + } + index++; + } + return true; + } + static eval::ValueType combineDimensions(const eval::ValueType &lhs, const eval::ValueType &rhs); }; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.h index 36432c420f5..49e075f6999 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.h @@ -2,13 +2,12 @@ #pragma once -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { + class Tensor; + class DenseTensor; +} -class Tensor; -class DenseTensor; - -namespace dense { +namespace vespalib::tensor::dense { /** * Creates a new tensor using all combinations of input tensor cells with matching @@ -22,7 +21,4 @@ template <typename Function> std::unique_ptr<Tensor> apply(const DenseTensorView &lhs, const DenseTensorView &rhs, Function &&func); -} // namespace vespalib::tensor::dense -} // namespace vespalib::tensor -} // namespace vespalib - +} diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp index 65fee767690..dc47d02d47c 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp @@ -6,9 +6,7 @@ #include "dense_tensor_address_combiner.h" #include "direct_dense_tensor_builder.h" -namespace vespalib { -namespace tensor { -namespace dense { +namespace vespalib::tensor::dense { template <typename Function> std::unique_ptr<Tensor> @@ -42,6 +40,4 @@ apply(const DenseTensorView &lhs, const Tensor &rhs, Function &&func) return Tensor::UP(); } -} // namespace vespalib::tensor::dense -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp index 0b66dd51206..5d52e5f6e0e 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp @@ -83,7 +83,7 @@ DenseTensorBuilder::calculateCellAddress() const auto &dim = _dimensions[i]; if (label == UNDEFINED_LABEL) { throw IllegalArgumentException(make_string("Label for dimension '%s' is undefined. " - "Expected a value in the range [0, %zu>", + "Expected a value in the range [0, %u>", dim.name.c_str(), dim.size)); } result += (label * multiplier); diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h index 765ed57393a..3969a9335b8 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h @@ -6,8 +6,7 @@ #include <vespa/vespalib/stllike/hash_map.h> #include <vespa/eval/tensor/tensor_builder.h> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** * A builder of for dense tensors. @@ -38,5 +37,5 @@ public: Tensor::UP build(); }; -} // namespace vespalib::tensor -} // namespace vespalib +} + diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.cpp index 59b4646a22b..d20c5124330 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.cpp @@ -2,23 +2,14 @@ #include "dense_tensor_cells_iterator.h" -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { -void -DenseTensorCellsIterator::next() -{ - ++_cellIdx; - if (valid()) { - for (int64_t i = (_address.size() - 1); i >= 0; --i) { - _address[i] = (_address[i] + 1) % _type.dimensions()[i].size; - if (_address[i] != 0) { - // Outer dimension labels can only be increased when this label wraps around. - break; - } - } - } -} +DenseTensorCellsIterator::DenseTensorCellsIterator(const eval::ValueType &type_in, CellsRef cells) + : _type(type_in), + _cells(cells), + _cellIdx(0), + _address(type_in.dimensions().size(), 0) +{} +DenseTensorCellsIterator::~DenseTensorCellsIterator() = default; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h index f77517bfdc5..fcffecef764 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h @@ -8,34 +8,41 @@ #include <vespa/eval/tensor/tensor.h> #include <vespa/vespalib/util/arrayref.h> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** * Utility class to iterate over cells in a dense tensor. */ class DenseTensorCellsIterator { +public: + using size_type = eval::ValueType::Dimension::size_type; + using Address = std::vector<size_type>; private: using CellsRef = vespalib::ConstArrayRef<double>; const eval::ValueType &_type; CellsRef _cells; - size_t _cellIdx; - std::vector<size_t> _address; - + size_t _cellIdx; + Address _address; public: - DenseTensorCellsIterator(const eval::ValueType &type_in, CellsRef cells) - : _type(type_in), - _cells(cells), - _cellIdx(0), - _address(type_in.dimensions().size(), 0) - {} + DenseTensorCellsIterator(const eval::ValueType &type_in, CellsRef cells); + ~DenseTensorCellsIterator(); + void next() { + ++_cellIdx; + for (int64_t i = (_address.size() - 1); i >= 0; --i) { + _address[i]++; + if (__builtin_expect((_address[i] != _type.dimensions()[i].size), true)) { + // Outer dimension labels can only be increased when this label wraps around. + break; + } else { + _address[i] = 0; + } + } + } bool valid() const { return _cellIdx < _cells.size(); } - void next(); double cell() const { return _cells[_cellIdx]; } - const std::vector<size_t> &address() const { return _address; } + const Address &address() const { return _address; } const eval::ValueType &fast_type() const { return _type; } }; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.cpp index 1268a46b8e5..22e2a3fb78c 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.cpp @@ -11,8 +11,7 @@ using namespace vespalib::eval; using namespace vespalib::eval::tensor_function; using namespace vespalib::eval::operation; -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { namespace { @@ -89,5 +88,5 @@ DenseTensorFunctionCompiler::compile(const eval::tensor_function::Node &expr, St return InnerProductFunctionCompiler::compile(expr, stash); } -} // namespace tensor -} // namespace vespalib +} + diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.h index d5ba4e4f7a7..61c3af079e3 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.h @@ -4,11 +4,9 @@ #include <vespa/eval/eval/tensor_function.h> -namespace vespalib { +namespace vespalib { class Stash; } -class Stash; - -namespace tensor { +namespace vespalib::tensor { /** * Class that recognizes calculations over dense tensors (in tensor function intermediate representation) @@ -19,5 +17,5 @@ struct DenseTensorFunctionCompiler static const eval::TensorFunction &compile(const eval::tensor_function::Node &expr, Stash &stash); }; -} // namespace tensor -} // namespace vespalib +} + diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.h index d8f47d2234c..fb054318985 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.h @@ -4,9 +4,7 @@ #include "dense_tensor.h" -namespace vespalib { -namespace tensor { -namespace dense { +namespace vespalib::tensor::dense { /** * Returns a tensor with the given dimension(s) removed and the cell values in that dimension(s) @@ -16,6 +14,5 @@ template<typename Function> std::unique_ptr<Tensor> reduce(const DenseTensorView &tensor, const std::vector<vespalib::string> &dimensions, Function &&func); -} // namespace dense -} // namespace tensor -} // namespace vespalib +} + diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp index 30c9f17348e..74c8981168d 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp @@ -13,8 +13,7 @@ using vespalib::eval::TensorSpec; -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { namespace { @@ -228,7 +227,7 @@ DenseTensorView::accept(TensorVisitor &visitor) const addressBuilder.clear(); auto rawIndex = iterator.address().begin(); for (const auto &dimension : _typeRef.dimensions()) { - label = vespalib::make_string("%zu", *rawIndex); + label = vespalib::make_string("%u", *rawIndex); addressBuilder.add(dimension.name, label); ++rawIndex; } @@ -264,5 +263,4 @@ DenseTensorView::reduce(join_fun_t op, op); } -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h index 5a59594667d..fd95c8555f4 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h @@ -7,8 +7,7 @@ #include <vespa/eval/eval/value_type.h> #include "dense_tensor_cells_iterator.h" -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { class DenseTensor; @@ -22,6 +21,7 @@ public: using Cells = std::vector<double>; using CellsRef = ConstArrayRef<double>; using CellsIterator = DenseTensorCellsIterator; + using Address = std::vector<eval::ValueType::Dimension::size_type>; private: const eval::ValueType &_typeRef; @@ -61,5 +61,5 @@ public: virtual void accept(TensorVisitor &visitor) const override; }; -} // namespace vespalib::tensor -} // namespace vespalib +} + diff --git a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp index 45de00dc7fe..1ab78b8ee30 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp @@ -8,8 +8,7 @@ #include <vespa/vespalib/util/exceptions.h> #include <assert.h> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { DenseXWProductFunction::DenseXWProductFunction(const eval::ValueType &resultType, size_t vectorId, @@ -87,5 +86,5 @@ DenseXWProductFunction::eval(ConstArrayRef<eval::Value::CREF> params, Stash &sta return stash.create<DenseTensorView>(_resultType, outputCells); } -} // namespace tensor -} // namespace vespalib +} + diff --git a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h index db006100e5a..151f1f13800 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h +++ b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h @@ -6,8 +6,7 @@ #include "dense_tensor_view.h" #include <vespa/vespalib/hwaccelrated/iaccelrated.h> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { using XWInput = DenseTensorView::CellsRef; using XWOutput = ArrayRef<double>; @@ -49,5 +48,5 @@ public: const eval::Value &eval(ConstArrayRef<eval::Value::CREF> params, Stash &stash) const override; }; -} // namespace tensor -} // namespace vespalib +} + diff --git a/eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.cpp b/eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.cpp index f73d123d4bd..27d72e18f96 100644 --- a/eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.cpp +++ b/eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.cpp @@ -3,8 +3,7 @@ #include "direct_dense_tensor_builder.h" #include <cassert> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { using Address = DirectDenseTensorBuilder::Address; using eval::ValueType; @@ -35,7 +34,7 @@ calculateCellAddress(const Address &address, const ValueType &type) } -DirectDenseTensorBuilder::~DirectDenseTensorBuilder() { } +DirectDenseTensorBuilder::~DirectDenseTensorBuilder() = default; DirectDenseTensorBuilder::DirectDenseTensorBuilder(const ValueType &type_in) : _type(type_in), @@ -57,5 +56,5 @@ DirectDenseTensorBuilder::build() return std::make_unique<DenseTensor>(std::move(_type), std::move(_cells)); } -} // namespace tensor -} // namesapce vespalib +} + diff --git a/eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.h b/eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.h index 5e0368e8e69..865decd9fb8 100644 --- a/eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.h +++ b/eval/src/vespa/eval/tensor/dense/direct_dense_tensor_builder.h @@ -4,8 +4,7 @@ #include "dense_tensor.h" -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** * Class for building a dense tensor by inserting cell values directly into underlying array of cells. @@ -14,7 +13,7 @@ class DirectDenseTensorBuilder { public: using Cells = DenseTensor::Cells; - using Address = std::vector<size_t>; + using Address = DenseTensor::Address; private: eval::ValueType _type; @@ -27,5 +26,5 @@ public: Tensor::UP build(); }; -} // namespace tensor -} // namespace vespalib +} + diff --git a/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.cpp b/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.cpp index 71b7824ee5d..e3b4c8dee42 100644 --- a/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.cpp +++ b/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.cpp @@ -4,8 +4,7 @@ using vespalib::eval::ValueType; -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { MutableDenseTensorView::MutableValueType::MutableValueType(ValueType type_in) : _type(type_in) @@ -19,7 +18,7 @@ MutableDenseTensorView::MutableValueType::MutableValueType(ValueType type_in) } } -MutableDenseTensorView::MutableValueType::~MutableValueType() {} +MutableDenseTensorView::MutableValueType::~MutableValueType() = default; MutableDenseTensorView::MutableDenseTensorView(ValueType type_in) : DenseTensorView(_concreteType.fast_type(), CellsRef()), @@ -33,5 +32,5 @@ MutableDenseTensorView::MutableDenseTensorView(ValueType type_in, CellsRef cells { } -} // namespace tensor -} // namespace vespalib +} + diff --git a/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h index 7eee3a9483c..b68a1594905 100644 --- a/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h +++ b/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h @@ -5,8 +5,7 @@ #include "dense_tensor_view.h" #include <cassert> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** * A mutable view to a dense tensor where all dimensions are indexed. @@ -18,7 +17,7 @@ private: { private: eval::ValueType _type; - std::vector<size_t *> _unboundDimSizes; + std::vector<eval::ValueType::Dimension::size_type *> _unboundDimSizes; public: MutableValueType(eval::ValueType type_in); @@ -55,5 +54,5 @@ public: } }; -} // namespace vespalib::tensor -} // namespace vespalib +} + diff --git a/eval/src/vespa/eval/tensor/direct_tensor_builder.h b/eval/src/vespa/eval/tensor/direct_tensor_builder.h index 667cec7c7a9..1eb171eef6e 100644 --- a/eval/src/vespa/eval/tensor/direct_tensor_builder.h +++ b/eval/src/vespa/eval/tensor/direct_tensor_builder.h @@ -2,8 +2,7 @@ #pragma once -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** * Forward declaration of utility class to build tensor of type TensorT, @@ -11,5 +10,4 @@ namespace tensor { */ template <typename TensorT> class DirectTensorBuilder; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/sparse/direct_sparse_tensor_builder.h b/eval/src/vespa/eval/tensor/sparse/direct_sparse_tensor_builder.h index 33000d4889d..c977131fcd3 100644 --- a/eval/src/vespa/eval/tensor/sparse/direct_sparse_tensor_builder.h +++ b/eval/src/vespa/eval/tensor/sparse/direct_sparse_tensor_builder.h @@ -7,8 +7,7 @@ #include "sparse_tensor_address_builder.h" #include "sparse_tensor_address_padder.h" -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** * Utility class to build tensors of type SparseTensor, to be used by @@ -91,9 +90,7 @@ public: ~DirectTensorBuilder() {} Tensor::UP build() { - return std::make_unique<SparseTensor>(std::move(_type), - std::move(_cells), - std::move(_stash)); + return std::make_unique<SparseTensor>(std::move(_type), std::move(_cells), std::move(_stash)); } template <class Function> @@ -129,7 +126,7 @@ public: eval::ValueType &fast_type() { return _type; } Cells &cells() { return _cells; } + void reserve(uint32_t estimatedCells) { _cells.resize(estimatedCells*2); } }; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp index 4762f1eceb4..1aa05bf4f61 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp @@ -12,8 +12,6 @@ #include <vespa/vespalib/stllike/hash_map.hpp> #include <vespa/vespalib/stllike/hash_map_equal.hpp> #include <vespa/vespalib/util/array_equal.hpp> -#include <sstream> -#include <algorithm> using vespalib::eval::TensorSpec; @@ -35,8 +33,7 @@ copyCells(Cells &cells, const Cells &cells_in, Stash &stash) } -SparseTensor::SparseTensor(const eval::ValueType &type_in, - const Cells &cells_in) +SparseTensor::SparseTensor(const eval::ValueType &type_in, const Cells &cells_in) : _type(type_in), _cells(), _stash(STASH_CHUNK_SIZE) @@ -45,14 +42,13 @@ SparseTensor::SparseTensor(const eval::ValueType &type_in, } -SparseTensor::SparseTensor(eval::ValueType &&type_in, - Cells &&cells_in, Stash &&stash_in) +SparseTensor::SparseTensor(eval::ValueType &&type_in, Cells &&cells_in, Stash &&stash_in) : _type(std::move(type_in)), _cells(std::move(cells_in)), _stash(std::move(stash_in)) -{ -} +{ } +SparseTensor::~SparseTensor() = default; bool SparseTensor::operator==(const SparseTensor &rhs) const diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h index c7c38f0a182..2715e606729 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h @@ -11,8 +11,7 @@ #include <vespa/vespalib/stllike/string.h> #include <vespa/vespalib/util/stash.h> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** * A tensor implementation using serialized tensor addresses to @@ -22,7 +21,7 @@ namespace tensor { class SparseTensor : public Tensor { public: - using Cells = vespalib::hash_map<SparseTensorAddressRef, double>; + using Cells = hash_map<SparseTensorAddressRef, double>; static constexpr size_t STASH_CHUNK_SIZE = 16384u; @@ -32,28 +31,23 @@ private: Stash _stash; public: - explicit SparseTensor(const eval::ValueType &type_in, - const Cells &cells_in); - SparseTensor(eval::ValueType &&type_in, - Cells &&cells_in, Stash &&stash_in); + explicit SparseTensor(const eval::ValueType &type_in, const Cells &cells_in); + SparseTensor(eval::ValueType &&type_in, Cells &&cells_in, Stash &&stash_in); + ~SparseTensor() override; const Cells &cells() const { return _cells; } const eval::ValueType &fast_type() const { return _type; } bool operator==(const SparseTensor &rhs) const; eval::ValueType combineDimensionsWith(const SparseTensor &rhs) const; - virtual const eval::ValueType &type() const override; - virtual double as_double() const override; - virtual Tensor::UP apply(const CellFunction &func) const override; - virtual Tensor::UP join(join_fun_t function, - const Tensor &arg) const override; - virtual Tensor::UP reduce(join_fun_t op, - const std::vector<vespalib::string> &dimensions) - const override; - virtual bool equals(const Tensor &arg) const override; - virtual Tensor::UP clone() const override; - virtual eval::TensorSpec toSpec() const override; - virtual void accept(TensorVisitor &visitor) const override; + const eval::ValueType &type() const override; + double as_double() const override; + Tensor::UP apply(const CellFunction &func) const override; + Tensor::UP join(join_fun_t function, const Tensor &arg) const override; + Tensor::UP reduce(join_fun_t op, const std::vector<vespalib::string> &dimensions) const override; + bool equals(const Tensor &arg) const override; + Tensor::UP clone() const override; + eval::TensorSpec toSpec() const override; + void accept(TensorVisitor &visitor) const override; }; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_builder.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_builder.h index e9a66eb4539..f74ce257b31 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_builder.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_builder.h @@ -2,12 +2,10 @@ #pragma once -#include <vespa/vespalib/stllike/string.h> -#include <vector> #include "sparse_tensor_address_ref.h" +#include <vespa/vespalib/stllike/string.h> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** @@ -20,21 +18,26 @@ namespace tensor { class SparseTensorAddressBuilder { private: - std::vector<char> _address; + vespalib::Array<char> _address; - void - append(vespalib::stringref str) - { - const char *cstr = str.c_str(); - _address.insert(_address.end(), cstr, cstr + str.size() + 1); +protected: + void append(vespalib::stringref str) { + for (size_t i(0); i < str.size() + 1; i++) { + _address.push_back_fast(str[i]); + } + } + void ensure_room(size_t additional) { + if (_address.capacity() < (_address.size() + additional)) { + _address.reserve(_address.size() + additional); + } } public: - SparseTensorAddressBuilder() - : _address() - { + SparseTensorAddressBuilder() : _address() {} + void add(vespalib::stringref label) { + ensure_room(label.size()+1); + append(label); } - void add(vespalib::stringref label) { append(label); } - void addUndefined() { _address.emplace_back('\0'); } + void addUndefined() { _address.push_back('\0'); } void clear() { _address.clear(); } SparseTensorAddressRef getAddressRef() const { return SparseTensorAddressRef(&_address[0], _address.size()); @@ -42,6 +45,4 @@ public: bool empty() const { return _address.empty(); } }; - -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.cpp index b386ec82528..e0de63b90d2 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.cpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.cpp @@ -5,12 +5,9 @@ #include <vespa/eval/eval/value_type.h> #include <cassert> -namespace vespalib { -namespace tensor { -namespace sparse { +namespace vespalib::tensor::sparse { -TensorAddressCombiner::TensorAddressCombiner(const eval::ValueType &lhs, - const eval::ValueType &rhs) +TensorAddressCombiner::TensorAddressCombiner(const eval::ValueType &lhs, const eval::ValueType &rhs) { auto rhsItr = rhs.dimensions().cbegin(); auto rhsItrEnd = rhs.dimensions().cend(); @@ -32,8 +29,17 @@ TensorAddressCombiner::TensorAddressCombiner(const eval::ValueType &lhs, } } -TensorAddressCombiner::~TensorAddressCombiner() -{ +TensorAddressCombiner::~TensorAddressCombiner() = default; + +size_t +TensorAddressCombiner::numOverlappingDimensions() const { + size_t count = 0; + for (AddressOp op : _ops) { + if (op == AddressOp::BOTH) { + count++; + } + } + return count; } bool @@ -41,15 +47,16 @@ TensorAddressCombiner::combine(SparseTensorAddressRef lhsRef, SparseTensorAddressRef rhsRef) { clear(); + ensure_room(lhsRef.size() + rhsRef.size()); SparseTensorAddressDecoder lhs(lhsRef); SparseTensorAddressDecoder rhs(rhsRef); for (auto op : _ops) { switch (op) { case AddressOp::LHS: - add(lhs.decodeLabel()); + append(lhs.decodeLabel()); break; case AddressOp::RHS: - add(rhs.decodeLabel()); + append(rhs.decodeLabel()); break; case AddressOp::BOTH: auto lhsLabel(lhs.decodeLabel()); @@ -57,14 +64,10 @@ TensorAddressCombiner::combine(SparseTensorAddressRef lhsRef, if (lhsLabel != rhsLabel) { return false; } - add(lhsLabel); + append(lhsLabel); } } - assert(!lhs.valid()); - assert(!rhs.valid()); return true; } -} // namespace vespalib::tensor::sparse -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.h index 402b4bc598a..1a7f2fd8d3c 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.h @@ -3,12 +3,11 @@ #pragma once #include "sparse_tensor_address_builder.h" -#include <vespa/eval/tensor/types.h> -namespace vespalib { -namespace eval { class ValueType; } -namespace tensor { -namespace sparse { +#define VESPA_DLL_LOCAL __attribute__ ((visibility("hidden"))) + +namespace vespalib::eval { class ValueType; } +namespace vespalib::tensor::sparse { /** * Combine two tensor addresses to a new tensor address. Common dimensions @@ -16,25 +15,17 @@ namespace sparse { */ class TensorAddressCombiner : public SparseTensorAddressBuilder { - enum class AddressOp - { - LHS, - RHS, - BOTH - }; + enum class AddressOp { LHS, RHS, BOTH }; std::vector<AddressOp> _ops; - public: - TensorAddressCombiner(const eval::ValueType &lhs, - const eval::ValueType &rhs); - + TensorAddressCombiner(const eval::ValueType &lhs, const eval::ValueType &rhs); ~TensorAddressCombiner(); - bool combine(SparseTensorAddressRef lhsRef, SparseTensorAddressRef rhsRef); + VESPA_DLL_LOCAL bool combine(SparseTensorAddressRef lhsRef, SparseTensorAddressRef rhsRef); + size_t numOverlappingDimensions() const; + size_t numDimensions() const { return _ops.size(); } }; +} -} // namespace vespalib::tensor::sparse -} // namespace vespalib::tensor -} // namespace vespalib diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_decoder.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_decoder.h index 3a0502aee5b..2fbd9932009 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_decoder.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_decoder.h @@ -5,10 +5,7 @@ #include <vespa/vespalib/stllike/string.h> #include "sparse_tensor_address_ref.h" -namespace vespalib { - - -namespace tensor { +namespace vespalib::tensor { /** * A decoder for a serialized tensor address, with only labels present. @@ -40,5 +37,5 @@ public: }; -} // namespace vespalib::tensor -} // namespace vespalib +} + diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_padder.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_padder.h index 506f8b29593..29e10c778ba 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_padder.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_padder.h @@ -6,8 +6,7 @@ #include "sparse_tensor_address_decoder.h" #include <cassert> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** @@ -16,11 +15,7 @@ namespace tensor { */ class SparseTensorAddressPadder : public SparseTensorAddressBuilder { - enum class PadOp - { - PAD, - COPY - }; + enum class PadOp { PAD, COPY }; std::vector<PadOp> _padOps; @@ -67,6 +62,5 @@ public: } }; +} -} // namespace vespalib::tensor -} // namespace vespalib diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.cpp index 7da5bd8d61a..fbd0034bc14 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.cpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.cpp @@ -4,20 +4,16 @@ #include <vespa/eval/eval/value_type.h> #include <vespa/vespalib/stllike/hash_set.hpp> -namespace vespalib { -namespace tensor { -namespace sparse { +namespace vespalib::tensor::sparse { TensorAddressReducer::TensorAddressReducer(const eval::ValueType &type, - const std::vector<vespalib::string> & - removeDimensions) + const std::vector<vespalib::string> & removeDimensions) : SparseTensorAddressBuilder(), _ops() { - TensorDimensionsSet removeSet(removeDimensions.cbegin(), - removeDimensions.cend()); + TensorDimensionsSet removeSet(removeDimensions.cbegin(), removeDimensions.cend()); _ops.reserve(type.dimensions().size()); - for (auto &dim : type.dimensions()) { + for (const auto &dim : type.dimensions()) { if (removeSet.find(dim.name) != removeSet.end()) { _ops.push_back(AddressOp::REMOVE); } else { @@ -26,10 +22,7 @@ TensorAddressReducer::TensorAddressReducer(const eval::ValueType &type, } } -TensorAddressReducer::~TensorAddressReducer() -{ +TensorAddressReducer::~TensorAddressReducer() = default; + } -} // namespace vespalib::tensor::sparse -} // namespace vespalib::tensor -} // namespace vespalib diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.h index c40d34d9a53..a2034d3be49 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_reducer.h @@ -7,21 +7,15 @@ #include "sparse_tensor_address_decoder.h" #include <cassert> -namespace vespalib { -namespace eval { class ValueType; } -namespace tensor { -namespace sparse { +namespace vespalib::eval { class ValueType; } +namespace vespalib::tensor::sparse { /** * Reduce sparse tensor address by removing one or more dimensions. */ class TensorAddressReducer : public SparseTensorAddressBuilder { - enum AddressOp - { - REMOVE, - COPY - }; + enum AddressOp { REMOVE, COPY }; using AddressOps = std::vector<AddressOp>; @@ -50,7 +44,5 @@ public: } }; +} -} // namespace vespalib::tensor::sparse -} // namespace vespalib::tensor -} // namespace vespalib diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_ref.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_ref.h index 788bf1b8ddc..321690085be 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_ref.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_ref.h @@ -2,9 +2,8 @@ #pragma once -#include <vespa/vespalib/stllike/string.h> -#include <vector> #include <vespa/vespalib/util/stash.h> +#include <cstring> namespace vespalib { @@ -19,15 +18,15 @@ namespace tensor { class SparseTensorAddressRef { const void *_start; - size_t _size; - size_t _hash; + uint32_t _size; + uint32_t _hash; public: SparseTensorAddressRef() : _start(nullptr), _size(0u), _hash(0u) { } - SparseTensorAddressRef(const void *start_in, size_t size_in) + SparseTensorAddressRef(const void *start_in, uint32_t size_in) : _start(start_in), _size(size_in), _hash(calcHash()) { @@ -43,9 +42,9 @@ public: _start = res; } - size_t hash() const { return _hash; } + uint32_t hash() const { return _hash; } - size_t calcHash() const { return hashValue(_start, _size); } + uint32_t calcHash() const { return hashValue(_start, _size); } bool operator<(const SparseTensorAddressRef &rhs) const { size_t minSize = std::min(_size, rhs._size); @@ -65,7 +64,7 @@ public: } const void *start() const { return _start; } - size_t size() const { return _size; } + uint32_t size() const { return _size; } }; } // namespace vespalib::tensor diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.h index 65d05bd4ba2..ec6edf2d847 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.h @@ -2,11 +2,12 @@ #pragma once -namespace vespalib { -namespace tensor { -class Tensor; -class SparseTensor; -namespace sparse { +namespace vespalib::tensor { + class Tensor; + class SparseTensor; +} + +namespace vespalib::tensor::sparse { /** * Create new tensor using all combinations of input tensor cells with matching @@ -17,7 +18,5 @@ template <typename Function> std::unique_ptr<Tensor> apply(const SparseTensor &lhs, const SparseTensor &rhs, Function &&func); +} -} // namespace vespalib::tensor::sparse -} // namespace vespalib::tensor -} // namespace vespalib diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.hpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.hpp index 4528c8ef1df..2027e0afc9d 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.hpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.hpp @@ -7,9 +7,7 @@ #include <vespa/eval/tensor/direct_tensor_builder.h> #include "direct_sparse_tensor_builder.h" -namespace vespalib { -namespace tensor { -namespace sparse { +namespace vespalib::tensor::sparse { template <typename Function> std::unique_ptr<Tensor> @@ -17,10 +15,14 @@ apply(const SparseTensor &lhs, const SparseTensor &rhs, Function &&func) { DirectTensorBuilder<SparseTensor> builder(lhs.combineDimensionsWith(rhs)); TensorAddressCombiner addressCombiner(lhs.fast_type(), rhs.fast_type()); + size_t estimatedCells = (lhs.cells().size() * rhs.cells().size()); + if (addressCombiner.numOverlappingDimensions() != 0) { + estimatedCells = std::min(lhs.cells().size(), rhs.cells().size()); + } + builder.reserve(estimatedCells*2); for (const auto &lhsCell : lhs.cells()) { for (const auto &rhsCell : rhs.cells()) { - bool combineSuccess = addressCombiner.combine(lhsCell.first, - rhsCell.first); + bool combineSuccess = addressCombiner.combine(lhsCell.first, rhsCell.first); if (combineSuccess) { builder.insertCell(addressCombiner.getAddressRef(), func(lhsCell.second, rhsCell.second)); @@ -30,6 +32,4 @@ apply(const SparseTensor &lhs, const SparseTensor &rhs, Function &&func) return builder.build(); } -} // namespace vespalib::tensor::sparse -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.cpp index dacf0c27593..9c3b13f6260 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.cpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.cpp @@ -3,8 +3,7 @@ #include "sparse_tensor_builder.h" #include <cassert> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { SparseTensorBuilder::SparseTensorBuilder() : TensorBuilder(), @@ -19,10 +18,7 @@ SparseTensorBuilder::SparseTensorBuilder() { } -SparseTensorBuilder::~SparseTensorBuilder() -{ -} - +SparseTensorBuilder::~SparseTensorBuilder() = default; void SparseTensorBuilder::makeType() @@ -103,6 +99,5 @@ SparseTensorBuilder::build() return ret; } +} -} // namespace vespalib::tensor -} // namespace vespalib diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.h index af1566d46c5..ea5f607ff7e 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_builder.h @@ -10,8 +10,7 @@ #include <vespa/vespalib/stllike/hash_map.h> #include <vespa/vespalib/util/stash.h> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** * A builder of sparse tensors. @@ -30,17 +29,13 @@ class SparseTensorBuilder : public TensorBuilder void makeType(); public: SparseTensorBuilder(); - virtual ~SparseTensorBuilder(); + ~SparseTensorBuilder() override; - virtual Dimension - define_dimension(const vespalib::string &dimension) override; - virtual TensorBuilder & - add_label(Dimension dimension, - const vespalib::string &label) override; - virtual TensorBuilder &add_cell(double value) override; - - virtual Tensor::UP build() override; + Dimension define_dimension(const vespalib::string &dimension) override; + TensorBuilder & add_label(Dimension dimension, const vespalib::string &label) override; + TensorBuilder &add_cell(double value) override; + Tensor::UP build() override; }; -} // namespace vespalib::tensor -} // namespace vespalib +} + diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.cpp index b4c9d511d09..cd5715e7379 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.cpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.cpp @@ -1,9 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "sparse_tensor_match.h" +#include <vespa/vespalib/stllike/hash_map.hpp> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { namespace { @@ -73,9 +73,9 @@ transformAddress(SparseTensorAddressBuilder &builder, void -SparseTensorMatch::fastMatch(const TensorImplType &lhs, - const TensorImplType &rhs) +SparseTensorMatch::fastMatch(const TensorImplType &lhs, const TensorImplType &rhs) { + _builder.reserve(lhs.cells().size()); for (const auto &lhsCell : lhs.cells()) { auto rhsItr = rhs.cells().find(lhsCell.first); if (rhsItr != rhs.cells().end()) { @@ -85,13 +85,11 @@ SparseTensorMatch::fastMatch(const TensorImplType &lhs, } void -SparseTensorMatch::slowMatch(const TensorImplType &lhs, - const TensorImplType &rhs) +SparseTensorMatch::slowMatch(const TensorImplType &lhs, const TensorImplType &rhs) { std::vector<AddressOp> ops; SparseTensorAddressBuilder addressBuilder; - SparseTensorAddressPadder addressPadder(_builder.fast_type(), - lhs.fast_type()); + SparseTensorAddressPadder addressPadder(_builder.fast_type(), lhs.fast_type()); buildTransformOps(ops, lhs.fast_type(), rhs.fast_type()); for (const auto &lhsCell : lhs.cells()) { if (!transformAddress(addressBuilder, lhsCell.first, ops)) { @@ -106,8 +104,7 @@ SparseTensorMatch::slowMatch(const TensorImplType &lhs, } } -SparseTensorMatch::SparseTensorMatch(const TensorImplType &lhs, - const TensorImplType &rhs) +SparseTensorMatch::SparseTensorMatch(const TensorImplType &lhs, const TensorImplType &rhs) : Parent(lhs.combineDimensionsWith(rhs)) { if ((lhs.fast_type().dimensions().size() == rhs.fast_type().dimensions().size()) && @@ -123,6 +120,4 @@ SparseTensorMatch::SparseTensorMatch(const TensorImplType &lhs, } } - -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.h index d88386ec508..bb2c82a6d00 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_match.h @@ -4,8 +4,7 @@ #include <vespa/eval/tensor/tensor_operation.h> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** * Returns the match product of two tensors. @@ -27,5 +26,4 @@ public: SparseTensorMatch(const TensorImplType &lhs, const TensorImplType &rhs); }; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_reduce.hpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_reduce.hpp index 53ab8116255..8a43c6b52bd 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_reduce.hpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_reduce.hpp @@ -6,9 +6,7 @@ #include <vespa/eval/tensor/direct_tensor_builder.h> #include "direct_sparse_tensor_builder.h" -namespace vespalib { -namespace tensor { -namespace sparse { +namespace vespalib::tensor::sparse { template <typename Function> std::unique_ptr<Tensor> @@ -50,6 +48,7 @@ reduce(const SparseTensor &tensor, return reduceAll(tensor, builder, func); } TensorAddressReducer addressReducer(tensor.fast_type(), dimensions); + builder.reserve(tensor.cells().size()*2); for (const auto &cell : tensor.cells()) { addressReducer.reduce(cell.first); builder.insertCell(addressReducer.getAddressRef(), cell.second, func); @@ -57,6 +56,4 @@ reduce(const SparseTensor &tensor, return builder.build(); } -} // namespace vespalib::tensor::sparse -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.cpp index 1e112cbaa6e..866956dd23e 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.cpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.cpp @@ -14,12 +14,10 @@ SparseTensorUnsortedAddressBuilder::SparseTensorUnsortedAddressBuilder() { } -SparseTensorUnsortedAddressBuilder::~SparseTensorUnsortedAddressBuilder() { -} +SparseTensorUnsortedAddressBuilder::~SparseTensorUnsortedAddressBuilder() = default; void -SparseTensorUnsortedAddressBuilder::buildTo(SparseTensorAddressBuilder & - builder, +SparseTensorUnsortedAddressBuilder::buildTo(SparseTensorAddressBuilder & builder, const eval::ValueType &type) { const char *base = &_elementStrings[0]; @@ -47,3 +45,4 @@ SparseTensorUnsortedAddressBuilder::buildTo(SparseTensorAddressBuilder & } } + diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.h index 24519e924d9..681bdabc5eb 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_unsorted_address_builder.h @@ -6,9 +6,8 @@ #include <vector> #include <vespa/eval/tensor/types.h> -namespace vespalib { -namespace eval { class ValueType; } -namespace tensor { +namespace vespalib::eval { class ValueType; } +namespace vespalib::tensor { class SparseTensorAddressBuilder; @@ -73,11 +72,9 @@ public: * Sort the stored tensor address and pass it over to a strict * tensor address builder in sorted order. */ - void buildTo(SparseTensorAddressBuilder &builder, - const eval::ValueType &type); + void buildTo(SparseTensorAddressBuilder &builder, const eval::ValueType &type); void clear() { _elementStrings.clear(); _elements.clear(); } }; +} -} // namespace vespalib::tensor -} // namespace vespalib diff --git a/eval/src/vespa/eval/tensor/tensor_address.h b/eval/src/vespa/eval/tensor/tensor_address.h index 74b2aff5561..c8c60ef6fa6 100644 --- a/eval/src/vespa/eval/tensor/tensor_address.h +++ b/eval/src/vespa/eval/tensor/tensor_address.h @@ -8,8 +8,7 @@ #include <map> #include <vector> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** * A sparse immutable address to a tensor cell. @@ -87,5 +86,4 @@ public: std::ostream &operator<<(std::ostream &out, const TensorAddress::Elements &elements); std::ostream &operator<<(std::ostream &out, const TensorAddress &value); -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/tensor_address_builder.h b/eval/src/vespa/eval/tensor/tensor_address_builder.h index 40b784e051a..47ea79fd985 100644 --- a/eval/src/vespa/eval/tensor/tensor_address_builder.h +++ b/eval/src/vespa/eval/tensor/tensor_address_builder.h @@ -4,8 +4,7 @@ #include "tensor_address.h" -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** @@ -27,5 +26,4 @@ public: }; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/tensor_apply.cpp b/eval/src/vespa/eval/tensor/tensor_apply.cpp index 7c518d0516f..8f0610fed65 100644 --- a/eval/src/vespa/eval/tensor/tensor_apply.cpp +++ b/eval/src/vespa/eval/tensor/tensor_apply.cpp @@ -1,9 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "tensor_apply.h" +#include <vespa/vespalib/stllike/hash_map.hpp> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { template <class TensorT> TensorApply<TensorT>::TensorApply(const TensorImplType &tensor, @@ -17,5 +17,4 @@ TensorApply<TensorT>::TensorApply(const TensorImplType &tensor, template class TensorApply<SparseTensor>; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/tensor_apply.h b/eval/src/vespa/eval/tensor/tensor_apply.h index bd675e7ec58..bb5ffdd1885 100644 --- a/eval/src/vespa/eval/tensor/tensor_apply.h +++ b/eval/src/vespa/eval/tensor/tensor_apply.h @@ -5,8 +5,7 @@ #include "cell_function.h" #include "tensor_operation.h" -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** * Returns a tensor with the given function applied to all cells in the input tensor. @@ -23,5 +22,4 @@ public: extern template class TensorApply<SparseTensor>; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/tensor_mapper.cpp b/eval/src/vespa/eval/tensor/tensor_mapper.cpp index 25b369c246d..f1039b08816 100644 --- a/eval/src/vespa/eval/tensor/tensor_mapper.cpp +++ b/eval/src/vespa/eval/tensor/tensor_mapper.cpp @@ -8,6 +8,7 @@ #include "wrapped_simple_tensor.h" #include <vespa/eval/tensor/sparse/direct_sparse_tensor_builder.h> #include <vespa/eval/tensor/dense/dense_tensor.h> +#include <vespa/vespalib/stllike/hash_map.hpp> #include <limits> using vespalib::eval::ValueType; diff --git a/eval/src/vespa/eval/tensor/tensor_operation.h b/eval/src/vespa/eval/tensor/tensor_operation.h index 6975c21c448..827c16573d5 100644 --- a/eval/src/vespa/eval/tensor/tensor_operation.h +++ b/eval/src/vespa/eval/tensor/tensor_operation.h @@ -5,8 +5,7 @@ #include "direct_tensor_builder.h" #include <vespa/eval/tensor/sparse/direct_sparse_tensor_builder.h> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** * Base class for an operation over tensors. @@ -46,5 +45,4 @@ public: } }; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/tensor_visitor.h b/eval/src/vespa/eval/tensor/tensor_visitor.h index 4002aab6e7e..4cd9792afbd 100644 --- a/eval/src/vespa/eval/tensor/tensor_visitor.h +++ b/eval/src/vespa/eval/tensor/tensor_visitor.h @@ -6,8 +6,7 @@ #include <vespa/vespalib/stllike/string.h> #include "types.h" -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** * Class for visiting a tensor. First visit must specify dimensions, @@ -20,5 +19,4 @@ public: virtual void visit(const TensorAddress &address, double value) = 0; }; -} // namespace vespalib::tensor -} // namespace vespalib +}
\ No newline at end of file diff --git a/eval/src/vespa/eval/tensor/types.h b/eval/src/vespa/eval/tensor/types.h index aa5d8c89707..d969bc0a2fb 100644 --- a/eval/src/vespa/eval/tensor/types.h +++ b/eval/src/vespa/eval/tensor/types.h @@ -7,13 +7,11 @@ #include <vector> #include <map> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { using TensorCells = std::map<std::map<vespalib::string, vespalib::string>, double>; using TensorDimensions = std::vector<vespalib::string>; using TensorDimensionsSet = vespalib::hash_set<vespalib::string>; using DenseTensorCells = std::map<std::map<vespalib::string, size_t>, double>; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDistributionRpcServer.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDistributionRpcServer.java index 2e58455bc39..b2d1af15867 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDistributionRpcServer.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDistributionRpcServer.java @@ -80,13 +80,13 @@ public class FileDistributionRpcServer { try { if (pathToFile.isPresent()) { req.returnValues().add(new StringValue(pathToFile.get().getAbsolutePath())); - log.log(LogLevel.INFO, "File reference '" + fileReference.value() + "' available at " + pathToFile.get()); + log.log(LogLevel.DEBUG, "File reference '" + fileReference.value() + "' available at " + pathToFile.get()); } else { log.log(LogLevel.INFO, "File reference '" + fileReference.value() + "' not found, returning error"); req.setError(fileReferenceDoesNotExists, "File reference '" + fileReference.value() + "' not found"); } } catch (Throwable e) { - log.log(LogLevel.WARNING, "File reference '" + fileReference.value() + "' got exeption: " + e.getMessage()); + log.log(LogLevel.WARNING, "File reference '" + fileReference.value() + "' got exception: " + e.getMessage()); req.setError(fileReferenceInternalError, "File reference '" + fileReference.value() + "' removed"); } req.returnRequest(); @@ -123,5 +123,4 @@ public class FileDistributionRpcServer { req.returnValues().add(new Int32Value(0)); } - } diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java index 727786cdc78..5de006cd17c 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java @@ -107,7 +107,7 @@ public class FileDownloader { } else if (!file.canRead()) { throw new RuntimeException("File with reference '" + fileReference.value() + "'exists, but unable to read it"); } else { - fileReferenceDownloader.setDownloadStatus(fileReference.value(), 100.0); + fileReferenceDownloader.setDownloadStatus(fileReference, 100.0); return Optional.of(file); } } diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReceiver.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReceiver.java index d57ce4ca5de..d9d1b4984eb 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReceiver.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReceiver.java @@ -85,7 +85,7 @@ public class FileReceiver { try { inprogressFile = Files.createTempFile(tmpDirectory.toPath(), fileName, ".inprogress").toFile(); } catch (IOException e) { - String msg = "Failed creating tempfile for inprogress file for(" + fileName + ") in '" + fileReferenceDir.toPath() + "': "; + String msg = "Failed creating temp file for inprogress file for(" + fileName + ") in '" + fileReferenceDir.toPath() + "': "; log.log(LogLevel.ERROR, msg + e.getMessage(), e); throw new RuntimeException(msg, e); } @@ -103,6 +103,7 @@ public class FileReceiver { Files.write(inprogressFile.toPath(), part, StandardOpenOption.WRITE, StandardOpenOption.APPEND); } catch (IOException e) { log.log(LogLevel.ERROR, "Failed writing to file(" + inprogressFile.toPath() + "): " + e.getMessage(), e); + inprogressFile.delete(); throw new RuntimeException("Failed writing to file(" + inprogressFile.toPath() + "): ", e); } currentFileSize += part.length; @@ -247,8 +248,11 @@ public class FileReceiver { log.log(LogLevel.DEBUG, "File moved from " + tempFile.getAbsolutePath()+ " to " + destination.getAbsolutePath()); } catch (FileAlreadyExistsException e) { // Don't fail if it already exists (we might get the file from several config servers when retrying, servers are down etc. - // so it might be written already) + // so it might be written already). Delete temp file in that case, to avoid filling the disk. log.log(LogLevel.DEBUG, "File '" + destination.getAbsolutePath() + "' already exists, continuing: " + e.getMessage()); + try { + Files.delete(tempFile.toPath()); + } catch (IOException ioe) { /* ignore failure */} } catch (IOException e) { String message = "Failed moving file '" + tempFile.getAbsolutePath() + "' to '" + destination.getAbsolutePath() + "'"; log.log(LogLevel.ERROR, message, e); @@ -295,7 +299,7 @@ public class FileReceiver { try { session.addPart(partId, part); } catch (Exception e) { - log.severe("Got exception + " + e); + log.severe("Got exception " + e); retval = 1; } double completeness = (double) session.currentFileSize / (double) session.fileSize; diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java index 509231ba7ff..031506487a8 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java @@ -65,7 +65,7 @@ public class FileReferenceDownloader { Thread.sleep(10); } } - catch (InterruptedException e) {} + catch (InterruptedException e) { /* ignored */} } if ( !downloadStarted) { @@ -107,7 +107,7 @@ public class FileReferenceDownloader { if (validateResponse(request)) { log.log(LogLevel.DEBUG, "Request callback, OK. Req: " + request + "\nSpec: " + connection); if (request.returnValues().get(0).asInt32() == 0) { - log.log(LogLevel.INFO, "Found file reference '" + fileReference.value() + "' available at " + connection.getAddress()); + log.log(LogLevel.DEBUG, "Found file reference '" + fileReference.value() + "' available at " + connection.getAddress()); return true; } else { log.log(LogLevel.INFO, "File reference '" + fileReference.value() + "' not found for " + connection.getAddress()); @@ -169,10 +169,6 @@ public class FileReferenceDownloader { return status; } - void setDownloadStatus(String file, double completeness) { - setDownloadStatus(new FileReference(file), completeness); - } - void setDownloadStatus(FileReference fileReference, double completeness) { synchronized (downloads) { downloadStatus.put(fileReference, completeness); diff --git a/jaxrs_client_utils/src/main/java/com/yahoo/vespa/jaxrs/client/JerseyJaxRsClientFactory.java b/jaxrs_client_utils/src/main/java/com/yahoo/vespa/jaxrs/client/JerseyJaxRsClientFactory.java index 0b3c7ad8d3b..5a3ccfed490 100644 --- a/jaxrs_client_utils/src/main/java/com/yahoo/vespa/jaxrs/client/JerseyJaxRsClientFactory.java +++ b/jaxrs_client_utils/src/main/java/com/yahoo/vespa/jaxrs/client/JerseyJaxRsClientFactory.java @@ -6,6 +6,7 @@ import org.glassfish.jersey.client.ClientProperties; import org.glassfish.jersey.client.HttpUrlConnectorProvider; import org.glassfish.jersey.client.proxy.WebResourceFactory; +import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLContext; import javax.ws.rs.client.ClientBuilder; import javax.ws.rs.client.ClientRequestFilter; @@ -25,23 +26,26 @@ public class JerseyJaxRsClientFactory implements JaxRsClientFactory { private final int readTimeoutMs; private final SSLContext sslContext; private final String userAgent; + private final HostnameVerifier hostnameVerifier; public JerseyJaxRsClientFactory() { this(DEFAULT_CONNECT_TIMEOUT_MS, DEFAULT_READ_TIMEOUT_MS); } - public JerseyJaxRsClientFactory(SSLContext sslContext, String userAgent) { - this(DEFAULT_CONNECT_TIMEOUT_MS, DEFAULT_READ_TIMEOUT_MS, sslContext, userAgent); + public JerseyJaxRsClientFactory(SSLContext sslContext, HostnameVerifier hostnameVerifier, String userAgent) { + this(DEFAULT_CONNECT_TIMEOUT_MS, DEFAULT_READ_TIMEOUT_MS, sslContext, hostnameVerifier, userAgent); } public JerseyJaxRsClientFactory(final int connectTimeoutMs, final int readTimeoutMs) { - this(connectTimeoutMs, readTimeoutMs, null, null); + this(connectTimeoutMs, readTimeoutMs, null, null, null); } - public JerseyJaxRsClientFactory(int connectTimeoutMs, int readTimeoutMs, SSLContext sslContext, String userAgent) { + public JerseyJaxRsClientFactory(int connectTimeoutMs, int readTimeoutMs, SSLContext sslContext, + HostnameVerifier hostnameVerifier, String userAgent) { this.connectTimeoutMs = connectTimeoutMs; this.readTimeoutMs = readTimeoutMs; this.sslContext = sslContext; + this.hostnameVerifier = hostnameVerifier; this.userAgent = userAgent; } @@ -61,7 +65,9 @@ public class JerseyJaxRsClientFactory implements JaxRsClientFactory { .property(ClientProperties.FOLLOW_REDIRECTS, true); if (sslContext != null) { builder.sslContext(sslContext); - builder.hostnameVerifier((s, sslSession) -> true); + } + if (hostnameVerifier != null) { + builder.hostnameVerifier(hostnameVerifier); } if (userAgent != null) { builder.register((ClientRequestFilter) context -> diff --git a/jdisc_core/src/main/java/com/yahoo/jdisc/core/StandaloneMain.java b/jdisc_core/src/main/java/com/yahoo/jdisc/core/StandaloneMain.java index a78e4f1af40..1291418083b 100644 --- a/jdisc_core/src/main/java/com/yahoo/jdisc/core/StandaloneMain.java +++ b/jdisc_core/src/main/java/com/yahoo/jdisc/core/StandaloneMain.java @@ -36,12 +36,15 @@ public class StandaloneMain { void run(String bundleLocation) { try { + // We're not logging at this point since the application is responsible + // for setting up logging. System.out.println("debug\tInitializing application without privileges."); loader.init(bundleLocation, false); loader.start(); setupSigTermHandler(); waitForShutdown(); System.out.println("debug\tTrying to shutdown in a controlled manner."); + log.log(Level.INFO, "JDisc shutting down"); loader.stop(); System.out.println("debug\tTrying to clean up in a controlled manner."); loader.destroy(); @@ -50,7 +53,7 @@ public class StandaloneMain { } catch (Throwable e) { System.out.print("debug\tUnexpected: "); e.printStackTrace(); - log.log(Level.SEVERE, "Unexpected: ", e); + log.log(Level.SEVERE, "JDisc exiting: Throwable caught: ", e); System.exit(6); } } diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java index 5cabe8acd27..31268c823ba 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java @@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -34,6 +35,7 @@ import static com.yahoo.jdisc.http.server.jetty.Exceptions.throwUnchecked; /** * @author Simon Thoresen Hult + * @author bjorncs */ class HttpRequestDispatch { @@ -123,10 +125,10 @@ class HttpRequestDispatch { boolean reportedError = false; if (error != null) { - if (error instanceof EofException) { + if (error instanceof CompletionException && error.getCause() instanceof EofException) { log.log(Level.FINE, - "Network connection was unexpectedly terminated: " + parent.servletRequest.getRequestURI(), - error); + error, + () -> "Network connection was unexpectedly terminated: " + parent.servletRequest.getRequestURI()); } else if (!(error instanceof OverloadException || error instanceof BindingNotFoundException)) { log.log(Level.WARNING, "Request failed: " + parent.servletRequest.getRequestURI(), error); } diff --git a/metrics/src/vespa/metrics/metrictimer.h b/metrics/src/vespa/metrics/metrictimer.h index 096ba3e27af..0282c0f17ad 100644 --- a/metrics/src/vespa/metrics/metrictimer.h +++ b/metrics/src/vespa/metrics/metrictimer.h @@ -8,7 +8,7 @@ #pragma once -#include <vespa/metrics/valuemetric.h> +#include "valuemetric.h" #include <chrono> namespace metrics { diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java index 2d22f4c4ccf..868ebf39f70 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java @@ -179,9 +179,10 @@ public class NodeAdminImpl implements NodeAdmin { } }, 0, 55, TimeUnit.SECONDS); + int delay = 120; // WARNING: Reducing this will increase the load on config servers. aclScheduler.scheduleWithFixedDelay(() -> { if (!isFrozen()) aclMaintainer.run(); - }, 30, 60, TimeUnit.SECONDS); + }, 30, delay, TimeUnit.SECONDS); } @Override diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/NodeAdminProvider.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/NodeAdminProvider.java index a343e431b5a..3777d7e20d1 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/NodeAdminProvider.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/provider/NodeAdminProvider.java @@ -36,7 +36,8 @@ import java.util.function.Function; */ public class NodeAdminProvider implements Provider<NodeAdminStateUpdater> { - private static final Duration NODE_AGENT_SCAN_INTERVAL = Duration.ofSeconds(30); + // WARNING: reducing the node agent interval will increase the load on the config servers + private static final Duration NODE_AGENT_SCAN_INTERVAL = Duration.ofSeconds(60); private static final Duration NODE_ADMIN_CONVERGE_STATE_INTERVAL = Duration.ofSeconds(30); private final NodeAdminStateUpdater nodeAdminStateUpdater; diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutor.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutor.java index 3576f37eb9a..94ad94d9a65 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutor.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutor.java @@ -43,7 +43,6 @@ import java.util.Optional; */ public class ConfigServerHttpRequestExecutor { private static final PrefixLogger NODE_ADMIN_LOGGER = PrefixLogger.getNodeAdminLogger(ConfigServerHttpRequestExecutor.class); - private static final int MAX_LOOPS = 2; private final ObjectMapper mapper = new ObjectMapper(); private final CloseableHttpClient client; @@ -108,43 +107,41 @@ public class ConfigServerHttpRequestExecutor { private <T> T tryAllConfigServers(CreateRequest requestFactory, Class<T> wantedReturnType) { Exception lastException = null; - for (int loopRetry = 0; loopRetry < MAX_LOOPS; loopRetry++) { - for (URI configServer : configServerHosts) { - final CloseableHttpResponse response; - try { - response = client.execute(requestFactory.createRequest(configServer)); - } catch (Exception e) { - // Failure to communicate with a config server is not abnormal, as they are - // upgraded at the same time as Docker hosts. - if (e.getMessage().indexOf("(Connection refused)") > 0) { - NODE_ADMIN_LOGGER.info("Connection refused to " + configServer + " (upgrading?), will try next"); - } else { - NODE_ADMIN_LOGGER.warning("Failed to communicate with " + configServer + ", will try next: " + e.getMessage()); - } - lastException = e; + for (URI configServer : configServerHosts) { + final CloseableHttpResponse response; + try { + response = client.execute(requestFactory.createRequest(configServer)); + } catch (Exception e) { + // Failure to communicate with a config server is not abnormal, as they are + // upgraded at the same time as Docker hosts. + if (e.getMessage().indexOf("(Connection refused)") > 0) { + NODE_ADMIN_LOGGER.info("Connection refused to " + configServer + " (upgrading?), will try next"); + } else { + NODE_ADMIN_LOGGER.warning("Failed to communicate with " + configServer + ", will try next: " + e.getMessage()); + } + lastException = e; + continue; + } + + try { + Optional<HttpException> retryableException = HttpException.handleStatusCode( + response.getStatusLine().getStatusCode(), + "Config server " + configServer); + if (retryableException.isPresent()) { + lastException = retryableException.get(); continue; } try { - Optional<HttpException> retryableException = HttpException.handleStatusCode( - response.getStatusLine().getStatusCode(), - "Config server " + configServer); - if (retryableException.isPresent()) { - lastException = retryableException.get(); - continue; - } - - try { - return mapper.readValue(response.getEntity().getContent(), wantedReturnType); - } catch (IOException e) { - throw new RuntimeException("Response didn't contain nodes element, failed parsing?", e); - } - } finally { - try { - response.close(); - } catch (IOException e) { - NODE_ADMIN_LOGGER.warning("Ignoring exception from closing response", e); - } + return mapper.readValue(response.getEntity().getContent(), wantedReturnType); + } catch (IOException e) { + throw new RuntimeException("Response didn't contain nodes element, failed parsing?", e); + } + } finally { + try { + response.close(); + } catch (IOException e) { + NODE_ADMIN_LOGGER.warning("Ignoring exception from closing response", e); } } } diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutorTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutorTest.java index 67cd2c79034..799f8a72fd9 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutorTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/util/ConfigServerHttpRequestExecutorTest.java @@ -117,8 +117,7 @@ public class ConfigServerHttpRequestExecutorTest { } String[] log = mockLog.toString().split(" "); - assertThat(log, arrayContainingInAnyOrder("GET http://host1:666/path", "GET http://host2:666/path", - "GET http://host1:666/path", "GET http://host2:666/path")); + assertThat(log, arrayContainingInAnyOrder("GET http://host1:666/path", "GET http://host2:666/path")); } @Test @@ -134,7 +133,6 @@ public class ConfigServerHttpRequestExecutorTest { String[] log = mockLog.toString().split(" "); assertThat(log, arrayContainingInAnyOrder( - "GET http://host1:666/path", "GET http://host2:666/path", "GET http://host1:666/path", "GET http://host2:666/path")); } diff --git a/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoreCollector.java b/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoreCollector.java index 4681010940c..de08bdbe107 100644 --- a/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoreCollector.java +++ b/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoreCollector.java @@ -2,7 +2,6 @@ package com.yahoo.vespa.hosted.node.maintainer; import com.yahoo.collections.Pair; -import static com.yahoo.vespa.defaults.Defaults.getDefaults; import com.yahoo.system.ProcessExecuter; import java.io.IOException; @@ -19,6 +18,8 @@ import java.util.logging.Logger; import java.util.regex.Matcher; import java.util.regex.Pattern; +import static com.yahoo.vespa.defaults.Defaults.getDefaults; + /** * Takes in a compressed (lz4) or uncompressed core dump and collects relevant metadata. * @@ -166,7 +167,7 @@ public class CoreCollector { Path decompressedPath = Paths.get(coredumpPath.toString().replaceFirst("\\.lz4$", "")); Pair<Integer, String> result = processExecuter.exec( - new String[]{LZ4_PATH, "-d", coredumpPath.toString(), decompressedPath.toString()}); + new String[]{LZ4_PATH, "-f", "-d", coredumpPath.toString(), decompressedPath.toString()}); if (result.getFirst() != 0) { throw new RuntimeException("Failed to decompress file " + coredumpPath + ": " + result); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java index 7ef609d6311..62b00f914a3 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java @@ -50,6 +50,10 @@ public class NodeRepositoryProvisioner implements Provisioner { private final Activator activator; private final BiConsumer<List<Node>, String> debugRecorder; + int getSpareCapacityProd() { + return SPARE_CAPACITY_PROD; + } + @Inject public NodeRepositoryProvisioner(NodeRepository nodeRepository, NodeFlavors flavors, Zone zone) { this(nodeRepository, flavors, zone, Clock.systemUTC(), (x, y) -> {}); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeSpec.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeSpec.java index 4c6f1b022a4..23a6e3a8b9a 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeSpec.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeSpec.java @@ -11,14 +11,14 @@ import java.util.Objects; * A specification of a set of nodes. * This reflects that nodes can be requested either by count and flavor or by type, * and encapsulates the differences in logic between these two cases. - * + * * @author bratseth */ public interface NodeSpec { /** The node type this requests */ NodeType type(); - + /** Returns whether the given flavor is compatible with this spec */ boolean isCompatible(Flavor flavor); @@ -33,15 +33,15 @@ public interface NodeSpec { /** Returns whether the given node count is sufficient to fulfill this spec */ boolean fulfilledBy(int count); - + /** Returns the amount the given count is above the minimum amount needed to fulfill this request */ int surplusGiven(int count); - + /** Returns a specification of a fraction of all the nodes of this. It is assumed the argument is a valid divisor. */ NodeSpec fraction(int divisor); - /** - * Assigns the flavor requested by this to the given node and returns it, + /** + * Assigns the flavor requested by this to the given node and returns it, * if one is requested and it is allowed to change. * Otherwise, the node is returned unchanged. */ @@ -50,17 +50,17 @@ public interface NodeSpec { static NodeSpec from(int nodeCount, Flavor flavor) { return new CountNodeSpec(nodeCount, flavor); } - + static NodeSpec from(NodeType type) { return new TypeNodeSpec(type); } - + /** A node spec specifying a node count and a flavor */ class CountNodeSpec implements NodeSpec { - + private final int count; private final Flavor requestedFlavor; - + public CountNodeSpec(int count, Flavor flavor) { Objects.requireNonNull(flavor, "A flavor must be specified"); this.count = count; @@ -79,7 +79,7 @@ public interface NodeSpec { public NodeType type() { return NodeType.tenant; } @Override - public boolean isCompatible(Flavor flavor) { + public boolean isCompatible(Flavor flavor) { if (flavor.satisfies(requestedFlavor)) return true; return requestedFlavorCanBeAchievedByResizing(flavor); } @@ -91,7 +91,7 @@ public interface NodeSpec { public boolean specifiesNonStockFlavor() { return ! requestedFlavor.isStock(); } @Override - public boolean fulfilledBy(int count) { return count >= this.count; } + public boolean fulfilledBy(int count) { return count >= this.count; } @Override public boolean saturatedBy(int count) { return fulfilledBy(count); } // min=max for count specs @@ -101,12 +101,13 @@ public interface NodeSpec { @Override public NodeSpec fraction(int divisor) { return new CountNodeSpec(count/divisor, requestedFlavor); } - + @Override public Node assignRequestedFlavor(Node node) { // Docker nodes can change flavor in place if (requestedFlavorCanBeAchievedByResizing(node.flavor())) return node.with(requestedFlavor); + return node; } @@ -115,16 +116,19 @@ public interface NodeSpec { /** Docker nodes can be downsized in place */ private boolean requestedFlavorCanBeAchievedByResizing(Flavor flavor) { - return flavor.isDocker() && requestedFlavor.isDocker() && flavor.isLargerThan(requestedFlavor); + // TODO: Enable this when we can do it safely + // Then also re-enable ProvisioningTest.application_deployment_with_inplace_downsize() + // return flavor.isDocker() && requestedFlavor.isDocker() && flavor.isLargerThan(requestedFlavor); + return false; } - + } /** A node spec specifying a node type. This will accept all nodes of this type. */ class TypeNodeSpec implements NodeSpec { - + private final NodeType type; - + public TypeNodeSpec(NodeType type) { this.type = type; } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicDockerProvisioningTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicDockerProvisioningTest.java index 75d5862f010..1dce5830540 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicDockerProvisioningTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicDockerProvisioningTest.java @@ -97,11 +97,11 @@ public class DynamicDockerProvisioningTest { * Test relocation of nodes from spare hosts. * <p> * Setup 4 docker hosts and allocate one container on each (from two different applications) - * No headroom defined - only 2 spares. + * No headroom defined - only getSpareCapacityProd() spares. * <p> - * Check that it relocates containers away from the 2 spares + * Check that it relocates containers away from the getSpareCapacityProd() spares * <p> - * Initial allocation of app 1 and 2 --> final allocation: + * Initial allocation of app 1 and 2 --> final allocation (example using 2 spares): * <p> * | | | | | | | | | | * | | | | | --> | 2a | 2b | | | @@ -139,7 +139,8 @@ public class DynamicDockerProvisioningTest { hostsWithChildren.add(node.parentHostname().get()); } } - Assert.assertEquals(2, hostsWithChildren.size()); + Assert.assertEquals(4 - tester.provisioner().getSpareCapacityProd(), hostsWithChildren.size()); + } /** @@ -389,8 +390,14 @@ public class DynamicDockerProvisioningTest { // Verify that there is still capacity (available spare) // Fail one node and redeploy, Verify that one less node is empty. - // Setup test + ProvisioningTester tester = new ProvisioningTester(new Zone(Environment.prod, RegionName.from("us-east")), flavorsConfig()); + // Only run test if there _is_ spare capacity + if (tester.provisioner().getSpareCapacityProd() == 0) { + return; + } + + // Setup test enableDynamicAllocation(tester); ApplicationId application1 = tester.makeApplicationId(); tester.makeReadyNodes(5, "host-small", NodeType.host, 32); diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTest.java index afdce0d25cc..a7ea77618bb 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTest.java @@ -181,7 +181,7 @@ public class ProvisioningTest { SystemState state5 = prepare(application1, 2, 2, 3, 3, "default", tester); tester.activate(application1, state5.allHosts); assertEquals("Superfluous container nodes are also deactivated", - 4-2 + 5-2 + 1, tester.getNodes(application1, Node.State.inactive).size()); // + 4-2 + 5-2 + 1, tester.getNodes(application1, Node.State.inactive).size()); // assertEquals("Superfluous content nodes are retired", 5-3 + 6-3 - 1, tester.getNodes(application1, Node.State.active).retired().size()); @@ -231,6 +231,8 @@ public class ProvisioningTest { 0, tester.getNodes(application1, Node.State.active).retired().flavor("large").size()); } + // TODO: Enable when this feature is re-enabled + @Ignore @Test public void application_deployment_with_inplace_downsize() { ProvisioningTester tester = new ProvisioningTester(new Zone(Environment.prod, RegionName.from("us-east"))); @@ -761,7 +763,7 @@ public class ProvisioningTest { if (nodeCount == 0) return Collections.emptySet(); // this is a shady practice return new HashSet<>(tester.prepare(application, cluster, nodeCount, groups, flavor)); } - + private static class SystemState { private Set<HostSpec> allHosts; @@ -781,7 +783,7 @@ public class ProvisioningTest { this.content0 = content0; this.content1 = content1; } - + /** Returns a host by cluster name and index, or null if there is no host with the given values in this */ public HostSpec hostByMembership(String clusterId, int group, int index) { for (HostSpec host : allHosts) { @@ -794,7 +796,7 @@ public class ProvisioningTest { } return null; } - + private boolean groupMatches(Optional<ClusterSpec.Group> clusterGroup, int group) { if ( ! clusterGroup.isPresent()) return group==0; return clusterGroup.get().index() == group; diff --git a/persistence/src/tests/spi/CMakeLists.txt b/persistence/src/tests/spi/CMakeLists.txt index a130573e028..c51270a420c 100644 --- a/persistence/src/tests/spi/CMakeLists.txt +++ b/persistence/src/tests/spi/CMakeLists.txt @@ -2,6 +2,7 @@ vespa_add_library(persistence_testspi SOURCES clusterstatetest.cpp + fixed_bucket_spaces_test.cpp DEPENDS persistence_persistence_conformancetest persistence diff --git a/persistence/src/tests/spi/fixed_bucket_spaces_test.cpp b/persistence/src/tests/spi/fixed_bucket_spaces_test.cpp new file mode 100644 index 00000000000..7e36d80248a --- /dev/null +++ b/persistence/src/tests/spi/fixed_bucket_spaces_test.cpp @@ -0,0 +1,64 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/persistence/spi/fixed_bucket_spaces.h> +#include <cppunit/extensions/HelperMacros.h> + +namespace storage::spi { + +struct FixedBucketSpacesTest : CppUnit::TestFixture { + CPPUNIT_TEST_SUITE(FixedBucketSpacesTest); + CPPUNIT_TEST(bucket_space_from_name_is_defined_for_default_space); + CPPUNIT_TEST(bucket_space_from_name_is_defined_for_global_space); + CPPUNIT_TEST(bucket_space_from_name_throws_exception_for_unknown_space); + CPPUNIT_TEST(name_from_bucket_space_is_defined_for_default_space); + CPPUNIT_TEST(name_from_bucket_space_is_defined_for_global_space); + CPPUNIT_TEST(name_from_bucket_space_throws_exception_for_unknown_space); + CPPUNIT_TEST_SUITE_END(); + + void bucket_space_from_name_is_defined_for_default_space(); + void bucket_space_from_name_is_defined_for_global_space(); + void bucket_space_from_name_throws_exception_for_unknown_space(); + void name_from_bucket_space_is_defined_for_default_space(); + void name_from_bucket_space_is_defined_for_global_space(); + void name_from_bucket_space_throws_exception_for_unknown_space(); +}; + +CPPUNIT_TEST_SUITE_REGISTRATION(FixedBucketSpacesTest); + +using document::BucketSpace; + +void FixedBucketSpacesTest::bucket_space_from_name_is_defined_for_default_space() { + CPPUNIT_ASSERT_EQUAL(FixedBucketSpaces::default_space(), FixedBucketSpaces::from_string("default")); +} + +void FixedBucketSpacesTest::bucket_space_from_name_is_defined_for_global_space() { + CPPUNIT_ASSERT_EQUAL(FixedBucketSpaces::global_space(), FixedBucketSpaces::from_string("global")); +} + +void FixedBucketSpacesTest::bucket_space_from_name_throws_exception_for_unknown_space() { + try { + FixedBucketSpaces::from_string("banana"); + CPPUNIT_FAIL("Expected exception on unknown bucket space name"); + } catch (spi::UnknownBucketSpaceException& e) { + } +} + +void FixedBucketSpacesTest::name_from_bucket_space_is_defined_for_default_space() { + CPPUNIT_ASSERT_EQUAL(vespalib::stringref("default"), + FixedBucketSpaces::to_string(FixedBucketSpaces::default_space())); +} + +void FixedBucketSpacesTest::name_from_bucket_space_is_defined_for_global_space() { + CPPUNIT_ASSERT_EQUAL(vespalib::stringref("global"), + FixedBucketSpaces::to_string(FixedBucketSpaces::global_space())); +} + +void FixedBucketSpacesTest::name_from_bucket_space_throws_exception_for_unknown_space() { + try { + FixedBucketSpaces::to_string(BucketSpace(4567)); + CPPUNIT_FAIL("Expected exception on unknown bucket space value"); + } catch (spi::UnknownBucketSpaceException& e) { + } +} + +}
\ No newline at end of file diff --git a/persistence/src/vespa/persistence/conformancetest/conformancetest.cpp b/persistence/src/vespa/persistence/conformancetest/conformancetest.cpp index 885d3e9aad7..7f4ea9dcc2e 100644 --- a/persistence/src/vespa/persistence/conformancetest/conformancetest.cpp +++ b/persistence/src/vespa/persistence/conformancetest/conformancetest.cpp @@ -8,6 +8,7 @@ #include <vespa/document/update/documentupdate.h> #include <vespa/document/update/assignvalueupdate.h> #include <vespa/document/test/make_bucket_space.h> +#include <vespa/metrics/loadmetric.h> #include <vespa/vdslib/state/state.h> #include <vespa/vdslib/state/node.h> #include <vespa/vdslib/state/nodestate.h> diff --git a/persistence/src/vespa/persistence/spi/CMakeLists.txt b/persistence/src/vespa/persistence/spi/CMakeLists.txt index a8b1faadcd3..a2b8fa7a79c 100644 --- a/persistence/src/vespa/persistence/spi/CMakeLists.txt +++ b/persistence/src/vespa/persistence/spi/CMakeLists.txt @@ -1,19 +1,20 @@ # Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. vespa_add_library(persistence_spi OBJECT SOURCES + abstractpersistenceprovider.cpp bucket.cpp bucketinfo.cpp - exceptions.cpp - persistenceprovider.cpp - partitionstate.cpp - abstractpersistenceprovider.cpp clusterstate.cpp context.cpp + docentry.cpp + exceptions.cpp + fixed_bucket_spaces.cpp metricpersistenceprovider.cpp + partitionstate.cpp + persistenceprovider.cpp read_consistency.cpp - result + result.cpp selection.cpp test.cpp - docentry DEPENDS ) diff --git a/persistence/src/vespa/persistence/spi/context.h b/persistence/src/vespa/persistence/spi/context.h index 75d3eac4538..ca4c79e3005 100644 --- a/persistence/src/vespa/persistence/spi/context.h +++ b/persistence/src/vespa/persistence/spi/context.h @@ -29,7 +29,6 @@ #pragma once -#include <vespa/metrics/loadmetric.h> #include <persistence/spi/types.h> #include <vespa/persistence/spi/read_consistency.h> #include <vespa/vespalib/trace/trace.h> @@ -38,8 +37,7 @@ namespace metrics { class LoadType; } -namespace storage { -namespace spi { +namespace storage::spi { using LoadType = metrics::LoadType; @@ -93,6 +91,4 @@ public: { _trace.trace(level, msg, addTime); } }; -} // spi -} // storage - +} diff --git a/persistence/src/vespa/persistence/spi/fixed_bucket_spaces.cpp b/persistence/src/vespa/persistence/spi/fixed_bucket_spaces.cpp new file mode 100644 index 00000000000..6a8ec0f18f7 --- /dev/null +++ b/persistence/src/vespa/persistence/spi/fixed_bucket_spaces.cpp @@ -0,0 +1,33 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "fixed_bucket_spaces.h" + +namespace storage::spi { + +VESPA_IMPLEMENT_EXCEPTION(UnknownBucketSpaceException, vespalib::IllegalArgumentException) + +// Some sanity checks to ensure we don't mess up any legacy mappings. +static_assert(document::BucketSpace::placeHolder() != document::BucketSpace::invalid()); +static_assert(FixedBucketSpaces::default_space() == document::BucketSpace::placeHolder()); +static_assert(FixedBucketSpaces::global_space() != FixedBucketSpaces::default_space()); + +document::BucketSpace FixedBucketSpaces::from_string(vespalib::stringref name) { + if (name == "default") { + return default_space(); + } else if (name == "global") { + return global_space(); + } else { + throw UnknownBucketSpaceException("Unknown bucket space name: " + vespalib::string(name), VESPA_STRLOC); + } +} + +vespalib::stringref FixedBucketSpaces::to_string(document::BucketSpace space) { + if (space == default_space()) { + return "default"; + } else if (space == global_space()) { + return "global"; + } else { + throw UnknownBucketSpaceException("Unknown bucket space: " + space.toString(), VESPA_STRLOC); + } +} + +} diff --git a/persistence/src/vespa/persistence/spi/fixed_bucket_spaces.h b/persistence/src/vespa/persistence/spi/fixed_bucket_spaces.h new file mode 100644 index 00000000000..c2e97407797 --- /dev/null +++ b/persistence/src/vespa/persistence/spi/fixed_bucket_spaces.h @@ -0,0 +1,30 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include <vespa/document/bucket/bucketspace.h> +#include <vespa/vespalib/util/exceptions.h> +#include <vespa/vespalib/stllike/string.h> + +namespace storage::spi { + +VESPA_DEFINE_EXCEPTION(UnknownBucketSpaceException, vespalib::IllegalArgumentException); + +/** + * Minimal repository/factory of bucket spaces hard coded for default and global + * distributions. + */ +struct FixedBucketSpaces { + static constexpr document::BucketSpace default_space() { return document::BucketSpace(1); }; + static constexpr document::BucketSpace global_space() { return document::BucketSpace(2); } + + // Post-condition: returned space has valid() == true iff name + // is either "default" or "global". + // Throws UnknownBucketSpaceException if name does not map to a known bucket space. + static document::BucketSpace from_string(vespalib::stringref name); + // Post-condition: returned string can be losslessly passed to from_string() + // iff space is equal to default_space() or global_space(). + // Throws UnknownBucketSpaceException if space does not map to a known name. + static vespalib::stringref to_string(document::BucketSpace space); +}; + +} diff --git a/persistence/src/vespa/persistence/spi/metricpersistenceprovider.cpp b/persistence/src/vespa/persistence/spi/metricpersistenceprovider.cpp index 76b0a3c4686..58e662a2b1d 100644 --- a/persistence/src/vespa/persistence/spi/metricpersistenceprovider.cpp +++ b/persistence/src/vespa/persistence/spi/metricpersistenceprovider.cpp @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "metricpersistenceprovider.h" +#include <vespa/metrics/valuemetric.h> +#include <vespa/metrics/metrictimer.h> #include <cassert> #include <vespa/log/log.h> diff --git a/persistence/src/vespa/persistence/spi/metricpersistenceprovider.h b/persistence/src/vespa/persistence/spi/metricpersistenceprovider.h index e169ad098c7..b804fd21550 100644 --- a/persistence/src/vespa/persistence/spi/metricpersistenceprovider.h +++ b/persistence/src/vespa/persistence/spi/metricpersistenceprovider.h @@ -6,10 +6,10 @@ #pragma once #include "persistenceprovider.h" -#include <vespa/metrics/metrics.h> +#include <vespa/metrics/metricset.h> +#include <vespa/metrics/valuemetric.h> -namespace storage { -namespace spi { +namespace storage::spi { class MetricPersistenceProvider : public PersistenceProvider, public metrics::MetricSet @@ -61,5 +61,5 @@ private: void defineResultMetrics(int index, const char* name); }; -} // spi -} // storage +} + @@ -23,6 +23,941 @@ </developer> </developers> + <distributionManagement> + <repository> + <id>bintray-vespa-repo</id> + <url>https://api.bintray.com/maven/yahoo/maven/vespa;publish=1</url> + </repository> + </distributionManagement> + + <repositories> + <!-- Required for Athenz libraries --> + <repository> + <snapshots> + <enabled>false</enabled> + </snapshots> + <id>bintray-yahoo-maven</id> + <name>bintray</name> + <url>https://yahoo.bintray.com/maven</url> + </repository> + </repositories> + + <scm> + <connection>scm:git:git@github.com:vespa-engine/vespa.git</connection> + <developerConnection>scm:git:git@github.com:vespa-engine/vespa.git</developerConnection> + <url>git@github.com:vespa-engine/vespa.git</url> + </scm> + + <build> + <finalName>${project.artifactId}</finalName> + <extensions> + <extension> + <groupId>org.apache.maven.wagon</groupId> + <artifactId>wagon-ssh-external</artifactId> + <version>2.7</version> + </extension> + <extension> + <groupId>org.apache.maven.archetype</groupId> + <artifactId>archetype-packaging</artifactId> + <version>2.0</version> + </extension> + </extensions> + <pluginManagement> + <plugins> + <plugin> + <groupId>org.antlr</groupId> + <artifactId>antlr3-maven-plugin</artifactId> + <version>${antlr.version}</version> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-antrun-plugin</artifactId> + <version>1.7</version> + </plugin> + <plugin> + <groupId>org.apache.felix</groupId> + <artifactId>maven-bundle-plugin</artifactId> + <version>2.4.0</version> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-assembly-plugin</artifactId> + <version>2.4</version> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-compiler-plugin</artifactId> + <version>3.6.1</version> + <configuration> + <source>1.8</source> + <target>1.8</target> + <showWarnings>true</showWarnings> + <optimize>true</optimize> + <showDeprecation>false</showDeprecation> + <compilerArgs> + <arg>-Xlint:all</arg> + <arg>-Xlint:-serial</arg> + <arg>-Xlint:-try</arg> + <arg>-Xlint:-processing</arg> + <arg>-Xlint:-varargs</arg> + <arg>-Werror</arg> + </compilerArgs> + </configuration> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-dependency-plugin</artifactId> + <version>2.10</version> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-deploy-plugin</artifactId> + <version>2.5</version> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-install-plugin</artifactId> + <version>2.5.2</version> + <configuration> + <updateReleaseInfo>true</updateReleaseInfo> + </configuration> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <version>3.0.2</version> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-javadoc-plugin</artifactId> + <configuration> + <additionalparam>-Xdoclint:${doclint} -Xdoclint:-missing</additionalparam> + </configuration> + <version>2.10.4</version> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-plugin-plugin</artifactId> + <version>3.5</version> + <configuration> + <!-- see http://jira.codehaus.org/browse/MNG-5346 --> + <skipErrorNoDescriptorsFound>true</skipErrorNoDescriptorsFound> + </configuration> + <executions> + <execution> + <id>mojo-descriptor</id> + <goals> + <goal>descriptor</goal> + </goals> + </execution> + </executions> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-resources-plugin</artifactId> + <version>2.7</version> + <configuration> + <escapeString>\</escapeString> + </configuration> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-site-plugin</artifactId> + <version>3.3</version> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-source-plugin</artifactId> + <version>2.1.2</version> + <configuration> + <includePom>true</includePom> + </configuration> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-surefire-plugin</artifactId> + <version>${surefire.version}</version> + <configuration> + <redirectTestOutputToFile>${test.hide}</redirectTestOutputToFile> + <systemPropertyVariables> + <java.io.tmpdir>${project.build.directory}</java.io.tmpdir> + </systemPropertyVariables> + </configuration> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-surefire-report-plugin</artifactId> + <version>${surefire.version}</version> + <configuration> + <alwaysGenerateSurefireReport>false</alwaysGenerateSurefireReport> + <showSuccess>false</showSuccess> + </configuration> + </plugin> + <plugin> + <groupId>org.codehaus.mojo</groupId> + <artifactId>build-helper-maven-plugin</artifactId> + <version>1.9.1</version> + </plugin> + <plugin> + <groupId>org.codehaus.mojo</groupId> + <artifactId>exec-maven-plugin</artifactId> + <version>1.6.0</version> + </plugin> + <plugin> + <groupId>org.codehaus.mojo</groupId> + <artifactId>javacc-maven-plugin</artifactId> + <version>2.6</version> + </plugin> + <plugin> + <groupId>org.codehaus.mojo</groupId> + <artifactId>properties-maven-plugin</artifactId> + <version>1.0.0</version> + </plugin> + <plugin> + <groupId>net.alchim31.maven</groupId> + <artifactId>scala-maven-plugin</artifactId> + <version>3.2.2</version> + <configuration> + <args> + <arg>-unchecked</arg> + <arg>-deprecation</arg> + <arg>-feature</arg> + <arg>-Xfatal-warnings</arg> + </args> + </configuration> + </plugin> + <plugin> + <groupId>com.yahoo.vespa</groupId> + <artifactId>bundle-plugin</artifactId> + <version>${project.version}</version> + <configuration> + <configGenVersion>${project.version}</configGenVersion> + <useCommonAssemblyIds>true</useCommonAssemblyIds> + </configuration> + </plugin> + </plugins> + </pluginManagement> + </build> + <profiles> + <profile> + <id>attach-sources</id> + <activation> + <property> + <name>!skipSources</name> + </property> + </activation> + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-source-plugin</artifactId> + <executions> + <execution> + <id>attach-sources</id> + <goals> + <goal>jar-no-fork</goal> + </goals> + </execution> + </executions> + </plugin> + </plugins> + </build> + </profile> + <profile> + <id>generate-javadoc</id> + <activation> + <property> + <name>!skipJavadoc</name> + </property> + </activation> + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-javadoc-plugin</artifactId> + <executions> + <execution> + <id>generate-javadoc</id> + <phase>package</phase> + <goals> + <goal>javadoc</goal> + </goals> + </execution> + </executions> + <configuration> + <additionalparam>-Xdoclint:${doclint} -Xdoclint:-missing</additionalparam> + <failOnError>${javadoc.failOnError}</failOnError> + <quiet>true</quiet> + <show>private</show> + </configuration> + </plugin> + </plugins> + </build> + </profile> + <profile> + <id>coverage</id> + <build> + <plugins> + <plugin> + <groupId>org.codehaus.mojo</groupId> + <artifactId>exec-maven-plugin</artifactId> + <configuration> + <includePluginDependencies>true</includePluginDependencies> + </configuration> + </plugin> + <plugin> + <groupId>org.codehaus.mojo</groupId> + <artifactId>build-helper-maven-plugin</artifactId> + <executions> + <execution> + <phase>generate-sources</phase> + <goals> + <goal>add-source</goal> + </goals> + <configuration> + <sources> + <source>src/main/scala</source> + </sources> + </configuration> + </execution> + <execution> + <id>add-test-source</id> + <phase>generate-test-sources</phase> + <goals> + <goal>add-test-source</goal> + </goals> + <configuration> + <sources> + <source>src/test/scala</source> + </sources> + </configuration> + </execution> + </executions> + </plugin> + </plugins> + </build> + </profile> + <profile> + <id>sign-artifacts</id> + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-gpg-plugin</artifactId> + <version>1.6</version> + <executions> + <execution> + <id>sign-artifacts</id> + <phase>verify</phase> + <goals> + <goal>sign</goal> + </goals> + </execution> + </executions> + </plugin> + </plugins> + </build> + </profile> + </profiles> + <dependencyManagement> + <dependencies> + <dependency> + <groupId>org.apache.maven.wagon</groupId> + <artifactId>wagon-ssh-external</artifactId> + <version>2.7</version> + </dependency> + <dependency> + <groupId>com.github.cverges.expect4j</groupId> + <artifactId>expect4j</artifactId> + <version>1.6</version> + </dependency> + <dependency> + <groupId>org.apache.commons</groupId> + <artifactId>commons-compress</artifactId> + <version>1.11</version> + </dependency> + <dependency> + <groupId>org.apache.commons</groupId> + <artifactId>commons-exec</artifactId> + <version>1.3</version> + </dependency> + <dependency> + <groupId>io.airlift</groupId> + <artifactId>airline</artifactId> + <version>0.7</version> + </dependency> + <dependency> + <groupId>aopalliance</groupId> + <artifactId>aopalliance</artifactId> + <version>1.0</version> + </dependency> + <dependency> + <groupId>org.ow2.asm</groupId> + <artifactId>asm</artifactId> + <version>5.2</version> + </dependency> + <dependency> + <groupId>com.google.code.findbugs</groupId> + <artifactId>annotations</artifactId> + <version>1.3.9</version> + </dependency> + <dependency> + <groupId>com.google.code.findbugs</groupId> + <artifactId>jsr305</artifactId> + <version>1.3.9</version> + </dependency> + <dependency> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + <version>18.0</version> + </dependency> + <dependency> + <groupId>com.google.guava</groupId> + <artifactId>guava-testlib</artifactId> + <version>18.0</version> + </dependency> + <dependency> + <groupId>com.google.inject</groupId> + <artifactId>guice</artifactId> + <version>3.0</version> + </dependency> + <dependency> + <groupId>com.google.inject</groupId> + <artifactId>guice</artifactId> + <version>3.0</version> + <classifier>no_aop</classifier> + </dependency> + <dependency> + <groupId>com.google.inject.extensions</groupId> + <artifactId>guice-assistedinject</artifactId> + <version>3.0</version> + </dependency> + <dependency> + <groupId>com.google.inject.extensions</groupId> + <artifactId>guice-multibindings</artifactId> + <version>3.0</version> + </dependency> + <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + <version>3.4.0</version> + </dependency> + <dependency> + <groupId>com.googlecode.jmockit</groupId> + <artifactId>jmockit</artifactId> + <version>1.2</version> + </dependency> + <dependency> + <groupId>com.goldmansachs</groupId> + <artifactId>gs-collections</artifactId> + <version>6.1.0</version> + </dependency> + <dependency> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-core</artifactId> + <version>${jackson2.version}</version> + </dependency> + <dependency> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-databind</artifactId> + <version>${jackson2.version}</version> + </dependency> + <dependency> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-annotations</artifactId> + <version>${jackson2.version}</version> + </dependency> + <dependency> + <groupId>com.fasterxml.jackson.jaxrs</groupId> + <artifactId>jackson-jaxrs-json-provider</artifactId> + <version>${jackson2.version}</version> + </dependency> + <dependency> + <groupId>com.fasterxml.jackson.module</groupId> + <artifactId>jackson-module-jaxb-annotations</artifactId> + <version>${jackson2.version}</version> + </dependency> + <dependency> + <groupId>com.fasterxml.jackson.jaxrs</groupId> + <artifactId>jackson-jaxrs-base</artifactId> + <version>${jackson2.version}</version> + </dependency> + <dependency> + <groupId>com.fasterxml.jackson.jaxrs</groupId> + <artifactId>jackson-jaxrs-xml-provider</artifactId> + <version>${jackson2.version}</version> + </dependency> + <dependency> + <groupId>com.fasterxml.jackson.dataformat</groupId> + <artifactId>jackson-dataformat-xml</artifactId> + <version>${jackson2.version}</version> + </dependency> + <dependency> + <groupId>com.fasterxml.jackson.datatype</groupId> + <artifactId>jackson-datatype-jdk8</artifactId> + <version>${jackson2.version}</version> + </dependency> + <dependency> + <groupId>com.fasterxml.jackson.datatype</groupId> + <artifactId>jackson-datatype-jsr310</artifactId> + <version>${jackson2.version}</version> + </dependency> + <dependency> + <groupId>com.infradna.tool</groupId> + <artifactId>bridge-method-annotation</artifactId> + <version>1.4</version> + </dependency> + <dependency> + <groupId>commons-cli</groupId> + <artifactId>commons-cli</artifactId> + <version>1.3.1</version> + </dependency> + <dependency> + <groupId>commons-codec</groupId> + <artifactId>commons-codec</artifactId> + <version>1.4</version> + </dependency> + <dependency> + <groupId>commons-collections</groupId> + <artifactId>commons-collections</artifactId> + <version>3.2.1</version> + </dependency> + <dependency> + <groupId>commons-configuration</groupId> + <artifactId>commons-configuration</artifactId> + <version>1.6</version> + </dependency> + <dependency> + <groupId>commons-daemon</groupId> + <artifactId>commons-daemon</artifactId> + <version>1.0.3</version> + </dependency> + <dependency> + <groupId>commons-io</groupId> + <artifactId>commons-io</artifactId> + <version>2.4</version> + </dependency> + <dependency> + <groupId>commons-lang</groupId> + <artifactId>commons-lang</artifactId> + <version>${commons-lang.version}</version> + </dependency> + <dependency> + <!-- This version is exported by jdisc via jcl-over-slf4j. --> + <groupId>commons-logging</groupId> + <artifactId>commons-logging</artifactId> + <version>1.1.1</version> + </dependency> + <dependency> + <groupId>commons-net</groupId> + <artifactId>commons-net</artifactId> + <version>2.0</version> + </dependency> + <dependency> + <groupId>commons-pool</groupId> + <artifactId>commons-pool</artifactId> + <version>1.5.6</version> + </dependency> + <!-- Explicitly included to get Zookeeper version 3.4.10, + can be excluded if you want the Zookeeper version + used by curator by default + --> + <dependency> + <groupId>org.apache.zookeeper</groupId> + <artifactId>zookeeper</artifactId> + <version>3.4.10</version> + </dependency> + <dependency> + <groupId>org.apache.curator</groupId> + <artifactId>curator-recipes</artifactId> + <version>${curator.version}</version> + </dependency> + <dependency> + <groupId>org.apache.curator</groupId> + <artifactId>curator-test</artifactId> + <version>${curator.version}</version> + </dependency> + <dependency> + <groupId>javax.servlet</groupId> + <artifactId>javax.servlet-api</artifactId> + <version>3.1.0</version> + </dependency> + <dependency> + <groupId>junit</groupId> + <artifactId>junit</artifactId> + <version>4.12</version> + </dependency> + <dependency> + <groupId>org.antlr</groupId> + <artifactId>antlr-runtime</artifactId> + <version>${antlr.version}</version> + </dependency> + <dependency> + <groupId>org.antlr</groupId> + <artifactId>antlr4-runtime</artifactId> + <version>${antlr4.version}</version> + </dependency> + <dependency> + <groupId>org.apache.aries.spifly</groupId> + <artifactId>org.apache.aries.spifly.dynamic.bundle</artifactId> + <version>${aries.spifly.version}</version> + </dependency> + <dependency> + <groupId>org.apache.commons</groupId> + <artifactId>commons-lang3</artifactId> + <version>3.1</version> + </dependency> + <dependency> + <groupId>org.apache.felix</groupId> + <artifactId>org.apache.felix.framework</artifactId> + <version>4.2.1</version> + </dependency> + <dependency> + <groupId>org.apache.felix</groupId> + <artifactId>org.apache.felix.log</artifactId> + <version>1.0.1</version> + </dependency> + <dependency> + <groupId>org.apache.felix</groupId> + <artifactId>org.apache.felix.main</artifactId> + <version>4.2.1</version> + </dependency> + <dependency> + <groupId>org.apache.httpcomponents</groupId> + <artifactId>fluent-hc</artifactId> + <version>4.3.6</version> + </dependency> + <dependency> + <groupId>org.apache.httpcomponents</groupId> + <artifactId>httpclient</artifactId> + <version>4.3.6</version> + </dependency> + <dependency> + <groupId>org.apache.httpcomponents</groupId> + <artifactId>httpcore</artifactId> + <version>4.3.3</version> + </dependency> + <dependency> + <groupId>org.apache.httpcomponents</groupId> + <artifactId>httpmime</artifactId> + <version>4.3.6</version> + </dependency> + <dependency> + <groupId>org.apache.maven</groupId> + <artifactId>maven-artifact</artifactId> + <version>3.5.0</version> + </dependency> + <dependency> + <groupId>org.apache.maven</groupId> + <artifactId>maven-core</artifactId> + <version>3.5.0</version> + </dependency> + <dependency> + <groupId>org.apache.maven</groupId> + <artifactId>maven-model</artifactId> + <version>3.5.0</version> + </dependency> + <dependency> + <groupId>org.apache.maven.plugin-tools</groupId> + <artifactId>maven-plugin-annotations</artifactId> + <version>3.5</version> + </dependency> + <dependency> + <groupId>org.apache.maven</groupId> + <artifactId>maven-plugin-api</artifactId> + <version>3.5.0</version> + </dependency> + <dependency> + <groupId>org.apache.maven</groupId> + <artifactId>maven-project</artifactId> + <version>2.2.1</version> + </dependency> + <dependency> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <version>3.0.2</version> + </dependency> + <dependency> + <groupId>org.apache.maven.surefire</groupId> + <artifactId>surefire-junit4</artifactId> + <version>${surefire.version}</version> + </dependency> + <dependency> + <groupId>org.apache.maven.surefire</groupId> + <artifactId>surefire-providers</artifactId> + <version>${surefire.version}</version> + <type>pom</type> + </dependency> + <dependency> + <groupId>org.codehaus.jettison</groupId> + <artifactId>jettison</artifactId> + <version>1.3.1</version> + </dependency> + <dependency> + <groupId>org.cthul</groupId> + <artifactId>cthul-matchers</artifactId> + <version>1.0</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.eclipse.jetty</groupId> + <artifactId>jetty-continuation</artifactId> + <version>${jetty.version}</version> + </dependency> + <dependency> + <groupId>org.eclipse.jetty</groupId> + <artifactId>jetty-server</artifactId> + <version>${jetty.version}</version> + </dependency> + <dependency> + <groupId>org.eclipse.jetty</groupId> + <artifactId>jetty-servlet</artifactId> + <version>${jetty.version}</version> + </dependency> + <dependency> + <groupId>org.eclipse.jetty</groupId> + <artifactId>jetty-servlets</artifactId> + <version>${jetty.version}</version> + </dependency> + <dependency> + <groupId>org.eclipse.jetty</groupId> + <artifactId>jetty-util</artifactId> + <version>${jetty.version}</version> + </dependency> + <dependency> + <groupId>org.eclipse.jetty</groupId> + <artifactId>jetty-http</artifactId> + <version>${jetty.version}</version> + </dependency> + <dependency> + <groupId>org.eclipse.jetty</groupId> + <artifactId>jetty-jmx</artifactId> + <version>${jetty.version}</version> + </dependency> + <dependency> + <groupId>org.hamcrest</groupId> + <artifactId>hamcrest-all</artifactId> + <version>1.3</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.hamcrest</groupId> + <artifactId>hamcrest-core</artifactId> + <version>1.3</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.hamcrest</groupId> + <artifactId>hamcrest-library</artifactId> + <version>1.3</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>uk.co.datumedge</groupId> + <artifactId>hamcrest-json</artifactId> + <version>0.2</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.hdrhistogram</groupId> + <artifactId>HdrHistogram</artifactId> + <version>2.1.8</version> + </dependency> + <dependency> + <groupId>org.json</groupId> + <artifactId>json</artifactId> + <version>20090211</version> + </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-all</artifactId> + <version>1.9.5</version> + </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-core</artifactId> + <version>1.9.5</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.osgi</groupId> + <artifactId>org.osgi.compendium</artifactId> + <version>4.3.0</version> + </dependency> + <dependency> + <groupId>org.osgi</groupId> + <artifactId>org.osgi.core</artifactId> + <version>4.3.0</version> + </dependency> + <dependency> + <groupId>org.scala-lang</groupId> + <artifactId>scala-library</artifactId> + <version>${scala.version}</version> + </dependency> + <dependency> + <groupId>org.scala-lang.modules</groupId> + <artifactId>scala-parser-combinators_${scala.major-version}</artifactId> + <version>1.0.1</version> + </dependency> + <dependency> + <groupId>org.scala-lang.modules</groupId> + <artifactId>scala-xml_${scala.major-version}</artifactId> + <version>1.0.2</version> + </dependency> + <dependency> + <groupId>org.scalatest</groupId> + <artifactId>scalatest_${scala.major-version}</artifactId> + <version>2.2.2</version> + </dependency> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>jcl-over-slf4j</artifactId> + <version>1.7.5</version> + </dependency> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>log4j-over-slf4j</artifactId> + <version>1.7.5</version> + </dependency> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-api</artifactId> + <version>1.7.5</version> + </dependency> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-jdk14</artifactId> + <version>1.7.5</version> + </dependency> + <dependency> + <groupId>org.springframework</groupId> + <artifactId>spring-test</artifactId> + <version>4.0.6.RELEASE</version> + </dependency> + <dependency> + <groupId>org.testng</groupId> + <artifactId>testng</artifactId> + <version>6.10</version> + </dependency> + <dependency> + <groupId>org.twdata.maven</groupId> + <artifactId>mojo-executor</artifactId> + <version>2.3.0</version> + </dependency> + <dependency> + <groupId>net.jcip</groupId> + <artifactId>jcip-annotations</artifactId> + <version>1.0</version> + </dependency> + <dependency> + <groupId>net.jpountz.lz4</groupId> + <artifactId>lz4</artifactId> + <version>1.3.0</version> + </dependency> + <dependency> + <groupId>net.spy</groupId> + <artifactId>spymemcached</artifactId> + <version>2.10.1</version> + </dependency> + <dependency> + <groupId>xerces</groupId> + <artifactId>xercesImpl</artifactId> + <version>2.11.0</version> + </dependency> + <dependency> + <groupId>org.bouncycastle</groupId> + <artifactId>bcpkix-jdk15on</artifactId> + <version>${bouncycastle.version}</version> + </dependency> + <dependency> + <groupId>org.bouncycastle</groupId> + <artifactId>bcprov-jdk15on</artifactId> + <version>${bouncycastle.version}</version> + </dependency> + <!-- jersey 2 support --> + <dependency> + <groupId>javax.ws.rs</groupId> + <artifactId>javax.ws.rs-api</artifactId> + <version>${javax.ws.rs-api.version}</version> + </dependency> + <dependency> + <groupId>org.glassfish.jersey.containers</groupId> + <artifactId>jersey-container-servlet-core</artifactId> + <version>${jersey2.version}</version> + </dependency> + <dependency> + <groupId>org.glassfish.jersey.containers</groupId> + <artifactId>jersey-container-servlet</artifactId> + <version>${jersey2.version}</version> + </dependency> + <dependency> + <groupId>org.glassfish.jersey.media</groupId> + <artifactId>jersey-media-json-jackson</artifactId> + <version>${jersey2.version}</version> + </dependency> + <dependency> + <groupId>org.glassfish.jersey.media</groupId> + <artifactId>jersey-media-multipart</artifactId> + <version>${jersey2.version}</version> + </dependency> + <dependency> + <groupId>org.glassfish.jersey.ext</groupId> + <artifactId>jersey-proxy-client</artifactId> + <version>${jersey2.version}</version> + </dependency> + <dependency> + <groupId>org.glassfish.jersey.core</groupId> + <artifactId>jersey-client</artifactId> + <version>${jersey2.version}</version> + </dependency> + <dependency> + <groupId>com.ibm.icu</groupId> + <artifactId>icu4j</artifactId> + <version>57.1</version> + </dependency> + <dependency> + <groupId>com.yahoo.athenz</groupId> + <artifactId>athenz-zms-java-client</artifactId> + <version>${athenz.version}</version> + </dependency> + <dependency> + <groupId>com.yahoo.athenz</groupId> + <artifactId>athenz-zts-java-client</artifactId> + <version>${athenz.version}</version> + </dependency> + </dependencies> + </dependencyManagement> + + <properties> + <javax.ws.rs-api.version>2.0.1</javax.ws.rs-api.version> <!-- must be kept in sync with version used by current jersey2.version --> + <antlr.version>3.5.2</antlr.version> + <antlr4.version>4.5</antlr4.version> + <aries.spifly.version>1.0.8</aries.spifly.version> + <aries.util.version>1.0.0</aries.util.version> + <asm-debug-all.version>5.0.3</asm-debug-all.version> + <!-- Athenz dependencies. Make sure these dependencies matches those in Vespa's internal repositories --> + <athenz.version>1.7.28</athenz.version> + <bouncycastle.version>1.58</bouncycastle.version> + <commons-lang.version>2.6</commons-lang.version> + <!-- WARNING: If you change curator version, you also need to update + zkfacade/src/main/java/org/apache/curator/**/package-info.java + using something like + find zkfacade/src/main/java/org/apache/curator -name package-info.java | \ + xargs perl -pi -e 's/major = [0-9]+, minor = [0-9]+, micro = [0-9]+/major = 2, minor = 9, micro = 1/g' + --> + <curator.version>2.9.1</curator.version> + <jackson2.version>2.8.3</jackson2.version> + <jersey2.version>2.23.2</jersey2.version> + <jetty.version>9.4.6.v20170531</jetty.version> + <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> + <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding> + <test.hide>true</test.hide> + <doclint>all</doclint> + <scala.major-version>2.11</scala.major-version> + <scala.version>${scala.major-version}.4</scala.version> + <surefire.version>2.19.1</surefire.version> <!-- NOTE bjorncs 15.06.2017: Version 2.20 has OoM issues --> + </properties> + <modules> <module>application</module> <module>application-deploy-plugin</module> diff --git a/searchcore/src/tests/proton/persistenceengine/persistenceengine_test.cpp b/searchcore/src/tests/proton/persistenceengine/persistenceengine_test.cpp index 4a195514db1..4b73e4ca115 100644 --- a/searchcore/src/tests/proton/persistenceengine/persistenceengine_test.cpp +++ b/searchcore/src/tests/proton/persistenceengine/persistenceengine_test.cpp @@ -13,6 +13,7 @@ #include <vespa/searchcore/proton/persistenceengine/persistenceengine.h> #include <vespa/vdslib/distribution/distribution.h> #include <vespa/vdslib/state/clusterstate.h> +#include <vespa/metrics/loadmetric.h> #include <vespa/vespalib/testkit/testapp.h> #include <algorithm> #include <set> @@ -369,7 +370,7 @@ Timestamp tstamp2(2); Timestamp tstamp3(3); DocumentSelection doc_sel(""); Selection selection(doc_sel); -BucketSpace altBucketSpace(1); +BucketSpace altBucketSpace(2); void diff --git a/searchcore/src/vespa/searchcore/proton/persistenceengine/persistenceengine.cpp b/searchcore/src/vespa/searchcore/proton/persistenceengine/persistenceengine.cpp index c7b01d209ee..e2b389fb898 100644 --- a/searchcore/src/vespa/searchcore/proton/persistenceengine/persistenceengine.cpp +++ b/searchcore/src/vespa/searchcore/proton/persistenceengine/persistenceengine.cpp @@ -3,8 +3,8 @@ #include "persistenceengine.h" #include "ipersistenceengineowner.h" #include "transport_latch.h" +#include <vespa/metrics/loadmetric.h> #include <vespa/vespalib/stllike/hash_set.h> -#include <vespa/fastos/thread.h> #include <vespa/log/log.h> LOG_SETUP(".proton.persistenceengine.persistenceengine"); @@ -23,6 +23,8 @@ using vespalib::IllegalStateException; using vespalib::Sequence; using vespalib::make_string; +using namespace std::chrono_literals; + namespace proton { namespace { @@ -623,7 +625,7 @@ PersistenceEngine::destroyIterators() Result res(destroyIterator(id, context)); if (res.hasError()) { LOG(debug, "%ld iterator left. Can not destroy iterator '%ld'. Reason='%s'", _iterators.size(), id.getValue(), res.toString().c_str()); - FastOS_Thread::Sleep(100); // Sleep 0.1 seconds + std::this_thread::sleep_for(100ms); } } } diff --git a/searchlib/pom.xml b/searchlib/pom.xml index 5f6717d9516..09ccf9928b7 100644 --- a/searchlib/pom.xml +++ b/searchlib/pom.xml @@ -36,6 +36,21 @@ <version>${project.version}</version> </dependency> <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + <version>3.4.0</version> + </dependency> + <dependency> + <groupId>org.tensorflow</groupId> + <artifactId>proto</artifactId> + <version>1.4.0</version> + </dependency> + <dependency> + <groupId>org.tensorflow</groupId> + <artifactId>tensorflow</artifactId> + <version>1.4.0</version> + </dependency> + <dependency> <groupId>com.fasterxml.jackson.core</groupId> <artifactId>jackson-core</artifactId> <scope>test</scope> diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index 785ed78492e..0eeb0a9e630 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.tensor.Tensor; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Set; @@ -18,26 +19,30 @@ public abstract class Context implements EvaluationContext { /** * <p>Returns the value of a simple variable name.</p> * - * @param name The name of the variable whose value to return. - * @return The value of the named variable. + * @param name the name of the variable whose value to return. + * @return the value of the named variable. */ public abstract Value get(String name); + /** Returns a variable as a tensor */ + @Override + public Tensor getTensor(String name) { return get(name).asTensor(); } + /** * <p>Returns the value of a <i>structured variable</i> on the form * <code>name(argument*)(.output)?</code>, where <i>argument</i> is any * string. This may be used to implement more advanced variables whose * values are calculated at runtime from arguments. Supporting this in a - * context is optional. - * + * context is optional. + * * <p>This default implementation generates a name on the form * <code>name(argument1, argument2, ...argumentN).output</code>. * If there are no arguments the parenthesis are omitted. * If there is no output, the dot is omitted.</p> * - * @param name The name of this variable. - * @param arguments The parsed arguments as given in the textual expression. - * @param output The name of the value to output (to enable one named + * @param name the name of this variable. + * @param arguments the parsed arguments as given in the textual expression. + * @param output the name of the value to output (to enable one named * calculation to output several), or null to output the * "main" (or only) value. */ @@ -54,20 +59,20 @@ public abstract class Context implements EvaluationContext { * context subclasses. This default implementation throws * UnsupportedOperationException.</p> * - * @param index The index of the variable whose value to return. - * @return The value of the indexed variable. + * @param index the index of the variable whose value to return. + * @return the value of the indexed variable. */ public Value get(int index) { throw new UnsupportedOperationException(this + " does not support variable lookup by index"); } /** - * <p>Lookup by index rather than name directly to a double. This is supported by some optimized + * Lookup by index rather than name directly to a double. This is supported by some optimized * context subclasses. This default implementation throws - * UnsupportedOperationException.</p> + * UnsupportedOperationException. * - * @param index The index of the variable whose value to return. - * @return The value of the indexed variable. + * @param index the index of the variable whose value to return. + * @return the value of the indexed variable. */ public double getDouble(int index) { throw new UnsupportedOperationException(this + " does not support variable lookup by index"); @@ -81,24 +86,23 @@ public abstract class Context implements EvaluationContext { } /** - * <p>Sets a value to this, or throws an UnsupportedOperationException if - * this is not supported. This default implementation does the latter.</p> * + * Sets a value to this, or throws an UnsupportedOperationException if + * this is not supported. This default implementation does the latter. * - * @param name The name of the variable to set. + * @param name the name of the variable to set. * @param value the value to set. Ownership of this value is transferred to this - if it is mutable * (not frozen) it may be modified during execution - * @since 5.1.5 */ public void put(String name, Value value) { throw new UnsupportedOperationException(this + " does not support variable assignment"); } /** - * <p>Returns all the names available in this, or throws an + * Returns all the names available in this, or throws an * UnsupportedOperationException if this operation is not supported. This - * default implementation does the latter.</p> + * default implementation does the latter. * - * @return The set of all variable names. + * @return the set of all variable names. */ public Set<String> names() { throw new UnsupportedOperationException(this + " does not support return a list of its names"); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index ea750295423..2ef4a2ede2f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -3,6 +3,9 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; /** * A value which acts as a double in numerical context. @@ -16,6 +19,11 @@ public abstract class DoubleCompatibleValue extends Value { public boolean hasDouble() { return true; } @Override + public Tensor asTensor() { + return doubleAsTensor(asDouble()); + } + + @Override public Value negate() { return new DoubleValue(-asDouble()); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index ac8aba6a617..dad69b31181 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -4,12 +4,14 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; /** * A string value. * * @author bratseth - * @since 5.1.21 */ public class StringValue extends Value { @@ -35,6 +37,11 @@ public class StringValue extends Value { } @Override + public Tensor asTensor() { + return doubleAsTensor(asDouble()); + } + + @Override public boolean hasDouble() { return true; } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 49c3ccb7b01..26c30fe5ed2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -2,14 +2,10 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.google.common.annotations.Beta; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; -import com.yahoo.tensor.TensorType; - -import java.util.Collections; -import java.util.Optional; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; /** * A Value containing a tensor. @@ -23,7 +19,7 @@ public class TensorValue extends Value { /** The tensor value of this */ private final Tensor value; - + public TensorValue(Tensor value) { this.value = value; } @@ -131,7 +127,7 @@ public class TensorValue extends Value { public Value compare(TruthOperator operator, Value argument) { return new TensorValue(compareTensor(operator, asTensor(argument, operator.toString()))); } - + private Tensor compareTensor(TruthOperator operator, Tensor argument) { switch (operator) { case LARGER: return value.larger(argument); @@ -152,7 +148,7 @@ public class TensorValue extends Value { else return new TensorValue(value.map((value) -> function.evaluate(value, arg.asDouble()))); } - + private Tensor functionOnTensor(Function function, Tensor argument) { switch (function) { case min: return value.min(argument); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index b2ccbe572d0..40d70e0022c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -5,6 +5,8 @@ import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; /** * The result of a ranking expression evaluation. @@ -25,6 +27,14 @@ public abstract class Value { return new DoubleValue(asDouble()); } + /** Returns this as a tensor value */ + public abstract Tensor asTensor(); + + /** A utility method for wrapping a sdouble in a rank 0 tensor */ + protected Tensor doubleAsTensor(double value) { + return Tensor.Builder.of(TensorType.empty).cell(TensorAddress.of(), value).build(); + } + /** Returns true if this value can return itself as a double, i.e asDoubleValue will return a value and not throw */ public abstract boolean hasDouble(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java new file mode 100644 index 00000000000..947e6d7a5e1 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java @@ -0,0 +1,102 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * The result of importing a TensorFlow model into Vespa. + * - A set of signatures which are named collections of inputs and outputs. + * - A set of named constant tensors represented by Variable nodes in TensorFlow. + * - A list of warning messages. + * + * @author bratseth + */ +// This object can be built incrementally within this package, but is immutable when observed from outside the package +public class ImportResult { + + private final Map<String, Signature> signatures = new HashMap<>(); + private final Map<String, TensorType> arguments = new HashMap<>(); + private final Map<String, Tensor> constants = new HashMap<>(); + private final Map<String, RankingExpression> expressions = new HashMap<>(); + private final List<String> warnings = new ArrayList<>(); + + void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } + void constant(String name, Tensor constant) { constants.put(name, constant); } + void expression(String name, RankingExpression expression) { expressions.put(name, expression); } + void warn(String warning) { warnings.add(warning); } + + /** Returns the given signature. If it does not already exist it is added to this. */ + Signature signature(String name) { + return signatures.computeIfAbsent(name, n -> new Signature(n)); + } + + /** Returns an immutable map of the arguments ("Placeholders") of this */ + public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } + + /** Returns an immutable map of the constants of this */ + public Map<String, Tensor> constants() { return Collections.unmodifiableMap(constants); } + + /** + * Returns an immutable map of the expressions of this - corresponding to TensorFlow nodes + * which are not Placeholders or Variables (which instead become respectively arguments and constants). + * Note that only nodes recursively referenced by a placeholder are added. + */ + public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } + + /** Returns an immutable list, in natural sort order of the warnings generated while importing this */ + public List<String> warnings() { + return warnings.stream().sorted().collect(Collectors.toList()); + } + + /** Returns an immutable map of the signatures of this */ + public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } + + /** + * A signature is a set of named inputs and outputs, where the inputs maps to argument ("placeholder") names+types, + * and outputs maps to expressions nodes. + */ + public class Signature { + + private final String name; + private final Map<String, String> inputs = new HashMap<>(); + private final Map<String, String> outputs = new HashMap<>(); + + Signature(String name) { + this.name = name; + } + + void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } + void output(String name, String expressionName) { outputs.put(name, expressionName); } + + /** Returns the result this is part of */ + ImportResult owner() { return ImportResult.this; } + + /** + * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name + * to argument (Placeholder) name in the owner of this + */ + public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); } + + /** Returns owner().arguments().get(inputs.get(name)), e.g the type of the argument this input references */ + public TensorType inputArgument(String inputName) { return owner().arguments().get(inputs.get(inputName)); } + + /** Returns an immutable list of the expression names of this */ + public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); } + + /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */ + public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); } + + @Override + public String toString() { return "signature '" + name + "'"; } + + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java new file mode 100644 index 00000000000..bac141644c6 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -0,0 +1,160 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.VariableTensor; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Matmul; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.Softmax; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleUnaryOperator; + +/** + * Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions. + * + * @author bratseth + */ +class OperationMapper { + + /* + A note on conversion from implicitly numbered to explicitly named dimensions: + Vespa tensor dimensions are explicitly named and thus have an explicit notion of being + 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation + comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation + around dimension renaming operations which mirrors those built into the TF operation definitions. + + To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost' + dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation + and the result is then renamed again (if necessary) to recover this convention across a full nested + computation. + + This requires us to track tensor types throughout the conversion. + */ + + private TensorConverter tensorConverter = new TensorConverter(); + + TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) { + ensureArguments(2, arguments, "join"); + TypedTensorFunction a = arguments.get(0); + TypedTensorFunction b = arguments.get(1); + if (a.type().rank() < b.type().rank()) + throw new IllegalArgumentException("Attempt to join " + a.type() + " and " + b.type() + ", " + + "but this is not supported when the second argument has a higher rank"); + + TensorFunction bFunction = b.function(); + + if (a.type().rank() > b.type().rank()) { + // Well now we have entered the wonderful world of "broadcasting" + // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html + // I'm not able to extract from that any unambiguous specification of which dimensions + // should be "stretched" when the tensor do not have the same number of dimensions. + // From trying this with TensorFlow it appears that the second tensor is matched to the + // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true. + // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first). + List<String> renameFrom = new ArrayList<>(); + List<String> renameTo = new ArrayList<>(); + int sizeDifference = a.type().rank() - b.type().rank(); + for (int i = 0; i < b.type().rank(); i++) { + renameFrom.add(b.type().dimensions().get(i).name()); + renameTo.add("d" + (sizeDifference + i)); + } + bFunction = new Rename(bFunction, renameFrom, renameTo); + } + + Join function = new Join(a.function(), bFunction, doubleFunction); + return new TypedTensorFunction(a.type(), function); // output type is a type by TF definition and a.rank>=b.rank + } + + TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) { + ensureArguments(1, arguments, "apply"); + TypedTensorFunction a = arguments.get(0); + + TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type()); + com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction); + return new TypedTensorFunction(resultType, function); + } + + TypedTensorFunction placeholder(NodeDef tfNode, ImportResult result) { + String name = tfNode.getName(); + TensorType type = result.arguments().get(name); + if (type == null) + throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name + + "', but there is no such placeholder"); + // Included literally in the expression and so must be produced by a separate macro in the rank profile + return new TypedTensorFunction(type, new VariableTensor(name)); + } + + TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult result) { + if ( ! tfNode.getName().endsWith("/read")) + throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " + + "nodes are only supported when reading variables"); + if (tfNode.getInputList().size() != 1) + throw new IllegalArgumentException("A Variable/read node must have one input but has " + + tfNode.getInputList().size()); + + String name = tfNode.getInput(0); + AttrValue shapes = tfNode.getAttrMap().get("_output_shapes"); + if (shapes == null) + throw new IllegalArgumentException("Referenced variable '" + name + "' is missing a tensor output shape"); + Session.Runner fetched = model.session().runner().fetch(name); + List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); + if ( importedTensors.size() != 1) + throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " + + importedTensors.size()); + Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0)); + result.constant(name, constant); + return new TypedTensorFunction(constant.type(), + new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")"))); + } + + TypedTensorFunction matmul(List<TypedTensorFunction> arguments) { + ensureArguments(2, arguments, "matmul"); + TypedTensorFunction a = arguments.get(0); + TypedTensorFunction b = arguments.get(1); + if (a.type().rank() < 2 || b.type().rank() < 2) + throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); + if (a.type().rank() != b.type().rank()) + throw new IllegalArgumentException("Tensors in matmul must have the same rank"); + + String afterLastDim = "d" + (a.type().rank() + 1); + // Let the first dimension of the second tensor be the same as the second dimension of the first + // and the second dimension of the second argument be not present in the first argument, while leaving the + // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication. + + // TODO: Check if transpose_a or transpose_b is set true and rename differently accordingly + + Rename renamedB = new Rename(b.function(), ImmutableList.of("d0", "d1"), + ImmutableList.of("d1", afterLastDim)); + Matmul matmul = new Matmul(a.function(), renamedB, "d1"); + return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"), + new Rename(matmul, afterLastDim, "d1")); + } + + TypedTensorFunction softmax(List<TypedTensorFunction> arguments) { + ensureArguments(1, arguments, "softmax"); + TypedTensorFunction a = arguments.get(0); + // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1 + String dimension = "d" + (a.type().rank() - 1); + Softmax softmax = new Softmax(a.function(), dimension); + return new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax); + } + + private void ensureArguments(int count, List<TypedTensorFunction> arguments, String operationName) { + if ( arguments.size() != count) + throw new IllegalArgumentException("Expected " + count + " arguments to " + operationName + + ", but got " + arguments.size()); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java new file mode 100644 index 00000000000..1960cf94591 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java @@ -0,0 +1,94 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; + +/** + * @author bratseth + */ +public class TensorConverter { + + public Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) { + TensorType type = toVespaTensorType(tfTensor.shape()); + Values values = readValuesOf(tfTensor); + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); + for (int i = 0; i < values.size(); i++) + builder.cellByDirectIndex(i, values.get(i)); + return builder.build(); + } + + private TensorType toVespaTensorType(long[] shape) { + TensorType.Builder b = new TensorType.Builder(); + int dimensionIndex = 0; + for (long dimensionSize : shape) { + if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ... + b.indexed("d" + (dimensionIndex++), dimensionSize); + } + return b.build(); + } + + private Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) { + switch (tfTensor.dataType()) { + case DOUBLE: return new DoubleValues(tfTensor); + case FLOAT: return new FloatValues(tfTensor); + // TODO: The rest + default: + throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + + tfTensor.dataType() + " to a Vespa tensor"); + } + } + + /** Allows reading values from buffers of various numeric types as bytes */ + private static abstract class Values { + + private final int size; + + protected Values(int size) { + this.size = size; + } + + abstract double get(int i); + + int size() { return size; } + + } + + private static class DoubleValues extends Values { + + private final DoubleBuffer values; + + DoubleValues(org.tensorflow.Tensor<?> tfTensor) { + super(tfTensor.numElements()); + values = DoubleBuffer.allocate(tfTensor.numElements()); + tfTensor.writeTo(values); + } + + @Override + double get(int i) { + return values.get(i); + } + + } + + private static class FloatValues extends Values { + + private final FloatBuffer values; + + FloatValues(org.tensorflow.Tensor<?> tfTensor) { + super(tfTensor.numElements()); + values = FloatBuffer.allocate(tfTensor.numElements()); + tfTensor.writeTo(values); + } + + @Override + double get(int i) { + return values.get(i); + } + + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java new file mode 100644 index 00000000000..4a6551adca7 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -0,0 +1,145 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.yolean.Exceptions; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.MetaGraphDef; +import org.tensorflow.framework.NodeDef; +import org.tensorflow.framework.SignatureDef; +import org.tensorflow.framework.TensorInfo; +import org.tensorflow.framework.TensorShapeProto; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Converts a saved TensorFlow model into a ranking expression and set of constants. + * + * @author bratseth + */ +public class TensorFlowImporter { + + private final OperationMapper operationMapper = new OperationMapper(); + + /** + * Imports a saved TensorFlow model from a directory. + * The model should be saved as a pbtxt file. + * The name of the model is taken as the db/pbtxt file name (not including the file ending). + * + * @param modelDir the directory containing the TensorFlow model files to import + */ + public ImportResult importModel(String modelDir) { + try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { + return importModel(model); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); + } + } + + /** Imports a TensorFlow model */ + public ImportResult importModel(SavedModelBundle model) { + try { + return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model); + } + catch (IOException e) { + throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e); + } + } + + private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) { + ImportResult result = new ImportResult(); + for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) { + ImportResult.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName" + + importInputs(signatureEntry.getValue().getInputsMap(), signature); + for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) { + String outputName = output.getKey(); + try { + NodeDef node = getNode(nameOf(output.getValue().getName()), graph.getGraphDef()); + importNode(node, graph.getGraphDef(), model, result); + signature.output(outputName, nameOf(output.getValue().getName())); + } + catch (IllegalArgumentException e) { + result.warn("Skipping output '" + outputName + "' of " + signature + + ": " + Exceptions.toMessageString(e)); + } + } + } + return result; + } + + private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult.Signature signature) { + inputInfoMap.forEach((key, value) -> { + String argumentName = nameOf(value.getName()); + TensorType argumentType = importTensorType(value.getTensorShape()); + // Arguments are (Placeholder) nodes, so not local to the signature: + signature.owner().argument(argumentName, argumentType); + signature.input(key, argumentName); + }); + } + + private TensorType importTensorType(TensorShapeProto tensorShape) { + TensorType.Builder b = new TensorType.Builder(); + for (TensorShapeProto.Dim dimension : tensorShape.getDimList()) { + int dimensionSize = (int)dimension.getSize(); + if (dimensionSize >= 0) + b.indexed("d" + b.rank(), dimensionSize); + else + b.indexed("d" + b.rank()); // unbound size + } + return b.build(); + } + + /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ + private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { + TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result); + // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output + // will be used + result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), new TensorFunctionNode(function.function()))); + return function; + } + + private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { + // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops + // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/ + switch (tfNode.getOp().toLowerCase()) { + case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add()); + case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos()); + case "placeholder" : return operationMapper.placeholder(tfNode, result); + case "identity" : return operationMapper.identity(tfNode, model, result); + case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result)); + case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result)); + default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported"); + } + } + + private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, + ImportResult result) { + return tfNode.getInputList().stream() + .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result)) + .collect(Collectors.toList()); + } + + private NodeDef getNode(String name, GraphDef graph) { + return graph.getNodeList().stream() + .filter(node -> node.getName().equals(name)) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("Could not find node '" + name + "'")); + } + + /** + * A method signature input and output has the form name:index. + * This returns the name part without the index. + */ + private String nameOf(String name) { + return name.split(":")[0]; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java new file mode 100644 index 00000000000..5712da77700 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java @@ -0,0 +1,24 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; + +/** + * A tensor function returning a specific tensor type + * + * @author bratseth + */ +final class TypedTensorFunction { + + private final TensorType type; + private final TensorFunction function; + + public TypedTensorFunction(TensorType type, TensorFunction function) { + this.type = type; + this.function = function; + } + + public TensorType type() { return type; } + public TensorFunction function() { return function; } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java index 71699b379b2..9da1ba40144 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; -import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -10,27 +9,26 @@ import com.yahoo.tensor.TensorType; import java.util.Collections; import java.util.Deque; import java.util.List; -import java.util.function.*; /** * A tensor generating function, whose arguments are determined by a tensor type - * + * * @author bratseth */ public class GeneratorLambdaFunctionNode extends CompositeNode { private final TensorType type; private final ExpressionNode generator; - + public GeneratorLambdaFunctionNode(TensorType type, ExpressionNode generator) { if ( ! type.dimensions().stream().allMatch(d -> d.size().isPresent())) - throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " + + throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " + "dimensions, but tried to generate " + type); // TODO: Verify that the function only accesses the given arguments this.type = type; this.generator = generator; } - + @Override public List<ExpressionNode> children() { return Collections.singletonList(generator); @@ -53,24 +51,24 @@ public class GeneratorLambdaFunctionNode extends CompositeNode { public Value evaluate(Context context) { return generator.evaluate(context); } - - /** + + /** * Returns this as an operator which converts a list of integers into a double */ - public IntegerListToDoubleLambda asIntegerListToDoubleOperator() { - return new IntegerListToDoubleLambda(); + public LongListToDoubleLambda asLongListToDoubleOperator() { + return new LongListToDoubleLambda(); } - private class IntegerListToDoubleLambda implements java.util.function.Function<List<Integer>, Double> { + private class LongListToDoubleLambda implements java.util.function.Function<List<Long>, Double> { @Override - public Double apply(List<Integer> arguments) { + public Double apply(List<Long> arguments) { MapContext context = new MapContext(); for (int i = 0; i < type.dimensions().size(); i++) context.put(type.dimensions().get(i).name(), arguments.get(i)); return evaluate(context).asDouble(); } - + @Override public String toString() { return GeneratorLambdaFunctionNode.this.toString(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java index 1f8db6e036c..ba765d07094 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java @@ -17,7 +17,7 @@ import java.util.Map; * @author bratseth */ public class SerializationContext { - + /** Expression functions indexed by name */ private final ImmutableMap<String, ExpressionFunction> functions; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index ce21e132980..8af3448ca6f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -21,22 +21,32 @@ import java.util.stream.Collectors; * * @author bratseth */ - @Beta +@Beta public class TensorFunctionNode extends CompositeNode { private final TensorFunction function; - + public TensorFunctionNode(TensorFunction function) { this.function = function; } + /** Returns the tensor function wrapped by this */ + public TensorFunction function() { return function; } + @Override public List<ExpressionNode> children() { return function.functionArguments().stream() - .map(f -> ((TensorFunctionExpressionNode)f).expression) + .map(this::toExpressionNode) .collect(Collectors.toList()); } + private ExpressionNode toExpressionNode(TensorFunction f) { + if (f instanceof TensorFunctionExpressionNode) + return ((TensorFunctionExpressionNode)f).expression; + else + return new TensorFunctionNode(f); + } + @Override public CompositeNode setChildren(List<ExpressionNode> children) { List<TensorFunction> wrappedChildren = children.stream() @@ -50,7 +60,7 @@ public class TensorFunctionNode extends CompositeNode { // Serialize as primitive return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this)); } - + @Override public Value evaluate(Context context) { return new TensorValue(function.evaluate(context)); @@ -59,8 +69,8 @@ public class TensorFunctionNode extends CompositeNode { public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) { return new TensorFunctionExpressionNode(node); } - - /** + + /** * A tensor function implemented by an expression. * This allows us to pass expressions as tensor function arguments. */ @@ -68,13 +78,13 @@ public class TensorFunctionNode extends CompositeNode { /** An expression which produces a tensor */ private final ExpressionNode expression; - + public TensorFunctionExpressionNode(ExpressionNode expression) { this.expression = expression; } - + @Override - public List<TensorFunction> functionArguments() { + public List<TensorFunction> functionArguments() { if (expression instanceof CompositeNode) return ((CompositeNode)expression).children().stream() .map(TensorFunctionExpressionNode::new) @@ -108,7 +118,7 @@ public class TensorFunctionNode extends CompositeNode { public String toString() { return toString(ExpressionNodeToStringContext.empty); } - + @Override public String toString(ToStringContext c) { if (c instanceof ExpressionNodeToStringContext) { @@ -121,14 +131,14 @@ public class TensorFunctionNode extends CompositeNode { } } - + /** Allows passing serialization context arguments through TensorFunctions */ private static class ExpressionNodeToStringContext implements ToStringContext { - + final SerializationContext context; final Deque<String> path; final CompositeNode parent; - + public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(null, null, null); public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 7821ab88b86..541738db8e0 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -467,7 +467,7 @@ ExpressionNode tensorGenerate() : } { <TENSOR> type = tensorTypeArgument() <LBRACE> generator = expression() <RBRACE> - { return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asIntegerListToDoubleOperator())); } + { return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asLongListToDoubleOperator())); } } ExpressionNode tensorRange() : diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py new file mode 100644 index 00000000000..a1861a1c981 --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py @@ -0,0 +1,89 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""A very simple MNIST classifier. + +See extensive documentation at +https://www.tensorflow.org/get_started/mnist/beginners +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +from tensorflow.examples.tutorials.mnist import input_data + +import tensorflow as tf + +FLAGS = None + + +def main(_): + # Import data + mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) + + # Create the model + x = tf.placeholder(tf.float32, [None, 784]) + W = tf.Variable(tf.zeros([784, 10])) + b = tf.Variable(tf.zeros([10])) + y = tf.matmul(x, W) + b + + # Define loss and optimizer + y_ = tf.placeholder(tf.float32, [None, 10]) + + # The raw formulation of cross-entropy, + # + # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)), + # reduction_indices=[1])) + # + # can be numerically unstable. + # + # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw + # outputs of 'y', and then average across the batch. + cross_entropy = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) + train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) + + sess = tf.InteractiveSession() + tf.global_variables_initializer().run() + # Train + for _ in range(1000): + batch_xs, batch_ys = mnist.train.next_batch(100) + sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) + + # Test trained model + correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + print(sess.run(accuracy, feed_dict={x: mnist.test.images, + y_: mnist.test.labels})) + + # Save the model + export_path = "saved" + print('Exporting trained model to ', export_path) + builder = tf.saved_model.builder.SavedModelBuilder(export_path) + signature = tf.saved_model.signature_def_utils.predict_signature_def(inputs = {'x':x}, outputs = {'y':y}) + builder.add_meta_graph_and_variables(sess, + [tf.saved_model.tag_constants.SERVING], + signature_def_map={'serving_default':signature}) + builder.save(as_text=True) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', + help='Directory for storing input data') + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt new file mode 100644 index 00000000000..8100dfd594d --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt @@ -0,0 +1,5039 @@ +saved_model_schema_version: 1 +meta_graphs { + meta_info_def { + stripped_op_list { + op { + name: "Add" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_STRING + } + } + } + } + op { + name: "ApplyGradientDescent" + input_arg { + name: "var" + type_attr: "T" + is_ref: true + } + input_arg { + name: "alpha" + type_attr: "T" + } + input_arg { + name: "delta" + type_attr: "T" + } + output_arg { + name: "out" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: false + } + } + } + op { + name: "ArgMax" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "dimension" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "output_type" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "output_type" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Assign" + input_arg { + name: "ref" + type_attr: "T" + is_ref: true + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output_ref" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + } + attr { + name: "validate_shape" + type: "bool" + default_value { + b: true + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: true + } + } + allows_uninitialized_input: true + } + op { + name: "BroadcastGradientArgs" + input_arg { + name: "s0" + type_attr: "T" + } + input_arg { + name: "s1" + type_attr: "T" + } + output_arg { + name: "r0" + type_attr: "T" + } + output_arg { + name: "r1" + type_attr: "T" + } + attr { + name: "T" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Cast" + input_arg { + name: "x" + type_attr: "SrcT" + } + output_arg { + name: "y" + type_attr: "DstT" + } + attr { + name: "SrcT" + type: "type" + } + attr { + name: "DstT" + type: "type" + } + } + op { + name: "ConcatV2" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + input_arg { + name: "axis" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 2 + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Const" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "value" + type: "tensor" + } + attr { + name: "dtype" + type: "type" + } + } + op { + name: "Equal" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type: DT_BOOL + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_QUINT8 + type: DT_QINT8 + type: DT_QINT32 + type: DT_STRING + type: DT_BOOL + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "ExpandDims" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "dim" + type_attr: "Tdim" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tdim" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Fill" + input_arg { + name: "dims" + type: DT_INT32 + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "FloorDiv" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Identity" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "MatMul" + input_arg { + name: "a" + type_attr: "T" + } + input_arg { + name: "b" + type_attr: "T" + } + output_arg { + name: "product" + type_attr: "T" + } + attr { + name: "transpose_a" + type: "bool" + default_value { + b: false + } + } + attr { + name: "transpose_b" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Maximum" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + is_commutative: true + } + op { + name: "Mean" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "MergeV2Checkpoints" + input_arg { + name: "checkpoint_prefixes" + type: DT_STRING + } + input_arg { + name: "destination_prefix" + type: DT_STRING + } + attr { + name: "delete_old_dirs" + type: "bool" + default_value { + b: true + } + } + is_stateful: true + } + op { + name: "Mul" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "NoOp" + } + op { + name: "Pack" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "T" + type: "type" + } + attr { + name: "axis" + type: "int" + default_value { + i: 0 + } + } + } + op { + name: "Placeholder" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + default_value { + shape { + unknown_rank: true + } + } + } + } + op { + name: "Prod" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "RealDiv" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Reshape" + input_arg { + name: "tensor" + type_attr: "T" + } + input_arg { + name: "shape" + type_attr: "Tshape" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tshape" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "RestoreV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + output_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "SaveV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + input_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "Shape" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "out_type" + } + attr { + name: "T" + type: "type" + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "ShardedFilename" + input_arg { + name: "basename" + type: DT_STRING + } + input_arg { + name: "shard" + type: DT_INT32 + } + input_arg { + name: "num_shards" + type: DT_INT32 + } + output_arg { + name: "filename" + type: DT_STRING + } + } + op { + name: "Slice" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "begin" + type_attr: "Index" + } + input_arg { + name: "size" + type_attr: "Index" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Index" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "SoftmaxCrossEntropyWithLogits" + input_arg { + name: "features" + type_attr: "T" + } + input_arg { + name: "labels" + type_attr: "T" + } + output_arg { + name: "loss" + type_attr: "T" + } + output_arg { + name: "backprop" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + } + op { + name: "StringJoin" + input_arg { + name: "inputs" + type: DT_STRING + number_attr: "N" + } + output_arg { + name: "output" + type: DT_STRING + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "separator" + type: "string" + default_value { + s: "" + } + } + } + op { + name: "Sub" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Sum" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Tile" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "multiples" + type_attr: "Tmultiples" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tmultiples" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "VariableV2" + output_arg { + name: "ref" + type_attr: "dtype" + is_ref: true + } + attr { + name: "shape" + type: "shape" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true + } + op { + name: "ZerosLike" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + } + tags: "serve" + tensorflow_version: "1.4.1" + tensorflow_git_version: "v1.4.0-19-ga52c8d9b01" + } + graph_def { + node { + name: "Placeholder" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + node { + name: "zeros" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "Variable" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "Variable/Assign" + op: "Assign" + input: "Variable" + input: "zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "Variable/read" + op: "Identity" + input: "Variable" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "zeros_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 10 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "Variable_1" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "Variable_1/Assign" + op: "Assign" + input: "Variable_1" + input: "zeros_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "Variable_1/read" + op: "Identity" + input: "Variable_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "MatMul" + op: "MatMul" + input: "Placeholder" + input: "Variable/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "add" + op: "Add" + input: "MatMul" + input: "Variable_1/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "Placeholder_1" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + node { + name: "Rank" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape" + op: "Shape" + input: "add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Rank_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape_1" + op: "Shape" + input: "add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Sub/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub" + op: "Sub" + input: "Rank_1" + input: "Sub/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice/begin" + op: "Pack" + input: "Sub" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice/size" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "Slice" + op: "Slice" + input: "Shape_1" + input: "Slice/begin" + input: "Slice/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "concat/values_0" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } + } + node { + name: "concat/axis" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "concat" + op: "ConcatV2" + input: "concat/values_0" + input: "Slice" + input: "concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + } + node { + name: "Reshape" + op: "Reshape" + input: "add" + input: "concat" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Rank_2" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape_2" + op: "Shape" + input: "Placeholder_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Sub_1/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub_1" + op: "Sub" + input: "Rank_2" + input: "Sub_1/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice_1/begin" + op: "Pack" + input: "Sub_1" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice_1/size" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "Slice_1" + op: "Slice" + input: "Shape_2" + input: "Slice_1/begin" + input: "Slice_1/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "concat_1/values_0" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } + } + node { + name: "concat_1/axis" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "concat_1" + op: "ConcatV2" + input: "concat_1/values_0" + input: "Slice_1" + input: "concat_1/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + } + node { + name: "Reshape_1" + op: "Reshape" + input: "Placeholder_1" + input: "concat_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "SoftmaxCrossEntropyWithLogits" + op: "SoftmaxCrossEntropyWithLogits" + input: "Reshape" + input: "Reshape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Sub_2/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub_2" + op: "Sub" + input: "Rank" + input: "Sub_2/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice_2/begin" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Slice_2/size" + op: "Pack" + input: "Sub_2" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice_2" + op: "Slice" + input: "Shape" + input: "Slice_2/begin" + input: "Slice_2/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Reshape_2" + op: "Reshape" + input: "SoftmaxCrossEntropyWithLogits" + input: "Slice_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Mean" + op: "Mean" + input: "Reshape_2" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "gradients/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } + } + node { + name: "gradients/Fill" + op: "Fill" + input: "gradients/Shape" + input: "gradients/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/Reshape/shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "gradients/Mean_grad/Reshape" + op: "Reshape" + input: "gradients/Fill" + input: "gradients/Mean_grad/Reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Shape" + op: "Shape" + input: "Reshape_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Mean_grad/Tile" + op: "Tile" + input: "gradients/Mean_grad/Reshape" + input: "gradients/Mean_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tmultiples" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Shape_1" + op: "Shape" + input: "Reshape_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Mean_grad/Shape_2" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "gradients/Mean_grad/Prod" + op: "Prod" + input: "gradients/Mean_grad/Shape_1" + input: "gradients/Mean_grad/Const" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Mean_grad/Const_1" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "gradients/Mean_grad/Prod_1" + op: "Prod" + input: "gradients/Mean_grad/Shape_2" + input: "gradients/Mean_grad/Const_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Mean_grad/Maximum/y" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "gradients/Mean_grad/Maximum" + op: "Maximum" + input: "gradients/Mean_grad/Prod_1" + input: "gradients/Mean_grad/Maximum/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/floordiv" + op: "FloorDiv" + input: "gradients/Mean_grad/Prod" + input: "gradients/Mean_grad/Maximum" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/Cast" + op: "Cast" + input: "gradients/Mean_grad/floordiv" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/truediv" + op: "RealDiv" + input: "gradients/Mean_grad/Tile" + input: "gradients/Mean_grad/Cast" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Reshape_2_grad/Shape" + op: "Shape" + input: "SoftmaxCrossEntropyWithLogits" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Reshape_2_grad/Reshape" + op: "Reshape" + input: "gradients/Mean_grad/truediv" + input: "gradients/Reshape_2_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/zeros_like" + op: "ZerosLike" + input: "SoftmaxCrossEntropyWithLogits:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: -1 + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + op: "ExpandDims" + input: "gradients/Reshape_2_grad/Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + op: "Mul" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + input: "SoftmaxCrossEntropyWithLogits:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Reshape_grad/Shape" + op: "Shape" + input: "add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Reshape_grad/Reshape" + op: "Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + input: "gradients/Reshape_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/add_grad/Shape" + op: "Shape" + input: "MatMul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/add_grad/Shape_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 10 + } + } + } + } + node { + name: "gradients/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/add_grad/Shape" + input: "gradients/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/add_grad/Sum" + op: "Sum" + input: "gradients/Reshape_grad/Reshape" + input: "gradients/add_grad/BroadcastGradientArgs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/add_grad/Reshape" + op: "Reshape" + input: "gradients/add_grad/Sum" + input: "gradients/add_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/add_grad/Sum_1" + op: "Sum" + input: "gradients/Reshape_grad/Reshape" + input: "gradients/add_grad/BroadcastGradientArgs:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/add_grad/Reshape_1" + op: "Reshape" + input: "gradients/add_grad/Sum_1" + input: "gradients/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/add_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/add_grad/Reshape" + input: "^gradients/add_grad/Reshape_1" + } + node { + name: "gradients/add_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/add_grad/Reshape" + input: "^gradients/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/add_grad/Reshape_1" + input: "^gradients/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/MatMul_grad/MatMul" + op: "MatMul" + input: "gradients/add_grad/tuple/control_dependency" + input: "Variable/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } + } + node { + name: "gradients/MatMul_grad/MatMul_1" + op: "MatMul" + input: "Placeholder" + input: "gradients/add_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "gradients/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/MatMul_grad/MatMul" + input: "^gradients/MatMul_grad/MatMul_1" + } + node { + name: "gradients/MatMul_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/MatMul_grad/MatMul" + input: "^gradients/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + } + node { + name: "gradients/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/MatMul_grad/MatMul_1" + input: "^gradients/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "GradientDescent/learning_rate" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5 + } + } + } + } + node { + name: "GradientDescent/update_Variable/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "Variable" + input: "GradientDescent/learning_rate" + input: "gradients/MatMul_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "GradientDescent/update_Variable_1/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "Variable_1" + input: "GradientDescent/learning_rate" + input: "gradients/add_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "GradientDescent" + op: "NoOp" + input: "^GradientDescent/update_Variable/ApplyGradientDescent" + input: "^GradientDescent/update_Variable_1/ApplyGradientDescent" + } + node { + name: "init" + op: "NoOp" + input: "^Variable/Assign" + input: "^Variable_1/Assign" + } + node { + name: "ArgMax/dimension" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "ArgMax" + op: "ArgMax" + input: "add" + input: "ArgMax/dimension" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_type" + value { + type: DT_INT64 + } + } + } + node { + name: "ArgMax_1/dimension" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "ArgMax_1" + op: "ArgMax" + input: "Placeholder_1" + input: "ArgMax_1/dimension" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_type" + value { + type: DT_INT64 + } + } + } + node { + name: "Equal" + op: "Equal" + input: "ArgMax" + input: "ArgMax_1" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Cast_1" + op: "Cast" + input: "Equal" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_BOOL + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Const_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Mean_1" + op: "Mean" + input: "Cast_1" + input: "Const_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "save/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "model" + } + } + } + } + node { + name: "save/StringJoin/inputs_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "_temp_6ca9fa5171ed4237a2fbcc27277e2864/part" + } + } + } + } + node { + name: "save/StringJoin" + op: "StringJoin" + input: "save/Const" + input: "save/StringJoin/inputs_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "separator" + value { + s: "" + } + } + } + node { + name: "save/num_shards" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "save/ShardedFilename/shard" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "save/ShardedFilename" + op: "ShardedFilename" + input: "save/StringJoin" + input: "save/ShardedFilename/shard" + input: "save/num_shards" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/SaveV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "Variable" + string_val: "Variable_1" + } + } + } + } + node { + name: "save/SaveV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "" + string_val: "" + } + } + } + } + node { + name: "save/SaveV2" + op: "SaveV2" + input: "save/ShardedFilename" + input: "save/SaveV2/tensor_names" + input: "save/SaveV2/shape_and_slices" + input: "Variable" + input: "Variable_1" + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + } + node { + name: "save/control_dependency" + op: "Identity" + input: "save/ShardedFilename" + input: "^save/SaveV2" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@save/ShardedFilename" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/MergeV2Checkpoints/checkpoint_prefixes" + op: "Pack" + input: "save/ShardedFilename" + input: "^save/control_dependency" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "save/MergeV2Checkpoints" + op: "MergeV2Checkpoints" + input: "save/MergeV2Checkpoints/checkpoint_prefixes" + input: "save/Const" + attr { + key: "delete_old_dirs" + value { + b: true + } + } + } + node { + name: "save/Identity" + op: "Identity" + input: "save/Const" + input: "^save/control_dependency" + input: "^save/MergeV2Checkpoints" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/RestoreV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "Variable" + } + } + } + } + node { + name: "save/RestoreV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2/tensor_names" + input: "save/RestoreV2/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign" + op: "Assign" + input: "Variable" + input: "save/RestoreV2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_1/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "Variable_1" + } + } + } + } + node { + name: "save/RestoreV2_1/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_1" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_1/tensor_names" + input: "save/RestoreV2_1/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_1" + op: "Assign" + input: "Variable_1" + input: "save/RestoreV2_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/restore_shard" + op: "NoOp" + input: "^save/Assign" + input: "^save/Assign_1" + } + node { + name: "save/restore_all" + op: "NoOp" + input: "^save/restore_shard" + } + versions { + producer: 24 + } + } + saver_def { + filename_tensor_name: "save/Const:0" + save_tensor_name: "save/Identity:0" + restore_op_name: "save/restore_all" + max_to_keep: 5 + sharded: true + keep_checkpoint_every_n_hours: 10000.0 + version: V2 + } + collection_def { + key: "train_op" + value { + node_list { + value: "GradientDescent" + } + } + } + collection_def { + key: "trainable_variables" + value { + bytes_list { + value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0" + value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0" + } + } + } + collection_def { + key: "variables" + value { + bytes_list { + value: "\n\nVariable:0\022\017Variable/Assign\032\017Variable/read:02\007zeros:0" + value: "\n\014Variable_1:0\022\021Variable_1/Assign\032\021Variable_1/read:02\tzeros_1:0" + } + } + } + signature_def { + key: "serving_default" + value { + inputs { + key: "x" + value { + name: "Placeholder:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + outputs { + key: "y" + value { + name: "add:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + method_name: "tensorflow/serving/predict" + } + } +} diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 Binary files differnew file mode 100644 index 00000000000..8474aa0a04c --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index Binary files differnew file mode 100644 index 00000000000..cfcdac20409 --- /dev/null +++ b/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 82e5d0cfe5b..3aa2d144f1f 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -4,9 +4,14 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.rule.*; -import com.yahoo.tensor.Tensor; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.IfNode; import org.junit.Test; + import static org.junit.Assert.assertEquals; /** @@ -83,7 +88,7 @@ public class EvaluationTestCase { tester.assertEvaluates(0, "sin(0)"); tester.assertEvaluates(1, "cos(0)"); tester.assertEvaluates(8, "pow(4/2,min(cos(0)*3,5))"); - + // Random feature (which is also a tensor function) (We expect to be able to parse it and look up a zero) tester.assertEvaluates(0, "random(1)"); tester.assertEvaluates(0, "random(foo)"); @@ -152,7 +157,7 @@ public class EvaluationTestCase { "tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", "!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }"); - + // -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence) tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "abs(tensor0)", "{ {x:0}:1, {x:1}:-2 }"); tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "acos(tensor0)", "{ {x:0}:1, {x:1}:1 }"); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java index ee2b1c147e3..ba0db4de5e1 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java @@ -34,7 +34,7 @@ public class EvaluationTester { } // TODO: Test both bound and unbound indexed - public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors, + public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors, String ... tensorArgumentStrings) { MapContext context = defaultContext.thawedCopy(); int argumentIndex = 0; @@ -46,7 +46,7 @@ public class EvaluationTester { argument = Tensor.from(typeFrom(argumentString, mappedTensors), argumentString); context.put("tensor" + (argumentIndex++), new TensorValue(argument)); } - return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context, + return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context, mappedTensors ? "Mapped tensors" : "Indexed tensors"); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java new file mode 100644 index 00000000000..0370fc7fc94 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java @@ -0,0 +1,118 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Ignore; +import org.junit.Test; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; + +import java.nio.FloatBuffer; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * @author bratseth + */ +public class Mnist_SoftmaxTestCase { + + @Ignore + @Test + public void testImporting() { + String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved"; + SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); + ImportResult result = new TensorFlowImporter().importModel(model); + + // Check logged messages + result.warnings().forEach(System.err::println); + assertEquals(0, result.warnings().size()); + + // Check constants + assertEquals(2, result.constants().size()); + + Tensor constant0 = result.constants().get("Variable"); + assertNotNull(constant0); + assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), + constant0.type()); + assertEquals(7840, constant0.size()); + + Tensor constant1 = result.constants().get("Variable_1"); + assertNotNull(constant1); + assertEquals(new TensorType.Builder().indexed("d0", 10).build(), + constant1.type()); + assertEquals(10, constant1.size()); + + // Check signatures + assertEquals(1, result.signatures().size()); + ImportResult.Signature signature = result.signatures().get("serving_default"); + assertNotNull(signature); + + // ... signature inputs + assertEquals(1, signature.inputs().size()); + TensorType argument0 = signature.inputArgument("x"); + assertNotNull(argument0); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); + + // ... signature outputs + assertEquals(1, signature.outputs().size()); + RankingExpression output = signature.outputExpression("y"); + assertNotNull(output); + assertEquals("add", output.getName()); + assertEquals("" + + "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " + + "rename(constant(Variable_1), d0, d1), " + + "f(a,b)(a + b))", + toNonPrimitiveString(output)); + + // Test execution + assertEqualResult(model, result, "Variable/read"); + assertEqualResult(model, result, "Variable_1/read"); + assertEqualResult(model, result, "MatMul"); + assertEqualResult(model, result, "add"); + } + + private void assertEqualResult(SavedModelBundle model, ImportResult result, String operationName) { + Tensor tfResult = tensorFlowExecute(model, operationName); + Context context = contextFrom(result); + Tensor placeholder = placeholderArgument(); + context.put("Placeholder", new TensorValue(placeholder)); + Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor(); + assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult); + } + + private Tensor tensorFlowExecute(SavedModelBundle model, String operationName) { + Session.Runner runner = model.session().runner(); + org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784)); + runner.feed("Placeholder", placeholder); + List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); + assertEquals(1, results.size()); + return new TensorConverter().toVespaTensor(results.get(0)); + } + + private Context contextFrom(ImportResult result) { + MapContext context = new MapContext(); + result.constants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + return context; + } + + private String toNonPrimitiveString(RankingExpression expression) { + // toString on the wrapping expression will map to primitives, which is harder to read + return ((TensorFunctionNode)expression.getRoot()).function().toString(); + } + + private Tensor placeholderArgument() { + int size = 784; + Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", size).build()); + for (int i = 0; i < size; i++) + b.cell(0, 0, i); + return b.build(); + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java index dde9d4bf21e..1960c1fe876 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java @@ -59,7 +59,7 @@ public class TensorConformanceTest { try { ObjectMapper mapper = new ObjectMapper(); JsonNode node = mapper.readTree(test); - + if (node.has("num_tests")) { Assert.assertEquals(node.get("num_tests").asInt(), count); return true; @@ -67,7 +67,7 @@ public class TensorConformanceTest { if (!node.has("expression")) { return true; // ignore } - + String expression = node.get("expression").asText(); MapContext context = getInput(node.get("inputs")); Tensor expect = getTensor(node.get("result").get("expect").asText()); diff --git a/searchlib/src/vespa/searchlib/attribute/postingchange.cpp b/searchlib/src/vespa/searchlib/attribute/postingchange.cpp index 9957d162d9d..702ff0fc5cf 100644 --- a/searchlib/src/vespa/searchlib/attribute/postingchange.cpp +++ b/searchlib/src/vespa/searchlib/attribute/postingchange.cpp @@ -6,6 +6,7 @@ #include "postinglistattribute.h" #include <vespa/searchlib/common/growablebitvector.h> #include <vespa/vespalib/util/array.hpp> +#include <vespa/vespalib/stllike/hash_map.hpp> namespace search { diff --git a/searchlib/src/vespa/searchlib/common/foregroundtaskexecutor.cpp b/searchlib/src/vespa/searchlib/common/foregroundtaskexecutor.cpp index 990117a71ce..1ab3c6b8b51 100644 --- a/searchlib/src/vespa/searchlib/common/foregroundtaskexecutor.cpp +++ b/searchlib/src/vespa/searchlib/common/foregroundtaskexecutor.cpp @@ -2,6 +2,7 @@ #include "foregroundtaskexecutor.h" #include <vespa/vespalib/util/threadstackexecutor.h> +#include <vespa/vespalib/stllike/hash_map.hpp> using vespalib::ThreadStackExecutor; diff --git a/searchlib/src/vespa/searchlib/common/sequencedtaskexecutor.cpp b/searchlib/src/vespa/searchlib/common/sequencedtaskexecutor.cpp index 45004db2615..446c9ec39ec 100644 --- a/searchlib/src/vespa/searchlib/common/sequencedtaskexecutor.cpp +++ b/searchlib/src/vespa/searchlib/common/sequencedtaskexecutor.cpp @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "sequencedtaskexecutor.h" +#include <vespa/vespalib/stllike/hash_map.hpp> using vespalib::BlockingThreadStackExecutor; diff --git a/searchlib/src/vespa/searchlib/docstore/visitcache.cpp b/searchlib/src/vespa/searchlib/docstore/visitcache.cpp index 8f73c9862ae..7e881d8de76 100644 --- a/searchlib/src/vespa/searchlib/docstore/visitcache.cpp +++ b/searchlib/src/vespa/searchlib/docstore/visitcache.cpp @@ -209,8 +209,7 @@ VisitCache::Cache::locateAndInvalidateOtherSubsets(const LockGuard & cacheGuard, CompressedBlobSet VisitCache::read(const IDocumentStore::LidVector & lids) const { - KeySet key(lids); - return _cache->readSet(lids); + return _cache->readSet(KeySet(lids)); } void diff --git a/searchlib/src/vespa/searchlib/docstore/visitcache.h b/searchlib/src/vespa/searchlib/docstore/visitcache.h index 1bf867c5580..effc6c19a21 100644 --- a/searchlib/src/vespa/searchlib/docstore/visitcache.h +++ b/searchlib/src/vespa/searchlib/docstore/visitcache.h @@ -20,7 +20,7 @@ class KeySet { public: KeySet() : _keys() { } KeySet(uint32_t key); - KeySet(const IDocumentStore::LidVector &keys); + explicit KeySet(const IDocumentStore::LidVector &keys); uint32_t hash() const { return _keys.empty() ? 0 : _keys[0]; } bool operator==(const KeySet &rhs) const { return _keys == rhs._keys; } bool operator<(const KeySet &rhs) const { return _keys < rhs._keys; } diff --git a/staging_vespalib/src/vespa/vespalib/stllike/lrucache_map.hpp b/staging_vespalib/src/vespa/vespalib/stllike/lrucache_map.hpp index fe57de093dd..61147229497 100644 --- a/staging_vespalib/src/vespa/vespalib/stllike/lrucache_map.hpp +++ b/staging_vespalib/src/vespa/vespalib/stllike/lrucache_map.hpp @@ -110,6 +110,7 @@ void lrucache_map<P>::erase(const K & key) { internal_iterator it = HashTable::find(key); if (it != HashTable::end()) { + next_t h = HashTable::hash(key); onRemove(key); LV & v = it->second; if (v._prev != LinkedValueBase::npos) { @@ -122,7 +123,7 @@ lrucache_map<P>::erase(const K & key) { } else { _tail = v._prev; } - HashTable::erase(*this, it); + HashTable::erase(*this, h, it); } } @@ -202,7 +203,7 @@ lrucache_map<P>::removeOld() { { _tail = last->second._prev; HashTable::getByInternalIndex(_tail).second._next = LinkedValueBase::npos; - HashTable::erase(*this, HashTable::find(last->first)); + HashTable::erase(*this, HashTable::hash(last->first), HashTable::find(last->first)); } } } diff --git a/storage/src/tests/distributor/blockingoperationstartertest.cpp b/storage/src/tests/distributor/blockingoperationstartertest.cpp index c2fdc25cebf..0160f5c9e51 100644 --- a/storage/src/tests/distributor/blockingoperationstartertest.cpp +++ b/storage/src/tests/distributor/blockingoperationstartertest.cpp @@ -60,7 +60,7 @@ BlockingOperationStarterTest::testOperationNotBlockedWhenNoMessagesPending() { CPPUNIT_ASSERT(_operationStarter->start(createMockOperation(), OperationStarter::Priority(0))); - CPPUNIT_ASSERT_EQUAL(std::string("Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri 0\n"), + CPPUNIT_ASSERT_EQUAL(std::string("Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri 0\n"), _starterImpl->toString()); } diff --git a/storage/src/tests/distributor/bucketdbupdatertest.cpp b/storage/src/tests/distributor/bucketdbupdatertest.cpp index b9e33ea8d26..ff442114c4c 100644 --- a/storage/src/tests/distributor/bucketdbupdatertest.cpp +++ b/storage/src/tests/distributor/bucketdbupdatertest.cpp @@ -4,6 +4,7 @@ #include <iomanip> #include <vespa/storageapi/message/persistence.h> #include <vespa/storage/distributor/bucketdbupdater.h> +#include <vespa/storage/distributor/distributormetricsset.h> #include <vespa/storage/distributor/pending_bucket_space_db_transition.h> #include <vespa/storage/distributor/outdated_nodes_map.h> #include <vespa/vespalib/io/fileutil.h> @@ -22,8 +23,7 @@ using namespace storage::lib; using document::test::makeDocumentBucket; using document::test::makeBucketSpace; -namespace storage { -namespace distributor { +namespace storage::distributor { class BucketDBUpdaterTest : public CppUnit::TestFixture, public DistributorTestUtil @@ -2499,5 +2499,4 @@ void BucketDBUpdaterTest::batch_update_from_distributor_change_does_not_mark_div "0:5/1/2/3|1:5/7/8/9", true)); } -} // distributor -} // storage +} diff --git a/storage/src/tests/distributor/externaloperationhandlertest.cpp b/storage/src/tests/distributor/externaloperationhandlertest.cpp index 683352e6b09..a0b8cd424ac 100644 --- a/storage/src/tests/distributor/externaloperationhandlertest.cpp +++ b/storage/src/tests/distributor/externaloperationhandlertest.cpp @@ -2,15 +2,14 @@ #include <tests/distributor/distributortestutil.h> #include <vespa/storage/distributor/externaloperationhandler.h> -#include <vespa/storage/distributor/operation_sequencer.h> -#include <vespa/storageapi/message/persistence.h> #include <vespa/storage/distributor/distributor.h> +#include <vespa/storage/distributor/distributormetricsset.h> +#include <vespa/storageapi/message/persistence.h> #include <vespa/document/test/make_document_bucket.h> using document::test::makeDocumentBucket; -namespace storage { -namespace distributor { +namespace storage::distributor { class ExternalOperationHandlerTest : public CppUnit::TestFixture, public DistributorTestUtil @@ -471,5 +470,4 @@ void ExternalOperationHandlerTest::sequencing_can_be_explicitly_config_disabled( // pseudo-locks in the sequencer. I.e. if we get a RemoveLocation with id.user==123456, this // prevents any handles from being acquired to any GID under location BucketId(32, 123456). -} // distributor -} // storage +} diff --git a/storage/src/tests/distributor/getoperationtest.cpp b/storage/src/tests/distributor/getoperationtest.cpp index 8bb8e24c17a..80c093dea87 100644 --- a/storage/src/tests/distributor/getoperationtest.cpp +++ b/storage/src/tests/distributor/getoperationtest.cpp @@ -5,16 +5,13 @@ #include <vespa/document/repo/documenttyperepo.h> #include <vespa/storage/distributor/externaloperationhandler.h> #include <vespa/storage/distributor/distributor.h> +#include <vespa/storage/distributor/distributormetricsset.h> #include <tests/distributor/distributortestutil.h> #include <vespa/storageapi/message/persistence.h> -#include <tests/common/dummystoragelink.h> #include <vespa/document/test/make_document_bucket.h> -#include <vespa/vdstestlib/cppunit/macros.h> #include <vespa/vespalib/testkit/testapp.h> #include <vespa/config/helper/configgetter.hpp> #include <iomanip> -#include <iostream> -#include <memory> #include <vespa/storage/distributor/operations/external/getoperation.h> using std::shared_ptr; @@ -23,8 +20,7 @@ using document::DocumenttypesConfig; using config::FileSpec; using document::test::makeDocumentBucket; -namespace storage { -namespace distributor { +namespace storage::distributor { class GetOperationTest : public CppUnit::TestFixture, public DistributorTestUtil { CPPUNIT_TEST_SUITE(GetOperationTest); @@ -568,5 +564,4 @@ GetOperationTest::canGetDocumentsWhenAllReplicaNodesRetired() _sender.getCommands(true)); } -} // distributor -} // storage +} diff --git a/storage/src/tests/distributor/idealstatemanagertest.cpp b/storage/src/tests/distributor/idealstatemanagertest.cpp index 0c695f9a3d4..bca15d702f5 100644 --- a/storage/src/tests/distributor/idealstatemanagertest.cpp +++ b/storage/src/tests/distributor/idealstatemanagertest.cpp @@ -143,9 +143,9 @@ IdealStateManagerTest::testClearActiveOnNodeDown() } CPPUNIT_ASSERT_EQUAL( - std::string("setbucketstate to [0] Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)) (pri 100)\n" - "setbucketstate to [0] Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000002)) (pri 100)\n" - "setbucketstate to [0] Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000003)) (pri 100)\n"), + std::string("setbucketstate to [0] Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)) (pri 100)\n" + "setbucketstate to [0] Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000002)) (pri 100)\n" + "setbucketstate to [0] Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000003)) (pri 100)\n"), _distributor->getActiveIdealStateOperations()); setSystemState(lib::ClusterState("distributor:1 storage:3 .0.s:d")); @@ -169,19 +169,19 @@ IdealStateManagerTest::testRecheckWhenActive() tick(); CPPUNIT_ASSERT_EQUAL( - std::string("setbucketstate to [0] Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)) (pri 100)\n"), + std::string("setbucketstate to [0] Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)) (pri 100)\n"), _distributor->getActiveIdealStateOperations()); tick(); CPPUNIT_ASSERT_EQUAL( - std::string("setbucketstate to [0] Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)) (pri 100)\n"), + std::string("setbucketstate to [0] Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)) (pri 100)\n"), _distributor->getActiveIdealStateOperations()); tick(); CPPUNIT_ASSERT_EQUAL( - std::string("setbucketstate to [0] Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)) (pri 100)\n"), + std::string("setbucketstate to [0] Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)) (pri 100)\n"), _distributor->getActiveIdealStateOperations()); } diff --git a/storage/src/tests/distributor/maintenanceschedulertest.cpp b/storage/src/tests/distributor/maintenanceschedulertest.cpp index 7e3d92053f8..db0347617f0 100644 --- a/storage/src/tests/distributor/maintenanceschedulertest.cpp +++ b/storage/src/tests/distributor/maintenanceschedulertest.cpp @@ -70,7 +70,7 @@ MaintenanceSchedulerTest::testOperationIsScheduled() { _priorityDb->setPriority(PrioritizedBucket(makeDocumentBucket(BucketId(16, 1)), Priority::MEDIUM)); _scheduler->tick(MaintenanceScheduler::NORMAL_SCHEDULING_MODE); - CPPUNIT_ASSERT_EQUAL(std::string("Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri 100\n"), + CPPUNIT_ASSERT_EQUAL(std::string("Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri 100\n"), _operationStarter->toString()); } @@ -89,9 +89,9 @@ MaintenanceSchedulerTest::testSuppressLowPrioritiesInEmergencyMode() _priorityDb->setPriority(PrioritizedBucket(makeDocumentBucket(BucketId(16, 2)), Priority::VERY_HIGH)); CPPUNIT_ASSERT_EQUAL(WaitTimeMs(0), _scheduler->tick(MaintenanceScheduler::RECOVERY_SCHEDULING_MODE)); CPPUNIT_ASSERT_EQUAL(WaitTimeMs(1), _scheduler->tick(MaintenanceScheduler::RECOVERY_SCHEDULING_MODE)); - CPPUNIT_ASSERT_EQUAL(std::string("Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000002)), pri 0\n"), + CPPUNIT_ASSERT_EQUAL(std::string("Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000002)), pri 0\n"), _operationStarter->toString()); - CPPUNIT_ASSERT_EQUAL(std::string("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri HIGH)\n"), + CPPUNIT_ASSERT_EQUAL(std::string("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri HIGH)\n"), _priorityDb->toString()); } @@ -102,7 +102,7 @@ MaintenanceSchedulerTest::testPriorityNotClearedIfOperationNotStarted() _operationStarter->setShouldStartOperations(false); WaitTimeMs waitMs(_scheduler->tick(MaintenanceScheduler::NORMAL_SCHEDULING_MODE)); CPPUNIT_ASSERT_EQUAL(WaitTimeMs(1), waitMs); - CPPUNIT_ASSERT_EQUAL(std::string("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri HIGH)\n"), + CPPUNIT_ASSERT_EQUAL(std::string("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri HIGH)\n"), _priorityDb->toString()); } diff --git a/storage/src/tests/distributor/messagesenderstub.h b/storage/src/tests/distributor/messagesenderstub.h index e5bae9c6702..b86863890a1 100644 --- a/storage/src/tests/distributor/messagesenderstub.h +++ b/storage/src/tests/distributor/messagesenderstub.h @@ -3,6 +3,7 @@ #include <vespa/storage/distributor/distributormessagesender.h> #include <cassert> +#include <vector> namespace storage { diff --git a/storage/src/tests/distributor/pendingmessagetrackertest.cpp b/storage/src/tests/distributor/pendingmessagetrackertest.cpp index cca55a11b38..7adadd226d7 100644 --- a/storage/src/tests/distributor/pendingmessagetrackertest.cpp +++ b/storage/src/tests/distributor/pendingmessagetrackertest.cpp @@ -254,7 +254,7 @@ PendingMessageTrackerTest::testSimple() CPPUNIT_ASSERT_CONTAIN( std::string( - "<b>Bucket(BucketSpace(0x0000000000000000), BucketId(0x40000000000004d2))</b>\n" + "<b>Bucket(BucketSpace(0x0000000000000001), BucketId(0x40000000000004d2))</b>\n" "<ul>\n" "<li><i>Node 0</i>: <b>1970-01-01 00:00:01</b> " "Remove(BucketId(0x40000000000004d2), " @@ -341,14 +341,14 @@ PendingMessageTrackerTest::testMultipleMessages() CPPUNIT_ASSERT_CONTAIN( std::string( - "<b>Bucket(BucketSpace(0x0000000000000000), BucketId(0x40000000000004d2))</b>\n" + "<b>Bucket(BucketSpace(0x0000000000000001), BucketId(0x40000000000004d2))</b>\n" "<ul>\n" "<li><i>Node 0</i>: <b>1970-01-01 00:00:01</b> Remove(BucketId(0x40000000000004d2), userdoc:footype:1234:0, timestamp 1000)</li>\n" "<li><i>Node 0</i>: <b>1970-01-01 00:00:01</b> Remove(BucketId(0x40000000000004d2), userdoc:footype:1234:2, timestamp 1002)</li>\n" "<li><i>Node 1</i>: <b>1970-01-01 00:00:01</b> Remove(BucketId(0x40000000000004d2), userdoc:footype:1234:1, timestamp 1001)</li>\n" "<li><i>Node 1</i>: <b>1970-01-01 00:00:01</b> Remove(BucketId(0x40000000000004d2), userdoc:footype:1234:3, timestamp 1003)</li>\n" "</ul>\n" - "<b>Bucket(BucketSpace(0x0000000000000000), BucketId(0x40000000000011d7))</b>\n" + "<b>Bucket(BucketSpace(0x0000000000000001), BucketId(0x40000000000011d7))</b>\n" "<ul>\n" "<li><i>Node 0</i>: <b>1970-01-01 00:00:01</b> Remove(BucketId(0x40000000000011d7), userdoc:footype:4567:0, timestamp 2000)</li>\n" "<li><i>Node 0</i>: <b>1970-01-01 00:00:01</b> Remove(BucketId(0x40000000000011d7), userdoc:footype:4567:2, timestamp 2002)</li>\n" diff --git a/storage/src/tests/distributor/putoperationtest.cpp b/storage/src/tests/distributor/putoperationtest.cpp index 7f54e163006..e621ef8645c 100644 --- a/storage/src/tests/distributor/putoperationtest.cpp +++ b/storage/src/tests/distributor/putoperationtest.cpp @@ -282,7 +282,7 @@ PutOperationTest::testNodeRemovedOnReply() CPPUNIT_ASSERT_EQUAL(std::string( "PutReply(doc:test:test, BucketId(0x0000000000000000), " "timestamp 100) ReturnCode(BUCKET_DELETED, " - "Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000008b13)) was deleted from nodes [0] " + "Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000008b13)) was deleted from nodes [0] " "after message was sent but before it was done. " "Sent to [1,0])"), _sender.getLastReply()); diff --git a/storage/src/tests/distributor/simplemaintenancescannertest.cpp b/storage/src/tests/distributor/simplemaintenancescannertest.cpp index 66a2d3efa6c..394df6024fd 100644 --- a/storage/src/tests/distributor/simplemaintenancescannertest.cpp +++ b/storage/src/tests/distributor/simplemaintenancescannertest.cpp @@ -92,7 +92,7 @@ void SimpleMaintenanceScannerTest::testPrioritizeSingleBucket() { addBucketToDb(1); - std::string expected("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri VERY_HIGH)\n"); + std::string expected("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri VERY_HIGH)\n"); auto scanResult = _scanner->scanNext(); CPPUNIT_ASSERT(!scanResult.isDone()); @@ -141,9 +141,9 @@ SimpleMaintenanceScannerTest::testPrioritizeMultipleBuckets() addBucketToDb(1); addBucketToDb(2); addBucketToDb(3); - std::string expected("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri VERY_HIGH)\n" - "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000002)), pri VERY_HIGH)\n" - "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000003)), pri VERY_HIGH)\n"); + std::string expected("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri VERY_HIGH)\n" + "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000002)), pri VERY_HIGH)\n" + "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000003)), pri VERY_HIGH)\n"); CPPUNIT_ASSERT(scanEntireDatabase(3)); CPPUNIT_ASSERT_EQUAL(sortLines(expected), @@ -168,8 +168,8 @@ SimpleMaintenanceScannerTest::testReset() addBucketToDb(3); CPPUNIT_ASSERT(scanEntireDatabase(2)); - std::string expected("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri VERY_HIGH)\n" - "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000003)), pri VERY_HIGH)\n"); + std::string expected("PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri VERY_HIGH)\n" + "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000003)), pri VERY_HIGH)\n"); CPPUNIT_ASSERT_EQUAL(expected, _priorityDb->toString()); addBucketToDb(2); @@ -179,9 +179,9 @@ SimpleMaintenanceScannerTest::testReset() _scanner->reset(); CPPUNIT_ASSERT(scanEntireDatabase(3)); - expected = "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri VERY_HIGH)\n" - "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000002)), pri VERY_HIGH)\n" - "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000003)), pri VERY_HIGH)\n"; + expected = "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri VERY_HIGH)\n" + "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000002)), pri VERY_HIGH)\n" + "PrioritizedBucket(Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000003)), pri VERY_HIGH)\n"; CPPUNIT_ASSERT_EQUAL(sortLines(expected), sortLines(_priorityDb->toString())); } diff --git a/storage/src/tests/distributor/throttlingoperationstartertest.cpp b/storage/src/tests/distributor/throttlingoperationstartertest.cpp index c3aebcafe06..c3290a8c0f6 100644 --- a/storage/src/tests/distributor/throttlingoperationstartertest.cpp +++ b/storage/src/tests/distributor/throttlingoperationstartertest.cpp @@ -70,7 +70,7 @@ ThrottlingOperationStarterTest::testOperationStartingIsForwardedToImplementation { CPPUNIT_ASSERT(_operationStarter->start(createMockOperation(), OperationStarter::Priority(0))); - CPPUNIT_ASSERT_EQUAL(std::string("Bucket(BucketSpace(0x0000000000000000), BucketId(0x4000000000000001)), pri 0\n"), + CPPUNIT_ASSERT_EQUAL(std::string("Bucket(BucketSpace(0x0000000000000001), BucketId(0x4000000000000001)), pri 0\n"), _starterImpl->toString()); } diff --git a/storage/src/tests/distributor/visitoroperationtest.cpp b/storage/src/tests/distributor/visitoroperationtest.cpp index 972ccf41bfe..17d1bc288ca 100644 --- a/storage/src/tests/distributor/visitoroperationtest.cpp +++ b/storage/src/tests/distributor/visitoroperationtest.cpp @@ -9,6 +9,7 @@ #include <vespa/storageapi/message/state.h> #include <vespa/storage/distributor/operations/external/visitoroperation.h> #include <vespa/storage/distributor/operations/external/visitororder.h> +#include <vespa/storage/distributor/distributormetricsset.h> #include <tests/distributor/distributortestutil.h> #include <vespa/storage/distributor/distributor.h> #include <tests/common/dummystoragelink.h> @@ -21,8 +22,7 @@ using namespace storage::lib; using namespace std::string_literals; using document::test::makeBucketSpace; -namespace storage { -namespace distributor { +namespace storage::distributor { class VisitorOperationTest : public CppUnit::TestFixture, public DistributorTestUtil { @@ -1674,5 +1674,4 @@ VisitorOperationTest::statistical_metrics_not_updated_on_wrong_distribution() CPPUNIT_ASSERT_EQUAL(0.0, defaultVisitorMetrics().latency.getCount()); } -} // distributor -} // storage +} diff --git a/storage/src/tests/persistence/splitbitdetectortest.cpp b/storage/src/tests/persistence/splitbitdetectortest.cpp index c20aae373ec..01baa8f4e98 100644 --- a/storage/src/tests/persistence/splitbitdetectortest.cpp +++ b/storage/src/tests/persistence/splitbitdetectortest.cpp @@ -8,6 +8,7 @@ #include <vespa/persistence/spi/test.h> #include <vespa/document/base/testdocman.h> #include <vespa/document/bucket/bucketidfactory.h> +#include <vespa/metrics/loadmetric.h> #include <algorithm> using storage::spi::test::makeSpiBucket; diff --git a/storage/src/tests/storageserver/CMakeLists.txt b/storage/src/tests/storageserver/CMakeLists.txt index 38fb0f6235a..95faf7e433e 100644 --- a/storage/src/tests/storageserver/CMakeLists.txt +++ b/storage/src/tests/storageserver/CMakeLists.txt @@ -5,6 +5,7 @@ vespa_add_library(storage_teststorageserver TEST bucketintegritycheckertest.cpp changedbucketownershiphandlertest.cpp communicationmanagertest.cpp + configurable_bucket_resolver_test.cpp documentapiconvertertest.cpp mergethrottlertest.cpp priorityconvertertest.cpp diff --git a/storage/src/tests/storageserver/configurable_bucket_resolver_test.cpp b/storage/src/tests/storageserver/configurable_bucket_resolver_test.cpp new file mode 100644 index 00000000000..3f121240065 --- /dev/null +++ b/storage/src/tests/storageserver/configurable_bucket_resolver_test.cpp @@ -0,0 +1,137 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/storage/storageserver/configurable_bucket_resolver.h> +#include <vespa/document/base/documentid.h> +#include <vespa/persistence/spi/fixed_bucket_spaces.h> +#include <cppunit/extensions/HelperMacros.h> + +namespace storage { + +using document::DocumentId; + +struct ConfigurableBucketResolverTest : CppUnit::TestFixture { + CPPUNIT_TEST_SUITE(ConfigurableBucketResolverTest); + CPPUNIT_TEST(bucket_space_from_name_is_defined_for_default_space); + CPPUNIT_TEST(bucket_space_from_name_is_defined_for_global_space); + CPPUNIT_TEST(bucket_space_from_name_throws_exception_for_unknown_space); + CPPUNIT_TEST(name_from_bucket_space_is_defined_for_default_space); + CPPUNIT_TEST(name_from_bucket_space_is_defined_for_global_space); + CPPUNIT_TEST(name_from_bucket_space_throws_exception_for_unknown_space); + CPPUNIT_TEST(known_bucket_space_is_resolved_from_document_id); + CPPUNIT_TEST(unknown_bucket_space_in_id_throws_exception); + CPPUNIT_TEST(can_create_resolver_from_bucket_space_config); + CPPUNIT_TEST_SUITE_END(); + + using BucketSpaceMapping = ConfigurableBucketResolver::BucketSpaceMapping; + + BucketSpaceMapping create_simple_mapping() { + return {{"foo", spi::FixedBucketSpaces::default_space()}, + {"bar", spi::FixedBucketSpaces::default_space()}, + {"baz", spi::FixedBucketSpaces::global_space()}}; + } + + ConfigurableBucketResolver create_empty_resolver() { + return ConfigurableBucketResolver({}); + } + + ConfigurableBucketResolver create_simple_resolver() { + return ConfigurableBucketResolver(create_simple_mapping()); + } + + void bucket_space_from_name_is_defined_for_default_space(); + void bucket_space_from_name_is_defined_for_global_space(); + void bucket_space_from_name_throws_exception_for_unknown_space(); + void name_from_bucket_space_is_defined_for_default_space(); + void name_from_bucket_space_is_defined_for_global_space(); + void name_from_bucket_space_throws_exception_for_unknown_space(); + void known_bucket_space_is_resolved_from_document_id(); + void unknown_bucket_space_in_id_throws_exception(); + void can_create_resolver_from_bucket_space_config(); +}; + +CPPUNIT_TEST_SUITE_REGISTRATION(ConfigurableBucketResolverTest); + +// TODO reduce overlap with FixedBucketSpacesTest +void ConfigurableBucketResolverTest::bucket_space_from_name_is_defined_for_default_space() { + auto space = create_empty_resolver().bucketSpaceFromName("default"); + CPPUNIT_ASSERT_EQUAL(spi::FixedBucketSpaces::default_space(), space); +} + +void ConfigurableBucketResolverTest::bucket_space_from_name_is_defined_for_global_space() { + auto space = create_empty_resolver().bucketSpaceFromName("global"); + CPPUNIT_ASSERT_EQUAL(spi::FixedBucketSpaces::global_space(), space); +} + +void ConfigurableBucketResolverTest::bucket_space_from_name_throws_exception_for_unknown_space() { + try { + create_empty_resolver().bucketSpaceFromName("bjarne"); + CPPUNIT_FAIL("Expected exception on unknown bucket space name"); + } catch (spi::UnknownBucketSpaceException& e) { + } +} + +void ConfigurableBucketResolverTest::name_from_bucket_space_is_defined_for_default_space() { + CPPUNIT_ASSERT_EQUAL(vespalib::string("default"), + create_empty_resolver().nameFromBucketSpace(spi::FixedBucketSpaces::default_space())); +} + +void ConfigurableBucketResolverTest::name_from_bucket_space_is_defined_for_global_space() { + CPPUNIT_ASSERT_EQUAL(vespalib::string("global"), + create_empty_resolver().nameFromBucketSpace(spi::FixedBucketSpaces::global_space())); +} + +void ConfigurableBucketResolverTest::name_from_bucket_space_throws_exception_for_unknown_space() { + try { + create_empty_resolver().nameFromBucketSpace(document::BucketSpace(1234)); + CPPUNIT_FAIL("Expected exception on unknown bucket space value"); + } catch (spi::UnknownBucketSpaceException& e) { + } +} + +void ConfigurableBucketResolverTest::known_bucket_space_is_resolved_from_document_id() { + auto resolver = create_simple_resolver(); + CPPUNIT_ASSERT_EQUAL(spi::FixedBucketSpaces::default_space(), + resolver.bucketFromId(DocumentId("id::foo::xyz")).getBucketSpace()); + CPPUNIT_ASSERT_EQUAL(spi::FixedBucketSpaces::default_space(), + resolver.bucketFromId(DocumentId("id::bar::xyz")).getBucketSpace()); + CPPUNIT_ASSERT_EQUAL(spi::FixedBucketSpaces::global_space(), + resolver.bucketFromId(DocumentId("id::baz::xyz")).getBucketSpace()); +} + +void ConfigurableBucketResolverTest::unknown_bucket_space_in_id_throws_exception() { + try { + create_simple_resolver().bucketFromId(DocumentId("id::bjarne::xyz")); + CPPUNIT_FAIL("Expected exception on unknown document type -> bucket space mapping"); + } catch (spi::UnknownBucketSpaceException& e) { + } +} + +using BucketSpacesConfigBuilder = vespa::config::content::core::BucketspacesConfigBuilder; + +namespace { + +BucketSpacesConfigBuilder::Documenttype make_doc_type(vespalib::stringref name, vespalib::stringref space) { + BucketSpacesConfigBuilder::Documenttype doc_type; + doc_type.name = name; + doc_type.bucketspace = space; + return doc_type; +} + +} + +void ConfigurableBucketResolverTest::can_create_resolver_from_bucket_space_config() { + BucketSpacesConfigBuilder builder; + builder.documenttype.emplace_back(make_doc_type("foo", "default")); + builder.documenttype.emplace_back(make_doc_type("bar", "global")); + builder.documenttype.emplace_back(make_doc_type("baz", "global")); + auto resolver = ConfigurableBucketResolver::from_config(builder); + CPPUNIT_ASSERT_EQUAL(spi::FixedBucketSpaces::default_space(), + resolver->bucketFromId(DocumentId("id::foo::xyz")).getBucketSpace()); + CPPUNIT_ASSERT_EQUAL(spi::FixedBucketSpaces::global_space(), + resolver->bucketFromId(DocumentId("id::bar::xyz")).getBucketSpace()); + CPPUNIT_ASSERT_EQUAL(spi::FixedBucketSpaces::global_space(), + resolver->bucketFromId(DocumentId("id::baz::xyz")).getBucketSpace()); +} + +} + diff --git a/storage/src/tests/storageserver/documentapiconvertertest.cpp b/storage/src/tests/storageserver/documentapiconvertertest.cpp index 386be60d88c..b878d5f6719 100644 --- a/storage/src/tests/storageserver/documentapiconvertertest.cpp +++ b/storage/src/tests/storageserver/documentapiconvertertest.cpp @@ -60,13 +60,13 @@ struct MockBucketResolver : public BucketResolver { struct DocumentApiConverterTest : public CppUnit::TestFixture { - MockBucketResolver _bucketResolver; + std::shared_ptr<MockBucketResolver> _bucketResolver; std::unique_ptr<DocumentApiConverter> _converter; const DocumentTypeRepo::SP _repo; const DataType& _html_type; DocumentApiConverterTest() - : _bucketResolver(), + : _bucketResolver(std::make_shared<MockBucketResolver>()), _repo(std::make_shared<DocumentTypeRepo>(readDocumenttypesConfig( TEST_PATH("config-doctypes.cfg")))), _html_type(*_repo->getDocumentType("text/html")) @@ -120,6 +120,7 @@ struct DocumentApiConverterTest : public CppUnit::TestFixture void testStatBucket(); void testGetBucketList(); void testRemoveLocation(); + void can_replace_bucket_resolver_after_construction(); CPPUNIT_TEST_SUITE(DocumentApiConverterTest); CPPUNIT_TEST(testPut); @@ -138,6 +139,7 @@ struct DocumentApiConverterTest : public CppUnit::TestFixture CPPUNIT_TEST(testStatBucket); CPPUNIT_TEST(testGetBucketList); CPPUNIT_TEST(testRemoveLocation); + CPPUNIT_TEST(can_replace_bucket_resolver_after_construction); CPPUNIT_TEST_SUITE_END(); }; @@ -463,4 +465,29 @@ DocumentApiConverterTest::testRemoveLocation() CPPUNIT_ASSERT_EQUAL(defaultBucket, cmd->getBucket()); } +namespace { + +struct ReplacementMockBucketResolver : public MockBucketResolver { + Bucket bucketFromId(const DocumentId& id) const override { + if (id.getDocType() == "testdoctype1") { + return defaultBucket; + } + return Bucket(BucketSpace(0), BucketId(0)); + } +}; + +} + +void DocumentApiConverterTest::can_replace_bucket_resolver_after_construction() { + documentapi::GetDocumentMessage get_msg(DocumentId("id::testdoctype1::baz"), "foo bar"); + auto cmd = toStorageAPI<api::GetCommand>(get_msg); + + CPPUNIT_ASSERT_EQUAL(BucketSpace(0), cmd->getBucket().getBucketSpace()); + + _converter->setBucketResolver(std::make_shared<ReplacementMockBucketResolver>()); + + cmd = toStorageAPI<api::GetCommand>(get_msg); + CPPUNIT_ASSERT_EQUAL(defaultBucketSpace, cmd->getBucket().getBucketSpace()); +} + } diff --git a/storage/src/vespa/storage/common/bucketmessages.cpp b/storage/src/vespa/storage/common/bucketmessages.cpp index 3157bad49e5..e92e2d4c3bf 100644 --- a/storage/src/vespa/storage/common/bucketmessages.cpp +++ b/storage/src/vespa/storage/common/bucketmessages.cpp @@ -2,6 +2,7 @@ #include "bucketmessages.h" #include <vespa/vespalib/stllike/asciistream.h> +#include <ostream> using document::BucketSpace; diff --git a/storage/src/vespa/storage/common/messagesender.h b/storage/src/vespa/storage/common/messagesender.h index 8c45995c42f..659fccad412 100644 --- a/storage/src/vespa/storage/common/messagesender.h +++ b/storage/src/vespa/storage/common/messagesender.h @@ -18,13 +18,14 @@ #include <memory> -namespace storage { -namespace api { +namespace storage::api { class StorageCommand; class StorageReply; class StorageMessage; } +namespace storage { + struct MessageSender { virtual ~MessageSender() {} diff --git a/storage/src/vespa/storage/common/storagecomponent.cpp b/storage/src/vespa/storage/common/storagecomponent.cpp index bf387240dc5..1d6b563f6eb 100644 --- a/storage/src/vespa/storage/common/storagecomponent.cpp +++ b/storage/src/vespa/storage/common/storagecomponent.cpp @@ -28,14 +28,14 @@ StorageComponent::setNodeInfo(vespalib::stringref clusterName, void StorageComponent::setDocumentTypeRepo(DocumentTypeRepoSP repo) { - std::lock_guard<std::mutex> guard(_lock); + std::lock_guard guard(_lock); _docTypeRepo = repo; } void StorageComponent::setLoadTypes(LoadTypeSetSP loadTypes) { - std::lock_guard<std::mutex> guard(_lock); + std::lock_guard guard(_lock); _loadTypes = loadTypes; } @@ -57,14 +57,21 @@ StorageComponent::setBucketIdFactory(const document::BucketIdFactory& factory) void StorageComponent::setDistribution(DistributionSP distribution) { - std::lock_guard<std::mutex> guard(_lock); + std::lock_guard guard(_lock); _distribution = distribution; } void +StorageComponent::enableMultipleBucketSpaces(bool value) +{ + std::lock_guard guard(_lock); + _enableMultipleBucketSpaces = value; +} + +void StorageComponent::setNodeStateUpdater(NodeStateUpdater& updater) { - std::lock_guard<std::mutex> guard(_lock); + std::lock_guard guard(_lock); if (_nodeStateUpdater != 0) { throw vespalib::IllegalStateException( "Node state updater is already set", VESPA_STRLOC); @@ -76,10 +83,16 @@ StorageComponent::StorageComponent(StorageComponentRegister& compReg, vespalib::stringref name) : Component(compReg, name), _clusterName(), - _nodeType(0), + _nodeType(nullptr), _index(0), + _docTypeRepo(), + _loadTypes(), _priorityMapper(new PriorityMapper), - _nodeStateUpdater(0) + _bucketIdFactory(), + _distribution(), + _nodeStateUpdater(nullptr), + _lock(), + _enableMultipleBucketSpaces(false) { compReg.registerStorageComponent(*this); } @@ -87,7 +100,7 @@ StorageComponent::StorageComponent(StorageComponentRegister& compReg, NodeStateUpdater& StorageComponent::getStateUpdater() const { - std::lock_guard<std::mutex> guard(_lock); + std::lock_guard guard(_lock); if (_nodeStateUpdater == 0) { throw vespalib::IllegalStateException( "Component need node state updater at this time, but it has " @@ -114,22 +127,29 @@ StorageComponent::getPriority(const documentapi::LoadType& lt) const StorageComponent::DocumentTypeRepoSP StorageComponent::getTypeRepo() const { - std::lock_guard<std::mutex> guard(_lock); + std::lock_guard guard(_lock); return _docTypeRepo; } StorageComponent::LoadTypeSetSP StorageComponent::getLoadTypes() const { - std::lock_guard<std::mutex> guard(_lock); + std::lock_guard guard(_lock); return _loadTypes; } StorageComponent::DistributionSP StorageComponent::getDistribution() const { - std::lock_guard<std::mutex> guard(_lock); + std::lock_guard guard(_lock); return _distribution; } +bool +StorageComponent::enableMultipleBucketSpaces() const +{ + std::lock_guard guard(_lock); + return _enableMultipleBucketSpaces; +} + } // storage diff --git a/storage/src/vespa/storage/common/storagecomponent.h b/storage/src/vespa/storage/common/storagecomponent.h index d469540b55f..e136d991ac5 100644 --- a/storage/src/vespa/storage/common/storagecomponent.h +++ b/storage/src/vespa/storage/common/storagecomponent.h @@ -37,10 +37,9 @@ #include <vespa/vdslib/state/node.h> #include <mutex> -namespace vespa { namespace config { namespace content { namespace core { -namespace internal { +namespace vespa::config::content::core::internal { class InternalStorPrioritymappingType; -} } } } } +} namespace document { class DocumentTypeRepo; } @@ -59,11 +58,11 @@ class StorageComponentRegister; class StorageComponent : public framework::Component { public: - typedef std::unique_ptr<StorageComponent> UP; - typedef vespa::config::content::core::internal::InternalStorPrioritymappingType PriorityConfig; - typedef std::shared_ptr<document::DocumentTypeRepo> DocumentTypeRepoSP; - typedef std::shared_ptr<documentapi::LoadTypeSet> LoadTypeSetSP; - typedef std::shared_ptr<lib::Distribution> DistributionSP; + using UP = std::unique_ptr<StorageComponent>; + using PriorityConfig = vespa::config::content::core::internal::InternalStorPrioritymappingType; + using DocumentTypeRepoSP = std::shared_ptr<document::DocumentTypeRepo>; + using LoadTypeSetSP = std::shared_ptr<documentapi::LoadTypeSet>; + using DistributionSP = std::shared_ptr<lib::Distribution>; /** * Node type is supposed to be set immediately, and never be updated. @@ -84,6 +83,7 @@ public: void setPriorityConfig(const PriorityConfig&); void setBucketIdFactory(const document::BucketIdFactory&); void setDistribution(DistributionSP); + void enableMultipleBucketSpaces(bool value); StorageComponent(StorageComponentRegister&, vespalib::stringref name); virtual ~StorageComponent(); @@ -102,6 +102,7 @@ public: uint8_t getPriority(const documentapi::LoadType&) const; DistributionSP getDistribution() const; NodeStateUpdater& getStateUpdater() const; + bool enableMultipleBucketSpaces() const; private: vespalib::string _clusterName; @@ -114,6 +115,7 @@ private: DistributionSP _distribution; NodeStateUpdater* _nodeStateUpdater; mutable std::mutex _lock; + bool _enableMultipleBucketSpaces; }; struct StorageComponentRegister : public virtual framework::ComponentRegister diff --git a/storage/src/vespa/storage/distributor/bucketdbupdater.cpp b/storage/src/vespa/storage/distributor/bucketdbupdater.cpp index 46fa0f72d76..cc1181e0d58 100644 --- a/storage/src/vespa/storage/distributor/bucketdbupdater.cpp +++ b/storage/src/vespa/storage/distributor/bucketdbupdater.cpp @@ -2,9 +2,9 @@ #include "bucketdbupdater.h" #include "distributor.h" -#include "distributor_bucket_space_repo.h" #include "distributor_bucket_space.h" #include "simpleclusterinformation.h" +#include "distributormetricsset.h" #include <vespa/storage/common/bucketoperationlogger.h> #include <vespa/storageapi/message/persistence.h> #include <vespa/storageapi/message/removelocation.h> diff --git a/storage/src/vespa/storage/distributor/bucketdbupdater.h b/storage/src/vespa/storage/distributor/bucketdbupdater.h index 29e8d3f6221..19e2e259778 100644 --- a/storage/src/vespa/storage/distributor/bucketdbupdater.h +++ b/storage/src/vespa/storage/distributor/bucketdbupdater.h @@ -13,9 +13,8 @@ #include <vespa/vdslib/state/clusterstate.h> #include <vespa/storage/common/storagelink.h> #include <vespa/storageframework/generic/clock/timer.h> +#include <vespa/storageframework/generic/status/statusreporter.h> #include <vespa/storageapi/messageapi/messagehandler.h> -#include <set> -#include <deque> #include <list> namespace storage::distributor { diff --git a/storage/src/vespa/storage/distributor/bucketgctimecalculator.h b/storage/src/vespa/storage/distributor/bucketgctimecalculator.h index e2b232a6cf5..4ff85e568c8 100644 --- a/storage/src/vespa/storage/distributor/bucketgctimecalculator.h +++ b/storage/src/vespa/storage/distributor/bucketgctimecalculator.h @@ -4,8 +4,7 @@ #include <chrono> #include <vespa/document/bucket/bucketid.h> -namespace storage { -namespace distributor { +namespace storage::distributor { /** * Semantics are basically as follows: @@ -51,6 +50,4 @@ private: std::chrono::seconds _checkInterval; }; -} // distributor -} // storage - +} diff --git a/storage/src/vespa/storage/distributor/bucketownership.h b/storage/src/vespa/storage/distributor/bucketownership.h index c7a7773686f..bfe63c9799d 100644 --- a/storage/src/vespa/storage/distributor/bucketownership.h +++ b/storage/src/vespa/storage/distributor/bucketownership.h @@ -2,9 +2,9 @@ #pragma once #include <vespa/vdslib/state/clusterstate.h> +#include <cassert> -namespace storage { -namespace distributor { +namespace storage::distributor { class BucketOwnership { @@ -14,8 +14,7 @@ class BucketOwnership BucketOwnership(const lib::ClusterState& checkedState) : _checkedState(&checkedState), _owned(false) - { - } + { } BucketOwnership() : _checkedState(nullptr), _owned(true) {} @@ -44,6 +43,4 @@ public: } }; -} // distributor -} // storage - +} diff --git a/storage/src/vespa/storage/distributor/distributor.cpp b/storage/src/vespa/storage/distributor/distributor.cpp index 1edcbe75dd6..988d39e571d 100644 --- a/storage/src/vespa/storage/distributor/distributor.cpp +++ b/storage/src/vespa/storage/distributor/distributor.cpp @@ -5,21 +5,17 @@ #include "throttlingoperationstarter.h" #include "idealstatemetricsset.h" #include "ownership_transfer_safe_time_point_calculator.h" -#include "distributor_bucket_space_repo.h" #include "distributor_bucket_space.h" -#include <vespa/storage/bucketdb/mapbucketdatabase.h> -#include <vespa/storage/distributor/maintenance/simplemaintenancescanner.h> +#include "distributormetricsset.h" #include <vespa/storage/distributor/maintenance/simplebucketprioritydatabase.h> #include <vespa/storage/common/nodestateupdater.h> #include <vespa/storage/common/hostreporter/hostinfo.h> #include <vespa/storageframework/generic/status/xmlstatusreporter.h> - #include <vespa/log/log.h> LOG_SETUP(".distributor-main"); -namespace storage { -namespace distributor { +namespace storage::distributor { class Distributor::Status { const DelegatedStatusRequest& _request; @@ -68,34 +64,25 @@ Distributor::Distributor(DistributorComponentRegister& compReg, _compReg(compReg), _component(compReg, "distributor"), _bucketSpaceRepo(std::make_unique<DistributorBucketSpaceRepo>()), - _metrics(new DistributorMetricSet( - _component.getLoadTypes()->getMetricLoadTypes())), + _metrics(new DistributorMetricSet(_component.getLoadTypes()->getMetricLoadTypes())), _operationOwner(*this, _component.getClock()), _maintenanceOperationOwner(*this, _component.getClock()), _pendingMessageTracker(compReg), _bucketDBUpdater(*this, *_bucketSpaceRepo, *this, compReg), _distributorStatusDelegate(compReg, *this, *this), _bucketDBStatusDelegate(compReg, *this, _bucketDBUpdater), - _idealStateManager(*this, *_bucketSpaceRepo, compReg, - manageActiveBucketCopies), - _externalOperationHandler(*this, *_bucketSpaceRepo, - _idealStateManager, compReg), + _idealStateManager(*this, *_bucketSpaceRepo, compReg, manageActiveBucketCopies), + _externalOperationHandler(*this, *_bucketSpaceRepo, _idealStateManager, compReg), _threadPool(threadPool), _initializingIsUp(true), _doneInitializeHandler(doneInitHandler), _doneInitializing(false), _messageSender(messageSender), _bucketPriorityDb(new SimpleBucketPriorityDatabase()), - _scanner(new SimpleMaintenanceScanner( - *_bucketPriorityDb, _idealStateManager, - *_bucketSpaceRepo)), - _throttlingStarter(new ThrottlingOperationStarter( - _maintenanceOperationOwner)), - _blockingStarter(new BlockingOperationStarter(_pendingMessageTracker, - *_throttlingStarter)), - _scheduler(new MaintenanceScheduler(_idealStateManager, - *_bucketPriorityDb, - *_blockingStarter)), + _scanner(new SimpleMaintenanceScanner(*_bucketPriorityDb, _idealStateManager, *_bucketSpaceRepo)), + _throttlingStarter(new ThrottlingOperationStarter(_maintenanceOperationOwner)), + _blockingStarter(new BlockingOperationStarter(_pendingMessageTracker, *_throttlingStarter)), + _scheduler(new MaintenanceScheduler(_idealStateManager, *_bucketPriorityDb, *_blockingStarter)), _schedulingMode(MaintenanceScheduler::NORMAL_SCHEDULING_MODE), _recoveryTimeStarted(_component.getClock()), _tickResult(framework::ThreadWaitInfo::NO_MORE_CRITICAL_WORK_KNOWN), @@ -105,8 +92,7 @@ Distributor::Distributor(DistributorComponentRegister& compReg, _metricLock(), _maintenanceStats(), _bucketDbStats(), - _hostInfoReporter(_pendingMessageTracker.getLatencyStatisticsProvider(), - *this), + _hostInfoReporter(_pendingMessageTracker.getLatencyStatisticsProvider(), *this), _ownershipSafeTimeCalc( std::make_unique<OwnershipTransferSafeTimePointCalculator>( std::chrono::seconds(0))) // Set by config later @@ -162,10 +148,8 @@ void Distributor::sendCommand(const std::shared_ptr<api::StorageCommand>& cmd) { if (cmd->getType() == api::MessageType::MERGEBUCKET) { - api::MergeBucketCommand& merge( - static_cast<api::MergeBucketCommand&>(*cmd)); - _idealStateManager.getMetrics().nodesPerMerge.addValue( - merge.getNodes().size()); + api::MergeBucketCommand& merge(static_cast<api::MergeBucketCommand&>(*cmd)); + _idealStateManager.getMetrics().nodesPerMerge.addValue(merge.getNodes().size()); } sendUp(cmd); } @@ -179,10 +163,8 @@ Distributor::sendReply(const std::shared_ptr<api::StorageReply>& reply) void Distributor::setNodeStateUp() { - NodeStateUpdater::Lock::SP lock( - _component.getStateUpdater().grabStateChangeLock()); - lib::NodeState ns( - *_component.getStateUpdater().getReportedNodeState()); + NodeStateUpdater::Lock::SP lock(_component.getStateUpdater().grabStateChangeLock()); + lib::NodeState ns(*_component.getStateUpdater().getReportedNodeState()); ns.setState(lib::State::UP); _component.getStateUpdater().setReportedNodeState(ns); } @@ -832,5 +814,4 @@ Distributor::handleStatusRequest(const DelegatedStatusRequest& request) const return true; } -} // distributor -} // storage +} diff --git a/storage/src/vespa/storage/distributor/distributorinterface.h b/storage/src/vespa/storage/distributor/distributorinterface.h index bf27dc432b6..3445397c17d 100644 --- a/storage/src/vespa/storage/distributor/distributorinterface.h +++ b/storage/src/vespa/storage/distributor/distributorinterface.h @@ -1,32 +1,28 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once -#include <vespa/storage/common/distributorcomponent.h> -#include <vespa/storage/common/messagesender.h> -#include <vespa/storage/distributor/pendingmessagetracker.h> -#include <vespa/storageapi/message/state.h> +#include "bucketgctimecalculator.h" +#include "distributormessagesender.h" +#include "bucketownership.h" #include <vespa/storage/bucketdb/bucketdatabase.h> -#include <vespa/storage/distributor/bucketgctimecalculator.h> -#include <vespa/storage/distributor/distributormetricsset.h> -#include <vespa/storage/config/distributorconfiguration.h> -#include <vespa/storage/distributor/distributormessagesender.h> -#include <vespa/storage/distributor/bucketownership.h> +#include <vespa/document/bucket/bucket.h> +namespace storage::api { class MergeBucketReply; } namespace storage { + class DistributorConfiguration; + class DistributorMetricSet; +} +namespace storage::distributor { -namespace distributor { +class PendingMessageTracker; class DistributorInterface : public DistributorMessageSender { public: virtual PendingMessageTracker& getPendingMessageTracker() = 0; - virtual DistributorMetricSet& getMetrics() = 0; - virtual void enableClusterState(const lib::ClusterState& state) = 0; - virtual BucketOwnership checkOwnershipInPendingState(const document::Bucket &bucket) const = 0; - virtual void notifyDistributionChangeEnabled() = 0; /** @@ -55,19 +51,11 @@ public: * Returns true if the node is currently initializing. */ virtual bool initializing() const = 0; - virtual void handleCompletedMerge(const std::shared_ptr<api::MergeBucketReply>&) = 0; - virtual const char* getStorageNodeUpStates() const = 0; - virtual const DistributorConfiguration& getConfig() const = 0; - virtual ChainedMessageSender& getMessageSender() = 0; - virtual const BucketGcTimeCalculator::BucketIdHasher& getBucketIdHasher() const = 0; }; } - -} - diff --git a/storage/src/vespa/storage/distributor/distributormessagesender.h b/storage/src/vespa/storage/distributor/distributormessagesender.h index 0fccaad87e3..078762dd05c 100644 --- a/storage/src/vespa/storage/distributor/distributormessagesender.h +++ b/storage/src/vespa/storage/distributor/distributormessagesender.h @@ -2,11 +2,9 @@ #pragma once #include <vespa/storage/common/messagesender.h> -#include <vespa/vdslib/distribution/distribution.h> -namespace storage { - -namespace distributor { +namespace storage::lib { class NodeType; } +namespace storage::distributor { class PendingMessageTracker; @@ -16,21 +14,12 @@ public: Sends the storage command to the given node, returns message id. */ - virtual uint64_t sendToNode(const lib::NodeType& nodeType, - uint16_t node, - const std::shared_ptr<api::StorageCommand>& cmd, - bool useDocumentAPI = false); + virtual uint64_t sendToNode(const lib::NodeType& nodeType, uint16_t node, + const std::shared_ptr<api::StorageCommand>& cmd, bool useDocumentAPI = false); virtual int getDistributorIndex() const = 0; - virtual const std::string& getClusterName() const = 0; - virtual const PendingMessageTracker& getPendingMessageTracker() const = 0; }; -} // distributor - -} // storage - - - +} diff --git a/storage/src/vespa/storage/distributor/idealstatemanager.cpp b/storage/src/vespa/storage/distributor/idealstatemanager.cpp index 4ceeb387341..031e9946178 100644 --- a/storage/src/vespa/storage/distributor/idealstatemanager.cpp +++ b/storage/src/vespa/storage/distributor/idealstatemanager.cpp @@ -95,7 +95,7 @@ IdealStateManager::getEntryForPrimaryBucket(StateChecker::Context& c) const { for (uint32_t j = 0; j < c.entries.size(); ++j) { BucketDatabase::Entry& e = c.entries[j]; - if (e.getBucketId() == c.getBucketId()) { + if (e.getBucketId() == c.getBucketId() && ! e->getNodes().empty()) { return &e; } } diff --git a/storage/src/vespa/storage/distributor/idealstatemanager.h b/storage/src/vespa/storage/distributor/idealstatemanager.h index b9607b35d28..028c9cbb0b6 100644 --- a/storage/src/vespa/storage/distributor/idealstatemanager.h +++ b/storage/src/vespa/storage/distributor/idealstatemanager.h @@ -1,18 +1,14 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once -#include <deque> -#include <map> -#include <set> -#include <vespa/storage/distributor/distributorcomponent.h> -#include <vespa/storage/distributor/statechecker.h> +#include "distributorcomponent.h" +#include "statechecker.h" #include <vespa/storage/distributor/maintenance/maintenanceprioritygenerator.h> #include <vespa/storage/distributor/maintenance/maintenanceoperationgenerator.h> +#include <vespa/storageframework/generic/status/htmlstatusreporter.h> #include <vespa/vdslib/state/clusterstate.h> -#include <vector> -namespace storage { -namespace distributor { +namespace storage::distributor { class IdealStateMetricSet; class IdealStateOperation; @@ -116,8 +112,7 @@ private: DistributorComponent _distributorComponent; DistributorBucketSpaceRepo &_bucketSpaceRepo; - std::vector<IdealStateOperation::SP> generateOperationsForBucket( - StateChecker::Context& c) const; + std::vector<IdealStateOperation::SP> generateOperationsForBucket(StateChecker::Context& c) const; bool iAmUp() const; @@ -125,9 +120,9 @@ private: // Stats tracker to use for all generateAll() calls to avoid having // to create a new hash map for each single bucket processed. NodeMaintenanceStatsTracker _statsTracker; - const IdealStateManager& _ism; - document::BucketSpace _bucketSpace; - std::ostream& _out; + const IdealStateManager & _ism; + document::BucketSpace _bucketSpace; + std::ostream & _out; public: StatusBucketVisitor(const IdealStateManager& ism, document::BucketSpace bucketSpace, std::ostream& out) : _statsTracker(), _ism(ism), _bucketSpace(bucketSpace), _out(out) {} @@ -139,11 +134,8 @@ private: }; friend class StatusBucketVisitor; - void getBucketStatus(document::BucketSpace bucketSpace, - const BucketDatabase::Entry& entry, - NodeMaintenanceStatsTracker& statsTracker, - std::ostream& out) const; + void getBucketStatus(document::BucketSpace bucketSpace, const BucketDatabase::Entry& entry, + NodeMaintenanceStatsTracker& statsTracker, std::ostream& out) const; }; -} // distributor -} // storage +} diff --git a/storage/src/vespa/storage/distributor/messagetracker.cpp b/storage/src/vespa/storage/distributor/messagetracker.cpp index b844987e978..6568cec9a80 100644 --- a/storage/src/vespa/storage/distributor/messagetracker.cpp +++ b/storage/src/vespa/storage/distributor/messagetracker.cpp @@ -1,19 +1,19 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "messagetracker.h" +#include <vespa/storageapi/messageapi/bucketcommand.h> +#include <vespa/storageapi/messageapi/bucketreply.h> #include <vespa/log/log.h> LOG_SETUP(".messagetracker"); -namespace storage { - -namespace distributor { +namespace storage::distributor { MessageTracker::MessageTracker(const std::string& clusterName) : _clusterName(clusterName) {} -MessageTracker::~MessageTracker() {} +MessageTracker::~MessageTracker() = default; void MessageTracker::flushQueue(MessageSender& sender) @@ -48,7 +48,4 @@ MessageTracker::finished() return _sentMessages.empty(); } - -} - } diff --git a/storage/src/vespa/storage/distributor/messagetracker.h b/storage/src/vespa/storage/distributor/messagetracker.h index 63c0be1ca93..017979c16c0 100644 --- a/storage/src/vespa/storage/distributor/messagetracker.h +++ b/storage/src/vespa/storage/distributor/messagetracker.h @@ -1,10 +1,14 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once -#include "distributormetricsset.h" #include <vespa/storage/common/messagesender.h> -#include <vespa/storageapi/messageapi/bucketcommand.h> -#include <vespa/storageapi/messageapi/bucketreply.h> +#include <vector> +#include <map> + +namespace storage::api { + class BucketCommand; + class BucketReply; +} namespace storage::distributor { diff --git a/storage/src/vespa/storage/distributor/operations/external/putoperation.h b/storage/src/vespa/storage/distributor/operations/external/putoperation.h index 8beffe8b2c3..c27f2ee2266 100644 --- a/storage/src/vespa/storage/distributor/operations/external/putoperation.h +++ b/storage/src/vespa/storage/distributor/operations/external/putoperation.h @@ -5,22 +5,21 @@ #include <vespa/storage/distributor/operations/sequenced_operation.h> #include <vespa/storageapi/messageapi/returncode.h> #include <vespa/storage/distributor/persistencemessagetracker.h> -#include <vespa/storage/distributor/operationtargetresolver.h> namespace document { class Document; } -namespace storage { -namespace lib { +namespace storage::lib { class Distribution; } -namespace api { +namespace storage::api { class CreateBucketReply; class PutCommand; } -namespace distributor { +namespace storage::distributor { class DistributorBucketSpace; +class OperationTargetList; class PutOperation : public SequencedOperation { @@ -78,5 +77,4 @@ private: DistributorBucketSpace &_bucketSpace; }; -} // distributor -} // storage +} diff --git a/storage/src/vespa/storage/distributor/operations/external/statbucketoperation.h b/storage/src/vespa/storage/distributor/operations/external/statbucketoperation.h index af448c2dd55..d40924e23f0 100644 --- a/storage/src/vespa/storage/distributor/operations/external/statbucketoperation.h +++ b/storage/src/vespa/storage/distributor/operations/external/statbucketoperation.h @@ -8,12 +8,11 @@ #pragma once #include <vespa/storage/distributor/operations/operation.h> +#include <map> -namespace storage { +namespace storage::api { class StatBucketCommand; } -namespace api { class StatBucketCommand; } - -namespace distributor { +namespace storage::distributor { class DistributorComponent; class DistributorBucketSpace; @@ -21,9 +20,8 @@ class DistributorBucketSpace; class StatBucketOperation : public Operation { public: - StatBucketOperation(DistributorComponent& manager, - DistributorBucketSpace &bucketSpace, - const std::shared_ptr<api::StatBucketCommand> & cmd); + StatBucketOperation(DistributorComponent& manager, DistributorBucketSpace &bucketSpace, + const std::shared_ptr<api::StatBucketCommand> & cmd); ~StatBucketOperation(); const char* getName() const override { return "statBucket"; } @@ -37,10 +35,8 @@ private: std::shared_ptr<api::StatBucketCommand> _command; - std::map<uint64_t, uint16_t> _sent; + std::map<uint64_t, uint16_t> _sent; std::map<uint16_t, std::string> _results; }; -} // distributor -} // storage - +} diff --git a/storage/src/vespa/storage/distributor/operations/external/twophaseupdateoperation.cpp b/storage/src/vespa/storage/distributor/operations/external/twophaseupdateoperation.cpp index 79ffee7430c..db120880267 100644 --- a/storage/src/vespa/storage/distributor/operations/external/twophaseupdateoperation.cpp +++ b/storage/src/vespa/storage/distributor/operations/external/twophaseupdateoperation.cpp @@ -4,13 +4,12 @@ #include "getoperation.h" #include "putoperation.h" #include "updateoperation.h" -#include <vespa/document/fieldvalue/document.h> -#include <vespa/document/datatype/documenttype.h> -#include <vespa/document/select/parser.h> +#include <vespa/storage/distributor/distributor_bucket_space.h> #include <vespa/storageapi/message/persistence.h> #include <vespa/storageapi/message/batch.h> +#include <vespa/document/datatype/documenttype.h> +#include <vespa/document/select/parser.h> #include <vespa/vespalib/stllike/hash_map.hpp> -#include <vespa/storage/distributor/distributor_bucket_space.h> #include <vespa/log/log.h> LOG_SETUP(".distributor.callback.twophaseupdate"); @@ -18,8 +17,7 @@ LOG_SETUP(".distributor.callback.twophaseupdate"); using namespace std::literals::string_literals; using document::BucketSpace; -namespace storage { -namespace distributor { +namespace storage::distributor { TwoPhaseUpdateOperation::TwoPhaseUpdateOperation( DistributorComponent& manager, @@ -570,5 +568,4 @@ TwoPhaseUpdateOperation::onClose(DistributorMessageSender& sender) { } } -} // distributor -} // storage +} diff --git a/storage/src/vespa/storage/distributor/operations/external/visitoroperation.h b/storage/src/vespa/storage/distributor/operations/external/visitoroperation.h index f35a9dcb3ec..b4f84d76649 100644 --- a/storage/src/vespa/storage/distributor/operations/external/visitoroperation.h +++ b/storage/src/vespa/storage/distributor/operations/external/visitoroperation.h @@ -11,11 +11,10 @@ namespace document { class Document; } -namespace storage { +namespace storage { class VisitorMetricSet; } +namespace storage::lib { class ClusterState; } -class VisitorMetricSet; - -namespace distributor { +namespace storage::distributor { class DistributorComponent; class DistributorBucketSpace; @@ -181,5 +180,3 @@ private: }; } - -} diff --git a/storage/src/vespa/storage/distributor/operations/idealstate/idealstateoperation.cpp b/storage/src/vespa/storage/distributor/operations/idealstate/idealstateoperation.cpp index 2337129e375..52c8344b820 100644 --- a/storage/src/vespa/storage/distributor/operations/idealstate/idealstateoperation.cpp +++ b/storage/src/vespa/storage/distributor/operations/idealstate/idealstateoperation.cpp @@ -3,9 +3,8 @@ #include <vespa/storage/distributor/idealstatemanager.h> #include <vespa/storage/distributor/pendingmessagetracker.h> #include <vespa/storage/distributor/idealstatemetricsset.h> -#include <vespa/storage/distributor/pendingmessagetracker.h> #include <vespa/storage/distributor/distributor_bucket_space_repo.h> -#include <vespa/storageapi/messageapi/maintenancecommand.h> +#include <vespa/documentapi/loadtypes/loadtypeset.h> #include <vespa/log/log.h> LOG_SETUP(".distributor.operation"); @@ -26,17 +25,15 @@ const uint32_t IdealStateOperation::MAINTENANCE_MESSAGE_TYPES[] = }; IdealStateOperation::IdealStateOperation(const BucketAndNodes& bucketAndNodes) - : _manager(nullptr), - _bucketSpace(nullptr), - _bucketAndNodes(bucketAndNodes), - _ok(true), - _priority(255) + : _manager(nullptr), + _bucketSpace(nullptr), + _bucketAndNodes(bucketAndNodes), + _ok(true), + _priority(255) { } -IdealStateOperation::~IdealStateOperation() -{ -} +IdealStateOperation::~IdealStateOperation() = default; BucketAndNodes::BucketAndNodes(const document::Bucket &bucket, uint16_t node) : _bucket(bucket) @@ -108,8 +105,7 @@ IdealStateOperation::setCommandMeta(api::MaintenanceCommand& cmd) const { cmd.setPriority(_priority); cmd.setReason(_detailedReason); - cmd.setLoadType( - (*_manager->getLoadTypes())["maintenance"]); + cmd.setLoadType((*_manager->getLoadTypes())["maintenance"]); } std::string diff --git a/storage/src/vespa/storage/distributor/operations/idealstate/mergeoperation.cpp b/storage/src/vespa/storage/distributor/operations/idealstate/mergeoperation.cpp index 271ac35968e..32ea695bd94 100644 --- a/storage/src/vespa/storage/distributor/operations/idealstate/mergeoperation.cpp +++ b/storage/src/vespa/storage/distributor/operations/idealstate/mergeoperation.cpp @@ -2,7 +2,7 @@ #include "mergeoperation.h" #include <vespa/storage/distributor/idealstatemanager.h> #include <vespa/storage/distributor/distributor_bucket_space.h> -#include <array> +#include <vespa/storage/distributor/pendingmessagetracker.h> #include <vespa/log/bufferedlogger.h> LOG_SETUP(".distributor.operation.idealstate.merge"); diff --git a/storage/src/vespa/storage/distributor/operations/idealstate/setbucketstateoperation.cpp b/storage/src/vespa/storage/distributor/operations/idealstate/setbucketstateoperation.cpp index 1acb2dcc64b..6a87688c295 100644 --- a/storage/src/vespa/storage/distributor/operations/idealstate/setbucketstateoperation.cpp +++ b/storage/src/vespa/storage/distributor/operations/idealstate/setbucketstateoperation.cpp @@ -3,6 +3,7 @@ #include "setbucketstateoperation.h" #include <vespa/storage/distributor/idealstatemanager.h> #include <vespa/storage/distributor/distributor_bucket_space.h> +#include <vespa/storageapi/message/bucket.h> #include <vespa/log/log.h> LOG_SETUP(".distributor.operation.idealstate.setactive"); diff --git a/storage/src/vespa/storage/distributor/operationtargetresolver.h b/storage/src/vespa/storage/distributor/operationtargetresolver.h index 23e0fbbcba4..20666ea254c 100644 --- a/storage/src/vespa/storage/distributor/operationtargetresolver.h +++ b/storage/src/vespa/storage/distributor/operationtargetresolver.h @@ -10,8 +10,7 @@ #include <vespa/vdslib/state/node.h> #include <vespa/vespalib/util/printable.h> -namespace storage { -namespace distributor { +namespace storage::distributor { class OperationTarget : public vespalib::AsciiPrintable { @@ -68,5 +67,4 @@ public: const document::BucketId& id) = 0; }; -} // distributor -} // storage +} diff --git a/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.cpp b/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.cpp index d2cb8fa4380..2f9430c49bb 100644 --- a/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.cpp +++ b/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.cpp @@ -10,11 +10,20 @@ LOG_SETUP(".storage.component.register"); namespace storage { StorageComponentRegisterImpl::StorageComponentRegisterImpl() - : _nodeType(0), + : _componentLock(), + _components(), + _clusterName(), + _nodeType(nullptr), _index(0xffff), + _docTypeRepo(), _loadTypes(new documentapi::LoadTypeSet), - _nodeStateUpdater(0) -{ } + _priorityConfig(), + _bucketIdFactory(), + _distribution(), + _nodeStateUpdater(nullptr), + _bucketSpacesConfig() +{ +} StorageComponentRegisterImpl::~StorageComponentRegisterImpl() { } @@ -33,6 +42,7 @@ StorageComponentRegisterImpl::registerStorageComponent(StorageComponent& smc) smc.setPriorityConfig(_priorityConfig); smc.setBucketIdFactory(_bucketIdFactory); smc.setDistribution(_distribution); + smc.enableMultipleBucketSpaces(_bucketSpacesConfig.enableMultipleBucketSpaces); } void @@ -115,4 +125,14 @@ StorageComponentRegisterImpl::setDistribution(lib::Distribution::SP distribution } } +void +StorageComponentRegisterImpl::setBucketSpacesConfig(const BucketspacesConfig& config) +{ + vespalib::LockGuard lock(_componentLock); + _bucketSpacesConfig = config; + for (size_t i = 0; i < _components.size(); ++i) { + _components[i]->enableMultipleBucketSpaces(config.enableMultipleBucketSpaces); + } +} + } // storage diff --git a/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.h b/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.h index 49387e2c2b5..afd9f11a88b 100644 --- a/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.h +++ b/storage/src/vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.h @@ -11,6 +11,7 @@ #include <vespa/document/repo/documenttyperepo.h> #include <vespa/documentapi/loadtypes/loadtypeset.h> #include <vespa/storage/common/storagecomponent.h> +#include <vespa/storage/config/config-bucketspaces.h> #include <vespa/storage/config/config-stor-prioritymapping.h> #include <vespa/storageframework/defaultimplementation/component/componentregisterimpl.h> #include <vespa/vdslib/distribution/distribution.h> @@ -21,9 +22,9 @@ class StorageComponentRegisterImpl : public virtual StorageComponentRegister, public virtual framework::defaultimplementation::ComponentRegisterImpl { - typedef framework::defaultimplementation::ComponentRegisterImpl CompRegImpl; - typedef StorageComponent::PriorityConfig PriorityConfig; - //CompRegImpl _compReg; + using PriorityConfig = StorageComponent::PriorityConfig; + using BucketspacesConfig = vespa::config::content::core::internal::InternalBucketspacesType; + vespalib::Lock _componentLock; std::vector<StorageComponent*> _components; vespalib::string _clusterName; @@ -35,6 +36,7 @@ class StorageComponentRegisterImpl document::BucketIdFactory _bucketIdFactory; lib::Distribution::SP _distribution; NodeStateUpdater* _nodeStateUpdater; + BucketspacesConfig _bucketSpacesConfig; public: typedef std::unique_ptr<StorageComponentRegisterImpl> UP; @@ -64,6 +66,7 @@ public: virtual void setPriorityConfig(const PriorityConfig&); virtual void setBucketIdFactory(const document::BucketIdFactory&); virtual void setDistribution(lib::Distribution::SP); + virtual void setBucketSpacesConfig(const BucketspacesConfig&); }; diff --git a/storage/src/vespa/storage/persistence/splitbitdetector.h b/storage/src/vespa/storage/persistence/splitbitdetector.h index b3fc5bea566..6f1af6c5970 100644 --- a/storage/src/vespa/storage/persistence/splitbitdetector.h +++ b/storage/src/vespa/storage/persistence/splitbitdetector.h @@ -18,6 +18,7 @@ #pragma once #include <vespa/persistence/spi/persistenceprovider.h> +#include <vespa/vespalib/util/printable.h> namespace storage { diff --git a/storage/src/vespa/storage/storageserver/CMakeLists.txt b/storage/src/vespa/storage/storageserver/CMakeLists.txt index c0238922a91..4fb3a5a0b99 100644 --- a/storage/src/vespa/storage/storageserver/CMakeLists.txt +++ b/storage/src/vespa/storage/storageserver/CMakeLists.txt @@ -6,6 +6,7 @@ vespa_add_library(storage_storageserver changedbucketownershiphandler.cpp communicationmanager.cpp communicationmanagermetrics.cpp + configurable_bucket_resolver.cpp distributornode.cpp distributornodecontext.cpp documentapiconverter.cpp diff --git a/storage/src/vespa/storage/storageserver/communicationmanager.cpp b/storage/src/vespa/storage/storageserver/communicationmanager.cpp index c19dc7cfd27..a2c923b93db 100644 --- a/storage/src/vespa/storage/storageserver/communicationmanager.cpp +++ b/storage/src/vespa/storage/storageserver/communicationmanager.cpp @@ -288,8 +288,7 @@ CommunicationManager::CommunicationManager(StorageComponentRegister& compReg, co _count(0), _configUri(configUri), _closed(false), - _bucketResolver(std::make_unique<PlaceHolderBucketResolver>()), - _docApiConverter(configUri, *_bucketResolver) + _docApiConverter(configUri, std::make_shared<PlaceHolderBucketResolver>()) { _component.registerMetricUpdateHook(*this, framework::SecondTime(5)); _component.registerMetric(_metrics); diff --git a/storage/src/vespa/storage/storageserver/communicationmanager.h b/storage/src/vespa/storage/storageserver/communicationmanager.h index f4f4aa5a236..b4508fbc9f9 100644 --- a/storage/src/vespa/storage/storageserver/communicationmanager.h +++ b/storage/src/vespa/storage/storageserver/communicationmanager.h @@ -170,7 +170,6 @@ private: config::ConfigUri _configUri; std::atomic<bool> _closed; - std::unique_ptr<BucketResolver> _bucketResolver; DocumentApiConverter _docApiConverter; framework::Thread::UP _thread; diff --git a/storage/src/vespa/storage/storageserver/configurable_bucket_resolver.cpp b/storage/src/vespa/storage/storageserver/configurable_bucket_resolver.cpp new file mode 100644 index 00000000000..86c802a65cf --- /dev/null +++ b/storage/src/vespa/storage/storageserver/configurable_bucket_resolver.cpp @@ -0,0 +1,36 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/document/base/documentid.h> +#include <vespa/persistence/spi/fixed_bucket_spaces.h> +#include <vespa/vespalib/util/exceptions.h> +#include "configurable_bucket_resolver.h" + +namespace storage { + +document::Bucket ConfigurableBucketResolver::bucketFromId(const document::DocumentId& id) const { + auto iter = _type_to_space.find(id.getDocType()); + if (iter != _type_to_space.end()) { + return document::Bucket(iter->second, document::BucketId(0)); + } + throw spi::UnknownBucketSpaceException("Unknown bucket space mapping for document type '" + + id.getDocType() + "' in id: " + id.toString(), VESPA_STRLOC); +} + +document::BucketSpace ConfigurableBucketResolver::bucketSpaceFromName(const vespalib::string& name) const { + return spi::FixedBucketSpaces::from_string(name); +} + +vespalib::string ConfigurableBucketResolver::nameFromBucketSpace(const document::BucketSpace& space) const { + return spi::FixedBucketSpaces::to_string(space); +} + +std::shared_ptr<ConfigurableBucketResolver> ConfigurableBucketResolver::from_config( + const vespa::config::content::core::BucketspacesConfig& config) { + ConfigurableBucketResolver::BucketSpaceMapping type_to_space; + for (auto& mapping : config.documenttype) { + type_to_space.emplace(mapping.name, spi::FixedBucketSpaces::from_string(mapping.bucketspace)); + } + return std::make_shared<ConfigurableBucketResolver>(std::move(type_to_space)); +} + +} diff --git a/storage/src/vespa/storage/storageserver/configurable_bucket_resolver.h b/storage/src/vespa/storage/storageserver/configurable_bucket_resolver.h new file mode 100644 index 00000000000..acebd9777fb --- /dev/null +++ b/storage/src/vespa/storage/storageserver/configurable_bucket_resolver.h @@ -0,0 +1,36 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include <vespa/storage/config/config-bucketspaces.h> +#include <vespa/storage/common/bucket_resolver.h> +#include <vespa/vespalib/stllike/hash_fun.h> +#include <memory> +#include <unordered_map> + +namespace storage { + +/** + * Immutable implementation of BucketResolver which maintains an explicit + * mapping from document type to bucket space. + * + * If an unknown document type or bucket space is given as an argument, + * an spi::UnknownBucketSpaceException is thrown. + */ +class ConfigurableBucketResolver : public BucketResolver { +public: + using BucketSpaceMapping = std::unordered_map<vespalib::string, document::BucketSpace, vespalib::hash<vespalib::string>>; + const BucketSpaceMapping _type_to_space; +public: + explicit ConfigurableBucketResolver(BucketSpaceMapping type_to_space) + : _type_to_space(std::move(type_to_space)) + {} + + document::Bucket bucketFromId(const document::DocumentId&) const override; + document::BucketSpace bucketSpaceFromName(const vespalib::string& name) const override; + vespalib::string nameFromBucketSpace(const document::BucketSpace& space) const override; + + static std::shared_ptr<ConfigurableBucketResolver> from_config( + const vespa::config::content::core::BucketspacesConfig& config); +}; + +}
\ No newline at end of file diff --git a/storage/src/vespa/storage/storageserver/documentapiconverter.cpp b/storage/src/vespa/storage/storageserver/documentapiconverter.cpp index c2761b3d832..09ca9924891 100644 --- a/storage/src/vespa/storage/storageserver/documentapiconverter.cpp +++ b/storage/src/vespa/storage/storageserver/documentapiconverter.cpp @@ -24,9 +24,9 @@ using document::BucketSpace; namespace storage { DocumentApiConverter::DocumentApiConverter(const config::ConfigUri &configUri, - const BucketResolver &bucketResolver) + std::shared_ptr<const BucketResolver> bucketResolver) : _priConverter(std::make_unique<PriorityConverter>(configUri)), - _bucketResolver(bucketResolver) + _bucketResolver(std::move(bucketResolver)) {} DocumentApiConverter::~DocumentApiConverter() {} @@ -42,7 +42,7 @@ DocumentApiConverter::toStorageAPI(documentapi::DocumentMessage& fromMsg, case DocumentProtocol::MESSAGE_PUTDOCUMENT: { documentapi::PutDocumentMessage& from(static_cast<documentapi::PutDocumentMessage&>(fromMsg)); - document::Bucket bucket = _bucketResolver.bucketFromId(from.getDocument().getId()); + document::Bucket bucket = bucketResolver()->bucketFromId(from.getDocument().getId()); auto to = std::make_unique<api::PutCommand>(bucket, from.stealDocument(), from.getTimestamp()); to->setCondition(from.getCondition()); toMsg = std::move(to); @@ -51,7 +51,7 @@ DocumentApiConverter::toStorageAPI(documentapi::DocumentMessage& fromMsg, case DocumentProtocol::MESSAGE_UPDATEDOCUMENT: { documentapi::UpdateDocumentMessage& from(static_cast<documentapi::UpdateDocumentMessage&>(fromMsg)); - document::Bucket bucket = _bucketResolver.bucketFromId(from.getDocumentUpdate().getId()); + document::Bucket bucket = bucketResolver()->bucketFromId(from.getDocumentUpdate().getId()); auto to = std::make_unique<api::UpdateCommand>(bucket, from.stealDocumentUpdate(), from.getNewTimestamp()); to->setOldTimestamp(from.getOldTimestamp()); to->setCondition(from.getCondition()); @@ -61,7 +61,7 @@ DocumentApiConverter::toStorageAPI(documentapi::DocumentMessage& fromMsg, case DocumentProtocol::MESSAGE_REMOVEDOCUMENT: { documentapi::RemoveDocumentMessage& from(static_cast<documentapi::RemoveDocumentMessage&>(fromMsg)); - auto to = std::make_unique<api::RemoveCommand>(_bucketResolver.bucketFromId(from.getDocumentId()), from.getDocumentId(), 0); + auto to = std::make_unique<api::RemoveCommand>(bucketResolver()->bucketFromId(from.getDocumentId()), from.getDocumentId(), 0); to->setCondition(from.getCondition()); toMsg = std::move(to); break; @@ -69,14 +69,14 @@ DocumentApiConverter::toStorageAPI(documentapi::DocumentMessage& fromMsg, case DocumentProtocol::MESSAGE_GETDOCUMENT: { documentapi::GetDocumentMessage& from(static_cast<documentapi::GetDocumentMessage&>(fromMsg)); - auto to = std::make_unique<api::GetCommand>(_bucketResolver.bucketFromId(from.getDocumentId()), from.getDocumentId(), from.getFieldSet()); + auto to = std::make_unique<api::GetCommand>(bucketResolver()->bucketFromId(from.getDocumentId()), from.getDocumentId(), from.getFieldSet()); toMsg.reset(to.release()); break; } case DocumentProtocol::MESSAGE_CREATEVISITOR: { documentapi::CreateVisitorMessage& from(static_cast<documentapi::CreateVisitorMessage&>(fromMsg)); - auto to = std::make_unique<api::CreateVisitorCommand>(_bucketResolver.bucketSpaceFromName(from.getBucketSpace()), + auto to = std::make_unique<api::CreateVisitorCommand>(bucketResolver()->bucketSpaceFromName(from.getBucketSpace()), from.getLibraryName(), from.getInstanceId(), from.getDocumentSelection()); @@ -118,14 +118,14 @@ DocumentApiConverter::toStorageAPI(documentapi::DocumentMessage& fromMsg, case DocumentProtocol::MESSAGE_STATBUCKET: { documentapi::StatBucketMessage& from(static_cast<documentapi::StatBucketMessage&>(fromMsg)); - document::Bucket bucket(_bucketResolver.bucketSpaceFromName(from.getBucketSpace()), from.getBucketId()); + document::Bucket bucket(bucketResolver()->bucketSpaceFromName(from.getBucketSpace()), from.getBucketId()); toMsg = std::make_unique<api::StatBucketCommand>(bucket, from.getDocumentSelection()); break; } case DocumentProtocol::MESSAGE_GETBUCKETLIST: { documentapi::GetBucketListMessage& from(static_cast<documentapi::GetBucketListMessage&>(fromMsg)); - document::Bucket bucket(_bucketResolver.bucketSpaceFromName(from.getBucketSpace()), from.getBucketId()); + document::Bucket bucket(bucketResolver()->bucketSpaceFromName(from.getBucketSpace()), from.getBucketId()); toMsg = std::make_unique<api::GetBucketListCommand>(bucket); break; } @@ -145,7 +145,7 @@ DocumentApiConverter::toStorageAPI(documentapi::DocumentMessage& fromMsg, case DocumentProtocol::MESSAGE_REMOVELOCATION: { documentapi::RemoveLocationMessage& from(static_cast<documentapi::RemoveLocationMessage&>(fromMsg)); - document::Bucket bucket(_bucketResolver.bucketSpaceFromName(from.getBucketSpace()), document::BucketId(0)); + document::Bucket bucket(bucketResolver()->bucketSpaceFromName(from.getBucketSpace()), document::BucketId(0)); api::RemoveLocationCommand::UP to(new api::RemoveLocationCommand(from.getDocumentSelection(), bucket)); toMsg.reset(to.release()); break; @@ -298,7 +298,7 @@ DocumentApiConverter::toDocumentAPI(api::StorageCommand& fromMsg, const document documentapi::CreateVisitorMessage::UP to( new documentapi::CreateVisitorMessage(from.getLibraryName(), from.getInstanceId(), from.getControlDestination(), from.getDataDestination())); - to->setBucketSpace(_bucketResolver.nameFromBucketSpace(from.getBucketSpace())); + to->setBucketSpace(bucketResolver()->nameFromBucketSpace(from.getBucketSpace())); to->setDocumentSelection(from.getDocumentSelection()); to->setMaximumPendingReplyCount(from.getMaximumPendingReplyCount()); to->setParameters(from.getParameters()); @@ -325,7 +325,7 @@ DocumentApiConverter::toDocumentAPI(api::StorageCommand& fromMsg, const document { api::StatBucketCommand& from(static_cast<api::StatBucketCommand&>(fromMsg)); auto statMsg = std::make_unique<documentapi::StatBucketMessage>(from.getBucket().getBucketId(), from.getDocumentSelection()); - statMsg->setBucketSpace(_bucketResolver.nameFromBucketSpace(from.getBucket().getBucketSpace())); + statMsg->setBucketSpace(bucketResolver()->nameFromBucketSpace(from.getBucket().getBucketSpace())); toMsg = std::move(statMsg); break; } @@ -404,4 +404,14 @@ DocumentApiConverter::transferReplyState(api::StorageReply& fromMsg, mbus::Reply } } +std::shared_ptr<const BucketResolver> DocumentApiConverter::bucketResolver() const { + std::lock_guard lock(_mutex); + return _bucketResolver; +} + +void DocumentApiConverter::setBucketResolver(std::shared_ptr<const BucketResolver> resolver) { + std::lock_guard lock(_mutex); + _bucketResolver = std::move(resolver); +} + } // storage diff --git a/storage/src/vespa/storage/storageserver/documentapiconverter.h b/storage/src/vespa/storage/storageserver/documentapiconverter.h index 5310bcd0127..546bc86a007 100644 --- a/storage/src/vespa/storage/storageserver/documentapiconverter.h +++ b/storage/src/vespa/storage/storageserver/documentapiconverter.h @@ -4,6 +4,7 @@ #include <vespa/documentapi/messagebus/messages/documentmessage.h> #include <vespa/documentapi/messagebus/messages/documentreply.h> #include <vespa/document/repo/documenttyperepo.h> +#include <mutex> namespace config { class ConfigUri; } namespace storage { @@ -23,7 +24,7 @@ class DocumentApiConverter { public: DocumentApiConverter(const config::ConfigUri &configUri, - const BucketResolver &bucketResolver); + std::shared_ptr<const BucketResolver> bucketResolver); ~DocumentApiConverter(); std::unique_ptr<api::StorageCommand> toStorageAPI(documentapi::DocumentMessage& msg, const document::DocumentTypeRepo::SP &repo); @@ -31,9 +32,14 @@ public: void transferReplyState(storage::api::StorageReply& from, mbus::Reply& to); std::unique_ptr<mbus::Message> toDocumentAPI(api::StorageCommand& cmd, const document::DocumentTypeRepo::SP &repo); const PriorityConverter& getPriorityConverter() const { return *_priConverter; } + + // BucketResolver getter and setter are both thread safe. + std::shared_ptr<const BucketResolver> bucketResolver() const; + void setBucketResolver(std::shared_ptr<const BucketResolver> resolver); private: + mutable std::mutex _mutex; std::unique_ptr<PriorityConverter> _priConverter; - const BucketResolver &_bucketResolver; + std::shared_ptr<const BucketResolver> _bucketResolver; }; } // namespace storage diff --git a/storage/src/vespa/storage/storageserver/storagenode.cpp b/storage/src/vespa/storage/storageserver/storagenode.cpp index ba1556bd3b9..d60f46e5a07 100644 --- a/storage/src/vespa/storage/storageserver/storagenode.cpp +++ b/storage/src/vespa/storage/storageserver/storagenode.cpp @@ -76,12 +76,38 @@ StorageNode::StorageNode( std::unique_ptr<HostInfo> hostInfo, RunMode mode) : _singleThreadedDebugMode(mode == SINGLE_THREADED_TEST_MODE), + _configFetcher(), _hostInfo(std::move(hostInfo)), _context(context), _generationFetcher(generationFetcher), + _rootFolder(), _attemptedStopped(false), + _pidFile(), + _statusWebServer(), + _metrics(), + _metricManager(), + _deadLockDetector(), + _statusMetrics(), + _stateReporter(), + _stateManager(), + _chain(), + _configLock(), + _initial_config_mutex(), + _serverConfig(), + _clusterConfig(), + _distributionConfig(), + _priorityConfig(), + _doctypesConfig(), + _bucketSpacesConfig(), + _newServerConfig(), + _newClusterConfig(), + _newDistributionConfig(), + _newPriorityConfig(), + _newDoctypesConfig(), + _newBucketSpacesConfig(), + _component(), _configUri(configUri), - _communicationManager(0) + _communicationManager(nullptr) { } @@ -93,6 +119,7 @@ StorageNode::subscribeToConfigs() _configFetcher->subscribe<UpgradingConfig>(_configUri.getConfigId(), this); _configFetcher->subscribe<StorServerConfig>(_configUri.getConfigId(), this); _configFetcher->subscribe<StorPrioritymappingConfig>(_configUri.getConfigId(), this); + _configFetcher->subscribe<BucketspacesConfig>(_configUri.getConfigId(), this); _configFetcher->start(); @@ -101,6 +128,7 @@ StorageNode::subscribeToConfigs() _clusterConfig = std::move(_newClusterConfig); _distributionConfig = std::move(_newDistributionConfig); _priorityConfig = std::move(_newPriorityConfig); + _bucketSpacesConfig = std::move(_newBucketSpacesConfig); } void @@ -127,6 +155,7 @@ StorageNode::initialize() _context.getComponentRegister().setBucketIdFactory(document::BucketIdFactory()); _context.getComponentRegister().setDistribution(make_shared<lib::Distribution>(*_distributionConfig)); _context.getComponentRegister().setPriorityConfig(*_priorityConfig); + _context.getComponentRegister().setBucketSpacesConfig(*_bucketSpacesConfig); _metrics.reset(new StorageMetricSet); _component.reset(new StorageComponent(_context.getComponentRegister(), "storagenode")); @@ -315,6 +344,11 @@ StorageNode::handleLiveConfigUpdate(const InitialGuard & initGuard) _priorityConfig = std::move(_newPriorityConfig); _context.getComponentRegister().setPriorityConfig(*_priorityConfig); } + if (_newBucketSpacesConfig) { + _bucketSpacesConfig = std::move(_newBucketSpacesConfig); + _context.getComponentRegister().setBucketSpacesConfig(*_bucketSpacesConfig); + // TODO: Add new bucket space resolver to document api converter + } } void @@ -430,7 +464,7 @@ void StorageNode::configure(std::unique_ptr<StorServerConfig> config) // updates { vespalib::LockGuard configLockGuard(_configLock); - _newServerConfig.reset(config.release()); + _newServerConfig = std::move(config); } if (_serverConfig) { InitialGuard concurrent_config_guard(_initial_config_mutex); @@ -447,7 +481,7 @@ StorageNode::configure(std::unique_ptr<UpgradingConfig> config) // updates { vespalib::LockGuard configLockGuard(_configLock); - _newClusterConfig.reset(config.release()); + _newClusterConfig = std::move(config); } if (_clusterConfig) { InitialGuard concurrent_config_guard(_initial_config_mutex); @@ -464,7 +498,7 @@ StorageNode::configure(std::unique_ptr<StorDistributionConfig> config) // updates { vespalib::LockGuard configLockGuard(_configLock); - _newDistributionConfig.reset(config.release()); + _newDistributionConfig = std::move(config); } if (_distributionConfig) { InitialGuard concurrent_config_guard(_initial_config_mutex); @@ -477,7 +511,7 @@ StorageNode::configure(std::unique_ptr<StorPrioritymappingConfig> config) { { vespalib::LockGuard configLockGuard(_configLock); - _newPriorityConfig.reset(config.release()); + _newPriorityConfig = std::move(config); } if (_priorityConfig) { InitialGuard concurrent_config_guard(_initial_config_mutex); @@ -485,15 +519,16 @@ StorageNode::configure(std::unique_ptr<StorPrioritymappingConfig> config) } } -void StorageNode::configure(std::unique_ptr<document::DocumenttypesConfig> config, - bool hasChanged, int64_t generation) +void +StorageNode::configure(std::unique_ptr<document::DocumenttypesConfig> config, + bool hasChanged, int64_t generation) { (void) generation; if (!hasChanged) return; { vespalib::LockGuard configLockGuard(_configLock); - _newDoctypesConfig.reset(config.release()); + _newDoctypesConfig = std::move(config); } if (_doctypesConfig) { InitialGuard concurrent_config_guard(_initial_config_mutex); @@ -501,6 +536,19 @@ void StorageNode::configure(std::unique_ptr<document::DocumenttypesConfig> confi } } +void +StorageNode::configure(std::unique_ptr<BucketspacesConfig> config) +{ + { + vespalib::LockGuard configLockGuard(_configLock); + _newBucketSpacesConfig = std::move(config); + } + if (_bucketSpacesConfig) { + InitialGuard concurrent_config_guard(_initial_config_mutex); + handleLiveConfigUpdate(concurrent_config_guard); + } +} + bool StorageNode::attemptedStopped() const { diff --git a/storage/src/vespa/storage/storageserver/storagenode.h b/storage/src/vespa/storage/storageserver/storagenode.h index e9d3004be68..a07d1c0c534 100644 --- a/storage/src/vespa/storage/storageserver/storagenode.h +++ b/storage/src/vespa/storage/storageserver/storagenode.h @@ -12,20 +12,19 @@ #pragma once -#include <vespa/storage/storageutil/resumeguard.h> -#include <vespa/storage/common/doneinitializehandler.h> -#include <vespa/storageframework/generic/metric/metricupdatehook.h> -#include <vespa/storageframework/defaultimplementation/component/componentregisterimpl.h> - -#include <vespa/config/subscription/configuri.h> -#include <vespa/config/helper/ifetchercallback.h> +#include <vespa/config-stor-distribution.h> +#include <vespa/config-upgrading.h> #include <vespa/config/helper/configfetcher.h> - +#include <vespa/config/helper/ifetchercallback.h> +#include <vespa/config/subscription/configuri.h> +#include <vespa/document/config/config-documenttypes.h> +#include <vespa/storage/common/doneinitializehandler.h> +#include <vespa/storage/config/config-bucketspaces.h> #include <vespa/storage/config/config-stor-prioritymapping.h> #include <vespa/storage/config/config-stor-server.h> -#include <vespa/document/config/config-documenttypes.h> -#include <vespa/config-upgrading.h> -#include <vespa/config-stor-distribution.h> +#include <vespa/storage/storageutil/resumeguard.h> +#include <vespa/storageframework/defaultimplementation/component/componentregisterimpl.h> +#include <vespa/storageframework/generic/metric/metricupdatehook.h> #include <mutex> namespace document { class DocumentTypeRepo; } @@ -54,6 +53,7 @@ class StorageNode : private config::IFetcherCallback<vespa::config::content::cor private config::IFetcherCallback<vespa::config::content::StorDistributionConfig>, private config::IFetcherCallback<vespa::config::content::UpgradingConfig>, private config::IFetcherCallback<vespa::config::content::core::StorPrioritymappingConfig>, + private config::IFetcherCallback<vespa::config::content::core::BucketspacesConfig>, private framework::MetricUpdateHook, private DoneInitializeHandler, private framework::defaultimplementation::ShutdownListener @@ -101,6 +101,7 @@ protected: using UpgradingConfig = vespa::config::content::UpgradingConfig; using StorDistributionConfig = vespa::config::content::StorDistributionConfig; using StorPrioritymappingConfig = vespa::config::content::core::StorPrioritymappingConfig; + using BucketspacesConfig = vespa::config::content::core::BucketspacesConfig; private: bool _singleThreadedDebugMode; // Subscriptions to config @@ -137,6 +138,7 @@ private: void configure(std::unique_ptr<StorPrioritymappingConfig>) override; virtual void configure(std::unique_ptr<document::DocumenttypesConfig> config, bool hasChanged, int64_t generation); + void configure(std::unique_ptr<BucketspacesConfig>) override; void updateUpgradeFlag(const UpgradingConfig&); protected: @@ -151,12 +153,14 @@ protected: std::unique_ptr<StorDistributionConfig> _distributionConfig; std::unique_ptr<StorPrioritymappingConfig> _priorityConfig; std::unique_ptr<document::DocumenttypesConfig> _doctypesConfig; + std::unique_ptr<BucketspacesConfig> _bucketSpacesConfig; // New configs gotten that has yet to have been handled std::unique_ptr<StorServerConfig> _newServerConfig; std::unique_ptr<UpgradingConfig> _newClusterConfig; std::unique_ptr<StorDistributionConfig> _newDistributionConfig; std::unique_ptr<StorPrioritymappingConfig> _newPriorityConfig; std::unique_ptr<document::DocumenttypesConfig> _newDoctypesConfig; + std::unique_ptr<BucketspacesConfig> _newBucketSpacesConfig; std::unique_ptr<StorageComponent> _component; config::ConfigUri _configUri; CommunicationManager* _communicationManager; diff --git a/storageapi/src/vespa/storageapi/message/state.h b/storageapi/src/vespa/storageapi/message/state.h index e8062c71d22..746d92fce6b 100644 --- a/storageapi/src/vespa/storageapi/message/state.h +++ b/storageapi/src/vespa/storageapi/message/state.h @@ -6,8 +6,7 @@ #include <vespa/storageapi/messageapi/storagereply.h> #include <vespa/vdslib/state/clusterstate.h> -namespace storage { -namespace api { +namespace storage::api { /** * @class GetNodeStateCommand @@ -90,5 +89,4 @@ public: DECLARE_STORAGEREPLY(SetSystemStateReply, onSetSystemStateReply) }; -} // api -} // storage +} diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java index 7874dcb24ab..16a541f939c 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java @@ -5,6 +5,7 @@ import com.google.common.annotations.Beta; import com.yahoo.vespa.http.client.Result; import com.yahoo.vespa.http.client.config.Endpoint; import com.yahoo.vespa.http.client.core.Document; +import com.yahoo.vespa.http.client.core.Exceptions; import com.yahoo.vespa.http.client.core.operationProcessor.EndPointResultFactory; import com.yahoo.vespa.http.client.core.EndpointResult; import com.yahoo.vespa.http.client.core.ServerResponseException; @@ -318,29 +319,28 @@ class IOThread implements Runnable, AutoCloseable { successfullHandshakes.getAndIncrement(); } catch (ServerResponseException ser) { executeProblemsCounter.incrementAndGet(); - log.log(Level.INFO, "Handshake did not work out " + endpoint, ser.getMessage()); + log.log(Level.INFO, "Handshake did not work out " + endpoint, Exceptions.toMessageString(ser)); drainFirstDocumentsInQueueIfOld(); return ThreadState.CONNECTED; } catch (Throwable throwable) { // This cover IOException as well executeProblemsCounter.incrementAndGet(); - log.log(Level.INFO, "Problem with Handshake " + endpoint, throwable.getMessage()); + log.log(Level.INFO, "Problem with Handshake " + endpoint, Exceptions.toMessageString(throwable)); drainFirstDocumentsInQueueIfOld(); client.close(); return ThreadState.DISCONNECTED; } return ThreadState.SESSION_SYNCED; case SESSION_SYNCED: - final int maxWaitTimeMilliSecs = 100; try { - ProcessResponse processResponse = pullAndProcessData(maxWaitTimeMilliSecs); + ProcessResponse processResponse = pullAndProcessData(100); gatewayThrottler.handleCall(processResponse.transitiveErrorCount); } catch (ServerResponseException ser) { - log.info("Problems while handing data over to gateway " + endpoint + " " + ser.getMessage()); + log.info("Problems while handing data over to gateway " + endpoint + ": " + Exceptions.toMessageString(ser)); return ThreadState.CONNECTED; } catch (Throwable e) { // Covers IOException as well - log.info("Problems while handing data over to gateway " + endpoint + " " + e.getMessage()); + log.info("Problems while handing data over to gateway " + endpoint + ": " + Exceptions.toMessageString(e)); client.close(); return ThreadState.DISCONNECTED; } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java index 0a9fe72552c..5907694f55a 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java @@ -177,10 +177,14 @@ public class OperationProcessor { docSendInfoByOperationId.remove(endpointResult.getOperationId()); String documentId = documentSendInfo.getDocument().getDocumentId(); - inflightDocumentIds.remove(documentId); - + /** + * If we got a pending operation against this document + * dont't remove it from inflightDocuments and send blocked document operation + */ List<Document> blockedDocuments = blockedDocumentsByDocumentId.get(documentId); - if (! blockedDocuments.isEmpty()) { + if (blockedDocuments.isEmpty()) { + inflightDocumentIds.remove(documentId); + } else { sendToClusters(blockedDocuments.remove(0)); } return result; diff --git a/vespajlib/src/main/java/com/yahoo/net/HostName.java b/vespajlib/src/main/java/com/yahoo/net/HostName.java index 37f7fe80246..157239e456f 100644 --- a/vespajlib/src/main/java/com/yahoo/net/HostName.java +++ b/vespajlib/src/main/java/com/yahoo/net/HostName.java @@ -27,7 +27,7 @@ public class HostName { private static final Logger logger = Logger.getLogger(HostName.class.getName()); - private static String cachedHostName = null; + private static String preferredHostName = null; /** * Return a public and fully qualified hostname for localhost that resolves to an IP address on @@ -38,14 +38,14 @@ public class HostName { * @throws RuntimeException if accessing the network or the 'hostname' command fails */ public static synchronized String getLocalhost() { - if (cachedHostName == null) { + if (preferredHostName == null) { try { - cachedHostName = getPreferredHostName(); + preferredHostName = getPreferredHostName(); } catch (Exception e) { throw new RuntimeException("Failed to find a preferred hostname", e); } } - return cachedHostName; + return preferredHostName; } private static String getPreferredHostName() throws Exception { @@ -178,4 +178,7 @@ public class HostName { } } + public static void setHostNameForTestingOnly(String hostName) { + preferredHostName = hostName; + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index 00e106dd035..01bf082d32f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -7,13 +7,13 @@ import java.util.Arrays; /** * The sizes of a set of dimensions. - * + * * @author bratseth */ @Beta public final class DimensionSizes { - private final int[] sizes; + private final long[] sizes; private DimensionSizes(Builder builder) { this.sizes = builder.sizes; @@ -25,15 +25,15 @@ public final class DimensionSizes { * * @throws IndexOutOfBoundsException if the index is larger than the number of dimensions in this tensor minus one */ - public int size(int dimensionIndex) { return sizes[dimensionIndex]; } + public long size(int dimensionIndex) { return sizes[dimensionIndex]; } /** Returns the number of dimensions this provides the size of */ public int dimensions() { return sizes.length; } /** Returns the product of the sizes of this */ - public int totalSize() { - int productSize = 1; - for (int dimensionSize : sizes ) + public long totalSize() { + long productSize = 1; + for (long dimensionSize : sizes ) productSize *= dimensionSize; return productSize; } @@ -48,19 +48,19 @@ public final class DimensionSizes { @Override public int hashCode() { return Arrays.hashCode(sizes); } - /** + /** * Builder of a set of dimension sizes. * Dimensions whose size is not set before building will get size 0. */ public final static class Builder { - private int[] sizes; + private long[] sizes; public Builder(int dimensions) { - this.sizes = new int[dimensions]; + this.sizes = new long[dimensions]; } - public Builder set(int dimensionIndex, int size) { + public Builder set(int dimensionIndex, long size) { sizes[dimensionIndex] = size; return this; } @@ -70,7 +70,7 @@ public final class DimensionSizes { * * @throws IndexOutOfBoundsException if the index is larger than the number of dimensions in this tensor minus one */ - public int size(int dimensionIndex) { return sizes[dimensionIndex]; } + public long size(int dimensionIndex) { return sizes[dimensionIndex]; } /** Returns the number of dimensions this provides the size of */ public int dimensions() { return sizes.length; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index c207dabca3a..7130c053e9f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -25,12 +25,12 @@ public class IndexedTensor implements Tensor { /** The prescribed and possibly abstract type this is an instance of */ private final TensorType type; - + /** The sizes of the dimensions of this in the order of the dimensions of the type */ private final DimensionSizes dimensionSizes; - + private final double[] values; - + private IndexedTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) { this.type = type; this.dimensionSizes = dimensionSizes; @@ -38,13 +38,13 @@ public class IndexedTensor implements Tensor { } @Override - public int size() { + public long size() { return values.length; } /** - * Returns an iterator over the cells of this. - * Cells are returned in order of increasing indexes in each dimension, increasing + * Returns an iterator over the cells of this. + * Cells are returned in order of increasing indexes in each dimension, increasing * indexes of later dimensions in the dimension type before earlier. */ @Override @@ -55,10 +55,10 @@ public class IndexedTensor implements Tensor { /** Returns an iterator over all the cells in this tensor which matches the given partial address */ // TODO: Move up to Tensor and create a mixed tensor which can implement it (and subspace iterators) efficiently public SubspaceIterator cellIterator(PartialAddress partialAddress, DimensionSizes iterationSizes) { - int[] startAddress = new int[type().dimensions().size()]; + long[] startAddress = new long[type().dimensions().size()]; List<Integer> iterateDimensions = new ArrayList<>(); for (int i = 0; i < type().dimensions().size(); i++) { - int partialAddressLabel = partialAddress.intLabel(type.dimensions().get(i).name()); + long partialAddressLabel = partialAddress.numericLabel(type.dimensions().get(i).name()); if (partialAddressLabel >= 0) // iterate at this label startAddress[i] = partialAddressLabel; else // iterate over this dimension @@ -69,7 +69,7 @@ public class IndexedTensor implements Tensor { /** * Returns an iterator over the values of this. - * Values are returned in order of increasing indexes in each dimension, increasing + * Values are returned in order of increasing indexes in each dimension, increasing * indexes of later dimensions in the dimension type before earlier. */ @Override @@ -81,7 +81,7 @@ public class IndexedTensor implements Tensor { * Returns an iterator over value iterators where the outer iterator is over each unique value of the dimensions * given and the inner iterator is over each unique value of the rest of the dimensions, in the same order as * other iterator. - * + * * @param dimensions the names of the dimensions of the superspace * @param sizes the size of each dimension in the space we are returning values for, containing * one value per dimension of this tensor (in order). Each size may be the same or smaller @@ -96,14 +96,14 @@ public class IndexedTensor implements Tensor { return subspaceIterator(dimensions, dimensionSizes); } - /** + /** * Returns the value at the given indexes - * + * * @param indexes the indexes into the dimensions of this. Must be one number per dimension of this * @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given */ - public double get(int ... indexes) { - return values[toValueIndex(indexes, dimensionSizes)]; + public double get(long ... indexes) { + return values[(int)toValueIndex(indexes, dimensionSizes)]; } /** Returns the value at this address, or NaN if there is no value at this address */ @@ -111,20 +111,20 @@ public class IndexedTensor implements Tensor { public double get(TensorAddress address) { // optimize for fast lookup within bounds: try { - return values[toValueIndex(address, dimensionSizes)]; + return values[(int)toValueIndex(address, dimensionSizes)]; } catch (IndexOutOfBoundsException e) { return Double.NaN; } } - private double get(int valueIndex) { return values[valueIndex]; } - - private static int toValueIndex(int[] indexes, DimensionSizes sizes) { + private double get(long valueIndex) { return values[(int)valueIndex]; } + + private static long toValueIndex(long[] indexes, DimensionSizes sizes) { if (indexes.length == 1) return indexes[0]; // for speed if (indexes.length == 0) return 0; // for speed - int valueIndex = 0; + long valueIndex = 0; for (int i = 0; i < indexes.length; i++) { if (indexes[i] >= sizes.size(i)) { throw new IndexOutOfBoundsException(); @@ -134,21 +134,21 @@ public class IndexedTensor implements Tensor { return valueIndex; } - private static int toValueIndex(TensorAddress address, DimensionSizes sizes) { + private static long toValueIndex(TensorAddress address, DimensionSizes sizes) { if (address.isEmpty()) return 0; - int valueIndex = 0; + long valueIndex = 0; for (int i = 0; i < address.size(); i++) { - if (address.intLabel(i) >= sizes.size(i)) { + if (address.numericLabel(i) >= sizes.size(i)) { throw new IndexOutOfBoundsException(); } - valueIndex += productOfDimensionsAfter(i, sizes) * address.intLabel(i); + valueIndex += productOfDimensionsAfter(i, sizes) * address.numericLabel(i); } return valueIndex; } - private static int productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) { - int product = 1; + private static long productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) { + long product = 1; for (int i = afterIndex + 1; i < sizes.dimensions(); i++) product *= sizes.size(i); return product; @@ -165,22 +165,22 @@ public class IndexedTensor implements Tensor { public Map<TensorAddress, Double> cells() { if (dimensionSizes.dimensions() == 0) return Collections.singletonMap(TensorAddress.of(), values[0]); - + ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>(); Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); - for (int i = 0; i < values.length; i++) { + for (long i = 0; i < values.length; i++) { indexes.next(); - builder.put(indexes.toAddress(), values[i]); + builder.put(indexes.toAddress(), values[(int)i]); } return builder.build(); } - + @Override public int hashCode() { return Arrays.hashCode(values); } @Override public String toString() { return Tensor.toStandardString(this); } - + @Override public boolean equals(Object other) { if ( ! ( other instanceof Tensor)) return false; @@ -188,9 +188,9 @@ public class IndexedTensor implements Tensor { } public abstract static class Builder implements Tensor.Builder { - + final TensorType type; - + private Builder(TensorType type) { this.type = type; } @@ -202,7 +202,7 @@ public class IndexedTensor implements Tensor { return new UnboundBuilder(type); } - /** + /** * Create a builder with dimension size information for this instance. Must be one size entry per dimension, * and, agree with the type size information when specified in the type. * If sizes are completely specified in the type this size information is redundant. @@ -210,20 +210,20 @@ public class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes) { // validate if (sizes.dimensions() != type.dimensions().size()) - throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " + + throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " + "for " + type); for (int i = 0; i < sizes.dimensions(); i++ ) { - Optional<Integer> size = type.dimensions().get(i).size(); + Optional<Long> size = type.dimensions().get(i).size(); if (size.isPresent() && size.get() < sizes.size(i)) - throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " + + throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " + sizes.size(i) + " but cannot be larger than " + size.get() + " in " + type); } - + return new BoundBuilder(type, sizes); } - public abstract Builder cell(double value, int ... indexes); + public abstract Builder cell(double value, long ... indexes); @Override public TensorType type() { return type; } @@ -232,7 +232,7 @@ public class IndexedTensor implements Tensor { public abstract IndexedTensor build(); } - + /** A bound builder can create the double array directly */ public static class BoundBuilder extends Builder { @@ -255,15 +255,15 @@ public class IndexedTensor implements Tensor { if ( sizes.dimensions() != type.dimensions().size()) throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type); this.sizes = sizes; - values = new double[sizes.totalSize()]; + values = new double[(int)sizes.totalSize()]; } - + @Override - public BoundBuilder cell(double value, int ... indexes) { - values[toValueIndex(indexes, sizes)] = value; + public BoundBuilder cell(double value, long ... indexes) { + values[(int)toValueIndex(indexes, sizes)] = value; return this; } - + @Override public CellBuilder cell() { return new CellBuilder(type, this); @@ -271,7 +271,7 @@ public class IndexedTensor implements Tensor { @Override public Builder cell(TensorAddress address, double value) { - values[toValueIndex(address, sizes)] = value; + values[(int)toValueIndex(address, sizes)] = value; return this; } @@ -286,21 +286,21 @@ public class IndexedTensor implements Tensor { @Override public Builder cell(Cell cell, double value) { - int directIndex = cell.getDirectIndex(); + long directIndex = cell.getDirectIndex(); if (directIndex >= 0) // optimization - values[directIndex] = value; + values[(int)directIndex] = value; else super.cell(cell, value); return this; } - /** - * Set a cell value by the index in the internal layout of this cell. + /** + * Set a cell value by the index in the internal layout of this cell. * This requires knowledge of the internal layout of cells in this implementation, and should therefore * probably not be used (but when it can be used it is fast). */ - public void cellByDirectIndex(int index, double value) { - values[index] = value; + public void cellByDirectIndex(long index, double value) { + values[(int)index] = value; } } @@ -326,13 +326,13 @@ public class IndexedTensor implements Tensor { return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) }); DimensionSizes dimensionSizes = findDimensionSizes(firstDimension); - double[] values = new double[dimensionSizes.totalSize()]; + double[] values = new double[(int)dimensionSizes.totalSize()]; fillValues(0, 0, firstDimension, dimensionSizes, values); return new IndexedTensor(type, dimensionSizes, values); } - + private DimensionSizes findDimensionSizes(List<Object> firstDimension) { - List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size()); + List<Long> dimensionSizeList = new ArrayList<>(type.dimensions().size()); findDimensionSizes(0, dimensionSizeList, firstDimension); DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); // may be longer than the list but that's correct for (int i = 0; i < b.dimensions(); i++) { @@ -343,33 +343,33 @@ public class IndexedTensor implements Tensor { } @SuppressWarnings("unchecked") - private void findDimensionSizes(int currentDimensionIndex, List<Integer> dimensionSizes, List<Object> currentDimension) { + private void findDimensionSizes(int currentDimensionIndex, List<Long> dimensionSizes, List<Object> currentDimension) { if (currentDimensionIndex == dimensionSizes.size()) - dimensionSizes.add(currentDimension.size()); + dimensionSizes.add((long)currentDimension.size()); else if (dimensionSizes.get(currentDimensionIndex) != currentDimension.size()) - throw new IllegalArgumentException("Missing values in dimension " + + throw new IllegalArgumentException("Missing values in dimension " + type.dimensions().get(currentDimensionIndex) + " in " + type); - + for (Object value : currentDimension) if (value instanceof List) findDimensionSizes(currentDimensionIndex + 1, dimensionSizes, (List<Object>)value); } @SuppressWarnings("unchecked") - private void fillValues(int currentDimensionIndex, int offset, List<Object> currentDimension, + private void fillValues(int currentDimensionIndex, long offset, List<Object> currentDimension, DimensionSizes sizes, double[] values) { if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension - for (int i = 0; i < currentDimension.size(); i++) + for (long i = 0; i < currentDimension.size(); i++) fillValues(currentDimensionIndex + 1, offset + productOfDimensionsAfter(currentDimensionIndex, sizes) * i, - (List<Object>) currentDimension.get(i), sizes, values); + (List<Object>) currentDimension.get((int)i), sizes, values); } else { // last dimension - fill values - for (int i = 0; i < currentDimension.size(); i++) { - values[offset + i] = nullAsZero((Double)currentDimension.get(i)); // fill missing values as zero + for (long i = 0; i < currentDimension.size(); i++) { + values[(int)(offset + i)] = nullAsZero((Double)currentDimension.get((int)i)); // fill missing values as zero } } } - + private double nullAsZero(Double value) { if (value == null) return 0; return value; @@ -382,9 +382,9 @@ public class IndexedTensor implements Tensor { @Override public Builder cell(TensorAddress address, double value) { - int[] indexes = new int[address.size()]; + long[] indexes = new long[address.size()]; for (int i = 0; i < address.size(); i++) { - indexes[i] = address.intLabel(i); + indexes[i] = address.numericLabel(i); } cell(value, indexes); return this; @@ -399,7 +399,7 @@ public class IndexedTensor implements Tensor { */ @SuppressWarnings("unchecked") @Override - public Builder cell(double value, int... indexes) { + public Builder cell(double value, long... indexes) { if (indexes.length != type.dimensions().size()) throw new IllegalArgumentException("Wrong number of indexes (" + indexes.length + ") for " + type); @@ -414,27 +414,27 @@ public class IndexedTensor implements Tensor { for (int dimensionIndex = 0; dimensionIndex < indexes.length; dimensionIndex++) { ensureCapacity(indexes[dimensionIndex], currentValues); if (dimensionIndex == indexes.length - 1) { // last dimension - currentValues.set(indexes[dimensionIndex], value); + currentValues.set((int)indexes[dimensionIndex], value); } else { - if (currentValues.get(indexes[dimensionIndex]) == null) - currentValues.set(indexes[dimensionIndex], new ArrayList<>()); - currentValues = (List<Object>) currentValues.get(indexes[dimensionIndex]); + if (currentValues.get((int)indexes[dimensionIndex]) == null) + currentValues.set((int)indexes[dimensionIndex], new ArrayList<>()); + currentValues = (List<Object>) currentValues.get((int)indexes[dimensionIndex]); } } return this; } /** Fill the given list with nulls if necessary to make sure it has a (possibly null) value at the given index */ - private void ensureCapacity(int index, List<Object> list) { + private void ensureCapacity(long index, List<Object> list) { while (list.size() <= index) list.add(list.size(), null); } } - + private final class CellIterator implements Iterator<Cell> { - private int count = 0; + private long count = 0; private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); private final LazyCell reusedCell = new LazyCell(indexes, Double.NaN); @@ -451,12 +451,12 @@ public class IndexedTensor implements Tensor { reusedCell.value = get(indexes.toSourceValueIndex()); return reusedCell; } - + } private final class ValueIterator implements Iterator<Double> { - private int count = 0; + private long count = 0; @Override public boolean hasNext() { @@ -466,7 +466,7 @@ public class IndexedTensor implements Tensor { @Override public Double next() { try { - return values[count++]; + return values[(int)count++]; } catch (IndexOutOfBoundsException e) { throw new NoSuchElementException("No element at position " + count); @@ -474,25 +474,25 @@ public class IndexedTensor implements Tensor { } } - + private final class SuperspaceIterator implements Iterator<SubspaceIterator> { private final Indexes superindexes; - /** Those indexes this should iterate over */ + /** The indexes this should iterate over */ private final List<Integer> subdimensionIndexes; - - /** + + /** * The sizes of the space we'll return values of, one value for each dimension of this tensor, - * which may be equal to or smaller than the sizes of this tensor + * which may be equal to or smaller than the sizes of this tensor */ private final DimensionSizes iterateSizes; - private int count = 0; - + private long count = 0; + private SuperspaceIterator(Set<String> superdimensionNames, DimensionSizes iterateSizes) { this.iterateSizes = iterateSizes; - + List<Integer> superdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for outer iterator subdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for inner iterator (max length) for (int i = type.dimensions().size() - 1; i >= 0; i-- ) { // iterate inner dimensions first @@ -501,10 +501,10 @@ public class IndexedTensor implements Tensor { else subdimensionIndexes.add(i); } - + superindexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, superdimensionIndexes); } - + @Override public boolean hasNext() { return count < superindexes.size(); @@ -527,60 +527,60 @@ public class IndexedTensor implements Tensor { */ public final class SubspaceIterator implements Iterator<Tensor.Cell> { - /** + /** * This iterator will iterate over the given dimensions, in the order given * (the first dimension index given is incremented to exhaustion first (i.e is etc.). * This may be any subset of the dimensions given by address and dimensionSizes. */ private final List<Integer> iterateDimensions; - private final int[] address; + private final long[] address; private final DimensionSizes iterateSizes; private Indexes indexes; - private int count = 0; - + private long count = 0; + /** A lazy cell for reuse */ private final LazyCell reusedCell; - - /** + + /** * Creates a new subspace iterator - * + * * @param iterateDimensions the dimensions to iterate over, given as indexes in the dimension order of the * type of the tensor this iterates over. This iterator will iterate over these - * dimensions to exhaustion in the order given (the first dimension index given is + * dimensions to exhaustion in the order given (the first dimension index given is * incremented to exhaustion first (i.e is etc.), while other dimensions will be held * at a constant position. * This may be any subset of the dimensions given by address and dimensionSizes. * This is treated as immutable. - * @param address the address of the first cell of this subspace. + * @param address the address of the first cell of this subspace. */ - private SubspaceIterator(List<Integer> iterateDimensions, int[] address, DimensionSizes iterateSizes) { + private SubspaceIterator(List<Integer> iterateDimensions, long[] address, DimensionSizes iterateSizes) { this.iterateDimensions = iterateDimensions; this.address = address; this.iterateSizes = iterateSizes; this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address); reusedCell = new LazyCell(indexes, Double.NaN); } - + /** Returns the total number of cells in this subspace */ - public int size() { + public long size() { return indexes.size(); } - + /** Returns the address of the cell this currently points to (which may be an invalid position) */ public TensorAddress address() { return indexes.toAddress(); } - + /** Rewind this iterator to the first element */ - public void reset() { + public void reset() { this.count = 0; - this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address); + this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address); } - + @Override public boolean hasNext() { - return count < indexes.size(); + return count < indexes.size(); } - + /** Returns the next cell, which is valid until next() is called again */ @Override public Cell next() { @@ -605,21 +605,21 @@ public class IndexedTensor implements Tensor { } @Override - int getDirectIndex() { return indexes.toIterationValueIndex(); } + long getDirectIndex() { return indexes.toIterationValueIndex(); } @Override public TensorAddress getKey() { return indexes.toAddress(); } - + @Override public Double getValue() { return value; } } // TODO: Make dimensionSizes a class - - /** + + /** * An array of indexes into this tensor which are able to find the next index in the value order. * next() can be called once per element in the dimensions we iterate over. It must be called once * before accessing the first position. @@ -630,8 +630,8 @@ public class IndexedTensor implements Tensor { private final DimensionSizes iterationSizes; - protected final int[] indexes; - + protected final long[] indexes; + public static Indexes of(DimensionSizes sizes) { return of(sizes, sizes); } @@ -640,7 +640,7 @@ public class IndexedTensor implements Tensor { return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions())); } - private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int size) { + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long size) { return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()), size); } @@ -648,15 +648,15 @@ public class IndexedTensor implements Tensor { return of(sourceSizes, iterateSizes, iterateDimensions, computeSize(iterateSizes, iterateDimensions)); } - private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int size) { - return of(sourceSizes, iterateSizes, iterateDimensions, new int[iterateSizes.dimensions()], size); + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, long size) { + return of(sourceSizes, iterateSizes, iterateDimensions, new long[iterateSizes.dimensions()], size); } - private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes) { + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, long[] initialIndexes) { return of(sourceSizes, iterateSizes, iterateDimensions, initialIndexes, computeSize(iterateSizes, iterateDimensions)); } - private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) { + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) { if (size == 0) { return new EmptyIndexes(sourceSizes, iterateSizes, initialIndexes); // we're told explicitly there are truly no values available } @@ -676,22 +676,22 @@ public class IndexedTensor implements Tensor { return new MultiDimensionIndexes(sourceSizes, iterateSizes, iterateDimensions, initialIndexes, size); } } - + private static List<Integer> completeIterationOrder(int length) { List<Integer> iterationDimensions = new ArrayList<>(length); for (int i = 0; i < length; i++) iterationDimensions.add(length - 1 - i); return iterationDimensions; } - - private Indexes(DimensionSizes sourceSizes, DimensionSizes iterationSizes, int[] indexes) { + + private Indexes(DimensionSizes sourceSizes, DimensionSizes iterationSizes, long[] indexes) { this.sourceSizes = sourceSizes; this.iterationSizes = iterationSizes; this.indexes = indexes; } - private static int computeSize(DimensionSizes sizes, List<Integer> iterateDimensions) { - int size = 1; + private static long computeSize(DimensionSizes sizes, List<Integer> iterateDimensions) { + long size = 1; for (int iterateDimension : iterateDimensions) size *= sizes.size(iterateDimension); return size; @@ -702,25 +702,25 @@ public class IndexedTensor implements Tensor { return TensorAddress.of(indexes); } - public int[] indexesCopy() { + public long[] indexesCopy() { return Arrays.copyOf(indexes, indexes.length); } /** Returns a copy of the indexes of this which must not be modified */ - public int[] indexesForReading() { return indexes; } - - int toSourceValueIndex() { - return IndexedTensor.toValueIndex(indexes, sourceSizes); + public long[] indexesForReading() { return indexes; } + + long toSourceValueIndex() { + return IndexedTensor.toValueIndex(indexes, sourceSizes); } - int toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationSizes); } + long toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationSizes); } DimensionSizes dimensionSizes() { return iterationSizes; } /** Returns an immutable list containing a copy of the indexes in this */ - public List<Integer> toList() { - ImmutableList.Builder<Integer> builder = new ImmutableList.Builder<>(); - for (int index : indexes) + public List<Long> toList() { + ImmutableList.Builder<Long> builder = new ImmutableList.Builder<>(); + for (long index : indexes) builder.add(index); return builder.build(); } @@ -729,21 +729,21 @@ public class IndexedTensor implements Tensor { public String toString() { return "indexes " + Arrays.toString(indexes); } - - public abstract int size(); - + + public abstract long size(); + public abstract void next(); } private final static class EmptyIndexes extends Indexes { - private EmptyIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int[] indexes) { + private EmptyIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long[] indexes) { super(sourceSizes, iterateSizes, indexes); } @Override - public int size() { return 0; } + public long size() { return 0; } @Override public void next() {} @@ -752,43 +752,43 @@ public class IndexedTensor implements Tensor { private final static class SingleValueIndexes extends Indexes { - private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int[] indexes) { + private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long[] indexes) { super(sourceSizes, iterateSizes, indexes); } @Override - public int size() { return 1; } + public long size() { return 1; } @Override public void next() {} } - + private static class MultiDimensionIndexes extends Indexes { - private final int size; + private final long size; private final List<Integer> iterateDimensions; - - private MultiDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) { + + private MultiDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) { super(sourceSizes, iterateSizes, initialIndexes); this.iterateDimensions = iterateDimensions; this.size = size; - + // Initialize to the (virtual) position before the first cell indexes[iterateDimensions.get(0)]--; } /** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */ @Override - public int size() { + public long size() { return size; } - /** - * Advances this to the next cell in the standard indexed tensor cell order. - * The first call to this will put it at the first position. - * + /** + * Advances this to the next cell in the standard indexed tensor cell order. + * The first call to this will put it at the first position. + * * @throws RuntimeException if this is called more times than its size */ @Override @@ -802,40 +802,42 @@ public class IndexedTensor implements Tensor { } } - + /** In this case we can reuse the source index computation for the iteration index */ private final static class EqualSizeMultiDimensionIndexes extends MultiDimensionIndexes { - private int lastComputedSourceValueIndex = -1; - - private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) { + private long lastComputedSourceValueIndex = -1; + + private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) { super(sizes, sizes, iterateDimensions, initialIndexes, size); } - int toSourceValueIndex() { + @Override + long toSourceValueIndex() { return lastComputedSourceValueIndex = super.toSourceValueIndex(); } // NOTE: We assume the source index always gets computed first. Otherwise using this will produce a runtime exception - int toIterationValueIndex() { return lastComputedSourceValueIndex; } + @Override + long toIterationValueIndex() { return lastComputedSourceValueIndex; } } /** In this case we can keep track of indexes using a step instead of using the more elaborate computation */ private final static class SingleDimensionIndexes extends Indexes { - private final int size; + private final long size; private final int iterateDimension; - + /** Maintain this directly as an optimization for 1-d iteration */ - private int currentSourceValueIndex, currentIterationValueIndex; + private long currentSourceValueIndex, currentIterationValueIndex; /** The iteration step in the value index space */ - private final int sourceStep, iterationStep; + private final long sourceStep, iterationStep; private SingleDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, - int iterateDimension, int[] initialIndexes, int size) { + int iterateDimension, long[] initialIndexes, long size) { super(sourceSizes, iterateSizes, initialIndexes); this.iterateDimension = iterateDimension; this.size = size; @@ -847,16 +849,16 @@ public class IndexedTensor implements Tensor { currentSourceValueIndex = IndexedTensor.toValueIndex(indexes, sourceSizes); currentIterationValueIndex = IndexedTensor.toValueIndex(indexes, iterateSizes); } - + /** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */ @Override - public int size() { + public long size() { return size; } /** - * Advances this to the next cell in the standard indexed tensor cell order. - * The first call to this will put it at the first position. + * Advances this to the next cell in the standard indexed tensor cell order. + * The first call to this will put it at the first position. * * @throws RuntimeException if this is called more times than its size */ @@ -868,28 +870,28 @@ public class IndexedTensor implements Tensor { } @Override - int toSourceValueIndex() { return currentSourceValueIndex; } + long toSourceValueIndex() { return currentSourceValueIndex; } @Override - int toIterationValueIndex() { return currentIterationValueIndex; } + long toIterationValueIndex() { return currentIterationValueIndex; } } /** In this case we only need to keep track of one index */ private final static class EqualSizeSingleDimensionIndexes extends Indexes { - private final int size; + private final long size; private final int iterateDimension; /** Maintain this directly as an optimization for 1-d iteration */ - private int currentValueIndex; + private long currentValueIndex; /** The iteration step in the value index space */ - private final int step; + private final long step; - private EqualSizeSingleDimensionIndexes(DimensionSizes sizes, - int iterateDimension, int[] initialIndexes, int size) { + private EqualSizeSingleDimensionIndexes(DimensionSizes sizes, + int iterateDimension, long[] initialIndexes, long size) { super(sizes, sizes, initialIndexes); this.iterateDimension = iterateDimension; this.size = size; @@ -902,13 +904,13 @@ public class IndexedTensor implements Tensor { /** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */ @Override - public int size() { + public long size() { return size; } /** - * Advances this to the next cell in the standard indexed tensor cell order. - * The first call to this will put it at the first position. + * Advances this to the next cell in the standard indexed tensor cell order. + * The first call to this will put it at the first position. * * @throws RuntimeException if this is called more times than its size */ @@ -919,10 +921,10 @@ public class IndexedTensor implements Tensor { } @Override - int toSourceValueIndex() { return currentValueIndex; } + long toSourceValueIndex() { return currentValueIndex; } @Override - int toIterationValueIndex() { return currentValueIndex; } + long toIterationValueIndex() { return currentValueIndex; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index 618bff0caae..15993072c37 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -27,9 +27,9 @@ public class MappedTensor implements Tensor { @Override public TensorType type() { return type; } - + @Override - public int size() { return cells.size(); } + public long size() { return cells.size(); } @Override public double get(TensorAddress address) { return cells.getOrDefault(address, Double.NaN); } @@ -56,16 +56,16 @@ public class MappedTensor implements Tensor { } public static class Builder implements Tensor.Builder { - + private final TensorType type; private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>(); - + public static Builder of(TensorType type) { return new Builder(type); } private Builder(TensorType type) { this.type = type; } - + public CellBuilder cell() { return new CellBuilder(type, this); } @@ -80,7 +80,7 @@ public class MappedTensor implements Tensor { } @Override - public Builder cell(double value, int... labels) { + public Builder cell(double value, long... labels) { cells.put(TensorAddress.of(labels), value); return this; } @@ -89,24 +89,24 @@ public class MappedTensor implements Tensor { public MappedTensor build() { return new MappedTensor(type, cells.build()); } - + } private static class CellIteratorAdaptor implements Iterator<Cell> { private final Iterator<Map.Entry<TensorAddress, Double>> adaptedIterator; - + private CellIteratorAdaptor(Iterator<Map.Entry<TensorAddress, Double>> adaptedIterator) { this.adaptedIterator = adaptedIterator; } - + @Override public boolean hasNext() { return adaptedIterator.hasNext(); } @Override public Cell next() { Map.Entry<TensorAddress, Double> entry = adaptedIterator.next(); - return new Cell(entry.getKey(), entry.getValue()); + return new Cell(entry.getKey(), entry.getValue()); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 79bb27fcd1b..0c9ed769c0d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -47,13 +47,13 @@ public class MixedTensor implements Tensor { /** Returns the size of the tensor measured in number of cells */ @Override - public int size() { return cells.size(); } + public long size() { return cells.size(); } /** Returns the value at the given address */ @Override public double get(TensorAddress address) { - int cellIndex = index.indexOf(address); - Cell cell = cells.get(cellIndex); + long cellIndex = index.indexOf(address); + Cell cell = cells.get((int)cellIndex); if (!address.equals(cell.getKey())) { throw new IllegalStateException("Unable to find correct cell by direct index."); } @@ -113,11 +113,11 @@ public class MixedTensor implements Tensor { } /** Returns the size of dense subspaces */ - public int denseSubspaceSize() { + public long denseSubspaceSize() { return index.denseSubspaceSize(); } - + /** * Base class for building mixed tensors. */ @@ -148,7 +148,7 @@ public class MixedTensor implements Tensor { } @Override - public Tensor.Builder cell(double value, int... labels) { + public Tensor.Builder cell(double value, long... labels) { throw new UnsupportedOperationException("Not implemented."); } @@ -179,13 +179,13 @@ public class MixedTensor implements Tensor { index = indexBuilder.index(); } - public int denseSubspaceSize() { + public long denseSubspaceSize() { return index.denseSubspaceSize(); } private double[] denseSubspace(TensorAddress sparsePartial) { if (!denseSubspaceMap.containsKey(sparsePartial)) { - denseSubspaceMap.put(sparsePartial, new double[denseSubspaceSize()]); + denseSubspaceMap.put(sparsePartial, new double[(int)denseSubspaceSize()]); } return denseSubspaceMap.get(sparsePartial); } @@ -193,21 +193,21 @@ public class MixedTensor implements Tensor { @Override public Tensor.Builder cell(TensorAddress address, double value) { TensorAddress sparsePart = index.sparsePartialAddress(address); - int denseOffset = index.denseOffset(address); + long denseOffset = index.denseOffset(address); double[] denseSubspace = denseSubspace(sparsePart); - denseSubspace[denseOffset] = value; + denseSubspace[(int)denseOffset] = value; return this; } public Tensor.Builder block(TensorAddress sparsePart, double[] values) { double[] denseSubspace = denseSubspace(sparsePart); - System.arraycopy(values, 0, denseSubspace, 0, denseSubspaceSize()); + System.arraycopy(values, 0, denseSubspace, 0, (int)denseSubspaceSize()); return this; } @Override public MixedTensor build() { - int count = 0; + long count = 0; ImmutableList.Builder<Cell> builder = new ImmutableList.Builder<>(); for (Map.Entry<TensorAddress, double[]> entry : denseSubspaceMap.entrySet()) { @@ -215,9 +215,9 @@ public class MixedTensor implements Tensor { indexBuilder.put(sparsePart, count); double[] denseSubspace = entry.getValue(); - for (int offset = 0; offset < denseSubspace.length; ++offset) { + for (long offset = 0; offset < denseSubspace.length; ++offset) { TensorAddress cellAddress = index.addressOf(sparsePart, offset); - double value = denseSubspace[offset]; + double value = denseSubspace[(int)offset]; builder.add(new Cell(cellAddress, value)); count++; } @@ -239,12 +239,12 @@ public class MixedTensor implements Tensor { public static class UnboundBuilder extends Builder { private Map<TensorAddress, Double> cells; - private final int[] dimensionBounds; + private final long[] dimensionBounds; private UnboundBuilder(TensorType type) { super(type); cells = new HashMap<>(); - dimensionBounds = new int[type.dimensions().size()]; + dimensionBounds = new long[type.dimensions().size()]; } @Override @@ -268,7 +268,7 @@ public class MixedTensor implements Tensor { for (int i = 0; i < type.dimensions().size(); ++i) { TensorType.Dimension dimension = type.dimensions().get(i); if (dimension.isIndexed()) { - dimensionBounds[i] = Math.max(address.intLabel(i), dimensionBounds[i]); + dimensionBounds[i] = Math.max(address.numericLabel(i), dimensionBounds[i]); } } } @@ -280,13 +280,13 @@ public class MixedTensor implements Tensor { if (!dimension.isIndexed()) { typeBuilder.mapped(dimension.name()); } else { - int size = dimension.size().orElse(dimensionBounds[i] + 1); + long size = dimension.size().orElse(dimensionBounds[i] + 1); typeBuilder.indexed(dimension.name(), size); } } return typeBuilder.build(); } - + } /** @@ -303,8 +303,8 @@ public class MixedTensor implements Tensor { private final List<TensorType.Dimension> mappedDimensions; private final List<TensorType.Dimension> indexedDimensions; - private ImmutableMap<TensorAddress, Integer> sparseMap; - private int denseSubspaceSize = -1; + private ImmutableMap<TensorAddress, Long> sparseMap; + private long denseSubspaceSize = -1; private Index(TensorType type) { this.type = type; @@ -314,26 +314,27 @@ public class MixedTensor implements Tensor { this.denseType = createPartialType(indexedDimensions); } - public int indexOf(TensorAddress address) { + public long indexOf(TensorAddress address) { TensorAddress sparsePart = sparsePartialAddress(address); - if (!sparseMap.containsKey(sparsePart)) { + if ( ! sparseMap.containsKey(sparsePart)) { throw new IllegalArgumentException("Address not found"); } - int base = sparseMap.get(sparsePart); - int offset = denseOffset(address); + long base = sparseMap.get(sparsePart); + long offset = denseOffset(address); return base + offset; } public static class Builder { + private final Index index; - private final ImmutableMap.Builder<TensorAddress, Integer> builder; + private final ImmutableMap.Builder<TensorAddress, Long> builder; public Builder(TensorType type) { index = new Index(type); builder = new ImmutableMap.Builder<>(); } - public void put(TensorAddress address, int index) { + public void put(TensorAddress address, long index) { builder.put(address, index); } @@ -347,7 +348,7 @@ public class MixedTensor implements Tensor { } } - public int denseSubspaceSize() { + public long denseSubspaceSize() { if (denseSubspaceSize == -1) { denseSubspaceSize = 1; for (int i = 0; i < type.dimensions().size(); ++i) { @@ -360,7 +361,7 @@ public class MixedTensor implements Tensor { } return denseSubspaceSize; } - + private TensorAddress sparsePartialAddress(TensorAddress address) { if (type.dimensions().size() != address.size()) { throw new IllegalArgumentException("Tensor type and address are not of same size."); @@ -375,13 +376,13 @@ public class MixedTensor implements Tensor { return builder.build(); } - private int denseOffset(TensorAddress address) { - int innerSize = 1; - int offset = 0; + private long denseOffset(TensorAddress address) { + long innerSize = 1; + long offset = 0; for (int i = type.dimensions().size(); --i >= 0; ) { TensorType.Dimension dimension = type.dimensions().get(i); if (dimension.isIndexed()) { - int label = address.intLabel(i); + long label = address.numericLabel(i); offset += label * innerSize; innerSize *= dimension.size().orElseThrow(() -> new IllegalArgumentException("Unknown size of indexed dimension.")); @@ -390,18 +391,18 @@ public class MixedTensor implements Tensor { return offset; } - private TensorAddress denseOffsetToAddress(int denseOffset) { + private TensorAddress denseOffsetToAddress(long denseOffset) { if (denseOffset < 0 || denseOffset > denseSubspaceSize) { throw new IllegalArgumentException("Offset out of bounds"); } - int restSize = denseOffset; - int innerSize = denseSubspaceSize; - int[] labels = new int[indexedDimensions.size()]; + long restSize = denseOffset; + long innerSize = denseSubspaceSize; + long[] labels = new long[indexedDimensions.size()]; for (int i = 0; i < labels.length; ++i) { TensorType.Dimension dimension = indexedDimensions.get(i); - int dimensionSize = dimension.size().orElseThrow(() -> + long dimensionSize = dimension.size().orElseThrow(() -> new IllegalArgumentException("Unknown size of indexed dimension.")); innerSize /= dimensionSize; @@ -411,7 +412,7 @@ public class MixedTensor implements Tensor { return TensorAddress.of(labels); } - private TensorAddress addressOf(TensorAddress sparsePart, int denseOffset) { + private TensorAddress addressOf(TensorAddress sparsePart, long denseOffset) { TensorAddress densePart = denseOffsetToAddress(denseOffset); String[] labels = new String[type.dimensions().size()]; int mappedIndex = 0; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java index e3398850373..23ef0772aea 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java @@ -6,11 +6,11 @@ import com.google.common.annotations.Beta; /** * An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors * dimensions. - * + * * @author bratseth */ -// Implementation notes: -// - These are created in inner (though not inner-most) loops so they are implemented with minimal allocation. +// Implementation notes: +// - These are created in inner (though not inner-most) loops so they are implemented with minimal allocation. // We also avoid non-essential error checking. // - We can add support for string labels later without breaking the API @Beta @@ -19,7 +19,7 @@ public class PartialAddress { // Two arrays which contains corresponding dimension=label pairs. // The sizes of these are always equal. private final String[] dimensionNames; - private final int[] labels; + private final long[] labels; private PartialAddress(Builder builder) { this.dimensionNames = builder.dimensionNames; @@ -27,36 +27,36 @@ public class PartialAddress { builder.dimensionNames = null; // invalidate builder to safely take over array ownership builder.labels = null; } - + /** Returns the int label of this dimension, or -1 if no label is specified for it */ - int intLabel(String dimensionName) { + long numericLabel(String dimensionName) { for (int i = 0; i < dimensionNames.length; i++) if (dimensionNames[i].equals(dimensionName)) return labels[i]; return -1; } - + public static class Builder { private String[] dimensionNames; - private int[] labels; + private long[] labels; private int index = 0; - + public Builder(int size) { dimensionNames = new String[size]; - labels = new int[size]; + labels = new long[size]; } - - public void add(String dimensionName, int label) { + + public void add(String dimensionName, long label) { dimensionNames[index] = dimensionName; labels[index] = label; index++; } - + public PartialAddress build() { return new PartialAddress(this); } - + } - + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 2ed211539d8..0c948f1fbee 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -52,14 +52,14 @@ import java.util.function.Function; public interface Tensor { // ----------------- Accessors - + TensorType type(); /** Returns whether this have any cells */ default boolean isEmpty() { return size() == 0; } /** Returns the number of cells in this */ - int size(); + long size(); /** Returns the value of a cell, or NaN if this cell does not exist/have no value */ double get(TensorAddress address); @@ -70,13 +70,13 @@ public interface Tensor { /** Returns the values of this in some undefined order */ Iterator<Double> valueIterator(); - /** + /** * Returns an immutable map of the cells of this in no particular order. - * This may be expensive for some implementations - avoid when possible + * This may be expensive for some implementations - avoid when possible */ Map<TensorAddress, Double> cells(); - /** + /** * Returns the value of this as a double if it has no dimensions and one value * * @throws IllegalStateException if this does not have zero dimensions and one value @@ -87,9 +87,9 @@ public interface Tensor { if (size() == 0) return Double.NaN; return valueIterator().next(); } - + // ----------------- Primitive tensor functions - + default Tensor map(DoubleUnaryOperator mapper) { return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate(); } @@ -108,7 +108,7 @@ public interface Tensor { } default Tensor rename(String fromDimension, String toDimension) { - return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension), + return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension), Collections.singletonList(toDimension)).evaluate(); } @@ -123,13 +123,13 @@ public interface Tensor { default Tensor rename(List<String> fromDimensions, List<String> toDimensions) { return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate(); } - - static Tensor generate(TensorType type, Function<List<Integer>, Double> valueSupplier) { + + static Tensor generate(TensorType type, Function<List<Long>, Double> valueSupplier) { return new Generate(type, valueSupplier).evaluate(); } - + // ----------------- Composite tensor functions which have a defined primitive mapping - + default Tensor l1Normalize(String dimension) { return new L1Normalize(new ConstantTensor(this), dimension).evaluate(); } @@ -231,7 +231,7 @@ public interface Tensor { if (cellEntries.isEmpty()) return "{}"; return "{" + cellEntries.get(0).getValue() +"}"; } - + Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey()); StringBuilder b = new StringBuilder("{"); @@ -253,7 +253,7 @@ public interface Tensor { */ boolean equals(Object o); - /** + /** * Implement here to make this work across implementations. * Implementations must override equals and call this because this is an interface and cannot override equals. */ @@ -328,13 +328,13 @@ public interface Tensor { @Override public TensorAddress getKey() { return address; } - /** + /** * Returns the direct index which can be used to locate this cell, or -1 if not available. * This is for optimizations mapping between tensors where this is possible without creating a * TensorAddress. */ - int getDirectIndex() { return -1; } - + long getDirectIndex() { return -1; } + @Override public Double getValue() { return value; } @@ -388,20 +388,20 @@ public interface Tensor { /** Returns the type this is building */ TensorType type(); - + /** Return a cell builder */ CellBuilder cell(); /** Add a cell */ Builder cell(TensorAddress address, double value); - + /** Add a cell */ - Builder cell(double value, int ... labels); + Builder cell(double value, long ... labels); - /** - * Add a cell - * - * @param cell a cell providing the location at which to add this cell + /** + * Add a cell + * + * @param cell a cell providing the location at which to add this cell * @param value the value to assign to the cell */ default Builder cell(Cell cell, double value) { @@ -409,12 +409,12 @@ public interface Tensor { } Tensor build(); - + class CellBuilder { private final TensorAddress.Builder addressBuilder; private final Tensor.Builder tensorBuilder; - + CellBuilder(TensorType type, Tensor.Builder tensorBuilder) { addressBuilder = new TensorAddress.Builder(type); this.tensorBuilder = tensorBuilder; @@ -425,7 +425,7 @@ public interface Tensor { return this; } - public CellBuilder label(String dimension, int label) { + public CellBuilder label(String dimension, long label) { return label(dimension, String.valueOf(label)); } @@ -436,5 +436,5 @@ public interface Tensor { } } - + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 7161450d5d5..38553497478 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -2,16 +2,10 @@ package com.yahoo.tensor; import com.google.common.annotations.Beta; -import com.google.common.collect.ImmutableList; -import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; import java.util.Objects; import java.util.Optional; -import java.util.Set; /** * An immutable address to a tensor cell. This simply supplies a value to each dimension @@ -26,29 +20,29 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return new StringTensorAddress(labels); } - public static TensorAddress of(int ... labels) { - return new IntTensorAddress(labels); + public static TensorAddress of(long ... labels) { + return new NumericTensorAddress(labels); } /** Returns the number of labels in this */ public abstract int size(); - + /** - * Returns the i'th label in this - * + * Returns the i'th label in this + * * @throws IllegalArgumentException if there is no label at this index */ public abstract String label(int i); /** - * Returns the i'th label in this as an int. - * Prefer this if you know that this is an integer address, but not otherwise. + * Returns the i'th label in this as a long. + * Prefer this if you know that this is a numeric address, but not otherwise. * * @throws IllegalArgumentException if there is no label at this index */ - public abstract int intLabel(int i); + public abstract long numericLabel(int i); - public abstract TensorAddress withLabel(int labelIndex, int label); + public abstract TensorAddress withLabel(int labelIndex, long label); public final boolean isEmpty() { return size() == 0; } @@ -102,25 +96,25 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { private StringTensorAddress(String ... labels) { this.labels = Arrays.copyOf(labels, labels.length); } - + @Override public int size() { return labels.length; } - + @Override public String label(int i) { return labels[i]; } - + @Override - public int intLabel(int i) { + public long numericLabel(int i) { try { - return Integer.parseInt(labels[i]); - } + return Long.parseLong(labels[i]); + } catch (NumberFormatException e) { - throw new IllegalArgumentException("Expected an int label in " + this + " at position " + i); + throw new IllegalArgumentException("Expected a long label in " + this + " at position " + i); } } - + @Override - public TensorAddress withLabel(int index, int label) { + public TensorAddress withLabel(int index, long label) { String[] labels = Arrays.copyOf(this.labels, this.labels.length); labels[index] = String.valueOf(label); return new StringTensorAddress(labels); @@ -133,11 +127,11 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { } - private static final class IntTensorAddress extends TensorAddress { + private static final class NumericTensorAddress extends TensorAddress { - private final int[] labels; + private final long[] labels; - private IntTensorAddress(int[] labels) { + private NumericTensorAddress(long[] labels) { this.labels = Arrays.copyOf(labels, labels.length); } @@ -148,13 +142,13 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { public String label(int i) { return String.valueOf(labels[i]); } @Override - public int intLabel(int i) { return labels[i]; } + public long numericLabel(int i) { return labels[i]; } @Override - public TensorAddress withLabel(int index, int label) { - int[] labels = Arrays.copyOf(this.labels, this.labels.length); + public TensorAddress withLabel(int index, long label) { + long[] labels = Arrays.copyOf(this.labels, this.labels.length); labels[index] = label; - return new IntTensorAddress(labels); + return new NumericTensorAddress(labels); } @Override @@ -169,7 +163,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { private final TensorType type; private final String[] labels; - + public Builder(TensorType type) { this(type, new String[type.dimensions().size()]); } @@ -193,7 +187,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { labels[labelIndex.get()] = label; return this; } - + /** Creates a copy of this which can be modified separately */ public Builder copy() { return new Builder(type, Arrays.copyOf(labels, labels.length)); @@ -202,7 +196,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { public TensorAddress build() { for (int i = 0; i < labels.length; i++) if (labels[i] == null) - throw new IllegalArgumentException("Missing a value for dimension " + + throw new IllegalArgumentException("Missing a value for dimension " + type.dimensions().get(i).name() + " for " + type); return TensorAddress.of(labels); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index da8ab3bb0ec..9b3a9328f07 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -96,7 +96,7 @@ class TensorParser { if (valueEnd < 0) throw new IllegalArgumentException("A tensor string must end by '}'"); } - + TensorAddress address = addressBuilder.build(); Double value = asDouble(address, s.substring(0, valueEnd).trim()); builder.cell(address, value); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index c05c35d6df3..b396f831de0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -53,14 +53,17 @@ public class TensorType { return TensorTypeParser.fromSpec(specString); } + /** Returns the number of dimensions of this: dimensions().size() */ + public int rank() { return dimensions.size(); } + /** Returns an immutable list of the dimensions of this */ public List<Dimension> dimensions() { return dimensions; } - + /** Returns an immutable set of the names of the dimensions of this */ public Set<String> dimensionNames() { return dimensions.stream().map(Dimension::name).collect(Collectors.toSet()); } - + /** Returns the dimension with this name, or empty if not present */ public Optional<Dimension> dimension(String name) { return indexOfDimension(name).map(i -> dimensions.get(i)); @@ -74,7 +77,7 @@ public class TensorType { return Optional.empty(); } - /** + /** * Returns whether this type can be assigned to the given type, * i.e if the given type is a generalization of this type. */ @@ -128,15 +131,15 @@ public class TensorType { private final String name; - private Dimension(String name) { + private Dimension(String name) { Objects.requireNonNull(name, "A tensor name cannot be null"); - this.name = name; + this.name = name; } public final String name() { return name; } /** Returns the size of this dimension if it is bound, empty otherwise */ - public abstract Optional<Integer> size(); + public abstract Optional<Long> size(); public abstract Type type(); @@ -146,7 +149,7 @@ public class TensorType { /** Returns true if this is an indexed bound or unboun type */ public boolean isIndexed() { return type() == Type.indexedBound || type() == Type.indexedUnbound; } - /** + /** * Returns the dimension resulting from combining two dimensions having the same name but possibly different * types. This works by degrading to the type making the fewer promises. * [N] + [M] = [min(N, M)] @@ -165,7 +168,7 @@ public class TensorType { IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get(); return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb; } - + @Override public abstract String toString(); @@ -175,36 +178,38 @@ public class TensorType { if (other == null || getClass() != other.getClass()) return false; return name.equals(((Dimension)other).name); } - + @Override public int hashCode() { return name.hashCode(); } - + @Override public int compareTo(Dimension other) { return this.name.compareTo(other.name); } - - public static Dimension indexed(String name, int size) { + + public static Dimension indexed(String name, long size) { return new IndexedBoundDimension(name, size); } - + } public static class IndexedBoundDimension extends TensorType.Dimension { - private final Integer size; + private final Long size; - private IndexedBoundDimension(String name, int size) { + private IndexedBoundDimension(String name, long size) { super(name); if (size < 1) throw new IllegalArgumentException("Size of bound dimension '" + name + "' must be at least 1"); + if (size > Integer.MAX_VALUE) + throw new IllegalArgumentException("Size of bound dimension '" + name + "' cannot be larger than " + Integer.MAX_VALUE); this.size = size; } @Override - public Optional<Integer> size() { return Optional.of(size); } + public Optional<Long> size() { return Optional.of(size); } @Override public Type type() { return Type.indexedBound; } @@ -245,7 +250,7 @@ public class TensorType { } @Override - public Optional<Integer> size() { return Optional.empty(); } + public Optional<Long> size() { return Optional.empty(); } @Override public Type type() { return Type.indexedUnbound; } @@ -266,7 +271,7 @@ public class TensorType { } @Override - public Optional<Integer> size() { return Optional.empty(); } + public Optional<Long> size() { return Optional.empty(); } @Override public Type type() { return Type.mapped; } @@ -289,9 +294,9 @@ public class TensorType { public Builder() { } - /** - * Creates a builder containing a combination of the dimensions of the given types - * + /** + * Creates a builder containing a combination of the dimensions of the given types + * * If the same dimension is indexed with different size restrictions the largest size will be used. * If it is size restricted in one argument but not the other it will not be size restricted. * If it is indexed in one and mapped in the other it will become mapped. @@ -325,9 +330,12 @@ public class TensorType { } } - /** + /** Returns the current number of dimensions in this */ + public int rank() { return dimensions.size(); } + + /** * Adds a new dimension to this - * + * * @throws IllegalArgumentException if the dimension is already present */ private Builder add(Dimension dimension) { @@ -346,16 +354,16 @@ public class TensorType { return this; } - /** + /** * Adds a bound indexed dimension to this * * @throws IllegalArgumentException if the dimension is already present */ - public Builder indexed(String name, int size) { return add(new IndexedBoundDimension(name, size)); } + public Builder indexed(String name, long size) { return add(new IndexedBoundDimension(name, size)); } /** * Adds an unbound indexed dimension to this - * + * * @throws IllegalArgumentException if the dimension is already present */ public Builder indexed(String name) { @@ -375,7 +383,7 @@ public class TensorType { public Builder dimension(Dimension dimension) { return add(dimension); } - + /** Returns the given dimension, or empty if none is present */ public Optional<Dimension> getDimension(String dimension) { return Optional.ofNullable(dimensions.get(dimension)); @@ -393,7 +401,7 @@ public class TensorType { public TensorType build() { return new TensorType(dimensions.values()); } - + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java index 84caca78fb2..3db661f8a23 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java @@ -2,16 +2,17 @@ package com.yahoo.tensor.evaluation; import com.google.common.annotations.Beta; - -import java.util.HashMap; +import com.yahoo.tensor.Tensor; /** * An evaluation context which is passed down to all nested functions during evaluation. - * The default context is empty to allow various evaluation frameworks to support their own implementation. - * + * * @author bratseth */ @Beta public interface EvaluationContext { + /** Returns the tensor bound to this name, or null if none */ + Tensor getTensor(String name); + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java index cf704c15f4f..db8a66a5fa2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java @@ -18,7 +18,7 @@ public class MapEvaluationContext implements EvaluationContext { public void put(String name, Tensor tensor) { bindings.put(name, tensor); } - /** Returns the tensor bound to this name, or null if none */ - public Tensor get(String name) { return bindings.get(name); } + @Override + public Tensor getTensor(String name) { return bindings.get(name); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index 8ade181bdb7..1f6ad050368 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -12,18 +12,18 @@ import java.util.List; /** * A tensor variable name which resolves to a tensor in the context at evaluation time - * + * * @author bratseth */ @Beta public class VariableTensor extends PrimitiveTensorFunction { private final String name; - + public VariableTensor(String name) { this.name = name; } - + @Override public List<TensorFunction> functionArguments() { return Collections.emptyList(); } @@ -35,7 +35,7 @@ public class VariableTensor extends PrimitiveTensorFunction { @Override public Tensor evaluate(EvaluationContext context) { - return ((MapEvaluationContext)context).get(name); + return context.getTensor(name); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 8f4dbf014a7..191c7988443 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -8,7 +8,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext; /** * A composite tensor function is a tensor function which can be expressed (less tersely) * as a tree of primitive tensor functions. - * + * * @author bratseth */ @Beta diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 1dbb94fdb20..d4affe0ef9b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -15,7 +15,7 @@ import java.util.stream.Collectors; /** * Concatenation of two tensors along an (indexed) dimension - * + * * @author bratseth */ @Beta @@ -67,15 +67,15 @@ public class Concat extends PrimitiveTensorFunction { DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); - int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new); + long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new); int[] aToIndexes = mapIndexes(a.type(), concatType); int[] bToIndexes = mapIndexes(b.type(), concatType); concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder); concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder); return builder.build(); } - - private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType, + + private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType, int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) { Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet()); for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) { @@ -112,7 +112,7 @@ public class Concat extends PrimitiveTensorFunction { Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build(); return tensor.multiply(unitTensor); } - + } /** Returns the type resulting from concatenating a and b */ @@ -129,8 +129,8 @@ public class Concat extends PrimitiveTensorFunction { DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size()); for (int i = 0; i < concatSizes.dimensions(); i++) { String currentDimension = concatType.dimensions().get(i).name(); - int aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0); - int bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0); + long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L); + long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L); if (currentDimension.equals(concatDimension)) concatSizes.set(i, aSize + bSize); else if (aSize != 0 && bSize != 0 && aSize!=bSize ) @@ -144,12 +144,12 @@ public class Concat extends PrimitiveTensorFunction { /** * Combine two addresses, adding the offset to the concat dimension * - * @return the combined address or null if the addresses are incompatible + * @return the combined address or null if the addresses are incompatible * (in some other dimension than the concat dimension) */ private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, - TensorType concatType, int concatOffset, String concatDimension) { - int[] combinedLabels = new int[concatType.dimensions().size()]; + TensorType concatType, long concatOffset, String concatDimension) { + long[] combinedLabels = new long[concatType.dimensions().size()]; Arrays.fill(combinedLabels, -1); int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get(); mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension @@ -161,7 +161,7 @@ public class Concat extends PrimitiveTensorFunction { /** * Returns the an array having one entry in order for each dimension of fromType * containing the index at which toType contains the same dimension name. - * That is, if the returned array contains n at index i then + * That is, if the returned array contains n at index i then * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) * If some dimension in fromType is not present in toType, the corresponding index will be -1 */ @@ -179,15 +179,15 @@ public class Concat extends PrimitiveTensorFunction { * @return true if the mapping was successful, false if one of the destination positions was * occupied by a different value */ - private boolean mapContent(TensorAddress from, int[] to, int[] indexMap, int concatDimension, int concatOffset) { + private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) { for (int i = 0; i < from.size(); i++) { int toIndex = indexMap[i]; if (concatDimension == toIndex) { - to[toIndex] = from.intLabel(i) + concatOffset; + to[toIndex] = from.numericLabel(i) + concatOffset; } else { - if (to[toIndex] != -1 && to[toIndex] != from.intLabel(i)) return false; - to[toIndex] = from.intLabel(i); + if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false; + to[toIndex] = from.numericLabel(i); } } return true; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 4ac7b21ba90..14ed38718ce 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -10,18 +10,18 @@ import java.util.List; /** * A function which returns a constant tensor. - * + * * @author bratseth */ @Beta public class ConstantTensor extends PrimitiveTensorFunction { private final Tensor constant; - + public ConstantTensor(String tensorString) { this.constant = Tensor.from(tensorString); } - + public ConstantTensor(Tensor tensor) { this.constant = tensor; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java index bbdbd5c3df1..653be8dacf0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java @@ -11,19 +11,19 @@ import java.util.stream.Stream; /** * A tensor generator which returns a tensor of any dimension filled with 1 in the diagonal and 0 elsewhere. - * + * * @author bratseth */ public class Diag extends CompositeTensorFunction { private final TensorType type; - private final Function<List<Integer>, Double> diagFunction; - + private final Function<List<Long>, Double> diagFunction; + public Diag(TensorType type) { this.type = type; this.diagFunction = ScalarFunctions.equal(dimensionNames().collect(Collectors.toList())); } - + @Override public List<TensorFunction> functionArguments() { return Collections.emptyList(); } @@ -43,7 +43,7 @@ public class Diag extends CompositeTensorFunction { public String toString(ToStringContext context) { return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction; } - + private Stream<String> dimensionNames() { return type.dimensions().stream().map(TensorType.Dimension::name); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 6ea73b7f310..ef2770c04f5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -15,31 +15,31 @@ import java.util.function.Function; /** * An indexed tensor whose values are generated by a function - * + * * @author bratseth */ @Beta public class Generate extends PrimitiveTensorFunction { private final TensorType type; - private final Function<List<Integer>, Double> generator; + private final Function<List<Long>, Double> generator; /** * Creates a generated tensor - * + * * @param type the type of the tensor - * @param generator the function generating values from a list of ints specifying the indexes of the + * @param generator the function generating values from a list of numbers specifying the indexes of the * tensor cell which will receive the value * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound */ - public Generate(TensorType type, Function<List<Integer>, Double> generator) { + public Generate(TensorType type, Function<List<Long>, Double> generator) { Objects.requireNonNull(type, "The argument tensor type cannot be null"); Objects.requireNonNull(generator, "The argument function cannot be null"); validateType(type); this.type = type; this.generator = generator; } - + private void validateType(TensorType type) { for (TensorType.Dimension dimension : type.dimensions()) if (dimension.type() != TensorType.Dimension.Type.indexedBound) @@ -58,7 +58,7 @@ public class Generate extends PrimitiveTensorFunction { @Override public PrimitiveTensorFunction toPrimitive() { return this; } - + @Override public Tensor evaluate(EvaluationContext context) { Tensor.Builder builder = Tensor.Builder.of(type); @@ -69,14 +69,14 @@ public class Generate extends PrimitiveTensorFunction { } return builder.build(); } - + private DimensionSizes dimensionSizes(TensorType type) { DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); for (int i = 0; i < b.dimensions(); i++) b.set(i, type.dimensions().get(i).size().get()); return b.build(); } - + @Override public String toString(ToStringContext context) { return type + "(" + generator + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 8c4dbfb0acb..174a8e4c435 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -28,12 +28,12 @@ import java.util.function.DoubleBinaryOperator; * The <i>join</i> tensor operation produces a tensor from the argument tensors containing the set of cells * given by the cross product of the cells of the given tensors, having as values the value produced by * applying the given combinator function on the values from the two source cells. - * + * * @author bratseth */ @Beta public class Join extends PrimitiveTensorFunction { - + private final TensorFunction argumentA, argumentB; private final DoubleBinaryOperator combinator; @@ -46,6 +46,30 @@ public class Join extends PrimitiveTensorFunction { this.combinator = combinator; } + /** Returns the type resulting from applying Join to the two given types */ + public static TensorType outputType(TensorType a, TensorType b) { + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (int i = 0; i < a.dimensions().size(); ++i) { + TensorType.Dimension aDim = a.dimensions().get(i); + for (int j = 0; j < b.dimensions().size(); ++j) { + TensorType.Dimension bDim = b.dimensions().get(j); + if (aDim.name().equals(bDim.name())) { // include + if (aDim.isIndexed() && bDim.isIndexed()) { + if (aDim.size().isPresent() || bDim.size().isPresent()) + typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Long.MAX_VALUE), + bDim.size().orElse(Long.MAX_VALUE))); + else + typeBuilder.indexed(aDim.name()); + } + else { + typeBuilder.mapped(aDim.name()); + } + } + } + } + return typeBuilder.build(); + } + public TensorFunction argumentA() { return argumentA; } public TensorFunction argumentB() { return argumentB; } public DoubleBinaryOperator combinator() { return combinator; } @@ -88,17 +112,17 @@ public class Join extends PrimitiveTensorFunction { else return generalJoin(a, b, joinedType); } - + private boolean hasSingleIndexedDimension(Tensor tensor) { return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed(); } - + private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) { - int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); + long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); Iterator<Double> aIterator = a.valueIterator(); Iterator<Double> bIterator = b.valueIterator(); - IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedLength).build()); - for (int i = 0; i < joinedLength; i++) + IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build()); + for (int i = 0; i < joinedRank; i++) builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i); return builder.build(); } @@ -114,7 +138,7 @@ public class Join extends PrimitiveTensorFunction { } return builder.build(); } - + /** Join a tensor into a superspace */ private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor) @@ -126,7 +150,7 @@ public class Join extends PrimitiveTensorFunction { private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build(); - + DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace); IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes); @@ -134,21 +158,21 @@ public class Join extends PrimitiveTensorFunction { // Find dimensions which are only in the supertype Set<String> superDimensionNames = new HashSet<>(superspace.type().dimensionNames()); superDimensionNames.removeAll(subspace.type().dimensionNames()); - + for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) { IndexedTensor.SubspaceIterator subspaceInSuper = i.next(); joinSubspaces(subspace.valueIterator(), subspace.size(), subspaceInSuper, subspaceInSuper.size(), reversedArgumentOrder, builder); } - + return builder.build(); } - private void joinSubspaces(Iterator<Double> subspace, int subspaceSize, - Iterator<Tensor.Cell> superspace, int superspaceSize, + private void joinSubspaces(Iterator<Double> subspace, long subspaceSize, + Iterator<Tensor.Cell> superspace, long superspaceSize, boolean reversedArgumentOrder, IndexedTensor.Builder builder) { - int joinedLength = Math.min(subspaceSize, superspaceSize); + long joinedLength = Math.min(subspaceSize, superspaceSize); if (reversedArgumentOrder) { for (int i = 0; i < joinedLength; i++) { Tensor.Cell supercell = superspace.next(); @@ -200,7 +224,7 @@ public class Join extends PrimitiveTensorFunction { subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get(); return subspaceIndexes; } - + private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) { String[] subspaceLabels = new String[subspaceIndexes.length]; for (int i = 0; i < subspaceIndexes.length; i++) @@ -235,7 +259,7 @@ public class Join extends PrimitiveTensorFunction { DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize); // for each combination of dimensions only in a - for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) { + for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) { IndexedTensor.SubspaceIterator aSubspace = ia.next(); // for each combination of dimensions in a which is also in b while (aSubspace.hasNext()) { @@ -252,15 +276,15 @@ public class Join extends PrimitiveTensorFunction { } } } - + private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) { PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size()); for (int i = 0; i < addressType.dimensions().size(); i++) if (retainDimensions.contains(addressType.dimensions().get(i).name())) - builder.add(addressType.dimensions().get(i).name(), address.intLabel(i)); + builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i)); return builder.build(); } - + /** Returns the sizes from the joined sizes which are present in the type argument */ private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) { DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); @@ -271,7 +295,7 @@ public class Join extends PrimitiveTensorFunction { } return builder.build(); } - + private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) { int[] aToIndexes = mapIndexes(a.type(), joinedType); int[] bToIndexes = mapIndexes(b.type(), joinedType); @@ -340,7 +364,7 @@ public class Join extends PrimitiveTensorFunction { /** * Returns the an array having one entry in order for each dimension of fromType * containing the index at which toType contains the same dimension name. - * That is, if the returned array contains n at index i then + * That is, if the returned array contains n at index i then * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) * If some dimension in fromType is not present in toType, the corresponding index will be -1 */ @@ -360,7 +384,7 @@ public class Join extends PrimitiveTensorFunction { return TensorAddress.of(joinedLabels); } - /** + /** * Maps the content in the given list to the given array, using the given index map. * * @return true if the mapping was successful, false if one of the destination positions was diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index a9872bb42d8..a5e1a016a41 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -6,6 +6,7 @@ import com.google.common.collect.ImmutableMap; import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Collections; @@ -32,6 +33,8 @@ public class Map extends PrimitiveTensorFunction { this.mapper = mapper; } + public static TensorType outputType(TensorType inputType) { return inputType; } + public TensorFunction argument() { return argument; } public DoubleUnaryOperator mapper() { return mapper; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index bb27e937699..4071917c2b5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.TensorType; import java.util.List; @@ -14,13 +15,17 @@ public class Matmul extends CompositeTensorFunction { private final TensorFunction argument1, argument2; private final String dimension; - + public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) { this.argument1 = argument1; this.argument2 = argument2; this.dimension = dimension; } + public static TensorType outputType(TensorType a, TensorType b, String dimension) { + return Join.outputType(a, b); + } + @Override public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); } @@ -39,7 +44,7 @@ public class Matmul extends CompositeTensorFunction { Reduce.Aggregator.sum, dimension); } - + @Override public String toString(ToStringContext context) { return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java index efb7b9e500c..b7c9a5d2342 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java @@ -8,10 +8,10 @@ import com.yahoo.tensor.Tensor; * A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions. * All tensor implementations must implement all primitive tensor functions. * Primitive tensor functions are fully inspectable. - * + * * @author bratseth */ @Beta public abstract class PrimitiveTensorFunction extends TensorFunction { - + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java index 457763e97ba..958ef85d1dc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java @@ -22,11 +22,11 @@ import java.util.stream.Stream; public class Random extends CompositeTensorFunction { private final TensorType type; - + public Random(TensorType type) { this.type = type; } - + @Override public List<TensorFunction> functionArguments() { return Collections.emptyList(); } @@ -46,7 +46,7 @@ public class Random extends CompositeTensorFunction { public String toString(ToStringContext context) { return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")"; } - + private Stream<String> dimensionNames() { return type.dimensions().stream().map(TensorType.Dimension::toString); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java index e2b39a2048d..8e7f4e4c773 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java @@ -12,19 +12,19 @@ import java.util.stream.Stream; /** * A tensor generator which returns a tensor of any dimension filled with the sum of the tensor * indexes of each position. - * + * * @author bratseth */ public class Range extends CompositeTensorFunction { private final TensorType type; - private final Function<List<Integer>, Double> rangeFunction; - + private final Function<List<Long>, Double> rangeFunction; + public Range(TensorType type) { this.type = type; this.rangeFunction = ScalarFunctions.sum(dimensionNames().collect(Collectors.toList())); } - + @Override public List<TensorFunction> functionArguments() { return Collections.emptyList(); } @@ -44,7 +44,7 @@ public class Range extends CompositeTensorFunction { public String toString(ToStringContext context) { return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction; } - + private Stream<String> dimensionNames() { return type.dimensions().stream().map(TensorType.Dimension::toString); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index cfc78be7e0c..de9f90a5804 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -19,7 +19,7 @@ import java.util.Objects; import java.util.Set; /** - * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions + * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions * are collapsed to a single value using an aggregator function. * * @author bratseth @@ -45,7 +45,7 @@ public class Reduce extends PrimitiveTensorFunction { /** * Creates a reduce function. - * + * * @param argument the tensor to reduce * @param aggregator the aggregator function to use * @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced, @@ -61,6 +61,15 @@ public class Reduce extends PrimitiveTensorFunction { this.dimensions = ImmutableList.copyOf(dimensions); } + public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) { + TensorType.Builder b = new TensorType.Builder(); + for (TensorType.Dimension dimension : inputType.dimensions()) { + if ( ! reduceDimensions.contains(dimension.name())) + b.dimension(dimension); + } + return b.build(); + } + public TensorFunction argument() { return argument; } @Override @@ -82,7 +91,7 @@ public class Reduce extends PrimitiveTensorFunction { public String toString(ToStringContext context) { return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")"; } - + private String commaSeparated(List<String> list) { StringBuilder b = new StringBuilder(); for (String element : list) @@ -94,7 +103,7 @@ public class Reduce extends PrimitiveTensorFunction { public Tensor evaluate(EvaluationContext context) { Tensor argument = this.argument.evaluate(context); if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) - throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + + throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + dimensions + ": Not all those dimensions are present in this tensor"); // Special case: Reduce all @@ -103,14 +112,14 @@ public class Reduce extends PrimitiveTensorFunction { return reduceIndexedVector((IndexedTensor)argument); else return reduceAllGeneral(argument); - + // Reduce type TensorType.Builder builder = new TensorType.Builder(); for (TensorType.Dimension dimension : argument.type().dimensions()) if ( ! dimensions.contains(dimension.name())) // keep builder.dimension(dimension); TensorType reducedType = builder.build(); - + // Reduce cells Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { @@ -122,10 +131,10 @@ public class Reduce extends PrimitiveTensorFunction { Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType); for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet()) reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue()); - + return reducedBuilder.build(); } - + private TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType) { Set<Integer> indexesToRemove = new HashSet<>(); for (String dimensionToRemove : this.dimensions) @@ -138,7 +147,7 @@ public class Reduce extends PrimitiveTensorFunction { reducedLabels[reducedLabelIndex++] = address.label(i); return TensorAddress.of(reducedLabels); } - + private Tensor reduceAllGeneral(Tensor argument) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); ) @@ -154,7 +163,7 @@ public class Reduce extends PrimitiveTensorFunction { } private static abstract class ValueAggregator { - + private static ValueAggregator ofType(Aggregator aggregator) { switch (aggregator) { case avg : return new AvgAggregator(); @@ -165,22 +174,22 @@ public class Reduce extends PrimitiveTensorFunction { case min : return new MinAggregator(); default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented"); } - + } /** Add a new value to those aggregated by this */ public abstract void aggregate(double value); - + /** Returns the value aggregated by this */ public abstract double aggregatedValue(); - + } - + private static class AvgAggregator extends ValueAggregator { private int valueCount = 0; private double valueSum = 0.0; - + @Override public void aggregate(double value) { valueCount++; @@ -188,7 +197,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public double aggregatedValue() { + public double aggregatedValue() { return valueSum / valueCount; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 6b0daf1b49a..ec9b762a41c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -3,8 +3,6 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -19,7 +17,7 @@ import java.util.Objects; /** * The <i>rename</i> tensor function returns a tensor where some dimensions are assigned new names. - * + * * @author bratseth */ @Beta @@ -29,6 +27,10 @@ public class Rename extends PrimitiveTensorFunction { private final List<String> fromDimensions; private final List<String> toDimensions; + public Rename(TensorFunction argument, String fromDimension, String toDimension) { + this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension)); + } + public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null"); @@ -42,7 +44,7 @@ public class Rename extends PrimitiveTensorFunction { this.fromDimensions = ImmutableList.copyOf(fromDimensions); this.toDimensions = ImmutableList.copyOf(toDimensions); } - + @Override public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } @@ -62,7 +64,7 @@ public class Rename extends PrimitiveTensorFunction { Map<String, String> fromToMap = fromToMap(); TensorType renamedType = rename(tensor.type(), fromToMap); - + // an array which lists the index of each label in the renamed type int[] toIndexes = new int[tensor.type().dimensions().size()]; for (int i = 0; i < tensor.type().dimensions().size(); i++) { @@ -70,7 +72,7 @@ public class Rename extends PrimitiveTensorFunction { String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName); toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get(); } - + Tensor.Builder builder = Tensor.Builder.of(renamedType); for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); @@ -86,7 +88,7 @@ public class Rename extends PrimitiveTensorFunction { builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); return builder.build(); } - + private TensorAddress rename(TensorAddress address, int[] toIndexes) { String[] reorderedLabels = new String[toIndexes.length]; for (int i = 0; i < toIndexes.length; i++) @@ -95,18 +97,18 @@ public class Rename extends PrimitiveTensorFunction { } @Override - public String toString(ToStringContext context) { - return "rename(" + argument.toString(context) + ", " + + public String toString(ToStringContext context) { + return "rename(" + argument.toString(context) + ", " + toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; } - + private Map<String, String> fromToMap() { Map<String, String> map = new HashMap<>(); for (int i = 0; i < fromDimensions.size(); i++) map.put(fromDimensions.get(i), toDimensions.get(i)); return map; } - + private String toVectorString(List<String> elements) { if (elements.size() == 1) return elements.get(0); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index 99f79cb735a..f1dadba2a29 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -14,127 +14,112 @@ import java.util.stream.Collectors; /** * Factory of scalar Java functions. * The purpose of this is to embellish anonymous functions with a runtime type - * such that they can be inspected and will return a parseable toString. - * + * such that they can be inspected and will return a parsable toString. + * * @author bratseth */ @Beta public class ScalarFunctions { - public static DoubleBinaryOperator add() { return new Addition(); } - public static DoubleBinaryOperator multiply() { return new Multiplication(); } - public static DoubleBinaryOperator divide() { return new Division(); } + public static DoubleBinaryOperator add() { return new Add(); } + public static DoubleBinaryOperator divide() { return new Divide(); } public static DoubleBinaryOperator equal() { return new Equal(); } - public static DoubleUnaryOperator square() { return new Square(); } + public static DoubleBinaryOperator multiply() { return new Multiply(); } + + public static DoubleUnaryOperator acos() { return new Acos(); } + public static DoubleUnaryOperator exp() { return new Exp(); } public static DoubleUnaryOperator sqrt() { return new Sqrt(); } - public static DoubleUnaryOperator exp() { return new Exponent(); } - public static Function<List<Integer>, Double> random() { return new Random(); } - public static Function<List<Integer>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); } - public static Function<List<Integer>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); } + public static DoubleUnaryOperator square() { return new Square(); } + + public static Function<List<Long>, Double> random() { return new Random(); } + public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); } + public static Function<List<Long>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); } - public static class Addition implements DoubleBinaryOperator { + // Binary operators ----------------------------------------------------------------------------- + public static class Add implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return left + right; } - @Override public String toString() { return "f(a,b)(a + b)"; } + } + public static class Equal implements DoubleBinaryOperator { + @Override + public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; } + @Override + public String toString() { return "f(a,b)(a==b)"; } } - public static class Multiplication implements DoubleBinaryOperator { + public static class Exp implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.exp(operand); } + @Override + public String toString() { return "f(a)(exp(a))"; } + } + public static class Multiply implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return left * right; } - @Override public String toString() { return "f(a,b)(a * b)"; } - } - public static class Division implements DoubleBinaryOperator { - + public static class Divide implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return left / right; } - @Override public String toString() { return "f(a,b)(a / b)"; } } - public static class Equal implements DoubleBinaryOperator { + // Unary operators ------------------------------------------------------------------------------ + public static class Acos implements DoubleUnaryOperator { @Override - public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; } - + public double applyAsDouble(double operand) { return Math.acos(operand); } @Override - public String toString() { return "f(a,b)(a==b)"; } - } - - public static class Square implements DoubleUnaryOperator { - - @Override - public double applyAsDouble(double operand) { return operand * operand; } - - @Override - public String toString() { return "f(a)(a * a)"; } - + public String toString() { return "f(a)(acos(a))"; } } public static class Sqrt implements DoubleUnaryOperator { - @Override public double applyAsDouble(double operand) { return Math.sqrt(operand); } - @Override public String toString() { return "f(a)(sqrt(a))"; } - } - public static class Exponent implements DoubleUnaryOperator { + public static class Square implements DoubleUnaryOperator { @Override - public double applyAsDouble(double operand) { return Math.exp(operand); } + public double applyAsDouble(double operand) { return operand * operand; } @Override - public String toString() { return "f(a)(exp(a))"; } + public String toString() { return "f(a)(a * a)"; } } - public static class Random implements Function<List<Integer>, Double> { - - @Override - public Double apply(List<Integer> values) { - return ThreadLocalRandom.current().nextDouble(); - } - - @Override - public String toString() { return "random"; } + // Variable-length operators ----------------------------------------------------------------------------- - } - - public static class EqualElements implements Function<List<Integer>, Double> { - + public static class EqualElements implements Function<List<Long>, Double> { private final ImmutableList<String> argumentNames; - private EqualElements(List<String> argumentNames) { this.argumentNames = ImmutableList.copyOf(argumentNames); } @Override - public Double apply(List<Integer> values) { + public Double apply(List<Long> values) { if (values.isEmpty()) return 1.0; - for (Integer value : values) + for (Long value : values) if ( ! value.equals(values.get(0))) return 0.0; return 1.0; } - @Override - public String toString() { + public String toString() { if (argumentNames.size() == 0) return "1"; if (argumentNames.size() == 1) return "1"; if (argumentNames.size() == 2) return argumentNames.get(0) + "==" + argumentNames.get(1); - + StringBuilder b = new StringBuilder(); for (int i = 0; i < argumentNames.size() -1; i++) { b.append("(").append(argumentNames.get(i)).append("==").append(argumentNames.get(i+1)).append(")"); @@ -143,30 +128,34 @@ public class ScalarFunctions { } return b.toString(); } - } - public static class SumElements implements Function<List<Integer>, Double> { + public static class Random implements Function<List<Long>, Double> { + @Override + public Double apply(List<Long> values) { + return ThreadLocalRandom.current().nextDouble(); + } + @Override + public String toString() { return "random"; } + } + public static class SumElements implements Function<List<Long>, Double> { private final ImmutableList<String> argumentNames; - private SumElements(List<String> argumentNames) { this.argumentNames = ImmutableList.copyOf(argumentNames); } @Override - public Double apply(List<Integer> values) { - int sum = 0; - for (Integer value : values) + public Double apply(List<Long> values) { + long sum = 0; + for (Long value : values) sum += value; return (double)sum; } - @Override public String toString() { return argumentNames.stream().collect(Collectors.joining("+")); } - } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index bf279eb24d8..c856b548180 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -2,6 +2,8 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; +import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.TensorType; import java.util.Collections; import java.util.List; @@ -19,6 +21,10 @@ public class Softmax extends CompositeTensorFunction { this.argument = argument; this.dimension = dimension; } + + public static TensorType outputType(TensorType inputType, String dimension) { + return Reduce.outputType(inputType, ImmutableList.of(dimension)); + } @Override public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index cabcce198d1..533a46f87fe 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -12,7 +12,7 @@ import java.util.List; * A representation of a tensor function which is able to be translated to a set of primitive * tensor functions if necessary. * All tensor functions are immutable. - * + * * @author bratseth */ @Beta @@ -48,11 +48,11 @@ public abstract class TensorFunction { /** * Return a string representation of this context. - * + * * @param context a context which must be passed to all nexted functions when requesting the string value */ public abstract String toString(ToStringContext context); - + @Override public String toString() { return toString(ToStringContext.empty()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java index e8c425d49e0..416b28afa22 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java @@ -24,7 +24,7 @@ interface BinaryFormat { /** * Deserialize the given binary data into a Tensor object. - * + * * @param type the expected abstract type of the tensor to serialize, or empty to use type information from the data * @param buffer the buffer containing the tensor binary data */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index 8b7325ec211..1e830bac461 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -16,9 +16,9 @@ import java.util.Optional; * * Sorted dimensions = num_dimensions [dimension_str_len dimension_str_bytes dimension_size_int]* * Cell_values = [double, double, double, ...]* - * where values are encoded in order of increasing indexes in each dimension, increasing + * where values are encoded in order of increasing indexes in each dimension, increasing * indexes of later dimensions in the dimension type before earlier. - * + * * @author bratseth */ @Beta @@ -36,7 +36,7 @@ public class DenseBinaryFormat implements BinaryFormat { buffer.putInt1_4Bytes(tensor.type().dimensions().size()); for (int i = 0; i < tensor.type().dimensions().size(); i++) { buffer.putUtf8String(tensor.type().dimensions().get(i).name()); - buffer.putInt1_4Bytes(tensor.dimensionSizes().size(i)); + buffer.putInt1_4Bytes((int)tensor.dimensionSizes().size(i)); // XXX: Size truncation } } @@ -54,7 +54,7 @@ public class DenseBinaryFormat implements BinaryFormat { type = optionalType.get(); TensorType serializedType = decodeType(buffer); if ( ! serializedType.isAssignableTo(type)) - throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + + throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + " cannot be assigned to type " + type); sizes = sizesFromType(serializedType); } @@ -71,7 +71,7 @@ public class DenseBinaryFormat implements BinaryFormat { int dimensionCount = buffer.getInt1_4Bytes(); TensorType.Builder builder = new TensorType.Builder(); for (int i = 0; i < dimensionCount; i++) - builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); + builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); // XXX: Size truncation return builder.build(); } @@ -84,7 +84,7 @@ public class DenseBinaryFormat implements BinaryFormat { } private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { - for (int i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < sizes.totalSize(); i++) builder.cellByDirectIndex(i, buffer.getDouble()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java index 61dfa888567..34e6cccf0f0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java @@ -46,16 +46,16 @@ class MixedBinaryFormat implements BinaryFormat { buffer.putInt1_4Bytes(denseDimensions.size()); for (TensorType.Dimension dimension : denseDimensions) { buffer.putUtf8String(dimension.name()); - buffer.putInt1_4Bytes(dimension.size().orElseThrow(() -> - new IllegalArgumentException("Unknown size of indexed dimension."))); + buffer.putInt1_4Bytes((int)dimension.size().orElseThrow(() -> + new IllegalArgumentException("Unknown size of indexed dimension.")).longValue()); // XXX: Size truncation } } private void encodeCells(GrowableByteBuffer buffer, MixedTensor tensor) { List<TensorType.Dimension> sparseDimensions = tensor.type().dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList()); - int denseSubspaceSize = tensor.denseSubspaceSize(); + long denseSubspaceSize = tensor.denseSubspaceSize(); if (sparseDimensions.size() > 0) { - buffer.putInt1_4Bytes(tensor.size() / denseSubspaceSize); + buffer.putInt1_4Bytes((int)(tensor.size() / denseSubspaceSize)); // XXX: Size truncation } Iterator<Tensor.Cell> cellIterator = tensor.cellIterator(); while (cellIterator.hasNext()) { @@ -98,7 +98,7 @@ class MixedBinaryFormat implements BinaryFormat { } int numIndexedDimensions = buffer.getInt1_4Bytes(); for (int i = 0; i < numIndexedDimensions; ++i) { - builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); + builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); // XXX: Size truncation } return builder.build(); } @@ -106,21 +106,21 @@ class MixedBinaryFormat implements BinaryFormat { private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type) { List<TensorType.Dimension> sparseDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList()); TensorType sparseType = MixedTensor.createPartialType(sparseDimensions); - int denseSubspaceSize = builder.denseSubspaceSize(); + long denseSubspaceSize = builder.denseSubspaceSize(); int numBlocks = 1; if (sparseDimensions.size() > 0) { numBlocks = buffer.getInt1_4Bytes(); } - double[] denseSubspace = new double[denseSubspaceSize]; + double[] denseSubspace = new double[(int)denseSubspaceSize]; for (int i = 0; i < numBlocks; ++i) { TensorAddress.Builder sparseAddress = new TensorAddress.Builder(sparseType); for (TensorType.Dimension sparseDimension : sparseDimensions) { sparseAddress.add(sparseDimension.name(), buffer.getUtf8String()); } - for (int denseOffset = 0; denseOffset < denseSubspaceSize; denseOffset++) { - denseSubspace[denseOffset] = buffer.getDouble(); + for (long denseOffset = 0; denseOffset < denseSubspaceSize; denseOffset++) { + denseSubspace[(int)denseOffset] = buffer.getDouble(); } builder.block(sparseAddress.build(), denseSubspace); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index 19969506eca..0cd3ff77aca 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -3,13 +3,14 @@ package com.yahoo.tensor.serialization; import com.google.common.annotations.Beta; import com.yahoo.io.GrowableByteBuffer; -import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; -import com.yahoo.text.Utf8; -import java.util.*; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; /** * Implementation of a sparse binary format for a tensor on the form: @@ -39,7 +40,7 @@ class SparseBinaryFormat implements BinaryFormat { } private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { - buffer.putInt1_4Bytes(tensor.size()); + buffer.putInt1_4Bytes((int)tensor.size()); // XXX: Size truncation for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); encodeAddress(buffer, cell.getKey()); @@ -79,8 +80,8 @@ class SparseBinaryFormat implements BinaryFormat { } private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) { - int numCells = buffer.getInt1_4Bytes(); - for (int i = 0; i < numCells; ++i) { + long numCells = buffer.getInt1_4Bytes(); // XXX: Size truncation + for (long i = 0; i < numCells; ++i) { Tensor.Builder.CellBuilder cellBuilder = builder.cell(); decodeAddress(buffer, cellBuilder, type); cellBuilder.value(buffer.getDouble()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index 7467554790a..01a1d023f2b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -46,9 +46,9 @@ public class TypedBinaryFormat { return result; } - /** - * Decode some data to a tensor - * + /** + * Decode some data to a tensor + * * @param type the type to decode and validate to, or empty to use the type given in the data * @param buffer the buffer containing the data, use GrowableByteByffer.wrap(byte[]) if you have a byte array * @return the resulting tensor diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java index d199dd3a876..abdb3071bf7 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java @@ -13,14 +13,14 @@ import java.util.stream.Collectors; /** * Microbenchmark of tensor operations. - * + * * @author bratseth */ public class TensorFunctionBenchmark { private final static Random random = new Random(); - - public double benchmark(int iterations, List<Tensor> modelVectors, TensorType.Dimension.Type dimensionType, + + public double benchmark(int iterations, List<Tensor> modelVectors, TensorType.Dimension.Type dimensionType, boolean extraSpace) { Tensor queryVector = vectors(1, 300, dimensionType).get(0); if (extraSpace) { @@ -34,7 +34,7 @@ public class TensorFunctionBenchmark { long totalTime = System.currentTimeMillis() - startTime; return (double)totalTime / (double)iterations; } - + private Tensor unitVector(String dimension) { return Tensor.Builder.of(new TensorType.Builder().indexed(dimension, 1).build()) .cell().label(dimension, 0).value(1).build(); @@ -49,11 +49,11 @@ public class TensorFunctionBenchmark { private double dotProduct(Tensor tensor, List<Tensor> tensors) { double largest = Double.MIN_VALUE; - TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), - new VariableTensor("argument"), (a, b) -> a * b), + TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), + new VariableTensor("argument"), (a, b) -> a * b), Reduce.Aggregator.sum).toPrimitive(); MapEvaluationContext context = new MapEvaluationContext(); - + for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor context.put("argument", tensorElement); double dotProduct = dotProductFunction.evaluate(context).asDouble(); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 30078b4a826..38a8329bff1 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -4,7 +4,6 @@ package com.yahoo.tensor; import com.google.common.collect.ImmutableList; import com.yahoo.tensor.evaluation.MapEvaluationContext; import com.yahoo.tensor.evaluation.VariableTensor; -import com.yahoo.tensor.functions.Argmax; import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Reduce; @@ -12,20 +11,18 @@ import com.yahoo.tensor.functions.TensorFunction; import org.junit.Test; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Set; -import java.util.stream.Collectors; -import static org.junit.Assert.assertEquals; import static com.yahoo.tensor.TensorType.Dimension.Type; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** * Tests tensor functionality - * + * * @author bratseth */ public class TensorTestCase { @@ -99,7 +96,7 @@ public class TensorTestCase { ImmutableList.of("y", "x"))); assertEquals(Tensor.from("{ {x:0,y:0}:0, {x:0,y:1}:0, {x:1,y:0}:0, {x:1,y:1}:1, {x:2,y:0}:0, {x:2,y:1}:2, }"), Tensor.generate(new TensorType.Builder().indexed("x", 3).indexed("y", 2).build(), - (List<Integer> indexes) -> (double)indexes.get(0)*indexes.get(1))); + (List<Long> indexes) -> (double)indexes.get(0)*indexes.get(1))); assertEquals(Tensor.from("{ {x:0,y:0,z:0}:0, {x:0,y:1,z:0}:1, {x:1,y:0,z:0}:1, {x:1,y:1,z:0}:2, {x:2,y:0,z:0}:2, {x:2,y:1,z:0}:3, "+ " {x:0,y:0,z:1}:1, {x:0,y:1,z:1}:2, {x:1,y:0,z:1}:2, {x:1,y:1,z:1}:3, {x:2,y:0,z:1}:3, {x:2,y:1,z:1}:4 }"), Tensor.range(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build())); @@ -108,7 +105,7 @@ public class TensorTestCase { Tensor.diag(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build())); assertEquals(Tensor.from("{ {x:1}:0, {x:3}:1, {x:9}:0 }"), Tensor.from("{ {x:1}:1, {x:3}:5, {x:9}:3 }").argmax("x")); } - + /** Test the same computation made in various ways which are implemented with special-case optimizations */ @Test public void testOptimizedComputation() { @@ -130,7 +127,7 @@ public class TensorTestCase { assertEquals("Mixed vector", 42, (int)dotProduct(vector(Type.indexedUnbound), vectors(Type.mapped, 2))); assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.indexedUnbound), matrix(Type.mapped, 2))); assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.indexedUnbound), matrix(Type.mapped, 2))); - + // Test the unoptimized path by joining in another dimension Tensor unitJ = Tensor.Builder.of(new TensorType.Builder().mapped("j").build()).cell().label("j", 0).value(1).build(); Tensor unitK = Tensor.Builder.of(new TensorType.Builder().mapped("k").build()).cell().label("k", 0).value(1).build(); @@ -138,7 +135,7 @@ public class TensorTestCase { Tensor matrixInKSpace = matrix(Type.mapped, 2).get(0).multiply(unitK); assertEquals("Generic computation implementation", 42, (int)dotProduct(vectorInJSpace, Collections.singletonList(matrixInKSpace))); } - + private double dotProduct(Tensor tensor, List<Tensor> tensors) { double sum = 0; TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), @@ -161,7 +158,7 @@ public class TensorTestCase { private Tensor vector(int vectorSize, TensorType.Dimension.Type dimensionType) { return vectors(vectorSize, dimensionType, 1).get(0); } - + /** Create a list of vectors having a single dimension x */ private List<Tensor> vectors(TensorType.Dimension.Type dimensionType, int vectorCount) { return vectors(3, dimensionType, vectorCount); @@ -179,8 +176,8 @@ public class TensorTestCase { } return tensors; } - - /** + + /** * Create a matrix of vectors (in dimension i) where each vector has the dimension x. * This matrix contains the same vectors as returned by createVectors, in a single list element for convenience. */ diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java index fab53218b2c..f11c068bd74 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java @@ -10,12 +10,12 @@ import static org.junit.Assert.assertEquals; * @author bratseth */ public class JoinTestCase { - + /** Test the indexed subspace join optimization */ @Test public void testJoinIndexedSubspace() { Tensor t1, t2; - + t1 = Tensor.from("tensor(x[]):{{x:0}:1.0,{x:1}:2.0}"); t2 = Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:10,{x:1,y:1,z:0}:0.0}"); assertEquals(Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:20.0,{x:1,y:1,z:0}:0.0}"), @@ -34,10 +34,10 @@ public class JoinTestCase { assertEquals(Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:10.0,{x:1,y:1,z:0}:0.0}"), t2.divide(t1)); } - + @Test public void testGeneralJoin() { - assertEquals(Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3 }"), + assertEquals(Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3 }"), Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:4, {x:2}:6 }") .divide(Tensor.from("tensor(y[]):{{y:0}:2}"))); @@ -45,5 +45,5 @@ public class JoinTestCase { Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:6, {x:1,y:0}:8, {x:0,y:1}:20, {x:1,y:1}:24 }") .divide(Tensor.from("tensor(y[],z[]):{ {y:0,z:0}:2, {y:1,z:0}:4, {y:2,z:0}:6 }"))); } - + } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java new file mode 100644 index 00000000000..9643c0a56e7 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java @@ -0,0 +1,97 @@ +package com.yahoo.tensor.functions; + +import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class MatmulTestCase { + + @Test + public void testMatmul2d() { + // d0 is the 'outermost' dimension, etc. + Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3])")); + ab.cell( 1,0, 0); + ab.cell( 2,0, 1); + ab.cell( 3,0, 2); + ab.cell( 4,1, 0); + ab.cell( 5,1, 1); + ab.cell( 6,1, 2); + Tensor a = ab.build(); + + Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[3],d1[2])")); + bb.cell( 7,0, 0); + bb.cell( 8,0, 1); + bb.cell( 9,1, 0); + bb.cell(10,1, 1); + bb.cell(11,2, 0); + bb.cell(12,2, 1); + Tensor b = bb.build(); + + Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2])")); + rb.cell( 58,0, 0); + rb.cell( 64,0, 1); + rb.cell(139,1, 0); + rb.cell(154,1, 1); + Tensor r = rb.build(); + + Tensor result = a.matmul(b.rename(ImmutableList.of("d0","d1"), ImmutableList.of("d1","d2")), "d1") + .rename("d2","d1"); + assertEquals(r, result); + } + + @Test + public void testMatmul3d() { + // Convention: a is the 'outermost' dimension, etc. + Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[3])")); + ab.cell( 1,0, 0, 0); + ab.cell( 2,0, 0, 1); + ab.cell( 3,0, 0, 2); + ab.cell( 4,0, 1, 0); + ab.cell( 5,0, 1, 1); + ab.cell( 6,0, 1, 2); + ab.cell( 7,1, 0, 0); + ab.cell( 8,1, 0, 1); + ab.cell( 9,1, 0, 2); + ab.cell(10,1, 1, 0); + ab.cell(11,1, 1, 1); + ab.cell(12,1, 1, 2); + Tensor a = ab.build(); + + Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3],d2[2])")); + bb.cell(13,0, 0, 0); + bb.cell(14,0, 0, 1); + bb.cell(15,0, 1, 0); + bb.cell(16,0, 1, 1); + bb.cell(17,0, 2, 0); + bb.cell(18,0, 2, 1); + bb.cell(19,1, 0, 0); + bb.cell(20,1, 0, 1); + bb.cell(21,1, 1, 0); + bb.cell(22,1, 1, 1); + bb.cell(23,1, 2, 0); + bb.cell(24,1, 2, 1); + Tensor b = bb.build(); + + Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[2])")); + rb.cell( 94,0, 0, 0); + rb.cell(100,0, 0, 1); + rb.cell(229,0, 1, 0); + rb.cell(244,0, 1, 1); + rb.cell(508,1, 0, 0); + rb.cell(532,1, 0, 1); + rb.cell(697,1, 1, 0); + rb.cell(730,1, 1, 1); + Tensor r = rb.build(); + + Tensor result = a.matmul(b.rename(ImmutableList.of("d1","d2"), ImmutableList.of("d2","d3")), "d2") + .rename("d3","d2"); + assertEquals(r, result); + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java index 8a58cb0bbed..55069eaced7 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java @@ -7,7 +7,7 @@ import static org.junit.Assert.assertEquals; /** * Tests translation of composite to primitive tensor function translation. - * + * * @author bratseth */ public class TensorFunctionTestCase { @@ -16,12 +16,12 @@ public class TensorFunctionTestCase { public void testTranslation() { assertTranslated("join({{x:1}:1.0}, reduce({{x:1}:1.0}, sum, x), f(a,b)(a / b))", new L1Normalize(new ConstantTensor("{{x:1}:1.0}"), "x")); - assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))", + assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))", new Diag(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build())); assertTranslated("join({{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, reduce({{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, max, x), f(a,b)(a==b))", new Argmax(new ConstantTensor("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x")); } - + private void assertTranslated(String expectedTranslation, TensorFunction inputFunction) { assertEquals(expectedTranslation, inputFunction.toPrimitive().toString()); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java index 349309a5052..15a872e439f 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java @@ -30,7 +30,7 @@ public class DenseBinaryFormatTestCase { assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); assertSerialization("tensor(x[1],y[2],z[3]):{{y:0,x:0,z:0}:2.0}"); } - + @Test public void testSerializationToSeparateType() { assertSerialization(Tensor.from("tensor(x[1],y[1]):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x[],y[])")); @@ -64,7 +64,7 @@ public class DenseBinaryFormatTestCase { private void assertSerialization(Tensor tensor) { assertSerialization(tensor, tensor.type()); } - + private void assertSerialization(Tensor tensor, TensorType expectedType) { byte[] encodedTensor = TypedBinaryFormat.encode(tensor); Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor)); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java index b1d7d797b3e..33dfca017f4 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java @@ -84,7 +84,7 @@ public class MixedBinaryFormatTestCase { private void assertSerialization(Tensor tensor) { assertSerialization(tensor, tensor.type()); } - + private void assertSerialization(Tensor tensor, TensorType expectedType) { byte[] encodedTensor = TypedBinaryFormat.encode(tensor); Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor)); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java index 68bf59e3ed9..f002637847b 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java @@ -50,7 +50,7 @@ public class SerializationTestCase { JsonNode node = mapper.readTree(test); if (node.has("tensor") && node.has("binary")) { System.out.println("Running test: " + test); - + Tensor tensor = buildTensor(node.get("tensor")); String spec = getSpec(node.get("tensor")); byte[] encodedTensor = TypedBinaryFormat.encode(tensor); @@ -123,7 +123,7 @@ public class SerializationTestCase { private byte[] getBytes(String binaryRepresentation) { return parseHexValue(binaryRepresentation.substring(2)); } - + private byte[] parseHexValue(String s) { final int len = s.length(); byte[] bytes = new byte[len/2]; diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index d17148cf8dc..f895b64379b 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -65,7 +65,7 @@ public class SparseBinaryFormatTestCase { private void assertSerialization(Tensor tensor, TensorType expectedType) { byte[] encodedTensor = TypedBinaryFormat.encode(tensor); - Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), + Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor)); assertEquals(tensor, decodedTensor); } diff --git a/vespalib/src/vespa/vespalib/stllike/hash_map.h b/vespalib/src/vespa/vespalib/stllike/hash_map.h index 31185a9ff7c..023594d3018 100644 --- a/vespalib/src/vespa/vespalib/stllike/hash_map.h +++ b/vespalib/src/vespa/vespalib/stllike/hash_map.h @@ -35,7 +35,7 @@ public: size_t capacity() const { return _ht.capacity(); } size_t size() const { return _ht.size(); } bool empty() const { return _ht.empty(); } - insert_result insert(const value_type & value); + insert_result insert(const value_type & value) { return _ht.insert(value); } template <typename InputIt> void insert(InputIt first, InputIt last); const V & operator [] (const K & key) const { return _ht.find(key)->second; } diff --git a/vespalib/src/vespa/vespalib/stllike/hash_map.hpp b/vespalib/src/vespa/vespalib/stllike/hash_map.hpp index 359ba235a36..b526188b8b2 100644 --- a/vespalib/src/vespa/vespalib/stllike/hash_map.hpp +++ b/vespalib/src/vespa/vespalib/stllike/hash_map.hpp @@ -17,13 +17,7 @@ hash_map<K, V, H, EQ, M>::hash_map(size_t reserveSize, H hasher, EQ equality) : { } template <typename K, typename V, typename H, typename EQ, typename M> -hash_map<K, V, H, EQ, M>::~hash_map() { } - -template <typename K, typename V, typename H, typename EQ, typename M> -typename hash_map<K, V, H, EQ, M>::insert_result -hash_map<K, V, H, EQ, M>::insert(const value_type & value) { - return _ht.insert(value); -} +hash_map<K, V, H, EQ, M>::~hash_map() = default; template <typename K, typename V, typename H, typename EQ, typename M> void @@ -64,12 +58,20 @@ hash_map<K, V, H, EQ, M>::getMemoryUsed() const } -#define VESPALIB_HASH_MAP_INSTANTIATE_H(K, V, H) \ - template class vespalib::hash_map<K, V, H>; \ - template class vespalib::hashtable<K, std::pair<K,V>, H, std::equal_to<K>, std::_Select1st<std::pair<K,V>>>; \ - template vespalib::hashtable<K, std::pair<K,V>, H, std::equal_to<K>, std::_Select1st<std::pair<K,V>>>::insert_result \ - vespalib::hashtable<K, std::pair<K,V>, H, std::equal_to<K>, std::_Select1st<std::pair<K,V>>>::insert(std::pair<K,V> &&); \ + +#define VESPALIB_HASH_MAP_INSTANTIATE_H_E_M(K, V, H, E, M) \ + template class vespalib::hash_map<K, V, H, E, M>; \ + template class vespalib::hashtable<K, std::pair<K,V>, H, E, std::_Select1st<std::pair<K,V>>, M>; \ + template vespalib::hashtable<K, std::pair<K,V>, H, E, std::_Select1st<std::pair<K,V>>, M>::insert_result \ + vespalib::hashtable<K, std::pair<K,V>, H, E, std::_Select1st<std::pair<K,V>>, M>::insert(std::pair<K,V> &&); \ + template vespalib::hashtable<K, std::pair<K,V>, H, E, std::_Select1st<std::pair<K,V>>, M>::insert_result \ + vespalib::hashtable<K, std::pair<K,V>, H, E, std::_Select1st<std::pair<K,V>>, M>::insertInternal(std::pair<K,V> &&); \ template class vespalib::Array<vespalib::hash_node<std::pair<K,V>>>; +#define VESPALIB_HASH_MAP_INSTANTIATE_H_E(K, V, H, E) \ + VESPALIB_HASH_MAP_INSTANTIATE_H_E_M(K, V, H, E, vespalib::hashtable_base::prime_modulator) + +#define VESPALIB_HASH_MAP_INSTANTIATE_H(K, V, H) VESPALIB_HASH_MAP_INSTANTIATE_H_E(K, V, H, std::equal_to<K>) + #define VESPALIB_HASH_MAP_INSTANTIATE(K, V) VESPALIB_HASH_MAP_INSTANTIATE_H(K, V, vespalib::hash<K>) diff --git a/vespalib/src/vespa/vespalib/stllike/hashtable.h b/vespalib/src/vespa/vespalib/stllike/hashtable.h index 263ee952c2e..15949067a60 100644 --- a/vespalib/src/vespa/vespalib/stllike/hashtable.h +++ b/vespalib/src/vespa/vespalib/stllike/hashtable.h @@ -141,19 +141,18 @@ public: typedef Value* pointer; typedef std::forward_iterator_tag iterator_category; - iterator(hashtable * hash, next_t start) : _hash(start), _subNode(start), _hashTable(hash) { - advanceToNextValidHash(); - } - iterator(hashtable * hash, next_t start, next_t subNode) : _hash(start), _subNode(subNode), _hashTable(hash) { } - Value & operator * () const { return _hashTable->get(_subNode); } - Value * operator -> () const { return & _hashTable->get(_subNode); } - iterator & operator ++ () { - if (_hashTable->_nodes[_subNode].hasNext()) { - _subNode = _hashTable->_nodes[_subNode].getNext(); - } else { - _hash++; + iterator(hashtable * hash) : _current(0), _hashTable(hash) { + if ((_current < _hashTable->initializedSize()) && ! _hashTable->_nodes[_current].valid()) { advanceToNextValidHash(); } + } + iterator(hashtable * hash, next_t pos) : _current(pos), _hashTable(hash) { } + static iterator end(hashtable *hash) { return iterator(hash, Node::npos); } + + Value & operator * () const { return _hashTable->get(_current); } + Value * operator -> () const { return & _hashTable->get(_current); } + iterator & operator ++ () { + advanceToNextValidHash(); return *this; } iterator operator ++ (int) { @@ -161,19 +160,19 @@ public: ++(*this); return prev; } - bool operator==(const iterator& rhs) const { return (_subNode == rhs._subNode); } - bool operator!=(const iterator& rhs) const { return (_subNode != rhs._subNode); } + bool operator==(const iterator& rhs) const { return (_current == rhs._current); } + bool operator!=(const iterator& rhs) const { return (_current != rhs._current); } /// Carefull about this one. Only used by lrucache. - next_t getInternalIndex() const { return _subNode; } - void setInternalIndex(next_t n) { _subNode = n; } - next_t getHash() const { return _hash; } + next_t getInternalIndex() const { return _current; } + void setInternalIndex(next_t n) { _current = n; } private: void advanceToNextValidHash() { - for (;(_hash < _hashTable->getTableSize()) && ! _hashTable->_nodes[_hash].valid(); _hash++) { } - _subNode = (_hash < _hashTable->getTableSize()) ? _hash : Node::npos; + for (_current++;(_current < _hashTable->initializedSize()) && ! _hashTable->_nodes[_current].valid(); _current++) { } + if (_current >= _hashTable->initializedSize()) { + _current = Node::npos; + } } - next_t _hash; - next_t _subNode; + next_t _current; hashtable * _hashTable; friend class hashtable::const_iterator; @@ -186,21 +185,19 @@ public: typedef const Value* pointer; typedef std::forward_iterator_tag iterator_category; - const_iterator(const hashtable * hash, next_t start) : _hash(start), _subNode(start), _hashTable(hash) { - advanceToNextValidHash(); - } - const_iterator(const hashtable * hash, next_t start, next_t subNode) : _hash(start), _subNode(subNode), _hashTable(hash) { } - const_iterator(const iterator &i) - : _hash(i._hash), _subNode(i._subNode), _hashTable(i._hashTable) {} - const Value & operator * () const { return _hashTable->get(_subNode); } - const Value * operator -> () const { return & _hashTable->get(_subNode); } - const_iterator & operator ++ () { - if (_hashTable->_nodes[_subNode].hasNext()) { - _subNode = _hashTable->_nodes[_subNode].getNext(); - } else { - _hash++; + const_iterator(const hashtable * hash) : _current(0), _hashTable(hash) { + if ((_current < _hashTable->initializedSize()) && ! _hashTable->_nodes[_current].valid()) { advanceToNextValidHash(); } + } + const_iterator(const hashtable * hash, next_t pos) : _current(pos), _hashTable(hash) { } + const_iterator(const iterator &i) : _current(i._current), _hashTable(i._hashTable) {} + static const_iterator end(const hashtable *hash) { return const_iterator(hash, Node::npos); } + + const Value & operator * () const { return _hashTable->get(_current); } + const Value * operator -> () const { return & _hashTable->get(_current); } + const_iterator & operator ++ () { + advanceToNextValidHash(); return *this; } const_iterator operator ++ (int) { @@ -208,17 +205,17 @@ public: ++(*this); return prev; } - bool operator==(const const_iterator& rhs) const { return (_subNode == rhs._subNode); } - bool operator!=(const const_iterator& rhs) const { return (_subNode != rhs._subNode); } - next_t getInternalIndex() const { return _subNode; } - next_t getHash() const { return _hash; } + bool operator==(const const_iterator& rhs) const { return (_current == rhs._current); } + bool operator!=(const const_iterator& rhs) const { return (_current != rhs._current); } + next_t getInternalIndex() const { return _current; } private: void advanceToNextValidHash() { - for (;(_hash < _hashTable->getTableSize()) && ! _hashTable->_nodes[_hash].valid(); _hash++) { } - _subNode = (_hash < _hashTable->getTableSize()) ? _hash : Node::npos; + for (_current++;(_current < _hashTable->initializedSize()) && ! _hashTable->_nodes[_current].valid(); _current++) { } + if (_current >= _hashTable->initializedSize()) { + _current = Node::npos; + } } - next_t _hash; - next_t _subNode; + next_t _current; const hashtable * _hashTable; }; typedef std::pair<iterator, bool> insert_result; @@ -231,10 +228,10 @@ public: hashtable(size_t reservedSpace); hashtable(size_t reservedSpace, const Hash & hasher, const Equal & equal); virtual ~hashtable(); - iterator begin() { return iterator(this, 0); } - iterator end() { return iterator(this, Node::npos); } - const_iterator begin() const { return const_iterator(this, 0); } - const_iterator end() const { return const_iterator(this, Node::npos); } + iterator begin() { return iterator(this); } + iterator end() { return iterator::end(this); } + const_iterator begin() const { return const_iterator(this); } + const_iterator end() const { return const_iterator::end(this); } size_t capacity() const { return _nodes.capacity(); } size_t size() const { return _count; } bool empty() const { return _count == 0; } @@ -249,7 +246,9 @@ public: const_iterator find(const AltKey & key) const { return find<AltKey, AltExtract, AltHash, AltEqual>(key, AltExtract()); } const_iterator find(const Key & key) const; template <typename V> - insert_result insert(V && node); + insert_result insert(V && node) { + return insertInternal(std::forward<V>(node)); + } void erase(const Key & key); void reserve(size_t sz) { if (sz > _nodes.capacity()) { @@ -280,7 +279,8 @@ protected: Value & getByInternalIndex(size_t index) { return _nodes[index].getValue(); } const Value & getByInternalIndex(size_t index) const { return _nodes[index].getValue(); } template <typename MoveHandler> - void erase(MoveHandler & moveHandler, const const_iterator & key); + void erase(MoveHandler & moveHandler, next_t h, const const_iterator & key); + next_t hash(const Key & key) const { return modulator(_hasher(key)); } private: Modulator _modulator; size_t _count; @@ -292,7 +292,7 @@ private: const Value & get(size_t index) const { return _nodes[index].getValue(); } next_t modulator(next_t key) const { return _modulator.modulo(key); } next_t getTableSize() const { return _modulator.getTableSize(); } - next_t hash(const Key & key) const { return modulator(_hasher(key)); } + size_t initializedSize() const { return _nodes.size(); } template <typename MoveHandler> void move(MoveHandler & moveHandler, next_t from, next_t to) { _nodes[to] = std::move(_nodes[from]); diff --git a/vespalib/src/vespa/vespalib/stllike/hashtable.hpp b/vespalib/src/vespa/vespalib/stllike/hashtable.hpp index 359e71aa0d2..f499ba35f3f 100644 --- a/vespalib/src/vespa/vespalib/stllike/hashtable.hpp +++ b/vespalib/src/vespa/vespalib/stllike/hashtable.hpp @@ -67,11 +67,10 @@ typename hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::iterator hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::find(const Key & key) { next_t h = hash(key); - if (_nodes[h].valid()) { - next_t start(h); + if (__builtin_expect(_nodes[h].valid(), true)) { do { - if (_equal(_keyExtractor(_nodes[h].getValue()), key)) { - return iterator(this, start, h); + if (__builtin_expect(_equal(_keyExtractor(_nodes[h].getValue()), key), true)) { + return iterator(this, h); } h = _nodes[h].getNext(); } while (h != Node::npos); @@ -84,11 +83,10 @@ typename hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::const_iterat hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::find(const Key & key) const { next_t h = hash(key); - if (_nodes[h].valid()) { - next_t start(h); + if (__builtin_expect(_nodes[h].valid(), true)) { do { - if (_equal(_keyExtractor(_nodes[h].getValue()), key)) { - return const_iterator(this, start, h); + if (__builtin_expect(_equal(_keyExtractor(_nodes[h].getValue()), key), true)) { + return const_iterator(this, h); } h = _nodes[h].getNext(); } while (h != Node::npos); @@ -104,11 +102,10 @@ hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::find(const AltKey & k AltHash altHasher; next_t h = modulator(altHasher(key)); if (_nodes[h].valid()) { - next_t start(h); AltEqual altEqual; do { if (altEqual(altExtract(_keyExtractor(_nodes[h].getValue())), key)) { - return const_iterator(this, start, h); + return const_iterator(this, h); } h = _nodes[h].getNext(); } while (h != Node::npos); @@ -124,11 +121,10 @@ hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::find(const AltKey & k AltHash altHasher; next_t h = modulator(altHasher(key)); if (_nodes[h].valid()) { - next_t start(h); AltEqual altEqual; do { if (altEqual(altExtract(_keyExtractor(_nodes[h].getValue())), key)) { - return iterator(this, start, h); + return iterator(this, h); } h = _nodes[h].getNext(); } while (h != Node::npos); @@ -137,19 +133,12 @@ hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::find(const AltKey & k } template< typename Key, typename Value, typename Hash, typename Equal, typename KeyExtract, typename Modulator > -template<typename V> -typename hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::insert_result -hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::insert(V && node) { - return insertInternal(std::forward<V>(node)); -} - -template< typename Key, typename Value, typename Hash, typename Equal, typename KeyExtract, typename Modulator > void hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::erase(const Key & key) { const_iterator found(find(key)); if (found != end()) { DefaultMoveHandler moveHandler; - erase(moveHandler, found); + erase(moveHandler, hash(key), found); } } @@ -169,11 +158,11 @@ hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::insertInternal(V && n if ( ! _nodes[h].valid() ) { _nodes[h] = std::forward<V>(node); _count++; - return insert_result(iterator(this, h, h), true); + return insert_result(iterator(this, h), true); } else if (_nodes.size() <= _nodes.capacity()) { for (next_t c(h); c != Node::npos; c = _nodes[c].getNext()) { if (_equal(_keyExtractor(_nodes[c].getValue()), _keyExtractor(node))) { - return insert_result(iterator(this, h, c), false); + return insert_result(iterator(this, c), false); } } if (_nodes.size() < _nodes.capacity()) { @@ -182,7 +171,7 @@ hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::insertInternal(V && n _nodes[h].setNext(newIdx); new (_nodes.push_back_fast()) Node(std::forward<V>(node), p); _count++; - return insert_result(iterator(this, h, newIdx), true); + return insert_result(iterator(this, newIdx), true); } else { resize(_nodes.capacity()*2); return insertInternal(std::forward<V>(node)); @@ -214,9 +203,8 @@ void hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::reclaim(MoveHand template< typename Key, typename Value, typename Hash, typename Equal, typename KeyExtract, typename Modulator > template <typename MoveHandler> void -hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::erase(MoveHandler & moveHandler, const const_iterator & it) +hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::erase(MoveHandler & moveHandler, next_t h, const const_iterator & it) { - next_t h = it.getHash(); next_t prev = Node::npos; do { if (h == it.getInternalIndex()) { diff --git a/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java b/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java index 15257e11cbe..4c932969460 100644 --- a/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java +++ b/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.curator; import com.google.inject.Inject; import com.yahoo.cloud.config.ConfigserverConfig; +import com.yahoo.net.HostName; import com.yahoo.path.Path; import com.yahoo.vespa.curator.recipes.CuratorCounter; import com.yahoo.vespa.zookeeper.ZooKeeperServer; @@ -21,7 +22,6 @@ import org.apache.curator.framework.state.ConnectionState; import org.apache.curator.framework.state.ConnectionStateListener; import org.apache.curator.retry.ExponentialBackoffRetry; -import java.io.Closeable; import java.time.Duration; import java.util.Arrays; import java.util.Collections; @@ -68,16 +68,26 @@ public class Curator implements AutoCloseable { this(createConnectionSpec(configserverConfig)); } - private static String createConnectionSpec(ConfigserverConfig config) { + static String createConnectionSpec(ConfigserverConfig config) { + String thisServer = HostName.getLocalhost(); + StringBuilder sb = new StringBuilder(); for (int i = 0; i < config.zookeeperserver().size(); i++) { ConfigserverConfig.Zookeeperserver server = config.zookeeperserver(i); - sb.append(server.hostname()); - sb.append(":"); - sb.append(server.port()); - if (i < config.zookeeperserver().size() - 1) { - sb.append(","); + + String spec = String.format("%s:%d", server.hostname(), server.port()); + + if (config.zookeeperLocalhostAffinity() && server.hostname().equals(thisServer)) { + // Only connect to localhost server if possible, to save network traffic + // and balance load. + return spec; } + + if (sb.length() > 0) { + sb.append(','); + } + + sb.append(spec); } return sb.toString(); } diff --git a/zkfacade/src/test/java/com/yahoo/vespa/zookeeper/CuratorTest.java b/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java index 36205bdaca3..1899dcfe7cd 100644 --- a/zkfacade/src/test/java/com/yahoo/vespa/zookeeper/CuratorTest.java +++ b/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java @@ -1,8 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.zookeeper; +package com.yahoo.vespa.curator; import com.yahoo.cloud.config.ConfigserverConfig; -import com.yahoo.vespa.curator.Curator; +import com.yahoo.net.HostName; import org.apache.curator.test.TestingServer; import org.junit.After; import org.junit.Before; @@ -11,7 +11,6 @@ import org.junit.Test; import java.io.IOException; import static org.hamcrest.core.Is.is; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThat; /** @@ -74,6 +73,23 @@ public class CuratorTest { } } + @Test + public void localhost_affinity() { + String localhostHostName = "myhost"; + int localhostPort = 123; + String localhostSpec = localhostHostName + ":" + localhostPort; + + ConfigserverConfig.Builder builder = new ConfigserverConfig.Builder(); + builder.zookeeperserver(createZKBuilder(localhostHostName, localhostPort)); + builder.zookeeperserver(createZKBuilder("otherhost", 345)); + builder.zookeeperLocalhostAffinity(true); + ConfigserverConfig config = new ConfigserverConfig(builder); + + HostName.setHostNameForTestingOnly(localhostHostName); + + assertThat(Curator.createConnectionSpec(config), is(localhostSpec)); + } + private ConfigserverConfig createTestConfig() { ConfigserverConfig.Builder builder = new ConfigserverConfig.Builder(); builder.zookeeperserver(createZKBuilder("localhost", port1)); |