diff options
204 files changed, 3186 insertions, 1916 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index a0f35dbefe6..6109e5c4aae 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -191,6 +191,8 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement else { // default dimension = ((ReferenceNode)arg0).reference().arguments().expressions().get(0).toString(); } + + // TODO: Determine the type of the weighted set/vector and use that as value type return Optional.of(new TensorType.Builder().mapped(dimension).build()); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index f197e2dfe6d..e12cc60b041 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -453,10 +453,9 @@ public class ConvertedModel { */ // TODO: determine when this is not necessary! private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { - if (after.equals(before)) { - return node; - } - TensorType.Builder typeBuilder = new TensorType.Builder(); + if (after.equals(before)) return node; + + TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType()); for (TensorType.Dimension dimension : before.dimensions()) { if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { typeBuilder.indexed(dimension.name(), 1); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java index 5c96635fd8f..80440ac8eb4 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java @@ -144,7 +144,7 @@ public class RankingExpressionWithTensorTestCase { @Test public void requireThatInvalidTensorTypeSpecThrowsException() throws ParseException { exception.expect(IllegalArgumentException.class); - exception.expectMessage("For constant tensor 'my_tensor' in rank profile 'my_profile': Illegal tensor type spec: Failed parsing element 'x' in type spec 'tensor(x)'"); + exception.expectMessage("For constant tensor 'my_tensor' in rank profile 'my_profile': Illegal tensor type spec: A tensor type spec must be on the form tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was 'tensor(x)'. Dimension 'x' is on the wrong format. Examples: tensor(x[]), tensor<float>(name{}, x[10])"); RankProfileSearchFixture f = new RankProfileSearchFixture( " rank-profile my_profile {\n" + " constants {\n" + diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java index 2fcf5809ea5..f53ca15635f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java @@ -39,7 +39,7 @@ public class TensorFieldTestCase { @Test public void requireThatIllegalTensorTypeSpecThrowsException() throws ParseException { exception.expect(IllegalArgumentException.class); - exception.expectMessage("Field type: Illegal tensor type spec: Failed parsing element 'invalid' in type spec 'tensor(invalid)'"); + exception.expectMessage("Field type: Illegal tensor type spec: A tensor type spec must be on the form tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was 'tensor(invalid)'. Dimension 'invalid' is on the wrong format. Examples: tensor(x[]), tensor<float>(name{}, x[10])"); SearchBuilder.createFromString(getSd("field f1 type tensor(invalid) { indexing: attribute }")); } diff --git a/configdefinitions/src/vespa/dispatch.def b/configdefinitions/src/vespa/dispatch.def index 7d5979bcdf1..477a781ebbc 100644 --- a/configdefinitions/src/vespa/dispatch.def +++ b/configdefinitions/src/vespa/dispatch.def @@ -40,6 +40,9 @@ minWaitAfterCoverageFactor double default=0 # Maximum wait time for full coverage after minimum coverage is achieved, factored based on time left at minimum coverage maxWaitAfterCoverageFactor double default=1 +# Number of JRT connection supervisors +numJrtSupervisors int default=8 + # The unique key of a search node node[].key int diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerDB.java b/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerDB.java index 152fc47d807..3e9783cabe9 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerDB.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerDB.java @@ -2,8 +2,8 @@ package com.yahoo.vespa.config.server; import com.yahoo.cloud.config.ConfigserverConfig; -import com.yahoo.config.model.application.provider.Bundle; import com.yahoo.config.application.ConfigDefinitionDir; +import com.yahoo.config.model.application.provider.Bundle; import com.yahoo.io.IOUtils; import com.yahoo.log.LogLevel; import com.yahoo.vespa.defaults.Defaults; diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java index 4fbda42fdc7..877b2acb86f 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java @@ -32,7 +32,6 @@ import java.util.Set; import java.util.stream.Collectors; import static com.yahoo.config.model.api.container.ContainerServiceType.CONTAINER; -import static com.yahoo.config.model.api.container.ContainerServiceType.LOGSERVER_CONTAINER; import static com.yahoo.config.model.api.container.ContainerServiceType.QRSERVER; /** @@ -44,12 +43,12 @@ import static com.yahoo.config.model.api.container.ContainerServiceType.QRSERVER public class ConfigConvergenceChecker extends AbstractComponent { private static final ApplicationId routingApplicationId = ApplicationId.from("hosted-vespa", "routing", "default"); + private static final String nodeAdminName = "node-admin"; private static final String statePath = "/state/v1/"; private static final String configSubPath = "config"; private final static Set<String> serviceTypesToCheck = new HashSet<>(Arrays.asList( CONTAINER.serviceName, QRSERVER.serviceName, - LOGSERVER_CONTAINER.serviceName, "searchnode", "storagenode", "distributor" diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java b/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java index 7731e13eac2..3705a0ec145 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server.application; -import com.google.common.collect.ImmutableSet; import com.yahoo.concurrent.ThreadFactoryFactory; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.TenantName; @@ -17,12 +16,13 @@ import com.yahoo.vespa.curator.transaction.CuratorTransaction; import org.apache.curator.framework.CuratorFramework; import org.apache.curator.framework.recipes.cache.PathChildrenCacheEvent; -import java.util.ArrayList; import java.util.List; -import java.util.Optional; +import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.logging.Logger; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** * The applications of a tenant, backed by ZooKeeper. @@ -70,40 +70,39 @@ public class TenantApplications { * @return a list of {@link ApplicationId}s that are active. */ public List<ApplicationId> listApplications() { - try { - List<String> appNodes = curator.framework().getChildren().forPath(applicationsPath.getAbsolute()); - List<ApplicationId> applicationIds = new ArrayList<>(); - for (String appNode : appNodes) { - parseApplication(appNode).ifPresent(applicationIds::add); - } - return applicationIds; - } catch (Exception e) { - throw new RuntimeException(TenantRepository.logPre(tenant)+"Unable to list applications", e); - } + return curator.getChildren(applicationsPath).stream() + .flatMap(this::parseApplication) + .collect(Collectors.toUnmodifiableList()); } - private Optional<ApplicationId> parseApplication(String appNode) { + // TODO jvenstad: Remove after it has run once everywhere. + private Stream<ApplicationId> parseApplication(String appNode) { try { - ApplicationId id = ApplicationId.fromSerializedForm(appNode); - getSessionIdForApplication(id); - return Optional.of(id); - } catch (IllegalArgumentException e) { - log.log(LogLevel.INFO, TenantRepository.logPre(tenant)+"Unable to parse application with id '" + appNode + "', ignoring."); - return Optional.empty(); + return Stream.of(ApplicationId.fromSerializedForm(appNode)); + } catch (IllegalArgumentException __) { + log.log(LogLevel.INFO, TenantRepository.logPre(tenant) + "Unable to parse application id from '" + + appNode + "'; deleting it as it shouldn't be here."); + try { + curator.delete(applicationsPath.append(appNode)); + } + catch (Exception e) { + log.log(LogLevel.WARNING, TenantRepository.logPre(tenant) + "Failed to clean up stray node '" + appNode + "'!", e); + } + return Stream.empty(); } } /** - * Register active application and adds it to the repo. If it already exists it is overwritten. + * Returns a transaction which writes the given session id as the currently active for the given application. * * @param applicationId An {@link ApplicationId} that represents an active application. * @param sessionId Id of the session containing the application package for this id. */ public Transaction createPutApplicationTransaction(ApplicationId applicationId, long sessionId) { if (listApplications().contains(applicationId)) { - return new CuratorTransaction(curator).add(CuratorOperations.setData(applicationsPath.append(applicationId.serializedForm()).getAbsolute(), Utf8.toAsciiBytes(sessionId))); + return new CuratorTransaction(curator).add(CuratorOperations.setData(applicationPath(applicationId).getAbsolute(), Utf8.toAsciiBytes(sessionId))); } else { - return new CuratorTransaction(curator).add(CuratorOperations.create(applicationsPath.append(applicationId.serializedForm()).getAbsolute(), Utf8.toAsciiBytes(sessionId))); + return new CuratorTransaction(curator).add(CuratorOperations.create(applicationPath(applicationId).getAbsolute(), Utf8.toAsciiBytes(sessionId))); } } @@ -115,7 +114,7 @@ public class TenantApplications { * @throws IllegalArgumentException if the application does not exist */ public long getSessionIdForApplication(ApplicationId applicationId) { - String path = applicationsPath.append(applicationId.serializedForm()).getAbsolute(); + String path = applicationPath(applicationId).getAbsolute(); try { return Long.parseLong(Utf8.toString(curator.framework().getData().forPath(path))); } catch (Exception e) { @@ -124,18 +123,22 @@ public class TenantApplications { } /** - * Returns a transaction which deletes this application - * - * @param applicationId an {@link ApplicationId} to delete. + * Returns a transaction which deletes this application. */ public CuratorTransaction deleteApplication(ApplicationId applicationId) { - Path path = applicationsPath.append(applicationId.serializedForm()); - return CuratorTransaction.from(CuratorOperations.delete(path.getAbsolute()), curator); + return CuratorTransaction.from(CuratorOperations.delete(applicationPath(applicationId).getAbsolute()), curator); } /** - * Closes the application repo. Once a repo has been closed, it should not be used again. - */ + * Removes all applications not known to this from the config server state. + */ + public void removeUnusedApplications() { + reloadHandler.removeApplicationsExcept(Set.copyOf(listApplications())); + } + + /** + * Closes the application repo. Once a repo has been closed, it should not be used again. + */ public void close() { directoryCache.close(); } @@ -169,13 +172,8 @@ public class TenantApplications { log.log(LogLevel.DEBUG, TenantRepository.logPre(applicationId) + "Application added: " + applicationId); } - /** - * Removes unused applications - * - */ - public void removeUnusedApplications() { - ImmutableSet<ApplicationId> activeApplications = ImmutableSet.copyOf(listApplications()); - reloadHandler.removeApplicationsExcept(activeApplications); + private Path applicationPath(ApplicationId id) { + return applicationsPath.append(id.serializedForm()); } } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java index 21716730825..082be2583c2 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java @@ -89,9 +89,8 @@ public class Deployment implements com.yahoo.config.provision.Deployment { timeout, clock, true, true, session.getVespaVersion(), isBootstrap); } - public Deployment setIgnoreSessionStaleFailure(boolean ignoreSessionStaleFailure) { + public void setIgnoreSessionStaleFailure(boolean ignoreSessionStaleFailure) { this.ignoreSessionStaleFailure = ignoreSessionStaleFailure; - return this; } /** Prepares this. This does nothing if this is already prepared */ diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/LocalSession.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/LocalSession.java index 0f9f8b72de1..0cdf5ebfe95 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/LocalSession.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/LocalSession.java @@ -83,12 +83,6 @@ public class LocalSession extends Session implements Comparable<LocalSession> { setStatus(Session.Status.PREPARE); } - private Transaction setActive() { - Transaction transaction = createSetStatusTransaction(Status.ACTIVATE); - transaction.add(applicationRepo.createPutApplicationTransaction(zooKeeperClient.readApplicationId(), getSessionId()).operations()); - return transaction; - } - private Transaction createSetStatusTransaction(Status status) { return zooKeeperClient.createWriteStatusTransaction(status); } @@ -99,8 +93,10 @@ public class LocalSession extends Session implements Comparable<LocalSession> { public Transaction createActivateTransaction() { zooKeeperClient.createActiveWaiter(); - superModelGenerationCounter.increment(); - return setActive(); + superModelGenerationCounter.increment(); // TODO jvenstad: I hope this counter isn't used for serious things, as it's updated way ahead of activation. + Transaction transaction = createSetStatusTransaction(Status.ACTIVATE); + transaction.add(applicationRepo.createPutApplicationTransaction(zooKeeperClient.readApplicationId(), getSessionId()).operations()); + return transaction; } public Transaction createDeactivateTransaction() { diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactory.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactory.java index a3dea83d50c..5527d3060f7 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactory.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactory.java @@ -17,8 +17,6 @@ public interface SessionFactory { /** * Creates a new deployment session from an application package. * - * - * * @param applicationDirectory a File pointing to an application. * @param applicationId application id for this new session. * @param timeoutBudget Timeout for creating session and waiting for other servers. @@ -29,10 +27,10 @@ public interface SessionFactory { /** * Creates a new deployment session from an already existing session. * - * @param existingSession The session to use as base + * @param existingSession the session to use as base * @param logger a deploy logger where the deploy log will be written. - * @param internalRedeploy if this session is for a system internal redeploy not an application package change - * @param timeoutBudget Timeout for creating session and waiting for other servers. + * @param internalRedeploy whether this session is for a system internal redeploy — not an application package change + * @param timeoutBudget timeout for creating session and waiting for other servers. * @return a new session */ LocalSession createSessionFromExisting(LocalSession existingSession, DeployLogger logger, diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactoryImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactoryImpl.java index b79ea720aea..90eeb89dc8e 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactoryImpl.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionFactoryImpl.java @@ -194,4 +194,5 @@ public class SessionFactoryImpl implements SessionFactory, LocalSessionLoader { } return nonExistingActiveSession; } + } diff --git a/container-search/src/main/java/com/yahoo/fs4/MapEncoder.java b/container-search/src/main/java/com/yahoo/fs4/MapEncoder.java index 4245f51ace8..565a4c483c3 100644 --- a/container-search/src/main/java/com/yahoo/fs4/MapEncoder.java +++ b/container-search/src/main/java/com/yahoo/fs4/MapEncoder.java @@ -20,8 +20,13 @@ public class MapEncoder { // TODO: Time to refactor - private static final String TYPE_SUFFIX = ".type"; - private static final String TENSOR_TYPE = "tensor"; + private static byte [] getUtf8(Object value) { + if (value instanceof Tensor) { + return TypedBinaryFormat.encode((Tensor)value); + } else { + return Utf8.toBytes(value.toString()); + } + } /** * Encodes a single value as a complete binary map. @@ -39,7 +44,7 @@ public class MapEncoder { utf8 = Utf8.toBytes(key); buffer.putInt(utf8.length); buffer.put(utf8); - utf8 = Utf8.toBytes(value.toString()); + utf8 = getUtf8(value); buffer.putInt(utf8.length); buffer.put(utf8); @@ -64,7 +69,12 @@ public class MapEncoder { utf8 = Utf8.toBytes(key); buffer.putInt(utf8.length); buffer.put(utf8); - utf8 = Utf8.toBytes(property.getValue() != null ? property.getValue().toString() : ""); + Object value = property.getValue(); + if (value == null) { + utf8 = Utf8.toBytes(""); + } else { + utf8 = getUtf8(value); + } buffer.putInt(utf8.length); buffer.put(utf8); } @@ -78,53 +88,21 @@ public class MapEncoder { * * Returns the number of maps encoded - 0 or 1 */ - public static int encodeStringMultiMap(String mapName, Map<String,List<String>> map, ByteBuffer buffer) { - if (map.isEmpty()) return 0; - - byte [] utf8 = Utf8.toBytes(mapName); - buffer.putInt(utf8.length); - buffer.put(utf8); - buffer.putInt(countStringEntries(map)); - for (Map.Entry<String, List<String>> property : map.entrySet()) { - String key = property.getKey(); - for (Object value : property.getValue()) { - utf8 = Utf8.toBytes(key); - buffer.putInt(utf8.length); - buffer.put(utf8); - utf8 = Utf8.toBytes(value.toString()); - buffer.putInt(utf8.length); - buffer.put(utf8); - } - } - - return 1; - } - /** - * Encodes a multi-map as binary. - * Does nothing if the value is null. - * - * Returns the number of maps encoded - 0 or 1 - */ - public static int encodeObjectMultiMap(String mapName, Map<String,List<Object>> map, ByteBuffer buffer) { + public static <T> int encodeMultiMap(String mapName, Map<String,List<T>> map, ByteBuffer buffer) { if (map.isEmpty()) return 0; byte[] utf8 = Utf8.toBytes(mapName); buffer.putInt(utf8.length); buffer.put(utf8); - addTensorTypeInfo(map); - buffer.putInt(countObjectEntries(map)); - for (Map.Entry<String, List<Object>> property : map.entrySet()) { + buffer.putInt(countEntries(map)); + for (Map.Entry<String, List<T>> property : map.entrySet()) { String key = property.getKey(); for (Object value : property.getValue()) { utf8 = Utf8.toBytes(key); buffer.putInt(utf8.length); buffer.put(utf8); - if (value instanceof Tensor) { - utf8 = TypedBinaryFormat.encode((Tensor)value); - } else { - utf8 = Utf8.toBytes(value.toString()); - } + utf8 = getUtf8(value); buffer.putInt(utf8.length); buffer.put(utf8); } @@ -133,32 +111,9 @@ public class MapEncoder { return 1; } - private static void addTensorTypeInfo(Map<String, List<Object>> map) { - Map<String, Tensor> tensorsToTag = new HashMap<>(); - for (Map.Entry<String, List<Object>> entry : map.entrySet()) { - for (Object value : entry.getValue()) { - if (value instanceof Tensor) { - tensorsToTag.put(entry.getKey(), (Tensor)value); - } - } - } - for (Map.Entry<String, Tensor> entry : tensorsToTag.entrySet()) { - // Ensure that we only have a single tensor associated with each key - map.put(entry.getKey(), Arrays.asList(entry.getValue())); - map.put(entry.getKey() + TYPE_SUFFIX, Arrays.asList(TENSOR_TYPE)); - } - } - - private static int countStringEntries(Map<String, List<String>> value) { - int entries = 0; - for (Map.Entry<String, List<String>> property : value.entrySet()) - entries += property.getValue().size(); - return entries; - } - - private static int countObjectEntries(Map<String, List<Object>> value) { + private static <T> int countEntries(Map<String, List<T>> value) { int entries = 0; - for (Map.Entry<String, List<Object>> property : value.entrySet()) + for (Map.Entry<String, List<T>> property : value.entrySet()) entries += property.getValue().size(); return entries; } diff --git a/container-search/src/main/java/com/yahoo/search/Query.java b/container-search/src/main/java/com/yahoo/search/Query.java index b97ee87f650..a5007c9cc33 100644 --- a/container-search/src/main/java/com/yahoo/search/Query.java +++ b/container-search/src/main/java/com/yahoo/search/Query.java @@ -1055,7 +1055,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { // TODO: Push down if (presentation.getHighlight() != null) { - mapCount += MapEncoder.encodeStringMultiMap(Highlight.HIGHLIGHTTERMS, presentation.getHighlight().getHighlightTerms(), buffer); + mapCount += MapEncoder.encodeMultiMap(Highlight.HIGHLIGHTTERMS, presentation.getHighlight().getHighlightTerms(), buffer); } // TODO: Push down diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/Client.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/Client.java index cc37df04a62..e54e2187818 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/Client.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/Client.java @@ -2,8 +2,6 @@ package com.yahoo.search.dispatch.rpc; import com.yahoo.compress.CompressionType; -import com.yahoo.compress.Compressor; -import com.yahoo.prelude.Pong; import com.yahoo.prelude.fastsearch.FastHit; import java.util.List; @@ -15,14 +13,6 @@ import java.util.Optional; * @author bratseth */ interface Client { - - void getDocsums(List<FastHit> hits, NodeConnection node, CompressionType compression, - int uncompressedLength, byte[] compressedSlime, RpcFillInvoker.GetDocsumsResponseReceiver responseReceiver, - double timeoutSeconds); - - void request(String rpcMethod, NodeConnection node, CompressionType compression, int uncompressedLength, - byte[] compressedPayload, ResponseReceiver responseReceiver, double timeoutSeconds); - /** Creates a connection to a particular node in this */ NodeConnection createConnection(String hostname, int port); @@ -91,6 +81,11 @@ interface Client { } interface NodeConnection { + void getDocsums(List<FastHit> hits, CompressionType compression, int uncompressedLength, byte[] compressedSlime, + RpcFillInvoker.GetDocsumsResponseReceiver responseReceiver, double timeoutSeconds); + + void request(String rpcMethod, CompressionType compression, int uncompressedLength, byte[] compressedPayload, + ResponseReceiver responseReceiver, double timeoutSeconds); /** Closes this connection */ void close(); diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcClient.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcClient.java index 2aa01b05955..7e48733106a 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcClient.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcClient.java @@ -29,31 +29,6 @@ class RpcClient implements Client { return new RpcNodeConnection(hostname, port, supervisor); } - @Override - public void getDocsums(List<FastHit> hits, NodeConnection node, CompressionType compression, int uncompressedLength, - byte[] compressedSlime, RpcFillInvoker.GetDocsumsResponseReceiver responseReceiver, double timeoutSeconds) { - Request request = new Request("proton.getDocsums"); - request.parameters().add(new Int8Value(compression.getCode())); - request.parameters().add(new Int32Value(uncompressedLength)); - request.parameters().add(new DataValue(compressedSlime)); - - request.setContext(hits); - RpcNodeConnection rpcNode = ((RpcNodeConnection) node); - rpcNode.invokeAsync(request, timeoutSeconds, new RpcDocsumResponseWaiter(rpcNode, responseReceiver)); - } - - @Override - public void request(String rpcMethod, NodeConnection node, CompressionType compression, int uncompressedLength, byte[] compressedPayload, - ResponseReceiver responseReceiver, double timeoutSeconds) { - Request request = new Request(rpcMethod); - request.parameters().add(new Int8Value(compression.getCode())); - request.parameters().add(new Int32Value(uncompressedLength)); - request.parameters().add(new DataValue(compressedPayload)); - - RpcNodeConnection rpcNode = ((RpcNodeConnection) node); - rpcNode.invokeAsync(request, timeoutSeconds, new RpcProtobufResponseWaiter(rpcNode, responseReceiver)); - } - private static class RpcNodeConnection implements NodeConnection { // Information about the connected node @@ -73,7 +48,30 @@ class RpcClient implements Client { description = "rpc node connection to " + hostname + ":" + port; } - public void invokeAsync(Request req, double timeout, RequestWaiter waiter) { + @Override + public void getDocsums(List<FastHit> hits, CompressionType compression, int uncompressedLength, + byte[] compressedSlime, RpcFillInvoker.GetDocsumsResponseReceiver responseReceiver, double timeoutSeconds) { + Request request = new Request("proton.getDocsums"); + request.parameters().add(new Int8Value(compression.getCode())); + request.parameters().add(new Int32Value(uncompressedLength)); + request.parameters().add(new DataValue(compressedSlime)); + + request.setContext(hits); + invokeAsync(request, timeoutSeconds, new RpcDocsumResponseWaiter(this, responseReceiver)); + } + + @Override + public void request(String rpcMethod, CompressionType compression, int uncompressedLength, byte[] compressedPayload, + ResponseReceiver responseReceiver, double timeoutSeconds) { + Request request = new Request(rpcMethod); + request.parameters().add(new Int8Value(compression.getCode())); + request.parameters().add(new Int32Value(uncompressedLength)); + request.parameters().add(new DataValue(compressedPayload)); + + invokeAsync(request, timeoutSeconds, new RpcProtobufResponseWaiter(this, responseReceiver)); + } + + private void invokeAsync(Request req, double timeout, RequestWaiter waiter) { // TODO: Consider replacing this by a watcher on the target synchronized(this) { // ensure we have exactly 1 valid connection across threads if (target == null || ! target.isValid()) diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcFillInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcFillInvoker.java index 760f7486923..aa72823c809 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcFillInvoker.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcFillInvoker.java @@ -100,7 +100,7 @@ public class RpcFillInvoker extends FillInvoker { /** Send a getDocsums request to a node. Responses will be added to the given receiver. */ private void sendGetDocsumsRequest(int nodeId, List<FastHit> hits, String summaryClass, CompressionType compression, Result result, GetDocsumsResponseReceiver responseReceiver) { - Client.NodeConnection node = resourcePool.nodeConnections().get(nodeId); + Client.NodeConnection node = resourcePool.getConnection(nodeId); if (node == null) { String error = "Could not fill hits from unknown node " + nodeId; responseReceiver.receive(Client.ResponseOrError.fromError(error)); @@ -114,9 +114,8 @@ public class RpcFillInvoker extends FillInvoker { byte[] serializedSlime = BinaryFormat .encode(toSlime(rankProfile, summaryClass, query.getModel().getDocumentDb(), query.getSessionId(), hits)); double timeoutSeconds = ((double) query.getTimeLeft() - 3.0) / 1000.0; - Compressor.Compression compressionResult = resourcePool.compressor().compress(compression, serializedSlime); - resourcePool.client().getDocsums(hits, node, compressionResult.type(), serializedSlime.length, compressionResult.data(), - responseReceiver, timeoutSeconds); + Compressor.Compression compressionResult = resourcePool.compress(query, serializedSlime); + node.getDocsums(hits, compressionResult.type(), serializedSlime.length, compressionResult.data(), responseReceiver, timeoutSeconds); } static private Slime toSlime(String rankProfile, String summaryClass, String docType, SessionId sessionId, List<FastHit> hits) { diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcPing.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcPing.java index f3479e2e4a9..c001b51ef11 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcPing.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcPing.java @@ -52,12 +52,12 @@ public class RpcPing implements Callable<Pong> { } private void sendPing(LinkedBlockingQueue<ResponseOrError<ProtobufResponse>> queue) { - var connection = resourcePool.nodeConnections().get(node.key()); + var connection = resourcePool.getConnection(node.key()); var ping = SearchProtocol.MonitorRequest.newBuilder().build().toByteArray(); double timeoutSeconds = ((double) clusterMonitor.getConfiguration().getRequestTimeout()) / 1000.0; Compressor.Compression compressionResult = resourcePool.compressor().compress(PING_COMPRESSION, ping); - resourcePool.client().request(RPC_METHOD, connection, compressionResult.type(), ping.length, compressionResult.data(), - rsp -> queue.add(rsp), timeoutSeconds); + connection.request(RPC_METHOD, compressionResult.type(), ping.length, compressionResult.data(), rsp -> queue.add(rsp), + timeoutSeconds); } private Pong decodeReply(ProtobufResponse response) throws InvalidProtocolBufferException { diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcProtobufFillInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcProtobufFillInvoker.java index 3ec821beba8..cd4ba191a7d 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcProtobufFillInvoker.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcProtobufFillInvoker.java @@ -66,9 +66,6 @@ public class RpcProtobufFillInvoker extends FillInvoker { protected void sendFillRequest(Result result, String summaryClass) { ListMap<Integer, FastHit> hitsByNode = hitsByNode(result); - CompressionType compression = CompressionType - .valueOf(result.getQuery().properties().getString(RpcResourcePool.dispatchCompression, "LZ4").toUpperCase()); - result.getQuery().trace(false, 5, "Sending ", hitsByNode.size(), " summary fetch requests with jrt/protobuf"); outstandingResponses = hitsByNode.size(); @@ -77,7 +74,7 @@ public class RpcProtobufFillInvoker extends FillInvoker { var builder = ProtobufSerialization.createDocsumRequestBuilder(result.getQuery(), serverId, summaryClass, summaryNeedsQuery); for (Map.Entry<Integer, List<FastHit>> nodeHits : hitsByNode.entrySet()) { var payload = ProtobufSerialization.serializeDocsumRequest(builder, nodeHits.getValue()); - sendDocsumsRequest(nodeHits.getKey(), nodeHits.getValue(), payload, compression, result); + sendDocsumsRequest(nodeHits.getKey(), nodeHits.getValue(), payload, result); } } @@ -117,8 +114,8 @@ public class RpcProtobufFillInvoker extends FillInvoker { } /** Send a docsums request to a node. Responses will be added to the given receiver. */ - private void sendDocsumsRequest(int nodeId, List<FastHit> hits, byte[] payload, CompressionType compression, Result result) { - Client.NodeConnection node = resourcePool.nodeConnections().get(nodeId); + private void sendDocsumsRequest(int nodeId, List<FastHit> hits, byte[] payload, Result result) { + Client.NodeConnection node = resourcePool.getConnection(nodeId); if (node == null) { String error = "Could not fill hits from unknown node " + nodeId; receive(Client.ResponseOrError.fromError(error), hits); @@ -129,9 +126,9 @@ public class RpcProtobufFillInvoker extends FillInvoker { Query query = result.getQuery(); double timeoutSeconds = ((double) query.getTimeLeft() - 3.0) / 1000.0; - Compressor.Compression compressionResult = resourcePool.compressor().compress(compression, payload); - resourcePool.client().request(RPC_METHOD, node, compressionResult.type(), payload.length, compressionResult.data(), - roe -> receive(roe, hits), timeoutSeconds); + Compressor.Compression compressionResult = resourcePool.compress(query, payload); + node.request(RPC_METHOD, compressionResult.type(), payload.length, compressionResult.data(), roe -> receive(roe, hits), + timeoutSeconds); } private void processResponses(Result result, String summaryClass) throws TimeoutException { diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcResourcePool.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcResourcePool.java index 830ba45ef0f..cccf8dd3693 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcResourcePool.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcResourcePool.java @@ -2,12 +2,20 @@ package com.yahoo.search.dispatch.rpc; import com.google.common.collect.ImmutableMap; +import com.yahoo.compress.CompressionType; import com.yahoo.compress.Compressor; +import com.yahoo.compress.Compressor.Compression; import com.yahoo.processing.request.CompoundName; +import com.yahoo.search.Query; import com.yahoo.search.dispatch.FillInvoker; +import com.yahoo.search.dispatch.rpc.Client.NodeConnection; import com.yahoo.vespa.config.search.DispatchConfig; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import java.util.Map; +import java.util.Random; /** * RpcResourcePool constructs {@link FillInvoker} objects that communicate with content nodes over RPC. It also contains @@ -19,43 +27,70 @@ public class RpcResourcePool { /** The compression method which will be used with rpc dispatch. "lz4" (default) and "none" is supported. */ public final static CompoundName dispatchCompression = new CompoundName("dispatch.compression"); - private final Compressor compressor = new Compressor(); - private final Client client; + private final Compressor compressor = new Compressor(CompressionType.LZ4, 5, 0.95, 32); + private final Random random = new Random(); /** Connections to the search nodes this talks to, indexed by node id ("partid") */ - private final ImmutableMap<Integer, Client.NodeConnection> nodeConnections; + private final ImmutableMap<Integer, NodeConnectionPool> nodeConnectionPools; - public RpcResourcePool(Client client, Map<Integer, Client.NodeConnection> nodeConnections) { - this.client = client; - this.nodeConnections = ImmutableMap.copyOf(nodeConnections); + public RpcResourcePool(Map<Integer, Client.NodeConnection> nodeConnections) { + var builder = new ImmutableMap.Builder<Integer, NodeConnectionPool>(); + nodeConnections.forEach((key, connection) -> builder.put(key, new NodeConnectionPool(Collections.singletonList(connection)))); + this.nodeConnectionPools = builder.build(); } public RpcResourcePool(DispatchConfig dispatchConfig) { - this.client = new RpcClient(); + var clients = new ArrayList<RpcClient>(dispatchConfig.numJrtSupervisors()); + for (int i = 0; i < dispatchConfig.numJrtSupervisors(); i++) { + clients.add(new RpcClient()); + } - // Create node rpc connections, indexed by the node distribution key - ImmutableMap.Builder<Integer, Client.NodeConnection> nodeConnectionsBuilder = new ImmutableMap.Builder<>(); - for (DispatchConfig.Node node : dispatchConfig.node()) { - nodeConnectionsBuilder.put(node.key(), client.createConnection(node.host(), node.port())); + // Create node rpc connection pools, indexed by the node distribution key + var builder = new ImmutableMap.Builder<Integer, NodeConnectionPool>(); + for (var node : dispatchConfig.node()) { + var connections = new ArrayList<Client.NodeConnection>(clients.size()); + clients.forEach(client -> connections.add(client.createConnection(node.host(), node.port()))); + builder.put(node.key(), new NodeConnectionPool(connections)); } - this.nodeConnections = nodeConnectionsBuilder.build(); + this.nodeConnectionPools = builder.build(); } public Compressor compressor() { return compressor; } - public Client client() { - return client; + public Compression compress(Query query, byte[] payload) { + CompressionType compression = CompressionType.valueOf(query.properties().getString(dispatchCompression, "LZ4").toUpperCase()); + return compressor.compress(compression, payload); } - public ImmutableMap<Integer, Client.NodeConnection> nodeConnections() { - return nodeConnections; + public NodeConnection getConnection(int nodeId) { + var pool = nodeConnectionPools.get(nodeId); + if (pool == null) { + return null; + } else { + return pool.nextConnection(); + } } public void release() { - for (Client.NodeConnection nodeConnection : nodeConnections.values()) { - nodeConnection.close(); + nodeConnectionPools.values().forEach(NodeConnectionPool::release); + } + + private class NodeConnectionPool { + private final List<Client.NodeConnection> connections; + + NodeConnectionPool(List<Client.NodeConnection> connections) { + this.connections = connections; + } + + Client.NodeConnection nextConnection() { + int slot = random.nextInt(connections.size()); + return connections.get(slot); + } + + void release() { + connections.forEach(Client.NodeConnection::close); } } } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcSearchInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcSearchInvoker.java index d70a7d95b63..75e9b06f445 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcSearchInvoker.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcSearchInvoker.java @@ -46,10 +46,7 @@ public class RpcSearchInvoker extends SearchInvoker implements Client.ResponseRe protected void sendSearchRequest(Query query) throws IOException { this.query = query; - CompressionType compression = CompressionType - .valueOf(query.properties().getString(RpcResourcePool.dispatchCompression, "LZ4").toUpperCase()); - - Client.NodeConnection nodeConnection = resourcePool.nodeConnections().get(node.key()); + Client.NodeConnection nodeConnection = resourcePool.getConnection(node.key()); if (nodeConnection == null) { responses.add(Client.ResponseOrError.fromError("Could not send search to unknown node " + node.key())); responseAvailable(); @@ -59,9 +56,8 @@ public class RpcSearchInvoker extends SearchInvoker implements Client.ResponseRe var payload = ProtobufSerialization.serializeSearchRequest(query, searcher.getServerId()); double timeoutSeconds = ((double) query.getTimeLeft() - 3.0) / 1000.0; - Compressor.Compression compressionResult = resourcePool.compressor().compress(compression, payload); - resourcePool.client().request(RPC_METHOD, nodeConnection, compressionResult.type(), payload.length, compressionResult.data(), this, - timeoutSeconds); + Compressor.Compression compressionResult = resourcePool.compress(query, payload); + nodeConnection.request(RPC_METHOD, compressionResult.type(), payload.length, compressionResult.data(), this, timeoutSeconds); } @Override diff --git a/container-search/src/main/java/com/yahoo/search/query/ranking/RankProperties.java b/container-search/src/main/java/com/yahoo/search/query/ranking/RankProperties.java index 4158b0e7476..37a54a82c43 100644 --- a/container-search/src/main/java/com/yahoo/search/query/ranking/RankProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/ranking/RankProperties.java @@ -76,7 +76,7 @@ public class RankProperties implements Cloneable { /** Encodes this in a binary internal representation and returns the number of property maps encoded (0 or 1) */ public int encode(ByteBuffer buffer, boolean encodeQueryData) { if (encodeQueryData) { - return MapEncoder.encodeObjectMultiMap("rank", properties, buffer); + return MapEncoder.encodeMultiMap("rank", properties, buffer); } else { List<Object> sessionId = properties.get(GetDocSumsPacket.sessionIdKey); diff --git a/container-search/src/test/java/com/yahoo/fs4/test/RankFeaturesTestCase.java b/container-search/src/test/java/com/yahoo/fs4/test/RankFeaturesTestCase.java index 69ca646dbd5..e8c16e572ae 100644 --- a/container-search/src/test/java/com/yahoo/fs4/test/RankFeaturesTestCase.java +++ b/container-search/src/test/java/com/yahoo/fs4/test/RankFeaturesTestCase.java @@ -63,11 +63,10 @@ public class RankFeaturesTestCase { assertEquals(entries.size(), properties.asMap().size()); Map<String, Object> decodedProperties = decode(type, encode(properties)); - assertEquals(entries.size() * 2, properties.asMap().size()); // tensor type info has been added - assertEquals(entries.size() * 2, decodedProperties.size()); + assertEquals(entries.size(), properties.asMap().size()); + assertEquals(entries.size(), decodedProperties.size()); for (Entry entry : entries) { assertEquals(entry.tensor, decodedProperties.get(entry.normalizedKey)); - assertEquals("tensor", decodedProperties.get(entry.normalizedKey + ".type")); } } diff --git a/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/FastSearcherTestCase.java b/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/FastSearcherTestCase.java index f4be2943f5f..04b1d526c67 100644 --- a/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/FastSearcherTestCase.java +++ b/container-search/src/test/java/com/yahoo/prelude/fastsearch/test/FastSearcherTestCase.java @@ -145,7 +145,8 @@ public class FastSearcherTestCase { doFill(fastSearcher, result); ErrorMessage error = result.hits().getError(); assertEquals("Since we don't actually run summary backends we get this error when the Dispatcher is used", - "Error response from rpc node connection to host1:0: Connection error", error.getDetailedMessage()); + "Error response from rpc node connection to hostX:0: Connection error", + error.getDetailedMessage().replaceAll("host[12]", "hostX")); } { // direct.summaries due to no summary features @@ -154,7 +155,8 @@ public class FastSearcherTestCase { doFill(fastSearcher, result); ErrorMessage error = result.hits().getError(); assertEquals("Since we don't actually run summary backends we get this error when the Dispatcher is used", - "Error response from rpc node connection to host1:0: Connection error", error.getDetailedMessage()); + "Error response from rpc node connection to hostX:0: Connection error", + error.getDetailedMessage().replaceAll("host[12]", "hostX")); } } diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/rpc/FillTestCase.java b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/FillTestCase.java index e059008acac..6d1f19eeaf2 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/rpc/FillTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/FillTestCase.java @@ -22,7 +22,6 @@ import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; - /** * Tests using a dispatcher to fill a result * @@ -38,7 +37,7 @@ public class FillTestCase { nodes.put(0, client.createConnection("host0", 123)); nodes.put(1, client.createConnection("host1", 123)); nodes.put(2, client.createConnection("host2", 123)); - RpcResourcePool rpcResourcePool = new RpcResourcePool(client, nodes); + RpcResourcePool rpcResourcePool = new RpcResourcePool(nodes); RpcInvokerFactory factory = new RpcInvokerFactory(rpcResourcePool, null, true); Query query = new Query(); @@ -75,7 +74,7 @@ public class FillTestCase { nodes.put(0, client.createConnection("host0", 123)); nodes.put(1, client.createConnection("host1", 123)); nodes.put(2, client.createConnection("host2", 123)); - RpcResourcePool rpcResourcePool = new RpcResourcePool(client, nodes); + RpcResourcePool rpcResourcePool = new RpcResourcePool(nodes); RpcInvokerFactory factory = new RpcInvokerFactory(rpcResourcePool, null, true); Query query = new Query(); @@ -90,7 +89,7 @@ public class FillTestCase { client.setDocsumReponse("host2", 1, "summaryClass1", map("field1", "s.2.1", "field2", 1)); client.setDocsumReponse("host1", 2, "summaryClass1", new HashMap<>()); client.setDocsumReponse("host2", 3, "summaryClass1", map("field1", "s.2.3", "field2", 3)); - client.setDocsumReponse("host0", 4, "summaryClass1",new HashMap<>()); + client.setDocsumReponse("host0", 4, "summaryClass1", new HashMap<>()); factory.createFillInvoker(db()).fill(result, "summaryClass1"); @@ -115,7 +114,7 @@ public class FillTestCase { Map<Integer, Client.NodeConnection> nodes = new HashMap<>(); nodes.put(0, client.createConnection("host0", 123)); - RpcResourcePool rpcResourcePool = new RpcResourcePool(client, nodes); + RpcResourcePool rpcResourcePool = new RpcResourcePool(nodes); RpcInvokerFactory factory = new RpcInvokerFactory(rpcResourcePool, null, true); Query query = new Query(); @@ -133,7 +132,7 @@ public class FillTestCase { Map<Integer, Client.NodeConnection> nodes = new HashMap<>(); nodes.put(0, client.createConnection("host0", 123)); - RpcResourcePool rpcResourcePool = new RpcResourcePool(client, nodes); + RpcResourcePool rpcResourcePool = new RpcResourcePool(nodes); RpcInvokerFactory factory = new RpcInvokerFactory(rpcResourcePool, null, true); Query query = new Query(); @@ -141,7 +140,6 @@ public class FillTestCase { result.hits().add(createHit(0, 0)); result.hits().add(createHit(1, 1)); - factory.createFillInvoker(db()).fill(result, "summaryClass1"); assertEquals("Could not fill hits from unknown node 1", result.hits().getError().getDetailedMessage()); @@ -151,8 +149,7 @@ public class FillTestCase { List<DocsumField> fields = new ArrayList<>(); fields.add(DocsumField.create("field1", "string")); fields.add(DocsumField.create("field2", "int64")); - DocsumDefinitionSet docsums = new DocsumDefinitionSet(Collections.singleton(new DocsumDefinition("summaryClass1", - fields))); + DocsumDefinitionSet docsums = new DocsumDefinitionSet(Collections.singleton(new DocsumDefinition("summaryClass1", fields))); return new DocumentDatabase("default", docsums, Collections.emptySet()); } diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/rpc/MockClient.java b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/MockClient.java index 687d3e728c0..3cc3257194c 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/rpc/MockClient.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/MockClient.java @@ -36,62 +36,6 @@ public class MockClient implements Client { return new MockNodeConnection(hostname, port); } - @Override - public void getDocsums(List<FastHit> hitsContext, NodeConnection node, CompressionType compression, - int uncompressedSize, byte[] compressedSlime, RpcFillInvoker.GetDocsumsResponseReceiver responseReceiver, - double timeoutSeconds) { - if (malfunctioning) { - responseReceiver.receive(ResponseOrError.fromError("Malfunctioning")); - return; - } - - Inspector request = BinaryFormat.decode(compressor.decompress(compressedSlime, compression, uncompressedSize)).get(); - String docsumClass = request.field("class").asString(); - List<Map<String, Object>> docsumsToReturn = new ArrayList<>(); - request.field("gids").traverse((ArrayTraverser)(index, gid) -> { - GlobalId docId = new GlobalId(gid.asData()); - docsumsToReturn.add(docsums.get(new DocsumKey(node.toString(), docId, docsumClass))); - }); - Slime responseSlime = new Slime(); - Cursor root = responseSlime.setObject(); - Cursor docsums = root.setArray("docsums"); - for (Map<String, Object> docsumFields : docsumsToReturn) { - Cursor docsumItem = docsums.addObject(); - Cursor docsum = docsumItem.setObject("docsum"); - for (Map.Entry<String, Object> field : docsumFields.entrySet()) { - if (field.getValue() instanceof Integer) - docsum.setLong(field.getKey(), (Integer)field.getValue()); - else if (field.getValue() instanceof String) - docsum.setString(field.getKey(), (String)field.getValue()); - else - throw new RuntimeException(); - } - } - byte[] slimeBytes = BinaryFormat.encode(responseSlime); - Compressor.Compression compressionResult = compressor.compress(compression, slimeBytes); - GetDocsumsResponse response = new GetDocsumsResponse(compressionResult.type().getCode(), slimeBytes.length, - compressionResult.data(), hitsContext); - responseReceiver.receive(ResponseOrError.fromResponse(response)); - } - - @Override - public void request(String rpcMethod, NodeConnection node, CompressionType compression, int uncompressedLength, byte[] compressedPayload, - ResponseReceiver responseReceiver, double timeoutSeconds) { - if (malfunctioning) { - responseReceiver.receive(ResponseOrError.fromError("Malfunctioning")); - return; - } - - if(searchResult == null) { - responseReceiver.receive(ResponseOrError.fromError("No result defined")); - return; - } - var payload = ProtobufSerialization.serializeResult(searchResult); - var compressionResult = compressor.compress(compression, payload); - var response = new ProtobufResponse(compressionResult.type().getCode(), payload.length, compressionResult.data()); - responseReceiver.receive(ResponseOrError.fromResponse(response)); - } - public void setDocsumReponse(String nodeId, int docId, String docsumClass, Map<String, Object> docsumValues) { docsums.put(new DocsumKey(nodeId, globalIdFrom(docId), docsumClass), docsumValues); } @@ -100,7 +44,7 @@ public class MockClient implements Client { return new GlobalId(new IdIdString("", "test", "", String.valueOf(hitId))); } - private static class MockNodeConnection implements Client.NodeConnection { + private class MockNodeConnection implements Client.NodeConnection { private final String hostname; @@ -109,6 +53,61 @@ public class MockClient implements Client { } @Override + public void getDocsums(List<FastHit> hitsContext, CompressionType compression, int uncompressedSize, byte[] compressedSlime, + RpcFillInvoker.GetDocsumsResponseReceiver responseReceiver, double timeoutSeconds) { + if (malfunctioning) { + responseReceiver.receive(ResponseOrError.fromError("Malfunctioning")); + return; + } + + Inspector request = BinaryFormat.decode(compressor.decompress(compressedSlime, compression, uncompressedSize)).get(); + String docsumClass = request.field("class").asString(); + List<Map<String, Object>> docsumsToReturn = new ArrayList<>(); + request.field("gids").traverse((ArrayTraverser) (index, gid) -> { + GlobalId docId = new GlobalId(gid.asData()); + docsumsToReturn.add(docsums.get(new DocsumKey(toString(), docId, docsumClass))); + }); + Slime responseSlime = new Slime(); + Cursor root = responseSlime.setObject(); + Cursor docsums = root.setArray("docsums"); + for (Map<String, Object> docsumFields : docsumsToReturn) { + Cursor docsumItem = docsums.addObject(); + Cursor docsum = docsumItem.setObject("docsum"); + for (Map.Entry<String, Object> field : docsumFields.entrySet()) { + if (field.getValue() instanceof Integer) + docsum.setLong(field.getKey(), (Integer) field.getValue()); + else if (field.getValue() instanceof String) + docsum.setString(field.getKey(), (String) field.getValue()); + else + throw new RuntimeException(); + } + } + byte[] slimeBytes = BinaryFormat.encode(responseSlime); + Compressor.Compression compressionResult = compressor.compress(compression, slimeBytes); + GetDocsumsResponse response = new GetDocsumsResponse(compressionResult.type().getCode(), slimeBytes.length, + compressionResult.data(), hitsContext); + responseReceiver.receive(ResponseOrError.fromResponse(response)); + } + + @Override + public void request(String rpcMethod, CompressionType compression, int uncompressedLength, byte[] compressedPayload, + ResponseReceiver responseReceiver, double timeoutSeconds) { + if (malfunctioning) { + responseReceiver.receive(ResponseOrError.fromError("Malfunctioning")); + return; + } + + if(searchResult == null) { + responseReceiver.receive(ResponseOrError.fromError("No result defined")); + return; + } + var payload = ProtobufSerialization.serializeResult(searchResult); + var compressionResult = compressor.compress(compression, payload); + var response = new ProtobufResponse(compressionResult.type().getCode(), payload.length, compressionResult.data()); + responseReceiver.receive(ResponseOrError.fromResponse(response)); + } + + @Override public void close() { } @Override diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/rpc/RpcSearchInvokerTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/RpcSearchInvokerTest.java index 64863b9a8a6..d629bd36bb1 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/rpc/RpcSearchInvokerTest.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/RpcSearchInvokerTest.java @@ -34,7 +34,7 @@ public class RpcSearchInvokerTest { var payloadHolder = new AtomicReference<byte[]>(); var lengthHolder = new AtomicInteger(); var mockClient = parameterCollectorClient(compressionTypeHolder, payloadHolder, lengthHolder); - var mockPool = new RpcResourcePool(mockClient, ImmutableMap.of(7, () -> {})); + var mockPool = new RpcResourcePool(ImmutableMap.of(7, mockClient.createConnection("foo", 123))); @SuppressWarnings("resource") var invoker = new RpcSearchInvoker(mockSearcher(), new Node(7, "seven", 77, 1), mockPool); @@ -53,23 +53,26 @@ public class RpcSearchInvokerTest { AtomicInteger lengthHolder) { return new Client() { @Override - public void request(String rpcMethod, NodeConnection node, CompressionType compression, int uncompressedLength, - byte[] compressedPayload, ResponseReceiver responseReceiver, double timeoutSeconds) { - compressionTypeHolder.set(compression); - payloadHolder.set(compressedPayload); - lengthHolder.set(uncompressedLength); - } + public NodeConnection createConnection(String hostname, int port) { + return new NodeConnection() { + @Override + public void getDocsums(List<FastHit> hits, CompressionType compression, int uncompressedLength, byte[] compressedSlime, + GetDocsumsResponseReceiver responseReceiver, double timeoutSeconds) { + fail("Unexpected call"); + } - @Override - public void getDocsums(List<FastHit> hits, NodeConnection node, CompressionType compression, int uncompressedLength, - byte[] compressedSlime, GetDocsumsResponseReceiver responseReceiver, double timeoutSeconds) { - fail("Unexpected call"); - } + @Override + public void request(String rpcMethod, CompressionType compression, int uncompressedLength, byte[] compressedPayload, + ResponseReceiver responseReceiver, double timeoutSeconds) { + compressionTypeHolder.set(compression); + payloadHolder.set(compressedPayload); + lengthHolder.set(uncompressedLength); + } - @Override - public NodeConnection createConnection(String hostname, int port) { - fail("Unexpected call"); - return null; + @Override + public void close() { + } + }; } }; } diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java index 8eaf4cc08cb..c05c3589a30 100644 --- a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java @@ -77,7 +77,7 @@ public class QueryProfileTypeTestCase { type.addField(new FieldDescription("myBoolean", FieldType.fromString("boolean", registry)), registry); type.addField(new FieldDescription("ranking.features.query(myTensor1)", FieldType.fromString("tensor(a{},b{})", registry)), registry); type.addField(new FieldDescription("ranking.features.query(myTensor2)", FieldType.fromString("tensor(x[2],y[2])", registry)), registry); - type.addField(new FieldDescription("ranking.features.query(myTensor3)", FieldType.fromString("tensor(x{})",registry)), registry); + type.addField(new FieldDescription("ranking.features.query(myTensor3)", FieldType.fromString("tensor<float>(x{})",registry)), registry); type.addField(new FieldDescription("myQuery", FieldType.fromString("query", registry)), registry); type.addField(new FieldDescription("myQueryProfile", FieldType.fromString("query-profile", registry),"qp"), registry); } @@ -136,7 +136,7 @@ public class QueryProfileTypeTestCase { assertEquals(true, properties.get("myBoolean")); assertEquals(Tensor.from(tensorString1), properties.get("ranking.features.query(myTensor1)")); assertEquals(Tensor.from("tensor(x[2],y[2])", tensorString2), properties.get("ranking.features.query(myTensor2)")); - assertEquals(Tensor.from("tensor(x{})", tensorString3), properties.get("ranking.features.query(myTensor3)")); + assertEquals(Tensor.from("tensor<float>(x{})", tensorString3), properties.get("ranking.features.query(myTensor3)")); // TODO: assertEquals(..., cprofile.get("myQuery")); assertEquals("value1", properties.get("myQueryProfile.anyString")); assertEquals("value1", properties.get("QP.anyString")); diff --git a/container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java index 3fa7f1ee47e..b5c4166e4de 100644 --- a/container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java @@ -3,6 +3,7 @@ package com.yahoo.search.yql; import static org.junit.Assert.*; +import com.yahoo.search.query.QueryTree; import org.apache.http.client.utils.URIBuilder; import org.junit.After; import org.junit.Before; @@ -29,20 +30,20 @@ public class UserInputTestCase { @Before public void setUp() throws Exception { - searchChain = new Chain<Searcher>(new MinimalQueryInserter()); + searchChain = new Chain<>(new MinimalQueryInserter()); context = Execution.Context.createContextStub(null); execution = new Execution(searchChain, context); } @After - public void tearDown() throws Exception { + public void tearDown() { searchChain = null; context = null; execution = null; } @Test - public final void testSimpleUserInput() { + public void testSimpleUserInput() { { URIBuilder builder = searchUri(); builder.setParameter("yql", @@ -70,7 +71,7 @@ public class UserInputTestCase { } @Test - public final void testRawUserInput() { + public void testRawUserInput() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"grammar\": \"raw\"}]userInput(\"nal le\");"); @@ -79,7 +80,7 @@ public class UserInputTestCase { } @Test - public final void testSegmentedUserInput() { + public void testSegmentedUserInput() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"grammar\": \"segment\"}]userInput(\"nal le\");"); @@ -88,7 +89,7 @@ public class UserInputTestCase { } @Test - public final void testSegmentedNoiseUserInput() { + public void testSegmentedNoiseUserInput() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"grammar\": \"segment\"}]userInput(\"^^^^^^^^\");"); @@ -97,7 +98,7 @@ public class UserInputTestCase { } @Test - public final void testCustomDefaultIndexUserInput() { + public void testCustomDefaultIndexUserInput() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"defaultIndex\": \"glompf\"}]userInput(\"nalle\");"); @@ -106,7 +107,7 @@ public class UserInputTestCase { } @Test - public final void testAnnotatedUserInputStemming() { + public void testAnnotatedUserInputStemming() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"stem\": false}]userInput(\"nalle\");"); @@ -117,7 +118,7 @@ public class UserInputTestCase { } @Test - public final void testAnnotatedUserInputUnrankedTerms() { + public void testAnnotatedUserInputUnrankedTerms() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"ranked\": false}]userInput(\"nalle\");"); @@ -128,7 +129,7 @@ public class UserInputTestCase { } @Test - public final void testAnnotatedUserInputFiltersTerms() { + public void testAnnotatedUserInputFiltersTerms() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"filter\": true}]userInput(\"nalle\");"); @@ -139,7 +140,7 @@ public class UserInputTestCase { } @Test - public final void testAnnotatedUserInputCaseNormalization() { + public void testAnnotatedUserInputCaseNormalization() { URIBuilder builder = searchUri(); builder.setParameter( "yql", @@ -151,7 +152,7 @@ public class UserInputTestCase { } @Test - public final void testAnnotatedUserInputAccentRemoval() { + public void testAnnotatedUserInputAccentRemoval() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"accentDrop\": false}]userInput(\"nalle\");"); @@ -162,7 +163,7 @@ public class UserInputTestCase { } @Test - public final void testAnnotatedUserInputPositionData() { + public void testAnnotatedUserInputPositionData() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"usePositionData\": false}]userInput(\"nalle\");"); @@ -173,7 +174,7 @@ public class UserInputTestCase { } @Test - public final void testQueryPropertiesAsStringArguments() { + public void testQueryPropertiesAsStringArguments() { URIBuilder builder = searchUri(); builder.setParameter("nalle", "bamse"); builder.setParameter("meta", "syntactic"); @@ -197,7 +198,7 @@ public class UserInputTestCase { } @Test - public final void testEmptyUserInput() { + public void testEmptyUserInput() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where userInput(\"\");"); @@ -205,7 +206,7 @@ public class UserInputTestCase { } @Test - public final void testEmptyUserInputFromQueryProperty() { + public void testEmptyUserInputFromQueryProperty() { URIBuilder builder = searchUri(); builder.setParameter("foo", ""); builder.setParameter("yql", @@ -214,7 +215,7 @@ public class UserInputTestCase { } @Test - public final void testEmptyQueryProperty() { + public void testEmptyQueryProperty() { URIBuilder builder = searchUri(); builder.setParameter("foo", ""); builder.setParameter("yql", "select * from sources * where bar contains \"a\" and nonEmpty(foo contains @foo);"); @@ -222,7 +223,7 @@ public class UserInputTestCase { } @Test - public final void testEmptyQueryPropertyInsideExpression() { + public void testEmptyQueryPropertyInsideExpression() { URIBuilder builder = searchUri(); builder.setParameter("foo", ""); builder.setParameter("yql", @@ -231,7 +232,7 @@ public class UserInputTestCase { } @Test - public final void testCompositeWithoutArguments() { + public void testCompositeWithoutArguments() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where bar contains \"a\" and foo contains phrase();"); searchAndAssertNoErrors(builder); @@ -241,7 +242,7 @@ public class UserInputTestCase { } @Test - public final void testAnnoyingPlacementOfNonEmpty() { + public void testAnnoyingPlacementOfNonEmpty() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where bar contains \"a\" and foo contains nonEmpty(phrase(\"a\", \"b\"));"); @@ -254,7 +255,7 @@ public class UserInputTestCase { } @Test - public final void testAllowEmptyUserInput() { + public void testAllowEmptyUserInput() { URIBuilder builder = searchUri(); builder.setParameter("foo", ""); builder.setParameter("yql", "select * from sources * where [{\"allowEmpty\": true}]userInput(@foo);"); @@ -262,7 +263,7 @@ public class UserInputTestCase { } @Test - public final void testAllowEmptyNullFromQueryParsing() { + public void testAllowEmptyNullFromQueryParsing() { URIBuilder builder = searchUri(); builder.setParameter("foo", ",,,,,,,,"); builder.setParameter("yql", "select * from sources * where [{\"allowEmpty\": true}]userInput(@foo);"); @@ -270,7 +271,7 @@ public class UserInputTestCase { } @Test - public final void testDisallowEmptyNullFromQueryParsing() { + public void testDisallowEmptyNullFromQueryParsing() { URIBuilder builder = searchUri(); builder.setParameter("foo", ",,,,,,,,"); builder.setParameter("yql", "select * from sources * where userInput(@foo);"); diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/RoleId.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/RoleId.java new file mode 100644 index 00000000000..199f233835f --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/RoleId.java @@ -0,0 +1,116 @@ +package com.yahoo.vespa.hosted.controller.api.integration.user; + +import com.yahoo.config.provision.ApplicationName; +import com.yahoo.config.provision.TenantName; +import com.yahoo.vespa.hosted.controller.api.role.ApplicationRole; +import com.yahoo.vespa.hosted.controller.api.role.RoleDefinition; +import com.yahoo.vespa.hosted.controller.api.role.Role; +import com.yahoo.vespa.hosted.controller.api.role.Roles; +import com.yahoo.vespa.hosted.controller.api.role.TenantRole; + +import java.util.Objects; + +/** + * An identifier for a role which users identified by {@link UserId}s can be members of, corresponding to a bound {@link Role}. + * + * @author jonmv + */ +public class RoleId { + + private final String value; + + private RoleId(String value) { + if (value.isBlank()) + throw new IllegalArgumentException("Id value must be non-blank."); + this.value = value; + } + + public static RoleId fromRole(TenantRole role) { + return new RoleId(valueOf(role)); + } + + public static RoleId fromRole(ApplicationRole role) { + return new RoleId(valueOf(role)); + } + + public static RoleId fromValue(String value) { + return new RoleId(value); + } + + public String value() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RoleId id = (RoleId) o; + return Objects.equals(value, id.value); + } + + @Override + public int hashCode() { + return Objects.hash(value); + } + + @Override + public String toString() { + return "role '" + value + "'"; + } + + /** Returns the {@link Role} this represent. */ + public Role toRole(Roles roles) { + String[] parts = value.split("\\."); + if (parts.length == 2) switch (parts[1]) { + case "tenantOwner": return roles.tenantOwner(TenantName.from(parts[0])); + case "tenantAdmin": return roles.tenantAdmin(TenantName.from(parts[0])); + case "tenantOperator": return roles.tenantOperator(TenantName.from(parts[0])); + } + if (parts.length == 3) switch (parts[2]) { + case "applicationOwner": return roles.applicationOwner(TenantName.from(parts[0]), ApplicationName.from(parts[1])); + case "applicationAdmin": return roles.applicationAdmin(TenantName.from(parts[0]), ApplicationName.from(parts[1])); + case "applicationOperator": return roles.applicationOperator(TenantName.from(parts[0]), ApplicationName.from(parts[1])); + case "applicationDeveloper": return roles.applicationDeveloper(TenantName.from(parts[0]), ApplicationName.from(parts[1])); + case "applicationReader": return roles.applicationReader(TenantName.from(parts[0]), ApplicationName.from(parts[1])); + } + throw new IllegalArgumentException("Malformed or illegal role value '" + value + "'."); + } + + private static String valueOf(TenantRole role) { + return valueOf(role.tenant()) + "." + valueOf(role.definition()); + } + + private static String valueOf(ApplicationRole role) { + return valueOf(role.tenant()) + "." + valueOf(role.application()) + "." + valueOf(role.definition()); + } + + private static String valueOf(TenantName tenant) { + if (tenant.value().contains(".")) + throw new IllegalArgumentException("Tenant names may not contain '.'."); + + return tenant.value(); + } + + private static String valueOf(ApplicationName application) { + if (application.value().contains(".")) + throw new IllegalArgumentException("Application names may not contain '.'."); + + return application.value(); + } + + private static String valueOf(RoleDefinition role) { + switch (role) { + case tenantOwner: return "tenantOwner"; + case tenantAdmin: return "tenantAdmin"; + case tenantOperator: return "tenantOperator"; + case applicationOwner: return "applicationOwner"; + case applicationAdmin: return "applicationAdmin"; + case applicationOperator: return "applicationOperator"; + case applicationDeveloper: return "applicationDeveloper"; + case applicationReader: return "applicationReader"; + default: throw new IllegalArgumentException("No value defined for role '" + role + "'."); + } + } + +} diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/UserId.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/UserId.java new file mode 100644 index 00000000000..3b138d0ce18 --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/UserId.java @@ -0,0 +1,40 @@ +package com.yahoo.vespa.hosted.controller.api.integration.user; + +import java.util.Objects; + +/** + * An identifier for a user. + * + * @author jonmv + */ +public class UserId { + + private final String value; + + public UserId(String value) { + if (value.isBlank()) + throw new IllegalArgumentException("Id must be non-blank."); + this.value = value; + } + + public String value() { return value; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + UserId id = (UserId) o; + return Objects.equals(value, id.value); + } + + @Override + public int hashCode() { + return Objects.hash(value); + } + + @Override + public String toString() { + return "user '" + value + "'"; + } + +} diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/UserManagement.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/UserManagement.java new file mode 100644 index 00000000000..c78dcc76854 --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/UserManagement.java @@ -0,0 +1,33 @@ +package com.yahoo.vespa.hosted.controller.api.integration.user; + +import com.yahoo.vespa.hosted.controller.api.role.Role; + +import java.util.Collection; +import java.util.List; + +/** + * Management of {@link UserId}s and {@link RoleId}s, used for access control with {@link Role}s. + * + * @author jonmv + */ +public interface UserManagement { + + /** Creates the given role, or throws if the role already exists. */ + void createRole(RoleId role); + + /** Deletes the given role, or throws if it doesn't already exist. */ + void deleteRole(RoleId role); + + /** Ensures the given users exist, and are part of the given role, or throws if the role does not exist. */ + void addUsers(RoleId role, Collection<UserId> users); + + /** Ensures none of the given users are part of the given role, or throws if the role does not exist. */ + void removeUsers(RoleId role, Collection<UserId> users); + + /** Returns all known roles. */ + List<RoleId> listRoles(); + + /** Returns all users in the given role, or throws if the role does not exist. */ + List<UserId> listUsers(RoleId role); + +} diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/package-info.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/package-info.java new file mode 100644 index 00000000000..ca595bab172 --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/user/package-info.java @@ -0,0 +1,5 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +package com.yahoo.vespa.hosted.controller.api.integration.user; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Action.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Action.java index 533c28905a9..2d9ef25d1f5 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Action.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Action.java @@ -1,5 +1,5 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.controller.role; +package com.yahoo.vespa.hosted.controller.api.role; import com.yahoo.jdisc.http.HttpRequest; diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/ApplicationRole.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/ApplicationRole.java new file mode 100644 index 00000000000..cc1e8462580 --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/ApplicationRole.java @@ -0,0 +1,29 @@ +package com.yahoo.vespa.hosted.controller.api.role; + +import com.yahoo.config.provision.ApplicationName; +import com.yahoo.config.provision.SystemName; +import com.yahoo.config.provision.TenantName; + +/** + * A {@link Role} with a {@link Context} of a {@link SystemName} a {@link TenantName} and an {@link ApplicationName}. + * + * @author jonmv + */ +public class ApplicationRole extends Role { + + ApplicationRole(RoleDefinition roleDefinition, SystemName system, TenantName tenant, ApplicationName application) { + super(roleDefinition, Context.limitedTo(tenant, application, system)); + } + + /** Returns the {@link TenantName} this is bound to. */ + public TenantName tenant() { return context.tenant().get(); } + + /** Returns the {@link ApplicationName} this is bound to. */ + public ApplicationName application() { return context.application().get(); } + + @Override + public String toString() { + return "role '" + definition() + "' of '" + application() + "' owned by '" + tenant() + "'"; + } + +} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Context.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Context.java index 71452a3ef20..3ba0367a00c 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Context.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Context.java @@ -1,5 +1,5 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.controller.role; +package com.yahoo.vespa.hosted.controller.api.role; import com.yahoo.config.provision.ApplicationName; import com.yahoo.config.provision.SystemName; diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/PathGroup.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java index ef97421119f..edf3f4e8711 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/PathGroup.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java @@ -1,5 +1,5 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.controller.role; +package com.yahoo.vespa.hosted.controller.api.role; import com.yahoo.restapi.Path; @@ -51,10 +51,10 @@ public enum PathGroup { Matcher.application, "/application/v4/tenant/{tenant}/application/{application}/deploying/{*}", "/application/v4/tenant/{tenant}/application/{application}/instance/{*}", - "/application/v4/tenant/{tenant}/application/{application}/environment/prod/region/{region}/instance/{instance}/logs", - "/application/v4/tenant/{tenant}/application/{application}/environment/prod/region/{region}/instance/{instance}/suspended", - "/application/v4/tenant/{tenant}/application/{application}/environment/prod/region/{region}/instance/{instance}/service/{*}", - "/application/v4/tenant/{tenant}/application/{application}/environment/prod/region/{region}/instance/{instance}/global-rotation/{*}"), + "/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/logs", + "/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/suspended", + "/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/service/{*}", + "/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/global-rotation/{*}"), /** Path used to restart application nodes. */ // TODO move to the above when everyone is on new pipeline. applicationRestart(Matcher.tenant, diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Policy.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Policy.java index 6ae68f598f0..970717b14a3 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Policy.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Policy.java @@ -1,5 +1,5 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.controller.role; +package com.yahoo.vespa.hosted.controller.api.role; import com.yahoo.config.provision.ApplicationName; import com.yahoo.config.provision.SystemName; @@ -39,9 +39,14 @@ public enum Policy { .in(SystemName.main, SystemName.cd, SystemName.dev)), // TODO SystemName.all() /** Full access to tenant information and settings. */ - tenantWrite(Privilege.grant(Action.write()) - .on(PathGroup.tenant) - .in(SystemName.all())), + tenantDelete(Privilege.grant(Action.delete) + .on(PathGroup.tenant) + .in(SystemName.all())), + + /** Full access to tenant information and settings. */ + tenantUpdate(Privilege.grant(Action.update) + .on(PathGroup.tenant) + .in(SystemName.all())), /** Read access to tenant information and settings. */ tenantRead(Privilege.grant(Action.read) diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Privilege.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Privilege.java index 4c5ad136f56..a53717b25d6 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Privilege.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Privilege.java @@ -1,5 +1,5 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.controller.role; +package com.yahoo.vespa.hosted.controller.api.role; import com.yahoo.config.provision.SystemName; diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Role.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Role.java new file mode 100644 index 00000000000..86d59b4bbb6 --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Role.java @@ -0,0 +1,48 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.api.role; + +import java.net.URI; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +/** + * A role is a combination of a {@link RoleDefinition} and a {@link Context}, which allows evaluation + * of access control for a given action on a resource. Create using {@link Roles}. + * + * @author jonmv + */ +public abstract class Role { + + private final RoleDefinition roleDefinition; + final Context context; + + Role(RoleDefinition roleDefinition, Context context) { + this.roleDefinition = requireNonNull(roleDefinition); + this.context = requireNonNull(context); + } + + /** Returns the role definition of this bound role. */ + public RoleDefinition definition() { return roleDefinition; } + + /** Returns whether this role is allowed to perform the given action on the given resource. */ + public boolean allows(Action action, URI uri) { + return roleDefinition.policies().stream().anyMatch(policy -> policy.evaluate(action, uri, context)); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Role role = (Role) o; + return roleDefinition == role.roleDefinition && + Objects.equals(context, role.context); + } + + @Override + public int hashCode() { + return Objects.hash(roleDefinition, context); + } + +} + diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Role.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/RoleDefinition.java index d82e4063391..e9c2f7bc643 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/Role.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/RoleDefinition.java @@ -1,21 +1,20 @@ -// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.controller.role; +package com.yahoo.vespa.hosted.controller.api.role; -import java.net.URI; import java.util.EnumSet; import java.util.Set; /** * This declares all tenant roles known to the controller. A role contains one or more {@link Policy}s which decide - * what actions a member of a role can perform. + * what actions a member of a role can perform, given a {@link Context} for the action. * - * Optionally, some role definition also inherit all policies from a "lower ranking" role. Read the list of roles - * from {@code everyone} to {@code tenantAdmin}, in order, to see what policies these roles. + * Optionally, some role definitions also inherit all policies from a "lower ranking" role. + * + * See {@link Role} for roles bound to a context, where policies can be evaluated. * * @author mpolden * @author jonmv */ -public enum Role { +public enum RoleDefinition { /** Deus ex machina. */ hostedOperator(Policy.operator), @@ -50,45 +49,52 @@ public enum Role { Policy.productionDeployment, Policy.submission), - /** Tenant admin with full access to all tenant resources, including the ability to create new applications. */ - tenantAdmin(applicationAdmin, - Policy.applicationCreate, + /** Application administrator with the additional ability to delete an application. */ + applicationOwner(applicationOperator, + Policy.applicationDelete), + + /** Tenant operator with admin access to all applications under the tenant, as well as the ability to create applications. */ + tenantOperator(applicationAdmin, + Policy.applicationCreate), + + /** Tenant admin with full access to all tenant resources, except deleting the tenant. */ + tenantAdmin(tenantOperator, Policy.applicationDelete, Policy.manager, - Policy.tenantWrite), + Policy.tenantUpdate), + + /** Tenant admin with full access to all tenant resources. */ + tenantOwner(tenantAdmin, + Policy.tenantDelete), /** Build and continuous delivery service. */ // TODO replace with buildService, when everyone is on new pipeline. - tenantPipeline(Policy.submission, + tenantPipeline(everyone, + Policy.submission, Policy.deploymentPipeline, Policy.productionDeployment), /** Tenant administrator with full access to all child resources. */ - athenzTenantAdmin(Policy.tenantWrite, + athenzTenantAdmin(everyone, Policy.tenantRead, + Policy.tenantUpdate, + Policy.tenantDelete, Policy.applicationCreate, Policy.applicationUpdate, Policy.applicationDelete, Policy.applicationOperations, - Policy.developmentDeployment); // TODO remove, as it is covered by applicationAdmin. + Policy.developmentDeployment); private final Set<Policy> policies; - Role(Policy... policies) { + RoleDefinition(Policy... policies) { this.policies = EnumSet.copyOf(Set.of(policies)); } - Role(Role inherited, Policy... policies) { + RoleDefinition(RoleDefinition inherited, Policy... policies) { this.policies = EnumSet.copyOf(Set.of(policies)); this.policies.addAll(inherited.policies); } - /** - * Returns whether this role is allowed to perform action in given role context. Action is allowed if at least one - * policy evaluates to true. - */ - public boolean allows(Action action, URI uri, Context context) { - return policies.stream().anyMatch(policy -> policy.evaluate(action, uri, context)); - } + Set<Policy> policies() { return policies; } } - diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Roles.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Roles.java new file mode 100644 index 00000000000..f6149bf6e88 --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Roles.java @@ -0,0 +1,104 @@ +package com.yahoo.vespa.hosted.controller.api.role; + +import com.google.inject.Inject; +import com.yahoo.config.provision.ApplicationName; +import com.yahoo.config.provision.SystemName; +import com.yahoo.config.provision.TenantName; +import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneRegistry; + +import java.util.Objects; + +/** + * Use if you need to create {@link Role}s for its system. + * + * This also defines the relationship between {@link RoleDefinition}s and their required {@link Context}s. + * + * @author jonmv + */ +public class Roles { + + private final SystemName system; + + + @Inject + public Roles(ZoneRegistry zones) { + this(zones.system()); + } + + /** Creates a Roles which can be used to create bound roles for the given system. */ + public Roles(SystemName system) { + this.system = Objects.requireNonNull(system); + } + + + // General roles. + /** Returns a {@link RoleDefinition#hostedOperator} for the current system. */ + public UnboundRole hostedOperator() { + return new UnboundRole(RoleDefinition.hostedOperator, system); + } + + /** Returns a {@link RoleDefinition#everyone} for the current system. */ + public UnboundRole everyone() { + return new UnboundRole(RoleDefinition.everyone, system); + } + + + // Athenz based roles. + /** Returns a {@link RoleDefinition#athenzTenantAdmin} for the current system and given tenant. */ + public TenantRole athenzTenantAdmin(TenantName tenant) { + return new TenantRole(RoleDefinition.athenzTenantAdmin, system, tenant); + } + + /** Returns a {@link RoleDefinition#tenantPipeline} for the current system and given tenant and application. */ + public ApplicationRole tenantPipeline(TenantName tenant, ApplicationName application) { + return new ApplicationRole(RoleDefinition.tenantPipeline, system, tenant, application); + } + + + // Other identity provider based roles. + /** Returns a {@link RoleDefinition#tenantOwner} for the current system and given tenant. */ + public TenantRole tenantOwner(TenantName tenant) { + return new TenantRole(RoleDefinition.tenantOwner, system, tenant); + } + + /** Returns a {@link RoleDefinition#tenantAdmin} for the current system and given tenant. */ + public TenantRole tenantAdmin(TenantName tenant) { + return new TenantRole(RoleDefinition.tenantAdmin, system, tenant); + } + + /** Returns a {@link RoleDefinition#tenantOperator} for the current system and given tenant. */ + public TenantRole tenantOperator(TenantName tenant) { + return new TenantRole(RoleDefinition.tenantOperator, system, tenant); + } + + /** Returns a {@link RoleDefinition#applicationOwner} for the current system and given tenant and application. */ + public ApplicationRole applicationOwner(TenantName tenant, ApplicationName application) { + return new ApplicationRole(RoleDefinition.applicationOwner, system, tenant, application); + } + + /** Returns a {@link RoleDefinition#applicationAdmin} for the current system and given tenant and application. */ + public ApplicationRole applicationAdmin(TenantName tenant, ApplicationName application) { + return new ApplicationRole(RoleDefinition.applicationAdmin, system, tenant, application); + } + + /** Returns a {@link RoleDefinition#applicationOperator} for the current system and given tenant and application. */ + public ApplicationRole applicationOperator(TenantName tenant, ApplicationName application) { + return new ApplicationRole(RoleDefinition.applicationOperator, system, tenant, application); + } + + /** Returns a {@link RoleDefinition#applicationDeveloper} for the current system and given tenant and application. */ + public ApplicationRole applicationDeveloper(TenantName tenant, ApplicationName application) { + return new ApplicationRole(RoleDefinition.applicationDeveloper, system, tenant, application); + } + + /** Returns a {@link RoleDefinition#applicationReader} for the current system and given tenant and application. */ + public ApplicationRole applicationReader(TenantName tenant, ApplicationName application) { + return new ApplicationRole(RoleDefinition.applicationReader, system, tenant, application); + } + + /** Returns a {@link RoleDefinition#buildService} for the current system and given tenant and application. */ + public ApplicationRole buildService(TenantName tenant, ApplicationName application) { + return new ApplicationRole(RoleDefinition.buildService, system, tenant, application); + } + +} diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/SecurityContext.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/SecurityContext.java new file mode 100644 index 00000000000..41444258a68 --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/SecurityContext.java @@ -0,0 +1,51 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.api.role; + +import java.security.Principal; +import java.util.Objects; +import java.util.Set; + +import static java.util.Objects.requireNonNull; + +public class SecurityContext { + + public static final String ATTRIBUTE_NAME = SecurityContext.class.getName(); + + private final Principal principal; + private final Set<Role> roles; + + public SecurityContext(Principal principal, Set<Role> roles) { + this.principal = requireNonNull(principal); + this.roles = Set.copyOf(roles); + } + + public Principal principal() { + return principal; + } + + public Set<Role> roles() { + return roles; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SecurityContext that = (SecurityContext) o; + return Objects.equals(principal, that.principal) && + Objects.equals(roles, that.roles); + } + + @Override + public int hashCode() { + return Objects.hash(principal, roles); + } + + @Override + public String toString() { + return "SecurityContext{" + + "principal=" + principal + + ", roles=" + roles + + '}'; + } +} diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/TenantRole.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/TenantRole.java new file mode 100644 index 00000000000..134628ec3a3 --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/TenantRole.java @@ -0,0 +1,25 @@ +package com.yahoo.vespa.hosted.controller.api.role; + +import com.yahoo.config.provision.SystemName; +import com.yahoo.config.provision.TenantName; + +/** + * A {@link Role} with a {@link Context} of a {@link SystemName} and a {@link TenantName}. + * + * @author jonmv + */ +public class TenantRole extends Role { + + TenantRole(RoleDefinition roleDefinition, SystemName system, TenantName tenant) { + super(roleDefinition, Context.limitedTo(tenant, system)); + } + + /** Returns the {@link TenantName} this is bound to. */ + public TenantName tenant() { return context.tenant().get(); } + + @Override + public String toString() { + return "role '" + definition() + "' of '" + tenant() + "'"; + } + +} diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/UnboundRole.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/UnboundRole.java new file mode 100644 index 00000000000..eb8319b2012 --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/UnboundRole.java @@ -0,0 +1,21 @@ +package com.yahoo.vespa.hosted.controller.api.role; + +import com.yahoo.config.provision.SystemName; + +/** + * A {@link Role} with a {@link Context} of only a {@link SystemName}. + * + * @author jonmv + */ +public class UnboundRole extends Role { + + UnboundRole(RoleDefinition roleDefinition, SystemName system) { + super(roleDefinition, Context.unlimitedIn(system)); + } + + @Override + public String toString() { + return "role '" + definition() + "'"; + } + +} diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/package-info.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/package-info.java new file mode 100644 index 00000000000..a7f70d6fe3c --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/package-info.java @@ -0,0 +1,5 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +package com.yahoo.vespa.hosted.controller.api.role; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/user/RoleIdTest.java b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/user/RoleIdTest.java new file mode 100644 index 00000000000..609646eb672 --- /dev/null +++ b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/integration/user/RoleIdTest.java @@ -0,0 +1,74 @@ +package com.yahoo.vespa.hosted.controller.api.integration.user; + +import com.yahoo.config.provision.ApplicationName; +import com.yahoo.config.provision.SystemName; +import com.yahoo.config.provision.TenantName; +import com.yahoo.vespa.hosted.controller.api.role.ApplicationRole; +import com.yahoo.vespa.hosted.controller.api.role.Roles; +import com.yahoo.vespa.hosted.controller.api.role.TenantRole; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * @author jonmv + */ +public class RoleIdTest { + + @Test + public void testSerialization() { + Roles roles = new Roles(SystemName.main); + + TenantName tenant = TenantName.from("my-tenant"); + for (TenantRole role : List.of(roles.tenantOwner(tenant), + roles.tenantAdmin(tenant), + roles.tenantOperator(tenant))) + assertEquals(role, RoleId.fromRole(role).toRole(roles)); + + ApplicationName application = ApplicationName.from("my-application"); + for (ApplicationRole role : List.of(roles.applicationOwner(tenant, application), + roles.applicationAdmin(tenant, application), + roles.applicationOperator(tenant, application), + roles.applicationDeveloper(tenant, application), + roles.applicationReader(tenant, application))) + assertEquals(role, RoleId.fromRole(role).toRole(roles)); + + assertEquals(roles.tenantOperator(tenant), + RoleId.fromValue("my-tenant.tenantOperator").toRole(roles)); + assertEquals(roles.applicationReader(tenant, application), + RoleId.fromValue("my-tenant.my-application.applicationReader").toRole(roles)); + } + + @Test(expected = IllegalArgumentException.class) + public void illegalTenantName() { + RoleId.fromRole(new Roles(SystemName.main).tenantAdmin(TenantName.from("my.tenant"))); + } + + @Test(expected = IllegalArgumentException.class) + public void illegalApplicationName() { + RoleId.fromRole(new Roles(SystemName.main).applicationOperator(TenantName.from("my-tenant"), ApplicationName.from("my.app"))); + } + + @Test(expected = IllegalArgumentException.class) + public void illegalRole() { + RoleId.fromRole(new Roles(SystemName.main).tenantPipeline(TenantName.from("my-tenant"), ApplicationName.from("my-app"))); + } + + @Test(expected = IllegalArgumentException.class) + public void illegalRoleValue() { + RoleId.fromValue("my-tenant.awesomePerson").toRole(new Roles(SystemName.cd)); + } + + @Test(expected = IllegalArgumentException.class) + public void illegalCombination() { + RoleId.fromValue("my-tenant.my-application.tenantOwner").toRole(new Roles(SystemName.cd)); + } + + @Test(expected = IllegalArgumentException.class) + public void illegalValue() { + RoleId.fromValue("hostedOperator").toRole(new Roles(SystemName.Public)); + } + +} diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/role/PathGroupTest.java b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/role/PathGroupTest.java index b4a3e674594..9d76d055877 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/role/PathGroupTest.java +++ b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/role/PathGroupTest.java @@ -1,9 +1,8 @@ -package com.yahoo.vespa.hosted.controller.role; +package com.yahoo.vespa.hosted.controller.api.role; import org.junit.Test; import java.util.HashSet; -import java.util.LinkedHashSet; import java.util.Set; import java.util.regex.Pattern; diff --git a/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/role/RoleTest.java b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/role/RoleTest.java new file mode 100644 index 00000000000..1badd157b1b --- /dev/null +++ b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/role/RoleTest.java @@ -0,0 +1,54 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.api.role; + +import com.yahoo.config.provision.ApplicationName; +import com.yahoo.config.provision.SystemName; +import com.yahoo.config.provision.TenantName; +import org.junit.Test; + +import java.net.URI; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * @author mpolden + */ +public class RoleTest { + + @Test + public void operator_membership() { + Role role = new Roles(SystemName.main).hostedOperator(); + + // Operator actions + assertFalse(role.allows(Action.create, URI.create("/not/explicitly/defined"))); + assertTrue(role.allows(Action.create, URI.create("/controller/v1/foo"))); + assertTrue(role.allows(Action.update, URI.create("/os/v1/bar"))); + assertTrue(role.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1"))); + assertTrue(role.allows(Action.update, URI.create("/application/v4/tenant/t2/application/a2"))); + } + + @Test + public void tenant_membership() { + Role role = new Roles(SystemName.main).athenzTenantAdmin(TenantName.from("t1")); + assertFalse(role.allows(Action.create, URI.create("/not/explicitly/defined"))); + assertFalse("Deny access to operator API", role.allows(Action.create, URI.create("/controller/v1/foo"))); + assertFalse("Deny access to other tenant and app", role.allows(Action.update, URI.create("/application/v4/tenant/t2/application/a2"))); + assertTrue(role.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1"))); + + Role publicSystem = new Roles(SystemName.vaas).athenzTenantAdmin(TenantName.from("t1")); + assertFalse(publicSystem.allows(Action.read, URI.create("/controller/v1/foo"))); + assertTrue(publicSystem.allows(Action.read, URI.create("/badge/v1/badge"))); + assertTrue(publicSystem.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1"))); + } + + @Test + public void build_service_membership() { + Role role = new Roles(SystemName.vaas).tenantPipeline(TenantName.from("t1"), ApplicationName.from("a1")); + assertFalse(role.allows(Action.create, URI.create("/not/explicitly/defined"))); + assertFalse(role.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1"))); + assertTrue(role.allows(Action.create, URI.create("/application/v4/tenant/t1/application/a1/jobreport"))); + assertFalse("No global read access", role.allows(Action.read, URI.create("/controller/v1/foo"))); + } + +} diff --git a/controller-server/pom.xml b/controller-server/pom.xml index c4cb66de3ec..f22142db727 100644 --- a/controller-server/pom.xml +++ b/controller-server/pom.xml @@ -100,6 +100,13 @@ <scope>provided</scope> </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>flags</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <!-- compile --> <dependency> diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java index 1d685895914..b6993fbc421 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java @@ -13,6 +13,9 @@ import com.yahoo.vespa.athenz.api.AthenzDomain; import com.yahoo.vespa.athenz.api.AthenzPrincipal; import com.yahoo.vespa.athenz.api.AthenzUser; import com.yahoo.vespa.curator.Lock; +import com.yahoo.vespa.flags.BooleanFlag; +import com.yahoo.vespa.flags.FetchVector; +import com.yahoo.vespa.flags.Flags; import com.yahoo.vespa.hosted.controller.api.ActivateResult; import com.yahoo.vespa.hosted.controller.api.application.v4.model.DeployOptions; import com.yahoo.vespa.hosted.controller.api.application.v4.model.EndpointStatus; @@ -42,6 +45,7 @@ import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.application.ApplicationPackage; import com.yahoo.vespa.hosted.controller.application.Deployment; import com.yahoo.vespa.hosted.controller.application.DeploymentMetrics; +import com.yahoo.vespa.hosted.controller.application.GlobalDnsName; import com.yahoo.vespa.hosted.controller.application.JobList; import com.yahoo.vespa.hosted.controller.application.JobStatus; import com.yahoo.vespa.hosted.controller.application.JobStatus.JobRun; @@ -112,6 +116,7 @@ public class ApplicationController { private final ConfigServer configServer; private final RoutingGenerator routingGenerator; private final Clock clock; + private final BooleanFlag redirectLegacyDnsFlag; private final DeploymentTrigger deploymentTrigger; @@ -127,6 +132,7 @@ public class ApplicationController { this.configServer = configServer; this.routingGenerator = routingGenerator; this.clock = clock; + this.redirectLegacyDnsFlag = Flags.REDIRECT_LEGACY_DNS_NAMES.bindTo(controller.flagSource()); this.artifactRepository = artifactRepository; this.applicationStore = applicationStore; @@ -231,14 +237,14 @@ public class ApplicationController { com.yahoo.vespa.hosted.controller.api.identifiers.ApplicationId.validate(id.application().value()); Optional<Tenant> tenant = controller.tenants().get(id.tenant()); - if ( ! tenant.isPresent()) + if (tenant.isEmpty()) throw new IllegalArgumentException("Could not create '" + id + "': This tenant does not exist"); if (get(id).isPresent()) throw new IllegalArgumentException("Could not create '" + id + "': Application already exists"); if (get(dashToUnderscore(id)).isPresent()) // VESPA-1945 throw new IllegalArgumentException("Could not create '" + id + "': Application " + dashToUnderscore(id) + " already exists"); if (tenant.get().type() != Tenant.Type.user) { - if ( ! credentials.isPresent()) + if (credentials.isEmpty()) throw new IllegalArgumentException("Could not create '" + id + "': No credentials provided"); if (id.instance().isDefault()) // Only store the application permits for non-user applications. @@ -269,7 +275,7 @@ public class ApplicationController { throw new IllegalArgumentException("'" + applicationId + "' is a tester application!"); Tenant tenant = controller.tenants().require(applicationId.tenant()); - if (tenant.type() == Tenant.Type.user && ! get(applicationId).isPresent()) + if (tenant.type() == Tenant.Type.user && get(applicationId).isEmpty()) createApplication(applicationId, Optional.empty()); try (Lock deploymentLock = lockForDeployment(applicationId, zone)) { @@ -292,15 +298,15 @@ public class ApplicationController { () -> new IllegalArgumentException("Application package must be given when deploying to " + zone)); platformVersion = options.vespaVersion.map(Version::new).orElse(applicationPackage.deploymentSpec().majorVersion() .flatMap(this::lastCompatibleVersion) - .orElse(controller.systemVersion())); + .orElseGet(controller::systemVersion)); } else { JobType jobType = JobType.from(controller.system(), zone) .orElseThrow(() -> new IllegalArgumentException("No job is known for " + zone + ".")); Optional<JobStatus> job = Optional.ofNullable(application.get().deploymentJobs().jobStatus().get(jobType)); - if ( ! job.isPresent() - || ! job.get().lastTriggered().isPresent() - || job.get().lastCompleted().isPresent() && job.get().lastCompleted().get().at().isAfter(job.get().lastTriggered().get().at())) + if ( job.isEmpty() + || job.get().lastTriggered().isEmpty() + || job.get().lastCompleted().isPresent() && job.get().lastCompleted().get().at().isAfter(job.get().lastTriggered().get().at())) return unexpectedDeployment(applicationId, zone); JobRun triggered = job.get().lastTriggered().get(); platformVersion = preferOldestVersion ? triggered.sourcePlatform().orElse(triggered.platform()) @@ -382,7 +388,7 @@ public class ApplicationController { application = withoutUnreferencedDeploymentJobs(application); store(application); - return(application); + return application; } /** Deploy a system application to given zone */ @@ -432,20 +438,28 @@ public class ApplicationController { application = application.with(rotation.id()); store(application); // store assigned rotation even if deployment fails - registerRotationInDns(rotation, application.get().globalDnsName(controller.system()).get().dnsName()); - registerRotationInDns(rotation, application.get().globalDnsName(controller.system()).get().secureDnsName()); - registerRotationInDns(rotation, application.get().globalDnsName(controller.system()).get().oathDnsName()); + GlobalDnsName dnsName = application.get().globalDnsName(controller.system()) + .orElseThrow(() -> new IllegalStateException("Expected rotation to be assigned")); + boolean redirectLegacyDns = redirectLegacyDnsFlag.with(FetchVector.Dimension.APPLICATION_ID, application.get().id().serializedForm()) + .value(); + registerCname(dnsName.oathDnsName(), rotation.name()); + if (redirectLegacyDns) { + registerCname(dnsName.dnsName(), dnsName.oathDnsName()); + registerCname(dnsName.secureDnsName(), dnsName.oathDnsName()); + } else { + registerCname(dnsName.dnsName(), rotation.name()); + registerCname(dnsName.secureDnsName(), rotation.name()); + } } } return application; } - private ActivateResult unexpectedDeployment(ApplicationId applicationId, ZoneId zone) { - + private ActivateResult unexpectedDeployment(ApplicationId application, ZoneId zone) { Log logEntry = new Log(); logEntry.level = "WARNING"; logEntry.time = clock.instant().toEpochMilli(); - logEntry.message = "Ignoring deployment of " + require(applicationId) + " to " + zone + + logEntry.message = "Ignoring deployment of application '" + application + "' to " + zone + " as a deployment is not currently expected"; PrepareResponse prepareResponse = new PrepareResponse(); prepareResponse.log = Collections.singletonList(logEntry); @@ -495,24 +509,22 @@ public class ApplicationController { options.deployCurrentVersion); } - /** Register a DNS name for rotation */ - private void registerRotationInDns(Rotation rotation, String dnsName) { + /** Register a CNAME record in DNS */ + private void registerCname(String name, String targetName) { try { - - RecordData rotationName = RecordData.fqdn(rotation.name()); - List<Record> records = nameService.findRecords(Record.Type.CNAME, RecordName.from(dnsName)); + RecordData data = RecordData.fqdn(targetName); + List<Record> records = nameService.findRecords(Record.Type.CNAME, RecordName.from(name)); records.forEach(record -> { - // Ensure that the existing record points to the correct rotation - if ( ! record.data().equals(rotationName)) { - nameService.updateRecord(record, rotationName); - log.info("Updated mapping for record '" + record + "': '" + dnsName - + "' -> '" + rotation.name() + "'"); + // Ensure that the existing record points to the correct target + if ( ! record.data().equals(data)) { + log.info("Updating mapping for record '" + record + "': '" + name + + "' -> '" + data.asString() + "'"); + nameService.updateRecord(record, data); } }); - if (records.isEmpty()) { - Record record = nameService.createCname(RecordName.from(dnsName), rotationName); - log.info("Registered mapping as record '" + record + "'"); + Record record = nameService.createCname(RecordName.from(name), data); + log.info("Registered mapping as record '" + record + "'"); } } catch (RuntimeException e) { log.log(Level.WARNING, "Failed to register CNAME", e); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java index 6e59c384485..7754286ba9e 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Controller.java @@ -9,8 +9,7 @@ import com.yahoo.config.provision.CloudName; import com.yahoo.config.provision.HostName; import com.yahoo.config.provision.SystemName; import com.yahoo.vespa.curator.Lock; -import com.yahoo.vespa.hosted.controller.api.identifiers.Property; -import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; +import com.yahoo.vespa.flags.FlagSource; import com.yahoo.vespa.hosted.controller.api.integration.BuildService; import com.yahoo.vespa.hosted.controller.api.integration.MetricsService; import com.yahoo.vespa.hosted.controller.api.integration.RunDataStore; @@ -76,6 +75,7 @@ public class Controller extends AbstractComponent { private final Chef chef; private final Mailer mailer; private final AuditLogger auditLogger; + private final FlagSource flagSource; /** * Creates a controller @@ -88,11 +88,11 @@ public class Controller extends AbstractComponent { NameService nameService, RoutingGenerator routingGenerator, Chef chef, AccessControl accessControl, ArtifactRepository artifactRepository, ApplicationStore applicationStore, TesterCloud testerCloud, - BuildService buildService, RunDataStore runDataStore, Mailer mailer) { + BuildService buildService, RunDataStore runDataStore, Mailer mailer, FlagSource flagSource) { this(curator, rotationsConfig, gitHub, zoneRegistry, configServer, metricsService, nameService, routingGenerator, chef, Clock.systemUTC(), accessControl, artifactRepository, applicationStore, testerCloud, - buildService, runDataStore, com.yahoo.net.HostName::getLocalhost, mailer); + buildService, runDataStore, com.yahoo.net.HostName::getLocalhost, mailer, flagSource); } public Controller(CuratorDb curator, RotationsConfig rotationsConfig, GitHub gitHub, @@ -102,7 +102,7 @@ public class Controller extends AbstractComponent { AccessControl accessControl, ArtifactRepository artifactRepository, ApplicationStore applicationStore, TesterCloud testerCloud, BuildService buildService, RunDataStore runDataStore, Supplier<String> hostnameSupplier, - Mailer mailer) { + Mailer mailer, FlagSource flagSource) { this.hostnameSupplier = Objects.requireNonNull(hostnameSupplier, "HostnameSupplier cannot be null"); this.curator = Objects.requireNonNull(curator, "Curator cannot be null"); @@ -113,6 +113,7 @@ public class Controller extends AbstractComponent { this.chef = Objects.requireNonNull(chef, "Chef cannot be null"); this.clock = Objects.requireNonNull(clock, "Clock cannot be null"); this.mailer = Objects.requireNonNull(mailer, "Mailer cannot be null"); + this.flagSource = Objects.requireNonNull(flagSource, "FlagSource cannot be null"); jobController = new JobController(this, runDataStore, Objects.requireNonNull(testerCloud)); applicationController = new ApplicationController(this, curator, accessControl, @@ -123,7 +124,8 @@ public class Controller extends AbstractComponent { Objects.requireNonNull(applicationStore, "ApplicationStore cannot be null"), Objects.requireNonNull(routingGenerator, "RoutingGenerator cannot be null"), Objects.requireNonNull(buildService, "BuildService cannot be null"), - clock); + clock + ); tenantController = new TenantController(this, curator, accessControl); auditLogger = new AuditLogger(curator, clock); @@ -146,6 +148,11 @@ public class Controller extends AbstractComponent { return mailer; } + /** Provides access to the feature flags of this */ + public FlagSource flagSource() { + return flagSource; + } + public Clock clock() { return clock; } public ZoneRegistry zoneRegistry() { return zoneRegistry; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java index e8b3e334631..d1a6e39a1dd 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java @@ -61,7 +61,7 @@ public class TenantController { .collect(Collectors.toList()); } - /** Returns the lsit of tenants accessible to the given user. */ + /** Returns the list of tenants accessible to the given user. */ public List<Tenant> asList(Credentials credentials) { return accessControl.accessibleTenants(asList(), credentials); } @@ -147,10 +147,11 @@ public class TenantController { } private void requireNonExistent(TenantName name) { - if (get(name).isPresent() || + if ( "hosted-vespa".equals(name.value()) + || get(name).isPresent() // Underscores are allowed in existing tenant names, but tenants with - and _ cannot co-exist. E.g. // my-tenant cannot be created if my_tenant exists. - get(name.value().replace('-', '_')).isPresent()) { + || get(name.value().replace('-', '_')).isPresent()) { throw new IllegalArgumentException("Tenant '" + name + "' already exists"); } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/GlobalDnsName.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/GlobalDnsName.java index 0254bf2fd38..ae638beed5c 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/GlobalDnsName.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/GlobalDnsName.java @@ -9,7 +9,7 @@ import java.net.URI; import java.util.Objects; /** - * Represents an application's global rotation. + * Represents names for an application's global rotation. * * @author mpolden */ diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainer.java index 2fe6af02480..7693f224b56 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainer.java @@ -72,7 +72,7 @@ public class DnsMaintainer extends Maintainer { private Optional<Rotation> rotationToCheckOf(Collection<Rotation> rotations) { if (rotations.isEmpty()) return Optional.empty(); List<Rotation> rotationList = new ArrayList<>(rotations); - int index = rotationIndex.getAndUpdate((i)-> { + int index = rotationIndex.getAndUpdate((i) -> { if (i < rotationList.size() - 1) { return ++i; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java index fb27247c48a..8f58827d33a 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java @@ -367,6 +367,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { } private void toSlime(Cursor object, Application application, HttpRequest request) { + object.setString("tenant", application.id().tenant().value()); object.setString("application", application.id().application().value()); object.setString("instance", application.id().instance().value()); object.setString("deployments", withPath("/application/v4" + @@ -456,21 +457,22 @@ public class ApplicationApiHandler extends LoggingRequestHandler { for (Deployment deployment : deployments) { Cursor deploymentObject = instancesArray.addObject(); - deploymentObject.setString("environment", deployment.zone().environment().value()); - deploymentObject.setString("region", deployment.zone().region().value()); - deploymentObject.setString("instance", application.id().instance().value()); // pointless if (application.rotation().isPresent() && deployment.zone().environment() == Environment.prod) { toSlime(application.rotationStatus(deployment), deploymentObject); } if (recurseOverDeployments(request)) // List full deployment information when recursive. toSlime(deploymentObject, new DeploymentId(application.id(), deployment.zone()), deployment, request); - else + else { + deploymentObject.setString("environment", deployment.zone().environment().value()); + deploymentObject.setString("region", deployment.zone().region().value()); + deploymentObject.setString("instance", application.id().instance().value()); // pointless deploymentObject.setString("url", withPath(request.getUri().getPath() + "/environment/" + deployment.zone().environment().value() + "/region/" + deployment.zone().region().value() + "/instance/" + application.id().instance().value(), request.getUri()).toString()); + } } // Metrics @@ -516,6 +518,12 @@ public class ApplicationApiHandler extends LoggingRequestHandler { private void toSlime(Cursor response, DeploymentId deploymentId, Deployment deployment, HttpRequest request) { + response.setString("tenant", deploymentId.applicationId().tenant().value()); + response.setString("application", deploymentId.applicationId().application().value()); + response.setString("instance", deploymentId.applicationId().instance().value()); // pointless + response.setString("environment", deploymentId.zoneId().environment().value()); + response.setString("region", deploymentId.zoneId().region().value()); + Cursor serviceUrlArray = response.setArray("serviceUrls"); controller.applications().getDeploymentEndpoints(deploymentId) .ifPresent(endpoints -> endpoints.forEach(endpoint -> serviceUrlArray.addString(endpoint.toString()))); @@ -1154,6 +1162,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { } private void toSlime(Application application, Cursor object, HttpRequest request) { + object.setString("tenant", application.id().tenant().value()); object.setString("application", application.id().application().value()); object.setString("instance", application.id().instance().value()); object.setString("url", withPath("/application/v4/tenant/" + application.id().tenant().value() + diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleFilter.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleFilter.java new file mode 100644 index 00000000000..f25deb11a52 --- /dev/null +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleFilter.java @@ -0,0 +1,115 @@ +package com.yahoo.vespa.hosted.controller.restapi.filter; + +import com.google.inject.Inject; +import com.yahoo.config.provision.ApplicationName; +import com.yahoo.config.provision.TenantName; +import com.yahoo.jdisc.Response; +import com.yahoo.jdisc.http.filter.DiscFilterRequest; +import com.yahoo.jdisc.http.filter.security.cors.CorsFilterConfig; +import com.yahoo.jdisc.http.filter.security.cors.CorsRequestFilterBase; +import com.yahoo.log.LogLevel; +import com.yahoo.restapi.Path; +import com.yahoo.vespa.athenz.api.AthenzDomain; +import com.yahoo.vespa.athenz.api.AthenzIdentity; +import com.yahoo.vespa.athenz.api.AthenzPrincipal; +import com.yahoo.vespa.athenz.client.zms.ZmsClientException; +import com.yahoo.vespa.hosted.controller.Controller; +import com.yahoo.vespa.hosted.controller.TenantController; +import com.yahoo.vespa.hosted.controller.api.role.Role; +import com.yahoo.vespa.hosted.controller.api.role.Roles; +import com.yahoo.vespa.hosted.controller.athenz.ApplicationAction; +import com.yahoo.vespa.hosted.controller.athenz.impl.AthenzFacade; +import com.yahoo.vespa.hosted.controller.api.role.SecurityContext; +import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; +import com.yahoo.vespa.hosted.controller.tenant.Tenant; +import com.yahoo.vespa.hosted.controller.tenant.UserTenant; +import com.yahoo.yolean.Exceptions; + +import java.net.URI; +import java.util.Optional; +import java.util.Set; +import java.util.logging.Logger; + +import static com.yahoo.vespa.hosted.controller.athenz.HostedAthenzIdentities.SCREWDRIVER_DOMAIN; + +/** + * Enriches the request principal with roles from Athenz. + * + * @author jonmv + */ +public class AthenzRoleFilter extends CorsRequestFilterBase { // TODO: No need for this super anyway. + + private static final Logger logger = Logger.getLogger(AthenzRoleFilter.class.getName()); + + private final AthenzFacade athenz; + private final TenantController tenants; + private final Roles roles; + + @Inject + public AthenzRoleFilter(CorsFilterConfig config, AthenzFacade athenz, Controller controller) { + super(Set.copyOf(config.allowedUrls())); + this.athenz = athenz; + this.tenants = controller.tenants(); + this.roles = new Roles(controller.system()); + } + + @Override + protected Optional<ErrorResponse> filterRequest(DiscFilterRequest request) { + try { + AthenzPrincipal athenzPrincipal = (AthenzPrincipal) request.getUserPrincipal(); + request.setAttribute(SecurityContext.ATTRIBUTE_NAME, new SecurityContext(athenzPrincipal, + roles(athenzPrincipal, request.getUri()))); + return Optional.empty(); + } + catch (Exception e) { + logger.log(LogLevel.DEBUG, () -> "Exception mapping Athenz principal to roles: " + Exceptions.toMessageString(e)); + return Optional.of(new ErrorResponse(Response.Status.UNAUTHORIZED, "Access denied")); + } + } + + Set<Role> roles(AthenzPrincipal principal, URI uri) { + Path path = new Path(uri); + + path.matches("/application/v4/tenant/{tenant}/{*}"); + Optional<Tenant> tenant = Optional.ofNullable(path.get("tenant")).map(TenantName::from).flatMap(tenants::get); + + path.matches("/application/v4/tenant/{tenant}/application/{application}/{*}"); + Optional<ApplicationName> application = Optional.ofNullable(path.get("application")).map(ApplicationName::from); + + AthenzIdentity identity = principal.getIdentity(); + + if (athenz.hasHostedOperatorAccess(identity)) + return Set.of(roles.hostedOperator()); + + if (tenant.isPresent() && isTenantAdmin(identity, tenant.get())) + return Set.of(roles.athenzTenantAdmin(tenant.get().name())); + + if (identity.getDomain().equals(SCREWDRIVER_DOMAIN) && application.isPresent() && tenant.isPresent()) + // NOTE: Only fine-grained deploy authorization for Athenz tenants + if ( tenant.get().type() != Tenant.Type.athenz + || hasDeployerAccess(identity, ((AthenzTenant) tenant.get()).domain(), application.get())) + return Set.of(roles.tenantPipeline(tenant.get().name(), application.get())); + + return Set.of(roles.everyone()); + } + + private boolean isTenantAdmin(AthenzIdentity identity, Tenant tenant) { + switch (tenant.type()) { + case athenz: return athenz.hasTenantAdminAccess(identity, ((AthenzTenant) tenant).domain()); + case user: return ((UserTenant) tenant).is(identity.getName()) || athenz.hasHostedOperatorAccess(identity); + default: throw new IllegalArgumentException("Unexpected tenant type '" + tenant.type() + "'."); + } + } + + private boolean hasDeployerAccess(AthenzIdentity identity, AthenzDomain tenantDomain, ApplicationName application) { + try { + return athenz.hasApplicationAccess(identity, + ApplicationAction.deploy, + tenantDomain, + application); + } catch (ZmsClientException e) { + throw new RuntimeException("Failed to authorize operation: (" + e.getMessage() + ")", e); + } + } + +} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleResolver.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleResolver.java deleted file mode 100644 index a1dfdbeb245..00000000000 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleResolver.java +++ /dev/null @@ -1,118 +0,0 @@ -package com.yahoo.vespa.hosted.controller.restapi.filter; - -import com.google.inject.Inject; -import com.yahoo.config.provision.ApplicationName; -import com.yahoo.config.provision.SystemName; -import com.yahoo.config.provision.TenantName; -import com.yahoo.restapi.Path; -import com.yahoo.vespa.athenz.api.AthenzDomain; -import com.yahoo.vespa.athenz.api.AthenzIdentity; -import com.yahoo.vespa.athenz.api.AthenzPrincipal; -import com.yahoo.vespa.athenz.api.AthenzUser; -import com.yahoo.vespa.athenz.client.zms.ZmsClientException; -import com.yahoo.vespa.hosted.controller.Controller; -import com.yahoo.vespa.hosted.controller.TenantController; -import com.yahoo.vespa.hosted.controller.athenz.ApplicationAction; -import com.yahoo.vespa.hosted.controller.athenz.impl.AthenzFacade; -import com.yahoo.vespa.hosted.controller.role.Role; -import com.yahoo.vespa.hosted.controller.role.RoleMembership; -import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; -import com.yahoo.vespa.hosted.controller.tenant.Tenant; -import com.yahoo.vespa.hosted.controller.tenant.UserTenant; - -import javax.ws.rs.InternalServerErrorException; -import java.security.Principal; -import java.util.Optional; - -import static com.yahoo.vespa.hosted.controller.athenz.HostedAthenzIdentities.SCREWDRIVER_DOMAIN; - -/** - * Translates Athenz principals to role memberships for use in access control. - * - * @author tokle - * @author mpolden - */ -public class AthenzRoleResolver implements RoleMembership.Resolver { - - private final AthenzFacade athenz; - private final TenantController tenants; - private final SystemName system; - - @Inject - public AthenzRoleResolver(AthenzFacade athenz, Controller controller) { - this.athenz = athenz; - this.tenants = controller.tenants(); - this.system = controller.system(); - } - - private boolean isTenantAdmin(AthenzIdentity identity, Tenant tenant) { - if (tenant instanceof AthenzTenant) { - return athenz.hasTenantAdminAccess(identity, ((AthenzTenant) tenant).domain()); - } else if (tenant instanceof UserTenant) { - if (!(identity instanceof AthenzUser)) { - return false; - } - AthenzUser user = (AthenzUser) identity; - return ((UserTenant) tenant).is(user.getName()) || isHostedOperator(identity); - } - throw new InternalServerErrorException("Unknown tenant type: " + tenant.getClass().getSimpleName()); - } - - private boolean hasDeployerAccess(AthenzIdentity identity, AthenzDomain tenantDomain, ApplicationName application) { - try { - return athenz.hasApplicationAccess(identity, - ApplicationAction.deploy, - tenantDomain, - application); - } catch (ZmsClientException e) { - throw new InternalServerErrorException("Failed to authorize operation: (" + e.getMessage() + ")", e); - } - } - - private boolean isHostedOperator(AthenzIdentity identity) { - return athenz.hasHostedOperatorAccess(identity); - } - - @Override - public RoleMembership membership(Principal principal, Optional<String> uriPath) { - if ( ! (principal instanceof AthenzPrincipal)) - throw new IllegalStateException("Expected an AthenzPrincipal to be set on the request."); - - @SuppressWarnings("deprecation") // TODO: Use URI when refactoring this. - Path path = new Path(uriPath.orElseThrow(() -> new IllegalArgumentException("This resolver needs the request path."))); - - path.matches("/application/v4/tenant/{tenant}/{*}"); - Optional<Tenant> tenant = Optional.ofNullable(path.get("tenant")).map(TenantName::from).flatMap(tenants::get); - - path.matches("/application/v4/tenant/{tenant}/application/{application}/{*}"); - Optional<ApplicationName> application = Optional.ofNullable(path.get("application")).map(ApplicationName::from); - - AthenzIdentity identity = ((AthenzPrincipal) principal).getIdentity(); - - RoleMembership.Builder memberships = RoleMembership.in(system); - if (isHostedOperator(identity)) { - memberships.add(Role.hostedOperator); - } - if (tenant.isPresent() && isTenantAdmin(identity, tenant.get())) { - memberships.add(Role.athenzTenantAdmin).limitedTo(tenant.get().name()); - } - AthenzDomain principalDomain = identity.getDomain(); - if (principalDomain.equals(SCREWDRIVER_DOMAIN)) { - if (application.isPresent() && tenant.isPresent()) { - // NOTE: Only fine-grained deploy authorization for Athenz tenants - if (tenant.get() instanceof AthenzTenant) { - AthenzDomain tenantDomain = ((AthenzTenant) tenant.get()).domain(); - if (hasDeployerAccess(identity, tenantDomain, application.get())) { - memberships.add(Role.tenantPipeline).limitedTo(tenant.get().name(), application.get()); - } - } - else { - memberships.add(Role.tenantPipeline).limitedTo(tenant.get().name(), application.get()); - } - } - } - memberships.add(Role.everyone); - return memberships.build(); - } - -} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java index dfcc5f732f8..39736d709d0 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.hosted.controller.restapi.filter; import com.google.inject.Inject; +import com.yahoo.config.provision.SystemName; import com.yahoo.jdisc.Response; import com.yahoo.jdisc.http.HttpRequest; import com.yahoo.jdisc.http.filter.DiscFilterRequest; @@ -9,12 +10,11 @@ import com.yahoo.jdisc.http.filter.security.cors.CorsFilterConfig; import com.yahoo.jdisc.http.filter.security.cors.CorsRequestFilterBase; import com.yahoo.log.LogLevel; import com.yahoo.vespa.hosted.controller.Controller; -import com.yahoo.vespa.hosted.controller.role.Action; -import com.yahoo.vespa.hosted.controller.role.RoleMembership; -import com.yahoo.yolean.chain.After; -import com.yahoo.yolean.chain.Provides; +import com.yahoo.vespa.hosted.controller.api.role.Action; +import com.yahoo.vespa.hosted.controller.api.role.Role; +import com.yahoo.vespa.hosted.controller.api.role.Roles; +import com.yahoo.vespa.hosted.controller.api.role.SecurityContext; -import javax.ws.rs.WebApplicationException; import java.security.Principal; import java.util.Optional; import java.util.Set; @@ -25,45 +25,41 @@ import java.util.logging.Logger; * * @author bjorncs */ -@After("com.yahoo.vespa.hosted.controller.athenz.filter.UserAuthWithAthenzPrincipalFilter") -@Provides("ControllerAuthorizationFilter") public class ControllerAuthorizationFilter extends CorsRequestFilterBase { private static final Logger log = Logger.getLogger(ControllerAuthorizationFilter.class.getName()); - private final RoleMembership.Resolver roleResolver; - private final Controller controller; + private final Roles roles; @Inject - public ControllerAuthorizationFilter(RoleMembership.Resolver roleResolver, - Controller controller, + public ControllerAuthorizationFilter(Controller controller, CorsFilterConfig corsConfig) { - this(roleResolver, controller, Set.copyOf(corsConfig.allowedUrls())); + this(controller.system(), Set.copyOf(corsConfig.allowedUrls())); } - ControllerAuthorizationFilter(RoleMembership.Resolver roleResolver, - Controller controller, + ControllerAuthorizationFilter(SystemName system, Set<String> allowedUrls) { super(allowedUrls); - this.roleResolver = roleResolver; - this.controller = controller; + this.roles = new Roles(system); } @Override public Optional<ErrorResponse> filterRequest(DiscFilterRequest request) { try { Principal principal = request.getUserPrincipal(); - if (principal == null) + Optional<SecurityContext> securityContext = Optional.ofNullable((SecurityContext)request.getAttribute(SecurityContext.ATTRIBUTE_NAME)); + + if (securityContext.isEmpty()) return Optional.of(new ErrorResponse(Response.Status.FORBIDDEN, "Access denied")); Action action = Action.from(HttpRequest.Method.valueOf(request.getMethod())); - // Avoid expensive lookups when request is always legal. - if (RoleMembership.everyoneIn(controller.system()).allows(action, request.getUri())) + // Avoid expensive look-ups when request is always legal. + if (roles.everyone().allows(action, request.getUri())) return Optional.empty(); - RoleMembership roles = this.roleResolver.membership(principal, Optional.of(request.getRequestURI())); - if (roles.allows(action, request.getUri())) + Set<Role> roles = securityContext.get().roles(); + if (roles.stream().anyMatch(role -> role.allows(action, request.getUri()))) return Optional.empty(); } catch (Exception e) { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/user/UserApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/user/UserApiHandler.java index 18b124778d5..067e6095b4d 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/user/UserApiHandler.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/user/UserApiHandler.java @@ -85,4 +85,6 @@ public class UserApiHandler extends LoggingRequestHandler { return response; } + + } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/RoleMembership.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/RoleMembership.java deleted file mode 100644 index 09e66528913..00000000000 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/role/RoleMembership.java +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.controller.role; - -import com.yahoo.config.provision.ApplicationName; -import com.yahoo.config.provision.SystemName; -import com.yahoo.config.provision.TenantName; - -import java.net.URI; -import java.security.Principal; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * A list of roles and their associated contexts. This defines the role membership of a tenant, and in which contexts - * (see {@link Context}) those roles apply. - * - * @author mpolden - * @author jonmv - */ -public class RoleMembership { - - private final Map<Role, Set<Context>> roles; - - private RoleMembership(Map<Role, Set<Context>> roles) { - this.roles = roles.entrySet().stream() - .collect(Collectors.toUnmodifiableMap(entry -> entry.getKey(), - entry -> Set.copyOf(entry.getValue()))); - } - - public static RoleMembership everyoneIn(SystemName system) { - return in(system).add(Role.everyone).build(); - } - - public static Builder in(SystemName system) { return new BuilderWithRole(system); } - - /** Returns whether any role in this allows action to take place in path */ - public boolean allows(Action action, URI uri) { - return roles.entrySet().stream().anyMatch(kv -> { - Role role = kv.getKey(); - Set<Context> contexts = kv.getValue(); - return contexts.stream().anyMatch(context -> role.allows(action, uri, context)); - }); - } - - /** Returns the set of contexts for which the given role is valid. */ - public Set<Context> contextsFor(Role role) { - return roles.getOrDefault(role, Collections.emptySet()); - } - - @Override - public String toString() { - return "roles " + roles; - } - - /** - * A role resolver. Identity providers can implement this to translate their internal representation of role - * membership to a {@link RoleMembership}. - */ - public interface Resolver { - RoleMembership membership(Principal user, Optional<String> path); // TODO get rid of path. - } - - public interface Builder { - - BuilderWithRole add(Role role); - - RoleMembership build(); - - } - - public static class BuilderWithRole implements Builder { - - private final SystemName system; - private final Map<Role, Set<Context>> roles; - - private Role current; - - private BuilderWithRole(SystemName system) { - this.system = Objects.requireNonNull(system); - this.roles = new HashMap<>(); - } - - @Override - public BuilderWithRole add(Role role) { - consumeCurrent(Context.unlimitedIn(system)); - current = role; - return this; - } - - public Builder limitedTo(TenantName tenant) { - consumeCurrent(Context.limitedTo(tenant, system)); - return this; - } - - public Builder limitedTo(TenantName tenant, ApplicationName application) { - consumeCurrent(Context.limitedTo(tenant, application, system)); - return this; - } - - @Override - public RoleMembership build() { - consumeCurrent(Context.unlimitedIn(system)); - return new RoleMembership(roles); - } - - private void consumeCurrent(Context context) { - if (current != null) { - roles.putIfAbsent(current, new HashSet<>()); - roles.get(current).add(context); - } - current = null; - } - - } - -} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/rotation/RotationRepository.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/rotation/RotationRepository.java index a22e5259919..b3953c47c01 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/rotation/RotationRepository.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/rotation/RotationRepository.java @@ -63,7 +63,7 @@ public class RotationRepository { if (application.rotation().isPresent()) { return allRotations.get(application.rotation().get()); } - if (!application.deploymentSpec().globalServiceId().isPresent()) { + if (application.deploymentSpec().globalServiceId().isEmpty()) { throw new IllegalArgumentException("global-service-id is not set in deployment spec"); } long productionZones = application.deploymentSpec().zones().stream() diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java index 67d7a02a915..d1806fb5747 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java @@ -2,10 +2,17 @@ package com.yahoo.vespa.hosted.controller.security; import com.google.inject.Inject; import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ApplicationName; import com.yahoo.config.provision.TenantName; import com.yahoo.vespa.hosted.controller.Application; import com.yahoo.vespa.hosted.controller.api.integration.organization.BillingInfo; import com.yahoo.vespa.hosted.controller.api.integration.organization.Marketplace; +import com.yahoo.vespa.hosted.controller.api.integration.user.RoleId; +import com.yahoo.vespa.hosted.controller.api.integration.user.UserId; +import com.yahoo.vespa.hosted.controller.api.integration.user.UserManagement; +import com.yahoo.vespa.hosted.controller.api.role.ApplicationRole; +import com.yahoo.vespa.hosted.controller.api.role.Roles; +import com.yahoo.vespa.hosted.controller.api.role.TenantRole; import com.yahoo.vespa.hosted.controller.tenant.CloudTenant; import com.yahoo.vespa.hosted.controller.tenant.Tenant; @@ -19,21 +26,28 @@ import java.util.List; public class CloudAccessControl implements AccessControl { private final Marketplace marketplace; + private final UserManagement userManagement; + private final Roles roles; @Inject - public CloudAccessControl(Marketplace marketplace) { + public CloudAccessControl(Marketplace marketplace, UserManagement userManagement, Roles roles) { this.marketplace = marketplace; + this.userManagement = userManagement; + this.roles = roles; } @Override public CloudTenant createTenant(TenantSpec tenantSpec, Credentials credentials, List<Tenant> existing) { CloudTenantSpec spec = (CloudTenantSpec) tenantSpec; + CloudTenant tenant = new CloudTenant(spec.tenant(), new BillingInfo("customer", "Vespa")); + // CloudTenant tenant new CloudTenant(spec.tenant(), marketplace.resolveCustomer(spec.getRegistrationToken())); + // TODO Enable the above when things work. - // Do things ... + RoleId ownerRole = RoleId.fromRole(roles.tenantOwner(spec.tenant())); + userManagement.createRole(ownerRole); + userManagement.addUsers(ownerRole, List.of(new UserId(credentials.user().getName()))); - // return new CloudTenant(spec.tenant(), marketplace.resolveCustomer(spec.getRegistrationToken())); - // TODO Enable the above when things work. - return new CloudTenant(spec.tenant(), new BillingInfo("customer", "Vespa")); + return tenant; } @Override @@ -43,31 +57,48 @@ public class CloudAccessControl implements AccessControl { @Override public void deleteTenant(TenantName tenant, Credentials credentials) { - // Probably terminate customer subscription? - // Delete tenant group - + tenantRoles(tenant).stream() + .map(RoleId::fromRole) + .filter(userManagement.listRoles()::contains) + .forEach(userManagement::deleteRole); } @Override public void createApplication(ApplicationId application, Credentials credentials) { - - // Create application group? - + RoleId ownerRole = RoleId.fromRole(roles.applicationOwner(application.tenant(), application.application())); + userManagement.createRole(ownerRole); + userManagement.addUsers(ownerRole, List.of(new UserId(credentials.user().getName()))); } @Override public void deleteApplication(ApplicationId id, Credentials credentials) { - - // Delete application group? - + applicationRoles(id.tenant(), id.application()).stream() + .map(RoleId::fromRole) + .filter(userManagement.listRoles()::contains) + .forEach(userManagement::deleteRole); } @Override public List<Tenant> accessibleTenants(List<Tenant> tenants, Credentials credentials) { - // Get credential things (token with roles or something) and check what it's good for. + // TODO: Get credential things (token with roles or something) and check what it's good for. + // TODO ... or ignore this here, and compute it somewhere else. return Collections.emptyList(); } + private List<TenantRole> tenantRoles(TenantName tenant) { + return List.of(roles.tenantOperator(tenant), + roles.tenantAdmin(tenant), + roles.tenantOwner(tenant)); + } + + private List<ApplicationRole> applicationRoles(TenantName tenant, ApplicationName application) { + return List.of(roles.applicationReader(tenant, application), + roles.applicationDeveloper(tenant, application), + roles.applicationOperator(tenant, application), + roles.applicationAdmin(tenant, application), + roles.applicationOwner(tenant, application)); + } + } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControlRequests.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControlRequests.java index 631d4debe88..ea931616211 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControlRequests.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControlRequests.java @@ -20,7 +20,7 @@ public class CloudAccessControlRequests implements AccessControlRequests { @Override public Credentials credentials(TenantName tenant, Inspector requestObject, HttpRequest request) { - // TODO Pick out JWT data and return a specialised credentials thing. + // TODO Include roles, if this is to be used for displaying accessible data. return new Credentials(request.getUserPrincipal()); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/TenantSpec.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/TenantSpec.java index 2f7dd656678..358088e9b08 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/TenantSpec.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/TenantSpec.java @@ -1,6 +1,7 @@ package com.yahoo.vespa.hosted.controller.security; import com.yahoo.config.provision.TenantName; +import com.yahoo.vespa.hosted.controller.tenant.Tenant; import static java.util.Objects.requireNonNull; @@ -14,7 +15,7 @@ public abstract class TenantSpec { private final TenantName tenant; protected TenantSpec(TenantName tenant) { - this.tenant = requireNonNull(tenant); + this.tenant = Tenant.requireName(requireNonNull(tenant)); } /** The name of the tenant. */ diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Tenant.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Tenant.java index 19b7229515b..e0c750dec80 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Tenant.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Tenant.java @@ -49,7 +49,7 @@ public abstract class Tenant { return Objects.hash(name); } - static TenantName requireName(TenantName name) { + public static TenantName requireName(TenantName name) { if ( ! name.value().matches("^(?=.{1,20}$)[a-z](-?[a-z0-9]+)*$")) { throw new IllegalArgumentException("New tenant or application names must start with a letter, may " + "contain no more than 20 characters, and may only contain lowercase " + diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java index 1f00d99350a..bc42b672da4 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java @@ -12,13 +12,14 @@ import com.yahoo.config.provision.InstanceName; import com.yahoo.config.provision.RegionName; import com.yahoo.config.provision.SystemName; import com.yahoo.config.provision.TenantName; +import com.yahoo.vespa.flags.Flags; +import com.yahoo.vespa.flags.InMemoryFlagSource; import com.yahoo.vespa.hosted.controller.api.application.v4.model.DeployOptions; import com.yahoo.vespa.hosted.controller.api.application.v4.model.EndpointStatus; import com.yahoo.vespa.hosted.controller.api.identifiers.DeploymentId; import com.yahoo.vespa.hosted.controller.api.integration.deployment.ApplicationVersion; import com.yahoo.vespa.hosted.controller.api.integration.deployment.SourceRevision; import com.yahoo.vespa.hosted.controller.api.integration.dns.Record; -import com.yahoo.vespa.hosted.controller.api.integration.dns.RecordName; import com.yahoo.vespa.hosted.controller.api.integration.routing.RoutingEndpoint; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.application.ApplicationPackage; @@ -281,40 +282,60 @@ public class ControllerTest { .region("us-central-1") // Two deployments should result in each DNS alias being registered once .build(); - Function<String, Optional<Record>> findCname = (name) -> tester.controllerTester().nameService() - .findRecords(Record.Type.CNAME, - RecordName.from(name)) - .stream() - .findFirst(); - tester.deployCompletely(application, applicationPackage); assertEquals(3, tester.controllerTester().nameService().records().size()); - Optional<Record> record = findCname.apply("app1--tenant1.global.vespa.yahooapis.com"); + Optional<Record> record = tester.controllerTester().findCname("app1--tenant1.global.vespa.yahooapis.com"); assertTrue(record.isPresent()); assertEquals("app1--tenant1.global.vespa.yahooapis.com", record.get().name().asString()); assertEquals("rotation-fqdn-01.", record.get().data().asString()); - record = findCname.apply("app1--tenant1.global.vespa.oath.cloud"); + record = tester.controllerTester().findCname("app1--tenant1.global.vespa.oath.cloud"); assertTrue(record.isPresent()); assertEquals("app1--tenant1.global.vespa.oath.cloud", record.get().name().asString()); assertEquals("rotation-fqdn-01.", record.get().data().asString()); - record = findCname.apply("app1.tenant1.global.vespa.yahooapis.com"); + record = tester.controllerTester().findCname("app1.tenant1.global.vespa.yahooapis.com"); assertTrue(record.isPresent()); assertEquals("app1.tenant1.global.vespa.yahooapis.com", record.get().name().asString()); assertEquals("rotation-fqdn-01.", record.get().data().asString()); } @Test - public void testUpdatesExistingDnsAlias() { + public void testRedirectLegacyDnsNames() { // TODO: Remove together with Flags.REDIRECT_LEGACY_DNS_NAMES DeploymentTester tester = new DeploymentTester(); + Application application = tester.createApplication("app1", "tenant1", 1, 1L); + ApplicationPackage applicationPackage = new ApplicationPackageBuilder() + .environment(Environment.prod) + .globalServiceId("foo") + .region("us-west-1") + .region("us-central-1") + .build(); + + ((InMemoryFlagSource) tester.controller().flagSource()).withBooleanFlag(Flags.REDIRECT_LEGACY_DNS_NAMES.id(), true); + + tester.deployCompletely(application, applicationPackage); + assertEquals(3, tester.controllerTester().nameService().records().size()); + + Optional<Record> record = tester.controllerTester().findCname("app1--tenant1.global.vespa.yahooapis.com"); + assertTrue(record.isPresent()); + assertEquals("app1--tenant1.global.vespa.yahooapis.com", record.get().name().asString()); + assertEquals("app1--tenant1.global.vespa.oath.cloud.", record.get().data().asString()); - Function<String, Optional<Record>> findCname = (name) -> tester.controllerTester().nameService() - .findRecords(Record.Type.CNAME, - RecordName.from(name)) - .stream() - .findFirst(); + record = tester.controllerTester().findCname("app1--tenant1.global.vespa.oath.cloud"); + assertTrue(record.isPresent()); + assertEquals("app1--tenant1.global.vespa.oath.cloud", record.get().name().asString()); + assertEquals("rotation-fqdn-01.", record.get().data().asString()); + + record = tester.controllerTester().findCname("app1.tenant1.global.vespa.yahooapis.com"); + assertTrue(record.isPresent()); + assertEquals("app1.tenant1.global.vespa.yahooapis.com", record.get().name().asString()); + assertEquals("app1--tenant1.global.vespa.oath.cloud.", record.get().data().asString()); + } + + @Test + public void testUpdatesExistingDnsAlias() { + DeploymentTester tester = new DeploymentTester(); // Application 1 is deployed and deleted { @@ -329,12 +350,12 @@ public class ControllerTest { tester.deployCompletely(app1, applicationPackage); assertEquals(3, tester.controllerTester().nameService().records().size()); - Optional<Record> record = findCname.apply("app1--tenant1.global.vespa.yahooapis.com"); + Optional<Record> record = tester.controllerTester().findCname("app1--tenant1.global.vespa.yahooapis.com"); assertTrue(record.isPresent()); assertEquals("app1--tenant1.global.vespa.yahooapis.com", record.get().name().asString()); assertEquals("rotation-fqdn-01.", record.get().data().asString()); - record = findCname.apply("app1.tenant1.global.vespa.yahooapis.com"); + record = tester.controllerTester().findCname("app1.tenant1.global.vespa.yahooapis.com"); assertTrue(record.isPresent()); assertEquals("app1.tenant1.global.vespa.yahooapis.com", record.get().name().asString()); assertEquals("rotation-fqdn-01.", record.get().data().asString()); @@ -356,13 +377,13 @@ public class ControllerTest { } // Records remain - record = findCname.apply("app1--tenant1.global.vespa.yahooapis.com"); + record = tester.controllerTester().findCname("app1--tenant1.global.vespa.yahooapis.com"); assertTrue(record.isPresent()); - record = findCname.apply("app1--tenant1.global.vespa.oath.cloud"); + record = tester.controllerTester().findCname("app1--tenant1.global.vespa.oath.cloud"); assertTrue(record.isPresent()); - record = findCname.apply("app1.tenant1.global.vespa.yahooapis.com"); + record = tester.controllerTester().findCname("app1.tenant1.global.vespa.yahooapis.com"); assertTrue(record.isPresent()); } @@ -378,17 +399,17 @@ public class ControllerTest { tester.deployCompletely(app2, applicationPackage); assertEquals(6, tester.controllerTester().nameService().records().size()); - Optional<Record> record = findCname.apply("app2--tenant2.global.vespa.yahooapis.com"); + Optional<Record> record = tester.controllerTester().findCname("app2--tenant2.global.vespa.yahooapis.com"); assertTrue(record.isPresent()); assertEquals("app2--tenant2.global.vespa.yahooapis.com", record.get().name().asString()); assertEquals("rotation-fqdn-01.", record.get().data().asString()); - record = findCname.apply("app2--tenant2.global.vespa.oath.cloud"); + record = tester.controllerTester().findCname("app2--tenant2.global.vespa.oath.cloud"); assertTrue(record.isPresent()); assertEquals("app2--tenant2.global.vespa.oath.cloud", record.get().name().asString()); assertEquals("rotation-fqdn-01.", record.get().data().asString()); - record = findCname.apply("app2.tenant2.global.vespa.yahooapis.com"); + record = tester.controllerTester().findCname("app2.tenant2.global.vespa.yahooapis.com"); assertTrue(record.isPresent()); assertEquals("app2.tenant2.global.vespa.yahooapis.com", record.get().name().asString()); assertEquals("rotation-fqdn-01.", record.get().data().asString()); @@ -411,15 +432,15 @@ public class ControllerTest { // Existing DNS records are updated to point to the newly assigned rotation assertEquals(6, tester.controllerTester().nameService().records().size()); - Optional<Record> record = findCname.apply("app1--tenant1.global.vespa.yahooapis.com"); + Optional<Record> record = tester.controllerTester().findCname("app1--tenant1.global.vespa.yahooapis.com"); assertTrue(record.isPresent()); assertEquals("rotation-fqdn-02.", record.get().data().asString()); - record = findCname.apply("app1--tenant1.global.vespa.oath.cloud"); + record = tester.controllerTester().findCname("app1--tenant1.global.vespa.oath.cloud"); assertTrue(record.isPresent()); assertEquals("rotation-fqdn-02.", record.get().data().asString()); - record = findCname.apply("app1.tenant1.global.vespa.yahooapis.com"); + record = tester.controllerTester().findCname("app1.tenant1.global.vespa.yahooapis.com"); assertTrue(record.isPresent()); assertEquals("rotation-fqdn-02.", record.get().data().asString()); } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java index d7845e4bfa1..c18e9c46f07 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java @@ -13,6 +13,7 @@ import com.yahoo.vespa.athenz.api.AthenzUser; import com.yahoo.vespa.athenz.api.OktaAccessToken; import com.yahoo.vespa.curator.Lock; import com.yahoo.vespa.curator.mock.MockCurator; +import com.yahoo.vespa.flags.InMemoryFlagSource; import com.yahoo.vespa.hosted.controller.api.application.v4.model.DeployOptions; import com.yahoo.vespa.hosted.controller.api.identifiers.Property; import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; @@ -22,14 +23,15 @@ import com.yahoo.vespa.hosted.controller.api.integration.deployment.ApplicationS import com.yahoo.vespa.hosted.controller.api.integration.deployment.ArtifactRepository; import com.yahoo.vespa.hosted.controller.api.integration.deployment.JobType; import com.yahoo.vespa.hosted.controller.api.integration.dns.MemoryNameService; -import com.yahoo.vespa.hosted.controller.api.integration.entity.MemoryEntityService; +import com.yahoo.vespa.hosted.controller.api.integration.dns.Record; +import com.yahoo.vespa.hosted.controller.api.integration.dns.RecordName; import com.yahoo.vespa.hosted.controller.api.integration.github.GitHubMock; import com.yahoo.vespa.hosted.controller.api.integration.organization.Contact; import com.yahoo.vespa.hosted.controller.api.integration.organization.MockContactRetriever; import com.yahoo.vespa.hosted.controller.api.integration.organization.MockIssueHandler; -import com.yahoo.vespa.hosted.controller.api.integration.stubs.MockMailer; import com.yahoo.vespa.hosted.controller.api.integration.routing.RoutingGenerator; import com.yahoo.vespa.hosted.controller.api.integration.stubs.MockBuildService; +import com.yahoo.vespa.hosted.controller.api.integration.stubs.MockMailer; import com.yahoo.vespa.hosted.controller.api.integration.stubs.MockRunDataStore; import com.yahoo.vespa.hosted.controller.api.integration.stubs.MockTesterCloud; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; @@ -43,11 +45,11 @@ import com.yahoo.vespa.hosted.controller.integration.ConfigServerMock; import com.yahoo.vespa.hosted.controller.integration.MetricsServiceMock; import com.yahoo.vespa.hosted.controller.integration.RoutingGeneratorMock; import com.yahoo.vespa.hosted.controller.integration.ZoneRegistryMock; -import com.yahoo.vespa.hosted.controller.security.AthenzCredentials; -import com.yahoo.vespa.hosted.controller.security.AthenzTenantSpec; import com.yahoo.vespa.hosted.controller.persistence.ApplicationSerializer; import com.yahoo.vespa.hosted.controller.persistence.CuratorDb; import com.yahoo.vespa.hosted.controller.persistence.MockCuratorDb; +import com.yahoo.vespa.hosted.controller.security.AthenzCredentials; +import com.yahoo.vespa.hosted.controller.security.AthenzTenantSpec; import com.yahoo.vespa.hosted.controller.security.Credentials; import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; import com.yahoo.vespa.hosted.controller.tenant.Tenant; @@ -188,6 +190,10 @@ public final class ControllerTester { return contactRetriever; } + public Optional<Record> findCname(String name) { + return nameService.findRecords(Record.Type.CNAME, RecordName.from(name)).stream().findFirst(); + } + /** Create a new controller instance. Useful to verify that controller state is rebuilt from persistence */ public final void createNewController() { controller = createController(curator, rotationsConfig, configServer, clock, gitHub, zoneRegistry, athenzDb, @@ -345,7 +351,8 @@ public final class ControllerTester { buildService, new MockRunDataStore(), () -> "test-controller", - new MockMailer()); + new MockMailer(), + new InMemoryFlagSource()); // Calculate initial versions controller.updateVersionStatus(VersionStatus.compute(controller)); return controller; diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainerTest.java index 23c7ec537f5..2b8e4f52d23 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DnsMaintainerTest.java @@ -4,7 +4,6 @@ package com.yahoo.vespa.hosted.controller.maintenance; import com.yahoo.config.application.api.ValidationId; import com.yahoo.config.provision.Environment; import com.yahoo.config.provision.RegionName; -import com.yahoo.vespa.athenz.api.OktaAccessToken; import com.yahoo.vespa.hosted.controller.Application; import com.yahoo.vespa.hosted.controller.ControllerTester; import com.yahoo.vespa.hosted.controller.api.integration.dns.Record; diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java index 331a6ba9ac8..c21d4b4b0bf 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java @@ -63,6 +63,7 @@ public class ControllerContainerTest { " <item>http://localhost</item>\n" + " </allowedUrls>\n" + " </config>\n" + + " <component id='com.yahoo.vespa.flags.InMemoryFlagSource'/>\n" + " <component id='com.yahoo.vespa.hosted.controller.persistence.MockCuratorDb'/>\n" + " <component id='com.yahoo.vespa.hosted.controller.athenz.mock.AthenzClientFactoryMock'/>\n" + " <component id='com.yahoo.vespa.hosted.controller.api.integration.chef.ChefMock'/>\n" + @@ -94,9 +95,9 @@ public class ControllerContainerTest { " <component id='com.yahoo.vespa.hosted.controller.integration.ApplicationStoreMock'/>\n" + " <component id='com.yahoo.vespa.hosted.controller.api.integration.stubs.MockTesterCloud'/>\n" + " <component id='com.yahoo.vespa.hosted.controller.api.integration.stubs.MockMailer'/>\n" + + " <component id='com.yahoo.vespa.hosted.controller.api.role.Roles'/>\n" + " <component id='com.yahoo.vespa.hosted.controller.security.AthenzAccessControlRequests'/>\n" + " <component id='com.yahoo.vespa.hosted.controller.athenz.impl.AthenzFacade'/>\n" + - " <component id='com.yahoo.vespa.hosted.controller.restapi.filter.AthenzRoleResolver'/>\n" + " <handler id='com.yahoo.vespa.hosted.controller.restapi.application.ApplicationApiHandler'>\n" + " <binding>http://*/application/v4/*</binding>\n" + " </handler>\n" + @@ -134,6 +135,7 @@ public class ControllerContainerTest { " <filtering>\n" + " <request-chain id='default'>\n" + " <filter id='com.yahoo.vespa.hosted.controller.integration.AthenzFilterMock'/>\n" + + " <filter id='com.yahoo.vespa.hosted.controller.restapi.filter.AthenzRoleFilter'/>\n" + " <filter id='com.yahoo.vespa.hosted.controller.restapi.filter.ControllerAuthorizationFilter'/>\n" + " <binding>http://*/*</binding>\n" + " </request-chain>\n" + diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java index 40d39248cb5..bde1c037bf2 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java @@ -384,7 +384,7 @@ public class ApplicationApiTest extends ControllerContainerTest { tester.assertResponse(request("/application/v4/tenant/tenant2/application/application1/environment/prod/region/us-central-1/instance/default/logs?from=1233&to=3214", GET) .userIdentity(USER_ID), new File("logs.json")); - tester.assertResponse(request("/application/v4/tenant/tenant2/application/application1/environment/prod/region/us-central-1/instance/default/logs?from=1233&to=3214&streaming", GET) + tester.assertResponse(request("/application/v4/tenant/tenant2/application/application1/environment/dev/region/us-central-1/instance/default/logs?from=1233&to=3214&streaming", GET) .userIdentity(USER_ID), "INFO - All good"); @@ -758,31 +758,6 @@ public class ApplicationApiTest extends ControllerContainerTest { new File("deploy-no-deployment.json"), 400); } - // Tests deployment to config server when using just on API call - // For now this depends on a switch in ApplicationController that does this for by- tenants in CD only - @Test - public void testDeployDirectlyUsingOneCallForDeploy() { - // Setup - tester.computeVersionStatus(); - UserId userId = new UserId("new_user"); - createAthenzDomainWithAdmin(ATHENZ_TENANT_DOMAIN, userId); - - // Create tenant - // PUT (create) the authenticated user - byte[] data = new byte[0]; - tester.assertResponse(request("/application/v4/user?user=new_user&domain=by", PUT) - .data(data) - .userIdentity(userId), // Normalized to by-new-user by API - new File("create-user-response.json")); - - // POST (deploy) an application to a dev zone - HttpEntity entity = createApplicationDeployData(applicationPackage, true); - tester.assertResponse(request("/application/v4/tenant/by-new-user/application/application1/environment/dev/region/cd-us-central-1/instance/default", POST) - .data(entity) - .userIdentity(userId), - new File("deploy-result.json")); - } - @Test public void testSortsDeploymentsAndJobs() { tester.computeVersionStatus(); @@ -897,7 +872,7 @@ public class ApplicationApiTest extends ControllerContainerTest { "{\"error-code\":\"BAD_REQUEST\",\"message\":\"Tenant 'tenant1' already exists\"}", 400); - // POST (add) a Athenz tenant with underscore in name + // POST (add) an Athenz tenant with underscore in name tester.assertResponse(request("/application/v4/tenant/my_tenant_2", POST) .userIdentity(USER_ID) .data("{\"athensDomain\":\"domain1\", \"property\":\"property1\"}") @@ -905,7 +880,7 @@ public class ApplicationApiTest extends ControllerContainerTest { "{\"error-code\":\"BAD_REQUEST\",\"message\":\"New tenant or application names must start with a letter, may contain no more than 20 characters, and may only contain lowercase letters, digits or dashes, but no double-dashes.\"}", 400); - // POST (add) a Athenz tenant with by- prefix + // POST (add) an Athenz tenant with by- prefix tester.assertResponse(request("/application/v4/tenant/by-tenant2", POST) .userIdentity(USER_ID) .data("{\"athensDomain\":\"domain1\", \"property\":\"property1\"}") @@ -913,6 +888,14 @@ public class ApplicationApiTest extends ControllerContainerTest { "{\"error-code\":\"BAD_REQUEST\",\"message\":\"Athenz tenant name cannot have prefix 'by-'\"}", 400); + // POST (add) an Athenz tenant with a reserved name + tester.assertResponse(request("/application/v4/tenant/hosted-vespa", POST) + .userIdentity(USER_ID) + .data("{\"athensDomain\":\"domain1\", \"property\":\"property1\"}") + .oktaAccessToken(OKTA_AT), + "{\"error-code\":\"BAD_REQUEST\",\"message\":\"Tenant 'hosted-vespa' already exists\"}", + 400); + // POST (create) an (empty) application tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1", POST) .userIdentity(USER_ID) diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-cluster-global-rotation.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-cluster-global-rotation.json index c31a47cb5b2..cd531bb96da 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-cluster-global-rotation.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-cluster-global-rotation.json @@ -1,4 +1,5 @@ { + "tenant": "tenant1", "application": "application1", "instance": "default", "deployments": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/instance/default/job/", diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference-2.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference-2.json index a4026d6a812..ff22b95739d 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference-2.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference-2.json @@ -1,4 +1,5 @@ { + "tenant": "tenant2", "application": "application2", "instance": "default", "url": "http://localhost:8080/application/v4/tenant/tenant2/application/application2" diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference.json index 1ec229a2b4a..1d56944f6bc 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-reference.json @@ -1,4 +1,5 @@ { + "tenant": "tenant1", "application":"application1", "instance":"default", "url":"http://localhost:8080/application/v4/tenant/tenant1/application/application1" diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json index 9f65b5952e1..f2f38f7f509 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json @@ -1,4 +1,5 @@ { + "tenant": "tenant1", "application": "application1", "instance": "default", "deployments": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/instance/default/job/", @@ -237,21 +238,21 @@ "rotationId": "rotation-id-1", "instances": [ { - "environment": "prod", - "region": "us-west-1", - "instance": "default", "bcpStatus": { "rotationStatus": "IN" }, + "environment": "prod", + "region": "us-west-1", + "instance": "default", "url": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/environment/prod/region/us-west-1/instance/default" }, { - "environment": "prod", - "region": "us-east-3", - "instance": "default", "bcpStatus": { "rotationStatus": "UNKNOWN" }, + "environment": "prod", + "region": "us-east-3", + "instance": "default", "url": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/environment/prod/region/us-east-3/instance/default" } ], diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json index 3744e44152a..22e8573b1d4 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json @@ -1,4 +1,5 @@ { + "tenant": "tenant1", "application": "application1", "instance": "default", "deployments": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/instance/default/job/", @@ -231,12 +232,12 @@ "url": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/environment/dev/region/us-west-1/instance/default" }, { - "environment": "prod", - "region": "us-central-1", - "instance": "default", "bcpStatus": { "rotationStatus": "UNKNOWN" }, + "environment": "prod", + "region": "us-central-1", + "instance": "default", "url": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/environment/prod/region/us-central-1/instance/default" } ], diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json index 822bc447d8d..662e045d169 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json @@ -1,4 +1,5 @@ { + "tenant": "tenant1", "application": "application1", "instance": "default", "deployments": "http://localhost:8080/application/v4/tenant/tenant1/application/application1/instance/default/job/", diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2-with-majorVersion.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2-with-majorVersion.json index 55803074ade..1477e18b4b8 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2-with-majorVersion.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2-with-majorVersion.json @@ -1,4 +1,5 @@ { + "tenant": "tenant2", "application": "application2", "instance": "default", "deployments": "http://localhost:8080/application/v4/tenant/tenant2/application/application2/instance/default/job/", diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2.json index 2c34e5ae712..3063bb62b7e 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application2.json @@ -1,4 +1,5 @@ { + "tenant": "tenant2", "application": "application2", "instance": "default", "deployments": "http://localhost:8080/application/v4/tenant/tenant2/application/application2/instance/default/job/", diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json index ac1797986fc..af21260676c 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json @@ -1,4 +1,9 @@ { + "tenant": "tenant1", + "application": "application1", + "instance": "default", + "environment": "prod", + "region": "us-central-1", "serviceUrls": [ "http://old-endpoint.vespa.yahooapis.com:4080", "http://qrs-endpoint.vespa.yahooapis.com:4080", diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json index 65ea3925d8c..54e94c4521e 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json @@ -1,7 +1,9 @@ { + "tenant": "tenant1", + "application": "application1", + "instance": "default", "environment": "dev", "region": "us-west-1", - "instance": "default", "serviceUrls": [ "http://old-endpoint.vespa.yahooapis.com:4080", "http://qrs-endpoint.vespa.yahooapis.com:4080", diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-us-central-1.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-us-central-1.json index a3380d823f3..cfefe629b9a 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-us-central-1.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-us-central-1.json @@ -1,10 +1,12 @@ { - "environment": "prod", - "region": "us-central-1", - "instance": "default", "bcpStatus": { "rotationStatus": "UNKNOWN" }, + "tenant": "tenant1", + "application": "application1", + "instance": "default", + "environment": "prod", + "region": "us-central-1", "serviceUrls": [ "http://old-endpoint.vespa.yahooapis.com:4080", "http://qrs-endpoint.vespa.yahooapis.com:4080", diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-with-application.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-with-application.json index ad8e65692b4..b222c33291c 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-with-application.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-with-application.json @@ -5,6 +5,7 @@ "property": "property1", "applications": [ { + "tenant": "tenant1", "application":"application1", "instance":"default", "url":"http://localhost:8080/application/v4/tenant/tenant1/application/application1" diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleFilterTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleFilterTest.java new file mode 100644 index 00000000000..dc4235e52bf --- /dev/null +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleFilterTest.java @@ -0,0 +1,122 @@ +package com.yahoo.vespa.hosted.controller.restapi.filter; + +import com.yahoo.config.provision.ApplicationName; +import com.yahoo.config.provision.TenantName; +import com.yahoo.jdisc.http.filter.security.cors.CorsFilterConfig; +import com.yahoo.vespa.athenz.api.AthenzDomain; +import com.yahoo.vespa.athenz.api.AthenzPrincipal; +import com.yahoo.vespa.athenz.api.AthenzService; +import com.yahoo.vespa.athenz.api.AthenzUser; +import com.yahoo.vespa.hosted.controller.ControllerTester; +import com.yahoo.vespa.hosted.controller.api.identifiers.ApplicationId; +import com.yahoo.vespa.hosted.controller.api.identifiers.ScrewdriverId; +import com.yahoo.vespa.hosted.controller.api.role.Roles; +import com.yahoo.vespa.hosted.controller.athenz.ApplicationAction; +import com.yahoo.vespa.hosted.controller.athenz.HostedAthenzIdentities; +import com.yahoo.vespa.hosted.controller.athenz.impl.AthenzFacade; +import com.yahoo.vespa.hosted.controller.athenz.mock.AthenzClientFactoryMock; +import com.yahoo.vespa.hosted.controller.athenz.mock.AthenzDbMock; +import org.junit.Before; +import org.junit.Test; + +import java.net.URI; +import java.util.Set; + +import static org.junit.Assert.assertEquals; + +/** + * @author jonmv + */ +public class AthenzRoleFilterTest { + + private static final AthenzPrincipal USER = new AthenzPrincipal(new AthenzUser("john")); + private static final AthenzPrincipal HOSTED_OPERATOR = new AthenzPrincipal(new AthenzUser("hosted-operator")); + private static final AthenzDomain TENANT_DOMAIN = new AthenzDomain("tenantdomain"); + private static final AthenzDomain TENANT_DOMAIN2 = new AthenzDomain("tenantdomain2"); + private static final AthenzPrincipal TENANT_ADMIN = new AthenzPrincipal(new AthenzService(TENANT_DOMAIN, "adminservice")); + private static final AthenzPrincipal TENANT_PIPELINE = new AthenzPrincipal(HostedAthenzIdentities.from(new ScrewdriverId("12345"))); + private static final TenantName TENANT = TenantName.from("mytenant"); + private static final TenantName TENANT2 = TenantName.from("othertenant"); + private static final ApplicationName APPLICATION = ApplicationName.from("myapp"); + private static final URI NO_CONTEXT_PATH = URI.create("/application/v4/"); + private static final URI TENANT_CONTEXT_PATH = URI.create("/application/v4/tenant/mytenant/"); + private static final URI APPLICATION_CONTEXT_PATH = URI.create("/application/v4/tenant/mytenant/application/myapp/"); + private static final URI TENANT2_CONTEXT_PATH = URI.create("/application/v4/tenant/othertenant/"); + private static final URI APPLICATION2_CONTEXT_PATH = URI.create("/application/v4/tenant/othertenant/application/myapp/"); + + private ControllerTester tester; + private AthenzRoleFilter filter; + + @Before + public void setup() { + tester = new ControllerTester(); + filter = new AthenzRoleFilter(new CorsFilterConfig.Builder().build(), + new AthenzFacade(new AthenzClientFactoryMock(tester.athenzDb())), + tester.controller()); + + tester.athenzDb().hostedOperators.add(HOSTED_OPERATOR.getIdentity()); + tester.createTenant(TENANT.value(), TENANT_DOMAIN.getName(), null); + tester.createApplication(TENANT, APPLICATION.value(), "default", 12345); + AthenzDbMock.Domain tenantDomain = tester.athenzDb().domains.get(TENANT_DOMAIN); + tenantDomain.admins.add(TENANT_ADMIN.getIdentity()); + tenantDomain.applications.get(new ApplicationId(APPLICATION.value())).addRoleMember(ApplicationAction.deploy, TENANT_PIPELINE.getIdentity()); + tester.createTenant(TENANT2.value(), TENANT_DOMAIN2.getName(), null); + tester.createApplication(TENANT2, APPLICATION.value(), "default", 42); + } + + @Test + public void testTranslations() { + + Roles roles = new Roles(tester.controller().system()); + + // Hosted operators are always members of the hostedOperator role. + assertEquals(Set.of(roles.hostedOperator()), + filter.roles(HOSTED_OPERATOR, NO_CONTEXT_PATH)); + + assertEquals(Set.of(roles.hostedOperator()), + filter.roles(HOSTED_OPERATOR, TENANT_CONTEXT_PATH)); + + assertEquals(Set.of(roles.hostedOperator()), + filter.roles(HOSTED_OPERATOR, APPLICATION_CONTEXT_PATH)); + + // Tenant admins are members of the athenzTenantAdmin role within their tenant subtree. + assertEquals(Set.of(roles.everyone()), + filter.roles(TENANT_PIPELINE, NO_CONTEXT_PATH)); + + assertEquals(Set.of(roles.athenzTenantAdmin(TENANT)), + filter.roles(TENANT_ADMIN, TENANT_CONTEXT_PATH)); + + assertEquals(Set.of(roles.athenzTenantAdmin(TENANT)), + filter.roles(TENANT_ADMIN, APPLICATION_CONTEXT_PATH)); + + assertEquals(Set.of(roles.everyone()), + filter.roles(TENANT_ADMIN, TENANT2_CONTEXT_PATH)); + + assertEquals(Set.of(roles.everyone()), + filter.roles(TENANT_ADMIN, APPLICATION2_CONTEXT_PATH)); + + // Build services are members of the tenantPipeline role within their application subtree. + assertEquals(Set.of(roles.everyone()), + filter.roles(TENANT_PIPELINE, NO_CONTEXT_PATH)); + + assertEquals(Set.of(roles.everyone()), + filter.roles(TENANT_PIPELINE, TENANT_CONTEXT_PATH)); + + assertEquals(Set.of(roles.tenantPipeline(TENANT, APPLICATION)), + filter.roles(TENANT_PIPELINE, APPLICATION_CONTEXT_PATH)); + + assertEquals(Set.of(roles.everyone()), + filter.roles(TENANT_PIPELINE, APPLICATION2_CONTEXT_PATH)); + + // Unprivileged users are just members of the everyone role. + assertEquals(Set.of(roles.everyone()), + filter.roles(USER, NO_CONTEXT_PATH)); + + assertEquals(Set.of(roles.everyone()), + filter.roles(USER, TENANT_CONTEXT_PATH)); + + assertEquals(Set.of(roles.everyone()), + filter.roles(USER, APPLICATION_CONTEXT_PATH)); + } + +} diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleResolverTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleResolverTest.java deleted file mode 100644 index 4628b95ad3c..00000000000 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/AthenzRoleResolverTest.java +++ /dev/null @@ -1,120 +0,0 @@ -package com.yahoo.vespa.hosted.controller.restapi.filter; - -import com.fasterxml.jackson.databind.ObjectMapper; -import com.yahoo.config.provision.ApplicationName; -import com.yahoo.config.provision.TenantName; -import com.yahoo.vespa.athenz.api.AthenzDomain; -import com.yahoo.vespa.athenz.api.AthenzPrincipal; -import com.yahoo.vespa.athenz.api.AthenzService; -import com.yahoo.vespa.athenz.api.AthenzUser; -import com.yahoo.vespa.hosted.controller.ControllerTester; -import com.yahoo.vespa.hosted.controller.api.identifiers.ApplicationId; -import com.yahoo.vespa.hosted.controller.api.identifiers.ScrewdriverId; -import com.yahoo.vespa.hosted.controller.athenz.ApplicationAction; -import com.yahoo.vespa.hosted.controller.athenz.HostedAthenzIdentities; -import com.yahoo.vespa.hosted.controller.athenz.impl.AthenzFacade; -import com.yahoo.vespa.hosted.controller.athenz.mock.AthenzClientFactoryMock; -import com.yahoo.vespa.hosted.controller.athenz.mock.AthenzDbMock; -import com.yahoo.vespa.hosted.controller.role.Context; -import com.yahoo.vespa.hosted.controller.role.Role; -import org.junit.Before; -import org.junit.Test; - -import java.util.Optional; -import java.util.Set; - -import static java.util.Collections.emptySet; -import static org.junit.Assert.assertEquals; - -/** - * @author jonmv - */ -public class AthenzRoleResolverTest { - - private static final ObjectMapper mapper = new ObjectMapper(); - - private static final AthenzPrincipal USER = new AthenzPrincipal(new AthenzUser("john")); - private static final AthenzPrincipal HOSTED_OPERATOR = new AthenzPrincipal(new AthenzUser("hosted-operator")); - private static final AthenzDomain TENANT_DOMAIN = new AthenzDomain("tenantdomain"); - private static final AthenzDomain TENANT_DOMAIN2 = new AthenzDomain("tenantdomain2"); - private static final AthenzPrincipal TENANT_ADMIN = new AthenzPrincipal(new AthenzService(TENANT_DOMAIN, "adminservice")); - private static final AthenzPrincipal TENANT_PIPELINE = new AthenzPrincipal(HostedAthenzIdentities.from(new ScrewdriverId("12345"))); - private static final TenantName TENANT = TenantName.from("mytenant"); - private static final TenantName TENANT2 = TenantName.from("othertenant"); - private static final ApplicationName APPLICATION = ApplicationName.from("myapp"); - private static final Optional<String> NO_CONTEXT_PATH = Optional.of("/application/v4/"); - private static final Optional<String> TENANT_CONTEXT_PATH = Optional.of("/application/v4/tenant/mytenant/"); - private static final Optional<String> APPLICATION_CONTEXT_PATH = Optional.of("/application/v4/tenant/mytenant/application/myapp/"); - private static final Optional<String> TENANT2_CONTEXT_PATH = Optional.of("/application/v4/tenant/othertenant/"); - private static final Optional<String> APPLICATION2_CONTEXT_PATH = Optional.of("/application/v4/tenant/othertenant/application/myapp/"); - - private ControllerTester tester; - private AthenzRoleResolver resolver; - - @Before - public void setup() { - tester = new ControllerTester(); - resolver = new AthenzRoleResolver(new AthenzFacade(new AthenzClientFactoryMock(tester.athenzDb())), - tester.controller()); - - tester.athenzDb().hostedOperators.add(HOSTED_OPERATOR.getIdentity()); - tester.createTenant(TENANT.value(), TENANT_DOMAIN.getName(), null); - tester.createApplication(TENANT, APPLICATION.value(), "default", 12345); - AthenzDbMock.Domain tenantDomain = tester.athenzDb().domains.get(TENANT_DOMAIN); - tenantDomain.admins.add(TENANT_ADMIN.getIdentity()); - tenantDomain.applications.get(new ApplicationId(APPLICATION.value())).addRoleMember(ApplicationAction.deploy, TENANT_PIPELINE.getIdentity()); - tester.createTenant(TENANT2.value(), TENANT_DOMAIN2.getName(), null); - tester.createApplication(TENANT2, APPLICATION.value(), "default", 42); - } - - @Test - public void testTranslations() { - - // Everyone is member of the everyone role. - assertEquals(Set.of(Context.unlimitedIn(tester.controller().system())), - resolver.membership(HOSTED_OPERATOR, APPLICATION_CONTEXT_PATH).contextsFor(Role.everyone)); - assertEquals(Set.of(Context.unlimitedIn(tester.controller().system())), - resolver.membership(TENANT_ADMIN, TENANT_CONTEXT_PATH).contextsFor(Role.everyone)); - assertEquals(Set.of(Context.unlimitedIn(tester.controller().system())), - resolver.membership(TENANT_PIPELINE, NO_CONTEXT_PATH).contextsFor(Role.everyone)); - assertEquals(Set.of(Context.unlimitedIn(tester.controller().system())), - resolver.membership(USER, APPLICATION_CONTEXT_PATH).contextsFor(Role.everyone)); - - // Only operators are members of the operator role. - assertEquals(Set.of(Context.unlimitedIn(tester.controller().system())), - resolver.membership(HOSTED_OPERATOR, TENANT_CONTEXT_PATH).contextsFor(Role.hostedOperator)); - assertEquals(emptySet(), - resolver.membership(TENANT_ADMIN, NO_CONTEXT_PATH).contextsFor(Role.hostedOperator)); - assertEquals(emptySet(), - resolver.membership(TENANT_PIPELINE, APPLICATION_CONTEXT_PATH).contextsFor(Role.hostedOperator)); - assertEquals(emptySet(), - resolver.membership(USER, TENANT_CONTEXT_PATH).contextsFor(Role.hostedOperator)); - - // Operators and tenant admins are tenant admins of their tenants. - assertEquals(Set.of(Context.limitedTo(TENANT, tester.controller().system())), - resolver.membership(HOSTED_OPERATOR, APPLICATION_CONTEXT_PATH).contextsFor(Role.athenzTenantAdmin)); - assertEquals(emptySet(), // TODO this is wrong, but we can't do better until we ask ZMS for roles. - resolver.membership(TENANT_ADMIN, NO_CONTEXT_PATH).contextsFor(Role.athenzTenantAdmin)); - assertEquals(Set.of(Context.limitedTo(TENANT, tester.controller().system())), - resolver.membership(TENANT_ADMIN, TENANT_CONTEXT_PATH).contextsFor(Role.athenzTenantAdmin)); - assertEquals(emptySet(), - resolver.membership(TENANT_ADMIN, TENANT2_CONTEXT_PATH).contextsFor(Role.athenzTenantAdmin)); - assertEquals(emptySet(), - resolver.membership(TENANT_PIPELINE, APPLICATION_CONTEXT_PATH).contextsFor(Role.athenzTenantAdmin)); - assertEquals(emptySet(), - resolver.membership(USER, TENANT_CONTEXT_PATH).contextsFor(Role.athenzTenantAdmin)); - - // Only build services are pipeline operators of their applications. - assertEquals(emptySet(), - resolver.membership(HOSTED_OPERATOR, APPLICATION_CONTEXT_PATH).contextsFor(Role.tenantPipeline)); - assertEquals(emptySet(), - resolver.membership(TENANT_ADMIN, APPLICATION_CONTEXT_PATH).contextsFor(Role.tenantPipeline)); - assertEquals(Set.of(Context.limitedTo(TENANT, APPLICATION, tester.controller().system())), - resolver.membership(TENANT_PIPELINE, APPLICATION_CONTEXT_PATH).contextsFor(Role.tenantPipeline)); - assertEquals(emptySet(), - resolver.membership(TENANT_PIPELINE, APPLICATION2_CONTEXT_PATH).contextsFor(Role.tenantPipeline)); - assertEquals(emptySet(), - resolver.membership(USER, APPLICATION_CONTEXT_PATH).contextsFor(Role.tenantPipeline)); - } - -} diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilterTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilterTest.java index 39b08695986..105e10eefd2 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilterTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilterTest.java @@ -6,13 +6,10 @@ import com.yahoo.application.container.handler.Request; import com.yahoo.config.provision.SystemName; import com.yahoo.jdisc.http.HttpRequest.Method; import com.yahoo.jdisc.http.filter.DiscFilterRequest; -import com.yahoo.vespa.athenz.api.AthenzIdentity; -import com.yahoo.vespa.athenz.api.AthenzPrincipal; -import com.yahoo.vespa.athenz.api.AthenzUser; import com.yahoo.vespa.hosted.controller.ControllerTester; +import com.yahoo.vespa.hosted.controller.api.role.Roles; +import com.yahoo.vespa.hosted.controller.api.role.SecurityContext; import com.yahoo.vespa.hosted.controller.restapi.ApplicationRequestToDiscFilterRequestWrapper; -import com.yahoo.vespa.hosted.controller.role.Role; -import com.yahoo.vespa.hosted.controller.role.RoleMembership; import org.junit.Test; import java.io.IOException; @@ -33,39 +30,42 @@ import static org.junit.Assert.assertTrue; public class ControllerAuthorizationFilterTest { private static final ObjectMapper mapper = new ObjectMapper(); - private static AthenzIdentity identity = new AthenzUser("user"); @Test public void operator() { ControllerTester tester = new ControllerTester(); - RoleMembership.Resolver operatorResolver = (user, path) -> RoleMembership.in(tester.controller().system()) - .add(Role.hostedOperator) - .build(); - ControllerAuthorizationFilter filter = createFilter(tester, operatorResolver); - assertIsAllowed(invokeFilter(filter, createRequest(Method.POST, "/zone/v2/path", identity))); - assertIsAllowed(invokeFilter(filter, createRequest(Method.PUT, "/application/v4/user", identity))); - assertIsAllowed(invokeFilter(filter, createRequest(Method.GET, "/zone/v1/path", identity))); + Roles roles = new Roles(tester.controller().system()); + SecurityContext securityContext = new SecurityContext(() -> "operator", Set.of(roles.hostedOperator())); + ControllerAuthorizationFilter filter = createFilter(tester); + + assertIsAllowed(invokeFilter(filter, createRequest(Method.POST, "/zone/v2/path", securityContext))); + assertIsAllowed(invokeFilter(filter, createRequest(Method.PUT, "/application/v4/user", securityContext))); + assertIsAllowed(invokeFilter(filter, createRequest(Method.GET, "/zone/v1/path", securityContext))); } @Test public void unprivileged() { ControllerTester tester = new ControllerTester(); - RoleMembership.Resolver emptyResolver = (user, path) -> RoleMembership.in(tester.controller().system()).build(); - ControllerAuthorizationFilter filter = createFilter(tester, emptyResolver); - assertIsForbidden(invokeFilter(filter, createRequest(Method.POST, "/zone/v2/path", identity))); - assertIsAllowed(invokeFilter(filter, createRequest(Method.PUT, "/application/v4/user", identity))); - assertIsAllowed(invokeFilter(filter, createRequest(Method.GET, "/zone/v1/path", identity))); + Roles roles = new Roles(tester.controller().system()); + SecurityContext securityContext = new SecurityContext(() -> "user", Set.of(roles.everyone())); + ControllerAuthorizationFilter filter = createFilter(tester); + + assertIsForbidden(invokeFilter(filter, createRequest(Method.POST, "/zone/v2/path", securityContext))); + assertIsAllowed(invokeFilter(filter, createRequest(Method.PUT, "/application/v4/user", securityContext))); + assertIsAllowed(invokeFilter(filter, createRequest(Method.GET, "/zone/v1/path", securityContext))); } @Test public void unprivilegedInPublic() { ControllerTester tester = new ControllerTester(); tester.zoneRegistry().setSystemName(SystemName.Public); - RoleMembership.Resolver emptyResolver = (user, path) -> RoleMembership.in(tester.controller().system()).build(); - ControllerAuthorizationFilter filter = createFilter(tester, emptyResolver); - assertIsForbidden(invokeFilter(filter, createRequest(Method.POST, "/zone/v2/path", identity))); - assertIsForbidden(invokeFilter(filter, createRequest(Method.PUT, "/application/v4/user", identity))); - assertIsAllowed(invokeFilter(filter, createRequest(Method.GET, "/zone/v1/path", identity))); + Roles roles = new Roles(tester.controller().system()); + SecurityContext securityContext = new SecurityContext(() -> "user", Set.of(roles.everyone())); + + ControllerAuthorizationFilter filter = createFilter(tester); + assertIsForbidden(invokeFilter(filter, createRequest(Method.POST, "/zone/v2/path", securityContext))); + assertIsForbidden(invokeFilter(filter, createRequest(Method.PUT, "/application/v4/user", securityContext))); + assertIsAllowed(invokeFilter(filter, createRequest(Method.GET, "/zone/v1/path", securityContext))); } private static void assertIsAllowed(Optional<AuthorizationResponse> response) { @@ -79,8 +79,8 @@ public class ControllerAuthorizationFilterTest { assertEquals("Invalid status code", FORBIDDEN, response.get().statusCode); } - private static ControllerAuthorizationFilter createFilter(ControllerTester tester, RoleMembership.Resolver resolver) { - return new ControllerAuthorizationFilter(resolver, tester.controller(), Set.of("http://localhost")); + private static ControllerAuthorizationFilter createFilter(ControllerTester tester) { + return new ControllerAuthorizationFilter(tester.controller().system(), Set.of("http://localhost")); } private static Optional<AuthorizationResponse> invokeFilter(ControllerAuthorizationFilter filter, @@ -91,9 +91,9 @@ public class ControllerAuthorizationFilterTest { .map(response -> new AuthorizationResponse(response.getStatus(), getErrorMessage(responseHandlerMock))); } - private static DiscFilterRequest createRequest(Method method, String path, AthenzIdentity identity) { - Request request = new Request(path, new byte[0], Request.Method.valueOf(method.name()), - new AthenzPrincipal(identity)); + private static DiscFilterRequest createRequest(Method method, String path, SecurityContext securityContext) { + Request request = new Request(path, new byte[0], Request.Method.valueOf(method.name()), securityContext.principal()); + request.getAttributes().put(SecurityContext.ATTRIBUTE_NAME, securityContext); return new ApplicationRequestToDiscFilterRequestWrapper(request); } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/role/RoleMembershipTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/role/RoleMembershipTest.java deleted file mode 100644 index 1da5d3764f6..00000000000 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/role/RoleMembershipTest.java +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.controller.role; - -import com.yahoo.config.provision.ApplicationName; -import com.yahoo.config.provision.SystemName; -import com.yahoo.config.provision.TenantName; -import org.junit.Test; - -import java.net.URI; - -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -/** - * @author mpolden - */ -public class RoleMembershipTest { - - @Test - public void operator_membership() { - RoleMembership roles = RoleMembership.in(SystemName.main) - .add(Role.hostedOperator) - .build(); - - // Operator actions - assertFalse(roles.allows(Action.create, URI.create("/not/explicitly/defined"))); - assertTrue(roles.allows(Action.create, URI.create("/controller/v1/foo"))); - assertTrue(roles.allows(Action.update, URI.create("/os/v1/bar"))); - assertTrue(roles.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1"))); - assertTrue(roles.allows(Action.update, URI.create("/application/v4/tenant/t2/application/a2"))); - } - - @Test - public void tenant_membership() { - RoleMembership roles = RoleMembership.in(SystemName.main) - .add(Role.athenzTenantAdmin).limitedTo(TenantName.from("t1"), ApplicationName.from("a1")) - .build(); - assertFalse(roles.allows(Action.create, URI.create("/not/explicitly/defined"))); - assertFalse("Deny access to operator API", roles.allows(Action.create, URI.create("/controller/v1/foo"))); - assertFalse("Deny access to other tenant and app", roles.allows(Action.update, URI.create("/application/v4/tenant/t2/application/a2"))); - assertFalse("Deny access to other app", roles.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a2"))); - assertTrue(roles.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1"))); - - RoleMembership multiContext = RoleMembership.in(SystemName.main) - .add(Role.athenzTenantAdmin).limitedTo(TenantName.from("t1"), ApplicationName.from("a1")) - .add(Role.athenzTenantAdmin).limitedTo(TenantName.from("t2"), ApplicationName.from("a2")) - .build(); - assertFalse("Deny access to other tenant and app", multiContext.allows(Action.update, URI.create("/application/v4/tenant/t3/application/a3"))); - assertTrue(multiContext.allows(Action.update, URI.create("/application/v4/tenant/t2/application/a2"))); - assertTrue(multiContext.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1"))); - - RoleMembership publicSystem = RoleMembership.in(SystemName.vaas) - .add(Role.athenzTenantAdmin).limitedTo(TenantName.from("t1"), ApplicationName.from("a1")) - .build(); - assertFalse(publicSystem.allows(Action.read, URI.create("/controller/v1/foo"))); - assertTrue(multiContext.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1"))); - } - - @Test - public void build_service_membership() { - RoleMembership roles = RoleMembership.in(SystemName.main) - .add(Role.tenantPipeline).build(); - assertFalse(roles.allows(Action.create, URI.create("/not/explicitly/defined"))); - assertFalse(roles.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1"))); - assertTrue(roles.allows(Action.create, URI.create("/application/v4/tenant/t1/application/a1/jobreport"))); - assertFalse("No global read access", roles.allows(Action.read, URI.create("/controller/v1/foo"))); - } - - @Test - public void multi_role_membership() { - RoleMembership roles = RoleMembership.in(SystemName.main) - .add(Role.athenzTenantAdmin).limitedTo(TenantName.from("t1"), ApplicationName.from("a1")) - .add(Role.tenantPipeline) - .add(Role.everyone) - .build(); - assertFalse(roles.allows(Action.create, URI.create("/not/explicitly/defined"))); - assertFalse(roles.allows(Action.create, URI.create("/controller/v1/foo"))); - assertTrue(roles.allows(Action.create, URI.create("/application/v4/tenant/t1/application/a1/jobreport"))); - assertTrue(roles.allows(Action.update, URI.create("/application/v4/tenant/t1/application/a1"))); - assertTrue("Global read access", roles.allows(Action.read, URI.create("/controller/v1/foo"))); - assertTrue("Dashboard read access", roles.allows(Action.read, URI.create("/"))); - assertTrue("Dashboard read access", roles.allows(Action.read, URI.create("/d/nodes"))); - assertTrue("Dashboard read access", roles.allows(Action.read, URI.create("/statuspage/v1/incidents"))); - } - -} diff --git a/document/src/main/java/com/yahoo/document/CollectionDataType.java b/document/src/main/java/com/yahoo/document/CollectionDataType.java index a73588a710c..c6420b5e71f 100644 --- a/document/src/main/java/com/yahoo/document/CollectionDataType.java +++ b/document/src/main/java/com/yahoo/document/CollectionDataType.java @@ -32,7 +32,6 @@ public abstract class CollectionDataType extends DataType { return type; } - @SuppressWarnings("deprecation") public DataType getNestedType() { return nestedType; } @@ -58,11 +57,7 @@ public abstract class CollectionDataType extends DataType { return false; } CollectionFieldValue cfv = (CollectionFieldValue) value; - if (equals(cfv.getDataType())) { - //the field value if of this type: - return true; - } - return false; + return equals(cfv.getDataType()); } @Override diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java index 2773f9d31da..435c8fcdc65 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java @@ -38,7 +38,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { * Converts the given tensor type to a type that is compatible for being used in this update (has only mapped dimensions). */ public static TensorType convertDimensionsToMapped(TensorType type) { - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(type.valueType()); type.dimensions().stream().forEach(dim -> builder.mapped(dim.name())); return builder.build(); } diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java index 335cda8e133..981120af145 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java @@ -97,7 +97,7 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { } public static TensorType extractSparseDimensions(TensorType type) { - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(type.valueType()); type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name())); return builder.build(); } diff --git a/documentgen-test/etc/complex/music4.sd b/documentgen-test/etc/complex/music4.sd index c8100ba7de2..eab0018360d 100644 --- a/documentgen-test/etc/complex/music4.sd +++ b/documentgen-test/etc/complex/music4.sd @@ -4,5 +4,8 @@ search music4 { field mu4 type string { } + field pos type position { + + } } } diff --git a/documentgen-test/src/test/java/com/yahoo/vespa/config/DocumentGenPluginTest.java b/documentgen-test/src/test/java/com/yahoo/vespa/config/DocumentGenPluginTest.java index deec438a332..b6a0f165ca6 100644 --- a/documentgen-test/src/test/java/com/yahoo/vespa/config/DocumentGenPluginTest.java +++ b/documentgen-test/src/test/java/com/yahoo/vespa/config/DocumentGenPluginTest.java @@ -5,21 +5,53 @@ import com.yahoo.compress.CompressionType; import com.yahoo.docproc.DocumentProcessor; import com.yahoo.docproc.Processing; import com.yahoo.docproc.proxy.ProxyDocument; -import com.yahoo.document.*; +import com.yahoo.document.ArrayDataType; +import com.yahoo.document.DataType; +import com.yahoo.document.Document; +import com.yahoo.document.DocumentId; +import com.yahoo.document.DocumentPut; +import com.yahoo.document.DocumentType; +import com.yahoo.document.DocumentTypeManager; +import com.yahoo.document.Field; +import com.yahoo.document.Generated; +import com.yahoo.document.MapDataType; +import com.yahoo.document.ReferenceDataType; +import com.yahoo.document.StructDataType; +import com.yahoo.document.WeightedSetDataType; import com.yahoo.document.annotation.Annotation; import com.yahoo.document.annotation.AnnotationType; import com.yahoo.document.annotation.SpanTree; import com.yahoo.document.config.DocumentmanagerConfig; -import com.yahoo.document.datatypes.*; -import com.yahoo.document.serialization.*; +import com.yahoo.document.datatypes.Array; +import com.yahoo.document.datatypes.DoubleFieldValue; +import com.yahoo.document.datatypes.FieldValue; +import com.yahoo.document.datatypes.FloatFieldValue; +import com.yahoo.document.datatypes.IntegerFieldValue; +import com.yahoo.document.datatypes.LongFieldValue; +import com.yahoo.document.datatypes.MapFieldValue; +import com.yahoo.document.datatypes.Raw; +import com.yahoo.document.datatypes.ReferenceFieldValue; +import com.yahoo.document.datatypes.StringFieldValue; +import com.yahoo.document.datatypes.Struct; +import com.yahoo.document.datatypes.StructuredFieldValue; +import com.yahoo.document.datatypes.WeightedSet; +import com.yahoo.document.serialization.DocumentDeserializerFactory; +import com.yahoo.document.serialization.DocumentSerializer; +import com.yahoo.document.serialization.DocumentSerializerFactory; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.searchdefinition.derived.Deriver; import com.yahoo.tensor.Tensor; import com.yahoo.vespa.document.NodeImpl; import com.yahoo.vespa.document.dom.DocumentImpl; -import com.yahoo.vespa.documentgen.test.*; +import com.yahoo.vespa.documentgen.test.Book; import com.yahoo.vespa.documentgen.test.Book.Ss0; import com.yahoo.vespa.documentgen.test.Book.Ss1; +import com.yahoo.vespa.documentgen.test.Common; +import com.yahoo.vespa.documentgen.test.ConcreteDocumentFactory; +import com.yahoo.vespa.documentgen.test.Music; +import com.yahoo.vespa.documentgen.test.Music3; +import com.yahoo.vespa.documentgen.test.Music4; +import com.yahoo.vespa.documentgen.test.Parent; import com.yahoo.vespa.documentgen.test.annotation.Artist; import com.yahoo.vespa.documentgen.test.annotation.Date; import com.yahoo.vespa.documentgen.test.annotation.Emptyannotation; @@ -32,10 +64,24 @@ import java.lang.Class; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.nio.ByteBuffer; -import java.util.*; - +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import static junit.framework.TestCase.assertFalse; +import static junit.framework.TestCase.assertNotSame; import static org.hamcrest.CoreMatchers.instanceOf; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertThat; + /** * Testcases for vespa-documentgen-plugin @@ -675,7 +721,7 @@ public class DocumentGenPluginTest { } if (generated.getAnnotation(com.yahoo.document.Generated.class)==null) return null; Book book = new Book(d.getId()); - for (Iterator<Map.Entry<Field, FieldValue>>i=d.iterator() ; i.hasNext() ; ) { + for (Iterator<Map.Entry<Field, FieldValue>> i = d.iterator(); i.hasNext() ; ) { Map.Entry<Field, FieldValue> e = i.next(); Field f = e.getKey(); FieldValue fv = e.getValue(); @@ -928,5 +974,12 @@ public class DocumentGenPluginTest { book.setVector(Tensor.from("{{x:0}:1.0, {x:1}:2.0, {x:2}:3.0}")); assertEquals("tensor(x{}):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}", book.getVector().toString()); } + + @Test + public void testPositionType() { + Music4 book = new Music4(new DocumentId("doc:music4:0")); + book.setPos(new Music4.Position().setX(7).setY(8)); + assertEquals(new Music4.Position().setX(7).setY(8), book.getPos()); + } } diff --git a/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp b/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp index ae6166f9d24..a0aeb6b63c9 100644 --- a/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp +++ b/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp @@ -40,6 +40,7 @@ assertTensorSpec(const TensorSpec &expSpec, const Tensor &tensor) struct Fixture { Builder builder; + Fixture() : builder() {} }; Tensor::UP diff --git a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp index 7aa1d71fe9a..b7aa988775d 100644 --- a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp +++ b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp @@ -11,6 +11,7 @@ #include <vespa/vespalib/objects/nbostream.h> #include <vespa/vespalib/objects/hexdump.h> #include <ostream> +#include <vespa/eval/tensor/dense/dense_tensor_view.h> using namespace vespalib::tensor; using vespalib::nbostream; @@ -32,9 +33,7 @@ std::ostream &operator<<(std::ostream &out, const std::vector<uint8_t> &rhs) } -namespace vespalib { - -namespace tensor { +namespace vespalib::tensor { static bool operator==(const Tensor &lhs, const Tensor &rhs) { @@ -42,7 +41,6 @@ static bool operator==(const Tensor &lhs, const Tensor &rhs) } } -} template <class BuilderType> void @@ -69,7 +67,7 @@ struct Fixture Fixture() : _builder() {} Tensor::UP createTensor(const TensorCells &cells) { - return vespalib::tensor::TensorFactory::create(cells, _builder); + return TensorFactory::create(cells, _builder); } Tensor::UP createTensor(const TensorCells &cells, const TensorDimensions &dimensions) { return TensorFactory::create(cells, dimensions, _builder); @@ -84,7 +82,7 @@ struct Fixture auto formatId = wrapStream.getInt1_4Bytes(); ASSERT_EQUAL(formatId, 1u); // sparse format SparseBinaryFormat::deserialize(wrapStream, builder); - EXPECT_TRUE(wrapStream.size() == 0); + EXPECT_TRUE(wrapStream.empty()); auto ret = builder.build(); checkDeserialize<BuilderType>(stream, *ret); stream.adjustReadPos(stream.size()); @@ -162,93 +160,129 @@ struct DenseFixture return ret; } void assertSerialized(const ExpBuffer &exp, const DenseTensorCells &rhs) { + assertSerialized(exp, SerializeFormat::DOUBLE, rhs); + } + template <typename T> + void assertCellsOnly(const ExpBuffer &exp, const DenseTensorView & rhs) { + nbostream a(&exp[0], exp.size()); + std::vector<T> v; + TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(a, v); + EXPECT_EQUAL(v.size(), rhs.cellsRef().size()); + for (size_t i(0); i < v.size(); i++) { + EXPECT_EQUAL(v[i], rhs.cellsRef()[i]); + } + } + void assertSerialized(const ExpBuffer &exp, SerializeFormat cellType, const DenseTensorCells &rhs) { Tensor::UP rhsTensor(createTensor(rhs)); nbostream rhsStream; - serialize(rhsStream, *rhsTensor); + TypedBinaryFormat::serialize(rhsStream, *rhsTensor, cellType); EXPECT_EQUAL(exp, rhsStream); auto rhs2 = deserialize(rhsStream); EXPECT_EQUAL(*rhs2, *rhsTensor); + + assertCellsOnly<float>(exp, dynamic_cast<const DenseTensorView &>(*rhs2)); + assertCellsOnly<double>(exp, dynamic_cast<const DenseTensorView &>(*rhs2)); } }; -TEST_F("test tensor serialization for DenseTensor", DenseFixture) -{ - TEST_DO(f.assertSerialized({ 0x02, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00}, +TEST_F("test tensor serialization for DenseTensor", DenseFixture) { + TEST_DO(f.assertSerialized({0x02, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00}, {})); - TEST_DO(f.assertSerialized({ 0x02, 0x01, 0x01, 0x78, 0x01, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00}, - { {{{"x",0}}, 0} })); - TEST_DO(f.assertSerialized({ 0x02, 0x02, 0x01, 0x78, 0x01, - 0x01, 0x79, 0x01, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00 }, - { {{{"x",0},{"y", 0}}, 0} })); - TEST_DO(f.assertSerialized({ 0x02, 0x01, 0x01, 0x78, 0x02, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x40, 0x08, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00 }, - { {{{"x",1}}, 3} })); - TEST_DO(f.assertSerialized({ 0x02, 0x02, 0x01, 0x78, 0x01, - 0x01, 0x79, 0x01, - 0x40, 0x08, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00 }, - { {{{"x",0},{"y",0}}, 3} })); - TEST_DO(f.assertSerialized({ 0x02, 0x02, 0x01, 0x78, 0x02, - 0x01, 0x79, 0x01, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x40, 0x08, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00 }, - { {{{"x",1},{"y",0}}, 3} })); - TEST_DO(f.assertSerialized({ 0x02, 0x02, 0x01, 0x78, 0x01, - 0x01, 0x79, 0x04, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x40, 0x08, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00 }, - { {{{"x",0},{"y",3}}, 3} })); - TEST_DO(f.assertSerialized({ 0x02, 0x02, 0x01, 0x78, 0x03, - 0x01, 0x79, 0x05, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x40, 0x08, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00 }, - { {{{"x",2}, {"y",4}}, 3} })); + TEST_DO(f.assertSerialized({0x02, 0x01, 0x01, 0x78, 0x01, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00}, + {{{{"x", 0}}, 0}})); + TEST_DO(f.assertSerialized({0x02, 0x02, 0x01, 0x78, 0x01, + 0x01, 0x79, 0x01, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00}, + {{{{"x", 0}, {"y", 0}}, 0}})); + TEST_DO(f.assertSerialized({0x02, 0x01, 0x01, 0x78, 0x02, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x40, 0x08, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00}, + {{{{"x", 1}}, 3}})); + TEST_DO(f.assertSerialized({0x02, 0x02, 0x01, 0x78, 0x01, + 0x01, 0x79, 0x01, + 0x40, 0x08, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00}, + {{{{"x", 0}, {"y", 0}}, 3}})); + TEST_DO(f.assertSerialized({0x02, 0x02, 0x01, 0x78, 0x02, + 0x01, 0x79, 0x01, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x40, 0x08, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00}, + {{{{"x", 1}, {"y", 0}}, 3}})); + TEST_DO(f.assertSerialized({0x02, 0x02, 0x01, 0x78, 0x01, + 0x01, 0x79, 0x04, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x40, 0x08, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00}, + {{{{"x", 0}, {"y", 3}}, 3}})); + TEST_DO(f.assertSerialized({0x02, 0x02, 0x01, 0x78, 0x03, + 0x01, 0x79, 0x05, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x40, 0x08, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00}, + {{{{"x", 2}, {"y", 4}}, 3}})); +} + +TEST_F("test 'float' cells", DenseFixture) { + TEST_DO(f.assertSerialized({0x06, 0x01, 0x02, 0x01, 0x78, 0x03, + 0x01, 0x79, 0x05, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x40, 0x40, 0x00, 0x00 }, + SerializeFormat::FLOAT, { {{{"x",2}, {"y",4}}, 3} })); } diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h index 564d6a6b84e..9d95d91ae15 100644 --- a/eval/src/vespa/eval/eval/value_type.h +++ b/eval/src/vespa/eval/eval/value_type.h @@ -4,7 +4,6 @@ #include <vespa/vespalib/stllike/string.h> #include <vector> -#include <memory> namespace vespalib::eval { @@ -36,11 +35,12 @@ public: }; private: - Type _type; + Type _type; std::vector<Dimension> _dimensions; - explicit ValueType(Type type_in) + ValueType(Type type_in) : _type(type_in), _dimensions() {} + ValueType(Type type_in, std::vector<Dimension> &&dimensions_in) : _type(type_in), _dimensions(std::move(dimensions_in)) {} diff --git a/eval/src/vespa/eval/eval/value_type_spec.cpp b/eval/src/vespa/eval/eval/value_type_spec.cpp index cf0fb6d493a..229a9201f08 100644 --- a/eval/src/vespa/eval/eval/value_type_spec.cpp +++ b/eval/src/vespa/eval/eval/value_type_spec.cpp @@ -6,9 +6,7 @@ #include <vespa/vespalib/util/stringfmt.h> #include <algorithm> -namespace vespalib { -namespace eval { -namespace value_type { +namespace vespalib::eval::value_type { namespace { @@ -205,6 +203,4 @@ to_spec(const ValueType &type) return os.str(); } -} // namespace vespalib::eval::value_type -} // namespace vespalib::eval -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/eval/value_type_spec.h b/eval/src/vespa/eval/eval/value_type_spec.h index 76d50834ae8..f2609f59f32 100644 --- a/eval/src/vespa/eval/eval/value_type_spec.h +++ b/eval/src/vespa/eval/eval/value_type_spec.h @@ -4,15 +4,11 @@ #include "value_type.h" -namespace vespalib { -namespace eval { -namespace value_type { +namespace vespalib::eval::value_type { ValueType parse_spec(const char *pos_in, const char *end_in, const char *&pos_out); ValueType from_spec(const vespalib::string &str); vespalib::string to_spec(const ValueType &type); -} // namespace vespalib::eval::value_type -} // namespace vespalib::eval -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/default_tensor.h b/eval/src/vespa/eval/tensor/default_tensor.h index 202b482e300..456d3333295 100644 --- a/eval/src/vespa/eval/tensor/default_tensor.h +++ b/eval/src/vespa/eval/tensor/default_tensor.h @@ -5,13 +5,11 @@ #include "sparse/sparse_tensor.h" #include "sparse/sparse_tensor_builder.h" -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { struct DefaultTensor { using type = SparseTensor; using builder = SparseTensorBuilder; }; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp index 554953288e1..5a16511fe71 100644 --- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp +++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp @@ -26,8 +26,7 @@ #include <vespa/log/log.h> LOG_SETUP(".eval.tensor.default_tensor_engine"); -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { using eval::Aggr; using eval::Aggregator; @@ -390,5 +389,4 @@ DefaultTensorEngine::rename(const Value &a, const std::vector<vespalib::string> //----------------------------------------------------------------------------- -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.h b/eval/src/vespa/eval/tensor/default_tensor_engine.h index 755bdcf6a9d..b7a9e4d43e7 100644 --- a/eval/src/vespa/eval/tensor/default_tensor_engine.h +++ b/eval/src/vespa/eval/tensor/default_tensor_engine.h @@ -4,8 +4,7 @@ #include <vespa/eval/eval/tensor_engine.h> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /** * This is a tensor engine implementation wrapping the default tensor @@ -34,5 +33,4 @@ public: const Value &rename(const Value &a, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash) const override; }; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor.cpp index e775385b623..c183e5c1db3 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor.cpp @@ -75,6 +75,8 @@ DenseTensor::DenseTensor(eval::ValueType &&type_in, checkCellsSize(*this); } +DenseTensor::~DenseTensor() = default; + bool DenseTensor::operator==(const DenseTensor &rhs) const { diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor.h b/eval/src/vespa/eval/tensor/dense/dense_tensor.h index 0da5f570674..3795831c914 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor.h @@ -2,10 +2,6 @@ #pragma once -#include <vespa/eval/tensor/tensor.h> -#include <vespa/eval/tensor/types.h> -#include <vespa/eval/eval/value_type.h> -#include "dense_tensor_cells_iterator.h" #include "dense_tensor_view.h" namespace vespalib::tensor { @@ -17,20 +13,16 @@ namespace vespalib::tensor { class DenseTensor : public DenseTensorView { public: - typedef std::unique_ptr<DenseTensor> UP; - using Cells = std::vector<double>; - -private: - eval::ValueType _type; - Cells _cells; - -public: DenseTensor(); - ~DenseTensor() {} + ~DenseTensor() override; DenseTensor(const eval::ValueType &type_in, const Cells &cells_in); DenseTensor(const eval::ValueType &type_in, Cells &&cells_in); DenseTensor(eval::ValueType &&type_in, Cells &&cells_in); bool operator==(const DenseTensor &rhs) const; +private: + eval::ValueType _type; + Cells _cells; + }; } diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp index 25478510587..fa1e59c87db 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp @@ -74,7 +74,7 @@ apply(const DenseTensorView &lhs, const Tensor &rhs, Function &&func) } const DenseTensor *dense = dynamic_cast<const DenseTensor *>(&rhs); if (dense) { - return apply(lhs, DenseTensorView(*dense), func); + return apply(lhs, *dense, func); } return Tensor::UP(); } diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp index 5d52e5f6e0e..cd4738cf1ee 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.cpp @@ -6,7 +6,6 @@ #include <limits> #include <algorithm> - using vespalib::IllegalArgumentException; using vespalib::make_string; @@ -83,8 +82,7 @@ DenseTensorBuilder::calculateCellAddress() const auto &dim = _dimensions[i]; if (label == UNDEFINED_LABEL) { throw IllegalArgumentException(make_string("Label for dimension '%s' is undefined. " - "Expected a value in the range [0, %u>", - dim.name.c_str(), dim.size)); + "Expected a value in the range [0, %u>", dim.name.c_str(), dim.size)); } result += (label * multiplier); multiplier *= dim.size; @@ -102,12 +100,10 @@ DenseTensorBuilder::DenseTensorBuilder() { } -DenseTensorBuilder::~DenseTensorBuilder() { -} +DenseTensorBuilder::~DenseTensorBuilder() = default; DenseTensorBuilder::Dimension -DenseTensorBuilder::defineDimension(const vespalib::string &dimension, - size_t dimensionSize) +DenseTensorBuilder::defineDimension(const vespalib::string &dimension, size_t dimensionSize) { auto itr = _dimensionsEnum.find(dimension); if (itr != _dimensionsEnum.end()) { @@ -135,8 +131,7 @@ DenseTensorBuilder::addLabel(Dimension dimension, size_t label) Dimension mappedDimension = _dimensionsMapping[dimension]; const auto &dim = _dimensions[mappedDimension]; validateLabelInRange(label, dim.size, dim.name); - validateLabelNotSpecified(_addressBuilder[mappedDimension], - dim.name); + validateLabelNotSpecified(_addressBuilder[mappedDimension], dim.name); _addressBuilder[mappedDimension] = label; return *this; } @@ -154,14 +149,13 @@ DenseTensorBuilder::addCell(double value) return *this; } -Tensor::UP +std::unique_ptr<DenseTensor> DenseTensorBuilder::build() { if (_cells.empty()) { allocateCellsStorage(); } - Tensor::UP result = std::make_unique<DenseTensor>(makeValueType(std::move(_dimensions)), - std::move(_cells)); + auto result = std::make_unique<DenseTensor>(makeValueType(std::move(_dimensions)), std::move(_cells)); _dimensionsEnum.clear(); _dimensions.clear(); DenseTensor::Cells().swap(_cells); diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h index 3969a9335b8..05cd88b1319 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_builder.h @@ -15,12 +15,11 @@ class DenseTensorBuilder { public: using Dimension = TensorBuilder::Dimension; - private: vespalib::hash_map<vespalib::string, size_t> _dimensionsEnum; std::vector<eval::ValueType::Dimension> _dimensions; - DenseTensor::Cells _cells; - std::vector<size_t> _addressBuilder; + DenseTensor::Cells _cells; + std::vector<size_t> _addressBuilder; std::vector<Dimension> _dimensionsMapping; void allocateCellsStorage(); @@ -34,7 +33,7 @@ public: Dimension defineDimension(const vespalib::string &dimension, size_t dimensionSize); DenseTensorBuilder &addLabel(Dimension dimension, size_t label); DenseTensorBuilder &addCell(double value); - Tensor::UP build(); + std::unique_ptr<DenseTensor> build(); }; } diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h index 447d8a4f805..caf92d6c8c7 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_cells_iterator.h @@ -2,10 +2,7 @@ #pragma once -#include <vespa/eval/tensor/tensor.h> -#include <vespa/eval/tensor/types.h> #include <vespa/eval/eval/value_type.h> -#include <vespa/eval/tensor/tensor.h> #include <vespa/vespalib/util/arrayref.h> namespace vespalib::tensor { diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.hpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.hpp index 8480e7418e1..98db89dd2a7 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.hpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_reduce.hpp @@ -25,7 +25,7 @@ public: ~DimensionReducer(); template <typename Function> - DenseTensor::UP + std::unique_ptr<DenseTensorView> reduceCells(CellsRef cellsIn, Function &&func) { auto itr_in = cellsIn.cbegin(); auto itr_out = _cellsResult.begin(); @@ -54,7 +54,7 @@ public: namespace { template <typename Function> -DenseTensor::UP +std::unique_ptr<DenseTensorView> reduce(const DenseTensorView &tensor, const vespalib::string &dimensionToRemove, Function &&func) { DimensionReducer reducer(tensor.fast_type(), dimensionToRemove); @@ -70,9 +70,9 @@ reduce(const DenseTensorView &tensor, const std::vector<vespalib::string> &dimen if (dimensions.size() == 1) { return reduce(tensor, dimensions[0], func); } else if (dimensions.size() > 0) { - DenseTensor::UP result = reduce(tensor, dimensions[0], func); + std::unique_ptr<DenseTensorView> result = reduce(tensor, dimensions[0], func); for (size_t i = 1; i < dimensions.size(); ++i) { - DenseTensor::UP tmpResult = reduce(DenseTensorView(*result), dimensions[i], func); + std::unique_ptr<DenseTensorView> tmpResult = reduce(*result, dimensions[i], func); result = std::move(tmpResult); } return result; diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp index 164ec042384..73b2e7b3ffb 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp @@ -134,14 +134,6 @@ bool sameCells(DenseTensorView::CellsRef lhs, DenseTensorView::CellsRef rhs) } - -DenseTensorView::DenseTensorView(const DenseTensor &rhs) - : _typeRef(rhs.fast_type()), - _cellsRef(rhs.cellsRef()) -{ -} - - bool DenseTensorView::operator==(const DenseTensorView &rhs) const { diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h index 11ed9639cc6..09b6b72375e 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h @@ -2,15 +2,11 @@ #pragma once -#include <vespa/eval/tensor/tensor.h> -#include <vespa/eval/tensor/types.h> -#include <vespa/eval/eval/value_type.h> #include "dense_tensor_cells_iterator.h" +#include <vespa/eval/tensor/tensor.h> namespace vespalib::tensor { -class DenseTensor; - /** * A view to a dense tensor where all dimensions are indexed. * Tensor cells are stored in an underlying array according to the order of the dimensions. @@ -23,26 +19,15 @@ public: using CellsIterator = DenseTensorCellsIterator; using Address = std::vector<eval::ValueType::Dimension::size_type>; -private: - const eval::ValueType &_typeRef; - Tensor::UP reduce_all(join_fun_t op, const std::vector<vespalib::string> &dimensions) const; -protected: - CellsRef _cellsRef; - - void initCellsRef(CellsRef cells_in) { - _cellsRef = cells_in; - } - -public: - explicit DenseTensorView(const DenseTensor &rhs); DenseTensorView(const eval::ValueType &type_in, CellsRef cells_in) : _typeRef(type_in), _cellsRef(cells_in) {} - DenseTensorView(const eval::ValueType &type_in) - : _typeRef(type_in), - _cellsRef() + explicit DenseTensorView(const eval::ValueType &type_in) + : _typeRef(type_in), + _cellsRef() {} + const eval::ValueType &fast_type() const { return _typeRef; } const CellsRef &cellsRef() const { return _cellsRef; } bool operator==(const DenseTensorView &rhs) const; @@ -60,6 +45,15 @@ public: Tensor::UP clone() const override; eval::TensorSpec toSpec() const override; void accept(TensorVisitor &visitor) const override; +protected: + void initCellsRef(CellsRef cells_in) { + _cellsRef = cells_in; + } +private: + Tensor::UP reduce_all(join_fun_t op, const std::vector<vespalib::string> &dimensions) const; + + const eval::ValueType &_typeRef; + CellsRef _cellsRef; }; } diff --git a/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h index 2132f861896..260e71b6f76 100644 --- a/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h +++ b/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h @@ -44,7 +44,7 @@ public: MutableDenseTensorView(eval::ValueType type_in); MutableDenseTensorView(eval::ValueType type_in, CellsRef cells_in); void setCells(CellsRef cells_in) { - _cellsRef = cells_in; + initCellsRef(cells_in); } void setUnboundDimensions(const uint32_t *unboundDimSizeBegin, const uint32_t *unboundDimSizeEnd) { _concreteType.setUnboundDimensions(unboundDimSizeBegin, unboundDimSizeEnd); diff --git a/eval/src/vespa/eval/tensor/join_tensors.h b/eval/src/vespa/eval/tensor/join_tensors.h index 86e5913d8f5..271a6b0195d 100644 --- a/eval/src/vespa/eval/tensor/join_tensors.h +++ b/eval/src/vespa/eval/tensor/join_tensors.h @@ -5,8 +5,7 @@ #include "tensor.h" #include "direct_tensor_builder.h" -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { /* * Join the cells of two tensors. @@ -44,5 +43,4 @@ joinTensorsNegated(const TensorImplType &lhs, return builder.build(); } -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/serialization/common.h b/eval/src/vespa/eval/tensor/serialization/common.h new file mode 100644 index 00000000000..40b1840be6e --- /dev/null +++ b/eval/src/vespa/eval/tensor/serialization/common.h @@ -0,0 +1,9 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +namespace vespalib::tensor { + +enum class SerializeFormat {FLOAT, DOUBLE}; + +} diff --git a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp index feb811a92de..4b1ccc8db5d 100644 --- a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp +++ b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp @@ -3,48 +3,48 @@ #include "dense_binary_format.h" #include <vespa/eval/tensor/dense/dense_tensor.h> #include <vespa/vespalib/objects/nbostream.h> +#include <vespa/vespalib/util/exceptions.h> #include <cassert> using vespalib::nbostream; -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { + +using Dimension = eval::ValueType::Dimension; + namespace { eval::ValueType -makeValueType(std::vector<eval::ValueType::Dimension> &&dimensions) { +makeValueType(std::vector<Dimension> &&dimensions) { return (dimensions.empty() ? eval::ValueType::double_type() : eval::ValueType::tensor_type(std::move(dimensions))); } -} - -void -DenseBinaryFormat::serialize(nbostream &stream, const DenseTensorView &tensor) -{ - stream.putInt1_4Bytes(tensor.fast_type().dimensions().size()); +size_t +encodeDimensions(nbostream &stream, const eval::ValueType & type) { + stream.putInt1_4Bytes(type.dimensions().size()); size_t cellsSize = 1; - for (const auto &dimension : tensor.fast_type().dimensions()) { + for (const auto &dimension : type.dimensions()) { stream.writeSmallString(dimension.name); stream.putInt1_4Bytes(dimension.size); cellsSize *= dimension.size; } - DenseTensorView::CellsRef cells = tensor.cellsRef(); - assert(cells.size() == cellsSize); + return cellsSize; +} + +template<typename T> +void +encodeCells(nbostream &stream, DenseTensorView::CellsRef cells) { for (const auto &value : cells) { - stream << value; + stream << static_cast<T>(value); } } - -std::unique_ptr<DenseTensor> -DenseBinaryFormat::deserialize(nbostream &stream) -{ +size_t +decodeDimensions(nbostream & stream, std::vector<Dimension> & dimensions) { vespalib::string dimensionName; - std::vector<eval::ValueType::Dimension> dimensions; - DenseTensor::Cells cells; size_t dimensionsSize = stream.getInt1_4Bytes(); size_t dimensionSize; size_t cellsSize = 1; @@ -54,16 +54,76 @@ DenseBinaryFormat::deserialize(nbostream &stream) dimensions.emplace_back(dimensionName, dimensionSize); cellsSize *= dimensionSize; } - cells.reserve(cellsSize); - double cellValue = 0.0; + return cellsSize; +} + +template<typename T, typename V> +void +decodeCells(nbostream &stream, size_t cellsSize, V & cells) { + T cellValue = 0.0; for (size_t i = 0; i < cellsSize; ++i) { stream >> cellValue; cells.emplace_back(cellValue); } - return std::make_unique<DenseTensor>(makeValueType(std::move(dimensions)), - std::move(cells)); } +template <typename V> +void decodeCells(SerializeFormat format, nbostream &stream, size_t cellsSize, V & cells) +{ + switch (format) { + case SerializeFormat::DOUBLE: + decodeCells<double>(stream, cellsSize, cells); + break; + case SerializeFormat::FLOAT: + decodeCells<float>(stream, cellsSize, cells); + break; + } +} + +} -} // namespace vespalib::tensor -} // namespace vespalib +void +DenseBinaryFormat::serialize(nbostream &stream, const DenseTensorView &tensor) +{ + size_t cellsSize = encodeDimensions(stream, tensor.fast_type()); + + DenseTensorView::CellsRef cells = tensor.cellsRef(); + assert(cells.size() == cellsSize); + switch (_format) { + case SerializeFormat::DOUBLE: + encodeCells<double>(stream, cells); + break; + case SerializeFormat::FLOAT: + encodeCells<float>(stream, cells); + break; + } +} + +std::unique_ptr<DenseTensor> +DenseBinaryFormat::deserialize(nbostream &stream) +{ + std::vector<Dimension> dimensions; + size_t cellsSize = decodeDimensions(stream,dimensions); + DenseTensor::Cells cells; + cells.reserve(cellsSize); + + decodeCells(_format, stream, cellsSize, cells); + + return std::make_unique<DenseTensor>(makeValueType(std::move(dimensions)), std::move(cells)); +} + +template <typename T> +void +DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<T> & cells) +{ + std::vector<Dimension> dimensions; + size_t cellsSize = decodeDimensions(stream,dimensions); + cells.clear(); + cells.reserve(cellsSize); + decodeCells(_format, stream, cellsSize, cells); +} + +template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<double> & cells); +template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<float> & cells); + +} diff --git a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h index 8019648ffcb..f9847d37784 100644 --- a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h +++ b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h @@ -2,12 +2,13 @@ #pragma once +#include "common.h" #include <memory> -namespace vespalib { +#include <vector> -class nbostream; +namespace vespalib { class nbostream; } -namespace tensor { +namespace vespalib::tensor { class DenseTensor; class DenseTensorView; @@ -18,9 +19,15 @@ class DenseTensorView; class DenseBinaryFormat { public: - static void serialize(nbostream &stream, const DenseTensorView &tensor); - static std::unique_ptr<DenseTensor> deserialize(nbostream &stream); + DenseBinaryFormat(SerializeFormat format) : _format(format) { } + void serialize(nbostream &stream, const DenseTensorView &tensor); + std::unique_ptr<DenseTensor> deserialize(nbostream &stream); + + // This is a temporary method untill we get full support for typed tensors + template <typename T> + void deserializeCellsOnly(nbostream &stream, std::vector<T> & cells); +private: + SerializeFormat _format; }; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/serialization/slime_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/slime_binary_format.cpp index 7ae3957dc0f..ece3c2e4a07 100644 --- a/eval/src/vespa/eval/tensor/serialization/slime_binary_format.cpp +++ b/eval/src/vespa/eval/tensor/serialization/slime_binary_format.cpp @@ -10,8 +10,7 @@ #include <vespa/vespalib/data/slime/slime.h> #include <vespa/vespalib/data/memory.h> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { using slime::Inserter; @@ -58,13 +57,10 @@ SlimeBinaryFormatSerializer::SlimeBinaryFormatSerializer(Inserter &inserter) } -SlimeBinaryFormatSerializer::~SlimeBinaryFormatSerializer() -{ -} +SlimeBinaryFormatSerializer::~SlimeBinaryFormatSerializer() = default; void -SlimeBinaryFormatSerializer::visit(const TensorAddress &address, - double value) +SlimeBinaryFormatSerializer::visit(const TensorAddress &address, double value) { Cursor &cellCursor = _cells.addObject(); writeTensorAddress(cellCursor, address); @@ -101,5 +97,4 @@ SlimeBinaryFormat::serialize(const Tensor &tensor) } -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/serialization/slime_binary_format.h b/eval/src/vespa/eval/tensor/serialization/slime_binary_format.h index f1366c64e2c..c9e9ff2c3e9 100644 --- a/eval/src/vespa/eval/tensor/serialization/slime_binary_format.h +++ b/eval/src/vespa/eval/tensor/serialization/slime_binary_format.h @@ -4,13 +4,11 @@ #include <memory> -namespace vespalib { +namespace vespalib { class Slime; } -class Slime; +namespace vespalib::slime { struct Inserter; } -namespace slime { struct Inserter; } - -namespace tensor { +namespace vespalib::tensor { class Tensor; class TensorBuilder; @@ -25,5 +23,4 @@ public: static std::unique_ptr<Slime> serialize(const Tensor &tensor); }; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp index bd0c5b25f93..79d1aaa83a8 100644 --- a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp +++ b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp @@ -11,8 +11,7 @@ using vespalib::nbostream; -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { namespace { @@ -59,13 +58,10 @@ SparseBinaryFormatSerializer::SparseBinaryFormatSerializer() } -SparseBinaryFormatSerializer::~SparseBinaryFormatSerializer() -{ -} +SparseBinaryFormatSerializer::~SparseBinaryFormatSerializer() = default; void -SparseBinaryFormatSerializer::visit(const TensorAddress &address, - double value) +SparseBinaryFormatSerializer::visit(const TensorAddress &address, double value) { ++_numCells; writeTensorAddress(_cells, _type, address); @@ -74,8 +70,7 @@ SparseBinaryFormatSerializer::visit(const TensorAddress &address, void -SparseBinaryFormatSerializer::serialize(nbostream &stream, - const Tensor &tensor) +SparseBinaryFormatSerializer::serialize(nbostream &stream, const Tensor &tensor) { _type = tensor.type(); tensor.accept(*this); @@ -121,5 +116,4 @@ SparseBinaryFormat::deserialize(nbostream &stream, TensorBuilder &builder) } -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h index db05574dfce..89f6947ad43 100644 --- a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h +++ b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h @@ -2,11 +2,9 @@ #pragma once -namespace vespalib { +namespace vespalib { class nbostream; } -class nbostream; - -namespace tensor { +namespace vespalib::tensor { class Tensor; class TensorBuilder; @@ -21,5 +19,4 @@ public: static void deserialize(nbostream &stream, TensorBuilder &builder); }; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp index fe35ce4c831..4ca037e82a4 100644 --- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp +++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp @@ -11,20 +11,64 @@ #include <vespa/eval/tensor/wrapped_simple_tensor.h> #include <vespa/log/log.h> +#include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/util/exceptions.h> + LOG_SETUP(".eval.tensor.serialization.typed_binary_format"); using vespalib::nbostream; -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { + +namespace { + +constexpr uint32_t SPARSE_BINARY_FORMAT_TYPE = 1u; +constexpr uint32_t DENSE_BINARY_FORMAT_TYPE = 2u; +constexpr uint32_t MIXED_BINARY_FORMAT_TYPE = 3u; +constexpr uint32_t SPARSE_BINARY_FORMAT_WITH_CELLTYPE = 5u; //Future +constexpr uint32_t DENSE_BINARY_FORMAT_WITH_CELLTYPE = 6u; +constexpr uint32_t MIXED_BINARY_FORMAT_WITH_CELLTYPE = 7u; //Future + +constexpr uint32_t DOUBLE_VALUE_TYPE = 0; +constexpr uint32_t FLOAT_VALUE_TYPE = 1; + +uint32_t +format2Encoding(SerializeFormat format) { + switch (format) { + case SerializeFormat::DOUBLE: + return DOUBLE_VALUE_TYPE; + case SerializeFormat::FLOAT: + return FLOAT_VALUE_TYPE; + } + abort(); +} + +SerializeFormat +encoding2Format(uint32_t serializedType) { + switch (serializedType) { + case DOUBLE_VALUE_TYPE: + return SerializeFormat::DOUBLE; + case FLOAT_VALUE_TYPE: + return SerializeFormat::FLOAT; + default: + throw IllegalArgumentException(make_string("Received unknown tensor value type = %u. Only 0(double), or 1(float) are legal.", serializedType)); + } +} +} void -TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor) +TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor, SerializeFormat format) { if (auto denseTensor = dynamic_cast<const DenseTensorView *>(&tensor)) { - stream.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE); - DenseBinaryFormat::serialize(stream, *denseTensor); + if (format != SerializeFormat::DOUBLE) { + stream.putInt1_4Bytes(DENSE_BINARY_FORMAT_WITH_CELLTYPE); + stream.putInt1_4Bytes(format2Encoding(format)); + DenseBinaryFormat(format).serialize(stream, *denseTensor); + } else { + stream.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE); + DenseBinaryFormat(SerializeFormat::DOUBLE).serialize(stream, *denseTensor); + } } else if (auto wrapped = dynamic_cast<const WrappedSimpleTensor *>(&tensor)) { eval::SimpleTensor::encode(wrapped->get(), stream); } else { @@ -45,15 +89,33 @@ TypedBinaryFormat::deserialize(nbostream &stream) return builder.build(); } if (formatId == DENSE_BINARY_FORMAT_TYPE) { - return DenseBinaryFormat::deserialize(stream); + return DenseBinaryFormat(SerializeFormat::DOUBLE).deserialize(stream); + } + if (formatId == DENSE_BINARY_FORMAT_WITH_CELLTYPE) { + return DenseBinaryFormat(encoding2Format(stream.getInt1_4Bytes())).deserialize(stream); } if (formatId == MIXED_BINARY_FORMAT_TYPE) { stream.adjustReadPos(read_pos - stream.rp()); return std::make_unique<WrappedSimpleTensor>(eval::SimpleTensor::decode(stream)); } - LOG_ABORT("should not be reached"); + abort(); } +template <typename T> +void +TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<T> & cells) +{ + auto formatId = stream.getInt1_4Bytes(); + if (formatId == DENSE_BINARY_FORMAT_TYPE) { + return DenseBinaryFormat(SerializeFormat::DOUBLE).deserializeCellsOnly(stream, cells); + } + if (formatId == DENSE_BINARY_FORMAT_WITH_CELLTYPE) { + return DenseBinaryFormat(encoding2Format(stream.getInt1_4Bytes())).deserializeCellsOnly(stream, cells); + } + abort(); +} -} // namespace vespalib::tensor -} // namespace vespalib +template void TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<double> & cells); +template void TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<float> & cells); + +} diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h index c655210907f..717d51effef 100644 --- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h +++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h @@ -2,30 +2,32 @@ #pragma once +#include "common.h" #include <memory> -#include <cstdint> +#include <vector> -namespace vespalib { +namespace vespalib { class nbostream; } -class nbostream; - -namespace tensor { +namespace vespalib::tensor { class Tensor; -class TensorBuilder; /** * Class for serializing a tensor. */ class TypedBinaryFormat { - static constexpr uint32_t SPARSE_BINARY_FORMAT_TYPE = 1u; - static constexpr uint32_t DENSE_BINARY_FORMAT_TYPE = 2u; - static constexpr uint32_t MIXED_BINARY_FORMAT_TYPE = 3u; public: - static void serialize(nbostream &stream, const Tensor &tensor); + static void serialize(nbostream &stream, const Tensor &tensor, SerializeFormat format); + static void serialize(nbostream &stream, const Tensor &tensor) { + serialize(stream, tensor, SerializeFormat::DOUBLE); + } + static std::unique_ptr<Tensor> deserialize(nbostream &stream); + + // This is a temporary method until we get full support for typed tensors + template <typename T> + static void deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<T> & cells); }; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/tensor.cpp b/eval/src/vespa/eval/tensor/tensor.cpp index 8715a864f68..51c94aab5b0 100644 --- a/eval/src/vespa/eval/tensor/tensor.cpp +++ b/eval/src/vespa/eval/tensor/tensor.cpp @@ -4,8 +4,7 @@ #include "default_tensor_engine.h" #include <sstream> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { Tensor::Tensor() : eval::Tensor(DefaultTensorEngine::ref()) @@ -34,5 +33,4 @@ operator<<(std::ostream &out, const Tensor &value) return out; } -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/tensor.h b/eval/src/vespa/eval/tensor/tensor.h index 4061ed9c115..edf5fa710e3 100644 --- a/eval/src/vespa/eval/tensor/tensor.h +++ b/eval/src/vespa/eval/tensor/tensor.h @@ -9,9 +9,8 @@ #include <vespa/eval/eval/tensor_spec.h> #include <vespa/eval/eval/value_type.h> -namespace vespalib { -namespace eval { struct BinaryOperation; } -namespace tensor { +namespace vespalib::eval { struct BinaryOperation; } +namespace vespalib::tensor { class TensorVisitor; class CellValues; @@ -66,5 +65,4 @@ public: std::ostream &operator<<(std::ostream &out, const Tensor &value); -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/tensor_address.cpp b/eval/src/vespa/eval/tensor/tensor_address.cpp index afadcf2c668..a68fc5d3353 100644 --- a/eval/src/vespa/eval/tensor/tensor_address.cpp +++ b/eval/src/vespa/eval/tensor/tensor_address.cpp @@ -4,19 +4,18 @@ #include <algorithm> #include <ostream> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { const vespalib::string TensorAddress::Element::UNDEFINED_LABEL = "(undefined)"; -TensorAddress::Element::~Element() {} +TensorAddress::Element::~Element() = default; TensorAddress::TensorAddress() : _elements() { } -TensorAddress::~TensorAddress() {} +TensorAddress::~TensorAddress() = default; TensorAddress::TensorAddress(const Elements &elements_in) : _elements(elements_in) @@ -87,5 +86,4 @@ operator<<(std::ostream &out, const TensorAddress &value) return out; } -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/tensor_address_element_iterator.h b/eval/src/vespa/eval/tensor/tensor_address_element_iterator.h index e413712362f..01710105840 100644 --- a/eval/src/vespa/eval/tensor/tensor_address_element_iterator.h +++ b/eval/src/vespa/eval/tensor/tensor_address_element_iterator.h @@ -2,11 +2,10 @@ #pragma once -#include <vespa/vespalib/stllike/hash_set.h> +#include <vespa/vespalib/stllike/string.h> namespace vespalib::tensor { -using DimensionsSet = vespalib::hash_set<vespalib::stringref>; /** * An iterator for tensor address elements used to simplify 3-way merge diff --git a/eval/src/vespa/eval/tensor/tensor_builder.h b/eval/src/vespa/eval/tensor/tensor_builder.h index 05238b27df5..30eef5f9c54 100644 --- a/eval/src/vespa/eval/tensor/tensor_builder.h +++ b/eval/src/vespa/eval/tensor/tensor_builder.h @@ -4,8 +4,7 @@ #include <vespa/vespalib/stllike/string.h> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { class Tensor; @@ -30,5 +29,4 @@ public: virtual std::unique_ptr<Tensor> build() = 0; }; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/tensor_factory.cpp b/eval/src/vespa/eval/tensor/tensor_factory.cpp index f88ae22c083..0b7fa3b9c2e 100644 --- a/eval/src/vespa/eval/tensor/tensor_factory.cpp +++ b/eval/src/vespa/eval/tensor/tensor_factory.cpp @@ -5,12 +5,10 @@ #include "tensor_builder.h" #include <vespa/eval/tensor/dense/dense_tensor_builder.h> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { std::unique_ptr<Tensor> -TensorFactory::create(const TensorCells &cells, - TensorBuilder &builder) { +TensorFactory::create(const TensorCells &cells, TensorBuilder &builder) { for (const auto &cell : cells) { for (const auto &addressElem : cell.first) { const auto &dimension = addressElem.first; @@ -30,9 +28,7 @@ TensorFactory::create(const TensorCells &cells, std::unique_ptr<Tensor> -TensorFactory::create(const TensorCells &cells, - const TensorDimensions &dimensions, - TensorBuilder &builder) { +TensorFactory::create(const TensorCells &cells, const TensorDimensions &dimensions, TensorBuilder &builder) { for (const auto &dimension : dimensions) { builder.define_dimension(dimension); } @@ -47,17 +43,12 @@ TensorFactory::createDense(const DenseTensorCells &cells) DenseTensorBuilder builder; for (const auto &cell : cells) { for (const auto &addressElem : cell.first) { - dimensionSizes[addressElem.first] = - std::max(dimensionSizes[addressElem.first], - (addressElem.second + 1)); + dimensionSizes[addressElem.first] = std::max(dimensionSizes[addressElem.first], (addressElem.second + 1)); } } - std::map<std::string, - typename DenseTensorBuilder::Dimension> dimensionEnums; + std::map<std::string, typename DenseTensorBuilder::Dimension> dimensionEnums; for (const auto &dimensionElem : dimensionSizes) { - dimensionEnums[dimensionElem.first] = - builder.defineDimension(dimensionElem.first, - dimensionElem.second); + dimensionEnums[dimensionElem.first] = builder.defineDimension(dimensionElem.first, dimensionElem.second); } for (const auto &cell : cells) { for (const auto &addressElem : cell.first) { @@ -71,5 +62,4 @@ TensorFactory::createDense(const DenseTensorCells &cells) } -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/tensor_factory.h b/eval/src/vespa/eval/tensor/tensor_factory.h index 5fe31afc4dd..5364c28c8ff 100644 --- a/eval/src/vespa/eval/tensor/tensor_factory.h +++ b/eval/src/vespa/eval/tensor/tensor_factory.h @@ -4,8 +4,7 @@ #include "types.h" -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { class Tensor; @@ -20,11 +19,9 @@ public: static std::unique_ptr<Tensor> create(const TensorCells &cells, TensorBuilder &builder); static std::unique_ptr<Tensor> - create(const TensorCells &cells, const TensorDimensions &dimensions, - TensorBuilder &builder); + create(const TensorCells &cells, const TensorDimensions &dimensions, TensorBuilder &builder); static std::unique_ptr<Tensor> createDense(const DenseTensorCells &cells); }; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/tensor_mapper.cpp b/eval/src/vespa/eval/tensor/tensor_mapper.cpp index c91237e4994..dbf1965d441 100644 --- a/eval/src/vespa/eval/tensor/tensor_mapper.cpp +++ b/eval/src/vespa/eval/tensor/tensor_mapper.cpp @@ -53,9 +53,7 @@ SparseTensorMapper(const ValueType &type) } template <class TensorT> -SparseTensorMapper<TensorT>::~SparseTensorMapper() -{ -} +SparseTensorMapper<TensorT>::~SparseTensorMapper() = default; template <class TensorT> std::unique_ptr<Tensor> diff --git a/eval/src/vespa/eval/tensor/tensor_mapper.h b/eval/src/vespa/eval/tensor/tensor_mapper.h index 99994bd15e8..95c6cce8fc6 100644 --- a/eval/src/vespa/eval/tensor/tensor_mapper.h +++ b/eval/src/vespa/eval/tensor/tensor_mapper.h @@ -4,8 +4,7 @@ #include <vespa/eval/eval/value_type.h> -namespace vespalib { -namespace tensor { +namespace vespalib::tensor { class Tensor; @@ -42,5 +41,4 @@ public: }; -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index c4c3021d607..0612cee040c 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -14,6 +14,24 @@ import static com.yahoo.vespa.flags.FetchVector.Dimension.HOSTNAME; import static com.yahoo.vespa.flags.FetchVector.Dimension.NODE_TYPE; /** + * Definitions of feature flags. + * + * <p>To use feature flags, define the flag in this class as an "unbound" flag, e.g. {@link UnboundBooleanFlag} + * or {@link UnboundStringFlag}. At the location you want to get the value of the flag, you need the following:</p> + * + * <ol> + * <li>The unbound flag</li> + * <li>A {@link FlagSource}. The flag source is typically available as an injectible component. Binding + * an unbound flag to a flag source produces a (bound) flag, e.g. {@link BooleanFlag} and {@link StringFlag}.</li> + * <li>If you would like your flag value to be dependent on e.g. the application ID, then 1. you should + * declare this in the unbound flag definition in this file (referring to + * {@link FetchVector.Dimension#APPLICATION_ID}), and 2. specify the application ID when retrieving the value, e.g. + * {@link BooleanFlag#with(FetchVector.Dimension, String)}. See {@link FetchVector} for more info.</li> + * </ol> + * + * <p>Once the code is in place, you can override the flag value. This depends on the flag source, but typically + * there is a REST API for updating the flags in the config server, which is the root of all flag sources in the zone.</p> + * * @author hakonhall */ public class Flags { @@ -137,6 +155,12 @@ public class Flags { "Takes effect at redeployment", APPLICATION_ID); + public static final UnboundBooleanFlag REDIRECT_LEGACY_DNS_NAMES = defineFeatureFlag( + "redirect-legacy-dns", false, + "Redirect legacy DNS names to the main DNS name", + "Takes effect on deployment through controller", + APPLICATION_ID); + /** WARNING: public for testing: All flags should be defined in {@link Flags}. */ public static UnboundBooleanFlag defineFeatureFlag(String flagId, boolean defaultValue, String description, String modificationEffect, FetchVector.Dimension... dimensions) { diff --git a/jdisc_core/src/main/java/com/yahoo/jdisc/Request.java b/jdisc_core/src/main/java/com/yahoo/jdisc/Request.java index 061d803b978..466e74202c1 100644 --- a/jdisc_core/src/main/java/com/yahoo/jdisc/Request.java +++ b/jdisc_core/src/main/java/com/yahoo/jdisc/Request.java @@ -12,6 +12,7 @@ import com.yahoo.jdisc.service.CurrentContainer; import com.yahoo.jdisc.service.ServerProvider; import java.net.URI; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -35,7 +36,7 @@ import java.util.concurrent.TimeUnit; */ public class Request extends AbstractResource { - private final Map<String, Object> context = new HashMap<>(); + private final Map<String, Object> context = Collections.synchronizedMap(new HashMap<>()); private final HeaderFields headers = new HeaderFields(); private final Container container; private final Request parent; @@ -205,10 +206,6 @@ public class Request extends AbstractResource { * <p>Returns the named application context objects. This data is not intended for network transport, rather they * are intended for passing shared data between components of an Application.</p> * - * <p>Modifying the context map is a thread-unsafe operation -- any changes made after calling {@link - * #connect(ResponseHandler)} might never become visible to other threads, and might throw - * ConcurrentModificationExceptions in other threads.</p> - * * @return The context map. */ public Map<String, Object> context() { diff --git a/jdisc_core/src/main/java/com/yahoo/jdisc/application/UriPattern.java b/jdisc_core/src/main/java/com/yahoo/jdisc/application/UriPattern.java index 0e6e5d28260..350d8170987 100644 --- a/jdisc_core/src/main/java/com/yahoo/jdisc/application/UriPattern.java +++ b/jdisc_core/src/main/java/com/yahoo/jdisc/application/UriPattern.java @@ -65,7 +65,7 @@ public class UriPattern implements Comparable<UriPattern> { if (!matcher.find()) { throw new IllegalArgumentException(uri); } - scheme = GlobPattern.compile(resolvePatternComponent(matcher.group(1))); + scheme = GlobPattern.compile(normalizeScheme(resolvePatternComponent(matcher.group(1)))); host = GlobPattern.compile(resolvePatternComponent(matcher.group(2))); port = resolvePortPattern(matcher.group(4)); path = GlobPattern.compile(resolvePatternComponent(matcher.group(7))); @@ -91,7 +91,7 @@ public class UriPattern implements Comparable<UriPattern> { return null; } // Match scheme before host because it has a higher chance of differing (e.g. http versus https) - GlobPattern.Match schemeMatch = scheme.match(resolveUriComponent(uri.getScheme())); + GlobPattern.Match schemeMatch = scheme.match(normalizeScheme(resolveUriComponent(uri.getScheme()))); if (schemeMatch == null) { return null; } @@ -172,6 +172,11 @@ public class UriPattern implements Comparable<UriPattern> { } } + private static String normalizeScheme(String scheme) { + if (scheme.equals("https")) return "http"; // handle 'https' in bindings and uris as 'http' + return scheme; + } + /** * <p>This class holds the result of a {@link UriPattern#match(URI)} operation. It contains methods to inspect the * groups captured during matching, where a <em>group</em> is defined as a sequence of characters matches by a diff --git a/jdisc_core/src/test/java/com/yahoo/jdisc/application/UriPatternTestCase.java b/jdisc_core/src/test/java/com/yahoo/jdisc/application/UriPatternTestCase.java index c91a7134c3a..d2499bbf369 100644 --- a/jdisc_core/src/test/java/com/yahoo/jdisc/application/UriPatternTestCase.java +++ b/jdisc_core/src/test/java/com/yahoo/jdisc/application/UriPatternTestCase.java @@ -295,6 +295,15 @@ public class UriPatternTestCase { assertMatch(httpsPattern, "https://host/path", NO_GROUPS); } + @Test + public void requireThatHttpsSchemeIsHandledAsHttp() { + UriPattern httpPattern = new UriPattern("http://host:80/path"); + assertMatch(httpPattern, "https://host:80/path", NO_GROUPS); + + UriPattern httpsPattern = new UriPattern("https://host:443/path"); + assertMatch(httpsPattern, "http://host:443/path", NO_GROUPS); + } + private static void assertIllegalPattern(String uri) { try { new UriPattern(uri); diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollector.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollector.java index 1e92fbef967..4239d2120cf 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollector.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollector.java @@ -16,10 +16,9 @@ import javax.servlet.AsyncListener; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; - import java.io.IOException; -import java.util.HashMap; -import java.util.Map; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.Future; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; @@ -40,19 +39,25 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G GET, PATCH, POST, PUT, DELETE, OPTIONS, HEAD, OTHER } + public enum HttpScheme { + HTTP, HTTPS, OTHER + } + private static final String[] HTTP_RESPONSE_GROUPS = { Metrics.RESPONSES_1XX, Metrics.RESPONSES_2XX, Metrics.RESPONSES_3XX, Metrics.RESPONSES_4XX, Metrics.RESPONSES_5XX, Metrics.RESPONSES_401, Metrics.RESPONSES_403}; private final AtomicLong inFlight = new AtomicLong(); - private final LongAdder statistics[][]; + private final LongAdder statistics[][][]; public HttpResponseStatisticsCollector() { super(); - statistics = new LongAdder[HttpMethod.values().length][]; - for (int method = 0; method < statistics.length; method++) { - statistics[method] = new LongAdder[HTTP_RESPONSE_GROUPS.length]; - for (int group = 0; group < HTTP_RESPONSE_GROUPS.length; group++) { - statistics[method][group] = new LongAdder(); + statistics = new LongAdder[HttpScheme.values().length][HttpMethod.values().length][]; + for (int scheme = 0; scheme < HttpScheme.values().length; ++scheme) { + for (int method = 0; method < HttpMethod.values().length; method++) { + statistics[scheme][method] = new LongAdder[HTTP_RESPONSE_GROUPS.length]; + for (int group = 0; group < HTTP_RESPONSE_GROUPS.length; group++) { + statistics[scheme][method][group] = new LongAdder(); + } } } } @@ -110,10 +115,11 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G private void observeEndOfRequest(Request request, HttpServletResponse flushableResponse) throws IOException { int group = groupIndex(request); if (group >= 0) { + HttpScheme scheme = getScheme(request); HttpMethod method = getMethod(request); - statistics[method.ordinal()][group].increment(); + statistics[scheme.ordinal()][method.ordinal()][group].increment(); if (group == 5 || group == 6) { // if 401/403, also increment 4xx - statistics[method.ordinal()][3].increment(); + statistics[scheme.ordinal()][method.ordinal()][3].increment(); } } @@ -146,6 +152,17 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G } } + private HttpScheme getScheme(Request request) { + switch (request.getScheme()) { + case "http": + return HttpScheme.HTTP; + case "https": + return HttpScheme.HTTPS; + default: + return HttpScheme.OTHER; + } + } + private HttpMethod getMethod(Request request) { switch (request.getMethod()) { case "GET": @@ -167,17 +184,18 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G } } - public Map<String, Map<String, Long>> takeStatisticsByMethod() { - Map<String, Map<String, Long>> ret = new HashMap<>(); - - for (HttpMethod method : HttpMethod.values()) { - int methodIndex = method.ordinal(); - Map<String, Long> methodStats = new HashMap<>(); - ret.put(method.toString(), methodStats); - - for (int group = 0; group < HTTP_RESPONSE_GROUPS.length; group++) { - long value = statistics[methodIndex][group].sumThenReset(); - methodStats.put(HTTP_RESPONSE_GROUPS[group], value); + public List<StatisticsEntry> takeStatistics() { + var ret = new ArrayList<StatisticsEntry>(); + for (HttpScheme scheme : HttpScheme.values()) { + int schemeIndex = scheme.ordinal(); + for (HttpMethod method : HttpMethod.values()) { + int methodIndex = method.ordinal(); + for (int group = 0; group < HTTP_RESPONSE_GROUPS.length; group++) { + long value = statistics[schemeIndex][methodIndex][group].sumThenReset(); + if (value > 0) { + ret.add(new StatisticsEntry(scheme.name().toLowerCase(), method.name(), HTTP_RESPONSE_GROUPS[group], value)); + } + } } } return ret; @@ -216,4 +234,19 @@ public class HttpResponseStatisticsCollector extends HandlerWrapper implements G FutureCallback futureCallback = shutdown.get(); return futureCallback != null && futureCallback.isDone(); } + + public static class StatisticsEntry { + public final String scheme; + public final String method; + public final String name; + public final long value; + + + public StatisticsEntry(String scheme, String method, String name, long value) { + this.scheme = scheme; + this.method = method; + this.name = name; + this.value = value; + } + } } diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscServerConnector.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscServerConnector.java index 6b371473a57..556d80d3772 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscServerConnector.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscServerConnector.java @@ -17,6 +17,7 @@ import java.net.SocketException; import java.nio.channels.ServerSocketChannel; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.logging.Level; import java.util.logging.Logger; @@ -28,7 +29,7 @@ class JDiscServerConnector extends ServerConnector { public static final String REQUEST_ATTRIBUTE = JDiscServerConnector.class.getName(); private final static Logger log = Logger.getLogger(JDiscServerConnector.class.getName()); private final Metric.Context metricCtx; - private final Map<String, Metric.Context> requestMetricContextCache = new ConcurrentHashMap<>(); + private final Map<RequestDimensions, Metric.Context> requestMetricContextCache = new ConcurrentHashMap<>(); private final ServerConnectionStatistics statistics; private final boolean tcpKeepAlive; private final boolean tcpNoDelay; @@ -124,9 +125,12 @@ class JDiscServerConnector extends ServerConnector { public Metric.Context getRequestMetricContext(HttpServletRequest request) { String method = request.getMethod(); - return requestMetricContextCache.computeIfAbsent(method, ignored -> { + String scheme = request.getScheme(); + var requestDimensions = new RequestDimensions(method, scheme); + return requestMetricContextCache.computeIfAbsent(requestDimensions, ignored -> { Map<String, Object> dimensions = createConnectorDimensions(listenPort, connectorName); dimensions.put(JettyHttpServer.Metrics.METHOD_DIMENSION, method); + dimensions.put(JettyHttpServer.Metrics.SCHEME_DIMENSION, scheme); return metric.createContext(dimensions); }); } @@ -142,4 +146,27 @@ class JDiscServerConnector extends ServerConnector { return props; } + private static class RequestDimensions { + final String method; + final String scheme; + + RequestDimensions(String method, String scheme) { + this.method = method; + this.scheme = scheme; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RequestDimensions that = (RequestDimensions) o; + return Objects.equals(method, that.method) && Objects.equals(scheme, that.scheme); + } + + @Override + public int hashCode() { + return Objects.hash(method, scheme); + } + } + } diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java index 0dbc5f59f67..30a1b1d885c 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java @@ -8,7 +8,6 @@ import com.yahoo.component.ComponentId; import com.yahoo.component.provider.ComponentRegistry; import com.yahoo.container.logging.AccessLog; import com.yahoo.jdisc.Metric; -import com.yahoo.jdisc.Metric.Context; import com.yahoo.jdisc.application.OsgiFramework; import com.yahoo.jdisc.http.ServerConfig; import com.yahoo.jdisc.http.ServletPathsConfig; @@ -71,6 +70,7 @@ public class JettyHttpServer extends AbstractServerProvider { String NAME_DIMENSION = "serverName"; String PORT_DIMENSION = "serverPort"; String METHOD_DIMENSION = "httpMethod"; + String SCHEME_DIMENSION = "scheme"; String NUM_OPEN_CONNECTIONS = "serverNumOpenConnections"; String NUM_CONNECTIONS_OPEN_MAX = "serverConnectionsOpenMax"; @@ -357,13 +357,12 @@ public class JettyHttpServer extends AbstractServerProvider { } private void addResponseMetrics(HttpResponseStatisticsCollector statisticsCollector) { - Map<String, Map<String, Long>> statistics = statisticsCollector.takeStatisticsByMethod(); - statistics.forEach((httpMethod, statsByResponseType) -> { + for (var metricEntry : statisticsCollector.takeStatistics()) { Map<String, Object> dimensions = new HashMap<>(); - dimensions.put(Metrics.METHOD_DIMENSION, httpMethod); - Context ctx = metric.createContext(dimensions); - statsByResponseType.forEach((group, value) -> metric.add(group, value, ctx)); - }); + dimensions.put(Metrics.METHOD_DIMENSION, metricEntry.method); + dimensions.put(Metrics.SCHEME_DIMENSION, metricEntry.scheme); + metric.add(metricEntry.name, metricEntry.value, metric.createContext(dimensions)); + } } private void setConnectorMetrics(JDiscServerConnector connector) { diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollectorTest.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollectorTest.java index 3c23a2b0937..df2308f6dd0 100644 --- a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollectorTest.java +++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpResponseStatisticsCollectorTest.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.jdisc.http.server.jetty; +import com.yahoo.jdisc.http.server.jetty.HttpResponseStatisticsCollector.StatisticsEntry; import com.yahoo.jdisc.http.server.jetty.JettyHttpServer.Metrics; import org.eclipse.jetty.http.HttpFields; import org.eclipse.jetty.http.HttpURI; @@ -22,10 +23,9 @@ import org.testng.annotations.Test; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; - import java.io.IOException; import java.nio.ByteBuffer; -import java.util.Map; +import java.util.List; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; @@ -40,55 +40,62 @@ public class HttpResponseStatisticsCollectorTest { @Test public void statistics_are_aggregated_by_category() throws Exception { - testRequest(300, "GET"); - testRequest(301, "GET"); - testRequest(200, "GET"); + testRequest("http", 300, "GET"); + testRequest("http", 301, "GET"); + testRequest("http", 200, "GET"); - Map<String, Map<String, Long>> stats = collector.takeStatisticsByMethod(); - assertThat(stats.get("GET").get(Metrics.RESPONSES_2XX), equalTo(1L)); - assertThat(stats.get("GET").get(Metrics.RESPONSES_3XX), equalTo(2L)); + var stats = collector.takeStatistics(); + assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_2XX, 1L); + assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_3XX, 2L); } @Test - public void statistics_are_grouped_by_http_method() throws Exception { - testRequest(200, "GET"); - testRequest(200, "PUT"); - testRequest(200, "POST"); - testRequest(200, "POST"); - testRequest(404, "GET"); - - Map<String, Map<String, Long>> stats = collector.takeStatisticsByMethod(); - assertThat(stats.get("GET").get(Metrics.RESPONSES_2XX), equalTo(1L)); - assertThat(stats.get("GET").get(Metrics.RESPONSES_4XX), equalTo(1L)); - assertThat(stats.get("PUT").get(Metrics.RESPONSES_2XX), equalTo(1L)); - assertThat(stats.get("POST").get(Metrics.RESPONSES_2XX), equalTo(2L)); + public void statistics_are_grouped_by_http_method_and_scheme() throws Exception { + testRequest("http", 200, "GET"); + testRequest("http", 200, "PUT"); + testRequest("http", 200, "POST"); + testRequest("http", 200, "POST"); + testRequest("http", 404, "GET"); + testRequest("https", 404, "GET"); + testRequest("https", 200, "POST"); + testRequest("https", 200, "POST"); + testRequest("https", 200, "POST"); + testRequest("https", 200, "POST"); + + var stats = collector.takeStatistics(); + assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_2XX, 1L); + assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_4XX, 1L); + assertStatisticsEntryPresent(stats, "http", "PUT", Metrics.RESPONSES_2XX, 1L); + assertStatisticsEntryPresent(stats, "http", "POST", Metrics.RESPONSES_2XX, 2L); + assertStatisticsEntryPresent(stats, "https", "GET", Metrics.RESPONSES_4XX, 1L); + assertStatisticsEntryPresent(stats, "https", "POST", Metrics.RESPONSES_2XX, 4L); } @Test public void statistics_include_grouped_and_single_statuscodes() throws Exception { - testRequest(401, "GET"); - testRequest(404, "GET"); - testRequest(403, "GET"); + testRequest("http", 401, "GET"); + testRequest("http", 404, "GET"); + testRequest("http", 403, "GET"); - Map<String, Map<String, Long>> stats = collector.takeStatisticsByMethod(); - assertThat(stats.get("GET").get(Metrics.RESPONSES_4XX), equalTo(3L)); - assertThat(stats.get("GET").get(Metrics.RESPONSES_401), equalTo(1L)); - assertThat(stats.get("GET").get(Metrics.RESPONSES_403), equalTo(1L)); + var stats = collector.takeStatistics(); + assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_4XX, 3L); + assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_401, 1L); + assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_403, 1L); } @Test public void retrieving_statistics_resets_the_counters() throws Exception { - testRequest(200, "GET"); - testRequest(200, "GET"); + testRequest("http", 200, "GET"); + testRequest("http", 200, "GET"); - Map<String, Map<String, Long>> stats = collector.takeStatisticsByMethod(); - assertThat(stats.get("GET").get(Metrics.RESPONSES_2XX), equalTo(2L)); + var stats = collector.takeStatistics(); + assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_2XX, 2L); - testRequest(200, "GET"); + testRequest("http", 200, "GET"); - stats = collector.takeStatisticsByMethod(); - assertThat(stats.get("GET").get(Metrics.RESPONSES_2XX), equalTo(1L)); + stats = collector.takeStatistics(); + assertStatisticsEntryPresent(stats, "http", "GET", Metrics.RESPONSES_2XX, 1L); } @BeforeTest @@ -116,9 +123,9 @@ public class HttpResponseStatisticsCollectorTest { server.start(); } - private Request testRequest(int responseCode, String httpMethod) throws Exception { + private Request testRequest(String scheme, int responseCode, String httpMethod) throws Exception { HttpChannel channel = new HttpChannel(connector, new HttpConfiguration(), null, new DummyTransport()); - MetaData.Request metaData = new MetaData.Request(httpMethod, new HttpURI("http://foo/bar"), HttpVersion.HTTP_1_1, new HttpFields()); + MetaData.Request metaData = new MetaData.Request(httpMethod, new HttpURI(scheme + "://foo/bar"), HttpVersion.HTTP_1_1, new HttpFields()); Request req = channel.getRequest(); req.setMetaData(metaData); @@ -127,6 +134,15 @@ public class HttpResponseStatisticsCollectorTest { return req; } + private static void assertStatisticsEntryPresent(List<StatisticsEntry> result, String scheme, String method, String name, long expectedValue) { + long value = result.stream() + .filter(entry -> entry.method.equals(method) && entry.scheme.equals(scheme) && entry.name.equals(name)) + .mapToLong(entry -> entry.value) + .findAny() + .orElseThrow(() -> new AssertionError(String.format("Not matching entry in result (scheme=%s, method=%s, name=%s)", scheme, method, name))); + assertThat(value, equalTo(expectedValue)); + } + private final class DummyTransport implements HttpTransport { @Override public void send(Response info, boolean head, ByteBuffer content, boolean lastContent, Callback callback) { diff --git a/jrt/pom.xml b/jrt/pom.xml index 5208c0417cc..e9383654e30 100644 --- a/jrt/pom.xml +++ b/jrt/pom.xml @@ -34,6 +34,16 @@ <artifactId>security-utils</artifactId> <version>${project.version}</version> <scope>compile</scope> + <exclusions> + <exclusion> <!-- not needed --> + <groupId>org.apache.httpcomponents</groupId> + <artifactId>httpclient</artifactId> + </exclusion> + <exclusion> <!-- not needed --> + <groupId>org.apache.httpcomponents</groupId> + <artifactId>httpcore</artifactId> + </exclusion> + </exclusions> </dependency> <dependency> <!-- required due to bug in maven dependency resolving - bouncycastle is compile scope in security-utils, yet it is not part of test scope here --> <groupId>org.bouncycastle</groupId> diff --git a/logd/src/logd/legacy_forwarder.cpp b/logd/src/logd/legacy_forwarder.cpp index c0f74d205e7..851e4458f77 100644 --- a/logd/src/logd/legacy_forwarder.cpp +++ b/logd/src/logd/legacy_forwarder.cpp @@ -11,6 +11,7 @@ #include <vespa/vespalib/util/stringfmt.h> #include <fcntl.h> #include <unistd.h> +#include <sstream> #include <vespa/log/log.h> LOG_SETUP(""); @@ -126,12 +127,12 @@ void LegacyForwarder::forwardLine(std::string_view line) { assert(_logserver_fd >= 0); - assert (line.size() > 0); assert (line.size() < 1024*1024); - assert (line[line.size() - 1] == '\n'); if (parseLine(line)) { - forwardText(line.data(), line.size()); + std::ostringstream line_copy; + line_copy << line << std::endl; + forwardText(line_copy.str().data(), line_copy.str().size()); } } diff --git a/logd/src/logd/watcher.cpp b/logd/src/logd/watcher.cpp index a92ad456e9f..fca9cd648bb 100644 --- a/logd/src/logd/watcher.cpp +++ b/logd/src/logd/watcher.cpp @@ -222,7 +222,7 @@ Watcher::watchfile() } while (nnl != nullptr && elapsed(tickStart) < 1) { ++nnl; - _forwarder.forwardLine(std::string_view(l, (nnl - l))); + _forwarder.forwardLine(std::string_view(l, (nnl - l) - 1)); ssize_t wsize = nnl - l; offset += wsize; l = nnl; diff --git a/logd/src/tests/legacy_forwarder/legacy_forwarder_test.cpp b/logd/src/tests/legacy_forwarder/legacy_forwarder_test.cpp index d3339894819..67d47a49384 100644 --- a/logd/src/tests/legacy_forwarder/legacy_forwarder_test.cpp +++ b/logd/src/tests/legacy_forwarder/legacy_forwarder_test.cpp @@ -40,7 +40,7 @@ struct ForwardFixture { timer.SetNow(); std::stringstream ss; ss << std::fixed << timer.Secs(); - ss << "\texample.yahoo.com\t7518/34779\tlogd\tlogdemon\tevent\tstarted/1 name=\"logdemon\"\n"; + ss << "\texample.yahoo.com\t7518/34779\tlogd\tlogdemon\tevent\tstarted/1 name=\"logdemon\""; return ss.str(); } @@ -50,7 +50,7 @@ struct ForwardFixture { int rfd = open(fname.c_str(), O_RDONLY); char *buffer[2048]; ssize_t bytes = read(rfd, buffer, 2048); - ssize_t expected = doForward ? logLine.length() : 0; + ssize_t expected = doForward ? logLine.length() + 1 : 0; EXPECT_EQUAL(expected, bytes); close(rfd); } diff --git a/logd/src/tests/watcher/watcher_test.cpp b/logd/src/tests/watcher/watcher_test.cpp index c2b379cc1a4..fffaac17058 100644 --- a/logd/src/tests/watcher/watcher_test.cpp +++ b/logd/src/tests/watcher/watcher_test.cpp @@ -71,8 +71,7 @@ struct DummyForwarder : public Forwarder { void sendMode() override { ++sendModeCount; } void forwardLine(std::string_view log_line) override { std::lock_guard guard(lock); - assert(log_line.size() > 0u); - lines.emplace_back(log_line.substr(0, log_line.size() - 1)); + lines.emplace_back(log_line); cond.notify_all(); } void flush() override { } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java index c4acfeb3235..9c8f6238731 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java @@ -29,9 +29,17 @@ public class OrderedTensorType { private final long[] innerSizesVespa; private final int[] dimensionMap; - private OrderedTensorType(List<TensorType.Dimension> dimensions) { + private OrderedTensorType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) { this.dimensions = Collections.unmodifiableList(dimensions); - this.type = new TensorType.Builder(dimensions).build(); + this.type = new TensorType.Builder(valueType, dimensions).build(); + this.innerSizesOriginal = new long[dimensions.size()]; + this.innerSizesVespa = new long[dimensions.size()]; + this.dimensionMap = createDimensionMap(); + } + + private OrderedTensorType(TensorType type) { + this.dimensions = type.dimensions(); + this.type = type; this.innerSizesOriginal = new long[dimensions.size()]; this.innerSizesVespa = new long[dimensions.size()]; this.dimensionMap = createDimensionMap(); @@ -136,11 +144,11 @@ public class OrderedTensorType { renamedDimensions.add(TensorType.Dimension.mapped(newName.get())); } } - return new OrderedTensorType(renamedDimensions); + return new OrderedTensorType(type.valueType(), renamedDimensions); } public OrderedTensorType rename(String dimensionPrefix) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(type.valueType()); for (int i = 0; i < dimensions.size(); ++ i) { String dimensionName = dimensionPrefix + i; Optional<Long> dimSize = dimensions.get(i).size(); @@ -154,7 +162,7 @@ public class OrderedTensorType { } public static OrderedTensorType standardType(OrderedTensorType type) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(type.type().valueType()); for (int i = 0; i < type.dimensions().size(); ++ i) { TensorType.Dimension dim = type.dimensions().get(i); String dimensionName = "d" + i; @@ -193,18 +201,18 @@ public class OrderedTensorType { * where dimensions are listed in the order of this rather than the natural order of their names. */ public static OrderedTensorType fromSpec(String typeSpec) { - return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec)); + return new OrderedTensorType(TensorType.fromSpec(typeSpec)); } - public static OrderedTensorType fromDimensionList(List<Long> dims) { - return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ... + public static OrderedTensorType fromDimensionList(TensorType.Value valueType, List<Long> dimensions) { + return fromDimensionList(valueType, dimensions, "d"); // standard naming convention: d0, d1, ... } - private static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); - for (int i = 0; i < dims.size(); ++ i) { + private static OrderedTensorType fromDimensionList(TensorType.Value valueType, List<Long> dimensions, String dimensionPrefix) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(valueType); + for (int i = 0; i < dimensions.size(); ++ i) { String dimensionName = dimensionPrefix + i; - Long dimSize = dims.get(i); + Long dimSize = dimensions.get(i); if (dimSize >= 0) { builder.add(TensorType.Dimension.indexed(dimensionName, dimSize)); } else { @@ -216,9 +224,15 @@ public class OrderedTensorType { public static class Builder { + private final TensorType.Value valueType; private final List<TensorType.Dimension> dimensions; public Builder() { + this(TensorType.Value.DOUBLE); + } + + public Builder(TensorType.Value valueType) { + this.valueType = valueType; this.dimensions = new ArrayList<>(); } @@ -228,7 +242,7 @@ public class OrderedTensorType { } public OrderedTensorType build() { - return new OrderedTensorType(dimensions); + return new OrderedTensorType(valueType, dimensions); } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index dd2add973e4..a469e666d93 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -105,8 +105,8 @@ class GraphImporter { if (isArgumentTensor(name, onnxGraph)) { Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph); if (valueInfoProto == null) - throw new IllegalArgumentException("Could not find argument tensor: " + name); - OrderedTensorType type = TypeConverter.fromOnnxType(valueInfoProto.getType()); + throw new IllegalArgumentException("Could not find argument tensor '" + name + "'"); + OrderedTensorType type = TypeConverter.typeFrom(valueInfoProto.getType()); operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type); intermediateGraph.inputs(intermediateGraph.defaultSignature()) @@ -114,7 +114,7 @@ class GraphImporter { } else if (isConstantTensor(name, onnxGraph)) { Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph); - OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList()); + OrderedTensorType defaultType = TypeConverter.typeFrom(tensorProto); operation = new Constant(intermediateGraph.name(), name, defaultType); operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java index f251a14213b..29d600fa7c6 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java @@ -30,13 +30,10 @@ class TypeConverter { } } - static OrderedTensorType fromOnnxType(Onnx.TypeProto type) { - return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ... - } - - private static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) { + static OrderedTensorType typeFrom(Onnx.TypeProto type) { + String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... Onnx.TensorShapeProto shape = type.getTensorType().getShape(); - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(type.getTensorType().getElemType())); for (int i = 0; i < shape.getDimCount(); ++ i) { String dimensionName = dimensionPrefix + i; Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); @@ -49,4 +46,28 @@ class TypeConverter { return builder.build(); } + static OrderedTensorType typeFrom(Onnx.TensorProto tensor) { + return OrderedTensorType.fromDimensionList(toValueType(tensor.getDataType()), + tensor.getDimsList()); + } + + private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) { + switch (dataType) { + case FLOAT: return TensorType.Value.FLOAT; + case DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case BOOL: return TensorType.Value.FLOAT; + case INT8: return TensorType.Value.FLOAT; + case INT16: return TensorType.Value.FLOAT; + case INT32: return TensorType.Value.FLOAT; + case INT64: return TensorType.Value.DOUBLE; + case UINT8: return TensorType.Value.FLOAT; + case UINT16: return TensorType.Value.FLOAT; + case UINT32: return TensorType.Value.FLOAT; + case UINT64: return TensorType.Value.DOUBLE; + default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); + } + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java index 1a564661ccb..7ae50a0549d 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java @@ -21,20 +21,15 @@ public class ConcatV2 extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) { - return null; - } + if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) return null; IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input - if (!concatDimOp.getConstantValue().isPresent()) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + - "concat dimension must be a constant."); - } + if ( ! concatDimOp.getConstantValue().isPresent()) + throw new IllegalArgumentException("ConcatV2 in " + name + ": Concat dimension must be a constant."); + Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor(); - if (concatDimTensor.type().rank() != 0) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + - "concat dimension must be a scalar."); - } + if (concatDimTensor.type().rank() != 0) + throw new IllegalArgumentException("ConcatV2 in " + name + ": Concat dimension must be a scalar."); OrderedTensorType aType = inputs.get(0).type().get(); concatDimensionIndex = (int)concatDimTensor.asDouble(); @@ -42,10 +37,9 @@ public class ConcatV2 extends IntermediateOperation { for (int i = 1; i < inputs.size() - 1; ++i) { OrderedTensorType bType = inputs.get(i).type().get(); - if (bType.rank() != aType.rank()) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + - "inputs must have save rank."); - } + if (bType.rank() != aType.rank()) + throw new IllegalArgumentException("ConcatV2 in " + name + ": Inputs must have the same rank."); + for (int j = 0; j < aType.rank(); ++j) { long dimSizeA = aType.dimensions().get(j).size().orElse(-1L); long dimSizeB = bType.dimensions().get(j).size().orElse(-1L); @@ -58,7 +52,7 @@ public class ConcatV2 extends IntermediateOperation { } } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); int dimensionIndex = 0; for (TensorType.Dimension dimension : aType.dimensions()) { if (dimensionIndex == concatDimensionIndex) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index 8ae6d81b8d4..c64b9ded601 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -27,20 +27,15 @@ public class ExpandDims extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; IntermediateOperation axisOperation = inputs().get(1); if (!axisOperation.getConstantValue().isPresent()) { - throw new IllegalArgumentException("ExpandDims in " + name + ": " + - "axis must be a constant."); + throw new IllegalArgumentException("ExpandDims in " + name + ": Axis must be a constant."); } Tensor axis = axisOperation.getConstantValue().get().asTensor(); - if (axis.type().rank() != 0) { - throw new IllegalArgumentException("ExpandDims in " + name + ": " + - "axis argument must be a scalar."); - } + if (axis.type().rank() != 0) + throw new IllegalArgumentException("ExpandDims in " + name + ": Axis argument must be a scalar."); OrderedTensorType inputType = inputs.get(0).type().get(); int dimensionToInsert = (int)axis.asDouble(); @@ -48,7 +43,7 @@ public class ExpandDims extends IntermediateOperation { dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); expandDimensions = new ArrayList<>(); int dimensionIndex = 0; for (TensorType.Dimension dimension : inputType.dimensions()) { @@ -66,12 +61,10 @@ public class ExpandDims extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputFunctionsPresent(2)) { - return null; - } + if ( ! allInputFunctionsPresent(2)) return null; // multiply with a generated tensor created from the reduced dimensions - TensorType.Builder typeBuilder = new TensorType.Builder(); + TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); for (String name : expandDimensions) { typeBuilder.indexed(name, 1); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 3b77f9527ca..0ee54f839bc 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -9,6 +9,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.TensorFunction; @@ -17,6 +18,7 @@ import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.function.Function; +import java.util.stream.Collectors; /** * Wraps an imported operation node and produces the respective Vespa tensor @@ -161,6 +163,19 @@ public abstract class IntermediateOperation { } /** + * Returns the largest value type among the input value types. + * This should only be called after it has been verified that input types are available. + * + * @throws IllegalArgumentException if a type cannot be uniquely determined + * @throws RuntimeException if called when input types are not available + */ + TensorType.Value resultValueType() { + return TensorType.Value.largestOf(inputs.stream() + .map(input -> input.type().get().type().valueType()) + .collect(Collectors.toList())); + } + + /** * A method signature input and output has the form name:index. * This returns the name part without the index. */ diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index fed95e13bb7..c2d75153586 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -22,13 +22,12 @@ public class Join extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + OrderedTensorType a = largestInput().type().get(); OrderedTensorType b = smallestInput().type().get(); - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType()); int sizeDifference = a.rank() - b.rank(); for (int i = 0; i < a.rank(); ++i) { TensorType.Dimension aDim = a.dimensions().get(i); @@ -52,12 +51,8 @@ public class Join extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } - if (!allInputFunctionsPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + if ( ! allInputFunctionsPresent(2)) return null; IntermediateOperation a = largestInput(); IntermediateOperation b = smallestInput(); @@ -92,9 +87,8 @@ public class Join extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!allInputTypesPresent(2)) { - return; - } + if ( ! allInputTypesPresent(2)) return; + OrderedTensorType a = largestInput().type().get(); OrderedTensorType b = smallestInput().type().get(); int sizeDifference = a.rank() - b.rank(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 1dbfd6e40dc..9a76662529d 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -17,10 +17,9 @@ public class MatMul extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + if ( ! allInputTypesPresent(2)) return null; + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); typeBuilder.add(inputs.get(0).type().get().dimensions().get(0)); typeBuilder.add(inputs.get(1).type().get().dimensions().get(1)); return typeBuilder.build(); @@ -28,9 +27,8 @@ public class MatMul extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + OrderedTensorType aType = inputs.get(0).type().get(); OrderedTensorType bType = inputs.get(1).type().get(); if (aType.type().rank() < 2 || bType.type().rank() < 2) @@ -48,9 +46,8 @@ public class MatMul extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!allInputTypesPresent(2)) { - return; - } + if ( ! allInputTypesPresent(2)) return; + List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); @@ -69,4 +66,5 @@ public class MatMul extends IntermediateOperation { renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this); renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java index 4be220db9d5..d8e9950c61f 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java @@ -32,13 +32,11 @@ public class Mean extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + IntermediateOperation reductionIndices = inputs.get(1); - if (!reductionIndices.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Mean in " + name + ": " + - "reduction indices must be a constant."); + if ( ! reductionIndices.getConstantValue().isPresent()) { + throw new IllegalArgumentException("Mean in " + name + ": Reduction indices must be a constant."); } Tensor indices = reductionIndices.getConstantValue().get().asTensor(); reduceDimensions = new ArrayList<>(); @@ -59,14 +57,14 @@ public class Mean extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + + TensorFunction inputFunction = inputs.get(0).function().get(); TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions); if (shouldKeepDimensions()) { // multiply with a generated tensor created from the reduced dimensions - TensorType.Builder typeBuilder = new TensorType.Builder(); + TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); for (String name : reduceDimensions) { typeBuilder.indexed(name, 1); } @@ -99,9 +97,9 @@ public class Mean extends IntermediateOperation { } private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType()); for (TensorType.Dimension dimension: inputType.type().dimensions()) { - if (!reduceDimensions.contains(dimension.name())) { + if ( ! reduceDimensions.contains(dimension.name())) { builder.add(dimension); } else if (keepDimensions) { builder.add(TensorType.Dimension.indexed(dimension.name(), 1L)); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index 18f3cc1cc39..4a0fe236c9f 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -32,18 +32,16 @@ public class Reshape extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + IntermediateOperation newShape = inputs.get(1); - if (!newShape.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Reshape in " + name + ": " + - "shape input must be a constant."); - } + if ( ! newShape.getConstantValue().isPresent()) + throw new IllegalArgumentException("Reshape in " + name + ": Shape input must be a constant."); + Tensor shape = newShape.getConstantValue().get().asTensor(); OrderedTensorType inputType = inputs.get(0).type().get(); - OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType()); int dimensionIndex = 0; for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { Tensor.Cell cell = cellIterator.next(); @@ -61,12 +59,9 @@ public class Reshape extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } - if (!allInputFunctionsPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + if ( ! allInputFunctionsPresent(2)) return null; + OrderedTensorType inputType = inputs.get(0).type().get(); TensorFunction inputFunction = inputs.get(0).function().get(); return reshape(inputFunction, inputType.type(), type.type()); @@ -80,9 +75,8 @@ public class Reshape extends IntermediateOperation { } public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { - if (!OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) { + if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); - } // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order, // then use the dimension order of the new shape to roll back into a tensor. @@ -96,20 +90,17 @@ public class Reshape extends IntermediateOperation { TensorType transformationType = new TensorType.Builder(inputType, outputType).build(); Generate transformTensor = new Generate(transformationType, - new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); - - TensorFunction outputFunction = new Reduce( - new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), - Reduce.Aggregator.sum, - inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList())); + new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); - return outputFunction; + return new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList())); } private static ExpressionNode unrollTensorExpression(TensorType type) { - if (type.rank() == 0) { + if (type.rank() == 0) return new ConstantNode(DoubleValue.zero); - } + List<ExpressionNode> children = new ArrayList<>(); List<ArithmeticOperator> operators = new ArrayList<>(); int size = 1; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java index 361729a8c14..79f3012c327 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java @@ -19,11 +19,10 @@ public class Shape extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(1)) { - return null; - } + if ( ! allInputTypesPresent(1)) return null; + OrderedTensorType inputType = inputs.get(0).type().get(); - return new OrderedTensorType.Builder() + return new OrderedTensorType.Builder(resultValueType()) .add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size())) .build(); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java index 2eeefcbe8a2..52d40144f61 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java @@ -25,9 +25,8 @@ public class Squeeze extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(1)) { - return null; - } + if ( ! allInputTypesPresent(1)) return null; + OrderedTensorType inputType = inputs.get(0).type().get(); squeezeDimensions = new ArrayList<>(); @@ -51,9 +50,8 @@ public class Squeeze extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputFunctionsPresent(1)) { - return null; - } + if ( ! allInputFunctionsPresent(1)) return null; + TensorFunction inputFunction = inputs.get(0).function().get(); return new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions); } @@ -73,7 +71,7 @@ public class Squeeze extends IntermediateOperation { } private OrderedTensorType reducedType(OrderedTensorType inputType) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType()); for (TensorType.Dimension dimension: inputType.type().dimensions()) { if ( ! squeezeDimensions.contains(dimension.name())) { builder.add(dimension); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java index cb838cd67b1..a07c0fdf4dc 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java @@ -51,7 +51,7 @@ class GraphImporter { String nodeName = node.getName(); String modelName = graph.name(); int nodePort = IntermediateOperation.indexPartOf(nodeName); - OrderedTensorType nodeType = TypeConverter.fromTensorFlowType(node); + OrderedTensorType nodeType = TypeConverter.typeFrom(node); AttributeConverter attributes = AttributeConverter.convert(node); switch (node.getOp().toLowerCase()) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java index 6c92ffa6055..9cba388d00e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java @@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import org.tensorflow.DataType; import org.tensorflow.framework.TensorProto; import java.nio.ByteBuffer; @@ -27,7 +28,7 @@ public class TensorConverter { } private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { - TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix); + TensorType type = TypeConverter.typeFrom(tfTensor, dimensionPrefix); Values values = readValuesOf(tfTensor); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); for (int i = 0; i < values.size(); i++) @@ -53,16 +54,6 @@ public class TensorConverter { return builder.build(); } - private static TensorType toVespaTensorType(long[] shape, String dimensionPrefix) { - TensorType.Builder b = new TensorType.Builder(); - int dimensionIndex = 0; - for (long dimensionSize : shape) { - if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ... - b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize); - } - return b.build(); - } - public static Long tensorSize(TensorType type) { Long size = 1L; for (TensorType.Dimension dimension : type.dimensions()) { @@ -85,7 +76,7 @@ public class TensorConverter { case INT64: return new LongValues(tfTensor); } throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + - tfTensor.dataType() + " to a Vespa tensor"); + tfTensor.dataType() + " to a Vespa tensor"); } private static Values readValuesOf(TensorProto tensorProto) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java index 63a605ce97a..d8ddb01b650 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java @@ -5,11 +5,10 @@ package ai.vespa.rankingexpression.importer.tensorflow; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.DataType; import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.TensorShapeProto; -import java.util.List; - /** * Converts and verifies TensorFlow tensor types into Vespa tensor types. * @@ -22,7 +21,7 @@ class TypeConverter { if (shape != null) { if (shape.getDimCount() != type.rank()) { throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " + - "does not match Vespa shape"); + "does not match Vespa shape"); } for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions().size(); ++tensorFlowIndex) { int vespaIndex = type.dimensionMap(tensorFlowIndex); @@ -30,33 +29,16 @@ class TypeConverter { TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex); if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) { throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " + - "does not match Vespa dimensions"); + "does not match Vespa dimensions"); } } } } - private static TensorShapeProto tensorFlowShape(NodeDef node) { - AttrValue attrValueList = node.getAttrMap().get("_output_shapes"); - if (attrValueList == null) { - throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + - "does not exist"); - } - if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) { - throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + - "is not of expected type"); - } - List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList(); - return shapeList.get(0); // support multiple outputs? - } - - static OrderedTensorType fromTensorFlowType(NodeDef node) { - return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ... - } - - private static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + static OrderedTensorType typeFrom(NodeDef node) { + String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... TensorShapeProto shape = tensorFlowShape(node); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(tensorFlowValueType(node))); for (int i = 0; i < shape.getDimCount(); ++ i) { String dimensionName = dimensionPrefix + i; TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i); @@ -69,4 +51,71 @@ class TypeConverter { return builder.build(); } + static TensorType typeFrom(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { + TensorType.Builder b = new TensorType.Builder(toValueType(tfTensor.dataType())); + int dimensionIndex = 0; + for (long dimensionSize : tfTensor.shape()) { + if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ... + b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize); + } + return b.build(); + } + + private static TensorShapeProto tensorFlowShape(NodeDef node) { + AttrValue attrValueList = node.getAttrMap().get("_output_shapes"); + if (attrValueList == null) + throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + + "does not exist"); + if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) + throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + + "is not of expected type"); + + return attrValueList.getList().getShape(0); // support multiple outputs? + } + + private static DataType tensorFlowValueType(NodeDef node) { + AttrValue attrValueList = node.getAttrMap().get("dtypes"); + if (attrValueList == null) + return DataType.DT_DOUBLE; // default. This will usually (always?) be used. TODO: How can we do better? + if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) + return DataType.DT_DOUBLE; // default + + return attrValueList.getList().getType(0); // support multiple outputs? + } + + private static TensorType.Value toValueType(DataType dataType) { + switch (dataType) { + case DT_FLOAT: return TensorType.Value.FLOAT; + case DT_DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case DT_BOOL: return TensorType.Value.FLOAT; + case DT_BFLOAT16: return TensorType.Value.FLOAT; + case DT_HALF: return TensorType.Value.FLOAT; + case DT_INT8: return TensorType.Value.FLOAT; + case DT_INT16: return TensorType.Value.FLOAT; + case DT_INT32: return TensorType.Value.FLOAT; + case DT_INT64: return TensorType.Value.DOUBLE; + case DT_UINT8: return TensorType.Value.FLOAT; + case DT_UINT16: return TensorType.Value.FLOAT; + case DT_UINT32: return TensorType.Value.FLOAT; + case DT_UINT64: return TensorType.Value.DOUBLE; + default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); + } + } + + private static TensorType.Value toValueType(org.tensorflow.DataType dataType) { + switch (dataType) { + case FLOAT: return TensorType.Value.FLOAT; + case DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case BOOL: return TensorType.Value.FLOAT; + case INT32: return TensorType.Value.FLOAT; + case UINT8: return TensorType.Value.FLOAT; + case INT64: return TensorType.Value.DOUBLE; + default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); + } + } + } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java index afe699d6e05..61f332327be 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java @@ -13,9 +13,10 @@ public class OrderedTensorTypeTestCase { @Test public void testToFromSpec() { String spec = "tensor(b[],c{},a[3])"; + String orderedSpec = "tensor(a[3],b[],c{})"; OrderedTensorType type = OrderedTensorType.fromSpec(spec); - assertEquals(spec, type.toString()); - assertEquals("tensor(a[3],b[],c{})", type.type().toString()); + assertEquals(orderedSpec, type.toString()); + assertEquals(orderedSpec, type.type().toString()); } } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index 424e4d6c57c..07814687dc6 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -43,14 +43,14 @@ public class OnnxMnistSoftmaxImportTestCase { // Check inputs assertEquals(1, model.inputs().size()); assertTrue(model.inputs().containsKey("Placeholder")); - assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder")); + assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"), model.inputs().get("Placeholder")); // Check signature ImportedMlFunction output = model.defaultSignature().outputFunction("add", "add"); assertNotNull(output); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", output.expression()); - assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), + assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"), model.inputs().get(model.defaultSignature().inputs().get("Placeholder"))); assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString()); } diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 79c633b9617..b8c51f4e33d 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -886,6 +886,7 @@ "public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()", "public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()", "public final com.yahoo.tensor.TensorType tensorTypeArgument()", + "public final com.yahoo.tensor.TensorType$Value optionalTensorValueTypeParameter()", "public final void tensorTypeDimension(com.yahoo.tensor.TensorType$Builder)", "public final java.lang.String tensorFunctionName()", "public final com.yahoo.searchlib.rankingexpression.rule.Function unaryFunctionName()", diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 2f173ad0266..c83de4ced0a 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -598,9 +598,12 @@ Reduce.Aggregator tensorReduceAggregator() : TensorType tensorTypeArgument() : { - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder; + TensorType.Value valueType; } { + valueType = optionalTensorValueTypeParameter() + { builder = new TensorType.Builder(valueType); } <LBRACE> ( tensorTypeDimension(builder) ) ? ( <COMMA> tensorTypeDimension(builder) ) * @@ -608,6 +611,15 @@ TensorType tensorTypeArgument() : { return builder.build(); } } +TensorType.Value optionalTensorValueTypeParameter() : +{ + String valueType = "double"; +} +{ + ( <LT> valueType = identifier() <GT> )? + { return TensorTypeParser.toValueType(valueType); } +} + // NOTE: Only indexed bound dimensions are parsed currently, as that is what we need void tensorTypeDimension(TensorType.Builder builder) : { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index f2122bb5da9..f7e38862883 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -238,6 +238,8 @@ public class EvaluationTestCase { "{{x:0}:1}", "{}", "{{y:0,z:0}:1}"); tester.assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:0}:3 }", "tensor(x{}):{ {x:1}:5 }"); + tester.assertEvaluates("tensor<float>(x{}):{}", + "tensor0 * tensor1", "{ {x:0}:3 }", "tensor<float>(x{}):{ {x:1}:5 }"); tester.assertEvaluates("{ {x:0}:15 }", "tensor0 * tensor1", "{ {x:0}:3 }", "{ {x:0}:5 }"); tester.assertEvaluates("{ {x:0,y:0}:15 }", diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java index ba0db4de5e1..488930a8eb9 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java @@ -40,7 +40,7 @@ public class EvaluationTester { int argumentIndex = 0; for (String argumentString : tensorArgumentStrings) { Tensor argument; - if (argumentString.startsWith("tensor(")) // explicitly decided type + if (argumentString.startsWith("tensor")) // explicitly decided type argument = Tensor.from(argumentString); else // use mappedTensors+dimensions in tensor to decide type argument = Tensor.from(typeFrom(argumentString, mappedTensors), argumentString); diff --git a/searchlib/src/tests/aggregator/perdocexpr.cpp b/searchlib/src/tests/aggregator/perdocexpr.cpp index 66d2e48194d..1b85fb8f427 100644 --- a/searchlib/src/tests/aggregator/perdocexpr.cpp +++ b/searchlib/src/tests/aggregator/perdocexpr.cpp @@ -1325,6 +1325,75 @@ TEST("testAggregationResults") { FloatResultNode(15.54)); } +TEST("test Average over integer") { + AggregationResult::Configure conf; + AverageAggregationResult avg; + avg.setExpression(createScalarInt(I4)).select(conf, conf); + avg.aggregate(0, 0); + EXPECT_EQUAL(I4, avg.getAverage().getInteger()); +} + +TEST("test Average over float") { + AggregationResult::Configure conf; + AverageAggregationResult avg; + avg.setExpression(createScalarFloat(I4)).select(conf, conf); + avg.aggregate(0, 0); + EXPECT_EQUAL(I4, avg.getAverage().getInteger()); +} + +TEST("test Average over numeric string") { + AggregationResult::Configure conf; + AverageAggregationResult avg; + avg.setExpression(createScalarString("7.8")).select(conf, conf); + avg.aggregate(0, 0); + EXPECT_EQUAL(7.8, avg.getAverage().getFloat()); +} + +TEST("test Average over non-numeric string") { + AggregationResult::Configure conf; + AverageAggregationResult avg; + avg.setExpression(createScalarString("ABC")).select(conf, conf); + avg.aggregate(0, 0); + EXPECT_EQUAL(0, avg.getAverage().getInteger()); +} + +TEST("test Sum over integer") { + AggregationResult::Configure conf; + SumAggregationResult sum; + sum.setExpression(createScalarInt(I4)).select(conf, conf); + sum.aggregate(0, 0); + sum.aggregate(0, 0); + EXPECT_EQUAL(I4*2, sum.getSum().getInteger()); +} + +TEST("test Sum over float") { + AggregationResult::Configure conf; + SumAggregationResult sum; + sum.setExpression(createScalarFloat(I4)).select(conf, conf); + sum.aggregate(0, 0); + sum.aggregate(0, 0); + EXPECT_EQUAL(I4*2, sum.getSum().getInteger()); +} + +TEST("test Sum over numeric string") { + AggregationResult::Configure conf; + SumAggregationResult sum; + sum.setExpression(createScalarString("7.8")).select(conf, conf); + sum.aggregate(0, 0); + sum.aggregate(0, 0); + EXPECT_EQUAL(7.8*2, sum.getSum().getFloat()); +} + +TEST("test Sum over non-numeric string") { + AggregationResult::Configure conf; + SumAggregationResult sum; + sum.setExpression(createScalarString("ABC")).select(conf, conf); + sum.aggregate(0, 0); + sum.aggregate(0, 0); + EXPECT_EQUAL(0, sum.getSum().getInteger()); +} + + TEST("testGrouping") { AttributeGuard attr1 = createInt64Attribute(); ExpressionNode::UP result1(new CountAggregationResult()); diff --git a/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp b/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp index 42ce9725f91..54c77fb25a7 100644 --- a/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp +++ b/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp @@ -6,6 +6,10 @@ #include <vespa/searchlib/fef/test/ftlib.h> #include <vespa/searchlib/fef/test/rankresult.h> #include <vespa/searchlib/fef/test/dummy_dependency_handler.h> +#include <vespa/eval/tensor/tensor.h> +#include <vespa/eval/tensor/serialization/typed_binary_format.h> +#include <vespa/vespalib/objects/nbostream.h> +#include <vespa/eval/tensor/dense/dense_tensor.h> using namespace search; using namespace search::attribute; @@ -104,7 +108,26 @@ struct ArrayFixture : FixtureBase { } template <typename ExpectedType> - void check_prepare_state_output(const vespalib::string& input_vector) { + void check_prepare_state_output(const vespalib::tensor::Tensor & tensor, vespalib::tensor::SerializeFormat format, const ExpectedType & expected) { + vespalib::nbostream os; + vespalib::tensor::TypedBinaryFormat::serialize(os, tensor, format); + vespalib::string input_vector(os.c_str(), os.size()); + check_prepare_state_output(".tensor", input_vector, expected); + } + + template <typename ExpectedType> + void check_prepare_state_output(const vespalib::string& input_vector, const ExpectedType & expected) { + check_prepare_state_output("", input_vector, expected); + } + template <typename T> + static void verify(const dotproduct::ArrayParam<T> & a, const dotproduct::ArrayParam<T> & b) { + ASSERT_EQUAL(a.values.size(), b.values.size()); + for (size_t i(0); i < a.values.size(); i++) { + EXPECT_EQUAL(a.values[i], b.values[i]); + } + } + template <typename ExpectedType> + void check_prepare_state_output(const vespalib::string & postfix, const vespalib::string& input_vector, const ExpectedType & expected) { FtFeatureTest feature(_factory, ""); DotProductBlueprint bp; DummyDependencyHandler dependency_handler(bp); @@ -116,7 +139,7 @@ struct ArrayFixture : FixtureBase { FieldType::ATTRIBUTE, schema::CollectionType::ARRAY, imported_attr->getName()); bp.setup(feature.getIndexEnv(), params); - feature.getQueryEnv().getProperties().add("dotProduct.fancyvector", input_vector); + feature.getQueryEnv().getProperties().add("dotProduct.fancyvector" + postfix, input_vector); auto& obj_store = feature.getQueryEnv().getObjectStore(); bp.prepareSharedState(feature.getQueryEnv(), obj_store); // Resulting name is very implementation defined. But at least the tests will break if it changes. @@ -124,13 +147,12 @@ struct ArrayFixture : FixtureBase { ASSERT_TRUE(parsed != nullptr); const auto* as_object = dynamic_cast<const ExpectedType*>(parsed); ASSERT_TRUE(as_object != nullptr); - // We don't test the parsed output values here; that's the responsibility of other tests. + verify(expected, *as_object); } - void check_all_float_executions(feature_t expected, - const vespalib::string& vector, - DocId doc_id, - const vespalib::string& shared_param = "") { + void check_all_float_executions(feature_t expected, const vespalib::string& vector, + DocId doc_id, const vespalib::string& shared_param = "") + { check_executions<double>([this](auto float_type){ this->setup_float_mappings(float_type); }, {{BasicType::FLOAT, BasicType::DOUBLE}}, expected, vector, doc_id, shared_param); @@ -155,22 +177,46 @@ TEST_F("Zero-length float/double array query vector evaluates to zero", ArrayFix TEST_F("prepareSharedState emits i64 vector for i32 imported attribute", ArrayFixture) { f.setup_integer_mappings(BasicType::INT32); - f.template check_prepare_state_output<dotproduct::ArrayParam<int64_t>>("[101 202 303]"); + f.template check_prepare_state_output("[101 202 303]", dotproduct::ArrayParam<int64_t>({101, 202, 303})); } TEST_F("prepareSharedState emits i64 vector for i64 imported attribute", ArrayFixture) { f.setup_integer_mappings(BasicType::INT64); - f.template check_prepare_state_output<dotproduct::ArrayParam<int64_t>>("[101 202 303]"); + f.template check_prepare_state_output("[101 202 303]", dotproduct::ArrayParam<int64_t>({101, 202, 303})); } TEST_F("prepareSharedState emits double vector for float imported attribute", ArrayFixture) { f.setup_float_mappings(BasicType::FLOAT); - f.template check_prepare_state_output<dotproduct::ArrayParam<double>>("[10.1 20.2 30.3]"); + f.template check_prepare_state_output("[10.1 20.2 30.3]", dotproduct::ArrayParam<double>({10.1, 20.2, 30.3})); } TEST_F("prepareSharedState emits double vector for double imported attribute", ArrayFixture) { f.setup_float_mappings(BasicType::DOUBLE); - f.template check_prepare_state_output<dotproduct::ArrayParam<double>>("[10.1 20.2 30.3]"); + f.template check_prepare_state_output("[10.1 20.2 30.3]", dotproduct::ArrayParam<double>({10.1, 20.2, 30.3})); +} + +TEST_F("prepareSharedState handles tensor as float from tensor for double imported attribute", ArrayFixture) { + f.setup_float_mappings(BasicType::DOUBLE); + vespalib::tensor::DenseTensor tensor(vespalib::eval::ValueType::from_spec("tensor(x[3])"), {10.1, 20.2, 30.3}); + f.template check_prepare_state_output(tensor, vespalib::tensor::SerializeFormat::FLOAT, dotproduct::ArrayParam<double>({10.1, 20.2, 30.3})); +} + +TEST_F("prepareSharedState handles tensor as double from tensor for double imported attribute", ArrayFixture) { + f.setup_float_mappings(BasicType::DOUBLE); + vespalib::tensor::DenseTensor tensor(vespalib::eval::ValueType::from_spec("tensor(x[3])"), {10.1, 20.2, 30.3}); + f.template check_prepare_state_output(tensor, vespalib::tensor::SerializeFormat::DOUBLE, dotproduct::ArrayParam<double>({10.1, 20.2, 30.3})); +} + +TEST_F("prepareSharedState handles tensor as float from tensor for float imported attribute", ArrayFixture) { + f.setup_float_mappings(BasicType::FLOAT); + vespalib::tensor::DenseTensor tensor(vespalib::eval::ValueType::from_spec("tensor(x[3])"), {10.1, 20.2, 30.3}); + f.template check_prepare_state_output(tensor, vespalib::tensor::SerializeFormat::FLOAT, dotproduct::ArrayParam<float>({10.1, 20.2, 30.3})); +} + +TEST_F("prepareSharedState handles tensor as double from tensor for float imported attribute", ArrayFixture) { + f.setup_float_mappings(BasicType::FLOAT); + vespalib::tensor::DenseTensor tensor(vespalib::eval::ValueType::from_spec("tensor(x[3])"), {10.1, 20.2, 30.3}); + f.template check_prepare_state_output(tensor, vespalib::tensor::SerializeFormat::DOUBLE, dotproduct::ArrayParam<float>({10.1, 20.2, 30.3})); } TEST_F("Dense i32/i64 array dot product can be evaluated with pre-parsed object parameter", ArrayFixture) { diff --git a/searchlib/src/tests/grouping/grouping_test.cpp b/searchlib/src/tests/grouping/grouping_test.cpp index fea18619ef9..0750d30f60d 100644 --- a/searchlib/src/tests/grouping/grouping_test.cpp +++ b/searchlib/src/tests/grouping/grouping_test.cpp @@ -313,10 +313,9 @@ Test::testAggregationSimple() ctx.add(FloatAttrBuilder("float").add(3).add(7).add(15).sp()); ctx.add(StringAttrBuilder("string").add("3").add("7").add("15").sp()); - char strsum[3] = {-101, '5', 0}; - testAggregationSimpleSum(ctx, SumAggregationResult(), Int64ResultNode(25), FloatResultNode(25), StringResultNode(strsum)); - testAggregationSimpleSum(ctx, MinAggregationResult(), Int64ResultNode(3), FloatResultNode(3), StringResultNode("15")); - testAggregationSimpleSum(ctx, MaxAggregationResult(), Int64ResultNode(15), FloatResultNode(15), StringResultNode("7")); + TEST_DO(testAggregationSimpleSum(ctx, SumAggregationResult(), Int64ResultNode(25), FloatResultNode(25), StringResultNode("25"))); + TEST_DO(testAggregationSimpleSum(ctx, MinAggregationResult(), Int64ResultNode(3), FloatResultNode(3), StringResultNode("15"))); + TEST_DO(testAggregationSimpleSum(ctx, MaxAggregationResult(), Int64ResultNode(15), FloatResultNode(15), StringResultNode("7"))); } #define MU std::make_unique @@ -630,6 +629,14 @@ createAggr(SingleResultNode::UP r, ExpressionNode::UP e) { return aggr; } +template<typename T> +ExpressionNode::UP +createNumAggr(NumericResultNode::UP r, ExpressionNode::UP e) { + std::unique_ptr<T> aggr = MU<T>(std::move(r)); + aggr->setExpression(std::move(e)); + return aggr; +} + void Test::testAggregationGroupCapping() { @@ -680,13 +687,13 @@ Test::testAggregationGroupCapping() Group expect; expect.addChild(Group().setId(Int64ResultNode(7)).setRank(RawRank(7)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(7), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(7), MU<AttributeNode>("attr"))) .addOrderBy(MU<AggregationRefNode>(0), false)) .addChild(Group().setId(Int64ResultNode(8)).setRank(RawRank(8)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(8), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(8), MU<AttributeNode>("attr"))) .addOrderBy(MU<AggregationRefNode>(0), false)) .addChild(Group().setId(Int64ResultNode(9)).setRank(RawRank(9)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(9), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(9), MU<AttributeNode>("attr"))) .addOrderBy(MU<AggregationRefNode>(0), false)); EXPECT_TRUE(testAggregation(ctx, request, expect)); @@ -701,13 +708,13 @@ Test::testAggregationGroupCapping() Group expect = Group() .addChild(Group().setId(Int64ResultNode(1)).setRank(RawRank(1)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(1), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(1), MU<AttributeNode>("attr"))) .addOrderBy(MU<AggregationRefNode>(0), true)) .addChild(Group().setId(Int64ResultNode(2)).setRank(RawRank(2)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(2), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(2), MU<AttributeNode>("attr"))) .addOrderBy(MU<AggregationRefNode>(0), true)) .addChild(Group().setId(Int64ResultNode(3)).setRank(RawRank(3)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(3), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(3), MU<AttributeNode>("attr"))) .addOrderBy(MU<AggregationRefNode>(0), true)); EXPECT_TRUE(testAggregation(ctx, request, expect)); @@ -726,13 +733,13 @@ Test::testAggregationGroupCapping() Group expect; expect.addChild(Group().setId(Int64ResultNode(7)).setRank(RawRank(7)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(7), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(7), MU<AttributeNode>("attr"))) .addOrderBy(AddFunctionNode().appendArg(MU<AggregationRefNode>(0)).appendArg(MU<ConstantNode>(MU<Int64ResultNode>(3))).setResult(Int64ResultNode(10)), false)) .addChild(Group().setId(Int64ResultNode(8)).setRank(RawRank(8)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(8), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(8), MU<AttributeNode>("attr"))) .addOrderBy(AddFunctionNode().appendArg(MU<AggregationRefNode>(0)).appendArg(MU<ConstantNode>(MU<Int64ResultNode>(3))).setResult(Int64ResultNode(11)), false)) .addChild(Group().setId(Int64ResultNode(9)).setRank(RawRank(9)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(9), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(9), MU<AttributeNode>("attr"))) .addOrderBy(AddFunctionNode().appendArg(MU<AggregationRefNode>(0)).appendArg(MU<ConstantNode>(MU<Int64ResultNode>(3))).setResult(Int64ResultNode(12)), false)); EXPECT_TRUE(testAggregation(ctx, request, expect)); diff --git a/searchlib/src/tests/groupingengine/groupingengine_test.cpp b/searchlib/src/tests/groupingengine/groupingengine_test.cpp index 3920667c1d6..da9b8d62305 100644 --- a/searchlib/src/tests/groupingengine/groupingengine_test.cpp +++ b/searchlib/src/tests/groupingengine/groupingengine_test.cpp @@ -615,6 +615,14 @@ createAggr(SingleResultNode::UP r, ExpressionNode::UP e) { return aggr; } +template<typename T> +ExpressionNode::UP +createNumAggr(NumericResultNode::UP r, ExpressionNode::UP e) { + std::unique_ptr<T> aggr = MU<T>(std::move(r)); + aggr->setExpression(std::move(e)); + return aggr; +} + void Test::testAggregationGroupCapping() { @@ -670,13 +678,13 @@ Test::testAggregationGroupCapping() Group expect; expect.setId(NullResultNode()) .addChild(Group().setId(Int64ResultNode(7)).setRank(RawRank(7)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(7), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(7), MU<AttributeNode>("attr"))) .addOrderBy(MU<AggregationRefNode>(0), false)) .addChild(Group().setId(Int64ResultNode(8)).setRank(RawRank(8)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(8), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(8), MU<AttributeNode>("attr"))) .addOrderBy(MU<AggregationRefNode>(0), false)) .addChild(Group().setId(Int64ResultNode(9)).setRank(RawRank(9)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(9), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(9), MU<AttributeNode>("attr"))) .addOrderBy(MU<AggregationRefNode>(0), false)); EXPECT_TRUE(testAggregation(ctx, request, expect)); @@ -693,13 +701,13 @@ Test::testAggregationGroupCapping() Group expect; expect.setId(NullResultNode()) .addChild(Group().setId(Int64ResultNode(1)).setRank(RawRank(1)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(1), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(1), MU<AttributeNode>("attr"))) .addOrderBy(MU<AggregationRefNode>(0), true)) .addChild(Group().setId(Int64ResultNode(2)).setRank(RawRank(2)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(2), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(2), MU<AttributeNode>("attr"))) .addOrderBy(MU<AggregationRefNode>(0), true)) .addChild(Group().setId(Int64ResultNode(3)).setRank(RawRank(3)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(3), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(3), MU<AttributeNode>("attr"))) .addOrderBy(MU<AggregationRefNode>(0), true)); EXPECT_TRUE(testAggregation(ctx, request, expect)); @@ -718,13 +726,13 @@ Test::testAggregationGroupCapping() Group expect = Group() .addChild(Group().setId(Int64ResultNode(7)).setRank(RawRank(7)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(7), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(7), MU<AttributeNode>("attr"))) .addOrderBy(AddFunctionNode().appendArg(MU<AggregationRefNode>(0)).appendArg(MU<ConstantNode>(MU<Int64ResultNode>(3))).setResult(Int64ResultNode(10)), false)) .addChild(Group().setId(Int64ResultNode(8)).setRank(RawRank(8)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(8), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(8), MU<AttributeNode>("attr"))) .addOrderBy(AddFunctionNode().appendArg(MU<AggregationRefNode>(0)).appendArg(MU<ConstantNode>(MU<Int64ResultNode>(3))).setResult(Int64ResultNode(11)), false)) .addChild(Group().setId(Int64ResultNode(9)).setRank(RawRank(9)) - .addAggregationResult(createAggr<SumAggregationResult>(MU<Int64ResultNode>(9), MU<AttributeNode>("attr"))) + .addAggregationResult(createNumAggr<SumAggregationResult>(MU<Int64ResultNode>(9), MU<AttributeNode>("attr"))) .addOrderBy(AddFunctionNode().appendArg(MU<AggregationRefNode>(0)).appendArg(MU<ConstantNode>(MU<Int64ResultNode>(3))).setResult(Int64ResultNode(12)), false)); EXPECT_TRUE(testAggregation(ctx, request, expect)); diff --git a/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp b/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp index d40cdd5f13e..d6eb3a033a2 100644 --- a/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp +++ b/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp @@ -17,6 +17,17 @@ bool isReady(const ResultNode *myRes, const ResultNode &ref) { return (myRes != 0 && myRes->getClass().id() == ref.getClass().id()); } +template<typename Wanted, typename Fallback> +std::unique_ptr<Wanted> +createAndEnsureWanted(const ResultNode & result) { + std::unique_ptr<ResultNode> tmp = result.createBaseType(); + if (dynamic_cast<Wanted *>(tmp.get()) != nullptr) { + return std::unique_ptr<Wanted>(static_cast<Wanted *>(tmp.release())); + } else { + return std::make_unique<Fallback>(); + } +} + } // namespace search::aggregation::<unnamed> @@ -38,14 +49,14 @@ IMPLEMENT_AGGREGATIONRESULT(ExpressionCountAggregationResult, AggregationResult) IMPLEMENT_AGGREGATIONRESULT(StandardDeviationAggregationResult, AggregationResult); AggregationResult::AggregationResult() : - _expressionTree(new ExpressionTree()), + _expressionTree(std::make_shared<ExpressionTree>()), _tag(-1) { } AggregationResult::AggregationResult(const AggregationResult &) = default; AggregationResult & AggregationResult::operator = (const AggregationResult &) = default; -AggregationResult::~AggregationResult() { } +AggregationResult::~AggregationResult() = default; void AggregationResult::aggregate(const document::Document & doc, HitRank rank) { @@ -66,14 +77,16 @@ AggregationResult::aggregate(DocId docId, HitRank rank) { } } -bool AggregationResult::Configure::check(const vespalib::Identifiable &obj) const +bool +AggregationResult::Configure::check(const vespalib::Identifiable &obj) const { return obj.inherits(AggregationResult::classId); } -void AggregationResult::Configure::execute(vespalib::Identifiable &obj) +void +AggregationResult::Configure::execute(vespalib::Identifiable &obj) { - AggregationResult & a(static_cast<AggregationResult &>(obj)); + auto & a(static_cast<AggregationResult &>(obj)); a.prepare(); } @@ -85,37 +98,40 @@ AggregationResult::setExpression(ExpressionNode::UP expr) return *this; } -void CountAggregationResult::onPrepare(const ResultNode & result, bool useForInit) +void +CountAggregationResult::onPrepare(const ResultNode & result, bool useForInit) { (void) result; (void) useForInit; } -void SumAggregationResult::onPrepare(const ResultNode & result, bool useForInit) +void +SumAggregationResult::onPrepare(const ResultNode & result, bool useForInit) { if (isReady(_sum.get(), result)) { return; } - _sum.reset(dynamic_cast<SingleResultNode *>(result.createBaseType().release())); + _sum = createAndEnsureWanted<NumericResultNode, FloatResultNode>(result); if ( useForInit ) { _sum->set(result); } } -MinAggregationResult::MinAggregationResult() : AggregationResult() { } +MinAggregationResult::MinAggregationResult() = default; MinAggregationResult::MinAggregationResult(const ResultNode::CP &result) : AggregationResult() { setResult(result); } -MinAggregationResult::~MinAggregationResult() { } +MinAggregationResult::~MinAggregationResult() = default; -void MinAggregationResult::onPrepare(const ResultNode & result, bool useForInit) +void +MinAggregationResult::onPrepare(const ResultNode & result, bool useForInit) { if (isReady(_min.get(), result)) { return; } - _min.reset(dynamic_cast<SingleResultNode *>(result.createBaseType().release())); + _min = createAndEnsureWanted<SingleResultNode, FloatResultNode>(result); if ( !useForInit ) { _min->setMax(); } else { @@ -123,19 +139,20 @@ void MinAggregationResult::onPrepare(const ResultNode & result, bool useForInit) } } -MaxAggregationResult::MaxAggregationResult() : AggregationResult(), _max() { } +MaxAggregationResult::MaxAggregationResult() = default; MaxAggregationResult::MaxAggregationResult(const SingleResultNode & max) : AggregationResult(), _max(max) { } -MaxAggregationResult::~MaxAggregationResult() { } +MaxAggregationResult::~MaxAggregationResult() = default; -void MaxAggregationResult::onPrepare(const ResultNode & result, bool useForInit) +void +MaxAggregationResult::onPrepare(const ResultNode & result, bool useForInit) { if (isReady(_max.get(), result)) { return; } - _max.reset(dynamic_cast<SingleResultNode *>(result.createBaseType().release())); + _max = createAndEnsureWanted<SingleResultNode, FloatResultNode>(result); if ( !useForInit ) { _max->setMin(); ///Should figure out how to set min too for float. } else { @@ -143,29 +160,33 @@ void MaxAggregationResult::onPrepare(const ResultNode & result, bool useForInit) } } -void AverageAggregationResult::onPrepare(const ResultNode & result, bool useForInit) +void +AverageAggregationResult::onPrepare(const ResultNode & result, bool useForInit) { if (isReady(_sum.get(), result)) { return; } - _sum.reset(dynamic_cast<NumericResultNode *>(result.createBaseType().release())); + _sum = createAndEnsureWanted<NumericResultNode, FloatResultNode>(result); if ( useForInit ) { _sum->set(result); } } -void XorAggregationResult::onPrepare(const ResultNode & result, bool useForInit) +void +XorAggregationResult::onPrepare(const ResultNode & result, bool useForInit) { (void) result; (void) useForInit; } -void SumAggregationResult::onMerge(const AggregationResult & b) +void +SumAggregationResult::onMerge(const AggregationResult & b) { _sum->add(*static_cast<const SumAggregationResult &>(b)._sum); } -void SumAggregationResult::onAggregate(const ResultNode & result) +void +SumAggregationResult::onAggregate(const ResultNode & result) { if (result.isMultiValue()) { static_cast<const ResultNodeVector &>(result).flattenSum(*_sum); @@ -174,17 +195,20 @@ void SumAggregationResult::onAggregate(const ResultNode & result) } } -void SumAggregationResult::onReset() +void +SumAggregationResult::onReset() { - _sum.reset(static_cast<SingleResultNode *>(_sum->getClass().create())); + _sum.reset(static_cast<NumericResultNode *>(_sum->getClass().create())); } -void CountAggregationResult::onMerge(const AggregationResult & b) +void +CountAggregationResult::onMerge(const AggregationResult & b) { _count.add(static_cast<const CountAggregationResult &>(b)._count); } -void CountAggregationResult::onAggregate(const ResultNode & result) +void +CountAggregationResult::onAggregate(const ResultNode & result) { if (result.isMultiValue()) { _count += static_cast<const ResultNodeVector &>(result).size(); @@ -193,17 +217,20 @@ void CountAggregationResult::onAggregate(const ResultNode & result) } } -void CountAggregationResult::onReset() +void +CountAggregationResult::onReset() { setCount(0); } -void MaxAggregationResult::onMerge(const AggregationResult & b) +void +MaxAggregationResult::onMerge(const AggregationResult & b) { _max->max(*static_cast<const MaxAggregationResult &>(b)._max); } -void MaxAggregationResult::onAggregate(const ResultNode & result) +void +MaxAggregationResult::onAggregate(const ResultNode & result) { if (result.isMultiValue()) { static_cast<const ResultNodeVector &>(result).flattenMax(*_max); @@ -212,18 +239,21 @@ void MaxAggregationResult::onAggregate(const ResultNode & result) } } -void MaxAggregationResult::onReset() +void +MaxAggregationResult::onReset() { _max.reset(static_cast<SingleResultNode *>(_max->getClass().create())); _max->setMin(); } -void MinAggregationResult::onMerge(const AggregationResult & b) +void +MinAggregationResult::onMerge(const AggregationResult & b) { _min->min(*static_cast<const MinAggregationResult &>(b)._min); } -void MinAggregationResult::onAggregate(const ResultNode & result) +void +MinAggregationResult::onAggregate(const ResultNode & result) { if (result.isMultiValue()) { static_cast<const ResultNodeVector &>(result).flattenMin(*_min); @@ -232,22 +262,25 @@ void MinAggregationResult::onAggregate(const ResultNode & result) } } -void MinAggregationResult::onReset() +void +MinAggregationResult::onReset() { _min.reset(static_cast<SingleResultNode *>(_min->getClass().create())); _min->setMax(); } -AverageAggregationResult::~AverageAggregationResult() {} +AverageAggregationResult::~AverageAggregationResult() = default; -void AverageAggregationResult::onMerge(const AggregationResult & b) +void +AverageAggregationResult::onMerge(const AggregationResult & b) { - const AverageAggregationResult & avg(static_cast<const AverageAggregationResult &>(b)); + const auto & avg(static_cast<const AverageAggregationResult &>(b)); _sum->add(*avg._sum); _count += avg._count; } -void AverageAggregationResult::onAggregate(const ResultNode & result) +void +AverageAggregationResult::onAggregate(const ResultNode & result) { if (result.isMultiValue()) { static_cast<const ResultNodeVector &>(result).flattenSum(*_sum); @@ -258,13 +291,15 @@ void AverageAggregationResult::onAggregate(const ResultNode & result) } } -void AverageAggregationResult::onReset() +void +AverageAggregationResult::onReset() { _count = 0; _sum.reset(static_cast<NumericResultNode *>(_sum->getClass().create())); } -const NumericResultNode & AverageAggregationResult::getAverage() const +const NumericResultNode & +AverageAggregationResult::getAverage() const { _averageScratchPad = _sum; if ( _count > 0 ) { @@ -275,12 +310,14 @@ const NumericResultNode & AverageAggregationResult::getAverage() const return *_averageScratchPad; } -void XorAggregationResult::onMerge(const AggregationResult & b) +void +XorAggregationResult::onMerge(const AggregationResult & b) { _xor.xorOp(static_cast<const XorAggregationResult &>(b)._xor); } -void XorAggregationResult::onAggregate(const ResultNode & result) +void +XorAggregationResult::onAggregate(const ResultNode & result) { if (result.isMultiValue()) { for (size_t i(0), m(static_cast<const ResultNodeVector &>(result).size()); i < m; i++) { @@ -291,21 +328,24 @@ void XorAggregationResult::onAggregate(const ResultNode & result) } } -void XorAggregationResult::onReset() +void +XorAggregationResult::onReset() { _xor = 0; } static FieldBase _G_tagField("tag"); -Serializer & AggregationResult::onSerialize(Serializer & os) const +Serializer & +AggregationResult::onSerialize(Serializer & os) const { return (os << *_expressionTree).put(_G_tagField, _tag); } -Deserializer & AggregationResult::onDeserialize(Deserializer & is) +Deserializer & +AggregationResult::onDeserialize(Deserializer & is) { - _expressionTree.reset(new ExpressionTree()); + _expressionTree = std::make_shared<ExpressionTree>(); return (is >> *_expressionTree).get(_G_tagField, _tag); } @@ -315,18 +355,21 @@ AggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const visit(visitor, "expression", _expressionTree); } -void AggregationResult::selectMembers(const vespalib::ObjectPredicate & predicate, vespalib::ObjectOperation & operation) +void +AggregationResult::selectMembers(const vespalib::ObjectPredicate & predicate, vespalib::ObjectOperation & operation) { _expressionTree->select(predicate,operation); } -Serializer & CountAggregationResult::onSerialize(Serializer & os) const +Serializer & +CountAggregationResult::onSerialize(Serializer & os) const { AggregationResult::onSerialize(os); return _count.serialize(os); } -Deserializer & CountAggregationResult::onDeserialize(Deserializer & is) +Deserializer & +CountAggregationResult::onDeserialize(Deserializer & is) { AggregationResult::onDeserialize(is); return _count.deserialize(is); @@ -339,27 +382,27 @@ CountAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const visit(visitor, "count", _count); } -Serializer & SumAggregationResult::onSerialize(Serializer & os) const +Serializer & +SumAggregationResult::onSerialize(Serializer & os) const { AggregationResult::onSerialize(os); return os << _sum; } -Deserializer & SumAggregationResult::onDeserialize(Deserializer & is) +Deserializer & +SumAggregationResult::onDeserialize(Deserializer & is) { AggregationResult::onDeserialize(is); return is >> _sum; } -SumAggregationResult::SumAggregationResult() - : AggregationResult(), - _sum() -{ } -SumAggregationResult::SumAggregationResult(SingleResultNode::UP sum) +SumAggregationResult::SumAggregationResult() = default; + +SumAggregationResult::SumAggregationResult(NumericResultNode::UP sum) : AggregationResult(), _sum(sum.release()) { } -SumAggregationResult::~SumAggregationResult() {} +SumAggregationResult::~SumAggregationResult() = default; void SumAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const @@ -368,13 +411,15 @@ SumAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const visit(visitor, "sum", _sum); } -Serializer & MinAggregationResult::onSerialize(Serializer & os) const +Serializer & +MinAggregationResult::onSerialize(Serializer & os) const { AggregationResult::onSerialize(os); return os << _min; } -Deserializer & MinAggregationResult::onDeserialize(Deserializer & is) +Deserializer & +MinAggregationResult::onDeserialize(Deserializer & is) { AggregationResult::onDeserialize(is); return is >> _min; @@ -387,13 +432,15 @@ MinAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const visit(visitor, "min", _min); } -Serializer & MaxAggregationResult::onSerialize(Serializer & os) const +Serializer & +MaxAggregationResult::onSerialize(Serializer & os) const { AggregationResult::onSerialize(os); return os << _max; } -Deserializer & MaxAggregationResult::onDeserialize(Deserializer & is) +Deserializer & +MaxAggregationResult::onDeserialize(Deserializer & is) { AggregationResult::onDeserialize(is); return is >> _max; @@ -406,16 +453,19 @@ MaxAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const visit(visitor, "max", _max); } -static FieldBase _G_countField("count"); -static FieldBase _G_sumField("sum"); +namespace { + FieldBase _G_countField("count"); +} -Serializer & AverageAggregationResult::onSerialize(Serializer & os) const +Serializer & +AverageAggregationResult::onSerialize(Serializer & os) const { AggregationResult::onSerialize(os); return os.put(_G_countField, _count) << _sum; } -Deserializer & AverageAggregationResult::onDeserialize(Deserializer & is) +Deserializer & +AverageAggregationResult::onDeserialize(Deserializer & is) { AggregationResult::onDeserialize(is); return is.get(_G_countField, _count) >> _sum; @@ -429,13 +479,15 @@ AverageAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const visit(visitor, "sum", _sum); } -Serializer & XorAggregationResult::onSerialize(Serializer & os) const +Serializer & +XorAggregationResult::onSerialize(Serializer & os) const { AggregationResult::onSerialize(os); return _xor.serialize(os); } -Deserializer & XorAggregationResult::onDeserialize(Deserializer & is) +Deserializer & +XorAggregationResult::onDeserialize(Deserializer & is) { AggregationResult::onDeserialize(is); return _xor.deserialize(is); @@ -451,7 +503,8 @@ XorAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const namespace { // Calculates the sum of all buckets. template <int BucketBits, typename HashT> -int calculateRank(const Sketch<BucketBits, HashT> &sketch) { +int +calculateRank(const Sketch<BucketBits, HashT> &sketch) { if (sketch.getClassId() == SparseSketch<BucketBits, HashT>::classId) { return static_cast<const SparseSketch<BucketBits, HashT>&>(sketch) .getSize(); @@ -465,13 +518,14 @@ int calculateRank(const Sketch<BucketBits, HashT> &sketch) { } } // namespace -void ExpressionCountAggregationResult::onMerge(const AggregationResult &r) { - const ExpressionCountAggregationResult &result = - Identifiable::cast<const ExpressionCountAggregationResult &>(r); +void +ExpressionCountAggregationResult::onMerge(const AggregationResult &r) { + const auto & result = Identifiable::cast<const ExpressionCountAggregationResult &>(r); _hll.merge(result._hll); _rank.set(calculateRank(_hll.getSketch())); } -void ExpressionCountAggregationResult::onAggregate(const ResultNode &result) { +void +ExpressionCountAggregationResult::onAggregate(const ResultNode &result) { size_t hash = result.hash(); const unsigned int seed = 42; hash = XXH32(&hash, sizeof(hash), seed); @@ -479,36 +533,38 @@ void ExpressionCountAggregationResult::onAggregate(const ResultNode &result) { // almost the same ordering as the actual estimates. _rank += _hll.aggregate(hash); } -void ExpressionCountAggregationResult::onReset() { +void +ExpressionCountAggregationResult::onReset() { _hll = HyperLogLog<PRECISION>(); _rank.set(0); } -Serializer &ExpressionCountAggregationResult::onSerialize( - Serializer &os) const { +Serializer & +ExpressionCountAggregationResult::onSerialize(Serializer &os) const { AggregationResult::onSerialize(os); _hll.serialize(os); return os; } -Deserializer &ExpressionCountAggregationResult::onDeserialize( - Deserializer &is) { +Deserializer & +ExpressionCountAggregationResult::onDeserialize(Deserializer &is) { AggregationResult::onDeserialize(is); _hll.deserialize(is); _rank.set(calculateRank(_hll.getSketch())); return is; } -ExpressionCountAggregationResult::ExpressionCountAggregationResult() : AggregationResult(), _hll() { } -ExpressionCountAggregationResult::~ExpressionCountAggregationResult() {} +ExpressionCountAggregationResult::ExpressionCountAggregationResult() = default; +ExpressionCountAggregationResult::~ExpressionCountAggregationResult() = default; StandardDeviationAggregationResult::StandardDeviationAggregationResult() - : AggregationResult(), _count(), _sum(), _sumOfSquared(), _stdDevScratchPad() + : AggregationResult(), _count(), _sum(), _sumOfSquared(), _stdDevScratchPad() { _stdDevScratchPad.reset(new expression::FloatResultNode()); } -StandardDeviationAggregationResult::~StandardDeviationAggregationResult() {} +StandardDeviationAggregationResult::~StandardDeviationAggregationResult() = default; -const NumericResultNode& StandardDeviationAggregationResult::getStandardDeviation() const noexcept +const NumericResultNode& +StandardDeviationAggregationResult::getStandardDeviation() const noexcept { if (_count == 0) { _stdDevScratchPad->set(Int64ResultNode(0)); @@ -520,15 +576,16 @@ const NumericResultNode& StandardDeviationAggregationResult::getStandardDeviatio return *_stdDevScratchPad; } -void StandardDeviationAggregationResult::onMerge(const AggregationResult &r) { - const StandardDeviationAggregationResult &result = - Identifiable::cast<const StandardDeviationAggregationResult &>(r); +void +StandardDeviationAggregationResult::onMerge(const AggregationResult &r) { + const auto & result = Identifiable::cast<const StandardDeviationAggregationResult &>(r); _count += result._count; _sum.add(result._sum); _sumOfSquared.add(result._sumOfSquared); } -void StandardDeviationAggregationResult::onAggregate(const ResultNode &result) { +void +StandardDeviationAggregationResult::onAggregate(const ResultNode &result) { if (result.isMultiValue()) { static_cast<const ResultNodeVector &>(result).flattenSum(_sum); static_cast<const ResultNodeVector &>(result).flattenSumOfSquared(_sumOfSquared); @@ -542,14 +599,16 @@ void StandardDeviationAggregationResult::onAggregate(const ResultNode &result) { } } -void StandardDeviationAggregationResult::onReset() +void +StandardDeviationAggregationResult::onReset() { _count = 0; _sum.set(0.0); _sumOfSquared.set(0.0); } -Serializer & StandardDeviationAggregationResult::onSerialize(Serializer & os) const +Serializer & +StandardDeviationAggregationResult::onSerialize(Serializer & os) const { AggregationResult::onSerialize(os); double sum = _sum.getFloat(); @@ -557,7 +616,8 @@ Serializer & StandardDeviationAggregationResult::onSerialize(Serializer & os) co return os << _count << sum << sumOfSquared; } -Deserializer & StandardDeviationAggregationResult::onDeserialize(Deserializer & is) +Deserializer & +StandardDeviationAggregationResult::onDeserialize(Deserializer & is) { AggregationResult::onDeserialize(is); double sum; @@ -568,7 +628,8 @@ Deserializer & StandardDeviationAggregationResult::onDeserialize(Deserializer & return r; } -void StandardDeviationAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const +void +StandardDeviationAggregationResult::visitMembers(vespalib::ObjectVisitor &visitor) const { AggregationResult::visitMembers(visitor); visit(visitor, "count", _count); diff --git a/searchlib/src/vespa/searchlib/aggregation/aggregationresult.h b/searchlib/src/vespa/searchlib/aggregation/aggregationresult.h index 765dcf23050..8587511497f 100644 --- a/searchlib/src/vespa/searchlib/aggregation/aggregationresult.h +++ b/searchlib/src/vespa/searchlib/aggregation/aggregationresult.h @@ -39,7 +39,7 @@ public: AggregationResult & operator = (const AggregationResult &); AggregationResult(AggregationResult &&) = default; AggregationResult & operator = (AggregationResult &&) = default; - ~AggregationResult(); + ~AggregationResult() override; class Configure : public vespalib::ObjectOperation, public vespalib::ObjectPredicate { private: @@ -73,7 +73,7 @@ private: void onPrepare(bool preserveAccurateTypes) override { (void) preserveAccurateTypes; } bool onExecute() const override { return true; } - void prepare() { if (getExpression() != NULL) { prepare(&getExpression()->getResult(), false); } } + void prepare() { if (getExpression() != nullptr) { prepare(&getExpression()->getResult(), false); } } void prepare(const ResultNode * result, bool useForInit) { if (result) { onPrepare(*result, useForInit); } } virtual void onPrepare(const ResultNode & result, bool useForInit) = 0; virtual void onMerge(const AggregationResult & b) = 0; diff --git a/searchlib/src/vespa/searchlib/aggregation/averageaggregationresult.h b/searchlib/src/vespa/searchlib/aggregation/averageaggregationresult.h index 3d3395c63fc..96c6c34796a 100644 --- a/searchlib/src/vespa/searchlib/aggregation/averageaggregationresult.h +++ b/searchlib/src/vespa/searchlib/aggregation/averageaggregationresult.h @@ -12,7 +12,7 @@ public: using NumericResultNode = expression::NumericResultNode; DECLARE_AGGREGATIONRESULT(AverageAggregationResult); AverageAggregationResult() : _sum(), _count(0) {} - ~AverageAggregationResult(); + ~AverageAggregationResult() override; void visitMembers(vespalib::ObjectVisitor &visitor) const override; const NumericResultNode & getAverage() const; const NumericResultNode & getSum() const { return *_sum; } diff --git a/searchlib/src/vespa/searchlib/aggregation/sumaggregationresult.h b/searchlib/src/vespa/searchlib/aggregation/sumaggregationresult.h index 7309520c00d..aae77066817 100644 --- a/searchlib/src/vespa/searchlib/aggregation/sumaggregationresult.h +++ b/searchlib/src/vespa/searchlib/aggregation/sumaggregationresult.h @@ -2,24 +2,24 @@ #pragma once #include "aggregationresult.h" -#include <vespa/searchlib/expression/singleresultnode.h> +#include <vespa/searchlib/expression/numericresultnode.h> namespace search::aggregation { class SumAggregationResult : public AggregationResult { public: - using SingleResultNode = expression::SingleResultNode; + using NumericResultNode = expression::NumericResultNode; DECLARE_AGGREGATIONRESULT(SumAggregationResult); SumAggregationResult(); - SumAggregationResult(SingleResultNode::UP sum); - ~SumAggregationResult(); + SumAggregationResult(NumericResultNode::UP sum); + ~SumAggregationResult() override; void visitMembers(vespalib::ObjectVisitor &visitor) const override; - const SingleResultNode & getSum() const { return *_sum; } + const NumericResultNode & getSum() const { return *_sum; } private: const ResultNode & onGetRank() const override { return getSum(); } void onPrepare(const ResultNode & result, bool useForInit) override; - SingleResultNode::CP _sum; + NumericResultNode::CP _sum; }; } diff --git a/searchlib/src/vespa/searchlib/expression/numericresultnode.h b/searchlib/src/vespa/searchlib/expression/numericresultnode.h index f14454e9403..e4c7d11b2d5 100644 --- a/searchlib/src/vespa/searchlib/expression/numericresultnode.h +++ b/searchlib/src/vespa/searchlib/expression/numericresultnode.h @@ -1,10 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once -#include <vespa/searchlib/expression/singleresultnode.h> +#include "singleresultnode.h" -namespace search { -namespace expression { +namespace search::expression { class NumericResultNode : public SingleResultNode { @@ -19,5 +18,3 @@ public: }; } -} - diff --git a/searchlib/src/vespa/searchlib/expression/singleresultnode.h b/searchlib/src/vespa/searchlib/expression/singleresultnode.h index 2417c15934b..663f6f8954f 100644 --- a/searchlib/src/vespa/searchlib/expression/singleresultnode.h +++ b/searchlib/src/vespa/searchlib/expression/singleresultnode.h @@ -3,8 +3,7 @@ #include "resultnode.h" -namespace search { -namespace expression { +namespace search::expression { class SingleResultNode : public ResultNode { @@ -26,5 +25,3 @@ public: }; } -} - diff --git a/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp b/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp index 1f554cb9af7..6eff09b65ab 100644 --- a/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp +++ b/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp @@ -7,8 +7,7 @@ using search::tensor::ITensorAttribute; using vespalib::eval::Tensor; using vespalib::tensor::MutableDenseTensorView; -namespace search { -namespace features { +namespace search::features { DenseTensorAttributeExecutor:: DenseTensorAttributeExecutor(const ITensorAttribute *attribute) @@ -24,5 +23,4 @@ DenseTensorAttributeExecutor::execute(uint32_t docId) outputs().set_object(0, _tensorView); } -} // namespace features -} // namespace search +} diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp index dffa3bb28b5..1dcd3e35580 100644 --- a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp +++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp @@ -12,6 +12,8 @@ #include <type_traits> #include <vespa/log/log.h> +#include <vespa/eval/tensor/serialization/typed_binary_format.h> +#include <vespa/vespalib/objects/nbostream.h> LOG_SETUP(".features.dotproduct"); @@ -340,11 +342,21 @@ ArrayParam<T>::ArrayParam(const Property & prop) { parseVectors(prop, values, indexes); } +template <typename T> +ArrayParam<T>::ArrayParam(vespalib::nbostream & stream) { + vespalib::tensor::TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(stream, values); +} + +template <typename T> +ArrayParam<T>::~ArrayParam() = default; + + // Explicit instantiation since these are inspected by unit tests. // FIXME this feels a bit dirty, consider breaking up ArrayParam to remove dependencies // on templated vector parsing. This is why it's defined in this translation unit as it is. -template struct ArrayParam<int64_t>; +template ArrayParam<int64_t>::ArrayParam(const Property & prop); template struct ArrayParam<double>; +template struct ArrayParam<float>; } // namespace dotproduct @@ -609,43 +621,63 @@ fef::Anything::UP attemptParseArrayQueryVector(const IAttributeVector & attribut } // anon ns +const IAttributeVector * +DotProductBlueprint::upgradeIfNecessary(const IAttributeVector * attribute, const IQueryEnvironment & env) const { + if ((attribute->getCollectionType() == attribute::CollectionType::WSET) && + attribute->hasEnum() && + (attribute->isStringType() || attribute->isIntegerType())) + { + attribute = env.getAttributeContext().getAttributeStableEnum(getAttribute(env)); + } + return attribute; +} + void DotProductBlueprint::prepareSharedState(const IQueryEnvironment & env, IObjectStore & store) const { _attribute = env.getAttributeContext().getAttribute(getAttribute(env)); const IAttributeVector * attribute = _attribute; - if (attribute != nullptr) { - if ((attribute->getCollectionType() == attribute::CollectionType::WSET) && - attribute->hasEnum() && - (attribute->isStringType() || attribute->isIntegerType())) - { - attribute = env.getAttributeContext().getAttributeStableEnum(getAttribute(env)); + if (attribute == nullptr) return; + + attribute = upgradeIfNecessary(attribute, env); + fef::Anything::UP arguments; + if (attribute->getCollectionType() == attribute::CollectionType::ARRAY) { + Property tensorBlob = env.getProperties().lookup(getBaseName(), _queryVector, "tensor"); + if (attribute->isFloatingPointType() && tensorBlob.found() && !tensorBlob.get().empty()) { + const Property::Value & blob = tensorBlob.get(); + vespalib::nbostream stream(blob.data(), blob.size()); + if (attribute->getBasicType() == BasicType::FLOAT) { + arguments = std::make_unique<ArrayParam<float>>(stream); + } else { + arguments = std::make_unique<ArrayParam<double>>(stream); + } + } else { + Property prop = env.getProperties().lookup(getBaseName(), _queryVector); + if (prop.found() && !prop.get().empty()) { + arguments = attemptParseArrayQueryVector(*attribute, prop); + } } + } else if (attribute->getCollectionType() == attribute::CollectionType::WSET) { Property prop = env.getProperties().lookup(getBaseName(), _queryVector); if (prop.found() && !prop.get().empty()) { - fef::Anything::UP arguments; - if (attribute->getCollectionType() == attribute::CollectionType::WSET) { - if (attribute->isStringType() && attribute->hasEnum()) { + if (attribute->isStringType() && attribute->hasEnum()) { + dotproduct::wset::EnumVector vector(attribute); + WeightedSetParser::parse(prop.get(), vector); + } else if (attribute->isIntegerType()) { + if (attribute->hasEnum()) { dotproduct::wset::EnumVector vector(attribute); WeightedSetParser::parse(prop.get(), vector); - } else if (attribute->isIntegerType()) { - if (attribute->hasEnum()) { - dotproduct::wset::EnumVector vector(attribute); - WeightedSetParser::parse(prop.get(), vector); - } else { - dotproduct::wset::IntegerVector vector; - WeightedSetParser::parse(prop.get(), vector); - } + } else { + dotproduct::wset::IntegerVector vector; + WeightedSetParser::parse(prop.get(), vector); } - // TODO actually use the parsed output for wset operations! - } else if (attribute->getCollectionType() == attribute::CollectionType::ARRAY) { - arguments = attemptParseArrayQueryVector(*attribute, prop); - } - if (arguments.get()) { - store.add(getBaseName() + "." + _queryVector + "." + OBJECT, std::move(arguments)); } + // TODO actually use the parsed output for wset operations! } } + if (arguments) { + store.add(getBaseName() + "." + _queryVector + "." + OBJECT, std::move(arguments)); + } } FeatureExecutor & @@ -657,12 +689,7 @@ DotProductBlueprint::createExecutor(const IQueryEnvironment & env, vespalib::Sta getAttribute(env).c_str()); return stash.create<SingleZeroValueExecutor>(); } - if ((attribute->getCollectionType() == attribute::CollectionType::WSET) && - attribute->hasEnum() && - (attribute->isStringType() || attribute->isIntegerType())) - { - attribute = env.getAttributeContext().getAttributeStableEnum(getAttribute(env)); - } + attribute = upgradeIfNecessary(attribute, env); const fef::Anything * argument = env.getObjectStore().get(getBaseName() + "." + _queryVector + "." + OBJECT); if (argument != nullptr) { return createFromObject(attribute, *argument, stash); diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.h b/searchlib/src/vespa/searchlib/features/dotproductfeature.h index b6107a1a271..089066cb5f6 100644 --- a/searchlib/src/vespa/searchlib/features/dotproductfeature.h +++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.h @@ -10,6 +10,7 @@ #include <vespa/vespalib/stllike/hash_map.hpp> namespace search::fef { class Property; } +namespace vespalib { class nbostream; } namespace search::features { @@ -34,6 +35,9 @@ struct Converter<vespalib::string, const char *> { template <typename T> struct ArrayParam : public fef::Anything { ArrayParam(const fef::Property & prop); + ArrayParam(vespalib::nbostream & stream); + ArrayParam(std::vector<T> v) : values(std::move(v)) {} + ~ArrayParam() override; std::vector<T> values; std::vector<uint32_t> indexes; }; @@ -260,12 +264,14 @@ private: */ class DotProductBlueprint : public fef::Blueprint { private: + using IAttributeVector = attribute::IAttributeVector; vespalib::string _defaultAttribute; vespalib::string _queryVector; - mutable const attribute::IAttributeVector * _attribute; + mutable const IAttributeVector * _attribute; vespalib::string getAttribute(const fef::IQueryEnvironment & env) const; + const IAttributeVector * upgradeIfNecessary(const IAttributeVector * attribute, const fef::IQueryEnvironment & env) const; public: DotProductBlueprint(); diff --git a/searchlib/src/vespa/searchlib/features/queryfeature.cpp b/searchlib/src/vespa/searchlib/features/queryfeature.cpp index c5488581d29..eb7eb427283 100644 --- a/searchlib/src/vespa/searchlib/features/queryfeature.cpp +++ b/searchlib/src/vespa/searchlib/features/queryfeature.cpp @@ -3,9 +3,9 @@ #include "queryfeature.h" #include "utils.h" #include "valuefeature.h" +#include "constant_tensor_executor.h" #include <vespa/document/datatype/tensor_data_type.h> -#include <vespa/searchlib/features/constant_tensor_executor.h> #include <vespa/searchlib/fef/featureexecutor.h> #include <vespa/searchlib/fef/indexproperties.h> #include <vespa/searchlib/fef/properties.h> @@ -25,8 +25,7 @@ using document::TensorDataType; using vespalib::eval::ValueType; using search::fef::FeatureType; -namespace search { -namespace features { +namespace search::features { namespace { @@ -65,25 +64,21 @@ QueryBlueprint::QueryBlueprint() : { } -QueryBlueprint::~QueryBlueprint() -{ -} +QueryBlueprint::~QueryBlueprint() = default; void -QueryBlueprint::visitDumpFeatures(const IIndexEnvironment &, - IDumpFeatureVisitor &) const +QueryBlueprint::visitDumpFeatures(const IIndexEnvironment &, IDumpFeatureVisitor &) const { } Blueprint::UP QueryBlueprint::createInstance() const { - return Blueprint::UP(new QueryBlueprint()); + return std::make_unique<QueryBlueprint>(); } bool -QueryBlueprint::setup(const IIndexEnvironment &env, - const ParameterList ¶ms) +QueryBlueprint::setup(const IIndexEnvironment &env, const ParameterList ¶ms) { _key = params[0].getValue(); _key2 = "$"; @@ -107,19 +102,18 @@ QueryBlueprint::setup(const IIndexEnvironment &env, FeatureType output_type = _valueType.is_tensor() ? FeatureType::object(_valueType) : FeatureType::number(); - describeOutput("out", "The value looked up in query properties using the given key.", - output_type); + describeOutput("out", "The value looked up in query properties using the given key.", output_type); return true; } namespace { FeatureExecutor & -createTensorExecutor(const search::fef::IQueryEnvironment &env, +createTensorExecutor(const IQueryEnvironment &env, const vespalib::string &queryKey, const ValueType &valueType, vespalib::Stash &stash) { - search::fef::Property prop = env.getProperties().lookup(queryKey); + Property prop = env.getProperties().lookup(queryKey); if (prop.found() && !prop.get().empty()) { const vespalib::string &value = prop.get(); vespalib::nbostream stream(value.data(), value.size()); @@ -156,5 +150,4 @@ QueryBlueprint::createExecutor(const IQueryEnvironment &env, vespalib::Stash &st } } -} // namespace features -} // namespace search +} diff --git a/searchlib/src/vespa/searchlib/queryeval/fake_result.cpp b/searchlib/src/vespa/searchlib/queryeval/fake_result.cpp index 9786593637e..288c0f5d1d0 100644 --- a/searchlib/src/vespa/searchlib/queryeval/fake_result.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/fake_result.cpp @@ -16,6 +16,9 @@ FakeResult::FakeResult(const FakeResult &) = default; FakeResult::~FakeResult() = default; +FakeResult & +FakeResult::operator=(const FakeResult &) = default; + std::ostream &operator << (std::ostream &out, const FakeResult &result) { const std::vector<FakeResult::Document> &doc = result.inspect(); if (doc.size() == 0) { diff --git a/searchlib/src/vespa/searchlib/queryeval/fake_result.h b/searchlib/src/vespa/searchlib/queryeval/fake_result.h index ecb7dd377b9..ddf1fa61b63 100644 --- a/searchlib/src/vespa/searchlib/queryeval/fake_result.h +++ b/searchlib/src/vespa/searchlib/queryeval/fake_result.h @@ -48,6 +48,7 @@ public: FakeResult(); FakeResult(const FakeResult &); ~FakeResult(); + FakeResult &operator=(const FakeResult &); FakeResult &doc(uint32_t docId) { _documents.push_back(Document(docId)); diff --git a/security-utils/pom.xml b/security-utils/pom.xml index 10dec598915..f7704762250 100644 --- a/security-utils/pom.xml +++ b/security-utils/pom.xml @@ -31,6 +31,16 @@ <artifactId>jackson-databind</artifactId> <scope>compile</scope> </dependency> + <dependency> + <groupId>org.apache.httpcomponents</groupId> + <artifactId>httpclient</artifactId> + <scope>compile</scope> + </dependency> + <dependency> + <groupId>org.apache.httpcomponents</groupId> + <artifactId>httpcore</artifactId> + <scope>compile</scope> + </dependency> <!-- test scope --> <dependency> diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/VespaHttpClientBuilder.java b/security-utils/src/main/java/com/yahoo/security/tls/https/VespaHttpClientBuilder.java new file mode 100644 index 00000000000..9fa51fc36cb --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/https/VespaHttpClientBuilder.java @@ -0,0 +1,109 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.https; + +import com.yahoo.security.tls.MixedMode; +import com.yahoo.security.tls.TlsContext; +import com.yahoo.security.tls.TransportSecurityUtils; +import org.apache.http.HttpRequest; +import org.apache.http.HttpRequestInterceptor; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.conn.HttpClientConnectionManager; +import org.apache.http.conn.ssl.NoopHostnameVerifier; +import org.apache.http.conn.ssl.SSLConnectionSocketFactory; +import org.apache.http.impl.client.HttpClientBuilder; +import org.apache.http.protocol.HttpContext; + +import javax.net.ssl.SSLParameters; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Http client builder for internal Vespa communications over http/https. + * + * Notes: + * - hostname verification is not enabled - CN/SAN verification is assumed to be handled by the underlying x509 trust manager. + * - custom connection managers must be configured through {@link #createBuilder(ConnectionManagerFactory)}. Do not call {@link HttpClientBuilder#setConnectionManager(HttpClientConnectionManager)}. + * + * @author bjorncs + */ +public class VespaHttpClientBuilder { + + private static final Logger log = Logger.getLogger(VespaHttpClientBuilder.class.getName()); + + public interface ConnectionManagerFactory { + HttpClientConnectionManager create(SSLConnectionSocketFactory sslSocketFactory); + } + + private VespaHttpClientBuilder() {} + + public static HttpClientBuilder create() { + return createBuilder(null); + } + + public static HttpClientBuilder create(ConnectionManagerFactory connectionManagerFactory) { + return createBuilder(connectionManagerFactory); + } + + private static HttpClientBuilder createBuilder(ConnectionManagerFactory connectionManagerFactory) { + var builder = HttpClientBuilder.create(); + addSslSocketFactory(builder, connectionManagerFactory); + addTlsAwareRequestInterceptor(builder); + return builder; + } + + private static void addSslSocketFactory(HttpClientBuilder builder, ConnectionManagerFactory connectionManagerFactory) { + TransportSecurityUtils.createTlsContext() + .ifPresent(tlsContext -> { + log.log(Level.FINE, "Adding ssl socket factory to client"); + SSLConnectionSocketFactory socketFactory = createSslSocketFactory(tlsContext); + if (connectionManagerFactory != null) { + builder.setConnectionManager(connectionManagerFactory.create(socketFactory)); + } else { + builder.setSSLSocketFactory(socketFactory); + } + }); + } + + private static void addTlsAwareRequestInterceptor(HttpClientBuilder builder) { + if (TransportSecurityUtils.isTransportSecurityEnabled() + && TransportSecurityUtils.getInsecureMixedMode() != MixedMode.PLAINTEXT_CLIENT_MIXED_SERVER) { + log.log(Level.FINE, "Adding request interceptor to client"); + builder.addInterceptorFirst(new HttpToHttpsRewritingRequestInterceptor()); + } + } + + private static SSLConnectionSocketFactory createSslSocketFactory(TlsContext tlsContext) { + SSLParameters parameters = tlsContext.parameters(); + return new SSLConnectionSocketFactory(tlsContext.context(), parameters.getProtocols(), parameters.getCipherSuites(), new NoopHostnameVerifier()); + } + + static class HttpToHttpsRewritingRequestInterceptor implements HttpRequestInterceptor { + @Override + public void process(HttpRequest request, HttpContext context) { + if (request instanceof HttpRequestBase) { + HttpRequestBase httpUriRequest = (HttpRequestBase) request; + httpUriRequest.setURI(rewriteUri(httpUriRequest.getURI())); + } else { + log.log(Level.FINE, () -> "Not a HttpRequestBase - skipping URI rewriting: " + request.getClass().getName()); + } + } + + private static URI rewriteUri(URI originalUri) { + if (!originalUri.getScheme().equals("http")) { + return originalUri; + } + int port = originalUri.getPort(); + int rewrittenPort = port != -1 ? port : 80; + try { + URI rewrittenUri = new URIBuilder(originalUri).setScheme("https").setPort(rewrittenPort).build(); + log.log(Level.FINE, () -> String.format("Uri rewritten from '%s' to '%s'", originalUri, rewrittenUri)); + return rewrittenUri; + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + } +} diff --git a/security-utils/src/test/java/com/yahoo/security/tls/https/VespaHttpClientBuilderTest.java b/security-utils/src/test/java/com/yahoo/security/tls/https/VespaHttpClientBuilderTest.java new file mode 100644 index 00000000000..10b8458359c --- /dev/null +++ b/security-utils/src/test/java/com/yahoo/security/tls/https/VespaHttpClientBuilderTest.java @@ -0,0 +1,39 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.https; + +import com.yahoo.security.tls.https.VespaHttpClientBuilder.HttpToHttpsRewritingRequestInterceptor; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.protocol.BasicHttpContext; +import org.junit.Test; + +import java.net.URI; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +/** + * @author bjorncs + */ +public class VespaHttpClientBuilderTest { + + @Test + public void request_interceptor_modifies_scheme_of_requests() { + verifyProcessedUriMatchesExpectedOutput("http://dummyhostname:8080/a/path/to/resource?query=value", + "https://dummyhostname:8080/a/path/to/resource?query=value"); + } + + @Test + public void request_interceptor_add_handles_implicit_http_port() { + verifyProcessedUriMatchesExpectedOutput("http://dummyhostname/a/path/to/resource?query=value", + "https://dummyhostname:80/a/path/to/resource?query=value"); + } + + private static void verifyProcessedUriMatchesExpectedOutput(String inputUri, String expectedOutputUri) { + var interceptor = new HttpToHttpsRewritingRequestInterceptor(); + HttpGet request = new HttpGet(inputUri); + interceptor.process(request, new BasicHttpContext()); + URI modifiedUri = request.getURI(); + URI expectedUri = URI.create(expectedOutputUri); + assertThat(modifiedUri).isEqualTo(expectedUri); + } + +}
\ No newline at end of file diff --git a/vespa-documentgen-plugin/etc/complex/music3.sd b/vespa-documentgen-plugin/etc/complex/music3.sd index 65f37029d04..45ce11fd581 100644 --- a/vespa-documentgen-plugin/etc/complex/music3.sd +++ b/vespa-documentgen-plugin/etc/complex/music3.sd @@ -4,5 +4,8 @@ search music3 { field mu3 type string { } + field pos type position { + + } } } diff --git a/vespa-documentgen-plugin/src/main/java/com/yahoo/vespa/DocumentGenMojo.java b/vespa-documentgen-plugin/src/main/java/com/yahoo/vespa/DocumentGenMojo.java index 7e73d6b5915..bc34a4ac3df 100644 --- a/vespa-documentgen-plugin/src/main/java/com/yahoo/vespa/DocumentGenMojo.java +++ b/vespa-documentgen-plugin/src/main/java/com/yahoo/vespa/DocumentGenMojo.java @@ -3,11 +3,14 @@ package com.yahoo.vespa; import com.yahoo.collections.Pair; import com.yahoo.document.ArrayDataType; +import com.yahoo.document.CollectionDataType; import com.yahoo.document.DataType; import com.yahoo.document.Field; import com.yahoo.document.MapDataType; +import com.yahoo.document.PositionDataType; import com.yahoo.document.ReferenceDataType; import com.yahoo.document.StructDataType; +import com.yahoo.document.StructuredDataType; import com.yahoo.document.TensorDataType; import com.yahoo.document.WeightedSetDataType; import com.yahoo.document.annotation.AnnotationReferenceDataType; @@ -18,7 +21,6 @@ import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.parser.ParseException; import org.apache.maven.plugin.AbstractMojo; -import org.apache.maven.plugin.MojoFailureException; import org.apache.maven.plugins.annotations.Component; import org.apache.maven.plugins.annotations.LifecyclePhase; import org.apache.maven.plugins.annotations.Mojo; @@ -31,6 +33,7 @@ import java.io.FilenameFilter; import java.io.IOException; import java.io.Writer; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Date; import java.util.HashMap; @@ -468,6 +471,9 @@ public class DocumentGenMojo extends AbstractMojo { exportHashCode(allUniqueFields, out, 1, "(getDataType() != null ? getDataType().hashCode() : 0) + getId().hashCode()"); exportEquals(className, allUniqueFields, out, 1); Set<DataType> exportedStructs = exportStructTypes(docType.getTypes(), out, 1, null); + if (hasAnyPositionField(allUniqueFields)) { + exportedStructs = exportStructTypes(Arrays.asList(PositionDataType.INSTANCE), out, 1, exportedStructs); + } docTypes.put(docType.getName(), packageName+"."+className); for (DataType exportedStruct : exportedStructs) { structTypes.put(exportedStruct.getName(), packageName+"."+className+"."+className(exportedStruct.getName())); @@ -475,6 +481,25 @@ public class DocumentGenMojo extends AbstractMojo { out.write("}\n"); } + private static boolean hasAnyPostionDataType(DataType dt) { + if (dt instanceof CollectionDataType) { + return hasAnyPostionDataType(((CollectionDataType)dt).getNestedType()); + } else if (dt instanceof StructuredDataType) { + return hasAnyPositionField(((StructuredDataType)dt).getFields()); + } else { + return PositionDataType.INSTANCE.equals(dt); + } + } + + private static boolean hasAnyPositionField(Collection<Field> fields) { + for (Field f : fields) { + if (hasAnyPostionDataType(f.getDataType())) { + return true; + } + } + return true; + } + private Collection<Field> getAllUniqueFields(Boolean multipleInheritance, Collection<Field> allFields) { if (multipleInheritance) { Map<String, Field> seen = new HashMap<>(); @@ -732,7 +757,8 @@ public class DocumentGenMojo extends AbstractMojo { ind(ind)+" * Input struct type: "+structType.getName()+"\n" + ind(ind)+" * Date: "+new Date()+"\n" + ind(ind)+" */\n" + - ind(ind)+"@com.yahoo.document.Generated public static class "+structClassName+" extends com.yahoo.document.datatypes.Struct {\n\n" + + ind(ind)+"@com.yahoo.document.Generated\n" + + ind(ind) + "public static class "+structClassName+" extends com.yahoo.document.datatypes.Struct {\n\n" + ind(ind+1)+"/** The type of this.*/\n" + ind(ind+1)+"public static final com.yahoo.document.StructDataType type = getStructType();\n\n"); out.write(ind(ind+1)+"public "+structClassName+"() {\n" + diff --git a/vespa-documentgen-plugin/src/test/java/com/yahoo/vespa/DocumentGenTest.java b/vespa-documentgen-plugin/src/test/java/com/yahoo/vespa/DocumentGenTest.java index b21f38c586a..c195e116bf0 100644 --- a/vespa-documentgen-plugin/src/test/java/com/yahoo/vespa/DocumentGenTest.java +++ b/vespa-documentgen-plugin/src/test/java/com/yahoo/vespa/DocumentGenTest.java @@ -5,8 +5,6 @@ import com.yahoo.document.DataType; import com.yahoo.document.StructDataType; import com.yahoo.document.WeightedSetDataType; import com.yahoo.searchdefinition.Search; -import org.apache.maven.plugin.MojoExecutionException; -import org.apache.maven.plugin.MojoFailureException; import org.junit.Test; import java.io.File; @@ -19,7 +17,7 @@ import static org.junit.Assert.fail; public class DocumentGenTest { @Test - public void testMusic() throws MojoExecutionException, MojoFailureException { + public void testMusic() { DocumentGenMojo mojo = new DocumentGenMojo(); mojo.execute(new File("etc/music/"), new File("target/generated-test-sources/vespa-documentgen-plugin/"), "com.yahoo.vespa.document"); Map<String, Search> searches = mojo.getSearches(); @@ -28,19 +26,21 @@ public class DocumentGenTest { } @Test - public void testComplex() throws MojoFailureException { + public void testComplex() { DocumentGenMojo mojo = new DocumentGenMojo(); mojo.execute(new File("etc/complex/"), new File("target/generated-test-sources/vespa-documentgen-plugin/"), "com.yahoo.vespa.document"); Map<String, Search> searches = mojo.getSearches(); assertEquals(searches.get("video").getDocument("video").getField("weight").getDataType(), DataType.FLOAT); assertEquals(searches.get("book").getDocument("book").getField("sw1").getDataType(), DataType.FLOAT); + assertTrue(searches.get("music3").getDocument("music3").getField("pos").getDataType() instanceof StructDataType); + assertEquals(searches.get("music3").getDocument("music3").getField("pos").getDataType().getName(), "position"); assertTrue(searches.get("book").getDocument("book").getField("mystruct").getDataType() instanceof StructDataType); assertTrue(searches.get("book").getDocument("book").getField("mywsfloat").getDataType() instanceof WeightedSetDataType); assertTrue(((WeightedSetDataType)(searches.get("book").getDocument("book").getField("mywsfloat").getDataType())).getNestedType() == DataType.FLOAT); } @Test - public void testLocalApp() throws MojoFailureException { + public void testLocalApp() { DocumentGenMojo mojo = new DocumentGenMojo(); mojo.execute(new File("etc/localapp/"), new File("target/generated-test-sources/vespa-documentgen-plugin/"), "com.yahoo.vespa.document"); Map<String, Search> searches = mojo.getSearches(); @@ -51,7 +51,7 @@ public class DocumentGenTest { } @Test - public void testEmptyPkgNameForbidden() throws MojoFailureException { + public void testEmptyPkgNameForbidden() { DocumentGenMojo mojo = new DocumentGenMojo(); try { mojo.execute(new File("etc/localapp/"), new File("target/generated-test-sources/vespa-documentgen-plugin/"), ""); diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 239efa0f89c..43388e4e18d 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -947,7 +947,7 @@ "public java.lang.String toString()", "public boolean equals(java.lang.Object)", "public long denseSubspaceSize()", - "public static com.yahoo.tensor.TensorType createPartialType(java.util.List)" + "public static com.yahoo.tensor.TensorType createPartialType(com.yahoo.tensor.TensorType$Value, java.util.List)" ], "fields": [] }, @@ -1162,11 +1162,10 @@ ], "methods": [ "public void <init>()", - "public void <init>(com.yahoo.tensor.TensorType$ValueType)", + "public void <init>(com.yahoo.tensor.TensorType$Value)", "public varargs void <init>(com.yahoo.tensor.TensorType[])", - "public varargs void <init>(com.yahoo.tensor.TensorType$ValueType, com.yahoo.tensor.TensorType[])", "public void <init>(java.lang.Iterable)", - "public void <init>(com.yahoo.tensor.TensorType$ValueType, java.lang.Iterable)", + "public void <init>(com.yahoo.tensor.TensorType$Value, java.lang.Iterable)", "public int rank()", "public com.yahoo.tensor.TensorType$Builder set(com.yahoo.tensor.TensorType$Dimension)", "public com.yahoo.tensor.TensorType$Builder indexed(java.lang.String, long)", @@ -1270,7 +1269,7 @@ ], "fields": [] }, - "com.yahoo.tensor.TensorType$ValueType": { + "com.yahoo.tensor.TensorType$Value": { "superClass": "java.lang.Enum", "interfaces": [], "attributes": [ @@ -1279,12 +1278,14 @@ "enum" ], "methods": [ - "public static com.yahoo.tensor.TensorType$ValueType[] values()", - "public static com.yahoo.tensor.TensorType$ValueType valueOf(java.lang.String)" + "public static com.yahoo.tensor.TensorType$Value[] values()", + "public static com.yahoo.tensor.TensorType$Value valueOf(java.lang.String)", + "public static com.yahoo.tensor.TensorType$Value largestOf(java.util.List)", + "public static com.yahoo.tensor.TensorType$Value largestOf(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType$Value)" ], "fields": [ - "public static final enum com.yahoo.tensor.TensorType$ValueType DOUBLE", - "public static final enum com.yahoo.tensor.TensorType$ValueType FLOAT" + "public static final enum com.yahoo.tensor.TensorType$Value DOUBLE", + "public static final enum com.yahoo.tensor.TensorType$Value FLOAT" ] }, "com.yahoo.tensor.TensorType": { @@ -1294,9 +1295,8 @@ "public" ], "methods": [ - "public final com.yahoo.tensor.TensorType$ValueType valueType()", - "public final com.yahoo.tensor.TensorType valueType(com.yahoo.tensor.TensorType$ValueType)", "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)", + "public com.yahoo.tensor.TensorType$Value valueType()", "public int rank()", "public java.util.List dimensions()", "public java.util.Set dimensionNames()", @@ -1325,7 +1325,7 @@ "methods": [ "public void <init>()", "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)", - "public static java.util.List dimensionsFromSpec(java.lang.String)" + "public static com.yahoo.tensor.TensorType$Value toValueType(java.lang.String)" ], "fields": [] }, diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 08878edeb83..c06cb2a0986 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -319,7 +319,7 @@ public class MixedTensor implements Tensor { } public TensorType createBoundType() { - TensorType.Builder typeBuilder = new TensorType.Builder(); + TensorType.Builder typeBuilder = new TensorType.Builder(type().valueType()); for (int i = 0; i < type.dimensions().size(); ++i) { TensorType.Dimension dimension = type.dimensions().get(i); if (!dimension.isIndexed()) { @@ -355,8 +355,8 @@ public class MixedTensor implements Tensor { this.type = type; this.mappedDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList()); this.indexedDimensions = type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList()); - this.sparseType = createPartialType(mappedDimensions); - this.denseType = createPartialType(indexedDimensions); + this.sparseType = createPartialType(type.valueType(), mappedDimensions); + this.denseType = createPartialType(type.valueType(), indexedDimensions); } public long indexOf(TensorAddress address) { @@ -476,8 +476,8 @@ public class MixedTensor implements Tensor { } - public static TensorType createPartialType(List<TensorType.Dimension> dimensions) { - TensorType.Builder builder = new TensorType.Builder(); + public static TensorType createPartialType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) { + TensorType.Builder builder = new TensorType.Builder(valueType); for (TensorType.Dimension dimension : dimensions) { builder.set(dimension); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index fa32d385004..45a9992c9ad 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -11,14 +11,14 @@ class TensorParser { static Tensor tensorFrom(String tensorString, Optional<TensorType> type) { tensorString = tensorString.trim(); try { - if (tensorString.startsWith("tensor(")) { + if (tensorString.startsWith("tensor")) { int colonIndex = tensorString.indexOf(':'); String typeString = tensorString.substring(0, colonIndex); String valueString = tensorString.substring(colonIndex + 1); TensorType typeFromString = TensorTypeParser.fromSpec(typeString); if (type.isPresent() && ! type.get().equals(typeFromString)) throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " + - "passed type " + type); + "passed type " + type.get()); return tensorFromValueString(valueString, typeFromString); } else if (tensorString.startsWith("{")) { @@ -48,7 +48,7 @@ class TensorParser { addressBody = addressBody.substring(1); // remove key start if (addressBody.isEmpty()) return TensorType.empty; // Empty key - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(TensorType.Value.DOUBLE); for (String elementString : addressBody.split(",")) { String[] pair = elementString.split(":"); if (pair.length != 2) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 036f5e3ee5d..df78f3dfc3a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -4,6 +4,7 @@ package com.yahoo.tensor; import com.google.common.collect.ImmutableList; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.LinkedHashMap; @@ -24,25 +25,40 @@ import java.util.stream.Collectors; */ public class TensorType { - public enum ValueType { DOUBLE, FLOAT}; + /** The permissible cell value types. Default is double. */ + public enum Value { - /** The empty tensor type - which is the same as a double */ - public static final TensorType empty = new TensorType(ValueType.DOUBLE, Collections.emptyList()); + // Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below + DOUBLE, FLOAT; - private ValueType valueType; + public static Value largestOf(List<Value> values) { + if (values.isEmpty()) return Value.DOUBLE; // Default + Value largest = null; + for (Value value : values) { + if (largest == null) + largest = value; + else + largest = largestOf(largest, value); + } + return largest; + } - public final ValueType valueType() { return valueType; } + public static Value largestOf(Value value1, Value value2) { + if (value1 == DOUBLE || value2 == DOUBLE) return DOUBLE; + return FLOAT; + } - //TODO Remove once value type is wired in were it should. - public final TensorType valueType(ValueType valueType) { - this.valueType = valueType; - return this; - } + }; + + /** The empty tensor type - which is the same as a double */ + public static final TensorType empty = new TensorType(Value.DOUBLE, Collections.emptyList()); + + private final Value valueType; /** Sorted list of the dimensions of this */ private final ImmutableList<Dimension> dimensions; - private TensorType(ValueType valueType, Collection<Dimension> dimensions) { + private TensorType(Value valueType, Collection<Dimension> dimensions) { this.valueType = valueType; List<Dimension> dimensionList = new ArrayList<>(dimensions); Collections.sort(dimensionList); @@ -64,6 +80,9 @@ public class TensorType { return TensorTypeParser.fromSpec(specString); } + /** Returns the numeric type of the cell values of this */ + public Value valueType() { return valueType; } + /** Returns the number of dimensions of this: dimensions().size() */ public int rank() { return dimensions.size(); } @@ -149,10 +168,14 @@ public class TensorType { } @Override - public boolean equals(Object other) { - if (this == other) return true; - if (other == null || getClass() != other.getClass()) return false; - return dimensions.equals(((TensorType)other).dimensions); + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + TensorType other = (TensorType)o; + if ( this.valueType != other.valueType) return false; + if ( ! this.dimensions.equals(other.dimensions)) return false; + return true; } /** Returns whether the given type has the same dimension names as this */ @@ -173,7 +196,7 @@ public class TensorType { if (this.equals(other)) return Optional.of(this); // shortcut if (this.dimensions.size() != other.dimensions.size()) return Optional.empty(); - Builder b = new Builder(); + Builder b = new Builder(TensorType.Value.largestOf(valueType, other.valueType)); for (int i = 0; i < dimensions.size(); i++) { Dimension thisDim = this.dimensions().get(i); Dimension otherDim = other.dimensions().get(i); @@ -386,14 +409,14 @@ public class TensorType { private final Map<String, Dimension> dimensions = new LinkedHashMap<>(); - private final ValueType valueType; + private final Value valueType; - /** Creates an empty builder with cells of type double*/ + /** Creates an empty builder with cells of type double */ public Builder() { - this(ValueType.DOUBLE); + this(Value.DOUBLE); } - public Builder(ValueType valueType) { + public Builder(Value valueType) { this.valueType = valueType; } @@ -403,23 +426,22 @@ public class TensorType { * If the same dimension is indexed with different size restrictions the largest size will be used. * If it is size restricted in one argument but not the other it will not be size restricted. * If it is indexed in one and mapped in the other it will become mapped. + * + * The value type will be the largest of the value types of the input types */ public Builder(TensorType ... types) { - this(ValueType.DOUBLE, types); - } - public Builder(ValueType valueType, TensorType ... types) { - this.valueType = valueType; + this.valueType = TensorType.Value.largestOf(Arrays.stream(types).map(type -> type.valueType()).collect(Collectors.toList())); for (TensorType type : types) addDimensionsOf(type); } - /** - * Creates a builder from the given dimensions. - */ + /** Creates a builder from the given dimensions, having double as the value type */ public Builder(Iterable<Dimension> dimensions) { - this(ValueType.DOUBLE, dimensions); + this(Value.DOUBLE, dimensions); } - public Builder(ValueType valueType, Iterable<Dimension> dimensions) { + + /** Creates a builder from the given value type and dimensions */ + public Builder(Value valueType, Iterable<Dimension> dimensions) { this.valueType = valueType; for (TensorType.Dimension dimension : dimensions) { dimension(dimension); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java index 32ad6171e57..d5f77be0dd0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java @@ -2,8 +2,10 @@ package com.yahoo.tensor; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -11,26 +13,36 @@ import java.util.regex.Pattern; * Class for parsing a tensor type spec. * * @author geirst + * @author bratseth */ public class TensorTypeParser { - private final static String START_STRING = "tensor("; + private final static String START_STRING = "tensor"; private final static String END_STRING = ")"; private static final Pattern indexedPattern = Pattern.compile("(\\w+)\\[(\\d*)\\]"); private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}"); public static TensorType fromSpec(String specString) { - return new TensorType.Builder(dimensionsFromSpec(specString)).build(); - } + if ( ! specString.startsWith(START_STRING) || ! specString.endsWith(END_STRING)) + throw formatException(specString); + String specBody = specString.substring(START_STRING.length(), specString.length() - END_STRING.length()); - public static List<TensorType.Dimension> dimensionsFromSpec(String specString) { - if ( ! specString.startsWith(START_STRING) || !specString.endsWith(END_STRING)) { - throw new IllegalArgumentException("Tensor type spec must start with '" + START_STRING + "'" + - " and end with '" + END_STRING + "', but was '" + specString + "'"); + String dimensionsSpec; + TensorType.Value valueType; + if (specBody.startsWith("(")) { + valueType = TensorType.Value.DOUBLE; // no value type spec: Use default + dimensionsSpec = specBody.substring(1); + } + else { + int parenthesisIndex = specBody.indexOf("("); + if (parenthesisIndex < 0) + throw formatException(specString); + valueType = parseValueTypeSpec(specBody.substring(0, parenthesisIndex), specString); + dimensionsSpec = specBody.substring(parenthesisIndex + 1); } - String dimensionsSpec = specString.substring(START_STRING.length(), specString.length() - END_STRING.length()); - if (dimensionsSpec.isEmpty()) return Collections.emptyList(); + + if (dimensionsSpec.isEmpty()) return new TensorType.Builder(valueType, Collections.emptyList()).build(); List<TensorType.Dimension> dimensions = new ArrayList<>(); for (String element : dimensionsSpec.split(",")) { @@ -38,10 +50,30 @@ public class TensorTypeParser { boolean success = tryParseIndexedDimension(trimmedElement, dimensions) || tryParseMappedDimension(trimmedElement, dimensions); if ( ! success) - throw new IllegalArgumentException("Failed parsing element '" + element + - "' in type spec '" + specString + "'"); + throw formatException(specString, "Dimension '" + element + "' is on the wrong format"); + } + return new TensorType.Builder(valueType, dimensions).build(); + } + + public static TensorType.Value toValueType(String valueTypeString) { + switch (valueTypeString) { + case "double" : return TensorType.Value.DOUBLE; + case "float" : return TensorType.Value.FLOAT; + default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" + + " but was '" + valueTypeString + "'"); + } + } + + private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) { + if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">")) + throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>")); + + try { + return toValueType(valueTypeSpec.substring(1, valueTypeSpec.length() - 1)); + } + catch (IllegalArgumentException e) { + throw formatException(fullSpecString, e.getMessage()); } - return dimensions; } private static boolean tryParseIndexedDimension(String element, List<TensorType.Dimension> dimensions) { @@ -69,5 +101,21 @@ public class TensorTypeParser { return false; } + + private static IllegalArgumentException formatException(String spec) { + return formatException(spec, Optional.empty()); + } + + private static IllegalArgumentException formatException(String spec, String errorDetail) { + return formatException(spec, Optional.of(errorDetail)); + } + + private static IllegalArgumentException formatException(String spec, Optional<String> errorDetail) { + throw new IllegalArgumentException("A tensor type spec must be on the form " + + "tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was '" + spec + "'. " + + errorDetail.map(s -> s + ". ").orElse("") + + "Examples: tensor(x[]), tensor<float>(name{}, x[10])"); + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 91ab4f9d046..a48ac19fbff 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -73,8 +73,8 @@ public class Concat extends PrimitiveTensorFunction { MutableLong concatSize = new MutableLong(0); a.sizeOfDimension(dimension).ifPresent(concatSize::add); b.sizeOfDimension(dimension).ifPresent(concatSize::add); - builder.set(TensorType.Dimension.indexed(dimension, concatSize.get())); - */ + builder.set(TensorType.Dimension.indexed(dimension, concatSize.get())); + */ } return builder.build(); } @@ -141,7 +141,11 @@ public class Concat extends PrimitiveTensorFunction { if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) throw new IllegalArgumentException("Concat requires an indexed tensor, " + "but got a tensor with type " + tensor.type()); - Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build(); + Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(tensor.type().valueType()) + .indexed(dimensionName, 1) + .build()) + .cell(1,0) + .build(); return tensor.multiply(unitTensor); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 62ee471fcf4..062e0d92e80 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -386,13 +386,12 @@ public class Join extends PrimitiveTensorFunction { return true; } - /** - * Returns common dimension of a and b as a new tensor type - */ + /** Returns common dimension of a and b as a new tensor type */ private static TensorType commonDimensions(Tensor a, Tensor b) { - TensorType.Builder typeBuilder = new TensorType.Builder(); TensorType aType = a.type(); TensorType bType = b.type(); + TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.Value.largestOf(aType.valueType(), + bType.valueType())); for (int i = 0; i < aType.dimensions().size(); ++i) { TensorType.Dimension aDim = aType.dimensions().get(i); for (int j = 0; j < bType.dimensions().size(); ++j) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 54d7710c9dc..017dc3920e6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -61,8 +61,8 @@ public class Reduce extends PrimitiveTensorFunction { } public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) { - if (reduceDimensions.isEmpty()) return TensorType.empty; // means reduce all - TensorType.Builder b = new TensorType.Builder(); + TensorType.Builder b = new TensorType.Builder(inputType.valueType()); + if (reduceDimensions.isEmpty()) return b.build(); // means reduce all for (TensorType.Dimension dimension : inputType.dimensions()) { if ( ! reduceDimensions.contains(dimension.name())) b.dimension(dimension); @@ -109,8 +109,8 @@ public class Reduce extends PrimitiveTensorFunction { } private static TensorType type(TensorType argumentType, List<String> dimensions) { - if (dimensions.isEmpty()) return TensorType.empty; // means reduce all - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(argumentType.valueType()); + if (dimensions.isEmpty()) return builder.build(); // means reduce all for (TensorType.Dimension dimension : argumentType.dimensions()) if ( ! dimensions.contains(dimension.name())) // keep builder.dimension(dimension); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index b268e33b418..db950e6c8b9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -268,7 +268,8 @@ public class ReduceJoin extends CompositeTensorFunction { } private TensorType dimensionsInCommon(IndexedTensor a, IndexedTensor b) { - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(TensorType.Value.largestOf(a.type().valueType(), + b.type().valueType())); for (TensorType.Dimension aDim : a.type().dimensions()) { for (TensorType.Dimension bDim : b.type().dimensions()) { if (aDim.name().equals(bDim.name())) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index e18af235d59..5694684956e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -75,7 +75,7 @@ public class Rename extends PrimitiveTensorFunction { } private TensorType type(TensorType type) { - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(type.valueType()); for (TensorType.Dimension dimension : type.dimensions()) builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); return builder.build(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index 500c436516f..ecd4f7d1965 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -43,7 +43,7 @@ public class DenseBinaryFormat implements BinaryFormat { encodeCells(buffer, tensor); } - private void encodeValueType(GrowableByteBuffer buffer, TensorType.ValueType valueType) { + private void encodeValueType(GrowableByteBuffer buffer, TensorType.Value valueType) { switch (valueType) { case DOUBLE: if (encodeType != EncodeType.DOUBLE_IS_DEFAULT) { @@ -100,7 +100,7 @@ public class DenseBinaryFormat implements BinaryFormat { sizes = sizesFromType(serializedType); } else { - type = decodeType(buffer, TensorType.ValueType.DOUBLE); + type = decodeType(buffer, TensorType.Value.DOUBLE); sizes = sizesFromType(type); } Tensor.Builder builder = Tensor.Builder.of(type, sizes); @@ -108,16 +108,16 @@ public class DenseBinaryFormat implements BinaryFormat { return builder.build(); } - private TensorType decodeType(GrowableByteBuffer buffer, TensorType.ValueType valueType) { - TensorType.ValueType serializedValueType = TensorType.ValueType.DOUBLE; - if ((valueType != TensorType.ValueType.DOUBLE) || (encodeType != EncodeType.DOUBLE_IS_DEFAULT)) { + private TensorType decodeType(GrowableByteBuffer buffer, TensorType.Value valueType) { + TensorType.Value serializedValueType = TensorType.Value.DOUBLE; + if ((valueType != TensorType.Value.DOUBLE) || (encodeType != EncodeType.DOUBLE_IS_DEFAULT)) { int type = buffer.getInt1_4Bytes(); switch (type) { case DOUBLE_VALUE_TYPE: - serializedValueType = TensorType.ValueType.DOUBLE; + serializedValueType = TensorType.Value.DOUBLE; break; case FLOAT_VALUE_TYPE: - serializedValueType = TensorType.ValueType.FLOAT; + serializedValueType = TensorType.Value.FLOAT; break; default: throw new IllegalArgumentException("Received tensor value type '" + serializedValueType + "'. Only 0(double), or 1(float) are legal."); @@ -141,7 +141,7 @@ public class DenseBinaryFormat implements BinaryFormat { return builder.build(); } - private void decodeCells(TensorType.ValueType valueType, DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { + private void decodeCells(TensorType.Value valueType, DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { switch (valueType) { case DOUBLE: decodeCellsAsDouble(sizes, buffer, builder); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java index acaeb3ef5ba..284dfea2141 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java @@ -78,7 +78,7 @@ class MixedBinaryFormat implements BinaryFormat { TensorType serializedType = decodeType(buffer); if ( ! serializedType.isAssignableTo(type)) throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + - " cannot be assigned to type " + type); + " cannot be assigned to type " + type); } else { type = decodeType(buffer); @@ -103,7 +103,7 @@ class MixedBinaryFormat implements BinaryFormat { private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type) { List<TensorType.Dimension> sparseDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList()); - TensorType sparseType = MixedTensor.createPartialType(sparseDimensions); + TensorType sparseType = MixedTensor.createPartialType(type.valueType(), sparseDimensions); long denseSubspaceSize = builder.denseSubspaceSize(); int numBlocks = 1; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index f2c5d4e2bd8..9b298f1dffb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -23,7 +23,9 @@ public class TypedBinaryFormat { private static final int SPARSE_BINARY_FORMAT_TYPE = 1; private static final int DENSE_BINARY_FORMAT_TYPE = 2; private static final int MIXED_BINARY_FORMAT_TYPE = 3; - private static final int TYPED_DENSE_BINARY_FORMAT_TYPE = 4; + private static final int SPARSE_BINARY_FORMAT_WITH_CELLTYPE = 5; + private static final int DENSE_BINARY_FORMAT_WITH_CELLTYPE = 6; + private static final int MIXED_BINARY_FORMAT_WITH_CELLTYPE = 7; public static byte[] encode(Tensor tensor) { GrowableByteBuffer buffer = new GrowableByteBuffer(); @@ -38,7 +40,7 @@ public class TypedBinaryFormat { new DenseBinaryFormat(DenseBinaryFormat.EncodeType.DOUBLE_IS_DEFAULT).encode(buffer, tensor); break; default: - buffer.putInt1_4Bytes(TYPED_DENSE_BINARY_FORMAT_TYPE); + buffer.putInt1_4Bytes(DENSE_BINARY_FORMAT_WITH_CELLTYPE); new DenseBinaryFormat(DenseBinaryFormat.EncodeType.NO_DEFAULT).encode(buffer, tensor); break; } @@ -67,7 +69,7 @@ public class TypedBinaryFormat { case MIXED_BINARY_FORMAT_TYPE: return new MixedBinaryFormat().decode(type, buffer); case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat().decode(type, buffer); case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat(DenseBinaryFormat.EncodeType.DOUBLE_IS_DEFAULT).decode(type, buffer); - case TYPED_DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat(DenseBinaryFormat.EncodeType.NO_DEFAULT).decode(type, buffer); + case DENSE_BINARY_FORMAT_WITH_CELLTYPE: return new DenseBinaryFormat(DenseBinaryFormat.EncodeType.NO_DEFAULT).decode(type, buffer); default: throw new IllegalArgumentException("Binary format type " + formatType + " is unknown"); } } diff --git a/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java b/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java index 9602bdb8d94..f6fed9d33ed 100644 --- a/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java @@ -69,16 +69,6 @@ public class BoundingBoxParserTestCase { all1234(parser); } - /** - * Tests various legal inputs and print the output - */ - @Test - public void testPrint() { - String here = "n=63.418417 E=10.433033 S=37.7 W=-122.02"; - parser = new BoundingBoxParser(here); - System.out.println(here+" -> "+parser); - } - @Test public void testGeoPlanetExample() { /* example XML: diff --git a/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java index e8ceab44c78..7cf4bddaa01 100644 --- a/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java @@ -57,7 +57,6 @@ public class BinaryFormatTestCase { @Test public void testZigZagConversion() { - System.out.println("test zigzag conversion"); assertThat(encode_zigzag(0), is((long)0)); assertThat(decode_zigzag(encode_zigzag(0)), is(0L)); @@ -88,7 +87,6 @@ public class BinaryFormatTestCase { @Test public void testDoubleConversion() { - System.out.println("test double conversion"); assertThat(encode_double(0.0), is(0L)); assertThat(decode_double(encode_double(0.0)), is(0.0)); @@ -116,7 +114,6 @@ public class BinaryFormatTestCase { @Test public void testTypeAndMetaMangling() { - System.out.println("test type and meta mangling"); for (byte type = 0; type < TYPE_LIMIT; ++type) { for (int meta = 0; meta < META_LIMIT; ++meta) { byte mangled = encode_type_and_meta(type, meta); @@ -126,10 +123,8 @@ public class BinaryFormatTestCase { } } - // was testCmprUlong @Test - public void testCmprLong() { - System.out.println("test compressed long"); + public void testCompressedLong() { { long value = 0; byte[] wanted = { 0 }; @@ -217,11 +212,8 @@ public class BinaryFormatTestCase { // testWriteBytes -> buffered IO test // testReadByte -> buffered IO test // testReadBytes -> buffered IO test - @Test - public void testTypeAndSize() { - System.out.println("test type and size conversion"); - + public void testTypeAndSizeConversion() { for (byte type = 0; type < TYPE_LIMIT; ++type) { for (long size = 0; size < 500; ++size) { BufferedOutput expect = new BufferedOutput(); @@ -271,8 +263,7 @@ public class BinaryFormatTestCase { } @Test - public void testTypeAndBytes() { - System.out.println("test encoding and decoding of type and bytes"); + public void testEncodingAndDecodingOfTypeAndBytes() { for (byte type = 0; type < TYPE_LIMIT; ++type) { for (int n = 0; n < MAX_NUM_SIZE; ++n) { for (int pre = 0; (pre == 0) || (pre < n); ++pre) { @@ -307,9 +298,7 @@ public class BinaryFormatTestCase { } @Test - public void testEmpty() { - System.out.println("test encoding empty slime"); - + public void testEncodingEmptySlime() { Slime slime = new Slime(); BufferedOutput expect = new BufferedOutput(); expect.put((byte)0); // num symbols @@ -321,8 +310,7 @@ public class BinaryFormatTestCase { } @Test - public void testBasic() { - System.out.println("test encoding slime holding a single basic value"); + public void testEncodingSlimeHoldingASingleBasicValue() { { Slime slime = new Slime(); slime.setBool(false); @@ -427,8 +415,7 @@ public class BinaryFormatTestCase { } @Test - public void testArray() { - System.out.println("test encoding slime holding an array of various basic values"); + public void testEncodingSlimeArray() { Slime slime = new Slime(); Cursor c = slime.setArray(); byte[] data = { 'd', 'a', 't', 'a' }; @@ -452,8 +439,7 @@ public class BinaryFormatTestCase { } @Test - public void testObject() { - System.out.println("test encoding slime holding an object of various basic values"); + public void testEncodingSlimeObject() { Slime slime = new Slime(); Cursor c = slime.setObject(); byte[] data = { 'd', 'a', 't', 'a' }; @@ -478,8 +464,7 @@ public class BinaryFormatTestCase { } @Test - public void testNesting() { - System.out.println("test encoding slime holding a more complex structure"); + public void testEncodingComplexSlimeStructure() { Slime slime = new Slime(); Cursor c1 = slime.setObject(); c1.setLong("bar", 10); @@ -503,8 +488,7 @@ public class BinaryFormatTestCase { } @Test - public void testSymbolReuse() { - System.out.println("test encoding slime reusing symbols"); + public void testEncodingSlimeReusingSymbols() { Slime slime = new Slime(); Cursor c1 = slime.setArray(); { @@ -533,8 +517,7 @@ public class BinaryFormatTestCase { } @Test - public void testOptionalDecodeOrder() { - System.out.println("test decoding slime with different symbol order"); + public void testDecodingSlimeWithDifferentSymbolOrder() { byte[] data = { 5, // num symbols 1, 'd', 1, 'e', 1, 'f', 1, 'b', 1, 'c', // symbol table @@ -564,4 +547,5 @@ public class BinaryFormatTestCase { assertThat(c.field("f").asData(), is(expd)); assertThat(c.entry(5).valid(), is(false)); // not ARRAY } + } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java index f7a0a3cdb7d..d3bb702175a 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java @@ -58,10 +58,13 @@ public class TensorTypeTestCase { @Test public void requireThatIllegalSyntaxInSpecThrowsException() { - assertIllegalTensorType("foo(x[10])", "Tensor type spec must start with 'tensor(' and end with ')', but was 'foo(x[10])'"); - assertIllegalTensorType("tensor(x_@[10])", "Failed parsing element 'x_@[10]' in type spec 'tensor(x_@[10])'"); - assertIllegalTensorType("tensor(x[10a])", "Failed parsing element 'x[10a]' in type spec 'tensor(x[10a])'"); - assertIllegalTensorType("tensor(x{10})", "Failed parsing element 'x{10}' in type spec 'tensor(x{10})'"); + assertIllegalTensorType("foo(x[10])", "but was 'foo(x[10])'."); + assertIllegalTensorType("tensor(x_@[10])", "Dimension 'x_@[10]' is on the wrong format"); + assertIllegalTensorType("tensor(x[10a])", "Dimension 'x[10a]' is on the wrong format"); + assertIllegalTensorType("tensor(x{10})", "Dimension 'x{10}' is on the wrong format"); + assertIllegalTensorType("tensor<(x{})", " Value type spec must be enclosed in <>"); + assertIllegalTensorType("tensor<>(x{})", "Value type must be"); + assertIllegalTensorType("tensor<notavalue>(x{})", "Value type must be"); } @Test @@ -88,6 +91,13 @@ public class TensorTypeTestCase { assertIsConvertibleTo("tensor(x{},y[10])", "tensor(x{},y[])"); } + @Test + public void testValueType() { + assertValueType(TensorType.Value.DOUBLE, "tensor(x[])"); + assertValueType(TensorType.Value.DOUBLE, "tensor<double>(x[])"); + assertValueType(TensorType.Value.FLOAT, "tensor<float>(x[])"); + } + private static void assertTensorType(String typeSpec) { assertTensorType(typeSpec, typeSpec); } @@ -121,4 +131,8 @@ public class TensorTypeTestCase { assertFalse(TensorType.fromSpec(specificType).isConvertibleTo(TensorType.fromSpec(generalType))); } + private void assertValueType(TensorType.Value expectedValueType, String tensorTypeSpec) { + assertEquals(expectedValueType, TensorType.fromSpec(tensorTypeSpec).valueType()); + } + } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java index e8b17812f32..5d1bc7b0c3f 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java @@ -55,7 +55,7 @@ public class DenseBinaryFormatTestCase { @Test public void requireThatFloatSerializationFormatDoNotChange() { - byte[] encodedTensor = new byte[]{4, // binary format type + byte[] encodedTensor = new byte[]{6, // binary format type 1, // float type 2, // dimension count 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size @@ -63,27 +63,21 @@ public class DenseBinaryFormatTestCase { 64, 0, 0, 0, // value 1 64, 64, 0, 0, // value 2 }; - Tensor tensor = Tensor.from("tensor(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}"); - tensor.type().valueType(TensorType.ValueType.FLOAT); - assertEquals(Arrays.toString(encodedTensor), - Arrays.toString(TypedBinaryFormat.encode(tensor))); + Tensor tensor = Tensor.from("tensor<float>(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}"); + assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor))); } @Test public void testSerializationOfDifferentValueTypes() { - assertSerialization(TensorType.ValueType.DOUBLE, "tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); - assertSerialization(TensorType.ValueType.FLOAT, "tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor<double>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor<float>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); } private void assertSerialization(String tensorString) { - assertSerialization(TensorType.ValueType.DOUBLE, Tensor.from(tensorString)); - } - private void assertSerialization(TensorType.ValueType valueType, String tensorString) { - assertSerialization(valueType, Tensor.from(tensorString)); + assertSerialization(Tensor.from(tensorString)); } - private void assertSerialization(TensorType.ValueType valueType, Tensor tensor) { - tensor.type().valueType(valueType); + private void assertSerialization(Tensor tensor) { assertSerialization(tensor, tensor.type()); } diff --git a/vespalog/src/vespa/log/log.cpp b/vespalog/src/vespa/log/log.cpp index 8e3ed9a18ba..a43c6dd0416 100644 --- a/vespalog/src/vespa/log/log.cpp +++ b/vespalog/src/vespa/log/log.cpp @@ -379,7 +379,7 @@ Logger::doEventProgress(const char *name, double value, double total) void Logger::doEventCount(const char *name, uint64_t value) { - doLog(event, "", 0, "count/1 name=\"%s\" value=%lu", name, value); + doLog(event, "", 0, "count/1 name=\"%s\" value=%" PRIu64, name, value); } void diff --git a/vespalog/src/vespa/log/log_message.cpp b/vespalog/src/vespa/log/log_message.cpp index 77f9b619e9f..8ce7df93a12 100644 --- a/vespalog/src/vespa/log/log_message.cpp +++ b/vespalog/src/vespa/log/log_message.cpp @@ -31,18 +31,29 @@ find_tab(std::string_view log_line, const char *tab_name, std::string_view::size } int64_t -parse_time_field(std::string time_field) +parse_time_subfield(std::string time_subfield, const std::string &time_field) { - std::istringstream time_stream(time_field); - time_stream.imbue(clocale); - double logtime = 0; - time_stream >> logtime; - if (!time_stream.eof()) { + std::istringstream subfield_stream(time_subfield); + subfield_stream.imbue(clocale); + int64_t result = 0; + subfield_stream >> result; + if (!subfield_stream.eof()) { std::ostringstream os; os << "Bad time field: " << time_field; throw BadLogLineException(os.str()); } - return logtime * 1000000000; + return result; +} + +int64_t +parse_time_field(std::string time_field) +{ + auto dotPos = time_field.find('.'); + int64_t log_time = parse_time_subfield(time_field.substr(0, dotPos), time_field) * 1000000000; + if (dotPos != std::string::npos) { + log_time += parse_time_subfield((time_field.substr(dotPos + 1) + "000000000").substr(0, 9), time_field); + } + return log_time; } struct PidFieldParser |