diff options
52 files changed, 721 insertions, 94 deletions
diff --git a/cloud-tenant-base-dependencies-enforcer/OWNERS b/cloud-tenant-base-dependencies-enforcer/OWNERS new file mode 100644 index 00000000000..0a0d219e4eb --- /dev/null +++ b/cloud-tenant-base-dependencies-enforcer/OWNERS @@ -0,0 +1,2 @@ +bjorncs +mortent diff --git a/cloud-tenant-base-dependencies-enforcer/README.md b/cloud-tenant-base-dependencies-enforcer/README.md new file mode 100644 index 00000000000..46d1af7090e --- /dev/null +++ b/cloud-tenant-base-dependencies-enforcer/README.md @@ -0,0 +1,3 @@ +# Dependencies enforcer for cloud-tenant-base parent pom + +Enforces that only allowed dependencies are visible for tenant projects using the cloud-tenant-base parent pom. diff --git a/cloud-tenant-base-dependencies-enforcer/is-base-pom-module.txt b/cloud-tenant-base-dependencies-enforcer/is-base-pom-module.txt new file mode 100644 index 00000000000..56ff7dcddc8 --- /dev/null +++ b/cloud-tenant-base-dependencies-enforcer/is-base-pom-module.txt @@ -0,0 +1 @@ +Used to skip 'hosted-build-vespa-application' profile in 'hosted-tenant-base' diff --git a/cloud-tenant-base-dependencies-enforcer/pom.xml b/cloud-tenant-base-dependencies-enforcer/pom.xml new file mode 100644 index 00000000000..c46cc26dd03 --- /dev/null +++ b/cloud-tenant-base-dependencies-enforcer/pom.xml @@ -0,0 +1,283 @@ +<?xml version="1.0"?> +<!-- Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<project xmlns="http://maven.apache.org/POM/4.0.0" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 + http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <groupId>com.yahoo.vespa</groupId> + <artifactId>cloud-tenant-base</artifactId> + <version>7-SNAPSHOT</version> + <relativePath>../cloud-tenant-base/pom.xml</relativePath> + </parent> + + <artifactId>cloud-tenant-base-dependencies-enforcer</artifactId> + <version>7-SNAPSHOT</version> + <packaging>pom</packaging> + + <!-- MUST BE KEPT IN SYNC WITH container-dependency-versions pom + Copied here because vz-tenant-base does not have a parent. --> + <properties> + <aopalliance.version>1.0</aopalliance.version> + <athenz.version>1.8.49</athenz.version> + <bouncycastle.version>1.65</bouncycastle.version> + <felix.version>6.0.3</felix.version> + <felix.log.version>1.0.1</felix.log.version> + <findbugs.version>1.3.9</findbugs.version> + <guava.version>20.0</guava.version> + <guice.version>3.0</guice.version> + <javax.inject.version>1</javax.inject.version> + <javax.servlet-api.version>3.1.0</javax.servlet-api.version> + <jaxb.version>2.3.0</jaxb.version> + <jetty.version>9.4.30.v20200611</jetty.version> + <org.lz4.version>1.7.1</org.lz4.version> + <org.json.version>20090211</org.json.version> + <slf4j.version>1.7.5</slf4j.version> + <tensorflow.version>1.12.0</tensorflow.version> + <xml-apis.version>1.4.01</xml-apis.version> + + <hk2.version>2.5.0-b32</hk2.version> + <hk2.osgi-resource-locator.version>1.0.1</hk2.osgi-resource-locator.version> + <jackson2.version>2.8.11</jackson2.version> + <jackson-databind.version>${jackson2.version}.6</jackson-databind.version> + <javassist.version>3.20.0-GA</javassist.version> + <javax.annotation-api.version>1.2</javax.annotation-api.version> + <javax.validation-api.version>1.1.0.Final</javax.validation-api.version> + <javax.ws.rs-api.version>2.0.1</javax.ws.rs-api.version> + <jersey2.version>2.25</jersey2.version> + <mimepull.version>1.9.6</mimepull.version> + </properties> + + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-enforcer-plugin</artifactId> + <version>3.0.0-M2</version> + <executions> + <execution> + <!-- To allow running 'mvn enforcer:enforce' from the command line --> + <id>default-cli</id> + <goals> + <goal>enforce</goal> + </goals> + <configuration> + <rules> + <bannedDependencies> + <excludes> + <!-- Only allow explicitly listed dependencies --> + <exclude>*:*:*:*:*:*</exclude> + </excludes> + <includes> + <!-- MUST BE KEPT IN SYNC WITH container-dependencies-enforcer pom --> + <include>aopalliance:aopalliance:[${aopalliance.version}]:jar:provided</include> + <include>com.fasterxml.jackson.core:jackson-annotations:[${jackson2.version}]:jar:provided</include> + <include>com.fasterxml.jackson.core:jackson-core:[${jackson2.version}]:jar:provided</include> + <include>com.fasterxml.jackson.core:jackson-databind:[${jackson-databind.version}]:jar:provided</include> + <include>com.fasterxml.jackson.datatype:jackson-datatype-jdk8:[${jackson2.version}]:jar:provided</include> + <include>com.fasterxml.jackson.datatype:jackson-datatype-jsr310:[${jackson2.version}]:jar:provided</include> + + <!-- Use version range for jax deps, because jersey and junit affect the versions. --> + <include>com.fasterxml.jackson.jaxrs:jackson-jaxrs-base:[2.5.4, ${jackson2.version}]:jar:provided</include> + <include>com.fasterxml.jackson.jaxrs:jackson-jaxrs-json-provider:[2.5.4, ${jackson2.version}]:jar:provided</include> + <include>com.fasterxml.jackson.module:jackson-module-jaxb-annotations:[2.5.4, ${jackson2.version}]:jar:provided</include> + + <include>com.google.code.findbugs:jsr305:[${findbugs.version}]:jar:provided</include> + <include>com.google.guava:guava:[${guava.version}]:jar:provided</include> + <include>com.google.inject.extensions:guice-assistedinject:[${guice.version}]:jar:provided</include> + <include>com.google.inject.extensions:guice-multibindings:[${guice.version}]:jar:provided</include> + <include>com.google.inject:guice:[${guice.version}]:jar:provided:no_aop</include> + <include>com.sun.activation:javax.activation:[1.2.0]:jar:provided</include> + <include>com.sun.xml.bind:jaxb-core:[${jaxb.version}]:jar:provided</include> + <include>com.sun.xml.bind:jaxb-impl:[${jaxb.version}]:jar:provided</include> + <include>commons-logging:commons-logging:[1.1.3]:jar:provided</include> + <include>javax.annotation:javax.annotation-api:[${javax.annotation-api.version}]:jar:provided</include> + <include>javax.inject:javax.inject:[${javax.inject.version}]:jar:provided</include> + <include>javax.servlet:javax.servlet-api:[${javax.servlet-api.version}]:jar:provided</include> + <include>javax.validation:validation-api:[${javax.validation-api.version}]:jar:provided</include> + <include>javax.ws.rs:javax.ws.rs-api:[${javax.ws.rs-api.version}]:jar:provided</include> + <include>javax.xml.bind:jaxb-api:[${jaxb.version}]:jar:provided</include> + <include>net.jcip:jcip-annotations:[1.0]:jar:provided</include> + <include>org.lz4:lz4-java:[${org.lz4.version}]:jar:provided</include> + <include>org.apache.felix:org.apache.felix.framework:[${felix.version}]:jar:provided</include> + <include>org.apache.felix:org.apache.felix.log:[${felix.log.version}]:jar:provided</include> + <include>org.apache.felix:org.apache.felix.main:[${felix.version}]:jar:provided</include> + <include>org.bouncycastle:bcpkix-jdk15on:[${bouncycastle.version}]:jar:provided</include> + <include>org.bouncycastle:bcprov-jdk15on:[${bouncycastle.version}]:jar:provided</include> + <include>org.eclipse.jetty:jetty-http:[${jetty.version}]:jar:provided</include> + <include>org.eclipse.jetty:jetty-io:[${jetty.version}]:jar:provided</include> + <include>org.eclipse.jetty:jetty-util:[${jetty.version}]:jar:provided</include> + <include>org.glassfish.hk2.external:aopalliance-repackaged:[${hk2.version}]:jar:provided</include> + <include>org.glassfish.hk2.external:javax.inject:[${hk2.version}]:jar:provided</include> + <include>org.glassfish.hk2:hk2-api:[${hk2.version}]:jar:provided</include> + <include>org.glassfish.hk2:hk2-locator:[${hk2.version}]:jar:provided</include> + <include>org.glassfish.hk2:hk2-utils:[${hk2.version}]:jar:provided</include> + <include>org.glassfish.hk2:osgi-resource-locator:[${hk2.osgi-resource-locator.version}]:jar:provided</include> + <include>org.glassfish.jersey.bundles.repackaged:jersey-guava:[${jersey2.version}]:jar:provided</include> + <include>org.glassfish.jersey.containers:jersey-container-servlet-core:[${jersey2.version}]:jar:provided</include> + <include>org.glassfish.jersey.containers:jersey-container-servlet:[${jersey2.version}]:jar:provided</include> + <include>org.glassfish.jersey.core:jersey-client:[${jersey2.version}]:jar:provided</include> + <include>org.glassfish.jersey.core:jersey-common:[${jersey2.version}]:jar:provided</include> + <include>org.glassfish.jersey.core:jersey-server:[${jersey2.version}]:jar:provided</include> + <include>org.glassfish.jersey.ext:jersey-entity-filtering:[${jersey2.version}]:jar:provided</include> + <include>org.glassfish.jersey.ext:jersey-proxy-client:[${jersey2.version}]:jar:provided</include> + <include>org.glassfish.jersey.media:jersey-media-jaxb:[${jersey2.version}]:jar:provided</include> + <include>org.glassfish.jersey.media:jersey-media-json-jackson:[${jersey2.version}]:jar:provided</include> + <include>org.glassfish.jersey.media:jersey-media-multipart:[${jersey2.version}]:jar:provided</include> + <include>org.javassist:javassist:[${javassist.version}]:jar:provided</include> + <include>org.json:json:[${org.json.version}]:jar:provided</include> + <include>org.jvnet.mimepull:mimepull:[${mimepull.version}]:jar:provided</include> + <include>org.slf4j:jcl-over-slf4j:[${slf4j.version}]:jar:provided</include> + <include>org.slf4j:log4j-over-slf4j:[${slf4j.version}]:jar:provided</include> + <include>org.slf4j:slf4j-api:[${slf4j.version}]:jar:provided</include> + <include>org.slf4j:slf4j-jdk14:[${slf4j.version}]:jar:provided</include> + <include>xml-apis:xml-apis:[${xml-apis.version}]:jar:provided</include> + + <!-- Vespa provided dependencies --> + <include>com.yahoo.vespa:annotations:*:jar:provided</include> + <include>com.yahoo.vespa:chain:*:jar:provided</include> + <include>com.yahoo.vespa:component:*:jar:provided</include> + <include>com.yahoo.vespa:config-bundle:*:jar:provided</include> + <include>com.yahoo.vespa:config-lib:*:jar:provided</include> + <include>com.yahoo.vespa:config:*:jar:provided</include> + <include>com.yahoo.vespa:configdefinitions:*:jar:provided</include> + <include>com.yahoo.vespa:configgen:*:jar:provided</include> + <include>com.yahoo.vespa:container-accesslogging:*:jar:provided</include> + <include>com.yahoo.vespa:container-core:*:jar:provided</include> + <include>com.yahoo.vespa:container-dev:*:jar:provided</include> + <include>com.yahoo.vespa:container-di:*:jar:provided</include> + <include>com.yahoo.vespa:container-disc:*:jar:provided</include> + <include>com.yahoo.vespa:container-documentapi:*:jar:provided</include> + <include>com.yahoo.vespa:container-jersey2:*:jar:provided</include> + <include>com.yahoo.vespa:container-messagebus:*:jar:provided</include> + <include>com.yahoo.vespa:container-search-and-docproc:*:jar:provided</include> + <include>com.yahoo.vespa:container-search:*:jar:provided</include> + <include>com.yahoo.vespa:container:*:jar:provided</include> + <include>com.yahoo.vespa:defaults:*:jar:provided</include> + <include>com.yahoo.vespa:docproc:*:jar:provided</include> + <include>com.yahoo.vespa:document:*:jar:provided</include> + <include>com.yahoo.vespa:documentapi:*:jar:provided</include> + <include>com.yahoo.vespa:fileacquirer:*:jar:provided</include> + <include>com.yahoo.vespa:fsa:*:jar:provided</include> + <include>com.yahoo.vespa:hosted-zone-api:*:jar:provided</include> + <include>com.yahoo.vespa:http-utils:*:jar:provided</include> + <include>com.yahoo.vespa:jdisc_core:*:jar:provided</include> + <include>com.yahoo.vespa:jdisc_http_service:*:jar:provided</include> + <include>com.yahoo.vespa:jdisc_messagebus_service:*:jar:provided</include> + <include>com.yahoo.vespa:jrt:*:jar:provided</include> + <include>com.yahoo.vespa:linguistics:*:jar:provided</include> + <include>com.yahoo.vespa:messagebus-disc:*:jar:provided</include> + <include>com.yahoo.vespa:messagebus:*:jar:provided</include> + <include>com.yahoo.vespa:model-evaluation:*:jar:provided</include> + <include>com.yahoo.vespa:predicate-search-core:*:jar:provided</include> + <include>com.yahoo.vespa:processing:*:jar:provided</include> + <include>com.yahoo.vespa:provided-dependencies:*:jar:provided</include> + <include>com.yahoo.vespa:provided-yahoo-dependencies:*:jar:provided</include> + <include>com.yahoo.vespa:searchcore:*:jar:provided</include> + <include>com.yahoo.vespa:searchlib:*:jar:provided</include> + <include>com.yahoo.vespa:security-utils:*:jar:provided</include> + <include>com.yahoo.vespa:simplemetrics:*:jar:provided</include> + <include>com.yahoo.vespa:statistics:*:jar:provided</include> + <include>com.yahoo.vespa:vdslib:*:jar:provided</include> + <include>com.yahoo.vespa:vespa-http-client:*:jar:provided</include> + <include>com.yahoo.vespa:vespa_jersey2:*:pom:provided</include> + <include>com.yahoo.vespa:vespaclient-container-plugin:*:jar:provided</include> + <include>com.yahoo.vespa:vespajlib:*:jar:provided</include> + <include>com.yahoo.vespa:vespalog:*:jar:provided</include> + <include>com.yahoo.vespa:yolean:*:jar:provided</include> + + <!-- Vespa test dependencies --> + <include>com.yahoo.vespa:application:*:jar:test</include> + <include>com.yahoo.vespa:cloud-tenant-cd:*:jar:test</include> + <include>com.yahoo.vespa:config-application-package:*:jar:test</include> + <include>com.yahoo.vespa:config-model-api:*:jar:test</include> + <include>com.yahoo.vespa:config-model:*:jar:test</include> + <include>com.yahoo.vespa:config-provisioning:*:jar:test</include> + <include>com.yahoo.vespa:container-search-gui:*:jar:test</include> + <include>com.yahoo.vespa:container-test:*:jar:test</include> + <include>com.yahoo.vespa:hosted-api:*:jar:test</include> + <include>com.yahoo.vespa:indexinglanguage:*:jar:test</include> + <include>com.yahoo.vespa:jdisc_jetty:*:jar:test</include> + <include>com.yahoo.vespa:logd:*:jar:test</include> + <include>com.yahoo.vespa:metrics-proxy:*:jar:test</include> + <include>com.yahoo.vespa:metrics:*:jar:test</include> + <include>com.yahoo.vespa:model-integration:*:jar:test</include> + <include>com.yahoo.vespa:searchsummary:*:jar:test</include> + <include>com.yahoo.vespa:standalone-container:*:jar:test</include> + <include>com.yahoo.vespa:storage:*:jar:test</include> + <include>com.yahoo.vespa:tenant-cd-api:*:jar:test</include> + <include>com.yahoo.vespa:tenant-cd-commons:*:jar:test</include> + <include>com.yahoo.vespa:vespa-athenz:*:jar:test</include> + <include>com.yahoo.vespa:vespa_jersey2:*:pom:test</include> + <include>com.yahoo.vespa:vespaclient-core:*:jar:test</include> + <include>com.yahoo.vespa:vsm:*:jar:test</include> + + <!-- 3rd party test dependencies --> + <include>com.amazonaws:aws-java-sdk-core:1.11.542:jar:test</include> + <include>com.auth0:java-jwt:3.10.0:jar:test</include> + <include>com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:2.6.7:jar:test</include> + <include>com.fasterxml.jackson.dataformat:jackson-dataformat-xml:[${jackson2.version}]:jar:test</include> + <include>com.fasterxml.woodstox:woodstox-core:5.0.3:jar:test</include> + <include>com.google.protobuf:protobuf-java:3.7.0:jar:test</include> + <include>com.ibm.icu:icu4j:57.1:jar:test</include> + <include>com.intellij:annotations:12.0:jar:test</include> + <include>com.optimaize.languagedetector:language-detector:0.6:jar:test</include> + <include>com.thaiopensource:jing:20091111:jar:test</include> + <include>com.yahoo.athenz:athenz-auth-core:[${athenz.version}]:jar:test</include> + <include>com.yahoo.athenz:athenz-client-common:[${athenz.version}]:jar:test</include> + <include>com.yahoo.athenz:athenz-zms-core:[${athenz.version}]:jar:test</include> + <include>com.yahoo.athenz:athenz-zpe-java-client:[${athenz.version}]:jar:test</include> + <include>com.yahoo.athenz:athenz-zts-core:[${athenz.version}]:jar:test</include> + <include>com.yahoo.rdl:rdl-java:1.5.2:jar:test</include> + <include>commons-beanutils:commons-beanutils-core:1.8.0:jar:test</include> + <include>commons-beanutils:commons-beanutils:1.7.0:jar:test</include> + <include>commons-codec:commons-codec:1.11:jar:test</include> + <include>commons-digester:commons-digester:1.8:jar:test</include> + <include>io.airlift:airline:0.7:jar:test</include> + <include>io.jsonwebtoken:jjwt:0.9.1:jar:test</include> + <include>io.prometheus:simpleclient:0.6.0:jar:test</include> + <include>io.prometheus:simpleclient_common:0.6.0:jar:test</include> + <include>joda-time:joda-time:2.8.1:jar:test</include> + <include>net.arnx:jsonic:1.2.11:jar:test</include> + <include>net.java.dev.jna:jna:4.5.2:jar:test</include> + <include>org.abego.treelayout:org.abego.treelayout.core:1.0.1:jar:test</include> + <include>org.antlr:antlr-runtime:3.5.2:jar:test</include> + <include>org.antlr:antlr4-runtime:4.5:jar:test</include> + <include>org.apache.commons:commons-exec:1.3:jar:test</include> + <include>org.apache.commons:commons-math3:3.6.1:jar:test</include> + <include>org.apache.httpcomponents:httpclient:4.5.12:jar:test</include> + <include>org.apache.httpcomponents:httpcore:4.4.13:jar:test</include> + <include>org.apache.opennlp:opennlp-tools:1.8.4:jar:test</include> + <include>org.apiguardian:apiguardian-api:1.1.0:jar:test</include> + <include>org.codehaus.woodstox:stax2-api:3.1.4:jar:test</include> + <include>org.eclipse.jetty:jetty-continuation:[${jetty.version}]:jar:test</include> + <include>org.eclipse.jetty:jetty-jmx:[${jetty.version}]:jar:test</include> + <include>org.eclipse.jetty:jetty-security:[${jetty.version}]:jar:test</include> + <include>org.eclipse.jetty:jetty-server:[${jetty.version}]:jar:test</include> + <include>org.eclipse.jetty:jetty-servlet:[${jetty.version}]:jar:test</include> + <include>org.eclipse.jetty:jetty-servlets:[${jetty.version}]:jar:test</include> + <include>org.hdrhistogram:HdrHistogram:2.1.8:jar:test</include> + <include>org.junit.jupiter:junit-jupiter-api:5.6.2:jar:test</include> + <include>org.junit.jupiter:junit-jupiter-engine:5.6.2:jar:test</include> + <include>org.junit.platform:junit-platform-commons:1.6.2:jar:test</include> + <include>org.junit.platform:junit-platform-engine:1.6.2:jar:test</include> + <include>org.kohsuke:libpam4j:1.11:jar:test</include> + <include>org.opentest4j:opentest4j:1.2.0:jar:test</include> + <include>org.tensorflow:libtensorflow:[${tensorflow.version}]:jar:test</include> + <include>org.tensorflow:libtensorflow_jni:[${tensorflow.version}]:jar:test</include> + <include>org.tensorflow:proto:[${tensorflow.version}]:jar:test</include> + <include>org.tensorflow:tensorflow:[${tensorflow.version}]:jar:test</include> + <include>software.amazon.ion:ion-java:1.0.2:jar:test</include> + <include>xerces:xercesImpl:2.12.0:jar:test</include> + </includes> + </bannedDependencies> + </rules> + <fail>true</fail> + </configuration> + </execution> + </executions> + </plugin> + </plugins> + </build> +</project> diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/http/AccessControl.java b/config-model/src/main/java/com/yahoo/vespa/model/container/http/AccessControl.java index efde2d43350..f04edeb67f4 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/http/AccessControl.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/http/AccessControl.java @@ -144,8 +144,7 @@ public class AccessControl { for (FilterBinding binding : http.getBindings()) { if (binding.chainId().toId().equals(chainId)) { for (FilterBinding otherBinding : http.getBindings()) { - if (!binding.chainId().equals(otherBinding.chainId()) - && effectivelyDuplicateOf(binding.binding(), otherBinding.binding())) { + if (effectivelyDuplicateOf(binding, otherBinding)) { duplicateBindings.add(binding); } } @@ -154,14 +153,17 @@ public class AccessControl { duplicateBindings.forEach(http.getBindings()::remove); } - private static boolean effectivelyDuplicateOf(BindingPattern accessControlBinding, BindingPattern other) { - return accessControlBinding.equals(other) - || (accessControlBinding.path().equals(other.path()) && other.matchesAnyPort()); + private static boolean effectivelyDuplicateOf(FilterBinding accessControlBinding, FilterBinding other) { + if (accessControlBinding.chainId().equals(other.chainId())) return false; // Same filter chain + if (other.type() == FilterBinding.Type.RESPONSE) return false; + return accessControlBinding.binding().equals(other.binding()) + || (accessControlBinding.binding().path().equals(other.binding().path()) && other.binding().matchesAnyPort()); } private static FilterBinding createAccessControlBinding(String path) { return FilterBinding.create( + FilterBinding.Type.REQUEST, new ComponentSpecification(ACCESS_CONTROL_CHAIN_ID.stringValue()), SystemBindingPattern.fromHttpPortAndPath(Integer.toString(HOSTED_CONTAINER_PORT), path)); } @@ -170,6 +172,7 @@ public class AccessControl { BindingPattern rewrittenBinding = SystemBindingPattern.fromHttpPortAndPath( Integer.toString(HOSTED_CONTAINER_PORT), excludedBinding.path()); // only keep path from excluded binding return FilterBinding.create( + FilterBinding.Type.REQUEST, new ComponentSpecification(ACCESS_CONTROL_EXCLUDED_CHAIN_ID.stringValue()), rewrittenBinding); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/http/FilterBinding.java b/config-model/src/main/java/com/yahoo/vespa/model/container/http/FilterBinding.java index 1ca54769683..2921cdc9f11 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/http/FilterBinding.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/http/FilterBinding.java @@ -11,16 +11,20 @@ import java.util.Objects; */ public class FilterBinding { + public enum Type {REQUEST, RESPONSE} + + private final Type type; private final ComponentSpecification chainId; private final BindingPattern binding; - private FilterBinding(ComponentSpecification chainId, BindingPattern binding) { + private FilterBinding(Type type, ComponentSpecification chainId, BindingPattern binding) { + this.type = type; this.chainId = chainId; this.binding = binding; } - public static FilterBinding create(ComponentSpecification chainId, BindingPattern binding) { - return new FilterBinding(chainId, binding); + public static FilterBinding create(Type type, ComponentSpecification chainId, BindingPattern binding) { + return new FilterBinding(type, chainId, binding); } public ComponentSpecification chainId() { @@ -31,17 +35,20 @@ public class FilterBinding { return binding; } + public Type type() { return type; } + @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; FilterBinding that = (FilterBinding) o; - return Objects.equals(chainId, that.chainId) && + return type == that.type && + Objects.equals(chainId, that.chainId) && Objects.equals(binding, that.binding); } @Override public int hashCode() { - return Objects.hash(chainId, binding); + return Objects.hash(type, chainId, binding); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/http/xml/HttpBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/http/xml/HttpBuilder.java index c86d8b206d5..5b360b478fa 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/http/xml/HttpBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/http/xml/HttpBuilder.java @@ -31,6 +31,10 @@ import java.util.logging.Level; */ public class HttpBuilder extends VespaDomBuilder.DomConfigProducerBuilder<Http> { + static final String REQUEST_CHAIN_TAG_NAME = "request-chain"; + static final String RESPONSE_CHAIN_TAG_NAME = "response-chain"; + static final List<String> VALID_FILTER_CHAIN_TAG_NAMES = List.of(REQUEST_CHAIN_TAG_NAME, RESPONSE_CHAIN_TAG_NAME); + @Override protected Http doBuild(DeployState deployState, AbstractConfigProducer ancestor, Element spec) { FilterChains filterChains; @@ -116,18 +120,26 @@ public class HttpBuilder extends VespaDomBuilder.DomConfigProducerBuilder<Http> for (Element child: XML.getChildren(filteringSpec)) { String tagName = child.getTagName(); - if ((tagName.equals("request-chain") || tagName.equals("response-chain"))) { + if (VALID_FILTER_CHAIN_TAG_NAMES.contains(tagName)) { ComponentSpecification chainId = XmlHelper.getIdRef(child); for (Element bindingSpec: XML.getChildren(child, "binding")) { String binding = XML.getValue(bindingSpec); - result.add(FilterBinding.create(chainId, UserBindingPattern.fromPattern(binding))); + result.add(FilterBinding.create(toFilterBindingType(tagName), chainId, UserBindingPattern.fromPattern(binding))); } } } return result; } + private static FilterBinding.Type toFilterBindingType(String chainTag) { + switch (chainTag) { + case REQUEST_CHAIN_TAG_NAME: return FilterBinding.Type.REQUEST; + case RESPONSE_CHAIN_TAG_NAME: return FilterBinding.Type.RESPONSE; + default: throw new IllegalArgumentException("Unknown filter chain tag: " + chainTag); + } + } + static int readPort(ModelElement spec, boolean isHosted, DeployLogger logger) { Integer port = spec.integerAttribute("port"); if (port == null) diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/AccessControlTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/AccessControlTest.java index f5d0c2d1825..92b54a4679d 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/AccessControlTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/AccessControlTest.java @@ -192,7 +192,7 @@ public class AccessControlTest extends ContainerModelBuilderTestBase { } @Test - public void access_control_chains_does_not_contain_duplicate_bindings_to_user_filter_chain() { + public void access_control_chains_does_not_contain_duplicate_bindings_to_user_request_filter_chain() { Http http = createModelAndGetHttp( " <http>", " <handler id='custom.Handler'>", @@ -227,6 +227,46 @@ public class AccessControlTest extends ContainerModelBuilderTestBase { assertThat(actualCustomChainBindings, containsInAnyOrder("http://*/custom-handler/*", "http://*/")); } + @Test + public void access_control_excludes_are_not_affected_by_user_response_filter_chain() { + Http http = createModelAndGetHttp( + " <http>", + " <handler id='custom.Handler'>", + " <binding>http://*/custom-handler/*</binding>", + " </handler>", + " <filtering>", + " <access-control>", + " <exclude>", + " <binding>http://*/custom-handler/*</binding>", + " </exclude>", + " </access-control>", + " <response-chain id='my-custom-response-chain'>", + " <filter id='my-custom-response-filter' />", + " <binding>http://*/custom-handler/*</binding>", + " </response-chain>", + " </filtering>", + " </http>"); + + Set<String> actualExcludeBindings = getFilterBindings(http, AccessControl.ACCESS_CONTROL_EXCLUDED_CHAIN_ID); + assertThat(actualExcludeBindings, containsInAnyOrder( + "http://*:4443/ApplicationStatus", + "http://*:4443/status.html", + "http://*:4443/state/v1", + "http://*:4443/state/v1/*", + "http://*:4443/prometheus/v1", + "http://*:4443/prometheus/v1/*", + "http://*:4443/metrics/v2", + "http://*:4443/metrics/v2/*", + "http://*:4443/", + "http://*:4443/custom-handler/*")); + + Set<String> actualAccessControlBindings = getFilterBindings(http, AccessControl.ACCESS_CONTROL_CHAIN_ID); + assertThat(actualAccessControlBindings, containsInAnyOrder("http://*:4443/*")); + + Set<String> actualCustomChainBindings = getFilterBindings(http, ComponentId.fromString("my-custom-response-chain")); + assertThat(actualCustomChainBindings, containsInAnyOrder("http://*/custom-handler/*")); + } + private Http createModelAndGetHttp(String... httpElement) { List<String> servicesXml = new ArrayList<>(); servicesXml.add("<container version='1.0'>"); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java index 6a3a715e58d..eb13baf3e6b 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java @@ -65,9 +65,7 @@ import com.yahoo.vespa.config.server.tenant.TenantRepository; import com.yahoo.vespa.curator.Curator; import com.yahoo.vespa.curator.Lock; import com.yahoo.vespa.defaults.Defaults; -import com.yahoo.vespa.flags.BooleanFlag; import com.yahoo.vespa.flags.FlagSource; -import com.yahoo.vespa.flags.Flags; import com.yahoo.vespa.flags.InMemoryFlagSource; import com.yahoo.vespa.orchestrator.Orchestrator; @@ -127,7 +125,6 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye private final LogRetriever logRetriever; private final TesterClient testerClient; private final Metric metric; - private final BooleanFlag useTenantMetaData; @Inject public ApplicationRepository(TenantRepository tenantRepository, @@ -177,7 +174,6 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye this.clock = Objects.requireNonNull(clock); this.testerClient = Objects.requireNonNull(testerClient); this.metric = Objects.requireNonNull(metric); - this.useTenantMetaData = Flags.USE_TENANT_META_DATA.bindTo(flagSource); } public static class Builder { @@ -410,10 +406,7 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye checkIfActiveIsNewerThanSessionToBeActivated(prepared.getSessionId(), active.getSessionId()); transaction.add(active.createDeactivateTransaction().operations()); } - - if (useTenantMetaData.value()) - transaction.add(updateMetaDataWithDeployTimestamp(tenant, clock.instant())); - + transaction.add(updateMetaDataWithDeployTimestamp(tenant, clock.instant())); return transaction; } @@ -876,8 +869,6 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye } public Set<TenantName> deleteUnusedTenants(Duration ttlForUnusedTenant, Instant now) { - if ( ! useTenantMetaData.value()) return Set.of(); - return tenantRepository.getAllTenantNames().stream() .filter(tenantName -> activeApplications(tenantName).isEmpty()) .filter(tenantName -> !tenantName.equals(TenantName.defaultName())) // Not allowed to remove 'default' tenant diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java index ecbcb513c03..41377bdf317 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java @@ -18,8 +18,6 @@ import com.yahoo.vespa.config.server.session.SessionRepository; import com.yahoo.vespa.curator.Curator; import com.yahoo.vespa.curator.transaction.CuratorOperations; import com.yahoo.vespa.curator.transaction.CuratorTransaction; -import com.yahoo.vespa.flags.BooleanFlag; -import com.yahoo.vespa.flags.Flags; import org.apache.curator.framework.CuratorFramework; import org.apache.curator.framework.recipes.cache.PathChildrenCacheEvent; import org.apache.curator.framework.state.ConnectionState; @@ -88,7 +86,6 @@ public class TenantRepository { private final ExecutorService bootstrapExecutor; private final ScheduledExecutorService checkForRemovedApplicationsService = new ScheduledThreadPoolExecutor(1); private final Optional<Curator.DirectoryCache> directoryCache; - private final BooleanFlag useTenantMetaData; /** * Creates a new tenant repository @@ -105,7 +102,6 @@ public class TenantRepository { this.tenantListeners.add(componentRegistry.getTenantListener()); this.zkCacheExecutor = componentRegistry.getZkCacheExecutor(); this.zkWatcherExecutor = componentRegistry.getZkWatcherExecutor(); - this.useTenantMetaData = Flags.USE_TENANT_META_DATA.bindTo(componentRegistry.getFlagSource()); curator.framework().getConnectionStateListenable().addListener(this::stateChanged); curator.create(tenantsPath); @@ -230,9 +226,7 @@ public class TenantRepository { private Tenant createTenant(TenantName tenantName, Instant created) { if (tenants.containsKey(tenantName)) { Tenant tenant = getTenant(tenantName); - if (useTenantMetaData.value()) - createAndWriteTenantMetaData(tenant); - + createAndWriteTenantMetaData(tenant); return tenant; } @@ -255,8 +249,7 @@ public class TenantRepository { Tenant tenant = new Tenant(tenantName, sessionRepository, applicationRepo, applicationRepo, created); notifyNewTenant(tenant); tenants.putIfAbsent(tenantName, tenant); - if (useTenantMetaData.value()) - createAndWriteTenantMetaData(tenant); + createAndWriteTenantMetaData(tenant); return tenant; } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java index bd3053dd61f..f879f6c2a2a 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java @@ -165,25 +165,11 @@ public class ApplicationRepositoryTest { LocalSession session = tenant.getSessionRepository().getLocalSession(tenant.getApplicationRepo() .requireActiveSessionOf(applicationId())); session.getAllocatedHosts(); - - assertEquals(Instant.EPOCH, applicationRepository.getTenantMetaData(tenant).lastDeployTimestamp()); - assertEquals(Instant.EPOCH, applicationRepository.getTenantMetaData(tenant).createdTimestamp()); } @Test - public void prepareAndActivateWithTenantMetaData() throws IOException { - InMemoryFlagSource flagSource = new InMemoryFlagSource().withBooleanFlag(Flags.USE_TENANT_META_DATA.id(), false); - setup(flagSource); - - // Tenants created when flag is false has EPOCH as metadata values - Tenant tenant = applicationRepository.getTenant(applicationId()); - assertEquals(Instant.EPOCH.toEpochMilli(), - applicationRepository.getTenantMetaData(tenant).createdTimestamp().toEpochMilli()); - assertEquals(Instant.EPOCH.toEpochMilli(), - applicationRepository.getTenantMetaData(tenant).lastDeployTimestamp().toEpochMilli()); - - // Change flag value to true - flagSource.withBooleanFlag(Flags.USE_TENANT_META_DATA.id(), true); + public void prepareAndActivateWithTenantMetaData() { + Instant startTime = clock.instant(); Duration duration = Duration.ofHours(1); clock.advance(duration); Instant deployTime = clock.instant(); @@ -191,12 +177,9 @@ public class ApplicationRepositoryTest { assertTrue(result.configChangeActions().getRefeedActions().isEmpty()); assertTrue(result.configChangeActions().getRestartActions().isEmpty()); - LocalSession session = tenant.getSessionRepository().getLocalSession(tenant.getApplicationRepo() - .requireActiveSessionOf(applicationId())); - session.getAllocatedHosts(); + Tenant tenant = applicationRepository.getTenant(applicationId()); - // Only last deploy timestamp updated - assertEquals(Instant.EPOCH.toEpochMilli(), + assertEquals(startTime.toEpochMilli(), applicationRepository.getTenantMetaData(tenant).createdTimestamp().toEpochMilli()); assertEquals(deployTime.toEpochMilli(), applicationRepository.getTenantMetaData(tenant).lastDeployTimestamp().toEpochMilli()); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/maintenance/MaintainerTester.java b/configserver/src/test/java/com/yahoo/vespa/config/server/maintenance/MaintainerTester.java index 043841c6acb..7999f9280c0 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/maintenance/MaintainerTester.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/maintenance/MaintainerTester.java @@ -23,7 +23,6 @@ import com.yahoo.vespa.config.server.session.PrepareParams; import com.yahoo.vespa.config.server.tenant.TenantRepository; import com.yahoo.vespa.curator.Curator; import com.yahoo.vespa.curator.mock.MockCurator; -import com.yahoo.vespa.flags.FlagSource; import org.junit.rules.TemporaryFolder; import java.io.File; @@ -39,7 +38,7 @@ class MaintainerTester { private final ApplicationRepository applicationRepository; private final Clock clock; - MaintainerTester(Clock clock, FlagSource flagSource, TemporaryFolder temporaryFolder) throws IOException { + MaintainerTester(Clock clock, TemporaryFolder temporaryFolder) throws IOException { this.clock = clock; this.curator = new MockCurator(); InMemoryProvisioner hostProvisioner = new InMemoryProvisioner(true, "host0", "host1", "host2", "host3", "host4"); @@ -54,7 +53,6 @@ class MaintainerTester { .clock(clock) .configServerConfig(configserverConfig) .provisioner(provisioner) - .flagSource(flagSource) .modelFactoryRegistry(new ModelFactoryRegistry(List.of(new DeployTester.CountingModelFactory(clock)))) .build(); tenantRepository = new TenantRepository(componentRegistry); @@ -63,7 +61,6 @@ class MaintainerTester { .withProvisioner(provisioner) .withOrchestrator(new OrchestratorMock()) .withLogRetriever(new MockLogRetriever()) - .withFlagSource(flagSource) .withClock(clock) .withConfigserverConfig(configserverConfig) .build(); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/maintenance/TenantsMaintainerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/maintenance/TenantsMaintainerTest.java index e6172546ff8..463e5361248 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/maintenance/TenantsMaintainerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/maintenance/TenantsMaintainerTest.java @@ -9,8 +9,6 @@ import com.yahoo.test.ManualClock; import com.yahoo.vespa.config.server.ApplicationRepository; import com.yahoo.vespa.config.server.session.PrepareParams; import com.yahoo.vespa.config.server.tenant.TenantRepository; -import com.yahoo.vespa.flags.FlagSource; -import com.yahoo.vespa.flags.Flags; import com.yahoo.vespa.flags.InMemoryFlagSource; import org.junit.Rule; import org.junit.Test; @@ -31,8 +29,7 @@ public class TenantsMaintainerTest { @Test public void deleteTenantWithNoApplications() throws IOException { ManualClock clock = new ManualClock("2020-06-01T00:00:00"); - FlagSource flagSource = new InMemoryFlagSource().withBooleanFlag(Flags.USE_TENANT_META_DATA.id(), true); - MaintainerTester tester = new MaintainerTester(clock, flagSource, temporaryFolder); + MaintainerTester tester = new MaintainerTester(clock, temporaryFolder); TenantRepository tenantRepository = tester.tenantRepository(); ApplicationRepository applicationRepository = tester.applicationRepository(); File applicationPackage = new File("src/test/apps/hosted"); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationMetaDataGarbageCollector.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationMetaDataGarbageCollector.java new file mode 100644 index 00000000000..9fa3b91f633 --- /dev/null +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ApplicationMetaDataGarbageCollector.java @@ -0,0 +1,29 @@ +package com.yahoo.vespa.hosted.controller.maintenance; + +import com.yahoo.vespa.hosted.controller.Controller; + +import java.time.Duration; +import java.util.logging.Level; +import java.util.logging.Logger; + +public class ApplicationMetaDataGarbageCollector extends ControllerMaintainer { + + private static final Logger log = Logger.getLogger(ApplicationMetaDataGarbageCollector.class.getName()); + + public ApplicationMetaDataGarbageCollector(Controller controller, Duration interval) { + super(controller, interval); + } + + @Override + protected boolean maintain() { + try { + controller().applications().applicationStore().pruneMeta(controller().clock().instant().minus(Duration.ofDays(365))); + return true; + } + catch (Exception e) { + log.log(Level.WARNING, "Exception pruning old application meta data", e); + return false; + } + } + +} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java index 336dc5ddd04..0e72a1b42a7 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java @@ -44,6 +44,7 @@ public class ControllerMaintenance extends AbstractComponent { private final RotationStatusUpdater rotationStatusUpdater; private final ResourceTagMaintainer resourceTagMaintainer; private final SystemRoutingPolicyMaintainer systemRoutingPolicyMaintainer; + private final ApplicationMetaDataGarbageCollector applicationMetaDataGarbageCollector; @Inject @SuppressWarnings("unused") // instantiated by Dependency Injection @@ -73,6 +74,7 @@ public class ControllerMaintenance extends AbstractComponent { rotationStatusUpdater = new RotationStatusUpdater(controller, maintenanceInterval); resourceTagMaintainer = new ResourceTagMaintainer(controller, Duration.ofMinutes(30), controller.serviceRegistry().resourceTagger()); systemRoutingPolicyMaintainer = new SystemRoutingPolicyMaintainer(controller, Duration.ofMinutes(10)); + applicationMetaDataGarbageCollector = new ApplicationMetaDataGarbageCollector(controller, Duration.ofHours(12)); } public Upgrader upgrader() { return upgrader; } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json index acd542b001c..385f0fbc3cf 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json @@ -1,6 +1,9 @@ { "jobs": [ { + "name": "ApplicationMetaDataGarbageCollector" + }, + { "name": "ApplicationOwnershipConfirmer" }, { diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index 7695fbae627..f33ca549cf7 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -368,7 +368,7 @@ public class Flags { public static final UnboundBooleanFlag USE_TENANT_META_DATA = defineFeatureFlag( "use-tenant-meta-data", - false, + true, "Whether config server should write and read tenant metadata", "Takes effect immediately" ); @@ -32,6 +32,7 @@ <module>chain</module> <module>client</module> <module>cloud-tenant-base</module> + <module>cloud-tenant-base-dependencies-enforcer</module> <module>cloud-tenant-cd</module> <module>clustercontroller-apps</module> <module>clustercontroller-apputil</module> diff --git a/searchcore/src/apps/tests/persistenceconformance_test.cpp b/searchcore/src/apps/tests/persistenceconformance_test.cpp index 44fb2770594..26a44606898 100644 --- a/searchcore/src/apps/tests/persistenceconformance_test.cpp +++ b/searchcore/src/apps/tests/persistenceconformance_test.cpp @@ -127,6 +127,7 @@ public: 1, std::make_shared<RankProfilesConfig>(), std::make_shared<matching::RankingConstants>(), + std::make_shared<matching::OnnxModels>(), indexschema, attributes, summary, diff --git a/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp b/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp index 118cad4d8ef..7043c450047 100644 --- a/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp +++ b/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp @@ -12,6 +12,7 @@ #include <vespa/eval/tensor/default_tensor_engine.h> #include <vespa/searchcommon/common/schemaconfigurer.h> #include <vespa/searchcore/config/config-ranking-constants.h> +#include <vespa/searchcore/config/config-onnx-models.h> #include <vespa/searchcore/proton/matching/indexenvironment.h> #include <vespa/searchlib/features/setup.h> #include <vespa/searchlib/fef/fef.h> @@ -28,10 +29,12 @@ using config::ConfigSubscriber; using config::IConfigContext; using config::InvalidConfigException; using proton::matching::IConstantValueRepo; +using proton::matching::OnnxModels; using vespa::config::search::AttributesConfig; using vespa::config::search::IndexschemaConfig; using vespa::config::search::RankProfilesConfig; using vespa::config::search::core::RankingConstantsConfig; +using vespa::config::search::core::OnnxModelsConfig; using vespalib::eval::ConstantValue; using vespalib::eval::TensorSpec; using vespalib::eval::ValueType; @@ -39,17 +42,30 @@ using vespalib::tensor::DefaultTensorEngine; using vespalib::eval::SimpleConstantValue; using vespalib::eval::BadConstantValue; +OnnxModels make_models(const OnnxModelsConfig &modelsCfg) { + OnnxModels::Vector model_list; + for (const auto &entry: modelsCfg.model) { + // TODO(havardpe): resolve model path + vespalib::string model_path = entry.name; + model_path += ".onnx"; + model_list.emplace_back(entry.name, model_path); + } + return OnnxModels(model_list); +} + class App : public FastOS_Application { public: bool verify(const search::index::Schema &schema, const search::fef::Properties &props, - const IConstantValueRepo &repo); + const IConstantValueRepo &repo, + OnnxModels models); bool verifyConfig(const RankProfilesConfig &rankCfg, const IndexschemaConfig &schemaCfg, const AttributesConfig &attributeCfg, - const RankingConstantsConfig &constantsCfg); + const RankingConstantsConfig &constantsCfg, + const OnnxModelsConfig &modelsCfg); int usage(); int Main() override; @@ -77,9 +93,10 @@ struct DummyConstantValueRepo : IConstantValueRepo { bool App::verify(const search::index::Schema &schema, const search::fef::Properties &props, - const IConstantValueRepo &repo) + const IConstantValueRepo &repo, + OnnxModels models) { - proton::matching::IndexEnvironment indexEnv(0, schema, props, repo); + proton::matching::IndexEnvironment indexEnv(0, schema, props, repo, models); search::fef::BlueprintFactory factory; search::features::setup_search_features(factory); search::fef::test::setup_fef_test_plugin(factory); @@ -106,13 +123,15 @@ bool App::verifyConfig(const RankProfilesConfig &rankCfg, const IndexschemaConfig &schemaCfg, const AttributesConfig &attributeCfg, - const RankingConstantsConfig &constantsCfg) + const RankingConstantsConfig &constantsCfg, + const OnnxModelsConfig &modelsCfg) { bool ok = true; search::index::Schema schema; search::index::SchemaBuilder::build(schemaCfg, schema); search::index::SchemaBuilder::build(attributeCfg, schema); DummyConstantValueRepo repo(constantsCfg); + auto models = make_models(modelsCfg); for(size_t i = 0; i < rankCfg.rankprofile.size(); i++) { search::fef::Properties properties; const RankProfilesConfig::Rankprofile &profile = rankCfg.rankprofile[i]; @@ -120,7 +139,7 @@ App::verifyConfig(const RankProfilesConfig &rankCfg, properties.add(profile.fef.property[j].name, profile.fef.property[j].value); } - if (verify(schema, properties, repo)) { + if (verify(schema, properties, repo, models)) { LOG(info, "rank profile '%s': pass", profile.name.c_str()); } else { LOG(error, "rank profile '%s': FAIL", profile.name.c_str()); @@ -157,12 +176,14 @@ App::Main() ConfigHandle<AttributesConfig>::UP attributesHandle = subscriber.subscribe<AttributesConfig>(cfgId); ConfigHandle<IndexschemaConfig>::UP schemaHandle = subscriber.subscribe<IndexschemaConfig>(cfgId); ConfigHandle<RankingConstantsConfig>::UP constantsHandle = subscriber.subscribe<RankingConstantsConfig>(cfgId); + ConfigHandle<OnnxModelsConfig>::UP modelsHandle = subscriber.subscribe<OnnxModelsConfig>(cfgId); subscriber.nextConfig(); ok = verifyConfig(*rankHandle->getConfig(), *schemaHandle->getConfig(), *attributesHandle->getConfig(), - *constantsHandle->getConfig()); + *constantsHandle->getConfig(), + *modelsHandle->getConfig()); } catch (ConfigRuntimeException & e) { LOG(error, "Unable to subscribe to config: %s", e.getMessage().c_str()); } catch (InvalidConfigException & e) { diff --git a/searchcore/src/tests/proton/documentdb/configurer/configurer_test.cpp b/searchcore/src/tests/proton/documentdb/configurer/configurer_test.cpp index 1da35c9f5c3..b2903f00226 100644 --- a/searchcore/src/tests/proton/documentdb/configurer/configurer_test.cpp +++ b/searchcore/src/tests/proton/documentdb/configurer/configurer_test.cpp @@ -646,6 +646,7 @@ TEST("require that maintenance controller should change if some config has chang TEST_DO(assertMaintenanceControllerShouldChange(CCR().setRankProfilesChanged(true))); TEST_DO(assertMaintenanceControllerShouldChange(CCR().setRankingConstantsChanged(true))); + TEST_DO(assertMaintenanceControllerShouldChange(CCR().setOnnxModelsChanged(true))); TEST_DO(assertMaintenanceControllerShouldChange(CCR().setIndexschemaChanged(true))); TEST_DO(assertMaintenanceControllerShouldChange(CCR().setAttributesChanged(true))); TEST_DO(assertMaintenanceControllerShouldChange(CCR().setSummaryChanged(true))); @@ -692,6 +693,7 @@ TEST("require that subdbs should change if relevant config changed") TEST_DO(assertSubDbsShouldChange(CCR().setVisibilityDelayChanged(true))); TEST_DO(assertSubDbsShouldChange(CCR().setRankProfilesChanged(true))); TEST_DO(assertSubDbsShouldChange(CCR().setRankingConstantsChanged(true))); + TEST_DO(assertSubDbsShouldChange(CCR().setOnnxModelsChanged(true))); TEST_DO(assertSubDbsShouldChange(CCR().setSchemaChanged(true))); } diff --git a/searchcore/src/tests/proton/documentdb/documentdbconfig/documentdbconfig_test.cpp b/searchcore/src/tests/proton/documentdb/documentdbconfig/documentdbconfig_test.cpp index a2b824b88ba..aed01ca0192 100644 --- a/searchcore/src/tests/proton/documentdb/documentdbconfig/documentdbconfig_test.cpp +++ b/searchcore/src/tests/proton/documentdb/documentdbconfig/documentdbconfig_test.cpp @@ -17,6 +17,7 @@ using namespace search::index; using namespace search; using namespace vespa::config::search; using proton::matching::RankingConstants; +using proton::matching::OnnxModels; using std::make_shared; using std::shared_ptr; using document::config_builder::DocumenttypesConfigBuilderHelper; @@ -68,6 +69,11 @@ public: _builder.rankingConstants(make_shared<RankingConstants>(constants)); return *this; } + MyConfigBuilder &addOnnxModel() { + OnnxModels::Vector models = {{"my_model_name", "my_model_file"}}; + _builder.onnxModels(make_shared<OnnxModels>(models)); + return *this; + } MyConfigBuilder &addImportedField() { ImportedFieldsConfigBuilder builder; builder.attribute.resize(1); @@ -132,6 +138,7 @@ struct Fixture { fullCfg = MyConfigBuilder(4, schema, repo).addAttribute(). addRankProfile(). addRankingConstant(). + addOnnxModel(). addImportedField(). addSummary(true). addSummarymap(). @@ -166,12 +173,14 @@ struct DelayAttributeAspectFixture { attrCfg = MyConfigBuilder(4, schema, makeDocTypeRepo(true)).addAttribute(). addRankProfile(). addRankingConstant(). + addOnnxModel(). addImportedField(). addSummary(true). addSummarymap(). build(); noAttrCfg = MyConfigBuilder(4, schema, makeDocTypeRepo(hasDocField)).addRankProfile(). addRankingConstant(). + addOnnxModel(). addImportedField(). addSummary(hasDocField). build(); diff --git a/searchcore/src/tests/proton/documentdb/fileconfigmanager/fileconfigmanager_test.cpp b/searchcore/src/tests/proton/documentdb/fileconfigmanager/fileconfigmanager_test.cpp index 2782117d8ae..2352fda65a0 100644 --- a/searchcore/src/tests/proton/documentdb/fileconfigmanager/fileconfigmanager_test.cpp +++ b/searchcore/src/tests/proton/documentdb/fileconfigmanager/fileconfigmanager_test.cpp @@ -28,6 +28,7 @@ using namespace vespa::config::search; using namespace std::chrono_literals; using vespa::config::content::core::BucketspacesConfig; using proton::matching::RankingConstants; +using proton::matching::OnnxModels; typedef DocumentDBConfigHelper DBCM; typedef DocumentDBConfig::DocumenttypesConfigSP DocumenttypesConfigSP; @@ -77,7 +78,9 @@ assertEqualSnapshot(const DocumentDBConfig &exp, const DocumentDBConfig &act) { EXPECT_TRUE(exp.getRankProfilesConfig() == act.getRankProfilesConfig()); EXPECT_TRUE(exp.getRankingConstants() == act.getRankingConstants()); + EXPECT_TRUE(exp.getOnnxModels() == act.getOnnxModels()); EXPECT_EQUAL(0u, exp.getRankingConstants().size()); + EXPECT_EQUAL(0u, exp.getOnnxModels().size()); EXPECT_TRUE(exp.getIndexschemaConfig() == act.getIndexschemaConfig()); EXPECT_TRUE(exp.getAttributesConfig() == act.getAttributesConfig()); EXPECT_TRUE(exp.getSummaryConfig() == act.getSummaryConfig()); @@ -105,6 +108,9 @@ addConfigsThatAreNotSavedToDisk(const DocumentDBConfig &cfg) RankingConstants::Vector constants = {{"my_name", "my_type", "my_path"}}; builder.rankingConstants(std::make_shared<RankingConstants>(constants)); + OnnxModels::Vector models = {{"my_model_name", "my_model_file"}}; + builder.onnxModels(std::make_shared<OnnxModels>(models)); + ImportedFieldsConfigBuilder importedFields; importedFields.attribute.resize(1); importedFields.attribute.back().name = "my_name"; diff --git a/searchcore/src/tests/proton/matching/index_environment/index_environment_test.cpp b/searchcore/src/tests/proton/matching/index_environment/index_environment_test.cpp index 932ab6f4d14..508a60480d0 100644 --- a/searchcore/src/tests/proton/matching/index_environment/index_environment_test.cpp +++ b/searchcore/src/tests/proton/matching/index_environment/index_environment_test.cpp @@ -14,6 +14,13 @@ using search::index::schema::DataType; using vespalib::eval::ConstantValue; using SIAF = Schema::ImportedAttributeField; +OnnxModels make_models() { + OnnxModels::Vector list; + list.emplace_back("model1", "path1"); + list.emplace_back("model2", "path2"); + return OnnxModels(list); +} + struct MyConstantValueRepo : public IConstantValueRepo { virtual ConstantValue::UP getConstant(const vespalib::string &) const override { return ConstantValue::UP(); @@ -42,7 +49,7 @@ struct Fixture { Fixture(Schema::UP schema_) : repo(), schema(std::move(schema_)), - env(7, *schema, Properties(), repo) + env(7, *schema, Properties(), repo, make_models()) { } const FieldInfo *assertField(size_t idx, @@ -97,4 +104,10 @@ TEST_F("require that imported attribute fields are extracted in index environmen EXPECT_EQUAL("[documentmetastore]", f.env.getField(2)->name()); } +TEST_F("require that onnx model paths can be obtained", Fixture(buildEmptySchema())) { + EXPECT_EQUAL(f1.env.getOnnxModelFullPath("model1").value(), vespalib::string("path1")); + EXPECT_EQUAL(f1.env.getOnnxModelFullPath("model2").value(), vespalib::string("path2")); + EXPECT_FALSE(f1.env.getOnnxModelFullPath("model3").has_value()); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/searchcore/src/tests/proton/matching/matching_test.cpp b/searchcore/src/tests/proton/matching/matching_test.cpp index 9d5b67af81c..0ea63bce859 100644 --- a/searchcore/src/tests/proton/matching/matching_test.cpp +++ b/searchcore/src/tests/proton/matching/matching_test.cpp @@ -278,7 +278,7 @@ struct MyWorld { } Matcher::SP createMatcher() { - return std::make_shared<Matcher>(schema, config, clock, queryLimiter, constantValueRepo, 0); + return std::make_shared<Matcher>(schema, config, clock, queryLimiter, constantValueRepo, OnnxModels(), 0); } struct MySearchHandler : ISearchHandler { diff --git a/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp b/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp index a947074a917..1e64a8f4ecb 100644 --- a/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp +++ b/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp @@ -8,6 +8,7 @@ #include <vespa/searchcore/proton/server/i_proton_configurer.h> #include <vespa/searchcore/proton/common/hw_info.h> #include <vespa/searchcore/config/config-ranking-constants.h> +#include <vespa/searchcore/config/config-onnx-models.h> #include <vespa/searchsummary/config/config-juniperrc.h> #include <vespa/document/repo/documenttyperepo.h> #include <vespa/fileacquirer/config-filedistributorrpc.h> @@ -45,6 +46,7 @@ struct DoctypeFixture { AttributesConfigBuilder attributesBuilder; RankProfilesConfigBuilder rankProfilesBuilder; RankingConstantsConfigBuilder rankingConstantsBuilder; + OnnxModelsConfigBuilder onnxModelsBuilder; IndexschemaConfigBuilder indexschemaBuilder; SummaryConfigBuilder summaryBuilder; SummarymapConfigBuilder summarymapBuilder; @@ -100,6 +102,7 @@ struct ConfigTestFixture { set.addBuilder(db.configid, &fixture->attributesBuilder); set.addBuilder(db.configid, &fixture->rankProfilesBuilder); set.addBuilder(db.configid, &fixture->rankingConstantsBuilder); + set.addBuilder(db.configid, &fixture->onnxModelsBuilder); set.addBuilder(db.configid, &fixture->indexschemaBuilder); set.addBuilder(db.configid, &fixture->summaryBuilder); set.addBuilder(db.configid, &fixture->summarymapBuilder); @@ -253,7 +256,7 @@ TEST_FF("require that documentdb config manager subscribes for config", DocumentDBConfigManager(f1.configId + "/typea", "typea")) { f1.addDocType("typea"); const ConfigKeySet keySet(f2.createConfigKeySet()); - ASSERT_EQUAL(8u, keySet.size()); + ASSERT_EQUAL(9u, keySet.size()); ASSERT_TRUE(f1.configEqual("typea", getDocumentDBConfig(f1, f2))); } diff --git a/searchcore/src/tests/proton/proton_configurer/proton_configurer_test.cpp b/searchcore/src/tests/proton/proton_configurer/proton_configurer_test.cpp index 83706d966ae..6190177ac9d 100644 --- a/searchcore/src/tests/proton/proton_configurer/proton_configurer_test.cpp +++ b/searchcore/src/tests/proton/proton_configurer/proton_configurer_test.cpp @@ -19,6 +19,7 @@ #include <vespa/searchcore/proton/server/i_proton_disk_layout.h> #include <vespa/searchsummary/config/config-juniperrc.h> #include <vespa/searchcore/config/config-ranking-constants.h> +#include <vespa/searchcore/config/config-onnx-models.h> #include <vespa/vespalib/gtest/gtest.h> #include <vespa/searchcommon/common/schemaconfigurer.h> #include <vespa/vespalib/util/threadstackexecutor.h> @@ -44,12 +45,14 @@ using std::map; using search::index::Schema; using search::index::SchemaBuilder; using proton::matching::RankingConstants; +using proton::matching::OnnxModels; struct DBConfigFixture { using UP = std::unique_ptr<DBConfigFixture>; AttributesConfigBuilder _attributesBuilder; RankProfilesConfigBuilder _rankProfilesBuilder; RankingConstantsConfigBuilder _rankingConstantsBuilder; + OnnxModelsConfigBuilder _onnxModelsBuilder; IndexschemaConfigBuilder _indexschemaBuilder; SummaryConfigBuilder _summaryBuilder; SummarymapConfigBuilder _summarymapBuilder; @@ -70,6 +73,11 @@ struct DBConfigFixture { return std::make_shared<RankingConstants>(); } + OnnxModels::SP buildOnnxModels() + { + return std::make_shared<OnnxModels>(); + } + DocumentDBConfig::SP getConfig(int64_t generation, std::shared_ptr<DocumenttypesConfig> documentTypes, std::shared_ptr<const DocumentTypeRepo> repo, @@ -80,6 +88,7 @@ struct DBConfigFixture { (generation, std::make_shared<RankProfilesConfig>(_rankProfilesBuilder), buildRankingConstants(), + buildOnnxModels(), std::make_shared<IndexschemaConfig>(_indexschemaBuilder), std::make_shared<AttributesConfig>(_attributesBuilder), std::make_shared<SummaryConfig>(_summaryBuilder), diff --git a/searchcore/src/vespa/searchcore/config/CMakeLists.txt b/searchcore/src/vespa/searchcore/config/CMakeLists.txt index a4f5560c712..915ab147978 100644 --- a/searchcore/src/vespa/searchcore/config/CMakeLists.txt +++ b/searchcore/src/vespa/searchcore/config/CMakeLists.txt @@ -9,4 +9,6 @@ vespa_generate_config(searchcore_fconfig proton.def) install_config_definition(proton.def vespa.config.search.core.proton.def) vespa_generate_config(searchcore_fconfig ranking-constants.def) install_config_definition(ranking-constants.def vespa.config.search.core.ranking-constants.def) +vespa_generate_config(searchcore_fconfig onnx-models.def) +install_config_definition(onnx-models.def vespa.config.search.core.onnx-models.def) vespa_generate_config(searchcore_fconfig hwinfo.def) diff --git a/searchcore/src/vespa/searchcore/proton/matching/CMakeLists.txt b/searchcore/src/vespa/searchcore/proton/matching/CMakeLists.txt index ffbab597118..a4688b5fdca 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/CMakeLists.txt +++ b/searchcore/src/vespa/searchcore/proton/matching/CMakeLists.txt @@ -20,6 +20,7 @@ vespa_add_library(searchcore_matching STATIC match_tools.cpp matcher.cpp matching_stats.cpp + onnx_models.cpp partial_result.cpp query.cpp queryenvironment.cpp diff --git a/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.cpp b/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.cpp index d6d185ccab4..5743a3d44d6 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.cpp @@ -64,13 +64,15 @@ IndexEnvironment::insertField(const search::fef::FieldInfo &field) IndexEnvironment::IndexEnvironment(uint32_t distributionKey, const search::index::Schema &schema, const search::fef::Properties &props, - const IConstantValueRepo &constantValueRepo) + const IConstantValueRepo &constantValueRepo, + OnnxModels onnxModels) : _tableManager(), _properties(props), _fieldNames(), _fields(), _motivation(UNKNOWN), _constantValueRepo(constantValueRepo), + _onnxModels(std::move(onnxModels)), _distributionKey(distributionKey) { _tableManager.addFactory(std::make_shared<search::fef::FunctionTableFactory>(256)); @@ -129,6 +131,15 @@ IndexEnvironment::hintFieldAccess(uint32_t ) const { } void IndexEnvironment::hintAttributeAccess(const string &) const { } +std::optional<vespalib::string> +IndexEnvironment::getOnnxModelFullPath(const vespalib::string &name) const +{ + if (const auto model = _onnxModels.getModel(name)) { + return model->filePath; + } + return std::nullopt; +} + IndexEnvironment::~IndexEnvironment() = default; } diff --git a/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.h b/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.h index 7da45909577..d0e9a516cd0 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.h +++ b/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.h @@ -2,6 +2,7 @@ #pragma once +#include "onnx_models.h" #include "i_constant_value_repo.h" #include <vespa/searchlib/fef/fieldinfo.h> #include <vespa/searchlib/fef/iindexenvironment.h> @@ -25,6 +26,7 @@ private: std::vector<search::fef::FieldInfo> _fields; mutable FeatureMotivation _motivation; const IConstantValueRepo &_constantValueRepo; + OnnxModels _onnxModels; uint32_t _distributionKey; @@ -44,11 +46,13 @@ public: * @param schema the index schema * @param props config * @param constantValueRepo repo used to access constant values for ranking + * @param onnxModels processed config about onnx models **/ IndexEnvironment(uint32_t distributionKey, const search::index::Schema &schema, const search::fef::Properties &props, - const IConstantValueRepo &constantValueRepo); + const IConstantValueRepo &constantValueRepo, + OnnxModels onnxModels); const search::fef::Properties &getProperties() const override; uint32_t getNumFields() const override; @@ -65,6 +69,7 @@ public: return _constantValueRepo.getConstant(name); } + std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const override; ~IndexEnvironment() override; }; diff --git a/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp b/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp index 735070002eb..98c4fdaa89a 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp @@ -99,8 +99,8 @@ handleGroupingSession(SessionManager &sessionMgr, GroupingContext & groupingCont } // namespace proton::matching::<unnamed> Matcher::Matcher(const search::index::Schema &schema, const Properties &props, const vespalib::Clock &clock, - QueryLimiter &queryLimiter, const IConstantValueRepo &constantValueRepo, uint32_t distributionKey) - : _indexEnv(distributionKey, schema, props, constantValueRepo), + QueryLimiter &queryLimiter, const IConstantValueRepo &constantValueRepo, OnnxModels onnxModels, uint32_t distributionKey) + : _indexEnv(distributionKey, schema, props, constantValueRepo, std::move(onnxModels)), _blueprintFactory(), _rankSetup(), _viewResolver(ViewResolver::createFromSchema(schema)), diff --git a/searchcore/src/vespa/searchcore/proton/matching/matcher.h b/searchcore/src/vespa/searchcore/proton/matching/matcher.h index 243fdad63ae..39d1fa38007 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/matcher.h +++ b/searchcore/src/vespa/searchcore/proton/matching/matcher.h @@ -89,7 +89,8 @@ public: **/ Matcher(const search::index::Schema &schema, const Properties &props, const vespalib::Clock &clock, QueryLimiter &queryLimiter, - const IConstantValueRepo &constantValueRepo, uint32_t distributionKey); + const IConstantValueRepo &constantValueRepo, OnnxModels onnxModels, + uint32_t distributionKey); const search::fef::IIndexEnvironment &get_index_env() const { return _indexEnv; } diff --git a/searchcore/src/vespa/searchcore/proton/matching/onnx_models.cpp b/searchcore/src/vespa/searchcore/proton/matching/onnx_models.cpp new file mode 100644 index 00000000000..bdcf3e21d8e --- /dev/null +++ b/searchcore/src/vespa/searchcore/proton/matching/onnx_models.cpp @@ -0,0 +1,54 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "onnx_models.h" + +namespace proton::matching { + +OnnxModels::Model::Model(const vespalib::string &name_in, + const vespalib::string &filePath_in) + : name(name_in), + filePath(filePath_in) +{ +} + +OnnxModels::Model::~Model() = default; + +bool +OnnxModels::Model::operator==(const Model &rhs) const +{ + return (name == rhs.name) && + (filePath == rhs.filePath); +} + +OnnxModels::OnnxModels() + : _models() +{ +} + +OnnxModels::~OnnxModels() = default; + +OnnxModels::OnnxModels(const Vector &models) + : _models() +{ + for (const auto &model : models) { + _models.insert(std::make_pair(model.name, model)); + } +} + +bool +OnnxModels::operator==(const OnnxModels &rhs) const +{ + return _models == rhs._models; +} + +const OnnxModels::Model * +OnnxModels::getModel(const vespalib::string &name) const +{ + auto itr = _models.find(name); + if (itr != _models.end()) { + return &itr->second; + } + return nullptr; +} + +} diff --git a/searchcore/src/vespa/searchcore/proton/matching/onnx_models.h b/searchcore/src/vespa/searchcore/proton/matching/onnx_models.h new file mode 100644 index 00000000000..fdaae657711 --- /dev/null +++ b/searchcore/src/vespa/searchcore/proton/matching/onnx_models.h @@ -0,0 +1,43 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/vespalib/stllike/string.h> +#include <map> +#include <vector> + +namespace proton::matching { + +/** + * Class representing a set of configured onnx models, with full path + * for where the models are stored on disk. + */ +class OnnxModels { +public: + struct Model { + vespalib::string name; + vespalib::string filePath; + + Model(const vespalib::string &name_in, + const vespalib::string &filePath_in); + ~Model(); + bool operator==(const Model &rhs) const; + }; + + using Vector = std::vector<Model>; + +private: + using Map = std::map<vespalib::string, Model>; + Map _models; + +public: + using SP = std::shared_ptr<OnnxModels>; + OnnxModels(); + OnnxModels(const Vector &models); + ~OnnxModels(); + bool operator==(const OnnxModels &rhs) const; + const Model *getModel(const vespalib::string &name) const; + size_t size() const { return _models.size(); } +}; + +} diff --git a/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.cpp b/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.cpp index 712bc553d08..8bcf8440101 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.cpp @@ -11,6 +11,7 @@ #include <vespa/document/config/config-documenttypes.h> #include <vespa/document/repo/documenttyperepo.h> #include <vespa/searchcore/config/config-ranking-constants.h> +#include <vespa/searchcore/config/config-onnx-models.h> #include <vespa/searchcore/proton/attribute/attribute_aspect_delayer.h> #include <vespa/searchcore/proton/common/document_type_inspector.h> #include <vespa/searchcore/proton/common/indexschema_inspector.h> @@ -25,12 +26,14 @@ using search::TuneFileDocumentDB; using search::index::Schema; using vespa::config::search::SummarymapConfig; using vespa::config::search::core::RankingConstantsConfig; +using vespa::config::search::core::OnnxModelsConfig; namespace proton { DocumentDBConfig::ComparisonResult::ComparisonResult() : rankProfilesChanged(false), rankingConstantsChanged(false), + onnxModelsChanged(false), indexschemaChanged(false), attributesChanged(false), summaryChanged(false), @@ -51,6 +54,7 @@ DocumentDBConfig::DocumentDBConfig( int64_t generation, const RankProfilesConfigSP &rankProfiles, const RankingConstants::SP &rankingConstants, + const OnnxModels::SP &onnxModels, const IndexschemaConfigSP &indexschema, const AttributesConfigSP &attributes, const SummaryConfigSP &summary, @@ -70,6 +74,7 @@ DocumentDBConfig::DocumentDBConfig( _generation(generation), _rankProfiles(rankProfiles), _rankingConstants(rankingConstants), + _onnxModels(onnxModels), _indexschema(indexschema), _attributes(attributes), _summary(summary), @@ -94,6 +99,7 @@ DocumentDBConfig(const DocumentDBConfig &cfg) _generation(cfg._generation), _rankProfiles(cfg._rankProfiles), _rankingConstants(cfg._rankingConstants), + _onnxModels(cfg._onnxModels), _indexschema(cfg._indexschema), _attributes(cfg._attributes), _summary(cfg._summary), @@ -117,6 +123,7 @@ DocumentDBConfig::operator==(const DocumentDBConfig & rhs) const { return equals<RankProfilesConfig>(_rankProfiles.get(), rhs._rankProfiles.get()) && equals<RankingConstants>(_rankingConstants.get(), rhs._rankingConstants.get()) && + equals<OnnxModels>(_onnxModels.get(), rhs._onnxModels.get()) && equals<IndexschemaConfig>(_indexschema.get(), rhs._indexschema.get()) && equals<AttributesConfig>(_attributes.get(), rhs._attributes.get()) && equals<SummaryConfig>(_summary.get(), rhs._summary.get()) && @@ -138,6 +145,7 @@ DocumentDBConfig::compare(const DocumentDBConfig &rhs) const ComparisonResult retval; retval.rankProfilesChanged = !equals<RankProfilesConfig>(_rankProfiles.get(), rhs._rankProfiles.get()); retval.rankingConstantsChanged = !equals<RankingConstants>(_rankingConstants.get(), rhs._rankingConstants.get()); + retval.onnxModelsChanged = !equals<OnnxModels>(_onnxModels.get(), rhs._onnxModels.get()); retval.indexschemaChanged = !equals<IndexschemaConfig>(_indexschema.get(), rhs._indexschema.get()); retval.attributesChanged = !equals<AttributesConfig>(_attributes.get(), rhs._attributes.get()); retval.summaryChanged = !equals<SummaryConfig>(_summary.get(), rhs._summary.get()); @@ -161,6 +169,7 @@ DocumentDBConfig::valid() const { return _rankProfiles && _rankingConstants && + _onnxModels && _indexschema && _attributes && _summary && @@ -201,6 +210,7 @@ DocumentDBConfig::makeReplayConfig(const SP & orig) o._generation, emptyConfig(o._rankProfiles), std::make_shared<RankingConstants>(), + std::make_shared<OnnxModels>(), o._indexschema, o._attributes, o._summary, @@ -241,6 +251,7 @@ DocumentDBConfig::newFromAttributesConfig(const AttributesConfigSP &attributes) _generation, _rankProfiles, _rankingConstants, + _onnxModels, _indexschema, attributes, _summary, @@ -276,6 +287,7 @@ DocumentDBConfig::makeDelayedAttributeAspectConfig(const SP &newCfg, const Docum (n._generation, n._rankProfiles, n._rankingConstants, + n._onnxModels, n._indexschema, attributeAspectDelayer.getAttributesConfig(), n._summary, diff --git a/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.h b/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.h index c4083c3db7a..09fdd5b5b0a 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.h +++ b/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.h @@ -6,6 +6,7 @@ #include <vespa/searchlib/common/tunefileinfo.h> #include <vespa/searchcommon/common/schema.h> #include <vespa/searchcore/proton/matching/ranking_constants.h> +#include <vespa/searchcore/proton/matching/onnx_models.h> #include <vespa/config/retriever/configkeyset.h> #include <vespa/config/retriever/configsnapshot.h> #include <vespa/searchlib/docstore/logdocumentstore.h> @@ -36,6 +37,7 @@ public: public: bool rankProfilesChanged; bool rankingConstantsChanged; + bool onnxModelsChanged; bool indexschemaChanged; bool attributesChanged; bool summaryChanged; @@ -54,6 +56,7 @@ public: ComparisonResult(); ComparisonResult &setRankProfilesChanged(bool val) { rankProfilesChanged = val; return *this; } ComparisonResult &setRankingConstantsChanged(bool val) { rankingConstantsChanged = val; return *this; } + ComparisonResult &setOnnxModelsChanged(bool val) { onnxModelsChanged = val; return *this; } ComparisonResult &setIndexschemaChanged(bool val) { indexschemaChanged = val; return *this; } ComparisonResult &setAttributesChanged(bool val) { attributesChanged = val; return *this; } ComparisonResult &setSummaryChanged(bool val) { summaryChanged = val; return *this; } @@ -91,6 +94,7 @@ public: using RankProfilesConfig = const vespa::config::search::internal::InternalRankProfilesType; using RankProfilesConfigSP = std::shared_ptr<RankProfilesConfig>; using RankingConstants = matching::RankingConstants; + using OnnxModels = matching::OnnxModels; using SummaryConfig = const vespa::config::search::internal::InternalSummaryType; using SummaryConfigSP = std::shared_ptr<SummaryConfig>; using SummarymapConfig = const vespa::config::search::internal::InternalSummarymapType; @@ -109,6 +113,7 @@ private: int64_t _generation; RankProfilesConfigSP _rankProfiles; RankingConstants::SP _rankingConstants; + OnnxModels::SP _onnxModels; IndexschemaConfigSP _indexschema; AttributesConfigSP _attributes; SummaryConfigSP _summary; @@ -145,6 +150,7 @@ public: DocumentDBConfig(int64_t generation, const RankProfilesConfigSP &rankProfiles, const RankingConstants::SP &rankingConstants, + const OnnxModels::SP &onnxModels, const IndexschemaConfigSP &indexschema, const AttributesConfigSP &attributes, const SummaryConfigSP &summary, @@ -172,6 +178,7 @@ public: const RankProfilesConfig &getRankProfilesConfig() const { return *_rankProfiles; } const RankingConstants &getRankingConstants() const { return *_rankingConstants; } + const OnnxModels &getOnnxModels() const { return *_onnxModels; } const IndexschemaConfig &getIndexschemaConfig() const { return *_indexschema; } const AttributesConfig &getAttributesConfig() const { return *_attributes; } const SummaryConfig &getSummaryConfig() const { return *_summary; } @@ -180,6 +187,7 @@ public: const DocumenttypesConfig &getDocumenttypesConfig() const { return *_documenttypes; } const RankProfilesConfigSP &getRankProfilesConfigSP() const { return _rankProfiles; } const RankingConstants::SP &getRankingConstantsSP() const { return _rankingConstants; } + const OnnxModels::SP &getOnnxModelsSP() const { return _onnxModels; } const IndexschemaConfigSP &getIndexschemaConfigSP() const { return _indexschema; } const AttributesConfigSP &getAttributesConfigSP() const { return _attributes; } const SummaryConfigSP &getSummaryConfigSP() const { return _summary; } diff --git a/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp b/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp index 68e65acb87d..a8996abc856 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp @@ -4,6 +4,7 @@ #include "bootstrapconfig.h" #include <vespa/searchcore/proton/common/hw_info.h> #include <vespa/searchcore/config/config-ranking-constants.h> +#include <vespa/searchcore/config/config-onnx-models.h> #include <vespa/config-imported-fields.h> #include <vespa/config-rank-profiles.h> #include <vespa/config-summarymap.h> @@ -30,6 +31,7 @@ using search::TuneFileDocumentDB; using search::index::Schema; using search::index::SchemaBuilder; using proton::matching::RankingConstants; +using proton::matching::OnnxModels; using vespalib::compression::CompressionConfig; using search::LogDocumentStore; using search::LogDataStore; @@ -46,6 +48,7 @@ DocumentDBConfigManager::createConfigKeySet() const ConfigKeySet set; set.add<RankProfilesConfig, RankingConstantsConfig, + OnnxModelsConfig, IndexschemaConfig, AttributesConfig, SummaryConfig, @@ -228,6 +231,7 @@ DocumentDBConfigManager::update(const ConfigSnapshot &snapshot) { using RankProfilesConfigSP = DocumentDBConfig::RankProfilesConfigSP; using RankingConstantsConfigSP = std::shared_ptr<vespa::config::search::core::RankingConstantsConfig>; + using OnnxModelsConfigSP = std::shared_ptr<vespa::config::search::core::OnnxModelsConfig>; using IndexschemaConfigSP = DocumentDBConfig::IndexschemaConfigSP; using SummaryConfigSP = DocumentDBConfig::SummaryConfigSP; using SummarymapConfigSP = DocumentDBConfig::SummarymapConfigSP; @@ -238,6 +242,7 @@ DocumentDBConfigManager::update(const ConfigSnapshot &snapshot) DocumentDBConfig::SP current = _pendingConfigSnapshot; RankProfilesConfigSP newRankProfilesConfig; matching::RankingConstants::SP newRankingConstants; + matching::OnnxModels::SP newOnnxModels; IndexschemaConfigSP newIndexschemaConfig; MaintenanceConfigSP oldMaintenanceConfig; MaintenanceConfigSP newMaintenanceConfig; @@ -261,6 +266,7 @@ DocumentDBConfigManager::update(const ConfigSnapshot &snapshot) if (current) { newRankProfilesConfig = current->getRankProfilesConfigSP(); newRankingConstants = current->getRankingConstantsSP(); + newOnnxModels = current->getOnnxModelsSP(); newIndexschemaConfig = current->getIndexschemaConfigSP(); oldMaintenanceConfig = current->getMaintenanceConfigSP(); currentGeneration = current->getGeneration(); @@ -294,6 +300,31 @@ DocumentDBConfigManager::update(const ConfigSnapshot &snapshot) } newRankingConstants = std::make_shared<RankingConstants>(constants); } + if (snapshot.isChanged<OnnxModelsConfig>(_configId, currentGeneration)) { + OnnxModelsConfigSP newOnnxModelsConfig = OnnxModelsConfigSP( + snapshot.getConfig<OnnxModelsConfig>(_configId)); + const vespalib::string &spec = _bootstrapConfig->getFiledistributorrpcConfig().connectionspec; + OnnxModels::Vector models; + if (spec != "") { + config::RpcFileAcquirer fileAcquirer(spec); + vespalib::TimeBox timeBox(5*60, 5); + for (const OnnxModelsConfig::Model &rc : newOnnxModelsConfig->model) { + vespalib::string filePath; + LOG(info, "Waiting for file acquirer (name='%s', ref='%s')", + rc.name.c_str(), rc.fileref.c_str()); + while (timeBox.hasTimeLeft() && (filePath == "")) { + filePath = fileAcquirer.wait_for(rc.fileref, timeBox.timeLeft()); + if (filePath == "") { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + } + LOG(info, "Got file path from file acquirer: '%s' (name='%s', ref='%s')", + filePath.c_str(), rc.name.c_str(), rc.fileref.c_str()); + models.emplace_back(rc.name, filePath); + } + } + newOnnxModels = std::make_shared<OnnxModels>(models); + } if (snapshot.isChanged<IndexschemaConfig>(_configId, currentGeneration)) { newIndexschemaConfig = snapshot.getConfig<IndexschemaConfig>(_configId); search::index::Schema schema; @@ -318,6 +349,7 @@ DocumentDBConfigManager::update(const ConfigSnapshot &snapshot) auto newSnapshot = std::make_shared<DocumentDBConfig>(generation, newRankProfilesConfig, newRankingConstants, + newOnnxModels, newIndexschemaConfig, filterImportedAttributes(newAttributesConfig), newSummaryConfig, diff --git a/searchcore/src/vespa/searchcore/proton/server/fileconfigmanager.cpp b/searchcore/src/vespa/searchcore/proton/server/fileconfigmanager.cpp index 395eb5a0ea2..e66043aa422 100644 --- a/searchcore/src/vespa/searchcore/proton/server/fileconfigmanager.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/fileconfigmanager.cpp @@ -345,6 +345,7 @@ FileConfigManager::loadConfig(const DocumentDBConfig ¤tSnapshot, config::DirSpec spec(snapDir); addEmptyFile(snapDir, "ranking-constants.cfg"); + addEmptyFile(snapDir, "onnx-models.cfg"); addEmptyFile(snapDir, "imported-fields.cfg"); DocumentDBConfigHelper dbc(spec, _docTypeName); diff --git a/searchcore/src/vespa/searchcore/proton/server/matchers.cpp b/searchcore/src/vespa/searchcore/proton/server/matchers.cpp index 29e940ca26d..53c96a81134 100644 --- a/searchcore/src/vespa/searchcore/proton/server/matchers.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/matchers.cpp @@ -2,16 +2,19 @@ #include "matchers.h" #include <vespa/searchcore/proton/matching/matcher.h> +#include <vespa/searchcore/proton/matching/onnx_models.h> #include <vespa/vespalib/stllike/hash_map.hpp> namespace proton { +using matching::OnnxModels; + Matchers::Matchers(const vespalib::Clock &clock, matching::QueryLimiter &queryLimiter, const matching::IConstantValueRepo &constantValueRepo) : _rpmap(), _fallback(new matching::Matcher(search::index::Schema(), search::fef::Properties(), - clock, queryLimiter, constantValueRepo, -1)), + clock, queryLimiter, constantValueRepo, OnnxModels(), -1)), _default() { } diff --git a/searchcore/src/vespa/searchcore/proton/server/reconfig_params.cpp b/searchcore/src/vespa/searchcore/proton/server/reconfig_params.cpp index 8ec41ae3e3c..4fc241571ac 100644 --- a/searchcore/src/vespa/searchcore/proton/server/reconfig_params.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/reconfig_params.cpp @@ -15,6 +15,7 @@ ReconfigParams::configHasChanged() const { return _res.rankProfilesChanged || _res.rankingConstantsChanged || + _res.onnxModelsChanged || _res.indexschemaChanged || _res.attributesChanged || _res.summaryChanged || @@ -38,7 +39,7 @@ ReconfigParams::shouldSchemaChange() const bool ReconfigParams::shouldMatchersChange() const { - return _res.rankProfilesChanged || _res.rankingConstantsChanged || shouldSchemaChange(); + return _res.rankProfilesChanged || _res.rankingConstantsChanged || _res.onnxModelsChanged || shouldSchemaChange(); } bool diff --git a/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.cpp b/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.cpp index 713256fd809..8f34484dfe2 100644 --- a/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.cpp @@ -21,6 +21,7 @@ using vespa::config::search::RankProfilesConfig; namespace proton { using matching::Matcher; +using matching::OnnxModels; typedef AttributeReprocessingInitializer::Config ARIConfig; @@ -122,7 +123,8 @@ SearchableDocSubDBConfigurer::~SearchableDocSubDBConfigurer() = default; Matchers::UP SearchableDocSubDBConfigurer::createMatchers(const Schema::SP &schema, - const RankProfilesConfig &cfg) + const RankProfilesConfig &cfg, + const OnnxModels &onnxModels) { auto newMatchers = std::make_unique<Matchers>(_clock, _queryLimiter, _constantValueRepo); for (const auto &profile : cfg.rankprofile) { @@ -132,7 +134,7 @@ SearchableDocSubDBConfigurer::createMatchers(const Schema::SP &schema, properties.add(property.name, property.value); } // schema instance only used during call. - auto profptr = std::make_shared<Matcher>(*schema, properties, _clock, _queryLimiter, _constantValueRepo, _distributionKey); + auto profptr = std::make_shared<Matcher>(*schema, properties, _clock, _queryLimiter, _constantValueRepo, onnxModels, _distributionKey); newMatchers->add(name, profptr); } return newMatchers; @@ -200,7 +202,9 @@ SearchableDocSubDBConfigurer::reconfigure(const DocumentDBConfig &newConfig, Matchers::SP matchers = searchView->getMatchers(); if (params.shouldMatchersChange()) { _constantValueRepo.reconfigure(newConfig.getRankingConstants()); - Matchers::SP newMatchers = createMatchers(newConfig.getSchemaSP(),newConfig.getRankProfilesConfig()); + Matchers::SP newMatchers = createMatchers(newConfig.getSchemaSP(), + newConfig.getRankProfilesConfig(), + newConfig.getOnnxModels()); matchers = newMatchers; shouldMatchViewChange = true; } diff --git a/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.h b/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.h index 6b836544735..0f86520fd0b 100644 --- a/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.h +++ b/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.h @@ -80,7 +80,8 @@ public: ~SearchableDocSubDBConfigurer(); Matchers::UP createMatchers(const search::index::Schema::SP &schema, - const vespa::config::search::RankProfilesConfig &cfg); + const vespa::config::search::RankProfilesConfig &cfg, + const proton::matching::OnnxModels &onnxModels); void reconfigureIndexSearchable(); diff --git a/searchcore/src/vespa/searchcore/proton/server/searchabledocsubdb.cpp b/searchcore/src/vespa/searchcore/proton/server/searchabledocsubdb.cpp index 592d1bc1b52..23ab568c767 100644 --- a/searchcore/src/vespa/searchcore/proton/server/searchabledocsubdb.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/searchabledocsubdb.cpp @@ -201,7 +201,7 @@ SearchableDocSubDB::initViews(const DocumentDBConfig &configSnapshot, const Sess const Schema::SP &schema = configSnapshot.getSchemaSP(); const IIndexManager::SP &indexMgr = getIndexManager(); _constantValueRepo.reconfigure(configSnapshot.getRankingConstants()); - Matchers::SP matchers(_configurer.createMatchers(schema, configSnapshot.getRankProfilesConfig()).release()); + Matchers::SP matchers = _configurer.createMatchers(schema, configSnapshot.getRankProfilesConfig(), configSnapshot.getOnnxModels()); auto matchView = std::make_shared<MatchView>(std::move(matchers), indexMgr->getSearchable(), attrMgr, sessionManager, _metaStoreCtx, _docIdLimit); _rSearchView.set(SearchView::create( diff --git a/searchcore/src/vespa/searchcore/proton/test/documentdb_config_builder.cpp b/searchcore/src/vespa/searchcore/proton/test/documentdb_config_builder.cpp index 5cd092d20ba..a2366a3cb92 100644 --- a/searchcore/src/vespa/searchcore/proton/test/documentdb_config_builder.cpp +++ b/searchcore/src/vespa/searchcore/proton/test/documentdb_config_builder.cpp @@ -31,6 +31,7 @@ DocumentDBConfigBuilder::DocumentDBConfigBuilder(int64_t generation, : _generation(generation), _rankProfiles(std::make_shared<RankProfilesConfig>()), _rankingConstants(std::make_shared<matching::RankingConstants>()), + _onnxModels(std::make_shared<matching::OnnxModels>()), _indexschema(std::make_shared<IndexschemaConfig>()), _attributes(std::make_shared<AttributesConfig>()), _summary(std::make_shared<SummaryConfig>()), @@ -52,6 +53,7 @@ DocumentDBConfigBuilder::DocumentDBConfigBuilder(const DocumentDBConfig &cfg) : _generation(cfg.getGeneration()), _rankProfiles(cfg.getRankProfilesConfigSP()), _rankingConstants(cfg.getRankingConstantsSP()), + _onnxModels(cfg.getOnnxModelsSP()), _indexschema(cfg.getIndexschemaConfigSP()), _attributes(cfg.getAttributesConfigSP()), _summary(cfg.getSummaryConfigSP()), @@ -77,6 +79,7 @@ DocumentDBConfigBuilder::build() _generation, _rankProfiles, _rankingConstants, + _onnxModels, _indexschema, _attributes, _summary, diff --git a/searchcore/src/vespa/searchcore/proton/test/documentdb_config_builder.h b/searchcore/src/vespa/searchcore/proton/test/documentdb_config_builder.h index 4a515cf3b19..68fb5454eef 100644 --- a/searchcore/src/vespa/searchcore/proton/test/documentdb_config_builder.h +++ b/searchcore/src/vespa/searchcore/proton/test/documentdb_config_builder.h @@ -14,6 +14,7 @@ private: int64_t _generation; DocumentDBConfig::RankProfilesConfigSP _rankProfiles; DocumentDBConfig::RankingConstants::SP _rankingConstants; + DocumentDBConfig::OnnxModels::SP _onnxModels; DocumentDBConfig::IndexschemaConfigSP _indexschema; DocumentDBConfig::AttributesConfigSP _attributes; DocumentDBConfig::SummaryConfigSP _summary; @@ -54,6 +55,10 @@ public: _rankingConstants = rankingConstants_in; return *this; } + DocumentDBConfigBuilder &onnxModels(const DocumentDBConfig::OnnxModels::SP &onnxModels_in) { + _onnxModels = onnxModels_in; + return *this; + } DocumentDBConfigBuilder &importedFields(const DocumentDBConfig::ImportedFieldsConfigSP &importedFields_in) { _importedFields = importedFields_in; return *this; diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp index 7a200a46ab2..826984832f6 100644 --- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp +++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp @@ -58,9 +58,7 @@ struct OnnxFeatureTest : ::testing::Test { indexEnv.getProperties().add(expr_name, expr); } void add_onnx(const vespalib::string &name, const vespalib::string &file) { - vespalib::string feature_name = onnx_feature(name); - vespalib::string file_name = feature_name + ".fileref"; - indexEnv.getProperties().add(file_name, file); + indexEnv.addOnnxModel(name, file); } void compile(const vespalib::string &seed) { resolver->addSeed(seed); diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp index 7433021b9b6..b24392ce629 100644 --- a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp @@ -66,15 +66,14 @@ OnnxBlueprint::setup(const IIndexEnvironment &env, auto optimize = (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP) ? Onnx::Optimize::DISABLE : Onnx::Optimize::ENABLE; - - // Note: Using the fileref property with the model name as - // fallback to get a file name. This needs to be replaced with an - // actual file reference obtained through config when available. - vespalib::string file_name = env.getProperties().lookup(getName(), "fileref").get(params[0].getValue()); + auto file_name = env.getOnnxModelFullPath(params[0].getValue()); + if (!file_name.has_value()) { + return fail("no model with name '%s' found", params[0].getValue().c_str()); + } try { - _model = std::make_unique<Onnx>(file_name, optimize); + _model = std::make_unique<Onnx>(file_name.value(), optimize); } catch (std::exception &ex) { - return fail("Model setup failed: %s", ex.what()); + return fail("model setup failed: %s", ex.what()); } Onnx::WirePlanner planner; for (size_t i = 0; i < _model->inputs().size(); ++i) { diff --git a/searchlib/src/vespa/searchlib/fef/iindexenvironment.h b/searchlib/src/vespa/searchlib/fef/iindexenvironment.h index bdeead3e852..26e88a98033 100644 --- a/searchlib/src/vespa/searchlib/fef/iindexenvironment.h +++ b/searchlib/src/vespa/searchlib/fef/iindexenvironment.h @@ -3,6 +3,7 @@ #pragma once #include <vespa/vespalib/stllike/string.h> +#include <optional> namespace vespalib::eval { struct ConstantValue; } @@ -120,6 +121,11 @@ public: */ virtual std::unique_ptr<vespalib::eval::ConstantValue> getConstantValue(const vespalib::string &name) const = 0; + /** + * Get the full path of the file containing the given onnx model + **/ + virtual std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const = 0; + virtual uint32_t getDistributionKey() const = 0; /** diff --git a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp index e998e4d18bd..6e2e0b88fbb 100644 --- a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp +++ b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp @@ -54,4 +54,21 @@ IndexEnvironment::addConstantValue(const vespalib::string &name, (void) insertRes; } +std::optional<vespalib::string> +IndexEnvironment::getOnnxModelFullPath(const vespalib::string &name) const +{ + auto pos = _models.find(name); + if (pos != _models.end()) { + return pos->second; + } + return std::nullopt; +} + +void +IndexEnvironment::addOnnxModel(const vespalib::string &name, const vespalib::string &path) +{ + _models[name] = path; +} + + } diff --git a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h index d84cebc7f52..6602d9f8ee9 100644 --- a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h +++ b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h @@ -47,6 +47,7 @@ public: }; using ConstantsMap = std::map<vespalib::string, Constant>; + using ModelMap = std::map<vespalib::string, vespalib::string>; IndexEnvironment(); ~IndexEnvironment(); @@ -83,6 +84,9 @@ public: vespalib::eval::ValueType type, std::unique_ptr<vespalib::eval::Value> value); + std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const override; + void addOnnxModel(const vespalib::string &name, const vespalib::string &path); + private: IndexEnvironment(const IndexEnvironment &); // hide IndexEnvironment & operator=(const IndexEnvironment &); // hide @@ -93,6 +97,7 @@ private: AttributeMap _attrMap; TableManager _tableMan; ConstantsMap _constants; + ModelMap _models; }; } diff --git a/streamingvisitors/src/vespa/searchvisitor/indexenvironment.h b/streamingvisitors/src/vespa/searchvisitor/indexenvironment.h index ac6836b08c5..3bbfb0b23f9 100644 --- a/streamingvisitors/src/vespa/searchvisitor/indexenvironment.h +++ b/streamingvisitors/src/vespa/searchvisitor/indexenvironment.h @@ -73,6 +73,10 @@ public: return vespalib::eval::ConstantValue::UP(); } + std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &) const override { + return std::nullopt; + } + bool addField(const vespalib::string & name, bool isAttribute); search::fef::Properties & getProperties() { return _properties; } |