diff options
59 files changed, 3156 insertions, 516 deletions
diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/FleetController.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/FleetController.java index ba35243c14d..364184331a8 100644 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/FleetController.java +++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/FleetController.java @@ -661,6 +661,12 @@ public class FleetController implements NodeStateOrHostInfoChangeHandler, NodeAd } private boolean broadcastClusterStateToEligibleNodes() { + // If there's a pending DB store we have not yet been able to store the + // current state bundle to ZK and must therefore _not_ allow it to be published. + if (database.hasPendingClusterStateMetaDataStore()) { + log.log(LogLevel.DEBUG, "Can't publish current cluster state as it has one or more pending ZooKeeper stores"); + return false; + } boolean sentAny = false; // Give nodes a fair chance to respond first time to state gathering requests, so we don't // disturb system when we take over. Allow anyways if we have states from all nodes. diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/database/DatabaseHandler.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/database/DatabaseHandler.java index f2b1b523aba..f30b86130c2 100644 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/database/DatabaseHandler.java +++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/database/DatabaseHandler.java @@ -352,6 +352,15 @@ public class DatabaseHandler { doNextZooKeeperTask(context); } + // TODO should we expand this to cover _any_ pending ZK write? + public boolean hasPendingClusterStateMetaDataStore() { + synchronized (databaseMonitor) { + return ((zooKeeperAddress != null) && + ((pendingStore.clusterStateBundle != null) || + (pendingStore.lastSystemStateVersion != null))); + } + } + public ClusterStateBundle getLatestClusterStateBundle() throws InterruptedException { log.log(LogLevel.DEBUG, () -> String.format("Fleetcontroller %d: Retrieving latest cluster state bundle from ZooKeeper", nodeIndex)); synchronized (databaseMonitor) { diff --git a/config-model/src/main/perl/vespa-deploy b/config-model/src/main/perl/vespa-deploy index ffde937bea0..fede8b994c1 100755 --- a/config-model/src/main/perl/vespa-deploy +++ b/config-model/src/main/perl/vespa-deploy @@ -87,7 +87,7 @@ readConfFile(); use strict; use warnings; use feature qw(switch say); -use vars qw/ $opt_c $opt_h $opt_n $opt_v $opt_f $opt_t $opt_a $opt_e $opt_E $opt_r $opt_i $opt_p $opt_z $opt_H $opt_R $opt_F $opt_V /; +use vars qw/ $opt_c $opt_h $opt_n $opt_v $opt_f $opt_t $opt_a $opt_e $opt_E $opt_r $opt_i $opt_p $opt_H $opt_R $opt_F $opt_V /; use Env qw($HOME); use JSON; use Getopt::Std; @@ -101,9 +101,6 @@ my $configsource_url_used_file = "$cloudconfig_dir/deploy-configsource-url-used" my $pathPrefix; -my $siaPath; -my $siaCertsPath; -my $siaKeysPath; my $tenant = "default"; my $application = "default"; my $environment = "prod"; @@ -112,8 +109,7 @@ my $instance = "default"; my $version = "v2"; my $configserver = ""; my $port = "19071"; -my $cert = ""; -getopts('c:fhnt:ve:E:r:a:i:p:z:HR:F:V:'); +getopts('c:fhnt:ve:E:r:a:i:p:HR:F:V:'); if ($opt_h) { usage(); @@ -148,18 +144,8 @@ if ($opt_p) { $port = $opt_p; } -if ($opt_z) { - $cert = $opt_z; -} - $pathPrefix = "/application/v2/tenant/$tenant/session"; -$siaPath = "/var/lib/sia/"; - -$siaCertsPath = $siaPath . "certs/"; - -$siaKeysPath = $siaPath . "keys/"; - create_cloudconfig_dir(); $session_id_file = "$cloudconfig_dir/$tenant/deploy-session-id"; @@ -167,10 +153,7 @@ $session_id_file = "$cloudconfig_dir/$tenant/deploy-session-id"; my $command = shift; $command ||= "help"; -my $curl_command = 'curl -A vespa-deploy --silent --show-error --connect-timeout 30 --max-time 1200'; -if ($cert) { - $curl_command = $curl_command . " -k --cert " . $siaCertsPath . $cert . ".cert.pem --key " . $siaKeysPath . $cert . ".key.pem "; -} +my $curl_command = $VESPA_HOME . '/libexec/vespa/vespa-curl-wrapper -A vespa-deploy --silent --show-error --connect-timeout 30 --max-time 1200'; my $CURL_PUT = $curl_command . ' --write-out \%{http_code} --request PUT'; my $CURL_GET = $curl_command . ' --request GET'; @@ -264,8 +247,6 @@ sub usage { print " '-t <timeout>' (timeout in seconds)\n"; print " '-c <server>' (config server hostname)\n"; print " '-p <port>' (config server http port)\n"; - print " '-z <cert>' (cert/key name)\n\n"; - print "Try 'vespa-deploy help <command>' to get more help\n"; } @@ -347,11 +328,7 @@ sub get_configsource_url { my @configsources; if ($configserver and $configserver ne "") { - if ($cert and $cert ne "") { - @configsources = ('https://' . $configserver . ':' . $port . '/'); - } else { - @configsources = ('http://' . $configserver . ':' . $port . '/'); - } + @configsources = ('http://' . $configserver . ':' . $port . '/'); } else { @configsources = split(' ', `$VESPA_HOME/bin/vespa-print-default configservers_http`); } diff --git a/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/filedistribution/FileDistributionAndUrlDownload.java b/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/filedistribution/FileDistributionAndUrlDownload.java index 4eef3c40df4..0b7de6ed562 100644 --- a/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/filedistribution/FileDistributionAndUrlDownload.java +++ b/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/filedistribution/FileDistributionAndUrlDownload.java @@ -4,11 +4,8 @@ package com.yahoo.vespa.config.proxy.filedistribution; import com.yahoo.config.subscription.ConfigSourceSet; import com.yahoo.jrt.Supervisor; import com.yahoo.vespa.config.JRTConnectionPool; -import com.yahoo.vespa.filedistribution.FileDistributionRpcServer; import com.yahoo.vespa.filedistribution.FileDownloader; -import java.util.stream.Stream; - /** * Keeps track of file distribution and url download rpc servers. * diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDistributionRpcServer.java b/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/filedistribution/FileDistributionRpcServer.java index cc76eef014f..0d72a1f02b6 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDistributionRpcServer.java +++ b/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/filedistribution/FileDistributionRpcServer.java @@ -1,5 +1,5 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.filedistribution; +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.config.proxy.filedistribution; import com.yahoo.concurrent.DaemonThreadFactory; import com.yahoo.config.FileReference; @@ -11,6 +11,8 @@ import com.yahoo.jrt.StringArray; import com.yahoo.jrt.StringValue; import com.yahoo.jrt.Supervisor; import com.yahoo.log.LogLevel; +import com.yahoo.vespa.filedistribution.FileDownloader; +import com.yahoo.vespa.filedistribution.FileReferenceDownload; import java.io.File; import java.util.Arrays; @@ -27,7 +29,7 @@ import java.util.stream.Collectors; * * @author hmusum */ -public class FileDistributionRpcServer { +class FileDistributionRpcServer { private final static Logger log = Logger.getLogger(FileDistributionRpcServer.class.getName()); @@ -36,13 +38,13 @@ public class FileDistributionRpcServer { private final ExecutorService rpcDownloadExecutor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors(), new DaemonThreadFactory("Rpc executor")); - public FileDistributionRpcServer(Supervisor supervisor, FileDownloader downloader) { + FileDistributionRpcServer(Supervisor supervisor, FileDownloader downloader) { this.supervisor = supervisor; this.downloader = downloader; declareFileDistributionMethods(); } - public void close() { + void close() { rpcDownloadExecutor.shutdownNow(); try { rpcDownloadExecutor.awaitTermination(10, TimeUnit.SECONDS); @@ -78,8 +80,6 @@ public class FileDistributionRpcServer { private static final int baseFileProviderErrorCode = baseErrorCode + 0x1000; private static final int fileReferenceDoesNotExists = baseFileProviderErrorCode; - private static final int fileReferenceRemoved = fileReferenceDoesNotExists + 1; - private static final int fileReferenceInternalError = fileReferenceRemoved + 1; private void getFile(Request req) { req.detach(); @@ -116,19 +116,16 @@ public class FileDistributionRpcServer { private void downloadFile(Request req) { FileReference fileReference = new FileReference(req.parameters().get(0).asString()); log.log(LogLevel.DEBUG, () -> "getFile() called for file reference '" + fileReference.value() + "'"); - Optional<File> pathToFile = downloader.getFile(fileReference); - try { - if (pathToFile.isPresent()) { - req.returnValues().add(new StringValue(pathToFile.get().getAbsolutePath())); - log.log(LogLevel.DEBUG, () -> "File reference '" + fileReference.value() + "' available at " + pathToFile.get()); - } else { - log.log(LogLevel.INFO, "File reference '" + fileReference.value() + "' not found, returning error"); - req.setError(fileReferenceDoesNotExists, "File reference '" + fileReference.value() + "' not found"); - } - } catch (Throwable e) { - log.log(LogLevel.WARNING, "File reference '" + fileReference.value() + "' got exception: " + e.getMessage()); - req.setError(fileReferenceInternalError, "File reference '" + fileReference.value() + "' removed"); + Optional<File> file = downloader.getFile(fileReference); + if (file.isPresent()) { + new RequestTracker().trackRequest(file.get().getParentFile()); + req.returnValues().add(new StringValue(file.get().getAbsolutePath())); + log.log(LogLevel.DEBUG, () -> "File reference '" + fileReference.value() + "' available at " + file.get()); + } else { + log.log(LogLevel.INFO, "File reference '" + fileReference.value() + "' not found, returning error"); + req.setError(fileReferenceDoesNotExists, "File reference '" + fileReference.value() + "' not found"); } + req.returnRequest(); } diff --git a/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/filedistribution/RequestTracker.java b/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/filedistribution/RequestTracker.java new file mode 100644 index 00000000000..47f478ea4d7 --- /dev/null +++ b/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/filedistribution/RequestTracker.java @@ -0,0 +1,30 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.config.proxy.filedistribution; + +import com.yahoo.log.LogLevel; + +import java.io.File; +import java.time.Instant; +import java.util.logging.Logger; + +/** + * Set last modification time for a file reference or downloaded url, to be able + * to later clean up file references or urls not used for a long time. + * + * @author hmusum + */ +class RequestTracker { + + private final static Logger log = Logger.getLogger(RequestTracker.class.getName()); + + void trackRequest(File file) { + String absolutePath = file.getAbsolutePath(); + if ( ! file.exists()) + log.log(LogLevel.WARNING, "Could not find file '" + absolutePath + "'"); + + if ( ! file.setLastModified(Instant.now().toEpochMilli())) + log.log(LogLevel.WARNING, "Could not set last modified timestamp for '" + absolutePath + "'"); + } + +} diff --git a/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/filedistribution/UrlDownloadRpcServer.java b/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/filedistribution/UrlDownloadRpcServer.java index 9d89f1d10b2..cdf079631fe 100644 --- a/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/filedistribution/UrlDownloadRpcServer.java +++ b/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/filedistribution/UrlDownloadRpcServer.java @@ -42,7 +42,7 @@ class UrlDownloadRpcServer { private final static Logger log = Logger.getLogger(UrlDownloadRpcServer.class.getName()); private static final String CONTENTS_FILE_NAME = "contents"; - private static final String LAST_MODFIED_FILE_NAME = "lastmodified"; + private static final String LAST_MODIFIED_FILE_NAME = "lastmodified"; private final File downloadBaseDir; private final ExecutorService rpcDownloadExecutor = Executors.newFixedThreadPool(Math.max(8, Runtime.getRuntime().availableProcessors()), @@ -110,16 +110,17 @@ class UrlDownloadRpcServer { if (contentsPath.exists() && contentsPath.length() > 0) { writeLastModifiedTimestamp(downloadDir, connection.getLastModified()); + new RequestTracker().trackRequest(downloadDir); req.returnValues().add(new StringValue(contentsPath.getAbsolutePath())); log.log(LogLevel.DEBUG, () -> "URL '" + url + "' available at " + contentsPath); + log.log(LogLevel.INFO, String.format("Download of URL '%s' done in %.3f seconds", + url, (System.currentTimeMillis() -start) / 1000.0)); } else { log.log(LogLevel.ERROR, "Downloaded URL '" + url + "' not found, returning error"); req.setError(DOES_NOT_EXIST, "Downloaded '" + url + "' not found"); } } } - long end = System.currentTimeMillis(); - log.log(LogLevel.INFO, String.format("Download of URL '%s' done in %.3f seconds", url, (end-start) / 1000.0)); } private static String urlToDirName(String uri) { @@ -137,7 +138,7 @@ class UrlDownloadRpcServer { } private static long readLastModifiedTimestamp(File downloadDir) throws IOException { - File lastModified = new File(downloadDir, LAST_MODFIED_FILE_NAME); + File lastModified = new File(downloadDir, LAST_MODIFIED_FILE_NAME); if (lastModified.exists() && lastModified.length() > 0) { try (BufferedReader br = new BufferedReader(new FileReader(lastModified))) { String timestamp = br.readLine(); @@ -148,7 +149,7 @@ class UrlDownloadRpcServer { } private static void writeLastModifiedTimestamp(File downloadDir, long timestamp) throws IOException { - File lastModified = new File(downloadDir, LAST_MODFIED_FILE_NAME); + File lastModified = new File(downloadDir, LAST_MODIFIED_FILE_NAME); try (BufferedWriter lastModifiedWriter = new BufferedWriter(new FileWriter(lastModified.getAbsolutePath()))) { lastModifiedWriter.write(Long.toString(timestamp)); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java index 4d0df545c39..d55e07540d6 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java @@ -67,7 +67,7 @@ public class ConfigConvergenceChecker extends AbstractComponent { /** Check all services in given application. Returns the minimum current generation of all services */ public ServiceListResponse servicesToCheck(Application application, URI requestUrl, Duration timeoutPerService) { - log.log(LogLevel.INFO, "Finding services to check config convergence for in '" + application); + log.log(LogLevel.DEBUG, () -> "Finding services to check config convergence for in '" + application); List<ServiceInfo> servicesToCheck = new ArrayList<>(); application.getModel().getHosts() .forEach(host -> host.getServices().stream() diff --git a/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java b/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java index 61ce9d98e69..b2f5d104890 100644 --- a/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java @@ -73,14 +73,13 @@ public class BlendingSearcher extends Searcher { } /** - * Produce a single blended result list from a group of hitgroups. + * Produce a single blended hit list from a group of hitgroups. * - * It is assumed that the results are ordered in hitgroups. If not, the blend will not be performed + * This assumes that all hits are organized into hitgroups. If not, blending will not be performed. */ protected Result blendResults(Result result, Query q, int offset, int hits, Execution execution) { //Assert that there are more than one hitgroup and that there are only hitgroups on the lowest level - boolean foundNonGroup = false; Iterator<Hit> hitIterator = result.hits().iterator(); List<HitGroup> groups = new ArrayList<>(); @@ -89,14 +88,14 @@ public class BlendingSearcher extends Searcher { if (hit instanceof HitGroup) { groups.add((HitGroup)hit); hitIterator.remove(); - } else if(!hit.isMeta()) { + } else if ( ! hit.isMeta()) { foundNonGroup = true; } } - if(foundNonGroup) { + if( foundNonGroup) { result.hits().addError(ErrorMessage.createUnspecifiedError("Blendingsearcher could not blend - there are toplevel hits" + - " that are not hitgroups")); + " that are not hitgroups")); return result; } if (groups.size() == 0) { diff --git a/container-search/src/main/java/com/yahoo/vespa/streamingvisitors/VdsVisitor.java b/container-search/src/main/java/com/yahoo/vespa/streamingvisitors/VdsVisitor.java index 5288b28cad1..32b48f0f8ae 100644 --- a/container-search/src/main/java/com/yahoo/vespa/streamingvisitors/VdsVisitor.java +++ b/container-search/src/main/java/com/yahoo/vespa/streamingvisitors/VdsVisitor.java @@ -41,6 +41,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Logger; /** @@ -90,8 +91,16 @@ class VdsVisitor extends VisitorDataHandler implements Visitor { } private static class MessageBusVisitorSessionFactory implements VisitorSessionFactory { - private static final LoadTypeSet loadTypes = new LoadTypeSet("client"); - private static final DocumentAccess access = new MessageBusDocumentAccess(new MessageBusParams(loadTypes)); + private static final Object initMonitor = new Object(); + private static final AtomicReference<MessageBusVisitorSessionFactory> instance = new AtomicReference<>(); + + private final LoadTypeSet loadTypes; + private final DocumentAccess access; + + private MessageBusVisitorSessionFactory() { + loadTypes = new LoadTypeSet("client"); + access = new MessageBusDocumentAccess(new MessageBusParams(loadTypes)); + } @Override public VisitorSession createVisitorSession(VisitorParameters params) throws ParseException { @@ -102,10 +111,32 @@ class VdsVisitor extends VisitorDataHandler implements Visitor { public LoadTypeSet getLoadTypeSet() { return loadTypes; } + + /** + * Returns a single, shared instance of this class which is lazily created in a thread-safe + * manner the first time this method is invoked. + * + * May throw any config-related exception if subscription fails. + */ + static MessageBusVisitorSessionFactory sharedInstance() { + var ref = instance.getAcquire(); + if (ref != null) { + return ref; + } + synchronized (initMonitor) { + ref = instance.getAcquire(); + if (ref != null) { + return ref; + } + ref = new MessageBusVisitorSessionFactory(); + instance.setRelease(ref); + } + return ref; + } } public VdsVisitor(Query query, String searchCluster, Route route, String documentType) { - this(query, searchCluster, route, documentType, new MessageBusVisitorSessionFactory()); + this(query, searchCluster, route, documentType, MessageBusVisitorSessionFactory.sharedInstance()); } public VdsVisitor(Query query, String searchCluster, Route route, diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/NGramExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/NGramExpression.java index 816d6a62fd0..d91338e3d3f 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/NGramExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/NGramExpression.java @@ -48,6 +48,10 @@ public final class NGramExpression extends Expression { @Override protected void doExecute(ExecutionContext ctx) { StringFieldValue input = (StringFieldValue)ctx.getValue(); + if (input.getSpanTree(SpanTrees.LINGUISTICS) != null) { + // This expression is already executed for this input instance + return; + } SpanList spanList = input.setSpanTree(new SpanTree(SpanTrees.LINGUISTICS)).spanList(); int lastPosition = 0; for (Iterator<GramSplitter.Gram> it = linguistics.getGramSplitter().split(input.getString(), gramSize); it.hasNext();) { diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/expressions/NGramTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/expressions/NGramTestCase.java index bad1407c7c1..0b217d5ba9a 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/expressions/NGramTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/expressions/NGramTestCase.java @@ -77,6 +77,22 @@ public class NGramTestCase { assertFalse(i.hasNext()); } + @Test + public void requireThatExecuteCanBeCalledMultipleTimes() { + ExecutionContext context = new ExecutionContext(new SimpleTestAdapter()); + context.setValue(new StringFieldValue("some random text string")); + NGramExpression expression = new NGramExpression(new SimpleLinguistics(), 3); + + expression.execute(context); + SpanTree firstTree = ((StringFieldValue)context.getValue()).getSpanTree(SpanTrees.LINGUISTICS); + assertNotNull(firstTree); + + expression.execute(context); + SpanTree secondTree = ((StringFieldValue)context.getValue()).getSpanTree(SpanTrees.LINGUISTICS); + // The span tree instance should be the same. + assertEquals(firstTree, secondTree); + } + private void assertSpan(int from, int length, boolean gram, Iterator<SpanNode> i, SpanTree tree) { assertSpan(from, length, gram, i, tree, null); } diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/ExternalMetrics.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/ExternalMetrics.java index 64ede137e8e..017b2c57370 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/ExternalMetrics.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/ExternalMetrics.java @@ -35,12 +35,13 @@ import static java.util.stream.Collectors.toCollection; public class ExternalMetrics { private static final Logger log = Logger.getLogger(ExternalMetrics.class.getName()); + // NOTE: node service id must be kept in sync with the same constant _value_ used in docker-api:Metrics.java + public static final ServiceId VESPA_NODE_SERVICE_ID = toServiceId("vespa.node"); + public static final DimensionId ROLE_DIMENSION = toDimensionId("role"); public static final DimensionId STATE_DIMENSION = toDimensionId("state"); public static final DimensionId ORCHESTRATOR_STATE_DIMENSION = toDimensionId("orchestratorState"); - public static final ServiceId VESPA_NODE_SERVICE_ID = toServiceId("vespa.node"); - private volatile List<MetricsPacket.Builder> metrics = new ArrayList<>(); private final MetricsConsumers consumers; @@ -58,7 +59,6 @@ public class ExternalMetrics { log.log(DEBUG, () -> "Setting new external metrics with " + externalPackets.size() + " metrics packets."); externalPackets.forEach(packet -> { packet.addConsumers(consumers.getAllConsumers()) - .service(VESPA_NODE_SERVICE_ID) .retainMetrics(metricsToRetain()) .applyOutputNames(outputNamesById()); }); diff --git a/metrics-proxy/src/test/java/ai/vespa/metricsproxy/core/MetricsManagerTest.java b/metrics-proxy/src/test/java/ai/vespa/metricsproxy/core/MetricsManagerTest.java index e441c353292..bc83712ac70 100644 --- a/metrics-proxy/src/test/java/ai/vespa/metricsproxy/core/MetricsManagerTest.java +++ b/metrics-proxy/src/test/java/ai/vespa/metricsproxy/core/MetricsManagerTest.java @@ -15,6 +15,7 @@ import ai.vespa.metricsproxy.metric.dimensions.NodeDimensions; import ai.vespa.metricsproxy.metric.dimensions.NodeDimensionsConfig; import ai.vespa.metricsproxy.metric.model.DimensionId; import ai.vespa.metricsproxy.metric.model.MetricsPacket; +import ai.vespa.metricsproxy.metric.model.ServiceId; import ai.vespa.metricsproxy.service.DownService; import ai.vespa.metricsproxy.service.DummyService; import ai.vespa.metricsproxy.service.VespaService; @@ -38,6 +39,7 @@ import static ai.vespa.metricsproxy.metric.model.ServiceId.toServiceId; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -162,6 +164,21 @@ public class MetricsManagerTest { } @Test + public void application_from_extra_metrics_packets_is_used_as_service_in_result_packets() { + final ServiceId serviceId = toServiceId("custom-service"); + metricsManager.setExtraMetrics(ImmutableList.of( + new MetricsPacket.Builder(serviceId) + .putMetrics(ImmutableList.of(new Metric(WHITELISTED_METRIC_ID, 0))))); + + List<MetricsPacket> packets = metricsManager.getMetrics(testServices, Instant.EPOCH); + MetricsPacket extraPacket = null; + for (MetricsPacket packet : packets) { + if (packet.service.equals(serviceId)) extraPacket = packet; + } + assertNotNull(extraPacket); + } + + @Test public void extra_dimensions_are_added_to_metrics_packets_that_do_not_have_those_dimensions() { metricsManager.setExtraMetrics(ImmutableList.of( new MetricsPacket.Builder(toServiceId("foo")) diff --git a/metrics-proxy/src/test/java/ai/vespa/metricsproxy/http/GenericMetricsHandlerTest.java b/metrics-proxy/src/test/java/ai/vespa/metricsproxy/http/GenericMetricsHandlerTest.java index 29ab8c66694..dc89e5bb9f2 100644 --- a/metrics-proxy/src/test/java/ai/vespa/metricsproxy/http/GenericMetricsHandlerTest.java +++ b/metrics-proxy/src/test/java/ai/vespa/metricsproxy/http/GenericMetricsHandlerTest.java @@ -37,6 +37,7 @@ import java.util.concurrent.Executors; import static ai.vespa.metricsproxy.core.VespaMetrics.INSTANCE_DIMENSION_ID; import static ai.vespa.metricsproxy.http.GenericMetricsHandler.DEFAULT_PUBLIC_CONSUMER_ID; +import static ai.vespa.metricsproxy.metric.ExternalMetrics.VESPA_NODE_SERVICE_ID; import static ai.vespa.metricsproxy.metric.model.ServiceId.toServiceId; import static ai.vespa.metricsproxy.metric.model.StatusCode.DOWN; import static ai.vespa.metricsproxy.metric.model.json.JacksonUtil.createObjectMapper; @@ -74,7 +75,7 @@ public class GenericMetricsHandlerTest { public static void setup() { MetricsManager metricsManager = TestUtil.createMetricsManager(vespaServices, getMetricsConsumers(), getApplicationDimensions(), getNodeDimensions()); metricsManager.setExtraMetrics(ImmutableList.of( - new MetricsPacket.Builder(toServiceId("foo")) + new MetricsPacket.Builder(VESPA_NODE_SERVICE_ID) .timestamp(Instant.now().getEpochSecond()) .putMetrics(ImmutableList.of(new Metric(CPU_METRIC, 12.345))))); GenericMetricsHandler handler = new GenericMetricsHandler(Executors.newSingleThreadExecutor(), metricsManager, vespaServices, getMetricsConsumers()); diff --git a/metrics-proxy/src/test/java/ai/vespa/metricsproxy/metric/ExternalMetricsTest.java b/metrics-proxy/src/test/java/ai/vespa/metricsproxy/metric/ExternalMetricsTest.java index 11c271d46e4..2cce2f66039 100644 --- a/metrics-proxy/src/test/java/ai/vespa/metricsproxy/metric/ExternalMetricsTest.java +++ b/metrics-proxy/src/test/java/ai/vespa/metricsproxy/metric/ExternalMetricsTest.java @@ -8,6 +8,7 @@ import ai.vespa.metricsproxy.core.ConsumersConfig; import ai.vespa.metricsproxy.core.MetricsConsumers; import ai.vespa.metricsproxy.metric.model.ConsumerId; import ai.vespa.metricsproxy.metric.model.MetricsPacket; +import ai.vespa.metricsproxy.metric.model.ServiceId; import com.google.common.collect.ImmutableList; import org.junit.Test; @@ -38,15 +39,17 @@ public class ExternalMetricsTest { } @Test - public void service_id_is_set_to_vespa_node_id() { + public void service_id_from_extra_packets_is_not_replaced() { + final ServiceId SERVICE_ID = toServiceId("do-not-replace"); + MetricsConsumers noConsumers = new MetricsConsumers(new ConsumersConfig.Builder().build()); ExternalMetrics externalMetrics = new ExternalMetrics(noConsumers); externalMetrics.setExtraMetrics(ImmutableList.of( - new MetricsPacket.Builder(toServiceId("replace_with_vespa_node_id")))); + new MetricsPacket.Builder(SERVICE_ID))); List<MetricsPacket.Builder> packets = externalMetrics.getMetrics(); assertEquals(1, packets.size()); - assertEquals(VESPA_NODE_SERVICE_ID, packets.get(0).build().service); + assertEquals(SERVICE_ID, packets.get(0).build().service); } @Test diff --git a/metrics-proxy/src/test/java/ai/vespa/metricsproxy/rpc/RpcMetricsTest.java b/metrics-proxy/src/test/java/ai/vespa/metricsproxy/rpc/RpcMetricsTest.java index d4777618546..d6084e3e03a 100644 --- a/metrics-proxy/src/test/java/ai/vespa/metricsproxy/rpc/RpcMetricsTest.java +++ b/metrics-proxy/src/test/java/ai/vespa/metricsproxy/rpc/RpcMetricsTest.java @@ -17,7 +17,9 @@ import com.yahoo.jrt.Transport; import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import java.util.List; @@ -34,6 +36,8 @@ import static org.hamcrest.CoreMatchers.notNullValue; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; /** * @author jobergum @@ -41,13 +45,60 @@ import static org.junit.Assert.assertThat; */ public class RpcMetricsTest { - private static final String METRICS_RESPONSE_CCL = - getFileContents("metrics-storage-simple.json").trim(); + private static final String METRICS_RESPONSE = getFileContents("metrics-storage-simple.json").trim(); + private static final String EXTRA_APP = "extra"; + + private static class RpcClient implements AutoCloseable { + private final Supervisor supervisor; + private final Target target; + + RpcClient(int port) { + supervisor = new Supervisor(new Transport()); + target = supervisor.connect(new Spec("localhost", port)); + } + + @Override + public void close() { + target.close(); + supervisor.transport().shutdown().join(); + } + } + + @Test + public void extra_metrics_are_added_to_output() throws Exception { + String extraMetricsPayload = "{\"timestamp\":1557754772,\"application\":\"" + EXTRA_APP + + "\",\"metrics\":{\"foo.count\":3},\"dimensions\":{\"role\":\"extra-role\"}}"; + + try (IntegrationTester tester = new IntegrationTester()) { + try (RpcClient rpcClient = new RpcClient(tester.rpcPort())) { + Request req = new Request("setExtraMetrics"); + req.parameters().add(new StringValue(extraMetricsPayload)); + invoke(req, rpcClient, false); + String allServicesResponse = getMetricsForYamas(ALL_SERVICES, rpcClient).trim(); + + // Verify that application is used as serviceId, and that metric exists. + JSONObject extraMetrics = findExtraMetricsObject(allServicesResponse); + assertThat(extraMetrics.getJSONObject("metrics").getInt("foo.count"), is(3)); + assertThat(extraMetrics.getJSONObject("dimensions").getString("role"), is("extra-role")); + } + } + } + + private JSONObject findExtraMetricsObject(String jsonResponse) throws JSONException { + JSONArray metrics = new JSONObject(jsonResponse).getJSONArray("metrics"); + for (int i = 0; i < metrics.length(); i++) { + JSONObject jsonObject = metrics.getJSONObject(i); + assertTrue(jsonObject.has("application")); + if (jsonObject.getString("application").equals(EXTRA_APP)) return jsonObject; + } + fail("Metrics from setExtraMetrics was missing."); + throw new RuntimeException(); + } @Test public void testGetMetrics() throws Exception { try (IntegrationTester tester = new IntegrationTester()) { - tester.httpServer().setResponse(METRICS_RESPONSE_CCL); + tester.httpServer().setResponse(METRICS_RESPONSE); List<VespaService> services = tester.vespaServices().getInstancesById(SERVICE_1_CONFIG_ID); assertThat("#Services should be 1 for config id " + SERVICE_1_CONFIG_ID, services.size(), is(1)); @@ -62,34 +113,29 @@ public class RpcMetricsTest { Metric m2 = metrics.getMetric("bar.count"); assertNotNull("Did not find expected metric with name 'bar.count'", m2); - // Setup RPC client - Supervisor supervisor = new Supervisor(new Transport()); - Target target = supervisor.connect(new Spec("localhost", tester.rpcPort())); + try (RpcClient rpcClient = new RpcClient(tester.rpcPort())) { + verifyMetricsFromRpcRequest(qrserver, rpcClient); - verifyMetricsFromRpcRequest(qrserver, target); + services = tester.vespaServices().getInstancesById(SERVICE_2_CONFIG_ID); + assertThat("#Services should be 1 for config id " + SERVICE_2_CONFIG_ID, services.size(), is(1)); - services = tester.vespaServices().getInstancesById(SERVICE_2_CONFIG_ID); - assertThat("#Services should be 1 for config id " + SERVICE_2_CONFIG_ID, services.size(), is(1)); + VespaService storageService = services.get(0); + verfiyMetricsFromServiceObject(storageService); - VespaService storageService = services.get(0); - verfiyMetricsFromServiceObject(storageService); + String metricsById = getMetricsById(storageService.getConfigId(), rpcClient); + assertThat(metricsById, is("'storage.cluster.storage.storage.0'.foo_count=1 ")); - String metricsById = getMetricsById(storageService.getConfigId(), target); - assertThat(metricsById, is("'storage.cluster.storage.storage.0'.foo_count=1 ")); + String jsonResponse = getMetricsForYamas("non-existing", rpcClient).trim(); + assertThat(jsonResponse, is("105: No service with name 'non-existing'")); - String jsonResponse = getMetricsForYamas("non-existing", target).trim(); - assertThat(jsonResponse, is("105: No service with name 'non-existing'")); + verifyMetricsFromRpcRequestForAllServices(rpcClient); - verifyMetricsFromRpcRequestForAllServices(target); - - // Shutdown RPC - target.close(); - supervisor.transport().shutdown().join(); + } } } - private static void verifyMetricsFromRpcRequest(VespaService service, Target target) throws JSONException { - String jsonResponse = getMetricsForYamas(service.getMonitoringName(), target).trim(); + private static void verifyMetricsFromRpcRequest(VespaService service, RpcClient client) throws JSONException { + String jsonResponse = getMetricsForYamas(service.getMonitoringName(), client).trim(); JSONArray metrics = new JSONObject(jsonResponse).getJSONArray("metrics"); assertThat("Expected 3 metric messages", metrics.length(), is(3)); for (int i = 0; i < metrics.length() - 1; i++) { // The last "metric message" contains only status code/message @@ -124,18 +170,18 @@ public class RpcMetricsTest { assertThat("Metric foo did not contain correct dimension for key = bar", foo.getDimensions().get(toDimensionId("bar")), is("foo")); } - private void verifyMetricsFromRpcRequestForAllServices(Target target) throws JSONException { + private void verifyMetricsFromRpcRequestForAllServices(RpcClient client) throws JSONException { // Verify that metrics for all services can be retrieved in one request. - String allServicesResponse = getMetricsForYamas(ALL_SERVICES, target).trim(); + String allServicesResponse = getMetricsForYamas(ALL_SERVICES, client).trim(); JSONArray allServicesMetrics = new JSONObject(allServicesResponse).getJSONArray("metrics"); assertThat(allServicesMetrics.length(), is(5)); } @Test - public void testGetAllMetricNames() { + public void testGetAllMetricNames() throws Exception { try (IntegrationTester tester = new IntegrationTester()) { - tester.httpServer().setResponse(METRICS_RESPONSE_CCL); + tester.httpServer().setResponse(METRICS_RESPONSE); List<VespaService> services = tester.vespaServices().getInstancesById(SERVICE_1_CONFIG_ID); assertThat(services.size(), is(1)); @@ -144,52 +190,48 @@ public class RpcMetricsTest { Metric m = metrics.getMetric("foo.count"); assertNotNull("Did not find expected metric with name 'foo.count'", m); - Metric m2 = metrics.getMetric("bar.count"); assertNotNull("Did not find expected metric with name 'bar'", m2); - // Setup RPC - Supervisor supervisor = new Supervisor(new Transport()); - Target target = supervisor.connect(new Spec("localhost", tester.rpcPort())); - - String response = getAllMetricNamesForService(services.get(0).getMonitoringName(), VESPA_CONSUMER_ID, target); - assertThat(response, is("foo.count=ON;output-name=foo_count,bar.count=OFF,")); - - // Shutdown RPC - target.close(); - supervisor.transport().shutdown().join(); + try (RpcClient rpcClient = new RpcClient(tester.rpcPort())) { + String response = getAllMetricNamesForService(services.get(0).getMonitoringName(), VESPA_CONSUMER_ID, rpcClient); + assertThat(response, is("foo.count=ON;output-name=foo_count,bar.count=OFF,")); + } } } - private static String getMetricsForYamas(String service, Target target) { + private static String getMetricsForYamas(String service, RpcClient client) { Request req = new Request("getMetricsForYamas"); req.parameters().add(new StringValue(service)); - return invoke(req, target); + return invoke(req, client, true); } - private String getMetricsById(String service, Target target) { + private String getMetricsById(String service, RpcClient client) { Request req = new Request("getMetricsById"); req.parameters().add(new StringValue(service)); - return invoke(req, target); + return invoke(req, client, true); } - private String getAllMetricNamesForService(String service, ConsumerId consumer, Target target) { + private String getAllMetricNamesForService(String service, ConsumerId consumer, RpcClient client) { Request req = new Request("getAllMetricNamesForService"); req.parameters().add(new StringValue(service)); req.parameters().add(new StringValue(consumer.id)); - return invoke(req, target); + return invoke(req, client, true); } - private static String invoke(Request req, Target target) { + private static String invoke(Request req, RpcClient client, boolean expectReturnValue) { String returnValue; - target.invokeSync(req, 20.0); + client.target.invokeSync(req, 20.0); if (req.checkReturnTypes("s")) { returnValue = req.returnValues().get(0).asString(); - } else { + } else if (expectReturnValue) { System.out.println(req.methodName() + " from rpcserver - Invocation failed " + req.errorCode() + ": " + req.errorMessage()); returnValue = req.errorCode() + ": " + req.errorMessage(); } + else { + return ""; + } return returnValue; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index 9e9f66be700..0f563a75b11 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -2,179 +2,179 @@ package ai.vespa.rankingexpression.importer; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import ai.vespa.rankingexpression.importer.operations.Rename; +import com.yahoo.collections.ListMap; -import java.util.ArrayDeque; import java.util.ArrayList; -import java.util.Collections; -import java.util.Deque; +import java.util.Comparator; import java.util.HashMap; -import java.util.Iterator; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.stream.Collectors; /** - * A constraint satisfier to find suitable dimension names to reduce the + * A constraint solver which finds suitable dimension names to reduce the * amount of necessary renaming during evaluation of an imported model. * * @author lesters + * @author bratseth */ public class DimensionRenamer { + private static final Logger log = Logger.getLogger(DimensionRenamer.class.getName()); + private final String dimensionPrefix; - private final Map<String, List<Integer>> variables = new HashMap<>(); - private final Map<Arc, Constraint> constraints = new HashMap<>(); - private final Map<String, Integer> renames = new HashMap<>(); - private int iterations = 0; + /** The graph we are renaming the dimensions of */ + private final IntermediateGraph graph; + + /** The set of dimensions to find a solution for */ + private final Set<String> dimensions = new HashSet<>(); + + /** The constraints on the dimension name assignment */ + private final ListMap<Arc, Constraint> constraints = new ListMap<>(); + + /** The solution to this, or null if no solution is found yet */ + private Map<String, Integer> renames = null; - public DimensionRenamer() { - this("d"); + public DimensionRenamer(IntermediateGraph graph) { + this(graph, "d"); } - public DimensionRenamer(String dimensionPrefix) { + public DimensionRenamer(IntermediateGraph graph, String dimensionPrefix) { + this.graph = graph; this.dimensionPrefix = dimensionPrefix; } - /** - * Add a dimension name variable. - */ - public void addDimension(String name) { - variables.computeIfAbsent(name, d -> new ArrayList<>()); - } + /** Add a dimension to the set of dimensions to be renamed */ + public void addDimension(String name) { dimensions.add(name); } + + /** Add a constraint between dimension names */ + public void addConstraint(String from, String to, Constraint constraint, IntermediateOperation operation) { + if (constraint instanceof EqualConstraint && from.equals(to)) return; - /** - * Add a constraint between dimension names. - */ - public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) { Arc arc = new Arc(from, to, operation); - Arc opposite = arc.opposite(); - constraints.put(arc, pred); - constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric + constraints.put(arc, constraint); + constraints.put(arc.opposite(), constraint.opposite()); // make constraint graph symmetric } - /** - * Retrieve resulting name of dimension after solving for constraints. - */ - public Optional<String> dimensionNameOf(String name) { - if (!renames.containsKey(name)) { - return Optional.empty(); - } - return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name))); + void solve() { + log.log(Level.FINE, () -> "Rename problem:\n" + constraintsToString(constraints)); + renames = solve(100000); + log.log(Level.FINE, () -> "Rename solution:\n" + renamesToString(renames)); } - /** - * Perform iterative arc consistency until we have found a solution. After - * an initial iteration, the variables (dimensions) will have multiple - * valid values. Find a single valid assignment by iteratively locking one - * dimension after another, and running the arc consistency algorithm - * multiple times. - * - * This requires having constraints that result in an absolute ordering: - * equals, lesserThan and greaterThan do that, but adding notEquals does - * not typically result in a guaranteed ordering. If that is needed, the - * algorithm below needs to be adapted with a backtracking (tree) search - * to find solutions. - */ - private void solve(int maxIterations) { - initialize(); - - // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts + private Map<String, Integer> solve(int maxIterations) { + Map<String, Integer> solution = solveWithOrWithoutSoftConstraints(maxIterations); + if (solution != null) return solution; - for (String dimension : variables.keySet()) { - List<Integer> values = variables.get(dimension); - if (values.size() > 1) { - if (!ac3()) { - throw new IllegalArgumentException("Dimension renamer unable to find a solution."); - } - values.sort(Integer::compare); - variables.put(dimension, Collections.singletonList(values.get(0))); - } - renames.put(dimension, variables.get(dimension).get(0)); - if (iterations > maxIterations) { - throw new IllegalArgumentException("Dimension renamer unable to find a solution within " + - maxIterations + " iterations"); - } + for (RenameTarget target : prioritizedRenameTargets()) { + System.out.println("Trying rename " + target); + target.insertRename(this); + solution = solveWithOrWithoutSoftConstraints(maxIterations); + if (solution != null) return solution; + target.uninsertRename(this); } - - // Todo: handle failure more gracefully: - // If a solution can't be found, look at the operation node in the arc - // with the most remaining constraints, and inject a rename operation. - // Then run this algorithm again. + throw new IllegalArgumentException("Could not find a dimension naming solution " + + "given constraints\n" + constraintsToString(constraints)); } - void solve() { - solve(100000); + private Map<String, Integer> solveWithOrWithoutSoftConstraints(int maxIterations) { + Map<String, Integer> solution = NamingConstraintSolver.solve(dimensions, constraints, maxIterations); + if ( solution == null) { + ListMap<Arc, Constraint> hardConstraints = new ListMap<>(); + boolean anyRemoved = copyHard(constraints, hardConstraints); + if (anyRemoved) + solution = NamingConstraintSolver.solve(dimensions, hardConstraints, maxIterations); + } + return solution; } - private void initialize() { - for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) { - List<Integer> values = variable.getValue(); - for (int i = 0; i < variables.size(); ++i) { - values.add(i); // invariant: values are in increasing order + /** Removes soft constraints and returns whether something was removed */ + private boolean copyHard(ListMap<Arc, Constraint> source, ListMap<Arc, Constraint> target) { + boolean removed = false; + for (var entry : source.entrySet()) { + Arc arc = entry.getKey(); + for (Constraint constraint : entry.getValue()) { + if ( ! constraint.isSoft()) + target.put(arc, constraint); + else + removed = true; } } + return removed; } - private boolean ac3() { - Deque<Arc> workList = new ArrayDeque<>(constraints.keySet()); - while (!workList.isEmpty()) { - Arc arc = workList.pop(); - iterations += 1; - if (revise(arc)) { - if (variables.get(arc.from).size() == 0) { - return false; // no solution found - } - for (Arc constraint : constraints.keySet()) { - if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) { - workList.add(constraint); - } - } - } + private List<RenameTarget> prioritizedRenameTargets() { + Map<IntermediateOperation, Integer> constraintsPerOperation = new HashMap<>(); + + for (var constraint : constraints.entrySet()) { + constraintsPerOperation.compute(constraint.getKey().operation, + (operation, count) -> count == null ? 1 : ++count); } - return true; - } + List<IntermediateOperation> prioritizedOperations = + constraintsPerOperation.entrySet().stream() + .sorted(Comparator.comparingInt(entry -> - entry.getValue())) + .map(entry -> entry.getKey()) + .collect(Collectors.toList()); - private boolean revise(Arc arc) { - boolean revised = false; - for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) { - Integer from = fromIterator.next(); - boolean satisfied = false; - for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) { - Integer to = toIterator.next(); - if (constraints.get(arc).test(from, to)) { - satisfied = true; + List<RenameTarget> targets = new ArrayList<>(); + for (IntermediateOperation operation : prioritizedOperations) { + for (int i = 0; i < operation.inputs().size(); i++) { + Optional<OrderedTensorType> inputType = operation.inputs().get(i).type(); + if (inputType.isEmpty()) continue; + for (String dimensionName : inputType.get().dimensionNames()) { + RenameTarget target = new RenameTarget(operation, i, dimensionName, graph); + if (target.rootKey != null) // TODO: Inserting renames under non-roots is not implemented + targets.add(target); } } - if (!satisfied) { - fromIterator.remove(); - revised = true; - } } - return revised; - } - - public interface Constraint { - boolean test(Integer x, Integer y); + return targets; } - public static boolean equals(Integer x, Integer y) { - return Objects.equals(x, y); + /** + * Retrieve resulting name of a dimension after solving for constraints, or empty if no + * solution is found yet, or this dimension was not added before finding a solution. + */ + public Optional<String> dimensionNameOf(String name) { + if ( renames == null || ! renames.containsKey(name)) + return Optional.empty(); + return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name))); } - public static boolean lesserThan(Integer x, Integer y) { - return x < y; + private static String renamesToString(Map<String, Integer> renames) { + return renames.entrySet().stream() + .map(e -> " " + e.getKey() + " -> " + e.getValue()) + .collect(Collectors.joining("\n")); } - public static boolean greaterThan(Integer x, Integer y) { - return x > y; + private static String constraintsToString(ListMap<Arc, Constraint> constraints) { + StringBuilder b = new StringBuilder(); + for (var entry : constraints.entrySet()) { + Arc arc = entry.getKey(); + for (Constraint constraint : entry.getValue()) { + if (constraint.isOpposite()) continue; // noise + b.append(" "); + if (constraint.isSoft()) + b.append("(soft) "); + b.append(arc.from).append(" ").append(constraint).append(" ").append(arc.to); + b.append(" (origin: ").append(arc.operation).append(")\n"); + } + } + return b.toString(); } - private static class Arc { + static class Arc { - private final String from; - private final String to; + final String from; + final String to; private final IntermediateOperation operation; Arc(String from, String to, IntermediateOperation operation) { @@ -194,7 +194,7 @@ public class DimensionRenamer { @Override public boolean equals(Object obj) { - if (obj == null || !(obj instanceof Arc)) { + if (!(obj instanceof Arc)) { return false; } Arc other = (Arc) obj; @@ -203,8 +203,185 @@ public class DimensionRenamer { @Override public String toString() { - return String.format("%s -> %s", from, to); + return from + " -> " + to; + } + } + + public static abstract class Constraint { + + private final boolean soft, opposite; + + protected Constraint(boolean soft, boolean opposite) { + this.soft = soft; + this.opposite = opposite; + } + + abstract boolean test(Integer x, Integer y); + abstract Constraint opposite(); + + /** Returns whether this constraint can be violated if that is necessary to achieve a solution */ + boolean isSoft() { return soft; } + + /** Returns whether this is an opposite of another constraint */ + boolean isOpposite() { return opposite; } + + public static Constraint equal(boolean soft) { return new EqualConstraint(soft, false); } + public static Constraint notEqual(boolean soft) { return new NotEqualConstraint(soft, false); } + public static Constraint lessThan(boolean soft) { return new LessThanConstraint(soft, false); } + public static Constraint greaterThan(boolean soft) { return new GreaterThanConstraint(soft, false); } + + } + + private static class EqualConstraint extends Constraint { + + private EqualConstraint(boolean soft, boolean opposite) { + super(soft, opposite); + } + + @Override + public boolean test(Integer x, Integer y) { return Objects.equals(x, y); } + + @Override + public Constraint opposite() { return new EqualConstraint(isSoft(), true); } + + @Override + public String toString() { return "=="; } + + } + + private static class NotEqualConstraint extends Constraint { + + private NotEqualConstraint(boolean soft, boolean opposite) { + super(soft, opposite); + } + + @Override + public boolean test(Integer x, Integer y) { return ! Objects.equals(x, y); } + + @Override + public Constraint opposite() { return new NotEqualConstraint(isSoft(), true); } + + @Override + public String toString() { return "!="; } + + } + + private static class LessThanConstraint extends Constraint { + + private LessThanConstraint(boolean soft, boolean opposite) { + super(soft, opposite); + } + + @Override + public boolean test(Integer x, Integer y) { return x < y; } + + @Override + public Constraint opposite() { return new GreaterThanConstraint(isSoft(), true); } + + @Override + public String toString() { return "<"; } + + } + + private static class GreaterThanConstraint extends Constraint { + + private GreaterThanConstraint(boolean soft, boolean opposite) { + super(soft, opposite); + } + + @Override + public boolean test(Integer x, Integer y) { return x > y; } + + @Override + public Constraint opposite() { return new LessThanConstraint(isSoft(), true); } + + @Override + public String toString() { return ">"; } + + } + + /** + * An operation and an input number which we may want to insert a rename operation at. + * That is, we may want to change op(..., input, ...) to op(..., rename(input), ...). + * + * This class is (and must remain) immutable. + */ + private static class RenameTarget { + + final IntermediateOperation operation; + final int inputNumber; + final String dimensionName; + final IntermediateGraph graph; + + /** + * Returns the key of this operation in the root operations of the graph, + * or null if it is not a root operation + */ + final String rootKey; + + public RenameTarget(IntermediateOperation operation, int inputNumber, String dimensionName, IntermediateGraph graph) { + this.operation = operation; + this.inputNumber = inputNumber; + this.dimensionName = dimensionName; + this.rootKey = findRootKey(operation, graph); + this.graph = graph; + } + + public IntermediateOperation input() { + return operation.inputs().get(inputNumber); + } + + private static String findRootKey(IntermediateOperation operation, IntermediateGraph graph) { + for (var entry : graph.operations().entrySet()) { + if (entry.getValue() == operation) + return entry.getKey(); + } + return null; + } + + /** Inserts a rename operation if possible. Returns whether an operation was inserted. */ + private boolean insertRename(DimensionRenamer renamer) { + Rename rename = new Rename(operation.modelName(), + dimensionName, + renamer.dimensionPrefix + renamer.dimensions.size(), + input()); + + List<IntermediateOperation> newInputs = new ArrayList<>(operation.inputs()); + newInputs.set(inputNumber, rename); + IntermediateOperation newOperation = operation.withInputs(newInputs); + if (rootKey == null) + throw new IllegalStateException("Renaming non-roots is not implemented"); + graph.put(rootKey, newOperation); + + removeConstraintsOf(operation, renamer); + rename.addDimensionNameConstraints(renamer); + newOperation.addDimensionNameConstraints(renamer); + return true; + } + + /** Undo what insertRenameOperation has done: Set back the original operation and remove+add constraints */ + private void uninsertRename(DimensionRenamer renamer) { + IntermediateOperation newOperation = graph.operations().get(rootKey); + Rename rename = (Rename)newOperation.inputs().get(inputNumber); + graph.put(rootKey, operation); + + removeConstraintsOf(rename, renamer); + removeConstraintsOf(newOperation, renamer); + operation.addDimensionNameConstraints(renamer); + } + + private void removeConstraintsOf(IntermediateOperation operation, DimensionRenamer renamer) { + for (Arc key : new HashSet<>(renamer.constraints.keySet())) { + if (key.operation == operation) + renamer.constraints.removeAll(key); + } + } + + @Override + public String toString() { + return operation + ", input " + inputNumber; } + } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index 0c570261ae7..a9be1bbd40e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -262,9 +262,12 @@ public class ImportedModel implements ImportedMlModel { /** Returns the expression this output references as an imported function */ public ImportedMlFunction outputFunction(String outputName, String functionName) { + RankingExpression outputExpression = owner().expressions().get(outputs.get(outputName)); + if (outputExpression == null) + throw new IllegalArgumentException("Missing output '" + outputName + "' in " + this); return new ImportedMlFunction(functionName, new ArrayList<>(inputs.values()), - owner().expressions().get(outputs.get(outputName)).getRoot().toString(), + outputExpression.getRoot().toString(), asStrings(inputMap()), Optional.empty()); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java index aec98d06874..6c583d960bd 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java @@ -3,9 +3,11 @@ package ai.vespa.rankingexpression.importer; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import ai.vespa.rankingexpression.importer.operations.MatMul; import java.util.Collection; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; @@ -20,7 +22,7 @@ import java.util.Set; public class IntermediateGraph { private final String modelName; - private final Map<String, IntermediateOperation> index = new HashMap<>(); + private final Map<String, IntermediateOperation> operations = new HashMap<>(); private final Map<String, GraphSignature> signatures = new HashMap<>(); private static class GraphSignature { @@ -37,11 +39,11 @@ public class IntermediateGraph { } public IntermediateOperation put(String key, IntermediateOperation operation) { - return index.put(key, operation); + return operations.put(key, operation); } public IntermediateOperation get(String key) { - return index.get(key); + return operations.get(key); } public Set<String> signatures() { @@ -61,11 +63,11 @@ public class IntermediateGraph { } public boolean alreadyImported(String key) { - return index.containsKey(key); + return operations.containsKey(key); } - public Collection<IntermediateOperation> operations() { - return index.values(); + public Map<String, IntermediateOperation> operations() { + return operations; } void optimize() { @@ -76,16 +78,16 @@ public class IntermediateGraph { * Find dimension names to avoid excessive renaming while evaluating the model. */ private void renameDimensions() { - DimensionRenamer renamer = new DimensionRenamer(); + DimensionRenamer renamer = new DimensionRenamer(this); for (String signature : signatures()) { for (String output : outputs(signature).values()) { - addDimensionNameConstraints(index.get(output), renamer); + addDimensionNameConstraints(operations.get(output), renamer); } } renamer.solve(); for (String signature : signatures()) { for (String output : outputs(signature).values()) { - renameDimensions(index.get(output), renamer); + renameDimensions(operations.get(output), renamer); } } } @@ -104,4 +106,16 @@ public class IntermediateGraph { } } + @Override + public String toString() { + return "intermediate graph for '" + modelName + "'"; + } + + public String toFullString() { + StringBuilder b = new StringBuilder(); + for (var input : operations.entrySet()) + b.append(input.getKey()).append(": ").append(input.getValue().toFullString()).append("\n"); + return b.toString(); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index 99bfa08db43..b587a9200ec 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -11,12 +11,14 @@ import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.text.ExpressionFormatter; import com.yahoo.yolean.Exceptions; import java.io.File; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.logging.Level; import java.util.logging.Logger; /** @@ -50,6 +52,8 @@ public abstract class ModelImporter implements MlModelImporter { */ protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) { ImportedModel model = new ImportedModel(graph.name(), modelSource); + log.log(Level.FINER, () -> "Intermediate graph created from '" + modelSource + "':\n" + + ExpressionFormatter.inTwoColumnMode(70, 50).format(graph.toFullString())); graph.optimize(); @@ -223,7 +227,7 @@ public abstract class ModelImporter implements MlModelImporter { * for fast model weight updates. */ private static void logVariableTypes(IntermediateGraph graph) { - for (IntermediateOperation operation : graph.operations()) { + for (IntermediateOperation operation : graph.operations().values()) { if ( ! (operation instanceof Constant)) continue; if ( ! operation.type().isPresent()) continue; // will not happen log.info("Importing model variable " + operation.name() + " as " + operation.vespaName() + diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java new file mode 100644 index 00000000000..21cc6b27dad --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java @@ -0,0 +1,126 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer; + +import com.yahoo.collections.ListMap; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Solves a dimension naming constraint problem. + * + * @author lesters + * @author bratseth + */ +class NamingConstraintSolver { + + private final ListMap<String, Integer> possibleAssignments; + private final ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints; + + private int iterations = 0; + private final int maxIterations; + + private NamingConstraintSolver(Set<String> dimensions, + ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints, + int maxIterations) { + this.possibleAssignments = allPossibilities(dimensions); + this.constraints = constraints; + this.maxIterations = maxIterations; + } + + /** Returns a list containing a list of all assignment possibilities for each of the given dimensions */ + private static ListMap<String, Integer> allPossibilities(Set<String> dimensions) { + ListMap<String, Integer> all = new ListMap<>(); + for (String dimension : dimensions) { + for (int i = 0; i < dimensions.size(); ++i) + all.put(dimension, i); + } + return all; + } + + /** + * Try the solve the constraint problem given in the arguments, and put the result in renames. + * + * This is done by performing iterative arc consistency until we have found a solution. + * After an initial iteration, the dimensions will have multiple + * valid values. Find a single valid assignment by iteratively locking one + * dimension after another, and running the arc consistency algorithm + * multiple times. + * + * This requires having constraints that result in an absolute ordering: + * equal, lessThan and greaterThan do that, but not necessarily notEqual + * If that is needed, the algorithm needs to be adapted with a backtracking + * (tree) search + * + * @return the solution in the form of the renames to perform + */ + private Map<String, Integer> trySolve() { + // TODO: Evaluate possible improved efficiency by using a heuristic such as min-conflicts + + Map<String, Integer> solution = new HashMap<>(); + for (String dimension : possibleAssignments.keySet()) { + List<Integer> values = possibleAssignments.get(dimension); + if (values.size() > 1) { + if ( ! ac3()) return null; + values.sort(Integer::compare); + possibleAssignments.replace(dimension, values.get(0)); + } + solution.put(dimension, possibleAssignments.get(dimension).get(0)); // Pick the first available solution + if (iterations > maxIterations) return null; + } + return solution; + } + + private boolean ac3() { + Deque<DimensionRenamer.Arc> workList = new ArrayDeque<>(constraints.keySet()); + while ( ! workList.isEmpty()) { + DimensionRenamer.Arc arc = workList.pop(); + iterations++; + if (revise(arc)) { + if (possibleAssignments.get(arc.from).isEmpty()) return false; + + for (DimensionRenamer.Arc constraint : constraints.keySet()) { + if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) + workList.add(constraint); + } + } + } + return true; + } + + private boolean revise(DimensionRenamer.Arc arc) { + boolean revised = false; + for (Iterator<Integer> fromIterator = possibleAssignments.get(arc.from).iterator(); fromIterator.hasNext(); ) { + Integer from = fromIterator.next(); + boolean satisfied = false; + for (Iterator<Integer> toIterator = possibleAssignments.get(arc.to).iterator(); toIterator.hasNext(); ) { + Integer to = toIterator.next(); + if (constraints.get(arc).stream().allMatch(constraint -> constraint.test(from, to))) + satisfied = true; + } + if ( ! satisfied) { + fromIterator.remove(); + revised = true; + } + } + return revised; + } + + /** + * Attempts to solve the given naming problem. The input maps are never modified. + * + * @return the solution as a map from existing names to name ids represented as integers, or NULL + * if no solution could be found + */ + public static Map<String, Integer> solve(Set<String> dimensions, + ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints, + int maxIterations) { + return new NamingConstraintSolver(dimensions, constraints, maxIterations).trySolve(); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java index 9115dc99b82..1cb8f3a2951 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java @@ -7,8 +7,11 @@ import com.yahoo.tensor.TensorTypeParser; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; /** @@ -131,11 +134,17 @@ public class OrderedTensorType { public OrderedTensorType rename(DimensionRenamer renamer) { List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size()); + Map<String, String> new2Old = new HashMap<>(); // Just to create meaningful error messages for (TensorType.Dimension dimension : dimensions) { String oldName = dimension.name(); Optional<String> newName = renamer.dimensionNameOf(oldName); - if (!newName.isPresent()) - return this; // presumably, already renamed + if ( newName.isEmpty()) return this; // presumably already renamed + + if (new2Old.containsKey(newName.get())) + throw new IllegalArgumentException("Can not rename '" + oldName + "' to '" + newName.get() + "' in " + this + + " as '" + new2Old.get(newName.get()) + "' should also be renamed to it"); + new2Old.put(newName.get(), oldName); + TensorType.Dimension.Type dimensionType = dimension.type(); if (dimensionType == TensorType.Dimension.Type.indexedBound) { renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get())); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java index d6ea00ca453..9f62a27a3b9 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java @@ -29,7 +29,7 @@ public class Argument extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type()); - if (!standardNamingType.equals(type)) { + if ( ! standardNamingType.equals(type)) { List<String> renameFrom = standardNamingType.dimensionNames(); List<String> renameTo = type.dimensionNames(); output = new Rename(output, renameFrom, renameTo); @@ -39,9 +39,7 @@ public class Argument extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); - } + addConstraintsFrom(type, renamer); } @Override @@ -54,4 +52,22 @@ public class Argument extends IntermediateOperation { return false; } + @Override + public Argument withInputs(List<IntermediateOperation> inputs) { + if ( ! inputs.isEmpty()) + throw new IllegalArgumentException("Argument cannot take inputs"); + return new Argument(modelName(), name(), type); + } + + @Override + public String operationName() { return "Argument"; } + + @Override + public String toString() { return "Argument(" + standardNamingType + ")"; } + + @Override + public String toFullString() { + return "\t" + lazyGetType() + ":\tArgument(" + standardNamingType + ")"; + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java index 7ae50a0549d..7787caa83ce 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.functions.TensorFunction; import java.util.List; import java.util.Optional; +import java.util.stream.Collectors; public class ConcatV2 extends IntermediateOperation { @@ -89,7 +90,7 @@ public class ConcatV2 extends IntermediateOperation { OrderedTensorType b = inputs.get(i).type().get(); String bDim = b.dimensions().get(concatDimensionIndex).name(); String aDim = a.dimensions().get(concatDimensionIndex).name(); - renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this); + renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this); } } @@ -99,4 +100,12 @@ public class ConcatV2 extends IntermediateOperation { concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName); } + @Override + public ConcatV2 withInputs(List<IntermediateOperation> inputs) { + return new ConcatV2(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "ConcatV2"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java index 41d421b1f5a..d13c1ad5f3c 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java @@ -62,9 +62,7 @@ public class Const extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); - } + addConstraintsFrom(type, renamer); } @Override @@ -86,4 +84,23 @@ public class Const extends IntermediateOperation { } return value.get(); } + + @Override + public Const withInputs(List<IntermediateOperation> inputs) { + return new Const(modelName(), name(), inputs, attributeMap, type); + } + + @Override + public String operationName() { return "Const"; } + + @Override + public String toString() { + return "Const(" + type + ")"; + } + + @Override + public String toFullString() { + return "\t" + lazyGetType() + ":\tConst(" + type + ")"; + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java index a1cc83296b0..1eaaf705220 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java @@ -8,6 +8,7 @@ import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; import java.util.Collections; +import java.util.List; import java.util.Optional; public class Constant extends IntermediateOperation { @@ -48,9 +49,7 @@ public class Constant extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); - } + addConstraintsFrom(type, renamer); } @Override @@ -58,4 +57,24 @@ public class Constant extends IntermediateOperation { return true; } + @Override + public Constant withInputs(List<IntermediateOperation> inputs) { + if ( ! inputs.isEmpty()) + throw new IllegalArgumentException("Constant cannot take inputs"); + return new Constant(modelName(), name(), type); + } + + @Override + public String operationName() { return "Constant"; } + + @Override + public String toString() { + return "Constant(" + type + ")"; + } + + @Override + public String toFullString() { + return "\t" + lazyGetType() + ":\tConstant(" + type + ")"; + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index c64b9ded601..e6cc96d48ad 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -30,7 +30,7 @@ public class ExpandDims extends IntermediateOperation { if ( ! allInputTypesPresent(2)) return null; IntermediateOperation axisOperation = inputs().get(1); - if (!axisOperation.getConstantValue().isPresent()) { + if ( ! axisOperation.getConstantValue().isPresent()) { throw new IllegalArgumentException("ExpandDims in " + name + ": Axis must be a constant."); } Tensor axis = axisOperation.getConstantValue().get().asTensor(); @@ -47,18 +47,23 @@ public class ExpandDims extends IntermediateOperation { expandDimensions = new ArrayList<>(); int dimensionIndex = 0; for (TensorType.Dimension dimension : inputType.dimensions()) { - if (dimensionIndex == dimensionToInsert) { - String name = String.format("%s_%d", vespaName(), dimensionIndex); - expandDimensions.add(name); - typeBuilder.add(TensorType.Dimension.indexed(name, 1L)); - } + if (dimensionIndex == dimensionToInsert) + addDimension(dimensionIndex, typeBuilder); typeBuilder.add(dimension); dimensionIndex++; } - + if (dimensionToInsert == inputType.dimensions().size()) { // Insert last dimension + addDimension(dimensionIndex, typeBuilder); + } return typeBuilder.build(); } + private void addDimension(int dimensionIndex, OrderedTensorType.Builder typeBuilder) { + String name = String.format("%s_%d", vespaName(), dimensionIndex); + expandDimensions.add(name); + typeBuilder.add(TensorType.Dimension.indexed(name, 1L)); + } + @Override protected TensorFunction lazyGetFunction() { if ( ! allInputFunctionsPresent(2)) return null; @@ -88,7 +93,7 @@ public class ExpandDims extends IntermediateOperation { List<String> renamedDimensions = new ArrayList<>(expandDimensions.size()); for (String name : expandDimensions) { Optional<String> newName = renamer.dimensionNameOf(name); - if (!newName.isPresent()) { + if ( ! newName.isPresent()) { return; // presumably, already renamed } renamedDimensions.add(newName.get()); @@ -96,4 +101,12 @@ public class ExpandDims extends IntermediateOperation { expandDimensions = renamedDimensions; } + @Override + public ExpandDims withInputs(List<IntermediateOperation> inputs) { + return new ExpandDims(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "ExpandDims"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java index c2787aa14d4..5463f645355 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java @@ -32,4 +32,12 @@ public class Identity extends IntermediateOperation { return inputs.get(0).function().orElse(null); } + @Override + public Identity withInputs(List<IntermediateOperation> inputs) { + return new Identity(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "Identity"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 0ee54f839bc..c3980b8fe93 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -58,6 +59,8 @@ public abstract class IntermediateOperation { protected abstract OrderedTensorType lazyGetType(); protected abstract TensorFunction lazyGetFunction(); + public String modelName() { return modelName; } + /** Returns the Vespa tensor type of this operation if it exists */ public Optional<OrderedTensorType> type() { if (type == null) { @@ -99,6 +102,20 @@ public abstract class IntermediateOperation { /** Add dimension name constraints for this operation */ public void addDimensionNameConstraints(DimensionRenamer renamer) { } + /** Conveinence method to adds dimensions and constraints of the given tensor type */ + protected void addConstraintsFrom(OrderedTensorType type, DimensionRenamer renamer) { + for (int i = 0; i < type.dimensions().size(); i++) { + renamer.addDimension(type.dimensions().get(i).name()); + + // Each dimension is distinct: + for (int j = i + 1; j < type.dimensions().size(); j++) { + renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(), + DimensionRenamer.Constraint.notEqual(false), + this); + } + } + } + /** Performs dimension rename for this operation */ public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); } @@ -175,6 +192,12 @@ public abstract class IntermediateOperation { .collect(Collectors.toList())); } + public abstract IntermediateOperation withInputs(List<IntermediateOperation> inputs); + + String asString(Optional<OrderedTensorType> type) { + return type.map(t -> t.toString()).orElse("(unknown)"); + } + /** * A method signature input and output has the form name:index. * This returns the name part without the index. @@ -203,4 +226,19 @@ public abstract class IntermediateOperation { Optional<List<Value>> getList(String key); } + public abstract String operationName(); + + @Override + public String toString() { + return operationName() + + inputs().stream().map(input -> asString(input.type())).collect(Collectors.joining(", ")) + + ")"; + } + + public String toFullString() { + return "\t" + lazyGetType() + ":\t" + operationName() + + inputs().stream().map(input -> input.toFullString()).collect(Collectors.joining(", ")) + + ")"; + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index c2d75153586..adb54474812 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -95,7 +95,7 @@ public class Join extends IntermediateOperation { for (int i = 0; i < b.rank(); ++i) { String bDim = b.dimensions().get(i).name(); String aDim = a.dimensions().get(i + sizeDifference).name(); - renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this); + renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this); } } @@ -111,4 +111,12 @@ public class Join extends IntermediateOperation { return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1); } + @Override + public Join withInputs(List<IntermediateOperation> inputs) { + return new Join(modelName(), name(), inputs, operator); + } + + @Override + public String operationName() { return "Join"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java index e0842d820f9..ea39e289c48 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java @@ -34,4 +34,12 @@ public class Map extends IntermediateOperation { return new com.yahoo.tensor.functions.Map(input.get(), operator); } + @Override + public Map withInputs(List<IntermediateOperation> inputs) { + return new Map(modelName(), name(), inputs, operator); + } + + @Override + public String operationName() { return "Map"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 9a76662529d..434261c6077 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.text.ExpressionFormatter; import java.util.List; import java.util.Optional; @@ -51,20 +52,40 @@ public class MatMul extends IntermediateOperation { List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); + assertTwoDimensions(aDimensions, inputs.get(0), "first argument"); + assertTwoDimensions(bDimensions, inputs.get(1), "second argument"); + String aDim0 = aDimensions.get(0).name(); String aDim1 = aDimensions.get(1).name(); String bDim0 = bDimensions.get(0).name(); String bDim1 = bDimensions.get(1).name(); // The second dimension of a should have the same name as the first dimension of b - renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this); + renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(false), this); // The first dimension of a should have a different name than the second dimension of b - renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this); + renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(false), this); // For efficiency, the dimensions to join over should be innermost - soft constraint - renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this); - renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); + renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(true), this); + renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(true), this); + } + + private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) { + if (dimensions.size() >= 2) return; + + + throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this + + " but got just " + dimensions + " from\n" + + ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString())); } + @Override + public MatMul withInputs(List<IntermediateOperation> inputs) { + return new MatMul(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "MatMul"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java index d8e9950c61f..215edf88c4f 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java @@ -91,6 +91,11 @@ public class Mean extends IntermediateOperation { reduceDimensions = renamedDimensions; } + @Override + public Mean withInputs(List<IntermediateOperation> inputs) { + return new Mean(modelName(), name(), inputs, attributeMap); + } + private boolean shouldKeepDimensions() { Optional<Value> keepDims = attributeMap.get("keep_dims"); return keepDims.isPresent() && keepDims.get().asBoolean(); @@ -108,4 +113,7 @@ public class Mean extends IntermediateOperation { return builder.build(); } + @Override + public String operationName() { return "Mean"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java index ce0c58971d0..671cfe852a7 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java @@ -32,4 +32,12 @@ public class Merge extends IntermediateOperation { return null; } + @Override + public Merge withInputs(List<IntermediateOperation> inputs) { + return new Merge(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "Merge"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java index 4c5ce33b1b5..35d89cf6ab6 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java @@ -23,4 +23,12 @@ public class NoOp extends IntermediateOperation { return null; } + @Override + public NoOp withInputs(List<IntermediateOperation> inputs) { + return new NoOp(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "NoOp"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java index e5e5c29f8f1..177ef8d5e17 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java @@ -45,4 +45,12 @@ public class PlaceholderWithDefault extends IntermediateOperation { return true; // not true if we add to function } + @Override + public PlaceholderWithDefault withInputs(List<IntermediateOperation> inputs) { + return new PlaceholderWithDefault(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "PlaceholdeWithDefault"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java new file mode 100644 index 00000000000..abc431233be --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java @@ -0,0 +1,67 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; + +/** + * Renames a tensor dimension to relax dimension constraints + * + * @author bratseth + */ +public class Rename extends IntermediateOperation { + + private final String from, to; + + public Rename(String modelName, String from, String to, IntermediateOperation input) { + super(modelName, "rename", List.of(input)); + this.from = from; + this.to = to; + } + + @Override + boolean allInputFunctionsPresent(int expected) { + return super.allInputFunctionsPresent(expected); + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(1)) return null; + + OrderedTensorType inputType = inputs.get(0).type().orElse(null); + if (inputType == null) return null; + + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(inputType.type().valueType()); + for (TensorType.Dimension dimension : inputType.dimensions()) + builder.add(dimension.withName(dimension.name().equals(from) ? to : dimension.name())); + return builder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputFunctionsPresent(1)) return null; + return new com.yahoo.tensor.functions.Rename(inputs.get(0).function().orElse(null), from, to); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + renamer.addDimension(to); + } + + @Override + public Rename withInputs(List<IntermediateOperation> inputs) { + if (inputs.size() != 1) + throw new IllegalArgumentException("Rename require 1 input, not " + inputs.size()); + return new Rename(modelName(), from, to, inputs.get(0)); + } + + @Override + public String operationName() { return "Rename"; } + +} + + diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index 4a0fe236c9f..a210ed13f5d 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -74,6 +74,11 @@ public class Reshape extends IntermediateOperation { } } + @Override + public Reshape withInputs(List<IntermediateOperation> inputs) { + return new Reshape(modelName(), name(), inputs); + } + public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); @@ -119,4 +124,7 @@ public class Reshape extends IntermediateOperation { return new ArithmeticNode(children, operators); } + @Override + public String operationName() { return "Reshape"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java index dc690329a8d..35a1b6e2b0e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java @@ -81,8 +81,16 @@ public class Select extends IntermediateOperation { String bDim1 = bDimensions.get(1).name(); // These tensors should have the same dimension names - renamer.addConstraint(aDim0, bDim0, DimensionRenamer::equals, this); - renamer.addConstraint(aDim1, bDim1, DimensionRenamer::equals, this); + renamer.addConstraint(aDim0, bDim0, DimensionRenamer.Constraint.equal(false), this); + renamer.addConstraint(aDim1, bDim1, DimensionRenamer.Constraint.equal(false), this); } + @Override + public Select withInputs(List<IntermediateOperation> inputs) { + return new Select(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "Select"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java index 79f3012c327..57175092b5c 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java @@ -37,6 +37,11 @@ public class Shape extends IntermediateOperation { return true; } + @Override + public Shape withInputs(List<IntermediateOperation> inputs) { + return new Shape(modelName(), name(), inputs); + } + private void createConstantValue() { if (!allInputTypesPresent(1)) { return; @@ -50,4 +55,7 @@ public class Shape extends IntermediateOperation { this.setConstantValue(new TensorValue(builder.build())); } + @Override + public String operationName() { return "Shape"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java index cdacbe1656a..032ffb88a46 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java @@ -37,4 +37,12 @@ public class Softmax extends IntermediateOperation { return new com.yahoo.tensor.functions.Softmax(inputFunction, dimension); } + @Override + public Softmax withInputs(List<IntermediateOperation> inputs) { + return new Softmax(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "SoftMax"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java index 52d40144f61..56d9b542093 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java @@ -70,6 +70,11 @@ public class Squeeze extends IntermediateOperation { squeezeDimensions = renamedDimensions; } + @Override + public Squeeze withInputs(List<IntermediateOperation> inputs) { + return new Squeeze(modelName(), name(), inputs, attributeMap); + } + private OrderedTensorType reducedType(OrderedTensorType inputType) { OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType()); for (TensorType.Dimension dimension: inputType.type().dimensions()) { @@ -80,4 +85,7 @@ public class Squeeze extends IntermediateOperation { return builder.build(); } + @Override + public String operationName() { return "Squeeze"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java index 46b95233d11..c8cd235f50e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java @@ -107,4 +107,12 @@ public class Sum extends IntermediateOperation { return builder.build(); } + @Override + public Sum withInputs(List<IntermediateOperation> inputs) { + return new Sum(modelName(), name(), inputs, attributeMap); + } + + @Override + public String operationName() { return "Sum"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java index 39702690bfa..4beafc68909 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java @@ -42,6 +42,14 @@ public class Switch extends IntermediateOperation { return predicate == port ? inputs().get(0).function().get() : null; } + @Override + public Switch withInputs(List<IntermediateOperation> inputs) { + return new Switch(modelName(), name(), inputs, port); + } + + @Override + public String operationName() { return "Switch"; } + } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java index cf8dd6e8e71..793258868ee 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java @@ -9,7 +9,7 @@ public class DimensionRenamerTest { @Test public void testMnistRenaming() { - DimensionRenamer renamer = new DimensionRenamer(); + DimensionRenamer renamer = new DimensionRenamer(new IntermediateGraph("test")); renamer.addDimension("first_dimension_of_x"); renamer.addDimension("second_dimension_of_x"); @@ -18,17 +18,17 @@ public class DimensionRenamerTest { renamer.addDimension("first_dimension_of_b"); // which dimension to join on matmul - renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null); + renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer.Constraint.equal(false), null); // other dimensions in matmul can't be equal - renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null); + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer.Constraint.lessThan(false), null); // for efficiency, put dimension to join on innermost - renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null); - renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null); + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer.Constraint.lessThan(true), null); + renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer.Constraint.greaterThan(true), null); // bias - renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null); + renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer.Constraint.equal(false), null); renamer.solve(); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java new file mode 100644 index 00000000000..6500a380190 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java @@ -0,0 +1,28 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.tensorflow; + +import ai.vespa.rankingexpression.importer.ImportedModel; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; +import org.junit.Assert; +import org.junit.Test; + +import static org.junit.Assert.assertNotNull; + +/** + * @author bratseth + */ +public class Issue9662TestCase { + + @Test + public void testImporting() { + TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/9662"); + ImportedModel.Signature signature = model.get().signature("serving_default"); + Assert.assertEquals("Should have no skipped outputs", + 0, model.get().signature("serving_default").skippedOutputs().size()); + + ImportedMlFunction output = signature.outputFunction("output", "output"); + assertNotNull(output); + model.assertEqualResultSum("input_embedding_user_guid", "dense_out/MatMul", 0.00001); + } + +} diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java index 9d2f8cf0692..75fa2ed7933 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java @@ -49,7 +49,7 @@ public class TestableTensorFlowModel { public ImportedModel get() { return model; } - /** Compare that summing the tensors produce the same result to within some tolerance delta */ + /** Compare that computing the expressions produce the same result to within some tolerance delta */ public void assertEqualResultSum(String inputName, String operationName, double delta) { Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName); Context context = contextFrom(model); diff --git a/model-integration/src/test/models/tensorflow/9662/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/9662/saved_model.pbtxt new file mode 100644 index 00000000000..83c601edfc0 --- /dev/null +++ b/model-integration/src/test/models/tensorflow/9662/saved_model.pbtxt @@ -0,0 +1,1318 @@ +saved_model_schema_version: 1 +meta_graphs { + meta_info_def { + stripped_op_list { + op { + name: "Add" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_STRING + } + } + } + } + op { + name: "BiasAdd" + input_arg { + name: "value" + type_attr: "T" + } + input_arg { + name: "bias" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } + } + } + op { + name: "Const" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "value" + type: "tensor" + } + attr { + name: "dtype" + type: "type" + } + } + op { + name: "ExpandDims" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "dim" + type_attr: "Tdim" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tdim" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "MatMul" + input_arg { + name: "a" + type_attr: "T" + } + input_arg { + name: "b" + type_attr: "T" + } + output_arg { + name: "product" + type_attr: "T" + } + attr { + name: "transpose_a" + type: "bool" + default_value { + b: false + } + } + attr { + name: "transpose_b" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Maximum" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + is_commutative: true + } + op { + name: "Mul" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "Placeholder" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + default_value { + shape { + unknown_rank: true + } + } + } + } + op { + name: "PlaceholderWithDefault" + input_arg { + name: "input" + type_attr: "dtype" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + } + } + op { + name: "Rsqrt" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Sigmoid" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Square" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Sub" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Sum" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + } + tags: "serve" + tensorflow_version: "1.13.1" + tensorflow_git_version: "b\'v1.13.1-0-g6612da8951\'" + } + graph_def { + node { + name: "keras_learning_phase/input" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_BOOL + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_BOOL + tensor_shape { + } + bool_val: false + } + } + } + } + node { + name: "keras_learning_phase" + op: "PlaceholderWithDefault" + input: "keras_learning_phase/input" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_BOOL + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "Dot/l2_normalize/Maximum/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 9.999999960041972e-13 + } + } + } + } + node { + name: "Dot/l2_normalize/Sum/reduction_indices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "dense_out/kernel" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + } + float_val: 0.1835838258266449 + } + } + } + } + node { + name: "input_embedding_user_guid" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 32 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 32 + } + } + } + } + } + node { + name: "Dot/l2_normalize_1/Square" + op: "Square" + input: "input_embedding_user_guid" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 32 + } + } + } + } + } + } + node { + name: "Dot/l2_normalize/Sum" + op: "Sum" + input: "Dot/l2_normalize_1/Square" + input: "Dot/l2_normalize/Sum/reduction_indices" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: true + } + } + } + node { + name: "Dot/l2_normalize/Maximum" + op: "Maximum" + input: "Dot/l2_normalize/Sum" + input: "Dot/l2_normalize/Maximum/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "Dot/l2_normalize_1/Rsqrt" + op: "Rsqrt" + input: "Dot/l2_normalize/Maximum" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "Dot/l2_normalize_1" + op: "Mul" + input: "input_embedding_user_guid" + input: "Dot/l2_normalize_1/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 32 + } + } + } + } + } + } + node { + name: "Dot/Mul" + op: "Mul" + input: "Dot/l2_normalize_1" + input: "Dot/l2_normalize_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 32 + } + } + } + } + } + } + node { + name: "Dot/Sum" + op: "Sum" + input: "Dot/Mul" + input: "Dot/l2_normalize/Sum/reduction_indices" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "batch_normalization_v1/moving_variance" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 1.0 + } + } + } + } + node { + name: "Dot/ExpandDims" + op: "ExpandDims" + input: "Dot/Sum" + input: "Dot/l2_normalize/Sum/reduction_indices" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "batch_normalization_v1/moving_mean" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "batch_normalization_v1/batchnorm/add/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000474974513 + } + } + } + } + node { + name: "batch_normalization_v1/batchnorm/add" + op: "Add" + input: "batch_normalization_v1/moving_variance" + input: "batch_normalization_v1/batchnorm/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "batch_normalization_v1/batchnorm/Rsqrt" + op: "Rsqrt" + input: "batch_normalization_v1/batchnorm/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "batch_normalization_v1/batchnorm/mul" + op: "Mul" + input: "batch_normalization_v1/batchnorm/Rsqrt" + input: "batch_normalization_v1/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "batch_normalization_v1/batchnorm/mul_1" + op: "Mul" + input: "Dot/ExpandDims" + input: "batch_normalization_v1/batchnorm/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "batch_normalization_v1/batchnorm/mul_2" + op: "Mul" + input: "batch_normalization_v1/moving_mean" + input: "batch_normalization_v1/batchnorm/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "batch_normalization_v1/batchnorm/sub" + op: "Sub" + input: "batch_normalization_v1/moving_mean" + input: "batch_normalization_v1/batchnorm/mul_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "batch_normalization_v1/batchnorm/add_1" + op: "Add" + input: "batch_normalization_v1/batchnorm/mul_1" + input: "batch_normalization_v1/batchnorm/sub" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "dense_out/MatMul" + op: "MatMul" + input: "batch_normalization_v1/batchnorm/add_1" + input: "dense_out/kernel" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "dense_out/BiasAdd" + op: "BiasAdd" + input: "dense_out/MatMul" + input: "batch_normalization_v1/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + } + node { + name: "dense_out/Sigmoid" + op: "Sigmoid" + input: "dense_out/BiasAdd" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + versions { + } + } + signature_def { + key: "serving_default" + value { + inputs { + key: "input_embedding_user_guid" + value { + name: "input_embedding_user_guid:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 32 + } + } + } + } + outputs { + key: "output" + value { + name: "dense_out/Sigmoid:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + method_name: "tensorflow/serving/predict" + } + } +} diff --git a/searchlib/src/tests/attribute/reference_attribute/CMakeLists.txt b/searchlib/src/tests/attribute/reference_attribute/CMakeLists.txt index b5aa87e32f1..6638bf886b7 100644 --- a/searchlib/src/tests/attribute/reference_attribute/CMakeLists.txt +++ b/searchlib/src/tests/attribute/reference_attribute/CMakeLists.txt @@ -4,5 +4,6 @@ vespa_add_executable(searchlib_reference_attribute_test_app TEST reference_attribute_test.cpp DEPENDS searchlib + gtest ) vespa_add_test(NAME searchlib_reference_attribute_test_app COMMAND searchlib_reference_attribute_test_app) diff --git a/searchlib/src/tests/attribute/reference_attribute/reference_attribute_test.cpp b/searchlib/src/tests/attribute/reference_attribute/reference_attribute_test.cpp index e534153d004..d7428f02ba5 100644 --- a/searchlib/src/tests/attribute/reference_attribute/reference_attribute_test.cpp +++ b/searchlib/src/tests/attribute/reference_attribute/reference_attribute_test.cpp @@ -1,28 +1,39 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <vespa/log/log.h> -LOG_SETUP("reference_attribute_test"); -#include <vespa/vespalib/testkit/testapp.h> -#include <vespa/vespalib/test/insertion_operators.h> -#include <vespa/vespalib/util/traits.h> -#include <vespa/vespalib/io/fileutil.h> + +#include <vespa/document/base/documentid.h> #include <vespa/searchlib/attribute/attributeguard.h> #include <vespa/searchlib/attribute/reference_attribute.h> -#include <vespa/searchlib/common/i_gid_to_lid_mapper_factory.h> #include <vespa/searchlib/common/i_gid_to_lid_mapper.h> +#include <vespa/searchlib/common/i_gid_to_lid_mapper_factory.h> +#include <vespa/searchlib/fef/termfieldmatchdata.h> +#include <vespa/searchlib/query/queryterm.h> +#include <vespa/searchlib/queryeval/fake_result.h> +#include <vespa/searchlib/queryeval/searchiterator.h> #include <vespa/searchlib/test/mock_gid_to_lid_mapping.h> -#include <vespa/document/base/documentid.h> +#include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/io/fileutil.h> +#include <vespa/vespalib/test/insertion_operators.h> +#include <vespa/vespalib/util/traits.h> -using vespalib::MemoryUsage; -using vespalib::ArrayRef; +#include <vespa/log/log.h> +LOG_SETUP("reference_attribute_test"); + +using document::DocumentId; +using document::GlobalId; using generation_t = vespalib::GenerationHandler::generation_t; +using search::AttributeGuard; +using search::AttributeVector; +using search::QueryTermSimple; +using search::attribute::BasicType; +using search::attribute::Config; using search::attribute::Reference; using search::attribute::ReferenceAttribute; -using search::attribute::Config; -using search::attribute::BasicType; -using search::AttributeVector; -using search::AttributeGuard; -using document::GlobalId; -using document::DocumentId; +using search::attribute::SearchContextParams; +using search::fef::TermFieldMatchData; +using search::queryeval::FakeResult; +using search::queryeval::SearchIterator; +using vespalib::ArrayRef; +using vespalib::MemoryUsage; namespace { @@ -36,8 +47,7 @@ vespalib::string doc3("id:test:music::3"); } -struct MyGidToLidMapperFactory : public search::attribute::test::MockGidToLidMapperFactory -{ +struct MyGidToLidMapperFactory : public search::attribute::test::MockGidToLidMapperFactory { MyGidToLidMapperFactory() { _map.insert({toGid(doc1), 10}); _map.insert({toGid(doc2), 17}); @@ -55,8 +65,7 @@ struct MyGidToLidMapperFactory : public search::attribute::test::MockGidToLidMap } }; -class LidCollector -{ +class LidCollector { std::vector<uint32_t> &_lids; public: LidCollector(std::vector<uint32_t> &lids) @@ -66,16 +75,17 @@ public: void operator()(uint32_t lid) { _lids.push_back(lid); } }; -struct Fixture -{ +struct ReferenceAttributeTest : public ::testing::Test { std::shared_ptr<ReferenceAttribute> _attr; - Fixture() + ReferenceAttributeTest() : _attr() { resetAttr(); } + ~ReferenceAttributeTest() {} + AttributeVector &attr() { return *_attr; } @@ -122,36 +132,34 @@ struct Fixture void commit() { attr().commit(); } - void assertNoRef(uint32_t doc) - { + void assertNoRef(uint32_t doc) { EXPECT_TRUE(get(doc) == nullptr); } void assertRef(vespalib::stringref str, uint32_t doc) { const GlobalId *gid = get(doc); - EXPECT_TRUE(gid != nullptr); - EXPECT_EQUAL(toGid(str), *gid); + ASSERT_TRUE(gid != nullptr); + EXPECT_EQ(toGid(str), *gid); } void assertTargetLid(uint32_t doc, uint32_t expTargetLid) { auto ref = getRef(doc); - EXPECT_TRUE(ref != nullptr); - EXPECT_EQUAL(expTargetLid, ref->lid()); - EXPECT_EQUAL(expTargetLid, _attr->getTargetLid(doc)); + ASSERT_TRUE(ref != nullptr); + EXPECT_EQ(expTargetLid, ref->lid()); + EXPECT_EQ(expTargetLid, _attr->getTargetLid(doc)); } void assertNoTargetLid(uint32_t doc) { auto ref = getRef(doc); EXPECT_TRUE(ref == nullptr); - EXPECT_EQUAL(0u, _attr->getTargetLid(doc)); + EXPECT_EQ(0u, _attr->getTargetLid(doc)); } - void assertLids(uint32_t targetLid, std::vector<uint32_t> expLids) - { + void assertLids(uint32_t targetLid, std::vector<uint32_t> expLids) { std::vector<uint32_t> lids; LidCollector collector(lids); _attr->foreach_lid(targetLid, collector); - EXPECT_EQUAL(expLids, lids); + EXPECT_EQ(expLids, lids); } void save() { @@ -176,9 +184,8 @@ struct Fixture } oldStatus = newStatus; } - EXPECT_GREATER(iterLimit, iter); - LOG(info, - "iter = %" PRIu64 ", memory usage %" PRIu64 ", -> %" PRIu64, + EXPECT_GT(iterLimit, iter); + LOG(info, "iter = %" PRIu64 ", memory usage %" PRIu64 ", -> %" PRIu64, iter, oldStatus.getUsed(), newStatus.getUsed()); } @@ -197,137 +204,138 @@ struct Fixture } }; -TEST_F("require that we can instantiate reference attribute", Fixture) +TEST_F(ReferenceAttributeTest, reference_attribute_can_be_instantiated) { - f.ensureDocIdLimit(5); - f.set(1, toGid(doc1)); - f.set(2, toGid(doc2)); - f.commit(); - - TEST_DO(f.assertNoRef(3)); - TEST_DO(f.assertRef(doc1, 1)); - TEST_DO(f.assertRef(doc2, 2)); + ensureDocIdLimit(5); + set(1, toGid(doc1)); + set(2, toGid(doc2)); + commit(); + + assertNoRef(3); + assertRef(doc1, 1); + assertRef(doc2, 2); } -TEST_F("require that we can set new reference for a document", Fixture) +TEST_F(ReferenceAttributeTest, new_reference_for_a_document_can_be_set) { - f.ensureDocIdLimit(5); - f.set(1, toGid(doc1)); - f.set(2, toGid(doc2)); - f.set(3, toGid(doc2)); - f.commit(); - TEST_DO(f.assertNoRef(4)); - TEST_DO(f.assertRef(doc1, 1)); - TEST_DO(f.assertRef(doc2, 2)); - TEST_DO(f.assertRef(doc2, 3)); - f.set(2, toGid(doc1)); - f.commit(); - TEST_DO(f.assertNoRef(4)); - TEST_DO(f.assertRef(doc1, 1)); - TEST_DO(f.assertRef(doc1, 2)); - TEST_DO(f.assertRef(doc2, 3)); + ensureDocIdLimit(5); + set(1, toGid(doc1)); + set(2, toGid(doc2)); + set(3, toGid(doc2)); + commit(); + assertNoRef(4); + assertRef(doc1, 1); + assertRef(doc2, 2); + assertRef(doc2, 3); + set(2, toGid(doc1)); + commit(); + assertNoRef(4); + assertRef(doc1, 1); + assertRef(doc1, 2); + assertRef(doc2, 3); } -TEST_F("require that we can clear reference for a document", Fixture) +TEST_F(ReferenceAttributeTest, reference_for_a_document_can_be_cleared) { - f.ensureDocIdLimit(5); - f.set(2, toGid(doc2)); - f.commit(); - TEST_DO(f.assertRef(doc2, 2)); - f.clear(2); - f.commit(); - TEST_DO(f.assertNoRef(2)); - f.clear(2); - f.commit(); - TEST_DO(f.assertNoRef(2)); + ensureDocIdLimit(5); + set(2, toGid(doc2)); + commit(); + assertRef(doc2, 2); + clear(2); + commit(); + assertNoRef(2); + clear(2); + commit(); + assertNoRef(2); } -TEST_F("require that read guard protects reference", Fixture) +TEST_F(ReferenceAttributeTest, read_guard_protects_references) { - f.ensureDocIdLimit(5); - f.set(2, toGid(doc2)); - f.commit(); - const GlobalId *gid = f.get(2); - EXPECT_TRUE(gid != nullptr); - EXPECT_EQUAL(toGid(doc2), *gid); + ensureDocIdLimit(5); + set(2, toGid(doc2)); + commit(); + const GlobalId *gid = get(2); + ASSERT_TRUE(gid != nullptr); + EXPECT_EQ(toGid(doc2), *gid); { - AttributeGuard guard(f._attr); - f.clear(2); - f.commit(); - EXPECT_EQUAL(toGid(doc2), *gid); + AttributeGuard guard(_attr); + clear(2); + commit(); + EXPECT_EQ(toGid(doc2), *gid); } - f.commit(); - EXPECT_NOT_EQUAL(toGid(doc2), *gid); + commit(); + EXPECT_NE(toGid(doc2), *gid); } -TEST_F("require that we can compact attribute", Fixture) +TEST_F(ReferenceAttributeTest, attribute_can_be_compacted) { - f.ensureDocIdLimit(5); - f.set(1, toGid(doc1)); - f.set(2, toGid(doc2)); - f.commit(); - TEST_DO(f.triggerCompaction(100000)); - TEST_DO(f.assertNoRef(3)); - TEST_DO(f.assertRef(doc1, 1)); - TEST_DO(f.assertRef(doc2, 2)); + ensureDocIdLimit(5); + set(1, toGid(doc1)); + set(2, toGid(doc2)); + commit(); + triggerCompaction(100000); + assertNoRef(3); + assertRef(doc1, 1); + assertRef(doc2, 2); } -TEST_F("require that we can save and load attribute", Fixture) +TEST_F(ReferenceAttributeTest, attribute_can_be_saved_and_loaded) { - f.ensureDocIdLimit(5); - f.set(1, toGid(doc1)); - f.set(2, toGid(doc2)); - f.set(4, toGid(doc1)); - f.commit(); - f.save(); - f.load(); - EXPECT_EQUAL(5u, f.attr().getNumDocs()); - TEST_DO(f.assertNoRef(3)); - TEST_DO(f.assertRef(doc1, 1)); - TEST_DO(f.assertRef(doc2, 2)); - TEST_DO(f.assertRef(doc1, 4)); + ensureDocIdLimit(5); + set(1, toGid(doc1)); + set(2, toGid(doc2)); + set(4, toGid(doc1)); + commit(); + save(); + load(); + EXPECT_EQ(5u, attr().getNumDocs()); + assertNoRef(3); + assertRef(doc1, 1); + assertRef(doc2, 2); + assertRef(doc1, 4); EXPECT_TRUE(vespalib::unlink("test.dat")); EXPECT_TRUE(vespalib::unlink("test.udat")); } -TEST_F("require that update() uses gid-mapper to set target lid", Fixture) +TEST_F(ReferenceAttributeTest, update_uses_gid_mapper_to_set_target_lid) { - f.ensureDocIdLimit(6); + ensureDocIdLimit(6); auto factory = std::make_shared<MyGidToLidMapperFactory>(); - f.setGidToLidMapperFactory(factory); - f.set(1, toGid(doc1)); - f.set(2, toGid(doc2)); - f.set(4, toGid(doc1)); - f.set(5, toGid(doc3)); - f.commit(); - TEST_DO(f.assertTargetLid(1, 10)); - TEST_DO(f.assertTargetLid(2, 17)); - TEST_DO(f.assertNoTargetLid(3)); - TEST_DO(f.assertTargetLid(4, 10)); - TEST_DO(f.assertTargetLid(5, 0)); + setGidToLidMapperFactory(factory); + set(1, toGid(doc1)); + set(2, toGid(doc2)); + set(4, toGid(doc1)); + set(5, toGid(doc3)); + commit(); + assertTargetLid(1, 10); + assertTargetLid(2, 17); + assertNoTargetLid(3); + assertTargetLid(4, 10); + assertTargetLid(5, 0); } -TEST_F("require that notifyReferencedPut() updates lid-2-lid mapping", Fixture) +TEST_F(ReferenceAttributeTest, notifyReferencedPut_updates_lid_2_lid_mapping) { - f.ensureDocIdLimit(4); - f.set(1, toGid(doc1)); - f.set(2, toGid(doc2)); - f.set(3, toGid(doc1)); - f.commit(); - TEST_DO(f.assertTargetLid(1, 0)); - TEST_DO(f.assertTargetLid(2, 0)); - TEST_DO(f.assertTargetLid(3, 0)); - f.notifyReferencedPut(toGid(doc1), 10); - f.notifyReferencedPut(toGid(doc2), 20); - f.notifyReferencedPut(toGid(doc3), 30); - TEST_DO(f.assertTargetLid(1, 10)); - TEST_DO(f.assertTargetLid(2, 20)); - TEST_DO(f.assertTargetLid(3, 10)); + ensureDocIdLimit(4); + set(1, toGid(doc1)); + set(2, toGid(doc2)); + set(3, toGid(doc1)); + commit(); + assertTargetLid(1, 0); + assertTargetLid(2, 0); + assertTargetLid(3, 0); + notifyReferencedPut(toGid(doc1), 10); + notifyReferencedPut(toGid(doc2), 20); + notifyReferencedPut(toGid(doc3), 30); + assertTargetLid(1, 10); + assertTargetLid(2, 20); + assertTargetLid(3, 10); } namespace { -void preparePopulateTargetLids(Fixture &f) +void +preparePopulateTargetLids(ReferenceAttributeTest &f) { f.ensureDocIdLimit(6); f.set(1, toGid(doc1)); @@ -335,91 +343,136 @@ void preparePopulateTargetLids(Fixture &f) f.set(3, toGid(doc1)); f.set(4, toGid(doc3)); f.commit(); - TEST_DO(f.assertTargetLid(1, 0)); - TEST_DO(f.assertTargetLid(2, 0)); - TEST_DO(f.assertTargetLid(3, 0)); - TEST_DO(f.assertTargetLid(4, 0)); - TEST_DO(f.assertNoTargetLid(5)); + f.assertTargetLid(1, 0); + f.assertTargetLid(2, 0); + f.assertTargetLid(3, 0); + f.assertTargetLid(4, 0); + f.assertNoTargetLid(5); } -void checkPopulateTargetLids(Fixture &f) +void +checkPopulateTargetLids(ReferenceAttributeTest &f) { auto factory = std::make_shared<MyGidToLidMapperFactory>(); f.setGidToLidMapperFactory(factory); - TEST_DO(f.assertTargetLid(1, 10)); - TEST_DO(f.assertTargetLid(2, 17)); - TEST_DO(f.assertTargetLid(3, 10)); - TEST_DO(f.assertTargetLid(4, 0)); - TEST_DO(f.assertNoTargetLid(5)); - TEST_DO(f.assertLids(0, { })); - TEST_DO(f.assertLids(10, { 1, 3})); - TEST_DO(f.assertLids(17, { 2 })); - TEST_DO(f.assertLids(18, { })); + f.assertTargetLid(1, 10); + f.assertTargetLid(2, 17); + f.assertTargetLid(3, 10); + f.assertTargetLid(4, 0); + f.assertNoTargetLid(5); + f.assertLids(0, { }); + f.assertLids(10, { 1, 3}); + f.assertLids(17, { 2 }); + f.assertLids(18, { }); } } -TEST_F("require that populateTargetLids() uses gid-mapper to update lid-2-lid mapping", Fixture) +TEST_F(ReferenceAttributeTest, populateTargetLids_uses_gid_mapper_to_update_lid_2_lid_mapping) { - TEST_DO(preparePopulateTargetLids(f)); - TEST_DO(checkPopulateTargetLids(f)); + preparePopulateTargetLids(*this); + checkPopulateTargetLids(*this); } -TEST_F("require that populateTargetLids() uses gid-mapper to update lid-2-lid mapping after load", Fixture) +TEST_F(ReferenceAttributeTest, populateTargetLids_uses_gid_mapper_to_update_lid_2_lid_mapping_after_load) { - TEST_DO(preparePopulateTargetLids(f)); - f.save(); - f.load(); - TEST_DO(checkPopulateTargetLids(f)); + preparePopulateTargetLids(*this); + save(); + load(); + checkPopulateTargetLids(*this); EXPECT_TRUE(vespalib::unlink("test.dat")); EXPECT_TRUE(vespalib::unlink("test.udat")); } -TEST_F("Require that notifyReferencedPut and notifyReferencedRemove changes reverse mapping", Fixture) +TEST_F(ReferenceAttributeTest, notifyReferencedPut_and_notifyReferencedRemove_changes_reverse_mapping) { - TEST_DO(preparePopulateTargetLids(f)); - TEST_DO(f.assertLids(10, { })); - TEST_DO(f.assertLids(11, { })); - f.notifyReferencedPut(toGid(doc1), 10); - TEST_DO(f.assertLids(10, { 1, 3})); - TEST_DO(f.assertLids(11, { })); - f.notifyReferencedPut(toGid(doc1), 11); - TEST_DO(f.assertLids(10, { })); - TEST_DO(f.assertLids(11, { 1, 3})); - f.notifyReferencedRemove(toGid(doc1)); - TEST_DO(f.assertLids(10, { })); - TEST_DO(f.assertLids(11, { })); + preparePopulateTargetLids(*this); + assertLids(10, { }); + assertLids(11, { }); + notifyReferencedPut(toGid(doc1), 10); + assertLids(10, { 1, 3}); + assertLids(11, { }); + notifyReferencedPut(toGid(doc1), 11); + assertLids(10, { }); + assertLids(11, { 1, 3}); + notifyReferencedRemove(toGid(doc1)); + assertLids(10, { }); + assertLids(11, { }); } -TEST_F("Require that we track unique gids", Fixture) +TEST_F(ReferenceAttributeTest, unique_gids_are_tracked) { - EXPECT_EQUAL(0u, f.getUniqueGids()); - f.notifyReferencedPut(toGid(doc1), 10); - EXPECT_EQUAL(1u, f.getUniqueGids()); - f.ensureDocIdLimit(3); - f.set(1, toGid(doc1)); - f.commit(); - EXPECT_EQUAL(1u, f.getUniqueGids()); - TEST_DO(f.assertTargetLid(1, 10)); - TEST_DO(f.assertLids(10, { 1 })); - f.set(2, toGid(doc2)); - f.commit(); - EXPECT_EQUAL(2u, f.getUniqueGids()); - TEST_DO(f.assertTargetLid(2, 0)); - f.notifyReferencedPut(toGid(doc2), 17); - EXPECT_EQUAL(2u, f.getUniqueGids()); - TEST_DO(f.assertTargetLid(2, 17)); - TEST_DO(f.assertLids(17, { 2 })); - f.clear(1); - f.notifyReferencedRemove(toGid(doc2)); - EXPECT_EQUAL(2u, f.getUniqueGids()); - TEST_DO(f.assertNoTargetLid(1)); - TEST_DO(f.assertTargetLid(2, 0)); - TEST_DO(f.assertLids(10, { })); - TEST_DO(f.assertLids(17, { })); - f.clear(2); - f.notifyReferencedRemove(toGid(doc1)); - EXPECT_EQUAL(0u, f.getUniqueGids()); + EXPECT_EQ(0u, getUniqueGids()); + notifyReferencedPut(toGid(doc1), 10); + EXPECT_EQ(1u, getUniqueGids()); + ensureDocIdLimit(3); + set(1, toGid(doc1)); + commit(); + EXPECT_EQ(1u, getUniqueGids()); + assertTargetLid(1, 10); + assertLids(10, { 1 }); + set(2, toGid(doc2)); + commit(); + EXPECT_EQ(2u, getUniqueGids()); + assertTargetLid(2, 0); + notifyReferencedPut(toGid(doc2), 17); + EXPECT_EQ(2u, getUniqueGids()); + assertTargetLid(2, 17); + assertLids(17, { 2 }); + clear(1); + notifyReferencedRemove(toGid(doc2)); + EXPECT_EQ(2u, getUniqueGids()); + assertNoTargetLid(1); + assertTargetLid(2, 0); + assertLids(10, { }); + assertLids(17, { }); + clear(2); + notifyReferencedRemove(toGid(doc1)); + EXPECT_EQ(0u, getUniqueGids()); +} + +struct ReferenceAttributeSearchTest : public ReferenceAttributeTest { + + constexpr static uint32_t doc_id_limit = 6; + + ReferenceAttributeSearchTest() + : ReferenceAttributeTest() + { + ensureDocIdLimit(doc_id_limit); + set(1, toGid(doc1)); + set(3, toGid(doc2)); + set(4, toGid(doc1)); + commit(); + } + + FakeResult perform_search(SearchIterator& itr) { + FakeResult result; + itr.initFullRange(); + for (uint32_t doc_id = 1; doc_id < doc_id_limit; ++doc_id) { + if (itr.seek(doc_id)) { + result.doc(doc_id); + } + } + return result; + } + + void expect_search_result(const std::string& term, const FakeResult& expected) { + auto ctx = _attr->getSearch(std::make_unique<QueryTermSimple>(term, QueryTermSimple::WORD), + SearchContextParams()); + TermFieldMatchData tfmd; + auto itr = ctx->createIterator(&tfmd, false); + FakeResult actual = perform_search(*itr); + EXPECT_EQ(expected, actual); + } + +}; + +TEST_F(ReferenceAttributeSearchTest, can_be_searched_by_document_id) +{ + expect_search_result(doc1, FakeResult().doc(1).doc(4)); + expect_search_result(doc2, FakeResult().doc(3)); + expect_search_result(doc3, FakeResult()); + expect_search_result("invalid document id", FakeResult()); } -TEST_MAIN() { TEST_RUN_ALL(); } +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp b/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp index f91108d066b..82539214ea9 100644 --- a/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp +++ b/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp @@ -1,21 +1,28 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include "reference_attribute.h" -#include "reference_attribute_saver.h" #include "attributesaver.h" #include "readerbase.h" -#include <vespa/searchlib/common/i_gid_to_lid_mapper_factory.h> +#include "reference_attribute.h" +#include "reference_attribute_saver.h" +#include <vespa/document/base/documentid.h> +#include <vespa/document/base/idstringexception.h> #include <vespa/searchlib/common/i_gid_to_lid_mapper.h> -#include <vespa/vespalib/datastore/unique_store_builder.h> +#include <vespa/searchlib/common/i_gid_to_lid_mapper_factory.h> +#include <vespa/searchlib/query/queryterm.h> +#include <vespa/vespalib/data/fileheader.h> #include <vespa/vespalib/datastore/datastore.hpp> #include <vespa/vespalib/datastore/unique_store.hpp> -#include <vespa/vespalib/data/fileheader.h> +#include <vespa/vespalib/datastore/unique_store_builder.h> #include <vespa/log/log.h> LOG_SETUP(".searchlib.attribute.reference_attribute"); namespace search::attribute { +using document::DocumentId; +using document::GlobalId; +using document::IdParseException; + namespace { // minimum dead bytes in unique store before consider compaction @@ -265,7 +272,7 @@ ReferenceAttribute::update(DocId doc, const GlobalId &gid) } const Reference * -ReferenceAttribute::getReference(DocId doc) +ReferenceAttribute::getReference(DocId doc) const { assert(doc < _indices.size()); EntryRef ref = _indices[doc]; @@ -411,6 +418,56 @@ ReferenceAttribute::onShrinkLidSpace() setNumDocs(committedDocIdLimit); } +namespace { + +class ReferenceSearchContext : public AttributeVector::SearchContext { +private: + const ReferenceAttribute& _ref_attr; + GlobalId _term; + +public: + ReferenceSearchContext(const ReferenceAttribute& ref_attr, const GlobalId& term) + : AttributeVector::SearchContext(ref_attr), + _ref_attr(ref_attr), + _term(term) + { + } + bool valid() const override { + return _term != GlobalId(); + } + int32_t onFind(DocId docId, int32_t elementId, int32_t& weight) const override { + if (elementId != 0) { + return -1; + } + auto* ref = _ref_attr.getReference(docId); + if (ref == nullptr) { + return -1; + } + weight = 1; + return (_term == ref->gid()) ? 0 : -1; + } + int32_t onFind(DocId docId, int32_t elementId) const override { + int32_t weight; + return onFind(docId, elementId, weight); + } +}; + +} + +AttributeVector::SearchContext::UP +ReferenceAttribute::getSearch(QueryTermSimpleUP term, const attribute::SearchContextParams& params) const +{ + (void) params; + GlobalId gid; + try { + DocumentId docId(term->getTerm()); + gid = docId.getGlobalId(); + } catch (const IdParseException&) { + // The query term is not valid, which will result in an empty search iterator. + } + return std::make_unique<ReferenceSearchContext>(*this, gid); +} + IMPLEMENT_IDENTIFIABLE_ABSTRACT(ReferenceAttribute, AttributeVector); } diff --git a/searchlib/src/vespa/searchlib/attribute/reference_attribute.h b/searchlib/src/vespa/searchlib/attribute/reference_attribute.h index 87d5a5c27bb..87d624eb21f 100644 --- a/searchlib/src/vespa/searchlib/attribute/reference_attribute.h +++ b/searchlib/src/vespa/searchlib/attribute/reference_attribute.h @@ -3,8 +3,8 @@ #pragma once #include "not_implemented_attribute.h" -#include "reference_mappings.h" #include "reference.h" +#include "reference_mappings.h" #include <vespa/vespalib/datastore/unique_store.h> #include <vespa/vespalib/util/rcuvector.h> @@ -71,7 +71,7 @@ public: bool addDoc(DocId &doc) override; uint32_t clearDoc(DocId doc) override; void update(DocId doc, const GlobalId &gid); - const Reference *getReference(DocId doc); + const Reference *getReference(DocId doc) const; void setGidToLidMapperFactory(std::shared_ptr<IGidToLidMapperFactory> gidToLidMapperFactory); std::shared_ptr<IGidToLidMapperFactory> getGidToLidMapperFactory() const { return _gidToLidMapperFactory; } TargetLids getTargetLids() const { return _referenceMappings.getTargetLids(); } @@ -91,6 +91,8 @@ public: foreach_lid(uint32_t targetLid, FunctionType &&func) const { _referenceMappings.foreach_lid(targetLid, std::forward<FunctionType>(func)); } + + SearchContext::UP getSearch(QueryTermSimpleUP term, const attribute::SearchContextParams& params) const override; }; } diff --git a/security-tools/src/main/java/com/yahoo/vespa/security/tool/securityenv/Main.java b/security-tools/src/main/java/com/yahoo/vespa/security/tool/securityenv/Main.java index ae18700246c..367d7b9dd83 100644 --- a/security-tools/src/main/java/com/yahoo/vespa/security/tool/securityenv/Main.java +++ b/security-tools/src/main/java/com/yahoo/vespa/security/tool/securityenv/Main.java @@ -51,17 +51,15 @@ public class Main { Map<OutputVariable, String> outputVariables = new TreeMap<>(); Optional<TransportSecurityOptions> options = TransportSecurityUtils.getOptions(envVars); - if (options.isPresent()) { + MixedMode mixedMode = TransportSecurityUtils.getInsecureMixedMode(envVars); + if (options.isPresent() && mixedMode != MixedMode.PLAINTEXT_CLIENT_MIXED_SERVER) { outputVariables.put(OutputVariable.TLS_ENABLED, "1"); options.get().getCaCertificatesFile() .ifPresent(caCertFile -> outputVariables.put(OutputVariable.CA_CERTIFICATE, caCertFile.toString())); - MixedMode mixedMode = TransportSecurityUtils.getInsecureMixedMode(envVars); - if (mixedMode != MixedMode.PLAINTEXT_CLIENT_MIXED_SERVER) { - options.get().getCertificatesFile() - .ifPresent(certificateFile -> outputVariables.put(OutputVariable.CERTIFICATE, certificateFile.toString())); - options.get().getPrivateKeyFile() - .ifPresent(privateKeyFile -> outputVariables.put(OutputVariable.PRIVATE_KEY, privateKeyFile.toString())); - } + options.get().getCertificatesFile() + .ifPresent(certificateFile -> outputVariables.put(OutputVariable.CERTIFICATE, certificateFile.toString())); + options.get().getPrivateKeyFile() + .ifPresent(privateKeyFile -> outputVariables.put(OutputVariable.PRIVATE_KEY, privateKeyFile.toString())); } shell.writeOutputVariables(stdOut, outputVariables); EnumSet<OutputVariable> unusedVariables = outputVariables.isEmpty() diff --git a/security-tools/src/main/sh/vespa-curl-wrapper b/security-tools/src/main/sh/vespa-curl-wrapper index 7c2f31d7719..da857984c01 100755 --- a/security-tools/src/main/sh/vespa-curl-wrapper +++ b/security-tools/src/main/sh/vespa-curl-wrapper @@ -6,26 +6,23 @@ set -e -. $(vespa-security-env) +eval $(vespa-security-env) -CURL_PARAMETERS=$1 -CONFIGSERVER_URI_WITHOUT_SCHEME=$2 +CURL_PARAMETERS=("$@") if [ -n "${VESPA_TLS_ENABLED}" ] then - CONFIGSERVER_URI="https://${CONFIGSERVER_URI_WITHOUT_SCHEME}" -else - CONFIGSERVER_URI="http://${CONFIGSERVER_URI_WITHOUT_SCHEME}" + CURL_PARAMETERS=("${CURL_PARAMETERS[@]/http:/https:}") fi if [ -n "${VESPA_TLS_CA_CERT}" ] then - CURL_PARAMETERS="--cacert \"${VESPA_TLS_CA_CERT}\" ${CURL_PARAMETERS}" + CURL_PARAMETERS=("--cacert" "${VESPA_TLS_CA_CERT}" "${CURL_PARAMETERS[@]}") fi if [[ -n "${VESPA_TLS_CERT}" && -n "${VESPA_TLS_PRIVATE_KEY}" ]] then - CURL_PARAMETERS="--cert \"${VESPA_TLS_CERT}\" --key \"${VESPA_TLS_PRIVATE_KEY}\" ${CURL_PARAMETERS}" + CURL_PARAMETERS=("--cert" "${VESPA_TLS_CERT}" "--key" "${VESPA_TLS_PRIVATE_KEY}" "${CURL_PARAMETERS[@]}") fi -curl ${CURL_PARAMETERS} "${CONFIGSERVER_URI}" +curl "${CURL_PARAMETERS[@]}" diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index a16127931e9..6f37b9edea4 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -2576,6 +2576,20 @@ ], "fields": [] }, + "com.yahoo.text.ExpressionFormatter": { + "superClass": "java.lang.Object", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public java.lang.String format(java.lang.String)", + "public static java.lang.String on(java.lang.String)", + "public static com.yahoo.text.ExpressionFormatter withLineLength(int)", + "public static com.yahoo.text.ExpressionFormatter inTwoColumnMode(int, int)" + ], + "fields": [] + }, "com.yahoo.text.ForwardWriter": { "superClass": "com.yahoo.text.GenericWriter", "interfaces": [], diff --git a/vespajlib/src/main/java/com/yahoo/collections/ListMap.java b/vespajlib/src/main/java/com/yahoo/collections/ListMap.java index e851362a99d..479850beb1a 100644 --- a/vespajlib/src/main/java/com/yahoo/collections/ListMap.java +++ b/vespajlib/src/main/java/com/yahoo/collections/ListMap.java @@ -23,6 +23,12 @@ public class ListMap<K, V> { this(HashMap.class); } + /** Copy constructor. This will not be frozen even if the argument map is */ + public ListMap(ListMap<K, V> original) { + map = new HashMap<>(); + original.map.forEach((k, v) -> this.map.put(k, new ArrayList<>(v))); + } + @SuppressWarnings("unchecked") public ListMap(@SuppressWarnings("rawtypes") Class<? extends Map> implementation) { try { @@ -45,6 +51,27 @@ public class ListMap<K, V> { list.add(value); } + /** Put a key without adding a new value, such that there is an empty list of values if no values are already added */ + public void put(K key) { + List<V> list = map.get(key); + if (list == null) { + list = new ArrayList<>(); + map.put(key, list); + } + } + + /** Put this map in the state where it has just the given value of the given key */ + public void replace(K key, V value) { + List<V> list = map.get(key); + if (list == null) { + put(key); + } + else { + list.clear(); + list.add(value); + } + } + public void removeAll(K key) { map.remove(key); } @@ -73,13 +100,13 @@ public class ListMap<K, V> { /** * Returns the List containing the elements with this key, or an empty list - * if there are no elements for this key. The list returned is unmodifiable. + * if there are no elements for this key. + * The returned list can be modified to add and remove values if the value exists. */ public List<V> get(K key) { List<V> list = map.get(key); - if (list == null) - return ImmutableList.of();; - return ImmutableList.copyOf(list); + if (list == null) return ImmutableList.of();; + return list; } /** The same as get */ diff --git a/vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java b/vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java new file mode 100644 index 00000000000..280b75f9cbb --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java @@ -0,0 +1,180 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.text; + +/** + * Formats any parenthesis expression. + * In addition to the obvious this can also operate in "two column mode", + * wherein each chunk that will be formatted on a separate line may optionally + * contain a prefix marked by a start and end tab sign which will be printed in a left column of the given fixed size. + * The prefix itself is not formatted but will be cut if too long. + * + * @author bratseth + */ +public class ExpressionFormatter { + + private static final int indentUnit = 2; + + /** The size of the first column, or 0 if none */ + private final int firstColumnLength; + + /** + * The desired size of the second column (or the entire line if no first column), + * or 0 to split into multiple lines as much as possible. + * Setting this collects larger chunks to one line across markup + * but will not split too long lines that have no markup. + */ + private final int secondColumnLength; + + private ExpressionFormatter(int firstColumnLength, int secondColumnLength) { + this.firstColumnLength = firstColumnLength; + this.secondColumnLength = secondColumnLength; + } + + public String format(String parenthesisExpression) { + StringBuilder b = new StringBuilder(); + format(parenthesisExpression, 0, b); + while (b.length() > 0 && Character.isWhitespace(b.charAt(b.length() - 1))) + b.setLength(b.length() - 1); + return b.toString(); + } + + private void format(String expression, int indent, StringBuilder b) { + if (expression.isEmpty()) return; + expression = appendFirstColumn(expression, b); + + Markup next = Markup.next(expression); + + appendIndent( ! next.isClose() || next.position() > 0 ? indent : indent - 2, b); + + int endOfBalancedChunk = endOfBalancedChunk(expression, Math.max(0, secondColumnLength - indent)); + if (next.isEmpty()) { + b.append(expression); + } + else if (endOfBalancedChunk > 0) { + b.append(expression, 0, endOfBalancedChunk + 1).append("\n"); + format(expression.substring(endOfBalancedChunk + 1), indent, b); + } + else if (next.isComma()) { + b.append(expression, 0, next.position() + 1).append("\n"); + format(expression.substring(next.position() + 1), indent, b); + } + else { + if ( next.isClose() && next.position() > 0) { // content before end parenthesis: content, newline, then end parenthesis + b.append(expression, 0, next.position()).append("\n"); + appendFirstColumn(")", b); + appendIndent(indent - 2, b); + b.append(")\n"); + } + else { + b.append(expression, 0, next.position() + 1).append("\n"); + } + format(expression.substring(next.position() + 1), indent + (next.isOpen() ? indentUnit : -indentUnit), b); + } + } + + /** Returns the position of the end of a balanced chunk of at most the given size, or 0 if there is no such chunk */ + private int endOfBalancedChunk(String expression, int maxSize) { + int chunkSize = 0; + int i = 0; + int nesting = 0; + while (i < maxSize && i < expression.length()) { + if (expression.charAt(i) == '\t') return chunkSize; + if (expression.charAt(i) == '(') nesting++; + if (expression.charAt(i) == ')') nesting--; + if (nesting < 0) return chunkSize; + if (nesting == 0 && ( expression.charAt(i)==')' || expression.charAt(i)==',')) + chunkSize = i; + i++; + } + return chunkSize; + } + + private String appendFirstColumn(String expression, StringBuilder b) { + if (firstColumnLength == 0) return expression; + + while (expression.charAt(0) == ' ') + expression = expression.substring(1); + + if (expression.charAt(0) == '\t') { + int tab2 = expression.indexOf('\t', 1); + if (tab2 >= 0) { + String firstColumn = expression.substring(1, tab2); + b.append(asSize(firstColumnLength, firstColumn)).append(" "); + return expression.substring(tab2 + 1); + } + } + appendIndent(firstColumnLength + 1, b); + return expression; + } + + private void appendIndent(int indent, StringBuilder b) { + b.append(" ".repeat(Math.max(0, indent))); + } + + private String asSize(int size, String s) { + if (s.length() > size) + return s.substring(0, size); + else + return s + " ".repeat(size - s.length()); + } + + /** Convenience method creating a formatter and using it to format the given expression */ + public static String on(String parenthesisExpression) { + return new ExpressionFormatter(0, 80).format(parenthesisExpression); + } + + public static ExpressionFormatter withLineLength(int maxLineLength) { + return new ExpressionFormatter(0, maxLineLength); + } + + public static ExpressionFormatter inTwoColumnMode(int firstColumnSize, int secondColumnSize) { + return new ExpressionFormatter(firstColumnSize, secondColumnSize); + } + + /** Contains the next position of each kind of markup, or Integer.MAX_VALUE if not present */ + private static class Markup { + + final int open, close, comma; + + private Markup(int open, int close, int comma) { + this.open = open; + this.close = close; + this.comma = comma; + } + + int position() { + return Math.min(Math.min(open, close), comma); + } + + boolean isOpen() { + return open < close && open < comma; + } + + boolean isClose() { + return close < open && close < comma; + } + + boolean isComma() { + return comma < open && comma < close; + } + + boolean isEmpty() { + return open == Integer.MAX_VALUE && close == Integer.MAX_VALUE && comma == Integer.MAX_VALUE; + } + + static Markup next(String expression) { + int nextOpen = expression.indexOf('('); + int nextClose = expression.indexOf(')'); + int nextComma = expression.indexOf(','); + if (nextOpen < 0) + nextOpen = Integer.MAX_VALUE; + if (nextClose < 0) + nextClose = Integer.MAX_VALUE; + if (nextComma < 0) + nextComma = Integer.MAX_VALUE; + return new Markup(nextOpen, nextClose, nextComma); + } + + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java b/vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java new file mode 100644 index 00000000000..7251ccef521 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java @@ -0,0 +1,190 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.text; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class ExpressionFormatterTest { + + @Test + public void testBasic() { + String expected = + "foo(\n" + + " bar(\n" + + " baz(\n" + + " )\n" + + " )\n" + + ")"; + assertPrettyPrint(expected, "foo(bar(baz()))", 0); + } + + @Test + public void testBasicDense() { + assertPrettyPrint("foo(bar(baz()))", "foo(bar(baz()))", 50); + } + + @Test + public void testArgument() { + String expected = + "foo(\n" + + " bar(\n" + + " baz(\n" + + " hello world\n" + + " )\n" + + " )\n" + + ")"; + assertPrettyPrint(expected, "foo(bar(baz(hello world)))", 0); + } + + @Test + public void testMultipleArguments() { + String expected = + "foo(\n" + + " bar(\n" + + " baz(\n" + + " hello world,\n" + + " 37\n" + + " )\n" + + " )\n" + + ")"; + assertPrettyPrint(expected, "foo(bar(baz(hello world,37)))", 0); + } + + @Test + public void testMultipleArgumentsSemiDense() { + String expected = + "foo(\n" + + " bar(\n" + + " baz(hi,37),\n" + + " baz(\n" + + " hello world,\n" + + " 37\n" + + " )\n" + + " )\n" + + ")"; + assertPrettyPrint(expected, "foo(bar(baz(hi,37),baz(hello world,37)))", 15); + } + + @Test + public void testUnmatchedStart() { + String expected = + "foo(\n" + + " (\n" + + " bar(\n" + + " baz(\n" + + " )\n" + + " )\n" + + " )"; + assertPrettyPrint(expected, "foo((bar(baz()))", 0); + } + + @Test + public void testUnmatchedEnd() { + String expected = + "foo(\n" + + " bar(\n" + + " baz(\n" + + " )\n" + + " )\n" + + ")\n" + + ")"; + assertPrettyPrint(expected, "foo(bar(baz())))", 0); + } + + @Test + public void testNoParenthesis() { + String expected = + "foo bar baz"; + assertPrettyPrint(expected, "foo bar baz", 0); + } + + @Test + public void testEmpty() { + String expected = + ""; + assertPrettyPrint(expected, "", 0); + } + + @Test + public void test2ColumnMode() { + String expected = + "1: foo(\n" + + " bar(\n" + + " baz(\n" + + "2: hello world\n" + + " )\n" + + "t(o )\n" + + " )"; + ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 0); + assertEquals(expected, pp.format("\t1:\tfoo(bar(baz(\t2:\thello world)\tt(o)@olong:\t))")); + } + + @Test + public void test2ColumnModeMultipleArguments() { + String expected = + "1: foo(\n" + + " bar(\n" + + " baz(\n" + + "2: hello world,\n" + + "3: 37\n" + + " )\n" + + "t(o )\n" + + " )"; + ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 0); + assertEquals(expected, pp.format("\t1:\tfoo(bar(baz(\t2:\thello world,\t3:\t37)\tt(o)@olong:\t))")); + } + + @Test + public void test2ColumnModeMultipleArgumentsSemiDense() { + String expected = + "1: foo(\n" + + " bar(\n" + + " baz(hi,37),\n" + + " boz(\n" + + "2: hello world,\n" + + "3: 5\n" + + " )\n" + + "t(o )\n" + + " )"; + ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 15); + assertEquals(expected, pp.format("\t1:\tfoo(bar(baz(hi,37),boz(\t2:\thello world,\t3:\t5)\tt(o)@olong:\t))")); + } + + @Test + public void test2ColumnModeMultipleArgumentsWithSpaces() { + String expected = + " foo(\n" + + "1: bar(\n" + + " baz(\n" + + "2: hello world,\n" + + "3: 37\n" + + " )\n" + + "t(o )\n" + + " )"; + ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 0); + assertEquals(expected, pp.format("foo(\t1:\tbar(baz(\t2:\thello world, \t3:\t37)\tt(o)@olong:\t))")); + } + + @Test + public void testTwoColumnLambdaFunction() { + String expected = + " join(\n" + + " a,\n" + + " join(\n" + + " b, c, f(a, b)(a * b)\n" + + " )\n" + + " , f(a, b)(a * b)\n" + + " )"; + ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(5, 25); + assertEquals(expected, pp.format("join(a, join(b, c, f(a, b)(a * b)), f(a, b)(a * b))")); + } + + private void assertPrettyPrint(String expected, String expression, int lineLength) { + assertEquals(expected, ExpressionFormatter.withLineLength(lineLength).format(expression)); + } + +} |