diff options
86 files changed, 963 insertions, 561 deletions
diff --git a/config-application-package/src/main/java/com/yahoo/config/model/application/AbstractApplicationPackage.java b/config-application-package/src/main/java/com/yahoo/config/model/application/AbstractApplicationPackage.java new file mode 100644 index 00000000000..c616784c7be --- /dev/null +++ b/config-application-package/src/main/java/com/yahoo/config/model/application/AbstractApplicationPackage.java @@ -0,0 +1,44 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.config.model.application; + +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.application.Xml; +import org.w3c.dom.Document; +import org.w3c.dom.Element; +import org.w3c.dom.Node; +import org.w3c.dom.NodeList; + +import java.util.Map; +import java.util.HashMap; + +/** + * Common code for all implementations of ApplicationPackage + * + * @author arnej + */ +public abstract class AbstractApplicationPackage implements ApplicationPackage { + + @Override + public Map<String,String> legacyOverrides() { + Map<String, String> result = new HashMap<>(); + try { + Document services = Xml.getDocument(getServices()); + NodeList legacyNodes = services.getElementsByTagName("legacy"); + for (int i=0; i < legacyNodes.getLength(); i++) { + var flagNodes = legacyNodes.item(i).getChildNodes(); + for (int j = 0; j < flagNodes.getLength(); ++j) { + var flagNode = flagNodes.item(j); + if (flagNode.getNodeType() == Node.ELEMENT_NODE) { + String key = flagNode.getNodeName(); + String value = flagNode.getTextContent(); + result.put(key, value); + } + } + } + } catch (Exception e) { + // nothing: This method does not validate that services.xml exists, or that it is valid xml. + } + return result; + } + +} diff --git a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java index 1223f438029..4cde4e7afaa 100644 --- a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java +++ b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java @@ -18,6 +18,7 @@ import com.yahoo.config.provision.ApplicationName; import com.yahoo.config.provision.InstanceName; import com.yahoo.config.provision.TenantName; import com.yahoo.config.provision.Zone; +import com.yahoo.config.model.application.AbstractApplicationPackage; import com.yahoo.io.HexDump; import com.yahoo.io.IOUtils; import com.yahoo.io.reader.NamedReader; @@ -72,7 +73,7 @@ import static com.yahoo.text.Lowercase.toLowerCase; * * @author Vegard Havdal */ -public class FilesApplicationPackage implements ApplicationPackage { +public class FilesApplicationPackage extends AbstractApplicationPackage { /** * The name of the subdirectory (below the original application package root) diff --git a/config-application-package/src/test/java/com/yahoo/config/model/application/provider/FilesApplicationPackageTest.java b/config-application-package/src/test/java/com/yahoo/config/model/application/provider/FilesApplicationPackageTest.java index ec68ed73864..ae6f9373e16 100644 --- a/config-application-package/src/test/java/com/yahoo/config/model/application/provider/FilesApplicationPackageTest.java +++ b/config-application-package/src/test/java/com/yahoo/config/model/application/provider/FilesApplicationPackageTest.java @@ -103,6 +103,16 @@ public class FilesApplicationPackageTest { } @Test + public void testLegacyOverrides() throws IOException { + File appDir = new File("src/test/resources/app-legacy-overrides"); + ApplicationPackage app = FilesApplicationPackage.fromFile(appDir); + var overrides = app.legacyOverrides(); + assertEquals(2, overrides.size()); + assertEquals("something here", overrides.get("foo-bar")); + assertEquals("false", overrides.get("v7-geo-positions")); + } + + @Test public void failOnEmptyServicesXml() throws IOException { File appDir = temporaryFolder.newFolder(); IOUtils.copyDirectory(new File("src/test/resources/multienvapp"), appDir); diff --git a/config-application-package/src/test/resources/app-legacy-overrides/hosts.xml b/config-application-package/src/test/resources/app-legacy-overrides/hosts.xml new file mode 100644 index 00000000000..64a07644038 --- /dev/null +++ b/config-application-package/src/test/resources/app-legacy-overrides/hosts.xml @@ -0,0 +1,10 @@ +<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<hosts xmlns:deploy="vespa" xmlns:preprocess="properties"> + <preprocess:properties> + <node1.hostname>foo.yahoo.com</node1.hostname> + <node1.hostname deploy:environment="dev">bar.yahoo.com</node1.hostname> + </preprocess:properties> + <host name="${node1.hostname}"> + <alias>node1</alias> + </host> +</hosts> diff --git a/config-application-package/src/test/resources/app-legacy-overrides/schemas/music.sd b/config-application-package/src/test/resources/app-legacy-overrides/schemas/music.sd new file mode 100644 index 00000000000..7da7c49c162 --- /dev/null +++ b/config-application-package/src/test/resources/app-legacy-overrides/schemas/music.sd @@ -0,0 +1,8 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +search music { + document music { + field f type string { + indexing: index | summary + } + } +} diff --git a/config-application-package/src/test/resources/app-legacy-overrides/services.xml b/config-application-package/src/test/resources/app-legacy-overrides/services.xml new file mode 100644 index 00000000000..5f8201336ef --- /dev/null +++ b/config-application-package/src/test/resources/app-legacy-overrides/services.xml @@ -0,0 +1,16 @@ +<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<services version='1.0'> + <legacy> + <foo-bar>something here</foo-bar> + <v7-geo-positions>false</v7-geo-positions> + </legacy> + <admin version='2.0'> + <adminserver hostalias='node0'/> + </admin> + <content version='1.0' id='foo'> + <redundancy>1</redundancy> + <documents> + <document type="music.sd" mode="index" /> + </documents> + </content> +</services> diff --git a/config-model-api/abi-spec.json b/config-model-api/abi-spec.json index 8ad9f66ee6a..ddc8c60ca31 100644 --- a/config-model-api/abi-spec.json +++ b/config-model-api/abi-spec.json @@ -127,6 +127,7 @@ "public void writeMetaData()", "public java.util.Optional getAllocatedHosts()", "public java.util.Map getFileRegistries()", + "public java.util.Map legacyOverrides()", "public java.util.Collection getSearchDefinitions()", "public abstract java.util.Collection getSchemas()", "public com.yahoo.config.application.api.ApplicationPackage preprocess(com.yahoo.config.provision.Zone, com.yahoo.config.application.api.DeployLogger)" diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java index 5c36de38c9b..d07df82fda1 100644 --- a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java +++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java @@ -238,6 +238,10 @@ public interface ApplicationPackage { return Collections.emptyMap(); } + default Map<String, String> legacyOverrides() { + return Collections.emptyMap(); + } + /** * @deprecated use {@link #getSchemas()} instead */ diff --git a/config-model/pom.xml b/config-model/pom.xml index 18b8432645e..d42d5af8975 100644 --- a/config-model/pom.xml +++ b/config-model/pom.xml @@ -306,8 +306,6 @@ <arg>-Xlint:-rawtypes</arg> <arg>-Xlint:-unchecked</arg> <arg>-Xlint:-serial</arg> - <arg>-Xlint:-cast</arg> - <arg>-Xlint:-overloads</arg> <arg>-Werror</arg> </compilerArgs> </configuration> diff --git a/config-model/src/main/java/com/yahoo/config/model/ConfigModelRepo.java b/config-model/src/main/java/com/yahoo/config/model/ConfigModelRepo.java index 4591578d7e5..8d192414871 100644 --- a/config-model/src/main/java/com/yahoo/config/model/ConfigModelRepo.java +++ b/config-model/src/main/java/com/yahoo/config/model/ConfigModelRepo.java @@ -133,6 +133,10 @@ public class ConfigModelRepo implements ConfigModelRepoAdder, Serializable, Iter for (Element servicesElement : children) { String tagName = servicesElement.getTagName(); + if (tagName.equals("legacy")) { + // for enabling legacy features from old vespa versions + continue; + } if (tagName.equals("config")) { // TODO: disallow on Vespa 8 continue; diff --git a/config-model/src/main/resources/schema/services.rnc b/config-model/src/main/resources/schema/services.rnc index 758fa107ee8..c8467898639 100644 --- a/config-model/src/main/resources/schema/services.rnc +++ b/config-model/src/main/resources/schema/services.rnc @@ -12,6 +12,7 @@ include "legacygenericcluster.rnc" start = element services { attribute version { "1.0" }? & attribute application-type { "hosted-infrastructure" }? & + element legacy { element v7-geo-positions { xsd:boolean } }? & LegacyGenericCluster* & GenericCluster* & GenericConfig* & diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ActivatedModelsBuilder.java b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ActivatedModelsBuilder.java index 687c897c88b..54ceb394ee6 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ActivatedModelsBuilder.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ActivatedModelsBuilder.java @@ -102,7 +102,7 @@ public class ActivatedModelsBuilder extends ModelsBuilder<Application> { ) { log.log(Level.FINE, () -> String.format("Loading model version %s for session %s application %s", modelFactory.version(), applicationGeneration, applicationId)); - ModelContext.Properties modelContextProperties = createModelContextProperties(applicationId); + ModelContext.Properties modelContextProperties = createModelContextProperties(applicationId, applicationPackage); Provisioned provisioned = new Provisioned(); ModelContext modelContext = new ModelContextImpl( applicationPackage, @@ -146,14 +146,14 @@ public class ActivatedModelsBuilder extends ModelsBuilder<Application> { return Optional.of(value); } - private ModelContext.Properties createModelContextProperties(ApplicationId applicationId) { + private ModelContext.Properties createModelContextProperties(ApplicationId applicationId, ApplicationPackage applicationPackage) { return new ModelContextImpl.Properties(applicationId, configserverConfig, zone(), ImmutableSet.copyOf(new ContainerEndpointsCache(TenantRepository.getTenantPath(tenant), curator).read(applicationId)), false, // We may be bootstrapping, but we only know and care during prepare false, // Always false, assume no one uses it when activating - flagSource, + LegacyFlags.from(applicationPackage, flagSource), new EndpointCertificateMetadataStore(curator, TenantRepository.getTenantPath(tenant)) .readEndpointCertificateMetadata(applicationId) .flatMap(new EndpointCertificateRetriever(secretStore)::readEndpointCertificateSecrets), diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/LegacyFlags.java b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/LegacyFlags.java new file mode 100644 index 00000000000..80467c80196 --- /dev/null +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/LegacyFlags.java @@ -0,0 +1,46 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.config.server.modelfactory; + +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.vespa.flags.Flags; +import com.yahoo.vespa.flags.FlagSource; +import com.yahoo.vespa.flags.InMemoryFlagSource; +import com.yahoo.vespa.flags.OrderedFlagSource; + +import java.util.Map; + + +/** + * @author arnej + */ +public class LegacyFlags { + + public static final String GEO_POSITIONS = "v7-geo-positions"; + public static final String FOO_BAR = "foo-bar"; // for testing + + private static FlagSource buildFrom(Map<String, String> legacyOverrides) { + var flags = new InMemoryFlagSource(); + for (var entry : legacyOverrides.entrySet()) { + String key = entry.getKey(); + String value = entry.getValue(); + boolean legacyWanted = Boolean.valueOf(value); + switch (key) { + case GEO_POSITIONS: + flags = flags.withBooleanFlag(Flags.USE_V8_GEO_POSITIONS.id(), ! legacyWanted); + break; + case FOO_BAR: + // ignored + break; + default: + throw new IllegalArgumentException("Unknown legacy override: "+key); + } + } + return flags; + } + + public static FlagSource from(ApplicationPackage pkg, FlagSource input) { + var overrides = buildFrom(pkg.legacyOverrides()); + FlagSource result = new OrderedFlagSource(overrides, input); + return result; + } +} diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java index aaacc9f69e0..e4a0fa81f94 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java @@ -35,6 +35,7 @@ import com.yahoo.vespa.config.server.deploy.ZooKeeperDeployer; import com.yahoo.vespa.config.server.filedistribution.FileDistributionFactory; import com.yahoo.vespa.config.server.host.HostValidator; import com.yahoo.vespa.config.server.http.InvalidApplicationException; +import com.yahoo.vespa.config.server.modelfactory.LegacyFlags; import com.yahoo.vespa.config.server.modelfactory.ModelFactoryRegistry; import com.yahoo.vespa.config.server.modelfactory.PreparedModelsBuilder; import com.yahoo.vespa.config.server.provision.HostProvisionerProvider; @@ -195,7 +196,7 @@ public class SessionPreparer { Set.copyOf(containerEndpoints), params.isBootstrap(), currentActiveApplicationSet.isEmpty(), - flagSource, + LegacyFlags.from(applicationPackage, flagSource), endpointCertificateSecrets, athenzDomain, params.quota(), diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackage.java b/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackage.java index 806b67758c2..6a483c38aee 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackage.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackage.java @@ -10,6 +10,7 @@ import com.yahoo.config.application.api.ComponentInfo; import com.yahoo.config.application.api.FileRegistry; import com.yahoo.config.application.api.UnparsedConfigDefinition; import com.yahoo.config.codegen.DefParser; +import com.yahoo.config.model.application.AbstractApplicationPackage; import com.yahoo.config.provision.AllocatedHosts; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.serialization.AllocatedHostsSerializer; @@ -44,7 +45,7 @@ import static com.yahoo.vespa.config.server.zookeeper.ZKApplication.USERAPP_ZK_S * * @author Tony Vaagenes */ -public class ZKApplicationPackage implements ApplicationPackage { +public class ZKApplicationPackage extends AbstractApplicationPackage { private final ZKApplication zkApplication; diff --git a/configserver/src/test/apps/legacy-flag/schemas/music.sd b/configserver/src/test/apps/legacy-flag/schemas/music.sd new file mode 100644 index 00000000000..f4b11d1e8e4 --- /dev/null +++ b/configserver/src/test/apps/legacy-flag/schemas/music.sd @@ -0,0 +1,50 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +# A basic search definition - called music, should be saved to music.sd +search music { + + # It contains one document type only - called music as well + document music { + + field title type string { + indexing: summary | index # How this field should be indexed + # index-to: title, default # Create two indexes + weight: 75 # Ranking importancy of this field, used by the built in nativeRank feature + } + + field artist type string { + indexing: summary | attribute | index + # index-to: artist, default + + weight: 25 + } + + field year type int { + indexing: summary | attribute + } + + # Increase query + field popularity type int { + indexing: summary | attribute + } + + field url type uri { + indexing: summary | index + } + + } + + rank-profile default inherits default { + first-phase { + expression: nativeRank(title,artist) + attribute(popularity) + } + + } + + rank-profile textmatch inherits default { + first-phase { + expression: nativeRank(title,artist) + } + + } + +} diff --git a/configserver/src/test/apps/legacy-flag/services.xml b/configserver/src/test/apps/legacy-flag/services.xml new file mode 100644 index 00000000000..4c9ac84b4f2 --- /dev/null +++ b/configserver/src/test/apps/legacy-flag/services.xml @@ -0,0 +1,31 @@ +<?xml version="1.0" encoding="utf-8" ?> +<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<services version="1.0"> + + <legacy> + <v7-geo-positions>false</v7-geo-positions> + <foo-bar>true</foo-bar> + </legacy> + + <admin version="2.0"> + <adminserver hostalias="node1"/> + <logserver hostalias="node1" /> + </admin> + + <content version="1.0"> + <redundancy>1</redundancy> + <documents> + <document type="music" mode="index"/> + </documents> + <nodes> + <node hostalias="node1" distribution-key="0"/> + </nodes> + </content> + + <container version="1.0"> + <nodes> + <node hostalias="node1" /> + </nodes> + </container> + +</services> diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/application/LegacyFlagsTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/application/LegacyFlagsTest.java new file mode 100644 index 00000000000..1bebe9089f1 --- /dev/null +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/application/LegacyFlagsTest.java @@ -0,0 +1,68 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.config.server.application; + +import com.yahoo.vespa.config.server.modelfactory.LegacyFlags; +import com.yahoo.cloud.config.ModelConfig; +import com.yahoo.component.Version; +import com.yahoo.config.model.application.provider.FilesApplicationPackage; +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ApplicationName; +import com.yahoo.config.provision.InstanceName; +import com.yahoo.config.provision.TenantName; +import com.yahoo.vespa.config.server.ModelStub; +import com.yahoo.vespa.config.server.ServerCache; +import com.yahoo.vespa.config.server.monitoring.MetricUpdater; +import com.yahoo.vespa.config.server.monitoring.Metrics; +import com.yahoo.vespa.flags.Flags; +import com.yahoo.vespa.flags.FlagSource; +import com.yahoo.vespa.flags.InMemoryFlagSource; +import com.yahoo.vespa.model.VespaModel; +import org.junit.Before; +import org.junit.Test; + +import com.yahoo.config.ConfigInstance; +import com.yahoo.vespa.config.ConfigKey; +import com.yahoo.vespa.config.ConfigDefinitionKey; +import com.yahoo.vespa.config.buildergen.ConfigDefinition; +import com.yahoo.document.config.DocumentmanagerConfig; + +import java.io.File; +import java.nio.charset.StandardCharsets; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * @author arnej + */ +public class LegacyFlagsTest { + + @Test + public void testThatLegacyOverridesWork() throws Exception { + File testApp = new File("src/test/apps/legacy-flag"); + var appPkg = FilesApplicationPackage.fromFile(testApp); + var flag = Flags.USE_V8_GEO_POSITIONS.bindTo(LegacyFlags.from(appPkg, new InMemoryFlagSource())); + assertTrue(flag.value()); + /* rest here tests that having a "legacy" XML tag doesn't break other things, but without actually using it: */ + VespaModel model = new VespaModel(appPkg); + ApplicationId applicationId = new ApplicationId.Builder().tenant("foo").applicationName("foo").build(); + ServerCache cache = new ServerCache(); + Application app = new Application(model, cache, 1L, new Version(1, 2, 3), + new MetricUpdater(Metrics.createTestMetrics(), Metrics.createDimensions(applicationId)), applicationId); + assertNotNull(app.getModel()); + /* + // Note: no feature flags active with this code path + ConfigDefinitionKey cdk = new ConfigDefinitionKey(DocumentmanagerConfig.CONFIG_DEF_NAME, DocumentmanagerConfig.CONFIG_DEF_NAMESPACE); + ConfigDefinition cdd = new ConfigDefinition(cdk.getName(), DocumentmanagerConfig.CONFIG_DEF_SCHEMA); + var cfgBuilder = app.getModel().getConfigInstance(new ConfigKey<>(cdk.getName(), "client", cdk.getNamespace()), cdd); + assertTrue(cfgBuilder instanceof DocumentmanagerConfig.Builder); + var cfg = ((DocumentmanagerConfig.Builder)cfgBuilder).build(); + // no effect from legacy override seen here: + System.out.println("CFG: "+cfg); + */ + } + +} diff --git a/container-messagebus/src/main/java/com/yahoo/messagebus/jdisc/MbusClient.java b/container-messagebus/src/main/java/com/yahoo/messagebus/jdisc/MbusClient.java index 7fa294422d6..7571cfb0969 100644 --- a/container-messagebus/src/main/java/com/yahoo/messagebus/jdisc/MbusClient.java +++ b/container-messagebus/src/main/java/com/yahoo/messagebus/jdisc/MbusClient.java @@ -43,6 +43,7 @@ public final class MbusClient extends AbstractResource implements ClientProvider this.session = session; this.sessionReference = session.refer(this); thread = new Thread(new SenderTask(), "mbus-client-" + threadId.getAndIncrement()); + thread.setDaemon(true); } @Override @@ -79,6 +80,11 @@ public final class MbusClient extends AbstractResource implements ClientProvider log.log(Level.FINE, "Destroying message bus client."); sessionReference.close(); done = true; + try { + thread.join(); + } catch (InterruptedException e) { + log.log(Level.WARNING, "Interrupted while joining thread on destroy.", e); + } } @Override @@ -121,7 +127,7 @@ public final class MbusClient extends AbstractResource implements ClientProvider if (error == null) { return true; } - if (error.isFatal()) { + if (error.isFatal() || done) { final Reply reply = new EmptyReply(); reply.swapState(request.getMessage()); reply.addError(error); diff --git a/container-messagebus/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java b/container-messagebus/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java index e7b1fc3e71d..084a5f82268 100644 --- a/container-messagebus/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java +++ b/container-messagebus/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java @@ -75,8 +75,8 @@ public final class MbusServer extends AbstractResource implements ServerProvider return; } if (state == State.STOPPED) { - // We might need to detect requests originating from the same JVM, as they nede to fail fast - // As they are holding references to the container preventing proper shutdown. + // We might need to detect requests originating from the same JVM, as they need to fail fast + // as they are holding references to the container preventing proper shutdown. dispatchErrorReply(msg, ErrorCode.SESSION_BUSY, "MBusServer has been closed."); return; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java index 2480c81755e..1601af2612b 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java @@ -396,22 +396,23 @@ public class JobController { }); logs.flush(id); metric.jobFinished(run.id().job(), finishedRun.status()); + + DeploymentId deploymentId = new DeploymentId(unlockedRun.id().application(), unlockedRun.id().job().type().zone(controller.system())); + (unlockedRun.versions().targetApplication().isDeployedDirectly() ? + Stream.of(unlockedRun.id().type()) : + JobType.allIn(controller.system()).stream().filter(jobType -> !jobType.environment().isManuallyDeployed())) + .flatMap(jobType -> controller.jobController().runs(unlockedRun.id().application(), jobType).values().stream()) + .mapToLong(r -> r.versions().targetApplication().buildNumber().orElse(Integer.MAX_VALUE)) + .min() + .ifPresent(oldestBuild -> { + if (unlockedRun.versions().targetApplication().isDeployedDirectly()) + controller.applications().applicationStore().pruneDevDiffs(deploymentId, oldestBuild); + else + controller.applications().applicationStore().pruneDiffs(deploymentId.applicationId().tenant(), deploymentId.applicationId().application(), oldestBuild); + }); + return finishedRun; }); - - DeploymentId deploymentId = new DeploymentId(unlockedRun.id().application(), unlockedRun.id().job().type().zone(controller.system())); - (unlockedRun.versions().targetApplication().isDeployedDirectly() ? - Stream.of(unlockedRun.id().type()) : - JobType.allIn(controller.system()).stream().filter(jobType -> !jobType.environment().isManuallyDeployed())) - .flatMap(jobType -> controller.jobController().runs(unlockedRun.id().application(), jobType).values().stream()) - .mapToLong(run -> run.versions().targetApplication().buildNumber().orElse(Integer.MAX_VALUE)) - .min() - .ifPresent(oldestBuild -> { - if (unlockedRun.versions().targetApplication().isDeployedDirectly()) - controller.applications().applicationStore().pruneDevDiffs(deploymentId, oldestBuild); - else - controller.applications().applicationStore().pruneDiffs(deploymentId.applicationId().tenant(), deploymentId.applicationId().application(), oldestBuild); - }); } finally { for (Lock lock : locks) diff --git a/document/pom.xml b/document/pom.xml index 1bfb18767eb..3faada08553 100644 --- a/document/pom.xml +++ b/document/pom.xml @@ -136,7 +136,6 @@ <arg>-Xlint:-serial</arg> <arg>-Xlint:-rawtypes</arg> <arg>-Xlint:-unchecked</arg> - <arg>-Xlint:-cast</arg> <arg>-Werror</arg> </compilerArgs> </configuration> diff --git a/fsa/pom.xml b/fsa/pom.xml index 25e1167a240..db863ba5522 100644 --- a/fsa/pom.xml +++ b/fsa/pom.xml @@ -44,16 +44,6 @@ <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> - <configuration> - <compilerArgs> - <arg>-Xlint:all</arg> - <arg>-Xlint:-fallthrough</arg> - <arg>-Xlint:-serial</arg> - <arg>-Xlint:-rawtypes</arg> - <arg>-Xlint:-unchecked</arg> - <arg>-Werror</arg> - </compilerArgs> - </configuration> </plugin> <plugin> <groupId>com.yahoo.vespa</groupId> diff --git a/fsa/src/main/java/com/yahoo/fsa/FSA.java b/fsa/src/main/java/com/yahoo/fsa/FSA.java index a964b32c54a..fcc940a335c 100644 --- a/fsa/src/main/java/com/yahoo/fsa/FSA.java +++ b/fsa/src/main/java/com/yahoo/fsa/FSA.java @@ -188,10 +188,10 @@ public class FSA implements Closeable { */ public Item(FSA fsa, int state) { this.fsa = fsa; - this.string = new java.util.Stack(); + this.string = new java.util.Stack<>(); this.symbol = 0; this.state = state; - this.stack = new java.util.Stack(); + this.stack = new java.util.Stack<>(); } /** @@ -199,7 +199,7 @@ public class FSA implements Closeable { */ public Item(Item item) { this.fsa = item.fsa; - this.string = new java.util.Stack(); + this.string = new java.util.Stack<>(); for (java.util.Iterator<Byte> itr = item.string.iterator(); itr.hasNext(); ) { byte b = itr.next(); this.string.push(b); @@ -415,7 +415,7 @@ public class FSA implements Closeable { if ((mmap == null) || !mmap.isDirect()) return; try { - Class unsafeClass; + Class<?> unsafeClass; try { unsafeClass = Class.forName("sun.misc.Unsafe"); } catch (Exception ex) { @@ -468,8 +468,8 @@ public class FSA implements Closeable { * @return the loaded FSA * @throws RuntimeException if the class could not be loaded */ - public static FSA loadFromResource(String resourceFileName,Class loadingClass) { - URL fsaUrl=loadingClass.getResource(resourceFileName); + public static FSA loadFromResource(String resourceFileName, Class<?> loadingClass) { + URL fsaUrl = loadingClass.getResource(resourceFileName); if ( ! "file".equals(fsaUrl.getProtocol())) { throw new RuntimeException("Could not open non-file url '" + fsaUrl + "' as a file input stream: " + "The classloader of " + loadingClass + "' does not return file urls"); diff --git a/fsa/src/main/java/com/yahoo/fsa/segmenter/Segmenter.java b/fsa/src/main/java/com/yahoo/fsa/segmenter/Segmenter.java index 7c3e76996bb..4edac362131 100644 --- a/fsa/src/main/java/com/yahoo/fsa/segmenter/Segmenter.java +++ b/fsa/src/main/java/com/yahoo/fsa/segmenter/Segmenter.java @@ -60,7 +60,7 @@ public class Segmenter { public Segments segment(String[] tokens) { Segments segments = new Segments(tokens); - LinkedList detectors = new LinkedList(); + LinkedList<Detector> detectors = new LinkedList<>(); int i=0; @@ -68,9 +68,9 @@ public class Segmenter { while(i<tokens.length){ detectors.add(new Detector(fsa.getState(), i)); - ListIterator det_it = detectors.listIterator(); + ListIterator<Detector> det_it = detectors.listIterator(); while(det_it.hasNext()){ - Detector d = (Detector)det_it.next(); + Detector d = det_it.next(); d.state().deltaWord(tokens[i]); if(d.state().isFinal()){ segments.add(new Segment(d.index(),i+1,d.state().data().getInt(0))); diff --git a/fsa/src/main/java/com/yahoo/fsa/segmenter/Segments.java b/fsa/src/main/java/com/yahoo/fsa/segmenter/Segments.java index 89368e2bf8f..e3bfe956a5c 100644 --- a/fsa/src/main/java/com/yahoo/fsa/segmenter/Segments.java +++ b/fsa/src/main/java/com/yahoo/fsa/segmenter/Segments.java @@ -8,7 +8,7 @@ import java.util.LinkedList; * * @author Peter Boros */ -public class Segments extends LinkedList { +public class Segments extends LinkedList<Segment> { public final static int SEGMENTATION_WEIGHTED = 0; public final static int SEGMENTATION_WEIGHTED_BIAS10 = 1; @@ -43,10 +43,12 @@ public class Segments extends LinkedList { } } - public void add(Segment s) + @Override + public boolean add(Segment s) { - super.add(s); + var result = super.add(s); _map[s.beg()][s.end()]=super.size()-1; + return result; } private void addMissingSingles() @@ -76,8 +78,8 @@ public class Segments extends LinkedList { if(idx<0 || idx>=super.size()){ return null; } - String s = new String(_tokens[((Segment)(super.get(idx))).beg()]); - for(int i=((Segment)(super.get(idx))).beg()+1;i<((Segment)(super.get(idx))).end();i++){ + String s = new String(_tokens[super.get(idx).beg()]); + for(int i = super.get(idx).beg() + 1; i < super.get(idx).end(); i++){ s += " " + _tokens[i]; } return s; @@ -88,7 +90,7 @@ public class Segments extends LinkedList { if(idx<0 || idx>=super.size()){ return -1; } - return ((Segment)(super.get(idx))).beg(); + return super.get(idx).beg(); } public int end(int idx) @@ -96,7 +98,7 @@ public class Segments extends LinkedList { if(idx<0 || idx>=super.size()){ return -1; } - return ((Segment)(super.get(idx))).end(); + return super.get(idx).end(); } public int len(int idx) @@ -104,7 +106,7 @@ public class Segments extends LinkedList { if(idx<0 || idx>=super.size()){ return -1; } - return ((Segment)(super.get(idx))).len(); + return super.get(idx).len(); } public int conn(int idx) @@ -112,9 +114,10 @@ public class Segments extends LinkedList { if(idx<0 || idx>=super.size()){ return -1; } - return ((Segment)(super.get(idx))).conn(); + return super.get(idx).conn(); } + @SuppressWarnings("fallthrough") public Segments segmentation(int method) { Segments smnt = new Segments(_tokens); @@ -170,7 +173,7 @@ public class Segments extends LinkedList { } id = bestid; while(id!=-1){ - smnt.add(((Segment)(super.get(id)))); + smnt.add(super.get(id)); id=nextid[id]; } break; @@ -189,7 +192,7 @@ public class Segments extends LinkedList { next = i; } } - smnt.add((Segment)(super.get(bestid))); + smnt.add(super.get(bestid)); pos=next; } break; @@ -302,7 +305,7 @@ public class Segments extends LinkedList { } // add segment - smnt.add((Segment)(super.get(bestid))); + smnt.add(super.get(bestid)); // check right side if(e>end(bestid)){ diff --git a/fsa/src/main/java/com/yahoo/fsa/topicpredictor/TopicPredictor.java b/fsa/src/main/java/com/yahoo/fsa/topicpredictor/TopicPredictor.java index 7049ad5495d..52dae951165 100644 --- a/fsa/src/main/java/com/yahoo/fsa/topicpredictor/TopicPredictor.java +++ b/fsa/src/main/java/com/yahoo/fsa/topicpredictor/TopicPredictor.java @@ -59,7 +59,7 @@ public class TopicPredictor extends MetaData { * as opposed to the two-argument version. * @param segment The segment string to find (all) topics for. * @return (Linked)List of PredictedTopic objects. */ - public List getPredictedTopics(String segment) { + public List<PredictedTopic> getPredictedTopics(String segment) { return getPredictedTopics(segment, 0); } @@ -70,8 +70,8 @@ public class TopicPredictor extends MetaData { * @param segment The segment string to find topics for. * @param maxTopics The max number of topics to return, 0 for all topics * @return (Linked)List of PredictedTopic objects. */ - public List getPredictedTopics(String segment, int maxTopics) { - List predictedTopics = new LinkedList(); + public List<PredictedTopic> getPredictedTopics(String segment, int maxTopics) { + List<PredictedTopic> predictedTopics = new LinkedList<>(); int segIdx = getSegmentIndex(segment); int[][] topicArr = getTopicArray(segIdx, maxTopics); diff --git a/fsa/src/test/java/com/yahoo/fsa/test/FSADataTestCase.java b/fsa/src/test/java/com/yahoo/fsa/test/FSADataTestCase.java index 3e9efc68558..28faaea1373 100644 --- a/fsa/src/test/java/com/yahoo/fsa/test/FSADataTestCase.java +++ b/fsa/src/test/java/com/yahoo/fsa/test/FSADataTestCase.java @@ -34,6 +34,7 @@ public class FSADataTestCase { this.numExceptions = 0; this.numAsserts = 0; } + @Override public void run() { for (long i = 0; i < numRuns; ++i) { state.start(); diff --git a/fsa/src/test/java/com/yahoo/fsa/test/FSAIteratorTestCase.java b/fsa/src/test/java/com/yahoo/fsa/test/FSAIteratorTestCase.java index 645536e596b..e99998e16f2 100644 --- a/fsa/src/test/java/com/yahoo/fsa/test/FSAIteratorTestCase.java +++ b/fsa/src/test/java/com/yahoo/fsa/test/FSAIteratorTestCase.java @@ -94,7 +94,7 @@ public class FSAIteratorTestCase { @Test public void testIteratorEmpty1() { state.delta("b"); - java.util.Iterator i = fsa.iterator(state); + FSA.Iterator i = fsa.iterator(state); assertFalse(i.hasNext()); try { i.next(); @@ -107,7 +107,7 @@ public class FSAIteratorTestCase { @Test public void testIteratorEmpty2() { state.delta("daciac"); - java.util.Iterator i = fsa.iterator(state); + FSA.Iterator i = fsa.iterator(state); assertFalse(i.hasNext()); try { i.next(); @@ -119,7 +119,7 @@ public class FSAIteratorTestCase { @Test public void testIteratorRemove() { - java.util.Iterator i = fsa.iterator(state); + FSA.Iterator i = fsa.iterator(state); try { i.remove(); assertFalse(true); diff --git a/hosted-api/src/main/java/ai/vespa/hosted/api/ControllerHttpClient.java b/hosted-api/src/main/java/ai/vespa/hosted/api/ControllerHttpClient.java index 4a79857955a..1352220166c 100644 --- a/hosted-api/src/main/java/ai/vespa/hosted/api/ControllerHttpClient.java +++ b/hosted-api/src/main/java/ai/vespa/hosted/api/ControllerHttpClient.java @@ -40,6 +40,7 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.OptionalLong; import java.util.Set; import java.util.concurrent.Callable; @@ -480,10 +481,11 @@ public abstract class ControllerHttpClient { // Note: Much more data in response, only the interesting parts of response are included in InstanceInfo for now private static InstanceInfo toInstanceInfo(HttpResponse<byte[]> response, ApplicationId applicationId) { - Set<ZoneId> zones = new HashSet<>(); + List<ZoneDeployment> zones = new ArrayList<>(); toInspector(response).field("instances").traverse((ArrayTraverser) (___, entryObject) -> - zones.add(ZoneId.from(entryObject.field("environment").asString(), - entryObject.field("region").asString()))); + zones.add(new ZoneDeployment(ZoneId.from(entryObject.field("environment").asString(), + entryObject.field("region").asString()), + entryObject.field("url").valid() ? Optional.of(entryObject.field("url").asString()) : Optional.empty()))); return new InstanceInfo(applicationId, zones); } @@ -561,21 +563,33 @@ public abstract class ControllerHttpClient { public static class InstanceInfo { private final ApplicationId applicationId; - private final Set<ZoneId> zones; + private final List<ZoneDeployment> zones; - InstanceInfo(ApplicationId applicationId, Set<ZoneId> zones) { + InstanceInfo(ApplicationId applicationId, List<ZoneDeployment> zones) { this.applicationId = applicationId; this.zones = zones; } - public ApplicationId applicationId() { - return applicationId; - } + public ApplicationId applicationId() { return applicationId; } + + public List<ZoneDeployment> zones() { return zones; } + + } - public Set<ZoneId> zones() { - return zones; + public static class ZoneDeployment { + + private final ZoneId zone; + private final Optional<String> uri; + + public ZoneDeployment(ZoneId zone, Optional<String> uri) { + this.zone = zone; + this.uri = uri; } + public ZoneId zone() { return zone; } + + public boolean isDeployed() { return uri.isPresent(); } + } } diff --git a/messagebus/src/main/java/com/yahoo/messagebus/MessageBus.java b/messagebus/src/main/java/com/yahoo/messagebus/MessageBus.java index 085978375a6..8611801b9a9 100644 --- a/messagebus/src/main/java/com/yahoo/messagebus/MessageBus.java +++ b/messagebus/src/main/java/com/yahoo/messagebus/MessageBus.java @@ -59,7 +59,7 @@ import java.util.logging.Logger; */ public class MessageBus implements ConfigHandler, NetworkOwner, MessageHandler, ReplyHandler { - private static Logger log = Logger.getLogger(MessageBus.class.getName()); + private final static Logger log = Logger.getLogger(MessageBus.class.getName()); private final AtomicBoolean destroyed = new AtomicBoolean(false); private final ProtocolRepository protocolRepository = new ProtocolRepository(); private final AtomicReference<Map<String, RoutingTable>> tablesRef = new AtomicReference<>(null); diff --git a/model-integration/pom.xml b/model-integration/pom.xml index 47ad6375491..41139394690 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -71,17 +71,6 @@ <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> - <configuration> - <compilerArgs> - <arg>-Xlint:all</arg> - <arg>-Xlint:-rawtypes</arg> - <arg>-Xlint:-unchecked</arg> - <arg>-Xlint:-serial</arg> - <arg>-Xlint:-cast</arg> - <arg>-Xlint:-overloads</arg> - <arg>-Werror</arg> - </compilerArgs> - </configuration> </plugin> <plugin> <groupId>com.github.os72</groupId> diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index b4b21d388b5..5627327d429 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -100,7 +101,7 @@ public abstract class ModelImporter implements MlModelImporter { for (ImportedModel.Signature signature : model.signatures().values()) { for (String outputName : signature.outputs().values()) { try { - Optional<TensorFunction> function = importExpression(graph.get(outputName), model); + Optional<TensorFunction<Reference>> function = importExpression(graph.get(outputName), model); if (function.isEmpty()) { signature.skippedOutput(outputName, "No valid output function could be found."); } @@ -112,7 +113,7 @@ public abstract class ModelImporter implements MlModelImporter { } } - private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) { + private static Optional<TensorFunction<Reference>> importExpression(IntermediateOperation operation, ImportedModel model) { if (model.expressions().containsKey(operation.name())) { return operation.function(); } @@ -134,7 +135,7 @@ public abstract class ModelImporter implements MlModelImporter { operation.inputs().forEach(input -> importExpression(input, model)); } - private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) { + private static Optional<TensorFunction<Reference>> importConstant(IntermediateOperation operation, ImportedModel model) { String name = operation.vespaName(); if (model.hasLargeConstant(name) || model.hasSmallConstant(name)) { return operation.function(); @@ -160,7 +161,7 @@ public abstract class ModelImporter implements MlModelImporter { if (operation.function().isPresent()) { String name = operation.name(); if ( ! model.expressions().containsKey(name)) { - TensorFunction function = operation.function().get(); + TensorFunction<Reference> function = operation.function().get(); if (isSignatureOutput(model, operation)) { OrderedTensorType operationType = operation.type().get(); @@ -168,7 +169,7 @@ public abstract class ModelImporter implements MlModelImporter { if ( ! operationType.equals(standardNamingType)) { List<String> renameFrom = operationType.dimensionNames(); List<String> renameTo = standardNamingType.dimensionNames(); - function = new Rename(function, renameFrom, renameTo); + function = new Rename<Reference>(function, renameFrom, renameTo); } } @@ -196,7 +197,7 @@ public abstract class ModelImporter implements MlModelImporter { private static void importFunctionExpression(IntermediateOperation operation, ImportedModel model) { if (operation.rankingExpressionFunction().isPresent()) { - TensorFunction function = operation.rankingExpressionFunction().get(); + TensorFunction<Reference> function = operation.rankingExpressionFunction().get(); try { model.function(operation.rankingExpressionFunctionName(), new RankingExpression(operation.rankingExpressionFunctionName(), function.toString())); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java index 37f5ae9dd29..b77960ff3fb 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.TensorTypeParser; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java index e58b5341e6b..bda2f16f9e2 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer.operations; +import com.yahoo.searchlib.rankingexpression.Reference; import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.DimensionRenamer; import com.yahoo.tensor.evaluation.VariableTensor; @@ -26,12 +27,12 @@ public class Argument extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { - TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type()); + protected TensorFunction<Reference> lazyGetFunction() { + TensorFunction<Reference> output = new VariableTensor<Reference>(vespaName(), standardNamingType.type()); if ( ! standardNamingType.equals(type)) { List<String> renameFrom = standardNamingType.dimensionNames(); List<String> renameTo = type.dimensionNames(); - output = new Rename(output, renameFrom, renameTo); + output = new Rename<Reference>(output, renameFrom, renameTo); } return output; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java index bf10eb2457b..9484545c9c1 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.TensorFunction; @@ -25,15 +26,15 @@ public class ConcatReduce extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(inputs.size())) return null; - TensorFunction result = inputs.get(0).function().get(); + TensorFunction<Reference> result = inputs.get(0).function().get(); for (int i = 1; i < inputs.size(); ++i) { - TensorFunction b = inputs.get(i).function().get(); - result = new com.yahoo.tensor.functions.Concat(result, b, tmpDimensionName); + TensorFunction<Reference> b = inputs.get(i).function().get(); + result = new com.yahoo.tensor.functions.Concat<>(result, b, tmpDimensionName); } - return new com.yahoo.tensor.functions.Reduce(result, aggregator, tmpDimensionName); + return new com.yahoo.tensor.functions.Reduce<>(result, aggregator, tmpDimensionName); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java index 9f3b15cddbd..6cb810aff94 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; @@ -68,14 +69,14 @@ public class ConcatV2 extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) { return null; } - TensorFunction result = inputs.get(0).function().get(); + TensorFunction<Reference> result = inputs.get(0).function().get(); for (int i = 1; i < inputs.size() - 1; ++i) { - TensorFunction b = inputs.get(i).function().get(); - result = new com.yahoo.tensor.functions.Concat(result, b, concatDimensionName); + TensorFunction<Reference> b = inputs.get(i).function().get(); + result = new com.yahoo.tensor.functions.Concat<>(result, b, concatDimensionName); } return result; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java index 859702dec40..d68b632bf61 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java @@ -35,7 +35,7 @@ public class Const extends IntermediateOperation { } @Override - public Optional<TensorFunction> function() { + public Optional<TensorFunction<Reference>> function() { if (function == null) { function = lazyGetFunction(); } @@ -43,7 +43,7 @@ public class Const extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { ExpressionNode expressionNode; if (type.type().rank() == 0 && getConstantValue().isPresent()) { expressionNode = new ConstantNode(getConstantValue().get().asDoubleValue()); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java index a381b2cb8a0..cdc408b3e70 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.functions.TensorFunction; @@ -23,7 +24,7 @@ public class Constant extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { return null; // will be added by function() since this is constant. } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java index c48e5592a56..d88fc34725e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; @@ -60,10 +61,10 @@ public class ConstantOfShape extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputTypesPresent(1)) return null; ExpressionNode valueExpr = new ConstantNode(new DoubleValue(valueToFillWith)); - TensorFunction function = Generate.bound(type.type(), wrapScalar(valueExpr)); + TensorFunction<Reference> function = Generate.bound(type.type(), wrapScalar(valueExpr)); return function; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java index eda188b339f..6d57adbd888 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java @@ -74,7 +74,7 @@ public class Expand extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!allInputFunctionsPresent(2)) return null; IntermediateOperation input = inputs.get(0); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index 027532cd02d..83132b0669c 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -4,6 +4,7 @@ package ai.vespa.rankingexpression.importer.operations; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; @@ -65,7 +66,7 @@ public class ExpandDims extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(2)) return null; // multiply with a generated tensor created from the reduced dimensions @@ -75,9 +76,9 @@ public class ExpandDims extends IntermediateOperation { } TensorType generatedType = typeBuilder.build(); ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); - Generate generatedFunction = new Generate(generatedType, + Generate<Reference> generatedFunction = new Generate<>(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); - return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply()); + return new com.yahoo.tensor.functions.Join<>(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply()); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java index bab9c47ca9a..cd0c4da6d0f 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java @@ -71,7 +71,7 @@ public class Gather extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(2)) return null; IntermediateOperation data = inputs.get(0); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java index 4b3208fdeb0..1f447f2a575 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; @@ -78,7 +79,7 @@ public class Gemm extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! check2or3InputsPresent()) return null; OrderedTensorType aType = inputs.get(0).type().get(); @@ -86,29 +87,29 @@ public class Gemm extends IntermediateOperation { if (aType.type().rank() != 2 || bType.type().rank() != 2) throw new IllegalArgumentException("Tensors in Gemm must have rank of exactly 2"); - Optional<TensorFunction> aFunction = inputs.get(0).function(); - Optional<TensorFunction> bFunction = inputs.get(1).function(); + Optional<TensorFunction<Reference>> aFunction = inputs.get(0).function(); + Optional<TensorFunction<Reference>> bFunction = inputs.get(1).function(); if (aFunction.isEmpty() || bFunction.isEmpty()) { return null; } String joinDimension = aType.dimensions().get(1 - transposeA).name(); - TensorFunction AxB = new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), joinDimension); - TensorFunction alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction( + TensorFunction<Reference> AxB = new com.yahoo.tensor.functions.Matmul<>(aFunction.get(), bFunction.get(), joinDimension); + TensorFunction<Reference> alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction( new ArithmeticNode( new TensorFunctionNode(AxB), ArithmeticOperator.MULTIPLY, new ConstantNode(new DoubleValue(alpha)))); if (inputs.size() == 3) { - Optional<TensorFunction> cFunction = inputs.get(2).function(); - TensorFunction betaxC = new TensorFunctionNode.ExpressionTensorFunction( + Optional<TensorFunction<Reference>> cFunction = inputs.get(2).function(); + TensorFunction<Reference> betaxC = new TensorFunctionNode.ExpressionTensorFunction( new ArithmeticNode( new TensorFunctionNode(cFunction.get()), ArithmeticOperator.MULTIPLY, new ConstantNode(new DoubleValue(beta)))); - return new com.yahoo.tensor.functions.Join(alphaxAxB, betaxC, ScalarFunctions.add()); + return new com.yahoo.tensor.functions.Join<>(alphaxAxB, betaxC, ScalarFunctions.add()); } return alphaxAxB; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java index f096cb1e54f..ab840e708a7 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.TensorFunction; import java.util.List; @@ -20,7 +21,7 @@ public class Identity extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!allInputFunctionsPresent(1)) return null; return inputs.get(0).function().orElse(null); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 6ebb478715a..6378442c6d0 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -45,8 +45,8 @@ public abstract class IntermediateOperation { protected final List<IntermediateOperation> outputs = new ArrayList<>(); protected OrderedTensorType type; - protected TensorFunction function; - protected TensorFunction rankingExpressionFunction = null; + protected TensorFunction<Reference> function; + protected TensorFunction<Reference> rankingExpressionFunction = null; protected boolean exportAsRankingFunction = false; private boolean hasRenamedDimensions = false; @@ -65,7 +65,7 @@ public abstract class IntermediateOperation { } protected abstract OrderedTensorType lazyGetType(); - protected abstract TensorFunction lazyGetFunction(); + protected abstract TensorFunction<Reference> lazyGetFunction(); public String modelName() { return modelName; } @@ -78,14 +78,14 @@ public abstract class IntermediateOperation { } /** Returns the Vespa tensor function implementing all operations from this node with inputs */ - public Optional<TensorFunction> function() { + public Optional<TensorFunction<Reference>> function() { if (function == null) { if (isConstant()) { ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName())); function = new TensorFunctionNode.ExpressionTensorFunction(constant); } else if (outputs.size() > 1 || exportAsRankingFunction) { rankingExpressionFunction = lazyGetFunction(); - function = new VariableTensor(rankingExpressionFunctionName(), type.type()); + function = new VariableTensor<Reference>(rankingExpressionFunctionName(), type.type()); } else { function = lazyGetFunction(); } @@ -103,7 +103,7 @@ public abstract class IntermediateOperation { public List<IntermediateOperation> outputs() { return Collections.unmodifiableList(outputs); } /** Returns a function that should be added as a ranking expression function */ - public Optional<TensorFunction> rankingExpressionFunction() { + public Optional<TensorFunction<Reference>> rankingExpressionFunction() { return Optional.ofNullable(rankingExpressionFunction); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index 92b5f2e743b..667641dc33a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.ScalarFunctions; @@ -53,7 +54,7 @@ public class Join extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputTypesPresent(2)) return null; if ( ! allInputFunctionsPresent(2)) return null; @@ -63,7 +64,7 @@ public class Join extends IntermediateOperation { if (mapOperator.isPresent()) { IntermediateOperation input = inputs.get(0); input.removeDuplicateOutputsTo(this); // avoids unnecessary function export - return new com.yahoo.tensor.functions.Map(input.function().get(), mapOperator.get()); + return new com.yahoo.tensor.functions.Map<Reference>(input.function().get(), mapOperator.get()); } } @@ -86,23 +87,23 @@ public class Join extends IntermediateOperation { } } - TensorFunction aReducedFunction = a.function().get(); + TensorFunction<Reference> aReducedFunction = a.function().get(); if (aDimensionsToReduce.size() > 0) { - aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce); + aReducedFunction = new Reduce<Reference>(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce); } - TensorFunction bReducedFunction = b.function().get(); + TensorFunction<Reference> bReducedFunction = b.function().get(); if (bDimensionsToReduce.size() > 0) { - bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce); + bReducedFunction = new Reduce<Reference>(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce); } // retain order of inputs if (a == inputs.get(1)) { - TensorFunction temp = bReducedFunction; + TensorFunction<Reference> temp = bReducedFunction; bReducedFunction = aReducedFunction; aReducedFunction = temp; } - return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator); + return new com.yahoo.tensor.functions.Join<Reference>(aReducedFunction, bReducedFunction, operator); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java index 1fd0f72f416..c9b03ba9b85 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.TensorFunction; import java.util.List; @@ -26,12 +27,12 @@ public class Map extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!allInputFunctionsPresent(1)) { return null; } - Optional<TensorFunction> input = inputs.get(0).function(); - return new com.yahoo.tensor.functions.Map(input.get(), operator); + Optional<TensorFunction<Reference>> input = inputs.get(0).function(); + return new com.yahoo.tensor.functions.Map<>(input.get(), operator); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 673df9be36b..7d64a023e27 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; @@ -58,24 +59,24 @@ public class MatMul extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputTypesPresent(2)) return null; if ( ! allInputFunctionsPresent(2)) return null; OrderedTensorType typeA = inputs.get(0).type().get(); OrderedTensorType typeB = inputs.get(1).type().get(); - TensorFunction functionA = handleBroadcasting(inputs.get(0).function().get(), typeA, typeB); - TensorFunction functionB = handleBroadcasting(inputs.get(1).function().get(), typeB, typeA); + TensorFunction<Reference> functionA = handleBroadcasting(inputs.get(0).function().get(), typeA, typeB); + TensorFunction<Reference> functionB = handleBroadcasting(inputs.get(1).function().get(), typeB, typeA); - return new com.yahoo.tensor.functions.Reduce( - new Join(functionA, functionB, ScalarFunctions.multiply()), + return new com.yahoo.tensor.functions.Reduce<Reference>( + new Join<Reference>(functionA, functionB, ScalarFunctions.multiply()), Reduce.Aggregator.sum, typeA.dimensions().get(typeA.rank() - 1).name()); } - private TensorFunction handleBroadcasting(TensorFunction tensorFunction, OrderedTensorType typeA, OrderedTensorType typeB) { - List<Slice.DimensionValue> slices = new ArrayList<>(); + private TensorFunction<Reference> handleBroadcasting(TensorFunction<Reference> tensorFunction, OrderedTensorType typeA, OrderedTensorType typeB) { + List<Slice.DimensionValue<Reference>> slices = new ArrayList<>(); for (int i = 0; i < typeA.rank() - 2; ++i) { long dimSizeA = typeA.dimensions().get(i).size().get(); String dimNameA = typeA.dimensionNames().get(i); @@ -84,11 +85,11 @@ public class MatMul extends IntermediateOperation { long dimSizeB = typeB.dimensions().get(j).size().get(); if (dimSizeB > dimSizeA && dimSizeA == 1) { ExpressionNode dimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero)); - slices.add(new Slice.DimensionValue(Optional.of(dimNameA), wrapScalar(dimensionExpression))); + slices.add(new Slice.DimensionValue<>(Optional.of(dimNameA), wrapScalar(dimensionExpression))); } } } - return slices.size() == 0 ? tensorFunction : new Slice(tensorFunction, slices); + return slices.size() == 0 ? tensorFunction : new Slice<>(tensorFunction, slices); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java index a4a47ca8ce7..fd262b2892c 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import ai.vespa.rankingexpression.importer.DimensionRenamer; @@ -56,12 +57,12 @@ public class Mean extends IntermediateOperation { // optimization: if keepDims and one reduce dimension that has size 1: same as identity. @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputTypesPresent(2)) return null; - TensorFunction inputFunction = inputs.get(0).function().get(); - TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions); + TensorFunction<Reference> inputFunction = inputs.get(0).function().get(); + TensorFunction<Reference> output = new Reduce<>(inputFunction, Reduce.Aggregator.avg, reduceDimensions); if (shouldKeepDimensions()) { // multiply with a generated tensor created from the reduced dimensions TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); @@ -70,9 +71,9 @@ public class Mean extends IntermediateOperation { } TensorType generatedType = typeBuilder.build(); ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); - Generate generatedFunction = new Generate(generatedType, + Generate<Reference> generatedFunction = new Generate<>(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); - output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply()); + output = new com.yahoo.tensor.functions.Join<>(output, generatedFunction, ScalarFunctions.multiply()); } return output; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java index f208cc97d4f..e2b5930f114 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.TensorFunction; import java.util.List; @@ -23,7 +24,7 @@ public class Merge extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { for (IntermediateOperation operation : inputs) { if (operation.function().isPresent()) { return operation.function().get(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java index 1d76fa3f0a7..d8055d548ad 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.TensorFunction; import java.util.Collections; @@ -19,7 +20,7 @@ public class NoOp extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { return null; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java index 7b0547be7d2..164e3dc5e11 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.TensorFunction; import onnx.Onnx.TensorProto.DataType; @@ -30,13 +31,13 @@ public class OnnxCast extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(1)) return null; - TensorFunction input = inputs.get(0).function().get(); + TensorFunction<Reference> input = inputs.get(0).function().get(); switch (toType) { case BOOL: - return new com.yahoo.tensor.functions.Map(input, new AsBool()); + return new com.yahoo.tensor.functions.Map<>(input, new AsBool()); case INT8: case INT16: case INT32: @@ -45,7 +46,7 @@ public class OnnxCast extends IntermediateOperation { case UINT16: case UINT32: case UINT64: - return new com.yahoo.tensor.functions.Map(input, new AsInt()); + return new com.yahoo.tensor.functions.Map<>(input, new AsInt()); case FLOAT: case DOUBLE: case FLOAT16: diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java index 2be8fc0dc4e..97818f4c27d 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; @@ -65,14 +66,14 @@ public class OnnxConcat extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) { return null; } - TensorFunction result = inputs.get(0).function().get(); + TensorFunction<Reference> result = inputs.get(0).function().get(); for (int i = 1; i < inputs.size(); ++i) { - TensorFunction b = inputs.get(i).function().get(); - result = new com.yahoo.tensor.functions.Concat(result, b, concatDimensionName); + TensorFunction<Reference> b = inputs.get(i).function().get(); + result = new com.yahoo.tensor.functions.Concat<>(result, b, concatDimensionName); } return result; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java index 79123cb0380..675e18da637 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -36,7 +37,7 @@ public class OnnxConstant extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { return null; // will be added by function() since this is constant. } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java index 3456a24f5dd..c0f825f9092 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.TensorFunction; import java.util.List; @@ -22,7 +23,7 @@ public class PlaceholderWithDefault extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!allInputFunctionsPresent(1)) { return null; } @@ -32,7 +33,7 @@ public class PlaceholderWithDefault extends IntermediateOperation { } @Override - public Optional<TensorFunction> rankingExpressionFunction() { + public Optional<TensorFunction<Reference>> rankingExpressionFunction() { // For now, it is much more efficient to assume we always will return // the default value, as we can prune away large parts of the expression // tree by having it calculated as a constant. If a case arises where diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java index 81a9e4996b4..5c4e8cd6cd0 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; @@ -58,7 +59,7 @@ public class Range extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputTypesPresent(3)) return null; String dimensionName = type().get().dimensionNames().get(0); ExpressionNode startExpr = new ConstantNode(new DoubleValue(start)); @@ -66,7 +67,7 @@ public class Range extends IntermediateOperation { ExpressionNode dimExpr = new EmbracedNode(new ReferenceNode(dimensionName)); ExpressionNode stepExpr = new ArithmeticNode(deltaExpr, ArithmeticOperator.MULTIPLY, dimExpr); ExpressionNode addExpr = new ArithmeticNode(startExpr, ArithmeticOperator.PLUS, stepExpr); - TensorFunction function = Generate.bound(type.type(), wrapScalar(addExpr)); + TensorFunction<Reference> function = Generate.bound(type.type(), wrapScalar(addExpr)); return function; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java index 8e49ce15265..b7a8a4a4e43 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; @@ -72,14 +73,14 @@ public class Reduce extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputTypesPresent(1)) return null; - TensorFunction inputFunction = inputs.get(0).function().get(); + TensorFunction<Reference> inputFunction = inputs.get(0).function().get(); if (preOperator != null) { - inputFunction = new com.yahoo.tensor.functions.Map(inputFunction, preOperator); + inputFunction = new com.yahoo.tensor.functions.Map<>(inputFunction, preOperator); } - TensorFunction output = new com.yahoo.tensor.functions.Reduce(inputFunction, aggregator, reduceDimensions); + TensorFunction<Reference> output = new com.yahoo.tensor.functions.Reduce<>(inputFunction, aggregator, reduceDimensions); if (shouldKeepDimensions()) { // multiply with a generated tensor created from the reduced dimensions TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); @@ -88,12 +89,12 @@ public class Reduce extends IntermediateOperation { } TensorType generatedType = typeBuilder.build(); ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); - Generate generatedFunction = new Generate(generatedType, + Generate<Reference> generatedFunction = new Generate<>(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); - output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply()); + output = new com.yahoo.tensor.functions.Join<>(output, generatedFunction, ScalarFunctions.multiply()); } if (postOperator != null) { - output = new com.yahoo.tensor.functions.Map(output, postOperator); + output = new com.yahoo.tensor.functions.Map<>(output, postOperator); } return output; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java index 724e49084ee..d80058dfa07 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; @@ -43,9 +44,9 @@ public class Rename extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(1)) return null; - return new com.yahoo.tensor.functions.Rename(inputs.get(0).function().orElse(null), from, to); + return new com.yahoo.tensor.functions.Rename<>(inputs.get(0).function().orElse(null), from, to); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index 57a43158c0d..7b675fa79af 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -110,12 +110,12 @@ public class Reshape extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent) ) return null; if ( ! inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent) ) return null; OrderedTensorType inputType = inputs.get(0).type().get(); - TensorFunction inputFunction = inputs.get(0).function().get(); + TensorFunction<Reference> inputFunction = inputs.get(0).function().get(); return reshape(inputFunction, inputType, type); } @@ -129,7 +129,7 @@ public class Reshape extends IntermediateOperation { return new Reshape(modelName(), name(), inputs, attributeMap); } - public TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) { + public TensorFunction<Reference> reshape(TensorFunction<Reference> inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) { if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type()))) throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java index a189ff9c07c..9836217866b 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ScalarFunctions; @@ -34,13 +35,13 @@ public class Select extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!allInputFunctionsPresent(3)) { return null; } IntermediateOperation conditionOperation = inputs().get(0); - TensorFunction a = inputs().get(1).function().get(); - TensorFunction b = inputs().get(2).function().get(); + TensorFunction<Reference> a = inputs().get(1).function().get(); + TensorFunction<Reference> b = inputs().get(2).function().get(); // Shortcut: if we know during import which tensor to select, do that directly here. if (conditionOperation.getConstantValue().isPresent()) { @@ -61,13 +62,13 @@ public class Select extends IntermediateOperation { // from 'x'. We do this by individually joining 'x' and 'y' with // 'condition', and then joining the resulting two tensors. - TensorFunction conditionFunction = conditionOperation.function().get(); - TensorFunction aCond = new com.yahoo.tensor.functions.Join(a, conditionFunction, ScalarFunctions.multiply()); - TensorFunction bCond = new com.yahoo.tensor.functions.Join(b, conditionFunction, new DoubleBinaryOperator() { + TensorFunction<Reference> conditionFunction = conditionOperation.function().get(); + TensorFunction<Reference> aCond = new com.yahoo.tensor.functions.Join<>(a, conditionFunction, ScalarFunctions.multiply()); + TensorFunction<Reference> bCond = new com.yahoo.tensor.functions.Join<>(b, conditionFunction, new DoubleBinaryOperator() { @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); } @Override public String toString() { return "f(a,b)(a * (1-b))"; } }); - return new com.yahoo.tensor.functions.Join(aCond, bCond, ScalarFunctions.add()); + return new com.yahoo.tensor.functions.Join<>(aCond, bCond, ScalarFunctions.add()); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java index 28e0115810a..c1cffd4243e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; @@ -28,7 +29,7 @@ public class Shape extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { return null; // will be added by function() since this is constant. } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java index ac5d66e22c1..91b7064b19c 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java @@ -143,7 +143,7 @@ public class Slice extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (inputs.size() < 1 || inputs.get(0).function().isEmpty()) { return null; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java index 6001bef87ed..d7060b9d440 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Map; import com.yahoo.tensor.functions.Reduce; @@ -34,12 +35,12 @@ public class Softmax extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(1)) return null; List<String> reduceDimensions = reduceDimensions(); - TensorFunction input = inputs.get(0).function().get(); - TensorFunction sum = new Reduce(input, Reduce.Aggregator.sum, reduceDimensions); - TensorFunction div = new Join(input, sum, ScalarFunctions.divide()); + TensorFunction<Reference> input = inputs.get(0).function().get(); + TensorFunction<Reference> sum = new Reduce<>(input, Reduce.Aggregator.sum, reduceDimensions); + TensorFunction<Reference> div = new Join<>(input, sum, ScalarFunctions.divide()); return div; } @@ -93,13 +94,13 @@ public class Softmax extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(1)) return null; List<String> reduceDimensions = reduceDimensions(); - TensorFunction input = inputs.get(0).function().get(); - TensorFunction max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions); - TensorFunction cap = new Join(input, max, ScalarFunctions.subtract()); // to avoid overflow - TensorFunction exp = new Map(cap, ScalarFunctions.exp()); + TensorFunction<Reference> input = inputs.get(0).function().get(); + TensorFunction<Reference> max = new Reduce<>(input, Reduce.Aggregator.max, reduceDimensions); + TensorFunction<Reference> cap = new Join<>(input, max, ScalarFunctions.subtract()); // to avoid overflow + TensorFunction<Reference> exp = new Map<>(cap, ScalarFunctions.exp()); return exp; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java index 2e586b38c71..6f720716adb 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java @@ -84,7 +84,7 @@ public class Split extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!allInputFunctionsPresent(1)) return null; IntermediateOperation input = inputs.get(0); @@ -104,7 +104,7 @@ public class Split extends IntermediateOperation { com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues); ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); - TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression)); + TensorFunction<Reference> generate = Generate.bound(type.type(), wrapScalar(sliceExpression)); return generate; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java index 07110b9b966..9229d6af254 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; @@ -52,11 +53,11 @@ public class Squeeze extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(1)) return null; - TensorFunction inputFunction = inputs.get(0).function().get(); - return new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions); + TensorFunction<Reference> inputFunction = inputs.get(0).function().get(); + return new Reduce<>(inputFunction, Reduce.Aggregator.sum, squeezeDimensions); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java index b8ca114343d..902144cfea2 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; @@ -56,11 +57,11 @@ public class Sum extends IntermediateOperation { // optimization: if keepDims and one reduce dimension that has size 1: same as identity. @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputTypesPresent(2)) return null; - TensorFunction inputFunction = inputs.get(0).function().get(); - TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.sum, reduceDimensions); + TensorFunction<Reference> inputFunction = inputs.get(0).function().get(); + TensorFunction<Reference> output = new Reduce<>(inputFunction, Reduce.Aggregator.sum, reduceDimensions); if (shouldKeepDimensions()) { // multiply with a generated tensor created from the reduced dimensions TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); @@ -69,9 +70,9 @@ public class Sum extends IntermediateOperation { } TensorType generatedType = typeBuilder.build(); ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); - Generate generatedFunction = new Generate(generatedType, + Generate<Reference> generatedFunction = new Generate<>(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); - output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply()); + output = new com.yahoo.tensor.functions.Join<>(output, generatedFunction, ScalarFunctions.multiply()); } return output; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java index f41140075d1..502f0769350 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.functions.TensorFunction; import java.util.List; @@ -29,7 +30,7 @@ public class Switch extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { IntermediateOperation predicateOperation = inputs().get(1); if (!predicateOperation.getConstantValue().isPresent()) { throw new IllegalArgumentException("Switch in " + name + ": predicate must be a constant"); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java index 7fe5e831391..4bfab284cc2 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java @@ -62,7 +62,7 @@ public class Tile extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!allInputFunctionsPresent(2)) return null; IntermediateOperation input = inputs.get(0); @@ -85,7 +85,7 @@ public class Tile extends IntermediateOperation { com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues); ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); - TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression)); + TensorFunction<Reference> generate = Generate.bound(type.type(), wrapScalar(sliceExpression)); return generate; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java index add24e665e6..ef51b11884a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; @@ -36,7 +37,7 @@ public class Transpose extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if (!allInputFunctionsPresent(1)) return null; return inputs.get(0).function().orElse(null); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java index bd3130a7cd1..a73b5a4c6ef 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -64,7 +65,7 @@ public class Unsqueeze extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputFunctionsPresent(1)) return null; // multiply with a generated tensor created from the expanded dimensions @@ -74,9 +75,9 @@ public class Unsqueeze extends IntermediateOperation { } TensorType generatedType = typeBuilder.build(); ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); - Generate generatedFunction = new Generate(generatedType, + Generate<Reference> generatedFunction = new Generate<>(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); - return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply()); + return new com.yahoo.tensor.functions.Join<>(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply()); } @Override diff --git a/model-integration/src/main/javacc/ModelParser.jj b/model-integration/src/main/javacc/ModelParser.jj index 9944b88a745..6f6f3508beb 100644 --- a/model-integration/src/main/javacc/ModelParser.jj +++ b/model-integration/src/main/javacc/ModelParser.jj @@ -170,7 +170,7 @@ void input() : void function() : { String name, expression, parameter; - List parameters = new ArrayList(); + List< String > parameters = new ArrayList< String >(); } { ( <FUNCTION> name = identifier() diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java index dfc4e98d409..3ef96cdf166 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.operations.Constant; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; @@ -703,7 +704,7 @@ public class OnnxOperationsTestCase { return builder.build(); } - private TensorFunction optimizeAndRename(String opName, IntermediateOperation op) { + private TensorFunction<Reference> optimizeAndRename(String opName, IntermediateOperation op) { IntermediateGraph graph = new IntermediateGraph(modelName); graph.put(opName, op); graph.outputs(graph.defaultSignature()).put(opName, opName); @@ -717,7 +718,7 @@ public class OnnxOperationsTestCase { if ( ! operationType.equals(standardNamingType)) { List<String> renameFrom = operationType.dimensionNames(); List<String> renameTo = standardNamingType.dimensionNames(); - TensorFunction func = new Rename(new ConstantTensor(tensor), renameFrom, renameTo); + TensorFunction<Reference> func = new Rename<>(new ConstantTensor<Reference>(tensor), renameFrom, renameTo); return func.evaluate(); } return tensor; diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java index f3b58b47ef0..9b988e9a379 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java @@ -264,7 +264,7 @@ public class NodeAgentImpl implements NodeAgent { ContainerResources currentResources = existingContainer.get().resources(); ContainerResources wantedResources = currentResources.withUnlimitedCpus(); - if ( ! wantedResources.equals(currentResources)) { + if ( ! warmUpDuration(context).isNegative() && ! wantedResources.equals(currentResources)) { context.log(logger, "Updating container resources: %s -> %s", existingContainer.get().resources().toStringCpu(), wantedResources.toStringCpu()); containerOperations.updateContainer(context, existingContainer.get().id(), wantedResources); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/Node.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/Node.java index 4d63863a917..3db68a27234 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/Node.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/Node.java @@ -27,6 +27,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.UUID; import java.util.stream.Collectors; /** @@ -63,7 +64,7 @@ public final class Node implements Nodelike { /** Creates a node builder in the initial state (reserved) */ public static Node.Builder reserve(Set<String> ipAddresses, String hostname, String parentHostname, NodeResources resources, NodeType type) { - return new Node.Builder("fake-" + hostname, hostname, new Flavor(resources), State.reserved, type) + return new Node.Builder(UUID.randomUUID().toString(), hostname, new Flavor(resources), State.reserved, type) .ipConfig(IP.Config.ofEmptyPool(ipAddresses)) .parentHostname(parentHostname); } @@ -140,7 +141,7 @@ public final class Node implements Nodelike { * * - OpenStack: UUID * - AWS: Instance ID - * - Linux containers: fake-[hostname] + * - Linux containers: UUID */ public String id() { return id; } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/docker-container1.json b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/docker-container1.json index c9784c7e610..b5efc69d6db 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/docker-container1.json +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/docker-container1.json @@ -1,6 +1,6 @@ { "url": "http://localhost:8080/nodes/v2/node/test-node-pool-102-2", - "id": "fake-test-node-pool-102-2", + "id": "(ignore)", "state": "active", "type": "tenant", "hostname": "test-node-pool-102-2", diff --git a/searchlib/src/vespa/searchlib/diskindex/field_merger.cpp b/searchlib/src/vespa/searchlib/diskindex/field_merger.cpp index ee5d72f1daa..6179f06c9da 100644 --- a/searchlib/src/vespa/searchlib/diskindex/field_merger.cpp +++ b/searchlib/src/vespa/searchlib/diskindex/field_merger.cpp @@ -13,7 +13,7 @@ #include <vespa/searchlib/index/schemautil.h> #include <vespa/searchlib/util/filekit.h> #include <vespa/searchlib/util/dirtraverse.h> -#include <vespa/searchlib/util/postingpriorityqueue.h> +#include <vespa/searchlib/util/posting_priority_queue_merger.hpp> #include <vespa/vespalib/io/fileutil.h> #include <vespa/vespalib/stllike/asciistream.h> #include <vespa/vespalib/util/exceptions.h> @@ -52,16 +52,20 @@ createTmpPath(const vespalib::string & base, uint32_t index) { FieldMerger::FieldMerger(uint32_t id, const FusionOutputIndex& fusion_out_index, std::shared_ptr<IFlushToken> flush_token) : _id(id), - _field_dir(fusion_out_index.get_path() + "/" + SchemaUtil::IndexIterator(fusion_out_index.get_schema(), id).getName()), + _field_name(SchemaUtil::IndexIterator(fusion_out_index.get_schema(), id).getName()), + _field_dir(fusion_out_index.get_path() + "/" + _field_name), _fusion_out_index(fusion_out_index), _flush_token(std::move(flush_token)), _word_readers(), _word_heap(), + _word_aggregator(), _word_num_mappings(), _num_word_ids(0), _readers(), _heap(), - _writer() + _writer(), + _state(State::MERGE_START), + _failed(false) { } @@ -107,14 +111,14 @@ bool FieldMerger::open_input_word_readers() { _word_readers.reserve(_fusion_out_index.get_old_indexes().size()); - _word_heap = std::make_unique<PostingPriorityQueue<DictionaryWordReader>>(); + _word_heap = std::make_unique<PostingPriorityQueueMerger<DictionaryWordReader, WordAggregator>>(); SchemaUtil::IndexIterator index(_fusion_out_index.get_schema(), _id); for (auto & oi : _fusion_out_index.get_old_indexes()) { auto reader(std::make_unique<DictionaryWordReader>()); const vespalib::string &tmpindexpath = createTmpPath(_field_dir, oi.getIndex()); const vespalib::string &oldindexpath = oi.getPath(); vespalib::string wordMapName = tmpindexpath + "/old2new.dat"; - vespalib::string fieldDir(oldindexpath + "/" + index.getName()); + vespalib::string fieldDir(oldindexpath + "/" + _field_name); vespalib::string dictName(fieldDir + "/dictionary"); const Schema &oldSchema = oi.getSchema(); if (!index.hasOldFields(oldSchema)) { @@ -163,24 +167,33 @@ FieldMerger::read_mapping_files() } bool -FieldMerger::renumber_word_ids() +FieldMerger::renumber_word_ids_start() { - SchemaUtil::IndexIterator index(_fusion_out_index.get_schema(), _id); - vespalib::string indexName = index.getName(); - LOG(debug, "Renumber word IDs for field %s", indexName.c_str()); - - WordAggregator out; - + LOG(debug, "Renumber word IDs for field %s", _field_name.c_str()); if (!open_input_word_readers()) { return false; } - _word_heap->merge(out, 4, *_flush_token); + _word_aggregator = std::make_unique<WordAggregator>(); + return true; +} + +bool +FieldMerger::renumber_word_ids_main() +{ + _word_heap->merge(*_word_aggregator, 4, *_flush_token); if (_flush_token->stop_requested()) { return false; } assert(_word_heap->empty()); + return true; +} + +bool +FieldMerger::renumber_word_ids_finish() +{ _word_heap.reset(); - _num_word_ids = out.getWordNum(); + _num_word_ids = _word_aggregator->getWordNum(); + _word_aggregator.reset(); // Close files for (auto &i : _word_readers) { @@ -193,11 +206,21 @@ FieldMerger::renumber_word_ids() if (!read_mapping_files()) { return false; } - LOG(debug, "Finished renumbering words IDs for field %s", indexName.c_str()); + LOG(debug, "Finished renumbering words IDs for field %s", _field_name.c_str()); return true; } +void +FieldMerger::renumber_word_ids_failed() +{ + _failed = true; + if (_flush_token->stop_requested()) { + return; + } + LOG(error, "Could not renumber field word ids for field %s dir %s", _field_name.c_str(), _field_dir.c_str()); +} + std::shared_ptr<FieldLengthScanner> FieldMerger::allocate_field_length_scanner() { @@ -226,7 +249,6 @@ FieldMerger::open_input_field_readers() _readers.reserve(_fusion_out_index.get_old_indexes().size()); SchemaUtil::IndexIterator index(_fusion_out_index.get_schema(), _id); auto field_length_scanner = allocate_field_length_scanner(); - vespalib::string indexName = index.getName(); for (const auto &oi : _fusion_out_index.get_old_indexes()) { const Schema &oldSchema = oi.getSchema(); if (!index.hasOldFields(oldSchema)) { @@ -234,7 +256,7 @@ FieldMerger::open_input_field_readers() } auto reader = FieldReader::allocFieldReader(index, oldSchema, field_length_scanner); reader->setup(_word_num_mappings[oi.getIndex()], oi.getDocIdMapping()); - if (!reader->open(oi.getPath() + "/" + indexName + "/", _fusion_out_index.get_tune_file_indexing()._read)) { + if (!reader->open(oi.getPath() + "/" + _field_name + "/", _fusion_out_index.get_tune_file_indexing()._read)) { return false; } _readers.push_back(std::move(reader)); @@ -322,7 +344,7 @@ FieldMerger::select_cooked_or_raw_features(FieldReader& reader) bool FieldMerger::setup_merge_heap() { - _heap = std::make_unique<PostingPriorityQueue<FieldReader>>(); + _heap = std::make_unique<PostingPriorityQueueMerger<FieldReader, FieldWriter>>(); for (auto &reader : _readers) { if (!select_cooked_or_raw_features(*reader)) { return false; @@ -338,12 +360,10 @@ FieldMerger::setup_merge_heap() } bool -FieldMerger::merge_postings() +FieldMerger::merge_postings_start() { - SchemaUtil::IndexIterator index(_fusion_out_index.get_schema(), _id); /* OUTPUT */ _writer = std::make_unique<FieldWriter>(_fusion_out_index.get_doc_id_limit(), _num_word_ids); - vespalib::string indexName = index.getName(); if (!open_input_field_readers()) { return false; @@ -351,15 +371,23 @@ FieldMerger::merge_postings() if (!open_field_writer()) { return false; } - if (!setup_merge_heap()) { - return false; - } + return setup_merge_heap(); +} +bool +FieldMerger::merge_postings_main() +{ _heap->merge(*_writer, 4, *_flush_token); if (_flush_token->stop_requested()) { return false; } assert(_heap->empty()); + return true; +} + +bool +FieldMerger::merge_postings_finish() +{ _heap.reset(); for (auto &reader : _readers) { @@ -375,55 +403,116 @@ FieldMerger::merge_postings() return true; } -bool -FieldMerger::merge_field() +void +FieldMerger::merge_postings_failed() +{ + _failed = true; + if (_flush_token->stop_requested()) { + return; + } + throw IllegalArgumentException(make_string("Could not merge field postings for field %s dir %s", + _field_name.c_str(), _field_dir.c_str())); +} + +void +FieldMerger::merge_field_start() { const Schema &schema = _fusion_out_index.get_schema(); SchemaUtil::IndexIterator index(schema, _id); - const vespalib::string &indexName = index.getName(); SchemaUtil::IndexSettings settings = index.getIndexSettings(); if (settings.hasError()) { - return false; + _failed = true; + return; } if (FileKit::hasStamp(_field_dir + "/.mergeocc_done")) { - return true; + _state = State::MERGE_DONE; + return; } vespalib::mkdir(_field_dir, false); - LOG(debug, "merge_field for field %s dir %s", indexName.c_str(), _field_dir.c_str()); + LOG(debug, "merge_field for field %s dir %s", _field_name.c_str(), _field_dir.c_str()); make_tmp_dirs(); - if (!renumber_word_ids()) { - if (_flush_token->stop_requested()) { - return false; - } - LOG(error, "Could not renumber field word ids for field %s dir %s", indexName.c_str(), _field_dir.c_str()); - return false; + if (!renumber_word_ids_start()) { + renumber_word_ids_failed(); + return; } + _state = State::RENUMBER_WORD_IDS; +} - // Tokamak - bool res = merge_postings(); +void +FieldMerger::merge_field_finish() +{ + bool res = merge_postings_finish(); if (!res) { - if (_flush_token->stop_requested()) { - return false; - } - throw IllegalArgumentException(make_string("Could not merge field postings for field %s dir %s", - indexName.c_str(), _field_dir.c_str())); + merge_postings_failed(); + _failed = true; + return; } if (!FileKit::createStamp(_field_dir + "/.mergeocc_done")) { - return false; + _failed = true; + return; } vespalib::File::sync(_field_dir); if (!clean_tmp_dirs()) { - return false; + _failed = true; + return; } - LOG(debug, "Finished merge_field for field %s dir %s", indexName.c_str(), _field_dir.c_str()); + LOG(debug, "Finished merge_field for field %s dir %s", _field_name.c_str(), _field_dir.c_str()); - return true; + _state = State::MERGE_DONE; +} + +void +FieldMerger::process_merge_field() +{ + switch (_state) { + case State::MERGE_START: + merge_field_start(); + break; + case State::RENUMBER_WORD_IDS: + if (!renumber_word_ids_main()) { + renumber_word_ids_failed(); + } else { + _state = State::RENUMBER_WORD_IDS_FINISH; + } + break; + case State::RENUMBER_WORD_IDS_FINISH: + if (!renumber_word_ids_finish()) { + renumber_word_ids_failed(); + } else if (!merge_postings_start()) { + merge_postings_failed(); + } else { + _state = State::MERGE_POSTINGS; + } + break; + case State::MERGE_POSTINGS: + if (!merge_postings_main()) { + merge_postings_failed(); + } else { + _state = State::MERGE_POSTINGS_FINISH; + } + break; + case State::MERGE_POSTINGS_FINISH: + merge_field_finish(); + break; + case State::MERGE_DONE: + default: + LOG_ABORT("should not be reached"); + } +} + +bool +FieldMerger::merge_field() +{ + while (!_failed && _state != State::MERGE_DONE) { + process_merge_field(); + } + return !_failed; } } diff --git a/searchlib/src/vespa/searchlib/diskindex/field_merger.h b/searchlib/src/vespa/searchlib/diskindex/field_merger.h index a57005e18a4..c5ce337e845 100644 --- a/searchlib/src/vespa/searchlib/diskindex/field_merger.h +++ b/searchlib/src/vespa/searchlib/diskindex/field_merger.h @@ -9,7 +9,7 @@ namespace search { class IFlushToken; -template <class IN> class PostingPriorityQueue; +template <class Reader, class Writer> class PostingPriorityQueueMerger; } namespace search::diskindex { @@ -19,6 +19,7 @@ class FieldLengthScanner; class FieldReader; class FieldWriter; class FusionOutputIndex; +class WordAggregator; class WordNumMapping; /* @@ -28,32 +29,54 @@ class FieldMerger { using WordNumMappingList = std::vector<WordNumMapping>; + enum class State { + MERGE_START, + RENUMBER_WORD_IDS, + RENUMBER_WORD_IDS_FINISH, + MERGE_POSTINGS, + MERGE_POSTINGS_FINISH, + MERGE_DONE + }; + uint32_t _id; + vespalib::string _field_name; vespalib::string _field_dir; const FusionOutputIndex& _fusion_out_index; std::shared_ptr<IFlushToken> _flush_token; std::vector<std::unique_ptr<DictionaryWordReader>> _word_readers; - std::unique_ptr<PostingPriorityQueue<DictionaryWordReader>> _word_heap; + std::unique_ptr<PostingPriorityQueueMerger<DictionaryWordReader, WordAggregator>> _word_heap; + std::unique_ptr<WordAggregator> _word_aggregator; WordNumMappingList _word_num_mappings; uint64_t _num_word_ids; std::vector<std::unique_ptr<FieldReader>> _readers; - std::unique_ptr<PostingPriorityQueue<FieldReader>> _heap; + std::unique_ptr<PostingPriorityQueueMerger<FieldReader, FieldWriter>> _heap; std::unique_ptr<FieldWriter> _writer; + State _state; + bool _failed; void make_tmp_dirs(); bool clean_tmp_dirs(); bool open_input_word_readers(); bool read_mapping_files(); - bool renumber_word_ids(); + bool renumber_word_ids_start(); + bool renumber_word_ids_main(); + bool renumber_word_ids_finish(); + void renumber_word_ids_failed(); std::shared_ptr<FieldLengthScanner> allocate_field_length_scanner(); bool open_input_field_readers(); bool open_field_writer(); bool select_cooked_or_raw_features(FieldReader& reader); bool setup_merge_heap(); - bool merge_postings(); + bool merge_postings_start(); + bool merge_postings_main(); + bool merge_postings_finish(); + void merge_postings_failed(); public: FieldMerger(uint32_t id, const FusionOutputIndex& fusion_out_index, std::shared_ptr<IFlushToken> flush_token); ~FieldMerger(); + void merge_field_start(); + void merge_field_finish(); + void process_merge_field(); // Called multiple times bool merge_field(); }; diff --git a/searchlib/src/vespa/searchlib/diskindex/fusion.cpp b/searchlib/src/vespa/searchlib/diskindex/fusion.cpp index eafbbac361b..1afed18cb48 100644 --- a/searchlib/src/vespa/searchlib/diskindex/fusion.cpp +++ b/searchlib/src/vespa/searchlib/diskindex/fusion.cpp @@ -3,46 +3,29 @@ #include "fusion.h" #include "fusion_input_index.h" #include "field_merger.h" -#include "fieldreader.h" -#include "dictionarywordreader.h" -#include "field_length_scanner.h" -#include <vespa/vespalib/util/stringfmt.h> -#include <vespa/searchlib/bitcompression/posocc_fields_params.h> +#include <vespa/fastos/file.h> +#include <vespa/searchlib/common/documentsummary.h> #include <vespa/searchlib/common/i_flush_token.h> -#include <vespa/searchlib/index/field_length_info.h> -#include <vespa/searchlib/util/filekit.h> +#include <vespa/searchlib/index/schemautil.h> #include <vespa/searchlib/util/dirtraverse.h> -#include <vespa/searchlib/util/postingpriorityqueue.h> #include <vespa/vespalib/io/fileutil.h> -#include <vespa/searchlib/common/documentsummary.h> +#include <vespa/vespalib/util/count_down_latch.h> #include <vespa/vespalib/util/error.h> +#include <vespa/vespalib/util/exceptions.h> #include <vespa/vespalib/util/lambdatask.h> -#include <vespa/vespalib/util/count_down_latch.h> -#include <vespa/vespalib/stllike/asciistream.h> #include <vespa/document/util/queue.h> -#include <sstream> #include <vespa/log/log.h> -#include <vespa/vespalib/util/exceptions.h> LOG_SETUP(".diskindex.fusion"); -using search::FileKit; -using search::PostingPriorityQueue; using search::common::FileHeaderContext; -using search::diskindex::DocIdMapping; -using search::diskindex::WordNumMapping; using search::docsummary::DocumentSummary; -using search::index::FieldLengthInfo; -using search::bitcompression::PosOccFieldParams; -using search::bitcompression::PosOccFieldsParams; -using search::index::PostingListParams; using search::index::Schema; using search::index::SchemaUtil; using search::index::schema::DataType; using vespalib::getLastErrorString; using vespalib::IllegalArgumentException; -using vespalib::make_string; namespace search::diskindex { diff --git a/searchlib/src/vespa/searchlib/diskindex/fusion.h b/searchlib/src/vespa/searchlib/diskindex/fusion.h index 22dda4d6edf..7e0b70dca36 100644 --- a/searchlib/src/vespa/searchlib/diskindex/fusion.h +++ b/searchlib/src/vespa/searchlib/diskindex/fusion.h @@ -3,12 +3,10 @@ #pragma once #include "fusion_output_index.h" - #include <vespa/vespalib/util/threadexecutor.h> namespace search { class IFlushToken; -template <class IN> class PostingPriorityQueue; class TuneFileIndexing; } diff --git a/searchlib/src/vespa/searchlib/test/fakedata/fakememtreeocc.cpp b/searchlib/src/vespa/searchlib/test/fakedata/fakememtreeocc.cpp index 8ae3fe0bdad..64d194c5c7e 100644 --- a/searchlib/src/vespa/searchlib/test/fakedata/fakememtreeocc.cpp +++ b/searchlib/src/vespa/searchlib/test/fakedata/fakememtreeocc.cpp @@ -5,7 +5,7 @@ #include <vespa/searchlib/common/flush_token.h> #include <vespa/searchlib/memoryindex/posting_iterator.h> #include <vespa/searchlib/queryeval/iterators.h> -#include <vespa/searchlib/util/postingpriorityqueue.h> +#include <vespa/searchlib/util/posting_priority_queue_merger.hpp> #include <vespa/vespalib/datastore/buffer_type.hpp> #include <vespa/vespalib/btree/btreeiterator.hpp> #include <vespa/vespalib/btree/btreenode.hpp> @@ -352,7 +352,7 @@ FakeMemTreeOccFactory::setup(const std::vector<const FakeWord *> &fws) ++wordIdx; } - PostingPriorityQueue<FakeWord::RandomizedReader> heap; + PostingPriorityQueueMerger<FakeWord::RandomizedReader, FakeWord::RandomizedWriter> heap; std::vector<FakeWord::RandomizedReader>::iterator i(r.begin()); std::vector<FakeWord::RandomizedReader>::iterator ie(r.end()); FlushToken flush_token; diff --git a/searchlib/src/vespa/searchlib/util/posting_priority_queue.h b/searchlib/src/vespa/searchlib/util/posting_priority_queue.h new file mode 100644 index 00000000000..01ae0995806 --- /dev/null +++ b/searchlib/src/vespa/searchlib/util/posting_priority_queue.h @@ -0,0 +1,59 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vector> + +namespace search { + +/* + * Provide priority queue semantics for a set of posting readers. + */ +template <class Reader> +class PostingPriorityQueue +{ +public: + class Ref + { + Reader *_ref; + public: + Ref(Reader *ref) + : _ref(ref) + { + } + + bool operator<(const Ref &rhs) const { return *_ref < *rhs._ref; } + Reader *get() const noexcept { return _ref; } + }; + + using Vector = std::vector<Ref>; + Vector _vec; + + PostingPriorityQueue() + : _vec() + { + } + + bool empty() const { return _vec.empty(); } + void clear() { _vec.clear(); } + void initialAdd(Reader *it) { _vec.push_back(Ref(it)); } + + /* + * Sort vector after a set of initial add operations, so lowest() + * and adjust() can be used. + */ + void sort() { std::sort(_vec.begin(), _vec.end()); } + + /* + * Return lowest value. Assumes vector is sorted. + */ + Reader *lowest() const { return _vec.front().get(); } + + /* + * The vector might no longer be sorted since the first element has changed + * value. Perform adjustments to make vector sorted again. + */ + void adjust(); +}; + +} diff --git a/searchlib/src/vespa/searchlib/util/posting_priority_queue.hpp b/searchlib/src/vespa/searchlib/util/posting_priority_queue.hpp new file mode 100644 index 00000000000..69bf7bc547c --- /dev/null +++ b/searchlib/src/vespa/searchlib/util/posting_priority_queue.hpp @@ -0,0 +1,35 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "posting_priority_queue.h" + +namespace search { + +template <class Reader> +void +PostingPriorityQueue<Reader>::adjust() +{ + typedef typename Vector::iterator VIT; + if (!_vec.front().get()->isValid()) { + _vec.erase(_vec.begin()); // Iterator no longer valid + return; + } + if (_vec.size() == 1) { // Only one iterator left + return; + } + // Peform binary search to find first element higher than changed value + VIT gt = std::upper_bound(_vec.begin() + 1, _vec.end(), _vec.front()); + VIT to = _vec.begin(); + VIT from = to; + ++from; + Ref changed = *to; // Remember changed value + while (from != gt) { // Shift elements to make space for changed value + *to = *from; + ++from; + ++to; + } + *to = changed; // Save changed value at right location +} + +} diff --git a/searchlib/src/vespa/searchlib/util/posting_priority_queue_merger.h b/searchlib/src/vespa/searchlib/util/posting_priority_queue_merger.h new file mode 100644 index 00000000000..8dd941a2c13 --- /dev/null +++ b/searchlib/src/vespa/searchlib/util/posting_priority_queue_merger.h @@ -0,0 +1,32 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "posting_priority_queue.h" + +namespace search { + +/* + * Provide priority queue semantics for a set of posting readers with + * merging to a posting writer. + */ +template <class Reader, class Writer> +class PostingPriorityQueueMerger : public PostingPriorityQueue<Reader> +{ +public: + using Parent = PostingPriorityQueue<Reader>; + using Vector = typename Parent::Vector; + using Parent::_vec; + using Parent::adjust; + using Parent::empty; + using Parent::lowest; + using Parent::sort; + + void mergeHeap(Writer& writer, const IFlushToken& flush_token) __attribute__((noinline)); + static void mergeOne(Writer& writer, Reader& reader, const IFlushToken &flush_token) __attribute__((noinline)); + static void mergeTwo(Writer& writer, Reader& reader1, Reader& reader2, const IFlushToken& flush_token) __attribute__((noinline)); + static void mergeSmall(Writer& writer, typename Vector::iterator ib, typename Vector::iterator ie, const IFlushToken &flush_token) __attribute__((noinline)); + void merge(Writer& writer, uint32_t heapLimit, const IFlushToken& flush_token) __attribute__((noinline)); +}; + +} diff --git a/searchlib/src/vespa/searchlib/util/posting_priority_queue_merger.hpp b/searchlib/src/vespa/searchlib/util/posting_priority_queue_merger.hpp new file mode 100644 index 00000000000..d33356cee4a --- /dev/null +++ b/searchlib/src/vespa/searchlib/util/posting_priority_queue_merger.hpp @@ -0,0 +1,114 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "posting_priority_queue.hpp" +#include "posting_priority_queue_merger.h" + +namespace search { + +template <class Reader, class Writer> +void +PostingPriorityQueueMerger<Reader, Writer>::mergeHeap(Writer& writer, const IFlushToken& flush_token) +{ + while (!empty() && !flush_token.stop_requested()) { + Reader *low = lowest(); + low->write(writer); + low->read(); + adjust(); + } +} + +template <class Reader, class Writer> +void +PostingPriorityQueueMerger<Reader, Writer>::mergeOne(Writer& writer, Reader& reader, const IFlushToken& flush_token) +{ + while (reader.isValid() && !flush_token.stop_requested()) { + reader.write(writer); + reader.read(); + } +} + +template <class Reader, class Writer> +void +PostingPriorityQueueMerger<Reader, Writer>::mergeTwo(Writer& writer, Reader& reader1, Reader& reader2, const IFlushToken& flush_token) +{ + while (!flush_token.stop_requested()) { + Reader &low = reader2 < reader1 ? reader2 : reader1; + low.write(writer); + low.read(); + if (!low.isValid()) + break; + } +} + +template <class Reader, class Writer> +void +PostingPriorityQueueMerger<Reader, Writer>::mergeSmall(Writer& writer, typename Vector::iterator ib, typename Vector::iterator ie, const IFlushToken& flush_token) +{ + while (!flush_token.stop_requested()) { + typename Vector::iterator i = ib; + Reader *low = i->get(); + for (++i; i != ie; ++i) + if (*i->get() < *low) + low = i->get(); + low->write(writer); + low->read(); + if (!low->isValid()) + break; + } +} + +template <class Reader, class Writer> +void +PostingPriorityQueueMerger<Reader, Writer>::merge(Writer& writer, uint32_t heapLimit, const IFlushToken& flush_token) +{ + if (_vec.empty()) + return; + for (typename Vector::iterator i = _vec.begin(), ie = _vec.end(); i != ie; + ++i) { + assert(i->get()->isValid()); + } + if (_vec.size() >= heapLimit) { + sort(); + void (PostingPriorityQueueMerger::*mergeHeapFunc)(Writer& writer, const IFlushToken& flush_token) = + &PostingPriorityQueueMerger::mergeHeap; + (this->*mergeHeapFunc)(writer, flush_token); + return; + } + while (!flush_token.stop_requested()) { + if (_vec.size() == 1) { + void (*mergeOneFunc)(Writer& writer, Reader& reader, const IFlushToken& flush_token) = + &PostingPriorityQueueMerger::mergeOne; + (*mergeOneFunc)(writer, *_vec.front().get(), flush_token); + _vec.clear(); + return; + } + if (_vec.size() == 2) { + void (*mergeTwoFunc)(Writer& writer, Reader& reader1, Reader& reader2, const IFlushToken& flush_token) = + &PostingPriorityQueueMerger::mergeTwo; + (*mergeTwoFunc)(writer, *_vec[0].get(), *_vec[1].get(), flush_token); + } else { + void (*mergeSmallFunc)(Writer& writer, + typename Vector::iterator ib, + typename Vector::iterator ie, + const IFlushToken& flush_token) = + &PostingPriorityQueueMerger::mergeSmall; + (*mergeSmallFunc)(writer, _vec.begin(), _vec.end(), flush_token); + } + for (typename Vector::iterator i = _vec.begin(), ie = _vec.end(); + i != ie; ++i) { + if (!i->get()->isValid()) { + _vec.erase(i); + break; + } + } + for (typename Vector::iterator i = _vec.begin(), ie = _vec.end(); + i != ie; ++i) { + assert(i->get()->isValid()); + } + assert(!_vec.empty()); + } +} + +} diff --git a/searchlib/src/vespa/searchlib/util/postingpriorityqueue.h b/searchlib/src/vespa/searchlib/util/postingpriorityqueue.h deleted file mode 100644 index c263c6bc470..00000000000 --- a/searchlib/src/vespa/searchlib/util/postingpriorityqueue.h +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#pragma once - -#include <vector> - -namespace search -{ - -/* - * Provide priority queue semantics for a set of posting inputs. - */ -template <class IN> -class PostingPriorityQueue -{ -public: - class Ref - { - IN *_ref; - public: - Ref(IN *ref) - : _ref(ref) - { - } - - bool - operator<(const Ref &rhs) const - { - return *_ref < *rhs._ref; - } - - IN * - get() const - { - return _ref; - } - }; - - typedef std::vector<Ref> Vector; - Vector _vec; - - PostingPriorityQueue() - : _vec() - { - } - - bool - empty() const - { - return _vec.empty(); - } - - void - clear() - { - _vec.clear(); - } - - void - initialAdd(IN *it) - { - _vec.push_back(Ref(it)); - } - - /* - * Sort vector after a set of initial add operations, so lowest() - * and adjust() can be used. - */ - void - sort() - { - std::sort(_vec.begin(), _vec.end()); - } - - /* - * Return lowest value. Assumes vector is sorted. - */ - IN * - lowest() const - { - return _vec.front().get(); - } - - /* - * The vector might no longer be sorted since the first element has changed - * value. Perform adjustments to make vector sorted again. - */ - void - adjust(); - - - template <class OUT> - void - mergeHeap(OUT &out, const IFlushToken& flush_token) __attribute__((noinline)); - - template <class OUT> - static void - mergeOne(OUT &out, IN &in, const IFlushToken &flush_token) __attribute__((noinline)); - - template <class OUT> - static void - mergeTwo(OUT &out, IN &in1, IN &in2, const IFlushToken& flush_token) __attribute__((noinline)); - - template <class OUT> - static void - mergeSmall(OUT &out, - typename Vector::iterator ib, - typename Vector::iterator ie, - const IFlushToken &flush_token) - __attribute__((noinline)); - - template <class OUT> - void - merge(OUT &out, uint32_t heapLimit, const IFlushToken& flush_token) __attribute__((noinline)); -}; - - -template <class IN> -void -PostingPriorityQueue<IN>::adjust() -{ - typedef typename Vector::iterator VIT; - if (!_vec.front().get()->isValid()) { - _vec.erase(_vec.begin()); // Iterator no longer valid - return; - } - if (_vec.size() == 1) // Only one iterator left - return; - // Peform binary search to find first element higher than changed value - VIT gt = std::upper_bound(_vec.begin() + 1, _vec.end(), _vec.front()); - VIT to = _vec.begin(); - VIT from = to; - ++from; - Ref changed = *to; // Remember changed value - while (from != gt) { // Shift elements to make space for changed value - *to = *from; - ++from; - ++to; - } - *to = changed; // Save changed value at right location -} - - -template <class IN> -template <class OUT> -void -PostingPriorityQueue<IN>::mergeHeap(OUT &out, const IFlushToken& flush_token) -{ - while (!empty() && !flush_token.stop_requested()) { - IN *low = lowest(); - low->write(out); - low->read(); - adjust(); - } -} - - -template <class IN> -template <class OUT> -void -PostingPriorityQueue<IN>::mergeOne(OUT &out, IN &in, const IFlushToken& flush_token) -{ - while (in.isValid() && !flush_token.stop_requested()) { - in.write(out); - in.read(); - } -} - -template <class IN> -template <class OUT> -void -PostingPriorityQueue<IN>::mergeTwo(OUT &out, IN &in1, IN &in2, const IFlushToken& flush_token) -{ - while (!flush_token.stop_requested()) { - IN &low = in2 < in1 ? in2 : in1; - low.write(out); - low.read(); - if (!low.isValid()) - break; - } -} - - -template <class IN> -template <class OUT> -void -PostingPriorityQueue<IN>::mergeSmall(OUT &out, - typename Vector::iterator ib, - typename Vector::iterator ie, - const IFlushToken& flush_token) -{ - while (!flush_token.stop_requested()) { - typename Vector::iterator i = ib; - IN *low = i->get(); - for (++i; i != ie; ++i) - if (*i->get() < *low) - low = i->get(); - low->write(out); - low->read(); - if (!low->isValid()) - break; - } -} - - -template <class IN> -template <class OUT> -void -PostingPriorityQueue<IN>::merge(OUT &out, uint32_t heapLimit, const IFlushToken& flush_token) -{ - if (_vec.empty()) - return; - for (typename Vector::iterator i = _vec.begin(), ie = _vec.end(); i != ie; - ++i) { - assert(i->get()->isValid()); - } - if (_vec.size() >= heapLimit) { - sort(); - void (PostingPriorityQueue::*mergeHeapFunc)(OUT &out, const IFlushToken& flush_token) = - &PostingPriorityQueue::mergeHeap; - (this->*mergeHeapFunc)(out, flush_token); - return; - } - while (!flush_token.stop_requested()) { - if (_vec.size() == 1) { - void (*mergeOneFunc)(OUT &out, IN &in, const IFlushToken& flush_token) = - &PostingPriorityQueue<IN>::mergeOne; - (*mergeOneFunc)(out, *_vec.front().get(), flush_token); - _vec.clear(); - return; - } - if (_vec.size() == 2) { - void (*mergeTwoFunc)(OUT &out, IN &in1, IN &in2, const IFlushToken& flush_token) = - &PostingPriorityQueue<IN>::mergeTwo; - (*mergeTwoFunc)(out, *_vec[0].get(), *_vec[1].get(), flush_token); - } else { - void (*mergeSmallFunc)(OUT &out, - typename Vector::iterator ib, - typename Vector::iterator ie, - const IFlushToken& flush_token) = - &PostingPriorityQueue::mergeSmall; - (*mergeSmallFunc)(out, _vec.begin(), _vec.end(), flush_token); - } - for (typename Vector::iterator i = _vec.begin(), ie = _vec.end(); - i != ie; ++i) { - if (!i->get()->isValid()) { - _vec.erase(i); - break; - } - } - for (typename Vector::iterator i = _vec.begin(), ie = _vec.end(); - i != ie; ++i) { - assert(i->get()->isValid()); - } - assert(!_vec.empty()); - } -} - - -} // namespace search - |