diff options
252 files changed, 6415 insertions, 5215 deletions
diff --git a/athenz-identity-provider-service/pom.xml b/athenz-identity-provider-service/pom.xml index 86d4defa861..982cb89f2bf 100644 --- a/athenz-identity-provider-service/pom.xml +++ b/athenz-identity-provider-service/pom.xml @@ -131,6 +131,14 @@ <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> + <configuration> + <compilerArgs> + <arg>-Xlint:all</arg> + <arg>-Xlint:-deprecation</arg> + <arg>-Xlint:-serial</arg> + <arg>-Werror</arg> + </compilerArgs> + </configuration> </plugin> </plugins> </build> diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslKeyStoreConfigurator.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslKeyStoreConfigurator.java index f1fc938d3ea..2a517e06ae2 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslKeyStoreConfigurator.java +++ b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/AthenzSslKeyStoreConfigurator.java @@ -23,11 +23,11 @@ import java.security.GeneralSecurityException; import java.security.KeyStore; import java.security.KeyStoreException; import java.security.PrivateKey; -import java.security.SecureRandom; import java.security.cert.X509Certificate; import java.time.Duration; import java.time.Instant; import java.util.Optional; +import java.util.UUID; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -45,7 +45,6 @@ import static com.yahoo.vespa.hosted.athenz.instanceproviderservice.impl.Utils.g @SuppressWarnings("unused") // Component injected into Jetty connector factory public class AthenzSslKeyStoreConfigurator extends AbstractComponent implements SslKeyStoreConfigurator { private static final Logger log = Logger.getLogger(AthenzSslKeyStoreConfigurator.class.getName()); - private static final SecureRandom secureRandom = new SecureRandom(); private static final String CERTIFICATE_ALIAS = "athenz"; private static final Duration EXPIRATION_MARGIN = Duration.ofHours(6); @@ -172,12 +171,7 @@ public class AthenzSslKeyStoreConfigurator extends AbstractComponent implements } private static char[] generateKeystorePassword() { - int length = 128; - char[] pwd = new char[length]; - for (int i = 0; i < length; i++) { - pwd[i] = (char) secureRandom.nextInt(); - } - return pwd; + return UUID.randomUUID().toString().toCharArray(); } private class AthenzCertificateUpdater implements Runnable { diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGenerator.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGenerator.java index 728406c297f..59126fd023f 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGenerator.java +++ b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGenerator.java @@ -7,6 +7,7 @@ import com.yahoo.net.HostName; import com.yahoo.vespa.athenz.api.AthenzService; import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper; import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocument; +import com.yahoo.vespa.athenz.identityprovider.api.IdentityType; import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument; import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId; import com.yahoo.vespa.hosted.athenz.instanceproviderservice.KeyProvider; @@ -27,7 +28,10 @@ import java.util.Objects; import java.util.Set; /** + * Generates a signed identity document for a given hostname and type + * * @author mortent + * @author bjorncs */ public class IdentityDocumentGenerator { @@ -47,10 +51,10 @@ public class IdentityDocumentGenerator { this.keyProvider = keyProvider; } - public SignedIdentityDocument generateSignedIdentityDocument(String hostname) { + public SignedIdentityDocument generateSignedIdentityDocument(String hostname, IdentityType identityType) { Node node = nodeRepository.getNode(hostname).orElseThrow(() -> new RuntimeException("Unable to find node " + hostname)); try { - IdentityDocument identityDocument = generateIdDocument(node); + IdentityDocument identityDocument = generateIdDocument(node, identityType); String identityDocumentString = Utils.getMapper().writeValueAsString(EntityBindingsMapper.toIdentityDocumentEntity(identityDocument)); String encodedIdentityDocument = @@ -70,13 +74,18 @@ public class IdentityDocumentGenerator { toZoneDnsSuffix(zone, zoneConfig.certDnsSuffix()), new AthenzService(zoneConfig.domain(), zoneConfig.serviceName()), URI.create(zoneConfig.ztsUrl()), - SignedIdentityDocument.DEFAULT_DOCUMENT_VERSION); + SignedIdentityDocument.DEFAULT_DOCUMENT_VERSION, + identityDocument.configServerHostname(), + identityDocument.instanceHostname(), + identityDocument.createdAt(), + identityDocument.ipAddresses(), + identityType); } catch (Exception e) { throw new RuntimeException("Exception generating identity document: " + e.getMessage(), e); } } - private IdentityDocument generateIdDocument(Node node) { + private IdentityDocument generateIdDocument(Node node, IdentityType identityType) { Allocation allocation = node.allocation().orElseThrow(() -> new RuntimeException("No allocation for node " + node.hostname())); VespaUniqueInstanceId providerUniqueId = new VespaUniqueInstanceId( allocation.membership().index(), @@ -85,17 +94,10 @@ public class IdentityDocumentGenerator { allocation.owner().application().value(), allocation.owner().tenant().value(), zone.region().value(), - zone.environment().value()); + zone.environment().value(), + identityType); - // TODO: Hack to allow access from docker containers to non-ipv6 services. - // Remove when yca-bridge is no longer needed Set<String> ips = new HashSet<>(node.ipAddresses()); - if(node.parentHostname().isPresent()) { - String parentHostName = node.parentHostname().get(); - nodeRepository.getNode(parentHostName) - .map(Node::ipAddresses) - .ifPresent(ips::addAll); - } return new IdentityDocument( providerUniqueId, HostName.getLocalhost(), diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentResource.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentResource.java index 93668006e26..219e12c7223 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentResource.java +++ b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentResource.java @@ -6,6 +6,7 @@ import com.yahoo.container.jaxrs.annotation.Component; import com.yahoo.jdisc.http.servlet.ServletRequest; import com.yahoo.log.LogLevel; import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper; +import com.yahoo.vespa.athenz.identityprovider.api.IdentityType; import com.yahoo.vespa.athenz.identityprovider.api.bindings.IdentityDocumentApi; import com.yahoo.vespa.athenz.identityprovider.api.bindings.SignedIdentityDocumentEntity; import com.yahoo.vespa.hosted.provision.restapi.v2.filter.NodePrincipal; @@ -18,7 +19,6 @@ import javax.ws.rs.InternalServerErrorException; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; -import javax.ws.rs.QueryParam; import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import java.util.logging.Logger; @@ -41,15 +41,7 @@ public class IdentityDocumentResource implements IdentityDocumentApi { this.request = request; } - /** - * @deprecated Use {@link #getNodeIdentityDocument(String)} and {@link #getTenantIdentityDocument(String)} instead. - */ - @GET - @Produces(MediaType.APPLICATION_JSON) - @Deprecated - @Override - // TODO Make this method private when the rest api is not longer in use - public SignedIdentityDocumentEntity getIdentityDocument(@QueryParam("hostname") String hostname) { + private SignedIdentityDocumentEntity getIdentityDocument(String hostname, IdentityType identityType) { if (hostname == null) { throw new BadRequestException("The 'hostname' query parameter is missing"); } @@ -67,7 +59,7 @@ public class IdentityDocumentResource implements IdentityDocumentApi { throw new ForbiddenException(); } try { - return EntityBindingsMapper.toSignedIdentityDocumentEntity(identityDocumentGenerator.generateSignedIdentityDocument(hostname)); + return EntityBindingsMapper.toSignedIdentityDocumentEntity(identityDocumentGenerator.generateSignedIdentityDocument(hostname, identityType)); } catch (Exception e) { String message = String.format("Unable to generate identity doument for '%s': %s", hostname, e.getMessage()); log.log(LogLevel.ERROR, message, e); @@ -80,7 +72,7 @@ public class IdentityDocumentResource implements IdentityDocumentApi { @Path("/node/{host}") @Override public SignedIdentityDocumentEntity getNodeIdentityDocument(@PathParam("host") String host) { - return getIdentityDocument(host); + return getIdentityDocument(host, IdentityType.NODE); } @GET @@ -88,7 +80,7 @@ public class IdentityDocumentResource implements IdentityDocumentApi { @Path("/tenant/{host}") @Override public SignedIdentityDocumentEntity getTenantIdentityDocument(@PathParam("host") String host) { - return getIdentityDocument(host); + return getIdentityDocument(host, IdentityType.TENANT); } } diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidator.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidator.java index e457df37946..0201c46b253 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidator.java +++ b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidator.java @@ -82,6 +82,7 @@ public class InstanceValidator { } // If/when we dont care about logging exactly whats wrong, this can be simplified + // TODO Use identity type to determine if this check should be performed boolean isSameIdentityAsInServicesXml(ApplicationId applicationId, String domain, String service) { Optional<ApplicationInfo> applicationInfo = superModelProvider.getSuperModel().getApplicationInfo(applicationId); diff --git a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGeneratorTest.java b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGeneratorTest.java index d7b061ca2f1..078ef1b7e39 100644 --- a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGeneratorTest.java +++ b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/identitydocument/IdentityDocumentGeneratorTest.java @@ -15,6 +15,7 @@ import com.yahoo.config.provision.SystemName; import com.yahoo.config.provision.TenantName; import com.yahoo.config.provision.Zone; import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper; +import com.yahoo.vespa.athenz.identityprovider.api.IdentityType; import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument; import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId; import com.yahoo.vespa.athenz.identityprovider.api.bindings.SignedIdentityDocumentEntity; @@ -81,7 +82,7 @@ public class IdentityDocumentGeneratorTest { AthenzProviderServiceConfig config = getAthenzProviderConfig("domain", "service", dnsSuffix, ZONE); IdentityDocumentGenerator identityDocumentGenerator = new IdentityDocumentGenerator(config, nodeRepository, ZONE, keyProvider); - SignedIdentityDocument signedIdentityDocument = identityDocumentGenerator.generateSignedIdentityDocument(containerHostname); + SignedIdentityDocument signedIdentityDocument = identityDocumentGenerator.generateSignedIdentityDocument(containerHostname, IdentityType.TENANT); // Verify attributes assertEquals(containerHostname, signedIdentityDocument.identityDocument().instanceHostname()); @@ -92,11 +93,11 @@ public class IdentityDocumentGeneratorTest { assertEquals(expectedZoneDnsSuffix, signedIdentityDocument.dnsSuffix()); VespaUniqueInstanceId expectedProviderUniqueId = - new VespaUniqueInstanceId(0, "default", "default", "application", "tenant", region, environment); + new VespaUniqueInstanceId(0, "default", "default", "application", "tenant", region, environment, IdentityType.TENANT); assertEquals(expectedProviderUniqueId, signedIdentityDocument.providerUniqueId()); - // Validate that both parent and container ips are present - assertThat(signedIdentityDocument.identityDocument().ipAddresses(), Matchers.containsInAnyOrder("127.0.0.1", "::1")); + // Validate that container ips are present + assertThat(signedIdentityDocument.identityDocument().ipAddresses(), Matchers.containsInAnyOrder("::1")); SignedIdentityDocumentEntity signedIdentityDocumentEntity = EntityBindingsMapper.toSignedIdentityDocumentEntity(signedIdentityDocument); diff --git a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidatorTest.java b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidatorTest.java index 54786c86cd3..54411b424eb 100644 --- a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidatorTest.java +++ b/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/instanceconfirmation/InstanceValidatorTest.java @@ -143,7 +143,12 @@ public class InstanceValidatorTest { "dnssuffix", "service", URI.create("http://localhost/zts"), - 1)); + 1, + identityDocument.configServerHostname, + identityDocument.instanceHostname, + identityDocument.createdAt, + identityDocument.ipAddresses, + null)); // TODO Remove support for legacy representation without type } catch (Exception e) { throw new RuntimeException(e); } diff --git a/bundle-plugin-test/pom.xml b/bundle-plugin-test/pom.xml index 5ae5496b1b0..53be71352c8 100644 --- a/bundle-plugin-test/pom.xml +++ b/bundle-plugin-test/pom.xml @@ -48,6 +48,14 @@ <artifactId>scala-library</artifactId> <scope>provided</scope> </dependency> + + <dependency> + <!-- Added to verify that module-info.class can be handled by bundle-plugin without throwing an exception. --> + <groupId>javax.xml.bind</groupId> + <artifactId>jaxb-api</artifactId> + <version>2.3.0</version> + </dependency> + </dependencies> <build> <plugins> diff --git a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.scala b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.scala index 903ad94e9e8..539684f2024 100644 --- a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.scala +++ b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.scala @@ -9,7 +9,7 @@ import collection.mutable * Picks up classes used in class files. * @author tonytv */ -private class AnalyzeClassVisitor extends ClassVisitor(Opcodes.ASM5) with AnnotationVisitorTrait with AttributeVisitorTrait { +private class AnalyzeClassVisitor extends ClassVisitor(Opcodes.ASM6) with AnnotationVisitorTrait with AttributeVisitorTrait { private var name : String = null protected val imports : ImportsSet = mutable.Set() protected var exportPackageAnnotation: Option[ExportPackageAnnotation] = None @@ -32,7 +32,7 @@ private class AnalyzeClassVisitor extends ClassVisitor(Opcodes.ASM5) with Annota imports ++= getClassName(Type.getType(desc)).toList AnalyzeSignatureVisitor.analyzeField(signature, this) - new FieldVisitor(Opcodes.ASM5) with SubVisitorTrait with AttributeVisitorTrait with AnnotationVisitorTrait { + new FieldVisitor(Opcodes.ASM6) with SubVisitorTrait with AttributeVisitorTrait with AnnotationVisitorTrait { val analyzeClassVisitor = AnalyzeClassVisitor.this override def visitAnnotation(desc: String, visible: Boolean): AnnotationVisitor = super.visitAnnotation(desc, visible) @@ -68,7 +68,7 @@ private class AnalyzeClassVisitor extends ClassVisitor(Opcodes.ASM5) with Annota def visitExportPackage(): AnnotationVisitor = { def defaultVersionValue[T](name: String) = classOf[Version].getMethod(name).getDefaultValue().asInstanceOf[T] - new AnnotationVisitor(Opcodes.ASM5) { + new AnnotationVisitor(Opcodes.ASM6) { var major: Int = defaultVersionValue("major") var minor: Int = defaultVersionValue("minor") var micro: Int = defaultVersionValue("micro") diff --git a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeMethodVisitor.scala b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeMethodVisitor.scala index 535ee2832c8..a8032b6a912 100644 --- a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeMethodVisitor.scala +++ b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeMethodVisitor.scala @@ -8,7 +8,7 @@ import org.objectweb.asm._ * @author tonytv */ private class AnalyzeMethodVisitor(val analyzeClassVisitor : AnalyzeClassVisitor) - extends MethodVisitor(Opcodes.ASM5) with AnnotationVisitorTrait with AttributeVisitorTrait with SubVisitorTrait { + extends MethodVisitor(Opcodes.ASM6) with AnnotationVisitorTrait with AttributeVisitorTrait with SubVisitorTrait { override def visitParameterAnnotation(parameter: Int, desc: String, visible: Boolean): AnnotationVisitor = super.visitParameterAnnotation(parameter, desc, visible) diff --git a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeSignatureVisitor.scala b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeSignatureVisitor.scala index 58a43b04d20..5bb8304cf1e 100644 --- a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeSignatureVisitor.scala +++ b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnalyzeSignatureVisitor.scala @@ -10,7 +10,7 @@ import org.objectweb.asm.signature.{SignatureReader, SignatureVisitor} */ private class AnalyzeSignatureVisitor(val analyzeClassVisitor: AnalyzeClassVisitor) - extends SignatureVisitor(Opcodes.ASM5) + extends SignatureVisitor(Opcodes.ASM6) with SubVisitorTrait { diff --git a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnnotationVisitorTrait.scala b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnnotationVisitorTrait.scala index 0ceaced1440..0bf6ee4a6b4 100644 --- a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnnotationVisitorTrait.scala +++ b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/AnnotationVisitorTrait.scala @@ -17,7 +17,7 @@ private trait AnnotationVisitorTrait { } def visitAnnotationDefault(): AnnotationVisitor = - new AnnotationVisitor(Opcodes.ASM5) { + new AnnotationVisitor(Opcodes.ASM6) { override def visit(name: String, value: AnyRef) {} override def visitEnum(name: String, desc: String, value: String) { diff --git a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/package.scala b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/package.scala index d217f720d1a..631884c58e3 100644 --- a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/package.scala +++ b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/classanalysis/package.scala @@ -8,7 +8,10 @@ package object classanalysis { type ImportsSet = mutable.Set[String] def internalNameToClassName(internalClassName: String) : Option[String] = { - getClassName(Type.getObjectType(internalClassName)) + internalClassName match { + case null => None + case _ => getClassName(Type.getObjectType(internalClassName)) + } } def getClassName(aType: Type): Option[String] = { diff --git a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/mojo/GenerateOsgiManifestMojo.scala b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/mojo/GenerateOsgiManifestMojo.scala index d66edf88702..67ce45ed7c6 100644 --- a/bundle-plugin/src/main/scala/com/yahoo/container/plugin/mojo/GenerateOsgiManifestMojo.scala +++ b/bundle-plugin/src/main/scala/com/yahoo/container/plugin/mojo/GenerateOsgiManifestMojo.scala @@ -210,7 +210,7 @@ class GenerateOsgiManifestMojo extends AbstractMojo { private def analyzeProjectClasses() : PackageTally = { val outputDirectory = new File(project.getBuild.getOutputDirectory) - val analyzedClasses = allDescendantFiles(outputDirectory).filter(_.getName.endsWith(".class")). + val analyzedClasses = allDescendantFiles(outputDirectory).filter(file => isClassToAnalyze(file.getName)). map(Analyze.analyzeClass) PackageTally.fromAnalyzedClassFiles(analyzedClasses) @@ -230,7 +230,7 @@ class GenerateOsgiManifestMojo extends AbstractMojo { for { entry <- toStream(jarFile.entries()) if !entry.isDirectory - if entry.getName.endsWith(".class") + if isClassToAnalyze(entry.getName) metaData = analyzeClass(jarFile, entry) } yield metaData @@ -278,6 +278,9 @@ object GenerateOsgiManifestMojo { } } + def isClassToAnalyze(name: String): Boolean = + name.endsWith(".class") && ! name.endsWith("module-info.class") + def emptyToNone(str: String) = Option(str) map {_.trim} filterNot {_.isEmpty} } diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ValidationOverrides.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ValidationOverrides.java index 11f9add6b25..441ef273a6f 100644 --- a/config-model-api/src/main/java/com/yahoo/config/application/api/ValidationOverrides.java +++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ValidationOverrides.java @@ -66,6 +66,11 @@ public class ValidationOverrides { return false; } + public static String toAllowMessage(ValidationId id) { + return "To allow this add <allow until='yyyy-mm-dd'>" + id + "</allow> to validation-overrides.xml" + + ", see https://docs.vespa.ai/documentation/reference/validation-overrides.html"; + } + /** Returns the XML form of this, or null if it was not created by fromXml, nor is empty */ public String xmlForm() { return xmlForm; } @@ -155,7 +160,9 @@ public class ValidationOverrides { /** Returns "validationId: message" */ @Override - public String getMessage() { return validationId + ": " + super.getMessage(); } + public String getMessage() { + return validationId + ": " + super.getMessage() + ". " + toAllowMessage(validationId); + } } diff --git a/config-model-fat/pom.xml b/config-model-fat/pom.xml index 3ef9925510c..649d8a37bf6 100644 --- a/config-model-fat/pom.xml +++ b/config-model-fat/pom.xml @@ -25,6 +25,13 @@ <artifactId>guava</artifactId> <version>13.0.1</version> </dependency> + <dependency> + <!-- TODO: can probably be removed. Added to get the same set of embedded deps with maven-bundle-plugin 3.5 as with 2.4. --> + <groupId>com.yahoo.vespa</groupId> + <artifactId>annotations</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> <groupId>com.yahoo.vespa</groupId> @@ -114,8 +121,6 @@ <plugin> <groupId>org.apache.felix</groupId> <artifactId>maven-bundle-plugin</artifactId> - <!-- version >= 2.5.0 causes java.lang.ArrayIndexOutOfBoundsException: 176 --> - <version>2.4.0</version> <extensions>true</extensions> <configuration> <instructions> diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java new file mode 100644 index 00000000000..effa261be3b --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java @@ -0,0 +1,674 @@ +package com.yahoo.searchdefinition.expressiontransforms; + +import com.google.common.base.Joiner; +import com.yahoo.collections.Pair; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.application.provider.FilesApplicationPackage; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.FeatureNames; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.tensor.serialization.TypedBinaryFormat; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.StringReader; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Base class for replacing instances of a pseudofeature for imported ML + * ranking models with native Vespa ranking expressions. + * + * @author bratseth + * @author lesters + */ +abstract class MLImportFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { + + ExpressionNode transformFromImportedModel(ImportedModel model, + ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles) { + // Add constants + Set<String> constantsReplacedByMacros = new HashSet<>(); + model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); + model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, + constantsReplacedByMacros, k, v)); + + // Find the specified expression + ImportedModel.Signature signature = chooseSignature(model, store.arguments().signature()); + String output = chooseOutput(signature, store.arguments().output()); + if (signature.skippedOutputs().containsKey(output)) { + String message = "Could not import model output '" + output + "'"; + if (!signature.skippedOutputs().get(output).isEmpty()) { + message += ": " + signature.skippedOutputs().get(output); + } + if (!signature.importWarnings().isEmpty()) { + message += ": " + String.join(", ", signature.importWarnings()); + } + throw new IllegalArgumentException(message); + } + + RankingExpression expression = model.expressions().get(output); + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); + verifyRequiredMacros(expression, model, profile, queryProfiles); + addGeneratedMacros(model, profile); + reduceBatchDimensions(expression, model, profile, queryProfiles); + + model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v)); + + store.writeConverted(expression); + return expression.getRoot(); + } + + ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { + for (Pair<String, Tensor> constant : store.readSmallConstants()) + profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); + + for (RankingConstant constant : store.readLargeConstants()) { + if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName())) + profile.getSearch().addRankingConstant(constant); + } + + for (Pair<String, RankingExpression> macro : store.readMacros()) { + addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond()); + } + + return store.readConverted().getRoot(); + } + + /** + * Returns the specified, existing signature, or the only signature if none is specified. + * Throws IllegalArgumentException in all other cases. + */ + private ImportedModel.Signature chooseSignature(ImportedModel importResult, Optional<String> signatureName) { + if ( ! signatureName.isPresent()) { + if (importResult.signatures().size() == 0) + throw new IllegalArgumentException("No signatures are available"); + if (importResult.signatures().size() > 1) + throw new IllegalArgumentException("Model has multiple signatures (" + + Joiner.on(", ").join(importResult.signatures().keySet()) + + "), one must be specified " + + "as a second argument to tensorflow()"); + return importResult.signatures().values().stream().findFirst().get(); + } + else { + ImportedModel.Signature signature = importResult.signatures().get(signatureName.get()); + if (signature == null) + throw new IllegalArgumentException("Model does not have the specified signature '" + + signatureName.get() + "'"); + return signature; + } + } + + /** + * Returns the specified, existing output expression, or the only output expression if no output name is specified. + * Throws IllegalArgumentException in all other cases. + */ + private String chooseOutput(ImportedModel.Signature signature, Optional<String> outputName) { + if ( ! outputName.isPresent()) { + if (signature.outputs().size() == 0) + throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature)); + if (signature.outputs().size() > 1) + throw new IllegalArgumentException(signature + " has multiple outputs (" + + Joiner.on(", ").join(signature.outputs().keySet()) + + "), one must be specified " + + "as a third argument to tensorflow()"); + return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get()); + } + else { + String output = signature.outputs().get(outputName.get()); + if (output == null) { + if (signature.skippedOutputs().containsKey(outputName.get())) + throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + + signature.skippedOutputs().get(outputName.get())); + else + throw new IllegalArgumentException("Model does not have the specified output '" + + outputName.get() + "'"); + } + return output; + } + } + + private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + store.writeSmallConstant(constantName, constantValue); + profile.addConstant(constantName, asValue(constantValue)); + } + + private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, + Set<String> constantsReplacedByMacros, + String constantName, Tensor constantValue) { + RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); + if (macroOverridingConstant != null) { + TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); + if ( ! macroType.equals(constantValue.type())) + throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + + typeMismatchExplanation(constantValue.type(), macroType)); + constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later + } + else { + Path constantPath = store.writeLargeConstant(constantName, constantValue); + if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { + profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), + constantPath.toString())); + } + } + } + + private void transformGeneratedMacro(ModelStore store, + Set<String> constantsReplacedByMacros, + String macroName, RankingExpression expression) { + + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); + store.writeMacro(macroName, expression); + } + + private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { + if (profile.getMacros().containsKey(macroName)) { + throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists."); + } + profile.addMacro(macroName, false); // todo: inline if only used once + RankProfile.Macro macro = profile.getMacros().get(macroName); + macro.setRankingExpression(expression); + macro.setTextualExpression(expression.getRoot().toString()); + } + + private String skippedOutputsDescription(ImportedModel.Signature signature) { + if (signature.skippedOutputs().isEmpty()) return ""; + StringBuilder b = new StringBuilder(": "); + signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); + return b.toString(); + } + + /** + * Verify that the macros referred in the given expression exists in the given rank profile, + * and return tensors of the types specified in requiredMacros. + */ + private void verifyRequiredMacros(RankingExpression expression, ImportedModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { + Set<String> macroNames = new HashSet<>(); + addMacroNamesIn(expression.getRoot(), macroNames, model); + for (String macroName : macroNames) { + TensorType requiredType = model.requiredMacros().get(macroName); + if (requiredType == null) continue; // Not a required macro + + RankProfile.Macro macro = profile.getMacros().get(macroName); + if (macro == null) + throw new IllegalArgumentException("Model refers input '" + macroName + + "' of type " + requiredType + " but this macro is not present in " + + profile); + // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second + // phase and summary features), as it may only resolve correctly given those bindings + // Or, probably better, annotate the macros with type constraints here and verify during general + // type verification + TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles)); + if ( actualType == null) + throw new IllegalArgumentException("Model refers input '" + macroName + + "' of type " + requiredType + + " which must be produced by a macro in the rank profile, but " + + "this macro references a feature which is not declared"); + if ( ! actualType.isAssignableTo(requiredType)) + throw new IllegalArgumentException("Model refers input '" + macroName + "'. " + + typeMismatchExplanation(requiredType, actualType)); + } + } + + private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) { + return "The required type of this is " + requiredType + ", but this macro returns " + actualType + + (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " + + "in query profile types - see the documentation." + : ""); + } + + /** + * Add the generated macros to the rank profile + */ + private void addGeneratedMacros(ImportedModel model, RankProfile profile) { + model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v)); + } + + /** + * Check if batch dimensions of inputs can be reduced out. If the input + * macro specifies that a single exemplar should be evaluated, we can + * reduce the batch dimension out. + */ + private void reduceBatchDimensions(RankingExpression expression, ImportedModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { + TypeContext<Reference> typeContext = profile.typeContext(queryProfiles); + TensorType typeBeforeReducing = expression.getRoot().type(typeContext); + + // Check generated macros for inputs to reduce + Set<String> macroNames = new HashSet<>(); + addMacroNamesIn(expression.getRoot(), macroNames, model); + for (String macroName : macroNames) { + if ( ! model.macros().containsKey(macroName)) { + continue; + } + RankProfile.Macro macro = profile.getMacros().get(macroName); + if (macro == null) { + throw new IllegalArgumentException("Model refers to generated macro '" + macroName + + "but this macro is not present in " + profile); + } + RankingExpression macroExpression = macro.getRankingExpression(); + macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext)); + } + + // Check expression for inputs to reduce + ExpressionNode root = expression.getRoot(); + root = reduceBatchDimensionsAtInput(root, model, typeContext); + TensorType typeAfterReducing = root.type(typeContext); + root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); + expression.setRoot(root); + } + + private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model, + TypeContext<Reference> typeContext) { + if (node instanceof TensorFunctionNode) { + TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); + if (tensorFunction instanceof Rename) { + List<ExpressionNode> children = ((TensorFunctionNode)node).children(); + if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) children.get(0); + if (model.requiredMacros().containsKey(referenceNode.getName())) { + return reduceBatchDimensionExpression(tensorFunction, typeContext); + } + } + } + } + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) node; + if (model.requiredMacros().containsKey(referenceNode.getName())) { + return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); + } + } + if (node instanceof CompositeNode) { + List<ExpressionNode> children = ((CompositeNode)node).children(); + List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); + for (ExpressionNode child : children) { + transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); + } + return ((CompositeNode)node).setChildren(transformedChildren); + } + return node; + } + + private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) { + TensorFunction result = function; + TensorType type = function.type(context); + if (type.dimensions().size() > 1) { + List<String> reduceDimensions = new ArrayList<>(); + for (TensorType.Dimension dimension : type.dimensions()) { + if (dimension.size().orElse(-1L) == 1) { + reduceDimensions.add(dimension.name()); + } + } + if (reduceDimensions.size() > 0) { + result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); + } + } + return new TensorFunctionNode(result); + } + + /** + * If batch dimensions have been reduced away above, bring them back here + * for any following computation of the tensor. + * Todo: determine when this is not necessary! + */ + private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { + if (after.equals(before)) { + return node; + } + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (TensorType.Dimension dimension : before.dimensions()) { + if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { + typeBuilder.indexed(dimension.name(), 1); + } + } + TensorType expandDimensionsType = typeBuilder.build(); + if (expandDimensionsType.dimensions().size() > 0) { + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); + Generate generatedFunction = new Generate(expandDimensionsType, + new GeneratorLambdaFunctionNode(expandDimensionsType, + generatedExpression) + .asLongListToDoubleOperator()); + Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply()); + return new TensorFunctionNode(expand); + } + return node; + } + + /** + * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. + * This method does that for the given expression and returns the result. + */ + private RankingExpression replaceConstantsByMacros(RankingExpression expression, + Set<String> constantsReplacedByMacros) { + if (constantsReplacedByMacros.isEmpty()) return expression; + return new RankingExpression(expression.getName(), + replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); + } + + private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { + if (node instanceof ReferenceNode) { + Reference reference = ((ReferenceNode)node).reference(); + if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { + String argument = reference.simpleArgument().get(); + if (constantsReplacedByMacros.contains(argument)) + return new ReferenceNode(argument); + } + } + if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above + CompositeNode composite = (CompositeNode)node; + return composite.setChildren(composite.children().stream() + .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros)) + .collect(Collectors.toList())); + } + return node; + } + + private void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) { + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode)node; + if (referenceNode.getOutput() == null) { // macro references cannot specify outputs + names.add(referenceNode.getName()); + if (model.macros().containsKey(referenceNode.getName())) { + addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model); + } + } + } + else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode)node).children()) + addMacroNamesIn(child, names, model); + } + } + + private Value asValue(Tensor tensor) { + if (tensor.type().rank() == 0) + return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors + else + return new TensorValue(tensor); + } + + /** + * Provides read/write access to the correct directories of the application package given by the feature arguments + */ + static class ModelStore { + + private final ApplicationPackage application; + private final FeatureArguments arguments; + + ModelStore(ApplicationPackage application, FeatureArguments arguments) { + this.application = application; + this.arguments = arguments; + } + + public FeatureArguments arguments() { return arguments; } + + public boolean hasStoredModel() { + try { + return application.getFile(arguments.expressionPath()).exists(); + } + catch (UnsupportedOperationException e) { + return false; + } + } + + /** + * Returns the directory which contains the source model to use for these arguments + */ + public File modelDir() { + return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath())); + } + + /** + * Adds this expression to the application package, such that it can be read later. + */ + void writeConverted(RankingExpression expression) { + application.getFile(arguments.expressionPath()) + .writeFile(new StringReader(expression.getRoot().toString())); + } + + /** Reads the previously stored ranking expression for these arguments */ + RankingExpression readConverted() { + try { + return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); + } + catch (IOException e) { + throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e); + } + catch (ParseException e) { + throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + } + } + + /** Adds this macro expression to the application package to it can be read later. */ + void writeMacro(String name, RankingExpression expression) { + application.getFile(arguments.macrosPath()).appendFile(name + "\t" + + expression.getRoot().toString() + "\n"); + } + + /** Reads the previously stored macro expressions for these arguments */ + List<Pair<String, RankingExpression>> readMacros() { + try { + ApplicationFile file = application.getFile(arguments.macrosPath()); + if (!file.exists()) return Collections.emptyList(); + + List<Pair<String, RankingExpression>> macros = new ArrayList<>(); + BufferedReader reader = new BufferedReader(file.createReader()); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + try { + RankingExpression expression = new RankingExpression(parts[1]); + macros.add(new Pair<>(name, expression)); + } + catch (ParseException e) { + throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + } + } + return macros; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Reads the information about all the large (aka ranking) constants stored in the application package + * (the constant value itself is replicated with file distribution). + */ + List<RankingConstant> readLargeConstants() { + try { + List<RankingConstant> constants = new ArrayList<>(); + for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) { + String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); + constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); + } + return constants; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Adds this constant to the application package as a file, + * such that it can be distributed using file distribution. + * + * @return the path to the stored constant, relative to the application package root + */ + Path writeLargeConstant(String name, Tensor constant) { + Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); + + // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: + Path constantPath = constantsPath.append(name + ".tbf"); + + // Remember the constant in a file we replicate in ZooKeeper + application.getFile(arguments.largeConstantsPath().append(name + ".constant")) + .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); + + // Write content explicitly as a file on the file system as this is distributed using file distribution + createIfNeeded(constantsPath); + IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); + return correct(constantPath); + } + + private List<Pair<String, Tensor>> readSmallConstants() { + try { + ApplicationFile file = application.getFile(arguments.smallConstantsPath()); + if (!file.exists()) return Collections.emptyList(); + + List<Pair<String, Tensor>> constants = new ArrayList<>(); + BufferedReader reader = new BufferedReader(file.createReader()); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + TensorType type = TensorType.fromSpec(parts[1]); + Tensor tensor = Tensor.from(type, parts[2]); + constants.add(new Pair<>(name, tensor)); + } + return constants; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Append this constant to the single file used for small constants distributed as config + */ + public void writeSmallConstant(String name, Tensor constant) { + // Secret file format for remembering constants: + application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" + + constant.type().toString() + "\t" + + constant.toString() + "\n"); + } + + /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */ + private Path correct(Path path) { + if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) + && ! path.elements().contains(FilesApplicationPackage.preprocessed)) { + return Path.fromString(FilesApplicationPackage.preprocessed).append(path); + } + else { + return path; + } + } + + private void createIfNeeded(Path path) { + File dir = application.getFileReference(path); + if ( ! dir.exists()) { + if (!dir.mkdirs()) + throw new IllegalStateException("Could not create " + dir); + } + } + + } + + /** Encapsulates the arguments to the import feature */ + static abstract class FeatureArguments { + + Path modelPath; + + /** Optional arguments */ + Optional<String> signature, output; + + /** Returns modelPath with slashes replaced by underscores */ + public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); } + + /** Returns relative path to this model below the "models/" dir in the application package */ + public Path modelPath() { return modelPath; } + public Optional<String> signature() { return signature; } + public Optional<String> output() { return output; } + + /** Path to the small constants file */ + public Path smallConstantsPath() { + return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt"); + } + + /** Path to the large (ranking) constants directory */ + public Path largeConstantsPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); + } + + /** Path to the macros file */ + public Path macrosPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt"); + } + + public Path expressionPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR + .append(modelPath).append("expressions").append(expressionFileName()); + } + + private String expressionFileName() { + StringBuilder fileName = new StringBuilder(); + signature.ifPresent(s -> fileName.append(s).append(".")); + output.ifPresent(s -> fileName.append(s).append(".")); + if (fileName.length() == 0) // single signature and output + fileName.append("single."); + fileName.append("expression"); + return fileName.toString(); + } + + Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { + if (argumentIndex >= arguments.expressions().size()) + return Optional.empty(); + return Optional.of(asString(arguments.expressions().get(argumentIndex))); + } + + String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); + } + + private String stripQuotes(String s) { + if ( ! isQuoteSign(s.codePointAt(0))) return s; + if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) + throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); + return s.substring(1, s.length()-1); + } + + private boolean isQuoteSign(int c) { + return c == '\'' || c == '"'; + } + + } +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index 1c41ad8284e..44eeb364603 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -2,58 +2,20 @@ package com.yahoo.searchdefinition.expressiontransforms; -import com.google.common.base.Joiner; -import com.yahoo.collections.Pair; -import com.yahoo.config.application.api.ApplicationFile; -import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.model.application.provider.FilesApplicationPackage; -import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchdefinition.FeatureNames; import com.yahoo.searchdefinition.RankProfile; -import com.yahoo.searchdefinition.RankingConstant; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxImporter; -import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxModel; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.TypeContext; -import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Join; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.ScalarFunctions; -import com.yahoo.tensor.functions.TensorFunction; -import com.yahoo.tensor.serialization.TypedBinaryFormat; -import java.io.BufferedReader; -import java.io.File; -import java.io.IOException; -import java.io.StringReader; import java.io.UncheckedIOException; -import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; /** * Replaces instances of the onnx(model-path, output) @@ -63,12 +25,12 @@ import java.util.stream.Collectors; * @author bratseth * @author lesters */ -public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { +public class OnnxFeatureConverter extends MLImportFeatureConverter { private final OnnxImporter onnxImporter = new OnnxImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ - private final Map<Path, OnnxModel> importedModels = new HashMap<>(); + private final Map<Path, ImportedModel> importedModels = new HashMap<>(); @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { @@ -84,7 +46,8 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans if ( ! feature.getName().equals("onnx")) return feature; try { - ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments()); + FeatureArguments arguments = new OnnxFeatureArguments(feature.getArguments()); + ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments); if ( ! store.hasStoredModel()) // not converted yet - access Onnx model files return transformFromOnnxModel(store, context.rankProfile(), context.queryProfiles()); else @@ -98,597 +61,24 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans private ExpressionNode transformFromOnnxModel(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles) { - OnnxModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), + ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), k -> onnxImporter.importModel(store.arguments().modelName(), - store.onnxModelDir())); - - // Add constants - Set<String> constantsReplacedByMacros = new HashSet<>(); - model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); - model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, - constantsReplacedByMacros, k, v)); - - // Find the specified expression - String output = chooseOutput(model, store.arguments().output()); - if (model.skippedOutputs().containsKey(output)) { - String message = "Could not import Onnx model output '" + output + "'"; - if (!model.skippedOutputs().get(output).isEmpty()) { - message += ": " + model.skippedOutputs().get(output); - } - if (!model.importWarnings().isEmpty()) { - message += ": " + String.join(", ", model.importWarnings()); - } - throw new IllegalArgumentException(message); - } - - RankingExpression expression = model.expressions().get(output); - expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); - verifyRequiredMacros(expression, model, profile, queryProfiles); - addGeneratedMacros(model, profile); - reduceBatchDimensions(expression, model, profile, queryProfiles); - - model.macros().forEach((k, v) -> transformGeneratedMacro(store, profile, constantsReplacedByMacros, k, v)); - - store.writeConverted(expression); - return expression.getRoot(); - } - - private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { - for (Pair<String, Tensor> constant : store.readSmallConstants()) - profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); - - for (RankingConstant constant : store.readLargeConstants()) { - if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName())) - profile.getSearch().addRankingConstant(constant); - } - - for (Pair<String, RankingExpression> macro : store.readMacros()) { - addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond()); - } - - return store.readConverted().getRoot(); - } - - /** - * Returns the specified, existing output expression, or the only output expression if no output name is specified. - * Throws IllegalArgumentException in all other cases. - */ - private String chooseOutput(OnnxModel model, Optional<String> outputName) { - if ( ! outputName.isPresent()) { - if (model.outputs().size() == 0) - throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(model)); - if (model.outputs().size() > 1) - throw new IllegalArgumentException("Onnx model has multiple outputs (" + - Joiner.on(", ").join(model.outputs().keySet()) + - "), one must be specified " + - "as a second argument to onnx()"); - return model.outputs().get(model.outputs().keySet().stream().findFirst().get()); - } - else { - String output = model.outputs().get(outputName.get()); - if (output == null) { - if (model.skippedOutputs().containsKey(outputName.get())) - throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + - model.skippedOutputs().get(outputName.get())); - else - throw new IllegalArgumentException("Model does not have the specified output '" + - outputName.get() + "'"); - } - return output; - } + store.modelDir())); + return transformFromImportedModel(model, store, profile, queryProfiles); } - private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { - store.writeSmallConstant(constantName, constantValue); - profile.addConstant(constantName, asValue(constantValue)); - } - - private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, - Set<String> constantsReplacedByMacros, - String constantName, Tensor constantValue) { - RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); - if (macroOverridingConstant != null) { - TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); - if ( ! macroType.equals(constantValue.type())) - throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + - "The required type of this is " + constantValue.type() + - ", but the macro returns " + macroType); - constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later - } - else { - Path constantPath = store.writeLargeConstant(constantName, constantValue); - if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { - profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), - constantPath.toString())); - } - } - } - - private void transformGeneratedMacro(ModelStore store, RankProfile profile, - Set<String> constantsReplacedByMacros, - String macroName, RankingExpression expression) { - - expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); - store.writeMacro(macroName, expression); - } - - private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { - if (profile.getMacros().containsKey(macroName)) { - throw new IllegalArgumentException("Generated Onnx macro '" + macroName + "' already exists."); - } - profile.addMacro(macroName, false); // todo: inline if only used once - RankProfile.Macro macro = profile.getMacros().get(macroName); - macro.setRankingExpression(expression); - macro.setTextualExpression(expression.getRoot().toString()); - } - - private String skippedOutputsDescription(OnnxModel model) { - if (model.skippedOutputs().isEmpty()) return ""; - StringBuilder b = new StringBuilder(": "); - model.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); - return b.toString(); - } - - /** - * Verify that the macros referred in the given expression exists in the given rank profile, - * and return tensors of the types specified in requiredMacros. - */ - private void verifyRequiredMacros(RankingExpression expression, OnnxModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { - Set<String> macroNames = new HashSet<>(); - addMacroNamesIn(expression.getRoot(), macroNames, model); - for (String macroName : macroNames) { - TensorType requiredType = model.requiredMacros().get(macroName); - if (requiredType == null) continue; // Not a required macro - - RankProfile.Macro macro = profile.getMacros().get(macroName); - if (macro == null) - throw new IllegalArgumentException("Model refers Placeholder '" + macroName + - "' of type " + requiredType + " but this macro is not present in " + - profile); - // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second - // phase and summary features), as it may only resolve correctly given those bindings - // Or, probably better, annotate the macros with type constraints here and verify during general - // type verification - TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles)); - if ( actualType == null) - throw new IllegalArgumentException("Model refers input '" + macroName + - "' of type " + requiredType + - " which must be produced by a macro in the rank profile, but " + - "this macro references a feature which is not declared"); - if ( ! actualType.isAssignableTo(requiredType)) - throw new IllegalArgumentException("Model refers input '" + macroName + - "' of type " + requiredType + - " which must be produced by a macro in the rank profile, but " + - "this macro produces type " + actualType); - } - } - - /** - * Add the generated macros to the rank profile - */ - private void addGeneratedMacros(OnnxModel model, RankProfile profile) { - model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v)); - } - - /** - * Check if batch dimensions of inputs can be reduced out. If the input - * macro specifies that a single exemplar should be evaluated, we can - * reduce the batch dimension out. - */ - private void reduceBatchDimensions(RankingExpression expression, OnnxModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { - TypeContext<Reference> typeContext = profile.typeContext(queryProfiles); - TensorType typeBeforeReducing = expression.getRoot().type(typeContext); - - // Check generated macros for inputs to reduce - Set<String> macroNames = new HashSet<>(); - addMacroNamesIn(expression.getRoot(), macroNames, model); - for (String macroName : macroNames) { - if ( ! model.macros().containsKey(macroName)) { - continue; - } - RankProfile.Macro macro = profile.getMacros().get(macroName); - if (macro == null) { - throw new IllegalArgumentException("Model refers to generated macro '" + macroName + - "but this macro is not present in " + profile); - } - RankingExpression macroExpression = macro.getRankingExpression(); - macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext)); - } - - // Check expression for inputs to reduce - ExpressionNode root = expression.getRoot(); - root = reduceBatchDimensionsAtInput(root, model, typeContext); - TensorType typeAfterReducing = root.type(typeContext); - root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); - expression.setRoot(root); - } - - private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, OnnxModel model, - TypeContext<Reference> typeContext) { - if (node instanceof TensorFunctionNode) { - TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); - if (tensorFunction instanceof Rename) { - List<ExpressionNode> children = ((TensorFunctionNode)node).children(); - if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.requiredMacros().containsKey(referenceNode.getName())) { - return reduceBatchDimensionExpression(tensorFunction, typeContext); - } - } - } - } - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) node; - if (model.requiredMacros().containsKey(referenceNode.getName())) { - return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); - } - } - if (node instanceof CompositeNode) { - List<ExpressionNode> children = ((CompositeNode)node).children(); - List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); - for (ExpressionNode child : children) { - transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); - } - return ((CompositeNode)node).setChildren(transformedChildren); - } - return node; - } - - private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) { - TensorFunction result = function; - TensorType type = function.type(context); - if (type.dimensions().size() > 1) { - List<String> reduceDimensions = new ArrayList<>(); - for (TensorType.Dimension dimension : type.dimensions()) { - if (dimension.size().orElse(-1L) == 1) { - reduceDimensions.add(dimension.name()); - } - } - if (reduceDimensions.size() > 0) { - result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); - } - } - return new TensorFunctionNode(result); - } - - /** - * If batch dimensions have been reduced away above, bring them back here - * for any following computation of the tensor. - * Todo: determine when this is not necessary! - */ - private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { - if (after.equals(before)) { - return node; - } - TensorType.Builder typeBuilder = new TensorType.Builder(); - for (TensorType.Dimension dimension : before.dimensions()) { - if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { - typeBuilder.indexed(dimension.name(), 1); - } - } - TensorType expandDimensionsType = typeBuilder.build(); - if (expandDimensionsType.dimensions().size() > 0) { - ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); - Generate generatedFunction = new Generate(expandDimensionsType, - new GeneratorLambdaFunctionNode(expandDimensionsType, - generatedExpression) - .asLongListToDoubleOperator()); - Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply()); - return new TensorFunctionNode(expand); - } - return node; - } - - /** - * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. - * This method does that for the given expression and returns the result. - */ - private RankingExpression replaceConstantsByMacros(RankingExpression expression, - Set<String> constantsReplacedByMacros) { - if (constantsReplacedByMacros.isEmpty()) return expression; - return new RankingExpression(expression.getName(), - replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); - } - - private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { - if (node instanceof ReferenceNode) { - Reference reference = ((ReferenceNode)node).reference(); - if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { - String argument = reference.simpleArgument().get(); - if (constantsReplacedByMacros.contains(argument)) - return new ReferenceNode(argument); - } - } - if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above - CompositeNode composite = (CompositeNode)node; - return composite.setChildren(composite.children().stream() - .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros)) - .collect(Collectors.toList())); - } - return node; - } - - private void addMacroNamesIn(ExpressionNode node, Set<String> names, OnnxModel model) { - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode)node; - if (referenceNode.getOutput() == null) { // macro references cannot specify outputs - names.add(referenceNode.getName()); - if (model.macros().containsKey(referenceNode.getName())) { - addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model); - } - } - } - else if (node instanceof CompositeNode) { - for (ExpressionNode child : ((CompositeNode)node).children()) - addMacroNamesIn(child, names, model); - } - } - - private Value asValue(Tensor tensor) { - if (tensor.type().rank() == 0) - return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors - else - return new TensorValue(tensor); - } - - /** - * Provides read/write access to the correct directories of the application package given by the feature arguments - */ - private static class ModelStore { - - private final ApplicationPackage application; - private final FeatureArguments arguments; - - public ModelStore(ApplicationPackage application, Arguments arguments) { - this.application = application; - this.arguments = new FeatureArguments(arguments); - } - - public FeatureArguments arguments() { return arguments; } - - public boolean hasStoredModel() { - try { - return application.getFile(arguments.expressionPath()).exists(); - } - catch (UnsupportedOperationException e) { - return false; - } - } - - /** - * Returns the directory which contains the source model to use for these arguments - */ - public File onnxModelDir() { - return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath())); - } - - /** - * Adds this expression to the application package, such that it can be read later. - */ - public void writeConverted(RankingExpression expression) { - application.getFile(arguments.expressionPath()) - .writeFile(new StringReader(expression.getRoot().toString())); - } - - /** Reads the previously stored ranking expression for these arguments */ - public RankingExpression readConverted() { - try { - return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); - } - catch (IOException e) { - throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); - } - } - - /** Adds this macro expression to the application package to it can be read later. */ - public void writeMacro(String name, RankingExpression expression) { - application.getFile(arguments.macrosPath()).appendFile(name + "\t" + - expression.getRoot().toString() + "\n"); - } - - /** Reads the previously stored macro expressions for these arguments */ - public List<Pair<String, RankingExpression>> readMacros() { - try { - ApplicationFile file = application.getFile(arguments.macrosPath()); - if (!file.exists()) return Collections.emptyList(); - - List<Pair<String, RankingExpression>> macros = new ArrayList<>(); - BufferedReader reader = new BufferedReader(file.createReader()); - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String name = parts[0]; - try { - RankingExpression expression = new RankingExpression(parts[1]); - macros.add(new Pair<>(name, expression)); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); - } - } - return macros; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Reads the information about all the large (aka ranking) constants stored in the application package - * (the constant value itself is replicated with file distribution). - */ - public List<RankingConstant> readLargeConstants() { - try { - List<RankingConstant> constants = new ArrayList<>(); - for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) { - String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); - constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); - } - return constants; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Adds this constant to the application package as a file, - * such that it can be distributed using file distribution. - * - * @return the path to the stored constant, relative to the application package root - */ - public Path writeLargeConstant(String name, Tensor constant) { - Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); - - // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: - Path constantPath = constantsPath.append(name + ".tbf"); - - // Remember the constant in a file we replicate in ZooKeeper - application.getFile(arguments.largeConstantsPath().append(name + ".constant")) - .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); - - // Write content explicitly as a file on the file system as this is distributed using file distribution - createIfNeeded(constantsPath); - IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); - return correct(constantPath); - } - - private List<Pair<String, Tensor>> readSmallConstants() { - try { - ApplicationFile file = application.getFile(arguments.smallConstantsPath()); - if (!file.exists()) return Collections.emptyList(); - - List<Pair<String, Tensor>> constants = new ArrayList<>(); - BufferedReader reader = new BufferedReader(file.createReader()); - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String name = parts[0]; - TensorType type = TensorType.fromSpec(parts[1]); - Tensor tensor = Tensor.from(type, parts[2]); - constants.add(new Pair<>(name, tensor)); - } - return constants; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Append this constant to the single file used for small constants distributed as config - */ - public void writeSmallConstant(String name, Tensor constant) { - // Secret file format for remembering constants: - application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" + - constant.type().toString() + "\t" + - constant.toString() + "\n"); - } - - /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */ - private Path correct(Path path) { - if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) - && ! path.elements().contains(FilesApplicationPackage.preprocessed)) { - return Path.fromString(FilesApplicationPackage.preprocessed).append(path); - } - else { - return path; - } - } - - private void createIfNeeded(Path path) { - File dir = application.getFileReference(path); - if ( ! dir.exists()) { - if (!dir.mkdirs()) - throw new IllegalStateException("Could not create " + dir); - } - } - - } - - /** Encapsulates the 1, 2 or 3 arguments to a onnx feature */ - private static class FeatureArguments { - - private final Path modelPath; - - /** Optional arguments */ - private final Optional<String> output; - - public FeatureArguments(Arguments arguments) { + static class OnnxFeatureArguments extends FeatureArguments { + public OnnxFeatureArguments(Arguments arguments) { if (arguments.isEmpty()) throw new IllegalArgumentException("An onnx node must take an argument pointing to " + - "the onnx model directory under [application]/models"); + "the tensorflow model directory under [application]/models"); if (arguments.expressions().size() > 3) throw new IllegalArgumentException("An onnx feature can have at most 2 arguments"); modelPath = Path.fromString(asString(arguments.expressions().get(0))); output = optionalArgument(1, arguments); + signature = Optional.of("default"); } - - /** Returns modelPath with slashes replaced by underscores */ - public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); } - - /** Returns relative path to this model below the "models/" dir in the application package */ - public Path modelPath() { return modelPath; } - public Optional<String> output() { return output; } - - /** Path to the small constants file */ - public Path smallConstantsPath() { - return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt"); - } - - /** Path to the large (ranking) constants directory */ - public Path largeConstantsPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); - } - - /** Path to the macros file */ - public Path macrosPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt"); - } - - public Path expressionPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR - .append(modelPath).append("expressions").append(expressionFileName()); - } - - private String expressionFileName() { - StringBuilder fileName = new StringBuilder(); - output.ifPresent(s -> fileName.append(s).append(".")); - if (fileName.length() == 0) // single signature and output - fileName.append("single."); - fileName.append("expression"); - return fileName.toString(); - } - - private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { - if (argumentIndex >= arguments.expressions().size()) - return Optional.empty(); - return Optional.of(asString(arguments.expressions().get(argumentIndex))); - } - - private String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as onnx argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); - } - - private String stripQuotes(String s) { - if ( ! isQuoteSign(s.codePointAt(0))) return s; - if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) - throw new IllegalArgumentException("onnx argument [" + s + "] is missing endquote"); - return s.substring(1, s.length()-1); - } - - private boolean isQuoteSign(int c) { - return c == '\'' || c == '"'; - } - } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 41da32f64c3..27e1ad51b33 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -1,59 +1,19 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; -import com.google.common.base.Joiner; -import com.yahoo.collections.Pair; -import com.yahoo.config.application.api.ApplicationFile; -import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.model.application.provider.FilesApplicationPackage; -import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchdefinition.FeatureNames; import com.yahoo.searchdefinition.RankProfile; -import com.yahoo.searchdefinition.RankingConstant; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel.Signature; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.TypeContext; -import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Join; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.ScalarFunctions; -import com.yahoo.tensor.functions.TensorFunction; -import com.yahoo.tensor.serialization.TypedBinaryFormat; -import java.io.BufferedReader; -import java.io.File; -import java.io.IOException; -import java.io.StringReader; import java.io.UncheckedIOException; -import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; -import java.util.List; import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; /** * Replaces instances of the tensorflow(model-path, signature, output) @@ -62,12 +22,12 @@ import java.util.stream.Collectors; * * @author bratseth */ -public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { +public class TensorFlowFeatureConverter extends MLImportFeatureConverter { private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ - private final Map<Path, TensorFlowModel> importedModels = new HashMap<>(); + private final Map<Path, ImportedModel> importedModels = new HashMap<>(); @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { @@ -83,7 +43,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil if ( ! feature.getName().equals("tensorflow")) return feature; try { - ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments()); + FeatureArguments arguments = new TensorFlowFeatureArguments(feature.getArguments()); + ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments); if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files return transformFromTensorFlowModel(store, context.rankProfile(), context.queryProfiles()); else @@ -95,565 +56,19 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } private ExpressionNode transformFromTensorFlowModel(ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles) { - TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), - k -> tensorFlowImporter.importModel(store.arguments().modelName(), - store.tensorFlowModelDir())); - - // Add constants - Set<String> constantsReplacedByMacros = new HashSet<>(); - model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); - model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, - constantsReplacedByMacros, k, v)); - - // Find the specified expression - Signature signature = chooseSignature(model, store.arguments().signature()); - String output = chooseOutput(signature, store.arguments().output()); - if (signature.skippedOutputs().containsKey(output)) { - String message = "Could not import TensorFlow model output '" + output + "'"; - if (!signature.skippedOutputs().get(output).isEmpty()) { - message += ": " + signature.skippedOutputs().get(output); - } - if (!signature.importWarnings().isEmpty()) { - message += ": " + String.join(", ", signature.importWarnings()); - } - throw new IllegalArgumentException(message); - } - - RankingExpression expression = model.expressions().get(output); - expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); - verifyRequiredMacros(expression, model, profile, queryProfiles); - addGeneratedMacros(model, profile); - reduceBatchDimensions(expression, model, profile, queryProfiles); - - model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v)); - - store.writeConverted(expression); - return expression.getRoot(); - } - - private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { - for (Pair<String, Tensor> constant : store.readSmallConstants()) - profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); - - for (RankingConstant constant : store.readLargeConstants()) { - if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName())) - profile.getSearch().addRankingConstant(constant); - } - - for (Pair<String, RankingExpression> macro : store.readMacros()) { - addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond()); - } - - return store.readConverted().getRoot(); - } - - /** - * Returns the specified, existing signature, or the only signature if none is specified. - * Throws IllegalArgumentException in all other cases. - */ - private Signature chooseSignature(TensorFlowModel importResult, Optional<String> signatureName) { - if ( ! signatureName.isPresent()) { - if (importResult.signatures().size() == 0) - throw new IllegalArgumentException("No signatures are available"); - if (importResult.signatures().size() > 1) - throw new IllegalArgumentException("Model has multiple signatures (" + - Joiner.on(", ").join(importResult.signatures().keySet()) + - "), one must be specified " + - "as a second argument to tensorflow()"); - return importResult.signatures().values().stream().findFirst().get(); - } - else { - Signature signature = importResult.signatures().get(signatureName.get()); - if (signature == null) - throw new IllegalArgumentException("Model does not have the specified signature '" + - signatureName.get() + "'"); - return signature; - } - } - - /** - * Returns the specified, existing output expression, or the only output expression if no output name is specified. - * Throws IllegalArgumentException in all other cases. - */ - private String chooseOutput(Signature signature, Optional<String> outputName) { - if ( ! outputName.isPresent()) { - if (signature.outputs().size() == 0) - throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature)); - if (signature.outputs().size() > 1) - throw new IllegalArgumentException(signature + " has multiple outputs (" + - Joiner.on(", ").join(signature.outputs().keySet()) + - "), one must be specified " + - "as a third argument to tensorflow()"); - return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get()); - } - else { - String output = signature.outputs().get(outputName.get()); - if (output == null) { - if (signature.skippedOutputs().containsKey(outputName.get())) - throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + - signature.skippedOutputs().get(outputName.get())); - else - throw new IllegalArgumentException("Model does not have the specified output '" + - outputName.get() + "'"); - } - return output; - } - } - - private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { - store.writeSmallConstant(constantName, constantValue); - profile.addConstant(constantName, asValue(constantValue)); - } - - private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, - Set<String> constantsReplacedByMacros, - String constantName, Tensor constantValue) { - RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); - if (macroOverridingConstant != null) { - TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); - if ( ! macroType.equals(constantValue.type())) - throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + - typeMismatchExplanation(constantValue.type(), macroType)); - constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later - } - else { - Path constantPath = store.writeLargeConstant(constantName, constantValue); - if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { - profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), - constantPath.toString())); - } - } - } - - private void transformGeneratedMacro(ModelStore store, - Set<String> constantsReplacedByMacros, - String macroName, RankingExpression expression) { - - expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); - store.writeMacro(macroName, expression); - } - - private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { - if (profile.getMacros().containsKey(macroName)) { - throw new IllegalArgumentException("Generated TensorFlow macro '" + macroName + "' already exists."); - } - profile.addMacro(macroName, false); // todo: inline if only used once - RankProfile.Macro macro = profile.getMacros().get(macroName); - macro.setRankingExpression(expression); - macro.setTextualExpression(expression.getRoot().toString()); - } - - private String skippedOutputsDescription(TensorFlowModel.Signature signature) { - if (signature.skippedOutputs().isEmpty()) return ""; - StringBuilder b = new StringBuilder(": "); - signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); - return b.toString(); + RankProfile profile, + QueryProfileRegistry queryProfiles) { + ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), + k -> tensorFlowImporter.importModel(store.arguments().modelName(), + store.modelDir())); + return transformFromImportedModel(model, store, profile, queryProfiles); } - /** - * Verify that the macros referred in the given expression exists in the given rank profile, - * and return tensors of the types specified in requiredMacros. - */ - private void verifyRequiredMacros(RankingExpression expression, TensorFlowModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { - Set<String> macroNames = new HashSet<>(); - addMacroNamesIn(expression.getRoot(), macroNames, model); - for (String macroName : macroNames) { - TensorType requiredType = model.requiredMacros().get(macroName); - if (requiredType == null) continue; // Not a required macro - - RankProfile.Macro macro = profile.getMacros().get(macroName); - if (macro == null) - throw new IllegalArgumentException("Model refers placeholder '" + macroName + - "' of type " + requiredType + " but this macro is not present in " + - profile); - // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second - // phase and summary features), as it may only resolve correctly given those bindings - // Or, probably better, annotate the macros with type constraints here and verify during general - // type verification - TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles)); - if ( actualType == null) - throw new IllegalArgumentException("Model refers placeholder '" + macroName + - "' of type " + requiredType + - " which must be produced by a macro in the rank profile, but " + - "this macro references a feature which is not declared"); - if ( ! actualType.isAssignableTo(requiredType)) - throw new IllegalArgumentException("Model refers placeholder '" + macroName + "'. " + - typeMismatchExplanation(requiredType, actualType)); - } - } - - private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) { - return "The required type of this is " + requiredType + ", but this macro returns " + actualType + - (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " + - "in query profile types - see the documentation." - : ""); - } - - /** - * Add the generated macros to the rank profile - */ - private void addGeneratedMacros(TensorFlowModel model, RankProfile profile) { - model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v)); - } - - /** - * Check if batch dimensions of inputs can be reduced out. If the input - * macro specifies that a single exemplar should be evaluated, we can - * reduce the batch dimension out. - */ - private void reduceBatchDimensions(RankingExpression expression, TensorFlowModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { - TypeContext<Reference> typeContext = profile.typeContext(queryProfiles); - TensorType typeBeforeReducing = expression.getRoot().type(typeContext); - - // Check generated macros for inputs to reduce - Set<String> macroNames = new HashSet<>(); - addMacroNamesIn(expression.getRoot(), macroNames, model); - for (String macroName : macroNames) { - if ( ! model.macros().containsKey(macroName)) { - continue; - } - RankProfile.Macro macro = profile.getMacros().get(macroName); - if (macro == null) { - throw new IllegalArgumentException("Model refers to generated macro '" + macroName + - "but this macro is not present in " + profile); - } - RankingExpression macroExpression = macro.getRankingExpression(); - macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext)); - } - - // Check expression for inputs to reduce - ExpressionNode root = expression.getRoot(); - root = reduceBatchDimensionsAtInput(root, model, typeContext); - TensorType typeAfterReducing = root.type(typeContext); - root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); - expression.setRoot(root); - } - - private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, TensorFlowModel model, - TypeContext<Reference> typeContext) { - if (node instanceof TensorFunctionNode) { - TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); - if (tensorFunction instanceof Rename) { - List<ExpressionNode> children = ((TensorFunctionNode)node).children(); - if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.requiredMacros().containsKey(referenceNode.getName())) { - return reduceBatchDimensionExpression(tensorFunction, typeContext); - } - } - } - } - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) node; - if (model.requiredMacros().containsKey(referenceNode.getName())) { - return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); - } - } - if (node instanceof CompositeNode) { - List<ExpressionNode> children = ((CompositeNode)node).children(); - List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); - for (ExpressionNode child : children) { - transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); - } - return ((CompositeNode)node).setChildren(transformedChildren); - } - return node; - } - - private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) { - TensorFunction result = function; - TensorType type = function.type(context); - if (type.dimensions().size() > 1) { - List<String> reduceDimensions = new ArrayList<>(); - for (TensorType.Dimension dimension : type.dimensions()) { - if (dimension.size().orElse(-1L) == 1) { - reduceDimensions.add(dimension.name()); - } - } - if (reduceDimensions.size() > 0) { - result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); - } - } - return new TensorFunctionNode(result); - } - - /** - * If batch dimensions have been reduced away above, bring them back here - * for any following computation of the tensor. - * Todo: determine when this is not necessary! - */ - private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { - if (after.equals(before)) { - return node; - } - TensorType.Builder typeBuilder = new TensorType.Builder(); - for (TensorType.Dimension dimension : before.dimensions()) { - if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { - typeBuilder.indexed(dimension.name(), 1); - } - } - TensorType expandDimensionsType = typeBuilder.build(); - if (expandDimensionsType.dimensions().size() > 0) { - ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); - Generate generatedFunction = new Generate(expandDimensionsType, - new GeneratorLambdaFunctionNode(expandDimensionsType, - generatedExpression) - .asLongListToDoubleOperator()); - Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply()); - return new TensorFunctionNode(expand); - } - return node; - } - - /** - * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. - * This method does that for the given expression and returns the result. - */ - private RankingExpression replaceConstantsByMacros(RankingExpression expression, - Set<String> constantsReplacedByMacros) { - if (constantsReplacedByMacros.isEmpty()) return expression; - return new RankingExpression(expression.getName(), - replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); - } - - private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { - if (node instanceof ReferenceNode) { - Reference reference = ((ReferenceNode)node).reference(); - if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { - String argument = reference.simpleArgument().get(); - if (constantsReplacedByMacros.contains(argument)) - return new ReferenceNode(argument); - } - } - if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above - CompositeNode composite = (CompositeNode)node; - return composite.setChildren(composite.children().stream() - .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros)) - .collect(Collectors.toList())); - } - return node; - } - - private void addMacroNamesIn(ExpressionNode node, Set<String> names, TensorFlowModel model) { - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode)node; - if (referenceNode.getOutput() == null) { // macro references cannot specify outputs - names.add(referenceNode.getName()); - if (model.macros().containsKey(referenceNode.getName())) { - addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model); - } - } - } - else if (node instanceof CompositeNode) { - for (ExpressionNode child : ((CompositeNode)node).children()) - addMacroNamesIn(child, names, model); - } - } - - private Value asValue(Tensor tensor) { - if (tensor.type().rank() == 0) - return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors - else - return new TensorValue(tensor); - } - - /** - * Provides read/write access to the correct directories of the application package given by the feature arguments - */ - private static class ModelStore { - - private final ApplicationPackage application; - private final FeatureArguments arguments; - - public ModelStore(ApplicationPackage application, Arguments arguments) { - this.application = application; - this.arguments = new FeatureArguments(arguments); - } - - - - public FeatureArguments arguments() { return arguments; } - - public boolean hasStoredModel() { - try { - return application.getFile(arguments.expressionPath()).exists(); - } - catch (UnsupportedOperationException e) { - return false; - } - } - - /** - * Returns the directory which (if hasTensorFlowModels is true) - * contains the source model to use for these arguments - */ - public File tensorFlowModelDir() { - return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath())); - } - - /** - * Adds this expression to the application package, such that it can be read later. - */ - public void writeConverted(RankingExpression expression) { - application.getFile(arguments.expressionPath()) - .writeFile(new StringReader(expression.getRoot().toString())); - } - - /** Reads the previously stored ranking expression for these arguments */ - public RankingExpression readConverted() { - try { - return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); - } - catch (IOException e) { - throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); - } - } - - /** Adds this macro expression to the application package to it can be read later. */ - public void writeMacro(String name, RankingExpression expression) { - application.getFile(arguments.macrosPath()).appendFile(name + "\t" + - expression.getRoot().toString() + "\n"); - } - - /** Reads the previously stored macro expressions for these arguments */ - public List<Pair<String, RankingExpression>> readMacros() { - try { - ApplicationFile file = application.getFile(arguments.macrosPath()); - if (!file.exists()) return Collections.emptyList(); - - List<Pair<String, RankingExpression>> macros = new ArrayList<>(); - BufferedReader reader = new BufferedReader(file.createReader()); - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String name = parts[0]; - try { - RankingExpression expression = new RankingExpression(parts[1]); - macros.add(new Pair<>(name, expression)); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); - } - } - return macros; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Reads the information about all the large (aka ranking) constants stored in the application package - * (the constant value itself is replicated with file distribution). - */ - public List<RankingConstant> readLargeConstants() { - try { - List<RankingConstant> constants = new ArrayList<>(); - for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) { - String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); - constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); - } - return constants; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Adds this constant to the application package as a file, - * such that it can be distributed using file distribution. - * - * @return the path to the stored constant, relative to the application package root - */ - public Path writeLargeConstant(String name, Tensor constant) { - Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); - - // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: - Path constantPath = constantsPath.append(name + ".tbf"); - - // Remember the constant in a file we replicate in ZooKeeper - application.getFile(arguments.largeConstantsPath().append(name + ".constant")) - .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); - - // Write content explicitly as a file on the file system as this is distributed using file distribution - createIfNeeded(constantsPath); - IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); - return correct(constantPath); - } - - private List<Pair<String, Tensor>> readSmallConstants() { - try { - ApplicationFile file = application.getFile(arguments.smallConstantsPath()); - if (!file.exists()) return Collections.emptyList(); - - List<Pair<String, Tensor>> constants = new ArrayList<>(); - BufferedReader reader = new BufferedReader(file.createReader()); - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String name = parts[0]; - TensorType type = TensorType.fromSpec(parts[1]); - Tensor tensor = Tensor.from(type, parts[2]); - constants.add(new Pair<>(name, tensor)); - } - return constants; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Append this constant to the single file used for small constants distributed as config - */ - public void writeSmallConstant(String name, Tensor constant) { - // Secret file format for remembering constants: - application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" + - constant.type().toString() + "\t" + - constant.toString() + "\n"); - } - - /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */ - private Path correct(Path path) { - if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) - && ! path.elements().contains(FilesApplicationPackage.preprocessed)) { - return Path.fromString(FilesApplicationPackage.preprocessed).append(path); - } - else { - return path; - } - } - - private void createIfNeeded(Path path) { - File dir = application.getFileReference(path); - if ( ! dir.exists()) { - if (!dir.mkdirs()) - throw new IllegalStateException("Could not create " + dir); - } - } - - } - - /** Encapsulates the 1, 2 or 3 arguments to a tensorflow feature */ - private static class FeatureArguments { - - private final Path modelPath; - - /** Optional arguments */ - private final Optional<String> signature, output; - - public FeatureArguments(Arguments arguments) { + static class TensorFlowFeatureArguments extends FeatureArguments { + public TensorFlowFeatureArguments(Arguments arguments) { if (arguments.isEmpty()) throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + - "the tensorflow model directory under [application]/models"); + "the tensorflow model directory under [application]/models"); if (arguments.expressions().size() > 3) throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments"); @@ -661,68 +76,6 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil signature = optionalArgument(1, arguments); output = optionalArgument(2, arguments); } - - /** Returns modelPath with slashes replaced by underscores */ - public String modelName() { return modelPath.toString().replace('/', '_'); } - - /** Returns relative path to this model below the "models/" dir in the application package */ - public Path modelPath() { return modelPath; } - public Optional<String> signature() { return signature; } - public Optional<String> output() { return output; } - - /** Path to the small constants file */ - public Path smallConstantsPath() { - return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt"); - } - - /** Path to the large (ranking) constants directory */ - public Path largeConstantsPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); - } - - /** Path to the macros file */ - public Path macrosPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt"); - } - - public Path expressionPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR - .append(modelPath).append("expressions").append(expressionFileName()); - } - - private String expressionFileName() { - StringBuilder fileName = new StringBuilder(); - signature.ifPresent(s -> fileName.append(s).append(".")); - output.ifPresent(s -> fileName.append(s).append(".")); - if (fileName.length() == 0) // single signature and output - fileName.append("single."); - fileName.append("expression"); - return fileName.toString(); - } - - private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { - if (argumentIndex >= arguments.expressions().size()) - return Optional.empty(); - return Optional.of(asString(arguments.expressions().get(argumentIndex))); - } - - private String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); - } - - private String stripQuotes(String s) { - if ( ! isQuoteSign(s.codePointAt(0))) return s; - if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) - throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote"); - return s.substring(1, s.length()-1); - } - - private boolean isQuoteSign(int c) { - return c == '\'' || c == '"'; - } - } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/Host.java b/config-model/src/main/java/com/yahoo/vespa/model/Host.java index 0adfe9e4bdb..624a9fd4da7 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/Host.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/Host.java @@ -42,16 +42,14 @@ public final class Host extends AbstractConfigProducer<AbstractConfigProducer<?> private void checkName(HostSystem parent, String hostname) { // Give a warning if the host does not exist - // Host exists - warn if given hostname is not a fully qualified one. - String canonical = hostname; try { - canonical = parent.getCanonicalHostname(hostname); + Object address = java.net.InetAddress.getByName(hostname); } catch (UnknownHostException e) { - deployLogger().log(Level.WARNING, "Unable to find canonical hostname of host: " + hostname); + deployLogger().log(Level.WARNING, "Unable to lookup IP address of host: " + hostname); } - if ((null != canonical) && (! hostname.equals(canonical))) { + if (! hostname.contains(".")) { deployLogger().log(Level.WARNING, "Host named '" + hostname + "' may not receive any config " + - "since it does not match its canonical hostname: " + canonical); + "since it is not a canonical hostname"); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java index 6467199d9f9..fc46ed18dde 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java @@ -110,6 +110,13 @@ public class VespaMetricSet { metrics.add(new Metric("jdisc.memory_mappings.max")); metrics.add(new Metric("jdisc.open_file_descriptors.max")); + metrics.add(new Metric("jdisc.gc.count.average")); + metrics.add(new Metric("jdisc.gc.count.max")); + metrics.add(new Metric("jdisc.gc.count.last")); + metrics.add(new Metric("jdisc.gc.ms.average")); + metrics.add(new Metric("jdisc.gc.ms.max")); + metrics.add(new Metric("jdisc.gc.ms.last")); + metrics.add(new Metric("jdisc.deactivated_containers.total.last")); metrics.add(new Metric("jdisc.deactivated_containers.with_retained_refs.last")); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 1c54d12d8b3..d9beab6e2f2 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -37,15 +37,6 @@ public class RankingExpressionWithOnnxTestCase { } @Test - public void testOnnxReference() throws ParseException { - RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", - "onnx('mnist_softmax.onnx')"); - search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); - assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); - } - - @Test public void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "onnx('mnist_softmax.onnx')", @@ -122,13 +113,6 @@ public class RankingExpressionWithOnnxTestCase { } @Test - public void testOnnxReferenceSpecifyingOutput() { - RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", - "onnx('mnist_softmax.onnx', 'add')"); - search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - } - - @Test public void testOnnxReferenceMissingMacro() throws ParseException { try { RankProfileSearchFixture search = new RankProfileSearchFixture( @@ -145,7 +129,7 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx'): " + - "Model refers Placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + + "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + "not present in rank profile 'my_profile'", Exceptions.toMessageString(expected)); } @@ -163,8 +147,8 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx'): " + - "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) which must be produced " + - "by a macro in the rank profile, but this macro produces type tensor(d0[2],d5[10])", + "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " + + "but this macro returns tensor(d0[2],d5[10])", Exceptions.toMessageString(expected)); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index d288a396732..7228af2b0de 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -162,7 +162,7 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved'): " + - "Model refers placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + + "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + "not present in rank profile 'my_profile'", Exceptions.toMessageString(expected)); } @@ -179,7 +179,7 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved'): " + - "Model refers placeholder 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " + + "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " + "but this macro returns tensor(d0[2],d5[10])", Exceptions.toMessageString(expected)); } @@ -305,9 +305,9 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testMacroGeneration() { - final String expression = "join(join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; + final String expression = "join(join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))"; - final String macroExpression2 = "join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))"; + final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))"; RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", "tensorflow('mnist/saved')", @@ -316,15 +316,15 @@ public class RankingExpressionWithTensorFlowTestCase { "input", new StoringApplicationPackage(applicationDir)); search.assertFirstPhaseExpression(expression, "my_profile"); - search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile"); - search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile"); + search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); + search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile"); } @Test public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException { - final String expression = "join(join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; + final String expression = "join(join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))"; - final String macroExpression2 = "join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))"; + final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", @@ -335,8 +335,8 @@ public class RankingExpressionWithTensorFlowTestCase { application); search.assertFirstPhaseExpression(expression, "my_profile"); assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); - search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile"); - search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile"); + search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); + search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile"); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); @@ -353,8 +353,8 @@ public class RankingExpressionWithTensorFlowTestCase { storedApplication); searchFromStored.assertFirstPhaseExpression(expression, "my_profile"); assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); - searchFromStored.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile"); - searchFromStored.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile"); + searchFromStored.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); + searchFromStored.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile"); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); @@ -465,7 +465,7 @@ public class RankingExpressionWithTensorFlowTestCase { } - public static class StoringApplicationPackageFile extends ApplicationFile { + static class StoringApplicationPackageFile extends ApplicationFile { /** The path to the application package root */ private final Path root; diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ClusterSizeReductionValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ClusterSizeReductionValidatorTest.java index 765acf9e27b..4c3583ba0ae 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ClusterSizeReductionValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ClusterSizeReductionValidatorTest.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.application.validation.change; +import com.yahoo.config.application.api.ValidationId; +import com.yahoo.config.application.api.ValidationOverrides; import com.yahoo.config.model.api.ConfigChangeAction; import com.yahoo.config.model.api.ConfigChangeRefeedAction; import com.yahoo.vespa.model.VespaModel; @@ -33,7 +35,8 @@ public class ClusterSizeReductionValidatorTest { fail("Expected exception due to cluster size reduction"); } catch (IllegalArgumentException expected) { - assertEquals("cluster-size-reduction: Size reduction in 'default' is too large. Current size: 30, new size: 14. New size must be at least 50% of the current size", + assertEquals("cluster-size-reduction: Size reduction in 'default' is too large. Current size: 30, new size: 14. New size must be at least 50% of the current size. " + + ValidationOverrides.toAllowMessage(ValidationId.clusterSizeReduction), Exceptions.toMessageString(expected)); } } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java index 25ad6dbc620..ee58ca67b02 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.application.validation.change; +import com.yahoo.config.application.api.ValidationId; +import com.yahoo.config.application.api.ValidationOverrides; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.application.validation.ValidationTester; import com.yahoo.yolean.Exceptions; @@ -24,7 +26,8 @@ public class ContentClusterRemovalValidatorTest { fail("Expected exception due to content cluster id change"); } catch (IllegalArgumentException expected) { - assertEquals("content-cluster-removal: Content cluster 'contentClusterId' is removed. This will cause loss of all data in this cluster", + assertEquals("content-cluster-removal: Content cluster 'contentClusterId' is removed. This will cause loss of all data in this cluster. " + + ValidationOverrides.toAllowMessage(ValidationId.contentClusterRemoval), Exceptions.toMessageString(expected)); } } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentTypeRemovalValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentTypeRemovalValidatorTest.java index a52c6d7c7a2..ca45520711e 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentTypeRemovalValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentTypeRemovalValidatorTest.java @@ -1,6 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.application.validation.change; +import com.yahoo.config.application.api.ValidationId; +import com.yahoo.config.application.api.ValidationOverrides; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.application.validation.ValidationTester; import com.yahoo.yolean.Exceptions; @@ -28,7 +30,8 @@ public class ContentTypeRemovalValidatorTest { } catch (IllegalArgumentException expected) { assertEquals("content-type-removal: Type 'music' is removed in content cluster 'test'. " + - "This will cause loss of all data of this type", + "This will cause loss of all data of this type. " + + ValidationOverrides.toAllowMessage(ValidationId.contentTypeRemoval), Exceptions.toMessageString(expected)); } } diff --git a/config/src/main/java/com/yahoo/config/subscription/impl/ConfigSubscription.java b/config/src/main/java/com/yahoo/config/subscription/impl/ConfigSubscription.java index 76241c560e4..2b6f9e24dd3 100644 --- a/config/src/main/java/com/yahoo/config/subscription/impl/ConfigSubscription.java +++ b/config/src/main/java/com/yahoo/config/subscription/impl/ConfigSubscription.java @@ -204,7 +204,7 @@ public abstract class ConfigSubscription<T extends ConfigInstance> { void setInternalRedeploy(boolean internalRedeploy) { ConfigState<T> prev = config.get(); - this.config.set(new ConfigState<>(prev.isGenerationChanged(), prev.getGeneration(), prev.isConfigChanged(), internalRedeploy, prev.getConfig())); + this.config.set(new ConfigState<>(prev.isGenerationChanged(), prev.getGeneration(), internalRedeploy, prev.isConfigChanged(), prev.getConfig())); } /** diff --git a/config/src/main/java/com/yahoo/config/subscription/impl/JRTConfigRequester.java b/config/src/main/java/com/yahoo/config/subscription/impl/JRTConfigRequester.java index 88af414e28d..243c9e932a8 100644 --- a/config/src/main/java/com/yahoo/config/subscription/impl/JRTConfigRequester.java +++ b/config/src/main/java/com/yahoo/config/subscription/impl/JRTConfigRequester.java @@ -30,10 +30,10 @@ import com.yahoo.vespa.config.protocol.Trace; * as context, and puts the requests objects on a queue on the subscription, * for handling by the user thread. * - * @author vegardh - * @since 5.1 + * @author Vegard Havdal */ public class JRTConfigRequester implements RequestWaiter { + private static final Logger log = Logger.getLogger(JRTConfigRequester.class.getName()); public static final ConfigSourceSet defaultSourceSet = ConfigSourceSet.createDefault(); private static final int TRACELEVEL = 6; diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java index 4c32e635391..e9d400591e8 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java @@ -23,7 +23,7 @@ import com.yahoo.path.Path; import com.yahoo.slime.Slime; import com.yahoo.transaction.NestedTransaction; import com.yahoo.vespa.config.server.application.Application; -import com.yahoo.vespa.config.server.application.ApplicationConvergenceChecker; +import com.yahoo.vespa.config.server.application.ConfigConvergenceChecker; import com.yahoo.vespa.config.server.application.ApplicationSet; import com.yahoo.vespa.config.server.application.FileDistributionStatus; import com.yahoo.vespa.config.server.application.HttpProxy; @@ -88,7 +88,7 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye private final TenantRepository tenantRepository; private final Optional<Provisioner> hostProvisioner; - private final ApplicationConvergenceChecker convergeChecker; + private final ConfigConvergenceChecker convergeChecker; private final HttpProxy httpProxy; private final Clock clock; private final DeployLogger logger = new SilentDeployLogger(); @@ -99,22 +99,22 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye @Inject public ApplicationRepository(TenantRepository tenantRepository, HostProvisionerProvider hostProvisionerProvider, - ApplicationConvergenceChecker applicationConvergenceChecker, + ConfigConvergenceChecker configConvergenceChecker, HttpProxy httpProxy, ConfigserverConfig configserverConfig) { this(tenantRepository, hostProvisionerProvider.getHostProvisioner(), - applicationConvergenceChecker, httpProxy, configserverConfig, Clock.systemUTC(), new FileDistributionStatus()); + configConvergenceChecker, httpProxy, configserverConfig, Clock.systemUTC(), new FileDistributionStatus()); } // For testing public ApplicationRepository(TenantRepository tenantRepository, Provisioner hostProvisioner, Clock clock) { - this(tenantRepository, new ApplicationConvergenceChecker(), hostProvisioner, clock); + this(tenantRepository, new ConfigConvergenceChecker(), hostProvisioner, clock); } public ApplicationRepository(TenantRepository tenantRepository, - ApplicationConvergenceChecker convergenceChecker, + ConfigConvergenceChecker convergenceChecker, Provisioner hostProvisioner, Clock clock) { this(tenantRepository, Optional.of(hostProvisioner), @@ -124,14 +124,14 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye private ApplicationRepository(TenantRepository tenantRepository, Optional<Provisioner> hostProvisioner, - ApplicationConvergenceChecker applicationConvergenceChecker, + ConfigConvergenceChecker configConvergenceChecker, HttpProxy httpProxy, ConfigserverConfig configserverConfig, Clock clock, FileDistributionStatus fileDistributionStatus) { this.tenantRepository = tenantRepository; this.hostProvisioner = hostProvisioner; - this.convergeChecker = applicationConvergenceChecker; + this.convergeChecker = configConvergenceChecker; this.httpProxy = httpProxy; this.clock = clock; this.configserverConfig = configserverConfig; @@ -373,12 +373,12 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye // ---------------- Convergence ---------------------------------------------------------------- - public HttpResponse serviceConvergenceCheck(ApplicationId applicationId, String hostname, URI uri) { - return convergeChecker.serviceConvergenceCheck(getApplication(applicationId), hostname, uri); + public HttpResponse checkServiceForConfigConvergence(ApplicationId applicationId, String hostAndPort, URI uri) { + return convergeChecker.checkService(getApplication(applicationId), hostAndPort, uri); } - public HttpResponse serviceListToCheckForConfigConvergence(ApplicationId applicationId, URI uri) { - return convergeChecker.serviceListToCheckForConfigConvergence(getApplication(applicationId), uri); + public HttpResponse servicesToCheckForConfigConvergence(ApplicationId applicationId, URI uri) { + return convergeChecker.servicesToCheck(getApplication(applicationId), uri); } // ---------------- Session operations ---------------------------------------------------------------- diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerBootstrap.java b/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerBootstrap.java index 9793a441355..916fde97e35 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerBootstrap.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerBootstrap.java @@ -3,56 +3,70 @@ package com.yahoo.vespa.config.server; import com.google.inject.Inject; import com.yahoo.component.AbstractComponent; +import com.yahoo.concurrent.DaemonThreadFactory; +import com.yahoo.container.handler.VipStatus; import com.yahoo.container.jdisc.state.StateMonitor; import com.yahoo.log.LogLevel; import com.yahoo.vespa.config.server.rpc.RpcServer; import com.yahoo.vespa.config.server.version.VersionState; +import java.time.Duration; +import java.time.Instant; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + /** * Main component that bootstraps and starts config server threads. * - * @author lulf - * @since 5.1 + * If config server has been upgraded to a new version since the last time it was running it will redeploy all + * applications. If that is done successfully the RPC server will start and the health status code will change from + * 'initializing' to 'up' and the config server will be put into rotation (start serving status.html with 200 OK) + * + * @author Ulf Lilleengen + * @author hmusum */ public class ConfigServerBootstrap extends AbstractComponent implements Runnable { private static final java.util.logging.Logger log = java.util.logging.Logger.getLogger(ConfigServerBootstrap.class.getName()); + private static final ExecutorService rpcServerExecutor = Executors.newSingleThreadExecutor(new DaemonThreadFactory("config server RPC server")); + private static final String vipStatusClusterIdentifier = "configserver"; private final ApplicationRepository applicationRepository; private final RpcServer server; private final Thread serverThread; private final VersionState versionState; private final StateMonitor stateMonitor; + private final VipStatus vipStatus; // The tenants object is injected so that all initial requests handlers are // added to the rpc server before it starts answering rpc requests. @SuppressWarnings("WeakerAccess") @Inject public ConfigServerBootstrap(ApplicationRepository applicationRepository, RpcServer server, - VersionState versionState, StateMonitor stateMonitor) { - this(applicationRepository, server, versionState, stateMonitor, true); + VersionState versionState, StateMonitor stateMonitor, VipStatus vipStatus) { + this(applicationRepository, server, versionState, stateMonitor, vipStatus, true); } // For testing only - ConfigServerBootstrap(ApplicationRepository applicationRepository, RpcServer server, - VersionState versionState, StateMonitor stateMonitor, boolean startMainThread) { + ConfigServerBootstrap(ApplicationRepository applicationRepository, RpcServer server, VersionState versionState, + StateMonitor stateMonitor, VipStatus vipStatus, boolean startMainThread) { this.applicationRepository = applicationRepository; this.server = server; this.versionState = versionState; this.stateMonitor = stateMonitor; this.serverThread = new Thread(this, "configserver main"); + this.vipStatus = vipStatus; + initializing(); // Initially take server out of rotation if (startMainThread) start(); } - private void start() { - serverThread.start(); - } - @Override public void deconstruct() { log.log(LogLevel.INFO, "Stopping config server"); + down(); server.stop(); + rpcServerExecutor.shutdown(); try { serverThread.join(); } catch (InterruptedException e) { @@ -74,9 +88,8 @@ public class ConfigServerBootstrap extends AbstractComponent implements Runnable return; // Status will not be set to 'up' since we return here } } - stateMonitor.status(StateMonitor.Status.up); - log.log(LogLevel.INFO, "Starting RPC server"); - server.run(); + startRpcServer(); + up(); do { try { Thread.sleep(1000); @@ -85,13 +98,51 @@ public class ConfigServerBootstrap extends AbstractComponent implements Runnable break; } } while (server.isRunning()); + down(); log.log(LogLevel.INFO, "RPC server stopped"); - stateMonitor.status(StateMonitor.Status.down); } StateMonitor.Status status() { return stateMonitor.status(); } + private void start() { + serverThread.start(); + } + + private void up() { + stateMonitor.status(StateMonitor.Status.up); + vipStatus.addToRotation(vipStatusClusterIdentifier); + } + + private void down() { + stateMonitor.status(StateMonitor.Status.down); + vipStatus.removeFromRotation(vipStatusClusterIdentifier); + } + + private void initializing() { + // This is default value (from config), so not strictly necessary + stateMonitor.status(StateMonitor.Status.initializing); + vipStatus.removeFromRotation(vipStatusClusterIdentifier); + } + + private void startRpcServer() { + log.log(LogLevel.INFO, "Starting RPC server"); + rpcServerExecutor.execute(server); + + Instant end = Instant.now().plus(Duration.ofSeconds(10)); + while (!server.isRunning() && Instant.now().isBefore(end)) { + try { + Thread.sleep(10); + } catch (InterruptedException e) { + log.log(LogLevel.ERROR, "Got interrupted", e); + break; + } + } + if (!server.isRunning()) + throw new RuntimeException("RPC server not started in 10 seconds"); + log.log(LogLevel.INFO, "RPC server started"); + } + } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceChecker.java b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java index 58168a7526f..4978f5f274d 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceChecker.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ConfigConvergenceChecker.java @@ -29,11 +29,10 @@ import java.util.stream.Collectors; /** * Checks for convergence of config generation for a given application. * - * @author lulf + * @author Ulf Lilleengen * @author hmusum */ -public class ApplicationConvergenceChecker extends AbstractComponent { - +public class ConfigConvergenceChecker extends AbstractComponent { private static final String statePath = "/state/v1/"; private static final String configSubPath = "config"; private final static Set<String> serviceTypesToCheck = new HashSet<>(Arrays.asList( @@ -49,15 +48,15 @@ public class ApplicationConvergenceChecker extends AbstractComponent { private final Client client = ClientBuilder.newClient(); @Inject - public ApplicationConvergenceChecker() { - this(ApplicationConvergenceChecker::createStateApi); + public ConfigConvergenceChecker() { + this(ConfigConvergenceChecker::createStateApi); } - public ApplicationConvergenceChecker(StateApiFactory stateApiFactory) { + public ConfigConvergenceChecker(StateApiFactory stateApiFactory) { this.stateApiFactory = stateApiFactory; } - public ServiceListResponse serviceListToCheckForConfigConvergence(Application application, URI uri) { + public ServiceListResponse servicesToCheck(Application application, URI uri) { List<ServiceInfo> servicesToCheck = new ArrayList<>(); application.getModel().getHosts() .forEach(host -> host.getServices().stream() @@ -69,7 +68,7 @@ public class ApplicationConvergenceChecker extends AbstractComponent { currentGeneration); } - public ServiceResponse serviceConvergenceCheck(Application application, String hostAndPortToCheck, URI uri) { + public ServiceResponse checkService(Application application, String hostAndPortToCheck, URI uri) { Long wantedGeneration = application.getApplicationGeneration(); try { if (! hostInApplication(application, hostAndPortToCheck)) @@ -157,8 +156,7 @@ public class ApplicationConvergenceChecker extends AbstractComponent { return false; } - static class ServiceListResponse extends JSONResponse { - final Cursor debug; + private static class ServiceListResponse extends JSONResponse { // Pre-condition: servicesToCheck has a state port private ServiceListResponse(int status, List<ServiceInfo> servicesToCheck, URI uri, long wantedGeneration, @@ -178,40 +176,28 @@ public class ApplicationConvergenceChecker extends AbstractComponent { object.setLong("currentGeneration", currentGeneration); object.setLong("wantedGeneration", wantedGeneration); object.setBool("converged", currentGeneration >= wantedGeneration); - // TODO: Remove debug when clients are not using it anymore - debug = object.setObject("debug"); - debug.setLong("wantedGeneration", wantedGeneration); } } static class ServiceResponse extends JSONResponse { - final Cursor debug; private ServiceResponse(int status, URI uri, String hostname, Long wantedGeneration) { super(status); object.setString("url", uri.toString()); object.setString("host", hostname); object.setLong("wantedGeneration", wantedGeneration); - // TODO: Remove debug when clients are not using it anymore - debug = object.setObject("debug"); - debug.setString("host", hostname); - debug.setLong("wantedGeneration", wantedGeneration); } static ServiceResponse createOkResponse(URI uri, String hostname, Long wantedGeneration, Long currentGeneration, boolean converged) { ServiceResponse serviceResponse = new ServiceResponse(200, uri, hostname, wantedGeneration); serviceResponse.object.setBool("converged", converged); serviceResponse.object.setLong("currentGeneration", currentGeneration); - // TODO: Remove debug when clients are not using it anymore - serviceResponse.debug.setLong("currentGeneration", currentGeneration); return serviceResponse; } static ServiceResponse createHostNotFoundInAppResponse(URI uri, String hostname, Long wantedGeneration) { ServiceResponse serviceResponse = new ServiceResponse(410, uri, hostname, wantedGeneration); serviceResponse.object.setString("problem", "Host:port (service) no longer part of application, refetch list of services."); - // TODO: Remove debug when clients are not using it anymore - serviceResponse.debug.setString("problem", "Host:port (service) no longer part of application, refetch list of services."); return serviceResponse; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionImpl.java index 2db89c2e8ed..36d76bbfc79 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionImpl.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/filedistribution/FileDistributionImpl.java @@ -10,7 +10,6 @@ import com.yahoo.jrt.Spec; import com.yahoo.jrt.StringArray; import com.yahoo.jrt.Supervisor; import com.yahoo.jrt.Target; -import com.yahoo.jrt.Transport; import com.yahoo.log.LogLevel; import com.yahoo.vespa.defaults.Defaults; @@ -24,11 +23,12 @@ import java.util.logging.Logger; public class FileDistributionImpl implements FileDistribution { private final static Logger log = Logger.getLogger(FileDistributionImpl.class.getName()); - private final Supervisor supervisor = new Supervisor(new Transport()); + private final Supervisor supervisor; private final File fileReferencesDir; - public FileDistributionImpl(ConfigserverConfig configserverConfig) { + public FileDistributionImpl(ConfigserverConfig configserverConfig, Supervisor supervisor) { this.fileReferencesDir = new File(Defaults.getDefaults().underVespaHome(configserverConfig.fileReferencesDir())); + this.supervisor = supervisor; } @Override diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java index 6bca8b1c562..473ec913f50 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java @@ -59,7 +59,9 @@ public class ApplicationHandler extends HttpHandler { Tenant tenant = verifyTenantAndApplication(applicationId); if (isServiceConvergeRequest(request)) { - return applicationRepository.serviceConvergenceCheck(applicationId, getHostNameFromRequest(request), request.getUri()); + // Expects both hostname and port in the request (hostname:port) + String hostAndPort = getHostNameFromRequest(request); + return applicationRepository.checkServiceForConfigConvergence(applicationId, hostAndPort, request.getUri()); } if (isClusterControllerStatusRequest(request)) { @@ -86,7 +88,7 @@ public class ApplicationHandler extends HttpHandler { } if (isServiceConvergeListRequest(request)) { - return applicationRepository.serviceListToCheckForConfigConvergence(applicationId, request.getUri()); + return applicationRepository.servicesToCheckForConfigConvergence(applicationId, request.getUri()); } if (isFiledistributionStatusRequest(request)) { diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java index c6a390caf86..2a53f9ee45c 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java @@ -51,7 +51,7 @@ public class ConfigServerMaintenance extends AbstractComponent { this.defaultInterval = Duration.ofMinutes(configserverConfig.maintainerIntervalMinutes()); // TODO: Want job control or feature flag to control when to run this, for now use a very // long interval to avoid running the maintainer - this.tenantsMaintainerInterval = isCd || isTest + this.tenantsMaintainerInterval = isCd || isTest || configserverConfig.region().equals("us-central-1") ? defaultInterval : Duration.ofMinutes(configserverConfig.tenantsMaintainerIntervalMinutes()); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/FileDistributionMaintainer.java b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/FileDistributionMaintainer.java index 2664a0bde8c..1d16283d938 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/FileDistributionMaintainer.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/FileDistributionMaintainer.java @@ -31,9 +31,10 @@ public class FileDistributionMaintainer extends Maintainer { @Override protected void maintain() { - // TODO: For now only deletes files in CD system + // TODO: Delete files in all zones boolean deleteFiles = (SystemName.from(configserverConfig.system()) == SystemName.cd) - || Environment.from(configserverConfig.environment()).isTest(); + || Environment.from(configserverConfig.environment()).isTest() + || configserverConfig.region().equals("us-central-1"); applicationRepository.deleteUnusedFiledistributionReferences(fileReferencesDir, deleteFiles); } } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/RpcServer.java b/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/RpcServer.java index 9de587ac17b..f1cf479a38a 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/RpcServer.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/RpcServer.java @@ -166,7 +166,7 @@ public class RpcServer implements Runnable, ReloadListener, TenantListener { } public void run() { - log.log(LogLevel.INFO, "Rpc server listening on port " + spec.port()); + log.log(LogLevel.INFO, "Rpc will listen on port " + spec.port()); try { Acceptor acceptor = supervisor.listen(spec); isRunning = true; diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java index 8394494adca..15bc3c1fb46 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/FileDistributionFactory.java @@ -3,6 +3,8 @@ package com.yahoo.vespa.config.server.session; import com.google.inject.Inject; import com.yahoo.cloud.config.ConfigserverConfig; +import com.yahoo.jrt.Supervisor; +import com.yahoo.jrt.Transport; import com.yahoo.vespa.config.server.filedistribution.FileDistributionImpl; import com.yahoo.vespa.config.server.filedistribution.FileDistributionProvider; @@ -17,6 +19,7 @@ import java.io.File; public class FileDistributionFactory { private final ConfigserverConfig configserverConfig; + private final Supervisor supervisor = new Supervisor(new Transport()); @Inject public FileDistributionFactory(ConfigserverConfig configserverConfig) { @@ -24,7 +27,7 @@ public class FileDistributionFactory { } public FileDistributionProvider createProvider(File applicationPackage) { - return new FileDistributionProvider(applicationPackage, new FileDistributionImpl(configserverConfig)); + return new FileDistributionProvider(applicationPackage, new FileDistributionImpl(configserverConfig, supervisor)); } } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/Tenant.java b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/Tenant.java index 3557d7bf9ab..8d21a1b8c03 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/Tenant.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/Tenant.java @@ -11,7 +11,6 @@ import com.yahoo.vespa.config.server.session.LocalSessionRepo; import com.yahoo.vespa.config.server.session.RemoteSessionRepo; import com.yahoo.vespa.config.server.session.SessionFactory; import com.yahoo.vespa.curator.Curator; -import org.apache.zookeeper.Op; import org.apache.zookeeper.data.Stat; import java.time.Instant; diff --git a/configserver/src/main/resources/configserver-app/services.xml b/configserver/src/main/resources/configserver-app/services.xml index 2eeefda63e7..79797854689 100644 --- a/configserver/src/main/resources/configserver-app/services.xml +++ b/configserver/src/main/resources/configserver-app/services.xml @@ -10,6 +10,10 @@ <initialStatus>initializing</initialStatus> </config> + <config name="container.core.vip-status"> + <initiallyInRotation>false</initiallyInRotation> + </config> + <accesslog type="vespa" fileNamePattern="logs/vespa/configserver/access.log.%Y%m%d%H%M%S" rotationScheme="date" compressOnRotation="true" symlinkName="access.log" /> <preprocess:include file='access-logging.xml' required='false' /> @@ -37,7 +41,7 @@ <component id="com.yahoo.vespa.config.server.host.ConfigRequestHostLivenessTracker" bundle="configserver" /> <component id="com.yahoo.container.jdisc.metric.state.StateMetricConsumerFactory" bundle="container-disc" /> <component id="com.yahoo.config.provision.Zone" bundle="config-provisioning" /> - <component id="com.yahoo.vespa.config.server.application.ApplicationConvergenceChecker" bundle="configserver" /> + <component id="com.yahoo.vespa.config.server.application.ConfigConvergenceChecker" bundle="configserver" /> <component id="com.yahoo.vespa.config.server.application.HttpProxy" bundle="configserver" /> <component id="com.yahoo.vespa.config.server.filedistribution.FileServer" bundle="configserver" /> <component id="com.yahoo.vespa.config.server.maintenance.ConfigServerMaintenance" bundle="configserver" /> diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/ConfigServerBootstrapTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/ConfigServerBootstrapTest.java index 67cc87ae223..992d46d3115 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/ConfigServerBootstrapTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/ConfigServerBootstrapTest.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.config.server; import com.yahoo.cloud.config.ConfigserverConfig; +import com.yahoo.container.handler.VipStatus; import com.yahoo.container.jdisc.config.HealthMonitorConfig; import com.yahoo.container.jdisc.state.StateMonitor; import com.yahoo.jdisc.core.SystemTimer; @@ -43,13 +44,17 @@ public class ConfigServerBootstrapTest { assertTrue(versionState.isUpgraded()); RpcServer rpcServer = createRpcServer(configserverConfig); - ConfigServerBootstrap bootstrap = new ConfigServerBootstrap(tester.applicationRepository(), rpcServer, versionState, createStateMonitor()); - waitUntil(() -> bootstrap.status() == StateMonitor.Status.up, "failed waiting for status 'up'"); + VipStatus vipStatus = new VipStatus(); + ConfigServerBootstrap bootstrap = new ConfigServerBootstrap(tester.applicationRepository(), rpcServer, versionState, createStateMonitor(), vipStatus); + assertFalse(vipStatus.isInRotation()); waitUntil(rpcServer::isRunning, "failed waiting for Rpc server running"); + waitUntil(() -> bootstrap.status() == StateMonitor.Status.up, "failed waiting for status 'up'"); + waitUntil(vipStatus::isInRotation, "failed waiting for server to be in rotation"); bootstrap.deconstruct(); assertEquals(StateMonitor.Status.down, bootstrap.status()); assertFalse(rpcServer.isRunning()); + assertFalse(vipStatus.isInRotation()); } @Test @@ -69,13 +74,17 @@ public class ConfigServerBootstrapTest { .resolve("sessions/2/services.xml")); RpcServer rpcServer = createRpcServer(configserverConfig); + VipStatus vipStatus = new VipStatus(); ConfigServerBootstrap bootstrap = new ConfigServerBootstrap(tester.applicationRepository(), rpcServer, versionState, - createStateMonitor(), false /* do not call run method */); + createStateMonitor(), vipStatus, false /* do not call run method */); + assertFalse(vipStatus.isInRotation()); // Call method directly, to be sure that it is finished redeploying all applications and we can check status bootstrap.run(); - // App is invalid, so bootstrapping was unsuccessful. Status should be 'initializing' and rpc server should not be running + // App is invalid, bootstrapping was unsuccessful. Status should be 'initializing', + // rpc server should not be running and it should be out of rotation assertEquals(StateMonitor.Status.initializing, bootstrap.status()); assertFalse(rpcServer.isRunning()); + assertFalse(vipStatus.isInRotation()); } private void waitUntil(BooleanSupplier booleanSupplier, String messageIfWaitingFails) throws InterruptedException { diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/TestWithCurator.java b/configserver/src/test/java/com/yahoo/vespa/config/server/TestWithCurator.java deleted file mode 100644 index 5dce0607f90..00000000000 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/TestWithCurator.java +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.config.server; - -import com.yahoo.vespa.curator.Curator; -import com.yahoo.vespa.curator.mock.MockCurator; -import com.yahoo.vespa.config.server.zookeeper.ConfigCurator; -import org.apache.curator.framework.CuratorFramework; -import org.junit.Before; - -/** - * For tests that require a Curator instance - * - * @author lulf - * @since 5.16 - */ -public class TestWithCurator { - - protected ConfigCurator configCurator; - protected CuratorFramework curatorFramework; - protected Curator curator; - - @Before - public void setupZKProvider() throws Exception { - curator = new MockCurator(); - configCurator = ConfigCurator.create(curator); - curatorFramework = curator.framework(); - } - -} diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceCheckerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/application/ConfigConvergenceCheckerTest.java index 399169c122a..71052c8b463 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/application/ApplicationConvergenceCheckerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/application/ConfigConvergenceCheckerTest.java @@ -25,21 +25,23 @@ import java.util.Arrays; import java.util.HashMap; import java.util.Map; -import static com.yahoo.vespa.config.server.application.ApplicationConvergenceChecker.ServiceResponse; +import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static com.yahoo.vespa.config.server.application.ConfigConvergenceChecker.ServiceResponse; /** * @author Ulf Lilleengen */ -public class ApplicationConvergenceCheckerTest { +public class ConfigConvergenceCheckerTest { private static final ObjectMapper mapper = new ObjectMapper(); private final TenantName tenant = TenantName.from("mytenant"); private final ApplicationId appId = ApplicationId.from(tenant, ApplicationName.from("myapp"), InstanceName.from("myinstance")); private Application application; + private ConfigConvergenceChecker checker; private Map<URI, Long> currentGeneration; - private ApplicationConvergenceChecker checker; @Rule public TemporaryFolder folder = new TemporaryFolder(); @@ -54,7 +56,7 @@ public class ApplicationConvergenceCheckerTest { Version.fromIntValues(0, 0, 0), MetricUpdater.createTestUpdater(), appId); currentGeneration = new HashMap<>(); - checker = new ApplicationConvergenceChecker( + checker = new ConfigConvergenceChecker( (client, serviceUri) -> () -> asJson("{\"config\":{\"generation\":" + currentGeneration.getOrDefault(serviceUri, 3L) + "}}")); @@ -62,37 +64,27 @@ public class ApplicationConvergenceCheckerTest { @Test public void service_convergence() throws Exception { - ServiceResponse serviceResponse = checker.serviceConvergenceCheck(application, - "localhost:1337", - URI.create("http://foo:234/serviceconverge/localhost:1337")); + ServiceResponse serviceResponse = checker.checkService(application, + "localhost:1337", + URI.create("http://foo:234/serviceconverge/localhost:1337")); assertEquals(200, serviceResponse.getStatus()); assertJsonEquals("{\n" + " \"url\": \"http://foo:234/serviceconverge/localhost:1337\",\n" + " \"host\": \"localhost:1337\",\n" + " \"wantedGeneration\": 3,\n" + - " \"debug\": {\n" + - " \"host\": \"localhost:1337\",\n" + - " \"wantedGeneration\": 3,\n" + - " \"currentGeneration\": 3\n" + - " },\n" + " \"converged\": true,\n" + " \"currentGeneration\": 3\n" + "}", SessionHandlerTest.getRenderedString(serviceResponse)); - ServiceResponse hostMissingResponse = checker.serviceConvergenceCheck(application, - "notPresent:1337", - URI.create("http://foo:234/serviceconverge/notPresent:1337")); + ServiceResponse hostMissingResponse = checker.checkService(application, + "notPresent:1337", + URI.create("http://foo:234/serviceconverge/notPresent:1337")); assertEquals(410, hostMissingResponse.getStatus()); assertJsonEquals("{\n" + " \"url\": \"http://foo:234/serviceconverge/notPresent:1337\",\n" + " \"host\": \"notPresent:1337\",\n" + " \"wantedGeneration\": 3,\n" + - " \"debug\": {\n" + - " \"host\": \"notPresent:1337\",\n" + - " \"wantedGeneration\": 3,\n" + - " \"problem\": \"Host:port (service) no longer part of application, refetch list of services.\"\n" + - " },\n" + " \"problem\": \"Host:port (service) no longer part of application, refetch list of services.\"\n" + "}", SessionHandlerTest.getRenderedString(hostMissingResponse)); @@ -100,8 +92,7 @@ public class ApplicationConvergenceCheckerTest { @Test public void service_list_convergence() throws Exception { - HttpResponse serviceListResponse = checker.serviceListToCheckForConfigConvergence(application, - URI.create("http://foo:234/serviceconverge")); + HttpResponse serviceListResponse = checker.servicesToCheck(application, URI.create("http://foo:234/serviceconverge")); assertEquals(200, serviceListResponse.getStatus()); assertJsonEquals("{\n" + " \"services\": [\n" + @@ -115,10 +106,7 @@ public class ApplicationConvergenceCheckerTest { " \"url\": \"http://foo:234/serviceconverge\",\n" + " \"currentGeneration\": 3,\n" + " \"wantedGeneration\": 3,\n" + - " \"converged\": true,\n" + - " \"debug\": {\n" + - " \"wantedGeneration\": 3\n" + - " }\n" + + " \"converged\": true\n" + "}", SessionHandlerTest.getRenderedString(serviceListResponse)); @@ -132,7 +120,7 @@ public class ApplicationConvergenceCheckerTest { Version.fromIntValues(0, 0, 0), MetricUpdater.createTestUpdater(), appId); currentGeneration.put(URI.create("http://host2:1234"), 4L); - serviceListResponse = checker.serviceListToCheckForConfigConvergence(application, URI.create("http://foo:234/serviceconverge")); + serviceListResponse = checker.servicesToCheck(application, URI.create("http://foo:234/serviceconverge")); assertEquals(200, serviceListResponse.getStatus()); assertJsonEquals("{\n" + " \"services\": [\n" + @@ -152,10 +140,7 @@ public class ApplicationConvergenceCheckerTest { " \"url\": \"http://foo:234/serviceconverge\",\n" + " \"currentGeneration\": 3,\n" + " \"wantedGeneration\": 4,\n" + - " \"converged\": false,\n" + - " \"debug\": {\n" + - " \"wantedGeneration\": 4\n" + - " }\n" + + " \"converged\": false\n" + "}", SessionHandlerTest.getRenderedString(serviceListResponse)); } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java index a4bfb1de221..3fa1b3fdb5e 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java @@ -5,9 +5,12 @@ import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.TenantName; import com.yahoo.text.Utf8; import com.yahoo.vespa.config.server.MockReloadHandler; -import com.yahoo.vespa.config.server.TestWithCurator; import com.yahoo.vespa.config.server.tenant.TenantRepository; +import com.yahoo.vespa.curator.Curator; +import com.yahoo.vespa.curator.mock.MockCurator; +import org.apache.curator.framework.CuratorFramework; +import org.junit.Before; import org.junit.Test; import java.util.Arrays; @@ -19,12 +22,20 @@ import static org.junit.Assert.*; /** * @author Ulf Lilleengen - * @since 5.1 */ -public class TenantApplicationsTest extends TestWithCurator { +public class TenantApplicationsTest { private static final TenantName tenantName = TenantName.from("tenant"); + private Curator curator; + private CuratorFramework curatorFramework; + + @Before + public void setup() { + curator = new MockCurator(); + curatorFramework = curator.framework(); + } + @Test public void require_that_applications_are_read_from_zookeeper() throws Exception { writeApplicationData(createApplicationId("foo"), 3L); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/deploy/ZooKeeperClientTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/deploy/ZooKeeperClientTest.java index 945c7d60750..c4a0fd9f3f0 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/deploy/ZooKeeperClientTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/deploy/ZooKeeperClientTest.java @@ -8,7 +8,6 @@ import com.yahoo.config.application.api.FileRegistry; import com.yahoo.config.model.application.provider.*; import com.yahoo.config.provision.*; import com.yahoo.path.Path; -import com.yahoo.vespa.config.server.TestWithCurator; import com.yahoo.vespa.config.server.zookeeper.ZKApplicationPackage; import com.yahoo.vespa.curator.mock.MockCurator; import com.yahoo.vespa.config.server.zookeeper.ConfigCurator; @@ -25,13 +24,12 @@ import java.util.*; import static org.hamcrest.core.Is.is; import static org.junit.Assert.*; - /** * Unit tests for ZooKeeperClient. * * @author hmusum */ -public class ZooKeeperClientTest extends TestWithCurator { +public class ZooKeeperClientTest { @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); @@ -41,7 +39,7 @@ public class ZooKeeperClientTest extends TestWithCurator { @Before public void setupZK() throws IOException { - this.zk = ConfigCurator.create(curator); + zk = ConfigCurator.create(new MockCurator()); ZooKeeperClient zkc = new ZooKeeperClient(zk, new BaseDeployLogger(), true, Path.fromString(appPath)); ApplicationPackage app = FilesApplicationPackage.fromFileWithDeployData(new File("src/test/apps/zkfeed"), new DeployData("foo", diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java index 2d39efb9013..d8c5e33ca65 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java @@ -12,7 +12,7 @@ import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.jdisc.Response; import com.yahoo.vespa.config.server.ApplicationRepository; import com.yahoo.vespa.config.server.TestComponentRegistry; -import com.yahoo.vespa.config.server.application.ApplicationConvergenceChecker; +import com.yahoo.vespa.config.server.application.ConfigConvergenceChecker; import com.yahoo.vespa.config.server.application.HttpProxy; import com.yahoo.vespa.config.server.http.HandlerTest; import com.yahoo.vespa.config.server.http.HttpErrorResponse; @@ -69,7 +69,7 @@ public class ApplicationHandlerTest { tenantRepository.addTenant(TenantBuilder.create(componentRegistry, foobar)); provisioner = new SessionHandlerTest.MockProvisioner(); applicationRepository = new ApplicationRepository(tenantRepository, - new ApplicationConvergenceChecker(stateApiFactory), + new ConfigConvergenceChecker(stateApiFactory), provisioner, Clock.systemUTC()); listApplicationsHandler = new ListApplicationsHandler(ListApplicationsHandler.testOnlyContext(), tenantRepository, @@ -163,7 +163,7 @@ public class ApplicationHandlerTest { HttpProxy mockHttpProxy = mock(HttpProxy.class); ApplicationRepository applicationRepository = new ApplicationRepository(tenantRepository, HostProvisionerProvider.withProvisioner(provisioner), - new ApplicationConvergenceChecker(stateApiFactory), + new ConfigConvergenceChecker(stateApiFactory), mockHttpProxy, new ConfigserverConfig(new ConfigserverConfig.Builder())); ApplicationHandler mockHandler = createApplicationHandler(applicationRepository); @@ -276,10 +276,10 @@ public class ApplicationHandlerTest { return createApplicationHandler().handle(HttpRequest.createTestRequest(restartUrl, com.yahoo.jdisc.http.HttpRequest.Method.GET)); } - private static class MockStateApiFactory implements ApplicationConvergenceChecker.StateApiFactory { + private static class MockStateApiFactory implements ConfigConvergenceChecker.StateApiFactory { boolean createdApi = false; @Override - public ApplicationConvergenceChecker.StateApi createStateApi(Client client, URI serviceUri) { + public ConfigConvergenceChecker.StateApi createStateApi(Client client, URI serviceUri) { createdApi = true; return () -> { try { diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionRepoTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionRepoTest.java index 987dd8a6c4d..829dfb978b2 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionRepoTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionRepoTest.java @@ -6,13 +6,13 @@ import com.yahoo.test.ManualClock; import com.yahoo.config.provision.TenantName; import com.yahoo.vespa.config.server.GlobalComponentRegistry; import com.yahoo.vespa.config.server.TestComponentRegistry; -import com.yahoo.vespa.config.server.TestWithCurator; import com.yahoo.vespa.config.server.application.MemoryTenantApplications; import com.yahoo.vespa.config.server.deploy.TenantFileSystemDirs; import com.yahoo.io.IOUtils; import com.yahoo.vespa.config.server.host.HostRegistry; import com.yahoo.vespa.config.server.http.SessionHandlerTest; +import com.yahoo.vespa.curator.mock.MockCurator; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -29,7 +29,7 @@ import static org.junit.Assert.fail; /** * @author Ulf Lilleengen */ -public class LocalSessionRepoTest extends TestWithCurator { +public class LocalSessionRepoTest { private File testApp = new File("src/test/apps/app"); private LocalSessionRepo repo; @@ -45,7 +45,7 @@ public class LocalSessionRepoTest extends TestWithCurator { } private void setupSessions(TenantName tenantName, boolean createInitialSessions) throws Exception { - GlobalComponentRegistry globalComponentRegistry = new TestComponentRegistry.Builder().curator(curator).build(); + GlobalComponentRegistry globalComponentRegistry = new TestComponentRegistry.Builder().curator(new MockCurator()).build(); TenantFileSystemDirs tenantFileSystemDirs = new TenantFileSystemDirs(temporaryFolder.newFolder(), tenantName); if (createInitialSessions) { IOUtils.copyDirectory(testApp, new File(tenantFileSystemDirs.sessionsPath(), "1")); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionRepoTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionRepoTest.java index 88997d29572..4bbfea48254 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionRepoTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionRepoTest.java @@ -13,12 +13,12 @@ import com.yahoo.text.Utf8; import com.yahoo.transaction.Transaction; import com.yahoo.vespa.config.server.TestComponentRegistry; -import com.yahoo.vespa.config.server.TestWithCurator; import com.yahoo.vespa.config.server.application.TenantApplications; import com.yahoo.vespa.config.server.tenant.Tenant; import com.yahoo.vespa.config.server.tenant.TenantBuilder; import com.yahoo.vespa.config.server.tenant.TenantRepository; import com.yahoo.vespa.curator.Curator; +import com.yahoo.vespa.curator.mock.MockCurator; import org.junit.Before; import org.junit.Test; @@ -32,16 +32,17 @@ import java.util.function.LongPredicate; /** * @author Ulf Lilleengen - * @since 5.1 */ -public class RemoteSessionRepoTest extends TestWithCurator { +public class RemoteSessionRepoTest { private static final TenantName tenantName = TenantName.defaultName(); private RemoteSessionRepo remoteSessionRepo; + private Curator curator; @Before - public void setupFacade() throws Exception { + public void setupFacade() { + curator = new MockCurator(); Tenant tenant = TenantBuilder.create(new TestComponentRegistry.Builder() .curator(curator) .build(), @@ -75,7 +76,7 @@ public class RemoteSessionRepoTest extends TestWithCurator { } @Test - public void testCreateSession() throws Exception { + public void testCreateSession() { createSession(3l, true); assertSessionExists(3l); } @@ -99,7 +100,7 @@ public class RemoteSessionRepoTest extends TestWithCurator { // repo even if it had bad data (by making getSessionIdForApplication() in FailingTenantApplications // throw an exception). @Test - public void testBadApplicationRepoOnActivate() throws Exception { + public void testBadApplicationRepoOnActivate() { long sessionId = 3L; TenantApplications applicationRepo = new FailingTenantApplications(); TenantName mytenant = TenantName.from("mytenant"); @@ -116,7 +117,7 @@ public class RemoteSessionRepoTest extends TestWithCurator { private void assertStatusChange(long sessionId, Session.Status status) throws Exception { Path statePath = TenantRepository.getSessionsPath(tenantName).append("" + sessionId).append(ConfigCurator.SESSIONSTATE_ZK_SUBPATH); curator.create(statePath); - curatorFramework.setData().forPath(statePath.getAbsolute(), Utf8.toBytes(status.toString())); + curator.framework().setData().forPath(statePath.getAbsolute(), Utf8.toBytes(status.toString())); System.out.println("Setting status " + status + " for " + sessionId); assertSessionStatus(sessionId, status); } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionTest.java index 39fe27e5adb..b57d2d1a1a1 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/RemoteSessionTest.java @@ -22,9 +22,7 @@ import com.yahoo.vespa.model.VespaModelFactory; import org.junit.Before; import org.junit.Test; -import org.xml.sax.SAXException; -import java.io.IOException; import java.time.Clock; import java.time.Instant; import java.time.LocalDate; @@ -42,8 +40,7 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; /** - * @author lulf - * @since 5.1 + * @author Ulf Lilleengen */ public class RemoteSessionTest { @@ -52,7 +49,7 @@ public class RemoteSessionTest { private Curator curator; @Before - public void setupTest() throws Exception { + public void setupTest() { curator = new MockCurator(); } @@ -66,7 +63,7 @@ public class RemoteSessionTest { } @Test - public void require_that_applications_are_loaded() throws IOException, SAXException { + public void require_that_applications_are_loaded() { RemoteSession session = createSession(3, Arrays.asList(new MockModelFactory(), new VespaModelFactory(new NullConfigModelRegistry())), Clock.systemUTC()); session.loadPrepared(); ApplicationSet applicationSet = session.ensureApplicationLoaded(); @@ -84,7 +81,7 @@ public class RemoteSessionTest { } @Test(expected = IllegalArgumentException.class) - public void require_that_new_invalid_application_throws_exception() throws IOException, SAXException { + public void require_that_new_invalid_application_throws_exception() { MockModelFactory failingFactory = new MockModelFactory(); failingFactory.vespaVersion = Version.fromIntValues(1, 2, 0); failingFactory.throwOnLoad = true; @@ -98,7 +95,7 @@ public class RemoteSessionTest { } @Test - public void require_that_application_incompatible_with_latestmajor_is_loaded_on_earlier_major() throws IOException, SAXException { + public void require_that_application_incompatible_with_latestmajor_is_loaded_on_earlier_major() { MockModelFactory okFactory1 = new MockModelFactory(); okFactory1.vespaVersion = Version.fromIntValues(1, 1, 0); okFactory1.throwOnLoad = false; @@ -116,7 +113,7 @@ public class RemoteSessionTest { } @Test - public void require_that_old_invalid_application_does_not_throw_exception_if_skipped() throws IOException, SAXException { + public void require_that_old_invalid_application_does_not_throw_exception_if_skipped() { MockModelFactory failingFactory = new MockModelFactory(); failingFactory.vespaVersion = Version.fromIntValues(1, 1, 0); failingFactory.throwOnLoad = true; @@ -131,7 +128,7 @@ public class RemoteSessionTest { } @Test - public void require_that_old_invalid_application_does_not_throw_exception_if_skipped_also_across_major_versions() throws IOException, SAXException { + public void require_that_old_invalid_application_does_not_throw_exception_if_skipped_also_across_major_versions() { MockModelFactory failingFactory = new MockModelFactory(); failingFactory.vespaVersion = Version.fromIntValues(1, 0, 0); failingFactory.throwOnLoad = true; @@ -146,7 +143,7 @@ public class RemoteSessionTest { } @Test - public void require_that_old_invalid_application_does_not_throw_exception_if_skipped_also_when_new_major_is_incompatible() throws IOException, SAXException { + public void require_that_old_invalid_application_does_not_throw_exception_if_skipped_also_when_new_major_is_incompatible() { MockModelFactory failingFactory = new MockModelFactory(); failingFactory.vespaVersion = Version.fromIntValues(1, 0, 0); failingFactory.throwOnLoad = true; @@ -166,7 +163,7 @@ public class RemoteSessionTest { } @Test - public void require_that_an_application_package_can_limit_to_one_major_version() throws IOException, SAXException { + public void require_that_an_application_package_can_limit_to_one_major_version() { ApplicationPackage application = new MockApplicationPackage.Builder().withServices("<services major-version='2' version=\"1.0\"></services>").build(); @@ -186,7 +183,7 @@ public class RemoteSessionTest { } @Test - public void require_that_session_status_is_updated() throws IOException, SAXException { + public void require_that_session_status_is_updated() { SessionZooKeeperClient zkc = new MockSessionZKClient(curator, tenantName, 3); RemoteSession session = createSession(3, zkc, Clock.systemUTC()); assertThat(session.getStatus(), is(Session.Status.NEW)); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionFactoryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionFactoryTest.java index 0ca487cfb67..6c8be2ac2f3 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionFactoryTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionFactoryTest.java @@ -15,7 +15,8 @@ import com.yahoo.vespa.config.server.http.CompressedApplicationInputStream; import com.yahoo.vespa.config.server.http.CompressedApplicationInputStreamTest; import com.yahoo.vespa.config.server.http.v2.ApplicationApiHandler; -import com.yahoo.vespa.config.server.tenant.TestWithTenant; +import com.yahoo.vespa.config.server.tenant.TenantRepository; +import com.yahoo.vespa.curator.mock.MockCurator; import org.json.JSONException; import org.json.JSONObject; import org.junit.Before; @@ -33,12 +34,13 @@ import static org.junit.Assert.assertTrue; /** * @author Ulf Lilleengen */ -public class SessionFactoryTest extends TestWithTenant { +public class SessionFactoryTest { private SessionFactory factory; @Before public void setup_test() { - factory = tenant.getSessionFactory(); + TenantRepository tenantRepository = new TenantRepository(new TestComponentRegistry.Builder().curator(new MockCurator()).build()); + factory = tenantRepository.defaultTenant().getSessionFactory(); } @Test diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionPreparerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionPreparerTest.java index 4fbd7fe7232..92fb67fdd54 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionPreparerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionPreparerTest.java @@ -32,11 +32,11 @@ import com.yahoo.vespa.config.server.provision.HostProvisionerProvider; import com.yahoo.vespa.config.server.tenant.Rotations; import com.yahoo.vespa.config.server.zookeeper.ConfigCurator; +import com.yahoo.vespa.curator.mock.MockCurator; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; -import org.xml.sax.SAXException; import java.io.File; import java.io.IOException; @@ -52,13 +52,15 @@ import static org.junit.Assert.*; * @author lulf * @since 5.1 */ -public class SessionPreparerTest extends TestWithCurator { +public class SessionPreparerTest { private static final Path tenantPath = Path.createRoot(); private static final Path sessionsPath = tenantPath.append("sessions").append("testapp"); private static final File testApp = new File("src/test/apps/app"); private static final File invalidTestApp = new File("src/test/apps/illegalApp"); + private MockCurator curator; + private ConfigCurator configCurator; private SessionPreparer preparer; private TestComponentRegistry componentRegistry; private MockFileDistributionFactory fileDistributionFactory; @@ -69,6 +71,8 @@ public class SessionPreparerTest extends TestWithCurator { @Before public void setUp() { + curator = new MockCurator(); + configCurator = ConfigCurator.create(curator); componentRegistry = new TestComponentRegistry.Builder().curator(curator).build(); fileDistributionFactory = (MockFileDistributionFactory)componentRegistry.getFileDistributionFactory(); preparer = createPreparer(); @@ -99,13 +103,13 @@ public class SessionPreparerTest extends TestWithCurator { } @Test(expected = InvalidApplicationException.class) - public void require_that_application_validation_exception_is_not_caught() throws IOException, SAXException { + public void require_that_application_validation_exception_is_not_caught() throws IOException { FilesApplicationPackage app = getApplicationPackage(invalidTestApp); preparer.prepare(getContext(app), getLogger(), new PrepareParams.Builder().build(), Optional.empty(), tenantPath, Instant.now()); } @Test - public void require_that_application_validation_exception_is_ignored_if_forced() throws IOException, SAXException { + public void require_that_application_validation_exception_is_ignored_if_forced() throws IOException { FilesApplicationPackage app = getApplicationPackage(invalidTestApp); preparer.prepare(getContext(app), getLogger(), new PrepareParams.Builder().ignoreValidationErrors(true).timeoutBudget(TimeoutBudgetTest.day()).build(), @@ -250,18 +254,18 @@ public class SessionPreparerTest extends TestWithCurator { return FilesApplicationPackage.fromFile(appDir); } - DeployHandlerLogger getLogger() { + private DeployHandlerLogger getLogger() { return getLogger(false); } - DeployHandlerLogger getLogger(boolean verbose) { + private DeployHandlerLogger getLogger(boolean verbose) { return new DeployHandlerLogger(new Slime().get(), verbose, new ApplicationId.Builder().tenant("testtenant").applicationName("testapp").build()); } private static class FailingModelFactory extends TestModelFactory { private final RuntimeException exception; - public FailingModelFactory(Version vespaVersion, RuntimeException exception) { + FailingModelFactory(Version vespaVersion, RuntimeException exception) { super(vespaVersion); this.exception = exception; } @@ -279,7 +283,7 @@ public class SessionPreparerTest extends TestWithCurator { private static class ConfigChangeActionsModelFactory extends TestModelFactory { private final ConfigChangeAction action; - public ConfigChangeActionsModelFactory(Version vespaVersion, ConfigChangeAction action) { + ConfigChangeActionsModelFactory(Version vespaVersion, ConfigChangeAction action) { super(vespaVersion); this.action = action; } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionZooKeeperClientTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionZooKeeperClientTest.java index 98ba3d4e178..522a21a47b3 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionZooKeeperClientTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionZooKeeperClientTest.java @@ -4,8 +4,10 @@ package com.yahoo.vespa.config.server.session; import com.yahoo.path.Path; import com.yahoo.config.provision.ApplicationId; import com.yahoo.text.Utf8; -import com.yahoo.vespa.config.server.TestWithCurator; import com.yahoo.vespa.config.server.zookeeper.ConfigCurator; +import com.yahoo.vespa.curator.Curator; +import com.yahoo.vespa.curator.mock.MockCurator; +import org.junit.Before; import org.junit.Test; import java.util.concurrent.TimeUnit; @@ -15,10 +17,18 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; /** - * @author lulf - * @since 5.1 + * @author Ulf Lilleengen */ -public class SessionZooKeeperClientTest extends TestWithCurator { +public class SessionZooKeeperClientTest { + + private Curator curator; + private ConfigCurator configCurator; + + @Before + public void setup() { + curator = new MockCurator(); + configCurator = ConfigCurator.create(curator); + } @Test public void require_that_status_can_be_updated() { diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java index f47ed69ad14..046369edce0 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java @@ -8,9 +8,10 @@ import com.yahoo.config.provision.Version; import com.yahoo.vespa.config.server.application.ApplicationSet; import com.yahoo.vespa.config.server.ServerCache; import com.yahoo.vespa.config.server.TestComponentRegistry; -import com.yahoo.vespa.config.server.TestWithCurator; import com.yahoo.vespa.config.server.application.Application; import com.yahoo.vespa.config.server.monitoring.MetricUpdater; +import com.yahoo.vespa.curator.Curator; +import com.yahoo.vespa.curator.mock.MockCurator; import com.yahoo.vespa.model.VespaModel; import org.junit.After; import org.junit.Before; @@ -30,17 +31,20 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -public class TenantRepositoryTest extends TestWithCurator { +public class TenantRepositoryTest { + private static final TenantName tenant1 = TenantName.from("tenant1"); + private static final TenantName tenant2 = TenantName.from("tenant2"); + private static final TenantName tenant3 = TenantName.from("tenant3"); + private TenantRepository tenantRepository; private TestComponentRegistry globalComponentRegistry; private TenantRequestHandlerTest.MockReloadListener listener; private MockTenantListener tenantListener; - private final TenantName tenant1 = TenantName.from("tenant1"); - private final TenantName tenant2 = TenantName.from("tenant2"); - private final TenantName tenant3 = TenantName.from("tenant3"); + private Curator curator; @Before public void setupSessions() { + curator = new MockCurator(); globalComponentRegistry = new TestComponentRegistry.Builder().curator(curator).build(); listener = (TenantRequestHandlerTest.MockReloadListener)globalComponentRegistry.getReloadListener(); tenantListener = (MockTenantListener)globalComponentRegistry.getTenantListener(); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRequestHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRequestHandlerTest.java index cecbab2d9ec..d517eb195a7 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRequestHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRequestHandlerTest.java @@ -24,8 +24,6 @@ import com.yahoo.vespa.config.server.host.HostRegistries; import com.yahoo.vespa.config.server.ReloadListener; import com.yahoo.vespa.config.server.ServerCache; import com.yahoo.vespa.config.server.TestComponentRegistry; -import com.yahoo.vespa.config.server.TestConfigDefinitionRepo; -import com.yahoo.vespa.config.server.TestWithCurator; import com.yahoo.vespa.config.server.rpc.UncompressedConfigResponseFactory; import com.yahoo.vespa.config.server.application.Application; import com.yahoo.config.provision.ApplicationId; @@ -37,6 +35,8 @@ import com.yahoo.vespa.config.server.monitoring.MetricUpdater; import com.yahoo.vespa.config.server.monitoring.Metrics; import com.yahoo.vespa.config.server.session.RemoteSession; import com.yahoo.vespa.config.server.session.SessionZooKeeperClient; +import com.yahoo.vespa.curator.Curator; +import com.yahoo.vespa.curator.mock.MockCurator; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.VespaModelFactory; @@ -58,7 +58,7 @@ import static org.junit.Assert.*; /** * @author Ulf Lilleengen */ -public class TenantRequestHandlerTest extends TestWithCurator { +public class TenantRequestHandlerTest { private static final Version vespaVersion = new VespaModelFactory(new NullConfigModelRegistry()).getVersion(); private TenantRequestHandler server; @@ -67,6 +67,7 @@ public class TenantRequestHandlerTest extends TestWithCurator { private File app2 = new File("src/test/apps/cs2"); private TenantName tenant = TenantName.from("mytenant"); private TestComponentRegistry componentRegistry; + private Curator curator; @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); @@ -77,6 +78,8 @@ public class TenantRequestHandlerTest extends TestWithCurator { @Before public void setUp() throws IOException { + curator = new MockCurator(); + feedApp(app1, 1, defaultApp(), false); Metrics sh = Metrics.createTestMetrics(); List<ReloadListener> listeners = new ArrayList<>(); @@ -86,10 +89,7 @@ public class TenantRequestHandlerTest extends TestWithCurator { } private void feedApp(File appDir, long sessionId, ApplicationId appId, boolean internalRedeploy) throws IOException { - SessionZooKeeperClient zkc = new SessionZooKeeperClient(curator, configCurator, - TenantRepository.getSessionsPath(tenant).append(String.valueOf(sessionId)), - new TestConfigDefinitionRepo(), - "", Optional.empty()); + SessionZooKeeperClient zkc = new SessionZooKeeperClient(curator, TenantRepository.getSessionsPath(tenant).append(String.valueOf(sessionId))); zkc.writeApplicationId(appId); File app = tempFolder.newFolder(); IOUtils.copyDirectory(appDir, app); @@ -107,17 +107,14 @@ public class TenantRequestHandlerTest extends TestWithCurator { AllocatedHosts.withHosts(Collections.emptySet())); } - private ApplicationSet reloadConfig(long id, Clock clock) { - return reloadConfig(id, "default", clock); + private ApplicationSet reloadConfig(long sessionId, Clock clock) { + return reloadConfig(sessionId, "default", clock); } - private ApplicationSet reloadConfig(long id, String application, Clock clock) { - SessionZooKeeperClient zkc = new SessionZooKeeperClient(curator, configCurator, - TenantRepository.getSessionsPath(tenant).append(String.valueOf(id)), - new TestConfigDefinitionRepo(), - "", Optional.empty()); + private ApplicationSet reloadConfig(long sessionId, String application, Clock clock) { + SessionZooKeeperClient zkc = new SessionZooKeeperClient(curator, TenantRepository.getSessionsPath(tenant).append(String.valueOf(sessionId))); zkc.writeApplicationId(new ApplicationId.Builder().tenant(tenant).applicationName(application).build()); - RemoteSession session = new RemoteSession(tenant, id, componentRegistry, zkc, clock); + RemoteSession session = new RemoteSession(tenant, sessionId, componentRegistry, zkc, clock); return session.ensureApplicationLoaded(); } @@ -207,10 +204,7 @@ public class TenantRequestHandlerTest extends TestWithCurator { @Test public void testResolveForAppId() { long id = 1L; - SessionZooKeeperClient zkc = new SessionZooKeeperClient(curator, configCurator, - TenantRepository.getSessionsPath(tenant).append(String.valueOf(id)), - new TestConfigDefinitionRepo(), - "", Optional.empty()); + SessionZooKeeperClient zkc = new SessionZooKeeperClient(curator, TenantRepository.getSessionsPath(tenant).append(String.valueOf(id))); ApplicationId appId = new ApplicationId.Builder() .tenant(tenant) .applicationName("myapp").instanceName("myinst").build(); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java index 1975899355c..1b3afeb353b 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java @@ -4,7 +4,6 @@ package com.yahoo.vespa.config.server.tenant; import com.google.common.testing.EqualsTester; import com.yahoo.config.provision.TenantName; import com.yahoo.vespa.config.server.TestComponentRegistry; -import com.yahoo.vespa.config.server.TestWithCurator; import com.yahoo.vespa.config.server.application.MemoryTenantApplications; import org.junit.Before; import org.junit.Test; @@ -14,10 +13,9 @@ import static org.hamcrest.Matchers.is; import static org.junit.Assert.*; /** - * @author lulf - * @since 5.3 + * @author Ulf Lilleengen */ -public class TenantTest extends TestWithCurator { +public class TenantTest { private final TestComponentRegistry componentRegistry = new TestComponentRegistry.Builder().build(); private Tenant t1; diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TestWithTenant.java b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TestWithTenant.java deleted file mode 100644 index 67fb320d821..00000000000 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TestWithTenant.java +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.config.server.tenant; - -import com.yahoo.vespa.config.server.TestComponentRegistry; -import com.yahoo.vespa.config.server.TestWithCurator; -import org.junit.Before; - -/** - * Utility for a test using a single default tenant. - * - * @author lulf - * @since 5.35 - */ -public class TestWithTenant extends TestWithCurator { - - protected TenantRepository tenantRepository; - protected Tenant tenant; - - @Before - public void setupTenant() throws Exception { - tenantRepository = new TenantRepository(new TestComponentRegistry.Builder().curator(curator).build()); - tenant = tenantRepository.defaultTenant(); - } - -} diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/InitializedCounterTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/InitializedCounterTest.java index b444e09f558..888cbb7a68b 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/InitializedCounterTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/InitializedCounterTest.java @@ -1,11 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server.zookeeper; -import com.yahoo.vespa.config.server.TestWithCurator; -import org.junit.Before; -import org.junit.Rule; +import com.yahoo.vespa.curator.mock.MockCurator; import org.junit.Test; -import org.junit.rules.TemporaryFolder; import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertThat; @@ -13,20 +10,15 @@ import static org.junit.Assert.assertThat; /** * @author Ulf Lilleengen */ -public class InitializedCounterTest extends TestWithCurator { +public class InitializedCounterTest { - @Rule - public TemporaryFolder folder = new TemporaryFolder(); - - @Before - public void setupZK() { + @Test + public void requireThatCounterIsInitializedFromNumberOfSessions() { + ConfigCurator configCurator = ConfigCurator.create(new MockCurator()); configCurator.createNode("/sessions"); configCurator.createNode("/sessions/1"); configCurator.createNode("/sessions/2"); - } - @Test - public void requireThatCounterIsInitializedFromNumberOfSessions() { InitializedCounter counter = new InitializedCounter(configCurator, "/counter", "/sessions"); assertThat(counter.counter.get(), is(2l)); } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackageTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackageTest.java index f0c74d19af9..06908dbab51 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackageTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackageTest.java @@ -21,14 +21,15 @@ import com.yahoo.config.provision.Version; import com.yahoo.config.provisioning.FlavorsConfig; import com.yahoo.path.Path; import com.yahoo.text.Utf8; -import com.yahoo.vespa.config.server.TestWithCurator; +import com.yahoo.vespa.curator.mock.MockCurator; +import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import com.yahoo.io.IOUtils; -public class ZKApplicationPackageTest extends TestWithCurator { +public class ZKApplicationPackageTest { private static final String APP = "src/test/apps/zkapp"; private static final String TEST_FLAVOR_NAME = "test-flavor"; @@ -37,9 +38,16 @@ public class ZKApplicationPackageTest extends TestWithCurator { Collections.singleton(new HostSpec("foo.yahoo.com", Collections.emptyList(), TEST_FLAVOR, Optional.empty(), Optional.of(com.yahoo.component.Version.fromString("6.0.1"))))); + private ConfigCurator configCurator; + @Rule public TemporaryFolder tmpDir = new TemporaryFolder(); + @Before + public void setup() { + configCurator = ConfigCurator.create(new MockCurator()); + } + @Test public void testBasicZKFeed() throws IOException { feed(configCurator, new File(APP)); diff --git a/container-core/src/main/java/com/yahoo/container/handler/VipStatus.java b/container-core/src/main/java/com/yahoo/container/handler/VipStatus.java index bcd6e930ee3..d7457140dae 100644 --- a/container-core/src/main/java/com/yahoo/container/handler/VipStatus.java +++ b/container-core/src/main/java/com/yahoo/container/handler/VipStatus.java @@ -6,6 +6,7 @@ import java.util.Map; import com.google.inject.Inject; import com.yahoo.container.QrSearchersConfig; +import com.yahoo.container.core.VipStatusConfig; /** * API for programmatically removing the container from VIP rotation. @@ -15,15 +16,22 @@ import com.yahoo.container.QrSearchersConfig; public class VipStatus { private final Map<Object, Boolean> clusters = new IdentityHashMap<>(); + private final VipStatusConfig vipStatusConfig; public VipStatus() { - this(null); + this(null, new VipStatusConfig(new VipStatusConfig.Builder())); } - @Inject public VipStatus(QrSearchersConfig dispatchers) { + this(dispatchers, new VipStatusConfig(new VipStatusConfig.Builder())); + } + + // TODO: Why use QrSearchersConfig here? Remove and inject ComponentRegistry<ClusterSearcher> instead? + @Inject + public VipStatus(QrSearchersConfig dispatchers, VipStatusConfig vipStatusConfig) { // the config is not used for anything, it's just a dummy to create a // dependency link to which dispatchers are used + this.vipStatusConfig = vipStatusConfig; } /** @@ -55,14 +63,14 @@ public class VipStatus { /** * Tell whether the container is connected to any active services at all. * - * @return true if at least one service or cluster is up, or if no services + * @return true if at least one service or cluster is up, or value is taken from config if no services * are registered (yet) */ public boolean isInRotation() { synchronized (clusters) { - // if no stored state, try serving + // if no stored state, use config to decide whether to serve or not if (clusters.size() == 0) { - return true; + return vipStatusConfig.initiallyInRotation(); } for (Boolean inRotation : clusters.values()) { if (inRotation) { diff --git a/container-core/src/main/resources/configdefinitions/vip-status.def b/container-core/src/main/resources/configdefinitions/vip-status.def index 44da7292f05..1e364419ab8 100644 --- a/container-core/src/main/resources/configdefinitions/vip-status.def +++ b/container-core/src/main/resources/configdefinitions/vip-status.def @@ -6,9 +6,12 @@ namespace=container.core ## rotation, ignoring any status file. noSearchBackendsImpliesOutOfService bool default=true -## Whether to return hard coded reply or serve "status.html" from disk +## Whether to return hard-coded reply or serve "status.html" from disk accessdisk bool default=false ## The file to serve as the status file. -## If the paht is relative vespa home is prepended +## If the path is relative vespa home is prepended statusfile string default="share/qrsdocs/status.html" + +## The initial rotation state when no information is known about backend clusters +initiallyInRotation bool default=true diff --git a/container-dependency-versions/pom.xml b/container-dependency-versions/pom.xml index b4af6800768..f546a4e36d2 100644 --- a/container-dependency-versions/pom.xml +++ b/container-dependency-versions/pom.xml @@ -459,7 +459,7 @@ <properties> <bouncycastle.version>1.58</bouncycastle.version> - <felix.version>5.0.1</felix.version> + <felix.version>5.4.0</felix.version> <findbugs.version>1.3.9</findbugs.version> <guava.version>18.0</guava.version> <guice.version>3.0</guice.version> diff --git a/container-dev/pom.xml b/container-dev/pom.xml index 53153a05c4a..ff67e8db9fe 100644 --- a/container-dev/pom.xml +++ b/container-dev/pom.xml @@ -123,10 +123,6 @@ <version>${project.version}</version> <exclusions> <exclusion> - <groupId>org.ow2.asm</groupId> - <artifactId>asm</artifactId> - </exclusion> - <exclusion> <groupId>org.scala-lang</groupId> <artifactId>scala-library</artifactId> </exclusion> diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/metric/GarbageCollectionMetrics.java b/container-disc/src/main/java/com/yahoo/container/jdisc/metric/GarbageCollectionMetrics.java new file mode 100644 index 00000000000..04fd8572ad4 --- /dev/null +++ b/container-disc/src/main/java/com/yahoo/container/jdisc/metric/GarbageCollectionMetrics.java @@ -0,0 +1,94 @@ +package com.yahoo.container.jdisc.metric; + +import com.yahoo.jdisc.Metric; + +import java.lang.management.GarbageCollectorMXBean; +import java.lang.management.ManagementFactory; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.Map; + +/** + * @author ollivir + */ +public class GarbageCollectionMetrics { + private static final String GC_COUNT = "jdisc.gc.count"; + private static final String GC_TIME = "jdisc.gc.ms"; + private static final String DIMENSION_KEY = "gcName"; + + public static final Duration REPORTING_INTERVAL = Duration.ofSeconds(62); + + static class GcStats { + private final Instant when; + private final long count; + private final Duration totalRuntime; + + private GcStats(Instant when, long count, Duration totalRuntime) { + this.when = when; + this.count = count; + this.totalRuntime = totalRuntime; + } + } + + private Map<String, LinkedList<GcStats>> gcStatistics; + + private final Clock clock; + + public GarbageCollectionMetrics(Clock clock) { + this.clock = clock; + this.gcStatistics = new HashMap<>(); + collectGcStatistics(clock.instant()); + } + + private void collectGcStatistics(Instant now) { + for (GarbageCollectorMXBean gcBean : ManagementFactory.getGarbageCollectorMXBeans()) { + String gcName = gcBean.getName().replace(" ", ""); + GcStats stats = new GcStats(now, gcBean.getCollectionCount(), Duration.ofMillis(gcBean.getCollectionTime())); + + LinkedList<GcStats> window = gcStatistics.computeIfAbsent(gcName, anyName -> new LinkedList<>()); + window.addLast(stats); + } + } + + private void cleanStatistics(Instant now) { + Instant oldestToKeep = now.minus(REPORTING_INTERVAL); + + for(Iterator<Map.Entry<String, LinkedList<GcStats>>> it = gcStatistics.entrySet().iterator(); it.hasNext(); ) { + Map.Entry<String, LinkedList<GcStats>> entry = it.next(); + LinkedList<GcStats> history = entry.getValue(); + while(history.isEmpty() == false && oldestToKeep.isAfter(history.getFirst().when)) { + history.removeFirst(); + } + if(history.isEmpty()) { + it.remove(); + } + } + } + + public void emitMetrics(Metric metric) { + Instant now = clock.instant(); + + collectGcStatistics(now); + cleanStatistics(now); + + for (Map.Entry<String, LinkedList<GcStats>> item : gcStatistics.entrySet()) { + GcStats reference = item.getValue().getFirst(); + GcStats latest = item.getValue().getLast(); + Map<String, String> contextData = new HashMap<>(); + contextData.put(DIMENSION_KEY, item.getKey()); + Metric.Context gcContext = metric.createContext(contextData); + + metric.set(GC_COUNT, latest.count - reference.count, gcContext); + metric.set(GC_TIME, latest.totalRuntime.minus(reference.totalRuntime).toMillis(), gcContext); + } + } + + // partial exposure for testing + Map<String, LinkedList<GcStats>> getGcStatistics() { + return gcStatistics; + } +} diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java b/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java index 22b049c9ab7..c2ef789e8fc 100644 --- a/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java +++ b/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java @@ -10,6 +10,7 @@ import java.nio.file.DirectoryStream; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.time.Clock; import java.time.Duration; import java.util.Timer; import java.util.TimerTask; @@ -89,10 +90,12 @@ public class MetricUpdater extends AbstractComponent { private final Runtime runtime = Runtime.getRuntime(); private final Metric metric; private final ContainerWatchdogMetrics containerWatchdogMetrics; + private final GarbageCollectionMetrics garbageCollectionMetrics; public UpdaterTask(Metric metric, ContainerWatchdogMetrics containerWatchdogMetrics) { this.metric = metric; this.containerWatchdogMetrics = containerWatchdogMetrics; + this.garbageCollectionMetrics = new GarbageCollectionMetrics(Clock.systemUTC()); } @SuppressWarnings("deprecation") @@ -109,9 +112,10 @@ public class MetricUpdater extends AbstractComponent { metric.set(TOTAL_MEMORY_BYTES, totalMemory, null); metric.set(MEMORY_MAPPINGS_COUNT, count_mappings(), null); metric.set(OPEN_FILE_DESCRIPTORS, count_open_files(), null); + containerWatchdogMetrics.emitMetrics(metric); + garbageCollectionMetrics.emitMetrics(metric); } - } private static class TimerScheduler implements Scheduler { diff --git a/container-disc/src/test/java/com/yahoo/container/jdisc/metric/GarbageCollectionMetricsTest.java b/container-disc/src/test/java/com/yahoo/container/jdisc/metric/GarbageCollectionMetricsTest.java new file mode 100644 index 00000000000..61d8763b852 --- /dev/null +++ b/container-disc/src/test/java/com/yahoo/container/jdisc/metric/GarbageCollectionMetricsTest.java @@ -0,0 +1,57 @@ +package com.yahoo.container.jdisc.metric; + +import com.yahoo.jdisc.Metric; +import com.yahoo.test.ManualClock; +import org.junit.Test; + +import java.lang.management.ManagementFactory; +import java.time.Duration; +import java.util.LinkedList; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * @author ollivir + */ +public class GarbageCollectionMetricsTest { + @Test + public void gc_metrics_are_collected_in_a_sliding_window() { + ManualClock clock = new ManualClock(); + Metric metric = mock(Metric.class); + int garbageCollectors = ManagementFactory.getGarbageCollectorMXBeans().size(); + + Duration interval = GarbageCollectionMetrics.REPORTING_INTERVAL; + GarbageCollectionMetrics garbageCollectionMetrics = new GarbageCollectionMetrics(clock); + assertThat(garbageCollectionMetrics.getGcStatistics().keySet().size(), is(garbageCollectors)); + + clock.advance(interval.minus(Duration.ofMillis(10))); + garbageCollectionMetrics.emitMetrics(metric); + assertWindowLengths(garbageCollectionMetrics, 2); + + clock.advance(Duration.ofMillis(10)); + garbageCollectionMetrics.emitMetrics(metric); + assertWindowLengths(garbageCollectionMetrics, 3); + + clock.advance(Duration.ofMillis(10)); + garbageCollectionMetrics.emitMetrics(metric); + assertWindowLengths(garbageCollectionMetrics, 3); + + clock.advance(interval); + garbageCollectionMetrics.emitMetrics(metric); + assertWindowLengths(garbageCollectionMetrics, 2); + + verify(metric, times(garbageCollectors * 4 * 2)).set(anyString(), any(), any()); + } + + private static void assertWindowLengths(GarbageCollectionMetrics gcm, int count) { + for(LinkedList<GarbageCollectionMetrics.GcStats> window: gcm.getGcStatistics().values()) { + assertThat(window.size(), is(count)); + } + } +} diff --git a/container-disc/src/test/java/com/yahoo/container/jdisc/metric/MetricUpdaterTest.java b/container-disc/src/test/java/com/yahoo/container/jdisc/metric/MetricUpdaterTest.java index f10af7593a4..e9e04eab3b4 100644 --- a/container-disc/src/test/java/com/yahoo/container/jdisc/metric/MetricUpdaterTest.java +++ b/container-disc/src/test/java/com/yahoo/container/jdisc/metric/MetricUpdaterTest.java @@ -5,6 +5,7 @@ import com.yahoo.jdisc.Metric; import com.yahoo.jdisc.statistics.ContainerWatchdogMetrics; import org.junit.Test; +import java.lang.management.ManagementFactory; import java.time.Duration; import static org.mockito.Matchers.any; @@ -20,11 +21,13 @@ public class MetricUpdaterTest { @Test public void metrics_are_updated_in_scheduler_cycle() throws InterruptedException { + int gcCount = ManagementFactory.getGarbageCollectorMXBeans().size(); + Metric metric = mock(Metric.class); ContainerWatchdogMetrics containerWatchdogMetrics = mock(ContainerWatchdogMetrics.class); new MetricUpdater(new MockScheduler(), metric, containerWatchdogMetrics); verify(containerWatchdogMetrics, times(1)).emitMetrics(any()); - verify(metric, times(8)).set(anyString(), any(), any()); + verify(metric, times(8 + 2 * gcCount)).set(anyString(), any(), any()); } private static class MockScheduler implements MetricUpdater.Scheduler { diff --git a/container-jersey2/pom.xml b/container-jersey2/pom.xml index 26dfa762032..c5ed7d872bf 100644 --- a/container-jersey2/pom.xml +++ b/container-jersey2/pom.xml @@ -52,11 +52,6 @@ <dependency> <groupId>org.ow2.asm</groupId> <artifactId>asm</artifactId> - <version>5.0.3</version> - </dependency> - <dependency> - <groupId>org.scala-lang</groupId> - <artifactId>scala-library</artifactId> </dependency> </dependencies> <build> diff --git a/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/ComponentGraphProvider.java b/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/ComponentGraphProvider.java new file mode 100644 index 00000000000..7ff9646cb27 --- /dev/null +++ b/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/ComponentGraphProvider.java @@ -0,0 +1,73 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.container.servlet.jersey; + +import com.yahoo.container.di.config.ResolveDependencyException; +import com.yahoo.container.di.config.RestApiContext; +import com.yahoo.container.jaxrs.annotation.Component; +import org.glassfish.hk2.api.Injectee; +import org.glassfish.hk2.api.InjectionResolver; +import org.glassfish.hk2.api.ServiceHandle; + +import javax.inject.Singleton; + +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Resolves jdisc container components for jersey 2 components. + * + * @author Tony Vaagenes + * @author ollivir + */ +@Singleton // jersey2 requirement: InjectionResolvers must be in the Singleton scope +public class ComponentGraphProvider implements InjectionResolver<Component> { + private Collection<RestApiContext.Injectable> injectables; + + public ComponentGraphProvider(Collection<RestApiContext.Injectable> injectables) { + this.injectables = injectables; + } + + @Override + public Object resolve(Injectee injectee, ServiceHandle<?> root) { + Class<?> wantedClass; + Type type = injectee.getRequiredType(); + if (type instanceof Class) { + wantedClass = (Class<?>) type; + } else { + throw new UnsupportedOperationException("Only classes are supported, got " + type); + } + + List<RestApiContext.Injectable> componentsWithMatchingType = new ArrayList<>(); + for (RestApiContext.Injectable injectable : injectables) { + if (wantedClass.isInstance(injectable.instance)) { + componentsWithMatchingType.add(injectable); + } + } + + if (componentsWithMatchingType.size() == 1) { + return componentsWithMatchingType.get(0).instance; + } else { + String injectionDescription = "class '" + wantedClass + "' to inject into Jersey resource/provider '" + + injectee.getInjecteeClass() + "')"; + if (componentsWithMatchingType.size() > 1) { + String ids = componentsWithMatchingType.stream().map(c -> c.id.toString()).collect(Collectors.joining(",")); + throw new ResolveDependencyException("Multiple components found of " + injectionDescription + ": " + ids); + } else { + throw new ResolveDependencyException("Could not find a component of " + injectionDescription + "."); + } + } + } + + @Override + public boolean isMethodParameterIndicator() { + return true; + } + + @Override + public boolean isConstructorParameterIndicator() { + return true; + } +} diff --git a/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/JerseyApplication.java b/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/JerseyApplication.java new file mode 100644 index 00000000000..4c4e43bc8d5 --- /dev/null +++ b/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/JerseyApplication.java @@ -0,0 +1,25 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.container.servlet.jersey; + +import javax.ws.rs.core.Application; + +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; + +/** + * @author Tony Vaagenes + * @author ollivir + */ +public class JerseyApplication extends Application { + private Set<Class<?>> classes; + + public JerseyApplication(Collection<Class<?>> resourcesAndProviderClasses) { + this.classes = new HashSet<>(resourcesAndProviderClasses); + } + + @Override + public Set<Class<?>> getClasses() { + return classes; + } +} diff --git a/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/JerseyServletProvider.java b/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/JerseyServletProvider.java new file mode 100644 index 00000000000..1dbe410ba54 --- /dev/null +++ b/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/JerseyServletProvider.java @@ -0,0 +1,118 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.container.servlet.jersey; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import com.fasterxml.jackson.jaxrs.json.JacksonJaxbJsonProvider; +import com.yahoo.container.di.componentgraph.Provider; +import com.yahoo.container.di.config.RestApiContext; +import com.yahoo.container.di.config.RestApiContext.BundleInfo; +import com.yahoo.container.jaxrs.annotation.Component; +import org.eclipse.jetty.servlet.ServletHolder; +import org.glassfish.hk2.api.InjectionResolver; +import org.glassfish.hk2.api.TypeLiteral; +import org.glassfish.hk2.utilities.Binder; +import org.glassfish.hk2.utilities.binding.AbstractBinder; +import org.glassfish.jersey.media.multipart.MultiPartFeature; +import org.glassfish.jersey.server.ResourceConfig; +import org.glassfish.jersey.servlet.ServletContainer; +import org.objectweb.asm.ClassReader; + +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Optional; + +import static com.yahoo.container.servlet.jersey.util.ResourceConfigUtil.registerComponent; + +/** + * @author Tony Vaagenes + * @author ollivir + */ +public class JerseyServletProvider implements Provider<ServletHolder> { + private final ServletHolder jerseyServletHolder; + + public JerseyServletProvider(RestApiContext restApiContext) { + this.jerseyServletHolder = new ServletHolder(new ServletContainer(resourceConfig(restApiContext))); + } + + private ResourceConfig resourceConfig(RestApiContext restApiContext) { + final ResourceConfig resourceConfig = ResourceConfig + .forApplication(new JerseyApplication(resourcesAndProviders(restApiContext.getBundles()))); + + registerComponent(resourceConfig, componentInjectorBinder(restApiContext)); + registerComponent(resourceConfig, jacksonDatatypeJdk8Provider()); + resourceConfig.register(MultiPartFeature.class); + + return resourceConfig; + } + + private static Collection<Class<?>> resourcesAndProviders(Collection<BundleInfo> bundles) { + final List<Class<?>> ret = new ArrayList<>(); + + for (BundleInfo bundle : bundles) { + for (String classEntry : bundle.getClassEntries()) { + Optional<String> className = detectResourceOrProvider(bundle.classLoader, classEntry); + className.ifPresent(cname -> ret.add(loadClass(bundle.symbolicName, bundle.classLoader, cname))); + } + } + return ret; + } + + private static Optional<String> detectResourceOrProvider(ClassLoader bundleClassLoader, String classEntry) { + try (InputStream inputStream = getResourceAsStream(bundleClassLoader, classEntry)) { + ResourceOrProviderClassVisitor visitor = ResourceOrProviderClassVisitor.visit(new ClassReader(inputStream)); + return visitor.getJerseyClassName(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static InputStream getResourceAsStream(ClassLoader bundleClassLoader, String classEntry) { + InputStream is = bundleClassLoader.getResourceAsStream(classEntry); + if (is == null) { + throw new RuntimeException("No entry " + classEntry + " in bundle " + bundleClassLoader); + } else { + return is; + } + } + + private static Class<?> loadClass(String bundleSymbolicName, ClassLoader classLoader, String className) { + try { + return classLoader.loadClass(className); + } catch (Exception e) { + throw new RuntimeException("Failed loading class " + className + " from bundle " + bundleSymbolicName, e); + } + } + + private static Binder componentInjectorBinder(RestApiContext restApiContext) { + final ComponentGraphProvider componentGraphProvider = new ComponentGraphProvider(restApiContext.getInjectableComponents()); + final TypeLiteral<InjectionResolver<Component>> componentAnnotationType = new TypeLiteral<InjectionResolver<Component>>() { + }; + + return new AbstractBinder() { + @Override + public void configure() { + bind(componentGraphProvider).to(componentAnnotationType); + } + }; + } + + private static JacksonJaxbJsonProvider jacksonDatatypeJdk8Provider() { + JacksonJaxbJsonProvider provider = new JacksonJaxbJsonProvider(); + provider.setMapper(new ObjectMapper().registerModule(new Jdk8Module()).registerModule(new JavaTimeModule())); + return provider; + } + + @Override + public ServletHolder get() { + return jerseyServletHolder; + } + + @Override + public void deconstruct() { + } +} diff --git a/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/ResourceOrProviderClassVisitor.java b/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/ResourceOrProviderClassVisitor.java new file mode 100644 index 00000000000..7cb47ac6118 --- /dev/null +++ b/container-jersey2/src/main/java/com/yahoo/container/servlet/jersey/ResourceOrProviderClassVisitor.java @@ -0,0 +1,103 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.container.servlet.jersey; + +import org.objectweb.asm.AnnotationVisitor; +import org.objectweb.asm.ClassReader; +import org.objectweb.asm.ClassVisitor; +import org.objectweb.asm.Opcodes; +import org.objectweb.asm.Type; + +import javax.ws.rs.Path; +import javax.ws.rs.ext.Provider; + +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; + +/** + * @author Tony Vaagenes + * @author ollivir + */ +public class ResourceOrProviderClassVisitor extends ClassVisitor { + private String className = null; + private boolean isPublic = false; + private boolean isAbstract = false; + + private boolean isInnerClass = false; + private boolean isStatic = false; + + private boolean isAnnotated = false; + + public ResourceOrProviderClassVisitor() { + super(Opcodes.ASM6); + } + + public Optional<String> getJerseyClassName() { + if (isJerseyClass()) { + return Optional.of(getClassName()); + } else { + return Optional.empty(); + } + } + + public boolean isJerseyClass() { + return isAnnotated && isPublic && !isAbstract && (!isInnerClass || isStatic); + } + + public String getClassName() { + assert (className != null); + return org.objectweb.asm.Type.getObjectType(className).getClassName(); + } + + @Override + public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) { + isPublic = isPublic(access); + className = name; + isAbstract = isAbstract(access); + } + + @Override + public void visitInnerClass(String name, String outerName, String innerName, int access) { + assert (className != null); + + if (name.equals(className)) { + isInnerClass = true; + isStatic = isStatic(access); + } + } + + @Override + public AnnotationVisitor visitAnnotation(String desc, boolean visible) { + isAnnotated |= annotationClassDescriptors.contains(desc); + return null; + } + + private static Set<String> annotationClassDescriptors = new HashSet<>(); + + static { + annotationClassDescriptors.add(Type.getDescriptor(Path.class)); + annotationClassDescriptors.add(Type.getDescriptor(Provider.class)); + } + + private static boolean isPublic(int access) { + return isSet(Opcodes.ACC_PUBLIC, access); + } + + private static boolean isStatic(int access) { + return isSet(Opcodes.ACC_STATIC, access); + } + + private static boolean isAbstract(int access) { + return isSet(Opcodes.ACC_ABSTRACT, access); + } + + private static boolean isSet(int bits, int access) { + return (access & bits) == bits; + } + + public static ResourceOrProviderClassVisitor visit(ClassReader classReader) { + ResourceOrProviderClassVisitor visitor = new ResourceOrProviderClassVisitor(); + classReader.accept(visitor, ClassReader.SKIP_DEBUG | ClassReader.SKIP_CODE | ClassReader.SKIP_FRAMES); + return visitor; + } +} diff --git a/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/ComponentGraphProvider.scala b/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/ComponentGraphProvider.scala deleted file mode 100644 index cabde3680a4..00000000000 --- a/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/ComponentGraphProvider.scala +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.container.servlet.jersey - -import javax.inject.Singleton - -import com.yahoo.container.di.config.{ResolveDependencyException, RestApiContext} -import com.yahoo.container.jaxrs.annotation.Component -import org.glassfish.hk2.api.{ServiceHandle, Injectee, InjectionResolver} - -/** - * Resolves jdisc container components for jersey 2 components. - * Similar to Gjoran's ComponentGraphProvider for jersey 1. - * @author tonytv - */ -@Singleton //jersey2 requirement: InjectionResolvers must be in the Singleton scope -class ComponentGraphProvider(injectables: Traversable[RestApiContext.Injectable]) extends InjectionResolver[Component] { - override def resolve(injectee: Injectee, root: ServiceHandle[_]): AnyRef = { - val wantedClass = injectee.getRequiredType match { - case c: Class[_] => c - case unsupported => throw new UnsupportedOperationException("Only classes are supported, got " + unsupported) - } - - val componentsWithMatchingType = injectables.filter{ injectable => - wantedClass.isInstance(injectable.instance) } - - val injectionDescription = - s"class '$wantedClass' to inject into Jersey resource/provider '${injectee.getInjecteeClass}')" - - if (componentsWithMatchingType.size > 1) - throw new ResolveDependencyException(s"Multiple components found of $injectionDescription: " + - componentsWithMatchingType.map(_.id).mkString(",")) - - componentsWithMatchingType.headOption.map(_.instance).getOrElse { - throw new ResolveDependencyException(s"Could not find a component of $injectionDescription.") - } - } - - override def isMethodParameterIndicator: Boolean = true - override def isConstructorParameterIndicator: Boolean = true -} diff --git a/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/JerseyApplication.scala b/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/JerseyApplication.scala deleted file mode 100644 index eea41003984..00000000000 --- a/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/JerseyApplication.scala +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.container.servlet.jersey - -import javax.ws.rs.core.Application - -import scala.collection.JavaConverters._ - -/** - * @author tonytv - */ -class JerseyApplication(resourcesAndProviderClasses: Set[Class[_]]) extends Application { - private val classes: java.util.Set[Class[_]] = resourcesAndProviderClasses.asJava - - override def getClasses = classes - override def getSingletons = super.getSingletons -} diff --git a/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/JerseyServletProvider.scala b/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/JerseyServletProvider.scala deleted file mode 100644 index f0eff54dc16..00000000000 --- a/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/JerseyServletProvider.scala +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.container.servlet.jersey - -import java.io.{IOException, InputStream} - -import com.fasterxml.jackson.databind.ObjectMapper -import com.fasterxml.jackson.datatype.jdk8.Jdk8Module -import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule -import com.fasterxml.jackson.jaxrs.json.JacksonJaxbJsonProvider -import com.yahoo.container.di.componentgraph.Provider -import com.yahoo.container.di.config.RestApiContext -import com.yahoo.container.di.config.RestApiContext.BundleInfo -import com.yahoo.container.jaxrs.annotation.Component -import com.yahoo.container.servlet.jersey.util.ResourceConfigUtil.registerComponent -import org.eclipse.jetty.servlet.ServletHolder -import org.glassfish.hk2.api.{InjectionResolver, TypeLiteral} -import org.glassfish.hk2.utilities.Binder -import org.glassfish.hk2.utilities.binding.AbstractBinder -import org.glassfish.jersey.media.multipart.MultiPartFeature -import org.glassfish.jersey.server.ResourceConfig -import org.glassfish.jersey.servlet.ServletContainer -import org.objectweb.asm.ClassReader - -import scala.collection.JavaConverters._ -import scala.util.control.Exception - - -/** - * @author tonytv - */ -class JerseyServletProvider(restApiContext: RestApiContext) extends Provider[ServletHolder] { - private val jerseyServletHolder = new ServletHolder(new ServletContainer(resourceConfig(restApiContext))) - - private def resourceConfig(restApiContext: RestApiContext) = { - val resourceConfig = ResourceConfig.forApplication( - new JerseyApplication(resourcesAndProviders(restApiContext.getBundles.asScala))) - - registerComponent(resourceConfig, componentInjectorBinder(restApiContext)) - registerComponent(resourceConfig, jacksonDatatypeJdk8Provider) - resourceConfig.register(classOf[MultiPartFeature]) - - resourceConfig - } - - def resourcesAndProviders(bundles: Traversable[BundleInfo]) = - (for { - bundle <- bundles.view - classEntry <- bundle.getClassEntries.asScala - className <- detectResourceOrProvider(bundle.classLoader, classEntry) - } yield loadClass(bundle.symbolicName, bundle.classLoader, className)).toSet - - - def detectResourceOrProvider(bundleClassLoader: ClassLoader, classEntry: String): Option[String] = { - using(getResourceAsStream(bundleClassLoader, classEntry)) { inputStream => - val visitor = ResourceOrProviderClassVisitor.visit(new ClassReader(inputStream)) - visitor.getJerseyClassName - } - } - - private def getResourceAsStream(bundleClassLoader: ClassLoader, classEntry: String) = { - bundleClassLoader.getResourceAsStream(classEntry) match { - case null => throw new RuntimeException(s"No entry $classEntry in bundle $bundleClassLoader") - case stream => stream - } - - } - - def using[T <: InputStream, R](stream: T)(f: T => R): R = { - try { - f(stream) - } finally { - Exception.ignoring(classOf[IOException]) { - stream.close() - } - } - } - - def loadClass(bundleSymbolicName: String, classLoader: ClassLoader, className: String) = { - try { - classLoader.loadClass(className) - } catch { - case e: Exception => throw new RuntimeException(s"Failed loading class $className from bundle $bundleSymbolicName", e) - } - } - - def componentInjectorBinder(restApiContext: RestApiContext): Binder = { - val componentGraphProvider = new ComponentGraphProvider(restApiContext.getInjectableComponents.asScala) - val componentAnnotationType = new TypeLiteral[InjectionResolver[Component]] {} - - new AbstractBinder { - override def configure() { - bind(componentGraphProvider).to(componentAnnotationType) - } - } - } - - def jacksonDatatypeJdk8Provider: JacksonJaxbJsonProvider = { - val provider = new JacksonJaxbJsonProvider() - provider.setMapper( - new ObjectMapper() - .registerModule(new Jdk8Module) - .registerModule(new JavaTimeModule)) - provider - } - - override def get() = jerseyServletHolder - override def deconstruct() {} -} - diff --git a/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/ResourceOrProviderClassVisitor.scala b/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/ResourceOrProviderClassVisitor.scala deleted file mode 100644 index c015f11360e..00000000000 --- a/container-jersey2/src/main/scala/com/yahoo/container/servlet/jersey/ResourceOrProviderClassVisitor.scala +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.container.servlet.jersey - -import javax.ws.rs.Path -import javax.ws.rs.ext.Provider - -import org.objectweb.asm.{ClassVisitor, Opcodes, Type, AnnotationVisitor, ClassReader} - - -/** - * @author tonytv - */ -class ResourceOrProviderClassVisitor private () extends ClassVisitor(Opcodes.ASM5) { - private var className: String = null - private var isPublic: Boolean = false - private var isAbstract = false - - private var isInnerClass: Boolean = false - private var isStatic: Boolean = false - - private var isAnnotated: Boolean = false - - def getJerseyClassName: Option[String] = { - if (isJerseyClass) Some(getClassName) - else None - } - - def isJerseyClass: Boolean = { - isAnnotated && isPublic && !isAbstract && - (!isInnerClass || isStatic) - } - - def getClassName = { - assert (className != null) - Type.getObjectType(className).getClassName - } - - override def visit(version: Int, access: Int, name: String, signature: String, superName: String, interfaces: Array[String]) { - isPublic = ResourceOrProviderClassVisitor.isPublic(access) - className = name - isAbstract = ResourceOrProviderClassVisitor.isAbstract(access) - } - - override def visitInnerClass(name: String, outerName: String, innerName: String, access: Int) { - assert (className != null) - - if (name == className) { - isInnerClass = true - isStatic = ResourceOrProviderClassVisitor.isStatic(access) - } - } - - override def visitAnnotation(desc: String, visible: Boolean): AnnotationVisitor = { - isAnnotated |= ResourceOrProviderClassVisitor.annotationClassDescriptors(desc) - null - } -} - - -object ResourceOrProviderClassVisitor { - val annotationClassDescriptors = Set(classOf[Path], classOf[Provider]) map Type.getDescriptor - - def isPublic = isSet(Opcodes.ACC_PUBLIC) _ - def isStatic = isSet(Opcodes.ACC_STATIC) _ - def isAbstract = isSet(Opcodes.ACC_ABSTRACT) _ - - private def isSet(bits: Int)(access: Int): Boolean = (access & bits) == bits - - def visit(classReader: ClassReader): ResourceOrProviderClassVisitor = { - val visitor = new ResourceOrProviderClassVisitor - classReader.accept(visitor, ClassReader.SKIP_DEBUG | ClassReader.SKIP_CODE | ClassReader.SKIP_FRAMES) - visitor - } -} diff --git a/container/pom.xml b/container/pom.xml index d252a5eee4a..32a7947d6d5 100644 --- a/container/pom.xml +++ b/container/pom.xml @@ -21,6 +21,12 @@ <groupId>com.yahoo.vespa</groupId> <artifactId>container-dev</artifactId> <version>${project.version}</version> + <exclusions> + <exclusion> + <groupId>org.ow2.asm</groupId> + <artifactId>asm</artifactId> + </exclusion> + </exclusions> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/Organization.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/Organization.java index 00c0d87554a..776002f31cb 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/Organization.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/organization/Organization.java @@ -3,7 +3,6 @@ package com.yahoo.vespa.hosted.controller.api.integration.organization; import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; -import java.io.UncheckedIOException; import java.net.URI; import java.time.Duration; import java.util.List; @@ -87,8 +86,9 @@ public interface Organization { * * @param issueId ID of the issue to escalate. * @param propertyId PropertyId of the tenant owning the application for which the issue was filed. + * @return User that was assigned issue as a result of the escalation, if any */ - default boolean escalate(IssueId issueId, PropertyId propertyId) { + default Optional<User> escalate(IssueId issueId, PropertyId propertyId) { List<? extends List<? extends User>> contacts = contactsFor(propertyId); Optional<User> assignee = assigneeOf(issueId); @@ -101,9 +101,9 @@ public interface Organization { for (int level = assigneeLevel + 1; level < contacts.size(); level++) for (User target : contacts.get(level)) if (reassign(issueId, target)) - return true; + return Optional.of(target); - return false; + return Optional.empty(); } /** diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Application.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Application.java index 295b1adbca9..295e0102782 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Application.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/Application.java @@ -10,6 +10,7 @@ import com.yahoo.config.provision.Environment; import com.yahoo.vespa.hosted.controller.api.integration.MetricsService.ApplicationMetrics; import com.yahoo.vespa.hosted.controller.api.integration.organization.IssueId; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; +import com.yahoo.vespa.hosted.controller.application.ApplicationActivity; import com.yahoo.vespa.hosted.controller.application.ApplicationRotation; import com.yahoo.vespa.hosted.controller.application.ApplicationVersion; import com.yahoo.vespa.hosted.controller.application.Change; @@ -142,14 +143,21 @@ public class Application { */ public Change outstandingChange() { return outstandingChange; } + /** Returns ID of the last ownership issue filed for this */ public Optional<IssueId> ownershipIssueId() { return ownershipIssueId; } + /** Returns metrics for this */ public ApplicationMetrics metrics() { return metrics; } + /** Returns activity for this */ + public ApplicationActivity activity() { + return ApplicationActivity.from(deployments.values()); + } + /** * Returns the oldest platform version this has deployed in a permanent zone (not test or staging). */ diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java index 8b0dc35e16b..f0e278c3e6d 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java @@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableList; import com.yahoo.component.Version; import com.yahoo.config.application.api.DeploymentSpec; import com.yahoo.config.application.api.ValidationId; +import com.yahoo.config.application.api.ValidationOverrides; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.Environment; import com.yahoo.config.provision.TenantName; @@ -109,7 +110,7 @@ public class ApplicationController { this.artifactRepository = artifactRepository; this.rotationRepository = new RotationRepository(rotationsConfig, this, curator); - this.deploymentTrigger = new DeploymentTrigger(controller, curator, buildService, clock); + this.deploymentTrigger = new DeploymentTrigger(controller, buildService, clock); for (Application application : curator.readApplications()) { lockIfPresent(application.id(), this::store); @@ -256,7 +257,7 @@ public class ApplicationController { LockedApplication application = new LockedApplication(new Application(id), lock); store(application); log.info("Created " + application); - return application; + return application.get(); } } @@ -285,7 +286,7 @@ public class ApplicationController { } else { JobType jobType = JobType.from(controller.system(), zone) .orElseThrow(() -> new IllegalArgumentException("No job found for zone " + zone)); - Optional<JobStatus> job = Optional.ofNullable(application.deploymentJobs().jobStatus().get(jobType)); + Optional<JobStatus> job = Optional.ofNullable(application.get().deploymentJobs().jobStatus().get(jobType)); if ( ! job.isPresent() || ! job.get().lastTriggered().isPresent() || job.get().lastCompleted().isPresent() && job.get().lastCompleted().get().at().isAfter(job.get().lastTriggered().get().at())) @@ -297,8 +298,8 @@ public class ApplicationController { applicationVersion = preferOldestVersion ? triggered.sourceApplication().orElse(triggered.application()) : triggered.application(); - applicationPackage = new ApplicationPackage(artifactRepository.getApplicationPackage(application.id(), applicationVersion.id())); - validateRun(application, zone, platformVersion, applicationVersion); + applicationPackage = new ApplicationPackage(artifactRepository.getApplicationPackage(application.get().id(), applicationVersion.id())); + validateRun(application.get(), zone, platformVersion, applicationVersion); } validate(applicationPackage.deploymentSpec()); @@ -323,7 +324,7 @@ public class ApplicationController { application = withRotation(application, zone); Set<String> rotationNames = new HashSet<>(); Set<String> cnames = new HashSet<>(); - application.rotation().ifPresent(applicationRotation -> { + application.get().rotation().ifPresent(applicationRotation -> { rotationNames.add(applicationRotation.id().asString()); cnames.add(applicationRotation.dnsName()); cnames.add(applicationRotation.secureDnsName()); @@ -366,15 +367,15 @@ public class ApplicationController { /** Makes sure the application has a global rotation, if eligible. */ private LockedApplication withRotation(LockedApplication application, ZoneId zone) { - if (zone.environment() == Environment.prod && application.deploymentSpec().globalServiceId().isPresent()) { + if (zone.environment() == Environment.prod && application.get().deploymentSpec().globalServiceId().isPresent()) { try (RotationLock rotationLock = rotationRepository.lock()) { - Rotation rotation = rotationRepository.getOrAssignRotation(application, rotationLock); + Rotation rotation = rotationRepository.getOrAssignRotation(application.get(), rotationLock); application = application.with(rotation.id()); store(application); // store assigned rotation even if deployment fails - registerRotationInDns(rotation, application.rotation().get().dnsName()); - registerRotationInDns(rotation, application.rotation().get().secureDnsName()); - registerRotationInDns(rotation, application.rotation().get().oathDnsName()); + registerRotationInDns(rotation, application.get().rotation().get().dnsName()); + registerRotationInDns(rotation, application.get().rotation().get().secureDnsName()); + registerRotationInDns(rotation, application.get().rotation().get().oathDnsName()); } } return application; @@ -394,22 +395,23 @@ public class ApplicationController { } private LockedApplication deleteRemovedDeployments(LockedApplication application) { - List<Deployment> deploymentsToRemove = application.productionDeployments().values().stream() - .filter(deployment -> ! application.deploymentSpec().includes(deployment.zone().environment(), - Optional.of(deployment.zone().region()))) + List<Deployment> deploymentsToRemove = application.get().productionDeployments().values().stream() + .filter(deployment -> ! application.get().deploymentSpec().includes(deployment.zone().environment(), + Optional.of(deployment.zone().region()))) .collect(Collectors.toList()); if (deploymentsToRemove.isEmpty()) return application; - if ( ! application.validationOverrides().allows(ValidationId.deploymentRemoval, clock.instant())) - throw new IllegalArgumentException(ValidationId.deploymentRemoval.value() + ": " + application + + if ( ! application.get().validationOverrides().allows(ValidationId.deploymentRemoval, clock.instant())) + throw new IllegalArgumentException(ValidationId.deploymentRemoval.value() + ": " + application.get() + " is deployed in " + deploymentsToRemove.stream() .map(deployment -> deployment.zone().region().value()) .collect(Collectors.joining(", ")) + ", but does not include " + (deploymentsToRemove.size() > 1 ? "these zones" : "this zone") + - " in deployment.xml"); + " in deployment.xml. " + + ValidationOverrides.toAllowMessage(ValidationId.deploymentRemoval)); LockedApplication applicationWithRemoval = application; for (Deployment deployment : deploymentsToRemove) @@ -418,10 +420,11 @@ public class ApplicationController { } private LockedApplication deleteUnreferencedDeploymentJobs(LockedApplication application) { - for (JobType job : application.deploymentJobs().jobStatus().keySet()) { + for (JobType job : application.get().deploymentJobs().jobStatus().keySet()) { Optional<ZoneId> zone = job.zone(controller.system()); - if ( ! job.isProduction() || (zone.isPresent() && application.deploymentSpec().includes(zone.get().environment(), zone.map(ZoneId::region)))) + if ( ! job.isProduction() || (zone.isPresent() && application.get().deploymentSpec().includes( + zone.get().environment(), zone.map(ZoneId::region)))) continue; application = application.withoutDeploymentJob(job); } @@ -493,7 +496,7 @@ public class ApplicationController { // TODO: Make this one transaction when database is moved to ZooKeeper instances.forEach(id -> lockOrThrow(id, application -> { - if ( ! application.deployments().isEmpty()) + if ( ! application.get().deployments().isEmpty()) throw new IllegalArgumentException("Could not delete '" + application + "': It has active deployments"); Tenant tenant = controller.tenants().tenant(id.tenant()).get(); @@ -518,7 +521,7 @@ public class ApplicationController { * @param application a locked application to store */ public void store(LockedApplication application) { - curator.writeApplication(application); + curator.writeApplication(application.get()); } /** @@ -572,7 +575,7 @@ public class ApplicationController { */ private LockedApplication deactivate(LockedApplication application, ZoneId zone) { try { - configServer.deactivate(new DeploymentId(application.id(), zone)); + configServer.deactivate(new DeploymentId(application.get().id(), zone)); } catch (NoInstanceException ignored) { // ok; already gone diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedApplication.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedApplication.java index 913adf06f22..3207d4b8399 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedApplication.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedApplication.java @@ -26,17 +26,29 @@ import com.yahoo.vespa.hosted.controller.rotation.RotationId; import java.time.Instant; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.OptionalLong; /** - * A combination of an application instance and a lock for that application. Provides methods for updating application - * fields. + * An application that has been locked for modification. Provides methods for modifying an application's fields. * * @author mpolden * @author jvenstad */ -public class LockedApplication extends Application { +public class LockedApplication { + + private final Lock lock; + private final ApplicationId id; + private final DeploymentSpec deploymentSpec; + private final ValidationOverrides validationOverrides; + private final Map<ZoneId, Deployment> deployments; + private final DeploymentJobs deploymentJobs; + private final Change change; + private final Change outstandingChange; + private final Optional<IssueId> ownershipIssueId; + private final ApplicationMetrics metrics; + private final Optional<RotationId> rotation; /** * Used to create a locked application @@ -44,180 +56,172 @@ public class LockedApplication extends Application { * @param application The application to lock. * @param lock The lock for the application. */ - LockedApplication(Application application, @SuppressWarnings("unused") Lock lock) { - this(new Builder(application)); - } - - private LockedApplication(Builder builder) { - super(builder.applicationId, builder.deploymentSpec, builder.validationOverrides, - builder.deployments, builder.deploymentJobs, builder.deploying, - builder.outstandingChange, builder.ownershipIssueId, builder.metrics, builder.rotation); + LockedApplication(Application application, Lock lock) { + this(Objects.requireNonNull(lock, "lock cannot be null"), application.id(), + application.deploymentSpec(), application.validationOverrides(), + application.deployments(), + application.deploymentJobs(), application.change(), application.outstandingChange(), + application.ownershipIssueId(), application.metrics(), + application.rotation().map(ApplicationRotation::id)); + } + + private LockedApplication(Lock lock, ApplicationId id, + DeploymentSpec deploymentSpec, ValidationOverrides validationOverrides, + Map<ZoneId, Deployment> deployments, DeploymentJobs deploymentJobs, Change change, + Change outstandingChange, Optional<IssueId> ownershipIssueId, ApplicationMetrics metrics, + Optional<RotationId> rotation) { + this.lock = lock; + this.id = id; + this.deploymentSpec = deploymentSpec; + this.validationOverrides = validationOverrides; + this.deployments = deployments; + this.deploymentJobs = deploymentJobs; + this.change = change; + this.outstandingChange = outstandingChange; + this.ownershipIssueId = ownershipIssueId; + this.metrics = metrics; + this.rotation = rotation; + } + + /** Returns a read-only copy of this */ + public Application get() { + return new Application(id, deploymentSpec, validationOverrides, deployments, deploymentJobs, change, + outstandingChange, ownershipIssueId, metrics, rotation); } public LockedApplication withProjectId(OptionalLong projectId) { - return new LockedApplication(new Builder(this).with(deploymentJobs().withProjectId(projectId))); + return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments, + deploymentJobs.withProjectId(projectId), change, outstandingChange, + ownershipIssueId, metrics, rotation); } public LockedApplication withDeploymentIssueId(IssueId issueId) { - return new LockedApplication(new Builder(this).with(deploymentJobs().with(issueId))); + return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments, + deploymentJobs.with(issueId), change, outstandingChange, + ownershipIssueId, metrics, rotation); } - public LockedApplication withJobCompletion(long projectId, JobType jobType, JobStatus.JobRun completion, Optional<DeploymentJobs.JobError> jobError) { - return new LockedApplication(new Builder(this).with(deploymentJobs().withCompletion(projectId, jobType, completion, jobError)) - ); + public LockedApplication withJobCompletion(long projectId, JobType jobType, JobStatus.JobRun completion, + Optional<DeploymentJobs.JobError> jobError) { + return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments, + deploymentJobs.withCompletion(projectId, jobType, completion, jobError), + change, outstandingChange, ownershipIssueId, metrics, rotation); } public LockedApplication withJobTriggering(JobType jobType, JobStatus.JobRun job) { - return new LockedApplication(new Builder(this).with(deploymentJobs().withTriggering(jobType, job))); + return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments, + deploymentJobs.withTriggering(jobType, job), change, outstandingChange, + ownershipIssueId, metrics, rotation); } public LockedApplication withNewDeployment(ZoneId zone, ApplicationVersion applicationVersion, Version version, Instant instant) { // Use info from previous deployment if available, otherwise create a new one. - Deployment previousDeployment = deployments().getOrDefault(zone, new Deployment(zone, applicationVersion, - version, instant)); + Deployment previousDeployment = deployments.getOrDefault(zone, new Deployment(zone, applicationVersion, + version, instant)); Deployment newDeployment = new Deployment(zone, applicationVersion, version, instant, previousDeployment.clusterUtils(), previousDeployment.clusterInfo(), - previousDeployment.metrics()); + previousDeployment.metrics(), + previousDeployment.activity()); return with(newDeployment); } public LockedApplication withClusterUtilization(ZoneId zone, Map<ClusterSpec.Id, ClusterUtilization> clusterUtilization) { - Deployment deployment = deployments().get(zone); + Deployment deployment = deployments.get(zone); if (deployment == null) return this; // No longer deployed in this zone. return with(deployment.withClusterUtils(clusterUtilization)); } public LockedApplication withClusterInfo(ZoneId zone, Map<ClusterSpec.Id, ClusterInfo> clusterInfo) { - Deployment deployment = deployments().get(zone); + Deployment deployment = deployments.get(zone); if (deployment == null) return this; // No longer deployed in this zone. return with(deployment.withClusterInfo(clusterInfo)); } + public LockedApplication recordActivityAt(Instant instant, ZoneId zone) { + Deployment deployment = deployments.get(zone); + if (deployment == null) return this; + return with(deployment.recordActivityAt(instant)); + } + public LockedApplication with(ZoneId zone, DeploymentMetrics deploymentMetrics) { - Deployment deployment = deployments().get(zone); + Deployment deployment = deployments.get(zone); if (deployment == null) return this; // No longer deployed in this zone. return with(deployment.withMetrics(deploymentMetrics)); } public LockedApplication withoutDeploymentIn(ZoneId zone) { - Map<ZoneId, Deployment> deployments = new LinkedHashMap<>(deployments()); + Map<ZoneId, Deployment> deployments = new LinkedHashMap<>(this.deployments); deployments.remove(zone); - return new LockedApplication(new Builder(this).with(deployments)); + return with(deployments); } public LockedApplication withoutDeploymentJob(DeploymentJobs.JobType jobType) { - return new LockedApplication(new Builder(this).with(deploymentJobs().without(jobType))); + return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments, + deploymentJobs.without(jobType), change, outstandingChange, + ownershipIssueId, metrics, rotation); } public LockedApplication with(DeploymentSpec deploymentSpec) { - return new LockedApplication(new Builder(this).with(deploymentSpec)); + return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments, + deploymentJobs, change, outstandingChange, + ownershipIssueId, metrics, rotation); } public LockedApplication with(ValidationOverrides validationOverrides) { - return new LockedApplication(new Builder(this).with(validationOverrides)); + return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments, + deploymentJobs, change, outstandingChange, + ownershipIssueId, metrics, rotation); } public LockedApplication withChange(Change change) { - return new LockedApplication(new Builder(this).withChange(change)); + return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments, + deploymentJobs, change, outstandingChange, + ownershipIssueId, metrics, rotation); } public LockedApplication withOutstandingChange(Change outstandingChange) { - return new LockedApplication(new Builder(this).withOutstandingChange(outstandingChange)); + return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments, + deploymentJobs, change, outstandingChange, + ownershipIssueId, metrics, rotation); } public LockedApplication withOwnershipIssueId(IssueId issueId) { - return new LockedApplication(new Builder(this).withOwnershipIssueId(Optional.ofNullable(issueId))); + return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments, + deploymentJobs, change, outstandingChange, + Optional.ofNullable(issueId), metrics, rotation); } public LockedApplication with(MetricsService.ApplicationMetrics metrics) { - return new LockedApplication(new Builder(this).with(metrics)); + return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments, + deploymentJobs, change, outstandingChange, + ownershipIssueId, metrics, rotation); } public LockedApplication with(RotationId rotation) { - return new LockedApplication(new Builder(this).with(rotation)); + return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments, + deploymentJobs, change, outstandingChange, + ownershipIssueId, metrics, Optional.of(rotation)); } /** Don't expose non-leaf sub-objects. */ private LockedApplication with(Deployment deployment) { - Map<ZoneId, Deployment> deployments = new LinkedHashMap<>(deployments()); + Map<ZoneId, Deployment> deployments = new LinkedHashMap<>(this.deployments); deployments.put(deployment.zone(), deployment); - return new LockedApplication(new Builder(this).with(deployments)); - } - - private static class Builder { - - private final ApplicationId applicationId; - private DeploymentSpec deploymentSpec; - private ValidationOverrides validationOverrides; - private Map<ZoneId, Deployment> deployments; - private DeploymentJobs deploymentJobs; - private Change deploying; - private Change outstandingChange; - private Optional<IssueId> ownershipIssueId; - private ApplicationMetrics metrics; - private Optional<RotationId> rotation; - - private Builder(Application application) { - this.applicationId = application.id(); - this.deploymentSpec = application.deploymentSpec(); - this.validationOverrides = application.validationOverrides(); - this.deployments = application.deployments(); - this.deploymentJobs = application.deploymentJobs(); - this.deploying = application.change(); - this.outstandingChange = application.outstandingChange(); - this.ownershipIssueId = application.ownershipIssueId(); - this.metrics = application.metrics(); - this.rotation = application.rotation().map(ApplicationRotation::id); - } - - private Builder with(DeploymentSpec deploymentSpec) { - this.deploymentSpec = deploymentSpec; - return this; - } - - private Builder with(ValidationOverrides validationOverrides) { - this.validationOverrides = validationOverrides; - return this; - } - - private Builder with(Map<ZoneId, Deployment> deployments) { - this.deployments = deployments; - return this; - } - - private Builder with(DeploymentJobs deploymentJobs) { - this.deploymentJobs = deploymentJobs; - return this; - } - - private Builder withChange(Change deploying) { - this.deploying = deploying; - return this; - } - - private Builder withOutstandingChange(Change outstandingChange) { - this.outstandingChange = outstandingChange; - return this; - } - - private Builder withOwnershipIssueId(Optional<IssueId> ownershipIssueId) { - this.ownershipIssueId = ownershipIssueId; - return this; - } - - private Builder with(ApplicationMetrics metrics) { - this.metrics = metrics; - return this; - } - - private Builder with(RotationId rotation) { - this.rotation = Optional.of(rotation); - return this; - } + return with(deployments); + } + + private LockedApplication with(Map<ZoneId, Deployment> deployments) { + return new LockedApplication(lock, id, deploymentSpec, validationOverrides, deployments, + deploymentJobs, change, outstandingChange, + ownershipIssueId, metrics, rotation); + } + @Override + public String toString() { + return "application '" + id + "'"; } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationActivity.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationActivity.java new file mode 100644 index 00000000000..ddd519382a6 --- /dev/null +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationActivity.java @@ -0,0 +1,56 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.application; + +import java.time.Instant; +import java.util.Collection; +import java.util.Comparator; +import java.util.Optional; +import java.util.function.Function; + +/** + * Recent activity in an application. + * + * @author mpolden + */ +public class ApplicationActivity { + + public static final ApplicationActivity none = new ApplicationActivity(Optional.empty(), Optional.empty()); + + private final Optional<Instant> lastQueried; + private final Optional<Instant> lastWritten; + + private ApplicationActivity(Optional<Instant> lastQueried, Optional<Instant> lastWritten) { + this.lastQueried = lastQueried; + this.lastWritten = lastWritten; + } + + /** The last time any deployment in this was queried */ + public Optional<Instant> lastQueried() { + return lastQueried; + } + + /** The last time any deployment in this was written */ + public Optional<Instant> lastWritten() { + return lastWritten; + } + + public static ApplicationActivity from(Collection<Deployment> deployments) { + Optional<Instant> lastQueried = lastActivity(deployments, DeploymentActivity::lastQueried); + Optional<Instant> lastWritten = lastActivity(deployments, DeploymentActivity::lastWritten); + if (!lastQueried.isPresent() && !lastWritten.isPresent()) { + return none; + } + return new ApplicationActivity(lastQueried, lastWritten); + } + + private static Optional<Instant> lastActivity(Collection<Deployment> deployments, + Function<DeploymentActivity, Optional<Instant>> activityField) { + return deployments.stream() + .map(Deployment::activity) + .map(activityField) + .filter(Optional::isPresent) + .map(Optional::get) + .max(Comparator.naturalOrder()); + } + +} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationPackage.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationPackage.java index 6df8e901653..40e2e4a92d1 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationPackage.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/ApplicationPackage.java @@ -34,9 +34,8 @@ public class ApplicationPackage { * it must not be further changed by the caller. */ public ApplicationPackage(byte[] zippedContent) { - Objects.requireNonNull(zippedContent, "The application package content cannot be null"); + this.zippedContent = Objects.requireNonNull(zippedContent, "The application package content cannot be null"); this.contentHash = DigestUtils.shaHex(zippedContent); - this.zippedContent = zippedContent; this.deploymentSpec = extractFile("deployment.xml", zippedContent).map(DeploymentSpec::fromXml).orElse(DeploymentSpec.empty); this.validationOverrides = extractFile("validation-overrides.xml", zippedContent).map(ValidationOverrides::fromXml).orElse(ValidationOverrides.empty); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java index 8fa0c6da49c..0a062427a8a 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Deployment.java @@ -6,6 +6,7 @@ import com.yahoo.config.provision.ClusterSpec.Id; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import java.time.Instant; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -25,27 +26,25 @@ public class Deployment { private final Map<Id, ClusterUtilization> clusterUtils; private final Map<Id, ClusterInfo> clusterInfo; private final DeploymentMetrics metrics; + private final DeploymentActivity activity; public Deployment(ZoneId zone, ApplicationVersion applicationVersion, Version version, Instant deployTime) { - this(zone, applicationVersion, version, deployTime, new HashMap<>(), new HashMap<>(), new DeploymentMetrics()); + this(zone, applicationVersion, version, deployTime, Collections.emptyMap(), Collections.emptyMap(), + new DeploymentMetrics(), DeploymentActivity.none); } public Deployment(ZoneId zone, ApplicationVersion applicationVersion, Version version, Instant deployTime, - Map<Id, ClusterUtilization> clusterUtils, Map<Id, ClusterInfo> clusterInfo, DeploymentMetrics metrics) { - Objects.requireNonNull(zone, "zone cannot be null"); - Objects.requireNonNull(applicationVersion, "applicationVersion cannot be null"); - Objects.requireNonNull(version, "version cannot be null"); - Objects.requireNonNull(deployTime, "deployTime cannot be null"); - Objects.requireNonNull(clusterUtils, "clusterUtils cannot be null"); - Objects.requireNonNull(clusterInfo, "clusterInfo cannot be null"); - Objects.requireNonNull(metrics, "deployment metrics cannot be null"); - this.zone = zone; - this.applicationVersion = applicationVersion; - this.version = version; - this.deployTime = deployTime; - this.clusterUtils = clusterUtils; - this.clusterInfo = clusterInfo; - this.metrics = metrics; + Map<Id, ClusterUtilization> clusterUtils, Map<Id, ClusterInfo> clusterInfo, + DeploymentMetrics metrics, + DeploymentActivity activity) { + this.zone = Objects.requireNonNull(zone, "zone cannot be null"); + this.applicationVersion = Objects.requireNonNull(applicationVersion, "applicationVersion cannot be null"); + this.version = Objects.requireNonNull(version, "version cannot be null"); + this.deployTime = Objects.requireNonNull(deployTime, "deployTime cannot be null"); + this.clusterUtils = Objects.requireNonNull(clusterUtils, "clusterUtils cannot be null"); + this.clusterInfo = Objects.requireNonNull(clusterInfo, "clusterInfo cannot be null"); + this.metrics = Objects.requireNonNull(metrics, "deploymentMetrics cannot be null"); + this.activity = Objects.requireNonNull(activity, "activity cannot be null"); } /** Returns the zone this was deployed to */ @@ -60,29 +59,42 @@ public class Deployment { /** Returns the time this was deployed */ public Instant at() { return deployTime; } + /** Returns metrics for this */ + public DeploymentMetrics metrics() { + return metrics; + } + + /** Returns activity for this */ + public DeploymentActivity activity() { return activity; } + + /** Returns information about the clusters allocated to this */ public Map<Id, ClusterInfo> clusterInfo() { return clusterInfo; } + /** Returns utilization of the clusters allocated to this */ public Map<Id, ClusterUtilization> clusterUtils() { return clusterUtils; } + public Deployment recordActivityAt(Instant instant) { + return new Deployment(zone, applicationVersion, version, deployTime, clusterUtils, clusterInfo, metrics, + activity.recordAt(instant, metrics)); + } + public Deployment withClusterUtils(Map<Id, ClusterUtilization> clusterUtilization) { - return new Deployment(zone, applicationVersion, version, deployTime, clusterUtilization, clusterInfo, metrics); + return new Deployment(zone, applicationVersion, version, deployTime, clusterUtilization, clusterInfo, metrics, + activity); } public Deployment withClusterInfo(Map<Id, ClusterInfo> newClusterInfo) { - return new Deployment(zone, applicationVersion, version, deployTime, clusterUtils, newClusterInfo, metrics); + return new Deployment(zone, applicationVersion, version, deployTime, clusterUtils, newClusterInfo, metrics, + activity); } public Deployment withMetrics(DeploymentMetrics metrics) { - return new Deployment(zone, applicationVersion, version, deployTime, clusterUtils, clusterInfo, metrics); - } - - /** @return Key metrics for the deployment (application level) like QPS and document count */ - public DeploymentMetrics metrics() { - return metrics; + return new Deployment(zone, applicationVersion, version, deployTime, clusterUtils, clusterInfo, metrics, + activity); } /** @@ -109,4 +121,5 @@ public class Deployment { public String toString() { return "deployment to " + zone + " of " + applicationVersion + " on version " + version + " at " + deployTime; } + } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentActivity.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentActivity.java new file mode 100644 index 00000000000..d4635212e80 --- /dev/null +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/DeploymentActivity.java @@ -0,0 +1,55 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.application; + +import java.time.Instant; +import java.util.Objects; +import java.util.Optional; + +/** + * Recent activity in a deployment. + * + * @author mpolden + */ +public class DeploymentActivity { + + /** Query rates at or below this threshold indicate inactivity */ + private static final double inactivityThreshold = 0; + + public static final DeploymentActivity none = new DeploymentActivity(Optional.empty(), Optional.empty()); + + private final Optional<Instant> lastQueried; + private final Optional<Instant> lastWritten; + + private DeploymentActivity(Optional<Instant> lastQueried, Optional<Instant> lastWritten) { + this.lastQueried = Objects.requireNonNull(lastQueried, "lastQueried must be non-null"); + this.lastWritten = Objects.requireNonNull(lastWritten, "lastWritten must be non-null"); + } + + /** The last time this deployment received queries (search) */ + public Optional<Instant> lastQueried() { + return lastQueried; + } + + /** The last time this deployment received writes (feed) */ + public Optional<Instant> lastWritten() { + return lastWritten; + } + + /** Record activity using given metrics */ + public DeploymentActivity recordAt(Instant instant, DeploymentMetrics metrics) { + return new DeploymentActivity(activityAt(instant, lastQueried, metrics.queriesPerSecond()), + activityAt(instant, lastWritten, metrics.writesPerSecond())); + } + + public static DeploymentActivity create(Optional<Instant> queriedAt, Optional<Instant> writtenAt) { + if (!queriedAt.isPresent() && !writtenAt.isPresent()) { + return none; + } + return new DeploymentActivity(queriedAt, writtenAt); + } + + private static Optional<Instant> activityAt(Instant newInstant, Optional<Instant> oldInstant, double rate) { + return rate > inactivityThreshold ? Optional.of(newInstant) : oldInstant; + } + +} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/concurrent/Locks.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/concurrent/Locks.java deleted file mode 100644 index 6168812203a..00000000000 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/concurrent/Locks.java +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.controller.concurrent; - -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.locks.ReentrantLock; - -/** - * Holds a map of locks indexed on keys of a given type. - * This is suitable in cases where exclusive access should be granted to any one of a set of keyed objects and - * there is a finite collection of keyed objects. - * - * The returned locks are reentrant (i.e the owning thread may call lock multiple times) and auto-closable. - * - * Typical use is - * <code> - * try (Lock lock = locks.lock(id)) { - * exclusive use of the object with key id - * } - * </code> - * - * @author bratseth - */ -public class Locks<TYPE> { - - private final Map<TYPE, ReentrantLock> locks = new ConcurrentHashMap<>(); - - private final long timeoutMs; - - public Locks(int timeout, TimeUnit timeoutUnit) { - timeoutMs = timeoutUnit.toMillis(timeout); - } - - /** - * Locks key. This will block until the key is acquired. - * Users of this <b>must</b> close any lock acquired. - * - * @param key the key to lock - * @return the acquired lock - * @throws TimeoutException if the lock could not be acquired within the timeout - */ - public Lock lock(TYPE key) { - try { - ReentrantLock lock = locks.computeIfAbsent(key, k -> new ReentrantLock(true)); - boolean acquired = lock.tryLock(timeoutMs, TimeUnit.MILLISECONDS); - if ( ! acquired) - throw new TimeoutException("Timed out waiting for the lock to " + key); - return new Lock(lock); - } catch (InterruptedException e) { - throw new RuntimeException("Interrupted while waiting for lock of " + key); - } - } - -} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentOrder.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentOrder.java index 405c8d17263..1c535a5a331 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentOrder.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentOrder.java @@ -3,7 +3,6 @@ package com.yahoo.vespa.hosted.controller.deployment; import com.yahoo.config.application.api.DeploymentSpec; import com.yahoo.config.provision.SystemName; -import com.yahoo.vespa.hosted.controller.Controller; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import com.yahoo.vespa.hosted.controller.application.Deployment; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs; @@ -30,8 +29,7 @@ public class DeploymentOrder { private final Supplier<SystemName> system; public DeploymentOrder(Supplier<SystemName> system) { - Objects.requireNonNull(system, "system may not be null"); - this.system = system; + this.system = Objects.requireNonNull(system, "system may not be null"); } /** Returns jobs for given deployment spec, in the order they are declared */ @@ -46,25 +44,25 @@ public class DeploymentOrder { public List<JobStatus> sortBy(DeploymentSpec deploymentSpec, Collection<JobStatus> jobStatus) { List<DeploymentJobs.JobType> sortedJobs = jobsFrom(deploymentSpec); return jobStatus.stream() - .sorted(comparingInt(job -> sortedJobs.indexOf(job.type()))) - .collect(collectingAndThen(toList(), Collections::unmodifiableList)); + .sorted(comparingInt(job -> sortedJobs.indexOf(job.type()))) + .collect(collectingAndThen(toList(), Collections::unmodifiableList)); } /** Returns deployments sorted according to declared zones */ public List<Deployment> sortBy(List<DeploymentSpec.DeclaredZone> zones, Collection<Deployment> deployments) { List<ZoneId> productionZones = zones.stream() - .filter(z -> z.region().isPresent()) - .map(z -> ZoneId.from(z.environment(), z.region().get())) - .collect(toList()); + .filter(z -> z.region().isPresent()) + .map(z -> ZoneId.from(z.environment(), z.region().get())) + .collect(toList()); return deployments.stream() - .sorted(comparingInt(deployment -> productionZones.indexOf(deployment.zone()))) - .collect(collectingAndThen(toList(), Collections::unmodifiableList)); + .sorted(comparingInt(deployment -> productionZones.indexOf(deployment.zone()))) + .collect(collectingAndThen(toList(), Collections::unmodifiableList)); } /** Resolve job from deployment step */ public JobType toJob(DeploymentSpec.DeclaredZone zone) { return JobType.from(system.get(), zone.environment(), zone.region().orElse(null)) - .orElseThrow(() -> new IllegalArgumentException("Invalid zone " + zone)); + .orElseThrow(() -> new IllegalArgumentException("Invalid zone " + zone)); } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java index e902206ad8b..63a6ac234ff 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTrigger.java @@ -20,7 +20,6 @@ import com.yahoo.vespa.hosted.controller.application.DeploymentJobs.JobReport; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs.JobType; import com.yahoo.vespa.hosted.controller.application.JobStatus; import com.yahoo.vespa.hosted.controller.application.JobStatus.JobRun; -import com.yahoo.vespa.hosted.controller.persistence.CuratorDb; import java.time.Clock; import java.time.Duration; @@ -78,14 +77,11 @@ public class DeploymentTrigger { private final DeploymentOrder order; private final BuildService buildService; - public DeploymentTrigger(Controller controller, CuratorDb curator, BuildService buildService, Clock clock) { - Objects.requireNonNull(controller, "controller cannot be null"); - Objects.requireNonNull(curator, "curator cannot be null"); - Objects.requireNonNull(clock, "clock cannot be null"); - this.controller = controller; - this.clock = clock; + public DeploymentTrigger(Controller controller, BuildService buildService, Clock clock) { + this.controller = Objects.requireNonNull(controller, "controller cannot be null"); + this.buildService = Objects.requireNonNull(buildService, "buildService cannot be null"); + this.clock = Objects.requireNonNull(clock, "clock cannot be null"); this.order = new DeploymentOrder(controller::system); - this.buildService = buildService; } public DeploymentOrder deploymentOrder() { @@ -116,15 +112,15 @@ public class DeploymentTrigger { triggering = JobRun.triggering(controller.systemVersion(), applicationVersion, Optional .empty(), Optional.empty(), "Application commit", clock.instant()); if (report.success()) { - if (acceptNewApplicationVersion(application)) - application = application.withChange(application.change().with(applicationVersion)) + if (acceptNewApplicationVersion(application.get())) + application = application.withChange(application.get().change().with(applicationVersion)) .withOutstandingChange(Change.empty()); else application = application.withOutstandingChange(Change.of(applicationVersion)); } } else { - triggering = application.deploymentJobs().statusOf(report.jobType()).flatMap(JobStatus::lastTriggered) + triggering = application.get().deploymentJobs().statusOf(report.jobType()).flatMap(JobStatus::lastTriggered) .orElseThrow(() -> new IllegalStateException("Notified of completion of " + report.jobType().jobName() + " for " + report.applicationId() + ", but that has neither been triggered nor deployed")); } @@ -132,7 +128,7 @@ public class DeploymentTrigger { report.jobType(), triggering.completion(report.buildNumber(), clock.instant()), report.jobError()); - application = application.withChange(remainingChange(application)); + application = application.withChange(remainingChange(application.get())); applications().store(application); }); } @@ -216,9 +212,9 @@ public class DeploymentTrigger { */ public void triggerChange(ApplicationId applicationId, Change change) { applications().lockOrThrow(applicationId, application -> { - if (application.changeAt(controller.clock().instant()).isPresent() && ! application.deploymentJobs().hasFailures()) + if (application.get().changeAt(controller.clock().instant()).isPresent() && ! application.get().deploymentJobs().hasFailures()) throw new IllegalArgumentException("Could not start " + change + " on " + application + ": " + - application.change() + " is already in progress"); + application.get().change() + " is already in progress"); application = application.withChange(change); if (change.application().isPresent()) application = application.withOutstandingChange(Change.empty()); @@ -230,7 +226,7 @@ public class DeploymentTrigger { /** Cancels a platform upgrade of the given application, and an application upgrade as well if {@code keepApplicationChange}. */ public void cancelChange(ApplicationId applicationId, boolean keepApplicationChange) { applications().lockOrThrow(applicationId, application -> { - applications().store(application.withChange(application.change().application() + applications().store(application.withChange(application.get().change().application() .filter(__ -> keepApplicationChange) .map(Change::of) .orElse(Change.empty()))); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainer.java index 821efba013d..4dacb2e32d6 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainer.java @@ -15,8 +15,8 @@ import java.util.logging.Level; import java.util.logging.Logger; /** - * Retrieve deployment metrics like qps and document count from the metric service and - * update the applications with this info. + * Retrieve deployment metrics such as QPS and document count from the metric service and + * update applications with this info. * * @author smorgrav */ @@ -39,19 +39,19 @@ public class DeploymentMetricsMaintainer extends Maintainer { for (Deployment deployment : application.deployments().values()) { MetricsService.DeploymentMetrics deploymentMetrics = controller().metricsService() .getDeploymentMetrics(application.id(), deployment.zone()); - DeploymentMetrics appMetrics = new DeploymentMetrics(deploymentMetrics.queriesPerSecond(), + DeploymentMetrics newMetrics = new DeploymentMetrics(deploymentMetrics.queriesPerSecond(), deploymentMetrics.writesPerSecond(), deploymentMetrics.documentCount(), deploymentMetrics.queryLatencyMillis(), deploymentMetrics.writeLatencyMillis()); controller().applications().lockIfPresent(application.id(), lockedApplication -> - controller().applications().store(lockedApplication.with(deployment.zone(), appMetrics))); + controller().applications().store(lockedApplication.with(deployment.zone(), newMetrics) + .recordActivityAt(controller().clock().instant(), deployment.zone()))); } - } - catch (UncheckedIOException e) { + } catch (UncheckedIOException e) { if (!hasWarned) // produce only one warning per maintenance interval - log.log(Level.WARNING, "Failed talking to YAMAS: " + Exceptions.toMessageString(e) + + log.log(Level.WARNING, "Failed to query metrics service: " + Exceptions.toMessageString(e) + ". Retrying in " + maintenanceInterval()); hasWarned = true; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/Upgrader.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/Upgrader.java index bd8b8fc8747..22cbe942932 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/Upgrader.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/Upgrader.java @@ -7,7 +7,6 @@ import com.yahoo.vespa.curator.Lock; import com.yahoo.vespa.hosted.controller.Application; import com.yahoo.vespa.hosted.controller.Controller; import com.yahoo.vespa.hosted.controller.application.ApplicationList; -import com.yahoo.vespa.hosted.controller.application.Change; import com.yahoo.vespa.hosted.controller.persistence.CuratorDb; import com.yahoo.vespa.hosted.controller.versions.VespaVersion; import com.yahoo.vespa.hosted.controller.versions.VespaVersion.Confidence; @@ -18,6 +17,7 @@ import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.logging.Level; import java.util.logging.Logger; @@ -36,7 +36,7 @@ public class Upgrader extends Maintainer { public Upgrader(Controller controller, Duration interval, JobControl jobControl, CuratorDb curator) { super(controller, interval, jobControl); - this.curator = curator; + this.curator = Objects.requireNonNull(curator, "curator cannot be null"); } /** diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java index 21eea21ba68..6ad2452b2a2 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializer.java @@ -20,6 +20,7 @@ import com.yahoo.vespa.hosted.controller.application.Change; import com.yahoo.vespa.hosted.controller.application.ClusterInfo; import com.yahoo.vespa.hosted.controller.application.ClusterUtilization; import com.yahoo.vespa.hosted.controller.application.Deployment; +import com.yahoo.vespa.hosted.controller.application.DeploymentActivity; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs.JobError; import com.yahoo.vespa.hosted.controller.application.DeploymentMetrics; @@ -68,6 +69,8 @@ public class ApplicationSerializer { private final String repositoryField = "repositoryField"; private final String branchField = "branchField"; private final String commitField = "commitField"; + private final String lastQueriedField = "lastQueried"; + private final String lastWrittenField = "lastWritten"; // DeploymentJobs fields private final String projectIdField = "projectId"; @@ -148,10 +151,12 @@ public class ApplicationSerializer { toSlime(deployment.applicationVersion(), object.setObject(applicationPackageRevisionField)); clusterInfoToSlime(deployment.clusterInfo(), object); clusterUtilsToSlime(deployment.clusterUtils(), object); - metricsToSlime(deployment.metrics(), object); + deploymentMetricsToSlime(deployment.metrics(), object); + deployment.activity().lastQueried().ifPresent(instant -> object.setLong(lastQueriedField, instant.toEpochMilli())); + deployment.activity().lastWritten().ifPresent(instant -> object.setLong(lastWrittenField, instant.toEpochMilli())); } - private void metricsToSlime(DeploymentMetrics metrics, Cursor object) { + private void deploymentMetricsToSlime(DeploymentMetrics metrics, Cursor object) { Cursor root = object.setObject(deploymentMetricsField); root.setDouble(deploymentMetricsQPSField, metrics.queriesPerSecond()); root.setDouble(deploymentMetricsWPSField, metrics.writesPerSecond()); @@ -289,19 +294,17 @@ public class ApplicationSerializer { Instant.ofEpochMilli(deploymentObject.field(deployTimeField).asLong()), clusterUtilsMapFromSlime(deploymentObject.field(clusterUtilsField)), clusterInfoMapFromSlime(deploymentObject.field(clusterInfoField)), - deploymentMetricsFromSlime(deploymentObject.field(deploymentMetricsField))); + deploymentMetricsFromSlime(deploymentObject.field(deploymentMetricsField)), + DeploymentActivity.create(optionalInstant(deploymentObject.field(lastQueriedField)), + optionalInstant(deploymentObject.field(lastWrittenField)))); } private DeploymentMetrics deploymentMetricsFromSlime(Inspector object) { - - double queriesPerSecond = object.field(deploymentMetricsQPSField).asDouble(); - double writesPerSecond = object.field(deploymentMetricsWPSField).asDouble(); - double documentCount = object.field(deploymentMetricsDocsField).asDouble(); - double queryLatencyMillis = object.field(deploymentMetricsQueryLatencyField).asDouble(); - double writeLatencyMills = object.field(deploymentMetricsWriteLatencyField).asDouble(); - - return new DeploymentMetrics(queriesPerSecond, writesPerSecond, - documentCount, queryLatencyMillis, writeLatencyMills); + return new DeploymentMetrics(object.field(deploymentMetricsQPSField).asDouble(), + object.field(deploymentMetricsWPSField).asDouble(), + object.field(deploymentMetricsDocsField).asDouble(), + object.field(deploymentMetricsQueryLatencyField).asDouble(), + object.field(deploymentMetricsWriteLatencyField).asDouble()); } private Map<ClusterSpec.Id, ClusterInfo> clusterInfoMapFromSlime(Inspector object) { @@ -426,4 +429,9 @@ public class ApplicationSerializer { return SlimeUtils.optionalString(field); } + private Optional<Instant> optionalInstant(Inspector field) { + OptionalLong value = optionalLong(field); + return value.isPresent() ? Optional.of(Instant.ofEpochMilli(value.getAsLong())) : Optional.empty(); + } + } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java index 10088ba3fea..3eced6d943e 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java @@ -425,6 +425,11 @@ public class ApplicationApiHandler extends LoggingRequestHandler { metricsObject.setDouble("queryServiceQuality", application.metrics().queryServiceQuality()); metricsObject.setDouble("writeServiceQuality", application.metrics().writeServiceQuality()); + // Activity + Cursor activity = object.setObject("activity"); + application.activity().lastQueried().ifPresent(lastQueried -> activity.setLong("queriedAt", lastQueried.toEpochMilli())); + application.activity().lastWritten().ifPresent(lastQueried -> activity.setLong("writtenAt", lastQueried.toEpochMilli())); + application.ownershipIssueId().ifPresent(issueId -> object.setString("ownershipIssueId", issueId.value())); application.deploymentJobs().issueId().ifPresent(issueId -> object.setString("deploymentIssueId", issueId.value())); } @@ -468,6 +473,12 @@ public class ApplicationApiHandler extends LoggingRequestHandler { .ifPresent(i -> response.setString("screwdriverId", String.valueOf(i))); sourceRevisionToSlime(deployment.applicationVersion().source(), response); + Cursor activity = response.setObject("activity"); + deployment.activity().lastQueried().ifPresent(instant -> activity.setLong("lastQueried", + instant.toEpochMilli())); + deployment.activity().lastWritten().ifPresent(instant -> activity.setLong("lastWritten", + instant.toEpochMilli())); + // Cost DeploymentCost appCost = deployment.calculateCost(); Cursor costObject = response.setObject("cost"); @@ -672,7 +683,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { ApplicationId id = ApplicationId.from(tenantName, applicationName, "default"); controller.applications().lockOrThrow(id, application -> { - controller.applications().deploymentTrigger().triggerChange(application.id(), Change.of(version)); + controller.applications().deploymentTrigger().triggerChange(application.get().id(), Change.of(version)); }); return new MessageResponse("Triggered deployment of application '" + id + "' on version " + version); } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java index 0de153fc3f9..c24c8693688 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.hosted.controller; import com.yahoo.component.Version; import com.yahoo.config.application.api.ValidationId; +import com.yahoo.config.application.api.ValidationOverrides; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ApplicationName; import com.yahoo.config.provision.Environment; @@ -179,7 +180,9 @@ public class ControllerTest { fail("Expected exception due to illegal production deployment removal"); } catch (IllegalArgumentException e) { - assertEquals("deployment-removal: application 'tenant1.app1' is deployed in corp-us-east-1, but does not include this zone in deployment.xml", e.getMessage()); + assertEquals("deployment-removal: application 'tenant1.app1' is deployed in corp-us-east-1, but does not include this zone in deployment.xml. " + + ValidationOverrides.toAllowMessage(ValidationId.deploymentRemoval), + e.getMessage()); } assertNotNull("Zone was not removed", applications.require(app1.id()).deployments().get(productionCorpUsEast1.zone(main).get())); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java index 98189613bd0..cf2fa182d0a 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTester.java @@ -13,7 +13,6 @@ import com.yahoo.vespa.curator.mock.MockCurator; import com.yahoo.vespa.hosted.controller.api.application.v4.model.DeployOptions; import com.yahoo.vespa.hosted.controller.api.identifiers.Property; import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; -import com.yahoo.vespa.hosted.controller.api.identifiers.ScrewdriverId; import com.yahoo.vespa.hosted.controller.api.integration.BuildService; import com.yahoo.vespa.hosted.controller.api.integration.chef.ChefMock; import com.yahoo.vespa.hosted.controller.api.integration.deployment.ArtifactRepository; @@ -41,6 +40,7 @@ import com.yahoo.vespa.hosted.rotation.config.RotationsConfig; import java.util.Optional; import java.util.OptionalLong; +import java.util.function.Supplier; import java.util.logging.Logger; import static org.junit.Assert.assertNotNull; @@ -64,25 +64,28 @@ public final class ControllerTester { private final ArtifactRepositoryMock artifactRepository; private final EntityService entityService; private final MockBuildService buildService; + private final MockMetricsService metricsService; private Controller controller; - public ControllerTester(ManualClock clock, RotationsConfig rotationsConfig, MockCuratorDb curatorDb) { + public ControllerTester(ManualClock clock, RotationsConfig rotationsConfig, MockCuratorDb curatorDb, + MockMetricsService metricsService) { this(new AthenzDbMock(), clock, new ConfigServerMock(new ZoneRegistryMock()), new ZoneRegistryMock(), new GitHubMock(), curatorDb, rotationsConfig, - new MemoryNameService(), new ArtifactRepositoryMock(), new MemoryEntityService(), new MockBuildService()); + new MemoryNameService(), new ArtifactRepositoryMock(), new MemoryEntityService(), new MockBuildService(), + metricsService); } public ControllerTester(ManualClock clock) { - this(clock, defaultRotationsConfig(), new MockCuratorDb()); + this(clock, defaultRotationsConfig(), new MockCuratorDb(), new MockMetricsService()); } public ControllerTester(RotationsConfig rotationsConfig) { - this(new ManualClock(), rotationsConfig, new MockCuratorDb()); + this(new ManualClock(), rotationsConfig, new MockCuratorDb(), new MockMetricsService()); } public ControllerTester(MockCuratorDb curatorDb) { - this(new ManualClock(), defaultRotationsConfig(), curatorDb); + this(new ManualClock(), defaultRotationsConfig(), curatorDb, new MockMetricsService()); } public ControllerTester() { @@ -93,7 +96,8 @@ public final class ControllerTester { ConfigServerMock configServer, ZoneRegistryMock zoneRegistry, GitHubMock gitHub, CuratorDb curator, RotationsConfig rotationsConfig, MemoryNameService nameService, ArtifactRepositoryMock artifactRepository, - EntityService entityService, MockBuildService buildService) { + EntityService entityService, MockBuildService buildService, + MockMetricsService metricsService) { this.athenzDb = athenzDb; this.clock = clock; this.configServer = configServer; @@ -105,8 +109,10 @@ public final class ControllerTester { this.artifactRepository = artifactRepository; this.entityService = entityService; this.buildService = buildService; + this.metricsService = metricsService; this.controller = createController(curator, rotationsConfig, configServer, clock, gitHub, zoneRegistry, - athenzDb, nameService, artifactRepository, entityService, buildService); + athenzDb, nameService, artifactRepository, entityService, buildService, + metricsService); // Make root logger use time from manual clock Logger.getLogger("").getHandlers()[0].setFilter( @@ -138,10 +144,12 @@ public final class ControllerTester { public MockBuildService buildService() { return buildService; } + public MockMetricsService metricsService() { return metricsService; } + /** Create a new controller instance. Useful to verify that controller state is rebuilt from persistence */ public final void createNewController() { controller = createController(curator, rotationsConfig, configServer, clock, gitHub, zoneRegistry, athenzDb, - nameService, artifactRepository, entityService, buildService); + nameService, artifactRepository, entityService, buildService, metricsService); } /** Creates the given tenant and application and deploys it */ @@ -241,6 +249,10 @@ public final class ControllerTester { new DeployOptions(false, Optional.empty(), false, deployCurrentVersion)); } + public Supplier<Application> application(ApplicationId application) { + return () -> controller().applications().require(application); + } + /** Used by ApplicationSerializerTest to avoid breaking encapsulation. Should not be used by anything else */ public static LockedApplication writable(Application application) { return new LockedApplication(application, new Lock("/test", new MockCurator())); @@ -251,7 +263,7 @@ public final class ControllerTester { GitHubMock gitHub, ZoneRegistryMock zoneRegistryMock, AthenzDbMock athensDb, MemoryNameService nameService, ArtifactRepository artifactRepository, EntityService entityService, - BuildService buildService) { + BuildService buildService, MockMetricsService metricsService) { Controller controller = new Controller(curator, rotationsConfig, gitHub, @@ -260,7 +272,7 @@ public final class ControllerTester { new MemoryGlobalRoutingService(), zoneRegistryMock, configServer, - new MockMetricsService(), + metricsService, nameService, new MockRoutingGenerator(), new ChefMock(), diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/MockMetricsService.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/MockMetricsService.java index 88bbb582564..67a4139ecf1 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/MockMetricsService.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/MockMetricsService.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.hosted.controller.integration; import com.yahoo.config.provision.ApplicationId; +import com.yahoo.vespa.hosted.controller.api.integration.MetricsService; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; import java.util.HashMap; @@ -10,16 +11,28 @@ import java.util.Map; /** * @author bratseth */ -public class MockMetricsService implements com.yahoo.vespa.hosted.controller.api.integration.MetricsService { +public class MockMetricsService implements MetricsService { + + private final Map<String, Double> metrics = new HashMap<>(); + + public MockMetricsService setMetric(String key, Double value) { + metrics.put(key, value); + return this; + } @Override public ApplicationMetrics getApplicationMetrics(ApplicationId application) { - return new ApplicationMetrics(0.5, 0.7); + return new ApplicationMetrics(metrics.getOrDefault("queryServiceQuality", 0.5), + metrics.getOrDefault("writeServiceQuality", 0.7)); } @Override public DeploymentMetrics getDeploymentMetrics(ApplicationId application, ZoneId zone) { - return new DeploymentMetrics(1, 2, 3, 4, 5); + return new DeploymentMetrics(metrics.getOrDefault("queriesPerSecond", 1D), + metrics.getOrDefault("writesPerSecond", 2D), + metrics.getOrDefault("docoumentCount", 3D).longValue(), + metrics.getOrDefault("queryLatencyMillis", 4D), + metrics.getOrDefault("writeLatencyMillis", 5D)); } @Override diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainerTest.java index 148d11e8b38..a651210767d 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/DeploymentMetricsMaintainerTest.java @@ -12,37 +12,69 @@ import com.yahoo.vespa.hosted.controller.persistence.MockCuratorDb; import org.junit.Test; import java.time.Duration; +import java.time.Instant; +import java.util.function.Supplier; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; /** * @author smorgrav + * @author mpolden */ public class DeploymentMetricsMaintainerTest { @Test public void maintain() { ControllerTester tester = new ControllerTester(); - ApplicationId app = tester.createAndDeploy("tenant1", "domain1", "app1", Environment.dev, 123).id(); + ApplicationId appId = tester.createAndDeploy("tenant1", "domain1", "app1", + Environment.dev, 123).id(); + DeploymentMetricsMaintainer maintainer = new DeploymentMetricsMaintainer(tester.controller(), + Duration.ofDays(1), + new JobControl(new MockCuratorDb())); + Supplier<Application> app = tester.application(appId); + Supplier<Deployment> deployment = () -> app.get().deployments().values().stream().findFirst().get(); - // Pre condition: no metric info on neither application nor deployment - assertEquals(0, tester.controller().applications().require(app).metrics().queryServiceQuality(), 0); - Deployment deployment = tester.controller().applications().get(app).get().deployments().values().stream().findAny().get(); - assertEquals(0, deployment.metrics().documentCount(), 0); + // No metrics gathered yet + assertEquals(0, app.get().metrics().queryServiceQuality(), 0); + assertEquals(0, deployment.get().metrics().documentCount(), 0); + assertFalse("Never received any queries", deployment.get().activity().lastQueried().isPresent()); + assertFalse("Never received any writes", deployment.get().activity().lastWritten().isPresent()); - DeploymentMetricsMaintainer maintainer = new DeploymentMetricsMaintainer(tester.controller(), Duration.ofMinutes(10), new JobControl(new MockCuratorDb())); + // Metrics are gathered and saved to application maintainer.maintain(); + assertEquals(0.5, app.get().metrics().queryServiceQuality(), Double.MIN_VALUE); + assertEquals(0.7, app.get().metrics().writeServiceQuality(), Double.MIN_VALUE); + assertEquals(1, deployment.get().metrics().queriesPerSecond(), Double.MIN_VALUE); + assertEquals(2, deployment.get().metrics().writesPerSecond(), Double.MIN_VALUE); + assertEquals(3, deployment.get().metrics().documentCount(), Double.MIN_VALUE); + assertEquals(4, deployment.get().metrics().queryLatencyMillis(), Double.MIN_VALUE); + assertEquals(5, deployment.get().metrics().writeLatencyMillis(), Double.MIN_VALUE); + Instant t1 = tester.clock().instant(); + assertEquals(t1, deployment.get().activity().lastQueried().get()); + assertEquals(t1, deployment.get().activity().lastWritten().get()); - // Post condition: - Application application = tester.controller().applications().require(app); - assertEquals(0.5, application.metrics().queryServiceQuality(), Double.MIN_VALUE); - assertEquals(0.7, application.metrics().writeServiceQuality(), Double.MIN_VALUE); - deployment = application.deployments().values().stream().findAny().get(); - assertEquals(1, deployment.metrics().queriesPerSecond(), Double.MIN_VALUE); - assertEquals(2, deployment.metrics().writesPerSecond(), Double.MIN_VALUE); - assertEquals(3, deployment.metrics().documentCount(), Double.MIN_VALUE); - assertEquals(4, deployment.metrics().queryLatencyMillis(), Double.MIN_VALUE); - assertEquals(5, deployment.metrics().writeLatencyMillis(), Double.MIN_VALUE); + // Time passes. Activity is updated as app is still receiving traffic + tester.clock().advance(Duration.ofHours(1)); + Instant t2 = tester.clock().instant(); + maintainer.maintain(); + assertEquals(t2, deployment.get().activity().lastQueried().get()); + assertEquals(t2, deployment.get().activity().lastWritten().get()); + + // Query traffic disappears. Query activity time is no longer updated + tester.clock().advance(Duration.ofHours(1)); + Instant t3 = tester.clock().instant(); + tester.metricsService().setMetric("queriesPerSecond", 0D); + maintainer.maintain(); + assertEquals(t2, deployment.get().activity().lastQueried().get()); + assertEquals(t3, deployment.get().activity().lastWritten().get()); + + // Feed traffic disappears. Feed activity time is no longer updated + tester.clock().advance(Duration.ofHours(1)); + tester.metricsService().setMetric("writesPerSecond", 0D); + maintainer.maintain(); + assertEquals(t2, deployment.get().activity().lastQueried().get()); + assertEquals(t3, deployment.get().activity().lastWritten().get()); } } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java index f6bf3bdd8cf..5c5827fa167 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/ApplicationSerializerTest.java @@ -9,7 +9,6 @@ import com.yahoo.config.provision.ClusterSpec; import com.yahoo.slime.Slime; import com.yahoo.vespa.config.SlimeUtils; import com.yahoo.vespa.hosted.controller.Application; -import com.yahoo.vespa.hosted.controller.ControllerTester; import com.yahoo.vespa.hosted.controller.api.integration.MetricsService; import com.yahoo.vespa.hosted.controller.api.integration.organization.IssueId; import com.yahoo.vespa.hosted.controller.api.integration.zone.ZoneId; @@ -18,6 +17,7 @@ import com.yahoo.vespa.hosted.controller.application.Change; import com.yahoo.vespa.hosted.controller.application.ClusterInfo; import com.yahoo.vespa.hosted.controller.application.ClusterUtilization; import com.yahoo.vespa.hosted.controller.application.Deployment; +import com.yahoo.vespa.hosted.controller.application.DeploymentActivity; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs; import com.yahoo.vespa.hosted.controller.application.DeploymentJobs.JobError; import com.yahoo.vespa.hosted.controller.application.DeploymentMetrics; @@ -42,7 +42,6 @@ import static com.yahoo.config.provision.SystemName.main; import static com.yahoo.vespa.hosted.controller.ControllerTester.writable; import static java.util.Optional.empty; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; /** * @author bratseth @@ -56,7 +55,6 @@ public class ApplicationSerializerTest { @Test public void testSerialization() { - ControllerTester tester = new ControllerTester(); DeploymentSpec deploymentSpec = DeploymentSpec.fromXml("<deployment version='1.0'>" + " <staging/>" + "</deployment>"); @@ -68,9 +66,12 @@ public class ApplicationSerializerTest { ApplicationVersion applicationVersion1 = ApplicationVersion.from(new SourceRevision("repo1", "branch1", "commit1"), 31); ApplicationVersion applicationVersion2 = ApplicationVersion .from(new SourceRevision("repo1", "branch1", "commit1"), 32); + Instant activityAt = Instant.parse("2018-06-01T10:15:30.00Z"); deployments.add(new Deployment(zone1, applicationVersion1, Version.fromString("1.2.3"), Instant.ofEpochMilli(3))); // One deployment without cluster info and utils deployments.add(new Deployment(zone2, applicationVersion2, Version.fromString("1.2.3"), Instant.ofEpochMilli(5), - createClusterUtils(3, 0.2), createClusterInfo(3, 4),new DeploymentMetrics(2,3,4,5,6))); + createClusterUtils(3, 0.2), createClusterInfo(3, 4), + new DeploymentMetrics(2,3,4,5,6), + DeploymentActivity.create(Optional.of(activityAt), Optional.of(activityAt)))); OptionalLong projectId = OptionalLong.of(123L); List<JobStatus> statusList = new ArrayList<>(); @@ -111,6 +112,8 @@ public class ApplicationSerializerTest { assertEquals(original.deployments().get(zone2).version(), serialized.deployments().get(zone2).version()); assertEquals(original.deployments().get(zone1).at(), serialized.deployments().get(zone1).at()); assertEquals(original.deployments().get(zone2).at(), serialized.deployments().get(zone2).at()); + assertEquals(original.deployments().get(zone2).activity().lastQueried().get(), serialized.deployments().get(zone2).activity().lastQueried().get()); + assertEquals(original.deployments().get(zone2).activity().lastWritten().get(), serialized.deployments().get(zone2).activity().lastWritten().get()); assertEquals(original.deploymentJobs().projectId(), serialized.deploymentJobs().projectId()); assertEquals(original.deploymentJobs().jobStatus().size(), serialized.deploymentJobs().jobStatus().size()); @@ -146,34 +149,33 @@ public class ApplicationSerializerTest { // Test metrics assertEquals(original.metrics().queryServiceQuality(), serialized.metrics().queryServiceQuality(), Double.MIN_VALUE); assertEquals(original.metrics().writeServiceQuality(), serialized.metrics().writeServiceQuality(), Double.MIN_VALUE); - - assertEquals(2, serialized.deployments().get(zone2).metrics().queriesPerSecond(), Double.MIN_VALUE); - assertEquals(3, serialized.deployments().get(zone2).metrics().writesPerSecond(), Double.MIN_VALUE); - assertEquals(4, serialized.deployments().get(zone2).metrics().documentCount(), Double.MIN_VALUE); - assertEquals(5, serialized.deployments().get(zone2).metrics().queryLatencyMillis(), Double.MIN_VALUE); - assertEquals(6, serialized.deployments().get(zone2).metrics().writeLatencyMillis(), Double.MIN_VALUE); + assertEquals(original.deployments().get(zone2).metrics().queriesPerSecond(), serialized.deployments().get(zone2).metrics().queriesPerSecond(), Double.MIN_VALUE); + assertEquals(original.deployments().get(zone2).metrics().writesPerSecond(), serialized.deployments().get(zone2).metrics().writesPerSecond(), Double.MIN_VALUE); + assertEquals(original.deployments().get(zone2).metrics().documentCount(), serialized.deployments().get(zone2).metrics().documentCount(), Double.MIN_VALUE); + assertEquals(original.deployments().get(zone2).metrics().queryLatencyMillis(), serialized.deployments().get(zone2).metrics().queryLatencyMillis(), Double.MIN_VALUE); + assertEquals(original.deployments().get(zone2).metrics().writeLatencyMillis(), serialized.deployments().get(zone2).metrics().writeLatencyMillis(), Double.MIN_VALUE); { // test more deployment serialization cases - Application original2 = writable(original).withChange(Change.of(ApplicationVersion.from(new SourceRevision("repo1", "branch1", "commit1"), 42))); + Application original2 = writable(original).withChange(Change.of(ApplicationVersion.from(new SourceRevision("repo1", "branch1", "commit1"), 42))).get(); Application serialized2 = applicationSerializer.fromSlime(applicationSerializer.toSlime(original2)); assertEquals(original2.change(), serialized2.change()); assertEquals(serialized2.change().application().get().source(), original2.change().application().get().source()); - Application original3 = writable(original).withChange(Change.of(ApplicationVersion.from(new SourceRevision("a", "b", "c"), 42))); + Application original3 = writable(original).withChange(Change.of(ApplicationVersion.from(new SourceRevision("a", "b", "c"), 42))).get(); Application serialized3 = applicationSerializer.fromSlime(applicationSerializer.toSlime(original3)); assertEquals(original3.change(), serialized3.change()); assertEquals(serialized3.change().application().get().source(), original3.change().application().get().source()); - Application original4 = writable(original).withChange(Change.empty()); + Application original4 = writable(original).withChange(Change.empty()).get(); Application serialized4 = applicationSerializer.fromSlime(applicationSerializer.toSlime(original4)); assertEquals(original4.change(), serialized4.change()); - Application original5 = writable(original).withChange(Change.of(ApplicationVersion.from(new SourceRevision("a", "b", "c"), 42))); + Application original5 = writable(original).withChange(Change.of(ApplicationVersion.from(new SourceRevision("a", "b", "c"), 42))).get(); Application serialized5 = applicationSerializer.fromSlime(applicationSerializer.toSlime(original5)); assertEquals(original5.change(), serialized5.change()); - Application original6 = writable(original).withOutstandingChange(Change.of(ApplicationVersion.from(new SourceRevision("a", "b", "c"), 42))); + Application original6 = writable(original).withOutstandingChange(Change.of(ApplicationVersion.from(new SourceRevision("a", "b", "c"), 42))).get(); Application serialized6 = applicationSerializer.fromSlime(applicationSerializer.toSlime(original6)); assertEquals(original6.outstandingChange(), serialized6.outstandingChange()); } @@ -210,15 +212,6 @@ public class ApplicationSerializerTest { } @Test - public void testLegacySerialization() { - Application applicationWithSuccessfulJob = applicationSerializer.fromSlime(applicationSlime(false)); - assertFalse("No job error for successful job", applicationWithSuccessfulJob.deploymentJobs().jobStatus().get(DeploymentJobs.JobType.systemTest).jobError().isPresent()); - - Application applicationWithFailingJob = applicationSerializer.fromSlime(applicationSlime(true)); - assertEquals(JobError.unknown, applicationWithFailingJob.deploymentJobs().jobStatus().get(DeploymentJobs.JobType.systemTest).jobError().get()); - } - - @Test public void testCompleteApplicationDeserialization() throws Exception { byte[] applicationJson = Files.readAllBytes(testData.resolve("complete-application.json")); applicationSerializer.fromSlime(SlimeUtils.jsonToSlime(applicationJson)); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java index 8d734ec549c..545ee529635 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java @@ -57,6 +57,7 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.net.URI; import java.nio.charset.StandardCharsets; +import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -1123,7 +1124,8 @@ public class ApplicationApiTest extends ControllerContainerTest { lockedApplication = lockedApplication .withClusterInfo(deployment.zone(), clusterInfo) .withClusterUtilization(deployment.zone(), clusterUtils) - .with(deployment.zone(), metrics); + .with(deployment.zone(), metrics) + .recordActivityAt(Instant.parse("2018-06-01T10:15:30.00Z"), deployment.zone()); } controllerTester.controller().applications().store(lockedApplication); }); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json index 30070e509c7..f8c4c26d6a8 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application-without-change-multiple-deployments.json @@ -249,5 +249,9 @@ "metrics": { "queryServiceQuality": 0.5, "writeServiceQuality": 0.7 + }, + "activity": { + "queriedAt": 1527848130000, + "writtenAt": 1527848130000 } } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json index dabeb3239aa..fc0f83c2cdc 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application.json @@ -234,5 +234,9 @@ "metrics": { "queryServiceQuality": 0.5, "writeServiceQuality": 0.7 + }, + "activity": { + "queriedAt": 1527848130000, + "writtenAt": 1527848130000 } } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json index 174bb2f1ba7..8bb1ee83282 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/application1-recursive.json @@ -222,6 +222,10 @@ "queryServiceQuality": 0.5, "writeServiceQuality": 0.7 }, + "activity": { + "queriedAt": 1527848130000, + "writtenAt": 1527848130000 + }, "ownershipIssueId": "321", "deploymentIssueId": "123" } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json index 9174e7dd8b2..79e86b5f7f4 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/deployment.json @@ -16,6 +16,10 @@ "gitRepository": "repository1", "gitBranch": "master", "gitCommit": "commit1", + "activity": { + "lastQueried": 1527848130000, + "lastWritten": 1527848130000 + }, "cost": { "tco": 74, "waste": 0, diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json index d13a4dac116..8fccd738554 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/dev-us-west-1.json @@ -15,8 +15,10 @@ "revision": "(ignore)", "deployTimeEpochMs": "(ignore)", "screwdriverId": "123", - - + "activity": { + "lastQueried": 1527848130000, + "lastWritten": 1527848130000 + }, "cost": { "tco": 74, "waste": 0, diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-corp-us-east-1.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-corp-us-east-1.json index 0f16bee308d..066e840fe16 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-corp-us-east-1.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/prod-corp-us-east-1.json @@ -22,6 +22,10 @@ "gitRepository": "repository1", "gitBranch": "master", "gitCommit": "commit1", + "activity": { + "lastQueried": 1527848130000, + "lastWritten": 1527848130000 + }, "cost": { "tco": 74, "waste": 0, diff --git a/docprocs/src/test/java/com/yahoo/docprocs/indexing/DocumentScriptTestCase.java b/docprocs/src/test/java/com/yahoo/docprocs/indexing/DocumentScriptTestCase.java index 5b1a4412b41..419b60432c4 100644 --- a/docprocs/src/test/java/com/yahoo/docprocs/indexing/DocumentScriptTestCase.java +++ b/docprocs/src/test/java/com/yahoo/docprocs/indexing/DocumentScriptTestCase.java @@ -1,11 +1,13 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.docprocs.indexing; +import com.yahoo.document.ArrayDataType; import com.yahoo.document.DataType; import com.yahoo.document.Document; import com.yahoo.document.DocumentType; import com.yahoo.document.DocumentUpdate; import com.yahoo.document.Field; +import com.yahoo.document.MapDataType; import com.yahoo.document.StructDataType; import com.yahoo.document.annotation.SpanTree; import com.yahoo.document.annotation.SpanTrees; @@ -16,6 +18,7 @@ import com.yahoo.document.datatypes.StringFieldValue; import com.yahoo.document.datatypes.Struct; import com.yahoo.document.datatypes.WeightedSet; import com.yahoo.document.fieldpathupdate.AssignFieldPathUpdate; +import com.yahoo.document.fieldpathupdate.FieldPathUpdate; import com.yahoo.document.update.FieldUpdate; import com.yahoo.document.update.MapValueUpdate; import com.yahoo.document.update.ValueUpdate; @@ -30,6 +33,7 @@ import org.junit.Test; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -157,6 +161,82 @@ public class DocumentScriptTestCase { assertSpanTrees(str, "mySpanTree"); } + private class FieldPathFixture { + final DocumentType type; + final StructDataType structType; + final DataType structMap; + final DataType structArray; + + FieldPathFixture() { + type = newDocumentType(); + structType = new StructDataType("mystruct"); + structType.addField(new Field("title", DataType.STRING)); + structType.addField(new Field("rating", DataType.INT)); + structArray = new ArrayDataType(structType); + type.addField(new Field("structarray", structArray)); + structMap = new MapDataType(DataType.STRING, structType); + type.addField(new Field("structmap", structMap)); + type.addField(new Field("structfield", structType)); + } + + DocumentUpdate executeWithUpdate(String fieldName, FieldPathUpdate updateIn) { + DocumentUpdate update = new DocumentUpdate(type, "doc:scheme:"); + update.addFieldPathUpdate(updateIn); + return newScript(type, fieldName).execute(ADAPTER_FACTORY, update); + } + + FieldPathUpdate executeWithUpdateAndExpectFieldPath(String fieldName, FieldPathUpdate updateIn) { + DocumentUpdate update = executeWithUpdate(fieldName, updateIn); + assertEquals(1, update.getFieldPathUpdates().size()); + return update.getFieldPathUpdates().get(0); + } + } + + @Test + public void array_field_path_updates_survive_indexing_scripts() { + FieldPathFixture f = new FieldPathFixture(); + + Struct newElemValue = new Struct(f.structType); + newElemValue.setFieldValue("title", "iron moose 2, the moosening"); + + FieldPathUpdate updated = f.executeWithUpdateAndExpectFieldPath("structarray", new AssignFieldPathUpdate(f.type, "structarray[10]", newElemValue)); + + assertTrue(updated instanceof AssignFieldPathUpdate); + AssignFieldPathUpdate assignUpdate = (AssignFieldPathUpdate)updated; + assertEquals("structarray[10]", assignUpdate.getOriginalFieldPath()); + assertEquals(newElemValue, assignUpdate.getFieldValue()); + } + + @Test + public void map_field_path_updates_survive_indexing_scripts() { + FieldPathFixture f = new FieldPathFixture(); + + Struct newElemValue = new Struct(f.structType); + newElemValue.setFieldValue("title", "iron moose 3, moose in new york"); + + FieldPathUpdate updated = f.executeWithUpdateAndExpectFieldPath("structmap", new AssignFieldPathUpdate(f.type, "structmap{foo}", newElemValue)); + + assertTrue(updated instanceof AssignFieldPathUpdate); + AssignFieldPathUpdate assignUpdate = (AssignFieldPathUpdate)updated; + assertEquals("structmap{foo}", assignUpdate.getOriginalFieldPath()); + assertEquals(newElemValue, assignUpdate.getFieldValue()); + } + + @Test + public void nested_struct_fieldpath_update_is_not_converted_to_regular_field_value_update() { + FieldPathFixture f = new FieldPathFixture(); + + StringFieldValue newTitleValue = new StringFieldValue("iron moose 4, moose with a vengeance"); + DocumentUpdate update = f.executeWithUpdate("structfield", new AssignFieldPathUpdate(f.type, "structfield.title", newTitleValue)); + + assertEquals(1, update.getFieldPathUpdates().size()); + assertEquals(0, update.getFieldUpdates().size()); + assertTrue(update.getFieldPathUpdates().get(0) instanceof AssignFieldPathUpdate); + AssignFieldPathUpdate assignUpdate = (AssignFieldPathUpdate)update.getFieldPathUpdates().get(0); + assertEquals("structfield.title", assignUpdate.getOriginalFieldPath()); + assertEquals(newTitleValue, assignUpdate.getFieldValue()); + } + private static FieldValue processDocument(FieldValue fieldValue) { DocumentType docType = new DocumentType("myDocumentType"); docType.addField("myField", fieldValue.getDataType()); @@ -184,11 +264,15 @@ public class DocumentScriptTestCase { return update.getFieldUpdate("myField").getValueUpdate(0); } + private static DocumentScript newScript(DocumentType docType, String fieldName) { + return new DocumentScript(docType.getName(), Collections.singletonList(fieldName), + new StatementExpression(new InputExpression(fieldName), + new IndexExpression(fieldName))); + } + private static DocumentScript newScript(DocumentType docType) { String fieldName = docType.getFields().iterator().next().getName(); - return new DocumentScript(docType.getName(), Arrays.asList(fieldName), - new StatementExpression(new InputExpression(fieldName), - new IndexExpression(fieldName))); + return newScript(docType, fieldName); } private static StringFieldValue newString(String... spanTrees) { @@ -210,6 +294,7 @@ public class DocumentScriptTestCase { DocumentType type = new DocumentType("documentType"); type.addField("documentField", DataType.STRING); type.addField("extraField", DataType.STRING); + return type; } diff --git a/document/src/main/java/com/yahoo/document/datatypes/Array.java b/document/src/main/java/com/yahoo/document/datatypes/Array.java index e37a32f28f4..01326bcea62 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/Array.java +++ b/document/src/main/java/com/yahoo/document/datatypes/Array.java @@ -290,7 +290,8 @@ public final class Array<T extends FieldValue> extends CollectionFieldValue<T> i if (pos < fieldPath.size()) { switch (fieldPath.get(pos).getType()) { case ARRAY_INDEX: - return iterateSubset(fieldPath.get(pos).getLookupIndex(), fieldPath.get(pos).getLookupIndex(), fieldPath, null, pos + 1, handler); + final int elemIndex = fieldPath.get(pos).getLookupIndex(); + return iterateSubset(elemIndex, elemIndex, fieldPath, null, pos + 1, handler); case VARIABLE: { FieldPathIteratorHandler.IndexValue val = handler.getVariables().get(fieldPath.get(pos).getVariableName()); if (val != null) { diff --git a/fat-model-dependencies/pom.xml b/fat-model-dependencies/pom.xml index 1415ca6e5aa..0011d108b98 100644 --- a/fat-model-dependencies/pom.xml +++ b/fat-model-dependencies/pom.xml @@ -16,13 +16,6 @@ <groupId>com.yahoo.vespa</groupId> <artifactId>config-model</artifactId> <version>${project.version}</version> - <exclusions> - <exclusion> - <!-- Large, and installed separately as part of Vespa --> - <groupId>org.tensorflow</groupId> - <artifactId>libtensorflow_jni</artifactId> - </exclusion> - </exclusions> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldPathUpdateHelper.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldPathUpdateHelper.java index 171c6a8eb9a..5c170fe147e 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldPathUpdateHelper.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldPathUpdateHelper.java @@ -20,19 +20,10 @@ public abstract class FieldPathUpdateHelper { if (!(update instanceof AssignFieldPathUpdate)) { return false; } - for (FieldPathEntry entry : update.getFieldPath()) { - switch (entry.getType()) { - case STRUCT_FIELD: - case MAP_ALL_KEYS: - case MAP_ALL_VALUES: - continue; - case ARRAY_INDEX: - case MAP_KEY: - case VARIABLE: - return false; - } - } - return true; + // Only consider field path updates that touch a top-level field as 'complete', + // as these may be converted to regular field value updates. + return ((update.getFieldPath().size() == 1) + && update.getFieldPath().get(0).getType() == FieldPathEntry.Type.STRUCT_FIELD); } public static void applyUpdate(FieldPathUpdate update, Document doc) { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/IdentityFieldPathUpdateAdapter.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/IdentityFieldPathUpdateAdapter.java new file mode 100644 index 00000000000..42c9bd8c10c --- /dev/null +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/IdentityFieldPathUpdateAdapter.java @@ -0,0 +1,68 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.indexinglanguage; + +import com.yahoo.document.DataType; +import com.yahoo.document.Document; +import com.yahoo.document.DocumentUpdate; +import com.yahoo.document.FieldPath; +import com.yahoo.document.datatypes.FieldValue; +import com.yahoo.document.fieldpathupdate.FieldPathUpdate; +import com.yahoo.vespa.indexinglanguage.expressions.Expression; +import com.yahoo.vespa.indexinglanguage.expressions.FieldValueAdapter; + +/** + * No-op update adapter which simply passes through the input update unchanged. + * I.e. getOutput() will return a DocumentUpdate containing only the FieldPathUpdate + * the IdentityFieldPathUpdateAdapter was created with. All other applicable calls are + * forwarded to the provided DocumentAdapter instance. + * + * This removes the need for a potentially lossy round-trip of update -> synthetic document -> update. + */ +public class IdentityFieldPathUpdateAdapter implements UpdateAdapter { + + private final FieldPathUpdate update; + private final DocumentAdapter fwdAdapter; + + public IdentityFieldPathUpdateAdapter(FieldPathUpdate update, DocumentAdapter fwdAdapter) { + this.update = update; + this.fwdAdapter = fwdAdapter; + } + + @Override + public DocumentUpdate getOutput() { + Document doc = fwdAdapter.getFullOutput(); + DocumentUpdate upd = new DocumentUpdate(doc.getDataType(), doc.getId()); + upd.addFieldPathUpdate(update); + return upd; + } + + @Override + public Expression getExpression(Expression expression) { + return expression; + } + + @Override + public FieldValue getInputValue(String fieldName) { + return fwdAdapter.getInputValue(fieldName); + } + + @Override + public FieldValue getInputValue(FieldPath fieldPath) { + return fwdAdapter.getInputValue(fieldPath); + } + + @Override + public FieldValueAdapter setOutputValue(Expression exp, String fieldName, FieldValue fieldValue) { + return fwdAdapter.setOutputValue(exp, fieldName, fieldValue); + } + + @Override + public DataType getInputType(Expression exp, String fieldName) { + return fwdAdapter.getInputType(exp, fieldName); + } + + @Override + public void tryOutputType(Expression exp, String fieldName, DataType valueType) { + fwdAdapter.tryOutputType(exp, fieldName, valueType); + } +} diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/SimpleAdapterFactory.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/SimpleAdapterFactory.java index 2ad09dfbdc4..509bdcaa32d 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/SimpleAdapterFactory.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/SimpleAdapterFactory.java @@ -49,10 +49,12 @@ public class SimpleAdapterFactory implements AdapterFactory { Document complete = new Document(docType, upd.getId()); for (FieldPathUpdate fieldUpd : upd) { if (FieldPathUpdateHelper.isComplete(fieldUpd)) { + // A 'complete' field path update is basically a regular top-level field update + // in wolf's clothing. Convert it to a regular field update to be friendlier + // towards the search core backend. FieldPathUpdateHelper.applyUpdate(fieldUpd, complete); } else { - Document partial = FieldPathUpdateHelper.newPartialDocument(docId, fieldUpd); - ret.add(new FieldPathUpdateAdapter(newDocumentAdapter(partial, true), fieldUpd)); + ret.add(new IdentityFieldPathUpdateAdapter(fieldUpd, newDocumentAdapter(complete, true))); } } for (FieldUpdate fieldUpd : upd.getFieldUpdates()) { diff --git a/jdisc_http_service/pom.xml b/jdisc_http_service/pom.xml index 6373189e738..f41994c4916 100644 --- a/jdisc_http_service/pom.xml +++ b/jdisc_http_service/pom.xml @@ -175,7 +175,6 @@ <extensions>true</extensions> <configuration> <discPreInstallBundle> - asm-debug-all-${asm-debug-all.version}.jar, bcpkix-jdk15on-${bouncycastle.version}.jar, bcprov-jdk15on-${bouncycastle.version}.jar, javax.servlet-api-3.1.0.jar, @@ -188,8 +187,6 @@ jetty-servlet-${jetty.version}.jar, jetty-servlets-${jetty.version}.jar, jetty-util-${jetty.version}.jar, - org.apache.aries.spifly.dynamic.bundle-${aries.spifly.version}.jar, - org.apache.aries.util-${aries.util.version}.jar, component-jar-with-dependencies.jar </discPreInstallBundle> </configuration> diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerConformanceTest.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerConformanceTest.java index 80c1cb8b458..77411fc080e 100644 --- a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerConformanceTest.java +++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerConformanceTest.java @@ -323,7 +323,7 @@ public class HttpServerConformanceTest extends ServerProviderConformanceTest { @Override @Test public void testRequestContentWriteExceptionAfterResponseWriteWithSyncCompletion() throws Throwable { - new TestRunner().expect(success()) + new TestRunner().expect(anyOf(success(), successNoContent())) .execute(); } diff --git a/jdisc_jetty/pom.xml b/jdisc_jetty/pom.xml index 0f8a5ba19e2..404476f7bf2 100644 --- a/jdisc_jetty/pom.xml +++ b/jdisc_jetty/pom.xml @@ -16,10 +16,6 @@ <packaging>jar</packaging> <dependencies> <dependency> - <groupId>org.apache.aries.spifly</groupId> - <artifactId>org.apache.aries.spifly.dynamic.bundle</artifactId> - </dependency> - <dependency> <groupId>org.eclipse.jetty</groupId> <artifactId>jetty-continuation</artifactId> </dependency> diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainer.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainer.java index 7f2d1f1eff7..a7bf22591d4 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainer.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/StorageMaintainer.java @@ -276,8 +276,8 @@ public class StorageMaintainer { */ public void handleCoreDumpsForContainer(ContainerName containerName, NodeSpec node, boolean force) { // Sample number of coredumps on the host - try { - numberOfCoredumpsOnHost.sample(Files.list(environment.pathInNodeAdminToDoneCoredumps()).count()); + try (Stream<Path> files = Files.list(environment.pathInNodeAdminToDoneCoredumps())) { + numberOfCoredumpsOnHost.sample(files.count()); } catch (IOException e) { // Ignore for now - this is either test or a misconfiguration } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java index f7e9c3ca1d8..ff85c49bb13 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java @@ -1,6 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.node.admin.maintenance.identity; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import com.yahoo.vespa.athenz.api.AthenzService; import com.yahoo.vespa.athenz.client.zts.DefaultZtsClient; import com.yahoo.vespa.athenz.client.zts.InstanceIdentity; @@ -9,7 +11,7 @@ import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider; import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper; import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocumentClient; import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument; -import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId; +import com.yahoo.vespa.athenz.identityprovider.api.bindings.SignedIdentityDocumentEntity; import com.yahoo.vespa.athenz.identityprovider.client.DefaultIdentityDocumentClient; import com.yahoo.vespa.athenz.identityprovider.client.InstanceCsrGenerator; import com.yahoo.vespa.athenz.tls.AthenzIdentityVerifier; @@ -19,9 +21,9 @@ import com.yahoo.vespa.athenz.tls.KeyUtils; import com.yahoo.vespa.athenz.tls.Pkcs10Csr; import com.yahoo.vespa.athenz.tls.SslContextBuilder; import com.yahoo.vespa.athenz.tls.X509CertificateUtils; +import com.yahoo.vespa.athenz.utils.SiaUtils; import com.yahoo.vespa.hosted.dockerapi.ContainerName; import com.yahoo.vespa.hosted.node.admin.component.Environment; -import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeSpec; import com.yahoo.vespa.hosted.node.admin.util.PrefixLogger; import javax.net.ssl.SSLContext; @@ -38,7 +40,6 @@ import java.security.cert.X509Certificate; import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.util.Set; import static java.util.Collections.singleton; @@ -53,12 +54,15 @@ public class AthenzCredentialsMaintainer { private static final Duration REFRESH_PERIOD = Duration.ofDays(1); private static final Path CONTAINER_SIA_DIRECTORY = Paths.get("/var/lib/sia"); + private static final ObjectMapper mapper = new ObjectMapper().registerModule(new JavaTimeModule()); + private final boolean enabled; private final PrefixLogger log; private final String hostname; private final Path trustStorePath; private final Path privateKeyFile; private final Path certificateFile; + private final Path identityDocumentFile; private final AthenzService containerIdentity; private final URI ztsEndpoint; private final Clock clock; @@ -66,8 +70,6 @@ public class AthenzCredentialsMaintainer { private final IdentityDocumentClient identityDocumentClient; private final InstanceCsrGenerator csrGenerator; private final AthenzService configserverIdentity; - private final String zoneRegion; - private final String zoneEnvironment; public AthenzCredentialsMaintainer(String hostname, Environment environment, @@ -82,8 +84,9 @@ public class AthenzCredentialsMaintainer { this.configserverIdentity = environment.getConfigserverAthenzIdentity(); this.csrGenerator = new InstanceCsrGenerator(environment.getCertificateDnsSuffix()); this.trustStorePath = environment.getTrustStorePath(); - this.privateKeyFile = getPrivateKeyFile(containerSiaDirectory, containerIdentity); - this.certificateFile = getCertificateFile(containerSiaDirectory, containerIdentity); + this.privateKeyFile = SiaUtils.getPrivateKeyFile(containerSiaDirectory, containerIdentity); + this.certificateFile = SiaUtils.getCertificateFile(containerSiaDirectory, containerIdentity); + this.identityDocumentFile = containerSiaDirectory.resolve("vespa-node-identity-document.json"); this.hostIdentityProvider = hostIdentityProvider; this.identityDocumentClient = new DefaultIdentityDocumentClient( @@ -91,15 +94,12 @@ public class AthenzCredentialsMaintainer { hostIdentityProvider, new AthenzIdentityVerifier(singleton(configserverIdentity))); this.clock = Clock.systemUTC(); - this.zoneRegion = environment.getRegion(); - this.zoneEnvironment = environment.getEnvironment(); } /** - * @param nodeSpec Node specification * @return Returns true if credentials were updated */ - public boolean converge(NodeSpec nodeSpec) { + public boolean converge() { try { if (!enabled) { log.debug("Feature disabled on this host - not fetching certificate"); @@ -107,26 +107,25 @@ public class AthenzCredentialsMaintainer { } log.debug("Checking certificate"); Instant now = clock.instant(); - VespaUniqueInstanceId instanceId = getVespaUniqueInstanceId(nodeSpec); - Set<String> ipAddresses = nodeSpec.getIpAddresses(); - if (!Files.exists(privateKeyFile) || !Files.exists(certificateFile)) { - log.info("Certificate and/or private key file does not exist"); + if (!Files.exists(privateKeyFile) || !Files.exists(certificateFile) || !Files.exists(identityDocumentFile)) { + log.info("Certificate/private key/identity document file does not exist"); Files.createDirectories(privateKeyFile.getParent()); Files.createDirectories(certificateFile.getParent()); - registerIdentity(instanceId, ipAddresses); + Files.createDirectories(identityDocumentFile.getParent()); + registerIdentity(); return true; } X509Certificate certificate = readCertificateFromFile(); Instant expiry = certificate.getNotAfter().toInstant(); if (isCertificateExpired(expiry, now)) { log.info(String.format("Certificate has expired (expiry=%s)", expiry.toString())); - registerIdentity(instanceId, ipAddresses); + registerIdentity(); return true; } Duration age = Duration.between(certificate.getNotBefore().toInstant(), now); if (shouldRefreshCredentials(age)) { log.info(String.format("Certificate is ready to be refreshed (age=%s)", age.toString())); - refreshIdentity(instanceId, ipAddresses); + refreshIdentity(); return true; } log.debug("Certificate is still valid"); @@ -148,19 +147,6 @@ public class AthenzCredentialsMaintainer { } } - private VespaUniqueInstanceId getVespaUniqueInstanceId(NodeSpec nodeSpec) { - NodeSpec.Membership membership = nodeSpec.getMembership().get(); - NodeSpec.Owner owner = nodeSpec.getOwner().get(); - return new VespaUniqueInstanceId( - membership.getIndex(), - membership.getClusterId(), - owner.getInstance(), - owner.getApplication(), - owner.getTenant(), - zoneRegion, - zoneEnvironment); - } - private boolean shouldRefreshCredentials(Duration age) { return age.compareTo(REFRESH_PERIOD) >= 0; } @@ -174,32 +160,32 @@ public class AthenzCredentialsMaintainer { return now.isAfter(expiry.minus(EXPIRY_MARGIN)); } - private void registerIdentity(VespaUniqueInstanceId instanceId, Set<String> ipAddresses) { + private void registerIdentity() { KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA); - Pkcs10Csr csr = csrGenerator.generateCsr(containerIdentity, instanceId, ipAddresses, keyPair); SignedIdentityDocument signedIdentityDocument = identityDocumentClient.getNodeIdentityDocument(hostname); + Pkcs10Csr csr = csrGenerator.generateCsr( + containerIdentity, signedIdentityDocument.providerUniqueId(), signedIdentityDocument.ipAddresses(), keyPair); try (ZtsClient ztsClient = new DefaultZtsClient(ztsEndpoint, hostIdentityProvider)) { InstanceIdentity instanceIdentity = ztsClient.registerInstance( configserverIdentity, containerIdentity, - instanceId.asDottedString(), + signedIdentityDocument.providerUniqueId().asDottedString(), EntityBindingsMapper.toAttestationData(signedIdentityDocument), false, csr); + writeIdentityDocument(signedIdentityDocument); writePrivateKeyAndCertificate(keyPair.getPrivate(), instanceIdentity.certificate()); log.info("Instance successfully registered and credentials written to file"); } catch (IOException e) { throw new UncheckedIOException(e); - } catch (Exception e) { - // TODO Change close() in ZtsClient to not throw checked exception - throw new RuntimeException(e); } } - private void refreshIdentity(VespaUniqueInstanceId instanceId, Set<String> ipAddresses) { + private void refreshIdentity() { + SignedIdentityDocument identityDocument = readIdentityDocument(); KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA); - Pkcs10Csr csr = csrGenerator.generateCsr(containerIdentity, instanceId, ipAddresses, keyPair); + Pkcs10Csr csr = csrGenerator.generateCsr(containerIdentity, identityDocument.providerUniqueId(), identityDocument.ipAddresses(), keyPair); SSLContext containerIdentitySslContext = new SslContextBuilder() .withKeyStore(privateKeyFile.toFile(), certificateFile.toFile()) @@ -210,16 +196,34 @@ public class AthenzCredentialsMaintainer { ztsClient.refreshInstance( configserverIdentity, containerIdentity, - instanceId.asDottedString(), + identityDocument.providerUniqueId().asDottedString(), false, csr); writePrivateKeyAndCertificate(keyPair.getPrivate(), instanceIdentity.certificate()); log.info("Instance successfully refreshed and credentials written to file"); } catch (IOException e) { throw new UncheckedIOException(e); - } catch (Exception e) { - // TODO Change close() in ZtsClient to not throw checked exception - throw new RuntimeException(e); + } + } + + private SignedIdentityDocument readIdentityDocument() { + try { + SignedIdentityDocumentEntity entity = mapper.readValue(identityDocumentFile.toFile(), SignedIdentityDocumentEntity.class); + return EntityBindingsMapper.toSignedIdentityDocument(entity); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private void writeIdentityDocument(SignedIdentityDocument signedIdentityDocument) { + try { + SignedIdentityDocumentEntity entity = + EntityBindingsMapper.toSignedIdentityDocumentEntity(signedIdentityDocument); + Path tempIdentityDocumentFile = toTempPath(identityDocumentFile); + mapper.writeValue(tempIdentityDocumentFile.toFile(), entity); + Files.move(tempIdentityDocumentFile, identityDocumentFile, StandardCopyOption.ATOMIC_MOVE); + } catch (IOException e) { + throw new UncheckedIOException(e); } } @@ -237,18 +241,4 @@ public class AthenzCredentialsMaintainer { return Paths.get(file.toAbsolutePath().toString() + ".tmp"); } - // TODO Move to vespa-athenz - private static Path getPrivateKeyFile(Path root, AthenzService service) { - return root - .resolve("keys") - .resolve(String.format("%s.%s.key.pem", service.getDomain().getName(), service.getName())); - } - - // TODO Move to vespa-athenz - private static Path getCertificateFile(Path root, AthenzService service) { - return root - .resolve("certs") - .resolve(String.format("%s.%s.cert.pem", service.getDomain().getName(), service.getName())); - } - } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java index 7fa9a90b744..5f1b7aefcfe 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java @@ -498,7 +498,7 @@ public class NodeAgentImpl implements NodeAgent { runLocalResumeScriptIfNeeded(node); - athenzCredentialsMaintainer.converge(node); + athenzCredentialsMaintainer.converge(); doBeforeConverge(node); diff --git a/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoredumpHandler.java b/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoredumpHandler.java index 99dfdb48334..63c74c17dd5 100644 --- a/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoredumpHandler.java +++ b/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/CoredumpHandler.java @@ -72,7 +72,7 @@ class CoredumpHandler { FileHelper.deleteDirectories(doneCoredumpsPath, Duration.ofDays(10), Optional.empty()); } - private void handleNewCoredumps() throws IOException { + private void handleNewCoredumps() { Path processingCoredumps = enqueueCoredumps(); processAndReportCoredumps(processingCoredumps); } @@ -82,12 +82,12 @@ class CoredumpHandler { * Moves a coredump to a new directory under the processing/ directory. Limit to only processing * one coredump at the time, starting with the oldest. */ - Path enqueueCoredumps() throws IOException { + Path enqueueCoredumps() { Path processingCoredumpsPath = coredumpsPath.resolve(PROCESSING_DIRECTORY_NAME); processingCoredumpsPath.toFile().mkdirs(); - if (Files.list(processingCoredumpsPath).count() > 0) return processingCoredumpsPath; + if (!FileHelper.listContentsOfDirectory(processingCoredumpsPath).isEmpty()) return processingCoredumpsPath; - Files.list(coredumpsPath) + FileHelper.listContentsOfDirectory(coredumpsPath).stream() .filter(path -> path.toFile().isFile() && ! path.getFileName().toString().startsWith(".")) .min((Comparator.comparingLong(o -> o.toFile().lastModified()))) .ifPresent(coredumpPath -> { @@ -101,10 +101,10 @@ class CoredumpHandler { return processingCoredumpsPath; } - void processAndReportCoredumps(Path processingCoredumpsPath) throws IOException { + void processAndReportCoredumps(Path processingCoredumpsPath) { doneCoredumpsPath.toFile().mkdirs(); - Files.list(processingCoredumpsPath) + FileHelper.listContentsOfDirectory(processingCoredumpsPath).stream() .filter(path -> path.toFile().isDirectory()) .forEach(coredumpDirectory -> { try { @@ -130,7 +130,7 @@ class CoredumpHandler { String collectMetadata(Path coredumpDirectory, Map<String, Object> nodeAttributes) throws IOException { Path metadataPath = coredumpDirectory.resolve(METADATA_FILE_NAME); if (!Files.exists(metadataPath)) { - Path coredumpPath = Files.list(coredumpDirectory).findFirst() + Path coredumpPath = FileHelper.listContentsOfDirectory(coredumpDirectory).stream().findFirst() .orElseThrow(() -> new RuntimeException("No coredump file found in processing directory " + coredumpDirectory)); Map<String, Object> metadata = coreCollector.collect(coredumpPath, installStatePath); metadata.putAll(nodeAttributes); diff --git a/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/FileHelper.java b/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/FileHelper.java index ae872042853..7b93e7ad98d 100644 --- a/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/FileHelper.java +++ b/node-maintainer/src/main/java/com/yahoo/vespa/hosted/node/maintainer/FileHelper.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.hosted.node.maintainer; import java.io.IOException; +import java.io.UncheckedIOException; import java.nio.file.Files; import java.nio.file.LinkOption; import java.nio.file.NoSuchFileException; @@ -63,7 +64,7 @@ public class FileHelper { throw new IllegalArgumentException("Number of files to keep must be a positive number"); } - List<Path> pathsInDeleteDir = Files.list(basePath) + List<Path> pathsInDeleteDir = listContentsOfDirectory(basePath).stream() .filter(Files::isRegularFile) .sorted(Comparator.comparing(FileHelper::getLastModifiedTime)) .skip(nMostRecentToKeep) @@ -153,13 +154,16 @@ public class FileHelper { return pattern == null || pattern.matcher(path.getFileName().toString()).find(); } - static List<Path> listContentsOfDirectory(Path basePath) { + /** + * @return list all files in a directory, returns empty list if directory does not exist + */ + public static List<Path> listContentsOfDirectory(Path basePath) { try (Stream<Path> directoryStream = Files.list(basePath)) { return directoryStream.collect(Collectors.toList()); } catch (NoSuchFileException ignored) { return Collections.emptyList(); } catch (IOException e) { - throw new RuntimeException("Failed to list contents of directory " + basePath.toAbsolutePath(), e); + throw new UncheckedIOException("Failed to list contents of directory " + basePath.toAbsolutePath(), e); } } @@ -167,7 +171,7 @@ public class FileHelper { try { return Files.getLastModifiedTime(path, LinkOption.NOFOLLOW_LINKS); } catch (IOException e) { - throw new RuntimeException("Failed to get last modified time of " + path.toAbsolutePath(), e); + throw new UncheckedIOException("Failed to get last modified time of " + path.toAbsolutePath(), e); } } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/GroupPreparer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/GroupPreparer.java index 93704d244b5..d31b4438a38 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/GroupPreparer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/GroupPreparer.java @@ -9,7 +9,6 @@ import com.yahoo.transaction.Mutex; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeRepository; -import java.time.Clock; import java.util.List; /** @@ -67,6 +66,7 @@ public class GroupPreparer { allocation.offer(prioritizer.prioritize()); if (! allocation.fullfilled()) throw new OutOfCapacityException("Could not satisfy " + requestedNodes + " for " + cluster + + " in " + application.toShortString() + outOfCapacityDetails(allocation)); // Extend reservation for already reserved nodes diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DockerProvisioningTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DockerProvisioningTest.java index e1b1d74c6d0..2cabee98c0d 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DockerProvisioningTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DockerProvisioningTest.java @@ -160,13 +160,13 @@ public class DockerProvisioningTest { assertEquals(setOf("host1", "host2"), hostsOf(tester.getNodes(application1, Node.State.active))); try { - ApplicationId application2 = tester.makeApplicationId(); + ApplicationId application2 = ApplicationId.from("tenant1", "app1", "default"); prepareAndActivate(application2, 3, false, tester); fail("Expected allocation failure"); } catch (Exception e) { assertEquals("No room for 3 nodes as 2 of 4 hosts are exclusive", - "Could not satisfy request for 3 nodes of flavor 'dockerSmall' for container cluster 'myContainer' group 0 6.39: Not enough nodes available due to host exclusivity constraints.", + "Could not satisfy request for 3 nodes of flavor 'dockerSmall' for container cluster 'myContainer' group 0 6.39 in tenant1.app1: Not enough nodes available due to host exclusivity constraints.", e.getMessage()); } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifierTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifierTest.java index c0cead74f5f..11c7832091b 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifierTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/v2/filter/NodeIdentifierTest.java @@ -29,6 +29,7 @@ import java.security.cert.X509Certificate; import java.time.Instant; import java.util.Optional; +import static com.yahoo.vespa.athenz.identityprovider.api.IdentityType.*; import static com.yahoo.vespa.athenz.tls.KeyAlgorithm.RSA; import static com.yahoo.vespa.athenz.tls.SignatureAlgorithm.SHA256_WITH_RSA; import static java.util.Collections.emptySet; @@ -161,7 +162,7 @@ public class NodeIdentifierTest { Pkcs10Csr csr = Pkcs10CsrBuilder .fromKeypair(new X500Principal("CN=" + TENANT_NODE_IDENTITY), KEYPAIR, SHA256_WITH_RSA) .build(); - VespaUniqueInstanceId vespaUniqueInstanceId = new VespaUniqueInstanceId(clusterIndex, clusterId, INSTANCE_ID, application, tenant, region, environment); + VespaUniqueInstanceId vespaUniqueInstanceId = new VespaUniqueInstanceId(clusterIndex, clusterId, INSTANCE_ID, application, tenant, region, environment, NODE); X509Certificate certificate = X509CertificateBuilder .fromCsr(csr, ATHENZ_YAHOO_CA_CERT.getSubjectX500Principal(), Instant.EPOCH, Instant.EPOCH.plusSeconds(60), KEYPAIR.getPrivate(), SHA256_WITH_RSA, 1) .addSubjectAlternativeName(vespaUniqueInstanceId.asDottedString() + ".instanceid.athenz.provider-name.vespa.yahoo.cloud") diff --git a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/ServiceMonitorInstanceLookupService.java b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/ServiceMonitorInstanceLookupService.java index a09ec29dada..d1d5f3e8c95 100644 --- a/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/ServiceMonitorInstanceLookupService.java +++ b/orchestrator/src/main/java/com/yahoo/vespa/orchestrator/ServiceMonitorInstanceLookupService.java @@ -46,7 +46,7 @@ public class ServiceMonitorInstanceLookupService implements InstanceLookupServic return Optional.empty(); } if (applicationInstancesUsingHost.size() > 1) { - throw new AssertionError( + throw new IllegalStateException( "Major assumption broken: Multiple application instances contain host " + hostName.s() + ": " + applicationInstancesUsingHost); } diff --git a/parent/pom.xml b/parent/pom.xml index 411cc5ede9e..10e93d4ffbf 100644 --- a/parent/pom.xml +++ b/parent/pom.xml @@ -81,7 +81,7 @@ <plugin> <groupId>org.apache.felix</groupId> <artifactId>maven-bundle-plugin</artifactId> - <version>2.4.0</version> + <version>3.5.0</version> </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> @@ -498,11 +498,6 @@ <version>${antlr4.version}</version> </dependency> <dependency> - <groupId>org.apache.aries.spifly</groupId> - <artifactId>org.apache.aries.spifly.dynamic.bundle</artifactId> - <version>${aries.spifly.version}</version> - </dependency> - <dependency> <groupId>org.apache.commons</groupId> <artifactId>commons-lang3</artifactId> <version>3.1</version> @@ -686,9 +681,6 @@ <properties> <antlr.version>3.5.2</antlr.version> <antlr4.version>4.5</antlr4.version> - <aries.spifly.version>1.0.8</aries.spifly.version> - <aries.util.version>1.0.0</aries.util.version> - <asm-debug-all.version>5.0.3</asm-debug-all.version> <!-- Athenz dependencies. Make sure these dependencies matches those in Vespa's internal repositories --> <athenz.version>1.7.43</athenz.version> <commons-lang.version>2.6</commons-lang.version> diff --git a/searchcore/src/tests/proton/matching/query_test.cpp b/searchcore/src/tests/proton/matching/query_test.cpp index 9adb86147b6..61823a17f09 100644 --- a/searchcore/src/tests/proton/matching/query_test.cpp +++ b/searchcore/src/tests/proton/matching/query_test.cpp @@ -107,6 +107,8 @@ class Test : public vespalib::TestApp { void requireThatParallelWandBlueprintsAreCreatedCorrectly(); void requireThatWhiteListBlueprintCanBeUsed(); void requireThatSameElementTermsAreProperlyPrefixed(); + void requireThatSameElementDoesNotAllocateMatchData(); + void requireThatSameElementIteratorsCanBeBuilt(); public: ~Test(); @@ -181,12 +183,7 @@ Node::UP buildQueryTree(const ViewResolver &resolver, query_builder.addPhrase(2, field, 7, Weight(0)); query_builder.addStringTerm(phrase_term, field, 8, Weight(0)); query_builder.addStringTerm(phrase_term, field, 9, Weight(0)); -#if 0 - //Todo add testing when SameElement blueprints are ready - query_builder.addSameElement(2, field); - query_builder.addStringTerm(string_term, field, 10, Weight(0)); - query_builder.addStringTerm(prefix_term, field, 11, Weight(0)); -#endif + Node::UP node = query_builder.build(); ResolveViewVisitor visitor(resolver, idxEnv); @@ -194,6 +191,19 @@ Node::UP buildQueryTree(const ViewResolver &resolver, return node; } +Node::UP buildSameElementQueryTree(const ViewResolver &resolver, + const search::fef::IIndexEnvironment &idxEnv) +{ + QueryBuilder<ProtonNodeTypes> query_builder; + query_builder.addSameElement(2, field); + query_builder.addStringTerm(string_term, field, 0, Weight(0)); + query_builder.addStringTerm(prefix_term, field, 1, Weight(0)); + Node::UP node = query_builder.build(); + ResolveViewVisitor visitor(resolver, idxEnv); + node->accept(visitor); + return node; +} + void Test::requireThatMatchDataIsReserved() { Node::UP node = buildQueryTree(ViewResolver(), plain_index_env); @@ -883,6 +893,7 @@ make_same_element_stack_dump(const vespalib::string &prefix, const vespalib::str query->accept(sem); return query; } + void Test::requireThatSameElementTermsAreProperlyPrefixed() { @@ -915,6 +926,32 @@ Test::requireThatSameElementTermsAreProperlyPrefixed() EXPECT_EQUAL(dynamic_cast<ProtonStringTerm *>(root->getChildren()[1])->getView(), "abc.abc.f2"); } +void +Test::requireThatSameElementDoesNotAllocateMatchData() +{ + Node::UP node = buildSameElementQueryTree(ViewResolver(), plain_index_env); + MatchDataLayout mdl; + MatchDataReserveVisitor visitor(mdl); + node->accept(visitor); + MatchData::UP match_data = mdl.createMatchData(); + EXPECT_EQUAL(0u, match_data->getNumTermFields()); +} + +void +Test::requireThatSameElementIteratorsCanBeBuilt() { + Node::UP node = buildSameElementQueryTree(ViewResolver(), plain_index_env); + FakeSearchContext context(10); + context.addIdx(0).idx(0).getFake() + .addResult(field, string_term, FakeResult() + .doc(4).elem(1).pos(0).doc(8).elem(1).pos(0)) + .addResult(field, prefix_term, FakeResult() + .doc(4).elem(2).pos(0).doc(8).elem(1).pos(1)); + SearchIterator::UP iterator = getIterator(*node, context); + ASSERT_TRUE(iterator.get()); + EXPECT_TRUE(!iterator->seek(4)); + EXPECT_TRUE(iterator->seek(8)); +} + Test::~Test() = default; int @@ -937,7 +974,6 @@ Test::Main() TEST_CALL(requireThatNearIteratorsCanBeBuilt); TEST_CALL(requireThatONearIteratorsCanBeBuilt); TEST_CALL(requireThatPhraseIteratorsCanBeBuilt); - //TODO Add SameElement testing TEST_CALL(requireThatUnknownFieldActsEmpty); TEST_CALL(requireThatIllegalFieldsAreIgnored); TEST_CALL(requireThatQueryGluesEverythingTogether); @@ -949,7 +985,8 @@ Test::Main() TEST_CALL(requireThatParallelWandBlueprintsAreCreatedCorrectly); TEST_CALL(requireThatWhiteListBlueprintCanBeUsed); TEST_CALL(requireThatSameElementTermsAreProperlyPrefixed); - + TEST_CALL(requireThatSameElementDoesNotAllocateMatchData); + TEST_CALL(requireThatSameElementIteratorsCanBeBuilt); TEST_DONE(); } diff --git a/searchcore/src/tests/proton/matching/querynodes_test.cpp b/searchcore/src/tests/proton/matching/querynodes_test.cpp index 7b6fdd1ae88..6607019cccc 100644 --- a/searchcore/src/tests/proton/matching/querynodes_test.cpp +++ b/searchcore/src/tests/proton/matching/querynodes_test.cpp @@ -25,6 +25,7 @@ #include <vespa/searchlib/queryeval/ranksearch.h> #include <vespa/searchlib/queryeval/searchiterator.h> #include <vespa/searchlib/queryeval/simple_phrase_search.h> +#include <vespa/searchlib/queryeval/same_element_search.h> #include <vespa/searchlib/queryeval/sourceblendersearch.h> #include <vespa/searchlib/queryeval/fake_search.h> #include <vespa/searchlib/queryeval/fake_requestcontext.h> @@ -39,28 +40,30 @@ using search::fef::FieldInfo; using search::fef::FieldType; using search::fef::MatchData; using search::fef::MatchDataLayout; -using search::fef::TermFieldMatchData; using search::fef::TermFieldHandle; +using search::fef::TermFieldMatchData; using search::fef::TermFieldMatchDataArray; using search::fef::test::IndexEnvironment; using search::query::Node; using search::query::QueryBuilder; +using search::queryeval::AndNotSearch; +using search::queryeval::AndSearch; +using search::queryeval::Blueprint; +using search::queryeval::EmptySearch; +using search::queryeval::FakeRequestContext; +using search::queryeval::FakeResult; +using search::queryeval::FakeSearch; +using search::queryeval::FieldSpec; using search::queryeval::ISourceSelector; using search::queryeval::NearSearch; using search::queryeval::ONearSearch; using search::queryeval::OrSearch; -using search::queryeval::AndSearch; -using search::queryeval::AndNotSearch; using search::queryeval::RankSearch; -using search::queryeval::Blueprint; +using search::queryeval::SameElementSearch; using search::queryeval::SearchIterator; -using search::queryeval::SourceBlenderSearch; -using search::queryeval::FieldSpec; using search::queryeval::Searchable; -using search::queryeval::FakeSearch; -using search::queryeval::FakeResult; -using search::queryeval::FakeRequestContext; using search::queryeval::SimplePhraseSearch; +using search::queryeval::SourceBlenderSearch; using std::string; using std::vector; using namespace proton::matching; @@ -287,6 +290,20 @@ SearchIterator *getParent<ONear>(SearchIterator *a, SearchIterator *b) { } template <> +SearchIterator *getParent<SameElement>(SearchIterator *a, SearchIterator *b) { + std::vector<SearchIterator::UP> children; + children.emplace_back(a); + children.emplace_back(b); + TermFieldMatchDataArray data; + static TermFieldMatchData tmd; + // we only check how many term/field combinations + // are below the SameElement parent: + // two terms searching in one index field + data.add(&tmd).add(&tmd); + return new SameElementSearch(nullptr, std::move(children), data, true); +} + +template <> SearchIterator *getParent<Or>(SearchIterator *a, SearchIterator *b) { return getSimpleParent<OrSearch>(a, b); } @@ -422,6 +439,7 @@ void checkProperBlending() { TEST_DO(checkOneFieldNoAttributesOneIndex<T>()); } + template <typename T> void checkProperBlendingWithParent() { IteratorStructureTest structure_test; @@ -454,6 +472,24 @@ void checkProperBlendingWithParent() { EXPECT_EQUAL(expected->asString(), structure_test.getIteratorAsString<T>()); } +template <> +void checkProperBlendingWithParent<SameElement>() { + using T = SameElement; + IteratorStructureTest structure_test; + structure_test.setFieldCount(1); + structure_test.setAttributeCount(0); + structure_test.setIndexCount(2); + + SearchIterator::UP expected( + getParent<T>(Blender() + .add(SourceId(0), getTerm(phrase_term1, field[0], source_tag[0])) + .add(SourceId(1), getTerm(phrase_term1, field[0], source_tag[1])), + Blender(bothStrict<T>()) + .add(SourceId(0), getTerm(phrase_term2, field[0], source_tag[0])) + .add(SourceId(1), getTerm(phrase_term2, field[0], source_tag[1])))); + EXPECT_EQUAL(expected->asString(), structure_test.getIteratorAsString<T>()); +} + TEST("requireThatTermNodeSearchIteratorsGetProperBlending") { TEST_DO(checkProperBlending<Term>()); } @@ -463,8 +499,7 @@ TEST("requireThatPhrasesGetProperBlending") { } TEST("requireThatSameElementGetProperBlending") { - //TODO SameEelement needs proper testing/implementation - //TEST_DO(checkProperBlending<SameElement>()); + TEST_DO(checkProperBlendingWithParent<SameElement>()); } TEST("requireThatNearGetProperBlending") { diff --git a/searchcore/src/tests/proton/matching/resolveviewvisitor_test.cpp b/searchcore/src/tests/proton/matching/resolveviewvisitor_test.cpp index 5ea2bcc982b..4fd079949d5 100644 --- a/searchcore/src/tests/proton/matching/resolveviewvisitor_test.cpp +++ b/searchcore/src/tests/proton/matching/resolveviewvisitor_test.cpp @@ -136,6 +136,23 @@ TEST_F("require that equiv nodes resolve view from children", Fixture) { EXPECT_EQUAL(field2, base.field(1).field_name); } +TEST_F("require that view is resolved for SameElement children", Fixture) { + ViewResolver resolver; + resolver.add(view, field1); + + QueryBuilder<ProtonNodeTypes> builder; + builder.addSameElement(2, ""); + ProtonStringTerm &my_term = builder.addStringTerm(term, view, 42, weight); + builder.addStringTerm(term, field2, 43, weight); + Node::UP node = builder.build(); + + ResolveViewVisitor visitor(resolver, f.index_environment); + node->accept(visitor); + + ASSERT_EQUAL(1u, my_term.numFields()); + EXPECT_EQUAL(field1, my_term.field(0).field_name); +} + } // namespace TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index 721214f9e94..4b49f17f74e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -1,5 +1,4 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; @@ -13,76 +12,61 @@ import java.util.Map; import java.util.regex.Pattern; /** - * The result of importing a TensorFlow model into Vespa. - * - A set of signatures which are named collections of inputs and outputs. - * - A set of named constant tensors represented by Variable nodes in TensorFlow. - * - A list of warning messages. + * The result of importing a model (TensorFlow or ONNX) into Vespa. * * @author bratseth */ -// This object can be built incrementally within this package, but is immutable when observed from outside the package -public class TensorFlowModel { +public class ImportedModel { - private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); + private static final String defaultSignatureName = "default"; + private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); private final String name; + private final Map<String, Signature> signatures = new HashMap<>(); + private final Map<String, TensorType> arguments = new HashMap<>(); + private final Map<String, Tensor> smallConstants = new HashMap<>(); + private final Map<String, Tensor> largeConstants = new HashMap<>(); + private final Map<String, RankingExpression> expressions = new HashMap<>(); + private final Map<String, RankingExpression> macros = new HashMap<>(); + private final Map<String, TensorType> requiredMacros = new HashMap<>(); + /** - * Creates a TensorFlow model + * Creates a new imported model. * * @param name the name of this mode, containing only characters in [A-Za-z0-9_] */ - public TensorFlowModel(String name) { + public ImportedModel(String name) { if ( ! nameRegexp.matcher(name).matches()) - throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" + - name + "'"); + throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" + + name + "'"); this.name = name; } /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ public String name() { return name; } - private final Map<String, Signature> signatures = new HashMap<>(); - private final Map<String, TensorType> arguments = new HashMap<>(); - private final Map<String, Tensor> smallConstants = new HashMap<>(); - private final Map<String, Tensor> largeConstants = new HashMap<>(); - private final Map<String, RankingExpression> expressions = new HashMap<>(); - private final Map<String, RankingExpression> macros = new HashMap<>(); - private final Map<String, TensorType> requiredMacros = new HashMap<>(); - - void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } - void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } - void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } - void expression(String name, RankingExpression expression) { expressions.put(name, expression); } - void macro(String name, RankingExpression expression) { macros.put(name, expression); } - void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } - - /** Returns the given signature. If it does not already exist it is added to this. */ - Signature signature(String name) { - return signatures.computeIfAbsent(name, Signature::new); - } - /** Returns an immutable map of the arguments ("Placeholders") of this */ public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } /** * Returns an immutable map of the small constants of this. * These should have sizes up to a few kb at most, and correspond to constant - * values given in the TensorFlow source. + * values given in the TensorFlow or ONNX source. */ public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); } /** * Returns an immutable map of the large constants of this. - * These can have sizes in gigabytes and must be distributed to nodes separately from configuration, - * and correspond to Variable files stored separately in TensorFlow. + * These can have sizes in gigabytes and must be distributed to nodes separately from configuration. + * For TensorFlow this corresponds to Variable files stored separately. */ public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); } /** - * Returns an immutable map of the expressions of this - corresponding to TensorFlow nodes - * which are not Placeholders or Variables (which instead become respectively arguments and constants). - * Note that only nodes recursively referenced by a placeholder are added. + * Returns an immutable map of the expressions of this - corresponding to graph nodes + * which are not Inputs/Placeholders or Variables (which instead become respectively arguments and constants). + * Note that only nodes recursively referenced by a placeholder/input are added. */ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } @@ -95,9 +79,26 @@ public class TensorFlowModel { /** Returns an immutable map of the signatures of this */ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } + /** Returns the given signature. If it does not already exist it is added to this. */ + Signature signature(String name) { + return signatures.computeIfAbsent(name, Signature::new); + } + + /** Convenience method for returning a default signature */ + Signature defaultSignature() { return signature(defaultSignatureName); } + + void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } + void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } + void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } + void expression(String name, RankingExpression expression) { expressions.put(name, expression); } + void macro(String name, RankingExpression expression) { macros.put(name, expression); } + void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } + /** - * A signature is a set of named inputs and outputs, where the inputs maps to argument ("placeholder") names+types, - * and outputs maps to expressions nodes. + * A signature is a set of named inputs and outputs, where the inputs maps to argument + * ("placeholder") names+types, and outputs maps to expressions nodes. + * Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit + * concept of signatures. For now, we handle ONNX models as having a single signature. */ public class Signature { @@ -107,19 +108,14 @@ public class TensorFlowModel { private final Map<String, String> skippedOutputs = new HashMap<>(); private final List<String> importWarnings = new ArrayList<>(); - Signature(String name) { + public Signature(String name) { this.name = name; } - void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } - void output(String name, String expressionName) { outputs.put(name, expressionName); } - void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } - void importWarning(String warning) { importWarnings.add(warning); } - public String name() { return name; } /** Returns the result this is part of */ - TensorFlowModel owner() { return TensorFlowModel.this; } + public ImportedModel owner() { return ImportedModel.this; } /** * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name @@ -127,7 +123,7 @@ public class TensorFlowModel { */ public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); } - /** Returns owner().arguments().get(inputs.get(name)), e.g the type of the argument this input references */ + /** Returns the type of the argument this input references */ public TensorType inputArgument(String inputName) { return owner().arguments().get(inputs.get(inputName)); } /** Returns an immutable list of the expression names of this */ @@ -144,12 +140,17 @@ public class TensorFlowModel { */ public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } - /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */ + /** Returns the expression this output references */ public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); } @Override public String toString() { return "signature '" + name + "'"; } + void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } + void output(String name, String expressionName) { outputs.put(name, expressionName); } + void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } + void importWarning(String warning) { importWarnings.add(warning); } + } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java new file mode 100644 index 00000000000..a658833b426 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java @@ -0,0 +1,242 @@ +package com.yahoo.searchlib.rankingexpression.integration.ml; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.yolean.Exceptions; + +import java.io.File; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Logger; + +/** + * Base class for importing ML models (ONNX/TensorFlow) as native Vespa + * ranking expressions. The general mechanism for import is for the + * specific ML platform import implementations to create an + * IntermediateGraph. This class offers common code to convert the + * IntermediateGraph to Vespa ranking expressions and macros. + * + * @author lesters + */ +public abstract class ModelImporter { + + private static final Logger log = Logger.getLogger(ModelImporter.class.getName()); + + /** + * The main import function. + */ + public abstract ImportedModel importModel(String modelName, String modelPath); + + public ImportedModel importModel(String modelName, File modelDir) { + return importModel(modelName, modelDir.toString()); + } + + /** + * Takes an IntermediateGraph and converts it to a ImportedModel containing + * the actual Vespa ranking expressions. + */ + static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph) { + ImportedModel model = new ImportedModel(graph.name()); + + graph.optimize(); + + importSignatures(graph, model); + importExpressions(graph, model); + reportWarnings(graph, model); + logVariableTypes(graph); + + return model; + } + + private static void importSignatures(IntermediateGraph graph, ImportedModel model) { + for (String signatureName : graph.signatures()) { + ImportedModel.Signature signature = model.signature(signatureName); + for (Map.Entry<String, String> input : graph.inputs(signatureName).entrySet()) { + signature.input(input.getKey(), input.getValue()); + } + for (Map.Entry<String, String> output : graph.outputs(signatureName).entrySet()) { + signature.output(output.getKey(), output.getValue()); + } + } + } + + private static boolean isSignatureInput(ImportedModel model, IntermediateOperation operation) { + for (ImportedModel.Signature signature : model.signatures().values()) { + for (String inputName : signature.inputs().values()) { + if (inputName.equals(operation.name())) { + return true; + } + } + } + return false; + } + + private static boolean isSignatureOutput(ImportedModel model, IntermediateOperation operation) { + for (ImportedModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + if (outputName.equals(operation.name())) { + return true; + } + } + } + return false; + } + + /** + * Convert intermediate representation to Vespa ranking expressions. + */ + static void importExpressions(IntermediateGraph graph, ImportedModel model) { + for (ImportedModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + try { + Optional<TensorFunction> function = importExpression(graph.get(outputName), model); + if (!function.isPresent()) { + signature.skippedOutput(outputName, "No valid output function could be found."); + } + } + catch (IllegalArgumentException e) { + signature.skippedOutput(outputName, Exceptions.toMessageString(e)); + } + } + } + } + + private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) { + if (!operation.type().isPresent()) { + return Optional.empty(); + } + if (operation.isConstant()) { + return importConstant(operation, model); + } + importExpressionInputs(operation, model); + importRankingExpression(operation, model); + importArgumentExpression(operation, model); + importMacroExpression(operation, model); + + return operation.function(); + } + + private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) { + operation.inputs().forEach(input -> importExpression(input, model)); + } + + private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) { + String name = operation.vespaName(); + if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { + return operation.function(); + } + + Value value = operation.getConstantValue().orElseThrow(() -> + new IllegalArgumentException("Operation '" + operation.vespaName() + "' " + + "is constant but does not have a value.")); + if ( ! (value instanceof TensorValue)) { + return operation.function(); // scalar values are inserted directly into the expression + } + + Tensor tensor = value.asTensor(); + if (tensor.type().rank() == 0) { + model.smallConstant(name, tensor); + } else { + model.largeConstant(name, tensor); + } + return operation.function(); + } + + private static void importRankingExpression(IntermediateOperation operation, ImportedModel model) { + if (operation.function().isPresent()) { + String name = operation.name(); + if (!model.expressions().containsKey(name)) { + TensorFunction function = operation.function().get(); + + if (isSignatureOutput(model, operation)) { + OrderedTensorType operationType = operation.type().get(); + OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType); + if ( ! operationType.equals(standardNamingType)) { + List<String> renameFrom = operationType.dimensionNames(); + List<String> renameTo = standardNamingType.dimensionNames(); + function = new Rename(function, renameFrom, renameTo); + } + } + + try { + // We add all intermediate nodes imported as separate expressions. Only + // those referenced from the output will be used. We parse the + // TensorFunction here to convert it to a RankingExpression tree. + model.expression(name, new RankingExpression(name, function.toString())); + } + catch (ParseException e) { + throw new RuntimeException("Imported function " + function + + " cannot be parsed as a ranking expression", e); + } + } + } + } + + private static void importArgumentExpression(IntermediateOperation operation, ImportedModel model) { + if (operation.isInput()) { + // All inputs must have dimensions with standard naming convention: d0, d1, ... + OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get()); + model.argument(operation.vespaName(), standardNamingConvention.type()); + model.requiredMacro(operation.vespaName(), standardNamingConvention.type()); + } + } + + private static void importMacroExpression(IntermediateOperation operation, ImportedModel model) { + if (operation.macro().isPresent()) { + TensorFunction function = operation.macro().get(); + try { + model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString())); + } + catch (ParseException e) { + throw new RuntimeException("Tensorflow function " + function + + " cannot be parsed as a ranking expression", e); + } + } + } + + /** + * Add any import warnings to the signature in the ImportedModel. + */ + private static void reportWarnings(IntermediateGraph graph, ImportedModel model) { + for (ImportedModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + reportWarnings(graph.get(outputName), model); + } + } + } + + private static void reportWarnings(IntermediateOperation operation, ImportedModel model) { + for (String warning : operation.warnings()) { + model.defaultSignature().importWarning(warning); + } + for (IntermediateOperation input : operation.inputs()) { + reportWarnings(input, model); + } + } + + /** + * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type. + * This allows users to learn the exact types (including dimension order after renaming) of the Variables + * such that these can be converted and fed to a parent document independently of the rest of the model + * for fast model weight updates. + */ + private static void logVariableTypes(IntermediateGraph graph) { + for (IntermediateOperation operation : graph.operations()) { + if ( ! (operation instanceof Constant)) continue; + if ( ! operation.type().isPresent()) continue; // will not happen + log.info("Importing TensorFlow variable " + operation.name() + " as " + operation.vespaName() + + " of type " + operation.type().get()); + } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java new file mode 100644 index 00000000000..d3dd2a1d418 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java @@ -0,0 +1,30 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchlib.rankingexpression.integration.ml; + +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter; +import onnx.Onnx; + +import java.io.FileInputStream; +import java.io.IOException; + +/** + * Converts a ONNX model into a ranking expression and set of constants. + * + * @author lesters + */ +public class OnnxImporter extends ModelImporter { + + @Override + public ImportedModel importModel(String modelName, String modelPath) { + try (FileInputStream inputStream = new FileInputStream(modelPath)) { + Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); + IntermediateGraph graph = GraphImporter.importGraph(modelName, model); + return convertIntermediateGraphToModel(graph); + } catch (IOException e) { + throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e); + } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java new file mode 100644 index 00000000000..ff584559a83 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java @@ -0,0 +1,47 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.ml; + +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter; +import org.tensorflow.SavedModelBundle; + +import java.io.IOException; + +/** + * Converts a saved TensorFlow model into a ranking expression and set of constants. + * + * @author bratseth + * @author lesters + */ +public class TensorFlowImporter extends ModelImporter { + + /** + * Imports a saved TensorFlow model from a directory. + * The model should be saved as a .pbtxt or .pb file. + * The name of the model is taken as the db/pbtxt file name (not including the file ending). + * + * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_] + * @param modelDir the directory containing the TensorFlow model files to import + */ + public ImportedModel importModel(String modelName, String modelDir) { + try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { + return importModel(modelName, model); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); + } + } + + /** Imports a TensorFlow model */ + ImportedModel importModel(String modelName, SavedModelBundle model) { + try { + IntermediateGraph graph = GraphImporter.importGraph(modelName, model); + return convertIntermediateGraphToModel(graph); + } + catch (IOException e) { + throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e); + } + } + + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java index c5ac7ace0fc..e1294ec3e01 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java @@ -1,7 +1,8 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter; import com.yahoo.tensor.serialization.JsonFormat; import com.yahoo.yolean.Exceptions; import org.tensorflow.SavedModelBundle; @@ -24,7 +25,7 @@ public class VariableConverter { */ public static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) { try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) { - return JsonFormat.encode(TensorConverter.toVespaTensor(TensorFlowImporter.readVariable(tensorFlowVariableName, + return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName, bundle), OrderedTensorType.fromSpec(orderedTypeSpec))); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java index 2524417cee0..38f1d2329e2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java @@ -1,7 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; import java.util.ArrayDeque; import java.util.ArrayList; @@ -47,7 +47,7 @@ public class DimensionRenamer { /** * Add a constraint between dimension names. */ - public void addConstraint(String from, String to, Constraint pred, OnnxOperation operation) { + public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) { Arc arc = new Arc(from, to, operation); Arc opposite = arc.opposite(); constraints.put(arc, pred); @@ -175,9 +175,9 @@ public class DimensionRenamer { private final String from; private final String to; - private final OnnxOperation operation; + private final IntermediateOperation operation; - Arc(String from, String to, OnnxOperation operation) { + Arc(String from, String to, IntermediateOperation operation) { this.from = from; this.to = to; this.operation = operation; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java new file mode 100644 index 00000000000..39a8b211d09 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java @@ -0,0 +1,107 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchlib.rankingexpression.integration.ml.importer; + +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * Holds an intermediate representation of an imported ONNX or TensorFlow + * graph. After this intermediate representation is constructed, it is used to + * simplify and optimize the computational graph and then converted into the + * final ImportedModel that holds the Vespa ranking expressions for the model. + * + * @author lesters + */ +public class IntermediateGraph { + + private final String modelName; + private final Map<String, IntermediateOperation> index = new HashMap<>(); + private final Map<String, GraphSignature> signatures = new HashMap<>(); + + private static class GraphSignature { + final Map<String, String> inputs = new HashMap<>(); + final Map<String, String> outputs = new HashMap<>(); + } + + public IntermediateGraph(String modelName) { + this.modelName = modelName; + } + + public String name() { + return modelName; + } + + public IntermediateOperation put(String key, IntermediateOperation operation) { + return index.put(key, operation); + } + + public IntermediateOperation get(String key) { + return index.get(key); + } + + public Set<String> signatures() { + return signatures.keySet(); + } + + public Map<String, String> inputs(String signature) { + return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).inputs; + } + + public Map<String, String> outputs(String signature) { + return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).outputs; + } + + public String defaultSignature() { + return "default"; + } + + public boolean alreadyImported(String key) { + return index.containsKey(key); + } + + public Collection<IntermediateOperation> operations() { + return index.values(); + } + + public void optimize() { + renameDimensions(); + } + + /** + * Find dimension names to avoid excessive renaming while evaluating the model. + */ + private void renameDimensions() { + DimensionRenamer renamer = new DimensionRenamer(); + for (String signature : signatures()) { + for (String output : outputs(signature).values()) { + addDimensionNameConstraints(index.get(output), renamer); + } + } + renamer.solve(); + for (String signature : signatures()) { + for (String output : outputs(signature).values()) { + renameDimensions(index.get(output), renamer); + } + } + } + + private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) { + if (operation.type().isPresent()) { + operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); + operation.addDimensionNameConstraints(renamer); + } + } + + private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) { + if (operation.type().isPresent()) { + operation.inputs().forEach(input -> renameDimensions(input, renamer)); + operation.renameDimensions(renamer); + } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java index 812e9b8d678..209d73a9f38 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java @@ -1,9 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer; import com.yahoo.tensor.TensorType; -import onnx.Onnx; +import com.yahoo.tensor.TensorTypeParser; import java.util.ArrayList; import java.util.Collections; @@ -13,9 +13,9 @@ import java.util.stream.Collectors; /** * A Vespa tensor type is ordered by the lexicographical ordering of dimension - * names. ONNX tensors have an explicit ordering of their dimensions. + * names. Imported tensors have an explicit ordering of their dimensions. * During import, we need to track the Vespa dimension that matches the - * corresponding ONNX dimension as the ordering can change after + * corresponding imported dimension as the ordering can change after * dimension renaming. That is the purpose of this class. * * @author lesters @@ -25,14 +25,14 @@ public class OrderedTensorType { private final TensorType type; private final List<TensorType.Dimension> dimensions; - private final long[] innerSizesOnnx; + private final long[] innerSizesOriginal; private final long[] innerSizesVespa; private final int[] dimensionMap; private OrderedTensorType(List<TensorType.Dimension> dimensions) { this.dimensions = Collections.unmodifiableList(dimensions); this.type = new TensorType.Builder(dimensions).build(); - this.innerSizesOnnx = new long[dimensions.size()]; + this.innerSizesOriginal = new long[dimensions.size()]; this.innerSizesVespa = new long[dimensions.size()]; this.dimensionMap = createDimensionMap(); } @@ -54,10 +54,10 @@ public class OrderedTensorType { if (numDimensions == 0) { return null; } - innerSizesOnnx[numDimensions - 1] = 1; + innerSizesOriginal[numDimensions - 1] = 1; innerSizesVespa[numDimensions - 1] = 1; for (int i = numDimensions - 1; --i >= 0; ) { - innerSizesOnnx[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOnnx[i+1]; + innerSizesOriginal[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOriginal[i+1]; innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1]; } int[] mapping = new int[numDimensions]; @@ -74,11 +74,15 @@ public class OrderedTensorType { return mapping; } + public int dimensionMap(int originalIndex) { + return dimensionMap[originalIndex]; + } + /** - * When dimension ordering between Vespa and Onnx differs, i.e. + * When dimension ordering between Vespa and imported differs, i.e. * after dimension renaming, use the dimension map to read in values * so that they are correctly laid out in memory for Vespa. - * Used when importing tensors from Onnx. + * Used when importing tensors. */ public int toDirectIndex(int index) { if (dimensions.size() == 0) { @@ -90,9 +94,9 @@ public class OrderedTensorType { int directIndex = 0; long rest = index; for (int i = 0; i < dimensions.size(); ++i) { - long address = rest / innerSizesOnnx[i]; + long address = rest / innerSizesOriginal[i]; directIndex += innerSizesVespa[dimensionMap[i]] * address; - rest %= innerSizesOnnx[i]; + rest %= innerSizesOriginal[i]; } return directIndex; } @@ -116,22 +120,6 @@ public class OrderedTensorType { return true; } - public void verifyType(Onnx.TypeProto typeProto) { - Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape(); - if (shape != null) { - if (shape.getDimCount() != type.rank()) { - throw new IllegalArgumentException("Onnx shape of does not match Vespa shape"); - } - for (int onnxIndex = 0; onnxIndex < dimensions.size(); ++onnxIndex) { - int vespaIndex = dimensionMap[onnxIndex]; - Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex); - TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex); - if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) { - throw new IllegalArgumentException("TensorFlow dimensions of does not match Vespa dimensions"); - } - } - } - } public OrderedTensorType rename(DimensionRenamer renamer) { List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size()); for (TensorType.Dimension dimension : dimensions) { @@ -151,18 +139,13 @@ public class OrderedTensorType { return new OrderedTensorType(renamedDimensions); } - public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) { - return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ... - } - - public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) { - Onnx.TensorShapeProto shape = type.getTensorType().getShape(); - Builder builder = new Builder(shape); - for (int i = 0; i < shape.getDimCount(); ++ i) { + public OrderedTensorType rename(String dimensionPrefix) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + for (int i = 0; i < dimensions.size(); ++ i) { String dimensionName = dimensionPrefix + i; - Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); - if (onnxDimension.getDimValue() >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue())); + Optional<Long> dimSize = dimensions.get(i).size(); + if (dimSize.isPresent() && dimSize.get() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, dimSize.get())); } else { builder.add(TensorType.Dimension.indexed(dimensionName)); } @@ -170,13 +153,13 @@ public class OrderedTensorType { return builder.build(); } - public static OrderedTensorType fromOnnxType(List<Long> dims, String dimensionPrefix) { - Builder builder = new Builder(); - for (int i = 0; i < dims.size(); ++ i) { - String dimensionName = dimensionPrefix + i; - Long dimSize = dims.get(i); - if (dimSize >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, dimSize)); + public static OrderedTensorType standardType(OrderedTensorType type) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + for (int i = 0; i < type.dimensions().size(); ++ i) { + TensorType.Dimension dim = type.dimensions().get(i); + String dimensionName = "d" + i; + if (dim.size().isPresent() && dim.size().get() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get())); } else { builder.add(TensorType.Dimension.indexed(dimensionName)); } @@ -184,13 +167,46 @@ public class OrderedTensorType { return builder.build(); } - public static OrderedTensorType standardType(OrderedTensorType type) { - Builder builder = new Builder(); - for (int i = 0; i < type.dimensions().size(); ++ i) { - TensorType.Dimension dim = type.dimensions().get(i); - String dimensionName = "d" + i; - if (dim.size().isPresent() && dim.size().get() >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get())); + public static Long tensorSize(TensorType type) { + Long size = 1L; + for (TensorType.Dimension dimension : type.dimensions()) { + size *= dimensionSize(dimension); + } + return size; + } + + public static Long dimensionSize(TensorType.Dimension dim) { + return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size")); + } + + /** + * Returns a string representation of this: A standard tensor type string where dimensions + * are listed in the order of this rather than in the natural order of their names. + */ + @Override + public String toString() { + return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")"; + } + + /** + * Creates an instance from the string representation of this: A standard tensor type string + * where dimensions are listed in the order of this rather than the natural order of their names. + */ + public static OrderedTensorType fromSpec(String typeSpec) { + return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec)); + } + + public static OrderedTensorType fromDimensionList(List<Long> dims) { + return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ... + } + + public static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + for (int i = 0; i < dims.size(); ++ i) { + String dimensionName = dimensionPrefix + i; + Long dimSize = dims.get(i); + if (dimSize >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, dimSize)); } else { builder.add(TensorType.Dimension.indexed(dimensionName)); } @@ -200,45 +216,13 @@ public class OrderedTensorType { public static class Builder { - private final Onnx.TensorShapeProto shape; private final List<TensorType.Dimension> dimensions; - public Builder(Onnx.TensorShapeProto shape) { - this.shape = shape; - this.dimensions = new ArrayList<>(shape.getDimCount()); - } - public Builder() { - this.shape = null; this.dimensions = new ArrayList<>(); } public Builder add(TensorType.Dimension vespaDimension) { - if (shape != null) { - int index = dimensions.size(); - Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(index); - long size = onnxDimension.getDimValue(); - if (size >= 0) { - if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) { - throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " + - "dimension types"); - } - if (!vespaDimension.size().isPresent()) { - throw new IllegalArgumentException("Tensor dimension is indexed bound but does " + - "not have a size"); - } - if (vespaDimension.size().get() != size) { - throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " + - "dimension sizes. TensorFlow: " + size + " Vespa: " + - vespaDimension.size().get()); - } - } else { - if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) { - throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " + - "dimension types"); - } - } - } this.dimensions.add(vespaDimension); return this; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java new file mode 100644 index 00000000000..3fe92440cae --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java @@ -0,0 +1,216 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx; + +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape; +import com.yahoo.tensor.functions.ScalarFunctions; +import onnx.Onnx; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * Converts an ONNX graph to a Vespa IntermediateGraph which is the basis + * for generating Vespa ranking expressions. + * + * @author lesters + */ +public class GraphImporter { + + public static IntermediateOperation mapOperation(Onnx.NodeProto node, + List<IntermediateOperation> inputs, + IntermediateGraph graph) { + String nodeName = node.getName(); + String modelName = graph.name(); + + switch (node.getOpType().toLowerCase()) { + case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); + case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); + case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); + case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin()); + case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan()); + case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil()); + case "concat": return new ConcatV2(modelName, nodeName, inputs); + case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); + case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); + case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); + case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal()); + case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); + case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); + case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater()); + case "identity": return new Identity(modelName, nodeName, inputs); + case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less()); + case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log()); + case "matmul": return new MatMul(modelName, nodeName, inputs); + case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max()); + case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min()); + case "mean": return new Join(modelName, nodeName, inputs, ScalarFunctions.mean()); + case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); + case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg()); + case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow()); + case "reshape": return new Reshape(modelName, nodeName, inputs); + case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal()); + case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); + case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); + case "shape": return new Shape(modelName, nodeName, inputs); + case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin()); + case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); + case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); + case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); + case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan()); + case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); + } + + IntermediateOperation op = new NoOp(modelName, node.getName(), inputs); + op.warning("Operation '" + node.getOpType() + "' is currently not implemented"); + return op; + } + + public static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) { + Onnx.GraphProto onnxGraph = model.getGraph(); + + IntermediateGraph intermediateGraph = new IntermediateGraph(modelName); + importOperations(onnxGraph, intermediateGraph); + verifyOutputTypes(onnxGraph, intermediateGraph); + + return intermediateGraph; + } + + private static void importOperations(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) { + for (Onnx.ValueInfoProto valueInfo : onnxGraph.getOutputList()) { + importOperation(valueInfo.getName(), onnxGraph, intermediateGraph); + } + } + + private static IntermediateOperation importOperation(String name, + Onnx.GraphProto onnxGraph, + IntermediateGraph intermediateGraph) { + if (intermediateGraph.alreadyImported(name)) { + return intermediateGraph.get(name); + } + IntermediateOperation operation; + if (isArgumentTensor(name, onnxGraph)) { + Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph); + if (valueInfoProto == null) + throw new IllegalArgumentException("Could not find argument tensor: " + name); + OrderedTensorType type = TypeConverter.fromOnnxType(valueInfoProto.getType()); + operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type); + + intermediateGraph.inputs(intermediateGraph.defaultSignature()) + .put(IntermediateOperation.namePartOf(name), operation.vespaName()); + + } else if (isConstantTensor(name, onnxGraph)) { + Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph); + OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList()); + operation = new Constant(intermediateGraph.name(), name, defaultType); + operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); + + } else { + Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph); + List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph); + operation = mapOperation(node, inputs, intermediateGraph); + + if (isOutputNode(name, onnxGraph)) { + intermediateGraph.outputs(intermediateGraph.defaultSignature()) + .put(IntermediateOperation.namePartOf(name), operation.vespaName()); + } + } + intermediateGraph.put(operation.vespaName(), operation); + + return operation; + } + + private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) { + Onnx.ValueInfoProto value = getArgumentTensor(name, graph); + Onnx.TensorProto tensor = getConstantTensor(name, graph); + return value != null && tensor == null; + } + + private static boolean isConstantTensor(String name, Onnx.GraphProto graph) { + Onnx.ValueInfoProto value = getArgumentTensor(name, graph); + Onnx.TensorProto tensor = getConstantTensor(name, graph); + return value != null && tensor != null; + } + + private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) { + for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) { + if (valueInfo.getName().equals(name)) { + return valueInfo; + } + } + return null; + } + + private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) { + for (Onnx.TensorProto tensorProto : graph.getInitializerList()) { + if (tensorProto.getName().equals(name)) { + return tensorProto; + } + } + return null; + } + + private static boolean isOutputNode(String name, Onnx.GraphProto graph) { + return getOutputNode(name, graph) != null; + } + + private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) { + for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) { + if (valueInfo.getName().equals(name)) { + return valueInfo; + } + String nodeName = IntermediateOperation.namePartOf(valueInfo.getName()); + if (nodeName.equals(name)) { + return valueInfo; + } + } + return null; + } + + private static List<IntermediateOperation> importOperationInputs(Onnx.NodeProto node, + Onnx.GraphProto onnxGraph, + IntermediateGraph intermediateGraph) { + return node.getInputList().stream() + .map(nodeName -> importOperation(nodeName, onnxGraph, intermediateGraph)) + .collect(Collectors.toList()); + } + + private static void verifyOutputTypes(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) { + for (String outputName : intermediateGraph.outputs(intermediateGraph.defaultSignature()).values()) { + IntermediateOperation operation = intermediateGraph.get(outputName); + Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, onnxGraph); + OrderedTensorType type = operation.type().orElseThrow( + () -> new IllegalArgumentException("Output of '" + outputName + "' has no type.")); + TypeConverter.verifyType(onnxNode.getType(), type); + } + } + + private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) { + boolean hasPortNumber = nodeName.contains(":"); + for (Onnx.NodeProto node : graph.getNodeList()) { + if (hasPortNumber) { + for (String outputName : node.getOutputList()) { + if (outputName.equals(nodeName)) { + return node; + } + } + } else if (node.getName().equals(nodeName)) { + return node; + } + } + throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph"); + } +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java index 2912db03b5f..18856d4a25f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java @@ -1,17 +1,16 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx; import com.google.protobuf.ByteString; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; import onnx.Onnx; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; -import java.util.List; /** * Converts Onnx tensors into Vespa tensors. @@ -29,7 +28,6 @@ public class TensorConverter { return builder.build(); } - /* todo: support more types */ private static Values readValuesOf(Onnx.TensorProto tensorProto) { if (tensorProto.hasRawData()) { switch (tensorProto.getDataType()) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java new file mode 100644 index 00000000000..715c55d8323 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java @@ -0,0 +1,52 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx; + +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import onnx.Onnx; + +/** + * Converts and verifies ONNX tensor types into Vespa tensor types. + * + * @author lesters + */ +public class TypeConverter { + + public static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) { + Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape(); + if (shape != null) { + if (shape.getDimCount() != type.rank()) { + throw new IllegalArgumentException("Onnx shape of does not match Vespa shape"); + } + for (int onnxIndex = 0; onnxIndex < type.dimensions().size(); ++onnxIndex) { + int vespaIndex = type.dimensionMap(onnxIndex); + Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex); + TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex); + if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) { + throw new IllegalArgumentException("Onnx dimensions of does not match Vespa dimensions"); + } + } + } + } + + public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) { + return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ... + } + + public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) { + Onnx.TensorShapeProto shape = type.getTensorType().getShape(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + for (int i = 0; i < shape.getDimCount(); ++ i) { + String dimensionName = dimensionPrefix + i; + Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); + if (onnxDimension.getDimValue() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue())); + } else { + builder.add(TensorType.Dimension.indexed(dimensionName)); + } + } + return builder.build(); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java index 1619c11427a..7fc2aae87d1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java @@ -1,28 +1,29 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; +import java.util.Collections; import java.util.List; -public class Placeholder extends TensorFlowOperation { +public class Argument extends IntermediateOperation { private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ... - public Placeholder(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); - standardNamingType = OrderedTensorType.fromTensorFlowType(node); + public Argument(String modelName, String nodeName, OrderedTensorType type) { + super(modelName, nodeName, Collections.emptyList()); + this.type = type.rename(vespaName() + "_"); + standardNamingType = OrderedTensorType.standardType(type); } @Override protected OrderedTensorType lazyGetType() { - return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_"); + return type; } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java index 4f5d61d75f9..1b8c62fe0e9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java @@ -1,38 +1,37 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.Optional; -public class ConcatV2 extends TensorFlowOperation { +public class ConcatV2 extends IntermediateOperation { private String concatDimensionName; - public ConcatV2(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public ConcatV2(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override protected OrderedTensorType lazyGetType() { - if (!inputs.stream().map(TensorFlowOperation::type).allMatch(Optional::isPresent)) { + if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) { return null; } - TensorFlowOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input + IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input if (!concatDimOp.getConstantValue().isPresent()) { - throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " + + throw new IllegalArgumentException("ConcatV2 in " + name + ": " + "concat dimension must be a constant."); } Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor(); if (concatDimTensor.type().rank() != 0) { - throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " + + throw new IllegalArgumentException("ConcatV2 in " + name + ": " + "concat dimension must be a scalar."); } @@ -44,7 +43,7 @@ public class ConcatV2 extends TensorFlowOperation { for (int i = 1; i < inputs.size() - 1; ++i) { OrderedTensorType bType = inputs.get(i).type().get(); if (bType.rank() != aType.rank()) { - throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " + + throw new IllegalArgumentException("ConcatV2 in " + name + ": " + "inputs must have save rank."); } for (int j = 0; j < aType.rank(); ++j) { @@ -53,13 +52,13 @@ public class ConcatV2 extends TensorFlowOperation { if (j == concatDim) { concatDimSize += dimSizeB; } else if (dimSizeA != dimSizeB) { - throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " + + throw new IllegalArgumentException("ConcatV2 in " + name + ": " + "input dimension " + j + " differs in input tensors."); } } } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); int dimensionIndex = 0; for (TensorType.Dimension dimension : aType.dimensions()) { if (dimensionIndex == concatDim) { @@ -75,7 +74,7 @@ public class ConcatV2 extends TensorFlowOperation { @Override protected TensorFunction lazyGetFunction() { - if (!inputs.stream().map(TensorFlowOperation::function).allMatch(Optional::isPresent)) { + if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) { return null; } TensorFunction result = inputs.get(0).function().get(); @@ -88,7 +87,7 @@ public class ConcatV2 extends TensorFlowOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!inputs.stream().map(TensorFlowOperation::type).allMatch(Optional::isPresent)) { + if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) { return; } OrderedTensorType a = inputs.get(0).type().get(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java index 718e2a4b3c2..3c0f8569c47 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java @@ -1,36 +1,38 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.Optional; -public class Const extends TensorFlowOperation { +public class Const extends IntermediateOperation { - public Const(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + private final AttributeMap attributeMap; + + public Const(String modelName, + String nodeName, + List<IntermediateOperation> inputs, + AttributeMap attributeMap, + OrderedTensorType type) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + this.type = type.rename(vespaName() + "_"); setConstantValue(value()); } @Override protected OrderedTensorType lazyGetType() { - return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_"); + return type; } @Override @@ -55,7 +57,7 @@ public class Const extends TensorFlowOperation { /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ @Override public String vespaName() { - return modelName() + "_" + super.vespaName(); + return modelName + "_" + super.vespaName(); } @Override @@ -77,24 +79,11 @@ public class Const extends TensorFlowOperation { } private Value value() { - if ( ! node.getAttrMap().containsKey("value")) { - throw new IllegalArgumentException("Node '" + node.getName() + "' of type " + - "const has missing 'value' attribute"); - } - AttrValue attrValue = node.getAttrMap().get("value"); - if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) { - return new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type().get().type())); - } - if (attrValue.getValueCase() == AttrValue.ValueCase.B) { - return new BooleanValue(attrValue.getB()); - } - if (attrValue.getValueCase() == AttrValue.ValueCase.I) { - return new DoubleValue(attrValue.getI()); - } - if (attrValue.getValueCase() == AttrValue.ValueCase.F) { - return new DoubleValue(attrValue.getF()); + Optional<Value> value = attributeMap.get("value", type); + if ( ! value.isPresent()) { + throw new IllegalArgumentException("Node '" + name + "' of type " + + "const has missing or non-recognized 'value' attribute"); } - throw new IllegalArgumentException("Requesting value of constant in " + - node.getName() + " but type is not recognized."); + return value.get(); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java index 13043a61a8e..5e4abeaa234 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java @@ -1,38 +1,34 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; -import onnx.Onnx; import java.util.Collections; import java.util.Optional; -public class Constant extends OnnxOperation { +public class Constant extends IntermediateOperation { - final String modelName; - final Onnx.TensorProto tensorProto; + private final String modelName; - public Constant(String modelName, Onnx.TensorProto tensorProto) { - super(null, Collections.emptyList()); + public Constant(String modelName, String nodeName, OrderedTensorType type) { + super(modelName, nodeName, Collections.emptyList()); this.modelName = modelName; - this.tensorProto = tensorProto; + this.type = type.rename(vespaName() + "_"); } /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ @Override public String vespaName() { - return modelName + "_" + vespaName(tensorProto.getName()); + return modelName + "_" + vespaName(name); } @Override protected OrderedTensorType lazyGetType() { - return OrderedTensorType.fromOnnxType(tensorProto.getDimsList(), vespaName() + "_"); + return type; } @Override @@ -40,9 +36,14 @@ public class Constant extends OnnxOperation { return null; // will be added by function() since this is constant. } + /** + * Constant values are sent in via the constantValueFunction, as the + * dimension names and thus the data layout depends on the dimension + * renaming which happens after the conversion to intermediate graph. + */ @Override public Optional<Value> getConstantValue() { - return Optional.of(new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); + return Optional.ofNullable(constantValueFunction).map(func -> func.apply(type)); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java index 2d0f4c7042b..742ed8b89ab 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java @@ -1,9 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; @@ -12,18 +12,17 @@ import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Generate; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.ArrayList; import java.util.List; import java.util.Optional; -public class ExpandDims extends TensorFlowOperation { +public class ExpandDims extends IntermediateOperation { private List<String> expandDimensions; - public ExpandDims(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public ExpandDims(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override @@ -32,14 +31,14 @@ public class ExpandDims extends TensorFlowOperation { return null; } - TensorFlowOperation axisOperation = inputs().get(1); + IntermediateOperation axisOperation = inputs().get(1); if (!axisOperation.getConstantValue().isPresent()) { - throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " + + throw new IllegalArgumentException("ExpandDims in " + name + ": " + "axis must be a constant."); } Tensor axis = axisOperation.getConstantValue().get().asTensor(); if (axis.type().rank() != 0) { - throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " + + throw new IllegalArgumentException("ExpandDims in " + name + ": " + "axis argument must be a scalar."); } @@ -49,7 +48,7 @@ public class ExpandDims extends TensorFlowOperation { dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); expandDimensions = new ArrayList<>(); int dimensionIndex = 0; for (TensorType.Dimension dimension : inputType.dimensions()) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java index 1408e7e04f0..d29bd4b7a9e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java @@ -1,22 +1,21 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.List; -public class Identity extends TensorFlowOperation { +public class Identity extends IntermediateOperation { - public Identity(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public Identity(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ @Override public String vespaName() { - return modelName() + "_" + super.vespaName(); + return modelName + "_" + super.vespaName(); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java index 3687bba8b85..43de29cedd5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java @@ -1,17 +1,16 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; -import com.yahoo.searchlib.rankingexpression.RankingExpression; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; + import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.ArrayList; import java.util.Collections; @@ -20,43 +19,40 @@ import java.util.Optional; import java.util.function.Function; /** - * Wraps a TensorFlow node and produces the respective Vespa tensor operation. - * During import, a graph of these operations are constructed. Then, the - * types are used to deduce sensible dimension names using the - * DimensionRenamer. After the types have been renamed, the proper - * Vespa expressions can be extracted. + * Wraps an imported operation node and produces the respective Vespa tensor + * operation. During import, a graph of these operations are constructed. Then, + * the types are used to deduce sensible dimension names using the + * DimensionRenamer. After the types have been renamed, the proper Vespa + * expressions can be extracted. * * @author lesters */ -public abstract class TensorFlowOperation { - - protected final static String MACRO_PREFIX = "tf_macro_"; +public abstract class IntermediateOperation { - private final String modelName; + private final static String MACRO_PREFIX = "imported_ml_macro_"; - protected final NodeDef node; - protected final int port; - protected final List<TensorFlowOperation> inputs; - protected final List<TensorFlowOperation> outputs = new ArrayList<>(); - protected final List<String> importWarnings = new ArrayList<>(); + protected final String name; + protected final String modelName; + protected final List<IntermediateOperation> inputs; + protected final List<IntermediateOperation> outputs = new ArrayList<>(); protected OrderedTensorType type; protected TensorFunction function; protected TensorFunction macro = null; + private final List<String> importWarnings = new ArrayList<>(); private Value constantValue = null; - private List<TensorFlowOperation> controlInputs = Collections.emptyList(); + private List<IntermediateOperation> controlInputs = Collections.emptyList(); - TensorFlowOperation(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + protected Function<OrderedTensorType, Value> constantValueFunction = null; + + IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) { + this.name = name; this.modelName = modelName; - this.node = node; - this.port = port; this.inputs = Collections.unmodifiableList(inputs); this.inputs.forEach(i -> i.outputs.add(this)); } - protected String modelName() { return modelName; } - protected abstract OrderedTensorType lazyGetType(); protected abstract TensorFunction lazyGetFunction(); @@ -65,9 +61,6 @@ public abstract class TensorFlowOperation { if (type == null) { type = lazyGetType(); } - if (type != null) { - type.verifyType(node); - } return Optional.ofNullable(type); } @@ -87,14 +80,14 @@ public abstract class TensorFlowOperation { return Optional.ofNullable(function); } - /** Return TensorFlow node */ - public NodeDef node() { return node; } + /** Returns original name of this operation node */ + public String name() { return name; } /** Return unmodifiable list of inputs */ - public List<TensorFlowOperation> inputs() { return inputs; } + public List<IntermediateOperation> inputs() { return inputs; } /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */ - public List<TensorFlowOperation> outputs() { return Collections.unmodifiableList(outputs); } + public List<IntermediateOperation> outputs() { return Collections.unmodifiableList(outputs); } /** Returns a Vespa ranking expression that should be added as a macro */ public Optional<TensorFunction> macro() { return Optional.ofNullable(macro); } @@ -109,22 +102,34 @@ public abstract class TensorFlowOperation { public boolean isInput() { return false; } /** Return true if this node is constant */ - public boolean isConstant() { return inputs.stream().allMatch(TensorFlowOperation::isConstant); } + public boolean isConstant() { return inputs.stream().allMatch(IntermediateOperation::isConstant); } /** Sets the constant value */ public void setConstantValue(Value value) { constantValue = value; } /** Gets the constant value if it exists */ - public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); } + public Optional<Value> getConstantValue() { + if (constantValue != null) { + return Optional.of(constantValue); + } + if (constantValueFunction != null) { + return Optional.of(constantValueFunction.apply(type)); + } + return Optional.empty(); + } + + /** Set the constant value function */ + public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { this.constantValueFunction = func; } /** Sets the external control inputs */ - public void setControlInputs(List<TensorFlowOperation> inputs) { this.controlInputs = inputs; } + public void setControlInputs(List<IntermediateOperation> inputs) { this.controlInputs = inputs; } /** Retrieve the control inputs for this operation */ - public List<TensorFlowOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); } + public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); } /** Retrieve the valid Vespa name of this node */ - public String vespaName() { return node.getName() != null ? node.getName().replace('/', '_') : null; } + public String vespaName() { return vespaName(name); } + public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; } /** Retrieve the valid Vespa name of this node if it is a macro */ public String macroName() { return vespaName() != null ? MACRO_PREFIX + modelName + "_" + vespaName() : null; } @@ -135,23 +140,48 @@ public abstract class TensorFlowOperation { /** Set an input warning */ public void warning(String warning) { importWarnings.add(warning); } - boolean verifyInputs(int expected, Function<TensorFlowOperation, Optional<?>> func) { - if (!controlInputs.stream().map(func).allMatch(Optional::isPresent)) { - return false; - } + boolean verifyInputs(int expected, Function<IntermediateOperation, Optional<?>> func) { if (inputs.size() != expected) { throw new IllegalArgumentException("Expected " + expected + " inputs " + - "for '" + node.getName() + "', got " + inputs.size()); + "for '" + name + "', got " + inputs.size()); } return inputs.stream().map(func).allMatch(Optional::isPresent); } boolean allInputTypesPresent(int expected) { - return verifyInputs(expected, TensorFlowOperation::type); + return verifyInputs(expected, IntermediateOperation::type); } boolean allInputFunctionsPresent(int expected) { - return verifyInputs(expected, TensorFlowOperation::function); + return verifyInputs(expected, IntermediateOperation::function); + } + + /** + * A method signature input and output has the form name:index. + * This returns the name part without the index. + */ + public static String namePartOf(String name) { + name = name.startsWith("^") ? name.substring(1) : name; + return name.split(":")[0]; + } + + /** + * This return the output index part. Indexes are used for nodes with + * multiple outputs. + */ + public static int indexPartOf(String name) { + int i = name.indexOf(":"); + return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); + } + + /** + * An interface mapping operation attributes to Vespa Values. + * Adapter for differences in ONNX/TensorFlow. + */ + public interface AttributeMap { + Optional<Value> get(String key); + Optional<Value> get(String key, OrderedTensorType type); + Optional<List<Value>> getList(String key); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java index fe2004a528d..8413ed74118 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java @@ -1,24 +1,22 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.TensorFunction; -import onnx.Onnx; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.function.DoubleBinaryOperator; -public class Join extends OnnxOperation { +public class Join extends IntermediateOperation { private final DoubleBinaryOperator operator; - public Join(Onnx.NodeProto node, List<OnnxOperation> inputs, DoubleBinaryOperator operator) { - super(node, inputs); + public Join(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleBinaryOperator operator) { + super(modelName, nodeName, inputs); this.operator = operator; } @@ -61,8 +59,8 @@ public class Join extends OnnxOperation { return null; } - OnnxOperation a = largestInput(); - OnnxOperation b = smallestInput(); + IntermediateOperation a = largestInput(); + IntermediateOperation b = smallestInput(); List<String> aDimensionsToReduce = new ArrayList<>(); List<String> bDimensionsToReduce = new ArrayList<>(); @@ -107,13 +105,13 @@ public class Join extends OnnxOperation { } } - private OnnxOperation largestInput() { + private IntermediateOperation largestInput() { OrderedTensorType a = inputs.get(0).type().get(); OrderedTensorType b = inputs.get(1).type().get(); return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1); } - private OnnxOperation smallestInput() { + private IntermediateOperation smallestInput() { OrderedTensorType a = inputs.get(0).type().get(); OrderedTensorType b = inputs.get(1).type().get(); return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java index c015f5ecba8..f54ae83052f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java @@ -1,20 +1,19 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.Optional; import java.util.function.DoubleUnaryOperator; -public class Map extends TensorFlowOperation { +public class Map extends IntermediateOperation { private final DoubleUnaryOperator operator; - public Map(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleUnaryOperator operator) { - super(modelName, node, inputs, port); + public Map(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleUnaryOperator operator) { + super(modelName, nodeName, inputs); this.operator = operator; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java index 1b388e2ae89..52e223f9518 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java @@ -1,21 +1,18 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; -import onnx.Onnx; -import java.util.Collections; import java.util.List; import java.util.Optional; -import java.util.function.DoubleBinaryOperator; -public class MatMul extends OnnxOperation { +public class MatMul extends IntermediateOperation { - public MatMul(Onnx.NodeProto node, List<OnnxOperation> inputs) { - super(node, inputs); + public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java index 3eba872c6a0..95a77c07590 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java @@ -1,9 +1,10 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; @@ -13,20 +14,20 @@ import com.yahoo.tensor.functions.Generate; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.NodeDef; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Optional; -public class Mean extends TensorFlowOperation { +public class Mean extends IntermediateOperation { + private final AttributeMap attributeMap; private List<String> reduceDimensions; - public Mean(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public Mean(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; } @Override @@ -34,9 +35,9 @@ public class Mean extends TensorFlowOperation { if (!allInputTypesPresent(2)) { return null; } - TensorFlowOperation reductionIndices = inputs.get(1); + IntermediateOperation reductionIndices = inputs.get(1); if (!reductionIndices.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Mean in " + node.getName() + ": " + + throw new IllegalArgumentException("Mean in " + name + ": " + "reduction indices must be a constant."); } Tensor indices = reductionIndices.getConstantValue().get().asTensor(); @@ -54,7 +55,7 @@ public class Mean extends TensorFlowOperation { return reducedType(inputType, shouldKeepDimensions()); } - // todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity. + // optimization: if keepDims and one reduce dimension that has size 1: same as identity. @Override protected TensorFunction lazyGetFunction() { @@ -93,12 +94,12 @@ public class Mean extends TensorFlowOperation { } private boolean shouldKeepDimensions() { - AttrValue keepDimsAttr = node.getAttrMap().get("keep_dims"); - return keepDimsAttr != null && keepDimsAttr.getB(); + Optional<Value> keepDims = attributeMap.get("keep_dims"); + return keepDims.isPresent() && keepDims.get().asBoolean(); } private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); for (TensorType.Dimension dimension: inputType.type().dimensions()) { if (!reduceDimensions.contains(dimension.name())) { builder.add(dimension); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java index 4c95e67e184..9d9eca47b1c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java @@ -1,21 +1,20 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.List; -public class Merge extends TensorFlowOperation { +public class Merge extends IntermediateOperation { - public Merge(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public Merge(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override protected OrderedTensorType lazyGetType() { - for (TensorFlowOperation operation : inputs) { + for (IntermediateOperation operation : inputs) { if (operation.type().isPresent()) { return operation.type().get(); } @@ -25,7 +24,7 @@ public class Merge extends TensorFlowOperation { @Override protected TensorFunction lazyGetFunction() { - for (TensorFlowOperation operation : inputs) { + for (IntermediateOperation operation : inputs) { if (operation.function().isPresent()) { return operation.function().get(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java new file mode 100644 index 00000000000..19ba146492c --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java @@ -0,0 +1,26 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.Collections; +import java.util.List; + +public class NoOp extends IntermediateOperation { + + public NoOp(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, Collections.emptyList()); // don't propagate inputs + } + + @Override + protected OrderedTensorType lazyGetType() { + return null; + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java index 65ce7f00e34..9299ae9be12 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java @@ -1,17 +1,16 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.Optional; -public class PlaceholderWithDefault extends TensorFlowOperation { +public class PlaceholderWithDefault extends IntermediateOperation { - public PlaceholderWithDefault(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public PlaceholderWithDefault(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java index e7d90e5fc1f..e91c2305f7d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java @@ -1,10 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; @@ -19,19 +18,18 @@ import com.yahoo.tensor.functions.Generate; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.stream.Collectors; -import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize; +import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize; -public class Reshape extends TensorFlowOperation { +public class Reshape extends IntermediateOperation { - public Reshape(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override @@ -39,15 +37,15 @@ public class Reshape extends TensorFlowOperation { if (!allInputTypesPresent(2)) { return null; } - TensorFlowOperation newShape = inputs.get(1); + IntermediateOperation newShape = inputs.get(1); if (!newShape.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Reshape in " + node.getName() + ": " + + throw new IllegalArgumentException("Reshape in " + name + ": " + "shape input must be a constant."); } Tensor shape = newShape.getConstantValue().get().asTensor(); OrderedTensorType inputType = inputs.get(0).type().get(); - OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(node); + OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(); int dimensionIndex = 0; for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { Tensor.Cell cell = cellIterator.next(); @@ -124,7 +122,7 @@ public class Reshape extends TensorFlowOperation { operators.add(0, ArithmeticOperator.MULTIPLY); children.add(0, new ConstantNode(new DoubleValue(size))); } - size *= TensorConverter.dimensionSize(dimension); + size *= OrderedTensorType.dimensionSize(dimension); if (i > 0) { operators.add(0, ArithmeticOperator.PLUS); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java index 5fdcb5a695f..927a4a368f9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java @@ -1,24 +1,23 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.function.DoubleBinaryOperator; -import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.dimensionSize; -import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize; +import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.dimensionSize; +import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize; -public class Select extends TensorFlowOperation { +public class Select extends IntermediateOperation { - public Select(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public Select(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override @@ -39,7 +38,7 @@ public class Select extends TensorFlowOperation { if (!allInputFunctionsPresent(3)) { return null; } - TensorFlowOperation conditionOperation = inputs().get(0); + IntermediateOperation conditionOperation = inputs().get(0); TensorFunction a = inputs().get(1).function().get(); TensorFunction b = inputs().get(2).function().get(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java index af49d2c108b..da566909adc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java @@ -1,20 +1,19 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.List; -public class Shape extends TensorFlowOperation { +public class Shape extends IntermediateOperation { - public Shape(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public Shape(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); createConstantValue(); } @@ -24,7 +23,7 @@ public class Shape extends TensorFlowOperation { return null; } OrderedTensorType inputType = inputs.get(0).type().get(); - return new OrderedTensorType.Builder(node) + return new OrderedTensorType.Builder() .add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size())) .build(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java index 17ce9e8b7cb..c750c47e27e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java @@ -1,26 +1,26 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.NodeDef; import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; -public class Squeeze extends TensorFlowOperation { +public class Squeeze extends IntermediateOperation { + private final AttributeMap attributeMap; private List<String> squeezeDimensions; - public Squeeze(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public Squeeze(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; } @Override @@ -31,20 +31,21 @@ public class Squeeze extends TensorFlowOperation { OrderedTensorType inputType = inputs.get(0).type().get(); squeezeDimensions = new ArrayList<>(); - AttrValue squeezeDimsAttr = node.getAttrMap().get("squeeze_dims"); - if (squeezeDimsAttr == null) { + Optional<List<Value>> squeezeDimsAttr = attributeMap.getList("squeeze_dims"); + if ( ! squeezeDimsAttr.isPresent()) { squeezeDimensions = inputType.type().dimensions().stream(). - filter(dim -> TensorConverter.dimensionSize(dim) == 1). + filter(dim -> OrderedTensorType.dimensionSize(dim) == 1). map(TensorType.Dimension::name). collect(Collectors.toList()); } else { - squeezeDimensions = squeezeDimsAttr.getList().getIList().stream(). + squeezeDimensions = squeezeDimsAttr.get().stream().map(Value::asDouble).map(Double::intValue). map(i -> i < 0 ? inputType.type().dimensions().size() - i : i). - map(i -> inputType.type().dimensions().get(i.intValue())). - filter(dim -> TensorConverter.dimensionSize(dim) == 1). + map(i -> inputType.type().dimensions().get(i)). + filter(dim -> OrderedTensorType.dimensionSize(dim) == 1). map(TensorType.Dimension::name). collect(Collectors.toList()); } + return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType); } @@ -72,7 +73,7 @@ public class Squeeze extends TensorFlowOperation { } private OrderedTensorType reducedType(OrderedTensorType inputType) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); for (TensorType.Dimension dimension: inputType.type().dimensions()) { if ( ! squeezeDimensions.contains(dimension.name())) { builder.add(dimension); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java index de4d8862fd6..0171d1ea171 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java @@ -1,17 +1,19 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.Optional; -public class Switch extends TensorFlowOperation { +public class Switch extends IntermediateOperation { - public Switch(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + private final int port; + + public Switch(String modelName, String nodeName, List<IntermediateOperation> inputs, int port) { + super(modelName, nodeName, inputs); + this.port = port; } @Override @@ -21,7 +23,7 @@ public class Switch extends TensorFlowOperation { } Optional<OrderedTensorType> predicate = inputs.get(1).type(); if (predicate.get().type().rank() != 0) { - throw new IllegalArgumentException("Switch in " + node.getName() + ": " + + throw new IllegalArgumentException("Switch in " + name + ": " + "predicate must be a scalar"); } return inputs.get(0).type().orElse(null); @@ -29,13 +31,13 @@ public class Switch extends TensorFlowOperation { @Override protected TensorFunction lazyGetFunction() { - TensorFlowOperation predicateOperation = inputs().get(1); + IntermediateOperation predicateOperation = inputs().get(1); if (!predicateOperation.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Switch in " + node.getName() + ": " + + throw new IllegalArgumentException("Switch in " + name + ": " + "predicate must be a constant"); } if (port < 0 || port > 1) { - throw new IllegalArgumentException("Switch in " + node.getName() + ": " + + throw new IllegalArgumentException("Switch in " + name + ": " + "choice should be boolean"); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java new file mode 100644 index 00000000000..a815cbc3944 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java @@ -0,0 +1,85 @@ +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; + +import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * Converts TensorFlow node attributes to Vespa attribute values. + * + * @author lesters + */ +public class AttributeConverter implements IntermediateOperation.AttributeMap { + + private final Map<String, AttrValue> attributeMap; + + public AttributeConverter(NodeDef node) { + attributeMap = node.getAttrMap(); + } + + public static AttributeConverter convert(NodeDef node) { + return new AttributeConverter(node); + } + + @Override + public Optional<Value> get(String key) { + if (attributeMap.containsKey(key)) { + AttrValue attrValue = attributeMap.get(key); + if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) { + return Optional.empty(); // requires type + } + if (attrValue.getValueCase() == AttrValue.ValueCase.B) { + return Optional.of(new BooleanValue(attrValue.getB())); + } + if (attrValue.getValueCase() == AttrValue.ValueCase.I) { + return Optional.of(new DoubleValue(attrValue.getI())); + } + if (attrValue.getValueCase() == AttrValue.ValueCase.F) { + return Optional.of(new DoubleValue(attrValue.getF())); + } + } + return Optional.empty(); + } + + @Override + public Optional<Value> get(String key, OrderedTensorType type) { + if (attributeMap.containsKey(key)) { + AttrValue attrValue = attributeMap.get(key); + if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) { + return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type.type()))); + } + } + return get(key); + } + + @Override + public Optional<List<Value>> getList(String key) { + if (attributeMap.containsKey(key)) { + AttrValue attrValue = attributeMap.get(key); + if (attrValue.getValueCase() == AttrValue.ValueCase.LIST) { + AttrValue.ListValue listValue = attrValue.getList(); + if ( ! listValue.getBList().isEmpty()) { + return Optional.of(listValue.getBList().stream().map(BooleanValue::new).collect(Collectors.toList())); + } + if ( ! listValue.getIList().isEmpty()) { + return Optional.of(listValue.getIList().stream().map(DoubleValue::new).collect(Collectors.toList())); + } + if ( ! listValue.getFList().isEmpty()) { + return Optional.of(listValue.getFList().stream().map(DoubleValue::new).collect(Collectors.toList())); + } + // add the rest + } + } + return Optional.empty(); + } +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java new file mode 100644 index 00000000000..e1b292f9e61 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java @@ -0,0 +1,234 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; + +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Const; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ExpandDims; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Mean; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Merge; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.PlaceholderWithDefault; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Select; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Squeeze; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Switch; +import com.yahoo.tensor.functions.ScalarFunctions; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.MetaGraphDef; +import org.tensorflow.framework.NodeDef; +import org.tensorflow.framework.SignatureDef; +import org.tensorflow.framework.TensorInfo; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Converts a TensorFlow graph to a Vespa IntermediateGraph which is the basis + * for generating Vespa ranking expressions. + * + * @author lesters + */ +public class GraphImporter { + + public static IntermediateOperation mapOperation(NodeDef node, + List<IntermediateOperation> inputs, + IntermediateGraph graph) { + String nodeName = node.getName(); + String modelName = graph.name(); + int nodePort = IntermediateOperation.indexPartOf(nodeName); + OrderedTensorType nodeType = TypeConverter.fromTensorFlowType(node); + AttributeConverter attributes = AttributeConverter.convert(node); + + switch (node.getOp().toLowerCase()) { + // array ops + case "concatv2": return new ConcatV2(modelName, nodeName, inputs); + case "const": return new Const(modelName, nodeName, inputs, attributes, nodeType); + case "expanddims": return new ExpandDims(modelName, nodeName, inputs); + case "identity": return new Identity(modelName, nodeName, inputs); + case "placeholder": return new Argument(modelName, nodeName, nodeType); + case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs); + case "reshape": return new Reshape(modelName, nodeName, inputs); + case "shape": return new Shape(modelName, nodeName, inputs); + case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); + + // control flow + case "merge": return new Merge(modelName, nodeName, inputs); + case "switch": return new Switch(modelName, nodeName, inputs, nodePort); + + // math ops + case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); + case "add_n": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); + case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); + case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); + case "realdiv": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); + case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); + case "matmul": return new MatMul(modelName, nodeName, inputs); + case "maximum": return new Join(modelName, nodeName, inputs, ScalarFunctions.max()); + case "mean": return new Mean(modelName, nodeName, inputs, attributes); + case "reducemean": return new Mean(modelName, nodeName, inputs, attributes); + case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); + case "multiply": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); + case "rsqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.rsqrt()); + case "select": return new Select(modelName, nodeName, inputs); + case "where3": return new Select(modelName, nodeName, inputs); + case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); + case "squareddifference": return new Join(modelName, nodeName, inputs, ScalarFunctions.squareddifference()); + case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); + case "subtract": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); + + // nn ops + case "biasadd": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); + case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); + case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); + case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); + + // state ops + case "variable": return new Constant(modelName, nodeName, nodeType); + case "variablev2": return new Constant(modelName, nodeName, nodeType); + + // evaluation no-ops + case "stopgradient":return new Identity(modelName, nodeName, inputs); + case "noop": return new NoOp(modelName, nodeName, inputs); + + } + + IntermediateOperation op = new NoOp(modelName, node.getName(), inputs); + op.warning("Operation '" + node.getOp() + "' is currently not implemented"); + return op; + } + + public static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException { + MetaGraphDef tfGraph = MetaGraphDef.parseFrom(bundle.metaGraphDef()); + + IntermediateGraph intermediateGraph = new IntermediateGraph(modelName); + importSignatures(tfGraph, intermediateGraph); + importOperations(tfGraph, intermediateGraph, bundle); + verifyOutputTypes(tfGraph, intermediateGraph); + + return intermediateGraph; + } + + private static void importSignatures(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) { + for (java.util.Map.Entry<String, SignatureDef> signatureEntry : tfGraph.getSignatureDefMap().entrySet()) { + String signatureName = signatureEntry.getKey(); + java.util.Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap(); + for (java.util.Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) { + String inputName = input.getKey(); + String nodeName = input.getValue().getName(); + intermediateGraph.inputs(signatureName).put(inputName, IntermediateOperation.namePartOf(nodeName)); + } + java.util.Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap(); + for (java.util.Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) { + String outputName = output.getKey(); + String nodeName = output.getValue().getName(); + intermediateGraph.outputs(signatureName).put(outputName, IntermediateOperation.namePartOf(nodeName)); + } + } + } + + private static void importOperations(MetaGraphDef tfGraph, + IntermediateGraph intermediateGraph, + SavedModelBundle bundle) { + for (String signatureName : intermediateGraph.signatures()) { + for (String outputName : intermediateGraph.outputs(signatureName).values()) { + importOperation(outputName, tfGraph.getGraphDef(), intermediateGraph, bundle); + } + } + } + + private static IntermediateOperation importOperation(String nodeName, + GraphDef tfGraph, + IntermediateGraph intermediateGraph, + SavedModelBundle bundle) { + if (intermediateGraph.alreadyImported(nodeName)) { + return intermediateGraph.get(nodeName); + } + NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(nodeName), tfGraph); + List<IntermediateOperation> inputs = importOperationInputs(node, tfGraph, intermediateGraph, bundle); + IntermediateOperation operation = mapOperation(node, inputs, intermediateGraph); + intermediateGraph.put(nodeName, operation); + + List<IntermediateOperation> controlInputs = importControlInputs(node, tfGraph, intermediateGraph, bundle); + if (controlInputs.size() > 0) { + operation.setControlInputs(controlInputs); + } + + if (operation.isConstant()) { + operation.setConstantValueFunction( + type -> new TensorValue(TensorConverter.toVespaTensor(readVariable(nodeName, bundle), type))); + } + + return operation; + } + + private static List<IntermediateOperation> importOperationInputs(NodeDef node, + GraphDef tfGraph, + IntermediateGraph intermediateGraph, + SavedModelBundle bundle) { + return node.getInputList().stream() + .filter(name -> ! isControlDependency(name)) + .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle)) + .collect(Collectors.toList()); + } + + private static List<IntermediateOperation> importControlInputs(NodeDef node, + GraphDef tfGraph, + IntermediateGraph intermediateGraph, + SavedModelBundle bundle) { + return node.getInputList().stream() + .filter(nodeName -> isControlDependency(nodeName)) + .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle)) + .collect(Collectors.toList()); + } + + private static boolean isControlDependency(String name) { + return name.startsWith("^"); + } + + private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef tfGraph) { + for (NodeDef node : tfGraph.getNodeList()) { + if (node.getName().equals(name)) { + return node; + } + } + throw new IllegalArgumentException("Could not find node '" + name + "'"); + } + + public static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) { + Session.Runner fetched = bundle.session().runner().fetch(name); + List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); + if (importedTensors.size() != 1) + throw new IllegalStateException("Expected 1 tensor from fetching " + name + + ", but got " + importedTensors.size()); + return importedTensors.get(0); + } + + private static void verifyOutputTypes(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) { + for (String signatureName : intermediateGraph.signatures()) { + for (String outputName : intermediateGraph.outputs(signatureName).values()) { + IntermediateOperation operation = intermediateGraph.get(outputName); + NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(operation.name()), tfGraph.getGraphDef()); + OrderedTensorType type = operation.type().orElseThrow( + () -> new IllegalArgumentException("Output of '" + outputName + "' has no type.")); + TypeConverter.verifyType(node, type); + } + } + + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java index 3f55e622fdf..d2d0acfc964 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java new file mode 100644 index 00000000000..67ad1edc312 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java @@ -0,0 +1,72 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; + +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; +import org.tensorflow.framework.TensorShapeProto; + +import java.util.List; + +/** + * Converts and verifies TensorFlow tensor types into Vespa tensor types. + * + * @author lesters + */ +public class TypeConverter { + + public static void verifyType(NodeDef node, OrderedTensorType type) { + TensorShapeProto shape = tensorFlowShape(node); + if (shape != null) { + if (shape.getDimCount() != type.rank()) { + throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " + + "does not match Vespa shape"); + } + for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions().size(); ++tensorFlowIndex) { + int vespaIndex = type.dimensionMap(tensorFlowIndex); + TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex); + TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex); + if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) { + throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " + + "does not match Vespa dimensions"); + } + } + } + } + + private static TensorShapeProto tensorFlowShape(NodeDef node) { + AttrValue attrValueList = node.getAttrMap().get("_output_shapes"); + if (attrValueList == null) { + throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + + "does not exist"); + } + if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) { + throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + + "is not of expected type"); + } + List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList(); + return shapeList.get(0); // support multiple outputs? + } + + public static OrderedTensorType fromTensorFlowType(NodeDef node) { + return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ... + } + + public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + TensorShapeProto shape = tensorFlowShape(node); + for (int i = 0; i < shape.getDimCount(); ++ i) { + String dimensionName = dimensionPrefix + i; + TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i); + if (tensorFlowDimension.getSize() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize())); + } else { + builder.add(TensorType.Dimension.indexed(dimensionName)); + } + } + return builder.build(); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java index 5cff8b03d40..1530754cc43 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java @@ -3,6 +3,6 @@ * ONNX integration */ @ExportPackage -package com.yahoo.searchlib.rankingexpression.integration.onnx; +package com.yahoo.searchlib.rankingexpression.integration.ml; import com.yahoo.osgi.annotation.ExportPackage; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java deleted file mode 100644 index fa1f929cc80..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java +++ /dev/null @@ -1,326 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchlib.rankingexpression.integration.onnx; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Constant; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Argument; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OperationMapper; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.TensorFunction; -import com.yahoo.yolean.Exceptions; -import onnx.Onnx; - -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.logging.Logger; -import java.util.stream.Collectors; - -/** - * Converts a ONNX model into a ranking expression and set of constants. - * - * @author lesters - */ -public class OnnxImporter { - - private static final Logger log = Logger.getLogger(OnnxImporter.class.getName()); - - public OnnxModel importModel(String modelName, File modelDir) { - return importModel(modelName, modelDir.toString()); - } - - public OnnxModel importModel(String modelName, String modelPath) { - try (FileInputStream inputStream = new FileInputStream(modelPath)) { - Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); - return importModel(modelName, model); - } catch (IOException e) { - throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e); - } - } - - public OnnxModel importModel(String modelName, Onnx.ModelProto model) { - return importGraph(modelName, model.getGraph()); - } - - private static OnnxModel importGraph(String modelName, Onnx.GraphProto graph) { - OnnxModel model = new OnnxModel(modelName); - OperationIndex index = new OperationIndex(); - - importNodes(graph, model, index); - verifyOutputTypes(graph, model, index); - findDimensionNames(model, index); - importExpressions(model, index); - - reportWarnings(model, index); - - return model; - } - - private static void importNodes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) { - for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) { - importNode(valueInfo.getName(), graph, model, index); - } - } - - private static OnnxOperation importNode(String name, Onnx.GraphProto graph, OnnxModel model, OperationIndex index) { - if (index.alreadyImported(name)) { - return index.get(name); - } - OnnxOperation operation; - if (isArgumentTensor(name, graph)) { - operation = new Argument(getArgumentTensor(name, graph)); - model.input(OnnxOperation.namePartOf(name), operation.vespaName()); - } else if (isConstantTensor(name, graph)) { - operation = new Constant(model.name(), getConstantTensor(name, graph)); - } else { - Onnx.NodeProto node = getNodeFromGraph(name, graph); - List<OnnxOperation> inputs = importNodeInputs(node, graph, model, index); - operation = OperationMapper.get(node, inputs); - if (isOutputNode(name, graph)) { - model.output(OnnxOperation.namePartOf(name), operation.vespaName()); - } - } - index.put(operation.vespaName(), operation); - - return operation; - } - - private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) { - Onnx.ValueInfoProto value = getArgumentTensor(name, graph); - Onnx.TensorProto tensor = getConstantTensor(name, graph); - return value != null && tensor == null; - } - - private static boolean isConstantTensor(String name, Onnx.GraphProto graph) { - Onnx.ValueInfoProto value = getArgumentTensor(name, graph); - Onnx.TensorProto tensor = getConstantTensor(name, graph); - return value != null && tensor != null; - } - - private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) { - for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) { - if (valueInfo.getName().equals(name)) { - return valueInfo; - } - } - return null; - } - - private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) { - for (Onnx.TensorProto tensorProto : graph.getInitializerList()) { - if (tensorProto.getName().equals(name)) { - return tensorProto; - } - } - return null; - } - - private static boolean isOutputNode(String name, Onnx.GraphProto graph) { - return getOutputNode(name, graph) != null; - } - - private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) { - for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) { - if (valueInfo.getName().equals(name)) { - return valueInfo; - } - String nodeName = OnnxOperation.namePartOf(valueInfo.getName()); - if (nodeName.equals(name)) { - return valueInfo; - } - } - return null; - } - - private static List<OnnxOperation> importNodeInputs(Onnx.NodeProto node, - Onnx.GraphProto graph, - OnnxModel model, - OperationIndex index) { - return node.getInputList().stream() - .map(nodeName -> importNode(nodeName, graph, model, index)) - .collect(Collectors.toList()); - } - - private static void verifyOutputTypes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) { - for (String outputName : model.outputs().values()) { - OnnxOperation operation = index.get(outputName); - Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, graph); - operation.type().orElseThrow( - () -> new IllegalArgumentException("Output of '" + outputName + "' has no type.")) - .verifyType(onnxNode.getType()); - } - } - - - /** Find dimension names to avoid excessive renaming while evaluating the model. */ - private static void findDimensionNames(OnnxModel model, OperationIndex index) { - DimensionRenamer renamer = new DimensionRenamer(); - for (String output : model.outputs().values()) { - addDimensionNameConstraints(index.get(output), renamer); - } - renamer.solve(); - for (String output : model.outputs().values()) { - renameDimensions(index.get(output), renamer); - } - } - - private static void addDimensionNameConstraints(OnnxOperation operation, DimensionRenamer renamer) { - if (operation.type().isPresent()) { - operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); - operation.addDimensionNameConstraints(renamer); - } - } - - private static void renameDimensions(OnnxOperation operation, DimensionRenamer renamer) { - if (operation.type().isPresent()) { - operation.inputs().forEach(input -> renameDimensions(input, renamer)); - operation.renameDimensions(renamer); - } - } - - private static void importExpressions(OnnxModel model, OperationIndex index) { - for (String outputName : model.outputs().values()) { - try { - Optional<TensorFunction> function = importExpression(index.get(outputName), model); - if (!function.isPresent()) { - model.skippedOutput(outputName, "No valid output function could be found."); - } - } - catch (IllegalArgumentException e) { - model.skippedOutput(outputName, Exceptions.toMessageString(e)); - } - } - } - - private static Optional<TensorFunction> importExpression(OnnxOperation operation, OnnxModel model) { - if (!operation.type().isPresent()) { - return Optional.empty(); - } - if (operation.isConstant()) { - return importConstant(operation, model); - } - importInputExpressions(operation, model); - importRankingExpression(operation, model); - importArgumentExpression(operation, model); - - return operation.function(); - } - - private static void importInputExpressions(OnnxOperation operation, OnnxModel model) { - operation.inputs().forEach(input -> importExpression(input, model)); - } - - private static Optional<TensorFunction> importConstant(OnnxOperation operation, OnnxModel model) { - String name = operation.vespaName(); - if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { - return operation.function(); - } - - Value value = operation.getConstantValue().orElseThrow(() -> - new IllegalArgumentException("Operation '" + operation.vespaName() + "' " + - "is constant but does not have a value.")); - if ( ! (value instanceof TensorValue)) { - return operation.function(); // scalar values are inserted directly into the expression - } - - Tensor tensor = value.asTensor(); - if (tensor.type().rank() == 0) { - model.smallConstant(name, tensor); - } else { - model.largeConstant(name, tensor); - } - return operation.function(); - } - - private static void importRankingExpression(OnnxOperation operation, OnnxModel model) { - if (operation.function().isPresent()) { - String name = operation.vespaName(); - if (!model.expressions().containsKey(name)) { - TensorFunction function = operation.function().get(); - - if (model.outputs().containsKey(name)) { - OrderedTensorType operationType = operation.type().get(); - OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType); - if ( ! operationType.equals(standardNamingType)) { - List<String> renameFrom = operationType.dimensionNames(); - List<String> renameTo = standardNamingType.dimensionNames(); - function = new Rename(function, renameFrom, renameTo); - } - } - - try { - // We add all intermediate nodes imported as separate expressions. Only - // those referenced from the output will be used. We parse the - // TensorFunction here to convert it to a RankingExpression tree. - model.expression(name, new RankingExpression(name, function.toString())); - } - catch (ParseException e) { - throw new RuntimeException("Tensorflow function " + function + - " cannot be parsed as a ranking expression", e); - } - } - } - } - - private static void importArgumentExpression(OnnxOperation operation, OnnxModel model) { - if (operation.isInput()) { - // All inputs must have dimensions with standard naming convention: d0, d1, ... - OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get()); - model.argument(operation.vespaName(), standardNamingConvention.type()); - model.requiredMacro(operation.vespaName(), standardNamingConvention.type()); - } - } - - private static void reportWarnings(OnnxModel model, OperationIndex index) { - for (String output : model.outputs().values()) { - reportWarnings(model, index.get(output)); - } - } - - private static void reportWarnings(OnnxModel model, OnnxOperation operation) { - for (String warning : operation.warnings()) { - model.importWarning(warning); - } - for (OnnxOperation input : operation.inputs()) { - reportWarnings(model, input); - } - } - - private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) { - boolean hasPortNumber = nodeName.contains(":"); - for (Onnx.NodeProto node : graph.getNodeList()) { - if (hasPortNumber) { - for (String outputName : node.getOutputList()) { - if (outputName.equals(nodeName)) { - return node; - } - } - } else if (node.getName().equals(nodeName)) { - return node; - } - } - throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph"); - } - - private static class OperationIndex { - private final Map<String, OnnxOperation> index = new HashMap<>(); - public OnnxOperation put(String key, OnnxOperation operation) { return index.put(key, operation); } - public OnnxOperation get(String key) { return index.get(key); } - public boolean alreadyImported(String key) { return index.containsKey(key); } - public Collection<OnnxOperation> operations() { return index.values(); } - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java deleted file mode 100644 index bd53afefc3f..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchlib.rankingexpression.integration.onnx; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.regex.Pattern; - -/** - * The result of importing an ONNX model into Vespa. - * - * @author bratseth - * @author lesters - */ -public class OnnxModel { - - private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); - - private final String name; - - public OnnxModel(String name) { - if ( ! nameRegexp.matcher(name).matches()) - throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" + - name + "'"); - this.name = name; - } - - /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ - public String name() { return name; } - - private final Map<String, String> inputs = new HashMap<>(); - private final Map<String, String> outputs = new HashMap<>(); - private final Map<String, String> skippedOutputs = new HashMap<>(); - private final List<String> importWarnings = new ArrayList<>(); - - private final Map<String, TensorType> arguments = new HashMap<>(); - private final Map<String, Tensor> smallConstants = new HashMap<>(); - private final Map<String, Tensor> largeConstants = new HashMap<>(); - private final Map<String, RankingExpression> expressions = new HashMap<>(); - private final Map<String, RankingExpression> macros = new HashMap<>(); - private final Map<String, TensorType> requiredMacros = new HashMap<>(); - - void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } - void output(String name, String expressionName) { outputs.put(name, expressionName); } - void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } - void importWarning(String warning) { importWarnings.add(warning); } - void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } - void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } - void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } - void expression(String name, RankingExpression expression) { expressions.put(name, expression); } - void macro(String name, RankingExpression expression) { macros.put(name, expression); } - void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } - - /** - * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name - * to argument (Placeholder) name in the owner of this - */ - public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); } - - /** Returns arguments().get(inputs.get(name)), e.g the type of the argument this input references */ - public TensorType inputArgument(String inputName) { return arguments().get(inputs.get(inputName)); } - - /** Returns an immutable list of the expression names of this */ - public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); } - - /** - * Returns an immutable list of the outputs of this which could not be imported, - * with a string detailing the reason for each - */ - public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); } - - /** - * Returns an immutable list of possibly non-fatal warnings encountered during import. - */ - public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } - - /** Returns expressions().get(outputs.get(outputName)), e.g the expression this output references */ - public RankingExpression outputExpression(String outputName) { return expressions().get(outputs.get(outputName)); } - - /** Returns an immutable map of the arguments (inputs) of this */ - public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } - - /** - * Returns an immutable map of the small constants of this. - */ - public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); } - - /** - * Returns an immutable map of the large constants of this. - */ - public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); } - - /** - * Returns an immutable map of the expressions of this - corresponding to ONNX nodes - * which are not inputs or constants. - */ - public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } - - /** Returns an immutable map of macros that are part of this model */ - public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); } - - /** Returns an immutable map of the macros that must be provided by the environment running this model */ - public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java deleted file mode 100644 index 12090145d3a..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer; - -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Join; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.MatMul; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.NoOp; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation; -import com.yahoo.tensor.functions.ScalarFunctions; -import onnx.Onnx; - -import java.util.List; - -public class OperationMapper { - - public static OnnxOperation get(Onnx.NodeProto node, List<OnnxOperation> inputs) { - switch (node.getOpType().toLowerCase()) { - case "add": return new Join(node, inputs, ScalarFunctions.add()); - case "matmul": return new MatMul(node, inputs); - } - - OnnxOperation op = new NoOp(node, inputs); - op.warning("Operation '" + node.getOpType() + "' is currently not implemented"); - return op; - } -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java deleted file mode 100644 index a8d8d63daf4..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; - -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.VariableTensor; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.TensorFunction; -import onnx.Onnx; - -import java.util.Collections; -import java.util.List; - -public class Argument extends OnnxOperation { - - private Onnx.ValueInfoProto valueInfo; - private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ... - - public Argument(Onnx.ValueInfoProto valueInfoProto) { - super(null, Collections.emptyList()); - valueInfo = valueInfoProto; - standardNamingType = OrderedTensorType.fromOnnxType(valueInfo.getType()); - } - - @Override - public String vespaName() { - return vespaName(valueInfo.getName()); - } - - @Override - protected OrderedTensorType lazyGetType() { - return OrderedTensorType.fromOnnxType(valueInfo.getType(), vespaName() + "_"); - } - - @Override - protected TensorFunction lazyGetFunction() { - TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type()); - if (!standardNamingType.equals(type)) { - List<String> renameFrom = standardNamingType.dimensionNames(); - List<String> renameTo = type.dimensionNames(); - output = new Rename(output, renameFrom, renameTo); - } - return output; - } - - @Override - public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); - } - } - - @Override - public boolean isInput() { - return true; - } - - @Override - public boolean isConstant() { - return false; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java deleted file mode 100644 index b1136a0ce0a..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; - -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; -import com.yahoo.tensor.functions.TensorFunction; -import onnx.Onnx; - -import java.util.Collections; -import java.util.List; - -public class NoOp extends OnnxOperation { - - public NoOp(Onnx.NodeProto node, List<OnnxOperation> inputs) { - super(node, Collections.emptyList()); // don't propagate inputs - } - - @Override - protected OrderedTensorType lazyGetType() { - return null; - } - - @Override - protected TensorFunction lazyGetFunction() { - return null; - } - - @Override - public boolean isConstant() { - return true; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java deleted file mode 100644 index 30f7b4f4711..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; - -import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.tensor.functions.TensorFunction; -import onnx.Onnx; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.function.Function; - -/** - * Wraps an ONNX node and produces the respective Vespa tensor operation. - * During import, a graph of these operations are constructed. Then, the - * types are used to deduce sensible dimension names using the - * DimensionRenamer. After the types have been renamed, the proper - * Vespa expressions can be extracted. - * - * @author lesters - */ -public abstract class OnnxOperation { - - protected final Onnx.NodeProto node; // can be null for onnx inputs and constants - protected final List<OnnxOperation> inputs; - protected final List<OnnxOperation> outputs = new ArrayList<>(); - protected final List<String> importWarnings = new ArrayList<>(); - - protected OrderedTensorType type; - protected TensorFunction function; - protected Value constantValue = null; - - OnnxOperation(Onnx.NodeProto node, List<OnnxOperation> inputs) { - this.node = node; - this.inputs = Collections.unmodifiableList(inputs); - this.inputs.forEach(i -> i.outputs.add(this)); - } - - protected abstract OrderedTensorType lazyGetType(); - protected abstract TensorFunction lazyGetFunction(); - - /** Returns the Vespa tensor type of this operation if it exists */ - public Optional<OrderedTensorType> type() { - if (type == null) { - type = lazyGetType(); - } - return Optional.ofNullable(type); - } - - /** Returns the Vespa tensor function implementing all operations from this node with inputs */ - public Optional<TensorFunction> function() { - if (function == null) { - if (isConstant()) { - ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName())); - function = new TensorFunctionNode.TensorFunctionExpressionNode(constant); - } else { - function = lazyGetFunction(); - } - } - return Optional.ofNullable(function); - } - - /** Return Onnx node */ - public Onnx.NodeProto node() { return node; } - - /** Return unmodifiable list of inputs */ - public List<OnnxOperation> inputs() { return inputs; } - - /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */ - public List<OnnxOperation> outputs() { return Collections.unmodifiableList(outputs); } - - /** Add dimension name constraints for this operation */ - public void addDimensionNameConstraints(DimensionRenamer renamer) { } - - /** Performs dimension rename for this operation */ - public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); } - - /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */ - public boolean isInput() { return false; } - - /** Return true if this node is constant */ - public boolean isConstant() { return inputs.stream().allMatch(OnnxOperation::isConstant); } - - /** Gets the constant value if it exists */ - public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); } - - /** Retrieve the valid Vespa name of this node */ - public String vespaName() { return vespaName(node.getName()); } - public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; } - - /** Retrieve the list of warnings produced during its lifetime */ - public List<String> warnings() { return Collections.unmodifiableList(importWarnings); } - - /** Set an input warning */ - public void warning(String warning) { importWarnings.add(warning); } - - boolean verifyInputs(int expected, Function<OnnxOperation, Optional<?>> func) { - if (inputs.size() != expected) { - throw new IllegalArgumentException("Expected " + expected + " inputs " + - "for '" + node.getName() + "', got " + inputs.size()); - } - return inputs.stream().map(func).allMatch(Optional::isPresent); - } - - boolean allInputTypesPresent(int expected) { - return verifyInputs(expected, OnnxOperation::type); - } - - boolean allInputFunctionsPresent(int expected) { - return verifyInputs(expected, OnnxOperation::function); - } - - /** - * A method signature input and output has the form name:index. - * This returns the name part without the index. - */ - public static String namePartOf(String name) { - name = name.startsWith("^") ? name.substring(1) : name; - return name.split(":")[0]; - } - - /** - * This return the output index part. Indexes are used for nodes with - * multiple outputs. - */ - public static int indexPartOf(String name) { - int i = name.indexOf(":"); - return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java deleted file mode 100644 index e3c72830095..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ /dev/null @@ -1,411 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OperationMapper; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.TensorFunction; -import com.yahoo.yolean.Exceptions; -import org.tensorflow.SavedModelBundle; -import org.tensorflow.Session; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.MetaGraphDef; -import org.tensorflow.framework.NodeDef; -import org.tensorflow.framework.SignatureDef; -import org.tensorflow.framework.TensorInfo; - -import java.io.File; -import java.io.IOException; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.logging.Logger; -import java.util.stream.Collectors; - -/** - * Converts a saved TensorFlow model into a ranking expression and set of constants. - * - * @author bratseth - * @author lesters - */ -public class TensorFlowImporter { - - private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName()); - - /** - * Imports a saved TensorFlow model from a directory. - * The model should be saved as a .pbtxt or .pb file. - * The name of the model is taken as the db/pbtxt file name (not including the file ending). - * - * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_] - * @param modelDir the directory containing the TensorFlow model files to import - */ - public TensorFlowModel importModel(String modelName, String modelDir) { - try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { - - return importModel(modelName, model); - } - catch (IllegalArgumentException e) { - throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); - } - } - - public TensorFlowModel importModel(String modelName, File modelDir) { - return importModel(modelName, modelDir.toString()); - } - - /** Imports a TensorFlow model */ - public TensorFlowModel importModel(String modelName, SavedModelBundle model) { - try { - return importGraph(modelName, MetaGraphDef.parseFrom(model.metaGraphDef()), model); - } - catch (IOException e) { - throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e); - } - } - - /** - * Imports the TensorFlow graph by first importing the tensor types, then - * finding a suitable set of dimensions names for each - * placeholder/constant/variable, then importing the expressions. - */ - private static TensorFlowModel importGraph(String modelName, MetaGraphDef graph, SavedModelBundle bundle) { - TensorFlowModel model = new TensorFlowModel(modelName); - OperationIndex index = new OperationIndex(); - - importSignatures(graph, model); - importNodes(graph, model, index); - findDimensionNames(model, index); - importExpressions(model, index, bundle); - - reportWarnings(model, index); - logVariableTypes(index); - - return model; - } - - private static void importSignatures(MetaGraphDef graph, TensorFlowModel model) { - for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) { - String signatureName = signatureEntry.getKey(); - TensorFlowModel.Signature signature = model.signature(signatureName); - - Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap(); - for (Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) { - String inputName = input.getKey(); - signature.input(inputName, namePartOf(input.getValue().getName())); - } - - Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap(); - for (Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) { - String outputName = output.getKey(); - signature.output(outputName, namePartOf(output.getValue().getName())); - } - } - } - - private static boolean isSignatureInput(TensorFlowModel model, TensorFlowOperation operation) { - for (TensorFlowModel.Signature signature : model.signatures().values()) { - for (String inputName : signature.inputs().values()) { - if (inputName.equals(operation.node().getName())) { - return true; - } - } - } - return false; - } - - private static boolean isSignatureOutput(TensorFlowModel model, TensorFlowOperation operation) { - for (TensorFlowModel.Signature signature : model.signatures().values()) { - for (String outputName : signature.outputs().values()) { - if (outputName.equals(operation.node().getName())) { - return true; - } - } - } - return false; - } - - private static void importNodes(MetaGraphDef graph, TensorFlowModel model, OperationIndex index) { - for (TensorFlowModel.Signature signature : model.signatures().values()) { - for (String outputName : signature.outputs().values()) { - importNode(model.name(), outputName, graph.getGraphDef(), index); - } - } - } - - private static TensorFlowOperation importNode(String modelName, String nodeName, GraphDef graph, OperationIndex index) { - if (index.alreadyImported(nodeName)) { - return index.get(nodeName); - } - NodeDef node = getTensorFlowNodeFromGraph(namePartOf(nodeName), graph); - List<TensorFlowOperation> inputs = importNodeInputs(modelName, node, graph, index); - TensorFlowOperation operation = OperationMapper.get(modelName, node, inputs, portPartOf(nodeName)); - index.put(nodeName, operation); - - List<TensorFlowOperation> controlInputs = importControlInputs(modelName, node, graph, index); - if (controlInputs.size() > 0) { - operation.setControlInputs(controlInputs); - } - - return operation; - } - - private static List<TensorFlowOperation> importNodeInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) { - return node.getInputList().stream() - .filter(name -> ! isControlDependency(name)) - .map(nodeName -> importNode(modelName, nodeName, graph, index)) - .collect(Collectors.toList()); - } - - private static List<TensorFlowOperation> importControlInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) { - return node.getInputList().stream() - .filter(nodeName -> isControlDependency(nodeName)) - .map(nodeName -> importNode(modelName, nodeName, graph, index)) - .collect(Collectors.toList()); - } - - private static boolean isControlDependency(String name) { - return name.startsWith("^"); - } - - /** Find dimension names to avoid excessive renaming while evaluating the model. */ - private static void findDimensionNames(TensorFlowModel model, OperationIndex index) { - DimensionRenamer renamer = new DimensionRenamer(); - for (TensorFlowModel.Signature signature : model.signatures().values()) { - for (String output : signature.outputs().values()) { - addDimensionNameConstraints(index.get(output), renamer); - } - } - renamer.solve(); - for (TensorFlowModel.Signature signature : model.signatures().values()) { - for (String output : signature.outputs().values()) { - renameDimensions(index.get(output), renamer); - } - } - } - - private static void addDimensionNameConstraints(TensorFlowOperation operation, DimensionRenamer renamer) { - if (operation.type().isPresent()) { - operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); - operation.addDimensionNameConstraints(renamer); - } - } - - private static void renameDimensions(TensorFlowOperation operation, DimensionRenamer renamer) { - if (operation.type().isPresent()) { - operation.inputs().forEach(input -> renameDimensions(input, renamer)); - operation.renameDimensions(renamer); - } - } - - private static void importExpressions(TensorFlowModel model, OperationIndex index, SavedModelBundle bundle) { - for (TensorFlowModel.Signature signature : model.signatures().values()) { - for (String outputName : signature.outputs().values()) { - try { - Optional<TensorFunction> function = importExpression(index.get(outputName), model, bundle); - if (!function.isPresent()) { - signature.skippedOutput(outputName, "No valid output function could be found."); - } - } - catch (IllegalArgumentException e) { - signature.skippedOutput(outputName, Exceptions.toMessageString(e)); - } - } - } - } - - private static Optional<TensorFunction> importExpression(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) { - if (!operation.type().isPresent()) { - return Optional.empty(); - } - if (operation.isConstant()) { - return importConstant(model, operation, bundle); - } - - importInputExpressions(operation, model, bundle); - importRankingExpression(model, operation); - importInputExpression(model, operation); - importMacroExpression(model, operation); - - return operation.function(); - } - - private static void importInputExpressions(TensorFlowOperation operation, TensorFlowModel model, - SavedModelBundle bundle) { - operation.inputs().forEach(input -> importExpression(input, model, bundle)); - } - - private static void importMacroExpression(TensorFlowModel model, TensorFlowOperation operation) { - if (operation.macro().isPresent()) { - TensorFunction function = operation.macro().get(); - try { - model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString())); - } - catch (ParseException e) { - throw new RuntimeException("Tensorflow function " + function + - " cannot be parsed as a ranking expression", e); - } - } - } - - private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation, - SavedModelBundle bundle) { - String name = operation.vespaName(); - if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { - return operation.function(); - } - - Tensor tensor; - if (operation.getConstantValue().isPresent()) { - Value value = operation.getConstantValue().get(); - if ( ! (value instanceof TensorValue)) { - return operation.function(); // scalar values are inserted directly into the expression - } - tensor = value.asTensor(); - } else { - // Here we use the type from the operation, which will have correct dimension names after name resolving - tensor = TensorConverter.toVespaTensor(readVariable(operation.node().getName(), bundle), - operation.type().get()); - operation.setConstantValue(new TensorValue(tensor)); - } - - if (tensor.type().rank() == 0) { - model.smallConstant(name, tensor); - } else { - model.largeConstant(name, tensor); - } - return operation.function(); - } - - static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) { - Session.Runner fetched = bundle.session().runner().fetch(name); - List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); - if (importedTensors.size() != 1) - throw new IllegalStateException("Expected 1 tensor from fetching " + name + - ", but got " + importedTensors.size()); - return importedTensors.get(0); - } - - private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) { - if (operation.function().isPresent()) { - String name = operation.node().getName(); - if (!model.expressions().containsKey(operation.node().getName())) { - TensorFunction function = operation.function().get(); - - // Make sure output adheres to standard naming convention - if (isSignatureOutput(model, operation)) { - OrderedTensorType operationType = operation.type().get(); - OrderedTensorType standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node()); - if ( ! operationType.equals(standardNamingType)) { - List<String> renameFrom = operationType.dimensionNames(); - List<String> renameTo = standardNamingType.dimensionNames(); - function = new Rename(function, renameFrom, renameTo); - } - } - - try { - // We add all intermediate nodes imported as separate expressions. Only - // those referenced in a signature output will be used. We parse the - // TensorFunction here to convert it to a RankingExpression tree. - model.expression(name, new RankingExpression(name, function.toString())); - } - catch (ParseException e) { - throw new RuntimeException("Tensorflow function " + function + - " cannot be parsed as a ranking expression", e); - } - } - } - } - - private static void importInputExpression(TensorFlowModel model, TensorFlowOperation operation) { - if (operation.isInput() && isSignatureInput(model, operation)) { - // All inputs must have dimensions with standard naming convention: d0, d1, ... - OrderedTensorType standardNamingConvention = OrderedTensorType.fromTensorFlowType(operation.node()); - model.argument(operation.node().getName(), standardNamingConvention.type()); - model.requiredMacro(operation.vespaName(), standardNamingConvention.type()); - } - } - - private static void reportWarnings(TensorFlowModel model, OperationIndex index) { - for (TensorFlowModel.Signature signature : model.signatures().values()) { - for (String output : signature.outputs().values()) { - reportWarnings(index.get(output), signature); - } - } - } - - /** - * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type. - * This allows users to learn the exact types (including dimension order after renaming) of the Variables - * such that these can be converted and fed to a parent document independently of the rest of the model - * for fast model weight updates. - */ - private static void logVariableTypes(OperationIndex index) { - for (TensorFlowOperation operation : index.operations()) { - if ( ! (operation instanceof Variable)) continue; - if ( ! operation.type().isPresent()) continue; // will not happen - - log.info("Importing TensorFlow variable " + operation.node().getName() + " as " + operation.vespaName() + - " of type " + operation.type().get()); - } - } - - private static void reportWarnings(TensorFlowOperation operation, TensorFlowModel.Signature signature) { - for (String warning : operation.warnings()) { - signature.importWarning(warning); - } - for (TensorFlowOperation input : operation.inputs()) { - reportWarnings(input, signature); - } - } - - private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef graph) { - for (NodeDef node : graph.getNodeList()) { - if (node.getName().equals(name)) { - return node; - } - } - throw new IllegalArgumentException("Could not find node '" + name + "'"); - } - - /** - * A method signature input and output has the form name:index. - * This returns the name part without the index. - */ - private static String namePartOf(String name) { - name = name.startsWith("^") ? name.substring(1) : name; - return name.split(":")[0]; - } - - /** - * This return the output port part. Indexes are used for nodes with - * multiple outputs. - */ - private static int portPartOf(String name) { - int i = name.indexOf(":"); - return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); - } - - private static class OperationIndex { - - private final Map<String, TensorFlowOperation> index = new HashMap<>(); - public TensorFlowOperation put(String key, TensorFlowOperation operation) { return index.put(key, operation); } - public TensorFlowOperation get(String key) { return index.get(key); } - public boolean alreadyImported(String key) { return index.containsKey(key); } - public Collection<TensorFlowOperation> operations() { return index.values(); } - - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java deleted file mode 100644 index c1665d066a4..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java +++ /dev/null @@ -1,210 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; - -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation; - -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Deque; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; - -/** - * A constraint satisfier to find suitable dimension names to reduce the - * amount of necessary renaming during evaluation of an imported model. - * - * @author lesters - */ -public class DimensionRenamer { - - private final String dimensionPrefix; - private final Map<String, List<Integer>> variables = new HashMap<>(); - private final Map<Arc, Constraint> constraints = new HashMap<>(); - private final Map<String, Integer> renames = new HashMap<>(); - - private int iterations = 0; - - public DimensionRenamer() { - this("d"); - } - - public DimensionRenamer(String dimensionPrefix) { - this.dimensionPrefix = dimensionPrefix; - } - - /** - * Add a dimension name variable. - */ - public void addDimension(String name) { - variables.computeIfAbsent(name, d -> new ArrayList<>()); - } - - /** - * Add a constraint between dimension names. - */ - public void addConstraint(String from, String to, Constraint pred, TensorFlowOperation operation) { - Arc arc = new Arc(from, to, operation); - Arc opposite = arc.opposite(); - constraints.put(arc, pred); - constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric - } - - /** - * Retrieve resulting name of dimension after solving for constraints. - */ - public Optional<String> dimensionNameOf(String name) { - if (!renames.containsKey(name)) { - return Optional.empty(); - } - return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name))); - } - - /** - * Perform iterative arc consistency until we have found a solution. After - * an initial iteration, the variables (dimensions) will have multiple - * valid values. Find a single valid assignment by iteratively locking one - * dimension after another, and running the arc consistency algorithm - * multiple times. - * - * This requires having constraints that result in an absolute ordering: - * equals, lesserThan and greaterThan do that, but adding notEquals does - * not typically result in a guaranteed ordering. If that is needed, the - * algorithm below needs to be adapted with a backtracking (tree) search - * to find solutions. - */ - public void solve(int maxIterations) { - initialize(); - - // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts - - for (String dimension : variables.keySet()) { - List<Integer> values = variables.get(dimension); - if (values.size() > 1) { - if (!ac3()) { - throw new IllegalArgumentException("Dimension renamer unable to find a solution."); - } - values.sort(Integer::compare); - variables.put(dimension, Collections.singletonList(values.get(0))); - } - renames.put(dimension, variables.get(dimension).get(0)); - if (iterations > maxIterations) { - throw new IllegalArgumentException("Dimension renamer unable to find a solution within " + - maxIterations + " iterations"); - } - } - - // Todo: handle failure more gracefully: - // If a solution can't be found, look at the operation node in the arc - // with the most remaining constraints, and inject a rename operation. - // Then run this algorithm again. - } - - public void solve() { - solve(100000); - } - - private void initialize() { - for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) { - List<Integer> values = variable.getValue(); - for (int i = 0; i < variables.size(); ++i) { - values.add(i); // invariant: values are in increasing order - } - } - } - - private boolean ac3() { - Deque<Arc> workList = new ArrayDeque<>(constraints.keySet()); - while (!workList.isEmpty()) { - Arc arc = workList.pop(); - iterations += 1; - if (revise(arc)) { - if (variables.get(arc.from).size() == 0) { - return false; // no solution found - } - for (Arc constraint : constraints.keySet()) { - if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) { - workList.add(constraint); - } - } - } - } - return true; - } - - private boolean revise(Arc arc) { - boolean revised = false; - for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) { - Integer from = fromIterator.next(); - boolean satisfied = false; - for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) { - Integer to = toIterator.next(); - if (constraints.get(arc).test(from, to)) { - satisfied = true; - } - } - if (!satisfied) { - fromIterator.remove(); - revised = true; - } - } - return revised; - } - - public interface Constraint { - boolean test(Integer x, Integer y); - } - - public static boolean equals(Integer x, Integer y) { - return Objects.equals(x, y); - } - - public static boolean lesserThan(Integer x, Integer y) { - return x < y; - } - - public static boolean greaterThan(Integer x, Integer y) { - return x > y; - } - - private static class Arc { - - private final String from; - private final String to; - private final TensorFlowOperation operation; - - Arc(String from, String to, TensorFlowOperation operation) { - this.from = from; - this.to = to; - this.operation = operation; - } - - Arc opposite() { - return new Arc(to, from, operation); - } - - @Override - public int hashCode() { - return Objects.hash(from, to); - } - - @Override - public boolean equals(Object obj) { - if (obj == null || !(obj instanceof Arc)) { - return false; - } - Arc other = (Arc) obj; - return Objects.equals(from, other.from) && Objects.equals(to, other.to); - } - - @Override - public String toString() { - return String.format("%s -> %s", from, to); - } - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java deleted file mode 100644 index b665413a6b2..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; - -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ConcatV2; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ExpandDims; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Identity; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Join; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Map; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Matmul; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Mean; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Merge; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.NoOp; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Placeholder; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.PlaceholderWithDefault; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Reshape; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Select; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Shape; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Squeeze; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Switch; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable; -import com.yahoo.tensor.functions.ScalarFunctions; -import org.tensorflow.framework.NodeDef; - -import java.util.List; - -/** - * Maps from TensorFlow operations to Vespa operations. - * - * @author bratseth - * @author lesters - */ -public class OperationMapper { - - public static TensorFlowOperation get(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - switch (node.getOp().toLowerCase()) { - // array ops - case "concatv2": return new ConcatV2(modelName, node, inputs, port); - case "const": return new Const(modelName, node, inputs, port); - case "expanddims": return new ExpandDims(modelName, node, inputs, port); - case "identity": return new Identity(modelName, node, inputs, port); - case "placeholder": return new Placeholder(modelName, node, inputs, port); - case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, node, inputs, port); - case "reshape": return new Reshape(modelName, node, inputs, port); - case "shape": return new Shape(modelName, node, inputs, port); - case "squeeze": return new Squeeze(modelName, node, inputs, port); - - // control flow - case "merge": return new Merge(modelName, node, inputs, port); - case "switch": return new Switch(modelName, node, inputs, port); - - // math ops - case "add": return new Join(modelName, node, inputs, port, ScalarFunctions.add()); - case "add_n": return new Join(modelName, node, inputs, port, ScalarFunctions.add()); - case "acos": return new Map(modelName, node, inputs, port, ScalarFunctions.acos()); - case "div": return new Join(modelName, node, inputs, port, ScalarFunctions.divide()); - case "realdiv": return new Join(modelName, node, inputs, port, ScalarFunctions.divide()); - case "floor": return new Map(modelName, node, inputs, port, ScalarFunctions.floor()); - case "matmul": return new Matmul(modelName, node, inputs, port); - case "maximum": return new Join(modelName, node, inputs, port, ScalarFunctions.max()); - case "mean": return new Mean(modelName, node, inputs, port); - case "reducemean": return new Mean(modelName, node, inputs, port); - case "mul": return new Join(modelName, node, inputs, port, ScalarFunctions.multiply()); - case "multiply": return new Join(modelName, node, inputs, port, ScalarFunctions.multiply()); - case "rsqrt": return new Map(modelName, node, inputs, port, ScalarFunctions.rsqrt()); - case "select": return new Select(modelName, node, inputs, port); - case "where3": return new Select(modelName, node, inputs, port); - case "sigmoid": return new Map(modelName, node, inputs, port, ScalarFunctions.sigmoid()); - case "squareddifference": return new Join(modelName, node, inputs, port, ScalarFunctions.squareddifference()); - case "sub": return new Join(modelName, node, inputs, port, ScalarFunctions.subtract()); - case "subtract": return new Join(modelName, node, inputs, port, ScalarFunctions.subtract()); - - // nn ops - case "biasadd": return new Join(modelName, node, inputs, port, ScalarFunctions.add()); - case "elu": return new Map(modelName, node, inputs, port, ScalarFunctions.elu()); - case "relu": return new Map(modelName, node, inputs, port, ScalarFunctions.relu()); - case "selu": return new Map(modelName, node, inputs, port, ScalarFunctions.selu()); - - // state ops - case "variable": return new Variable(modelName, node, inputs, port); - case "variablev2": return new Variable(modelName, node, inputs, port); - - // evaluation no-ops - case "stopgradient":return new Identity(modelName, node, inputs, port); - case "noop": return new NoOp(modelName, node, inputs, port); - } - - TensorFlowOperation op = new NoOp(modelName, node, inputs, port); - op.warning("Operation '" + node.getOp() + "' is currently not implemented"); - return op; - } - -} - - - diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java deleted file mode 100644 index 03a65333192..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java +++ /dev/null @@ -1,255 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; - -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.TensorTypeParser; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.NodeDef; -import org.tensorflow.framework.TensorShapeProto; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; - -/** - * A Vespa tensor type is ordered by the lexicographical ordering of dimension - * names. TensorFlow tensors have an explicit ordering of their dimensions. - * During import, we need to track the Vespa dimension that matches the - * corresponding TensorFlow dimension as the ordering can change after - * dimension renaming. That is the purpose of this class. - * - * @author lesters - */ -public class OrderedTensorType { - - private final TensorType type; - private final List<TensorType.Dimension> dimensions; - - private final long[] innerSizesTensorFlow; - private final long[] innerSizesVespa; - private final int[] dimensionMap; - - private OrderedTensorType(List<TensorType.Dimension> dimensions) { - this.dimensions = Collections.unmodifiableList(dimensions); - this.type = new TensorType.Builder(dimensions).build(); - this.innerSizesTensorFlow = new long[dimensions.size()]; - this.innerSizesVespa = new long[dimensions.size()]; - this.dimensionMap = createDimensionMap(); - } - - public TensorType type() { - return this.type; - } - - public int rank() { return dimensions.size(); } - - public List<TensorType.Dimension> dimensions() { - return dimensions; - } - - public List<String> dimensionNames() { - return dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList()); - } - - private int[] createDimensionMap() { - int numDimensions = dimensions.size(); - if (numDimensions == 0) { - return null; - } - innerSizesTensorFlow[numDimensions - 1] = 1; - innerSizesVespa[numDimensions - 1] = 1; - for (int i = numDimensions - 1; --i >= 0; ) { - innerSizesTensorFlow[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesTensorFlow[i+1]; - innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1]; - } - int[] mapping = new int[numDimensions]; - for (int i = 0; i < numDimensions; ++i) { - TensorType.Dimension dim1 = dimensions().get(i); - for (int j = 0; j < numDimensions; ++j) { - TensorType.Dimension dim2 = type.dimensions().get(j); - if (dim1.equals(dim2)) { - mapping[i] = j; - break; - } - } - } - return mapping; - } - - /** - * When dimension ordering between Vespa and TensorFlow differs, i.e. - * after dimension renaming, use the dimension map to read in values - * so that they are correctly laid out in memory for Vespa. - * Used when importing tensors from TensorFlow. - */ - public int toDirectIndex(int index) { - if (dimensions.size() == 0) { - return 0; - } - if (dimensionMap == null) { - throw new IllegalArgumentException("Dimension map is not available"); - } - int directIndex = 0; - long rest = index; - for (int i = 0; i < dimensions.size(); ++i) { - long address = rest / innerSizesTensorFlow[i]; - directIndex += innerSizesVespa[dimensionMap[i]] * address; - rest %= innerSizesTensorFlow[i]; - } - return directIndex; - } - - @Override - public boolean equals(Object obj) { - if (obj == null || !(obj instanceof OrderedTensorType)) { - return false; - } - OrderedTensorType other = (OrderedTensorType) obj; - if (dimensions.size() != dimensions.size()) { - return false; - } - List<TensorType.Dimension> thisDimensions = this.dimensions(); - List<TensorType.Dimension> otherDimensions = other.dimensions(); - for (int i = 0; i < thisDimensions.size(); ++i) { - if (!thisDimensions.get(i).equals(otherDimensions.get(i))) { - return false; - } - } - return true; - } - - public void verifyType(NodeDef node) { - TensorShapeProto shape = tensorFlowShape(node); - if (shape != null) { - if (shape.getDimCount() != type.rank()) { - throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " + - "does not match Vespa shape"); - } - for (int tensorFlowIndex = 0; tensorFlowIndex < dimensions.size(); ++tensorFlowIndex) { - int vespaIndex = dimensionMap[tensorFlowIndex]; - TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex); - TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex); - if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) { - throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " + - "does not match Vespa dimensions"); - } - } - } - } - - private static TensorShapeProto tensorFlowShape(NodeDef node) { - AttrValue attrValueList = node.getAttrMap().get("_output_shapes"); - if (attrValueList == null) { - throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + - "does not exist"); - } - if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) { - throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + - "is not of expected type"); - } - List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList(); - return shapeList.get(0); // support multiple outputs? - } - - public OrderedTensorType rename(DimensionRenamer renamer) { - List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size()); - for (TensorType.Dimension dimension : dimensions) { - String oldName = dimension.name(); - Optional<String> newName = renamer.dimensionNameOf(oldName); - if (!newName.isPresent()) - return this; // presumably, already renamed - TensorType.Dimension.Type dimensionType = dimension.type(); - if (dimensionType == TensorType.Dimension.Type.indexedBound) { - renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get())); - } else if (dimensionType == TensorType.Dimension.Type.indexedUnbound) { - renamedDimensions.add(TensorType.Dimension.indexed(newName.get())); - } else if (dimensionType == TensorType.Dimension.Type.mapped) { - renamedDimensions.add(TensorType.Dimension.mapped(newName.get())); - } - } - return new OrderedTensorType(renamedDimensions); - } - - /** - * Returns a string representation of this: A standard tensor type string where dimensions - * are listed in the order of this rather than in the natural order of their names. - */ - @Override - public String toString() { - return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")"; - } - - /** - * Creates an instance from the string representation of this: A standard tensor type string - * where dimensions are listed in the order of this rather than the natural order of their names. - */ - public static OrderedTensorType fromSpec(String typeSpec) { - return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec)); - } - - public static OrderedTensorType fromTensorFlowType(NodeDef node) { - return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ... - } - - public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) { - Builder builder = new Builder(node); - TensorShapeProto shape = tensorFlowShape(node); - for (int i = 0; i < shape.getDimCount(); ++ i) { - String dimensionName = dimensionPrefix + i; - TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i); - if (tensorFlowDimension.getSize() >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize())); - } else { - builder.add(TensorType.Dimension.indexed(dimensionName)); - } - } - return builder.build(); - } - - public static class Builder { - - private final TensorShapeProto shape; - private final List<TensorType.Dimension> dimensions; - - public Builder(NodeDef node) { - this.shape = tensorFlowShape(node); - this.dimensions = new ArrayList<>(shape.getDimCount()); - } - - public Builder add(TensorType.Dimension vespaDimension) { - int index = dimensions.size(); - TensorShapeProto.Dim tensorFlowDimension = shape.getDim(index); - long size = tensorFlowDimension.getSize(); - if (size >= 0) { - if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) { - throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + - "dimension types"); - } - if (!vespaDimension.size().isPresent()) { - throw new IllegalArgumentException("Tensor dimension is indexed bound but does " + - "not have a size"); - } - if (vespaDimension.size().get() != size) { - throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + - "dimension sizes. TensorFlow: " + size + " Vespa: " + - vespaDimension.size().get()); - } - } else { - if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) { - throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + - "dimension types"); - } - } - this.dimensions.add(vespaDimension); - return this; - } - - public OrderedTensorType build() { - return new OrderedTensorType(dimensions); - } - - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java deleted file mode 100644 index 6cbfe0dfb05..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; - -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; - -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import java.util.function.DoubleBinaryOperator; - -public class Join extends TensorFlowOperation { - - private final DoubleBinaryOperator operator; - - public Join(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleBinaryOperator operator) { - super(modelName, node, inputs, port); - this.operator = operator; - } - - @Override - protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } - OrderedTensorType a = largestInput().type().get(); - OrderedTensorType b = smallestInput().type().get(); - - // Well now we have potentially entered the wonderful world of "broadcasting" - // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html - // In broadcasting, the size of each dimension is compared element-wise, - // starting with the trailing dimensions and working forward. A special - // case occurs when the size of one dimension is 1, while the other is not. - // Then the dimension with size 1 is "stretched" to be of compatible size. - // - // An example: - // - // Tensor A: d0[5], d1[1], d2[3], d3[1] - // Tensor B: d1[4], d2[1], d3[2] - // - // In TensorFlow and using the above rules of broadcasting, the resulting - // type is: - // d0[5], d1[4], d2[3], d2[2] - // - // However, in Vespa's tensor logic, the join of the two above tensors would - // result in a tensor of type: - // d0[5], d1[1], d2[1], d3[1] - // - // By reducing the dimensions of size 1 in each tensor before joining, - // we get equal results as in TensorFlow. - - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node); - int sizeDifference = a.rank() - b.rank(); - for (int i = 0; i < a.rank(); ++i) { - TensorType.Dimension aDim = a.dimensions().get(i); - long size = aDim.size().orElse(-1L); - - if (i - sizeDifference >= 0) { - TensorType.Dimension bDim = b.dimensions().get(i - sizeDifference); - size = Math.max(size, bDim.size().orElse(-1L)); - } - - if (aDim.type() == TensorType.Dimension.Type.indexedBound) { - builder.add(TensorType.Dimension.indexed(aDim.name(), size)); - } else if (aDim.type() == TensorType.Dimension.Type.indexedUnbound) { - builder.add(TensorType.Dimension.indexed(aDim.name())); - } else if (aDim.type() == TensorType.Dimension.Type.mapped) { - builder.add(TensorType.Dimension.mapped(aDim.name())); - } - } - return builder.build(); - } - - @Override - protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } - if (!allInputFunctionsPresent(2)) { - return null; - } - - TensorFlowOperation a = largestInput(); - TensorFlowOperation b = smallestInput(); - - List<String> aDimensionsToReduce = new ArrayList<>(); - List<String> bDimensionsToReduce = new ArrayList<>(); - int sizeDifference = a.type().get().rank() - b.type().get().rank(); - for (int i = 0; i < b.type().get().rank(); ++i) { - TensorType.Dimension bDim = b.type().get().dimensions().get(i); - TensorType.Dimension aDim = a.type().get().dimensions().get(i + sizeDifference); - long bSize = bDim.size().orElse(-1L); - long aSize = aDim.size().orElse(-1L); - if (bSize == 1L && aSize != 1L) { - bDimensionsToReduce.add(bDim.name()); - } - if (aSize == 1L && bSize != 1L) { - aDimensionsToReduce.add(bDim.name()); - } - } - - TensorFunction aReducedFunction = a.function().get(); - if (aDimensionsToReduce.size() > 0) { - aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce); - } - TensorFunction bReducedFunction = b.function().get(); - if (bDimensionsToReduce.size() > 0) { - bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce); - } - - return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator); - } - - @Override - public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!allInputTypesPresent(2)) { - return; - } - OrderedTensorType a = largestInput().type().get(); - OrderedTensorType b = smallestInput().type().get(); - int sizeDifference = a.rank() - b.rank(); - for (int i = 0; i < b.rank(); ++i) { - String bDim = b.dimensions().get(i).name(); - String aDim = a.dimensions().get(i + sizeDifference).name(); - renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this); - } - } - - private TensorFlowOperation largestInput() { - OrderedTensorType a = inputs.get(0).type().get(); - OrderedTensorType b = inputs.get(1).type().get(); - return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1); - } - - private TensorFlowOperation smallestInput() { - OrderedTensorType a = inputs.get(0).type().get(); - OrderedTensorType b = inputs.get(1).type().get(); - return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java deleted file mode 100644 index b2b9530a161..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; - -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; - -import java.util.List; -import java.util.Optional; - -public class Matmul extends TensorFlowOperation { - - public Matmul(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); - } - - @Override - protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node); - typeBuilder.add(inputs.get(0).type().get().dimensions().get(0)); - typeBuilder.add(inputs.get(1).type().get().dimensions().get(1)); - return typeBuilder.build(); - } - - @Override - protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } - OrderedTensorType aType = inputs.get(0).type().get(); - OrderedTensorType bType = inputs.get(1).type().get(); - if (aType.type().rank() < 2 || bType.type().rank() < 2) - throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); - if (aType.type().rank() != bType.type().rank()) - throw new IllegalArgumentException("Tensors in matmul must have the same rank"); - - Optional<TensorFunction> aFunction = inputs.get(0).function(); - Optional<TensorFunction> bFunction = inputs.get(1).function(); - if (!aFunction.isPresent() || !bFunction.isPresent()) { - return null; - } - return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name()); - } - - @Override - public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!allInputTypesPresent(2)) { - return; - } - List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); - List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); - - String aDim0 = aDimensions.get(0).name(); - String aDim1 = aDimensions.get(1).name(); - String bDim0 = bDimensions.get(0).name(); - String bDim1 = bDimensions.get(1).name(); - - // The second dimension of a should have the same name as the first dimension of b - renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this); - - // The first dimension of a should have a different name than the second dimension of b - renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this); - - // For efficiency, the dimensions to join over should be innermost - soft constraint - renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this); - renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java deleted file mode 100644 index d558ec89e87..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; - -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; - -import java.util.Collections; -import java.util.List; - -public class NoOp extends TensorFlowOperation { - - public NoOp(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, Collections.emptyList(), port); // don't propagate inputs - } - - @Override - protected OrderedTensorType lazyGetType() { - return null; - } - - @Override - protected TensorFunction lazyGetFunction() { - return null; - } - - @Override - public boolean isConstant() { - return true; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java deleted file mode 100644 index b18a8a9b212..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; - -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; - -import java.util.List; - -public class Variable extends TensorFlowOperation { - - public Variable(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); - } - - /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ - @Override - public String vespaName() { - return modelName() + "_" + super.vespaName(); - } - - @Override - protected OrderedTensorType lazyGetType() { - return OrderedTensorType.fromTensorFlowType(node, super.vespaName() + "_"); - } - - @Override - protected TensorFunction lazyGetFunction() { - return null; // will be added by function() since this is constant. - } - - @Override - public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); - } - } - - @Override - public boolean isConstant() { - return true; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java deleted file mode 100644 index 9e53990a9d6..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -/** - * Tensorflow integration - */ -@ExportPackage -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.osgi.annotation.ExportPackage; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/Benchmark.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/Benchmark.java index bb2110a0f5f..51a1b09b9fa 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/Benchmark.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/Benchmark.java @@ -16,7 +16,7 @@ import java.util.LinkedList; import java.util.List; /** - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen */ public final class Benchmark { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/StreamEvaluationBenchmark.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/StreamEvaluationBenchmark.java index 280ffc6278b..760e056327c 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/StreamEvaluationBenchmark.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/StreamEvaluationBenchmark.java @@ -147,11 +147,6 @@ public class StreamEvaluationBenchmark { new StreamEvaluationBenchmark().run(); } - private void assertEqualish(double a,double b) { - if (Math.abs(a-b) >= Math.abs((a+b)/100000000) ) - throw new RuntimeException("Expected value " + a + " but optimized evaluation produced " + b); - } - private void bindStreamingFeatures(Map<String, Double> featureItem, Context context) { for (Map.Entry<String, Double> feature : featureItem.entrySet()) context.put(feature.getKey(), feature.getValue()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java index 0f5eec93feb..bf9684082f4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; import com.yahoo.searchlib.rankingexpression.RankingExpression; import org.junit.Test; @@ -15,7 +15,7 @@ public class BatchNormImportTestCase { @Test public void testBatchNormImport() { TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved"); - TensorFlowModel.Signature signature = model.get().signature("serving_default"); + ImportedModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java index 74b0d11f1d6..c8c7ec798bb 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java @@ -1,6 +1,6 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; import org.junit.Test; import static org.junit.Assert.assertTrue; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java index 50a467ec581..a63c7346335 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.TensorType; @@ -24,7 +24,7 @@ public class DropoutImportTestCase { assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), model.get().requiredMacros().get("X")); - TensorFlowModel.Signature signature = model.get().signature("serving_default"); + ImportedModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); @@ -32,7 +32,7 @@ public class DropoutImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("outputs/Maximum", output.getName()); - assertEquals("join(join(tf_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))", + assertEquals("join(join(imported_ml_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))", output.getRoot().toString()); model.assertEqualResult("X", output.getName()); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java index 9f919c452d6..bd7644be23b 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; @@ -45,7 +45,7 @@ public class MnistSoftmaxImportTestCase { // Check signatures assertEquals(1, model.get().signatures().size()); - TensorFlowModel.Signature signature = model.get().signatures().get("serving_default"); + ImportedModel.Signature signature = model.get().signatures().get("serving_default"); assertNotNull(signature); // ... signature inputs diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java index 4b68cd40a08..a7926cd2e02 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java @@ -1,11 +1,9 @@ -package com.yahoo.searchlib.rankingexpression.integration.onnx; +package com.yahoo.searchlib.rankingexpression.integration.ml; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -24,7 +22,7 @@ public class OnnxMnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() throws IOException { - OnnxModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"); + ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"); // Check constants assertEquals(2, model.largeConstants().size()); @@ -48,7 +46,7 @@ public class OnnxMnistSoftmaxImportTestCase { model.requiredMacros().get("Placeholder")); // Check outputs - RankingExpression output = model.outputExpression("add"); + RankingExpression output = model.defaultSignature().outputExpression("add"); assertNotNull(output); assertEquals("add", output.getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", @@ -68,13 +66,12 @@ public class OnnxMnistSoftmaxImportTestCase { } private Tensor evaluateTensorFlowModel(String path, Tensor argument, String input, String output) { - SavedModelBundle tensorFlowModel = SavedModelBundle.load(path, "serve"); - TensorFlowModel model = new TensorFlowImporter().importModel("test", tensorFlowModel); + ImportedModel model = new TensorFlowImporter().importModel("test", path); return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input); } private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) { - OnnxModel model = new OnnxImporter().importModel("test", path); + ImportedModel model = new OnnxImporter().importModel("test", path); return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input); } @@ -83,14 +80,7 @@ public class OnnxMnistSoftmaxImportTestCase { return expression.evaluate(context).asTensor(); } - private Context contextFrom(TensorFlowModel result) { - MapContext context = new MapContext(); - result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); - result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); - return context; - } - - private Context contextFrom(OnnxModel result) { + private Context contextFrom(ImportedModel result) { MapContext context = new MapContext(); result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java index beec2ab1ead..b2443082ab1 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java @@ -1,6 +1,6 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java index 7ca16939477..723c5f27914 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java @@ -1,11 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; @@ -28,7 +28,7 @@ import static org.junit.Assert.assertEquals; public class TestableTensorFlowModel { private SavedModelBundle tensorFlowModel; - private TensorFlowModel model; + private ImportedModel model; // Sizes of the input vector private final int d0Size = 1; @@ -39,7 +39,7 @@ public class TestableTensorFlowModel { model = new TensorFlowImporter().importModel(modelName, tensorFlowModel); } - public TensorFlowModel get() { return model; } + public ImportedModel get() { return model; } public void assertEqualResult(String inputName, String operationName) { Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName); @@ -66,7 +66,7 @@ public class TestableTensorFlowModel { return TensorConverter.toVespaTensor(results.get(0)); } - private Context contextFrom(TensorFlowModel result) { + private Context contextFrom(ImportedModel result) { MapContext context = new MapContext(); result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); @@ -81,7 +81,7 @@ public class TestableTensorFlowModel { return b.build(); } - private void evaluateMacro(Context context, TensorFlowModel model, String macroName) { + private void evaluateMacro(Context context, ImportedModel model, String macroName) { if (!context.names().contains(macroName)) { RankingExpression e = model.macros().get(macroName); evaluateMacroDependencies(context, model, e.getRoot()); @@ -89,7 +89,7 @@ public class TestableTensorFlowModel { } } - private void evaluateMacroDependencies(Context context, TensorFlowModel model, ExpressionNode node) { + private void evaluateMacroDependencies(Context context, ImportedModel model, ExpressionNode node) { if (node instanceof ReferenceNode) { String name = node.toString(); if (model.macros().containsKey(name)) { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java index 051c2c60c95..f94098e6255 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java @@ -1,4 +1,4 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; import org.junit.Test; diff --git a/searchsummary/CMakeLists.txt b/searchsummary/CMakeLists.txt index 5f6e8881f13..4df636e0219 100644 --- a/searchsummary/CMakeLists.txt +++ b/searchsummary/CMakeLists.txt @@ -24,6 +24,7 @@ vespa_define_module( TESTS src/tests/docsumformat src/tests/docsummary + src/tests/docsummary/attribute_combiner src/tests/docsummary/slime_summary src/tests/extractkeywords ) diff --git a/searchsummary/src/tests/docsummary/attribute_combiner/CMakeLists.txt b/searchsummary/src/tests/docsummary/attribute_combiner/CMakeLists.txt new file mode 100644 index 00000000000..df323b9c982 --- /dev/null +++ b/searchsummary/src/tests/docsummary/attribute_combiner/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(searchsummary_attribute_combiner_test_app TEST + SOURCES + attribute_combiner_test.cpp + DEPENDS + searchsummary +) +vespa_add_test(NAME searchsummary_attribute_combiner_test_app COMMAND searchsummary_attribute_combiner_test_app) diff --git a/searchsummary/src/tests/docsummary/attribute_combiner/attribute_combiner_test.cpp b/searchsummary/src/tests/docsummary/attribute_combiner/attribute_combiner_test.cpp new file mode 100644 index 00000000000..97fafd0a446 --- /dev/null +++ b/searchsummary/src/tests/docsummary/attribute_combiner/attribute_combiner_test.cpp @@ -0,0 +1,217 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchcommon/common/undefinedvalues.h> +#include <vespa/searchlib/attribute/attributefactory.h> +#include <vespa/searchlib/attribute/attributemanager.h> +#include <vespa/searchlib/attribute/attributevector.h> +#include <vespa/searchlib/attribute/attributevector.hpp> +#include <vespa/searchlib/attribute/floatbase.h> +#include <vespa/searchlib/attribute/integerbase.h> +#include <vespa/searchlib/attribute/stringbase.h> +#include <vespa/searchlib/util/slime_output_raw_buf_adapter.h> +#include <vespa/searchsummary/docsummary/docsumstate.h> +#include <vespa/searchsummary/docsummary/docsum_field_writer_state.h> +#include <vespa/searchsummary/docsummary/attribute_combiner_dfw.h> +#include <vespa/vespalib/data/slime/slime.h> +#include <vespa/vespalib/testkit/testapp.h> + +#include <vespa/log/log.h> +LOG_SETUP("attribute_combiner_test"); + +using search::AttributeFactory; +using search::AttributeManager; +using search::AttributeVector; +using search::IntegerAttribute; +using search::FloatingPointAttribute; +using search::StringAttribute; +using search::attribute::BasicType; +using search::attribute::CollectionType; +using search::attribute::Config; +using search::attribute::IAttributeVector; +using search::attribute::getUndefined; +using search::docsummary::AttributeCombinerDFW; +using search::docsummary::GetDocsumsState; +using search::docsummary::GetDocsumsStateCallback; +using search::docsummary::IDocsumEnvironment; +using search::docsummary::IDocsumFieldWriter; + +namespace { + +vespalib::string +toCompactJsonString(const vespalib::Slime &slime) +{ + vespalib::SimpleBuffer buf; + vespalib::slime::JsonFormat::encode(slime, buf, true); + return buf.get().make_string(); +} + +struct FieldBlock { + vespalib::string input; + vespalib::Slime slime; + search::RawBuf binary; + vespalib::string json; + + explicit FieldBlock(const vespalib::string &jsonInput) + : input(jsonInput), slime(), binary(1024), json() + { + size_t used = vespalib::slime::JsonFormat::decode(jsonInput, slime); + EXPECT_TRUE(used > 0); + json = toCompactJsonString(slime); + search::SlimeOutputRawBufAdapter adapter(binary); + vespalib::slime::BinaryFormat::encode(slime, adapter); + } + const char *data() const { return binary.GetDrainPos(); } + size_t dataLen() const { return binary.GetUsedLen(); } +}; + +struct AttributeManagerFixture +{ + AttributeManager mgr; + + AttributeManagerFixture(); + + ~AttributeManagerFixture(); + + template <typename AttributeType, typename ValueType> + void + buildAttribute(const vespalib::string &name, + BasicType type, + std::vector<std::vector<ValueType>> values); + + void + buildStringAttribute(const vespalib::string &name, + std::vector<std::vector<vespalib::string>> values); + void + buildFloatAttribute(const vespalib::string &name, + std::vector<std::vector<double>> values); + + void + buildIntegerAttribute(const vespalib::string &name, + BasicType type, + std::vector<std::vector<IAttributeVector::largeint_t>> values); +}; + +AttributeManagerFixture::AttributeManagerFixture() + : mgr() +{ + buildStringAttribute("array.name", {{"n1.1", "n1.2"}, {"n2"}, {"n3.1", "n3.2"}, {"", "n4.2"}}); + buildIntegerAttribute("array.val", BasicType::Type::INT8, {{ 10, 11}, {20, 21 }, {30}, { getUndefined<int8_t>(), 41}}); + buildFloatAttribute("array.fval", {{ 110.0}, { 120.0, 121.0 }, { 130.0, 131.0}, { getUndefined<double>(), 141.0 }}); +} + +AttributeManagerFixture::~AttributeManagerFixture() = default; + +template <typename AttributeType, typename ValueType> +void +AttributeManagerFixture::buildAttribute(const vespalib::string &name, + BasicType type, + std::vector<std::vector<ValueType>> values) +{ + Config cfg(type, CollectionType::Type::ARRAY); + auto attrBase = AttributeFactory::createAttribute(name, cfg); + EXPECT_TRUE(attrBase); + auto attr = std::dynamic_pointer_cast<AttributeType>(attrBase); + EXPECT_TRUE(attr); + attr->addReservedDoc(); + for (const auto &docValues : values) { + uint32_t docId = 0; + EXPECT_TRUE(attr->addDoc(docId)); + EXPECT_NOT_EQUAL(0u, docId); + for (const auto &value : docValues) { + attr->append(docId, value, 1); + } + attr->commit(); + } + EXPECT_TRUE(mgr.add(attr)); +} + +void +AttributeManagerFixture::buildStringAttribute(const vespalib::string &name, + std::vector<std::vector<vespalib::string>> values) +{ + buildAttribute<StringAttribute, vespalib::string>(name, BasicType::Type::STRING, std::move(values)); +} + +void +AttributeManagerFixture::buildFloatAttribute(const vespalib::string &name, + std::vector<std::vector<double>> values) +{ + buildAttribute<FloatingPointAttribute, double>(name, BasicType::Type::DOUBLE, std::move(values)); +} + +void +AttributeManagerFixture::buildIntegerAttribute(const vespalib::string &name, + BasicType type, + std::vector<std::vector<IAttributeVector::largeint_t>> values) +{ + buildAttribute<IntegerAttribute, IAttributeVector::largeint_t>(name, type, std::move(values)); +} + + +class DummyStateCallback : public GetDocsumsStateCallback +{ +public: + void FillSummaryFeatures(GetDocsumsState *, IDocsumEnvironment *) override { } + void FillRankFeatures(GetDocsumsState *, IDocsumEnvironment *) override { } + void ParseLocation(GetDocsumsState *) override { } + ~DummyStateCallback() override { } +}; + + +struct Fixture +{ + AttributeManagerFixture attrs; + std::unique_ptr<IDocsumFieldWriter> writer; + DummyStateCallback stateCallback; + GetDocsumsState state; + + Fixture(); + ~Fixture(); + void assertWritten(const vespalib::string &exp, uint32_t docId); +}; + +Fixture::Fixture() + : attrs(), + writer(AttributeCombinerDFW::create("array", attrs.mgr)), + stateCallback(), + state(stateCallback) +{ + EXPECT_TRUE(writer->setFieldWriterStateIndex(0)); + state._attrCtx = attrs.mgr.createContext(); + state._fieldWriterStates.resize(1); +} + +Fixture::~Fixture() +{ +} + +void +Fixture::assertWritten(const vespalib::string &expectedJson, uint32_t docId) +{ + vespalib::Slime target; + vespalib::slime::SlimeInserter inserter(target); + writer->insertField(docId, nullptr, &state, search::docsummary::RES_JSONSTRING, inserter); + search::RawBuf binary(1024); + vespalib::string json = toCompactJsonString(target); + search::SlimeOutputRawBufAdapter adapter(binary); + vespalib::slime::BinaryFormat::encode(target, adapter); + FieldBlock block(expectedJson); + if (!EXPECT_EQUAL(block.dataLen(), binary.GetUsedLen()) || + !EXPECT_EQUAL(0, memcmp(block.data(), binary.GetDrainPos(), block.dataLen()))) { + LOG(error, "Expected '%s'", expectedJson.c_str()); + LOG(error, "Expected normalized '%s'", block.json.c_str()); + LOG(error, "Got '%s'", json.c_str()); + } +} + +TEST_F("require that attributes combiner dfw generates correct slime output for array of struct", Fixture()) +{ + f.assertWritten("[ { fval: 110.0, name: \"n1.1\", val: 10}, { name: \"n1.2\", val: 11}]", 1); + f.assertWritten("[ { fval: 120.0, name: \"n2\", val: 20}, { fval: 121.0, val: 21 }]", 2); + f.assertWritten("[ { fval: 130.0, name: \"n3.1\", val: 30}, { fval: 131.0, name: \"n3.2\"} ]", 3); + f.assertWritten("[ { }, { fval: 141.0, name: \"n4.2\", val: 41} ]", 4); +} + +} + +TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/searchsummary/src/vespa/searchsummary/docsummary/CMakeLists.txt b/searchsummary/src/vespa/searchsummary/docsummary/CMakeLists.txt index 9009f0bcbc7..ce54e7b0ea7 100644 --- a/searchsummary/src/vespa/searchsummary/docsummary/CMakeLists.txt +++ b/searchsummary/src/vespa/searchsummary/docsummary/CMakeLists.txt @@ -1,6 +1,9 @@ # Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. vespa_add_library(searchsummary_docsummary OBJECT SOURCES + array_attribute_combiner_dfw.cpp + attribute_combiner_dfw.cpp + attribute_field_writer.cpp resultclass.cpp resultconfig.cpp resultpacker.cpp diff --git a/searchsummary/src/vespa/searchsummary/docsummary/array_attribute_combiner_dfw.cpp b/searchsummary/src/vespa/searchsummary/docsummary/array_attribute_combiner_dfw.cpp new file mode 100644 index 00000000000..84e329f159d --- /dev/null +++ b/searchsummary/src/vespa/searchsummary/docsummary/array_attribute_combiner_dfw.cpp @@ -0,0 +1,89 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "array_attribute_combiner_dfw.h" +#include "docsum_field_writer_state.h" +#include "attribute_field_writer.h" +#include <vespa/searchcommon/attribute/iattributecontext.h> +#include <vespa/searchcommon/attribute/iattributevector.h> +#include <vespa/vespalib/data/slime/cursor.h> + +using search::attribute::IAttributeContext; +using search::attribute::IAttributeVector; +using vespalib::slime::Cursor; + +namespace search::docsummary { + +namespace { + +class ArrayAttributeFieldWriterState : public DocsumFieldWriterState +{ + std::vector<std::unique_ptr<AttributeFieldWriter>> _writers; + +public: + ArrayAttributeFieldWriterState(const std::vector<vespalib::string> &fieldNames, + const std::vector<vespalib::string> &attributeNames, + IAttributeContext &context); + ~ArrayAttributeFieldWriterState() override; + void insertField(uint32_t docId, vespalib::slime::Inserter &target) override; +}; + +ArrayAttributeFieldWriterState::ArrayAttributeFieldWriterState(const std::vector<vespalib::string> &fieldNames, + const std::vector<vespalib::string> &attributeNames, + IAttributeContext &context) + : DocsumFieldWriterState() +{ + size_t fields = fieldNames.size(); + _writers.reserve(fields); + for (uint32_t field = 0; field < fields; ++field) { + const IAttributeVector *attr = context.getAttribute(attributeNames[field]); + if (attr != nullptr) { + _writers.emplace_back(AttributeFieldWriter::create(fieldNames[field], *attr)); + } + } +} + +ArrayAttributeFieldWriterState::~ArrayAttributeFieldWriterState() = default; + +void +ArrayAttributeFieldWriterState::insertField(uint32_t docId, vespalib::slime::Inserter &target) +{ + uint32_t elems = 0; + for (auto &writer : _writers) { + writer->fetch(docId); + if (elems < writer->size()) { + elems = writer->size(); + } + } + Cursor &arr = target.insertArray(); + for (uint32_t idx = 0; idx < elems; ++idx) { + Cursor &obj = arr.addObject(); + for (auto &writer : _writers) { + writer->print(idx, obj); + } + } +} + +} + +ArrayAttributeCombinerDFW::ArrayAttributeCombinerDFW(const vespalib::string &fieldName, + const std::vector<vespalib::string> &fields) + : AttributeCombinerDFW(fieldName), + _fields(fields), + _attributeNames() +{ + _attributeNames.reserve(_fields.size()); + vespalib::string prefix = fieldName + "."; + for (const auto &field : _fields) { + _attributeNames.emplace_back(prefix + field); + } +} + +ArrayAttributeCombinerDFW::~ArrayAttributeCombinerDFW() = default; + +std::unique_ptr<DocsumFieldWriterState> +ArrayAttributeCombinerDFW::allocFieldWriterState(IAttributeContext &context) +{ + return std::make_unique<ArrayAttributeFieldWriterState>(_fields, _attributeNames, context); +} + +} diff --git a/searchsummary/src/vespa/searchsummary/docsummary/array_attribute_combiner_dfw.h b/searchsummary/src/vespa/searchsummary/docsummary/array_attribute_combiner_dfw.h new file mode 100644 index 00000000000..c02d2bd5da6 --- /dev/null +++ b/searchsummary/src/vespa/searchsummary/docsummary/array_attribute_combiner_dfw.h @@ -0,0 +1,29 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "attribute_combiner_dfw.h" + +namespace search::attribute { class IAttributeContext; } + +namespace search::docsummary { + +class DocsumFieldWriterState; + +/* + * This class reads values from multiple struct field attributes and + * inserts them as an array of struct. + */ +class ArrayAttributeCombinerDFW : public AttributeCombinerDFW +{ + std::vector<vespalib::string> _fields; + std::vector<vespalib::string> _attributeNames; + + std::unique_ptr<DocsumFieldWriterState> allocFieldWriterState(search::attribute::IAttributeContext &context) override; +public: + ArrayAttributeCombinerDFW(const vespalib::string &fieldName, + const std::vector<vespalib::string> &fields); + ~ArrayAttributeCombinerDFW() override; +}; + +} diff --git a/searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.cpp b/searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.cpp new file mode 100644 index 00000000000..b532cfb273a --- /dev/null +++ b/searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.cpp @@ -0,0 +1,141 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "attribute_combiner_dfw.h" +#include "array_attribute_combiner_dfw.h" +#include "docsum_field_writer_state.h" +#include "docsumstate.h" +#include <vespa/searchlib/attribute/attributeguard.h> +#include <vespa/searchlib/attribute/attributevector.h> +#include <vespa/searchlib/attribute/iattributemanager.h> +#include <algorithm> + +#include <vespa/log/log.h> +LOG_SETUP(".searchsummary.docsummary.attribute_combiner_dfw"); + +using search::AttributeGuard; +using search::AttributeVector; +using search::attribute::CollectionType; + +namespace search::docsummary { + +namespace { + +class StructFields +{ + std::vector<vespalib::string> _mapFields; + std::vector<vespalib::string> _arrayFields; + bool _hasMapKey; + bool _error; + +public: + StructFields(const vespalib::string &fieldName, const IAttributeManager &attrMgr); + ~StructFields(); + const std::vector<vespalib::string> &getMapFields() const { return _mapFields; } + const std::vector<vespalib::string> &getArrayFields() const { return _arrayFields; } + bool hasMapKey() const { return _hasMapKey; } + bool getError() const { return _error; } +}; + + +StructFields::StructFields(const vespalib::string &fieldName, const IAttributeManager &attrMgr) + : _mapFields(), + _arrayFields(), + _hasMapKey(false), + _error(false) +{ + // Note: Doesn't handle imported attributes + std::vector<AttributeGuard> attrs; + attrMgr.getAttributeList(attrs); + vespalib::string prefix = fieldName + "."; + vespalib::string keyName = prefix + "key"; + vespalib::string valuePrefix = prefix + "value."; + for (const auto &guard : attrs) { + vespalib::string name = guard->getName(); + if (name.substr(0, prefix.size()) != prefix) { + continue; + } + auto collType = guard->getCollectionType(); + if (collType != CollectionType::Type::ARRAY) { + LOG(warning, "Attribute %s is not an array attribute", name.c_str()); + _error = true; + break; + } + if (name.substr(0, valuePrefix.size()) == valuePrefix) { + _mapFields.emplace_back(name.substr(valuePrefix.size())); + } else { + _arrayFields.emplace_back(name.substr(prefix.size())); + if (name == keyName) { + _hasMapKey = true; + } + } + } + if (!_error) { + std::sort(_arrayFields.begin(), _arrayFields.end()); + std::sort(_mapFields.begin(), _mapFields.end()); + if (!_mapFields.empty()) { + if (!_hasMapKey) { + LOG(warning, "Missing key attribute '%s', have value attributes for map", keyName.c_str()); + _error = true; + } else if (_arrayFields.size() != 1u) { + LOG(warning, "Could not determine if field '%s' is array or map of struct", fieldName.c_str()); + _error = true; + } + } + } +} + +StructFields::~StructFields() = default; + +} + +AttributeCombinerDFW::AttributeCombinerDFW(const vespalib::string &fieldName) + : IDocsumFieldWriter(), + _stateIndex(0), + _fieldName(fieldName) +{ +} + +AttributeCombinerDFW::~AttributeCombinerDFW() = default; + +bool +AttributeCombinerDFW::IsGenerated() const +{ + return true; +} + +bool +AttributeCombinerDFW::setFieldWriterStateIndex(uint32_t fieldWriterStateIndex) +{ + _stateIndex = fieldWriterStateIndex; + return true; +} + +std::unique_ptr<IDocsumFieldWriter> +AttributeCombinerDFW::create(const vespalib::string &fieldName, IAttributeManager &attrMgr) +{ + StructFields structFields(fieldName, attrMgr); + if (structFields.getError()) { + return std::unique_ptr<IDocsumFieldWriter>(); + } else if (!structFields.getMapFields().empty()) { + LOG(warning, "map of struct is not yet supported for field '%s'", fieldName.c_str()); + return std::unique_ptr<IDocsumFieldWriter>(); + } + return std::make_unique<ArrayAttributeCombinerDFW>(fieldName, structFields.getArrayFields()); +} + +void +AttributeCombinerDFW::insertField(uint32_t docid, + GeneralResult *, + GetDocsumsState *state, + ResType, + vespalib::slime::Inserter &target) +{ + auto &fieldWriterState = state->_fieldWriterStates[_stateIndex]; + if (!fieldWriterState) { + fieldWriterState = allocFieldWriterState(*state->_attrCtx); + } + fieldWriterState->insertField(docid, target); +} + +} + diff --git a/searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.h b/searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.h new file mode 100644 index 00000000000..ef54522a923 --- /dev/null +++ b/searchsummary/src/vespa/searchsummary/docsummary/attribute_combiner_dfw.h @@ -0,0 +1,36 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "docsumfieldwriter.h" + +namespace search::attribute { class IAttributeContext; } + +namespace search::docsummary { + +class DocsumFieldWriterState; +class DynamicDocsumWriter; + +/* + * This class reads values from multiple struct field attributes and + * inserts them as an array of struct or a map of struct. + */ +class AttributeCombinerDFW : public IDocsumFieldWriter +{ +protected: + uint32_t _stateIndex; + vespalib::string _fieldName; + AttributeCombinerDFW(const vespalib::string &fieldName); +protected: + virtual std::unique_ptr<DocsumFieldWriterState> allocFieldWriterState(search::attribute::IAttributeContext &context) = 0; +public: + ~AttributeCombinerDFW() override; + bool IsGenerated() const override; + bool setFieldWriterStateIndex(uint32_t fieldWriterStateIndex) override; + static std::unique_ptr<IDocsumFieldWriter> create(const vespalib::string &fieldName, IAttributeManager &attrMgr); + void insertField(uint32_t docid, GeneralResult *gres, GetDocsumsState *state, + ResType type, vespalib::slime::Inserter &target) override; +}; + +} + diff --git a/searchsummary/src/vespa/searchsummary/docsummary/attribute_field_writer.cpp b/searchsummary/src/vespa/searchsummary/docsummary/attribute_field_writer.cpp new file mode 100644 index 00000000000..2eebe7137dc --- /dev/null +++ b/searchsummary/src/vespa/searchsummary/docsummary/attribute_field_writer.cpp @@ -0,0 +1,172 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "attribute_field_writer.h" +#include <vespa/searchcommon/attribute/attributecontent.h> +#include <vespa/searchcommon/common/undefinedvalues.h> +#include <vespa/vespalib/data/slime/cursor.h> +#include <cassert> + +using search::attribute::BasicType; +using search::attribute::IAttributeVector; +using search::attribute::getUndefined; +using vespalib::slime::Cursor; + +namespace search::docsummary { + +AttributeFieldWriter::AttributeFieldWriter(const vespalib::string &fieldName, + const IAttributeVector &attr) + : _fieldName(fieldName), + _attr(attr), + _size(0) +{ +} + +AttributeFieldWriter::~AttributeFieldWriter() = default; + +namespace { + +template <class Content> +class WriteField : public AttributeFieldWriter +{ +protected: + Content _content; + + WriteField(const vespalib::string &fieldName, const IAttributeVector &attr); + ~WriteField() override; +private: + void fetch(uint32_t docId) override; +}; + +class WriteStringField : public WriteField<search::attribute::ConstCharContent> +{ +public: + WriteStringField(const vespalib::string &fieldName, + const IAttributeVector &attr); + ~WriteStringField() override; + void print(uint32_t idx, Cursor &cursor) override; +}; + + +class WriteFloatField : public WriteField<search::attribute::FloatContent> +{ +public: + WriteFloatField(const vespalib::string &fieldName, + const IAttributeVector &attr); + ~WriteFloatField() override; + void print(uint32_t idx, Cursor &cursor) override; +}; + +class WriteIntField : public WriteField<search::attribute::IntegerContent> +{ + IAttributeVector::largeint_t _undefined; +public: + WriteIntField(const vespalib::string &fieldName, + const IAttributeVector &attr, + IAttributeVector::largeint_t undefined); + ~WriteIntField() override; + void print(uint32_t idx, Cursor &cursor) override; +}; + +template <class Content> +WriteField<Content>::WriteField(const vespalib::string &fieldName, const IAttributeVector &attr) + : AttributeFieldWriter(fieldName, attr), + _content() +{ +} + +template <class Content> +WriteField<Content>::~WriteField() = default; + +template <class Content> +void +WriteField<Content>::fetch(uint32_t docId) +{ + _content.fill(_attr, docId); + _size = _content.size(); +} + +WriteStringField::WriteStringField(const vespalib::string &fieldName, + const IAttributeVector &attr) + : WriteField(fieldName, attr) +{ +} + +WriteStringField::~WriteStringField() = default; + +void +WriteStringField::print(uint32_t idx, Cursor &cursor) +{ + if (idx < _size) { + const char *s = _content[idx]; + if (s[0] != '\0') { + cursor.setString(_fieldName, vespalib::Memory(s)); + } + } +} + +WriteFloatField::WriteFloatField(const vespalib::string &fieldName, + const IAttributeVector &attr) + : WriteField(fieldName, attr) +{ +} + +WriteFloatField::~WriteFloatField() = default; + +void +WriteFloatField::print(uint32_t idx, Cursor &cursor) +{ + if (idx < _size) { + double val = _content[idx]; + if (!search::attribute::isUndefined(val)) { + cursor.setDouble(_fieldName, val); + } + } +} + +WriteIntField::WriteIntField(const vespalib::string &fieldName, + const IAttributeVector &attr, + IAttributeVector::largeint_t undefined) + : WriteField(fieldName, attr), + _undefined(undefined) +{ +} + +WriteIntField::~WriteIntField() = default; + +void +WriteIntField::print(uint32_t idx, Cursor &cursor) +{ + if (idx < _size) { + auto val = _content[idx]; + if (val != _undefined) { + cursor.setLong(_fieldName, _content[idx]); + } + } +} + +} + +std::unique_ptr<AttributeFieldWriter> +AttributeFieldWriter::create(const vespalib::string &fieldName, const IAttributeVector &attr) +{ + switch (attr.getBasicType()) { + case BasicType::INT8: + return std::make_unique<WriteIntField>(fieldName, attr, getUndefined<int8_t>()); + case BasicType::INT16: + return std::make_unique<WriteIntField>(fieldName, attr, getUndefined<int16_t>()); + case BasicType::INT32: + return std::make_unique<WriteIntField>(fieldName, attr, getUndefined<int32_t>()); + case BasicType::INT64: + return std::make_unique<WriteIntField>(fieldName, attr, getUndefined<int64_t>()); + case BasicType::FLOAT: + case BasicType::DOUBLE: + return std::make_unique<WriteFloatField>(fieldName, attr); + case BasicType::STRING: + return std::make_unique<WriteStringField>(fieldName, attr); + default: + assert(false); + abort(); + } +} + +} diff --git a/searchsummary/src/vespa/searchsummary/docsummary/attribute_field_writer.h b/searchsummary/src/vespa/searchsummary/docsummary/attribute_field_writer.h new file mode 100644 index 00000000000..104455a0e79 --- /dev/null +++ b/searchsummary/src/vespa/searchsummary/docsummary/attribute_field_writer.h @@ -0,0 +1,34 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/vespalib/data/memory.h> + +namespace search::attribute { class IAttributeVector; } +namespace vespalib::slime { class Cursor; } + +namespace search::docsummary { + +/* + * This class reads values from a struct field attribute and inserts + * them into proper position in an array of struct or map of struct. + * If the value to be inserted is considered to be undefined then + * the value is not inserted. + */ +class AttributeFieldWriter +{ +protected: + const vespalib::Memory _fieldName; + const search::attribute::IAttributeVector &_attr; + size_t _size; +public: + AttributeFieldWriter(const vespalib::string &fieldName, + const search::attribute::IAttributeVector &attr); + virtual ~AttributeFieldWriter(); + virtual void fetch(uint32_t docId) = 0; + virtual void print(uint32_t idx, vespalib::slime::Cursor &cursor) = 0; + static std::unique_ptr<AttributeFieldWriter> create(const vespalib::string &fieldName, const search::attribute::IAttributeVector &attr); + uint32_t size() const { return _size; } +}; + +} diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_state.h b/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_state.h new file mode 100644 index 00000000000..940cfd6ce06 --- /dev/null +++ b/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_state.h @@ -0,0 +1,21 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +namespace vespalib::slime { class Inserter; } + +namespace search::docsummary { + +/* + * A subclass of this class can be instantiated by a document field writer to + * track extra state during handling of a document summary request and + * insert the field value using that state. + */ +class DocsumFieldWriterState +{ +public: + virtual void insertField(uint32_t docId, vespalib::slime::Inserter &target) = 0; + virtual ~DocsumFieldWriterState() = default; +}; + +} diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.cpp b/searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.cpp index 7b463352155..18e7e471663 100644 --- a/searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.cpp +++ b/searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.cpp @@ -21,6 +21,12 @@ using search::common::Location; const vespalib::string IDocsumFieldWriter::_empty(""); +bool +IDocsumFieldWriter::setFieldWriterStateIndex(uint32_t) +{ + return false; // Don't need any field writer state by default +} + //-------------------------------------------------------------------------- EmptyDFW::EmptyDFW() { } diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.h b/searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.h index abce5c12227..51079f7736e 100644 --- a/searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.h +++ b/searchsummary/src/vespa/searchsummary/docsummary/docsumfieldwriter.h @@ -40,6 +40,7 @@ public: } void setIndex(size_t v) { _index = v; } size_t getIndex() const { return _index; } + virtual bool setFieldWriterStateIndex(uint32_t fieldWriterStateIndex); private: size_t _index; static const vespalib::string _empty; diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsumstate.cpp b/searchsummary/src/vespa/searchsummary/docsummary/docsumstate.cpp index 91953612f6a..b0431b6e6ac 100644 --- a/searchsummary/src/vespa/searchsummary/docsummary/docsumstate.cpp +++ b/searchsummary/src/vespa/searchsummary/docsummary/docsumstate.cpp @@ -4,6 +4,7 @@ #include <vespa/juniper/rpinterface.h> #include <vespa/searchcommon/attribute/iattributecontext.h> #include <vespa/searchlib/common/location.h> +#include "docsum_field_writer_state.h" namespace search { namespace docsummary { @@ -19,6 +20,7 @@ GetDocsumsState::GetDocsumsState(GetDocsumsStateCallback &callback) _docSumFieldSpace(_docSumFieldSpaceStore, sizeof(_docSumFieldSpaceStore)), // only alloc buffer if needed _attrCtx(), _attributes(), + _fieldWriterStates(), _jsonStringer(), _parsedLocation(), _summaryFeatures(NULL), diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsumstate.h b/searchsummary/src/vespa/searchsummary/docsummary/docsumstate.h index 4ffed79043e..fa47d5244eb 100644 --- a/searchsummary/src/vespa/searchsummary/docsummary/docsumstate.h +++ b/searchsummary/src/vespa/searchsummary/docsummary/docsumstate.h @@ -23,6 +23,7 @@ namespace search::docsummary { class GetDocsumsState; class IDocsumEnvironment; class KeywordExtractor; +class DocsumFieldWriterState; class GetDocsumsStateCallback { @@ -70,6 +71,7 @@ public: char _docSumFieldSpaceStore[2048]; std::unique_ptr<search::attribute::IAttributeContext> _attrCtx; std::vector<const search::attribute::IAttributeVector *> _attributes; + std::vector<std::unique_ptr<DocsumFieldWriterState>> _fieldWriterStates; vespalib::JSONStringer _jsonStringer; // used by AbsDistanceDFW diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.cpp b/searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.cpp index bf660b1319b..abd1780b773 100644 --- a/searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.cpp +++ b/searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.cpp @@ -2,6 +2,7 @@ #include "docsumwriter.h" #include "docsumstate.h" +#include "docsum_field_writer_state.h" #include <vespa/searchlib/common/transport.h> #include <vespa/searchlib/util/slime_output_raw_buf_adapter.h> #include <vespa/searchlib/attribute/iattributemanager.h> @@ -77,7 +78,6 @@ DynamicDocsumWriter::resolveInputClass(ResolveClassInfo &rci, uint32_t id) const } } - static void convertEntry(GetDocsumsState *state, const ResConfigEntry *resCfg, const ResEntry *entry, @@ -194,6 +194,7 @@ DynamicDocsumWriter::DynamicDocsumWriter( ResultConfig *config, KeywordExtractor _defaultOutputClass(ResultConfig::NoClassID()), _numClasses(config->GetNumResultClasses()), _numEnumValues(config->GetFieldNameEnum().GetNumEntries()), + _numFieldWriterStates(0), _classInfoTable(nullptr), _overrideTable(nullptr) { @@ -267,6 +268,9 @@ DynamicDocsumWriter::Override(const char *fieldName, IDocsumFieldWriter *writer) writer->setIndex(fieldEnumValue); _overrideTable[fieldEnumValue] = writer; + if (writer->setFieldWriterStateIndex(_numFieldWriterStates)) { + ++_numFieldWriterStates; + } for (ResultConfig::iterator it(_resultConfig->begin()), mt(_resultConfig->end()); it != mt; it++) { @@ -288,6 +292,7 @@ DynamicDocsumWriter::InitState(IAttributeManager & attrMan, GetDocsumsState *sta state->_kwExtractor = _keywordExtractor; state->_attrCtx = attrMan.createContext(); state->_attributes.resize(_numEnumValues); + state->_fieldWriterStates.resize(_numFieldWriterStates); for (size_t i(0); i < state->_attributes.size(); i++) { const IDocsumFieldWriter *fw = _overrideTable[i]; if (fw) { diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.h b/searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.h index 6ef21a71e74..92b26d5cf14 100644 --- a/searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.h +++ b/searchsummary/src/vespa/searchsummary/docsummary/docsumwriter.h @@ -54,6 +54,7 @@ private: uint32_t _defaultOutputClass; uint32_t _numClasses; uint32_t _numEnumValues; + uint32_t _numFieldWriterStates; ResultClass::DynamicInfo *_classInfoTable; IDocsumFieldWriter **_overrideTable; diff --git a/service-monitor/pom.xml b/service-monitor/pom.xml index 70f9d4aa655..b8065ed3636 100644 --- a/service-monitor/pom.xml +++ b/service-monitor/pom.xml @@ -64,6 +64,12 @@ <version>${project.version}</version> </dependency> <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>vespa-athenz</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> <groupId>com.google.inject</groupId> <artifactId>guice</artifactId> <scope>provided</scope> @@ -76,6 +82,23 @@ <scope>provided</scope> </dependency> <dependency> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-core</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-databind</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.httpcomponents</groupId> + <artifactId>httpclient</artifactId> + <version>4.5</version> + <!-- This is necessary to get 4.4's HostnameVerifier API of SSLConnectionSocketFactory::new --> + <scope>compile</scope> + </dependency> + <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> <scope>test</scope> diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/ServiceStatusProvider.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/ServiceStatusProvider.java index 35003313775..75e61eef772 100644 --- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/ServiceStatusProvider.java +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/ServiceStatusProvider.java @@ -11,7 +11,13 @@ import com.yahoo.vespa.applicationmodel.ServiceType; * @author hakon */ public interface ServiceStatusProvider { - /** Get the {@link ServiceStatus} of a particular service. */ + /** + * Get the {@link ServiceStatus} of a particular service. + * + * <p>{@link ServiceStatus#NOT_CHECKED NOT_CHECKED} must be returned if the + * service status provider does does not monitor the service status for + * the particular application, cluster, service type, and config id. + */ ServiceStatus getStatus(ApplicationId applicationId, ClusterId clusterId, ServiceType serviceType, diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ApplicationInstanceGenerator.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ApplicationInstanceGenerator.java index ec2702bcfaf..cbdcce125cc 100644 --- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ApplicationInstanceGenerator.java +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ApplicationInstanceGenerator.java @@ -1,13 +1,148 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.service.monitor.application; +import com.yahoo.config.model.api.ApplicationInfo; +import com.yahoo.config.model.api.HostInfo; +import com.yahoo.config.model.api.ServiceInfo; +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.Zone; import com.yahoo.vespa.applicationmodel.ApplicationInstance; +import com.yahoo.vespa.applicationmodel.ApplicationInstanceId; +import com.yahoo.vespa.applicationmodel.ClusterId; +import com.yahoo.vespa.applicationmodel.ConfigId; +import com.yahoo.vespa.applicationmodel.HostName; +import com.yahoo.vespa.applicationmodel.ServiceCluster; +import com.yahoo.vespa.applicationmodel.ServiceClusterKey; +import com.yahoo.vespa.applicationmodel.ServiceInstance; +import com.yahoo.vespa.applicationmodel.ServiceStatus; +import com.yahoo.vespa.applicationmodel.ServiceType; +import com.yahoo.vespa.applicationmodel.TenantId; import com.yahoo.vespa.service.monitor.ServiceStatusProvider; +import com.yahoo.vespa.service.monitor.internal.ServiceId; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static com.yahoo.vespa.service.monitor.application.ConfigServerApplication.CONFIG_SERVER_APPLICATION; /** + * Class to generate an ApplicationInstance given service status for a standard (deployed) application. + * * @author hakon */ -public interface ApplicationInstanceGenerator { - /** Make an ApplicationInstance based on current service status. */ - ApplicationInstance makeApplicationInstance(ServiceStatusProvider serviceStatusProvider); +public class ApplicationInstanceGenerator { + public static final String CLUSTER_ID_PROPERTY_NAME = "clustername"; + + private final ApplicationInfo applicationInfo; + private final Zone zone; + + public ApplicationInstanceGenerator(ApplicationInfo applicationInfo, Zone zone) { + this.applicationInfo = applicationInfo; + this.zone = zone; + } + + public ApplicationInstance makeApplicationInstance(ServiceStatusProvider serviceStatusProvider) { + Map<ServiceClusterKey, Set<ServiceInstance>> groupedServiceInstances = new HashMap<>(); + + for (HostInfo host : applicationInfo.getModel().getHosts()) { + HostName hostName = new HostName(host.getHostname()); + for (ServiceInfo serviceInfo : host.getServices()) { + ServiceClusterKey serviceClusterKey = toServiceClusterKey(serviceInfo); + ServiceInstance serviceInstance = + toServiceInstance( + applicationInfo.getApplicationId(), + serviceClusterKey.clusterId(), + serviceInfo, + hostName, + serviceStatusProvider); + + if (!groupedServiceInstances.containsKey(serviceClusterKey)) { + groupedServiceInstances.put(serviceClusterKey, new HashSet<>()); + } + groupedServiceInstances.get(serviceClusterKey).add(serviceInstance); + } + } + + Set<ServiceCluster> serviceClusters = groupedServiceInstances.entrySet().stream() + .map(entry -> new ServiceCluster( + entry.getKey().clusterId(), + entry.getKey().serviceType(), + entry.getValue())) + .collect(Collectors.toSet()); + + ApplicationInstance applicationInstance = new ApplicationInstance( + new TenantId(applicationInfo.getApplicationId().tenant().toString()), + toApplicationInstanceId(applicationInfo, zone), + serviceClusters); + + // Fill back-references + for (ServiceCluster serviceCluster : applicationInstance.serviceClusters()) { + serviceCluster.setApplicationInstance(applicationInstance); + for (ServiceInstance serviceInstance : serviceCluster.serviceInstances()) { + serviceInstance.setServiceCluster(serviceCluster); + } + } + + return applicationInstance; + } + + private ServiceInstance toServiceInstance( + ApplicationId applicationId, + ClusterId clusterId, + ServiceInfo serviceInfo, + HostName hostName, + ServiceStatusProvider serviceStatusProvider) { + ConfigId configId = toConfigId(serviceInfo); + + ServiceStatus status = serviceStatusProvider.getStatus( + applicationId, + clusterId, + toServiceType(serviceInfo), configId); + + return new ServiceInstance(configId, hostName, status); + } + + private ApplicationInstanceId toApplicationInstanceId(ApplicationInfo applicationInfo, Zone zone) { + if (applicationInfo.getApplicationId().equals(CONFIG_SERVER_APPLICATION.getApplicationId())) { + // Removing this historical discrepancy would break orchestration during rollout. + // An alternative may be to use a feature flag and flip it between releases, + // once that's available. + return new ApplicationInstanceId(applicationInfo.getApplicationId().application().value()); + } else { + return new ApplicationInstanceId(String.format("%s:%s:%s:%s", + applicationInfo.getApplicationId().application().value(), + zone.environment().value(), + zone.region().value(), + applicationInfo.getApplicationId().instance().value())); + } + } + + public static ServiceId getServiceId(ApplicationInfo applicationInfo, ServiceInfo serviceInfo) { + return new ServiceId( + applicationInfo.getApplicationId(), + getClusterId(serviceInfo), + toServiceType(serviceInfo), + toConfigId(serviceInfo)); + } + + private static ClusterId getClusterId(ServiceInfo serviceInfo) { + return new ClusterId(serviceInfo.getProperty(CLUSTER_ID_PROPERTY_NAME).orElse("")); + } + + private static ServiceClusterKey toServiceClusterKey(ServiceInfo serviceInfo) { + ClusterId clusterId = getClusterId(serviceInfo); + ServiceType serviceType = toServiceType(serviceInfo); + return new ServiceClusterKey(clusterId, serviceType); + } + + private static ServiceType toServiceType(ServiceInfo serviceInfo) { + return new ServiceType(serviceInfo.getServiceType()); + } + + private static ConfigId toConfigId(ServiceInfo serviceInfo) { + return new ConfigId(serviceInfo.getConfigId()); + } } diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ConfigServerAppGenerator.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ConfigServerAppGenerator.java deleted file mode 100644 index 76ca59cf583..00000000000 --- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ConfigServerAppGenerator.java +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.service.monitor.application; - -import com.yahoo.vespa.applicationmodel.ApplicationInstance; -import com.yahoo.vespa.applicationmodel.ConfigId; -import com.yahoo.vespa.applicationmodel.HostName; -import com.yahoo.vespa.applicationmodel.ServiceCluster; -import com.yahoo.vespa.applicationmodel.ServiceInstance; -import com.yahoo.vespa.applicationmodel.ServiceStatus; -import com.yahoo.vespa.service.monitor.ServiceStatusProvider; - -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * Class for generating an ApplicationInstance for the synthesized config server application. - * - * @author hakon - */ -public class ConfigServerAppGenerator implements ApplicationInstanceGenerator { - private final List<String> hostnames; - - public ConfigServerAppGenerator(List<String> hostnames) { - this.hostnames = hostnames; - } - - @Override - public ApplicationInstance makeApplicationInstance(ServiceStatusProvider statusProvider) { - Set<ServiceInstance> serviceInstances = hostnames.stream() - .map(hostname -> makeServiceInstance(hostname, statusProvider)) - .collect(Collectors.toSet()); - - ServiceCluster serviceCluster = new ServiceCluster( - ConfigServerApplication.CLUSTER_ID, - ConfigServerApplication.SERVICE_TYPE, - serviceInstances); - - Set<ServiceCluster> serviceClusters = new HashSet<>(); - serviceClusters.add(serviceCluster); - - ApplicationInstance applicationInstance = new ApplicationInstance( - ConfigServerApplication.TENANT_ID, - ConfigServerApplication.APPLICATION_INSTANCE_ID, - serviceClusters); - - // Fill back-references - serviceCluster.setApplicationInstance(applicationInstance); - for (ServiceInstance serviceInstance : serviceCluster.serviceInstances()) { - serviceInstance.setServiceCluster(serviceCluster); - } - - return applicationInstance; - } - - private ServiceInstance makeServiceInstance(String hostname, ServiceStatusProvider statusProvider) { - ConfigId configId = new ConfigId(ConfigServerApplication.CONFIG_ID_PREFIX + hostname); - ServiceStatus status = statusProvider.getStatus( - ConfigServerApplication.CONFIG_SERVER_APPLICATION.getApplicationId(), - ConfigServerApplication.CLUSTER_ID, - ConfigServerApplication.SERVICE_TYPE, - configId); - - return new ServiceInstance(configId, new HostName(hostname), status); - } -} diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ConfigServerApplication.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ConfigServerApplication.java index 132bb0927b8..5ad38cebcfc 100644 --- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ConfigServerApplication.java +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ConfigServerApplication.java @@ -1,12 +1,26 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.service.monitor.application; +import com.yahoo.cloud.config.ConfigserverConfig; +import com.yahoo.config.model.api.ApplicationInfo; +import com.yahoo.config.model.api.HostInfo; +import com.yahoo.config.model.api.PortInfo; +import com.yahoo.config.model.api.ServiceInfo; import com.yahoo.config.provision.ClusterSpec; import com.yahoo.config.provision.NodeType; import com.yahoo.vespa.applicationmodel.ApplicationInstanceId; import com.yahoo.vespa.applicationmodel.ClusterId; +import com.yahoo.vespa.applicationmodel.ConfigId; import com.yahoo.vespa.applicationmodel.ServiceType; import com.yahoo.vespa.applicationmodel.TenantId; +import com.yahoo.vespa.service.monitor.internal.ModelGenerator; +import com.yahoo.vespa.service.monitor.internal.health.ApplicationHealthMonitor; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; /** * A service/application model of the config server with health status. @@ -21,8 +35,44 @@ public class ConfigServerApplication extends HostedVespaApplication { public static final ServiceType SERVICE_TYPE = new ServiceType("configserver"); public static final String CONFIG_ID_PREFIX = "configid."; + public static ConfigId configIdFrom(int index) { + return new ConfigId(CONFIG_ID_PREFIX + index); + } + private ConfigServerApplication() { super("zone-config-servers", NodeType.config, ClusterSpec.Type.admin, ClusterSpec.Id.from("zone-config-servers")); } + + public ApplicationInfo makeApplicationInfo(ConfigserverConfig config) { + List<HostInfo> hostInfos = new ArrayList<>(); + List<ConfigserverConfig.Zookeeperserver> zooKeeperServers = config.zookeeperserver(); + for (int index = 0; index < zooKeeperServers.size(); ++index) { + String hostname = zooKeeperServers.get(index).hostname(); + hostInfos.add(makeHostInfo(hostname, config.httpport(), index)); + } + + return new ApplicationInfo( + CONFIG_SERVER_APPLICATION.getApplicationId(), + 0, + new HostsModel(hostInfos)); + } + + private static HostInfo makeHostInfo(String hostname, int port, int configIndex) { + PortInfo portInfo = new PortInfo(port, ApplicationHealthMonitor.PORT_TAGS_HEALTH); + + Map<String, String> properties = new HashMap<>(); + properties.put(ModelGenerator.CLUSTER_ID_PROPERTY_NAME, CLUSTER_ID.s()); + + ServiceInfo serviceInfo = new ServiceInfo( + // service name == service type for the first service of each type on each host + SERVICE_TYPE.s(), + SERVICE_TYPE.s(), + Collections.singletonList(portInfo), + properties, + configIdFrom(configIndex).s(), + hostname); + + return new HostInfo(hostname, Collections.singletonList(serviceInfo)); + } } diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/DeployedAppGenerator.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/DeployedAppGenerator.java deleted file mode 100644 index 2691a8bf1ee..00000000000 --- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/DeployedAppGenerator.java +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.service.monitor.application; - -import com.yahoo.config.model.api.ApplicationInfo; -import com.yahoo.config.model.api.HostInfo; -import com.yahoo.config.model.api.ServiceInfo; -import com.yahoo.config.provision.ApplicationId; -import com.yahoo.config.provision.Zone; -import com.yahoo.vespa.applicationmodel.ApplicationInstance; -import com.yahoo.vespa.applicationmodel.ApplicationInstanceId; -import com.yahoo.vespa.applicationmodel.ClusterId; -import com.yahoo.vespa.applicationmodel.ConfigId; -import com.yahoo.vespa.applicationmodel.HostName; -import com.yahoo.vespa.applicationmodel.ServiceCluster; -import com.yahoo.vespa.applicationmodel.ServiceClusterKey; -import com.yahoo.vespa.applicationmodel.ServiceInstance; -import com.yahoo.vespa.applicationmodel.ServiceStatus; -import com.yahoo.vespa.applicationmodel.ServiceType; -import com.yahoo.vespa.applicationmodel.TenantId; -import com.yahoo.vespa.service.monitor.ServiceStatusProvider; - -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * Class to generate an ApplicationInstance given service status for a standard (deployed) application. - * - * @author hakon - */ -public class DeployedAppGenerator implements ApplicationInstanceGenerator { - public static final String CLUSTER_ID_PROPERTY_NAME = "clustername"; - - private final ApplicationInfo applicationInfo; - private final Zone zone; - - public DeployedAppGenerator(ApplicationInfo applicationInfo, Zone zone) { - this.applicationInfo = applicationInfo; - this.zone = zone; - } - - @Override - public ApplicationInstance makeApplicationInstance(ServiceStatusProvider serviceStatusProvider) { - Map<ServiceClusterKey, Set<ServiceInstance>> groupedServiceInstances = new HashMap<>(); - - for (HostInfo host : applicationInfo.getModel().getHosts()) { - HostName hostName = new HostName(host.getHostname()); - for (ServiceInfo serviceInfo : host.getServices()) { - ServiceClusterKey serviceClusterKey = toServiceClusterKey(serviceInfo); - ServiceInstance serviceInstance = - toServiceInstance( - applicationInfo.getApplicationId(), - serviceClusterKey.clusterId(), - serviceInfo, - hostName, - serviceStatusProvider); - - if (!groupedServiceInstances.containsKey(serviceClusterKey)) { - groupedServiceInstances.put(serviceClusterKey, new HashSet<>()); - } - groupedServiceInstances.get(serviceClusterKey).add(serviceInstance); - } - } - - Set<ServiceCluster> serviceClusters = groupedServiceInstances.entrySet().stream() - .map(entry -> new ServiceCluster( - entry.getKey().clusterId(), - entry.getKey().serviceType(), - entry.getValue())) - .collect(Collectors.toSet()); - - ApplicationInstance applicationInstance = new ApplicationInstance( - new TenantId(applicationInfo.getApplicationId().tenant().toString()), - toApplicationInstanceId(applicationInfo, zone), - serviceClusters); - - // Fill back-references - for (ServiceCluster serviceCluster : applicationInstance.serviceClusters()) { - serviceCluster.setApplicationInstance(applicationInstance); - for (ServiceInstance serviceInstance : serviceCluster.serviceInstances()) { - serviceInstance.setServiceCluster(serviceCluster); - } - } - - return applicationInstance; - } - - static ClusterId getClusterId(ServiceInfo serviceInfo) { - return new ClusterId(serviceInfo.getProperty(CLUSTER_ID_PROPERTY_NAME).orElse("")); - } - - private ServiceClusterKey toServiceClusterKey(ServiceInfo serviceInfo) { - ClusterId clusterId = getClusterId(serviceInfo); - ServiceType serviceType = toServiceType(serviceInfo); - return new ServiceClusterKey(clusterId, serviceType); - } - - private ServiceInstance toServiceInstance( - ApplicationId applicationId, - ClusterId clusterId, - ServiceInfo serviceInfo, - HostName hostName, - ServiceStatusProvider serviceStatusProvider) { - ConfigId configId = new ConfigId(serviceInfo.getConfigId()); - - ServiceStatus status = serviceStatusProvider.getStatus( - applicationId, - clusterId, - toServiceType(serviceInfo), configId); - - return new ServiceInstance(configId, hostName, status); - } - - private ApplicationInstanceId toApplicationInstanceId(ApplicationInfo applicationInfo, Zone zone) { - return new ApplicationInstanceId(String.format("%s:%s:%s:%s", - applicationInfo.getApplicationId().application().value(), - zone.environment().value(), - zone.region().value(), - applicationInfo.getApplicationId().instance().value())); - } - - private ServiceType toServiceType(ServiceInfo serviceInfo) { - return new ServiceType(serviceInfo.getServiceType()); - } -} diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/HostsModel.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/HostsModel.java new file mode 100644 index 00000000000..225ffb0adbc --- /dev/null +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/HostsModel.java @@ -0,0 +1,75 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.service.monitor.application; + +import com.yahoo.config.FileReference; +import com.yahoo.config.model.api.FileDistribution; +import com.yahoo.config.model.api.HostInfo; +import com.yahoo.config.model.api.Model; +import com.yahoo.config.provision.AllocatedHosts; +import com.yahoo.vespa.config.ConfigKey; +import com.yahoo.vespa.config.ConfigPayload; +import com.yahoo.vespa.config.buildergen.ConfigDefinition; + +import java.time.Instant; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Set; + +/** + * Model that only supports the subset necessary to create an ApplicationInstance. + * + * @author hakon + */ +public class HostsModel implements Model { + private final Collection<HostInfo> hosts; + + public HostsModel(List<HostInfo> hosts) { + this.hosts = Collections.unmodifiableCollection(hosts); + } + + @Override + public Collection<HostInfo> getHosts() { + return hosts; + } + + @Override + public ConfigPayload getConfig(ConfigKey<?> configKey, ConfigDefinition configDefinition) { + throw new UnsupportedOperationException(); + } + + @Override + public Set<ConfigKey<?>> allConfigsProduced() { + throw new UnsupportedOperationException(); + } + + @Override + public Set<String> allConfigIds() { + throw new UnsupportedOperationException(); + } + + @Override + public void distributeFiles(FileDistribution fileDistribution) { + throw new UnsupportedOperationException(); + } + + @Override + public Set<FileReference> fileReferences() { + throw new UnsupportedOperationException(); + } + + @Override + public AllocatedHosts allocatedHosts() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean allowModelVersionMismatch(Instant now) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean skipOldConfigModels(Instant now) { + throw new UnsupportedOperationException(); + } +} diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ZoneApplication.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ZoneApplication.java index 6bbf0cb6d1d..c10015d3bfa 100644 --- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ZoneApplication.java +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/application/ZoneApplication.java @@ -21,8 +21,8 @@ public class ZoneApplication { .createHostedVespaApplicationId("routing"); public static boolean isNodeAdminService(ApplicationId applicationId, - ClusterId clusterId, - ServiceType serviceType) { + ClusterId clusterId, + ServiceType serviceType) { return Objects.equals(applicationId, ZONE_APPLICATION_ID) && Objects.equals(serviceType, ServiceType.CONTAINER) && Objects.equals(clusterId, ClusterId.NODE_ADMIN); diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/DuperModel.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/DuperModel.java new file mode 100644 index 00000000000..80e0bfd2710 --- /dev/null +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/DuperModel.java @@ -0,0 +1,42 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.service.monitor.internal; + +import com.yahoo.cloud.config.ConfigserverConfig; +import com.yahoo.config.model.api.ApplicationInfo; +import com.yahoo.config.model.api.SuperModel; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static com.yahoo.vespa.service.monitor.application.ConfigServerApplication.CONFIG_SERVER_APPLICATION; + +/** + * The {@code DuperModel} unites the {@link com.yahoo.config.model.api.SuperModel SuperModel} + * with the synthetically produced applications like the config server application. + * + * @author hakon + */ +public class DuperModel { + private final List<ApplicationInfo> staticApplicationInfos = new ArrayList<>(); + + public DuperModel(ConfigserverConfig configServerConfig) { + // Single-tenant applications have the config server as part of the application model. + // TODO: Add health monitoring for config server when part of application model. + if (configServerConfig.multitenant()) { + staticApplicationInfos.add(CONFIG_SERVER_APPLICATION.makeApplicationInfo(configServerConfig)); + } + } + + /** For testing. */ + DuperModel(ApplicationInfo... staticApplicationInfos) { + this.staticApplicationInfos.addAll(Arrays.asList(staticApplicationInfos)); + } + + public List<ApplicationInfo> getApplicationInfos(SuperModel superModelSnapshot) { + List<ApplicationInfo> allApplicationInfos = new ArrayList<>(); + allApplicationInfos.addAll(staticApplicationInfos); + allApplicationInfos.addAll(superModelSnapshot.getAllApplicationInfos()); + return allApplicationInfos; + } +} diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/DuperModelListener.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/DuperModelListener.java new file mode 100644 index 00000000000..235c7db5c36 --- /dev/null +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/DuperModelListener.java @@ -0,0 +1,28 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.service.monitor.internal; + +import com.yahoo.config.model.api.ApplicationInfo; +import com.yahoo.config.model.api.SuperModel; +import com.yahoo.config.provision.ApplicationId; + +/** + * Interface for listening for changes to the {@link DuperModel}. + * + * @author hakon + */ +public interface DuperModelListener { + /** + * An application has been activated: + * + * <ul> + * <li>A synthetic application like the config server application has been added/"activated" + * <li>A super model application has been activated (see + * {@link com.yahoo.config.model.api.SuperModelListener#applicationActivated(SuperModel, ApplicationInfo) + * SuperModelListener} + * </ul> + */ + void applicationActivated(ApplicationInfo application); + + /** Application has been removed. */ + void applicationRemoved(ApplicationId id); +} diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ModelGenerator.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ModelGenerator.java index 9da449289a7..ad2f223acf8 100644 --- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ModelGenerator.java +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ModelGenerator.java @@ -1,56 +1,40 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.service.monitor.internal; -import com.yahoo.config.model.api.SuperModel; +import com.yahoo.config.model.api.ApplicationInfo; import com.yahoo.config.provision.Zone; import com.yahoo.vespa.applicationmodel.ApplicationInstance; import com.yahoo.vespa.applicationmodel.ApplicationInstanceReference; import com.yahoo.vespa.service.monitor.ServiceModel; import com.yahoo.vespa.service.monitor.ServiceStatusProvider; import com.yahoo.vespa.service.monitor.application.ApplicationInstanceGenerator; -import com.yahoo.vespa.service.monitor.application.ConfigServerAppGenerator; -import com.yahoo.vespa.service.monitor.application.DeployedAppGenerator; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.function.Function; import java.util.stream.Collectors; /** - * Util to convert SuperModel to ServiceModel and application model classes + * Util to make ServiceModel and its related application model classes */ public class ModelGenerator { public static final String CLUSTER_ID_PROPERTY_NAME = "clustername"; - private final List<ApplicationInstanceGenerator> staticGenerators; - - public ModelGenerator(List<String> configServerHosts) { - if (configServerHosts.isEmpty()) { - staticGenerators = Collections.emptyList(); - } else { - staticGenerators = Collections.singletonList(new ConfigServerAppGenerator(configServerHosts)); - } - } - /** * Create service model based primarily on super model. * * If the configServerhosts is non-empty, a config server application is added. */ - ServiceModel toServiceModel( - SuperModel superModel, - Zone zone, - ServiceStatusProvider serviceStatusProvider) { - List<ApplicationInstanceGenerator> generators = new ArrayList<>(staticGenerators); - superModel.getAllApplicationInfos() - .forEach(info -> generators.add(new DeployedAppGenerator(info, zone))); - - Map<ApplicationInstanceReference, ApplicationInstance> applicationInstances = generators.stream() - .map(generator -> generator.makeApplicationInstance(serviceStatusProvider)) - .collect(Collectors.toMap(ApplicationInstance::reference, Function.identity())); + public ServiceModel toServiceModel(List<ApplicationInfo> allApplicationInfos, + Zone zone, + ServiceStatusProvider serviceStatusProvider) { + Map<ApplicationInstanceReference, ApplicationInstance> applicationInstances = + allApplicationInfos.stream() + .map(info -> new ApplicationInstanceGenerator(info, zone) + .makeApplicationInstance(serviceStatusProvider)) + .collect(Collectors.toMap(ApplicationInstance::reference, Function.identity())); return new ServiceModel(applicationInstances); } + } diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/MonitorManager.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/MonitorManager.java index 49863672c43..1edf3a18215 100644 --- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/MonitorManager.java +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/MonitorManager.java @@ -1,11 +1,10 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.service.monitor.internal;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.service.monitor.internal; -import com.yahoo.config.model.api.SuperModelListener; import com.yahoo.vespa.service.monitor.ServiceStatusProvider; /** * @author hakon */ -public interface MonitorManager extends SuperModelListener, ServiceStatusProvider { +public interface MonitorManager extends DuperModelListener, ServiceStatusProvider { } diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ServiceId.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ServiceId.java new file mode 100644 index 00000000000..993ea7fed5c --- /dev/null +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ServiceId.java @@ -0,0 +1,75 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.service.monitor.internal; + +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.vespa.applicationmodel.ClusterId; +import com.yahoo.vespa.applicationmodel.ConfigId; +import com.yahoo.vespa.applicationmodel.ServiceType; + +import javax.annotation.concurrent.Immutable; +import java.util.Objects; + +/** + * Identifies a service. + * + * @author hakon + */ +@Immutable +public class ServiceId { + private final ApplicationId applicationId; + private final ClusterId clusterId; + private final ServiceType serviceType; + private final ConfigId configId; + + public ServiceId(ApplicationId applicationId, + ClusterId clusterId, + ServiceType serviceType, + ConfigId configId) { + this.applicationId = applicationId; + this.clusterId = clusterId; + this.serviceType = serviceType; + this.configId = configId; + } + + public ApplicationId getApplicationId() { + return applicationId; + } + + public ClusterId getClusterId() { + return clusterId; + } + + public ServiceType getServiceType() { + return serviceType; + } + + public ConfigId getConfigId() { + return configId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ServiceId serviceId = (ServiceId) o; + return Objects.equals(applicationId, serviceId.applicationId) && + Objects.equals(clusterId, serviceId.clusterId) && + Objects.equals(serviceType, serviceId.serviceType) && + Objects.equals(configId, serviceId.configId); + } + + @Override + public int hashCode() { + return Objects.hash(applicationId, clusterId, serviceType, configId); + } + + @Override + public String toString() { + return "ServiceId{" + + "applicationId=" + applicationId + + ", clusterId=" + clusterId + + ", serviceType=" + serviceType + + ", configId=" + configId + + '}'; + } +} diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ServiceMonitorImpl.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ServiceMonitorImpl.java index 97c4fdda0f3..bd8fd4a50e0 100644 --- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ServiceMonitorImpl.java +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/ServiceMonitorImpl.java @@ -14,10 +14,7 @@ import com.yahoo.vespa.service.monitor.ServiceMonitor; import com.yahoo.vespa.service.monitor.internal.health.HealthMonitorManager; import com.yahoo.vespa.service.monitor.internal.slobrok.SlobrokMonitorManagerImpl; -import java.util.Collections; -import java.util.List; import java.util.Map; -import java.util.stream.Collectors; public class ServiceMonitorImpl implements ServiceMonitor { private final ServiceModelCache serviceModelCache; @@ -32,30 +29,20 @@ public class ServiceMonitorImpl implements ServiceMonitor { Zone zone = superModelProvider.getZone(); ServiceMonitorMetrics metrics = new ServiceMonitorMetrics(metric, timer); - UnionMonitorManager monitorManager = new UnionMonitorManager( - slobrokMonitorManager, - healthMonitorManager, - configserverConfig); + DuperModel duperModel = new DuperModel(configserverConfig); + UnionMonitorManager monitorManager = + new UnionMonitorManager(slobrokMonitorManager, healthMonitorManager); SuperModelListenerImpl superModelListener = new SuperModelListenerImpl( monitorManager, metrics, - new ModelGenerator(toConfigServerList(configserverConfig)), + duperModel, + new ModelGenerator(), zone); superModelListener.start(superModelProvider); serviceModelCache = new ServiceModelCache(superModelListener, timer); } - private List<String> toConfigServerList(ConfigserverConfig configserverConfig) { - if (configserverConfig.multitenant()) { - return configserverConfig.zookeeperserver().stream() - .map(ConfigserverConfig.Zookeeperserver::hostname) - .collect(Collectors.toList()); - } - - return Collections.emptyList(); - } - @Override public Map<ApplicationInstanceReference, ApplicationInstance> getAllApplicationInstances() { return serviceModelCache.get().getAllApplicationInstances(); diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImpl.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImpl.java index b2f3617131b..f509809c33d 100644 --- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImpl.java +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImpl.java @@ -8,7 +8,9 @@ import com.yahoo.config.model.api.SuperModelProvider; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.Zone; import com.yahoo.vespa.service.monitor.ServiceModel; +import com.yahoo.vespa.service.monitor.ServiceStatusProvider; +import java.util.List; import java.util.function.Supplier; import java.util.logging.Logger; @@ -16,6 +18,7 @@ public class SuperModelListenerImpl implements SuperModelListener, Supplier<Serv private static final Logger logger = Logger.getLogger(SuperModelListenerImpl.class.getName()); private final ServiceMonitorMetrics metrics; + private final DuperModel duperModel; private final ModelGenerator modelGenerator; private final Zone zone; @@ -27,10 +30,12 @@ public class SuperModelListenerImpl implements SuperModelListener, Supplier<Serv SuperModelListenerImpl(MonitorManager monitorManager, ServiceMonitorMetrics metrics, + DuperModel duperModel, ModelGenerator modelGenerator, Zone zone) { this.monitorManager = monitorManager; this.metrics = metrics; + this.duperModel = duperModel; this.modelGenerator = modelGenerator; this.zone = zone; } @@ -41,8 +46,7 @@ public class SuperModelListenerImpl implements SuperModelListener, Supplier<Serv // since applicationActivated()/applicationRemoved() may be called // asynchronously even before snapshot() returns. this.superModel = superModelProvider.snapshot(this); - superModel.getAllApplicationInfos().stream().forEach(application -> - monitorManager.applicationActivated(superModel, application)); + duperModel.getApplicationInfos(superModel).forEach(monitorManager::applicationActivated); } } @@ -50,7 +54,7 @@ public class SuperModelListenerImpl implements SuperModelListener, Supplier<Serv public void applicationActivated(SuperModel superModel, ApplicationInfo application) { synchronized (monitor) { this.superModel = superModel; - monitorManager.applicationActivated(superModel, application); + monitorManager.applicationActivated(application); } } @@ -58,7 +62,7 @@ public class SuperModelListenerImpl implements SuperModelListener, Supplier<Serv public void applicationRemoved(SuperModel superModel, ApplicationId id) { synchronized (monitor) { this.superModel = superModel; - monitorManager.applicationRemoved(superModel, id); + monitorManager.applicationRemoved(id); } } @@ -71,7 +75,9 @@ public class SuperModelListenerImpl implements SuperModelListener, Supplier<Serv dummy(measurement); // WARNING: The slobrok monitor manager may be out-of-sync with super model (no locking) - return modelGenerator.toServiceModel(superModel, zone, monitorManager); + List<ApplicationInfo> applicationInfos = duperModel.getApplicationInfos(superModel); + + return modelGenerator.toServiceModel(applicationInfos, zone, (ServiceStatusProvider) monitorManager); } } diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManager.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManager.java index 82d2043bd17..81cf6f2af5e 100644 --- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManager.java +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManager.java @@ -1,16 +1,12 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.service.monitor.internal; -import com.yahoo.cloud.config.ConfigserverConfig; import com.yahoo.config.model.api.ApplicationInfo; -import com.yahoo.config.model.api.SuperModel; import com.yahoo.config.provision.ApplicationId; import com.yahoo.vespa.applicationmodel.ClusterId; import com.yahoo.vespa.applicationmodel.ConfigId; import com.yahoo.vespa.applicationmodel.ServiceStatus; import com.yahoo.vespa.applicationmodel.ServiceType; -import com.yahoo.vespa.service.monitor.application.ConfigServerApplication; -import com.yahoo.vespa.service.monitor.application.ZoneApplication; import com.yahoo.vespa.service.monitor.internal.health.HealthMonitorManager; import com.yahoo.vespa.service.monitor.internal.slobrok.SlobrokMonitorManagerImpl; @@ -20,14 +16,11 @@ import com.yahoo.vespa.service.monitor.internal.slobrok.SlobrokMonitorManagerImp public class UnionMonitorManager implements MonitorManager { private final SlobrokMonitorManagerImpl slobrokMonitorManager; private final HealthMonitorManager healthMonitorManager; - private final ConfigserverConfig configserverConfig; UnionMonitorManager(SlobrokMonitorManagerImpl slobrokMonitorManager, - HealthMonitorManager healthMonitorManager, - ConfigserverConfig configserverConfig) { + HealthMonitorManager healthMonitorManager) { this.slobrokMonitorManager = slobrokMonitorManager; this.healthMonitorManager = healthMonitorManager; - this.configserverConfig = configserverConfig; } @Override @@ -35,33 +28,25 @@ public class UnionMonitorManager implements MonitorManager { ClusterId clusterId, ServiceType serviceType, ConfigId configId) { - - if (applicationId.equals(ConfigServerApplication.CONFIG_SERVER_APPLICATION.getApplicationId())) { - // todo: use health - return ServiceStatus.NOT_CHECKED; + // Trust the new health monitoring status if it actually monitors the particular service. + ServiceStatus status = healthMonitorManager.getStatus(applicationId, clusterId, serviceType, configId); + if (status != ServiceStatus.NOT_CHECKED) { + return status; } - MonitorManager monitorManager = useHealth(applicationId, clusterId, serviceType) ? - healthMonitorManager : - slobrokMonitorManager; - - return monitorManager.getStatus(applicationId, clusterId, serviceType, configId); + // fallback is the older slobrok + return slobrokMonitorManager.getStatus(applicationId, clusterId, serviceType, configId); } @Override - public void applicationActivated(SuperModel superModel, ApplicationInfo application) { - slobrokMonitorManager.applicationActivated(superModel, application); - healthMonitorManager.applicationActivated(superModel, application); + public void applicationActivated(ApplicationInfo application) { + slobrokMonitorManager.applicationActivated(application); + healthMonitorManager.applicationActivated(application); } @Override - public void applicationRemoved(SuperModel superModel, ApplicationId id) { - slobrokMonitorManager.applicationRemoved(superModel, id); - healthMonitorManager.applicationRemoved(superModel, id); - } - - private boolean useHealth(ApplicationId applicationId, ClusterId clusterId, ServiceType serviceType) { - return !configserverConfig.nodeAdminInContainer() && - ZoneApplication.isNodeAdminService(applicationId, clusterId, serviceType); + public void applicationRemoved(ApplicationId id) { + slobrokMonitorManager.applicationRemoved(id); + healthMonitorManager.applicationRemoved(id); } } diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/ApplicationHealthMonitor.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/ApplicationHealthMonitor.java new file mode 100644 index 00000000000..bd2658db8aa --- /dev/null +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/ApplicationHealthMonitor.java @@ -0,0 +1,102 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.service.monitor.internal.health; + +import com.yahoo.config.model.api.ApplicationInfo; +import com.yahoo.config.model.api.HostInfo; +import com.yahoo.config.model.api.PortInfo; +import com.yahoo.config.model.api.ServiceInfo; +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.HostName; +import com.yahoo.vespa.applicationmodel.ClusterId; +import com.yahoo.vespa.applicationmodel.ConfigId; +import com.yahoo.vespa.applicationmodel.ServiceStatus; +import com.yahoo.vespa.applicationmodel.ServiceType; +import com.yahoo.vespa.service.monitor.ServiceStatusProvider; +import com.yahoo.vespa.service.monitor.application.ApplicationInstanceGenerator; +import com.yahoo.vespa.service.monitor.internal.ServiceId; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Responsible for monitoring a whole application using /state/v1/health. + * + * @author hakon + */ +public class ApplicationHealthMonitor implements ServiceStatusProvider, AutoCloseable { + public static final String PORT_TAG_STATE = "STATE"; + public static final String PORT_TAG_HTTP = "HTTP"; + /** Port tags implying /state/v1/health is served */ + public static final List<String> PORT_TAGS_HEALTH = + Collections.unmodifiableList(Arrays.asList(PORT_TAG_HTTP, PORT_TAG_STATE)); + + private final Map<ServiceId, HealthMonitor> healthMonitors; + + public static ApplicationHealthMonitor startMonitoring(ApplicationInfo application) { + return new ApplicationHealthMonitor(makeHealthMonitors(application)); + } + + private ApplicationHealthMonitor(Map<ServiceId, HealthMonitor> healthMonitors) { + this.healthMonitors = healthMonitors; + } + + @Override + public ServiceStatus getStatus(ApplicationId applicationId, + ClusterId clusterId, + ServiceType serviceType, + ConfigId configId) { + ServiceId serviceId = new ServiceId(applicationId, clusterId, serviceType, configId); + HealthMonitor monitor = healthMonitors.get(serviceId); + if (monitor == null) { + return ServiceStatus.NOT_CHECKED; + } + + return monitor.getStatus(); + } + + @Override + public void close() { + healthMonitors.values().forEach(HealthMonitor::close); + healthMonitors.clear(); + } + + private static Map<ServiceId, HealthMonitor> makeHealthMonitors(ApplicationInfo application) { + Map<ServiceId, HealthMonitor> healthMonitors = new HashMap<>(); + for (HostInfo hostInfo : application.getModel().getHosts()) { + for (ServiceInfo serviceInfo : hostInfo.getServices()) { + for (PortInfo portInfo : serviceInfo.getPorts()) { + maybeCreateHealthMonitor( + application, + hostInfo, + serviceInfo, + portInfo) + .ifPresent(healthMonitor -> healthMonitors.put( + ApplicationInstanceGenerator.getServiceId(application, serviceInfo), + healthMonitor)); + } + } + } + return healthMonitors; + } + + private static Optional<HealthMonitor> maybeCreateHealthMonitor( + ApplicationInfo applicationInfo, + HostInfo hostInfo, + ServiceInfo serviceInfo, + PortInfo portInfo) { + if (portInfo.getTags().containsAll(PORT_TAGS_HEALTH)) { + HostName hostname = HostName.from(hostInfo.getHostname()); + HealthEndpoint endpoint = HealthEndpoint.forHttp(hostname, portInfo.getPort()); + // todo: make HealthMonitor + // HealthMonitor healthMonitor = new HealthMonitor(endpoint); + // healthMonitor.startMonitoring(); + return Optional.empty(); + } + + return Optional.empty(); + } +} diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthClient.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthClient.java new file mode 100644 index 00000000000..43a02a385be --- /dev/null +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthClient.java @@ -0,0 +1,139 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.service.monitor.internal.health; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.yahoo.vespa.athenz.api.AthenzService; +import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider; +import org.apache.http.HttpEntity; +import org.apache.http.HttpResponse; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.config.Registry; +import org.apache.http.config.RegistryBuilder; +import org.apache.http.conn.ConnectionKeepAliveStrategy; +import org.apache.http.conn.HttpClientConnectionManager; +import org.apache.http.conn.socket.ConnectionSocketFactory; +import org.apache.http.conn.ssl.SSLConnectionSocketFactory; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.DefaultConnectionKeepAliveStrategy; +import org.apache.http.impl.client.HttpClients; +import org.apache.http.impl.conn.BasicHttpClientConnectionManager; +import org.apache.http.protocol.HttpContext; +import org.apache.http.util.EntityUtils; + +import javax.net.ssl.SSLContext; + +/** + * @author hakon + */ +public class HealthClient implements AutoCloseable, ServiceIdentityProvider.Listener { + private static final ObjectMapper mapper = new ObjectMapper(); + private static final long MAX_CONTENT_LENGTH = 1L << 20; // 1 MB + private static final int DEFAULT_TIMEOUT_MILLIS = 1_000; + + private static final ConnectionKeepAliveStrategy KEEP_ALIVE_STRATEGY = + new DefaultConnectionKeepAliveStrategy() { + @Override + public long getKeepAliveDuration(HttpResponse response, HttpContext context) { + long keepAlive = super.getKeepAliveDuration(response, context); + if (keepAlive == -1) { + // Keep connections alive 60 seconds if a keep-alive value + // has not be explicitly set by the server + keepAlive = 60000; + } + return keepAlive; + } + }; + + private final HealthEndpoint endpoint; + + private volatile CloseableHttpClient httpClient; + + public HealthClient(HealthEndpoint endpoint) { + this.endpoint = endpoint; + } + + public void start() { + endpoint.getServiceIdentityProvider().ifPresent(provider -> { + onCredentialsUpdate(provider.getIdentitySslContext(), null); + provider.addIdentityListener(this); + }); + } + + @Override + public void onCredentialsUpdate(SSLContext sslContext, AthenzService ignored) { + SSLConnectionSocketFactory socketFactory = + new SSLConnectionSocketFactory(sslContext, endpoint.getHostnameVerifier().orElse(null)); + + Registry<ConnectionSocketFactory> registry = RegistryBuilder.<ConnectionSocketFactory>create() + .register("https", socketFactory) + .build(); + + HttpClientConnectionManager connectionManager = new BasicHttpClientConnectionManager(registry); + + RequestConfig requestConfig = RequestConfig.custom() + .setConnectTimeout(DEFAULT_TIMEOUT_MILLIS) // establishment of connection + .setConnectionRequestTimeout(DEFAULT_TIMEOUT_MILLIS) // connection from connection manager + .setSocketTimeout(DEFAULT_TIMEOUT_MILLIS) // waiting for data + .build(); + + this.httpClient = HttpClients.custom() + .setKeepAliveStrategy(KEEP_ALIVE_STRATEGY) + .setConnectionManager(connectionManager) + .disableAutomaticRetries() + .setDefaultRequestConfig(requestConfig) + .build(); + } + + public HealthInfo getHealthInfo() { + try { + return probeHealth(); + } catch (Exception e) { + return HealthInfo.fromException(e); + } + } + + @Override + public void close() { + endpoint.getServiceIdentityProvider().ifPresent(provider -> provider.removeIdentityListener(this)); + + try { + httpClient.close(); + } catch (Exception e) { + // ignore + } + httpClient = null; + } + + private HealthInfo probeHealth() throws Exception { + HttpGet httpget = new HttpGet(endpoint.getStateV1HealthUrl().toString()); + CloseableHttpResponse httpResponse; + + CloseableHttpClient httpClient = this.httpClient; + if (httpClient == null) { + throw new IllegalStateException("HTTP client has closed"); + } + + httpResponse = httpClient.execute(httpget); + + int httpStatusCode = httpResponse.getStatusLine().getStatusCode(); + if (httpStatusCode < 200 || httpStatusCode >= 300) { + return HealthInfo.fromBadHttpStatusCode(httpStatusCode); + } + + HttpEntity bodyEntity = httpResponse.getEntity(); + long contentLength = bodyEntity.getContentLength(); + if (contentLength > MAX_CONTENT_LENGTH) { + throw new IllegalArgumentException("Content too long: " + contentLength + " bytes"); + } + String body = EntityUtils.toString(bodyEntity); + HealthResponse healthResponse = mapper.readValue(body, HealthResponse.class); + + if (healthResponse.status == null || healthResponse.status.code == null) { + return HealthInfo.fromHealthStatusCode(HealthResponse.Status.DEFAULT_STATUS); + } else { + return HealthInfo.fromHealthStatusCode(healthResponse.status.code); + } + } +} diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthEndpoint.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthEndpoint.java new file mode 100644 index 00000000000..e9d17a9ab70 --- /dev/null +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthEndpoint.java @@ -0,0 +1,57 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.service.monitor.internal.health; + +import com.yahoo.config.provision.HostName; +import com.yahoo.vespa.athenz.api.AthenzIdentity; +import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider; +import com.yahoo.vespa.athenz.tls.AthenzIdentityVerifier; + +import javax.net.ssl.HostnameVerifier; +import java.net.URL; +import java.util.Collections; +import java.util.Optional; + +import static com.yahoo.yolean.Exceptions.uncheck; + +/** + * @author hakon + */ +class HealthEndpoint { + private final URL url; + private final Optional<HostnameVerifier> hostnameVerifier; + private final Optional<ServiceIdentityProvider> serviceIdentityProvider; + + static HealthEndpoint forHttp(HostName hostname, int port) { + URL url = uncheck(() -> new URL("http", hostname.value(), port, "/state/v1/health")); + return new HealthEndpoint(url, Optional.empty(), Optional.empty()); + } + + static HealthEndpoint forHttps(HostName hostname, + int port, + ServiceIdentityProvider serviceIdentityProvider, + AthenzIdentity remoteIdentity) { + URL url = uncheck(() -> new URL("https", hostname.value(), port, "/state/v1/health")); + HostnameVerifier peerVerifier = new AthenzIdentityVerifier(Collections.singleton(remoteIdentity)); + return new HealthEndpoint(url, Optional.of(serviceIdentityProvider), Optional.of(peerVerifier)); + } + + private HealthEndpoint(URL url, + Optional<ServiceIdentityProvider> serviceIdentityProvider, + Optional<HostnameVerifier> hostnameVerifier) { + this.url = url; + this.serviceIdentityProvider = serviceIdentityProvider; + this.hostnameVerifier = hostnameVerifier; + } + + public URL getStateV1HealthUrl() { + return url; + } + + public Optional<ServiceIdentityProvider> getServiceIdentityProvider() { + return serviceIdentityProvider; + } + + public Optional<HostnameVerifier> getHostnameVerifier() { + return hostnameVerifier; + } +} diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthInfo.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthInfo.java new file mode 100644 index 00000000000..a3fe3cb3106 --- /dev/null +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthInfo.java @@ -0,0 +1,75 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.service.monitor.internal.health; + +import com.yahoo.vespa.applicationmodel.ServiceStatus; +import com.yahoo.yolean.Exceptions; + +import java.time.Instant; +import java.util.Optional; +import java.util.OptionalInt; + +/** + * The result of a health lookup. + * + * @author hakon + */ +public class HealthInfo { + public static final String UP_STATUS_CODE = "up"; + + private final Optional<Exception> exception; + private final OptionalInt httpStatusCode; + private final Optional<String> healthStatusCode; + private final Instant time; + + static HealthInfo fromException(Exception exception) { + return new HealthInfo(Optional.of(exception), OptionalInt.empty(), Optional.empty()); + } + + static HealthInfo fromBadHttpStatusCode(int httpStatusCode) { + return new HealthInfo(Optional.empty(), OptionalInt.of(httpStatusCode), Optional.empty()); + } + + static HealthInfo fromHealthStatusCode(String healthStatusCode) { + return new HealthInfo(Optional.empty(), OptionalInt.empty(), Optional.of(healthStatusCode)); + } + + static HealthInfo empty() { + return new HealthInfo(Optional.empty(), OptionalInt.empty(), Optional.empty()); + } + + private HealthInfo(Optional<Exception> exception, + OptionalInt httpStatusCode, + Optional<String> healthStatusCode) { + this.exception = exception; + this.httpStatusCode = httpStatusCode; + this.healthStatusCode = healthStatusCode; + this.time = Instant.now(); + } + + public boolean isHealthy() { + return healthStatusCode.map(UP_STATUS_CODE::equals).orElse(false); + } + + public ServiceStatus toSerivceStatus() { + return isHealthy() ? ServiceStatus.UP : ServiceStatus.DOWN; + } + + public Instant time() { + return time; + } + + @Override + public String toString() { + if (isHealthy()) { + return UP_STATUS_CODE; + } else if (healthStatusCode.isPresent()) { + return "Bad health status code '" + healthStatusCode.get() + "'"; + } else if (exception.isPresent()) { + return Exceptions.toMessageString(exception.get()); + } else if (httpStatusCode.isPresent()) { + return "Bad HTTP response status code " + httpStatusCode.getAsInt(); + } else { + return "No health info available"; + } + } +} diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitor.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitor.java new file mode 100644 index 00000000000..fd809b32918 --- /dev/null +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitor.java @@ -0,0 +1,73 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.service.monitor.internal.health; + +import com.yahoo.log.LogLevel; +import com.yahoo.vespa.applicationmodel.ServiceStatus; + +import java.time.Duration; +import java.util.Random; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; + +/** + * Used to monitor the health of a single URL endpoint. + * + * @author hakon + */ +public class HealthMonitor implements AutoCloseable { + private static final Logger logger = Logger.getLogger(HealthMonitor.class.getName()); + private static final Duration DELAY = Duration.ofSeconds(20); + // About 'static': Javadoc says "Instances of java.util.Random are threadsafe." + private static final Random random = new Random(); + + private final ScheduledThreadPoolExecutor executor = new ScheduledThreadPoolExecutor(1); + private final HealthClient healthClient; + + private volatile HealthInfo lastHealthInfo = HealthInfo.empty(); + + public HealthMonitor(HealthEndpoint stateV1HealthEndpoint) { + this.healthClient = new HealthClient(stateV1HealthEndpoint); + } + + /** For testing. */ + HealthMonitor(HealthClient healthClient) { + this.healthClient = healthClient; + } + + public void startMonitoring() { + healthClient.start(); + executor.scheduleWithFixedDelay( + this::updateSynchronously, + initialDelayInSeconds(DELAY.getSeconds()), + DELAY.getSeconds(), + TimeUnit.SECONDS); + } + + public ServiceStatus getStatus() { + // todo: return lastHealthInfo.toServiceStatus(); + return ServiceStatus.NOT_CHECKED; + } + + @Override + public void close() { + executor.shutdown(); + + try { + executor.awaitTermination(2, TimeUnit.SECONDS); + } catch (InterruptedException e) { + logger.log(LogLevel.INFO, "Interrupted while waiting for health monitor termination: " + + e.getMessage()); + } + + healthClient.close(); + } + + private long initialDelayInSeconds(long maxInitialDelayInSeconds) { + return random.nextLong() % maxInitialDelayInSeconds; + } + + private void updateSynchronously() { + lastHealthInfo = healthClient.getHealthInfo(); + } +} diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorManager.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorManager.java index 5a4b7251ae2..473ef5e3a94 100644 --- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorManager.java +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorManager.java @@ -2,8 +2,8 @@ package com.yahoo.vespa.service.monitor.internal.health; import com.google.inject.Inject; +import com.yahoo.cloud.config.ConfigserverConfig; import com.yahoo.config.model.api.ApplicationInfo; -import com.yahoo.config.model.api.SuperModel; import com.yahoo.config.provision.ApplicationId; import com.yahoo.vespa.applicationmodel.ClusterId; import com.yahoo.vespa.applicationmodel.ConfigId; @@ -12,19 +12,38 @@ import com.yahoo.vespa.applicationmodel.ServiceType; import com.yahoo.vespa.service.monitor.application.ZoneApplication; import com.yahoo.vespa.service.monitor.internal.MonitorManager; +import java.util.HashMap; +import java.util.Map; + /** * @author hakon */ public class HealthMonitorManager implements MonitorManager { + private final Map<ApplicationId, ApplicationHealthMonitor> healthMonitors = new HashMap<>(); + private final ConfigserverConfig configserverConfig; + @Inject - public HealthMonitorManager() {} + public HealthMonitorManager(ConfigserverConfig configserverConfig) { + this.configserverConfig = configserverConfig; + } @Override - public void applicationActivated(SuperModel superModel, ApplicationInfo application) { + public void applicationActivated(ApplicationInfo application) { + if (applicationMonitored(application.getApplicationId())) { + ApplicationHealthMonitor monitor = + ApplicationHealthMonitor.startMonitoring(application); + healthMonitors.put(application.getApplicationId(), monitor); + } } @Override - public void applicationRemoved(SuperModel superModel, ApplicationId id) { + public void applicationRemoved(ApplicationId id) { + if (applicationMonitored(id)) { + ApplicationHealthMonitor monitor = healthMonitors.remove(id); + if (monitor != null) { + monitor.close(); + } + } } @Override @@ -32,13 +51,18 @@ public class HealthMonitorManager implements MonitorManager { ClusterId clusterId, ServiceType serviceType, ConfigId configId) { - // TODO: Do proper health check - if (ZoneApplication.isNodeAdminService(applicationId, clusterId, serviceType)) { + if (!configserverConfig.nodeAdminInContainer() && + ZoneApplication.isNodeAdminService(applicationId, clusterId, serviceType)) { + // If node admin doesn't run in a JDisc container, it must be monitored with health. + // TODO: Do proper health check return ServiceStatus.UP; } - throw new IllegalArgumentException("Health monitoring not implemented for application " + - applicationId.toShortString() + ", cluster " + clusterId.s() + ", serviceType " + - serviceType); + return ServiceStatus.NOT_CHECKED; + } + + private boolean applicationMonitored(ApplicationId id) { + // todo: health-check config server + return false; } } diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthResponse.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthResponse.java new file mode 100644 index 00000000000..574523ad564 --- /dev/null +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/health/HealthResponse.java @@ -0,0 +1,35 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.service.monitor.internal.health; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.yahoo.text.JSON; + +/** + * Response entity from /state/v1/health + * + * @author hakon + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public class HealthResponse { + @JsonProperty("status") + public Status status = new Status(); + + @JsonIgnoreProperties(ignoreUnknown = true) + public static class Status { + public static final String DEFAULT_STATUS = "down"; + + @JsonProperty("code") + public String code = DEFAULT_STATUS; + + @Override + public String toString() { + return "{ \"code\": \"" + JSON.escape(code) + "\" }"; + } + } + + @Override + public String toString() { + return "{ \"status\": " + status.toString() + " }"; + } +} diff --git a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImpl.java b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImpl.java index aaaab22e742..68958c94dfd 100644 --- a/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImpl.java +++ b/service-monitor/src/main/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImpl.java @@ -3,8 +3,6 @@ package com.yahoo.vespa.service.monitor.internal.slobrok; import com.google.inject.Inject; import com.yahoo.config.model.api.ApplicationInfo; -import com.yahoo.config.model.api.SuperModel; -import com.yahoo.config.model.api.SuperModelListener; import com.yahoo.config.provision.ApplicationId; import com.yahoo.jrt.slobrok.api.Mirror; import com.yahoo.log.LogLevel; @@ -13,6 +11,7 @@ import com.yahoo.vespa.applicationmodel.ConfigId; import com.yahoo.vespa.applicationmodel.ServiceStatus; import com.yahoo.vespa.applicationmodel.ServiceType; import com.yahoo.vespa.service.monitor.SlobrokApi; +import com.yahoo.vespa.service.monitor.application.ConfigServerApplication; import com.yahoo.vespa.service.monitor.internal.MonitorManager; import java.util.HashMap; @@ -21,7 +20,7 @@ import java.util.Optional; import java.util.function.Supplier; import java.util.logging.Logger; -public class SlobrokMonitorManagerImpl implements SuperModelListener, SlobrokApi, MonitorManager { +public class SlobrokMonitorManagerImpl implements SlobrokApi, MonitorManager { private static final Logger logger = Logger.getLogger(SlobrokMonitorManagerImpl.class.getName()); @@ -40,7 +39,11 @@ public class SlobrokMonitorManagerImpl implements SuperModelListener, SlobrokApi } @Override - public void applicationActivated(SuperModel superModel, ApplicationInfo application) { + public void applicationActivated(ApplicationInfo application) { + if (!applicationMonitoredWithSlobrok(application.getApplicationId())) { + return; + } + synchronized (monitor) { SlobrokMonitor slobrokMonitor = slobrokMonitors.computeIfAbsent( application.getApplicationId(), @@ -50,7 +53,11 @@ public class SlobrokMonitorManagerImpl implements SuperModelListener, SlobrokApi } @Override - public void applicationRemoved(SuperModel superModel, ApplicationId id) { + public void applicationRemoved(ApplicationId id) { + if (!applicationMonitoredWithSlobrok(id)) { + return; + } + synchronized (monitor) { SlobrokMonitor slobrokMonitor = slobrokMonitors.remove(id); if (slobrokMonitor == null) { @@ -79,6 +86,10 @@ public class SlobrokMonitorManagerImpl implements SuperModelListener, SlobrokApi ClusterId clusterId, ServiceType serviceType, ConfigId configId) { + if (!applicationMonitoredWithSlobrok(applicationId)) { + return ServiceStatus.NOT_CHECKED; + } + Optional<String> slobrokServiceName = findSlobrokServiceName(serviceType, configId); if (slobrokServiceName.isPresent()) { synchronized (monitor) { @@ -95,6 +106,14 @@ public class SlobrokMonitorManagerImpl implements SuperModelListener, SlobrokApi } } + private boolean applicationMonitoredWithSlobrok(ApplicationId applicationId) { + if (applicationId.equals(ConfigServerApplication.CONFIG_SERVER_APPLICATION.getApplicationId())) { + return false; + } + + return true; + } + /** * Get the Slobrok service name of the service, or empty if the service * is not registered with Slobrok. diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/application/ConfigServerAppGeneratorTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/application/ApplicationInstanceGeneratorTest.java index 58f99786017..899cc59bb34 100644 --- a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/application/ConfigServerAppGeneratorTest.java +++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/application/ApplicationInstanceGeneratorTest.java @@ -1,22 +1,27 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.service.monitor.application; +import com.yahoo.cloud.config.ConfigserverConfig; +import com.yahoo.config.model.api.ApplicationInfo; +import com.yahoo.config.provision.Zone; import com.yahoo.vespa.applicationmodel.ApplicationInstance; import com.yahoo.vespa.applicationmodel.ServiceStatus; import com.yahoo.vespa.service.monitor.ServiceStatusProvider; +import com.yahoo.vespa.service.monitor.internal.ConfigserverUtil; import org.junit.Test; import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; +import static com.yahoo.vespa.service.monitor.application.ConfigServerApplication.CONFIG_SERVER_APPLICATION; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class ConfigServerAppGeneratorTest { +public class ApplicationInstanceGeneratorTest { private static final String configServer1 = "cfg1.yahoo.com"; private static final String configServer2 = "cfg2.yahoo.com"; private static final String configServer3 = "cfg3.yahoo.com"; @@ -28,9 +33,17 @@ public class ConfigServerAppGeneratorTest { private final ServiceStatusProvider statusProvider = mock(ServiceStatusProvider.class); @Test - public void toApplicationInstance() throws Exception { + public void toApplicationInstance() { when(statusProvider.getStatus(any(), any(), any(), any())).thenReturn(ServiceStatus.NOT_CHECKED); - ApplicationInstance applicationInstance = new ConfigServerAppGenerator(configServerList) + ConfigserverConfig config = ConfigserverUtil.create( + true, + true, + configServer1, + configServer2, + configServer3); + Zone zone = mock(Zone.class); + ApplicationInfo configServer = CONFIG_SERVER_APPLICATION.makeApplicationInfo(config); + ApplicationInstance applicationInstance = new ApplicationInstanceGenerator(configServer, zone) .makeApplicationInstance(statusProvider); assertEquals( diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/ConfigserverUtil.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/ConfigserverUtil.java new file mode 100644 index 00000000000..85df02949a6 --- /dev/null +++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/ConfigserverUtil.java @@ -0,0 +1,52 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.service.monitor.internal; + +import com.yahoo.cloud.config.ConfigserverConfig; +import com.yahoo.config.model.api.ApplicationInfo; +import com.yahoo.vespa.service.monitor.application.ConfigServerApplication; + +/** + * @author hakon + */ +public class ConfigserverUtil { + /** Create a ConfigserverConfig with the given settings. */ + public static ConfigserverConfig create( + boolean nodeAdminInContainer, + boolean multitenant, + String configServerHostname1, + String configServerHostname2, + String configServerHostname3) { + return new ConfigserverConfig( + new ConfigserverConfig.Builder() + .nodeAdminInContainer(nodeAdminInContainer) + .multitenant(multitenant) + .zookeeperserver(new ConfigserverConfig.Zookeeperserver.Builder().hostname(configServerHostname1).port(1)) + .zookeeperserver(new ConfigserverConfig.Zookeeperserver.Builder().hostname(configServerHostname2).port(2)) + .zookeeperserver(new ConfigserverConfig.Zookeeperserver.Builder().hostname(configServerHostname3).port(3))); + } + + public static ConfigserverConfig createExampleConfigserverConfig() { + return create(true, true, "cfg1", "cfg2", "cfg3"); + } + + public static ConfigserverConfig createExampleConfigserverConfig(boolean nodeAdminInContainer, + boolean multitenant) { + return create(nodeAdminInContainer, multitenant, "cfg1", "cfg2", "cfg3"); + } + + public static ApplicationInfo makeConfigServerApplicationInfo( + String configServerHostname1, + String configServerHostname2, + String configServerHostname3) { + return ConfigServerApplication.CONFIG_SERVER_APPLICATION.makeApplicationInfo(create( + true, + true, + configServerHostname1, + configServerHostname2, + configServerHostname3)); + } + + public static ApplicationInfo makeExampleConfigServer() { + return makeConfigServerApplicationInfo("cfg1", "cfg2", "cfg3"); + } +} diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/DuperModelTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/DuperModelTest.java new file mode 100644 index 00000000000..c9d19d0ccd9 --- /dev/null +++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/DuperModelTest.java @@ -0,0 +1,53 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.service.monitor.internal; + +import com.yahoo.cloud.config.ConfigserverConfig; +import com.yahoo.config.model.api.ApplicationInfo; +import com.yahoo.config.model.api.SuperModel; +import com.yahoo.vespa.applicationmodel.ServiceStatus; +import com.yahoo.vespa.service.monitor.ServiceStatusProvider; +import com.yahoo.vespa.service.monitor.application.ConfigServerApplication; +import org.junit.Test; + +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * @author hakon + */ +public class DuperModelTest { + private final ServiceStatusProvider statusProvider = mock(ServiceStatusProvider.class); + + @Test + public void toApplicationInstance() { + when(statusProvider.getStatus(any(), any(), any(), any())).thenReturn(ServiceStatus.NOT_CHECKED); + ConfigserverConfig config = ConfigserverUtil.createExampleConfigserverConfig(); + DuperModel duperModel = new DuperModel(config); + SuperModel superModel = mock(SuperModel.class); + ApplicationInfo superModelApplicationInfo = mock(ApplicationInfo.class); + when(superModel.getAllApplicationInfos()).thenReturn(Collections.singletonList(superModelApplicationInfo)); + List<ApplicationInfo> applicationInfos = duperModel.getApplicationInfos(superModel); + assertEquals(2, applicationInfos.size()); + assertEquals(ConfigServerApplication.CONFIG_SERVER_APPLICATION.getApplicationId(), applicationInfos.get(0).getApplicationId()); + assertSame(superModelApplicationInfo, applicationInfos.get(1)); + } + + @Test + public void toApplicationInstanceInSingleTenantMode() { + when(statusProvider.getStatus(any(), any(), any(), any())).thenReturn(ServiceStatus.NOT_CHECKED); + ConfigserverConfig config = ConfigserverUtil.createExampleConfigserverConfig(true, false); + DuperModel duperModel = new DuperModel(config); + SuperModel superModel = mock(SuperModel.class); + ApplicationInfo superModelApplicationInfo = mock(ApplicationInfo.class); + when(superModel.getAllApplicationInfos()).thenReturn(Collections.singletonList(superModelApplicationInfo)); + List<ApplicationInfo> applicationInfos = duperModel.getApplicationInfos(superModel); + assertEquals(1, applicationInfos.size()); + assertSame(superModelApplicationInfo, applicationInfos.get(0)); + } +} diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/ModelGeneratorTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/ModelGeneratorTest.java index a21691ee4d0..5a57451a298 100644 --- a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/ModelGeneratorTest.java +++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/ModelGeneratorTest.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.service.monitor.internal; +import com.yahoo.cloud.config.ConfigserverConfig; import com.yahoo.config.model.api.SuperModel; import com.yahoo.config.provision.Environment; import com.yahoo.config.provision.RegionName; @@ -15,13 +16,9 @@ import com.yahoo.vespa.service.monitor.application.ConfigServerApplication; import com.yahoo.vespa.service.monitor.internal.slobrok.SlobrokMonitorManagerImpl; import org.junit.Test; -import java.util.Collections; import java.util.Iterator; -import java.util.List; import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.Stream; import static org.junit.Assert.assertEquals; import static org.mockito.Matchers.any; @@ -35,13 +32,12 @@ public class ModelGeneratorTest { private final int PORT = 2; @Test - public void toApplicationModelWithConfigServerApplication() throws Exception { - SuperModel superModel = - ExampleModel.createExampleSuperModelWithOneRpcPort(HOSTNAME, PORT); + public void toApplicationModel() throws Exception { + SuperModel superModel = ExampleModel.createExampleSuperModelWithOneRpcPort(HOSTNAME, PORT); - List<String> configServerHosts = Stream.of("cfg1", "cfg2", "cfg3") - .collect(Collectors.toList()); - ModelGenerator modelGenerator = new ModelGenerator(configServerHosts); + ConfigserverConfig config = ConfigserverUtil.createExampleConfigserverConfig(); + DuperModel duperModel = new DuperModel(config); + ModelGenerator modelGenerator = new ModelGenerator(); Zone zone = new Zone(Environment.from(ENVIRONMENT), RegionName.from(REGION)); @@ -51,7 +47,7 @@ public class ModelGeneratorTest { ServiceModel serviceModel = modelGenerator.toServiceModel( - superModel, + duperModel.getApplicationInfos(superModel), zone, slobrokMonitorManager); @@ -78,32 +74,6 @@ public class ModelGeneratorTest { } } - @Test - public void toApplicationModel() throws Exception { - SuperModel superModel = - ExampleModel.createExampleSuperModelWithOneRpcPort(HOSTNAME, PORT); - ModelGenerator modelGenerator = new ModelGenerator(Collections.emptyList()); - - Zone zone = new Zone(Environment.from(ENVIRONMENT), RegionName.from(REGION)); - - SlobrokMonitorManagerImpl slobrokMonitorManager = mock(SlobrokMonitorManagerImpl.class); - when(slobrokMonitorManager.getStatus(any(), any(), any(), any())) - .thenReturn(ServiceStatus.UP); - - ServiceModel serviceModel = - modelGenerator.toServiceModel( - superModel, - zone, - slobrokMonitorManager); - - Map<ApplicationInstanceReference, - ApplicationInstance> applicationInstances = - serviceModel.getAllApplicationInstances(); - - assertEquals(1, applicationInstances.size()); - verifyOtherApplication(applicationInstances.values().iterator().next()); - } - private void verifyOtherApplication(ApplicationInstance applicationInstance) { assertEquals(String.format("%s:%s:%s:%s:%s", ExampleModel.TENANT, diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImplTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImplTest.java index 83bad0ddb2a..eb6d6d583f7 100644 --- a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImplTest.java +++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/SuperModelListenerImplTest.java @@ -14,6 +14,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -22,11 +23,13 @@ public class SuperModelListenerImplTest { public void sanityCheck() { SlobrokMonitorManagerImpl slobrokMonitorManager = mock(SlobrokMonitorManagerImpl.class); ServiceMonitorMetrics metrics = mock(ServiceMonitorMetrics.class); + DuperModel duperModel = mock(DuperModel.class); ModelGenerator modelGenerator = mock(ModelGenerator.class); Zone zone = mock(Zone.class); SuperModelListenerImpl listener = new SuperModelListenerImpl( slobrokMonitorManager, metrics, + duperModel, modelGenerator, zone); @@ -38,13 +41,15 @@ public class SuperModelListenerImplTest { ApplicationInfo application2 = mock(ApplicationInfo.class); List<ApplicationInfo> applications = Stream.of(application1, application2) .collect(Collectors.toList()); - when(superModel.getAllApplicationInfos()).thenReturn(applications); + when(duperModel.getApplicationInfos(superModel)).thenReturn(applications); listener.start(superModelProvider); - verify(slobrokMonitorManager).applicationActivated(superModel, application1); - verify(slobrokMonitorManager).applicationActivated(superModel, application2); + verify(duperModel, times(1)).getApplicationInfos(superModel); + verify(slobrokMonitorManager).applicationActivated(application1); + verify(slobrokMonitorManager).applicationActivated(application2); ServiceModel serviceModel = listener.get(); - verify(modelGenerator).toServiceModel(superModel, zone, slobrokMonitorManager); + verify(duperModel, times(2)).getApplicationInfos(superModel); + verify(modelGenerator).toServiceModel(applications, zone, slobrokMonitorManager); } }
\ No newline at end of file diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManagerTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManagerTest.java index b7c3ed8e1e1..79916e43712 100644 --- a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManagerTest.java +++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/UnionMonitorManagerTest.java @@ -1,95 +1,44 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.service.monitor.internal; -import com.yahoo.cloud.config.ConfigserverConfig; -import com.yahoo.config.provision.ApplicationId; -import com.yahoo.vespa.applicationmodel.ClusterId; import com.yahoo.vespa.applicationmodel.ConfigId; -import com.yahoo.vespa.applicationmodel.ServiceType; +import com.yahoo.vespa.applicationmodel.ServiceStatus; import com.yahoo.vespa.service.monitor.internal.health.HealthMonitorManager; import com.yahoo.vespa.service.monitor.internal.slobrok.SlobrokMonitorManagerImpl; import org.junit.Test; import static com.yahoo.vespa.applicationmodel.ClusterId.NODE_ADMIN; +import static com.yahoo.vespa.applicationmodel.ServiceStatus.*; +import static com.yahoo.vespa.applicationmodel.ServiceStatus.NOT_CHECKED; +import static com.yahoo.vespa.applicationmodel.ServiceStatus.UP; import static com.yahoo.vespa.applicationmodel.ServiceType.CONTAINER; import static com.yahoo.vespa.service.monitor.application.ZoneApplication.ZONE_APPLICATION_ID; +import static org.junit.Assert.assertSame; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class UnionMonitorManagerTest { - @Test - public void nodeAdminInContainer() { - testWith( - true, - ZONE_APPLICATION_ID, - NODE_ADMIN, - CONTAINER, - 1, - 0); - } - - @Test - public void nodeAdminOutsideContainer() { - boolean inContainer = false; - - // When nodeAdminInContainer is set, then only the node admin cluster should use health - testWith( - inContainer, - ZONE_APPLICATION_ID, - NODE_ADMIN, - CONTAINER, - 0, - 1); - - testWith( - inContainer, - ApplicationId.fromSerializedForm("a:b:default"), - NODE_ADMIN, - CONTAINER, - 1, - 0); + private final SlobrokMonitorManagerImpl slobrokMonitorManager = mock(SlobrokMonitorManagerImpl.class); + private final HealthMonitorManager healthMonitorManager = mock(HealthMonitorManager.class); - testWith( - inContainer, - ZONE_APPLICATION_ID, - new ClusterId("foo"), - CONTAINER, - 1, - 0); + private final UnionMonitorManager manager = new UnionMonitorManager( + slobrokMonitorManager, + healthMonitorManager); - testWith( - inContainer, - ZONE_APPLICATION_ID, - NODE_ADMIN, - new ServiceType("foo"), - 1, - 0); + @Test + public void verifyHealthTakesPriority() { + testWith(UP, DOWN, UP); + testWith(NOT_CHECKED, DOWN, DOWN); + testWith(NOT_CHECKED, NOT_CHECKED, NOT_CHECKED); } - private void testWith(boolean nodeAdminInContainer, - ApplicationId applicationId, - ClusterId clusterId, - ServiceType serviceType, - int expectedSlobrokCalls, - int expectedHealthCalls) { - SlobrokMonitorManagerImpl slobrokMonitorManager = mock(SlobrokMonitorManagerImpl.class); - HealthMonitorManager healthMonitorManager = mock(HealthMonitorManager.class); - - ConfigserverConfig.Builder builder = new ConfigserverConfig.Builder(); - builder.nodeAdminInContainer(nodeAdminInContainer); - ConfigserverConfig config = new ConfigserverConfig(builder); - - - UnionMonitorManager manager = new UnionMonitorManager( - slobrokMonitorManager, - healthMonitorManager, - config); - - manager.getStatus(applicationId, clusterId, serviceType, new ConfigId("config-id")); - - verify(slobrokMonitorManager, times(expectedSlobrokCalls)).getStatus(any(), any(), any(), any()); - verify(healthMonitorManager, times(expectedHealthCalls)).getStatus(any(), any(), any(), any()); + private void testWith(ServiceStatus healthStatus, + ServiceStatus slobrokStatus, + ServiceStatus expectedStatus) { + when(healthMonitorManager.getStatus(any(), any(), any(), any())).thenReturn(healthStatus); + when(slobrokMonitorManager.getStatus(any(), any(), any(), any())).thenReturn(slobrokStatus); + ServiceStatus status = manager.getStatus(ZONE_APPLICATION_ID, NODE_ADMIN, CONTAINER, new ConfigId("config-id")); + assertSame(expectedStatus, status); } }
\ No newline at end of file diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/ApplicationHealthMonitorTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/ApplicationHealthMonitorTest.java new file mode 100644 index 00000000000..51b0503565f --- /dev/null +++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/ApplicationHealthMonitorTest.java @@ -0,0 +1,24 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.service.monitor.internal.health; + +import com.yahoo.vespa.applicationmodel.ServiceStatus; +import com.yahoo.vespa.service.monitor.application.ConfigServerApplication; +import com.yahoo.vespa.service.monitor.internal.ConfigserverUtil; +import org.junit.Test; + +import static com.yahoo.vespa.applicationmodel.ServiceStatus.NOT_CHECKED; +import static org.junit.Assert.assertEquals; + +public class ApplicationHealthMonitorTest { + @Test + public void sanityCheck() { + ApplicationHealthMonitor monitor = ApplicationHealthMonitor.startMonitoring( + ConfigserverUtil.makeExampleConfigServer()); + ServiceStatus status = monitor.getStatus( + ConfigServerApplication.CONFIG_SERVER_APPLICATION.getApplicationId(), + ConfigServerApplication.CLUSTER_ID, + ConfigServerApplication.SERVICE_TYPE, + ConfigServerApplication.configIdFrom(0)); + assertEquals(NOT_CHECKED, status); + } +}
\ No newline at end of file diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorManagerTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorManagerTest.java new file mode 100644 index 00000000000..b9d25406f9b --- /dev/null +++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorManagerTest.java @@ -0,0 +1,49 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.service.monitor.internal.health; + +import com.yahoo.cloud.config.ConfigserverConfig; +import com.yahoo.config.model.api.ApplicationInfo; +import com.yahoo.vespa.applicationmodel.ClusterId; +import com.yahoo.vespa.applicationmodel.ConfigId; +import com.yahoo.vespa.applicationmodel.ServiceStatus; +import com.yahoo.vespa.applicationmodel.ServiceType; +import com.yahoo.vespa.service.monitor.application.ZoneApplication; +import com.yahoo.vespa.service.monitor.internal.ConfigserverUtil; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class HealthMonitorManagerTest { + @Test + public void addRemove() { + ConfigserverConfig config = ConfigserverUtil.createExampleConfigserverConfig(); + HealthMonitorManager manager = new HealthMonitorManager(config); + ApplicationInfo applicationInfo = ConfigserverUtil.makeExampleConfigServer(); + manager.applicationActivated(applicationInfo); + manager.applicationRemoved(applicationInfo.getApplicationId()); + } + + @Test + public void withNodeAdmin() { + ConfigserverConfig config = ConfigserverUtil.createExampleConfigserverConfig(); + HealthMonitorManager manager = new HealthMonitorManager(config); + ServiceStatus status = manager.getStatus( + ZoneApplication.ZONE_APPLICATION_ID, + ClusterId.NODE_ADMIN, + ServiceType.CONTAINER, + new ConfigId("config-id-1")); + assertEquals(ServiceStatus.NOT_CHECKED, status); + } + + @Test + public void withHostAdmin() { + ConfigserverConfig config = ConfigserverUtil.createExampleConfigserverConfig(false, true); + HealthMonitorManager manager = new HealthMonitorManager(config); + ServiceStatus status = manager.getStatus( + ZoneApplication.ZONE_APPLICATION_ID, + ClusterId.NODE_ADMIN, + ServiceType.CONTAINER, + new ConfigId("config-id-1")); + assertEquals(ServiceStatus.UP, status); + } +}
\ No newline at end of file diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorTest.java new file mode 100644 index 00000000000..cca1530ad97 --- /dev/null +++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/health/HealthMonitorTest.java @@ -0,0 +1,21 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.service.monitor.internal.health; + +import com.yahoo.vespa.applicationmodel.ServiceStatus; +import org.junit.Test; + +import java.net.MalformedURLException; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; + +public class HealthMonitorTest { + @Test + public void basicTests() throws MalformedURLException { + HealthClient healthClient = mock(HealthClient.class); + try (HealthMonitor monitor = new HealthMonitor(healthClient)) { + monitor.startMonitoring(); + assertEquals(ServiceStatus.NOT_CHECKED, monitor.getStatus()); + } + } +}
\ No newline at end of file diff --git a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImplTest.java b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImplTest.java index 8e4443df83b..a567559980b 100644 --- a/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImplTest.java +++ b/service-monitor/src/test/java/com/yahoo/vespa/service/monitor/internal/slobrok/SlobrokMonitorManagerImplTest.java @@ -2,7 +2,7 @@ package com.yahoo.vespa.service.monitor.internal.slobrok; import com.yahoo.config.model.api.ApplicationInfo; -import com.yahoo.config.model.api.SuperModel; +import com.yahoo.config.provision.ApplicationId; import com.yahoo.vespa.applicationmodel.ClusterId; import com.yahoo.vespa.applicationmodel.ConfigId; import com.yahoo.vespa.applicationmodel.ServiceStatus; @@ -28,18 +28,19 @@ public class SlobrokMonitorManagerImplTest { private final SlobrokMonitorManagerImpl slobrokMonitorManager = new SlobrokMonitorManagerImpl(slobrokMonitorFactory); private final SlobrokMonitor slobrokMonitor = mock(SlobrokMonitor.class); - private final SuperModel superModel = mock(SuperModel.class); + private final ApplicationId applicationId = ApplicationId.from("tenant", "app", "instance"); private final ApplicationInfo application = mock(ApplicationInfo.class); private final ClusterId clusterId = new ClusterId("cluster-id"); @Before public void setup() { when(slobrokMonitorFactory.get()).thenReturn(slobrokMonitor); + when(application.getApplicationId()).thenReturn(applicationId); } @Test public void testActivationOfApplication() { - slobrokMonitorManager.applicationActivated(superModel, application); + slobrokMonitorManager.applicationActivated(application); verify(slobrokMonitorFactory, times(1)).get(); } @@ -51,14 +52,14 @@ public class SlobrokMonitorManagerImplTest { @Test public void testGetStatus_ApplicationInSlobrok() { - slobrokMonitorManager.applicationActivated(superModel, application); + slobrokMonitorManager.applicationActivated(application); when(slobrokMonitor.registeredInSlobrok("config.id")).thenReturn(true); assertEquals(ServiceStatus.UP, getStatus("topleveldispatch")); } @Test public void testGetStatus_ServiceNotInSlobrok() { - slobrokMonitorManager.applicationActivated(superModel, application); + slobrokMonitorManager.applicationActivated(application); when(slobrokMonitor.registeredInSlobrok("config.id")).thenReturn(false); assertEquals(ServiceStatus.DOWN, getStatus("topleveldispatch")); } diff --git a/valgrind-suppressions.txt b/valgrind-suppressions.txt index 2df6c9c5691..2587552ceff 100644 --- a/valgrind-suppressions.txt +++ b/valgrind-suppressions.txt @@ -339,3 +339,20 @@ fun:__static_initialization_and_destruction_0 ... } +{ + Apparent memory leak on Fedora 28. + Memcheck:Leak + match-leak-kinds: possible + fun:malloc + fun:tsearch + fun:__add_to_environ + fun:setenv +} +{ + Apparent memory leak on Fedora 28. + Memcheck:Leak + match-leak-kinds: possible + fun:malloc + fun:__add_to_environ + fun:setenv +} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java index 1504119d9cc..ab127b19bf1 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java @@ -10,8 +10,13 @@ import com.yahoo.vespa.athenz.identityprovider.api.bindings.SignedIdentityDocume import com.yahoo.vespa.athenz.identityprovider.api.bindings.VespaUniqueInstanceIdEntity; import com.yahoo.vespa.athenz.utils.AthenzIdentities; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Path; import java.util.Base64; +import static com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId.fromDottedString; + /** * Utility class for mapping objects model types and their Jackson binding versions. * @@ -33,7 +38,7 @@ public class EntityBindingsMapper { public static VespaUniqueInstanceId toVespaUniqueInstanceId(VespaUniqueInstanceIdEntity entity) { return new VespaUniqueInstanceId( - entity.clusterIndex, entity.clusterId, entity.instance, entity.application, entity.tenant, entity.region, entity.environment); + entity.clusterIndex, entity.clusterId, entity.instance, entity.application, entity.tenant, entity.region, entity.environment, entity.type != null ? IdentityType.fromId(entity.type) : null); // TODO Remove support for legacy representation without type } public static IdentityDocument toIdentityDocument(IdentityDocumentEntity entity) { @@ -50,17 +55,22 @@ public class EntityBindingsMapper { toIdentityDocument(entity.identityDocument), entity.signature, entity.signingKeyVersion, - VespaUniqueInstanceId.fromDottedString(entity.providerUniqueId), + fromDottedString(entity.providerUniqueId), entity.dnsSuffix, (AthenzService) AthenzIdentities.from(entity.providerService), entity.ztsEndpoint, - entity.documentVersion); + entity.documentVersion, + entity.configServerHostname, + entity.instanceHostname, + entity.createdAt, + entity.ipAddresses, + entity.identityType != null ? IdentityType.fromId(entity.identityType) : null); // TODO Remove support for legacy representation without type } public static VespaUniqueInstanceIdEntity toVespaUniqueInstanceIdEntity(VespaUniqueInstanceId model) { return new VespaUniqueInstanceIdEntity( model.tenant(), model.application(), model.environment(), model.region(), - model.instance(), model.clusterId(), model.clusterIndex()); + model.instance(), model.clusterId(), model.clusterIndex(), model.type() != null ? model.type().id() : null); // TODO Remove support for legacy representation without type } public static IdentityDocumentEntity toIdentityDocumentEntity(IdentityDocument model) { @@ -84,10 +94,33 @@ public class EntityBindingsMapper { model.dnsSuffix(), model.providerService().getFullName(), model.ztsEndpoint(), - model.documentVersion()); + model.documentVersion(), + model.configServerHostname(), + model.instanceHostname(), + model.createdAt(), + model.ipAddresses(), + model.identityType() != null ? model.identityType().id() : null); // TODO Remove support for legacy representation without type } catch (JsonProcessingException e) { throw new RuntimeException(e); } } + public static SignedIdentityDocument readSignedIdentityDocumentFromFile(Path file) { + try { + SignedIdentityDocumentEntity entity = mapper.readValue(file.toFile(), SignedIdentityDocumentEntity.class); + return EntityBindingsMapper.toSignedIdentityDocument(entity); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public static void writeSignedIdentityDocumentToFile(Path file, SignedIdentityDocument document) { + try { + SignedIdentityDocumentEntity entity = EntityBindingsMapper.toSignedIdentityDocumentEntity(document); + mapper.writeValue(file.toFile(), entity); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocument.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocument.java index 8da2bd0a343..82d0a3d622c 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocument.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocument.java @@ -8,7 +8,9 @@ import java.util.Set; * The identity document that contains the instance specific information * * @author bjorncs + * @deprecated Will soon be inlined into {@link SignedIdentityDocument} */ +@Deprecated public class IdentityDocument { private final VespaUniqueInstanceId providerUniqueId; private final String configServerHostname; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityType.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityType.java new file mode 100644 index 00000000000..4ca2e34a618 --- /dev/null +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityType.java @@ -0,0 +1,25 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.athenz.identityprovider.api; + +import java.util.Arrays; + +/** + * Represents the types of identities that the configserver can provide. + * + * @author bjorncs + */ +public enum IdentityType {TENANT("tenant"), NODE("node"); + private final String id; + + IdentityType(String id) { this.id = id; } + + public String id() { return id; } + + public static IdentityType fromId(String id) { + return Arrays.stream(values()) + .filter(v -> v.id.equals(id)) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("Invalid id: " + id)); + } +} + diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java index d184efc0221..60be42544c7 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java @@ -4,6 +4,8 @@ package com.yahoo.vespa.athenz.identityprovider.api; import com.yahoo.vespa.athenz.api.AthenzService; import java.net.URI; +import java.time.Instant; +import java.util.Set; /** * A signed identity document which contains a {@link IdentityDocument} @@ -22,6 +24,11 @@ public class SignedIdentityDocument { private final AthenzService providerService; private final URI ztsEndpoint; private final int documentVersion; + private final String configServerHostname; + private final String instanceHostname; + private final Instant createdAt; + private final Set<String> ipAddresses; + private final IdentityType identityType; public SignedIdentityDocument(IdentityDocument identityDocument, String signature, @@ -30,7 +37,12 @@ public class SignedIdentityDocument { String dnsSuffix, AthenzService providerService, URI ztsEndpoint, - int documentVersion) { + int documentVersion, + String configServerHostname, + String instanceHostname, + Instant createdAt, + Set<String> ipAddresses, + IdentityType identityType) { this.identityDocument = identityDocument; this.signature = signature; this.signingKeyVersion = signingKeyVersion; @@ -39,6 +51,11 @@ public class SignedIdentityDocument { this.providerService = providerService; this.ztsEndpoint = ztsEndpoint; this.documentVersion = documentVersion; + this.configServerHostname = configServerHostname; + this.instanceHostname = instanceHostname; + this.createdAt = createdAt; + this.ipAddresses = ipAddresses; + this.identityType = identityType; } public IdentityDocument identityDocument() { @@ -72,4 +89,24 @@ public class SignedIdentityDocument { public int documentVersion() { return documentVersion; } + + public String configServerHostname() { + return configServerHostname; + } + + public String instanceHostname() { + return instanceHostname; + } + + public Instant createdAt() { + return createdAt; + } + + public Set<String> ipAddresses() { + return ipAddresses; + } + + public IdentityType identityType() { + return identityType; + } } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceId.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceId.java index 5539ba53882..be94cc59691 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceId.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceId.java @@ -4,6 +4,8 @@ package com.yahoo.vespa.athenz.identityprovider.api; import java.util.Objects; /** + * Represents the unique instance id as used in Vespa's integration with Athenz Copper Argos + * * @author bjorncs */ public class VespaUniqueInstanceId { @@ -15,6 +17,7 @@ public class VespaUniqueInstanceId { private final String tenant; private final String region; private final String environment; + private final IdentityType type; public VespaUniqueInstanceId(int clusterIndex, String clusterId, @@ -22,7 +25,8 @@ public class VespaUniqueInstanceId { String application, String tenant, String region, - String environment) { + String environment, + IdentityType type) { this.clusterIndex = clusterIndex; this.clusterId = clusterId; this.instance = instance; @@ -30,21 +34,43 @@ public class VespaUniqueInstanceId { this.tenant = tenant; this.region = region; this.environment = environment; + this.type = type; } + // TODO Remove support for legacy representation without type + @Deprecated + public VespaUniqueInstanceId(int clusterIndex, + String clusterId, + String instance, + String application, + String tenant, + String region, + String environment) { + this(clusterIndex, clusterId, instance, application, tenant, region, environment, null); + } + + + // TODO Remove support for legacy representation without type public static VespaUniqueInstanceId fromDottedString(String instanceId) { String[] tokens = instanceId.split("\\."); - if (tokens.length != 7) { + if (tokens.length != 7 && tokens.length != 8) { throw new IllegalArgumentException("Invalid instance id: " + instanceId); } return new VespaUniqueInstanceId( - Integer.parseInt(tokens[0]), tokens[1], tokens[2], tokens[3], tokens[4], tokens[5], tokens[6]); + Integer.parseInt(tokens[0]), tokens[1], tokens[2], tokens[3], tokens[4], tokens[5], tokens[6], tokens.length == 8 ? IdentityType.fromId(tokens[7]) : null); } + // TODO Remove support for legacy representation without type public String asDottedString() { - return String.format( - "%d.%s.%s.%s.%s.%s.%s", - clusterIndex, clusterId, instance, application, tenant, region, environment); + if (type != null) { + return String.format( + "%d.%s.%s.%s.%s.%s.%s.%s", + clusterIndex, clusterId, instance, application, tenant, region, environment, type.id()); + } else { + return String.format( + "%d.%s.%s.%s.%s.%s.%s", + clusterIndex, clusterId, instance, application, tenant, region, environment); + } } public int clusterIndex() { @@ -75,6 +101,8 @@ public class VespaUniqueInstanceId { return environment; } + public IdentityType type() { return type; } + @Override public String toString() { return "VespaUniqueInstanceId{" + @@ -85,6 +113,7 @@ public class VespaUniqueInstanceId { ", tenant='" + tenant + '\'' + ", region='" + region + '\'' + ", environment='" + environment + '\'' + + ", type=" + type + '}'; } @@ -99,11 +128,12 @@ public class VespaUniqueInstanceId { Objects.equals(application, that.application) && Objects.equals(tenant, that.tenant) && Objects.equals(region, that.region) && - Objects.equals(environment, that.environment); + Objects.equals(environment, that.environment) && + type == that.type; } @Override public int hashCode() { - return Objects.hash(clusterIndex, clusterId, instance, application, tenant, region, environment); + return Objects.hash(clusterIndex, clusterId, instance, application, tenant, region, environment, type); } } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentApi.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentApi.java index 775a49349a3..fc5392411c1 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentApi.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentApi.java @@ -5,7 +5,6 @@ import javax.ws.rs.GET; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; -import javax.ws.rs.QueryParam; import javax.ws.rs.core.MediaType; /** @@ -16,11 +15,6 @@ public interface IdentityDocumentApi { @GET @Produces(MediaType.APPLICATION_JSON) - @Deprecated - SignedIdentityDocumentEntity getIdentityDocument(@QueryParam("hostname") String hostname); - - @GET - @Produces(MediaType.APPLICATION_JSON) @Path("/node/{host}") SignedIdentityDocumentEntity getNodeIdentityDocument(@PathParam("host") String host); diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentEntity.java index 58a4f1e24bf..b4b2e82ab0e 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentEntity.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentEntity.java @@ -10,8 +10,10 @@ import java.util.Set; /** * @author bjorncs + * @deprecated Will soon be inlined into {@link SignedIdentityDocumentEntity} */ @JsonIgnoreProperties(ignoreUnknown = true) +@Deprecated public class IdentityDocumentEntity { @JsonProperty("provider-unique-id") diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java index e397b81ef9e..aa514b3caf3 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java @@ -11,8 +11,10 @@ import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import java.io.IOException; import java.io.UncheckedIOException; import java.net.URI; +import java.time.Instant; import java.util.Base64; import java.util.Objects; +import java.util.Set; /** * @author bjorncs @@ -31,6 +33,11 @@ public class SignedIdentityDocumentEntity { @JsonProperty("provider-service") public final String providerService; @JsonProperty("zts-endpoint") public final URI ztsEndpoint; @JsonProperty("document-version") public final int documentVersion; + @JsonProperty("configserver-hostname") public final String configServerHostname; + @JsonProperty("instance-hostname") public final String instanceHostname; + @JsonProperty("created-at") public final Instant createdAt; + @JsonProperty("ip-addresses") public final Set<String> ipAddresses; + @JsonProperty("identity-type") public final String identityType; @JsonCreator public SignedIdentityDocumentEntity(@JsonProperty("identity-document") String rawIdentityDocument, @@ -40,7 +47,12 @@ public class SignedIdentityDocumentEntity { @JsonProperty("dns-suffix") String dnsSuffix, @JsonProperty("provider-service") String providerService, @JsonProperty("zts-endpoint") URI ztsEndpoint, - @JsonProperty("document-version") int documentVersion) { + @JsonProperty("document-version") int documentVersion, + @JsonProperty("configserver-hostname") String configServerHostname, + @JsonProperty("instance-hostname") String instanceHostname, + @JsonProperty("created-at") Instant createdAt, + @JsonProperty("ip-addresses") Set<String> ipAddresses, + @JsonProperty("identity-type") String identityType) { this.rawIdentityDocument = rawIdentityDocument; this.identityDocument = parseIdentityDocument(rawIdentityDocument); this.signature = signature; @@ -50,6 +62,11 @@ public class SignedIdentityDocumentEntity { this.providerService = providerService; this.ztsEndpoint = ztsEndpoint; this.documentVersion = documentVersion; + this.configServerHostname = configServerHostname; + this.instanceHostname = instanceHostname; + this.createdAt = createdAt; + this.ipAddresses = ipAddresses; + this.identityType = identityType; } private static IdentityDocumentEntity parseIdentityDocument(String rawIdentityDocument) { @@ -73,7 +90,16 @@ public class SignedIdentityDocumentEntity { ", identityDocument=" + identityDocument + ", signature='" + signature + '\'' + ", signingKeyVersion=" + signingKeyVersion + + ", providerUniqueId='" + providerUniqueId + '\'' + + ", dnsSuffix='" + dnsSuffix + '\'' + + ", providerService='" + providerService + '\'' + + ", ztsEndpoint=" + ztsEndpoint + ", documentVersion=" + documentVersion + + ", configServerHostname='" + configServerHostname + '\'' + + ", instanceHostname='" + instanceHostname + '\'' + + ", createdAt=" + createdAt + + ", ipAddresses=" + ipAddresses + + ", identityType=" + identityType + '}'; } @@ -86,11 +112,20 @@ public class SignedIdentityDocumentEntity { documentVersion == that.documentVersion && Objects.equals(rawIdentityDocument, that.rawIdentityDocument) && Objects.equals(identityDocument, that.identityDocument) && - Objects.equals(signature, that.signature); + Objects.equals(signature, that.signature) && + Objects.equals(providerUniqueId, that.providerUniqueId) && + Objects.equals(dnsSuffix, that.dnsSuffix) && + Objects.equals(providerService, that.providerService) && + Objects.equals(ztsEndpoint, that.ztsEndpoint) && + Objects.equals(configServerHostname, that.configServerHostname) && + Objects.equals(instanceHostname, that.instanceHostname) && + Objects.equals(createdAt, that.createdAt) && + Objects.equals(ipAddresses, that.ipAddresses) && + Objects.equals(identityType, identityType); } @Override public int hashCode() { - return Objects.hash(rawIdentityDocument, identityDocument, signature, signingKeyVersion, documentVersion); + return Objects.hash(rawIdentityDocument, identityDocument, signature, signingKeyVersion, providerUniqueId, dnsSuffix, providerService, ztsEndpoint, documentVersion, configServerHostname, instanceHostname, createdAt, ipAddresses, identityType); } } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/VespaUniqueInstanceIdEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/VespaUniqueInstanceIdEntity.java index 3c521e992ad..3fdbb49b28e 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/VespaUniqueInstanceIdEntity.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/VespaUniqueInstanceIdEntity.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.athenz.identityprovider.api.bindings; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.Objects; @@ -26,14 +27,18 @@ public class VespaUniqueInstanceIdEntity { public final String clusterId; @JsonProperty("cluster-index") public final int clusterIndex; + @JsonProperty("type") + public final String type; + @JsonCreator public VespaUniqueInstanceIdEntity(@JsonProperty("tenant") String tenant, @JsonProperty("application") String application, @JsonProperty("environment") String environment, @JsonProperty("region") String region, @JsonProperty("instance") String instance, @JsonProperty("cluster-id") String clusterId, - @JsonProperty("cluster-index") int clusterIndex) { + @JsonProperty("cluster-index") int clusterIndex, + @JsonProperty("type") String type) { this.tenant = tenant; this.application = application; this.environment = environment; @@ -41,8 +46,21 @@ public class VespaUniqueInstanceIdEntity { this.instance = instance; this.clusterId = clusterId; this.clusterIndex = clusterIndex; + this.type = type; } + @Deprecated + public VespaUniqueInstanceIdEntity(String tenant, + String application, + String environment, + String region, + String instance, + String clusterId, + int clusterIndex) { + this(tenant, application, environment, region, instance, clusterId, clusterIndex, null); + } + + @Override public String toString() { return "VespaUniqueInstanceIdEntity{" + @@ -53,6 +71,7 @@ public class VespaUniqueInstanceIdEntity { ", instance='" + instance + '\'' + ", clusterId='" + clusterId + '\'' + ", clusterIndex=" + clusterIndex + + ", type='" + type + '\'' + '}'; } @@ -67,11 +86,12 @@ public class VespaUniqueInstanceIdEntity { Objects.equals(environment, that.environment) && Objects.equals(region, that.region) && Objects.equals(instance, that.instance) && - Objects.equals(clusterId, that.clusterId); + Objects.equals(clusterId, that.clusterId) && + Objects.equals(type, that.type); } @Override public int hashCode() { - return Objects.hash(tenant, application, environment, region, instance, clusterId, clusterIndex); + return Objects.hash(tenant, application, environment, region, instance, clusterId, clusterIndex, type); } } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java index 96e93ca419d..e8ef2d9f97e 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.athenz.identityprovider.client; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import com.yahoo.container.core.identity.IdentityConfig; import com.yahoo.vespa.athenz.api.AthenzService; import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper; @@ -28,7 +29,7 @@ import static com.yahoo.vespa.athenz.tls.KeyStoreType.JKS; */ class AthenzCredentialsService { - private static final ObjectMapper mapper = new ObjectMapper(); + private static final ObjectMapper mapper = new ObjectMapper().registerModule(new JavaTimeModule()); private final IdentityConfig identityConfig; private final IdentityDocumentClient identityDocumentClient; diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java index 90d1312c9f9..b9aba6e66b0 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java @@ -2,14 +2,12 @@ package com.yahoo.vespa.athenz.identityprovider.client; import com.fasterxml.jackson.databind.ObjectMapper; -import com.yahoo.vespa.athenz.api.AthenzService; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider; import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper; import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocumentClient; import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument; -import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId; import com.yahoo.vespa.athenz.identityprovider.api.bindings.SignedIdentityDocumentEntity; -import com.yahoo.vespa.athenz.utils.AthenzIdentities; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpUriRequest; import org.apache.http.client.methods.RequestBuilder; @@ -34,7 +32,7 @@ import java.util.function.Supplier; public class DefaultIdentityDocumentClient implements IdentityDocumentClient { private static final String IDENTITY_DOCUMENT_API = "/athenz/v1/provider/identity-document/"; - private static final ObjectMapper objectMapper = new ObjectMapper(); + private static final ObjectMapper objectMapper = new ObjectMapper().registerModule(new JavaTimeModule()); private final Supplier<SSLContext> sslContextSupplier; private final HostnameVerifier hostnameVerifier; @@ -82,15 +80,7 @@ public class DefaultIdentityDocumentClient implements IdentityDocumentClient { String responseContent = EntityUtils.toString(response.getEntity()); if (HttpStatus.isSuccess(response.getStatusLine().getStatusCode())) { SignedIdentityDocumentEntity entity = objectMapper.readValue(responseContent, SignedIdentityDocumentEntity.class); - return new SignedIdentityDocument( - EntityBindingsMapper.toIdentityDocument(entity.identityDocument), - entity.signature, - entity.signingKeyVersion, - VespaUniqueInstanceId.fromDottedString(entity.providerUniqueId), - entity.dnsSuffix, - (AthenzService) AthenzIdentities.from(entity.providerService), - entity.ztsEndpoint, - entity.documentVersion); + return EntityBindingsMapper.toSignedIdentityDocument(entity); } else { throw new RuntimeException( String.format( diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceIdTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceIdTest.java index 8c4e4c1262d..86b6c566987 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceIdTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/VespaUniqueInstanceIdTest.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.athenz.identityprovider.api; import org.junit.Test; +import static com.yahoo.vespa.athenz.identityprovider.api.IdentityType.*; import static org.junit.Assert.*; /** @@ -12,6 +13,18 @@ public class VespaUniqueInstanceIdTest { @Test public void can_serialize_to_and_deserialize_from_string() { VespaUniqueInstanceId id = + new VespaUniqueInstanceId(1, "cluster-id", "instance", "application", "tenant", "region", "environment", TENANT); + String stringRepresentation = id.asDottedString(); + String expectedStringRepresentation = "1.cluster-id.instance.application.tenant.region.environment.tenant"; + assertEquals(expectedStringRepresentation, stringRepresentation); + VespaUniqueInstanceId deserializedId = VespaUniqueInstanceId.fromDottedString(stringRepresentation); + assertEquals(id, deserializedId); + } + + // TODO Remove support for legacy representation without type + @Test + public void supports_legacy_representation_without_type() { + VespaUniqueInstanceId id = new VespaUniqueInstanceId(1, "cluster-id", "instance", "application", "tenant", "region", "environment"); String stringRepresentation = id.asDottedString(); String expectedStringRepresentation = "1.cluster-id.instance.application.tenant.region.environment"; diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java index 2e9b29f5327..7ad465a7d80 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java @@ -11,6 +11,7 @@ import com.yahoo.test.ManualClock; import com.yahoo.vespa.athenz.api.AthenzService; import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper; import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocument; +import com.yahoo.vespa.athenz.identityprovider.api.IdentityType; import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument; import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId; import com.yahoo.vespa.athenz.tls.KeyStoreBuilder; @@ -132,7 +133,7 @@ public class AthenzIdentityProviderImplTest { } private static String getIdentityDocument() throws JsonProcessingException { - VespaUniqueInstanceId instanceId = new VespaUniqueInstanceId(0, "default", "default", "application", "tenant", "us-north-1", "dev"); + VespaUniqueInstanceId instanceId = new VespaUniqueInstanceId(0, "default", "default", "application", "tenant", "us-north-1", "dev", IdentityType.TENANT); SignedIdentityDocument signedIdentityDocument = new SignedIdentityDocument( new IdentityDocument(instanceId, "localhost", "x.y.com", Instant.EPOCH, Collections.emptySet()), "dummysignature", @@ -141,7 +142,12 @@ public class AthenzIdentityProviderImplTest { "dev-us-north-1.vespa.cloud", new AthenzService("vespa.vespa.provider_dev_us-north-1"), URI.create("https://zts:4443/zts/v1"), - 1); + 1, + "localhost", + "x.y.com", + Instant.EPOCH, + Collections.emptySet(), + IdentityType.TENANT); return new ObjectMapper().registerModule(new JavaTimeModule()) .writeValueAsString(EntityBindingsMapper.toSignedIdentityDocumentEntity(signedIdentityDocument)); diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/CommandLineArguments.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/CommandLineArguments.java index 671038c852a..84d3b320772 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/CommandLineArguments.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/CommandLineArguments.java @@ -11,10 +11,15 @@ import io.airlift.command.Command; import io.airlift.command.HelpOption; import io.airlift.command.Option; import io.airlift.command.SingleCommand; +import org.apache.http.Header; +import org.apache.http.ParseException; import org.apache.http.conn.ssl.NoopHostnameVerifier; import org.apache.http.conn.ssl.SSLConnectionSocketFactory; +import org.apache.http.message.BasicLineParser; import javax.inject.Inject; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.TimeUnit; /** @@ -53,6 +58,15 @@ public class CommandLineArguments { return null; } + for (String header : cmdArgs.headers) { + try { + cmdArgs.parsedHeaders.add(BasicLineParser.parseHeader(header, null)); + } catch (ParseException e) { + System.err.printf("Invalid header: '%s' (%s)%n", header, e.getMessage()); + return null; + } + } + return cmdArgs; } @@ -180,6 +194,12 @@ public class CommandLineArguments { description = "Skip hostname verification when using TLS") private boolean insecure = false; + @Option(name = {"--header"}, + description = "Add http header to every request. Header must have the format '<Name>: <Value>'. Use this parameter multiple times for multiple headers") + private List<String> headers = new ArrayList<>(); + + private final List<Header> parsedHeaders = new ArrayList<>(); + int getWhenVerboseEnabledPrintMessageForEveryXDocuments() { return whenVerboseEnabledPrintMessageForEveryXDocuments; } @@ -192,6 +212,8 @@ public class CommandLineArguments { SessionParams createSessionParams(boolean useJson) { final int minThrottleValue = useDynamicThrottlingArg ? 10 : 0; + ConnectionParams.Builder connectionParamsBuilder = new ConnectionParams.Builder(); + parsedHeaders.forEach(header -> connectionParamsBuilder.addHeader(header.getName(), header.getValue())); SessionParams.Builder builder = new SessionParams.Builder() .setFeedParams( new FeedParams.Builder() @@ -208,7 +230,7 @@ public class CommandLineArguments { .build() ) .setConnectionParams( - new ConnectionParams.Builder() + connectionParamsBuilder .setHostnameVerifier(insecure ? NoopHostnameVerifier.INSTANCE : SSLConnectionSocketFactory.getDefaultHostnameVerifier()) .setNumPersistentConnectionsPerEndpoint(16) diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/runner/CommandLineArgumentsTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/runner/CommandLineArgumentsTest.java index e0d93a7fa18..84a69520a84 100644 --- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/runner/CommandLineArgumentsTest.java +++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/runner/CommandLineArgumentsTest.java @@ -7,7 +7,13 @@ import com.yahoo.vespa.http.client.config.SessionParams; import org.junit.Test; import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -109,6 +115,7 @@ public class CommandLineArgumentsTest { add("debugport", "7890"); args.add("--verbose"); args.add("--useTls"); + add("header", "Header-Name: Header-Value"); CommandLineArguments arguments = CommandLineArguments.build(asArray()); SessionParams params = arguments.createSessionParams(true /* use json */); assertThat(params.getClientQueueSize(), is(3456)); @@ -116,6 +123,7 @@ public class CommandLineArgumentsTest { assertThat(params.getClusters().get(0).getEndpoints().get(0).getPort(), is(1234)); assertThat(params.getClusters().get(0).getEndpoints().get(0).isUseSsl(), is(true)); assertThat(params.getConnectionParams().getUseCompression(), is(true)); + assertThat(params.getConnectionParams().getHeaders().size(), is(1)); assertThat(params.getFeedParams().getRoute(), is("routeValue")); assertThat(params.getFeedParams().getDataFormat(), is(FeedParams.DataFormat.JSON_UTF8)); assertThat(params.getFeedParams().getLocalQueueTimeOut(), is(2345000L)); @@ -124,6 +132,31 @@ public class CommandLineArgumentsTest { } @Test + public void testAddingMultipleHttpHeaders() { + add("host", "hostValue"); + String header1Name = "Header-Name-1"; + String header1Value = "Header-Value"; + add("header", header1Name + ": " + header1Value); + String header2Name = "Header-Name-2"; + String header2Value = "Another-Header-Value"; + add("header", header2Name + ": " + header2Value); + + CommandLineArguments arguments = CommandLineArguments.build(asArray()); + SessionParams params = arguments.createSessionParams(true /* use json */); + + List<Map.Entry<String, String>> headers = new ArrayList<>(params.getConnectionParams().getHeaders()); + headers.sort(Comparator.comparing(Map.Entry::getKey)); + + assertThat(headers.size(), is(2)); + Map.Entry<String, String> actualHeader1 = headers.get(0); + assertThat(actualHeader1.getKey(), is(header1Name)); + assertThat(actualHeader1.getValue(), is(header1Value)); + Map.Entry<String, String> actualHeader2 = headers.get(1); + assertThat(actualHeader2.getKey(), is(header2Name)); + assertThat(actualHeader2.getValue(), is(header2Value)); + } + + @Test public void testMultiHost() { add("file", "fileValue.json"); add("port", "1234"); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index 944755c9db2..3a66eef258d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -22,22 +22,37 @@ public class ScalarFunctions { public static DoubleBinaryOperator add() { return new Add(); } public static DoubleBinaryOperator divide() { return new Divide(); } public static DoubleBinaryOperator equal() { return new Equal(); } + public static DoubleBinaryOperator greater() { return new Greater(); } + public static DoubleBinaryOperator less() { return new Less(); } public static DoubleBinaryOperator max() { return new Max(); } public static DoubleBinaryOperator min() { return new Min(); } + public static DoubleBinaryOperator mean() { return new Mean(); } public static DoubleBinaryOperator multiply() { return new Multiply(); } + public static DoubleBinaryOperator pow() { return new Pow(); } public static DoubleBinaryOperator squareddifference() { return new SquaredDifference(); } public static DoubleBinaryOperator subtract() { return new Subtract(); } + public static DoubleUnaryOperator abs() { return new Abs(); } public static DoubleUnaryOperator acos() { return new Acos(); } + public static DoubleUnaryOperator asin() { return new Asin(); } + public static DoubleUnaryOperator atan() { return new Atan(); } + public static DoubleUnaryOperator ceil() { return new Ceil(); } + public static DoubleUnaryOperator cos() { return new Cos(); } public static DoubleUnaryOperator elu() { return new Elu(); } public static DoubleUnaryOperator exp() { return new Exp(); } public static DoubleUnaryOperator floor() { return new Floor(); } + public static DoubleUnaryOperator log() { return new Log(); } + public static DoubleUnaryOperator neg() { return new Neg(); } + public static DoubleUnaryOperator reciprocal() { return new Reciprocal(); } public static DoubleUnaryOperator relu() { return new Relu(); } public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); } public static DoubleUnaryOperator selu() { return new Selu(); } + public static DoubleUnaryOperator sin() { return new Sin(); } public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); } public static DoubleUnaryOperator sqrt() { return new Sqrt(); } public static DoubleUnaryOperator square() { return new Square(); } + public static DoubleUnaryOperator tan() { return new Tan(); } + public static DoubleUnaryOperator tanh() { return new Tanh(); } public static Function<List<Long>, Double> random() { return new Random(); } public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); } @@ -59,6 +74,20 @@ public class ScalarFunctions { public String toString() { return "f(a,b)(a==b)"; } } + public static class Greater implements DoubleBinaryOperator { + @Override + public double applyAsDouble(double left, double right) { return left > right ? 1 : 0; } + @Override + public String toString() { return "f(a,b)(a > b)"; } + } + + public static class Less implements DoubleBinaryOperator { + @Override + public double applyAsDouble(double left, double right) { return left < right ? 1 : 0; } + @Override + public String toString() { return "f(a,b)(a < b)"; } + } + public static class Max implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return Math.max(left, right); } @@ -73,6 +102,13 @@ public class ScalarFunctions { public String toString() { return "f(a,b)(min(a, b))"; } } + public static class Mean implements DoubleBinaryOperator { + @Override + public double applyAsDouble(double left, double right) { return (left + right) / 2; } + @Override + public String toString() { return "f(a,b)((a + b) / 2)"; } + } + public static class Multiply implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return left * right; } @@ -80,6 +116,13 @@ public class ScalarFunctions { public String toString() { return "f(a,b)(a * b)"; } } + public static class Pow implements DoubleBinaryOperator { + @Override + public double applyAsDouble(double left, double right) { return Math.pow(left, right); } + @Override + public String toString() { return "f(a,b)(pow(a, b))"; } + } + public static class Divide implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return left / right; } @@ -104,6 +147,13 @@ public class ScalarFunctions { // Unary operators ------------------------------------------------------------------------------ + public static class Abs implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.abs(operand); } + @Override + public String toString() { return "f(a)(fabs(a))"; } + } + public static class Acos implements DoubleUnaryOperator { @Override public double applyAsDouble(double operand) { return Math.acos(operand); } @@ -111,6 +161,34 @@ public class ScalarFunctions { public String toString() { return "f(a)(acos(a))"; } } + public static class Asin implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.asin(operand); } + @Override + public String toString() { return "f(a)(asin(a))"; } + } + + public static class Atan implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.atan(operand); } + @Override + public String toString() { return "f(a)(atan(a))"; } + } + + public static class Ceil implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.ceil(operand); } + @Override + public String toString() { return "f(a)(ceil(a))"; } + } + + public static class Cos implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.cos(operand); } + @Override + public String toString() { return "f(a)(cos(a))"; } + } + public static class Elu implements DoubleUnaryOperator { @Override public double applyAsDouble(double operand) { return operand < 0 ? Math.exp(operand) -1 : operand; } @@ -132,6 +210,26 @@ public class ScalarFunctions { public String toString() { return "f(a)(floor(a))"; } } + public static class Log implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.log(operand); } + @Override + public String toString() { return "f(a)(log(a))"; } + } + + public static class Neg implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return -operand; } + @Override + public String toString() { return "f(a)(-a)"; } + } + + public static class Reciprocal implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return 1.0 / operand; } + @Override + public String toString() { return "f(a)(1 / a)"; } + } public static class Relu implements DoubleUnaryOperator { @Override @@ -150,6 +248,13 @@ public class ScalarFunctions { public String toString() { return String.format("f(a)(%f * if(a >= 0, a, %f*(exp(a)-1)))", scale, alpha); } } + public static class Sin implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.sin(operand); } + @Override + public String toString() { return "f(a)(sin(a))"; } + } + public static class Rsqrt implements DoubleUnaryOperator { @Override public double applyAsDouble(double operand) { return 1.0 / Math.sqrt(operand); } @@ -172,15 +277,29 @@ public class ScalarFunctions { } public static class Square implements DoubleUnaryOperator { - @Override public double applyAsDouble(double operand) { return operand * operand; } - @Override public String toString() { return "f(a)(a * a)"; } + } + public static class Tan implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.tan(operand); } + @Override + public String toString() { return "f(a)(tan(a))"; } } + public static class Tanh implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.tanh(operand); } + @Override + public String toString() { return "f(a)(tanh(a))"; } + } + + + + // Variable-length operators ----------------------------------------------------------------------------- public static class EqualElements implements Function<List<Long>, Double> { |