diff options
100 files changed, 2319 insertions, 354 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/AbstractService.java b/config-model/src/main/java/com/yahoo/vespa/model/AbstractService.java index 60a49598c42..daa237d90c1 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/AbstractService.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/AbstractService.java @@ -482,8 +482,8 @@ public abstract class AbstractService extends AbstractConfigProducer<AbstractCon * Must be done this way since the system test framework * currently uses the first port as container http port. */ - public void reservePortPrepended(int port) { - hostResource.reservePort(this, port); + public void reservePortPrepended(int port, String suffix) { + hostResource.reservePort(this, port, suffix); ports.add(0, port); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ConfigProxy.java b/config-model/src/main/java/com/yahoo/vespa/model/ConfigProxy.java index c540a5f62d2..7ab28faa434 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ConfigProxy.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ConfigProxy.java @@ -47,6 +47,11 @@ public class ConfigProxy extends AbstractService { */ public int getPortCount() { return 1; } + @Override + public String[] getPortSuffixes() { + return new String[]{"rpc"}; + } + /** * The config proxy is not started by the config system! */ diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ConfigSentinel.java b/config-model/src/main/java/com/yahoo/vespa/model/ConfigSentinel.java index cd92f27cc50..1b5c5e4a579 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ConfigSentinel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ConfigSentinel.java @@ -26,7 +26,7 @@ public class ConfigSentinel extends AbstractService implements SentinelConfig.Pr super(host, "sentinel"); this.applicationId = applicationId; this.zone = zone; - portsMeta.on(0).tag("rpc").tag("notyet"); + portsMeta.on(0).tag("rpc").tag("admin"); portsMeta.on(1).tag("telnet").tag("interactive").tag("http").tag("state"); setProp("clustertype", "hosts"); setProp("clustername", "admin"); @@ -48,6 +48,11 @@ public class ConfigSentinel extends AbstractService implements SentinelConfig.Pr public int getPortCount() { return 2; } @Override + public String[] getPortSuffixes() { + return new String[]{ "rpc", "http" }; + } + + @Override public int getHealthPort() {return getRelativePort(1); } /** diff --git a/config-model/src/main/java/com/yahoo/vespa/model/HostResource.java b/config-model/src/main/java/com/yahoo/vespa/model/HostResource.java index a27b33173ee..3a8cc5c2e4c 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/HostResource.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/HostResource.java @@ -6,6 +6,7 @@ import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.api.HostInfo; import com.yahoo.config.provision.ClusterMembership; import com.yahoo.config.provision.Flavor; +import com.yahoo.config.provision.NetworkPorts; import javax.annotation.Nullable; import java.util.ArrayList; @@ -36,10 +37,23 @@ public class HostResource implements Comparable<HostResource> { /** Map from "sentinel name" to service */ private final Map<String, Service> services = new LinkedHashMap<>(); - private final Map<Integer, Service> portDB = new LinkedHashMap<>(); + private final Map<Integer, NetworkPortRequestor> portDB = new LinkedHashMap<>(); private int allocatedPorts = 0; + static class PortReservation { + int gotPort; + NetworkPortRequestor service; + String suffix; + PortReservation(int port, NetworkPortRequestor svc, String suf) { + this.gotPort = port; + this.service = svc; + this.suffix = suf; + } + } + + private List<PortReservation> portReservations = new ArrayList<>(); + private Set<ClusterMembership> clusterMemberships = new LinkedHashSet<>(); // Empty for self-hosted Vespa. @@ -48,6 +62,12 @@ public class HostResource implements Comparable<HostResource> { /** The current Vespa version running on this node, or empty if not known */ private final Optional<Version> version; + private Optional<NetworkPorts> networkPortsList = Optional.empty(); + + public Optional<NetworkPorts> networkPorts() { return networkPortsList; } + + public void addNetworkPorts(NetworkPorts ports) { this.networkPortsList = Optional.of(ports); } + /** * Create a new {@link HostResource} bound to a specific {@link com.yahoo.vespa.model.Host}. * @@ -108,22 +128,22 @@ public class HostResource implements Comparable<HostResource> { return ports; } - private List<Integer> allocatePorts(DeployLogger deployLogger, AbstractService service, int wantedPort) { + private List<Integer> allocatePorts(DeployLogger deployLogger, NetworkPortRequestor service, int wantedPort) { List<Integer> ports = new ArrayList<>(); if (service.getPortCount() < 1) return ports; int serviceBasePort = BASE_PORT + allocatedPorts; if (wantedPort > 0) { - if (service.getPortCount() < 1) { - throw new RuntimeException(service + " wants baseport " + wantedPort + - ", but it has not reserved any ports, so it cannot name a desired baseport."); - } if (service.requiresWantedPort() || canUseWantedPort(deployLogger, service, wantedPort, serviceBasePort)) serviceBasePort = wantedPort; } + String[] suffixes = service.getPortSuffixes(); + if (suffixes.length != service.getPortCount()) { + throw new IllegalArgumentException("service "+service+" had "+suffixes.length+" port suffixes, but port count "+service.getPortCount()+", mismatch"); + } - reservePort(service, serviceBasePort); + reservePort(service, serviceBasePort, suffixes[0]); ports.add(serviceBasePort); int remainingPortsStart = service.requiresConsecutivePorts() ? @@ -131,17 +151,30 @@ public class HostResource implements Comparable<HostResource> { BASE_PORT + allocatedPorts; for (int i = 0; i < service.getPortCount() - 1; i++) { int port = remainingPortsStart + i; - reservePort(service, port); + reservePort(service, port, suffixes[i+1]); ports.add(port); } + if (suffixes.length != service.getPortCount()) { + throw new IllegalArgumentException("service "+service+" had "+suffixes.length+" port suffixes, but port count "+service.getPortCount()+", mismatch"); + } return ports; } - private boolean canUseWantedPort(DeployLogger deployLogger, AbstractService service, int wantedPort, int serviceBasePort) { + public void flushPortReservations() { + List<NetworkPorts.Allocation> list = new ArrayList<>(); + for (PortReservation pr : portReservations) { + String servType = pr.service.getServiceType(); + String configId = pr.service.getConfigId(); + list.add(new NetworkPorts.Allocation(pr.gotPort, servType, configId, pr.suffix)); + } + this.networkPortsList = Optional.of(new NetworkPorts(list)); + } + + private boolean canUseWantedPort(DeployLogger deployLogger, NetworkPortRequestor service, int wantedPort, int serviceBasePort) { for (int i = 0; i < service.getPortCount(); i++) { int port = wantedPort + i; if (portDB.containsKey(port)) { - AbstractService s = (AbstractService)portDB.get(port); + NetworkPortRequestor s = portDB.get(port); deployLogger.log(Level.WARNING, service.getServiceName() +" cannot reserve port " + port + " on " + this + ": Already reserved for " + s.getServiceName() + ". Using default port range from " + serviceBasePort); @@ -159,7 +192,7 @@ public class HostResource implements Comparable<HostResource> { * @param service the service that wishes to reserve the port. * @param port the port to be reserved. */ - void reservePort(AbstractService service, int port) { + void reservePort(NetworkPortRequestor service, int port, String suffix) { if (portDB.containsKey(port)) { portAlreadyReserved(service, port); } else { @@ -170,6 +203,7 @@ public class HostResource implements Comparable<HostResource> { } } portDB.put(port, service); + portReservations.add(new PortReservation(port, service, suffix)); } } @@ -178,8 +212,8 @@ public class HostResource implements Comparable<HostResource> { port < BASE_PORT + MAX_PORTS; } - private void portAlreadyReserved(AbstractService service, int port) { - AbstractService otherService = (AbstractService)portDB.get(port); + private void portAlreadyReserved(NetworkPortRequestor service, int port) { + NetworkPortRequestor otherService = portDB.get(port); int nextAvailablePort = nextAvailableBaseport(service.getPortCount()); if (nextAvailablePort == 0) { noMoreAvailablePorts(); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/HostSystem.java b/config-model/src/main/java/com/yahoo/vespa/model/HostSystem.java index fdfe4f01790..a1b030ffc61 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/HostSystem.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/HostSystem.java @@ -8,6 +8,7 @@ import com.yahoo.config.provision.Capacity; import com.yahoo.config.provision.ClusterMembership; import com.yahoo.config.provision.ClusterSpec; import com.yahoo.config.provision.HostSpec; +import com.yahoo.config.provision.NetworkPorts; import com.yahoo.config.provision.ProvisionLogger; import java.net.UnknownHostException; @@ -126,9 +127,11 @@ public class HostSystem extends AbstractConfigProducer<Host> { private HostResource addNewHost(HostSpec hostSpec) { Host host = Host.createHost(this, hostSpec.hostname()); - HostResource hostResource = new HostResource(host, hostSpec.version()); + HostResource hostResource = new HostResource(host, + hostSpec.version()); hostResource.setFlavor(hostSpec.flavor()); hostSpec.membership().ifPresent(hostResource::addClusterMembership); + hostSpec.networkPorts().ifPresent(hostResource::addNetworkPorts); hostname2host.put(host.getHostname(), hostResource); log.log(DEBUG, () -> "Added new host resource for " + host.getHostname() + " with flavor " + hostResource.getFlavor()); return hostResource; @@ -141,6 +144,19 @@ public class HostSystem extends AbstractConfigProducer<Host> { .collect(Collectors.toList()); } + public void dumpPortAllocations() { + for (HostResource hr : getHosts()) { + hr.flushPortReservations(); +/* + System.out.println("port allocations for: "+hr.getHostname()); + NetworkPorts ports = hr.networkPorts().get(); + for (NetworkPorts.Allocation allocation: ports.allocations()) { + System.out.println("port="+allocation.port+" [type="+allocation.serviceType+", cfgId="+allocation.configId+", suffix="+allocation.portSuffix+"]"); + } +*/ + } + } + public Map<HostResource, ClusterMembership> allocateHosts(ClusterSpec cluster, Capacity capacity, int groups, DeployLogger logger) { List<HostSpec> allocatedHosts = provisioner.prepare(cluster, capacity, groups, new ProvisionDeployLogger(logger)); // TODO: Even if HostResource owns a set of memberships, we need to return a map because the caller needs the current membership. @@ -177,7 +193,7 @@ public class HostSystem extends AbstractConfigProducer<Host> { Set<HostSpec> getHostSpecs() { return getHosts().stream() .map(host -> new HostSpec(host.getHostname(), Collections.emptyList(), - host.getFlavor(), host.primaryClusterMembership(), host.version())) + host.getFlavor(), host.primaryClusterMembership(), host.version(), host.networkPorts())) .collect(Collectors.toCollection(LinkedHashSet::new)); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/Logd.java b/config-model/src/main/java/com/yahoo/vespa/model/Logd.java index 0f7418582a3..3c7f1ba6cfa 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/Logd.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/Logd.java @@ -32,6 +32,11 @@ public class Logd */ public int getPortCount() { return 1; } + @Override + public String[] getPortSuffixes() { + return new String[]{"http"}; + } + /** Returns the desired base port for this service. */ public int getWantedPort() { return 19089; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/NetworkPortRequestor.java b/config-model/src/main/java/com/yahoo/vespa/model/NetworkPortRequestor.java new file mode 100644 index 00000000000..52319f71810 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/NetworkPortRequestor.java @@ -0,0 +1,53 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.model; + +/** + * Interface implemented by services using network ports, identifying its requirements. + * @author arnej + */ +public interface NetworkPortRequestor { + + /** Returns the type of service */ + String getServiceType(); + + /** Returns the name that identifies this service for the config-sentinel */ + String getServiceName(); + + /** Returns the config id */ + String getConfigId(); + + /** + * Returns the desired base port for this service, or '0' if this + * service should use the default port allocation mechanism. + * + * @return The desired base port for this service. + */ + default int getWantedPort() { return 0; } + + /** Returns the number of ports needed by this service. */ + int getPortCount(); + + /** + * Returns true if the desired base port (returned by + * getWantedPort()) for this service is the only allowed base + * port. + * + * @return true if this Service requires the wanted base port. + */ + default boolean requiresWantedPort() { return false; } + + /** + * Override if the services does not require consecutive port numbers. I.e. if any ports + * in addition to the baseport should be allocated from Vespa's default port range. + * + * @return true by default + */ + default boolean requiresConsecutivePorts() { return true; } + + /** + * Return names for each port requested. + * The size of the returned array must be equal to getPortCount(). + **/ + String[] getPortSuffixes(); +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/Service.java b/config-model/src/main/java/com/yahoo/vespa/model/Service.java index d5d33a08b5d..0af4355764c 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/Service.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/Service.java @@ -11,7 +11,7 @@ import java.util.Optional; * * @author gjoranv */ -public interface Service extends ConfigProducer { +public interface Service extends ConfigProducer, NetworkPortRequestor { /** * Services that should be started by config-sentinel must return @@ -44,39 +44,6 @@ public interface Service extends ConfigProducer { boolean getAutorestartFlag(); /** - * Returns the type of service. E.g. the class-name without the - * package prefix. - */ - String getServiceType(); - - /** - * Returns the name that identifies this service for the config-sentinel. - */ - String getServiceName(); - - /** - * Returns the desired base port for this service, or '0' if this - * service should use the default port allocation mechanism. - * - * @return The desired base port for this service. - */ - int getWantedPort(); - - /** - * Returns true if the desired base port (returned by - * getWantedPort()) for this service is the only allowed base - * port. - * - * @return true if this Service requires the wanted base port. - */ - boolean requiresWantedPort(); - - /** - * Returns the number of ports needed by this service. - */ - int getPortCount(); - - /** * Returns a PortsMeta object, giving access to more information * about the different ports of this service. */ diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index 20089dc3980..d954c69d144 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -172,7 +172,6 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri if (complete) { // create a a completed, frozen model configModelRepo.readConfigModels(deployState, this, builder, root, configModelRegistry); addServiceClusters(deployState, builder); - this.allocatedHosts = AllocatedHosts.withHosts(hostSystem.getHostSpecs()); // must happen after the two lines above setupRouting(deployState); this.fileDistributor = root.getFileDistributionConfigProducer().getFileDistributor(); getAdmin().addPerHostServices(hostSystem.getHosts(), deployState); @@ -180,6 +179,9 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri root.prepare(configModelRepo); configModelRepo.prepareConfigModels(deployState); validateWrapExceptions(); + hostSystem.dumpPortAllocations(); + // must happen after stuff above + this.allocatedHosts = AllocatedHosts.withHosts(hostSystem.getHostSpecs()); } else { // create a model with no services instantiated and the given file distributor this.allocatedHosts = AllocatedHosts.withHosts(hostSystem.getHostSpecs()); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/Configserver.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/Configserver.java index 2a32549b6bf..a2839ec0fb6 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/admin/Configserver.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/Configserver.java @@ -51,6 +51,11 @@ public class Configserver extends AbstractService { */ public int getPortCount() { return 2; } + @Override + public String[] getPortSuffixes() { + return new String[]{ "rpc", "http" }; + } + /** * The configserver is not started by the config system! */ diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/LogForwarder.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/LogForwarder.java index 2693a4c7409..d766507c75f 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/admin/LogForwarder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/LogForwarder.java @@ -55,6 +55,11 @@ public class LogForwarder extends AbstractService implements LogforwarderConfig. */ public int getPortCount() { return 0; } + @Override + public String[] getPortSuffixes() { + return null; + } + /** * @return The command used to start LogForwarder */ diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/Logserver.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/Logserver.java index c354445b690..4dcbfb5b3c3 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/admin/Logserver.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/Logserver.java @@ -70,4 +70,9 @@ public class Logserver extends AbstractService { return 4; } + @Override + public String[] getPortSuffixes() { + return new String[]{ "unused", "logtp", "last.errors", "replicator" }; + } + } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/Slobrok.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/Slobrok.java index 12a0d35de5e..99738c13d4a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/admin/Slobrok.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/Slobrok.java @@ -56,6 +56,11 @@ public class Slobrok extends AbstractService implements StateserverConfig.Produc return 2; } + @Override + public String[] getPortSuffixes() { + return new String[] { "rpc", "http" }; + } + /** * @return The port on which this slobrok should respond, as a String. */ diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java index 88d66f0c608..c7cc04faf95 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java @@ -93,7 +93,6 @@ public class DomAdminV4Builder extends DomAdminBuilderBase { } private NodesSpecification createNodesSpecificationForLogserver() { - // TODO: Enable for main system as well DeployState deployState = context.getDeployState(); if (deployState.getProperties().useDedicatedNodeForLogserver() && context.getApplicationType() == ConfigModelContext.ApplicationType.DEFAULT && @@ -124,6 +123,7 @@ public class DomAdminV4Builder extends DomAdminBuilderBase { logServerCluster.addContainer(container); admin.addAndInitializeService(deployState.getDeployLogger(), hostResource, container); admin.setLogserverContainerCluster(logServerCluster); + context.getConfigModelRepoAdder().add(logserverClusterModel); } private void addLogHandler(ContainerCluster cluster) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/Container.java b/config-model/src/main/java/com/yahoo/vespa/model/container/Container.java index dc962ed5931..f61fc3d4df8 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/Container.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/Container.java @@ -174,7 +174,7 @@ public class Container extends AbstractService implements private void reserveHttpPortsPrepended() { if (getHttp().getHttpServer() != null) { for (ConnectorFactory connectorFactory : getHttp().getHttpServer().getConnectorFactories()) { - reservePortPrepended(getPort(connectorFactory)); + reservePortPrepended(getPort(connectorFactory), "http/" + connectorFactory.getName()); } } } @@ -240,6 +240,34 @@ public class Container extends AbstractService implements return httpPorts + rpcPorts; } + @Override + public String[] getPortSuffixes() { + // TODO clean up this mess + int n = getPortCount(); + String[] suffixes = new String[n]; + int off = 0; + int httpPorts = (getHttp() != null) ? 0 : numHttpServerPorts; + if (httpPorts > 0) { + suffixes[off++] = "http"; + } + for (int i = 1; i < httpPorts; i++) { + suffixes[off++] = "http/" + i; + } + int rpcPorts = (rpcServerEnabled()) ? numRpcServerPorts : 0; + if (rpcPorts > 0) { + suffixes[off++] = "messaging"; + } + if (rpcPorts > 1) { + suffixes[off++] = "rpc"; + } + while (off < n) { + suffixes[off] = "unused/" + off; + ++off; + } + assert (off == n); + return suffixes; + } + /** * @return the actual search port * TODO: Remove. Use {@link #getPortsMeta()} and check tags in conjunction with {@link #getRelativePort(int)}. diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/ContentNode.java b/config-model/src/main/java/com/yahoo/vespa/model/content/ContentNode.java index c75421c9636..dc9372c463b 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/ContentNode.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/ContentNode.java @@ -57,6 +57,10 @@ public abstract class ContentNode extends AbstractService @Override public int getPortCount() { return 3; } + @Override + public String[] getPortSuffixes() { + return new String[] { "messaging", "rpc", "http" }; + } @Override public void getConfig(StorCommunicationmanagerConfig.Builder builder) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/generic/service/Service.java b/config-model/src/main/java/com/yahoo/vespa/model/generic/service/Service.java index 9ccf5103175..0d30bade53c 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/generic/service/Service.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/generic/service/Service.java @@ -24,6 +24,11 @@ public class Service extends AbstractService { } @Override + public String[] getPortSuffixes() { + return null; + } + + @Override public String getStartupCommand() { return ((ServiceCluster) getParent()).getCommand(); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/Dispatch.java b/config-model/src/main/java/com/yahoo/vespa/model/search/Dispatch.java index b9c937b4a4c..9b4fe93d6ea 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/Dispatch.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/Dispatch.java @@ -234,4 +234,10 @@ public class Dispatch extends AbstractService implements SearchInterface, * @return the number of ports needed */ public int getPortCount() { return 3; } + + @Override + public String[] getPortSuffixes() { + return new String[]{ "rpc", "fs4", "health" }; + } + } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/SearchNode.java b/config-model/src/main/java/com/yahoo/vespa/model/search/SearchNode.java index 5be51310504..934184d5972 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/SearchNode.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/SearchNode.java @@ -160,6 +160,11 @@ public class SearchNode extends AbstractService implements return 5; } + @Override + public String[] getPortSuffixes() { + return new String[] { "rpc", "fs4", "future/4", "unused/3", "health" }; + } + /** * Returns the RPC port used by this searchnode. * diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/TransactionLogServer.java b/config-model/src/main/java/com/yahoo/vespa/model/search/TransactionLogServer.java index 61cac8afb91..c42579085a5 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/TransactionLogServer.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/TransactionLogServer.java @@ -38,6 +38,11 @@ public class TransactionLogServer extends AbstractService { return 1; } + @Override + public String[] getPortSuffixes() { + return new String[]{"tls"}; + } + /** * Returns the port used by the TLS. * diff --git a/config-model/src/test/java/com/yahoo/vespa/model/HostResourceTest.java b/config-model/src/test/java/com/yahoo/vespa/model/HostResourceTest.java index abf4ec02a3e..d16bbe72a95 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/HostResourceTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/HostResourceTest.java @@ -37,7 +37,7 @@ public class HostResourceTest { public void next_available_baseport_is_BASE_PORT_plus_one_when_one_port_has_been_reserved() { MockRoot root = new MockRoot(); HostResource host = mockHostResource(root); - host.reservePort(new TestService(1), HostResource.BASE_PORT); + host.reservePort(new TestService(1), HostResource.BASE_PORT, "foo"); assertThat(host.nextAvailableBaseport(1), is(HostResource.BASE_PORT + 1)); } @@ -47,12 +47,12 @@ public class HostResourceTest { HostResource host = mockHostResource(root); for (int p = HostResource.BASE_PORT; p < HostResource.BASE_PORT + HostResource.MAX_PORTS; p += 2) { - host.reservePort(new TestService(1), p); + host.reservePort(new TestService(1), p, "foo"); } assertThat(host.nextAvailableBaseport(2), is(0)); try { - host.reservePort(new TestService(2), HostResource.BASE_PORT); + host.reservePort(new TestService(2), HostResource.BASE_PORT, "bar"); } catch (RuntimeException e) { assertThat(e.getMessage(), containsString("Too many ports are reserved")); } @@ -181,5 +181,14 @@ public class HostResourceTest { @Override public int getPortCount() { return portCount; } + + @Override + public String[] getPortSuffixes() { + String[] suffixes = new String[portCount]; + for (int i = 0; i < portCount; i++) { + suffixes[i] = "generic." + i; + } + return suffixes; + } } } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ConfigValueChangeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ConfigValueChangeValidatorTest.java index 2456113f40d..f13f53e8648 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ConfigValueChangeValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ConfigValueChangeValidatorTest.java @@ -246,6 +246,9 @@ public class ConfigValueChangeValidatorTest { public int getPortCount() { return 0; } + + @Override + public String[] getPortSuffixes() { return null; } } private static class SimpleConfigProducer extends AbstractConfigProducer<AbstractConfigProducer<?>> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/StartupCommandChangeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/StartupCommandChangeValidatorTest.java index 4f6a1ddf7b3..2b04b026ee7 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/StartupCommandChangeValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/StartupCommandChangeValidatorTest.java @@ -75,5 +75,8 @@ public class StartupCommandChangeValidatorTest { public int getPortCount() { return 0; } + + @Override + public String[] getPortSuffixes() { return null; } } } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/test/ApiService.java b/config-model/src/test/java/com/yahoo/vespa/model/test/ApiService.java index 28c56e1e45f..6f06eb3e482 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/test/ApiService.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/test/ApiService.java @@ -36,4 +36,6 @@ public class ApiService extends AbstractService implements com.yahoo.test.Standa public int getPortCount() { return 0; } + @Override + public String[] getPortSuffixes() { return null; } } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/test/ModelAmendingTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/test/ModelAmendingTestCase.java index 57b0606457d..82da14f0d29 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/test/ModelAmendingTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/test/ModelAmendingTestCase.java @@ -129,6 +129,8 @@ public class ModelAmendingTestCase { return 0; } + @Override + public String[] getPortSuffixes() { return null; } } public static class AdminModelAmender extends ConfigModel { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/test/ParentService.java b/config-model/src/test/java/com/yahoo/vespa/model/test/ParentService.java index 325cc78a361..c7559c68592 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/test/ParentService.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/test/ParentService.java @@ -58,4 +58,7 @@ public class ParentService extends AbstractService implements com.yahoo.test.Sta } public int getPortCount() { return 0; } + + @Override + public String[] getPortSuffixes() { return null; } } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/test/SimpleService.java b/config-model/src/test/java/com/yahoo/vespa/model/test/SimpleService.java index 0037e40a20f..a38916463c4 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/test/SimpleService.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/test/SimpleService.java @@ -38,6 +38,11 @@ public class SimpleService extends AbstractService implements com.yahoo.test.Sta public int getWantedPort(){ return 10000; } public int getPortCount() { return 5; } + @Override + public String[] getPortSuffixes() { + return new String[]{ "a", "b", "c", "d", "e" }; + } + // Make sure this service is listed in the sentinel config public String getStartupCommand() { return "sleep 0"; } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/utils/FileSenderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/utils/FileSenderTest.java index 9b5bd71274a..79646dacaa9 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/utils/FileSenderTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/utils/FileSenderTest.java @@ -172,5 +172,8 @@ public class FileSenderTest { public int getPortCount() { return 0; } + + @Override + public String[] getPortSuffixes() { return null; } } } diff --git a/config-provisioning/abi-spec.json b/config-provisioning/abi-spec.json index 04e4a7276e2..af61fb46e50 100644 --- a/config-provisioning/abi-spec.json +++ b/config-provisioning/abi-spec.json @@ -462,11 +462,13 @@ "public void <init>(java.lang.String, java.util.List, com.yahoo.config.provision.ClusterMembership)", "public void <init>(java.lang.String, java.util.List, java.util.Optional, java.util.Optional)", "public void <init>(java.lang.String, java.util.List, java.util.Optional, java.util.Optional, java.util.Optional)", + "public void <init>(java.lang.String, java.util.List, java.util.Optional, java.util.Optional, java.util.Optional, java.util.Optional)", "public java.lang.String hostname()", "public java.util.List aliases()", "public java.util.Optional flavor()", "public java.util.Optional version()", "public java.util.Optional membership()", + "public java.util.Optional networkPorts()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)", "public int hashCode()", @@ -497,6 +499,50 @@ ], "fields": [] }, + "com.yahoo.config.provision.NetworkPorts$Allocation": { + "superClass": "java.lang.Object", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>(int, java.lang.String, java.lang.String, java.lang.String)", + "public java.lang.String key()", + "public java.lang.String toString()" + ], + "fields": [ + "public final int port", + "public final java.lang.String serviceType", + "public final java.lang.String configId", + "public final java.lang.String portSuffix" + ] + }, + "com.yahoo.config.provision.NetworkPorts": { + "superClass": "java.lang.Object", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>(java.util.Collection)", + "public java.util.Collection allocations()", + "public int size()" + ], + "fields": [] + }, + "com.yahoo.config.provision.NetworkPortsSerializer": { + "superClass": "java.lang.Object", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>()", + "public static void toSlime(com.yahoo.config.provision.NetworkPorts, com.yahoo.slime.Cursor)", + "public static java.util.Optional fromSlime(com.yahoo.slime.Inspector)" + ], + "fields": [] + }, "com.yahoo.config.provision.NodeFlavors": { "superClass": "java.lang.Object", "interfaces": [], diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/AllocatedHosts.java b/config-provisioning/src/main/java/com/yahoo/config/provision/AllocatedHosts.java index 4c1798c549f..28c5d475e19 100644 --- a/config-provisioning/src/main/java/com/yahoo/config/provision/AllocatedHosts.java +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/AllocatedHosts.java @@ -35,6 +35,7 @@ public class AllocatedHosts { /** Current version */ private static final String hostSpecCurrentVespaVersion = "currentVespaVersion"; + private static final String hostSpecNetworkPorts = "ports"; private final ImmutableSet<HostSpec> hosts; @@ -60,6 +61,7 @@ public class AllocatedHosts { }); host.flavor().ifPresent(flavor -> cursor.setString(hostSpecFlavor, flavor.name())); host.version().ifPresent(version -> cursor.setString(hostSpecCurrentVespaVersion, version.toFullString())); + host.networkPorts().ifPresent(ports -> NetworkPortsSerializer.toSlime(ports, cursor.setArray(hostSpecNetworkPorts))); } /** Returns the hosts of this allocation */ @@ -84,8 +86,9 @@ public class AllocatedHosts { object.field(hostSpecFlavor).valid() ? flavorFromSlime(object, nodeFlavors) : Optional.empty(); Optional<com.yahoo.component.Version> version = optionalString(object.field(hostSpecCurrentVespaVersion)).map(com.yahoo.component.Version::new); - - return new HostSpec(object.field(hostSpecHostName).asString(), Collections.emptyList(), flavor, membership, version); + Optional<NetworkPorts> networkPorts = + NetworkPortsSerializer.fromSlime(object.field(hostSpecNetworkPorts)); + return new HostSpec(object.field(hostSpecHostName).asString(), Collections.emptyList(), flavor, membership, version, networkPorts); } private static ClusterMembership membershipFromSlime(Inspector object) { diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/HostSpec.java b/config-provisioning/src/main/java/com/yahoo/config/provision/HostSpec.java index f0e8774759d..e5d4aadb988 100644 --- a/config-provisioning/src/main/java/com/yahoo/config/provision/HostSpec.java +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/HostSpec.java @@ -29,6 +29,8 @@ public class HostSpec implements Comparable<HostSpec> { private final Optional<com.yahoo.component.Version> version; + private final Optional<NetworkPorts> networkPorts; + public HostSpec(String hostname, Optional<ClusterMembership> membership) { this(hostname, new ArrayList<>(), Optional.empty(), membership); } @@ -40,6 +42,7 @@ public class HostSpec implements Comparable<HostSpec> { public HostSpec(String hostname, List<String> aliases) { this(hostname, aliases, Optional.empty(), Optional.empty()); } + public HostSpec(String hostname, List<String> aliases, Flavor flavor) { this(hostname, aliases, Optional.of(flavor), Optional.empty()); } @@ -54,13 +57,21 @@ public class HostSpec implements Comparable<HostSpec> { public HostSpec(String hostname, List<String> aliases, Optional<Flavor> flavor, Optional<ClusterMembership> membership, Optional<com.yahoo.component.Version> version) { + this(hostname, aliases, flavor, membership, version, Optional.empty()); + } + + public HostSpec(String hostname, List<String> aliases, Optional<Flavor> flavor, + Optional<ClusterMembership> membership, Optional<com.yahoo.component.Version> version, + Optional<NetworkPorts> networkPorts) { if (hostname == null || hostname.isEmpty()) throw new IllegalArgumentException("Hostname must be specified"); Objects.requireNonNull(version, "Version cannot be null but can be empty"); + Objects.requireNonNull(networkPorts, "Network ports cannot be null but can be empty"); this.hostname = hostname; this.aliases = ImmutableList.copyOf(aliases); this.flavor = flavor; this.membership = membership; this.version = version; + this.networkPorts = networkPorts; } /** Returns the name identifying this host */ @@ -77,6 +88,9 @@ public class HostSpec implements Comparable<HostSpec> { /** Returns the membership of this host, or an empty value if not present */ public Optional<ClusterMembership> membership() { return membership; } + /** Returns the network port allocations on this host, or empty if not present */ + public Optional<NetworkPorts> networkPorts() { return networkPorts; } + @Override public String toString() { return hostname + diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/NetworkPorts.java b/config-provisioning/src/main/java/com/yahoo/config/provision/NetworkPorts.java new file mode 100644 index 00000000000..90ac3651bb2 --- /dev/null +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/NetworkPorts.java @@ -0,0 +1,55 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.config.provision; + +import java.util.Collection; +import java.util.List; + +/** + * Models an immutable list of network port allocations + * @author arnej + */ +public class NetworkPorts { + + public static class Allocation { + public final int port; + public final String serviceType; + public final String configId; + public final String portSuffix; + + public Allocation(int port, String serviceType, String configId, String portSuffix) { + this.port = port; + this.serviceType = serviceType; + this.configId = configId; + this.portSuffix = portSuffix; + } + public String key() { + StringBuilder buf = new StringBuilder(); + buf.append("t=").append(serviceType); + buf.append(" cfg=").append(configId); + buf.append(" suf=").append(portSuffix); + return buf.toString(); + } + public String toString() { + StringBuilder buf = new StringBuilder(); + buf.append("[port=").append(port); + buf.append(" serviceType=").append(serviceType); + buf.append(" configId=").append(configId); + buf.append(" suffix=").append(portSuffix); + buf.append("]"); + return buf.toString(); + } + } + + private final List<Allocation> allocations; + + public NetworkPorts(Collection<Allocation> allocations) { + this.allocations = List.copyOf(allocations); + } + + public Collection<Allocation> allocations() { + return this.allocations; + } + + public int size() { return allocations.size(); } +} diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/NetworkPortsSerializer.java b/config-provisioning/src/main/java/com/yahoo/config/provision/NetworkPortsSerializer.java new file mode 100644 index 00000000000..d3af337e9be --- /dev/null +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/NetworkPortsSerializer.java @@ -0,0 +1,56 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.config.provision; + +import com.yahoo.slime.ArrayTraverser; +import com.yahoo.slime.Cursor; +import com.yahoo.slime.Inspector; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +/** + * Serializes network port allocations to/from JSON. + * + * @author arnej + */ +public class NetworkPortsSerializer { + + // Network port fields + private static final String portNumberKey = "port"; + private static final String serviceTypeKey = "type"; + private static final String configIdKey = "cfg"; + private static final String portSuffixKey = "suf"; + + // ---------------- Serialization ---------------------------------------------------- + + public static void toSlime(NetworkPorts networkPorts, Cursor array) { + for (NetworkPorts.Allocation allocation : networkPorts.allocations()) { + Cursor object = array.addObject(); + object.setLong(portNumberKey, allocation.port); + object.setString(serviceTypeKey, allocation.serviceType); + object.setString(configIdKey, allocation.configId); + object.setString(portSuffixKey, allocation.portSuffix); + } + } + + // ---------------- Deserialization -------------------------------------------------- + + public static Optional<NetworkPorts> fromSlime(Inspector array) { + List<NetworkPorts.Allocation> list = new ArrayList<>(array.entries()); + array.traverse((ArrayTraverser) (int i, Inspector item) -> { + list.add(new NetworkPorts.Allocation((int)item.field(portNumberKey).asLong(), + item.field(serviceTypeKey).asString(), + item.field(configIdKey).asString(), + item.field(portSuffixKey).asString())); + } + ); + if (list.size() > 0) { + NetworkPorts allocator = new NetworkPorts(list); + return Optional.of(allocator); + } + return Optional.empty(); + } + +} diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/provision/ProvisionerAdapter.java b/configserver/src/main/java/com/yahoo/vespa/config/server/provision/ProvisionerAdapter.java index 32380b296dd..ee4cc4a3043 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/provision/ProvisionerAdapter.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/provision/ProvisionerAdapter.java @@ -8,6 +8,7 @@ import com.yahoo.config.provision.ClusterSpec; import com.yahoo.config.provision.HostSpec; import com.yahoo.config.provision.ProvisionLogger; import com.yahoo.config.provision.Provisioner; +import com.yahoo.config.provision.NetworkPorts; import java.util.*; diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionTest.java index c4b3e5f24dc..10ee1a22bab 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionTest.java @@ -7,6 +7,7 @@ import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.provision.AllocatedHosts; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.HostSpec; +import com.yahoo.config.provision.NetworkPorts; import com.yahoo.config.provision.TenantName; import com.yahoo.path.Path; import com.yahoo.config.model.application.provider.*; @@ -136,7 +137,15 @@ public class LocalSessionTest { @Test public void require_that_provision_info_can_be_read() throws Exception { - AllocatedHosts input = AllocatedHosts.withHosts(Collections.singleton(new HostSpec("myhost", Collections.<String>emptyList()))); + List<NetworkPorts.Allocation> list = new ArrayList<>(); + list.add(new NetworkPorts.Allocation(8080, "container", "default/0", "http")); + list.add(new NetworkPorts.Allocation(19101, "searchnode", "other/1", "rpc")); + NetworkPorts ports = new NetworkPorts(list); + + AllocatedHosts input = AllocatedHosts.withHosts(Collections.singleton( + new HostSpec("myhost", Collections.<String>emptyList(), + Optional.empty(), Optional.empty(), Optional.empty(), + Optional.of(ports)))); LocalSession session = createSession(TenantName.defaultName(), 3, new SessionTest.MockSessionPreparer(), Optional.of(input)); ApplicationId origId = new ApplicationId.Builder() @@ -147,6 +156,9 @@ public class LocalSessionTest { assertNotNull(info); assertThat(info.getHosts().size(), is(1)); assertTrue(info.getHosts().contains(new HostSpec("myhost", Collections.emptyList()))); + Optional<NetworkPorts> portsCopy = info.getHosts().iterator().next().networkPorts(); + assertTrue(portsCopy.isPresent()); + assertThat(portsCopy.get().allocations(), is(list)); } @Test diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/JobType.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/JobType.java index cee8d3ddfd9..1dd9da6dc32 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/JobType.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/JobType.java @@ -28,7 +28,9 @@ public enum JobType { productionAwsUsEast1b ("production-aws-us-east-1b" , ZoneId.from("prod" , "aws-us-east-1b") , null ), productionCdAwsUsEast1a("production-cd-aws-us-east-1a", null , ZoneId.from("prod" , "cd-aws-us-east-1a")), productionCdUsCentral1 ("production-cd-us-central-1" , null , ZoneId.from("prod" , "cd-us-central-1") ), - productionCdUsCentral2 ("production-cd-us-central-2" , null , ZoneId.from("prod" , "cd-us-central-2") ); + // TODO: Cannot remove production-cd-us-central-2 until we know there are no serialized data in controller referencing it + productionCdUsCentral2 ("production-cd-us-central-2" , null , ZoneId.from("prod" , "cd-us-central-2") ), + productionCdUsWest1 ("production-cd-us-west-1" , null , ZoneId.from("prod" , "cd-us-west-1") ); private final String jobName; private final ImmutableMap<SystemName, ZoneId> zones; diff --git a/document/abi-spec.json b/document/abi-spec.json index 61390af3523..d4db3026b27 100644 --- a/document/abi-spec.json +++ b/document/abi-spec.json @@ -5244,7 +5244,7 @@ ], "methods": [ "public void <init>(com.yahoo.document.update.TensorModifyUpdate$Operation, com.yahoo.document.datatypes.TensorFieldValue)", - "public static com.yahoo.tensor.TensorType convertToCompatibleType(com.yahoo.tensor.TensorType)", + "public static com.yahoo.tensor.TensorType convertDimensionsToMapped(com.yahoo.tensor.TensorType)", "public com.yahoo.document.update.TensorModifyUpdate$Operation getOperation()", "public com.yahoo.document.datatypes.TensorFieldValue getValue()", "public void setValue(com.yahoo.document.datatypes.TensorFieldValue)", @@ -5278,6 +5278,7 @@ "public boolean equals(java.lang.Object)", "public int hashCode()", "public java.lang.String toString()", + "public static com.yahoo.tensor.TensorType extractSparseDimensions(com.yahoo.tensor.TensorType)", "public bridge synthetic void setValue(com.yahoo.document.datatypes.FieldValue)", "public bridge synthetic com.yahoo.document.datatypes.FieldValue getValue()" ], diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java index ffbfe49347c..6310fa62d15 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java @@ -23,22 +23,23 @@ public class TensorAddUpdateReader { public static TensorAddUpdate createTensorAddUpdate(TokenBuffer buffer, Field field) { expectObjectStart(buffer.currentToken()); - expectTensorTypeIsSparse(field); + expectTensorTypeHasSparseDimensions(field); TensorDataType tensorDataType = (TensorDataType)field.getDataType(); TensorType tensorType = tensorDataType.getTensorType(); TensorFieldValue tensorFieldValue = new TensorFieldValue(tensorType); fillTensor(buffer, tensorFieldValue); + expectTensorIsNonEmpty(field, tensorFieldValue.getTensor().get()); return new TensorAddUpdate(tensorFieldValue); } - private static void expectTensorTypeIsSparse(Field field) { + private static void expectTensorTypeHasSparseDimensions(Field field) { TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType(); - if (tensorType.dimensions().stream() - .anyMatch(dim -> dim.isIndexed())) { - throw new IllegalArgumentException("An add update can only be applied to sparse tensors. " - + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'"); + if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) { + throw new IllegalArgumentException("An add update can only be applied to tensors " + + "with at least one sparse dimension. Field '" + field.getName() + + "' has unsupported tensor type '" + tensorType + "'"); } } @@ -48,5 +49,4 @@ public class TensorAddUpdateReader { } } - } diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java index a9bbba519bd..66588debbca 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java @@ -29,10 +29,8 @@ public class TensorModifyUpdateReader { private static final String MODIFY_MULTIPLY = "multiply"; public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) { - expectFieldIsOfTypeTensor(field); expectTensorTypeHasNoneIndexedUnboundDimensions(field); - expectTensorTypeIsNotMixed(field); expectObjectStart(buffer.currentToken()); ModifyUpdateResult result = createModifyUpdateResult(buffer, field); @@ -58,16 +56,6 @@ public class TensorModifyUpdateReader { } } - private static void expectTensorTypeIsNotMixed(Field field) { - TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType(); - long numMappedDimensions = tensorType.dimensions().stream().filter(dim -> dim.type().equals(TensorType.Dimension.Type.mapped)).count(); - long numIndexedDimensions = tensorType.dimensions().stream().filter(dim -> dim.isIndexed()).count(); - if (numMappedDimensions > 0 && numIndexedDimensions > 0) { - throw new IllegalArgumentException("A modify update cannot be applied to tensor types with mixed dimensions. " - + "Field '" + field.getName() + "' has mixed tensor type '" + tensorType + "'"); - } - } - private static void expectOperationSpecified(TensorModifyUpdate.Operation operation, String fieldName) { if (operation == null) { throw new IllegalArgumentException("Modify update for field '" + fieldName + "' does not contain an operation"); @@ -121,7 +109,7 @@ public class TensorModifyUpdateReader { private static TensorFieldValue createTensor(TokenBuffer buffer, Field field) { TensorDataType tensorDataType = (TensorDataType)field.getDataType(); TensorType originalType = tensorDataType.getTensorType(); - TensorType convertedType = TensorModifyUpdate.convertToCompatibleType(originalType); + TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(originalType); Tensor.Builder tensorBuilder = Tensor.Builder.of(convertedType); readTensorCells(buffer, tensorBuilder); @@ -129,25 +117,26 @@ public class TensorModifyUpdateReader { validateBounds(tensor, originalType); - TensorFieldValue result = new TensorFieldValue(convertedType); - result.assign(tensor); - return result; + return new TensorFieldValue(tensor); } - /** Only validate if original type is indexed bound */ - private static void validateBounds(Tensor convertedTensor, TensorType originalType) { - if ( ! originalType.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) { + /** Only validate if original type has indexed bound dimensions */ + static void validateBounds(Tensor convertedTensor, TensorType originalType) { + if (originalType.dimensions().stream().noneMatch(d -> d instanceof TensorType.IndexedBoundDimension)) { return; } for (Iterator<Tensor.Cell> iter = convertedTensor.cellIterator(); iter.hasNext(); ) { Tensor.Cell cell = iter.next(); TensorAddress address = cell.getKey(); for (int i = 0; i < address.size(); ++i) { - long label = address.numericLabel(i); - long bound = originalType.dimensions().get(i).size().get(); // size is non-optional for indexed bound - if (label >= bound) { - throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() + - "' has label '" + label + "' but type is " + originalType.toString()); + TensorType.Dimension dim = originalType.dimensions().get(i); + if (dim instanceof TensorType.IndexedBoundDimension) { + long label = address.numericLabel(i); + long bound = dim.size().get(); // size is non-optional for indexed bound + if (label >= bound) { + throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() + + "' has label '" + label + "' but type is " + originalType.toString()); + } } } } diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java index 210a6a80ee5..3bb4b2e262f 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java @@ -24,23 +24,23 @@ public class TensorRemoveUpdateReader { static TensorRemoveUpdate createTensorRemoveUpdate(TokenBuffer buffer, Field field) { expectObjectStart(buffer.currentToken()); - expectTensorTypeIsSparse(field); + expectTensorTypeHasSparseDimensions(field); TensorDataType tensorDataType = (TensorDataType)field.getDataType(); - TensorType tensorType = tensorDataType.getTensorType(); + TensorType originalType = tensorDataType.getTensorType(); + TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(originalType); + Tensor tensor = readRemoveUpdateTensor(buffer, convertedType, originalType); - // TODO: for mixed case extract a new tensor type based only on mapped dimensions - - Tensor tensor = readRemoveUpdateTensor(buffer, tensorType); expectAddressesAreNonEmpty(field, tensor); return new TensorRemoveUpdate(new TensorFieldValue(tensor)); } - private static void expectTensorTypeIsSparse(Field field) { + private static void expectTensorTypeHasSparseDimensions(Field field) { TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType(); - if (tensorType.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed)) { - throw new IllegalArgumentException("A remove update can only be applied to sparse tensors. " - + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'"); + if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) { + throw new IllegalArgumentException("A remove update can only be applied to tensors " + + "with at least one sparse dimension. Field '" + field.getName() + + "' has unsupported tensor type '" + tensorType + "'"); } } @@ -53,7 +53,7 @@ public class TensorRemoveUpdateReader { /** * Reads all addresses in buffer and returns a tensor where addresses have cell value 1.0 */ - private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type) { + private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type, TensorType originalType) { Tensor.Builder builder = Tensor.Builder.of(type); expectObjectStart(buffer.currentToken()); int initNesting = buffer.nesting(); @@ -62,7 +62,7 @@ public class TensorRemoveUpdateReader { expectArrayStart(buffer.currentToken()); int nesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= nesting; buffer.next()) { - builder.cell(readTensorAddress(buffer, type), 1.0); + builder.cell(readTensorAddress(buffer, type, originalType), 1.0); } expectCompositeEnd(buffer.currentToken()); } @@ -71,12 +71,15 @@ public class TensorRemoveUpdateReader { return builder.build(); } - private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type) { + private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type, TensorType originalType) { TensorAddress.Builder builder = new TensorAddress.Builder(type); expectObjectStart(buffer.currentToken()); int initNesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { String dimension = buffer.currentName(); + if ( ! type.dimension(dimension).isPresent() && originalType.dimension(dimension).isPresent()) { + throw new IllegalArgumentException("Indexed dimension address '" + dimension + "' should not be specified in remove update"); + } String label = buffer.currentText(); builder.add(dimension, label); } diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java index 2f22def9aa1..a763db33e7a 100644 --- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java +++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java @@ -5,6 +5,7 @@ import com.yahoo.document.DataType; import com.yahoo.document.DocumentTypeManager; import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.TensorFieldValue; +import com.yahoo.document.json.readers.TensorRemoveUpdateReader; import com.yahoo.document.update.TensorAddUpdate; import com.yahoo.document.update.TensorModifyUpdate; import com.yahoo.document.update.TensorRemoveUpdate; @@ -35,7 +36,10 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 { throw new DeserializationException("Expected tensor data type, got " + type); } TensorDataType tensorDataType = (TensorDataType)type; - TensorFieldValue tensor = new TensorFieldValue(TensorModifyUpdate.convertToCompatibleType(tensorDataType.getTensorType())); + TensorType tensorType = tensorDataType.getTensorType(); + TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(tensorType); + + TensorFieldValue tensor = new TensorFieldValue(convertedType); tensor.deserialize(this); return new TensorModifyUpdate(operation, tensor); } @@ -46,7 +50,8 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 { throw new DeserializationException("Expected tensor data type, got " + type); } TensorDataType tensorDataType = (TensorDataType)type; - TensorFieldValue tensor = new TensorFieldValue(tensorDataType.getTensorType()); + TensorType tensorType = tensorDataType.getTensorType(); + TensorFieldValue tensor = new TensorFieldValue(tensorType); tensor.deserialize(this); return new TensorAddUpdate(tensor); } @@ -58,10 +63,9 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 { } TensorDataType tensorDataType = (TensorDataType)type; TensorType tensorType = tensorDataType.getTensorType(); + TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(tensorType); - // TODO: for mixed case extract a new tensor type based only on mapped dimensions - - TensorFieldValue tensor = new TensorFieldValue(tensorType); + TensorFieldValue tensor = new TensorFieldValue(convertedType); tensor.deserialize(this); return new TensorRemoveUpdate(tensor); } diff --git a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java index cfc3ee0c742..f8d2464deb7 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java @@ -7,15 +7,11 @@ import com.yahoo.document.datatypes.FieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.serialization.DocumentUpdateWriter; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; -import java.util.Map; import java.util.Objects; /** - * An update used to add cells to a sparse tensor (has only mapped dimensions). - * - * The cells to add are contained in a sparse tensor as well. + * An update used to add cells to a sparse or mixed tensor (has at least one mapped dimension). */ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> { @@ -50,22 +46,10 @@ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> { return oldValue; } - Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get(); - Map<TensorAddress, Double> oldCells = oldTensor.cells(); - Map<TensorAddress, Double> addCells = tensor.getTensor().get().cells(); - - // currently, underlying implementation disallows multiple entries with the same key - - Tensor.Builder builder = Tensor.Builder.of(oldTensor.type()); - for (Map.Entry<TensorAddress, Double> oldCell : oldCells.entrySet()) { - builder.cell(oldCell.getKey(), addCells.getOrDefault(oldCell.getKey(), oldCell.getValue())); - } - for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { - if ( ! oldCells.containsKey(addCell.getKey())) { - builder.cell(addCell.getKey(), addCell.getValue()); - } - } - return new TensorFieldValue(builder.build()); + Tensor old = ((TensorFieldValue) oldValue).getTensor().get(); + Tensor update = tensor.getTensor().get(); + Tensor result = old.merge((left, right) -> right, update.cells()); // note this might be slow for large mixed tensor updates + return new TensorFieldValue(result); } @Override diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java index 6111b51ca4e..2773f9d31da 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java @@ -37,7 +37,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { /** * Converts the given tensor type to a type that is compatible for being used in this update (has only mapped dimensions). */ - public static TensorType convertToCompatibleType(TensorType type) { + public static TensorType convertDimensionsToMapped(TensorType type) { TensorType.Builder builder = new TensorType.Builder(); type.dimensions().stream().forEach(dim -> builder.mapped(dim.name())); return builder.build(); diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java index e9fb1e3efd5..335cda8e133 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java @@ -7,10 +7,8 @@ import com.yahoo.document.datatypes.FieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.serialization.DocumentUpdateWriter; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; -import java.util.Iterator; -import java.util.Map; import java.util.Objects; /** @@ -25,6 +23,18 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { public TensorRemoveUpdate(TensorFieldValue value) { super(ValueUpdateClassID.TENSORREMOVE); this.tensor = value; + verifyCompatibleType(); + } + + private void verifyCompatibleType() { + if ( ! tensor.getTensor().isPresent()) { + throw new IllegalArgumentException("Tensor must be present in remove update"); + } + TensorType tensorType = tensor.getTensor().get().type(); + TensorType expectedType = extractSparseDimensions(tensor.getDataType().getTensorType()); + if ( ! tensorType.equals(expectedType)) { + throw new IllegalArgumentException("Unexpected type '" + tensorType + "' in remove update. Expected is '" + expectedType + "'"); + } } @Override @@ -51,17 +61,10 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { return oldValue; } - Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get(); - Map<TensorAddress, Double> cellsToRemove = tensor.getTensor().get().cells(); - Tensor.Builder builder = Tensor.Builder.of(oldTensor.type()); - for (Iterator<Tensor.Cell> i = oldTensor.cellIterator(); i.hasNext(); ) { - Tensor.Cell cell = i.next(); - TensorAddress address = cell.getKey(); - if ( ! cellsToRemove.containsKey(address)) { - builder.cell(address, cell.getValue()); - } - } - return new TensorFieldValue(builder.build()); + Tensor old = ((TensorFieldValue) oldValue).getTensor().get(); + Tensor update = tensor.getTensor().get(); + Tensor result = old.remove(update.cells().keySet()); + return new TensorFieldValue(result); } @Override @@ -93,4 +96,11 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { return super.toString() + " " + tensor; } + public static TensorType extractSparseDimensions(TensorType type) { + TensorType.Builder builder = new TensorType.Builder(); + type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name())); + return builder.build(); + } + + } diff --git a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java index e2736dabd2b..454ad72f344 100644 --- a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java +++ b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java @@ -40,6 +40,7 @@ public class DocumentUpdateJsonSerializerTest { final static TensorType sparseTensorType = new TensorType.Builder().mapped("x").mapped("y").build(); final static TensorType denseTensorType = new TensorType.Builder().indexed("x", 2).indexed("y", 3).build(); + final static TensorType mixedTensorType = new TensorType.Builder().mapped("x").indexed("y", 3).build(); final static DocumentTypeManager types = new DocumentTypeManager(); final static JsonFactory parserFactory = new JsonFactory(); final static DocumentType docType = new DocumentType("doctype"); @@ -60,6 +61,7 @@ public class DocumentUpdateJsonSerializerTest { docType.addField(new Field("byte_field", DataType.BYTE)); docType.addField(new Field("sparse_tensor", new TensorDataType(sparseTensorType))); docType.addField(new Field("dense_tensor", new TensorDataType(denseTensorType))); + docType.addField(new Field("mixed_tensor", new TensorDataType(mixedTensorType))); docType.addField(new Field("reference_field", new ReferenceDataType(refTargetDocType, 777))); docType.addField(new Field("predicate_field", DataType.PREDICATE)); docType.addField(new Field("raw_field", DataType.RAW)); @@ -336,6 +338,26 @@ public class DocumentUpdateJsonSerializerTest { } @Test + public void test_tensor_modify_update_on_mixed_tensor() { + roundtripSerializeJsonAndMatch(inputJson( + "{", + " 'update': 'DOCUMENT_ID',", + " 'fields': {", + " 'mixed_tensor': {", + " 'modify': {", + " 'operation': 'multiply',", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 },", + " { 'address': { 'x': 'c', 'y': '1' }, 'value': 3.0 }", + " ]", + " }", + " }", + " }", + "}" + )); + } + + @Test public void test_tensor_add_update() { roundtripSerializeJsonAndMatch(inputJson( "{", @@ -355,6 +377,29 @@ public class DocumentUpdateJsonSerializerTest { } @Test + public void test_tensor_add_update_mixed() { + roundtripSerializeJsonAndMatch(inputJson( + "{", + " 'update': 'DOCUMENT_ID',", + " 'fields': {", + " 'mixed_tensor': {", + " 'add': {", + " 'cells': [", + " { 'address': { 'x': '1', 'y': '0' }, 'value': 2.0 },", + " { 'address': { 'x': '1', 'y': '1' }, 'value': 0.0 },", + " { 'address': { 'x': '1', 'y': '2' }, 'value': 0.0 },", + " { 'address': { 'x': '0', 'y': '0' }, 'value': 0.0 },", + " { 'address': { 'x': '0', 'y': '1' }, 'value': 0.0 },", + " { 'address': { 'x': '0', 'y': '2' }, 'value': 3.0 }", + " ]", + " }", + " }", + " }", + "}" + )); + } + + @Test public void test_tensor_remove_update() { roundtripSerializeJsonAndMatch(inputJson( "{", @@ -374,6 +419,24 @@ public class DocumentUpdateJsonSerializerTest { } @Test + public void test_tensor_remove_update_mixed() { + roundtripSerializeJsonAndMatch(inputJson( + "{", + " 'update': 'DOCUMENT_ID',", + " 'fields': {", + " 'mixed_tensor': {", + " 'remove': {", + " 'addresses': [", + " {'x':'0' }", + " ]", + " }", + " }", + " }", + "}" + )); + } + + @Test public void reference_field_id_can_be_update_assigned_non_empty_id() { roundtripSerializeJsonAndMatch(inputJson( "{", diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index e58b26d500d..15d1e859f73 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -1387,12 +1387,30 @@ public class JsonReaderTestCase { } @Test - public void tensor_modify_update_on_mixed_tensor_throws() { - exception.expect(IllegalArgumentException.class); - exception.expectMessage("A modify update cannot be applied to tensor types with mixed dimensions. Field 'mixed_tensor' has mixed tensor type 'tensor(x{},y[3])'"); - createTensorModifyUpdate(inputJson("{", - " 'operation': 'replace',", - " 'cells': [] }"), "mixed_tensor"); + public void tensor_modify_update_with_replace_operation_mixed() { + assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor", + inputJson("{", + " 'operation': 'replace',", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}")); + } + + @Test + public void tensor_modify_update_with_add_operation_mixed() { + assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.ADD, "mixed_tensor", + inputJson("{", + " 'operation': 'add',", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}")); + } + + @Test + public void tensor_modify_update_with_multiply_operation_mixed() { + assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.MULTIPLY, "mixed_tensor", + inputJson("{", + " 'operation': 'multiply',", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}")); } @Test @@ -1406,6 +1424,17 @@ public class JsonReaderTestCase { } @Test + public void tensor_modify_update_with_out_of_bound_cells_throws_mixed() { + exception.expect(IndexOutOfBoundsException.class); + exception.expectMessage("Dimension 'y' has label '3' but type is tensor(x{},y[3])"); + createTensorModifyUpdate(inputJson("{", + " 'operation': 'replace',", + " 'cells': [", + " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"), "mixed_tensor"); + } + + + @Test public void tensor_modify_update_with_unknown_operation_throws() { exception.expect(IllegalArgumentException.class); exception.expectMessage("Unknown operation 'unknown' in modify update for field 'sparse_tensor'"); @@ -1449,11 +1478,29 @@ public class JsonReaderTestCase { } @Test - public void tensor_add_update_on_non_sparse_tensor_throws() { + public void tensor_add_update_on_mixed_tensor() { + assertTensorAddUpdate("{{x:a,y:0}:2.0, {x:a,y:1}:3.0, {x:a,y:2}:0.0}", "mixed_tensor", + inputJson("{", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 },", + " { 'address': { 'x': 'a', 'y': '1' }, 'value': 3.0 } ]}")); + } + + @Test + public void tensor_add_update_on_mixed_with_out_of_bound_dense_cells_throws() { + exception.expect(IndexOutOfBoundsException.class); + exception.expectMessage("Index 3 out of bounds for length 3"); + createTensorAddUpdate(inputJson("{", + " 'cells': [", + " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"), "mixed_tensor"); + } + + @Test + public void tensor_add_update_on_dense_tensor_throws() { exception.expect(IllegalArgumentException.class); - exception.expectMessage("An add update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'"); + exception.expectMessage("An add update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'"); createTensorAddUpdate(inputJson("{", - " 'cells': [] }"), "mixed_tensor"); + " 'cells': [] }"), "dense_tensor"); } @Test @@ -1470,6 +1517,7 @@ public class JsonReaderTestCase { exception.expect(IllegalArgumentException.class); exception.expectMessage("Add update for field 'sparse_tensor' does not contain tensor cells"); createTensorAddUpdate(inputJson("{}"), "sparse_tensor"); + createTensorAddUpdate(inputJson("{}"), "mixed_tensor"); } @Test @@ -1482,11 +1530,30 @@ public class JsonReaderTestCase { } @Test - public void tensor_remove_update_on_non_sparse_tensor_throws() { + public void tensor_remove_update_on_mixed_tensor() { + assertTensorRemoveUpdate("{{x:1}:1.0,{x:2}:1.0}", "mixed_tensor", + inputJson("{", + " 'addresses': [", + " { 'x': '1' },", + " { 'x': '2' } ]}")); + } + + @Test + public void tensor_remove_update_on_mixed_tensor_with_dense_addresses_throws() { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Indexed dimension address 'y' should not be specified in remove update"); + createTensorRemoveUpdate(inputJson("{", + " 'addresses': [", + " { 'x': '1', 'y': '0' },", + " { 'x': '2', 'y': '0' } ]}"), "mixed_tensor"); + } + + @Test + public void tensor_remove_update_on_dense_tensor_throws() { exception.expect(IllegalArgumentException.class); - exception.expectMessage("A remove update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'"); + exception.expectMessage("A remove update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'"); createTensorRemoveUpdate(inputJson("{", - " 'addresses': [] }"), "mixed_tensor"); + " 'addresses': [] }"), "dense_tensor"); } @Test @@ -1503,6 +1570,7 @@ public class JsonReaderTestCase { exception.expect(IllegalArgumentException.class); exception.expectMessage("Remove update for field 'sparse_tensor' does not contain tensor addresses"); createTensorRemoveUpdate(inputJson("{'addresses': [] }"), "sparse_tensor"); + createTensorRemoveUpdate(inputJson("{'addresses': [] }"), "mixed_tensor"); } @Test diff --git a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java index eb4001e6415..6935c54ba2a 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java @@ -12,18 +12,14 @@ public class TensorAddUpdateTest { @Test public void apply_add_update_operations() { assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}"); - assertApplyTo("{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1, {x:0,y:1}:2}"); } private void assertApplyTo(String init, String update, String expected) { String spec = "tensor(x{},y{})"; TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init)); TensorAddUpdate addUpdate = new TensorAddUpdate(new TensorFieldValue(Tensor.from(spec, update))); - TensorFieldValue updatedFieldValue = (TensorFieldValue) addUpdate.applyTo(initialFieldValue); - assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get()); + Tensor updated = ((TensorFieldValue) addUpdate.applyTo(initialFieldValue)).getTensor().get(); + assertEquals(Tensor.from(spec, expected), updated); } } diff --git a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java index 6e9444de2be..b885e6ddca0 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java @@ -1,12 +1,6 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.document.update; -import com.yahoo.document.Document; -import com.yahoo.document.DocumentId; -import com.yahoo.document.DocumentType; -import com.yahoo.document.DocumentTypeManager; -import com.yahoo.document.Field; -import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.update.TensorModifyUpdate.Operation; import com.yahoo.tensor.Tensor; @@ -28,10 +22,11 @@ public class TensorModifyUpdateTest { assertConvertToCompatible("tensor(x{})", "tensor(x[10])"); assertConvertToCompatible("tensor(x{})", "tensor(x{})"); assertConvertToCompatible("tensor(x{},y{},z{})", "tensor(x[],y[10],z{})"); + assertConvertToCompatible("tensor(x{},y{})", "tensor(x{},y[3])"); } private static void assertConvertToCompatible(String expectedType, String inputType) { - assertEquals(expectedType, TensorModifyUpdate.convertToCompatibleType(TensorType.fromSpec(inputType)).toString()); + assertEquals(expectedType, TensorModifyUpdate.convertDimensionsToMapped(TensorType.fromSpec(inputType)).toString()); } @Test @@ -46,15 +41,9 @@ public class TensorModifyUpdateTest { public void apply_modify_update_operations() { assertApplyTo("tensor(x{},y{})", Operation.REPLACE, "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}"); - assertApplyTo("tensor(x{},y{})", Operation.ADD, - "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}"); - assertApplyTo("tensor(x{},y{})", Operation.MULTIPLY, - "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}"); - assertApplyTo("tensor(x[1],y[2])", Operation.REPLACE, - "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}"); assertApplyTo("tensor(x[1],y[2])", Operation.ADD, "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}"); - assertApplyTo("tensor(x[1],y[2])", Operation.MULTIPLY, + assertApplyTo("tensor(x{},y[2])", Operation.MULTIPLY, "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}"); } diff --git a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java index 40ab00facdb..3a005e858c8 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java @@ -12,9 +12,6 @@ public class TensorRemoveUpdateTest { @Test public void apply_remove_update_operations() { assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0,y:1}:1}", "{{x:0,y:0}:2}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:0}:1,{x:0,y:1}:1}", "{}"); - assertApplyTo("{}", "{{x:0,y:0}:1}", "{}"); - assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}"); } private void assertApplyTo(String init, String update, String expected) { diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index c74c211756f..017d83893f0 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -922,6 +922,17 @@ TEST(DocumentUpdateTest, tensor_add_update_can_be_applied) .add({{"x", "c"}}, 7)); } +TEST(DocumentUpdateTest, tensor_remove_update_can_be_applied) +{ + TensorUpdateFixture f; + f.assertApplyUpdate(f.spec().add({{"x", "a"}}, 2) + .add({{"x", "b"}}, 3), + + TensorRemoveUpdate(f.makeTensor(f.spec().add({{"x", "b"}}, 1))), + + f.spec().add({{"x", "a"}}, 2)); +} + TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied) { TensorUpdateFixture f; diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp index 3e2bb86c66b..671bf260629 100644 --- a/document/src/vespa/document/update/tensor_remove_update.cpp +++ b/document/src/vespa/document/update/tensor_remove_update.cpp @@ -6,6 +6,8 @@ #include <vespa/document/fieldvalue/document.h> #include <vespa/document/fieldvalue/tensorfieldvalue.h> #include <vespa/document/serialization/vespadocumentdeserializer.h> +#include <vespa/eval/tensor/cell_values.h> +#include <vespa/eval/tensor/sparse/sparse_tensor.h> #include <vespa/eval/tensor/tensor.h> #include <vespa/vespalib/objects/nbostream.h> #include <vespa/vespalib/util/xmlstream.h> @@ -77,17 +79,35 @@ TensorRemoveUpdate::checkCompatibility(const Field &field) const std::unique_ptr<Tensor> TensorRemoveUpdate::applyTo(const Tensor &tensor) const { - // TODO: implement - (void) tensor; + auto &addressTensor = _tensor->getAsTensorPtr(); + if (addressTensor) { + if (const auto *sparseTensor = dynamic_cast<const vespalib::tensor::SparseTensor *>(addressTensor.get())) { + vespalib::tensor::CellValues cellAddresses(*sparseTensor); + return tensor.remove(cellAddresses); + } else { + throw IllegalArgumentException(make_string("Expected address tensor to be sparse, but has type '%s'", + addressTensor->type().to_spec().c_str())); + } + } return std::unique_ptr<Tensor>(); } bool TensorRemoveUpdate::applyTo(FieldValue &value) const { - // TODO: implement - (void) value; - return false; + if (value.inherits(TensorFieldValue::classId)) { + TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value); + auto &oldTensor = tensorFieldValue.getAsTensorPtr(); + auto newTensor = applyTo(*oldTensor); + if (newTensor) { + tensorFieldValue = std::move(newTensor); + } + } else { + std::string err = make_string("Unable to perform a tensor remove update on a '%s' field value.", + value.getClass().name()); + throw IllegalStateException(err, VESPA_STRLOC); + } + return true; } void diff --git a/document/src/vespa/document/update/valueupdate.h b/document/src/vespa/document/update/valueupdate.h index 0e15943f8e4..6939d10ce2c 100644 --- a/document/src/vespa/document/update/valueupdate.h +++ b/document/src/vespa/document/update/valueupdate.h @@ -55,7 +55,8 @@ public: Map = IDENTIFIABLE_CLASSID(MapValueUpdate), Remove = IDENTIFIABLE_CLASSID(RemoveValueUpdate), TensorModifyUpdate = IDENTIFIABLE_CLASSID(TensorModifyUpdate), - TensorAddUpdate = IDENTIFIABLE_CLASSID(TensorAddUpdate) + TensorAddUpdate = IDENTIFIABLE_CLASSID(TensorAddUpdate), + TensorRemoveUpdate = IDENTIFIABLE_CLASSID(TensorRemoveUpdate) }; ValueUpdate() diff --git a/documentapi/CMakeLists.txt b/documentapi/CMakeLists.txt index b03dd66c817..86d29732399 100644 --- a/documentapi/CMakeLists.txt +++ b/documentapi/CMakeLists.txt @@ -14,7 +14,6 @@ vespa_define_module( vdslib LIBS - src/vespa/binref src/vespa/documentapi src/vespa/documentapi/loadtypes src/vespa/documentapi/messagebus diff --git a/documentapi/src/vespa/binref/.gitignore b/documentapi/src/vespa/binref/.gitignore deleted file mode 100644 index cfb0e619824..00000000000 --- a/documentapi/src/vespa/binref/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -.depend -Makefile -testrun.sh diff --git a/documentapi/src/vespa/binref/CMakeLists.txt b/documentapi/src/vespa/binref/CMakeLists.txt deleted file mode 100644 index adece6dd711..00000000000 --- a/documentapi/src/vespa/binref/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. diff --git a/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java b/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java index 41302a4c725..84fbb7d4f01 100644 --- a/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java +++ b/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java @@ -22,7 +22,6 @@ public class TlsCryptoEngine implements CryptoEngine { @Override public TlsCryptoSocket createCryptoSocket(SocketChannel channel, boolean isServer) { SSLEngine sslEngine = tlsContext.createSslEngine(); - sslEngine.setNeedClientAuth(true); sslEngine.setUseClientMode(!isServer); return new TlsCryptoSocket(channel, sslEngine); } diff --git a/jrt_test/src/binref/testrun.sh b/jrt_test/src/binref/testrun.sh deleted file mode 120000 index 56c3c1186d8..00000000000 --- a/jrt_test/src/binref/testrun.sh +++ /dev/null @@ -1 +0,0 @@ -../../../vespalib/src/vespa/vespalib/testkit/testrun.sh
\ No newline at end of file diff --git a/lowercasing_test/src/binref/testrun.sh b/lowercasing_test/src/binref/testrun.sh deleted file mode 120000 index 56c3c1186d8..00000000000 --- a/lowercasing_test/src/binref/testrun.sh +++ /dev/null @@ -1 +0,0 @@ -../../../vespalib/src/vespa/vespalib/testkit/testrun.sh
\ No newline at end of file diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/Node.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/Node.java index 668795f362b..4fef3d8ebf7 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/Node.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/Node.java @@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableSet; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ClusterMembership; import com.yahoo.config.provision.Flavor; +import com.yahoo.config.provision.NetworkPorts; import com.yahoo.config.provision.NodeType; import com.yahoo.vespa.hosted.provision.node.Agent; import com.yahoo.vespa.hosted.provision.node.Allocation; diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeRepositoryMaintenance.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeRepositoryMaintenance.java index 2496d9ba8c9..7fdd9a168c8 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeRepositoryMaintenance.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeRepositoryMaintenance.java @@ -182,7 +182,8 @@ public class NodeRepositoryMaintenance extends AbstractComponent { DefaultTimes(Zone zone) { failGrace = Duration.ofMinutes(60); periodicRedeployInterval = Duration.ofMinutes(30); - redeployMaintainerInterval = Duration.ofMinutes(1); + // Don't redeploy in test environments + redeployMaintainerInterval = zone.environment().isTest() ? Duration.ofDays(1) : Duration.ofMinutes(1); operatorChangeRedeployInterval = Duration.ofMinutes(1); failedExpirerInterval = Duration.ofMinutes(10); provisionedExpiry = Duration.ofHours(4); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Allocation.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Allocation.java index 8a331209efc..53e1ae3721e 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Allocation.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Allocation.java @@ -3,6 +3,9 @@ package com.yahoo.vespa.hosted.provision.node; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ClusterMembership; +import com.yahoo.config.provision.NetworkPorts; + +import java.util.Optional; /** * The allocation of a node @@ -24,12 +27,21 @@ public class Allocation { /** This node can (and should) be removed from the cluster on the next deployment */ private final boolean removable; + private final Optional<NetworkPorts> networkPorts; + + public Allocation(ApplicationId owner, ClusterMembership clusterMembership, Generation restartGeneration, boolean removable) { + this(owner, clusterMembership, restartGeneration, removable, Optional.empty()); + } + + public Allocation(ApplicationId owner, ClusterMembership clusterMembership, + Generation restartGeneration, boolean removable, Optional<NetworkPorts> networkPorts) { this.owner = owner; this.clusterMembership = clusterMembership; this.restartGeneration = restartGeneration; this.removable = removable; + this.networkPorts = networkPorts; } /** Returns the id of the application this is allocated to */ @@ -41,14 +53,17 @@ public class Allocation { /** Returns the restart generation (wanted and current) of this */ public Generation restartGeneration() { return restartGeneration; } + /** Returns network ports allocations (or empty if not recorded) */ + public Optional<NetworkPorts> networkPorts() { return networkPorts; } + /** Returns a copy of this which is retired */ public Allocation retire() { - return new Allocation(owner, clusterMembership.retire(), restartGeneration, removable); + return new Allocation(owner, clusterMembership.retire(), restartGeneration, removable, networkPorts); } /** Returns a copy of this which is not retired */ public Allocation unretire() { - return new Allocation(owner, clusterMembership.unretire(), restartGeneration, removable); + return new Allocation(owner, clusterMembership.unretire(), restartGeneration, removable, networkPorts); } /** Return whether this node is ready to be removed from the application */ @@ -56,16 +71,20 @@ public class Allocation { /** Returns a copy of this with the current restart generation set to generation */ public Allocation withRestart(Generation generation) { - return new Allocation(owner, clusterMembership, generation, removable); + return new Allocation(owner, clusterMembership, generation, removable, networkPorts); } /** Returns a copy of this allocation where removable is set to true */ public Allocation removable() { - return new Allocation(owner, clusterMembership, restartGeneration, true); + return new Allocation(owner, clusterMembership, restartGeneration, true, networkPorts); } public Allocation with(ClusterMembership newMembership) { - return new Allocation(owner, newMembership, restartGeneration, removable); + return new Allocation(owner, newMembership, restartGeneration, removable, networkPorts); + } + + public Allocation withNetworkPorts(NetworkPorts ports) { + return new Allocation(owner, clusterMembership, restartGeneration, removable, Optional.of(ports)); } @Override diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/NodeSerializer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/NodeSerializer.java index 54668c4eda1..bb4dab3b97b 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/NodeSerializer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/NodeSerializer.java @@ -10,6 +10,8 @@ import com.yahoo.config.provision.Flavor; import com.yahoo.config.provision.InstanceName; import com.yahoo.config.provision.NodeFlavors; import com.yahoo.config.provision.NodeType; +import com.yahoo.config.provision.NetworkPorts; +import com.yahoo.config.provision.NetworkPortsSerializer; import com.yahoo.config.provision.TenantName; import com.yahoo.slime.ArrayTraverser; import com.yahoo.slime.Cursor; @@ -85,6 +87,9 @@ public class NodeSerializer { private static final String atKey = "at"; private static final String agentKey = "agent"; // retired events only + // Network port fields + private static final String networkPortsKey = "networkPorts"; + // ---------------- Serialization ---------------------------------------------------- public NodeSerializer(NodeFlavors flavors) { @@ -136,6 +141,7 @@ public class NodeSerializer { object.setLong(currentRestartGenerationKey, allocation.restartGeneration().current()); object.setBool(removableKey, allocation.isRemovable()); object.setString(wantedVespaVersionKey, allocation.membership().cluster().vespaVersion().toString()); + allocation.networkPorts().ifPresent(ports -> NetworkPortsSerializer.toSlime(ports, object.setArray(networkPortsKey))); } private void toSlime(History history, Cursor array) { @@ -197,7 +203,8 @@ public class NodeSerializer { return Optional.of(new Allocation(applicationIdFromSlime(object), clusterMembershipFromSlime(object), generationFromSlime(object, restartGenerationKey, currentRestartGenerationKey), - object.field(removableKey).asBool())); + object.field(removableKey).asBool(), + NetworkPortsSerializer.fromSlime(object.field(networkPortsKey)))); } private ApplicationId applicationIdFromSlime(Inspector object) { diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java index f48f0c1bdce..4626a600d2c 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java @@ -9,6 +9,7 @@ import com.yahoo.transaction.NestedTransaction; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeList; import com.yahoo.vespa.hosted.provision.NodeRepository; +import com.yahoo.vespa.hosted.provision.node.Allocation; import java.util.ArrayList; import java.util.Collection; @@ -73,7 +74,7 @@ class Activator { activeToRemove = activeToRemove.stream().map(Node::unretire).collect(Collectors.toList()); // only active nodes can be retired nodeRepository.deactivate(activeToRemove, transaction); nodeRepository.activate(updateFrom(hosts, continuedActive), transaction); // update active with any changes - nodeRepository.activate(reservedToActivate, transaction); + nodeRepository.activate(updatePortsFrom(hosts, reservedToActivate), transaction); } } @@ -133,7 +134,11 @@ class Activator { for (Node node : nodes) { HostSpec hostSpec = getHost(node.hostname(), hosts); node = hostSpec.membership().get().retired() ? node.retire(nodeRepository.clock().instant()) : node.unretire(); - node = node.with(node.allocation().get().with(hostSpec.membership().get())); + Allocation allocation = node.allocation().get().with(hostSpec.membership().get()); + if (hostSpec.networkPorts().isPresent()) { + allocation = allocation.withNetworkPorts(hostSpec.networkPorts().get()); + } + node = node.with(allocation); if (hostSpec.flavor().isPresent()) // Docker nodes may change flavor node = node.with(hostSpec.flavor().get()); updated.add(node); @@ -141,6 +146,23 @@ class Activator { return updated; } + /** + * Returns the input nodes with any port allocations from the hosts + */ + private List<Node> updatePortsFrom(Collection<HostSpec> hosts, List<Node> nodes) { + List<Node> updated = new ArrayList<>(); + for (Node node : nodes) { + HostSpec hostSpec = getHost(node.hostname(), hosts); + Allocation allocation = node.allocation().get(); + if (hostSpec.networkPorts().isPresent()) { + allocation = allocation.withNetworkPorts(hostSpec.networkPorts().get()); + node = node.with(allocation); + } + updated.add(node); + } + return updated; + } + private HostSpec getHost(String hostname, Collection<HostSpec> fromHosts) { for (HostSpec host : fromHosts) if (host.hostname().equals(hostname)) diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java index a0d76241533..246c56ee28b 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java @@ -22,11 +22,13 @@ import com.yahoo.vespa.flags.Flags; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeRepository; import com.yahoo.vespa.hosted.provision.flag.FlagId; +import com.yahoo.vespa.hosted.provision.node.Allocation; import com.yahoo.vespa.hosted.provision.node.filter.ApplicationFilter; import com.yahoo.vespa.hosted.provision.node.filter.NodeHostFilter; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Optional; @@ -141,10 +143,16 @@ public class NodeRepositoryProvisioner implements Provisioner { List<HostSpec> hosts = new ArrayList<>(nodes.size()); for (Node node : nodes) { log.log(LogLevel.DEBUG, () -> "Prepared node " + node.hostname() + " - " + node.flavor()); + Allocation nodeAllocation = node.allocation().orElseThrow(IllegalStateException::new); hosts.add(new HostSpec(node.hostname(), - node.allocation().orElseThrow(IllegalStateException::new).membership(), - node.flavor(), - node.status().vespaVersion())); + Collections.emptyList(), + Optional.of(node.flavor()), + Optional.of(nodeAllocation.membership()), + node.status().vespaVersion(), + nodeAllocation.networkPorts())); + if (nodeAllocation.networkPorts().isPresent()) { + log.log(LogLevel.DEBUG, () -> "Prepared node " + node.hostname() + " has port allocations"); + } } return hosts; } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/v2/NodesResponse.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/v2/NodesResponse.java index ba513db5342..1254664eb78 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/v2/NodesResponse.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/v2/NodesResponse.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.hosted.provision.restapi.v2; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ClusterMembership; +import com.yahoo.config.provision.NetworkPortsSerializer; import com.yahoo.config.provision.NodeType; import com.yahoo.container.jdisc.HttpRequest; import com.yahoo.container.jdisc.HttpResponse; @@ -153,17 +154,18 @@ class NodesResponse extends HttpResponse { object.setBool("fastDisk", node.flavor().hasFastDisk()); object.setDouble("bandwidth", node.flavor().getBandwidth()); object.setString("environment", node.flavor().getType().name()); - if (node.allocation().isPresent()) { - toSlime(node.allocation().get().owner(), object.setObject("owner")); - toSlime(node.allocation().get().membership(), object.setObject("membership")); - object.setLong("restartGeneration", node.allocation().get().restartGeneration().wanted()); - object.setLong("currentRestartGeneration", node.allocation().get().restartGeneration().current()); - object.setString("wantedDockerImage", nodeRepository.dockerImage().withTag(node.allocation().get().membership().cluster().vespaVersion()).asString()); - object.setString("wantedVespaVersion", node.allocation().get().membership().cluster().vespaVersion().toFullString()); + node.allocation().ifPresent(allocation -> { + toSlime(allocation.owner(), object.setObject("owner")); + toSlime(allocation.membership(), object.setObject("membership")); + object.setLong("restartGeneration", allocation.restartGeneration().wanted()); + object.setLong("currentRestartGeneration", allocation.restartGeneration().current()); + object.setString("wantedDockerImage", nodeRepository.dockerImage().withTag(allocation.membership().cluster().vespaVersion()).asString()); + object.setString("wantedVespaVersion", allocation.membership().cluster().vespaVersion().toFullString()); + allocation.networkPorts().ifPresent(ports -> NetworkPortsSerializer.toSlime(ports, object.setArray("networkPorts"))); orchestrator.apply(new HostName(node.hostname())) .map(status -> status == HostStatus.ALLOWED_TO_BE_DOWN) .ifPresent(allowedToBeDown -> object.setBool("allowedToBeDown", allowedToBeDown)); - } + }); object.setLong("rebootGeneration", node.status().reboot().wanted()); object.setLong("currentRebootGeneration", node.status().reboot().current()); node.status().osVersion().ifPresent(version -> object.setString("currentOsVersion", version.toFullString())); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/ContainerConfig.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/ContainerConfig.java index e17e1871555..2a07cadc6ad 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/ContainerConfig.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/ContainerConfig.java @@ -12,7 +12,7 @@ public class ContainerConfig { public static String servicesXmlV2(int port) { return "<jdisc version='1.0'>\n" + " <config name=\"container.handler.threadpool\">\n" + - " <maxthreads>10</maxthreads>\n" + + " <maxthreads>20</maxthreads>\n" + " </config> \n" + " <component id='com.yahoo.test.ManualClock'/>\n" + " <component id='com.yahoo.vespa.curator.mock.MockCurator'/>\n" + diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/SerializationTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/SerializationTest.java index 29229efc662..53f6b745da1 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/SerializationTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/SerializationTest.java @@ -8,6 +8,7 @@ import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ApplicationName; import com.yahoo.config.provision.ClusterMembership; import com.yahoo.config.provision.InstanceName; +import com.yahoo.config.provision.NetworkPorts; import com.yahoo.config.provision.NodeFlavors; import com.yahoo.config.provision.NodeType; import com.yahoo.config.provision.TenantName; @@ -29,8 +30,11 @@ import org.junit.Test; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; +import java.util.List; import java.util.Optional; import java.util.stream.Collectors; @@ -371,6 +375,31 @@ public class SerializationTest { assertEquals("some model", node.modelName().get()); } + @Test + public void testNodeWithNetworkPorts() { + Node node = createNode(); + List<NetworkPorts.Allocation> list = new ArrayList<>(); + list.add(new NetworkPorts.Allocation(8080, "container", "default/0", "http")); + list.add(new NetworkPorts.Allocation(19101, "searchnode", "other/1", "rpc")); + NetworkPorts ports = new NetworkPorts(list); + node = node.allocate(ApplicationId.from(TenantName.from("myTenant"), + ApplicationName.from("myApplication"), + InstanceName.from("myInstance")), + ClusterMembership.from("content/myId/0/0", Vtag.currentVersion), + clock.instant()); + assertTrue(node.allocation().isPresent()); + node = node.with(node.allocation().get().withNetworkPorts(ports)); + assertTrue(node.allocation().isPresent()); + assertTrue(node.allocation().get().networkPorts().isPresent()); + Node copy = nodeSerializer.fromJson(Node.State.provisioned, nodeSerializer.toJson(node)); + assertTrue(copy.allocation().isPresent()); + assertTrue(copy.allocation().get().networkPorts().isPresent()); + NetworkPorts portsCopy = node.allocation().get().networkPorts().get(); + Collection<NetworkPorts.Allocation> listCopy = portsCopy.allocations(); + assertEquals(list, listCopy); + } + + private byte[] createNodeJson(String hostname, String... ipAddress) { String ipAddressJsonPart = ""; if (ipAddress.length > 0) { diff --git a/parent/pom.xml b/parent/pom.xml index 7d7dfc1dc08..1cececbd8e8 100644 --- a/parent/pom.xml +++ b/parent/pom.xml @@ -708,6 +708,11 @@ <artifactId>assertj-core</artifactId> <version>3.11.1</version> </dependency> + <dependency> + <groupId>com.amazonaws</groupId> + <artifactId>aws-java-sdk-core</artifactId> + <version>${aws.sdk.version}</version> + </dependency> </dependencies> </dependencyManagement> @@ -717,6 +722,7 @@ <apache.httpclient.version>4.4.1</apache.httpclient.version> <apache.httpcore.version>4.4.1</apache.httpcore.version> <asm.version>7.0</asm.version> + <aws.sdk.version>1.11.357</aws.sdk.version> <jna.version>4.5.2</jna.version> <tensorflow.version>1.12.0</tensorflow.version> <!-- Athenz dependencies. Make sure these dependencies matches those in Vespa's internal repositories --> diff --git a/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp b/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp index afbb1c30f17..78cd9ce44b9 100644 --- a/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp +++ b/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp @@ -20,6 +20,7 @@ #include <vespa/document/update/removevalueupdate.h> #include <vespa/document/update/tensor_add_update.h> #include <vespa/document/update/tensor_modify_update.h> +#include <vespa/document/update/tensor_remove_update.h> #include <vespa/eval/tensor/default_tensor_engine.h> #include <vespa/eval/tensor/tensor.h> #include <vespa/searchcore/proton/common/attribute_updater.h> @@ -28,8 +29,8 @@ #include <vespa/searchlib/attribute/reference_attribute.h> #include <vespa/searchlib/tensor/dense_tensor_attribute.h> #include <vespa/searchlib/tensor/generic_tensor_attribute.h> -#include <vespa/vespalib/testkit/testapp.h> #include <vespa/vespalib/stllike/hash_map.hpp> +#include <vespa/vespalib/testkit/testapp.h> #include <vespa/log/log.h> LOG_SETUP("attribute_updater_test"); @@ -76,7 +77,8 @@ makeDocumentTypeRepo() .addField("wsfloat", Wset(DataType::T_FLOAT)) .addField("wsstring", Wset(DataType::T_STRING)) .addField("ref", 333) - .addField("dense_tensor", DataType::T_TENSOR), + .addField("dense_tensor", DataType::T_TENSOR) + .addField("sparse_tensor", DataType::T_TENSOR), Struct("testdoc.body")) .referenceType(333, 222); return std::make_unique<DocumentTypeRepo>(builder.config()); @@ -416,35 +418,54 @@ makeTensorFieldValue(const TensorSpec &spec) return result; } -void -setTensor(TensorAttribute &attribute, uint32_t lid, const TensorSpec &spec) -{ - auto tensor = makeTensor(spec); - attribute.setTensor(lid, *tensor); - attribute.commit(); -} +template <typename TensorAttributeType> +struct TensorFixture : public Fixture { + vespalib::string type; + std::unique_ptr<TensorAttributeType> attribute; -TEST_F("require that tensor modify update is applied", Fixture) -{ - vespalib::string type = "tensor(x[2])"; - auto attribute = makeTensorAttribute<DenseTensorAttribute>("dense_tensor", type); - setTensor(*attribute, 1, TensorSpec(type).add({{"x", 0}}, 3).add({{"x", 1}}, 5)); + TensorFixture(const vespalib::string &type_, const vespalib::string &name) + : type(type_), + attribute(makeTensorAttribute<TensorAttributeType>(name, type)) + { + } - f.applyValueUpdate(*attribute, 1, + void setTensor(const TensorSpec &spec) { + auto tensor = makeTensor(spec); + attribute->setTensor(1, *tensor); + attribute->commit(); + } + + void assertTensor(const TensorSpec &expSpec) { + EXPECT_EQUAL(expSpec, attribute->getTensor(1)->toSpec()); + } +}; + +TEST_F("require that tensor modify update is applied", + TensorFixture<DenseTensorAttribute>("tensor(x[2])", "dense_tensor")) +{ + f.setTensor(TensorSpec(f.type).add({{"x", 0}}, 3).add({{"x", 1}}, 5)); + f.applyValueUpdate(*f.attribute, 1, TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, makeTensorFieldValue(TensorSpec("tensor(x{})").add({{"x", 0}}, 7)))); - EXPECT_EQUAL(TensorSpec(type).add({{"x", 0}}, 7).add({{"x", 1}}, 5), attribute->getTensor(1)->toSpec()); + f.assertTensor(TensorSpec(f.type).add({{"x", 0}}, 7).add({{"x", 1}}, 5)); } -TEST_F("require that tensor add update is applied", Fixture) +TEST_F("require that tensor add update is applied", + TensorFixture<GenericTensorAttribute>("tensor(x{})", "sparse_tensor")) { - vespalib::string type = "tensor(x{})"; - auto attribute = makeTensorAttribute<GenericTensorAttribute>("dense_tensor", type); - setTensor(*attribute, 1, TensorSpec(type).add({{"x", "a"}}, 2)); + f.setTensor(TensorSpec(f.type).add({{"x", "a"}}, 2)); + f.applyValueUpdate(*f.attribute, 1, + TensorAddUpdate(makeTensorFieldValue(TensorSpec(f.type).add({{"x", "a"}}, 3)))); + f.assertTensor(TensorSpec(f.type).add({{"x", "a"}}, 3)); +} - f.applyValueUpdate(*attribute, 1, - TensorAddUpdate(makeTensorFieldValue(TensorSpec(type).add({{"x", "a"}}, 3)))); - EXPECT_EQUAL(TensorSpec(type).add({{"x", "a"}}, 3), attribute->getTensor(1)->toSpec()); +TEST_F("require that tensor remove update is applied", + TensorFixture<GenericTensorAttribute>("tensor(x{})", "sparse_tensor")) +{ + f.setTensor(TensorSpec(f.type).add({{"x", "a"}}, 2).add({{"x", "b"}}, 3)); + f.applyValueUpdate(*f.attribute, 1, + TensorRemoveUpdate(makeTensorFieldValue(TensorSpec(f.type).add({{"x", "b"}}, 1)))); + f.assertTensor(TensorSpec(f.type).add({{"x", "a"}}, 2)); } } diff --git a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp index 933857cffed..fcca1c2a737 100644 --- a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp +++ b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp @@ -16,6 +16,7 @@ #include <vespa/document/update/removevalueupdate.h> #include <vespa/document/update/tensor_add_update.h> #include <vespa/document/update/tensor_modify_update.h> +#include <vespa/document/update/tensor_remove_update.h> #include <vespa/eval/tensor/tensor.h> #include <vespa/searchlib/attribute/attributevector.hpp> #include <vespa/searchlib/attribute/changevector.hpp> @@ -238,6 +239,8 @@ AttributeUpdater::handleUpdate(TensorAttribute &vec, uint32_t lid, const ValueUp applyTensorUpdate(vec, lid, static_cast<const TensorModifyUpdate &>(upd)); } else if (op == ValueUpdate::TensorAddUpdate) { applyTensorUpdate(vec, lid, static_cast<const TensorAddUpdate &>(upd)); + } else if (op == ValueUpdate::TensorRemoveUpdate) { + applyTensorUpdate(vec, lid, static_cast<const TensorRemoveUpdate &>(upd)); } else if (op == ValueUpdate::Clear) { vec.clearDoc(lid); } else { diff --git a/security-utils/pom.xml b/security-utils/pom.xml index 0a26c73cf70..10dec598915 100644 --- a/security-utils/pom.xml +++ b/security-utils/pom.xml @@ -54,6 +54,11 @@ <artifactId>assertj-core</artifactId> <scope>test</scope> </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-core</artifactId> + <scope>test</scope> + </dependency> </dependencies> <build> <plugins> diff --git a/security-utils/src/main/java/com/yahoo/security/KeyStoreType.java b/security-utils/src/main/java/com/yahoo/security/KeyStoreType.java index 7fb8df35286..d72bd45865d 100644 --- a/security-utils/src/main/java/com/yahoo/security/KeyStoreType.java +++ b/security-utils/src/main/java/com/yahoo/security/KeyStoreType.java @@ -16,7 +16,7 @@ public enum KeyStoreType { }, PKCS12 { KeyStore createKeystore() throws KeyStoreException { - return KeyStore.getInstance("PKCS12", BouncyCastleProviderHolder.getInstance()); + return KeyStore.getInstance("PKCS12"); } }; abstract KeyStore createKeystore() throws GeneralSecurityException; diff --git a/security-utils/src/main/java/com/yahoo/security/SslContextBuilder.java b/security-utils/src/main/java/com/yahoo/security/SslContextBuilder.java index 09a5a87138f..1ef4df9c7bc 100644 --- a/security-utils/src/main/java/com/yahoo/security/SslContextBuilder.java +++ b/security-utils/src/main/java/com/yahoo/security/SslContextBuilder.java @@ -1,11 +1,14 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.security; +import com.yahoo.security.tls.KeyManagerUtils; +import com.yahoo.security.tls.TrustManagerUtils; + import javax.net.ssl.KeyManager; -import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.TrustManager; -import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedKeyManager; +import javax.net.ssl.X509ExtendedTrustManager; import java.io.IOException; import java.io.UncheckedIOException; import java.nio.file.Files; @@ -19,14 +22,17 @@ import java.util.List; import static java.util.Collections.singletonList; /** + * A builder for {@link SSLContext}. + * * @author bjorncs */ public class SslContextBuilder { - private KeyStoreSupplier trustStoreSupplier; - private KeyStoreSupplier keyStoreSupplier; + private KeyStoreSupplier trustStoreSupplier = () -> null; + private KeyStoreSupplier keyStoreSupplier = () -> null; private char[] keyStorePassword; - private TrustManagersFactory trustManagersFactory = SslContextBuilder::createDefaultTrustManagers; + private TrustManagerFactory trustManagerFactory = TrustManagerUtils::createDefaultX509TrustManager; + private KeyManagerFactory keyManagerFactory = KeyManagerUtils::createDefaultX509KeyManager; public SslContextBuilder() {} @@ -94,18 +100,21 @@ public class SslContextBuilder { return this; } - public SslContextBuilder withTrustManagerFactory(TrustManagersFactory trustManagersFactory) { - this.trustManagersFactory = trustManagersFactory; + public SslContextBuilder withTrustManagerFactory(TrustManagerFactory trustManagersFactory) { + this.trustManagerFactory = trustManagersFactory; + return this; + } + + public SslContextBuilder withKeyManagerFactory(KeyManagerFactory keyManagerFactory) { + this.keyManagerFactory = keyManagerFactory; return this; } public SSLContext build() { try { SSLContext sslContext = SSLContext.getInstance("TLSv1.2"); - TrustManager[] trustManagers = - trustStoreSupplier != null ? createTrustManagers(trustManagersFactory, trustStoreSupplier) : null; - KeyManager[] keyManagers = - keyStoreSupplier != null ? createKeyManagers(keyStoreSupplier, keyStorePassword) : null; + TrustManager[] trustManagers = new TrustManager[] { trustManagerFactory.createTrustManager(trustStoreSupplier.get()) }; + KeyManager[] keyManagers = new KeyManager[] { keyManagerFactory.createKeyManager(keyStoreSupplier.get(), keyStorePassword) }; sslContext.init(keyManagers, trustManagers, null); return sslContext; } catch (GeneralSecurityException e) { @@ -115,27 +124,6 @@ public class SslContextBuilder { } } - private static TrustManager[] createTrustManagers(TrustManagersFactory trustManagersFactory, KeyStoreSupplier trustStoreSupplier) - throws GeneralSecurityException, IOException { - KeyStore truststore = trustStoreSupplier.get(); - return trustManagersFactory.createTrustManagers(truststore); - } - - private static TrustManager[] createDefaultTrustManagers(KeyStore truststore) throws GeneralSecurityException { - TrustManagerFactory trustManagerFactory = - TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); - trustManagerFactory.init(truststore); - return trustManagerFactory.getTrustManagers(); - } - - private static KeyManager[] createKeyManagers(KeyStoreSupplier keyStoreSupplier, char[] password) - throws GeneralSecurityException, IOException { - KeyManagerFactory keyManagerFactory = - KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); - keyManagerFactory.init(keyStoreSupplier.get(), password); - return keyManagerFactory.getKeyManagers(); - } - private static KeyStore createTrustStore(List<X509Certificate> caCertificates) { KeyStoreBuilder trustStoreBuilder = KeyStoreBuilder.withType(KeyStoreType.JKS); for (int i = 0; i < caCertificates.size(); i++) { @@ -149,11 +137,19 @@ public class SslContextBuilder { } /** - * A factory interface that is similar to {@link TrustManagerFactory}, but is an interface instead of a class. + * A factory interface for creating {@link X509ExtendedTrustManager}. + */ + @FunctionalInterface + public interface TrustManagerFactory { + X509ExtendedTrustManager createTrustManager(KeyStore truststore) throws GeneralSecurityException; + } + + /** + * A factory interface for creating {@link X509ExtendedKeyManager}. */ @FunctionalInterface - public interface TrustManagersFactory { - TrustManager[] createTrustManagers(KeyStore truststore) throws GeneralSecurityException; + public interface KeyManagerFactory { + X509ExtendedKeyManager createKeyManager(KeyStore truststore, char[] password) throws GeneralSecurityException; } } diff --git a/security-utils/src/main/java/com/yahoo/security/tls/AutoReloadingX509KeyManager.java b/security-utils/src/main/java/com/yahoo/security/tls/AutoReloadingX509KeyManager.java new file mode 100644 index 00000000000..0dae185995c --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/AutoReloadingX509KeyManager.java @@ -0,0 +1,150 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls; + +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyStoreType; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.X509CertificateUtils; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.X509ExtendedKeyManager; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.Socket; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.KeyStore; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.time.Duration; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * A {@link X509ExtendedKeyManager} that reloads the certificate and private key from file regularly. + * + * @author bjorncs + */ +public class AutoReloadingX509KeyManager extends X509ExtendedKeyManager implements AutoCloseable { + + private static final Duration UPDATE_PERIOD = Duration.ofHours(1); + + private static final Logger log = Logger.getLogger(AutoReloadingX509KeyManager.class.getName()); + + private final MutableX509KeyManager mutableX509KeyManager; + private final ScheduledExecutorService scheduler; + private final Path privateKeyFile; + private final Path certificatesFile; + + private AutoReloadingX509KeyManager(Path privateKeyFile, Path certificatesFile) { + this(privateKeyFile, certificatesFile, createDefaultScheduler()); + } + + AutoReloadingX509KeyManager(Path privateKeyFile, Path certificatesFile, ScheduledExecutorService scheduler) { + this.privateKeyFile = privateKeyFile; + this.certificatesFile = certificatesFile; + this.scheduler = scheduler; + this.mutableX509KeyManager = new MutableX509KeyManager(createKeystore(privateKeyFile, certificatesFile), new char[0]); + scheduler.scheduleAtFixedRate( + new KeyManagerReloader(), UPDATE_PERIOD.getSeconds()/*initial delay*/, UPDATE_PERIOD.getSeconds(), TimeUnit.SECONDS); + } + + public static AutoReloadingX509KeyManager fromPemFiles(Path privateKeyFile, Path certificatesFile) { + return new AutoReloadingX509KeyManager(privateKeyFile, certificatesFile); + } + + private static KeyStore createKeystore(Path privateKey, Path certificateChain) { + try { + return KeyStoreBuilder.withType(KeyStoreType.PKCS12) + .withKeyEntry( + "default", + KeyUtils.fromPemEncodedPrivateKey(Files.readString(privateKey)), + X509CertificateUtils.certificateListFromPem(Files.readString(certificateChain))) + .build(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private static ScheduledExecutorService createDefaultScheduler() { + return Executors.newSingleThreadScheduledExecutor(runnable -> { + Thread thread = new Thread(runnable, "auto-reloading-x509-key-manager"); + thread.setDaemon(true); + return thread; + }); + } + + private class KeyManagerReloader implements Runnable { + @Override + public void run() { + try { + log.log(Level.FINE, () -> String.format("Reloading key and certificate chain (private-key='%s', certificates='%s')", privateKeyFile, certificatesFile)); + mutableX509KeyManager.updateKeystore(createKeystore(privateKeyFile, certificatesFile), new char[0]); + } catch (Throwable t) { + log.log(Level.SEVERE, + String.format("Failed to load X509 key manager (private-key='%s', certificates='%s'): %s", + privateKeyFile, certificatesFile, t.getMessage()), + t); + } + } + } + + @Override + public void close() { + try { + scheduler.shutdownNow(); + scheduler.awaitTermination(5, TimeUnit.SECONDS); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + // + // Methods from X509ExtendedKeyManager + // + + @Override + public String[] getServerAliases(String keyType, Principal[] issuers) { + return mutableX509KeyManager.getServerAliases(keyType, issuers); + } + + @Override + public String[] getClientAliases(String keyType, Principal[] issuers) { + return mutableX509KeyManager.getClientAliases(keyType, issuers); + } + + @Override + public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) { + return mutableX509KeyManager.chooseServerAlias(keyType, issuers, socket); + } + + @Override + public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) { + return mutableX509KeyManager.chooseClientAlias(keyType, issuers, socket); + } + + @Override + public String chooseEngineServerAlias(String keyType, Principal[] issuers, SSLEngine engine) { + return mutableX509KeyManager.chooseEngineServerAlias(keyType, issuers, engine); + } + + @Override + public String chooseEngineClientAlias(String[] keyType, Principal[] issuers, SSLEngine engine) { + return mutableX509KeyManager.chooseEngineClientAlias(keyType, issuers, engine); + } + + @Override + public X509Certificate[] getCertificateChain(String alias) { + return mutableX509KeyManager.getCertificateChain(alias); + } + + @Override + public PrivateKey getPrivateKey(String alias) { + return mutableX509KeyManager.getPrivateKey(alias); + } + +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/DefaultTlsContext.java b/security-utils/src/main/java/com/yahoo/security/tls/DefaultTlsContext.java index 2befd50332a..c9c326df9ed 100644 --- a/security-utils/src/main/java/com/yahoo/security/tls/DefaultTlsContext.java +++ b/security-utils/src/main/java/com/yahoo/security/tls/DefaultTlsContext.java @@ -7,7 +7,7 @@ import com.yahoo.security.tls.policy.AuthorizedPeers; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; -import java.nio.file.Path; +import javax.net.ssl.SSLParameters; import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.util.Arrays; @@ -38,7 +38,8 @@ public class DefaultTlsContext implements TlsContext { private static final Logger log = Logger.getLogger(DefaultTlsContext.class.getName()); private final SSLContext sslContext; - private final List<String> acceptedCiphers; + private final String[] validCiphers; + private final String[] validProtocols; public DefaultTlsContext(List<X509Certificate> certificates, PrivateKey privateKey, @@ -46,49 +47,77 @@ public class DefaultTlsContext implements TlsContext { AuthorizedPeers authorizedPeers, AuthorizationMode mode, List<String> acceptedCiphers) { - this.sslContext = createSslContext(certificates, privateKey, caCertificates, authorizedPeers, mode); - this.acceptedCiphers = acceptedCiphers; + this(createSslContext(certificates, privateKey, caCertificates, authorizedPeers, mode), + acceptedCiphers); } - public DefaultTlsContext(Path tlsOptionsConfigFile, AuthorizationMode mode) { - TransportSecurityOptions options = TransportSecurityOptions.fromJsonFile(tlsOptionsConfigFile); - this.sslContext = createSslContext(options, mode); - this.acceptedCiphers = options.getAcceptedCiphers(); - } - @Override - public SSLEngine createSslEngine() { - SSLEngine sslEngine = sslContext.createSSLEngine(); - restrictSetOfEnabledCiphers(sslEngine, acceptedCiphers); - restrictTlsProtocols(sslEngine); - return sslEngine; + public DefaultTlsContext(SSLContext sslContext, List<String> acceptedCiphers) { + this.sslContext = sslContext; + this.validCiphers = getAllowedCiphers(sslContext, acceptedCiphers); + this.validProtocols = getAllowedProtocols(sslContext); } - private static void restrictSetOfEnabledCiphers(SSLEngine sslEngine, List<String> acceptedCiphers) { - String[] validCipherSuites = Arrays.stream(sslEngine.getSupportedCipherSuites()) + + private static String[] getAllowedCiphers(SSLContext sslContext, List<String> acceptedCiphers) { + String[] supportedCipherSuites = sslContext.getSupportedSSLParameters().getCipherSuites(); + String[] validCipherSuites = Arrays.stream(supportedCipherSuites) .filter(suite -> ALLOWED_CIPHER_SUITES.contains(suite) && (acceptedCiphers.isEmpty() || acceptedCiphers.contains(suite))) .toArray(String[]::new); if (validCipherSuites.length == 0) { throw new IllegalStateException( String.format("None of the allowed cipher suites are supported " + "(allowed-cipher-suites=%s, supported-cipher-suites=%s, accepted-cipher-suites=%s)", - ALLOWED_CIPHER_SUITES, List.of(sslEngine.getSupportedCipherSuites()), acceptedCiphers)); + ALLOWED_CIPHER_SUITES, List.of(supportedCipherSuites), acceptedCiphers)); } - log.log(Level.FINE, () -> String.format("Allowed cipher suites that are supported: %s", Arrays.toString(validCipherSuites))); - sslEngine.setEnabledCipherSuites(validCipherSuites); + log.log(Level.FINE, () -> String.format("Allowed cipher suites that are supported: %s", List.of(validCipherSuites))); + return validCipherSuites; } - private static void restrictTlsProtocols(SSLEngine sslEngine) { - String[] validProtocols = Arrays.stream(sslEngine.getSupportedProtocols()) + private static String[] getAllowedProtocols(SSLContext sslContext) { + String[] supportedProtocols = sslContext.getSupportedSSLParameters().getProtocols(); + String[] validProtocols = Arrays.stream(supportedProtocols) .filter(ALLOWED_PROTOCOLS::contains) .toArray(String[]::new); if (validProtocols.length == 0) { throw new IllegalArgumentException( String.format("None of the allowed protocols are supported (allowed-protocols=%s, supported-protocols=%s)", - ALLOWED_PROTOCOLS, Arrays.toString(sslEngine.getSupportedProtocols()))); + ALLOWED_PROTOCOLS, List.of(supportedProtocols))); } - log.log(Level.FINE, () -> String.format("Allowed protocols that are supported: %s", Arrays.toString(validProtocols))); - sslEngine.setEnabledProtocols(validProtocols); + log.log(Level.FINE, () -> String.format("Allowed protocols that are supported: %s", List.of(validProtocols))); + return validProtocols; + } + + @Override + public SSLContext context() { + return sslContext; + } + + @Override + public SSLParameters parameters() { + return createSslParameters(); + } + + @Override + public SSLEngine createSslEngine() { + SSLEngine sslEngine = sslContext.createSSLEngine(); + sslEngine.setSSLParameters(createSslParameters()); + return sslEngine; + } + + @Override + public SSLEngine createSslEngine(String peerHost, int peerPort) { + SSLEngine sslEngine = sslContext.createSSLEngine(peerHost, peerPort); + sslEngine.setSSLParameters(createSslParameters()); + return sslEngine; + } + + private SSLParameters createSslParameters() { + SSLParameters newParameters = sslContext.getDefaultSSLParameters(); + newParameters.setCipherSuites(validCiphers); + newParameters.setProtocols(validProtocols); + newParameters.setNeedClientAuth(true); + return newParameters; } private static SSLContext createSslContext(List<X509Certificate> certificates, @@ -109,16 +138,5 @@ public class DefaultTlsContext implements TlsContext { return builder.build(); } - private static SSLContext createSslContext(TransportSecurityOptions options, AuthorizationMode mode) { - SslContextBuilder builder = new SslContextBuilder(); - options.getCertificatesFile() - .ifPresent(certificates -> builder.withKeyStore(options.getPrivateKeyFile().get(), certificates)); - options.getCaCertificatesFile().ifPresent(builder::withTrustStore); - if (mode != AuthorizationMode.DISABLE) { - options.getAuthorizedPeers().ifPresent( - authorizedPeers -> builder.withTrustManagerFactory(new PeerAuthorizerTrustManagersFactory(authorizedPeers, mode))); - } - return builder.build(); - } } diff --git a/security-utils/src/main/java/com/yahoo/security/tls/KeyManagerUtils.java b/security-utils/src/main/java/com/yahoo/security/tls/KeyManagerUtils.java new file mode 100644 index 00000000000..2e48de3c01f --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/KeyManagerUtils.java @@ -0,0 +1,49 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls; + +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyStoreType; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.X509ExtendedKeyManager; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.List; + +/** + * Utility methods for constructing {@link X509ExtendedKeyManager}. + * + * @author bjorncs + */ +public class KeyManagerUtils { + + public static X509ExtendedKeyManager createDefaultX509KeyManager(KeyStore keystore, char[] password) { + try { + KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(keystore, password); + KeyManager[] keyManagers = keyManagerFactory.getKeyManagers(); + return Arrays.stream(keyManagers) + .filter(manager -> manager instanceof X509ExtendedKeyManager) + .map(X509ExtendedKeyManager.class::cast) + .findFirst() + .orElseThrow(() -> new RuntimeException("No X509ExtendedKeyManager in " + List.of(keyManagers))); + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } + + public static X509ExtendedKeyManager createDefaultX509KeyManager(PrivateKey privateKey, List<X509Certificate> certificateChain) { + KeyStore keystore = KeyStoreBuilder.withType(KeyStoreType.PKCS12) + .withKeyEntry("default", privateKey, certificateChain) + .build(); + return createDefaultX509KeyManager(keystore, new char[0]); + } + + public static X509ExtendedKeyManager createDefaultX509KeyManager() { + return createDefaultX509KeyManager(null, new char[0]); + } +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/MutableX509KeyManager.java b/security-utils/src/main/java/com/yahoo/security/tls/MutableX509KeyManager.java new file mode 100644 index 00000000000..e5e56f7a181 --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/MutableX509KeyManager.java @@ -0,0 +1,106 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.X509ExtendedKeyManager; +import java.net.Socket; +import java.security.KeyStore; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.WeakHashMap; + +/** + * A {@link X509ExtendedKeyManager} which can be updated with new certificate chain and private key while in use. + * + * The implementations assumes that aliases are retrieved from the same thread as the certificate chain and private key. + * This is case for OpenJDK 11. + * + * @author bjorncs + */ +public class MutableX509KeyManager extends X509ExtendedKeyManager { + + // Not using ThreadLocal as we want the x509 key manager instances to be collected + // when either the thread dies or the MutableX509KeyManager instance is collected (latter not the case for ThreadLocal). + private final WeakHashMap<Thread, X509ExtendedKeyManager> threadLocalManager = new WeakHashMap<>(); + private volatile X509ExtendedKeyManager currentManager; + + public MutableX509KeyManager(KeyStore keystore, char[] password) { + this.currentManager = KeyManagerUtils.createDefaultX509KeyManager(keystore, password); + } + + public MutableX509KeyManager() { + this.currentManager = KeyManagerUtils.createDefaultX509KeyManager(); + } + + public void updateKeystore(KeyStore keystore, char[] password) { + this.currentManager = KeyManagerUtils.createDefaultX509KeyManager(keystore, password); + } + + public void useDefaultKeystore() { + this.currentManager = KeyManagerUtils.createDefaultX509KeyManager(); + } + + @Override + public String[] getServerAliases(String keyType, Principal[] issuers) { + return updateAndGetThreadLocalManager() + .getServerAliases(keyType, issuers); + } + + @Override + public String[] getClientAliases(String keyType, Principal[] issuers) { + return updateAndGetThreadLocalManager() + .getClientAliases(keyType, issuers); + } + + @Override + public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) { + return updateAndGetThreadLocalManager() + .chooseServerAlias(keyType, issuers, socket); + } + + @Override + public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) { + return updateAndGetThreadLocalManager() + .chooseClientAlias(keyType, issuers, socket); + } + + @Override + public String chooseEngineServerAlias(String keyType, Principal[] issuers, SSLEngine engine) { + return updateAndGetThreadLocalManager() + .chooseEngineServerAlias(keyType, issuers, engine); + } + + @Override + public String chooseEngineClientAlias(String[] keyType, Principal[] issuers, SSLEngine engine) { + return updateAndGetThreadLocalManager() + .chooseEngineClientAlias(keyType, issuers, engine); + } + + private X509ExtendedKeyManager updateAndGetThreadLocalManager() { + X509ExtendedKeyManager currentManager = this.currentManager; + threadLocalManager.put(Thread.currentThread(), currentManager); + return currentManager; + } + + @Override + public X509Certificate[] getCertificateChain(String alias) { + return getThreadLocalManager() + .getCertificateChain(alias); + } + + @Override + public PrivateKey getPrivateKey(String alias) { + return getThreadLocalManager() + .getPrivateKey(alias); + } + + private X509ExtendedKeyManager getThreadLocalManager() { + X509ExtendedKeyManager manager = threadLocalManager.get(Thread.currentThread()); + if (manager == null) { + throw new IllegalStateException("Methods to retrieve valid aliases has not been called previously from this thread"); + } + return manager; + } + +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/MutableX509TrustManager.java b/security-utils/src/main/java/com/yahoo/security/tls/MutableX509TrustManager.java new file mode 100644 index 00000000000..ed424480d26 --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/MutableX509TrustManager.java @@ -0,0 +1,70 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.X509ExtendedTrustManager; +import java.net.Socket; +import java.security.KeyStore; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; + +/** + * A {@link X509ExtendedTrustManager} which can be updated with new CA certificates while in use. + * + * @author bjorncs + */ +public class MutableX509TrustManager extends X509ExtendedTrustManager { + + private volatile X509ExtendedTrustManager currentManager; + + public MutableX509TrustManager(KeyStore truststore) { + this.currentManager = TrustManagerUtils.createDefaultX509TrustManager(truststore); + } + + public MutableX509TrustManager() { + this.currentManager = TrustManagerUtils.createDefaultX509TrustManager(); + } + + public void updateTruststore(KeyStore truststore) { + this.currentManager = TrustManagerUtils.createDefaultX509TrustManager(truststore); + } + + public void useDefaultTruststore() { + this.currentManager = TrustManagerUtils.createDefaultX509TrustManager(); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException { + currentManager.checkClientTrusted(chain, authType); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { + currentManager.checkServerTrusted(chain, authType); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket) throws CertificateException { + currentManager.checkClientTrusted(chain, authType, socket); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) throws CertificateException { + currentManager.checkServerTrusted(chain, authType, socket); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine sslEngine) throws CertificateException { + currentManager.checkClientTrusted(chain, authType, sslEngine); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine sslEngine) throws CertificateException { + currentManager.checkServerTrusted(chain, authType, sslEngine); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return currentManager.getAcceptedIssuers(); + } +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/ReloadingTlsContext.java b/security-utils/src/main/java/com/yahoo/security/tls/ReloadingTlsContext.java index 5add13e067d..b57105f54f9 100644 --- a/security-utils/src/main/java/com/yahoo/security/tls/ReloadingTlsContext.java +++ b/security-utils/src/main/java/com/yahoo/security/tls/ReloadingTlsContext.java @@ -1,13 +1,28 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.security.tls; +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyStoreType; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.SslContextBuilder; +import com.yahoo.security.X509CertificateUtils; +import com.yahoo.security.tls.authz.PeerAuthorizerTrustManager; + +import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.X509ExtendedTrustManager; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Files; import java.nio.file.Path; +import java.security.KeyStore; +import java.security.cert.X509Certificate; import java.time.Duration; +import java.util.List; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; @@ -23,8 +38,9 @@ public class ReloadingTlsContext implements TlsContext { private static final Logger log = Logger.getLogger(ReloadingTlsContext.class.getName()); private final Path tlsOptionsConfigFile; - private final AuthorizationMode mode; - private final AtomicReference<TlsContext> currentTlsContext; + private final TlsContext tlsContext; + private final MutableX509TrustManager trustManager = new MutableX509TrustManager(); + private final MutableX509KeyManager keyManager = new MutableX509KeyManager(); private final ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor(runnable -> { Thread thread = new Thread(runnable, "tls-context-reloader"); @@ -34,19 +50,77 @@ public class ReloadingTlsContext implements TlsContext { public ReloadingTlsContext(Path tlsOptionsConfigFile, AuthorizationMode mode) { this.tlsOptionsConfigFile = tlsOptionsConfigFile; - this.mode = mode; - this.currentTlsContext = new AtomicReference<>(new DefaultTlsContext(tlsOptionsConfigFile, mode)); - this.scheduler.scheduleAtFixedRate(new SslContextReloader(), + TransportSecurityOptions options = TransportSecurityOptions.fromJsonFile(tlsOptionsConfigFile); + reloadCryptoMaterial(options, trustManager, keyManager); + this.tlsContext = createDefaultTlsContext(options, mode, trustManager, keyManager); + this.scheduler.scheduleAtFixedRate(new CryptoMaterialReloader(), UPDATE_PERIOD.getSeconds()/*initial delay*/, UPDATE_PERIOD.getSeconds(), TimeUnit.SECONDS); } - @Override - public SSLEngine createSslEngine() { - return currentTlsContext.get().createSslEngine(); + private static void reloadCryptoMaterial(TransportSecurityOptions options, + MutableX509TrustManager trustManager, + MutableX509KeyManager keyManager) { + if (options.getCaCertificatesFile().isPresent()) { + trustManager.updateTruststore(loadTruststore(options.getCaCertificatesFile().get())); + } else { + trustManager.useDefaultTruststore(); + } + + if (options.getPrivateKeyFile().isPresent() && options.getCertificatesFile().isPresent()) { + keyManager.updateKeystore(loadKeystore(options.getPrivateKeyFile().get(), options.getCertificatesFile().get()), new char[0]); + } else { + keyManager.useDefaultKeystore(); + } } + private static KeyStore loadTruststore(Path caCertificateFile) { + try { + List<X509Certificate> caCertificates = X509CertificateUtils.certificateListFromPem(Files.readString(caCertificateFile)); + KeyStoreBuilder trustStoreBuilder = KeyStoreBuilder.withType(KeyStoreType.PKCS12); + for (int i = 0; i < caCertificates.size(); i++) { + trustStoreBuilder.withCertificateEntry("cert-" + i, caCertificates.get(i)); + } + return trustStoreBuilder.build(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private static KeyStore loadKeystore(Path privateKeyFile, Path certificatesFile) { + try { + return KeyStoreBuilder.withType(KeyStoreType.PKCS12) + .withKeyEntry( + "default", + KeyUtils.fromPemEncodedPrivateKey(Files.readString(privateKeyFile)), + X509CertificateUtils.certificateListFromPem(Files.readString(certificatesFile))) + .build(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private static DefaultTlsContext createDefaultTlsContext(TransportSecurityOptions options, + AuthorizationMode mode, + MutableX509TrustManager mutableTrustManager, + MutableX509KeyManager mutableKeyManager) { + SSLContext sslContext = new SslContextBuilder() + .withKeyManagerFactory((ignoredKeystore, ignoredPassword) -> mutableKeyManager) + .withTrustManagerFactory( + ignoredTruststore -> options.getAuthorizedPeers() + .map(authorizedPeers -> (X509ExtendedTrustManager) new PeerAuthorizerTrustManager(authorizedPeers, mode, mutableTrustManager)) + .orElse(mutableTrustManager)) + .build(); + return new DefaultTlsContext(sslContext, options.getAcceptedCiphers()); + } + + // Wrapped methods from TlsContext + @Override public SSLContext context() { return tlsContext.context(); } + @Override public SSLParameters parameters() { return tlsContext.parameters(); } + @Override public SSLEngine createSslEngine() { return tlsContext.createSslEngine(); } + @Override public SSLEngine createSslEngine(String peerHost, int peerPort) { return tlsContext.createSslEngine(peerHost, peerPort); } + @Override public void close() { try { @@ -57,13 +131,13 @@ public class ReloadingTlsContext implements TlsContext { } } - private class SslContextReloader implements Runnable { + private class CryptoMaterialReloader implements Runnable { @Override public void run() { try { - currentTlsContext.set(new DefaultTlsContext(tlsOptionsConfigFile, mode)); + reloadCryptoMaterial(TransportSecurityOptions.fromJsonFile(tlsOptionsConfigFile), trustManager, keyManager); } catch (Throwable t) { - log.log(Level.SEVERE, String.format("Failed to load SSLContext (path='%s'): %s", tlsOptionsConfigFile, t.getMessage()), t); + log.log(Level.SEVERE, String.format("Failed to reload crypto material (path='%s'): %s", tlsOptionsConfigFile, t.getMessage()), t); } } } diff --git a/security-utils/src/main/java/com/yahoo/security/tls/TlsContext.java b/security-utils/src/main/java/com/yahoo/security/tls/TlsContext.java index 58687a0ba8f..b315dd00b31 100644 --- a/security-utils/src/main/java/com/yahoo/security/tls/TlsContext.java +++ b/security-utils/src/main/java/com/yahoo/security/tls/TlsContext.java @@ -3,6 +3,7 @@ package com.yahoo.security.tls; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; /** * A simplified version of {@link SSLContext} modelled as an interface. @@ -11,8 +12,14 @@ import javax.net.ssl.SSLEngine; */ public interface TlsContext extends AutoCloseable { + SSLContext context(); + + SSLParameters parameters(); + SSLEngine createSslEngine(); + SSLEngine createSslEngine(String peerHost, int peerPort); + @Override default void close() {} } diff --git a/security-utils/src/main/java/com/yahoo/security/tls/TrustManagerUtils.java b/security-utils/src/main/java/com/yahoo/security/tls/TrustManagerUtils.java new file mode 100644 index 00000000000..f114b672ed8 --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/TrustManagerUtils.java @@ -0,0 +1,50 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls; + +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyStoreType; + +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedTrustManager; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.List; + +/** + * Utility methods for constructing {@link X509ExtendedTrustManager}. + * + * @author bjorncs + */ +public class TrustManagerUtils { + + public static X509ExtendedTrustManager createDefaultX509TrustManager(KeyStore truststore) { + try { + TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(truststore); + TrustManager[] trustManagers = trustManagerFactory.getTrustManagers(); + return Arrays.stream(trustManagers) + .filter(manager -> manager instanceof X509ExtendedTrustManager) + .map(X509ExtendedTrustManager.class::cast) + .findFirst() + .orElseThrow(() -> new RuntimeException("No X509ExtendedTrustManager in " + List.of(trustManagers))); + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } + + public static X509ExtendedTrustManager createDefaultX509TrustManager(List<X509Certificate> certificates) { + KeyStoreBuilder truststoreBuilder = KeyStoreBuilder.withType(KeyStoreType.PKCS12); + for (int i = 0; i < certificates.size(); i++) { + truststoreBuilder.withCertificateEntry("cert-" + i, certificates.get(i)); + } + KeyStore truststore = truststoreBuilder.build(); + return createDefaultX509TrustManager(truststore); + } + + public static X509ExtendedTrustManager createDefaultX509TrustManager() { + return createDefaultX509TrustManager((KeyStore) null); + } +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManager.java b/security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManager.java index 80acc940a99..eee2e502183 100644 --- a/security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManager.java +++ b/security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManager.java @@ -3,14 +3,12 @@ package com.yahoo.security.tls.authz; import com.yahoo.security.X509CertificateUtils; import com.yahoo.security.tls.AuthorizationMode; +import com.yahoo.security.tls.TrustManagerUtils; import com.yahoo.security.tls.policy.AuthorizedPeers; import javax.net.ssl.SSLEngine; -import javax.net.ssl.TrustManager; -import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.X509ExtendedTrustManager; import java.net.Socket; -import java.security.GeneralSecurityException; import java.security.KeyStore; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; @@ -39,22 +37,8 @@ public class PeerAuthorizerTrustManager extends X509ExtendedTrustManager { this.defaultTrustManager = defaultTrustManager; } - public static TrustManager[] wrapTrustManagersFromKeystore(AuthorizedPeers authorizedPeers, AuthorizationMode mode, KeyStore keystore) throws GeneralSecurityException { - TrustManagerFactory factory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); - factory.init(keystore); - return wrapTrustManagers(authorizedPeers, mode, factory.getTrustManagers()); - } - - public static TrustManager[] wrapTrustManagers(AuthorizedPeers authorizedPeers, AuthorizationMode mode, TrustManager[] managers) { - TrustManager[] wrappedManagers = new TrustManager[managers.length]; - for (int i = 0; i < managers.length; i++) { - if (managers[i] instanceof X509ExtendedTrustManager) { - wrappedManagers[i] = new PeerAuthorizerTrustManager(authorizedPeers, mode, (X509ExtendedTrustManager) managers[i]); - } else { - wrappedManagers[i] = managers[i]; - } - } - return wrappedManagers; + public PeerAuthorizerTrustManager(AuthorizedPeers authorizedPeers, AuthorizationMode mode, KeyStore truststore) { + this(authorizedPeers, mode, TrustManagerUtils.createDefaultX509TrustManager(truststore)); } @Override diff --git a/security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManagersFactory.java b/security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManagersFactory.java index c0a3b4e41a5..6ec8450c035 100644 --- a/security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManagersFactory.java +++ b/security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManagersFactory.java @@ -5,14 +5,12 @@ import com.yahoo.security.SslContextBuilder; import com.yahoo.security.tls.AuthorizationMode; import com.yahoo.security.tls.policy.AuthorizedPeers; -import javax.net.ssl.TrustManager; -import java.security.GeneralSecurityException; import java.security.KeyStore; /** * @author bjorncs */ -public class PeerAuthorizerTrustManagersFactory implements SslContextBuilder.TrustManagersFactory { +public class PeerAuthorizerTrustManagersFactory implements SslContextBuilder.TrustManagerFactory { private final AuthorizedPeers authorizedPeers; private AuthorizationMode mode; @@ -22,7 +20,7 @@ public class PeerAuthorizerTrustManagersFactory implements SslContextBuilder.Tru } @Override - public TrustManager[] createTrustManagers(KeyStore truststore) throws GeneralSecurityException { - return PeerAuthorizerTrustManager.wrapTrustManagersFromKeystore(authorizedPeers, mode, truststore); + public PeerAuthorizerTrustManager createTrustManager(KeyStore truststore) { + return new PeerAuthorizerTrustManager(authorizedPeers, mode, truststore); } } diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClient.java b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClient.java new file mode 100644 index 00000000000..2911b77707a --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClient.java @@ -0,0 +1,101 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.https; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; +import java.io.IOException; +import java.net.Authenticator; +import java.net.CookieHandler; +import java.net.ProxySelector; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; + +/** + * A {@link HttpClient} that uses either http or https based on the global Vespa TLS configuration. + * + * @author bjorncs + */ +class TlsAwareHttpClient extends HttpClient { + + private final HttpClient wrappedClient; + private final String userAgent; + + TlsAwareHttpClient(HttpClient wrappedClient, String userAgent) { + this.wrappedClient = wrappedClient; + this.userAgent = userAgent; + } + + @Override + public Optional<CookieHandler> cookieHandler() { + return wrappedClient.cookieHandler(); + } + + @Override + public Optional<Duration> connectTimeout() { + return wrappedClient.connectTimeout(); + } + + @Override + public Redirect followRedirects() { + return wrappedClient.followRedirects(); + } + + @Override + public Optional<ProxySelector> proxy() { + return wrappedClient.proxy(); + } + + @Override + public SSLContext sslContext() { + return wrappedClient.sslContext(); + } + + @Override + public SSLParameters sslParameters() { + return wrappedClient.sslParameters(); + } + + @Override + public Optional<Authenticator> authenticator() { + return wrappedClient.authenticator(); + } + + @Override + public Version version() { + return wrappedClient.version(); + } + + @Override + public Optional<Executor> executor() { + return wrappedClient.executor(); + } + + @Override + public <T> HttpResponse<T> send(HttpRequest request, HttpResponse.BodyHandler<T> responseBodyHandler) throws IOException, InterruptedException { + return wrappedClient.send(wrapRequest(request), responseBodyHandler); + } + + @Override + public <T> CompletableFuture<HttpResponse<T>> sendAsync(HttpRequest request, HttpResponse.BodyHandler<T> responseBodyHandler) { + return wrappedClient.sendAsync(wrapRequest(request), responseBodyHandler); + } + + @Override + public <T> CompletableFuture<HttpResponse<T>> sendAsync(HttpRequest request, HttpResponse.BodyHandler<T> responseBodyHandler, HttpResponse.PushPromiseHandler<T> pushPromiseHandler) { + return wrappedClient.sendAsync(wrapRequest(request), responseBodyHandler, pushPromiseHandler); + } + + @Override + public String toString() { + return wrappedClient.toString(); + } + + private HttpRequest wrapRequest(HttpRequest request) { + return new TlsAwareHttpRequest(request, userAgent); + } +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClientBuilder.java b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClientBuilder.java new file mode 100644 index 00000000000..7eca2463ba7 --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClientBuilder.java @@ -0,0 +1,97 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.https; + +import com.yahoo.security.tls.TlsContext; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; +import java.net.Authenticator; +import java.net.CookieHandler; +import java.net.ProxySelector; +import java.net.http.HttpClient; +import java.time.Duration; +import java.util.concurrent.Executor; + +/** + * A client builder for {@link HttpClient} which uses {@link TlsContext} for TLS configuration. + * Intended for internal Vespa communication only. + * + * @author bjorncs + */ +public class TlsAwareHttpClientBuilder implements HttpClient.Builder { + + private final HttpClient.Builder wrappedBuilder; + private final String userAgent; + + public TlsAwareHttpClientBuilder(TlsContext tlsContext) { + this(tlsContext, "vespa-tls-aware-client"); + } + + public TlsAwareHttpClientBuilder(TlsContext tlsContext, String userAgent) { + this.wrappedBuilder = HttpClient.newBuilder() + .sslContext(tlsContext.context()) + .sslParameters(tlsContext.parameters()); + this.userAgent = userAgent; + } + + @Override + public HttpClient.Builder cookieHandler(CookieHandler cookieHandler) { + throw new UnsupportedOperationException(); + } + + @Override + public HttpClient.Builder connectTimeout(Duration duration) { + wrappedBuilder.connectTimeout(duration); + return this; + } + + @Override + public HttpClient.Builder sslContext(SSLContext sslContext) { + throw new UnsupportedOperationException("SSLContext is given from tls context"); + } + + @Override + public HttpClient.Builder sslParameters(SSLParameters sslParameters) { + throw new UnsupportedOperationException("SSLParameters is given from tls context"); + } + + @Override + public HttpClient.Builder executor(Executor executor) { + wrappedBuilder.executor(executor); + return this; + } + + @Override + public HttpClient.Builder followRedirects(HttpClient.Redirect policy) { + wrappedBuilder.followRedirects(policy); + return this; + } + + @Override + public HttpClient.Builder version(HttpClient.Version version) { + wrappedBuilder.version(version); + return this; + } + + @Override + public HttpClient.Builder priority(int priority) { + wrappedBuilder.priority(priority); + return this; + } + + @Override + public HttpClient.Builder proxy(ProxySelector proxySelector) { + throw new UnsupportedOperationException(); + } + + @Override + public HttpClient.Builder authenticator(Authenticator authenticator) { + throw new UnsupportedOperationException(); + } + + @Override + public HttpClient build() { + // TODO Stop wrapping the client once TLS is mandatory + return new TlsAwareHttpClient(wrappedBuilder.build(), userAgent); + } +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpRequest.java b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpRequest.java new file mode 100644 index 00000000000..bbdd8af791f --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpRequest.java @@ -0,0 +1,103 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.https; + +import com.yahoo.security.tls.MixedMode; +import com.yahoo.security.tls.TransportSecurityUtils; + +import java.net.URI; +import java.net.URISyntaxException; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.time.Duration; +import java.util.HashMap; +import java.util.List; +import java.util.Optional; + +/** + * A {@link HttpRequest} where the scheme is either http or https based on the global Vespa TLS configuration. + * + * @author bjorncs + */ +class TlsAwareHttpRequest extends HttpRequest { + + private final URI rewrittenUri; + private final HttpRequest wrappedRequest; + private final HttpHeaders rewrittenHeaders; + + TlsAwareHttpRequest(HttpRequest wrappedRequest, String userAgent) { + this.wrappedRequest = wrappedRequest; + this.rewrittenUri = rewriteUri(wrappedRequest.uri()); + this.rewrittenHeaders = rewriteHeaders(wrappedRequest, userAgent); + } + + @Override + public Optional<BodyPublisher> bodyPublisher() { + return wrappedRequest.bodyPublisher(); + } + + @Override + public String method() { + return wrappedRequest.method(); + } + + @Override + public Optional<Duration> timeout() { + return wrappedRequest.timeout(); + } + + @Override + public boolean expectContinue() { + return wrappedRequest.expectContinue(); + } + + @Override + public URI uri() { + return rewrittenUri; + } + + @Override + public Optional<HttpClient.Version> version() { + return wrappedRequest.version(); + } + + @Override + public HttpHeaders headers() { + return rewrittenHeaders; + } + + private static URI rewriteUri(URI uri) { + if (!uri.getScheme().equals("http")) { + return uri; + } + String rewrittenScheme = + TransportSecurityUtils.getConfigFile().isPresent() && TransportSecurityUtils.getInsecureMixedMode() != MixedMode.PLAINTEXT_CLIENT_MIXED_SERVER ? + "https" : + "http"; + int port = uri.getPort(); + int rewrittenPort = port != -1 ? port : (rewrittenScheme.equals("http") ? 80 : 443); + try { + return new URI(rewrittenScheme, uri.getUserInfo(), uri.getHost(), rewrittenPort, uri.getPath(), uri.getQuery(), uri.getFragment()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + private static HttpHeaders rewriteHeaders(HttpRequest request, String userAgent) { + HttpHeaders headers = request.headers(); + if (headers.firstValue("User-Agent").isPresent()) { + return headers; + } + HashMap<String, List<String>> rewrittenHeaders = new HashMap<>(headers.map()); + rewrittenHeaders.put("User-Agent", List.of(userAgent)); + return HttpHeaders.of(rewrittenHeaders, (ignored1, ignored2) -> true); + } + + @Override + public String toString() { + return "TlsAwareHttpRequest{" + + "rewrittenUri=" + rewrittenUri + + ", wrappedRequest=" + wrappedRequest + + '}'; + } +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/package-info.java b/security-utils/src/main/java/com/yahoo/security/tls/https/package-info.java new file mode 100644 index 00000000000..43067705fa3 --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/https/package-info.java @@ -0,0 +1,8 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +/** + * @author bjorncs + */ +@ExportPackage +package com.yahoo.security.tls.https; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/security-utils/src/test/java/com/yahoo/security/tls/AutoReloadingX509KeyManagerTest.java b/security-utils/src/test/java/com/yahoo/security/tls/AutoReloadingX509KeyManagerTest.java new file mode 100644 index 00000000000..139d5313074 --- /dev/null +++ b/security-utils/src/test/java/com/yahoo/security/tls/AutoReloadingX509KeyManagerTest.java @@ -0,0 +1,84 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls; + +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.SignatureAlgorithm; +import com.yahoo.security.X509CertificateBuilder; +import com.yahoo.security.X509CertificateUtils; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import javax.security.auth.x500.X500Principal; +import java.io.IOException; +import java.math.BigInteger; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.KeyPair; +import java.security.Principal; +import java.security.cert.X509Certificate; +import java.time.Instant; +import java.util.concurrent.ScheduledExecutorService; + +import static java.time.temporal.ChronoUnit.DAYS; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Mockito.verify; + +/** + * @author bjorncs + */ +public class AutoReloadingX509KeyManagerTest { + private static final X500Principal SUBJECT = new X500Principal("CN=dummy"); + + @Rule + public TemporaryFolder tempDirectory = new TemporaryFolder(); + + @Test + public void crypto_material_is_reloaded_when_scheduler_task_is_executed() throws IOException { + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC); + Path privateKeyFile = tempDirectory.newFile().toPath(); + Files.writeString(privateKeyFile, KeyUtils.toPem(keyPair.getPrivate())); + + Path certificateFile = tempDirectory.newFile().toPath(); + BigInteger serialNumberInitialCertificate = BigInteger.ONE; + X509Certificate initialCertificate = generateCertificate(keyPair, serialNumberInitialCertificate); + Files.writeString(certificateFile, X509CertificateUtils.toPem(initialCertificate)); + + ScheduledExecutorService scheduler = Mockito.mock(ScheduledExecutorService.class); + ArgumentCaptor<Runnable> updaterTaskCaptor = ArgumentCaptor.forClass(Runnable.class); + + AutoReloadingX509KeyManager keyManager = new AutoReloadingX509KeyManager(privateKeyFile, certificateFile, scheduler); + verify(scheduler).scheduleAtFixedRate(updaterTaskCaptor.capture(), anyLong(), anyLong(), any()); + + String[] initialAliases = keyManager.getClientAliases(keyPair.getPublic().getAlgorithm(), new Principal[]{SUBJECT}); + X509Certificate[] certChain = keyManager.getCertificateChain(initialAliases[0]); + assertThat(certChain).hasSize(1); + assertThat(certChain[0].getSerialNumber()).isEqualTo(serialNumberInitialCertificate); + + BigInteger serialNumberUpdatedCertificate = BigInteger.TWO; + X509Certificate updatedCertificate = generateCertificate(keyPair, serialNumberUpdatedCertificate); + Files.writeString(certificateFile, X509CertificateUtils.toPem(updatedCertificate)); + + updaterTaskCaptor.getValue().run(); // run update task in ReloadingX509KeyManager + + String[] updatedAliases = keyManager.getClientAliases(keyPair.getPublic().getAlgorithm(), new Principal[]{SUBJECT}); + X509Certificate[] updatedCertChain = keyManager.getCertificateChain(updatedAliases[0]); + assertThat(updatedCertChain).hasSize(1); + assertThat(updatedCertChain[0].getSerialNumber()).isEqualTo(serialNumberUpdatedCertificate); + } + + private static X509Certificate generateCertificate(KeyPair keyPair, BigInteger serialNumber) { + return X509CertificateBuilder.fromKeypair(keyPair, + SUBJECT, + Instant.EPOCH, + Instant.EPOCH.plus(1, DAYS), + SignatureAlgorithm.SHA256_WITH_ECDSA, + serialNumber) + .build(); + } +}
\ No newline at end of file diff --git a/security-utils/src/test/java/com/yahoo/security/tls/MutableX509KeyManagerTest.java b/security-utils/src/test/java/com/yahoo/security/tls/MutableX509KeyManagerTest.java new file mode 100644 index 00000000000..30e54d3c09d --- /dev/null +++ b/security-utils/src/test/java/com/yahoo/security/tls/MutableX509KeyManagerTest.java @@ -0,0 +1,65 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls; + +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyStoreType; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.SignatureAlgorithm; +import com.yahoo.security.X509CertificateBuilder; +import org.junit.Test; + +import javax.security.auth.x500.X500Principal; +import java.math.BigInteger; +import java.security.KeyPair; +import java.security.KeyStore; +import java.security.Principal; +import java.security.cert.X509Certificate; +import java.time.Instant; + +import static java.time.temporal.ChronoUnit.DAYS; +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author bjorncs + */ +public class MutableX509KeyManagerTest { + + private static final X500Principal SUBJECT = new X500Principal("CN=dummy"); + + @Test + public void key_manager_can_be_updated_with_new_certificate() { + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC); + + BigInteger serialNumberInitialCertificate = BigInteger.ONE; + KeyStore initialKeystore = generateKeystore(keyPair, serialNumberInitialCertificate); + + MutableX509KeyManager keyManager = new MutableX509KeyManager(initialKeystore, new char[0]); + + String[] initialAliases = keyManager.getClientAliases(keyPair.getPublic().getAlgorithm(), new Principal[]{SUBJECT}); + assertThat(initialAliases).hasSize(1); + X509Certificate[] certChain = keyManager.getCertificateChain(initialAliases[0]); + assertThat(certChain).hasSize(1); + assertThat(certChain[0].getSerialNumber()).isEqualTo(serialNumberInitialCertificate); + + BigInteger serialNumberUpdatedCertificate = BigInteger.TWO; + KeyStore updatedKeystore = generateKeystore(keyPair, serialNumberUpdatedCertificate); + keyManager.updateKeystore(updatedKeystore, new char[0]); + + String[] updatedAliases = keyManager.getClientAliases(keyPair.getPublic().getAlgorithm(), new Principal[]{SUBJECT}); + assertThat(updatedAliases).hasSize(1); + X509Certificate[] updatedCertChain = keyManager.getCertificateChain(updatedAliases[0]); + assertThat(updatedCertChain).hasSize(1); + assertThat(updatedCertChain[0].getSerialNumber()).isEqualTo(serialNumberUpdatedCertificate); + } + + private static KeyStore generateKeystore(KeyPair keyPair, BigInteger serialNumber) { + X509Certificate certificate = X509CertificateBuilder.fromKeypair( + keyPair, SUBJECT, Instant.EPOCH, Instant.EPOCH.plus(1, DAYS), SignatureAlgorithm.SHA256_WITH_ECDSA, serialNumber) + .build(); + return KeyStoreBuilder.withType(KeyStoreType.PKCS12) + .withKeyEntry("default", keyPair.getPrivate(), certificate) + .build(); + } + +}
\ No newline at end of file diff --git a/security-utils/src/test/java/com/yahoo/security/tls/MutableX509TrustManagerTest.java b/security-utils/src/test/java/com/yahoo/security/tls/MutableX509TrustManagerTest.java new file mode 100644 index 00000000000..4c4ea332818 --- /dev/null +++ b/security-utils/src/test/java/com/yahoo/security/tls/MutableX509TrustManagerTest.java @@ -0,0 +1,59 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls; + +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyStoreType; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.SignatureAlgorithm; +import com.yahoo.security.X509CertificateBuilder; +import org.junit.Test; + +import javax.security.auth.x500.X500Principal; +import java.math.BigInteger; +import java.security.KeyPair; +import java.security.KeyStore; +import java.security.cert.X509Certificate; +import java.time.Instant; + +import static java.time.temporal.ChronoUnit.DAYS; +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author bjorncs + */ +public class MutableX509TrustManagerTest { + + @Test + public void key_manager_can_be_updated_with_new_certificate() { + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC); + + X509Certificate initialCertificate = generateCertificate(new X500Principal("CN=issuer1"), keyPair); + KeyStore initialTruststore = generateTruststore(initialCertificate); + + MutableX509TrustManager trustManager = new MutableX509TrustManager(initialTruststore); + + X509Certificate[] initialAcceptedIssuers = trustManager.getAcceptedIssuers(); + assertThat(initialAcceptedIssuers).containsExactly(initialCertificate); + + X509Certificate updatedCertificate = generateCertificate(new X500Principal("CN=issuer2"), keyPair); + KeyStore updatedTruststore = generateTruststore(updatedCertificate); + trustManager.updateTruststore(updatedTruststore); + + X509Certificate[] updatedAcceptedIssuers = trustManager.getAcceptedIssuers(); + assertThat(updatedAcceptedIssuers).containsExactly(updatedCertificate); + } + + private static X509Certificate generateCertificate(X500Principal issuer, KeyPair keyPair) { + return X509CertificateBuilder.fromKeypair( + keyPair, issuer, Instant.EPOCH, Instant.EPOCH.plus(1, DAYS), SignatureAlgorithm.SHA256_WITH_ECDSA, BigInteger.ONE) + .build(); + } + + private static KeyStore generateTruststore(X509Certificate certificate) { + return KeyStoreBuilder.withType(KeyStoreType.PKCS12) + .withCertificateEntry("default", certificate) + .build(); + } + +}
\ No newline at end of file diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/model/ServiceModelCache.java b/service-monitor/src/main/java/com/yahoo/vespa/service/model/ServiceModelCache.java index 7a6f37b2c94..c50f5e6c2d5 100644 --- a/service-monitor/src/main/java/com/yahoo/vespa/service/model/ServiceModelCache.java +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/model/ServiceModelCache.java @@ -46,10 +46,12 @@ public class ServiceModelCache implements Supplier<ServiceModel> { updatePossiblyInProgress = true; } - takeSnapshot(); - - synchronized (updateMonitor) { - updatePossiblyInProgress = false; + try { + takeSnapshot(); + } finally { + synchronized (updateMonitor) { + updatePossiblyInProgress = false; + } } } diff --git a/vespa-athenz/pom.xml b/vespa-athenz/pom.xml index 27b68fbf360..0f23eaed964 100644 --- a/vespa-athenz/pom.xml +++ b/vespa-athenz/pom.xml @@ -114,7 +114,24 @@ <groupId>org.apache.httpcomponents</groupId> <artifactId>httpclient</artifactId> </dependency> - + <dependency> + <groupId>com.amazonaws</groupId> + <artifactId>aws-java-sdk-core</artifactId> + <exclusions> + <exclusion> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-core</artifactId> + </exclusion> + <exclusion> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-databind</artifactId> + </exclusion> + <exclusion> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-annotations</artifactId> + </exclusion> + </exclusions> + </dependency> </dependencies> <build> diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/aws/AwsCredentialsProvider.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/aws/AwsCredentialsProvider.java new file mode 100644 index 00000000000..28f028832b4 --- /dev/null +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/aws/AwsCredentialsProvider.java @@ -0,0 +1,79 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.athenz.client.aws; + +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.BasicSessionCredentials; +import com.yahoo.vespa.athenz.api.AthenzDomain; +import com.yahoo.vespa.athenz.api.AwsRole; +import com.yahoo.vespa.athenz.api.AwsTemporaryCredentials; +import com.yahoo.vespa.athenz.client.zts.DefaultZtsClient; +import com.yahoo.vespa.athenz.client.zts.ZtsClient; +import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider; + +import javax.net.ssl.SSLContext; +import java.net.URI; +import java.time.Duration; +import java.time.Instant; +import java.util.Objects; +import java.util.logging.Logger; + +/** + * Implementation of AWSCredentialsProvider using com.yahoo.vespa.athenz.client.zts.ZtsClient + * + * @author mortent + */ +public class AwsCredentialsProvider implements AWSCredentialsProvider { + + private static final Logger logger = Logger.getLogger(AwsCredentialsProvider.class.getName()); + + private final static Duration MIN_EXPIRY = Duration.ofMinutes(5); + private final AthenzDomain athenzDomain; + private final AwsRole awsRole; + private final ZtsClient ztsClient; + private volatile AwsTemporaryCredentials credentials; + + public AwsCredentialsProvider(ZtsClient ztsClient, AthenzDomain athenzDomain, AwsRole awsRole) { + this.ztsClient = ztsClient; + this.athenzDomain = athenzDomain; + this.awsRole = awsRole; + this.credentials = getAthenzTempCredentials(); + } + + public AwsCredentialsProvider(URI ztsUrl, ServiceIdentityProvider identityProvider, AthenzDomain athenzDomain, AwsRole awsRole) { + this(new DefaultZtsClient(ztsUrl, identityProvider), athenzDomain, awsRole); + } + + public AwsCredentialsProvider(URI ztsUrl, SSLContext sslContext, AthenzDomain athenzDomain, AwsRole awsRole) { + this(new DefaultZtsClient(ztsUrl, null, sslContext), athenzDomain, awsRole); + } + + /** + * Requests temporary credentials from ZTS or return cached credentials + */ + private AwsTemporaryCredentials getAthenzTempCredentials() { + if(shouldRefresh(credentials)) { + this.credentials = ztsClient.getAwsTemporaryCredentials(athenzDomain, awsRole); + } + return credentials; + } + + @Override + public AWSCredentials getCredentials() { + AwsTemporaryCredentials creds = getAthenzTempCredentials(); + return new BasicSessionCredentials(creds.accessKeyId(), creds.secretAccessKey(), creds.sessionToken()); + } + + @Override + public void refresh() { + getAthenzTempCredentials(); + } + + /* + * Checks credential expiration, returns true if it will expipre in the next MIN_EXPIRY minutes + */ + private static boolean shouldRefresh(AwsTemporaryCredentials credentials) { + Instant expiration = credentials.expiration(); + return Objects.isNull(expiration) || expiration.minus(MIN_EXPIRY).isAfter(Instant.now()); + } +} diff --git a/vespa-hadoop/abi-spec.json b/vespa-hadoop/abi-spec.json index 5bbac15f0e5..e3f4dcf272a 100644 --- a/vespa-hadoop/abi-spec.json +++ b/vespa-hadoop/abi-spec.json @@ -1201,6 +1201,8 @@ "public com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.DimensionSizes dimensionSizes()", "public java.util.Map cells()", + "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)", @@ -1245,6 +1247,8 @@ "public java.util.Iterator valueIterator()", "public java.util.Map cells()", "public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", + "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)" @@ -1330,6 +1334,8 @@ "public java.util.Iterator valueIterator()", "public java.util.Map cells()", "public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", + "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)", @@ -1432,6 +1438,8 @@ "public double asDouble()", "public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public abstract com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public abstract com.yahoo.tensor.Tensor remove(java.util.Set)", "public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)", "public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])", "public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)", diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 932513f8a57..c3fe8c5c7ad 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -808,6 +808,8 @@ "public com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.DimensionSizes dimensionSizes()", "public java.util.Map cells()", + "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)", @@ -852,6 +854,8 @@ "public java.util.Iterator valueIterator()", "public java.util.Map cells()", "public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", + "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)" @@ -937,6 +941,8 @@ "public java.util.Iterator valueIterator()", "public java.util.Map cells()", "public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", + "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)", @@ -1039,6 +1045,8 @@ "public double asDouble()", "public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public abstract com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public abstract com.yahoo.tensor.Tensor remove(java.util.Set)", "public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)", "public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])", "public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index fb55b2d5014..38d832d01c2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -13,6 +13,7 @@ import java.util.Map; import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; +import java.util.function.DoubleBinaryOperator; /** * An indexed (dense) tensor backed by a double array. @@ -190,6 +191,16 @@ public class IndexedTensor implements Tensor { } @Override + public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells) { + throw new IllegalArgumentException("Merge is not supported for indexed tensors"); + } + + @Override + public Tensor remove(Set<TensorAddress> addresses) { + throw new IllegalArgumentException("Remove is not supported for indexed tensors"); + } + + @Override public int hashCode() { return Arrays.hashCode(values); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index ec3020a1a4e..22ceed22d3e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -5,6 +5,8 @@ import com.google.common.collect.ImmutableMap; import java.util.Iterator; import java.util.Map; +import java.util.Set; +import java.util.function.DoubleBinaryOperator; /** * A sparse implementation of a tensor backed by a Map of cells to values. @@ -51,6 +53,38 @@ public class MappedTensor implements Tensor { } @Override + public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) { + + // currently, underlying implementation disallows multiple entries with the same key + + Tensor.Builder builder = Tensor.Builder.of(type()); + for (Map.Entry<TensorAddress, Double> cell : cells.entrySet()) { + TensorAddress address = cell.getKey(); + double value = cell.getValue(); + builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value); + } + for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { + if ( ! cells.containsKey(addCell.getKey())) { + builder.cell(addCell.getKey(), addCell.getValue()); + } + } + return builder.build(); + } + + @Override + public Tensor remove(Set<TensorAddress> addresses) { + Tensor.Builder builder = Tensor.Builder.of(type()); + for (Iterator<Tensor.Cell> i = cellIterator(); i.hasNext(); ) { + Tensor.Cell cell = i.next(); + TensorAddress address = cell.getKey(); + if ( ! addresses.contains(address)) { + builder.cell(address, cell.getValue()); + } + } + return builder.build(); + } + + @Override public int hashCode() { return cells.hashCode(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 17e33c58a13..08878edeb83 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -9,6 +9,8 @@ import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.function.DoubleBinaryOperator; import java.util.stream.Collectors; /** @@ -70,13 +72,17 @@ public class MixedTensor implements Tensor { return cells.iterator(); } + private Iterable<Cell> cellIterable() { + return this::cellIterator; + } + /** * Returns an iterator over the values of this tensor. * The iteration order is the same as for cellIterator. */ @Override public Iterator<Double> valueIterator() { - return new Iterator<Double>() { + return new Iterator<>() { Iterator<Cell> cellIterator = cellIterator(); @Override public boolean hasNext() { @@ -108,6 +114,38 @@ public class MixedTensor implements Tensor { } @Override + public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) { + Tensor.Builder builder = Tensor.Builder.of(type()); + for (Cell cell : cellIterable()) { + TensorAddress address = cell.getKey(); + double value = cell.getValue(); + builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value); + } + for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { + builder.cell(addCell.getKey(), addCell.getValue()); + } + return builder.build(); + } + + @Override + public Tensor remove(Set<TensorAddress> addresses) { + Tensor.Builder builder = Tensor.Builder.of(type()); + + // iterate through all sparse addresses referencing a dense subspace + for (Map.Entry<TensorAddress, Long> entry : index.sparseMap.entrySet()) { + TensorAddress sparsePartialAddress = entry.getKey(); + if ( ! addresses.contains(sparsePartialAddress)) { // assumption: addresses only contain the sparse part + long offset = entry.getValue(); + for (int i = 0; i < index.denseSubspaceSize; ++i) { + Cell cell = cells.get((int)offset + i); + builder.cell(cell.getKey(), cell.getValue()); + } + } + } + return builder.build(); + } + + @Override public int hashCode() { return cells.hashCode(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 8002990e5c6..eb16801c306 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -25,6 +25,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; import java.util.function.Function; @@ -113,6 +114,29 @@ public interface Tensor { return builder.build(); } + /** + * Returns a new tensor where existing cells in this tensor have been + * modified according to the given operation and cells in the given map. + * In contrast to {@link #modify}, previously non-existing cells are added + * to this tensor. Only valid for sparse or mixed tensors. + * + * @param op how to update overlapping cells + * @param cells cells to merge with this tensor + * @return a new tensor where this tensor is merged with the other + */ + Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells); + + /** + * Returns a new tensor where existing cells in this tensor have been + * removed according to the given set of addresses. Only valid for sparse + * or mixed tensors. For mixed tensors, addresses are assumed to only + * contain the sparse dimensions, as the entire dense subspace is removed. + * + * @param addresses list of addresses to remove + * @return a new tensor where cells have been removed + */ + Tensor remove(Set<TensorAddress> addresses); + // ----------------- Primitive tensor functions default Tensor map(DoubleUnaryOperator mapper) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 2c9eefbd130..02d16e6f3e4 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -151,12 +151,106 @@ public class TensorTestCase { Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:1, {x:0,y:1}:2}"), Tensor.from("tensor(x[1],y[3])", "{}"), Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:0,{x:0,y:1}:0}")); + assertTensorModify((left, right) -> left * right, + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:6}")); + } + + @Test + public void testTensorMerge() { + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:2}:3}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}")); + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:3}")); + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3,{x:0,y:2}:4}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}")); + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:5}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:5}")); + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}")); + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:2}:3}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:0,{x:0,y:2}:3}")); // notice difference with sparse case - y is dense dimension here with default value 0.0 + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:0}")); + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:1}:3,{x:0,y:2}:4}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:4}")); + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:5}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:5}")); + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y[3])", "{}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}")); + } + + @Test + public void testTensorRemove() { + assertTensorRemove( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:1}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2}")); + assertTensorRemove( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:1}"), + Tensor.from("tensor(x{},y{})", "{}")); + assertTensorRemove( + Tensor.from("tensor(x{},y{})", "{}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1}"), + Tensor.from("tensor(x{},y{})", "{}")); + assertTensorRemove( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y{})", "{}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}")); + assertTensorRemove( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2, {x:0,y:1}:3}"), + Tensor.from("tensor(x{})", "{{x:0}:1}"), // notice update is without dense dimension + Tensor.from("tensor(x{},y[3])", "{}")); + assertTensorRemove( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:1,y:0}:2}"), + Tensor.from("tensor(x{})", "{{x:0}:1}"), + Tensor.from("tensor(x{},y[3])", "{{x:1,y:0}:2,{x:1,y:1}:0,{x:1,y:2}:0}")); + assertTensorRemove( + Tensor.from("tensor(x{},y[3])", "{}"), + Tensor.from("tensor(x{})", "{{x:0}:1}"), + Tensor.from("tensor(x{},y[3])", "{}")); + assertTensorRemove( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}"), + Tensor.from("tensor(x{})", "{}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}")); } private void assertTensorModify(DoubleBinaryOperator op, Tensor init, Tensor update, Tensor expected) { assertEquals(expected, init.modify(op, update.cells())); } + private void assertTensorMerge(Tensor init, Tensor update, Tensor expected) { + DoubleBinaryOperator op = (left, right) -> right; + assertEquals(expected, init.merge(op, update.cells())); + } + + private void assertTensorRemove(Tensor init, Tensor update, Tensor expected) { + assertEquals(expected, init.remove(update.cells().keySet())); + } + + private double dotProduct(Tensor tensor, List<Tensor> tensors) { double sum = 0; TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), |