summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/AbstractService.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ConfigProxy.java5
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ConfigSentinel.java7
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/HostResource.java60
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/HostSystem.java20
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/Logd.java5
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/NetworkPortRequestor.java53
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/Service.java35
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/admin/Configserver.java5
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/admin/LogForwarder.java5
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/admin/Logserver.java5
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/admin/Slobrok.java5
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/Container.java30
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/content/ContentNode.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/generic/service/Service.java5
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/search/Dispatch.java6
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/search/SearchNode.java5
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/search/TransactionLogServer.java5
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/HostResourceTest.java15
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ConfigValueChangeValidatorTest.java3
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/StartupCommandChangeValidatorTest.java3
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/test/ApiService.java2
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/test/ModelAmendingTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/test/ParentService.java3
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/test/SimpleService.java5
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/utils/FileSenderTest.java3
-rw-r--r--config-provisioning/abi-spec.json46
-rw-r--r--config-provisioning/src/main/java/com/yahoo/config/provision/AllocatedHosts.java7
-rw-r--r--config-provisioning/src/main/java/com/yahoo/config/provision/HostSpec.java14
-rw-r--r--config-provisioning/src/main/java/com/yahoo/config/provision/NetworkPorts.java55
-rw-r--r--config-provisioning/src/main/java/com/yahoo/config/provision/NetworkPortsSerializer.java56
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/provision/ProvisionerAdapter.java1
-rw-r--r--configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionTest.java14
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/JobType.java4
-rw-r--r--document/abi-spec.json3
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java14
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java37
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java27
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java14
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java26
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java2
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java38
-rw-r--r--document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java63
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java92
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java8
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java17
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java3
-rw-r--r--document/src/tests/documentupdatetestcase.cpp11
-rw-r--r--document/src/vespa/document/update/tensor_remove_update.cpp30
-rw-r--r--document/src/vespa/document/update/valueupdate.h3
-rw-r--r--documentapi/CMakeLists.txt1
-rw-r--r--documentapi/src/vespa/binref/.gitignore3
-rw-r--r--documentapi/src/vespa/binref/CMakeLists.txt1
-rw-r--r--jrt/src/com/yahoo/jrt/TlsCryptoEngine.java1
l---------jrt_test/src/binref/testrun.sh1
l---------lowercasing_test/src/binref/testrun.sh1
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/Node.java1
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeRepositoryMaintenance.java3
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Allocation.java29
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/NodeSerializer.java9
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java26
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java14
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/v2/NodesResponse.java18
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/ContainerConfig.java2
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/SerializationTest.java29
-rw-r--r--parent/pom.xml6
-rw-r--r--searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp67
-rw-r--r--searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp3
-rw-r--r--security-utils/pom.xml5
-rw-r--r--security-utils/src/main/java/com/yahoo/security/KeyStoreType.java2
-rw-r--r--security-utils/src/main/java/com/yahoo/security/SslContextBuilder.java66
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/AutoReloadingX509KeyManager.java150
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/DefaultTlsContext.java90
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/KeyManagerUtils.java49
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/MutableX509KeyManager.java106
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/MutableX509TrustManager.java70
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/ReloadingTlsContext.java98
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/TlsContext.java7
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/TrustManagerUtils.java50
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManager.java22
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManagersFactory.java8
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClient.java101
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClientBuilder.java97
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpRequest.java103
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/https/package-info.java8
-rw-r--r--security-utils/src/test/java/com/yahoo/security/tls/AutoReloadingX509KeyManagerTest.java84
-rw-r--r--security-utils/src/test/java/com/yahoo/security/tls/MutableX509KeyManagerTest.java65
-rw-r--r--security-utils/src/test/java/com/yahoo/security/tls/MutableX509TrustManagerTest.java59
-rw-r--r--service-monitor/src/main/java/com/yahoo/vespa/service/model/ServiceModelCache.java10
-rw-r--r--vespa-athenz/pom.xml19
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/aws/AwsCredentialsProvider.java79
-rw-r--r--vespa-hadoop/abi-spec.json8
-rw-r--r--vespajlib/abi-spec.json8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java34
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java40
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java24
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java94
100 files changed, 2319 insertions, 354 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/AbstractService.java b/config-model/src/main/java/com/yahoo/vespa/model/AbstractService.java
index 60a49598c42..daa237d90c1 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/AbstractService.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/AbstractService.java
@@ -482,8 +482,8 @@ public abstract class AbstractService extends AbstractConfigProducer<AbstractCon
* Must be done this way since the system test framework
* currently uses the first port as container http port.
*/
- public void reservePortPrepended(int port) {
- hostResource.reservePort(this, port);
+ public void reservePortPrepended(int port, String suffix) {
+ hostResource.reservePort(this, port, suffix);
ports.add(0, port);
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ConfigProxy.java b/config-model/src/main/java/com/yahoo/vespa/model/ConfigProxy.java
index c540a5f62d2..7ab28faa434 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ConfigProxy.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ConfigProxy.java
@@ -47,6 +47,11 @@ public class ConfigProxy extends AbstractService {
*/
public int getPortCount() { return 1; }
+ @Override
+ public String[] getPortSuffixes() {
+ return new String[]{"rpc"};
+ }
+
/**
* The config proxy is not started by the config system!
*/
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ConfigSentinel.java b/config-model/src/main/java/com/yahoo/vespa/model/ConfigSentinel.java
index cd92f27cc50..1b5c5e4a579 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ConfigSentinel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ConfigSentinel.java
@@ -26,7 +26,7 @@ public class ConfigSentinel extends AbstractService implements SentinelConfig.Pr
super(host, "sentinel");
this.applicationId = applicationId;
this.zone = zone;
- portsMeta.on(0).tag("rpc").tag("notyet");
+ portsMeta.on(0).tag("rpc").tag("admin");
portsMeta.on(1).tag("telnet").tag("interactive").tag("http").tag("state");
setProp("clustertype", "hosts");
setProp("clustername", "admin");
@@ -48,6 +48,11 @@ public class ConfigSentinel extends AbstractService implements SentinelConfig.Pr
public int getPortCount() { return 2; }
@Override
+ public String[] getPortSuffixes() {
+ return new String[]{ "rpc", "http" };
+ }
+
+ @Override
public int getHealthPort() {return getRelativePort(1); }
/**
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/HostResource.java b/config-model/src/main/java/com/yahoo/vespa/model/HostResource.java
index a27b33173ee..3a8cc5c2e4c 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/HostResource.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/HostResource.java
@@ -6,6 +6,7 @@ import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.model.api.HostInfo;
import com.yahoo.config.provision.ClusterMembership;
import com.yahoo.config.provision.Flavor;
+import com.yahoo.config.provision.NetworkPorts;
import javax.annotation.Nullable;
import java.util.ArrayList;
@@ -36,10 +37,23 @@ public class HostResource implements Comparable<HostResource> {
/** Map from "sentinel name" to service */
private final Map<String, Service> services = new LinkedHashMap<>();
- private final Map<Integer, Service> portDB = new LinkedHashMap<>();
+ private final Map<Integer, NetworkPortRequestor> portDB = new LinkedHashMap<>();
private int allocatedPorts = 0;
+ static class PortReservation {
+ int gotPort;
+ NetworkPortRequestor service;
+ String suffix;
+ PortReservation(int port, NetworkPortRequestor svc, String suf) {
+ this.gotPort = port;
+ this.service = svc;
+ this.suffix = suf;
+ }
+ }
+
+ private List<PortReservation> portReservations = new ArrayList<>();
+
private Set<ClusterMembership> clusterMemberships = new LinkedHashSet<>();
// Empty for self-hosted Vespa.
@@ -48,6 +62,12 @@ public class HostResource implements Comparable<HostResource> {
/** The current Vespa version running on this node, or empty if not known */
private final Optional<Version> version;
+ private Optional<NetworkPorts> networkPortsList = Optional.empty();
+
+ public Optional<NetworkPorts> networkPorts() { return networkPortsList; }
+
+ public void addNetworkPorts(NetworkPorts ports) { this.networkPortsList = Optional.of(ports); }
+
/**
* Create a new {@link HostResource} bound to a specific {@link com.yahoo.vespa.model.Host}.
*
@@ -108,22 +128,22 @@ public class HostResource implements Comparable<HostResource> {
return ports;
}
- private List<Integer> allocatePorts(DeployLogger deployLogger, AbstractService service, int wantedPort) {
+ private List<Integer> allocatePorts(DeployLogger deployLogger, NetworkPortRequestor service, int wantedPort) {
List<Integer> ports = new ArrayList<>();
if (service.getPortCount() < 1)
return ports;
int serviceBasePort = BASE_PORT + allocatedPorts;
if (wantedPort > 0) {
- if (service.getPortCount() < 1) {
- throw new RuntimeException(service + " wants baseport " + wantedPort +
- ", but it has not reserved any ports, so it cannot name a desired baseport.");
- }
if (service.requiresWantedPort() || canUseWantedPort(deployLogger, service, wantedPort, serviceBasePort))
serviceBasePort = wantedPort;
}
+ String[] suffixes = service.getPortSuffixes();
+ if (suffixes.length != service.getPortCount()) {
+ throw new IllegalArgumentException("service "+service+" had "+suffixes.length+" port suffixes, but port count "+service.getPortCount()+", mismatch");
+ }
- reservePort(service, serviceBasePort);
+ reservePort(service, serviceBasePort, suffixes[0]);
ports.add(serviceBasePort);
int remainingPortsStart = service.requiresConsecutivePorts() ?
@@ -131,17 +151,30 @@ public class HostResource implements Comparable<HostResource> {
BASE_PORT + allocatedPorts;
for (int i = 0; i < service.getPortCount() - 1; i++) {
int port = remainingPortsStart + i;
- reservePort(service, port);
+ reservePort(service, port, suffixes[i+1]);
ports.add(port);
}
+ if (suffixes.length != service.getPortCount()) {
+ throw new IllegalArgumentException("service "+service+" had "+suffixes.length+" port suffixes, but port count "+service.getPortCount()+", mismatch");
+ }
return ports;
}
- private boolean canUseWantedPort(DeployLogger deployLogger, AbstractService service, int wantedPort, int serviceBasePort) {
+ public void flushPortReservations() {
+ List<NetworkPorts.Allocation> list = new ArrayList<>();
+ for (PortReservation pr : portReservations) {
+ String servType = pr.service.getServiceType();
+ String configId = pr.service.getConfigId();
+ list.add(new NetworkPorts.Allocation(pr.gotPort, servType, configId, pr.suffix));
+ }
+ this.networkPortsList = Optional.of(new NetworkPorts(list));
+ }
+
+ private boolean canUseWantedPort(DeployLogger deployLogger, NetworkPortRequestor service, int wantedPort, int serviceBasePort) {
for (int i = 0; i < service.getPortCount(); i++) {
int port = wantedPort + i;
if (portDB.containsKey(port)) {
- AbstractService s = (AbstractService)portDB.get(port);
+ NetworkPortRequestor s = portDB.get(port);
deployLogger.log(Level.WARNING, service.getServiceName() +" cannot reserve port " + port + " on " +
this + ": Already reserved for " + s.getServiceName() +
". Using default port range from " + serviceBasePort);
@@ -159,7 +192,7 @@ public class HostResource implements Comparable<HostResource> {
* @param service the service that wishes to reserve the port.
* @param port the port to be reserved.
*/
- void reservePort(AbstractService service, int port) {
+ void reservePort(NetworkPortRequestor service, int port, String suffix) {
if (portDB.containsKey(port)) {
portAlreadyReserved(service, port);
} else {
@@ -170,6 +203,7 @@ public class HostResource implements Comparable<HostResource> {
}
}
portDB.put(port, service);
+ portReservations.add(new PortReservation(port, service, suffix));
}
}
@@ -178,8 +212,8 @@ public class HostResource implements Comparable<HostResource> {
port < BASE_PORT + MAX_PORTS;
}
- private void portAlreadyReserved(AbstractService service, int port) {
- AbstractService otherService = (AbstractService)portDB.get(port);
+ private void portAlreadyReserved(NetworkPortRequestor service, int port) {
+ NetworkPortRequestor otherService = portDB.get(port);
int nextAvailablePort = nextAvailableBaseport(service.getPortCount());
if (nextAvailablePort == 0) {
noMoreAvailablePorts();
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/HostSystem.java b/config-model/src/main/java/com/yahoo/vespa/model/HostSystem.java
index fdfe4f01790..a1b030ffc61 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/HostSystem.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/HostSystem.java
@@ -8,6 +8,7 @@ import com.yahoo.config.provision.Capacity;
import com.yahoo.config.provision.ClusterMembership;
import com.yahoo.config.provision.ClusterSpec;
import com.yahoo.config.provision.HostSpec;
+import com.yahoo.config.provision.NetworkPorts;
import com.yahoo.config.provision.ProvisionLogger;
import java.net.UnknownHostException;
@@ -126,9 +127,11 @@ public class HostSystem extends AbstractConfigProducer<Host> {
private HostResource addNewHost(HostSpec hostSpec) {
Host host = Host.createHost(this, hostSpec.hostname());
- HostResource hostResource = new HostResource(host, hostSpec.version());
+ HostResource hostResource = new HostResource(host,
+ hostSpec.version());
hostResource.setFlavor(hostSpec.flavor());
hostSpec.membership().ifPresent(hostResource::addClusterMembership);
+ hostSpec.networkPorts().ifPresent(hostResource::addNetworkPorts);
hostname2host.put(host.getHostname(), hostResource);
log.log(DEBUG, () -> "Added new host resource for " + host.getHostname() + " with flavor " + hostResource.getFlavor());
return hostResource;
@@ -141,6 +144,19 @@ public class HostSystem extends AbstractConfigProducer<Host> {
.collect(Collectors.toList());
}
+ public void dumpPortAllocations() {
+ for (HostResource hr : getHosts()) {
+ hr.flushPortReservations();
+/*
+ System.out.println("port allocations for: "+hr.getHostname());
+ NetworkPorts ports = hr.networkPorts().get();
+ for (NetworkPorts.Allocation allocation: ports.allocations()) {
+ System.out.println("port="+allocation.port+" [type="+allocation.serviceType+", cfgId="+allocation.configId+", suffix="+allocation.portSuffix+"]");
+ }
+*/
+ }
+ }
+
public Map<HostResource, ClusterMembership> allocateHosts(ClusterSpec cluster, Capacity capacity, int groups, DeployLogger logger) {
List<HostSpec> allocatedHosts = provisioner.prepare(cluster, capacity, groups, new ProvisionDeployLogger(logger));
// TODO: Even if HostResource owns a set of memberships, we need to return a map because the caller needs the current membership.
@@ -177,7 +193,7 @@ public class HostSystem extends AbstractConfigProducer<Host> {
Set<HostSpec> getHostSpecs() {
return getHosts().stream()
.map(host -> new HostSpec(host.getHostname(), Collections.emptyList(),
- host.getFlavor(), host.primaryClusterMembership(), host.version()))
+ host.getFlavor(), host.primaryClusterMembership(), host.version(), host.networkPorts()))
.collect(Collectors.toCollection(LinkedHashSet::new));
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/Logd.java b/config-model/src/main/java/com/yahoo/vespa/model/Logd.java
index 0f7418582a3..3c7f1ba6cfa 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/Logd.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/Logd.java
@@ -32,6 +32,11 @@ public class Logd
*/
public int getPortCount() { return 1; }
+ @Override
+ public String[] getPortSuffixes() {
+ return new String[]{"http"};
+ }
+
/** Returns the desired base port for this service. */
public int getWantedPort() { return 19089; }
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/NetworkPortRequestor.java b/config-model/src/main/java/com/yahoo/vespa/model/NetworkPortRequestor.java
new file mode 100644
index 00000000000..52319f71810
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/NetworkPortRequestor.java
@@ -0,0 +1,53 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.vespa.model;
+
+/**
+ * Interface implemented by services using network ports, identifying its requirements.
+ * @author arnej
+ */
+public interface NetworkPortRequestor {
+
+ /** Returns the type of service */
+ String getServiceType();
+
+ /** Returns the name that identifies this service for the config-sentinel */
+ String getServiceName();
+
+ /** Returns the config id */
+ String getConfigId();
+
+ /**
+ * Returns the desired base port for this service, or '0' if this
+ * service should use the default port allocation mechanism.
+ *
+ * @return The desired base port for this service.
+ */
+ default int getWantedPort() { return 0; }
+
+ /** Returns the number of ports needed by this service. */
+ int getPortCount();
+
+ /**
+ * Returns true if the desired base port (returned by
+ * getWantedPort()) for this service is the only allowed base
+ * port.
+ *
+ * @return true if this Service requires the wanted base port.
+ */
+ default boolean requiresWantedPort() { return false; }
+
+ /**
+ * Override if the services does not require consecutive port numbers. I.e. if any ports
+ * in addition to the baseport should be allocated from Vespa's default port range.
+ *
+ * @return true by default
+ */
+ default boolean requiresConsecutivePorts() { return true; }
+
+ /**
+ * Return names for each port requested.
+ * The size of the returned array must be equal to getPortCount().
+ **/
+ String[] getPortSuffixes();
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/Service.java b/config-model/src/main/java/com/yahoo/vespa/model/Service.java
index d5d33a08b5d..0af4355764c 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/Service.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/Service.java
@@ -11,7 +11,7 @@ import java.util.Optional;
*
* @author gjoranv
*/
-public interface Service extends ConfigProducer {
+public interface Service extends ConfigProducer, NetworkPortRequestor {
/**
* Services that should be started by config-sentinel must return
@@ -44,39 +44,6 @@ public interface Service extends ConfigProducer {
boolean getAutorestartFlag();
/**
- * Returns the type of service. E.g. the class-name without the
- * package prefix.
- */
- String getServiceType();
-
- /**
- * Returns the name that identifies this service for the config-sentinel.
- */
- String getServiceName();
-
- /**
- * Returns the desired base port for this service, or '0' if this
- * service should use the default port allocation mechanism.
- *
- * @return The desired base port for this service.
- */
- int getWantedPort();
-
- /**
- * Returns true if the desired base port (returned by
- * getWantedPort()) for this service is the only allowed base
- * port.
- *
- * @return true if this Service requires the wanted base port.
- */
- boolean requiresWantedPort();
-
- /**
- * Returns the number of ports needed by this service.
- */
- int getPortCount();
-
- /**
* Returns a PortsMeta object, giving access to more information
* about the different ports of this service.
*/
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index 20089dc3980..d954c69d144 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -172,7 +172,6 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
if (complete) { // create a a completed, frozen model
configModelRepo.readConfigModels(deployState, this, builder, root, configModelRegistry);
addServiceClusters(deployState, builder);
- this.allocatedHosts = AllocatedHosts.withHosts(hostSystem.getHostSpecs()); // must happen after the two lines above
setupRouting(deployState);
this.fileDistributor = root.getFileDistributionConfigProducer().getFileDistributor();
getAdmin().addPerHostServices(hostSystem.getHosts(), deployState);
@@ -180,6 +179,9 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
root.prepare(configModelRepo);
configModelRepo.prepareConfigModels(deployState);
validateWrapExceptions();
+ hostSystem.dumpPortAllocations();
+ // must happen after stuff above
+ this.allocatedHosts = AllocatedHosts.withHosts(hostSystem.getHostSpecs());
}
else { // create a model with no services instantiated and the given file distributor
this.allocatedHosts = AllocatedHosts.withHosts(hostSystem.getHostSpecs());
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/Configserver.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/Configserver.java
index 2a32549b6bf..a2839ec0fb6 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/admin/Configserver.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/Configserver.java
@@ -51,6 +51,11 @@ public class Configserver extends AbstractService {
*/
public int getPortCount() { return 2; }
+ @Override
+ public String[] getPortSuffixes() {
+ return new String[]{ "rpc", "http" };
+ }
+
/**
* The configserver is not started by the config system!
*/
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/LogForwarder.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/LogForwarder.java
index 2693a4c7409..d766507c75f 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/admin/LogForwarder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/LogForwarder.java
@@ -55,6 +55,11 @@ public class LogForwarder extends AbstractService implements LogforwarderConfig.
*/
public int getPortCount() { return 0; }
+ @Override
+ public String[] getPortSuffixes() {
+ return null;
+ }
+
/**
* @return The command used to start LogForwarder
*/
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/Logserver.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/Logserver.java
index c354445b690..4dcbfb5b3c3 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/admin/Logserver.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/Logserver.java
@@ -70,4 +70,9 @@ public class Logserver extends AbstractService {
return 4;
}
+ @Override
+ public String[] getPortSuffixes() {
+ return new String[]{ "unused", "logtp", "last.errors", "replicator" };
+ }
+
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/Slobrok.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/Slobrok.java
index 12a0d35de5e..99738c13d4a 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/admin/Slobrok.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/Slobrok.java
@@ -56,6 +56,11 @@ public class Slobrok extends AbstractService implements StateserverConfig.Produc
return 2;
}
+ @Override
+ public String[] getPortSuffixes() {
+ return new String[] { "rpc", "http" };
+ }
+
/**
* @return The port on which this slobrok should respond, as a String.
*/
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java
index 88d66f0c608..c7cc04faf95 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java
@@ -93,7 +93,6 @@ public class DomAdminV4Builder extends DomAdminBuilderBase {
}
private NodesSpecification createNodesSpecificationForLogserver() {
- // TODO: Enable for main system as well
DeployState deployState = context.getDeployState();
if (deployState.getProperties().useDedicatedNodeForLogserver() &&
context.getApplicationType() == ConfigModelContext.ApplicationType.DEFAULT &&
@@ -124,6 +123,7 @@ public class DomAdminV4Builder extends DomAdminBuilderBase {
logServerCluster.addContainer(container);
admin.addAndInitializeService(deployState.getDeployLogger(), hostResource, container);
admin.setLogserverContainerCluster(logServerCluster);
+ context.getConfigModelRepoAdder().add(logserverClusterModel);
}
private void addLogHandler(ContainerCluster cluster) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/Container.java b/config-model/src/main/java/com/yahoo/vespa/model/container/Container.java
index dc962ed5931..f61fc3d4df8 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/Container.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/Container.java
@@ -174,7 +174,7 @@ public class Container extends AbstractService implements
private void reserveHttpPortsPrepended() {
if (getHttp().getHttpServer() != null) {
for (ConnectorFactory connectorFactory : getHttp().getHttpServer().getConnectorFactories()) {
- reservePortPrepended(getPort(connectorFactory));
+ reservePortPrepended(getPort(connectorFactory), "http/" + connectorFactory.getName());
}
}
}
@@ -240,6 +240,34 @@ public class Container extends AbstractService implements
return httpPorts + rpcPorts;
}
+ @Override
+ public String[] getPortSuffixes() {
+ // TODO clean up this mess
+ int n = getPortCount();
+ String[] suffixes = new String[n];
+ int off = 0;
+ int httpPorts = (getHttp() != null) ? 0 : numHttpServerPorts;
+ if (httpPorts > 0) {
+ suffixes[off++] = "http";
+ }
+ for (int i = 1; i < httpPorts; i++) {
+ suffixes[off++] = "http/" + i;
+ }
+ int rpcPorts = (rpcServerEnabled()) ? numRpcServerPorts : 0;
+ if (rpcPorts > 0) {
+ suffixes[off++] = "messaging";
+ }
+ if (rpcPorts > 1) {
+ suffixes[off++] = "rpc";
+ }
+ while (off < n) {
+ suffixes[off] = "unused/" + off;
+ ++off;
+ }
+ assert (off == n);
+ return suffixes;
+ }
+
/**
* @return the actual search port
* TODO: Remove. Use {@link #getPortsMeta()} and check tags in conjunction with {@link #getRelativePort(int)}.
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/ContentNode.java b/config-model/src/main/java/com/yahoo/vespa/model/content/ContentNode.java
index c75421c9636..dc9372c463b 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/content/ContentNode.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/content/ContentNode.java
@@ -57,6 +57,10 @@ public abstract class ContentNode extends AbstractService
@Override
public int getPortCount() { return 3; }
+ @Override
+ public String[] getPortSuffixes() {
+ return new String[] { "messaging", "rpc", "http" };
+ }
@Override
public void getConfig(StorCommunicationmanagerConfig.Builder builder) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/generic/service/Service.java b/config-model/src/main/java/com/yahoo/vespa/model/generic/service/Service.java
index 9ccf5103175..0d30bade53c 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/generic/service/Service.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/generic/service/Service.java
@@ -24,6 +24,11 @@ public class Service extends AbstractService {
}
@Override
+ public String[] getPortSuffixes() {
+ return null;
+ }
+
+ @Override
public String getStartupCommand() {
return ((ServiceCluster) getParent()).getCommand();
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/Dispatch.java b/config-model/src/main/java/com/yahoo/vespa/model/search/Dispatch.java
index b9c937b4a4c..9b4fe93d6ea 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/search/Dispatch.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/search/Dispatch.java
@@ -234,4 +234,10 @@ public class Dispatch extends AbstractService implements SearchInterface,
* @return the number of ports needed
*/
public int getPortCount() { return 3; }
+
+ @Override
+ public String[] getPortSuffixes() {
+ return new String[]{ "rpc", "fs4", "health" };
+ }
+
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/SearchNode.java b/config-model/src/main/java/com/yahoo/vespa/model/search/SearchNode.java
index 5be51310504..934184d5972 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/search/SearchNode.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/search/SearchNode.java
@@ -160,6 +160,11 @@ public class SearchNode extends AbstractService implements
return 5;
}
+ @Override
+ public String[] getPortSuffixes() {
+ return new String[] { "rpc", "fs4", "future/4", "unused/3", "health" };
+ }
+
/**
* Returns the RPC port used by this searchnode.
*
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/TransactionLogServer.java b/config-model/src/main/java/com/yahoo/vespa/model/search/TransactionLogServer.java
index 61cac8afb91..c42579085a5 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/search/TransactionLogServer.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/search/TransactionLogServer.java
@@ -38,6 +38,11 @@ public class TransactionLogServer extends AbstractService {
return 1;
}
+ @Override
+ public String[] getPortSuffixes() {
+ return new String[]{"tls"};
+ }
+
/**
* Returns the port used by the TLS.
*
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/HostResourceTest.java b/config-model/src/test/java/com/yahoo/vespa/model/HostResourceTest.java
index abf4ec02a3e..d16bbe72a95 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/HostResourceTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/HostResourceTest.java
@@ -37,7 +37,7 @@ public class HostResourceTest {
public void next_available_baseport_is_BASE_PORT_plus_one_when_one_port_has_been_reserved() {
MockRoot root = new MockRoot();
HostResource host = mockHostResource(root);
- host.reservePort(new TestService(1), HostResource.BASE_PORT);
+ host.reservePort(new TestService(1), HostResource.BASE_PORT, "foo");
assertThat(host.nextAvailableBaseport(1), is(HostResource.BASE_PORT + 1));
}
@@ -47,12 +47,12 @@ public class HostResourceTest {
HostResource host = mockHostResource(root);
for (int p = HostResource.BASE_PORT; p < HostResource.BASE_PORT + HostResource.MAX_PORTS; p += 2) {
- host.reservePort(new TestService(1), p);
+ host.reservePort(new TestService(1), p, "foo");
}
assertThat(host.nextAvailableBaseport(2), is(0));
try {
- host.reservePort(new TestService(2), HostResource.BASE_PORT);
+ host.reservePort(new TestService(2), HostResource.BASE_PORT, "bar");
} catch (RuntimeException e) {
assertThat(e.getMessage(), containsString("Too many ports are reserved"));
}
@@ -181,5 +181,14 @@ public class HostResourceTest {
@Override
public int getPortCount() { return portCount; }
+
+ @Override
+ public String[] getPortSuffixes() {
+ String[] suffixes = new String[portCount];
+ for (int i = 0; i < portCount; i++) {
+ suffixes[i] = "generic." + i;
+ }
+ return suffixes;
+ }
}
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ConfigValueChangeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ConfigValueChangeValidatorTest.java
index 2456113f40d..f13f53e8648 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ConfigValueChangeValidatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ConfigValueChangeValidatorTest.java
@@ -246,6 +246,9 @@ public class ConfigValueChangeValidatorTest {
public int getPortCount() {
return 0;
}
+
+ @Override
+ public String[] getPortSuffixes() { return null; }
}
private static class SimpleConfigProducer extends AbstractConfigProducer<AbstractConfigProducer<?>>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/StartupCommandChangeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/StartupCommandChangeValidatorTest.java
index 4f6a1ddf7b3..2b04b026ee7 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/StartupCommandChangeValidatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/StartupCommandChangeValidatorTest.java
@@ -75,5 +75,8 @@ public class StartupCommandChangeValidatorTest {
public int getPortCount() {
return 0;
}
+
+ @Override
+ public String[] getPortSuffixes() { return null; }
}
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/test/ApiService.java b/config-model/src/test/java/com/yahoo/vespa/model/test/ApiService.java
index 28c56e1e45f..6f06eb3e482 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/test/ApiService.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/test/ApiService.java
@@ -36,4 +36,6 @@ public class ApiService extends AbstractService implements com.yahoo.test.Standa
public int getPortCount() { return 0; }
+ @Override
+ public String[] getPortSuffixes() { return null; }
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/test/ModelAmendingTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/test/ModelAmendingTestCase.java
index 57b0606457d..82da14f0d29 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/test/ModelAmendingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/test/ModelAmendingTestCase.java
@@ -129,6 +129,8 @@ public class ModelAmendingTestCase {
return 0;
}
+ @Override
+ public String[] getPortSuffixes() { return null; }
}
public static class AdminModelAmender extends ConfigModel {
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/test/ParentService.java b/config-model/src/test/java/com/yahoo/vespa/model/test/ParentService.java
index 325cc78a361..c7559c68592 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/test/ParentService.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/test/ParentService.java
@@ -58,4 +58,7 @@ public class ParentService extends AbstractService implements com.yahoo.test.Sta
}
public int getPortCount() { return 0; }
+
+ @Override
+ public String[] getPortSuffixes() { return null; }
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/test/SimpleService.java b/config-model/src/test/java/com/yahoo/vespa/model/test/SimpleService.java
index 0037e40a20f..a38916463c4 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/test/SimpleService.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/test/SimpleService.java
@@ -38,6 +38,11 @@ public class SimpleService extends AbstractService implements com.yahoo.test.Sta
public int getWantedPort(){ return 10000; }
public int getPortCount() { return 5; }
+ @Override
+ public String[] getPortSuffixes() {
+ return new String[]{ "a", "b", "c", "d", "e" };
+ }
+
// Make sure this service is listed in the sentinel config
public String getStartupCommand() { return "sleep 0"; }
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/utils/FileSenderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/utils/FileSenderTest.java
index 9b5bd71274a..79646dacaa9 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/utils/FileSenderTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/utils/FileSenderTest.java
@@ -172,5 +172,8 @@ public class FileSenderTest {
public int getPortCount() {
return 0;
}
+
+ @Override
+ public String[] getPortSuffixes() { return null; }
}
}
diff --git a/config-provisioning/abi-spec.json b/config-provisioning/abi-spec.json
index 04e4a7276e2..af61fb46e50 100644
--- a/config-provisioning/abi-spec.json
+++ b/config-provisioning/abi-spec.json
@@ -462,11 +462,13 @@
"public void <init>(java.lang.String, java.util.List, com.yahoo.config.provision.ClusterMembership)",
"public void <init>(java.lang.String, java.util.List, java.util.Optional, java.util.Optional)",
"public void <init>(java.lang.String, java.util.List, java.util.Optional, java.util.Optional, java.util.Optional)",
+ "public void <init>(java.lang.String, java.util.List, java.util.Optional, java.util.Optional, java.util.Optional, java.util.Optional)",
"public java.lang.String hostname()",
"public java.util.List aliases()",
"public java.util.Optional flavor()",
"public java.util.Optional version()",
"public java.util.Optional membership()",
+ "public java.util.Optional networkPorts()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
"public int hashCode()",
@@ -497,6 +499,50 @@
],
"fields": []
},
+ "com.yahoo.config.provision.NetworkPorts$Allocation": {
+ "superClass": "java.lang.Object",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>(int, java.lang.String, java.lang.String, java.lang.String)",
+ "public java.lang.String key()",
+ "public java.lang.String toString()"
+ ],
+ "fields": [
+ "public final int port",
+ "public final java.lang.String serviceType",
+ "public final java.lang.String configId",
+ "public final java.lang.String portSuffix"
+ ]
+ },
+ "com.yahoo.config.provision.NetworkPorts": {
+ "superClass": "java.lang.Object",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>(java.util.Collection)",
+ "public java.util.Collection allocations()",
+ "public int size()"
+ ],
+ "fields": []
+ },
+ "com.yahoo.config.provision.NetworkPortsSerializer": {
+ "superClass": "java.lang.Object",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>()",
+ "public static void toSlime(com.yahoo.config.provision.NetworkPorts, com.yahoo.slime.Cursor)",
+ "public static java.util.Optional fromSlime(com.yahoo.slime.Inspector)"
+ ],
+ "fields": []
+ },
"com.yahoo.config.provision.NodeFlavors": {
"superClass": "java.lang.Object",
"interfaces": [],
diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/AllocatedHosts.java b/config-provisioning/src/main/java/com/yahoo/config/provision/AllocatedHosts.java
index 4c1798c549f..28c5d475e19 100644
--- a/config-provisioning/src/main/java/com/yahoo/config/provision/AllocatedHosts.java
+++ b/config-provisioning/src/main/java/com/yahoo/config/provision/AllocatedHosts.java
@@ -35,6 +35,7 @@ public class AllocatedHosts {
/** Current version */
private static final String hostSpecCurrentVespaVersion = "currentVespaVersion";
+ private static final String hostSpecNetworkPorts = "ports";
private final ImmutableSet<HostSpec> hosts;
@@ -60,6 +61,7 @@ public class AllocatedHosts {
});
host.flavor().ifPresent(flavor -> cursor.setString(hostSpecFlavor, flavor.name()));
host.version().ifPresent(version -> cursor.setString(hostSpecCurrentVespaVersion, version.toFullString()));
+ host.networkPorts().ifPresent(ports -> NetworkPortsSerializer.toSlime(ports, cursor.setArray(hostSpecNetworkPorts)));
}
/** Returns the hosts of this allocation */
@@ -84,8 +86,9 @@ public class AllocatedHosts {
object.field(hostSpecFlavor).valid() ? flavorFromSlime(object, nodeFlavors) : Optional.empty();
Optional<com.yahoo.component.Version> version =
optionalString(object.field(hostSpecCurrentVespaVersion)).map(com.yahoo.component.Version::new);
-
- return new HostSpec(object.field(hostSpecHostName).asString(), Collections.emptyList(), flavor, membership, version);
+ Optional<NetworkPorts> networkPorts =
+ NetworkPortsSerializer.fromSlime(object.field(hostSpecNetworkPorts));
+ return new HostSpec(object.field(hostSpecHostName).asString(), Collections.emptyList(), flavor, membership, version, networkPorts);
}
private static ClusterMembership membershipFromSlime(Inspector object) {
diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/HostSpec.java b/config-provisioning/src/main/java/com/yahoo/config/provision/HostSpec.java
index f0e8774759d..e5d4aadb988 100644
--- a/config-provisioning/src/main/java/com/yahoo/config/provision/HostSpec.java
+++ b/config-provisioning/src/main/java/com/yahoo/config/provision/HostSpec.java
@@ -29,6 +29,8 @@ public class HostSpec implements Comparable<HostSpec> {
private final Optional<com.yahoo.component.Version> version;
+ private final Optional<NetworkPorts> networkPorts;
+
public HostSpec(String hostname, Optional<ClusterMembership> membership) {
this(hostname, new ArrayList<>(), Optional.empty(), membership);
}
@@ -40,6 +42,7 @@ public class HostSpec implements Comparable<HostSpec> {
public HostSpec(String hostname, List<String> aliases) {
this(hostname, aliases, Optional.empty(), Optional.empty());
}
+
public HostSpec(String hostname, List<String> aliases, Flavor flavor) {
this(hostname, aliases, Optional.of(flavor), Optional.empty());
}
@@ -54,13 +57,21 @@ public class HostSpec implements Comparable<HostSpec> {
public HostSpec(String hostname, List<String> aliases, Optional<Flavor> flavor,
Optional<ClusterMembership> membership, Optional<com.yahoo.component.Version> version) {
+ this(hostname, aliases, flavor, membership, version, Optional.empty());
+ }
+
+ public HostSpec(String hostname, List<String> aliases, Optional<Flavor> flavor,
+ Optional<ClusterMembership> membership, Optional<com.yahoo.component.Version> version,
+ Optional<NetworkPorts> networkPorts) {
if (hostname == null || hostname.isEmpty()) throw new IllegalArgumentException("Hostname must be specified");
Objects.requireNonNull(version, "Version cannot be null but can be empty");
+ Objects.requireNonNull(networkPorts, "Network ports cannot be null but can be empty");
this.hostname = hostname;
this.aliases = ImmutableList.copyOf(aliases);
this.flavor = flavor;
this.membership = membership;
this.version = version;
+ this.networkPorts = networkPorts;
}
/** Returns the name identifying this host */
@@ -77,6 +88,9 @@ public class HostSpec implements Comparable<HostSpec> {
/** Returns the membership of this host, or an empty value if not present */
public Optional<ClusterMembership> membership() { return membership; }
+ /** Returns the network port allocations on this host, or empty if not present */
+ public Optional<NetworkPorts> networkPorts() { return networkPorts; }
+
@Override
public String toString() {
return hostname +
diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/NetworkPorts.java b/config-provisioning/src/main/java/com/yahoo/config/provision/NetworkPorts.java
new file mode 100644
index 00000000000..90ac3651bb2
--- /dev/null
+++ b/config-provisioning/src/main/java/com/yahoo/config/provision/NetworkPorts.java
@@ -0,0 +1,55 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.config.provision;
+
+import java.util.Collection;
+import java.util.List;
+
+/**
+ * Models an immutable list of network port allocations
+ * @author arnej
+ */
+public class NetworkPorts {
+
+ public static class Allocation {
+ public final int port;
+ public final String serviceType;
+ public final String configId;
+ public final String portSuffix;
+
+ public Allocation(int port, String serviceType, String configId, String portSuffix) {
+ this.port = port;
+ this.serviceType = serviceType;
+ this.configId = configId;
+ this.portSuffix = portSuffix;
+ }
+ public String key() {
+ StringBuilder buf = new StringBuilder();
+ buf.append("t=").append(serviceType);
+ buf.append(" cfg=").append(configId);
+ buf.append(" suf=").append(portSuffix);
+ return buf.toString();
+ }
+ public String toString() {
+ StringBuilder buf = new StringBuilder();
+ buf.append("[port=").append(port);
+ buf.append(" serviceType=").append(serviceType);
+ buf.append(" configId=").append(configId);
+ buf.append(" suffix=").append(portSuffix);
+ buf.append("]");
+ return buf.toString();
+ }
+ }
+
+ private final List<Allocation> allocations;
+
+ public NetworkPorts(Collection<Allocation> allocations) {
+ this.allocations = List.copyOf(allocations);
+ }
+
+ public Collection<Allocation> allocations() {
+ return this.allocations;
+ }
+
+ public int size() { return allocations.size(); }
+}
diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/NetworkPortsSerializer.java b/config-provisioning/src/main/java/com/yahoo/config/provision/NetworkPortsSerializer.java
new file mode 100644
index 00000000000..d3af337e9be
--- /dev/null
+++ b/config-provisioning/src/main/java/com/yahoo/config/provision/NetworkPortsSerializer.java
@@ -0,0 +1,56 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.config.provision;
+
+import com.yahoo.slime.ArrayTraverser;
+import com.yahoo.slime.Cursor;
+import com.yahoo.slime.Inspector;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * Serializes network port allocations to/from JSON.
+ *
+ * @author arnej
+ */
+public class NetworkPortsSerializer {
+
+ // Network port fields
+ private static final String portNumberKey = "port";
+ private static final String serviceTypeKey = "type";
+ private static final String configIdKey = "cfg";
+ private static final String portSuffixKey = "suf";
+
+ // ---------------- Serialization ----------------------------------------------------
+
+ public static void toSlime(NetworkPorts networkPorts, Cursor array) {
+ for (NetworkPorts.Allocation allocation : networkPorts.allocations()) {
+ Cursor object = array.addObject();
+ object.setLong(portNumberKey, allocation.port);
+ object.setString(serviceTypeKey, allocation.serviceType);
+ object.setString(configIdKey, allocation.configId);
+ object.setString(portSuffixKey, allocation.portSuffix);
+ }
+ }
+
+ // ---------------- Deserialization --------------------------------------------------
+
+ public static Optional<NetworkPorts> fromSlime(Inspector array) {
+ List<NetworkPorts.Allocation> list = new ArrayList<>(array.entries());
+ array.traverse((ArrayTraverser) (int i, Inspector item) -> {
+ list.add(new NetworkPorts.Allocation((int)item.field(portNumberKey).asLong(),
+ item.field(serviceTypeKey).asString(),
+ item.field(configIdKey).asString(),
+ item.field(portSuffixKey).asString()));
+ }
+ );
+ if (list.size() > 0) {
+ NetworkPorts allocator = new NetworkPorts(list);
+ return Optional.of(allocator);
+ }
+ return Optional.empty();
+ }
+
+}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/provision/ProvisionerAdapter.java b/configserver/src/main/java/com/yahoo/vespa/config/server/provision/ProvisionerAdapter.java
index 32380b296dd..ee4cc4a3043 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/provision/ProvisionerAdapter.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/provision/ProvisionerAdapter.java
@@ -8,6 +8,7 @@ import com.yahoo.config.provision.ClusterSpec;
import com.yahoo.config.provision.HostSpec;
import com.yahoo.config.provision.ProvisionLogger;
import com.yahoo.config.provision.Provisioner;
+import com.yahoo.config.provision.NetworkPorts;
import java.util.*;
diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionTest.java
index c4b3e5f24dc..10ee1a22bab 100644
--- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionTest.java
+++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionTest.java
@@ -7,6 +7,7 @@ import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.provision.AllocatedHosts;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.HostSpec;
+import com.yahoo.config.provision.NetworkPorts;
import com.yahoo.config.provision.TenantName;
import com.yahoo.path.Path;
import com.yahoo.config.model.application.provider.*;
@@ -136,7 +137,15 @@ public class LocalSessionTest {
@Test
public void require_that_provision_info_can_be_read() throws Exception {
- AllocatedHosts input = AllocatedHosts.withHosts(Collections.singleton(new HostSpec("myhost", Collections.<String>emptyList())));
+ List<NetworkPorts.Allocation> list = new ArrayList<>();
+ list.add(new NetworkPorts.Allocation(8080, "container", "default/0", "http"));
+ list.add(new NetworkPorts.Allocation(19101, "searchnode", "other/1", "rpc"));
+ NetworkPorts ports = new NetworkPorts(list);
+
+ AllocatedHosts input = AllocatedHosts.withHosts(Collections.singleton(
+ new HostSpec("myhost", Collections.<String>emptyList(),
+ Optional.empty(), Optional.empty(), Optional.empty(),
+ Optional.of(ports))));
LocalSession session = createSession(TenantName.defaultName(), 3, new SessionTest.MockSessionPreparer(), Optional.of(input));
ApplicationId origId = new ApplicationId.Builder()
@@ -147,6 +156,9 @@ public class LocalSessionTest {
assertNotNull(info);
assertThat(info.getHosts().size(), is(1));
assertTrue(info.getHosts().contains(new HostSpec("myhost", Collections.emptyList())));
+ Optional<NetworkPorts> portsCopy = info.getHosts().iterator().next().networkPorts();
+ assertTrue(portsCopy.isPresent());
+ assertThat(portsCopy.get().allocations(), is(list));
}
@Test
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/JobType.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/JobType.java
index cee8d3ddfd9..1dd9da6dc32 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/JobType.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/deployment/JobType.java
@@ -28,7 +28,9 @@ public enum JobType {
productionAwsUsEast1b ("production-aws-us-east-1b" , ZoneId.from("prod" , "aws-us-east-1b") , null ),
productionCdAwsUsEast1a("production-cd-aws-us-east-1a", null , ZoneId.from("prod" , "cd-aws-us-east-1a")),
productionCdUsCentral1 ("production-cd-us-central-1" , null , ZoneId.from("prod" , "cd-us-central-1") ),
- productionCdUsCentral2 ("production-cd-us-central-2" , null , ZoneId.from("prod" , "cd-us-central-2") );
+ // TODO: Cannot remove production-cd-us-central-2 until we know there are no serialized data in controller referencing it
+ productionCdUsCentral2 ("production-cd-us-central-2" , null , ZoneId.from("prod" , "cd-us-central-2") ),
+ productionCdUsWest1 ("production-cd-us-west-1" , null , ZoneId.from("prod" , "cd-us-west-1") );
private final String jobName;
private final ImmutableMap<SystemName, ZoneId> zones;
diff --git a/document/abi-spec.json b/document/abi-spec.json
index 61390af3523..d4db3026b27 100644
--- a/document/abi-spec.json
+++ b/document/abi-spec.json
@@ -5244,7 +5244,7 @@
],
"methods": [
"public void <init>(com.yahoo.document.update.TensorModifyUpdate$Operation, com.yahoo.document.datatypes.TensorFieldValue)",
- "public static com.yahoo.tensor.TensorType convertToCompatibleType(com.yahoo.tensor.TensorType)",
+ "public static com.yahoo.tensor.TensorType convertDimensionsToMapped(com.yahoo.tensor.TensorType)",
"public com.yahoo.document.update.TensorModifyUpdate$Operation getOperation()",
"public com.yahoo.document.datatypes.TensorFieldValue getValue()",
"public void setValue(com.yahoo.document.datatypes.TensorFieldValue)",
@@ -5278,6 +5278,7 @@
"public boolean equals(java.lang.Object)",
"public int hashCode()",
"public java.lang.String toString()",
+ "public static com.yahoo.tensor.TensorType extractSparseDimensions(com.yahoo.tensor.TensorType)",
"public bridge synthetic void setValue(com.yahoo.document.datatypes.FieldValue)",
"public bridge synthetic com.yahoo.document.datatypes.FieldValue getValue()"
],
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
index ffbfe49347c..6310fa62d15 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
@@ -23,22 +23,23 @@ public class TensorAddUpdateReader {
public static TensorAddUpdate createTensorAddUpdate(TokenBuffer buffer, Field field) {
expectObjectStart(buffer.currentToken());
- expectTensorTypeIsSparse(field);
+ expectTensorTypeHasSparseDimensions(field);
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType tensorType = tensorDataType.getTensorType();
TensorFieldValue tensorFieldValue = new TensorFieldValue(tensorType);
fillTensor(buffer, tensorFieldValue);
+
expectTensorIsNonEmpty(field, tensorFieldValue.getTensor().get());
return new TensorAddUpdate(tensorFieldValue);
}
- private static void expectTensorTypeIsSparse(Field field) {
+ private static void expectTensorTypeHasSparseDimensions(Field field) {
TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
- if (tensorType.dimensions().stream()
- .anyMatch(dim -> dim.isIndexed())) {
- throw new IllegalArgumentException("An add update can only be applied to sparse tensors. "
- + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'");
+ if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) {
+ throw new IllegalArgumentException("An add update can only be applied to tensors " +
+ "with at least one sparse dimension. Field '" + field.getName() +
+ "' has unsupported tensor type '" + tensorType + "'");
}
}
@@ -48,5 +49,4 @@ public class TensorAddUpdateReader {
}
}
-
}
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
index a9bbba519bd..66588debbca 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
@@ -29,10 +29,8 @@ public class TensorModifyUpdateReader {
private static final String MODIFY_MULTIPLY = "multiply";
public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) {
-
expectFieldIsOfTypeTensor(field);
expectTensorTypeHasNoneIndexedUnboundDimensions(field);
- expectTensorTypeIsNotMixed(field);
expectObjectStart(buffer.currentToken());
ModifyUpdateResult result = createModifyUpdateResult(buffer, field);
@@ -58,16 +56,6 @@ public class TensorModifyUpdateReader {
}
}
- private static void expectTensorTypeIsNotMixed(Field field) {
- TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
- long numMappedDimensions = tensorType.dimensions().stream().filter(dim -> dim.type().equals(TensorType.Dimension.Type.mapped)).count();
- long numIndexedDimensions = tensorType.dimensions().stream().filter(dim -> dim.isIndexed()).count();
- if (numMappedDimensions > 0 && numIndexedDimensions > 0) {
- throw new IllegalArgumentException("A modify update cannot be applied to tensor types with mixed dimensions. "
- + "Field '" + field.getName() + "' has mixed tensor type '" + tensorType + "'");
- }
- }
-
private static void expectOperationSpecified(TensorModifyUpdate.Operation operation, String fieldName) {
if (operation == null) {
throw new IllegalArgumentException("Modify update for field '" + fieldName + "' does not contain an operation");
@@ -121,7 +109,7 @@ public class TensorModifyUpdateReader {
private static TensorFieldValue createTensor(TokenBuffer buffer, Field field) {
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType originalType = tensorDataType.getTensorType();
- TensorType convertedType = TensorModifyUpdate.convertToCompatibleType(originalType);
+ TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(originalType);
Tensor.Builder tensorBuilder = Tensor.Builder.of(convertedType);
readTensorCells(buffer, tensorBuilder);
@@ -129,25 +117,26 @@ public class TensorModifyUpdateReader {
validateBounds(tensor, originalType);
- TensorFieldValue result = new TensorFieldValue(convertedType);
- result.assign(tensor);
- return result;
+ return new TensorFieldValue(tensor);
}
- /** Only validate if original type is indexed bound */
- private static void validateBounds(Tensor convertedTensor, TensorType originalType) {
- if ( ! originalType.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) {
+ /** Only validate if original type has indexed bound dimensions */
+ static void validateBounds(Tensor convertedTensor, TensorType originalType) {
+ if (originalType.dimensions().stream().noneMatch(d -> d instanceof TensorType.IndexedBoundDimension)) {
return;
}
for (Iterator<Tensor.Cell> iter = convertedTensor.cellIterator(); iter.hasNext(); ) {
Tensor.Cell cell = iter.next();
TensorAddress address = cell.getKey();
for (int i = 0; i < address.size(); ++i) {
- long label = address.numericLabel(i);
- long bound = originalType.dimensions().get(i).size().get(); // size is non-optional for indexed bound
- if (label >= bound) {
- throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() +
- "' has label '" + label + "' but type is " + originalType.toString());
+ TensorType.Dimension dim = originalType.dimensions().get(i);
+ if (dim instanceof TensorType.IndexedBoundDimension) {
+ long label = address.numericLabel(i);
+ long bound = dim.size().get(); // size is non-optional for indexed bound
+ if (label >= bound) {
+ throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() +
+ "' has label '" + label + "' but type is " + originalType.toString());
+ }
}
}
}
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
index 210a6a80ee5..3bb4b2e262f 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
@@ -24,23 +24,23 @@ public class TensorRemoveUpdateReader {
static TensorRemoveUpdate createTensorRemoveUpdate(TokenBuffer buffer, Field field) {
expectObjectStart(buffer.currentToken());
- expectTensorTypeIsSparse(field);
+ expectTensorTypeHasSparseDimensions(field);
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
- TensorType tensorType = tensorDataType.getTensorType();
+ TensorType originalType = tensorDataType.getTensorType();
+ TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(originalType);
+ Tensor tensor = readRemoveUpdateTensor(buffer, convertedType, originalType);
- // TODO: for mixed case extract a new tensor type based only on mapped dimensions
-
- Tensor tensor = readRemoveUpdateTensor(buffer, tensorType);
expectAddressesAreNonEmpty(field, tensor);
return new TensorRemoveUpdate(new TensorFieldValue(tensor));
}
- private static void expectTensorTypeIsSparse(Field field) {
+ private static void expectTensorTypeHasSparseDimensions(Field field) {
TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
- if (tensorType.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed)) {
- throw new IllegalArgumentException("A remove update can only be applied to sparse tensors. "
- + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'");
+ if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) {
+ throw new IllegalArgumentException("A remove update can only be applied to tensors " +
+ "with at least one sparse dimension. Field '" + field.getName() +
+ "' has unsupported tensor type '" + tensorType + "'");
}
}
@@ -53,7 +53,7 @@ public class TensorRemoveUpdateReader {
/**
* Reads all addresses in buffer and returns a tensor where addresses have cell value 1.0
*/
- private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type) {
+ private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type, TensorType originalType) {
Tensor.Builder builder = Tensor.Builder.of(type);
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
@@ -62,7 +62,7 @@ public class TensorRemoveUpdateReader {
expectArrayStart(buffer.currentToken());
int nesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= nesting; buffer.next()) {
- builder.cell(readTensorAddress(buffer, type), 1.0);
+ builder.cell(readTensorAddress(buffer, type, originalType), 1.0);
}
expectCompositeEnd(buffer.currentToken());
}
@@ -71,12 +71,15 @@ public class TensorRemoveUpdateReader {
return builder.build();
}
- private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type) {
+ private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type, TensorType originalType) {
TensorAddress.Builder builder = new TensorAddress.Builder(type);
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
String dimension = buffer.currentName();
+ if ( ! type.dimension(dimension).isPresent() && originalType.dimension(dimension).isPresent()) {
+ throw new IllegalArgumentException("Indexed dimension address '" + dimension + "' should not be specified in remove update");
+ }
String label = buffer.currentText();
builder.add(dimension, label);
}
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
index 2f22def9aa1..a763db33e7a 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
@@ -5,6 +5,7 @@ import com.yahoo.document.DataType;
import com.yahoo.document.DocumentTypeManager;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
+import com.yahoo.document.json.readers.TensorRemoveUpdateReader;
import com.yahoo.document.update.TensorAddUpdate;
import com.yahoo.document.update.TensorModifyUpdate;
import com.yahoo.document.update.TensorRemoveUpdate;
@@ -35,7 +36,10 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
throw new DeserializationException("Expected tensor data type, got " + type);
}
TensorDataType tensorDataType = (TensorDataType)type;
- TensorFieldValue tensor = new TensorFieldValue(TensorModifyUpdate.convertToCompatibleType(tensorDataType.getTensorType()));
+ TensorType tensorType = tensorDataType.getTensorType();
+ TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(tensorType);
+
+ TensorFieldValue tensor = new TensorFieldValue(convertedType);
tensor.deserialize(this);
return new TensorModifyUpdate(operation, tensor);
}
@@ -46,7 +50,8 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
throw new DeserializationException("Expected tensor data type, got " + type);
}
TensorDataType tensorDataType = (TensorDataType)type;
- TensorFieldValue tensor = new TensorFieldValue(tensorDataType.getTensorType());
+ TensorType tensorType = tensorDataType.getTensorType();
+ TensorFieldValue tensor = new TensorFieldValue(tensorType);
tensor.deserialize(this);
return new TensorAddUpdate(tensor);
}
@@ -58,10 +63,9 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
}
TensorDataType tensorDataType = (TensorDataType)type;
TensorType tensorType = tensorDataType.getTensorType();
+ TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(tensorType);
- // TODO: for mixed case extract a new tensor type based only on mapped dimensions
-
- TensorFieldValue tensor = new TensorFieldValue(tensorType);
+ TensorFieldValue tensor = new TensorFieldValue(convertedType);
tensor.deserialize(this);
return new TensorRemoveUpdate(tensor);
}
diff --git a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
index cfc3ee0c742..f8d2464deb7 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
@@ -7,15 +7,11 @@ import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.serialization.DocumentUpdateWriter;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
-import java.util.Map;
import java.util.Objects;
/**
- * An update used to add cells to a sparse tensor (has only mapped dimensions).
- *
- * The cells to add are contained in a sparse tensor as well.
+ * An update used to add cells to a sparse or mixed tensor (has at least one mapped dimension).
*/
public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> {
@@ -50,22 +46,10 @@ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> {
return oldValue;
}
- Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get();
- Map<TensorAddress, Double> oldCells = oldTensor.cells();
- Map<TensorAddress, Double> addCells = tensor.getTensor().get().cells();
-
- // currently, underlying implementation disallows multiple entries with the same key
-
- Tensor.Builder builder = Tensor.Builder.of(oldTensor.type());
- for (Map.Entry<TensorAddress, Double> oldCell : oldCells.entrySet()) {
- builder.cell(oldCell.getKey(), addCells.getOrDefault(oldCell.getKey(), oldCell.getValue()));
- }
- for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) {
- if ( ! oldCells.containsKey(addCell.getKey())) {
- builder.cell(addCell.getKey(), addCell.getValue());
- }
- }
- return new TensorFieldValue(builder.build());
+ Tensor old = ((TensorFieldValue) oldValue).getTensor().get();
+ Tensor update = tensor.getTensor().get();
+ Tensor result = old.merge((left, right) -> right, update.cells()); // note this might be slow for large mixed tensor updates
+ return new TensorFieldValue(result);
}
@Override
diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
index 6111b51ca4e..2773f9d31da 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
@@ -37,7 +37,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
/**
* Converts the given tensor type to a type that is compatible for being used in this update (has only mapped dimensions).
*/
- public static TensorType convertToCompatibleType(TensorType type) {
+ public static TensorType convertDimensionsToMapped(TensorType type) {
TensorType.Builder builder = new TensorType.Builder();
type.dimensions().stream().forEach(dim -> builder.mapped(dim.name()));
return builder.build();
diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
index e9fb1e3efd5..335cda8e133 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
@@ -7,10 +7,8 @@ import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.serialization.DocumentUpdateWriter;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
-import java.util.Iterator;
-import java.util.Map;
import java.util.Objects;
/**
@@ -25,6 +23,18 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
public TensorRemoveUpdate(TensorFieldValue value) {
super(ValueUpdateClassID.TENSORREMOVE);
this.tensor = value;
+ verifyCompatibleType();
+ }
+
+ private void verifyCompatibleType() {
+ if ( ! tensor.getTensor().isPresent()) {
+ throw new IllegalArgumentException("Tensor must be present in remove update");
+ }
+ TensorType tensorType = tensor.getTensor().get().type();
+ TensorType expectedType = extractSparseDimensions(tensor.getDataType().getTensorType());
+ if ( ! tensorType.equals(expectedType)) {
+ throw new IllegalArgumentException("Unexpected type '" + tensorType + "' in remove update. Expected is '" + expectedType + "'");
+ }
}
@Override
@@ -51,17 +61,10 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
return oldValue;
}
- Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get();
- Map<TensorAddress, Double> cellsToRemove = tensor.getTensor().get().cells();
- Tensor.Builder builder = Tensor.Builder.of(oldTensor.type());
- for (Iterator<Tensor.Cell> i = oldTensor.cellIterator(); i.hasNext(); ) {
- Tensor.Cell cell = i.next();
- TensorAddress address = cell.getKey();
- if ( ! cellsToRemove.containsKey(address)) {
- builder.cell(address, cell.getValue());
- }
- }
- return new TensorFieldValue(builder.build());
+ Tensor old = ((TensorFieldValue) oldValue).getTensor().get();
+ Tensor update = tensor.getTensor().get();
+ Tensor result = old.remove(update.cells().keySet());
+ return new TensorFieldValue(result);
}
@Override
@@ -93,4 +96,11 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
return super.toString() + " " + tensor;
}
+ public static TensorType extractSparseDimensions(TensorType type) {
+ TensorType.Builder builder = new TensorType.Builder();
+ type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name()));
+ return builder.build();
+ }
+
+
}
diff --git a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
index e2736dabd2b..454ad72f344 100644
--- a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
+++ b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
@@ -40,6 +40,7 @@ public class DocumentUpdateJsonSerializerTest {
final static TensorType sparseTensorType = new TensorType.Builder().mapped("x").mapped("y").build();
final static TensorType denseTensorType = new TensorType.Builder().indexed("x", 2).indexed("y", 3).build();
+ final static TensorType mixedTensorType = new TensorType.Builder().mapped("x").indexed("y", 3).build();
final static DocumentTypeManager types = new DocumentTypeManager();
final static JsonFactory parserFactory = new JsonFactory();
final static DocumentType docType = new DocumentType("doctype");
@@ -60,6 +61,7 @@ public class DocumentUpdateJsonSerializerTest {
docType.addField(new Field("byte_field", DataType.BYTE));
docType.addField(new Field("sparse_tensor", new TensorDataType(sparseTensorType)));
docType.addField(new Field("dense_tensor", new TensorDataType(denseTensorType)));
+ docType.addField(new Field("mixed_tensor", new TensorDataType(mixedTensorType)));
docType.addField(new Field("reference_field", new ReferenceDataType(refTargetDocType, 777)));
docType.addField(new Field("predicate_field", DataType.PREDICATE));
docType.addField(new Field("raw_field", DataType.RAW));
@@ -336,6 +338,26 @@ public class DocumentUpdateJsonSerializerTest {
}
@Test
+ public void test_tensor_modify_update_on_mixed_tensor() {
+ roundtripSerializeJsonAndMatch(inputJson(
+ "{",
+ " 'update': 'DOCUMENT_ID',",
+ " 'fields': {",
+ " 'mixed_tensor': {",
+ " 'modify': {",
+ " 'operation': 'multiply',",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 },",
+ " { 'address': { 'x': 'c', 'y': '1' }, 'value': 3.0 }",
+ " ]",
+ " }",
+ " }",
+ " }",
+ "}"
+ ));
+ }
+
+ @Test
public void test_tensor_add_update() {
roundtripSerializeJsonAndMatch(inputJson(
"{",
@@ -355,6 +377,29 @@ public class DocumentUpdateJsonSerializerTest {
}
@Test
+ public void test_tensor_add_update_mixed() {
+ roundtripSerializeJsonAndMatch(inputJson(
+ "{",
+ " 'update': 'DOCUMENT_ID',",
+ " 'fields': {",
+ " 'mixed_tensor': {",
+ " 'add': {",
+ " 'cells': [",
+ " { 'address': { 'x': '1', 'y': '0' }, 'value': 2.0 },",
+ " { 'address': { 'x': '1', 'y': '1' }, 'value': 0.0 },",
+ " { 'address': { 'x': '1', 'y': '2' }, 'value': 0.0 },",
+ " { 'address': { 'x': '0', 'y': '0' }, 'value': 0.0 },",
+ " { 'address': { 'x': '0', 'y': '1' }, 'value': 0.0 },",
+ " { 'address': { 'x': '0', 'y': '2' }, 'value': 3.0 }",
+ " ]",
+ " }",
+ " }",
+ " }",
+ "}"
+ ));
+ }
+
+ @Test
public void test_tensor_remove_update() {
roundtripSerializeJsonAndMatch(inputJson(
"{",
@@ -374,6 +419,24 @@ public class DocumentUpdateJsonSerializerTest {
}
@Test
+ public void test_tensor_remove_update_mixed() {
+ roundtripSerializeJsonAndMatch(inputJson(
+ "{",
+ " 'update': 'DOCUMENT_ID',",
+ " 'fields': {",
+ " 'mixed_tensor': {",
+ " 'remove': {",
+ " 'addresses': [",
+ " {'x':'0' }",
+ " ]",
+ " }",
+ " }",
+ " }",
+ "}"
+ ));
+ }
+
+ @Test
public void reference_field_id_can_be_update_assigned_non_empty_id() {
roundtripSerializeJsonAndMatch(inputJson(
"{",
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index e58b26d500d..15d1e859f73 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -1387,12 +1387,30 @@ public class JsonReaderTestCase {
}
@Test
- public void tensor_modify_update_on_mixed_tensor_throws() {
- exception.expect(IllegalArgumentException.class);
- exception.expectMessage("A modify update cannot be applied to tensor types with mixed dimensions. Field 'mixed_tensor' has mixed tensor type 'tensor(x{},y[3])'");
- createTensorModifyUpdate(inputJson("{",
- " 'operation': 'replace',",
- " 'cells': [] }"), "mixed_tensor");
+ public void tensor_modify_update_with_replace_operation_mixed() {
+ assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor",
+ inputJson("{",
+ " 'operation': 'replace',",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}"));
+ }
+
+ @Test
+ public void tensor_modify_update_with_add_operation_mixed() {
+ assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.ADD, "mixed_tensor",
+ inputJson("{",
+ " 'operation': 'add',",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}"));
+ }
+
+ @Test
+ public void tensor_modify_update_with_multiply_operation_mixed() {
+ assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.MULTIPLY, "mixed_tensor",
+ inputJson("{",
+ " 'operation': 'multiply',",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}"));
}
@Test
@@ -1406,6 +1424,17 @@ public class JsonReaderTestCase {
}
@Test
+ public void tensor_modify_update_with_out_of_bound_cells_throws_mixed() {
+ exception.expect(IndexOutOfBoundsException.class);
+ exception.expectMessage("Dimension 'y' has label '3' but type is tensor(x{},y[3])");
+ createTensorModifyUpdate(inputJson("{",
+ " 'operation': 'replace',",
+ " 'cells': [",
+ " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"), "mixed_tensor");
+ }
+
+
+ @Test
public void tensor_modify_update_with_unknown_operation_throws() {
exception.expect(IllegalArgumentException.class);
exception.expectMessage("Unknown operation 'unknown' in modify update for field 'sparse_tensor'");
@@ -1449,11 +1478,29 @@ public class JsonReaderTestCase {
}
@Test
- public void tensor_add_update_on_non_sparse_tensor_throws() {
+ public void tensor_add_update_on_mixed_tensor() {
+ assertTensorAddUpdate("{{x:a,y:0}:2.0, {x:a,y:1}:3.0, {x:a,y:2}:0.0}", "mixed_tensor",
+ inputJson("{",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 },",
+ " { 'address': { 'x': 'a', 'y': '1' }, 'value': 3.0 } ]}"));
+ }
+
+ @Test
+ public void tensor_add_update_on_mixed_with_out_of_bound_dense_cells_throws() {
+ exception.expect(IndexOutOfBoundsException.class);
+ exception.expectMessage("Index 3 out of bounds for length 3");
+ createTensorAddUpdate(inputJson("{",
+ " 'cells': [",
+ " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"), "mixed_tensor");
+ }
+
+ @Test
+ public void tensor_add_update_on_dense_tensor_throws() {
exception.expect(IllegalArgumentException.class);
- exception.expectMessage("An add update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'");
+ exception.expectMessage("An add update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'");
createTensorAddUpdate(inputJson("{",
- " 'cells': [] }"), "mixed_tensor");
+ " 'cells': [] }"), "dense_tensor");
}
@Test
@@ -1470,6 +1517,7 @@ public class JsonReaderTestCase {
exception.expect(IllegalArgumentException.class);
exception.expectMessage("Add update for field 'sparse_tensor' does not contain tensor cells");
createTensorAddUpdate(inputJson("{}"), "sparse_tensor");
+ createTensorAddUpdate(inputJson("{}"), "mixed_tensor");
}
@Test
@@ -1482,11 +1530,30 @@ public class JsonReaderTestCase {
}
@Test
- public void tensor_remove_update_on_non_sparse_tensor_throws() {
+ public void tensor_remove_update_on_mixed_tensor() {
+ assertTensorRemoveUpdate("{{x:1}:1.0,{x:2}:1.0}", "mixed_tensor",
+ inputJson("{",
+ " 'addresses': [",
+ " { 'x': '1' },",
+ " { 'x': '2' } ]}"));
+ }
+
+ @Test
+ public void tensor_remove_update_on_mixed_tensor_with_dense_addresses_throws() {
+ exception.expect(IllegalArgumentException.class);
+ exception.expectMessage("Indexed dimension address 'y' should not be specified in remove update");
+ createTensorRemoveUpdate(inputJson("{",
+ " 'addresses': [",
+ " { 'x': '1', 'y': '0' },",
+ " { 'x': '2', 'y': '0' } ]}"), "mixed_tensor");
+ }
+
+ @Test
+ public void tensor_remove_update_on_dense_tensor_throws() {
exception.expect(IllegalArgumentException.class);
- exception.expectMessage("A remove update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'");
+ exception.expectMessage("A remove update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'");
createTensorRemoveUpdate(inputJson("{",
- " 'addresses': [] }"), "mixed_tensor");
+ " 'addresses': [] }"), "dense_tensor");
}
@Test
@@ -1503,6 +1570,7 @@ public class JsonReaderTestCase {
exception.expect(IllegalArgumentException.class);
exception.expectMessage("Remove update for field 'sparse_tensor' does not contain tensor addresses");
createTensorRemoveUpdate(inputJson("{'addresses': [] }"), "sparse_tensor");
+ createTensorRemoveUpdate(inputJson("{'addresses': [] }"), "mixed_tensor");
}
@Test
diff --git a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java
index eb4001e6415..6935c54ba2a 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java
@@ -12,18 +12,14 @@ public class TensorAddUpdateTest {
@Test
public void apply_add_update_operations() {
assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}");
- assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3}");
- assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}");
- assertApplyTo("{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5}");
- assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1, {x:0,y:1}:2}");
}
private void assertApplyTo(String init, String update, String expected) {
String spec = "tensor(x{},y{})";
TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init));
TensorAddUpdate addUpdate = new TensorAddUpdate(new TensorFieldValue(Tensor.from(spec, update)));
- TensorFieldValue updatedFieldValue = (TensorFieldValue) addUpdate.applyTo(initialFieldValue);
- assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get());
+ Tensor updated = ((TensorFieldValue) addUpdate.applyTo(initialFieldValue)).getTensor().get();
+ assertEquals(Tensor.from(spec, expected), updated);
}
}
diff --git a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
index 6e9444de2be..b885e6ddca0 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
@@ -1,12 +1,6 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.document.update;
-import com.yahoo.document.Document;
-import com.yahoo.document.DocumentId;
-import com.yahoo.document.DocumentType;
-import com.yahoo.document.DocumentTypeManager;
-import com.yahoo.document.Field;
-import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.update.TensorModifyUpdate.Operation;
import com.yahoo.tensor.Tensor;
@@ -28,10 +22,11 @@ public class TensorModifyUpdateTest {
assertConvertToCompatible("tensor(x{})", "tensor(x[10])");
assertConvertToCompatible("tensor(x{})", "tensor(x{})");
assertConvertToCompatible("tensor(x{},y{},z{})", "tensor(x[],y[10],z{})");
+ assertConvertToCompatible("tensor(x{},y{})", "tensor(x{},y[3])");
}
private static void assertConvertToCompatible(String expectedType, String inputType) {
- assertEquals(expectedType, TensorModifyUpdate.convertToCompatibleType(TensorType.fromSpec(inputType)).toString());
+ assertEquals(expectedType, TensorModifyUpdate.convertDimensionsToMapped(TensorType.fromSpec(inputType)).toString());
}
@Test
@@ -46,15 +41,9 @@ public class TensorModifyUpdateTest {
public void apply_modify_update_operations() {
assertApplyTo("tensor(x{},y{})", Operation.REPLACE,
"{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}");
- assertApplyTo("tensor(x{},y{})", Operation.ADD,
- "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}");
- assertApplyTo("tensor(x{},y{})", Operation.MULTIPLY,
- "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}");
- assertApplyTo("tensor(x[1],y[2])", Operation.REPLACE,
- "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}");
assertApplyTo("tensor(x[1],y[2])", Operation.ADD,
"{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}");
- assertApplyTo("tensor(x[1],y[2])", Operation.MULTIPLY,
+ assertApplyTo("tensor(x{},y[2])", Operation.MULTIPLY,
"{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}");
}
diff --git a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
index 40ab00facdb..3a005e858c8 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
@@ -12,9 +12,6 @@ public class TensorRemoveUpdateTest {
@Test
public void apply_remove_update_operations() {
assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0,y:1}:1}", "{{x:0,y:0}:2}");
- assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:0}:1,{x:0,y:1}:1}", "{}");
- assertApplyTo("{}", "{{x:0,y:0}:1}", "{}");
- assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}");
}
private void assertApplyTo(String init, String update, String expected) {
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp
index c74c211756f..017d83893f0 100644
--- a/document/src/tests/documentupdatetestcase.cpp
+++ b/document/src/tests/documentupdatetestcase.cpp
@@ -922,6 +922,17 @@ TEST(DocumentUpdateTest, tensor_add_update_can_be_applied)
.add({{"x", "c"}}, 7));
}
+TEST(DocumentUpdateTest, tensor_remove_update_can_be_applied)
+{
+ TensorUpdateFixture f;
+ f.assertApplyUpdate(f.spec().add({{"x", "a"}}, 2)
+ .add({{"x", "b"}}, 3),
+
+ TensorRemoveUpdate(f.makeTensor(f.spec().add({{"x", "b"}}, 1))),
+
+ f.spec().add({{"x", "a"}}, 2));
+}
+
TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied)
{
TensorUpdateFixture f;
diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp
index 3e2bb86c66b..671bf260629 100644
--- a/document/src/vespa/document/update/tensor_remove_update.cpp
+++ b/document/src/vespa/document/update/tensor_remove_update.cpp
@@ -6,6 +6,8 @@
#include <vespa/document/fieldvalue/document.h>
#include <vespa/document/fieldvalue/tensorfieldvalue.h>
#include <vespa/document/serialization/vespadocumentdeserializer.h>
+#include <vespa/eval/tensor/cell_values.h>
+#include <vespa/eval/tensor/sparse/sparse_tensor.h>
#include <vespa/eval/tensor/tensor.h>
#include <vespa/vespalib/objects/nbostream.h>
#include <vespa/vespalib/util/xmlstream.h>
@@ -77,17 +79,35 @@ TensorRemoveUpdate::checkCompatibility(const Field &field) const
std::unique_ptr<Tensor>
TensorRemoveUpdate::applyTo(const Tensor &tensor) const
{
- // TODO: implement
- (void) tensor;
+ auto &addressTensor = _tensor->getAsTensorPtr();
+ if (addressTensor) {
+ if (const auto *sparseTensor = dynamic_cast<const vespalib::tensor::SparseTensor *>(addressTensor.get())) {
+ vespalib::tensor::CellValues cellAddresses(*sparseTensor);
+ return tensor.remove(cellAddresses);
+ } else {
+ throw IllegalArgumentException(make_string("Expected address tensor to be sparse, but has type '%s'",
+ addressTensor->type().to_spec().c_str()));
+ }
+ }
return std::unique_ptr<Tensor>();
}
bool
TensorRemoveUpdate::applyTo(FieldValue &value) const
{
- // TODO: implement
- (void) value;
- return false;
+ if (value.inherits(TensorFieldValue::classId)) {
+ TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value);
+ auto &oldTensor = tensorFieldValue.getAsTensorPtr();
+ auto newTensor = applyTo(*oldTensor);
+ if (newTensor) {
+ tensorFieldValue = std::move(newTensor);
+ }
+ } else {
+ std::string err = make_string("Unable to perform a tensor remove update on a '%s' field value.",
+ value.getClass().name());
+ throw IllegalStateException(err, VESPA_STRLOC);
+ }
+ return true;
}
void
diff --git a/document/src/vespa/document/update/valueupdate.h b/document/src/vespa/document/update/valueupdate.h
index 0e15943f8e4..6939d10ce2c 100644
--- a/document/src/vespa/document/update/valueupdate.h
+++ b/document/src/vespa/document/update/valueupdate.h
@@ -55,7 +55,8 @@ public:
Map = IDENTIFIABLE_CLASSID(MapValueUpdate),
Remove = IDENTIFIABLE_CLASSID(RemoveValueUpdate),
TensorModifyUpdate = IDENTIFIABLE_CLASSID(TensorModifyUpdate),
- TensorAddUpdate = IDENTIFIABLE_CLASSID(TensorAddUpdate)
+ TensorAddUpdate = IDENTIFIABLE_CLASSID(TensorAddUpdate),
+ TensorRemoveUpdate = IDENTIFIABLE_CLASSID(TensorRemoveUpdate)
};
ValueUpdate()
diff --git a/documentapi/CMakeLists.txt b/documentapi/CMakeLists.txt
index b03dd66c817..86d29732399 100644
--- a/documentapi/CMakeLists.txt
+++ b/documentapi/CMakeLists.txt
@@ -14,7 +14,6 @@ vespa_define_module(
vdslib
LIBS
- src/vespa/binref
src/vespa/documentapi
src/vespa/documentapi/loadtypes
src/vespa/documentapi/messagebus
diff --git a/documentapi/src/vespa/binref/.gitignore b/documentapi/src/vespa/binref/.gitignore
deleted file mode 100644
index cfb0e619824..00000000000
--- a/documentapi/src/vespa/binref/.gitignore
+++ /dev/null
@@ -1,3 +0,0 @@
-.depend
-Makefile
-testrun.sh
diff --git a/documentapi/src/vespa/binref/CMakeLists.txt b/documentapi/src/vespa/binref/CMakeLists.txt
deleted file mode 100644
index adece6dd711..00000000000
--- a/documentapi/src/vespa/binref/CMakeLists.txt
+++ /dev/null
@@ -1 +0,0 @@
-# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
diff --git a/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java b/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java
index 41302a4c725..84fbb7d4f01 100644
--- a/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java
+++ b/jrt/src/com/yahoo/jrt/TlsCryptoEngine.java
@@ -22,7 +22,6 @@ public class TlsCryptoEngine implements CryptoEngine {
@Override
public TlsCryptoSocket createCryptoSocket(SocketChannel channel, boolean isServer) {
SSLEngine sslEngine = tlsContext.createSslEngine();
- sslEngine.setNeedClientAuth(true);
sslEngine.setUseClientMode(!isServer);
return new TlsCryptoSocket(channel, sslEngine);
}
diff --git a/jrt_test/src/binref/testrun.sh b/jrt_test/src/binref/testrun.sh
deleted file mode 120000
index 56c3c1186d8..00000000000
--- a/jrt_test/src/binref/testrun.sh
+++ /dev/null
@@ -1 +0,0 @@
-../../../vespalib/src/vespa/vespalib/testkit/testrun.sh \ No newline at end of file
diff --git a/lowercasing_test/src/binref/testrun.sh b/lowercasing_test/src/binref/testrun.sh
deleted file mode 120000
index 56c3c1186d8..00000000000
--- a/lowercasing_test/src/binref/testrun.sh
+++ /dev/null
@@ -1 +0,0 @@
-../../../vespalib/src/vespa/vespalib/testkit/testrun.sh \ No newline at end of file
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/Node.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/Node.java
index 668795f362b..4fef3d8ebf7 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/Node.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/Node.java
@@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableSet;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.ClusterMembership;
import com.yahoo.config.provision.Flavor;
+import com.yahoo.config.provision.NetworkPorts;
import com.yahoo.config.provision.NodeType;
import com.yahoo.vespa.hosted.provision.node.Agent;
import com.yahoo.vespa.hosted.provision.node.Allocation;
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeRepositoryMaintenance.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeRepositoryMaintenance.java
index 2496d9ba8c9..7fdd9a168c8 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeRepositoryMaintenance.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeRepositoryMaintenance.java
@@ -182,7 +182,8 @@ public class NodeRepositoryMaintenance extends AbstractComponent {
DefaultTimes(Zone zone) {
failGrace = Duration.ofMinutes(60);
periodicRedeployInterval = Duration.ofMinutes(30);
- redeployMaintainerInterval = Duration.ofMinutes(1);
+ // Don't redeploy in test environments
+ redeployMaintainerInterval = zone.environment().isTest() ? Duration.ofDays(1) : Duration.ofMinutes(1);
operatorChangeRedeployInterval = Duration.ofMinutes(1);
failedExpirerInterval = Duration.ofMinutes(10);
provisionedExpiry = Duration.ofHours(4);
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Allocation.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Allocation.java
index 8a331209efc..53e1ae3721e 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Allocation.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Allocation.java
@@ -3,6 +3,9 @@ package com.yahoo.vespa.hosted.provision.node;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.ClusterMembership;
+import com.yahoo.config.provision.NetworkPorts;
+
+import java.util.Optional;
/**
* The allocation of a node
@@ -24,12 +27,21 @@ public class Allocation {
/** This node can (and should) be removed from the cluster on the next deployment */
private final boolean removable;
+ private final Optional<NetworkPorts> networkPorts;
+
+
public Allocation(ApplicationId owner, ClusterMembership clusterMembership,
Generation restartGeneration, boolean removable) {
+ this(owner, clusterMembership, restartGeneration, removable, Optional.empty());
+ }
+
+ public Allocation(ApplicationId owner, ClusterMembership clusterMembership,
+ Generation restartGeneration, boolean removable, Optional<NetworkPorts> networkPorts) {
this.owner = owner;
this.clusterMembership = clusterMembership;
this.restartGeneration = restartGeneration;
this.removable = removable;
+ this.networkPorts = networkPorts;
}
/** Returns the id of the application this is allocated to */
@@ -41,14 +53,17 @@ public class Allocation {
/** Returns the restart generation (wanted and current) of this */
public Generation restartGeneration() { return restartGeneration; }
+ /** Returns network ports allocations (or empty if not recorded) */
+ public Optional<NetworkPorts> networkPorts() { return networkPorts; }
+
/** Returns a copy of this which is retired */
public Allocation retire() {
- return new Allocation(owner, clusterMembership.retire(), restartGeneration, removable);
+ return new Allocation(owner, clusterMembership.retire(), restartGeneration, removable, networkPorts);
}
/** Returns a copy of this which is not retired */
public Allocation unretire() {
- return new Allocation(owner, clusterMembership.unretire(), restartGeneration, removable);
+ return new Allocation(owner, clusterMembership.unretire(), restartGeneration, removable, networkPorts);
}
/** Return whether this node is ready to be removed from the application */
@@ -56,16 +71,20 @@ public class Allocation {
/** Returns a copy of this with the current restart generation set to generation */
public Allocation withRestart(Generation generation) {
- return new Allocation(owner, clusterMembership, generation, removable);
+ return new Allocation(owner, clusterMembership, generation, removable, networkPorts);
}
/** Returns a copy of this allocation where removable is set to true */
public Allocation removable() {
- return new Allocation(owner, clusterMembership, restartGeneration, true);
+ return new Allocation(owner, clusterMembership, restartGeneration, true, networkPorts);
}
public Allocation with(ClusterMembership newMembership) {
- return new Allocation(owner, newMembership, restartGeneration, removable);
+ return new Allocation(owner, newMembership, restartGeneration, removable, networkPorts);
+ }
+
+ public Allocation withNetworkPorts(NetworkPorts ports) {
+ return new Allocation(owner, clusterMembership, restartGeneration, removable, Optional.of(ports));
}
@Override
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/NodeSerializer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/NodeSerializer.java
index 54668c4eda1..bb4dab3b97b 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/NodeSerializer.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/NodeSerializer.java
@@ -10,6 +10,8 @@ import com.yahoo.config.provision.Flavor;
import com.yahoo.config.provision.InstanceName;
import com.yahoo.config.provision.NodeFlavors;
import com.yahoo.config.provision.NodeType;
+import com.yahoo.config.provision.NetworkPorts;
+import com.yahoo.config.provision.NetworkPortsSerializer;
import com.yahoo.config.provision.TenantName;
import com.yahoo.slime.ArrayTraverser;
import com.yahoo.slime.Cursor;
@@ -85,6 +87,9 @@ public class NodeSerializer {
private static final String atKey = "at";
private static final String agentKey = "agent"; // retired events only
+ // Network port fields
+ private static final String networkPortsKey = "networkPorts";
+
// ---------------- Serialization ----------------------------------------------------
public NodeSerializer(NodeFlavors flavors) {
@@ -136,6 +141,7 @@ public class NodeSerializer {
object.setLong(currentRestartGenerationKey, allocation.restartGeneration().current());
object.setBool(removableKey, allocation.isRemovable());
object.setString(wantedVespaVersionKey, allocation.membership().cluster().vespaVersion().toString());
+ allocation.networkPorts().ifPresent(ports -> NetworkPortsSerializer.toSlime(ports, object.setArray(networkPortsKey)));
}
private void toSlime(History history, Cursor array) {
@@ -197,7 +203,8 @@ public class NodeSerializer {
return Optional.of(new Allocation(applicationIdFromSlime(object),
clusterMembershipFromSlime(object),
generationFromSlime(object, restartGenerationKey, currentRestartGenerationKey),
- object.field(removableKey).asBool()));
+ object.field(removableKey).asBool(),
+ NetworkPortsSerializer.fromSlime(object.field(networkPortsKey))));
}
private ApplicationId applicationIdFromSlime(Inspector object) {
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java
index f48f0c1bdce..4626a600d2c 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java
@@ -9,6 +9,7 @@ import com.yahoo.transaction.NestedTransaction;
import com.yahoo.vespa.hosted.provision.Node;
import com.yahoo.vespa.hosted.provision.NodeList;
import com.yahoo.vespa.hosted.provision.NodeRepository;
+import com.yahoo.vespa.hosted.provision.node.Allocation;
import java.util.ArrayList;
import java.util.Collection;
@@ -73,7 +74,7 @@ class Activator {
activeToRemove = activeToRemove.stream().map(Node::unretire).collect(Collectors.toList()); // only active nodes can be retired
nodeRepository.deactivate(activeToRemove, transaction);
nodeRepository.activate(updateFrom(hosts, continuedActive), transaction); // update active with any changes
- nodeRepository.activate(reservedToActivate, transaction);
+ nodeRepository.activate(updatePortsFrom(hosts, reservedToActivate), transaction);
}
}
@@ -133,7 +134,11 @@ class Activator {
for (Node node : nodes) {
HostSpec hostSpec = getHost(node.hostname(), hosts);
node = hostSpec.membership().get().retired() ? node.retire(nodeRepository.clock().instant()) : node.unretire();
- node = node.with(node.allocation().get().with(hostSpec.membership().get()));
+ Allocation allocation = node.allocation().get().with(hostSpec.membership().get());
+ if (hostSpec.networkPorts().isPresent()) {
+ allocation = allocation.withNetworkPorts(hostSpec.networkPorts().get());
+ }
+ node = node.with(allocation);
if (hostSpec.flavor().isPresent()) // Docker nodes may change flavor
node = node.with(hostSpec.flavor().get());
updated.add(node);
@@ -141,6 +146,23 @@ class Activator {
return updated;
}
+ /**
+ * Returns the input nodes with any port allocations from the hosts
+ */
+ private List<Node> updatePortsFrom(Collection<HostSpec> hosts, List<Node> nodes) {
+ List<Node> updated = new ArrayList<>();
+ for (Node node : nodes) {
+ HostSpec hostSpec = getHost(node.hostname(), hosts);
+ Allocation allocation = node.allocation().get();
+ if (hostSpec.networkPorts().isPresent()) {
+ allocation = allocation.withNetworkPorts(hostSpec.networkPorts().get());
+ node = node.with(allocation);
+ }
+ updated.add(node);
+ }
+ return updated;
+ }
+
private HostSpec getHost(String hostname, Collection<HostSpec> fromHosts) {
for (HostSpec host : fromHosts)
if (host.hostname().equals(hostname))
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java
index a0d76241533..246c56ee28b 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java
@@ -22,11 +22,13 @@ import com.yahoo.vespa.flags.Flags;
import com.yahoo.vespa.hosted.provision.Node;
import com.yahoo.vespa.hosted.provision.NodeRepository;
import com.yahoo.vespa.hosted.provision.flag.FlagId;
+import com.yahoo.vespa.hosted.provision.node.Allocation;
import com.yahoo.vespa.hosted.provision.node.filter.ApplicationFilter;
import com.yahoo.vespa.hosted.provision.node.filter.NodeHostFilter;
import java.util.ArrayList;
import java.util.Collection;
+import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
@@ -141,10 +143,16 @@ public class NodeRepositoryProvisioner implements Provisioner {
List<HostSpec> hosts = new ArrayList<>(nodes.size());
for (Node node : nodes) {
log.log(LogLevel.DEBUG, () -> "Prepared node " + node.hostname() + " - " + node.flavor());
+ Allocation nodeAllocation = node.allocation().orElseThrow(IllegalStateException::new);
hosts.add(new HostSpec(node.hostname(),
- node.allocation().orElseThrow(IllegalStateException::new).membership(),
- node.flavor(),
- node.status().vespaVersion()));
+ Collections.emptyList(),
+ Optional.of(node.flavor()),
+ Optional.of(nodeAllocation.membership()),
+ node.status().vespaVersion(),
+ nodeAllocation.networkPorts()));
+ if (nodeAllocation.networkPorts().isPresent()) {
+ log.log(LogLevel.DEBUG, () -> "Prepared node " + node.hostname() + " has port allocations");
+ }
}
return hosts;
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/v2/NodesResponse.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/v2/NodesResponse.java
index ba513db5342..1254664eb78 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/v2/NodesResponse.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/v2/NodesResponse.java
@@ -3,6 +3,7 @@ package com.yahoo.vespa.hosted.provision.restapi.v2;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.ClusterMembership;
+import com.yahoo.config.provision.NetworkPortsSerializer;
import com.yahoo.config.provision.NodeType;
import com.yahoo.container.jdisc.HttpRequest;
import com.yahoo.container.jdisc.HttpResponse;
@@ -153,17 +154,18 @@ class NodesResponse extends HttpResponse {
object.setBool("fastDisk", node.flavor().hasFastDisk());
object.setDouble("bandwidth", node.flavor().getBandwidth());
object.setString("environment", node.flavor().getType().name());
- if (node.allocation().isPresent()) {
- toSlime(node.allocation().get().owner(), object.setObject("owner"));
- toSlime(node.allocation().get().membership(), object.setObject("membership"));
- object.setLong("restartGeneration", node.allocation().get().restartGeneration().wanted());
- object.setLong("currentRestartGeneration", node.allocation().get().restartGeneration().current());
- object.setString("wantedDockerImage", nodeRepository.dockerImage().withTag(node.allocation().get().membership().cluster().vespaVersion()).asString());
- object.setString("wantedVespaVersion", node.allocation().get().membership().cluster().vespaVersion().toFullString());
+ node.allocation().ifPresent(allocation -> {
+ toSlime(allocation.owner(), object.setObject("owner"));
+ toSlime(allocation.membership(), object.setObject("membership"));
+ object.setLong("restartGeneration", allocation.restartGeneration().wanted());
+ object.setLong("currentRestartGeneration", allocation.restartGeneration().current());
+ object.setString("wantedDockerImage", nodeRepository.dockerImage().withTag(allocation.membership().cluster().vespaVersion()).asString());
+ object.setString("wantedVespaVersion", allocation.membership().cluster().vespaVersion().toFullString());
+ allocation.networkPorts().ifPresent(ports -> NetworkPortsSerializer.toSlime(ports, object.setArray("networkPorts")));
orchestrator.apply(new HostName(node.hostname()))
.map(status -> status == HostStatus.ALLOWED_TO_BE_DOWN)
.ifPresent(allowedToBeDown -> object.setBool("allowedToBeDown", allowedToBeDown));
- }
+ });
object.setLong("rebootGeneration", node.status().reboot().wanted());
object.setLong("currentRebootGeneration", node.status().reboot().current());
node.status().osVersion().ifPresent(version -> object.setString("currentOsVersion", version.toFullString()));
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/ContainerConfig.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/ContainerConfig.java
index e17e1871555..2a07cadc6ad 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/ContainerConfig.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/ContainerConfig.java
@@ -12,7 +12,7 @@ public class ContainerConfig {
public static String servicesXmlV2(int port) {
return "<jdisc version='1.0'>\n" +
" <config name=\"container.handler.threadpool\">\n" +
- " <maxthreads>10</maxthreads>\n" +
+ " <maxthreads>20</maxthreads>\n" +
" </config> \n" +
" <component id='com.yahoo.test.ManualClock'/>\n" +
" <component id='com.yahoo.vespa.curator.mock.MockCurator'/>\n" +
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/SerializationTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/SerializationTest.java
index 29229efc662..53f6b745da1 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/SerializationTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/SerializationTest.java
@@ -8,6 +8,7 @@ import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.ApplicationName;
import com.yahoo.config.provision.ClusterMembership;
import com.yahoo.config.provision.InstanceName;
+import com.yahoo.config.provision.NetworkPorts;
import com.yahoo.config.provision.NodeFlavors;
import com.yahoo.config.provision.NodeType;
import com.yahoo.config.provision.TenantName;
@@ -29,8 +30,11 @@ import org.junit.Test;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collection;
import java.util.Collections;
+import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
@@ -371,6 +375,31 @@ public class SerializationTest {
assertEquals("some model", node.modelName().get());
}
+ @Test
+ public void testNodeWithNetworkPorts() {
+ Node node = createNode();
+ List<NetworkPorts.Allocation> list = new ArrayList<>();
+ list.add(new NetworkPorts.Allocation(8080, "container", "default/0", "http"));
+ list.add(new NetworkPorts.Allocation(19101, "searchnode", "other/1", "rpc"));
+ NetworkPorts ports = new NetworkPorts(list);
+ node = node.allocate(ApplicationId.from(TenantName.from("myTenant"),
+ ApplicationName.from("myApplication"),
+ InstanceName.from("myInstance")),
+ ClusterMembership.from("content/myId/0/0", Vtag.currentVersion),
+ clock.instant());
+ assertTrue(node.allocation().isPresent());
+ node = node.with(node.allocation().get().withNetworkPorts(ports));
+ assertTrue(node.allocation().isPresent());
+ assertTrue(node.allocation().get().networkPorts().isPresent());
+ Node copy = nodeSerializer.fromJson(Node.State.provisioned, nodeSerializer.toJson(node));
+ assertTrue(copy.allocation().isPresent());
+ assertTrue(copy.allocation().get().networkPorts().isPresent());
+ NetworkPorts portsCopy = node.allocation().get().networkPorts().get();
+ Collection<NetworkPorts.Allocation> listCopy = portsCopy.allocations();
+ assertEquals(list, listCopy);
+ }
+
+
private byte[] createNodeJson(String hostname, String... ipAddress) {
String ipAddressJsonPart = "";
if (ipAddress.length > 0) {
diff --git a/parent/pom.xml b/parent/pom.xml
index 7d7dfc1dc08..1cececbd8e8 100644
--- a/parent/pom.xml
+++ b/parent/pom.xml
@@ -708,6 +708,11 @@
<artifactId>assertj-core</artifactId>
<version>3.11.1</version>
</dependency>
+ <dependency>
+ <groupId>com.amazonaws</groupId>
+ <artifactId>aws-java-sdk-core</artifactId>
+ <version>${aws.sdk.version}</version>
+ </dependency>
</dependencies>
</dependencyManagement>
@@ -717,6 +722,7 @@
<apache.httpclient.version>4.4.1</apache.httpclient.version>
<apache.httpcore.version>4.4.1</apache.httpcore.version>
<asm.version>7.0</asm.version>
+ <aws.sdk.version>1.11.357</aws.sdk.version>
<jna.version>4.5.2</jna.version>
<tensorflow.version>1.12.0</tensorflow.version>
<!-- Athenz dependencies. Make sure these dependencies matches those in Vespa's internal repositories -->
diff --git a/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp b/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp
index afbb1c30f17..78cd9ce44b9 100644
--- a/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp
+++ b/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp
@@ -20,6 +20,7 @@
#include <vespa/document/update/removevalueupdate.h>
#include <vespa/document/update/tensor_add_update.h>
#include <vespa/document/update/tensor_modify_update.h>
+#include <vespa/document/update/tensor_remove_update.h>
#include <vespa/eval/tensor/default_tensor_engine.h>
#include <vespa/eval/tensor/tensor.h>
#include <vespa/searchcore/proton/common/attribute_updater.h>
@@ -28,8 +29,8 @@
#include <vespa/searchlib/attribute/reference_attribute.h>
#include <vespa/searchlib/tensor/dense_tensor_attribute.h>
#include <vespa/searchlib/tensor/generic_tensor_attribute.h>
-#include <vespa/vespalib/testkit/testapp.h>
#include <vespa/vespalib/stllike/hash_map.hpp>
+#include <vespa/vespalib/testkit/testapp.h>
#include <vespa/log/log.h>
LOG_SETUP("attribute_updater_test");
@@ -76,7 +77,8 @@ makeDocumentTypeRepo()
.addField("wsfloat", Wset(DataType::T_FLOAT))
.addField("wsstring", Wset(DataType::T_STRING))
.addField("ref", 333)
- .addField("dense_tensor", DataType::T_TENSOR),
+ .addField("dense_tensor", DataType::T_TENSOR)
+ .addField("sparse_tensor", DataType::T_TENSOR),
Struct("testdoc.body"))
.referenceType(333, 222);
return std::make_unique<DocumentTypeRepo>(builder.config());
@@ -416,35 +418,54 @@ makeTensorFieldValue(const TensorSpec &spec)
return result;
}
-void
-setTensor(TensorAttribute &attribute, uint32_t lid, const TensorSpec &spec)
-{
- auto tensor = makeTensor(spec);
- attribute.setTensor(lid, *tensor);
- attribute.commit();
-}
+template <typename TensorAttributeType>
+struct TensorFixture : public Fixture {
+ vespalib::string type;
+ std::unique_ptr<TensorAttributeType> attribute;
-TEST_F("require that tensor modify update is applied", Fixture)
-{
- vespalib::string type = "tensor(x[2])";
- auto attribute = makeTensorAttribute<DenseTensorAttribute>("dense_tensor", type);
- setTensor(*attribute, 1, TensorSpec(type).add({{"x", 0}}, 3).add({{"x", 1}}, 5));
+ TensorFixture(const vespalib::string &type_, const vespalib::string &name)
+ : type(type_),
+ attribute(makeTensorAttribute<TensorAttributeType>(name, type))
+ {
+ }
- f.applyValueUpdate(*attribute, 1,
+ void setTensor(const TensorSpec &spec) {
+ auto tensor = makeTensor(spec);
+ attribute->setTensor(1, *tensor);
+ attribute->commit();
+ }
+
+ void assertTensor(const TensorSpec &expSpec) {
+ EXPECT_EQUAL(expSpec, attribute->getTensor(1)->toSpec());
+ }
+};
+
+TEST_F("require that tensor modify update is applied",
+ TensorFixture<DenseTensorAttribute>("tensor(x[2])", "dense_tensor"))
+{
+ f.setTensor(TensorSpec(f.type).add({{"x", 0}}, 3).add({{"x", 1}}, 5));
+ f.applyValueUpdate(*f.attribute, 1,
TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE,
makeTensorFieldValue(TensorSpec("tensor(x{})").add({{"x", 0}}, 7))));
- EXPECT_EQUAL(TensorSpec(type).add({{"x", 0}}, 7).add({{"x", 1}}, 5), attribute->getTensor(1)->toSpec());
+ f.assertTensor(TensorSpec(f.type).add({{"x", 0}}, 7).add({{"x", 1}}, 5));
}
-TEST_F("require that tensor add update is applied", Fixture)
+TEST_F("require that tensor add update is applied",
+ TensorFixture<GenericTensorAttribute>("tensor(x{})", "sparse_tensor"))
{
- vespalib::string type = "tensor(x{})";
- auto attribute = makeTensorAttribute<GenericTensorAttribute>("dense_tensor", type);
- setTensor(*attribute, 1, TensorSpec(type).add({{"x", "a"}}, 2));
+ f.setTensor(TensorSpec(f.type).add({{"x", "a"}}, 2));
+ f.applyValueUpdate(*f.attribute, 1,
+ TensorAddUpdate(makeTensorFieldValue(TensorSpec(f.type).add({{"x", "a"}}, 3))));
+ f.assertTensor(TensorSpec(f.type).add({{"x", "a"}}, 3));
+}
- f.applyValueUpdate(*attribute, 1,
- TensorAddUpdate(makeTensorFieldValue(TensorSpec(type).add({{"x", "a"}}, 3))));
- EXPECT_EQUAL(TensorSpec(type).add({{"x", "a"}}, 3), attribute->getTensor(1)->toSpec());
+TEST_F("require that tensor remove update is applied",
+ TensorFixture<GenericTensorAttribute>("tensor(x{})", "sparse_tensor"))
+{
+ f.setTensor(TensorSpec(f.type).add({{"x", "a"}}, 2).add({{"x", "b"}}, 3));
+ f.applyValueUpdate(*f.attribute, 1,
+ TensorRemoveUpdate(makeTensorFieldValue(TensorSpec(f.type).add({{"x", "b"}}, 1))));
+ f.assertTensor(TensorSpec(f.type).add({{"x", "a"}}, 2));
}
}
diff --git a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp
index 933857cffed..fcca1c2a737 100644
--- a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp
+++ b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp
@@ -16,6 +16,7 @@
#include <vespa/document/update/removevalueupdate.h>
#include <vespa/document/update/tensor_add_update.h>
#include <vespa/document/update/tensor_modify_update.h>
+#include <vespa/document/update/tensor_remove_update.h>
#include <vespa/eval/tensor/tensor.h>
#include <vespa/searchlib/attribute/attributevector.hpp>
#include <vespa/searchlib/attribute/changevector.hpp>
@@ -238,6 +239,8 @@ AttributeUpdater::handleUpdate(TensorAttribute &vec, uint32_t lid, const ValueUp
applyTensorUpdate(vec, lid, static_cast<const TensorModifyUpdate &>(upd));
} else if (op == ValueUpdate::TensorAddUpdate) {
applyTensorUpdate(vec, lid, static_cast<const TensorAddUpdate &>(upd));
+ } else if (op == ValueUpdate::TensorRemoveUpdate) {
+ applyTensorUpdate(vec, lid, static_cast<const TensorRemoveUpdate &>(upd));
} else if (op == ValueUpdate::Clear) {
vec.clearDoc(lid);
} else {
diff --git a/security-utils/pom.xml b/security-utils/pom.xml
index 0a26c73cf70..10dec598915 100644
--- a/security-utils/pom.xml
+++ b/security-utils/pom.xml
@@ -54,6 +54,11 @@
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-core</artifactId>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
<plugins>
diff --git a/security-utils/src/main/java/com/yahoo/security/KeyStoreType.java b/security-utils/src/main/java/com/yahoo/security/KeyStoreType.java
index 7fb8df35286..d72bd45865d 100644
--- a/security-utils/src/main/java/com/yahoo/security/KeyStoreType.java
+++ b/security-utils/src/main/java/com/yahoo/security/KeyStoreType.java
@@ -16,7 +16,7 @@ public enum KeyStoreType {
},
PKCS12 {
KeyStore createKeystore() throws KeyStoreException {
- return KeyStore.getInstance("PKCS12", BouncyCastleProviderHolder.getInstance());
+ return KeyStore.getInstance("PKCS12");
}
};
abstract KeyStore createKeystore() throws GeneralSecurityException;
diff --git a/security-utils/src/main/java/com/yahoo/security/SslContextBuilder.java b/security-utils/src/main/java/com/yahoo/security/SslContextBuilder.java
index 09a5a87138f..1ef4df9c7bc 100644
--- a/security-utils/src/main/java/com/yahoo/security/SslContextBuilder.java
+++ b/security-utils/src/main/java/com/yahoo/security/SslContextBuilder.java
@@ -1,11 +1,14 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.security;
+import com.yahoo.security.tls.KeyManagerUtils;
+import com.yahoo.security.tls.TrustManagerUtils;
+
import javax.net.ssl.KeyManager;
-import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
-import javax.net.ssl.TrustManagerFactory;
+import javax.net.ssl.X509ExtendedKeyManager;
+import javax.net.ssl.X509ExtendedTrustManager;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.file.Files;
@@ -19,14 +22,17 @@ import java.util.List;
import static java.util.Collections.singletonList;
/**
+ * A builder for {@link SSLContext}.
+ *
* @author bjorncs
*/
public class SslContextBuilder {
- private KeyStoreSupplier trustStoreSupplier;
- private KeyStoreSupplier keyStoreSupplier;
+ private KeyStoreSupplier trustStoreSupplier = () -> null;
+ private KeyStoreSupplier keyStoreSupplier = () -> null;
private char[] keyStorePassword;
- private TrustManagersFactory trustManagersFactory = SslContextBuilder::createDefaultTrustManagers;
+ private TrustManagerFactory trustManagerFactory = TrustManagerUtils::createDefaultX509TrustManager;
+ private KeyManagerFactory keyManagerFactory = KeyManagerUtils::createDefaultX509KeyManager;
public SslContextBuilder() {}
@@ -94,18 +100,21 @@ public class SslContextBuilder {
return this;
}
- public SslContextBuilder withTrustManagerFactory(TrustManagersFactory trustManagersFactory) {
- this.trustManagersFactory = trustManagersFactory;
+ public SslContextBuilder withTrustManagerFactory(TrustManagerFactory trustManagersFactory) {
+ this.trustManagerFactory = trustManagersFactory;
+ return this;
+ }
+
+ public SslContextBuilder withKeyManagerFactory(KeyManagerFactory keyManagerFactory) {
+ this.keyManagerFactory = keyManagerFactory;
return this;
}
public SSLContext build() {
try {
SSLContext sslContext = SSLContext.getInstance("TLSv1.2");
- TrustManager[] trustManagers =
- trustStoreSupplier != null ? createTrustManagers(trustManagersFactory, trustStoreSupplier) : null;
- KeyManager[] keyManagers =
- keyStoreSupplier != null ? createKeyManagers(keyStoreSupplier, keyStorePassword) : null;
+ TrustManager[] trustManagers = new TrustManager[] { trustManagerFactory.createTrustManager(trustStoreSupplier.get()) };
+ KeyManager[] keyManagers = new KeyManager[] { keyManagerFactory.createKeyManager(keyStoreSupplier.get(), keyStorePassword) };
sslContext.init(keyManagers, trustManagers, null);
return sslContext;
} catch (GeneralSecurityException e) {
@@ -115,27 +124,6 @@ public class SslContextBuilder {
}
}
- private static TrustManager[] createTrustManagers(TrustManagersFactory trustManagersFactory, KeyStoreSupplier trustStoreSupplier)
- throws GeneralSecurityException, IOException {
- KeyStore truststore = trustStoreSupplier.get();
- return trustManagersFactory.createTrustManagers(truststore);
- }
-
- private static TrustManager[] createDefaultTrustManagers(KeyStore truststore) throws GeneralSecurityException {
- TrustManagerFactory trustManagerFactory =
- TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
- trustManagerFactory.init(truststore);
- return trustManagerFactory.getTrustManagers();
- }
-
- private static KeyManager[] createKeyManagers(KeyStoreSupplier keyStoreSupplier, char[] password)
- throws GeneralSecurityException, IOException {
- KeyManagerFactory keyManagerFactory =
- KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
- keyManagerFactory.init(keyStoreSupplier.get(), password);
- return keyManagerFactory.getKeyManagers();
- }
-
private static KeyStore createTrustStore(List<X509Certificate> caCertificates) {
KeyStoreBuilder trustStoreBuilder = KeyStoreBuilder.withType(KeyStoreType.JKS);
for (int i = 0; i < caCertificates.size(); i++) {
@@ -149,11 +137,19 @@ public class SslContextBuilder {
}
/**
- * A factory interface that is similar to {@link TrustManagerFactory}, but is an interface instead of a class.
+ * A factory interface for creating {@link X509ExtendedTrustManager}.
+ */
+ @FunctionalInterface
+ public interface TrustManagerFactory {
+ X509ExtendedTrustManager createTrustManager(KeyStore truststore) throws GeneralSecurityException;
+ }
+
+ /**
+ * A factory interface for creating {@link X509ExtendedKeyManager}.
*/
@FunctionalInterface
- public interface TrustManagersFactory {
- TrustManager[] createTrustManagers(KeyStore truststore) throws GeneralSecurityException;
+ public interface KeyManagerFactory {
+ X509ExtendedKeyManager createKeyManager(KeyStore truststore, char[] password) throws GeneralSecurityException;
}
}
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/AutoReloadingX509KeyManager.java b/security-utils/src/main/java/com/yahoo/security/tls/AutoReloadingX509KeyManager.java
new file mode 100644
index 00000000000..0dae185995c
--- /dev/null
+++ b/security-utils/src/main/java/com/yahoo/security/tls/AutoReloadingX509KeyManager.java
@@ -0,0 +1,150 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.security.tls;
+
+import com.yahoo.security.KeyStoreBuilder;
+import com.yahoo.security.KeyStoreType;
+import com.yahoo.security.KeyUtils;
+import com.yahoo.security.X509CertificateUtils;
+
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.X509ExtendedKeyManager;
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.net.Socket;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.security.KeyStore;
+import java.security.Principal;
+import java.security.PrivateKey;
+import java.security.cert.X509Certificate;
+import java.time.Duration;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * A {@link X509ExtendedKeyManager} that reloads the certificate and private key from file regularly.
+ *
+ * @author bjorncs
+ */
+public class AutoReloadingX509KeyManager extends X509ExtendedKeyManager implements AutoCloseable {
+
+ private static final Duration UPDATE_PERIOD = Duration.ofHours(1);
+
+ private static final Logger log = Logger.getLogger(AutoReloadingX509KeyManager.class.getName());
+
+ private final MutableX509KeyManager mutableX509KeyManager;
+ private final ScheduledExecutorService scheduler;
+ private final Path privateKeyFile;
+ private final Path certificatesFile;
+
+ private AutoReloadingX509KeyManager(Path privateKeyFile, Path certificatesFile) {
+ this(privateKeyFile, certificatesFile, createDefaultScheduler());
+ }
+
+ AutoReloadingX509KeyManager(Path privateKeyFile, Path certificatesFile, ScheduledExecutorService scheduler) {
+ this.privateKeyFile = privateKeyFile;
+ this.certificatesFile = certificatesFile;
+ this.scheduler = scheduler;
+ this.mutableX509KeyManager = new MutableX509KeyManager(createKeystore(privateKeyFile, certificatesFile), new char[0]);
+ scheduler.scheduleAtFixedRate(
+ new KeyManagerReloader(), UPDATE_PERIOD.getSeconds()/*initial delay*/, UPDATE_PERIOD.getSeconds(), TimeUnit.SECONDS);
+ }
+
+ public static AutoReloadingX509KeyManager fromPemFiles(Path privateKeyFile, Path certificatesFile) {
+ return new AutoReloadingX509KeyManager(privateKeyFile, certificatesFile);
+ }
+
+ private static KeyStore createKeystore(Path privateKey, Path certificateChain) {
+ try {
+ return KeyStoreBuilder.withType(KeyStoreType.PKCS12)
+ .withKeyEntry(
+ "default",
+ KeyUtils.fromPemEncodedPrivateKey(Files.readString(privateKey)),
+ X509CertificateUtils.certificateListFromPem(Files.readString(certificateChain)))
+ .build();
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ private static ScheduledExecutorService createDefaultScheduler() {
+ return Executors.newSingleThreadScheduledExecutor(runnable -> {
+ Thread thread = new Thread(runnable, "auto-reloading-x509-key-manager");
+ thread.setDaemon(true);
+ return thread;
+ });
+ }
+
+ private class KeyManagerReloader implements Runnable {
+ @Override
+ public void run() {
+ try {
+ log.log(Level.FINE, () -> String.format("Reloading key and certificate chain (private-key='%s', certificates='%s')", privateKeyFile, certificatesFile));
+ mutableX509KeyManager.updateKeystore(createKeystore(privateKeyFile, certificatesFile), new char[0]);
+ } catch (Throwable t) {
+ log.log(Level.SEVERE,
+ String.format("Failed to load X509 key manager (private-key='%s', certificates='%s'): %s",
+ privateKeyFile, certificatesFile, t.getMessage()),
+ t);
+ }
+ }
+ }
+
+ @Override
+ public void close() {
+ try {
+ scheduler.shutdownNow();
+ scheduler.awaitTermination(5, TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ //
+ // Methods from X509ExtendedKeyManager
+ //
+
+ @Override
+ public String[] getServerAliases(String keyType, Principal[] issuers) {
+ return mutableX509KeyManager.getServerAliases(keyType, issuers);
+ }
+
+ @Override
+ public String[] getClientAliases(String keyType, Principal[] issuers) {
+ return mutableX509KeyManager.getClientAliases(keyType, issuers);
+ }
+
+ @Override
+ public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) {
+ return mutableX509KeyManager.chooseServerAlias(keyType, issuers, socket);
+ }
+
+ @Override
+ public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) {
+ return mutableX509KeyManager.chooseClientAlias(keyType, issuers, socket);
+ }
+
+ @Override
+ public String chooseEngineServerAlias(String keyType, Principal[] issuers, SSLEngine engine) {
+ return mutableX509KeyManager.chooseEngineServerAlias(keyType, issuers, engine);
+ }
+
+ @Override
+ public String chooseEngineClientAlias(String[] keyType, Principal[] issuers, SSLEngine engine) {
+ return mutableX509KeyManager.chooseEngineClientAlias(keyType, issuers, engine);
+ }
+
+ @Override
+ public X509Certificate[] getCertificateChain(String alias) {
+ return mutableX509KeyManager.getCertificateChain(alias);
+ }
+
+ @Override
+ public PrivateKey getPrivateKey(String alias) {
+ return mutableX509KeyManager.getPrivateKey(alias);
+ }
+
+}
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/DefaultTlsContext.java b/security-utils/src/main/java/com/yahoo/security/tls/DefaultTlsContext.java
index 2befd50332a..c9c326df9ed 100644
--- a/security-utils/src/main/java/com/yahoo/security/tls/DefaultTlsContext.java
+++ b/security-utils/src/main/java/com/yahoo/security/tls/DefaultTlsContext.java
@@ -7,7 +7,7 @@ import com.yahoo.security.tls.policy.AuthorizedPeers;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
-import java.nio.file.Path;
+import javax.net.ssl.SSLParameters;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.Arrays;
@@ -38,7 +38,8 @@ public class DefaultTlsContext implements TlsContext {
private static final Logger log = Logger.getLogger(DefaultTlsContext.class.getName());
private final SSLContext sslContext;
- private final List<String> acceptedCiphers;
+ private final String[] validCiphers;
+ private final String[] validProtocols;
public DefaultTlsContext(List<X509Certificate> certificates,
PrivateKey privateKey,
@@ -46,49 +47,77 @@ public class DefaultTlsContext implements TlsContext {
AuthorizedPeers authorizedPeers,
AuthorizationMode mode,
List<String> acceptedCiphers) {
- this.sslContext = createSslContext(certificates, privateKey, caCertificates, authorizedPeers, mode);
- this.acceptedCiphers = acceptedCiphers;
+ this(createSslContext(certificates, privateKey, caCertificates, authorizedPeers, mode),
+ acceptedCiphers);
}
- public DefaultTlsContext(Path tlsOptionsConfigFile, AuthorizationMode mode) {
- TransportSecurityOptions options = TransportSecurityOptions.fromJsonFile(tlsOptionsConfigFile);
- this.sslContext = createSslContext(options, mode);
- this.acceptedCiphers = options.getAcceptedCiphers();
- }
- @Override
- public SSLEngine createSslEngine() {
- SSLEngine sslEngine = sslContext.createSSLEngine();
- restrictSetOfEnabledCiphers(sslEngine, acceptedCiphers);
- restrictTlsProtocols(sslEngine);
- return sslEngine;
+ public DefaultTlsContext(SSLContext sslContext, List<String> acceptedCiphers) {
+ this.sslContext = sslContext;
+ this.validCiphers = getAllowedCiphers(sslContext, acceptedCiphers);
+ this.validProtocols = getAllowedProtocols(sslContext);
}
- private static void restrictSetOfEnabledCiphers(SSLEngine sslEngine, List<String> acceptedCiphers) {
- String[] validCipherSuites = Arrays.stream(sslEngine.getSupportedCipherSuites())
+
+ private static String[] getAllowedCiphers(SSLContext sslContext, List<String> acceptedCiphers) {
+ String[] supportedCipherSuites = sslContext.getSupportedSSLParameters().getCipherSuites();
+ String[] validCipherSuites = Arrays.stream(supportedCipherSuites)
.filter(suite -> ALLOWED_CIPHER_SUITES.contains(suite) && (acceptedCiphers.isEmpty() || acceptedCiphers.contains(suite)))
.toArray(String[]::new);
if (validCipherSuites.length == 0) {
throw new IllegalStateException(
String.format("None of the allowed cipher suites are supported " +
"(allowed-cipher-suites=%s, supported-cipher-suites=%s, accepted-cipher-suites=%s)",
- ALLOWED_CIPHER_SUITES, List.of(sslEngine.getSupportedCipherSuites()), acceptedCiphers));
+ ALLOWED_CIPHER_SUITES, List.of(supportedCipherSuites), acceptedCiphers));
}
- log.log(Level.FINE, () -> String.format("Allowed cipher suites that are supported: %s", Arrays.toString(validCipherSuites)));
- sslEngine.setEnabledCipherSuites(validCipherSuites);
+ log.log(Level.FINE, () -> String.format("Allowed cipher suites that are supported: %s", List.of(validCipherSuites)));
+ return validCipherSuites;
}
- private static void restrictTlsProtocols(SSLEngine sslEngine) {
- String[] validProtocols = Arrays.stream(sslEngine.getSupportedProtocols())
+ private static String[] getAllowedProtocols(SSLContext sslContext) {
+ String[] supportedProtocols = sslContext.getSupportedSSLParameters().getProtocols();
+ String[] validProtocols = Arrays.stream(supportedProtocols)
.filter(ALLOWED_PROTOCOLS::contains)
.toArray(String[]::new);
if (validProtocols.length == 0) {
throw new IllegalArgumentException(
String.format("None of the allowed protocols are supported (allowed-protocols=%s, supported-protocols=%s)",
- ALLOWED_PROTOCOLS, Arrays.toString(sslEngine.getSupportedProtocols())));
+ ALLOWED_PROTOCOLS, List.of(supportedProtocols)));
}
- log.log(Level.FINE, () -> String.format("Allowed protocols that are supported: %s", Arrays.toString(validProtocols)));
- sslEngine.setEnabledProtocols(validProtocols);
+ log.log(Level.FINE, () -> String.format("Allowed protocols that are supported: %s", List.of(validProtocols)));
+ return validProtocols;
+ }
+
+ @Override
+ public SSLContext context() {
+ return sslContext;
+ }
+
+ @Override
+ public SSLParameters parameters() {
+ return createSslParameters();
+ }
+
+ @Override
+ public SSLEngine createSslEngine() {
+ SSLEngine sslEngine = sslContext.createSSLEngine();
+ sslEngine.setSSLParameters(createSslParameters());
+ return sslEngine;
+ }
+
+ @Override
+ public SSLEngine createSslEngine(String peerHost, int peerPort) {
+ SSLEngine sslEngine = sslContext.createSSLEngine(peerHost, peerPort);
+ sslEngine.setSSLParameters(createSslParameters());
+ return sslEngine;
+ }
+
+ private SSLParameters createSslParameters() {
+ SSLParameters newParameters = sslContext.getDefaultSSLParameters();
+ newParameters.setCipherSuites(validCiphers);
+ newParameters.setProtocols(validProtocols);
+ newParameters.setNeedClientAuth(true);
+ return newParameters;
}
private static SSLContext createSslContext(List<X509Certificate> certificates,
@@ -109,16 +138,5 @@ public class DefaultTlsContext implements TlsContext {
return builder.build();
}
- private static SSLContext createSslContext(TransportSecurityOptions options, AuthorizationMode mode) {
- SslContextBuilder builder = new SslContextBuilder();
- options.getCertificatesFile()
- .ifPresent(certificates -> builder.withKeyStore(options.getPrivateKeyFile().get(), certificates));
- options.getCaCertificatesFile().ifPresent(builder::withTrustStore);
- if (mode != AuthorizationMode.DISABLE) {
- options.getAuthorizedPeers().ifPresent(
- authorizedPeers -> builder.withTrustManagerFactory(new PeerAuthorizerTrustManagersFactory(authorizedPeers, mode)));
- }
- return builder.build();
- }
}
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/KeyManagerUtils.java b/security-utils/src/main/java/com/yahoo/security/tls/KeyManagerUtils.java
new file mode 100644
index 00000000000..2e48de3c01f
--- /dev/null
+++ b/security-utils/src/main/java/com/yahoo/security/tls/KeyManagerUtils.java
@@ -0,0 +1,49 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.security.tls;
+
+import com.yahoo.security.KeyStoreBuilder;
+import com.yahoo.security.KeyStoreType;
+
+import javax.net.ssl.KeyManager;
+import javax.net.ssl.KeyManagerFactory;
+import javax.net.ssl.X509ExtendedKeyManager;
+import java.security.GeneralSecurityException;
+import java.security.KeyStore;
+import java.security.PrivateKey;
+import java.security.cert.X509Certificate;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Utility methods for constructing {@link X509ExtendedKeyManager}.
+ *
+ * @author bjorncs
+ */
+public class KeyManagerUtils {
+
+ public static X509ExtendedKeyManager createDefaultX509KeyManager(KeyStore keystore, char[] password) {
+ try {
+ KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
+ keyManagerFactory.init(keystore, password);
+ KeyManager[] keyManagers = keyManagerFactory.getKeyManagers();
+ return Arrays.stream(keyManagers)
+ .filter(manager -> manager instanceof X509ExtendedKeyManager)
+ .map(X509ExtendedKeyManager.class::cast)
+ .findFirst()
+ .orElseThrow(() -> new RuntimeException("No X509ExtendedKeyManager in " + List.of(keyManagers)));
+ } catch (GeneralSecurityException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public static X509ExtendedKeyManager createDefaultX509KeyManager(PrivateKey privateKey, List<X509Certificate> certificateChain) {
+ KeyStore keystore = KeyStoreBuilder.withType(KeyStoreType.PKCS12)
+ .withKeyEntry("default", privateKey, certificateChain)
+ .build();
+ return createDefaultX509KeyManager(keystore, new char[0]);
+ }
+
+ public static X509ExtendedKeyManager createDefaultX509KeyManager() {
+ return createDefaultX509KeyManager(null, new char[0]);
+ }
+}
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/MutableX509KeyManager.java b/security-utils/src/main/java/com/yahoo/security/tls/MutableX509KeyManager.java
new file mode 100644
index 00000000000..e5e56f7a181
--- /dev/null
+++ b/security-utils/src/main/java/com/yahoo/security/tls/MutableX509KeyManager.java
@@ -0,0 +1,106 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.security.tls;
+
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.X509ExtendedKeyManager;
+import java.net.Socket;
+import java.security.KeyStore;
+import java.security.Principal;
+import java.security.PrivateKey;
+import java.security.cert.X509Certificate;
+import java.util.WeakHashMap;
+
+/**
+ * A {@link X509ExtendedKeyManager} which can be updated with new certificate chain and private key while in use.
+ *
+ * The implementations assumes that aliases are retrieved from the same thread as the certificate chain and private key.
+ * This is case for OpenJDK 11.
+ *
+ * @author bjorncs
+ */
+public class MutableX509KeyManager extends X509ExtendedKeyManager {
+
+ // Not using ThreadLocal as we want the x509 key manager instances to be collected
+ // when either the thread dies or the MutableX509KeyManager instance is collected (latter not the case for ThreadLocal).
+ private final WeakHashMap<Thread, X509ExtendedKeyManager> threadLocalManager = new WeakHashMap<>();
+ private volatile X509ExtendedKeyManager currentManager;
+
+ public MutableX509KeyManager(KeyStore keystore, char[] password) {
+ this.currentManager = KeyManagerUtils.createDefaultX509KeyManager(keystore, password);
+ }
+
+ public MutableX509KeyManager() {
+ this.currentManager = KeyManagerUtils.createDefaultX509KeyManager();
+ }
+
+ public void updateKeystore(KeyStore keystore, char[] password) {
+ this.currentManager = KeyManagerUtils.createDefaultX509KeyManager(keystore, password);
+ }
+
+ public void useDefaultKeystore() {
+ this.currentManager = KeyManagerUtils.createDefaultX509KeyManager();
+ }
+
+ @Override
+ public String[] getServerAliases(String keyType, Principal[] issuers) {
+ return updateAndGetThreadLocalManager()
+ .getServerAliases(keyType, issuers);
+ }
+
+ @Override
+ public String[] getClientAliases(String keyType, Principal[] issuers) {
+ return updateAndGetThreadLocalManager()
+ .getClientAliases(keyType, issuers);
+ }
+
+ @Override
+ public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) {
+ return updateAndGetThreadLocalManager()
+ .chooseServerAlias(keyType, issuers, socket);
+ }
+
+ @Override
+ public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) {
+ return updateAndGetThreadLocalManager()
+ .chooseClientAlias(keyType, issuers, socket);
+ }
+
+ @Override
+ public String chooseEngineServerAlias(String keyType, Principal[] issuers, SSLEngine engine) {
+ return updateAndGetThreadLocalManager()
+ .chooseEngineServerAlias(keyType, issuers, engine);
+ }
+
+ @Override
+ public String chooseEngineClientAlias(String[] keyType, Principal[] issuers, SSLEngine engine) {
+ return updateAndGetThreadLocalManager()
+ .chooseEngineClientAlias(keyType, issuers, engine);
+ }
+
+ private X509ExtendedKeyManager updateAndGetThreadLocalManager() {
+ X509ExtendedKeyManager currentManager = this.currentManager;
+ threadLocalManager.put(Thread.currentThread(), currentManager);
+ return currentManager;
+ }
+
+ @Override
+ public X509Certificate[] getCertificateChain(String alias) {
+ return getThreadLocalManager()
+ .getCertificateChain(alias);
+ }
+
+ @Override
+ public PrivateKey getPrivateKey(String alias) {
+ return getThreadLocalManager()
+ .getPrivateKey(alias);
+ }
+
+ private X509ExtendedKeyManager getThreadLocalManager() {
+ X509ExtendedKeyManager manager = threadLocalManager.get(Thread.currentThread());
+ if (manager == null) {
+ throw new IllegalStateException("Methods to retrieve valid aliases has not been called previously from this thread");
+ }
+ return manager;
+ }
+
+}
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/MutableX509TrustManager.java b/security-utils/src/main/java/com/yahoo/security/tls/MutableX509TrustManager.java
new file mode 100644
index 00000000000..ed424480d26
--- /dev/null
+++ b/security-utils/src/main/java/com/yahoo/security/tls/MutableX509TrustManager.java
@@ -0,0 +1,70 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.security.tls;
+
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.X509ExtendedTrustManager;
+import java.net.Socket;
+import java.security.KeyStore;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
+
+/**
+ * A {@link X509ExtendedTrustManager} which can be updated with new CA certificates while in use.
+ *
+ * @author bjorncs
+ */
+public class MutableX509TrustManager extends X509ExtendedTrustManager {
+
+ private volatile X509ExtendedTrustManager currentManager;
+
+ public MutableX509TrustManager(KeyStore truststore) {
+ this.currentManager = TrustManagerUtils.createDefaultX509TrustManager(truststore);
+ }
+
+ public MutableX509TrustManager() {
+ this.currentManager = TrustManagerUtils.createDefaultX509TrustManager();
+ }
+
+ public void updateTruststore(KeyStore truststore) {
+ this.currentManager = TrustManagerUtils.createDefaultX509TrustManager(truststore);
+ }
+
+ public void useDefaultTruststore() {
+ this.currentManager = TrustManagerUtils.createDefaultX509TrustManager();
+ }
+
+ @Override
+ public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
+ currentManager.checkClientTrusted(chain, authType);
+ }
+
+ @Override
+ public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
+ currentManager.checkServerTrusted(chain, authType);
+ }
+
+ @Override
+ public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket) throws CertificateException {
+ currentManager.checkClientTrusted(chain, authType, socket);
+ }
+
+ @Override
+ public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) throws CertificateException {
+ currentManager.checkServerTrusted(chain, authType, socket);
+ }
+
+ @Override
+ public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine sslEngine) throws CertificateException {
+ currentManager.checkClientTrusted(chain, authType, sslEngine);
+ }
+
+ @Override
+ public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine sslEngine) throws CertificateException {
+ currentManager.checkServerTrusted(chain, authType, sslEngine);
+ }
+
+ @Override
+ public X509Certificate[] getAcceptedIssuers() {
+ return currentManager.getAcceptedIssuers();
+ }
+}
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/ReloadingTlsContext.java b/security-utils/src/main/java/com/yahoo/security/tls/ReloadingTlsContext.java
index 5add13e067d..b57105f54f9 100644
--- a/security-utils/src/main/java/com/yahoo/security/tls/ReloadingTlsContext.java
+++ b/security-utils/src/main/java/com/yahoo/security/tls/ReloadingTlsContext.java
@@ -1,13 +1,28 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.security.tls;
+import com.yahoo.security.KeyStoreBuilder;
+import com.yahoo.security.KeyStoreType;
+import com.yahoo.security.KeyUtils;
+import com.yahoo.security.SslContextBuilder;
+import com.yahoo.security.X509CertificateUtils;
+import com.yahoo.security.tls.authz.PeerAuthorizerTrustManager;
+
+import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLParameters;
+import javax.net.ssl.X509ExtendedTrustManager;
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.nio.file.Files;
import java.nio.file.Path;
+import java.security.KeyStore;
+import java.security.cert.X509Certificate;
import java.time.Duration;
+import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -23,8 +38,9 @@ public class ReloadingTlsContext implements TlsContext {
private static final Logger log = Logger.getLogger(ReloadingTlsContext.class.getName());
private final Path tlsOptionsConfigFile;
- private final AuthorizationMode mode;
- private final AtomicReference<TlsContext> currentTlsContext;
+ private final TlsContext tlsContext;
+ private final MutableX509TrustManager trustManager = new MutableX509TrustManager();
+ private final MutableX509KeyManager keyManager = new MutableX509KeyManager();
private final ScheduledExecutorService scheduler =
Executors.newSingleThreadScheduledExecutor(runnable -> {
Thread thread = new Thread(runnable, "tls-context-reloader");
@@ -34,19 +50,77 @@ public class ReloadingTlsContext implements TlsContext {
public ReloadingTlsContext(Path tlsOptionsConfigFile, AuthorizationMode mode) {
this.tlsOptionsConfigFile = tlsOptionsConfigFile;
- this.mode = mode;
- this.currentTlsContext = new AtomicReference<>(new DefaultTlsContext(tlsOptionsConfigFile, mode));
- this.scheduler.scheduleAtFixedRate(new SslContextReloader(),
+ TransportSecurityOptions options = TransportSecurityOptions.fromJsonFile(tlsOptionsConfigFile);
+ reloadCryptoMaterial(options, trustManager, keyManager);
+ this.tlsContext = createDefaultTlsContext(options, mode, trustManager, keyManager);
+ this.scheduler.scheduleAtFixedRate(new CryptoMaterialReloader(),
UPDATE_PERIOD.getSeconds()/*initial delay*/,
UPDATE_PERIOD.getSeconds(),
TimeUnit.SECONDS);
}
- @Override
- public SSLEngine createSslEngine() {
- return currentTlsContext.get().createSslEngine();
+ private static void reloadCryptoMaterial(TransportSecurityOptions options,
+ MutableX509TrustManager trustManager,
+ MutableX509KeyManager keyManager) {
+ if (options.getCaCertificatesFile().isPresent()) {
+ trustManager.updateTruststore(loadTruststore(options.getCaCertificatesFile().get()));
+ } else {
+ trustManager.useDefaultTruststore();
+ }
+
+ if (options.getPrivateKeyFile().isPresent() && options.getCertificatesFile().isPresent()) {
+ keyManager.updateKeystore(loadKeystore(options.getPrivateKeyFile().get(), options.getCertificatesFile().get()), new char[0]);
+ } else {
+ keyManager.useDefaultKeystore();
+ }
}
+ private static KeyStore loadTruststore(Path caCertificateFile) {
+ try {
+ List<X509Certificate> caCertificates = X509CertificateUtils.certificateListFromPem(Files.readString(caCertificateFile));
+ KeyStoreBuilder trustStoreBuilder = KeyStoreBuilder.withType(KeyStoreType.PKCS12);
+ for (int i = 0; i < caCertificates.size(); i++) {
+ trustStoreBuilder.withCertificateEntry("cert-" + i, caCertificates.get(i));
+ }
+ return trustStoreBuilder.build();
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ private static KeyStore loadKeystore(Path privateKeyFile, Path certificatesFile) {
+ try {
+ return KeyStoreBuilder.withType(KeyStoreType.PKCS12)
+ .withKeyEntry(
+ "default",
+ KeyUtils.fromPemEncodedPrivateKey(Files.readString(privateKeyFile)),
+ X509CertificateUtils.certificateListFromPem(Files.readString(certificatesFile)))
+ .build();
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ private static DefaultTlsContext createDefaultTlsContext(TransportSecurityOptions options,
+ AuthorizationMode mode,
+ MutableX509TrustManager mutableTrustManager,
+ MutableX509KeyManager mutableKeyManager) {
+ SSLContext sslContext = new SslContextBuilder()
+ .withKeyManagerFactory((ignoredKeystore, ignoredPassword) -> mutableKeyManager)
+ .withTrustManagerFactory(
+ ignoredTruststore -> options.getAuthorizedPeers()
+ .map(authorizedPeers -> (X509ExtendedTrustManager) new PeerAuthorizerTrustManager(authorizedPeers, mode, mutableTrustManager))
+ .orElse(mutableTrustManager))
+ .build();
+ return new DefaultTlsContext(sslContext, options.getAcceptedCiphers());
+ }
+
+ // Wrapped methods from TlsContext
+ @Override public SSLContext context() { return tlsContext.context(); }
+ @Override public SSLParameters parameters() { return tlsContext.parameters(); }
+ @Override public SSLEngine createSslEngine() { return tlsContext.createSslEngine(); }
+ @Override public SSLEngine createSslEngine(String peerHost, int peerPort) { return tlsContext.createSslEngine(peerHost, peerPort); }
+
@Override
public void close() {
try {
@@ -57,13 +131,13 @@ public class ReloadingTlsContext implements TlsContext {
}
}
- private class SslContextReloader implements Runnable {
+ private class CryptoMaterialReloader implements Runnable {
@Override
public void run() {
try {
- currentTlsContext.set(new DefaultTlsContext(tlsOptionsConfigFile, mode));
+ reloadCryptoMaterial(TransportSecurityOptions.fromJsonFile(tlsOptionsConfigFile), trustManager, keyManager);
} catch (Throwable t) {
- log.log(Level.SEVERE, String.format("Failed to load SSLContext (path='%s'): %s", tlsOptionsConfigFile, t.getMessage()), t);
+ log.log(Level.SEVERE, String.format("Failed to reload crypto material (path='%s'): %s", tlsOptionsConfigFile, t.getMessage()), t);
}
}
}
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/TlsContext.java b/security-utils/src/main/java/com/yahoo/security/tls/TlsContext.java
index 58687a0ba8f..b315dd00b31 100644
--- a/security-utils/src/main/java/com/yahoo/security/tls/TlsContext.java
+++ b/security-utils/src/main/java/com/yahoo/security/tls/TlsContext.java
@@ -3,6 +3,7 @@ package com.yahoo.security.tls;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLParameters;
/**
* A simplified version of {@link SSLContext} modelled as an interface.
@@ -11,8 +12,14 @@ import javax.net.ssl.SSLEngine;
*/
public interface TlsContext extends AutoCloseable {
+ SSLContext context();
+
+ SSLParameters parameters();
+
SSLEngine createSslEngine();
+ SSLEngine createSslEngine(String peerHost, int peerPort);
+
@Override default void close() {}
}
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/TrustManagerUtils.java b/security-utils/src/main/java/com/yahoo/security/tls/TrustManagerUtils.java
new file mode 100644
index 00000000000..f114b672ed8
--- /dev/null
+++ b/security-utils/src/main/java/com/yahoo/security/tls/TrustManagerUtils.java
@@ -0,0 +1,50 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.security.tls;
+
+import com.yahoo.security.KeyStoreBuilder;
+import com.yahoo.security.KeyStoreType;
+
+import javax.net.ssl.TrustManager;
+import javax.net.ssl.TrustManagerFactory;
+import javax.net.ssl.X509ExtendedTrustManager;
+import java.security.GeneralSecurityException;
+import java.security.KeyStore;
+import java.security.cert.X509Certificate;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Utility methods for constructing {@link X509ExtendedTrustManager}.
+ *
+ * @author bjorncs
+ */
+public class TrustManagerUtils {
+
+ public static X509ExtendedTrustManager createDefaultX509TrustManager(KeyStore truststore) {
+ try {
+ TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+ trustManagerFactory.init(truststore);
+ TrustManager[] trustManagers = trustManagerFactory.getTrustManagers();
+ return Arrays.stream(trustManagers)
+ .filter(manager -> manager instanceof X509ExtendedTrustManager)
+ .map(X509ExtendedTrustManager.class::cast)
+ .findFirst()
+ .orElseThrow(() -> new RuntimeException("No X509ExtendedTrustManager in " + List.of(trustManagers)));
+ } catch (GeneralSecurityException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public static X509ExtendedTrustManager createDefaultX509TrustManager(List<X509Certificate> certificates) {
+ KeyStoreBuilder truststoreBuilder = KeyStoreBuilder.withType(KeyStoreType.PKCS12);
+ for (int i = 0; i < certificates.size(); i++) {
+ truststoreBuilder.withCertificateEntry("cert-" + i, certificates.get(i));
+ }
+ KeyStore truststore = truststoreBuilder.build();
+ return createDefaultX509TrustManager(truststore);
+ }
+
+ public static X509ExtendedTrustManager createDefaultX509TrustManager() {
+ return createDefaultX509TrustManager((KeyStore) null);
+ }
+}
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManager.java b/security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManager.java
index 80acc940a99..eee2e502183 100644
--- a/security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManager.java
+++ b/security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManager.java
@@ -3,14 +3,12 @@ package com.yahoo.security.tls.authz;
import com.yahoo.security.X509CertificateUtils;
import com.yahoo.security.tls.AuthorizationMode;
+import com.yahoo.security.tls.TrustManagerUtils;
import com.yahoo.security.tls.policy.AuthorizedPeers;
import javax.net.ssl.SSLEngine;
-import javax.net.ssl.TrustManager;
-import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509ExtendedTrustManager;
import java.net.Socket;
-import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
@@ -39,22 +37,8 @@ public class PeerAuthorizerTrustManager extends X509ExtendedTrustManager {
this.defaultTrustManager = defaultTrustManager;
}
- public static TrustManager[] wrapTrustManagersFromKeystore(AuthorizedPeers authorizedPeers, AuthorizationMode mode, KeyStore keystore) throws GeneralSecurityException {
- TrustManagerFactory factory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
- factory.init(keystore);
- return wrapTrustManagers(authorizedPeers, mode, factory.getTrustManagers());
- }
-
- public static TrustManager[] wrapTrustManagers(AuthorizedPeers authorizedPeers, AuthorizationMode mode, TrustManager[] managers) {
- TrustManager[] wrappedManagers = new TrustManager[managers.length];
- for (int i = 0; i < managers.length; i++) {
- if (managers[i] instanceof X509ExtendedTrustManager) {
- wrappedManagers[i] = new PeerAuthorizerTrustManager(authorizedPeers, mode, (X509ExtendedTrustManager) managers[i]);
- } else {
- wrappedManagers[i] = managers[i];
- }
- }
- return wrappedManagers;
+ public PeerAuthorizerTrustManager(AuthorizedPeers authorizedPeers, AuthorizationMode mode, KeyStore truststore) {
+ this(authorizedPeers, mode, TrustManagerUtils.createDefaultX509TrustManager(truststore));
}
@Override
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManagersFactory.java b/security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManagersFactory.java
index c0a3b4e41a5..6ec8450c035 100644
--- a/security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManagersFactory.java
+++ b/security-utils/src/main/java/com/yahoo/security/tls/authz/PeerAuthorizerTrustManagersFactory.java
@@ -5,14 +5,12 @@ import com.yahoo.security.SslContextBuilder;
import com.yahoo.security.tls.AuthorizationMode;
import com.yahoo.security.tls.policy.AuthorizedPeers;
-import javax.net.ssl.TrustManager;
-import java.security.GeneralSecurityException;
import java.security.KeyStore;
/**
* @author bjorncs
*/
-public class PeerAuthorizerTrustManagersFactory implements SslContextBuilder.TrustManagersFactory {
+public class PeerAuthorizerTrustManagersFactory implements SslContextBuilder.TrustManagerFactory {
private final AuthorizedPeers authorizedPeers;
private AuthorizationMode mode;
@@ -22,7 +20,7 @@ public class PeerAuthorizerTrustManagersFactory implements SslContextBuilder.Tru
}
@Override
- public TrustManager[] createTrustManagers(KeyStore truststore) throws GeneralSecurityException {
- return PeerAuthorizerTrustManager.wrapTrustManagersFromKeystore(authorizedPeers, mode, truststore);
+ public PeerAuthorizerTrustManager createTrustManager(KeyStore truststore) {
+ return new PeerAuthorizerTrustManager(authorizedPeers, mode, truststore);
}
}
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClient.java b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClient.java
new file mode 100644
index 00000000000..2911b77707a
--- /dev/null
+++ b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClient.java
@@ -0,0 +1,101 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.security.tls.https;
+
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLParameters;
+import java.io.IOException;
+import java.net.Authenticator;
+import java.net.CookieHandler;
+import java.net.ProxySelector;
+import java.net.http.HttpClient;
+import java.net.http.HttpRequest;
+import java.net.http.HttpResponse;
+import java.time.Duration;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Executor;
+
+/**
+ * A {@link HttpClient} that uses either http or https based on the global Vespa TLS configuration.
+ *
+ * @author bjorncs
+ */
+class TlsAwareHttpClient extends HttpClient {
+
+ private final HttpClient wrappedClient;
+ private final String userAgent;
+
+ TlsAwareHttpClient(HttpClient wrappedClient, String userAgent) {
+ this.wrappedClient = wrappedClient;
+ this.userAgent = userAgent;
+ }
+
+ @Override
+ public Optional<CookieHandler> cookieHandler() {
+ return wrappedClient.cookieHandler();
+ }
+
+ @Override
+ public Optional<Duration> connectTimeout() {
+ return wrappedClient.connectTimeout();
+ }
+
+ @Override
+ public Redirect followRedirects() {
+ return wrappedClient.followRedirects();
+ }
+
+ @Override
+ public Optional<ProxySelector> proxy() {
+ return wrappedClient.proxy();
+ }
+
+ @Override
+ public SSLContext sslContext() {
+ return wrappedClient.sslContext();
+ }
+
+ @Override
+ public SSLParameters sslParameters() {
+ return wrappedClient.sslParameters();
+ }
+
+ @Override
+ public Optional<Authenticator> authenticator() {
+ return wrappedClient.authenticator();
+ }
+
+ @Override
+ public Version version() {
+ return wrappedClient.version();
+ }
+
+ @Override
+ public Optional<Executor> executor() {
+ return wrappedClient.executor();
+ }
+
+ @Override
+ public <T> HttpResponse<T> send(HttpRequest request, HttpResponse.BodyHandler<T> responseBodyHandler) throws IOException, InterruptedException {
+ return wrappedClient.send(wrapRequest(request), responseBodyHandler);
+ }
+
+ @Override
+ public <T> CompletableFuture<HttpResponse<T>> sendAsync(HttpRequest request, HttpResponse.BodyHandler<T> responseBodyHandler) {
+ return wrappedClient.sendAsync(wrapRequest(request), responseBodyHandler);
+ }
+
+ @Override
+ public <T> CompletableFuture<HttpResponse<T>> sendAsync(HttpRequest request, HttpResponse.BodyHandler<T> responseBodyHandler, HttpResponse.PushPromiseHandler<T> pushPromiseHandler) {
+ return wrappedClient.sendAsync(wrapRequest(request), responseBodyHandler, pushPromiseHandler);
+ }
+
+ @Override
+ public String toString() {
+ return wrappedClient.toString();
+ }
+
+ private HttpRequest wrapRequest(HttpRequest request) {
+ return new TlsAwareHttpRequest(request, userAgent);
+ }
+}
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClientBuilder.java b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClientBuilder.java
new file mode 100644
index 00000000000..7eca2463ba7
--- /dev/null
+++ b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClientBuilder.java
@@ -0,0 +1,97 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.security.tls.https;
+
+import com.yahoo.security.tls.TlsContext;
+
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLParameters;
+import java.net.Authenticator;
+import java.net.CookieHandler;
+import java.net.ProxySelector;
+import java.net.http.HttpClient;
+import java.time.Duration;
+import java.util.concurrent.Executor;
+
+/**
+ * A client builder for {@link HttpClient} which uses {@link TlsContext} for TLS configuration.
+ * Intended for internal Vespa communication only.
+ *
+ * @author bjorncs
+ */
+public class TlsAwareHttpClientBuilder implements HttpClient.Builder {
+
+ private final HttpClient.Builder wrappedBuilder;
+ private final String userAgent;
+
+ public TlsAwareHttpClientBuilder(TlsContext tlsContext) {
+ this(tlsContext, "vespa-tls-aware-client");
+ }
+
+ public TlsAwareHttpClientBuilder(TlsContext tlsContext, String userAgent) {
+ this.wrappedBuilder = HttpClient.newBuilder()
+ .sslContext(tlsContext.context())
+ .sslParameters(tlsContext.parameters());
+ this.userAgent = userAgent;
+ }
+
+ @Override
+ public HttpClient.Builder cookieHandler(CookieHandler cookieHandler) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public HttpClient.Builder connectTimeout(Duration duration) {
+ wrappedBuilder.connectTimeout(duration);
+ return this;
+ }
+
+ @Override
+ public HttpClient.Builder sslContext(SSLContext sslContext) {
+ throw new UnsupportedOperationException("SSLContext is given from tls context");
+ }
+
+ @Override
+ public HttpClient.Builder sslParameters(SSLParameters sslParameters) {
+ throw new UnsupportedOperationException("SSLParameters is given from tls context");
+ }
+
+ @Override
+ public HttpClient.Builder executor(Executor executor) {
+ wrappedBuilder.executor(executor);
+ return this;
+ }
+
+ @Override
+ public HttpClient.Builder followRedirects(HttpClient.Redirect policy) {
+ wrappedBuilder.followRedirects(policy);
+ return this;
+ }
+
+ @Override
+ public HttpClient.Builder version(HttpClient.Version version) {
+ wrappedBuilder.version(version);
+ return this;
+ }
+
+ @Override
+ public HttpClient.Builder priority(int priority) {
+ wrappedBuilder.priority(priority);
+ return this;
+ }
+
+ @Override
+ public HttpClient.Builder proxy(ProxySelector proxySelector) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public HttpClient.Builder authenticator(Authenticator authenticator) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public HttpClient build() {
+ // TODO Stop wrapping the client once TLS is mandatory
+ return new TlsAwareHttpClient(wrappedBuilder.build(), userAgent);
+ }
+}
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpRequest.java b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpRequest.java
new file mode 100644
index 00000000000..bbdd8af791f
--- /dev/null
+++ b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpRequest.java
@@ -0,0 +1,103 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.security.tls.https;
+
+import com.yahoo.security.tls.MixedMode;
+import com.yahoo.security.tls.TransportSecurityUtils;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.net.http.HttpClient;
+import java.net.http.HttpHeaders;
+import java.net.http.HttpRequest;
+import java.time.Duration;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * A {@link HttpRequest} where the scheme is either http or https based on the global Vespa TLS configuration.
+ *
+ * @author bjorncs
+ */
+class TlsAwareHttpRequest extends HttpRequest {
+
+ private final URI rewrittenUri;
+ private final HttpRequest wrappedRequest;
+ private final HttpHeaders rewrittenHeaders;
+
+ TlsAwareHttpRequest(HttpRequest wrappedRequest, String userAgent) {
+ this.wrappedRequest = wrappedRequest;
+ this.rewrittenUri = rewriteUri(wrappedRequest.uri());
+ this.rewrittenHeaders = rewriteHeaders(wrappedRequest, userAgent);
+ }
+
+ @Override
+ public Optional<BodyPublisher> bodyPublisher() {
+ return wrappedRequest.bodyPublisher();
+ }
+
+ @Override
+ public String method() {
+ return wrappedRequest.method();
+ }
+
+ @Override
+ public Optional<Duration> timeout() {
+ return wrappedRequest.timeout();
+ }
+
+ @Override
+ public boolean expectContinue() {
+ return wrappedRequest.expectContinue();
+ }
+
+ @Override
+ public URI uri() {
+ return rewrittenUri;
+ }
+
+ @Override
+ public Optional<HttpClient.Version> version() {
+ return wrappedRequest.version();
+ }
+
+ @Override
+ public HttpHeaders headers() {
+ return rewrittenHeaders;
+ }
+
+ private static URI rewriteUri(URI uri) {
+ if (!uri.getScheme().equals("http")) {
+ return uri;
+ }
+ String rewrittenScheme =
+ TransportSecurityUtils.getConfigFile().isPresent() && TransportSecurityUtils.getInsecureMixedMode() != MixedMode.PLAINTEXT_CLIENT_MIXED_SERVER ?
+ "https" :
+ "http";
+ int port = uri.getPort();
+ int rewrittenPort = port != -1 ? port : (rewrittenScheme.equals("http") ? 80 : 443);
+ try {
+ return new URI(rewrittenScheme, uri.getUserInfo(), uri.getHost(), rewrittenPort, uri.getPath(), uri.getQuery(), uri.getFragment());
+ } catch (URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private static HttpHeaders rewriteHeaders(HttpRequest request, String userAgent) {
+ HttpHeaders headers = request.headers();
+ if (headers.firstValue("User-Agent").isPresent()) {
+ return headers;
+ }
+ HashMap<String, List<String>> rewrittenHeaders = new HashMap<>(headers.map());
+ rewrittenHeaders.put("User-Agent", List.of(userAgent));
+ return HttpHeaders.of(rewrittenHeaders, (ignored1, ignored2) -> true);
+ }
+
+ @Override
+ public String toString() {
+ return "TlsAwareHttpRequest{" +
+ "rewrittenUri=" + rewrittenUri +
+ ", wrappedRequest=" + wrappedRequest +
+ '}';
+ }
+}
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/package-info.java b/security-utils/src/main/java/com/yahoo/security/tls/https/package-info.java
new file mode 100644
index 00000000000..43067705fa3
--- /dev/null
+++ b/security-utils/src/main/java/com/yahoo/security/tls/https/package-info.java
@@ -0,0 +1,8 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+/**
+ * @author bjorncs
+ */
+@ExportPackage
+package com.yahoo.security.tls.https;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/security-utils/src/test/java/com/yahoo/security/tls/AutoReloadingX509KeyManagerTest.java b/security-utils/src/test/java/com/yahoo/security/tls/AutoReloadingX509KeyManagerTest.java
new file mode 100644
index 00000000000..139d5313074
--- /dev/null
+++ b/security-utils/src/test/java/com/yahoo/security/tls/AutoReloadingX509KeyManagerTest.java
@@ -0,0 +1,84 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.security.tls;
+
+import com.yahoo.security.KeyAlgorithm;
+import com.yahoo.security.KeyUtils;
+import com.yahoo.security.SignatureAlgorithm;
+import com.yahoo.security.X509CertificateBuilder;
+import com.yahoo.security.X509CertificateUtils;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mockito;
+
+import javax.security.auth.x500.X500Principal;
+import java.io.IOException;
+import java.math.BigInteger;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.security.KeyPair;
+import java.security.Principal;
+import java.security.cert.X509Certificate;
+import java.time.Instant;
+import java.util.concurrent.ScheduledExecutorService;
+
+import static java.time.temporal.ChronoUnit.DAYS;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyLong;
+import static org.mockito.Mockito.verify;
+
+/**
+ * @author bjorncs
+ */
+public class AutoReloadingX509KeyManagerTest {
+ private static final X500Principal SUBJECT = new X500Principal("CN=dummy");
+
+ @Rule
+ public TemporaryFolder tempDirectory = new TemporaryFolder();
+
+ @Test
+ public void crypto_material_is_reloaded_when_scheduler_task_is_executed() throws IOException {
+ KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC);
+ Path privateKeyFile = tempDirectory.newFile().toPath();
+ Files.writeString(privateKeyFile, KeyUtils.toPem(keyPair.getPrivate()));
+
+ Path certificateFile = tempDirectory.newFile().toPath();
+ BigInteger serialNumberInitialCertificate = BigInteger.ONE;
+ X509Certificate initialCertificate = generateCertificate(keyPair, serialNumberInitialCertificate);
+ Files.writeString(certificateFile, X509CertificateUtils.toPem(initialCertificate));
+
+ ScheduledExecutorService scheduler = Mockito.mock(ScheduledExecutorService.class);
+ ArgumentCaptor<Runnable> updaterTaskCaptor = ArgumentCaptor.forClass(Runnable.class);
+
+ AutoReloadingX509KeyManager keyManager = new AutoReloadingX509KeyManager(privateKeyFile, certificateFile, scheduler);
+ verify(scheduler).scheduleAtFixedRate(updaterTaskCaptor.capture(), anyLong(), anyLong(), any());
+
+ String[] initialAliases = keyManager.getClientAliases(keyPair.getPublic().getAlgorithm(), new Principal[]{SUBJECT});
+ X509Certificate[] certChain = keyManager.getCertificateChain(initialAliases[0]);
+ assertThat(certChain).hasSize(1);
+ assertThat(certChain[0].getSerialNumber()).isEqualTo(serialNumberInitialCertificate);
+
+ BigInteger serialNumberUpdatedCertificate = BigInteger.TWO;
+ X509Certificate updatedCertificate = generateCertificate(keyPair, serialNumberUpdatedCertificate);
+ Files.writeString(certificateFile, X509CertificateUtils.toPem(updatedCertificate));
+
+ updaterTaskCaptor.getValue().run(); // run update task in ReloadingX509KeyManager
+
+ String[] updatedAliases = keyManager.getClientAliases(keyPair.getPublic().getAlgorithm(), new Principal[]{SUBJECT});
+ X509Certificate[] updatedCertChain = keyManager.getCertificateChain(updatedAliases[0]);
+ assertThat(updatedCertChain).hasSize(1);
+ assertThat(updatedCertChain[0].getSerialNumber()).isEqualTo(serialNumberUpdatedCertificate);
+ }
+
+ private static X509Certificate generateCertificate(KeyPair keyPair, BigInteger serialNumber) {
+ return X509CertificateBuilder.fromKeypair(keyPair,
+ SUBJECT,
+ Instant.EPOCH,
+ Instant.EPOCH.plus(1, DAYS),
+ SignatureAlgorithm.SHA256_WITH_ECDSA,
+ serialNumber)
+ .build();
+ }
+} \ No newline at end of file
diff --git a/security-utils/src/test/java/com/yahoo/security/tls/MutableX509KeyManagerTest.java b/security-utils/src/test/java/com/yahoo/security/tls/MutableX509KeyManagerTest.java
new file mode 100644
index 00000000000..30e54d3c09d
--- /dev/null
+++ b/security-utils/src/test/java/com/yahoo/security/tls/MutableX509KeyManagerTest.java
@@ -0,0 +1,65 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.security.tls;
+
+import com.yahoo.security.KeyAlgorithm;
+import com.yahoo.security.KeyStoreBuilder;
+import com.yahoo.security.KeyStoreType;
+import com.yahoo.security.KeyUtils;
+import com.yahoo.security.SignatureAlgorithm;
+import com.yahoo.security.X509CertificateBuilder;
+import org.junit.Test;
+
+import javax.security.auth.x500.X500Principal;
+import java.math.BigInteger;
+import java.security.KeyPair;
+import java.security.KeyStore;
+import java.security.Principal;
+import java.security.cert.X509Certificate;
+import java.time.Instant;
+
+import static java.time.temporal.ChronoUnit.DAYS;
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * @author bjorncs
+ */
+public class MutableX509KeyManagerTest {
+
+ private static final X500Principal SUBJECT = new X500Principal("CN=dummy");
+
+ @Test
+ public void key_manager_can_be_updated_with_new_certificate() {
+ KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC);
+
+ BigInteger serialNumberInitialCertificate = BigInteger.ONE;
+ KeyStore initialKeystore = generateKeystore(keyPair, serialNumberInitialCertificate);
+
+ MutableX509KeyManager keyManager = new MutableX509KeyManager(initialKeystore, new char[0]);
+
+ String[] initialAliases = keyManager.getClientAliases(keyPair.getPublic().getAlgorithm(), new Principal[]{SUBJECT});
+ assertThat(initialAliases).hasSize(1);
+ X509Certificate[] certChain = keyManager.getCertificateChain(initialAliases[0]);
+ assertThat(certChain).hasSize(1);
+ assertThat(certChain[0].getSerialNumber()).isEqualTo(serialNumberInitialCertificate);
+
+ BigInteger serialNumberUpdatedCertificate = BigInteger.TWO;
+ KeyStore updatedKeystore = generateKeystore(keyPair, serialNumberUpdatedCertificate);
+ keyManager.updateKeystore(updatedKeystore, new char[0]);
+
+ String[] updatedAliases = keyManager.getClientAliases(keyPair.getPublic().getAlgorithm(), new Principal[]{SUBJECT});
+ assertThat(updatedAliases).hasSize(1);
+ X509Certificate[] updatedCertChain = keyManager.getCertificateChain(updatedAliases[0]);
+ assertThat(updatedCertChain).hasSize(1);
+ assertThat(updatedCertChain[0].getSerialNumber()).isEqualTo(serialNumberUpdatedCertificate);
+ }
+
+ private static KeyStore generateKeystore(KeyPair keyPair, BigInteger serialNumber) {
+ X509Certificate certificate = X509CertificateBuilder.fromKeypair(
+ keyPair, SUBJECT, Instant.EPOCH, Instant.EPOCH.plus(1, DAYS), SignatureAlgorithm.SHA256_WITH_ECDSA, serialNumber)
+ .build();
+ return KeyStoreBuilder.withType(KeyStoreType.PKCS12)
+ .withKeyEntry("default", keyPair.getPrivate(), certificate)
+ .build();
+ }
+
+} \ No newline at end of file
diff --git a/security-utils/src/test/java/com/yahoo/security/tls/MutableX509TrustManagerTest.java b/security-utils/src/test/java/com/yahoo/security/tls/MutableX509TrustManagerTest.java
new file mode 100644
index 00000000000..4c4ea332818
--- /dev/null
+++ b/security-utils/src/test/java/com/yahoo/security/tls/MutableX509TrustManagerTest.java
@@ -0,0 +1,59 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.security.tls;
+
+import com.yahoo.security.KeyAlgorithm;
+import com.yahoo.security.KeyStoreBuilder;
+import com.yahoo.security.KeyStoreType;
+import com.yahoo.security.KeyUtils;
+import com.yahoo.security.SignatureAlgorithm;
+import com.yahoo.security.X509CertificateBuilder;
+import org.junit.Test;
+
+import javax.security.auth.x500.X500Principal;
+import java.math.BigInteger;
+import java.security.KeyPair;
+import java.security.KeyStore;
+import java.security.cert.X509Certificate;
+import java.time.Instant;
+
+import static java.time.temporal.ChronoUnit.DAYS;
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * @author bjorncs
+ */
+public class MutableX509TrustManagerTest {
+
+ @Test
+ public void key_manager_can_be_updated_with_new_certificate() {
+ KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC);
+
+ X509Certificate initialCertificate = generateCertificate(new X500Principal("CN=issuer1"), keyPair);
+ KeyStore initialTruststore = generateTruststore(initialCertificate);
+
+ MutableX509TrustManager trustManager = new MutableX509TrustManager(initialTruststore);
+
+ X509Certificate[] initialAcceptedIssuers = trustManager.getAcceptedIssuers();
+ assertThat(initialAcceptedIssuers).containsExactly(initialCertificate);
+
+ X509Certificate updatedCertificate = generateCertificate(new X500Principal("CN=issuer2"), keyPair);
+ KeyStore updatedTruststore = generateTruststore(updatedCertificate);
+ trustManager.updateTruststore(updatedTruststore);
+
+ X509Certificate[] updatedAcceptedIssuers = trustManager.getAcceptedIssuers();
+ assertThat(updatedAcceptedIssuers).containsExactly(updatedCertificate);
+ }
+
+ private static X509Certificate generateCertificate(X500Principal issuer, KeyPair keyPair) {
+ return X509CertificateBuilder.fromKeypair(
+ keyPair, issuer, Instant.EPOCH, Instant.EPOCH.plus(1, DAYS), SignatureAlgorithm.SHA256_WITH_ECDSA, BigInteger.ONE)
+ .build();
+ }
+
+ private static KeyStore generateTruststore(X509Certificate certificate) {
+ return KeyStoreBuilder.withType(KeyStoreType.PKCS12)
+ .withCertificateEntry("default", certificate)
+ .build();
+ }
+
+} \ No newline at end of file
diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/model/ServiceModelCache.java b/service-monitor/src/main/java/com/yahoo/vespa/service/model/ServiceModelCache.java
index 7a6f37b2c94..c50f5e6c2d5 100644
--- a/service-monitor/src/main/java/com/yahoo/vespa/service/model/ServiceModelCache.java
+++ b/service-monitor/src/main/java/com/yahoo/vespa/service/model/ServiceModelCache.java
@@ -46,10 +46,12 @@ public class ServiceModelCache implements Supplier<ServiceModel> {
updatePossiblyInProgress = true;
}
- takeSnapshot();
-
- synchronized (updateMonitor) {
- updatePossiblyInProgress = false;
+ try {
+ takeSnapshot();
+ } finally {
+ synchronized (updateMonitor) {
+ updatePossiblyInProgress = false;
+ }
}
}
diff --git a/vespa-athenz/pom.xml b/vespa-athenz/pom.xml
index 27b68fbf360..0f23eaed964 100644
--- a/vespa-athenz/pom.xml
+++ b/vespa-athenz/pom.xml
@@ -114,7 +114,24 @@
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
</dependency>
-
+ <dependency>
+ <groupId>com.amazonaws</groupId>
+ <artifactId>aws-java-sdk-core</artifactId>
+ <exclusions>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-core</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-annotations</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
</dependencies>
<build>
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/aws/AwsCredentialsProvider.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/aws/AwsCredentialsProvider.java
new file mode 100644
index 00000000000..28f028832b4
--- /dev/null
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/client/aws/AwsCredentialsProvider.java
@@ -0,0 +1,79 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.athenz.client.aws;
+
+import com.amazonaws.auth.AWSCredentials;
+import com.amazonaws.auth.AWSCredentialsProvider;
+import com.amazonaws.auth.BasicSessionCredentials;
+import com.yahoo.vespa.athenz.api.AthenzDomain;
+import com.yahoo.vespa.athenz.api.AwsRole;
+import com.yahoo.vespa.athenz.api.AwsTemporaryCredentials;
+import com.yahoo.vespa.athenz.client.zts.DefaultZtsClient;
+import com.yahoo.vespa.athenz.client.zts.ZtsClient;
+import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider;
+
+import javax.net.ssl.SSLContext;
+import java.net.URI;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Objects;
+import java.util.logging.Logger;
+
+/**
+ * Implementation of AWSCredentialsProvider using com.yahoo.vespa.athenz.client.zts.ZtsClient
+ *
+ * @author mortent
+ */
+public class AwsCredentialsProvider implements AWSCredentialsProvider {
+
+ private static final Logger logger = Logger.getLogger(AwsCredentialsProvider.class.getName());
+
+ private final static Duration MIN_EXPIRY = Duration.ofMinutes(5);
+ private final AthenzDomain athenzDomain;
+ private final AwsRole awsRole;
+ private final ZtsClient ztsClient;
+ private volatile AwsTemporaryCredentials credentials;
+
+ public AwsCredentialsProvider(ZtsClient ztsClient, AthenzDomain athenzDomain, AwsRole awsRole) {
+ this.ztsClient = ztsClient;
+ this.athenzDomain = athenzDomain;
+ this.awsRole = awsRole;
+ this.credentials = getAthenzTempCredentials();
+ }
+
+ public AwsCredentialsProvider(URI ztsUrl, ServiceIdentityProvider identityProvider, AthenzDomain athenzDomain, AwsRole awsRole) {
+ this(new DefaultZtsClient(ztsUrl, identityProvider), athenzDomain, awsRole);
+ }
+
+ public AwsCredentialsProvider(URI ztsUrl, SSLContext sslContext, AthenzDomain athenzDomain, AwsRole awsRole) {
+ this(new DefaultZtsClient(ztsUrl, null, sslContext), athenzDomain, awsRole);
+ }
+
+ /**
+ * Requests temporary credentials from ZTS or return cached credentials
+ */
+ private AwsTemporaryCredentials getAthenzTempCredentials() {
+ if(shouldRefresh(credentials)) {
+ this.credentials = ztsClient.getAwsTemporaryCredentials(athenzDomain, awsRole);
+ }
+ return credentials;
+ }
+
+ @Override
+ public AWSCredentials getCredentials() {
+ AwsTemporaryCredentials creds = getAthenzTempCredentials();
+ return new BasicSessionCredentials(creds.accessKeyId(), creds.secretAccessKey(), creds.sessionToken());
+ }
+
+ @Override
+ public void refresh() {
+ getAthenzTempCredentials();
+ }
+
+ /*
+ * Checks credential expiration, returns true if it will expipre in the next MIN_EXPIRY minutes
+ */
+ private static boolean shouldRefresh(AwsTemporaryCredentials credentials) {
+ Instant expiration = credentials.expiration();
+ return Objects.isNull(expiration) || expiration.minus(MIN_EXPIRY).isAfter(Instant.now());
+ }
+}
diff --git a/vespa-hadoop/abi-spec.json b/vespa-hadoop/abi-spec.json
index 5bbac15f0e5..e3f4dcf272a 100644
--- a/vespa-hadoop/abi-spec.json
+++ b/vespa-hadoop/abi-spec.json
@@ -1201,6 +1201,8 @@
"public com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.DimensionSizes dimensionSizes()",
"public java.util.Map cells()",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
@@ -1245,6 +1247,8 @@
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)"
@@ -1330,6 +1334,8 @@
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
@@ -1432,6 +1438,8 @@
"public double asDouble()",
"public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public abstract com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public abstract com.yahoo.tensor.Tensor remove(java.util.Set)",
"public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)",
"public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])",
"public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)",
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 932513f8a57..c3fe8c5c7ad 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -808,6 +808,8 @@
"public com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.DimensionSizes dimensionSizes()",
"public java.util.Map cells()",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
@@ -852,6 +854,8 @@
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)"
@@ -937,6 +941,8 @@
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
@@ -1039,6 +1045,8 @@
"public double asDouble()",
"public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public abstract com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public abstract com.yahoo.tensor.Tensor remove(java.util.Set)",
"public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)",
"public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])",
"public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index fb55b2d5014..38d832d01c2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -13,6 +13,7 @@ import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.Set;
+import java.util.function.DoubleBinaryOperator;
/**
* An indexed (dense) tensor backed by a double array.
@@ -190,6 +191,16 @@ public class IndexedTensor implements Tensor {
}
@Override
+ public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells) {
+ throw new IllegalArgumentException("Merge is not supported for indexed tensors");
+ }
+
+ @Override
+ public Tensor remove(Set<TensorAddress> addresses) {
+ throw new IllegalArgumentException("Remove is not supported for indexed tensors");
+ }
+
+ @Override
public int hashCode() { return Arrays.hashCode(values); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index ec3020a1a4e..22ceed22d3e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -5,6 +5,8 @@ import com.google.common.collect.ImmutableMap;
import java.util.Iterator;
import java.util.Map;
+import java.util.Set;
+import java.util.function.DoubleBinaryOperator;
/**
* A sparse implementation of a tensor backed by a Map of cells to values.
@@ -51,6 +53,38 @@ public class MappedTensor implements Tensor {
}
@Override
+ public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) {
+
+ // currently, underlying implementation disallows multiple entries with the same key
+
+ Tensor.Builder builder = Tensor.Builder.of(type());
+ for (Map.Entry<TensorAddress, Double> cell : cells.entrySet()) {
+ TensorAddress address = cell.getKey();
+ double value = cell.getValue();
+ builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value);
+ }
+ for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) {
+ if ( ! cells.containsKey(addCell.getKey())) {
+ builder.cell(addCell.getKey(), addCell.getValue());
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
+ public Tensor remove(Set<TensorAddress> addresses) {
+ Tensor.Builder builder = Tensor.Builder.of(type());
+ for (Iterator<Tensor.Cell> i = cellIterator(); i.hasNext(); ) {
+ Tensor.Cell cell = i.next();
+ TensorAddress address = cell.getKey();
+ if ( ! addresses.contains(address)) {
+ builder.cell(address, cell.getValue());
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
public int hashCode() { return cells.hashCode(); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 17e33c58a13..08878edeb83 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -9,6 +9,8 @@ import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.Set;
+import java.util.function.DoubleBinaryOperator;
import java.util.stream.Collectors;
/**
@@ -70,13 +72,17 @@ public class MixedTensor implements Tensor {
return cells.iterator();
}
+ private Iterable<Cell> cellIterable() {
+ return this::cellIterator;
+ }
+
/**
* Returns an iterator over the values of this tensor.
* The iteration order is the same as for cellIterator.
*/
@Override
public Iterator<Double> valueIterator() {
- return new Iterator<Double>() {
+ return new Iterator<>() {
Iterator<Cell> cellIterator = cellIterator();
@Override
public boolean hasNext() {
@@ -108,6 +114,38 @@ public class MixedTensor implements Tensor {
}
@Override
+ public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) {
+ Tensor.Builder builder = Tensor.Builder.of(type());
+ for (Cell cell : cellIterable()) {
+ TensorAddress address = cell.getKey();
+ double value = cell.getValue();
+ builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value);
+ }
+ for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) {
+ builder.cell(addCell.getKey(), addCell.getValue());
+ }
+ return builder.build();
+ }
+
+ @Override
+ public Tensor remove(Set<TensorAddress> addresses) {
+ Tensor.Builder builder = Tensor.Builder.of(type());
+
+ // iterate through all sparse addresses referencing a dense subspace
+ for (Map.Entry<TensorAddress, Long> entry : index.sparseMap.entrySet()) {
+ TensorAddress sparsePartialAddress = entry.getKey();
+ if ( ! addresses.contains(sparsePartialAddress)) { // assumption: addresses only contain the sparse part
+ long offset = entry.getValue();
+ for (int i = 0; i < index.denseSubspaceSize; ++i) {
+ Cell cell = cells.get((int)offset + i);
+ builder.cell(cell.getKey(), cell.getValue());
+ }
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
public int hashCode() { return cells.hashCode(); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 8002990e5c6..eb16801c306 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -25,6 +25,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;
@@ -113,6 +114,29 @@ public interface Tensor {
return builder.build();
}
+ /**
+ * Returns a new tensor where existing cells in this tensor have been
+ * modified according to the given operation and cells in the given map.
+ * In contrast to {@link #modify}, previously non-existing cells are added
+ * to this tensor. Only valid for sparse or mixed tensors.
+ *
+ * @param op how to update overlapping cells
+ * @param cells cells to merge with this tensor
+ * @return a new tensor where this tensor is merged with the other
+ */
+ Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells);
+
+ /**
+ * Returns a new tensor where existing cells in this tensor have been
+ * removed according to the given set of addresses. Only valid for sparse
+ * or mixed tensors. For mixed tensors, addresses are assumed to only
+ * contain the sparse dimensions, as the entire dense subspace is removed.
+ *
+ * @param addresses list of addresses to remove
+ * @return a new tensor where cells have been removed
+ */
+ Tensor remove(Set<TensorAddress> addresses);
+
// ----------------- Primitive tensor functions
default Tensor map(DoubleUnaryOperator mapper) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 2c9eefbd130..02d16e6f3e4 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -151,12 +151,106 @@ public class TensorTestCase {
Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:1, {x:0,y:1}:2}"),
Tensor.from("tensor(x[1],y[3])", "{}"),
Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:0,{x:0,y:1}:0}"));
+ assertTensorModify((left, right) -> left * right,
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:6}"));
+ }
+
+ @Test
+ public void testTensorMerge() {
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:2}:3}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:3}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3,{x:0,y:2}:4}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:5}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:5}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:2}:3}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:0,{x:0,y:2}:3}")); // notice difference with sparse case - y is dense dimension here with default value 0.0
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:0}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:1}:3,{x:0,y:2}:4}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:4}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:5}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:5}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y[3])", "{}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"));
+ }
+
+ @Test
+ public void testTensorRemove() {
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:1}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:1}"),
+ Tensor.from("tensor(x{},y{})", "{}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y{})", "{}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1}"),
+ Tensor.from("tensor(x{},y{})", "{}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y{})", "{}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2, {x:0,y:1}:3}"),
+ Tensor.from("tensor(x{})", "{{x:0}:1}"), // notice update is without dense dimension
+ Tensor.from("tensor(x{},y[3])", "{}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:1,y:0}:2}"),
+ Tensor.from("tensor(x{})", "{{x:0}:1}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:1,y:0}:2,{x:1,y:1}:0,{x:1,y:2}:0}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y[3])", "{}"),
+ Tensor.from("tensor(x{})", "{{x:0}:1}"),
+ Tensor.from("tensor(x{},y[3])", "{}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{})", "{}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}"));
}
private void assertTensorModify(DoubleBinaryOperator op, Tensor init, Tensor update, Tensor expected) {
assertEquals(expected, init.modify(op, update.cells()));
}
+ private void assertTensorMerge(Tensor init, Tensor update, Tensor expected) {
+ DoubleBinaryOperator op = (left, right) -> right;
+ assertEquals(expected, init.merge(op, update.cells()));
+ }
+
+ private void assertTensorRemove(Tensor init, Tensor update, Tensor expected) {
+ assertEquals(expected, init.remove(update.cells().keySet()));
+ }
+
+
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
double sum = 0;
TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),