diff options
51 files changed, 952 insertions, 309 deletions
diff --git a/client/js/app/yarn.lock b/client/js/app/yarn.lock index e9dc5bf25fe..ebf1c99db13 100644 --- a/client/js/app/yarn.lock +++ b/client/js/app/yarn.lock @@ -1311,70 +1311,70 @@ resolved "https://registry.yarnpkg.com/@remix-run/router/-/router-1.14.2.tgz#4d58f59908d9197ba3179310077f25c88e49ed17" integrity sha512-ACXpdMM9hmKZww21yEqWwiLws/UPLhNKvimN8RrYSqPSvB3ov7sLvAcfvaxePeLvccTQKGdkDIhLYApZVDFuKg== -"@rollup/rollup-android-arm-eabi@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.9.4.tgz#b1094962742c1a0349587040bc06185e2a667c9b" - integrity sha512-ub/SN3yWqIv5CWiAZPHVS1DloyZsJbtXmX4HxUTIpS0BHm9pW5iYBo2mIZi+hE3AeiTzHz33blwSnhdUo+9NpA== - -"@rollup/rollup-android-arm64@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.9.4.tgz#96eb86fb549e05b187f2ad06f51d191a23cb385a" - integrity sha512-ehcBrOR5XTl0W0t2WxfTyHCR/3Cq2jfb+I4W+Ch8Y9b5G+vbAecVv0Fx/J1QKktOrgUYsIKxWAKgIpvw56IFNA== - -"@rollup/rollup-darwin-arm64@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.9.4.tgz#2456630c007cc5905cb368acb9ff9fc04b2d37be" - integrity sha512-1fzh1lWExwSTWy8vJPnNbNM02WZDS8AW3McEOb7wW+nPChLKf3WG2aG7fhaUmfX5FKw9zhsF5+MBwArGyNM7NA== - -"@rollup/rollup-darwin-x64@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.9.4.tgz#97742214fc7dfd47a0f74efba6f5ae264e29c70c" - integrity sha512-Gc6cukkF38RcYQ6uPdiXi70JB0f29CwcQ7+r4QpfNpQFVHXRd0DfWFidoGxjSx1DwOETM97JPz1RXL5ISSB0pA== - -"@rollup/rollup-linux-arm-gnueabihf@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.9.4.tgz#cd933e61d6f689c9cdefde424beafbd92cfe58e2" - integrity sha512-g21RTeFzoTl8GxosHbnQZ0/JkuFIB13C3T7Y0HtKzOXmoHhewLbVTFBQZu+z5m9STH6FZ7L/oPgU4Nm5ErN2fw== - -"@rollup/rollup-linux-arm64-gnu@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.9.4.tgz#33b09bf462f1837afc1e02a1b352af6b510c78a6" - integrity sha512-TVYVWD/SYwWzGGnbfTkrNpdE4HON46orgMNHCivlXmlsSGQOx/OHHYiQcMIOx38/GWgwr/po2LBn7wypkWw/Mg== - -"@rollup/rollup-linux-arm64-musl@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.9.4.tgz#50257fb248832c2308064e3764a16273b6ee4615" - integrity sha512-XcKvuendwizYYhFxpvQ3xVpzje2HHImzg33wL9zvxtj77HvPStbSGI9czrdbfrf8DGMcNNReH9pVZv8qejAQ5A== - -"@rollup/rollup-linux-riscv64-gnu@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.9.4.tgz#09589e4e1a073cf56f6249b77eb6c9a8e9b613a8" - integrity sha512-LFHS/8Q+I9YA0yVETyjonMJ3UA+DczeBd/MqNEzsGSTdNvSJa1OJZcSH8GiXLvcizgp9AlHs2walqRcqzjOi3A== - -"@rollup/rollup-linux-x64-gnu@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.9.4.tgz#bd312bb5b5f02e54d15488605d15cfd3f90dda7c" - integrity sha512-dIYgo+j1+yfy81i0YVU5KnQrIJZE8ERomx17ReU4GREjGtDW4X+nvkBak2xAUpyqLs4eleDSj3RrV72fQos7zw== - -"@rollup/rollup-linux-x64-musl@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.9.4.tgz#25b3bede85d86438ce28cc642842d10d867d40e9" - integrity sha512-RoaYxjdHQ5TPjaPrLsfKqR3pakMr3JGqZ+jZM0zP2IkDtsGa4CqYaWSfQmZVgFUCgLrTnzX+cnHS3nfl+kB6ZQ== - -"@rollup/rollup-win32-arm64-msvc@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.9.4.tgz#95957067eb107f571da1d81939f017d37b4958d3" - integrity sha512-T8Q3XHV+Jjf5e49B4EAaLKV74BbX7/qYBRQ8Wop/+TyyU0k+vSjiLVSHNWdVd1goMjZcbhDmYZUYW5RFqkBNHQ== - -"@rollup/rollup-win32-ia32-msvc@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.9.4.tgz#71b6facad976db527863f698692c6964c0b6e10e" - integrity sha512-z+JQ7JirDUHAsMecVydnBPWLwJjbppU+7LZjffGf+Jvrxq+dVjIE7By163Sc9DKc3ADSU50qPVw0KonBS+a+HQ== - -"@rollup/rollup-win32-x64-msvc@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.9.4.tgz#16295ccae354707c9bc6842906bdeaad4f3ba7a5" - integrity sha512-LfdGXCV9rdEify1oxlN9eamvDSjv9md9ZVMAbNHA87xqIfFCxImxan9qZ8+Un54iK2nnqPlbnSi4R54ONtbWBw== +"@rollup/rollup-android-arm-eabi@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.9.5.tgz#b752b6c88a14ccfcbdf3f48c577ccc3a7f0e66b9" + integrity sha512-idWaG8xeSRCfRq9KpRysDHJ/rEHBEXcHuJ82XY0yYFIWnLMjZv9vF/7DOq8djQ2n3Lk6+3qfSH8AqlmHlmi1MA== + +"@rollup/rollup-android-arm64@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.9.5.tgz#33757c3a448b9ef77b6f6292d8b0ec45c87e9c1a" + integrity sha512-f14d7uhAMtsCGjAYwZGv6TwuS3IFaM4ZnGMUn3aCBgkcHAYErhV1Ad97WzBvS2o0aaDv4mVz+syiN0ElMyfBPg== + +"@rollup/rollup-darwin-arm64@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.9.5.tgz#5234ba62665a3f443143bc8bcea9df2cc58f55fb" + integrity sha512-ndoXeLx455FffL68OIUrVr89Xu1WLzAG4n65R8roDlCoYiQcGGg6MALvs2Ap9zs7AHg8mpHtMpwC8jBBjZrT/w== + +"@rollup/rollup-darwin-x64@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.9.5.tgz#981256c054d3247b83313724938d606798a919d1" + integrity sha512-UmElV1OY2m/1KEEqTlIjieKfVwRg0Zwg4PLgNf0s3glAHXBN99KLpw5A5lrSYCa1Kp63czTpVll2MAqbZYIHoA== + +"@rollup/rollup-linux-arm-gnueabihf@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.9.5.tgz#120678a5a2b3a283a548dbb4d337f9187a793560" + integrity sha512-Q0LcU61v92tQB6ae+udZvOyZ0wfpGojtAKrrpAaIqmJ7+psq4cMIhT/9lfV6UQIpeItnq/2QDROhNLo00lOD1g== + +"@rollup/rollup-linux-arm64-gnu@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.9.5.tgz#c99d857e2372ece544b6f60b85058ad259f64114" + integrity sha512-dkRscpM+RrR2Ee3eOQmRWFjmV/payHEOrjyq1VZegRUa5OrZJ2MAxBNs05bZuY0YCtpqETDy1Ix4i/hRqX98cA== + +"@rollup/rollup-linux-arm64-musl@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.9.5.tgz#3064060f568a5718c2a06858cd6e6d24f2ff8632" + integrity sha512-QaKFVOzzST2xzY4MAmiDmURagWLFh+zZtttuEnuNn19AiZ0T3fhPyjPPGwLNdiDT82ZE91hnfJsUiDwF9DClIQ== + +"@rollup/rollup-linux-riscv64-gnu@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.9.5.tgz#987d30b5d2b992fff07d055015991a57ff55fbad" + integrity sha512-HeGqmRJuyVg6/X6MpE2ur7GbymBPS8Np0S/vQFHDmocfORT+Zt76qu+69NUoxXzGqVP1pzaY6QIi0FJWLC3OPA== + +"@rollup/rollup-linux-x64-gnu@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.9.5.tgz#85946ee4d068bd12197aeeec2c6f679c94978a49" + integrity sha512-Dq1bqBdLaZ1Gb/l2e5/+o3B18+8TI9ANlA1SkejZqDgdU/jK/ThYaMPMJpVMMXy2uRHvGKbkz9vheVGdq3cJfA== + +"@rollup/rollup-linux-x64-musl@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.9.5.tgz#fe0b20f9749a60eb1df43d20effa96c756ddcbd4" + integrity sha512-ezyFUOwldYpj7AbkwyW9AJ203peub81CaAIVvckdkyH8EvhEIoKzaMFJj0G4qYJ5sw3BpqhFrsCc30t54HV8vg== + +"@rollup/rollup-win32-arm64-msvc@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.9.5.tgz#422661ef0e16699a234465d15b2c1089ef963b2a" + integrity sha512-aHSsMnUw+0UETB0Hlv7B/ZHOGY5bQdwMKJSzGfDfvyhnpmVxLMGnQPGNE9wgqkLUs3+gbG1Qx02S2LLfJ5GaRQ== + +"@rollup/rollup-win32-ia32-msvc@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.9.5.tgz#7b73a145891c202fbcc08759248983667a035d85" + integrity sha512-AiqiLkb9KSf7Lj/o1U3SEP9Zn+5NuVKgFdRIZkvd4N0+bYrTOovVd0+LmYCPQGbocT4kvFyK+LXCDiXPBF3fyA== + +"@rollup/rollup-win32-x64-msvc@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.9.5.tgz#10491ccf4f63c814d4149e0316541476ea603602" + integrity sha512-1q+mykKE3Vot1kaFJIDoUFv5TuW+QQVaf2FmTT9krg86pQrGStOSJJ0Zil7CFagyxDuouTepzt5Y5TVzyajOdQ== "@sinclair/typebox@^0.27.8": version "0.27.8" @@ -4670,17 +4670,17 @@ react-refresh@^0.14.0: integrity sha512-wViHqhAd8OHeLS/IRMJjTSDHF3U9eWi62F/MledQGPdJGDhodXJ9PBLNGr6WWL7qlH12Mt3TyTpbS+hGXMjCzQ== react-router-dom@^6: - version "6.21.2" - resolved "https://registry.yarnpkg.com/react-router-dom/-/react-router-dom-6.21.2.tgz#5fba851731a194fa32c31990c4829c5e247f650a" - integrity sha512-tE13UukgUOh2/sqYr6jPzZTzmzc70aGRP4pAjG2if0IP3aUT+sBtAKUJh0qMh0zylJHGLmzS+XWVaON4UklHeg== + version "6.21.3" + resolved "https://registry.yarnpkg.com/react-router-dom/-/react-router-dom-6.21.3.tgz#ef3a7956a3699c7b82c21fcb3dbc63c313ed8c5d" + integrity sha512-kNzubk7n4YHSrErzjLK72j0B5i969GsuCGazRl3G6j1zqZBLjuSlYBdVdkDOgzGdPIffUOc9nmgiadTEVoq91g== dependencies: "@remix-run/router" "1.14.2" - react-router "6.21.2" + react-router "6.21.3" -react-router@6.21.2: - version "6.21.2" - resolved "https://registry.yarnpkg.com/react-router/-/react-router-6.21.2.tgz#8820906c609ae7e4e8f926cc8eb5ce161428b956" - integrity sha512-jJcgiwDsnaHIeC+IN7atO0XiSRCrOsQAHHbChtJxmgqG2IaYQXSnhqGb5vk2CU/wBQA12Zt+TkbuJjIn65gzbA== +react-router@6.21.3: + version "6.21.3" + resolved "https://registry.yarnpkg.com/react-router/-/react-router-6.21.3.tgz#8086cea922c2bfebbb49c6594967418f1f167d70" + integrity sha512-a0H638ZXULv1OdkmiK6s6itNhoy33ywxmUFT/xtSoVyf9VnC7n7+VT4LjVzdIHSaF5TIh9ylUgxMXksHTgGrKg== dependencies: "@remix-run/router" "1.14.2" @@ -4845,25 +4845,25 @@ rimraf@^3.0.2: glob "^7.1.3" rollup@^4.2.0: - version "4.9.4" - resolved "https://registry.yarnpkg.com/rollup/-/rollup-4.9.4.tgz#37bc0c09ae6b4538a9c974f4d045bb64b2e7c27c" - integrity sha512-2ztU7pY/lrQyXSCnnoU4ICjT/tCG9cdH3/G25ERqE3Lst6vl2BCM5hL2Nw+sslAvAf+ccKsAq1SkKQALyqhR7g== + version "4.9.5" + resolved "https://registry.yarnpkg.com/rollup/-/rollup-4.9.5.tgz#62999462c90f4c8b5d7c38fc7161e63b29101b05" + integrity sha512-E4vQW0H/mbNMw2yLSqJyjtkHY9dslf/p0zuT1xehNRqUTBOFMqEjguDvqhXr7N7r/4ttb2jr4T41d3dncmIgbQ== dependencies: "@types/estree" "1.0.5" optionalDependencies: - "@rollup/rollup-android-arm-eabi" "4.9.4" - "@rollup/rollup-android-arm64" "4.9.4" - "@rollup/rollup-darwin-arm64" "4.9.4" - "@rollup/rollup-darwin-x64" "4.9.4" - "@rollup/rollup-linux-arm-gnueabihf" "4.9.4" - "@rollup/rollup-linux-arm64-gnu" "4.9.4" - "@rollup/rollup-linux-arm64-musl" "4.9.4" - "@rollup/rollup-linux-riscv64-gnu" "4.9.4" - "@rollup/rollup-linux-x64-gnu" "4.9.4" - "@rollup/rollup-linux-x64-musl" "4.9.4" - "@rollup/rollup-win32-arm64-msvc" "4.9.4" - "@rollup/rollup-win32-ia32-msvc" "4.9.4" - "@rollup/rollup-win32-x64-msvc" "4.9.4" + "@rollup/rollup-android-arm-eabi" "4.9.5" + "@rollup/rollup-android-arm64" "4.9.5" + "@rollup/rollup-darwin-arm64" "4.9.5" + "@rollup/rollup-darwin-x64" "4.9.5" + "@rollup/rollup-linux-arm-gnueabihf" "4.9.5" + "@rollup/rollup-linux-arm64-gnu" "4.9.5" + "@rollup/rollup-linux-arm64-musl" "4.9.5" + "@rollup/rollup-linux-riscv64-gnu" "4.9.5" + "@rollup/rollup-linux-x64-gnu" "4.9.5" + "@rollup/rollup-linux-x64-musl" "4.9.5" + "@rollup/rollup-win32-arm64-msvc" "4.9.5" + "@rollup/rollup-win32-ia32-msvc" "4.9.5" + "@rollup/rollup-win32-x64-msvc" "4.9.5" fsevents "~2.3.2" rsvp@^4.8.4: @@ -5474,9 +5474,9 @@ v8-to-istanbul@^9.0.1: convert-source-map "^1.6.0" vite@^5.0.5: - version "5.0.11" - resolved "https://registry.yarnpkg.com/vite/-/vite-5.0.11.tgz#31562e41e004cb68e1d51f5d2c641ab313b289e4" - integrity sha512-XBMnDjZcNAw/G1gEiskiM1v6yzM4GE5aMGvhWTlHAYYhxb7S3/V1s3m2LDHa8Vh6yIWYYB0iJwsEaS523c4oYA== + version "5.0.12" + resolved "https://registry.yarnpkg.com/vite/-/vite-5.0.12.tgz#8a2ffd4da36c132aec4adafe05d7adde38333c47" + integrity sha512-4hsnEkG3q0N4Tzf1+t6NdN9dg/L3BM+q8SWgbSPnJvrgH2kgdyzfVJwbR1ic69/4uMJJ/3dqDZZE5/WwqW8U1w== dependencies: esbuild "^0.19.3" postcss "^8.4.32" diff --git a/config-model/src/test/java/com/yahoo/schema/SchemaTestCase.java b/config-model/src/test/java/com/yahoo/schema/SchemaTestCase.java index c959634019d..e920672646f 100644 --- a/config-model/src/test/java/com/yahoo/schema/SchemaTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/SchemaTestCase.java @@ -1,11 +1,17 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema; +import com.yahoo.document.Document; import com.yahoo.schema.document.Stemming; import com.yahoo.schema.parser.ParseException; import com.yahoo.schema.processing.ImportedFieldsResolver; import com.yahoo.schema.processing.OnnxModelTypeResolver; import com.yahoo.vespa.documentmodel.DocumentSummary; +import com.yahoo.vespa.indexinglanguage.expressions.AttributeExpression; +import com.yahoo.vespa.indexinglanguage.expressions.Expression; +import com.yahoo.vespa.indexinglanguage.expressions.InputExpression; +import com.yahoo.vespa.indexinglanguage.expressions.ScriptExpression; +import com.yahoo.vespa.indexinglanguage.expressions.StatementExpression; import com.yahoo.vespa.model.test.utils.DeployLoggerStub; import org.junit.jupiter.api.Test; diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java index cac297f061a..c10220b4d95 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java @@ -544,6 +544,7 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye static void checkIfActiveHasChanged(Session session, Session activeSession, boolean ignoreStaleSessionFailure) { long activeSessionAtCreate = session.getActiveSessionAtCreate(); log.log(Level.FINE, () -> activeSession.logPre() + "active session id at create time=" + activeSessionAtCreate); + if (activeSessionAtCreate == 0) return; // No active session at create time, or session created for indeterminate app. long sessionId = session.getSessionId(); long activeSessionSessionId = activeSession.getSessionId(); @@ -554,7 +555,7 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye String errMsg = activeSession.logPre() + "Cannot activate session " + sessionId + " because the currently active session (" + activeSessionSessionId + ") has changed since session " + sessionId + " was created (was " + - (activeSessionAtCreate == 0 ? "empty" : activeSessionAtCreate) + " at creation time)"; + activeSessionAtCreate + " at creation time)"; if (ignoreStaleSessionFailure) { log.warning(errMsg + " (Continuing because of force.)"); } else { diff --git a/configserver/src/main/resources/configserver-app/services.xml b/configserver/src/main/resources/configserver-app/services.xml index 2e666089152..069b7ffc496 100644 --- a/configserver/src/main/resources/configserver-app/services.xml +++ b/configserver/src/main/resources/configserver-app/services.xml @@ -111,6 +111,7 @@ </handler> <handler id='com.yahoo.vespa.config.server.http.v2.ApplicationApiHandler' bundle='configserver'> <binding>http://*/application/v2/tenant/*/prepareandactivate</binding> + <binding>http://*/application/v2/tenant/*/prepareandactivate/*</binding> </handler> <handler id='com.yahoo.vespa.config.server.http.v2.SessionContentHandler' bundle='configserver'> <binding>http://*/application/v2/tenant/*/session/*/content/*</binding> diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandlerTest.java index 36892860295..d4aa0676c4f 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandlerTest.java @@ -212,15 +212,34 @@ class ApplicationApiHandlerTest { @Test void testActivationFailuresAndRetries() throws Exception { - // Prepare session 2, but fail on hosts; this session will be activated later. - provisioner.activationFailure(new ParentHostUnavailableException("host still booting")); + // Prepare session 2, and activate it successfully. verifyResponse(post(minimalPrepareParams, zip(appPackage), Map.of()), 200, """ { "log": [ ], - "message": "Session 2 for tenant 'test' prepared, but activation failed: host still booting", + "message": "Session 2 for tenant 'test' prepared and activated.", "session-id": "2", + "activated": true, + "tenant": "test", + "url": "http://host:123/application/v2/tenant/test/application/default/environment/prod/region/default/instance/default", + "configChangeActions": { + "restart": [ ], + "refeed": [ ], + "reindex": [ ] + } + } + """); + + // Prepare session 3, but fail on hosts; this session will be activated later. + provisioner.activationFailure(new ParentHostUnavailableException("host still booting")); + verifyResponse(post(minimalPrepareParams, zip(appPackage), Map.of()), + 200, + """ + { + "log": [ ], + "message": "Session 3 for tenant 'test' prepared, but activation failed: host still booting", + "session-id": "3", "activated": false, "tenant": "test", "url": "http://host:123/application/v2/tenant/test/application/default/environment/prod/region/default/instance/default", @@ -232,15 +251,15 @@ class ApplicationApiHandlerTest { } """); - // Prepare session 3, but fail on lock; this session will become outdated later. + // Prepare session 4, but fail on lock; this session will become outdated later. provisioner.activationFailure(new ApplicationLockException("lock timeout")); verifyResponse(post(minimalPrepareParams, zip(appPackage), Map.of()), 200, """ { "log": [ ], - "message": "Session 3 for tenant 'test' prepared, but activation failed: lock timeout", - "session-id": "3", + "message": "Session 4 for tenant 'test' prepared, but activation failed: lock timeout", + "session-id": "4", "activated": false, "tenant": "test", "url": "http://host:123/application/v2/tenant/test/application/default/environment/prod/region/default/instance/default", @@ -263,9 +282,9 @@ class ApplicationApiHandlerTest { } """); - // Retry only activation of session 2, but fail again with hosts. + // Retry only activation of session 3, but fail again with hosts. provisioner.activationFailure(new ParentHostUnavailableException("host still booting")); - verifyResponse(put(2, Map.of()), + verifyResponse(put(3, Map.of()), 409, """ { @@ -274,9 +293,9 @@ class ApplicationApiHandlerTest { } """); - // Retry only activation of session 2, but fail again with lock. + // Retry only activation of session 3, but fail again with lock. provisioner.activationFailure(new ApplicationLockException("lock timeout")); - verifyResponse(put(2, Map.of()), + verifyResponse(put(3, Map.of()), 500, """ { @@ -285,33 +304,33 @@ class ApplicationApiHandlerTest { } """); - // Retry only activation of session 2, and succeed! + // Retry only activation of session 3, and succeed! provisioner.activationFailure(null); - verifyResponse(put(2, Map.of()), + verifyResponse(put(3, Map.of()), 200, """ { - "message": "Session 2 for test.default.default activated" + "message": "Session 3 for test.default.default activated" } """); - // Retry only activation of session 3, but fail because it is now based on an outdated session. - verifyResponse(put(3, Map.of()), + // Retry only activation of session 4, but fail because it is now based on an outdated session. + verifyResponse(put(4, Map.of()), 409, """ { "error-code": "ACTIVATION_CONFLICT", - "message": "app:test.default.default Cannot activate session 3 because the currently active session (2) has changed since session 3 was created (was empty at creation time)" + "message": "app:test.default.default Cannot activate session 4 because the currently active session (3) has changed since session 4 was created (was 2 at creation time)" } """); - // Retry activation of session 2 again, and fail. - verifyResponse(put(2, Map.of()), + // Retry activation of session 3 again, and fail. + verifyResponse(put(3, Map.of()), 400, """ { "error-code": "BAD_REQUEST", - "message": "app:test.default.default Session 2 is already active" + "message": "app:test.default.default Session 3 is already active" } """); } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java index bef766e7ef9..70f6e405a92 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java @@ -15,6 +15,7 @@ public class ConversionContext { private final String destination; private final CompiledQueryProfileRegistry registry; private final Map<String, Embedder> embedders; + private final Map<String, String> contextValues; private final Language language; public ConversionContext(String destination, CompiledQueryProfileRegistry registry, Embedder embedder, @@ -30,6 +31,7 @@ public class ConversionContext { this.embedders = embedders; this.language = context.containsKey("language") ? Language.fromLanguageTag(context.get("language")) : Language.UNKNOWN; + this.contextValues = context; } /** Returns the local name of the field which will receive the converted value (or null when this is empty) */ @@ -44,6 +46,9 @@ public class ConversionContext { /** Returns the language, which is never null but may be UNKNOWN */ Language language() { return language; } + /** Returns a read-only map of context key-values which can be looked up during conversion. */ + Map<String,String> contextValues() { return contextValues; } + /** Returns an empty context */ public static ConversionContext empty() { return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of()); diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index cfadd79de8f..e16f8e7b0cd 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -48,7 +48,8 @@ public class TensorFieldType extends FieldType { @Override public Object convertFrom(Object o, ConversionContext context) { if (o instanceof SubstituteString) return new SubstituteStringTensor((SubstituteString) o, type); - return new TensorConverter(context.embedders()).convertTo(type, context.destination(), o, context.language()); + return new TensorConverter(context.embedders()).convertTo(type, context.destination(), o, + context.language(), context.contextValues()); } public static TensorFieldType fromTypeString(String s) { diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java index c9f935e5f52..25a5c277dce 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java @@ -44,7 +44,8 @@ public class RankProfileInputProperties extends Properties { value = tensorConverter.convertTo(expectedType, name.last(), value, - query.getModel().getLanguage()); + query.getModel().getLanguage(), + context); } } catch (IllegalArgumentException e) { diff --git a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java index 6da53ae699c..94f92c7fd48 100644 --- a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java +++ b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java @@ -19,7 +19,8 @@ import java.util.regex.Pattern; */ public class TensorConverter { - private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])"); + private static final Pattern embedderArgumentAndQuotedTextRegexp = Pattern.compile("^([A-Za-z0-9_@\\-.]+),\\s*([\"'].*[\"'])"); + private static final Pattern embedderArgumentAndReferenceRegexp = Pattern.compile("^([A-Za-z0-9_@\\-.]+),\\s*(@.*)"); private final Map<String, Embedder> embedders; @@ -27,8 +28,9 @@ public class TensorConverter { this.embedders = embedders; } - public Tensor convertTo(TensorType type, String key, Object value, Language language) { - var context = new Embedder.Context(key).setLanguage(language); + public Tensor convertTo(TensorType type, String key, Object value, Language language, + Map<String, String> contextValues) { + var context = new Embedder.Context(key).setLanguage(language).setContextValues(contextValues); Tensor tensor = toTensor(type, value, context); if (tensor == null) return null; if (! tensor.type().isAssignableTo(type)) @@ -55,16 +57,16 @@ public class TensorConverter { String embedderId; // Check if arguments specifies an embedder with the format embed(embedder, "text to encode") - Matcher matcher = embedderArgumentRegexp.matcher(argument); - if (matcher.matches()) { + Matcher matcher; + if (( matcher = embedderArgumentAndQuotedTextRegexp.matcher(argument)).matches()) { embedderId = matcher.group(1); + embedder = requireEmbedder(embedderId); argument = matcher.group(2); - if ( ! embedders.containsKey(embedderId)) { - throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + - "Valid embedders are " + validEmbedders(embedders)); - } - embedder = embedders.get(embedderId); - } else if (embedders.size() == 0) { + } else if (( matcher = embedderArgumentAndReferenceRegexp.matcher(argument)).matches()) { + embedderId = matcher.group(1); + embedder = requireEmbedder(embedderId); + argument = matcher.group(2); + } else if (embedders.isEmpty()) { throw new IllegalStateException("No embedders provided"); // should never happen } else if (embedders.size() > 1) { throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " + @@ -74,19 +76,35 @@ public class TensorConverter { embedderId = entry.getKey(); embedder = entry.getValue(); } - return embedder.embed(removeQuotes(argument), embedderContext.copy().setEmbedderId(embedderId), type); + return embedder.embed(resolve(argument, embedderContext), embedderContext.copy().setEmbedderId(embedderId), type); } - private static String removeQuotes(String s) { - if (s.startsWith("'") && s.endsWith("'")) { + private Embedder requireEmbedder(String embedderId) { + if ( ! embedders.containsKey(embedderId)) + throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + + "Valid embedders are " + validEmbedders(embedders)); + return embedders.get(embedderId); + } + + private static String resolve(String s, Embedder.Context embedderContext) { + if (s.startsWith("'") && s.endsWith("'")) return s.substring(1, s.length() - 1); - } - if (s.startsWith("\"") && s.endsWith("\"")) { + if (s.startsWith("\"") && s.endsWith("\"")) return s.substring(1, s.length() - 1); - } + if (s.startsWith("@")) + return resolveReference(s, embedderContext); return s; } + private static String resolveReference(String s, Embedder.Context embedderContext) { + String referenceKey = s.substring(1); + String referencedValue = embedderContext.getContextValues().get(referenceKey); + if (referencedValue == null) + throw new IllegalArgumentException("Could not resolve query parameter reference '" + referenceKey + + "' used in an embed() argument"); + return referencedValue; + } + private static String validEmbedders(Map<String, Embedder> embedders) { List<String> embedderIds = new ArrayList<>(); embedders.forEach((key, value) -> embedderIds.add(key)); diff --git a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java index 90e21e5f3b0..429b8d1c6cb 100644 --- a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java +++ b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java @@ -185,6 +185,21 @@ public class RankProfileInputTest { assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders, Language.UNKNOWN.languageCode()); } + @Test + void testUnembeddedTensorRankFeatureInRequestReferencedFromAParameter() { + String text = "text to embed into a tensor"; + Tensor embedding1 = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]"); + + Map<String, Embedder> embedders = Map.of( + "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1) + ); + assertEmbedQuery("embed(@param1)", embedding1, embedders, null, text); + assertEmbedQuery("embed(emb1, @param1)", embedding1, embedders, null, text); + assertEmbedQueryFails("embed(emb1, @noSuchParam)", embedding1, embedders, + "Could not resolve query parameter reference 'noSuchParam' " + + "used in an embed() argument"); + } + private Query createTensor1Query(String tensorString, String profile, String additionalParams) { return new Query.Builder() .setSchemaInfo(createSchemaInfo()) @@ -202,18 +217,24 @@ public class RankProfileInputTest { } private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders) { - assertEmbedQuery(embed, expected, embedders, null); + assertEmbedQuery(embed, expected, embedders, null, null); } private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language) { + assertEmbedQuery(embed, expected, embedders, language, null); + } + private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language, String param1Value) { String languageParam = language == null ? "" : "&language=" + language; + String param1 = param1Value == null ? "" : "¶m1=" + urlEncode(param1Value); + String destination = "query(myTensor4)"; Query query = new Query.Builder().setRequest(HttpRequest.createTestRequest( "?" + urlEncode("ranking.features." + destination) + "=" + urlEncode(embed) + "&ranking=commonProfile" + - languageParam, + languageParam + + param1, com.yahoo.jdisc.http.HttpRequest.Method.GET)) .setSchemaInfo(createSchemaInfo()) .setQueryProfile(createQueryProfile()) @@ -230,7 +251,7 @@ public class RankProfileInputTest { if (t.getMessage().equals(errMsg)) return; t = t.getCause(); } - fail("Error '" + errMsg + "' not thrown"); + fail("Exception with message '" + errMsg + "' not thrown"); } private CompiledQueryProfile createQueryProfile() { diff --git a/dependency-versions/pom.xml b/dependency-versions/pom.xml index cbeb4d01fe4..9af479a3494 100644 --- a/dependency-versions/pom.xml +++ b/dependency-versions/pom.xml @@ -66,7 +66,7 @@ <!-- Athenz dependencies. Make sure these dependencies match those in Vespa's internal repositories --> <athenz.vespa.version>1.11.50</athenz.vespa.version> - <aws-sdk.vespa.version>1.12.639</aws-sdk.vespa.version> + <aws-sdk.vespa.version>1.12.640</aws-sdk.vespa.version> <!-- Athenz END --> <!-- WARNING: If you change curator version, you also need to update @@ -120,7 +120,7 @@ <mimepull.vespa.version>1.10.0</mimepull.vespa.version> <mockito.vespa.version>5.9.0</mockito.vespa.version> <mojo-executor.vespa.version>2.4.0</mojo-executor.vespa.version> - <netty.vespa.version>4.1.105.Final</netty.vespa.version> + <netty.vespa.version>4.1.106.Final</netty.vespa.version> <netty-tcnative.vespa.version>2.0.62.Final</netty-tcnative.vespa.version> <onnxruntime.vespa.version>1.16.3</onnxruntime.vespa.version> <opennlp.vespa.version>2.3.1</opennlp.vespa.version> diff --git a/document/src/main/java/com/yahoo/document/json/TokenBuffer.java b/document/src/main/java/com/yahoo/document/json/TokenBuffer.java index a9cd3cc87a8..c1ac239d5f0 100644 --- a/document/src/main/java/com/yahoo/document/json/TokenBuffer.java +++ b/document/src/main/java/com/yahoo/document/json/TokenBuffer.java @@ -159,7 +159,7 @@ public class TokenBuffer { if (name.equals(currentName()) && current().isScalarValue()) { toReturn = tokens.get(position); } else { - i = tokens.iterator(); + i = rest().iterator(); i.next(); // just ignore the first value, as we know it's not what // we're looking for, and it's nesting effect is already // included diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json index 1ffb879e57e..dc6a62cc463 100644 --- a/linguistics/abi-spec.json +++ b/linguistics/abi-spec.json @@ -344,7 +344,9 @@ "public java.lang.String getDestination()", "public com.yahoo.language.process.Embedder$Context setDestination(java.lang.String)", "public java.lang.String getEmbedderId()", - "public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)" + "public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)", + "public java.util.Map getContextValues()", + "public com.yahoo.language.process.Embedder$Context setContextValues(java.util.Map)" ], "fields" : [ ] }, diff --git a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java index fa141977d5d..f6fdb86d01f 100644 --- a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java +++ b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java @@ -88,6 +88,7 @@ public interface Embedder { private Language language = Language.UNKNOWN; private String destination; private String embedderId = "unknown"; + private Map<String, String> contextValues; public Context(String destination) { this.destination = destination; @@ -138,6 +139,15 @@ public interface Embedder { this.embedderId = embedderId; return this; } + + /** Returns a read-only map of context key-values which can be looked up during conversion. */ + public Map<String, String> getContextValues() { return contextValues; } + + public Context setContextValues(Map<String, String> contextValues) { + this.contextValues = contextValues; + return this; + } + } class FailingEmbedder implements Embedder { diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 853009873a1..28f8c4e252f 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -10,9 +10,9 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.embedding.SpladeEmbedderConfig; import com.yahoo.language.huggingface.HuggingFaceTokenizer; import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.DirectIndexedAddress; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import java.nio.file.Paths; import java.util.List; @@ -152,10 +152,15 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { String dimension = tensorType.dimensions().get(0).name(); //Iterate over the vocab dimension and find the max value for each sequence token long [] tokens = new long[1]; + DirectIndexedAddress directAddress = modelOutput.directAddress(); + directAddress.setIndex(0,0); for (int v = 0; v < vocabSize; v++) { double maxValue = 0.0d; + directAddress.setIndex(2, v); + long increment = directAddress.getStride(1); + long directIndex = directAddress.getDirectIndex(); for (int s = 0; s < sequenceLength; s++) { - double value = modelOutput.get(0, s, v); // batch, sequence, vocab + double value = modelOutput.get(directIndex + s * increment); if (value > maxValue) { maxValue = value; } diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java index 07f2aea4ab6..d1a06d8c7ff 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java @@ -28,6 +28,7 @@ import java.nio.ShortBuffer; import java.util.HashMap; import java.util.Map; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; @@ -101,70 +102,67 @@ class TensorConverter { throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensorInfo.type); } + interface Short2Float { + float convert(short value); + } + + private static void extractTensor(FloatBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(DoubleBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(ByteBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(ShortBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(ShortBuffer buffer, Short2Float converter, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, converter.convert(buffer.get(i))); + } + private static void extractTensor(IntBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(LongBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + static Tensor toVespaTensor(OnnxValue onnxValue) { if ( ! (onnxValue instanceof OnnxTensor onnxTensor)) { throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported"); } TensorInfo tensorInfo = onnxTensor.getInfo(); - TensorType type = toVespaType(onnxTensor.getInfo()); - DimensionSizes sizes = sizesFromType(type); - + DimensionSizes sizes = DimensionSizes.of(type); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type, sizes); - long totalSize = sizes.totalSize(); - if (tensorInfo.type == OnnxJavaType.FLOAT) { - FloatBuffer buffer = onnxTensor.getFloatBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.DOUBLE) { - DoubleBuffer buffer = onnxTensor.getDoubleBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT8) { - ByteBuffer buffer = onnxTensor.getByteBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT16) { - ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT32) { - IntBuffer buffer = onnxTensor.getIntBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT64) { - LongBuffer buffer = onnxTensor.getLongBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.FLOAT16) { - ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, Fp16Conversions.fp16ToFloat(buffer.get())); - } - else if (tensorInfo.type == OnnxJavaType.BFLOAT16) { - ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, Fp16Conversions.bf16ToFloat((buffer.get()))); - } - else { - throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type); + long totalSizeAsLong = sizes.totalSize(); + if (totalSizeAsLong > Integer.MAX_VALUE) { + throw new IllegalArgumentException("TotalSize=" + totalSizeAsLong + " currently limited at INTEGER.MAX_VALUE"); + } + + int totalSize = (int) totalSizeAsLong; + switch (tensorInfo.type) { + case FLOAT -> extractTensor(onnxTensor.getFloatBuffer(), builder, totalSize); + case DOUBLE -> extractTensor(onnxTensor.getDoubleBuffer(), builder, totalSize); + case INT8 -> extractTensor(onnxTensor.getByteBuffer(), builder, totalSize); + case INT16 -> extractTensor(onnxTensor.getShortBuffer(), builder, totalSize); + case INT32 -> extractTensor(onnxTensor.getIntBuffer(), builder, totalSize); + case INT64 -> extractTensor(onnxTensor.getLongBuffer(), builder, totalSize); + case FLOAT16 -> extractTensor(onnxTensor.getShortBuffer(), Fp16Conversions::fp16ToFloat, builder, totalSize); + case BFLOAT16 -> extractTensor(onnxTensor.getShortBuffer(), Fp16Conversions::bf16ToFloat, builder, totalSize); + default -> throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type); } return builder.build(); } - static private DimensionSizes sizesFromType(TensorType type) { - DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); - for (int i = 0; i < type.dimensions().size(); i++) - builder.set(i, type.dimensions().get(i).size().get()); - return builder.build(); - } - static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> infoMap) { return infoMap.entrySet().stream().collect(Collectors.toMap(e -> asValidName(e.getKey()), e -> toVespaType(e.getValue().getInfo()))); diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java index 9ecb0e3e162..b48051814ab 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java @@ -49,11 +49,11 @@ public class SpladeEmbedderTest { String text = "what was the manhattan project in this context it was a secret project to develop a nuclear weapon in world war" + " ii the project was led by the united states with the support of the united kingdom and canada"; Long now = System.currentTimeMillis(); - int n = 10; + int n = 1000; // Takes around 7s on Intel core i9 2.4Ghz (macbook pro, 2019) for (int i = 0; i < n; i++) { assertEmbed("tensor<float>(t{})", text, indexingContext); } - Long elapsed = (System.currentTimeMillis() - now)/1000; + Long elapsed = System.currentTimeMillis() - now; System.out.println("Elapsed time: " + elapsed + " ms"); } diff --git a/searchlib/src/tests/attribute/direct_multi_term_blueprint/direct_multi_term_blueprint_test.cpp b/searchlib/src/tests/attribute/direct_multi_term_blueprint/direct_multi_term_blueprint_test.cpp index a94ca9d890e..87b771af8e6 100644 --- a/searchlib/src/tests/attribute/direct_multi_term_blueprint/direct_multi_term_blueprint_test.cpp +++ b/searchlib/src/tests/attribute/direct_multi_term_blueprint/direct_multi_term_blueprint_test.cpp @@ -26,6 +26,7 @@ using LookupKey = IDirectPostingStore::LookupKey; struct IntegerKey : public LookupKey { int64_t _value; IntegerKey(int64_t value_in) : _value(value_in) {} + IntegerKey(const vespalib::string&) : _value() { abort(); } vespalib::stringref asString() const override { abort(); } bool asInteger(int64_t& value) const override { value = _value; return true; } }; @@ -33,6 +34,7 @@ struct IntegerKey : public LookupKey { struct StringKey : public LookupKey { vespalib::string _value; StringKey(int64_t value_in) : _value(std::to_string(value_in)) {} + StringKey(const vespalib::string& value_in) : _value(value_in) {} vespalib::stringref asString() const override { return _value; } bool asInteger(int64_t&) const override { abort(); } }; @@ -78,6 +80,10 @@ populate_attribute(AttributeType& attr, const std::vector<DataType>& values) for (auto docid : range(300, 128)) { attr.update(docid, values[3]); } + if (values.size() > 4) { + attr.update(40, values[4]); + attr.update(41, values[5]); + } attr.commit(true); } @@ -93,7 +99,7 @@ make_attribute(CollectionType col_type, BasicType type, bool field_is_filter) auto attr = test::AttributeBuilder(field_name, cfg).docs(num_docs).get(); if (type == BasicType::STRING) { populate_attribute<StringAttribute, vespalib::string>(dynamic_cast<StringAttribute&>(*attr), - {"1", "3", "100", "300"}); + {"1", "3", "100", "300", "foo", "Foo"}); } else { populate_attribute<IntegerAttribute, int64_t>(dynamic_cast<IntegerAttribute&>(*attr), {1, 3, 100, 300}); @@ -223,15 +229,16 @@ public: tfmd.tagAsNotNeeded(); } } - template <typename BlueprintType> - void add_term_helper(BlueprintType& b, int64_t term_value) { + template <typename BlueprintType, typename TermType> + void add_term_helper(BlueprintType& b, TermType term_value) { if (integer_type) { b.addTerm(IntegerKey(term_value), 1, estimate); } else { b.addTerm(StringKey(term_value), 1, estimate); } } - void add_term(int64_t term_value) { + template <typename TermType> + void add_term(TermType term_value) { if (single_type) { if (in_operator) { add_term_helper(dynamic_cast<SingleInBlueprintType&>(*blueprint), term_value); @@ -246,8 +253,18 @@ public: } } } - std::unique_ptr<SearchIterator> create_leaf_search() const { - return blueprint->createLeafSearch(tfmda, true); + void add_terms(const std::vector<int64_t>& term_values) { + for (auto value : term_values) { + add_term(value); + } + } + void add_terms(const std::vector<vespalib::string>& term_values) { + for (auto value : term_values) { + add_term(value); + } + } + std::unique_ptr<SearchIterator> create_leaf_search(bool strict = true) const { + return blueprint->createLeafSearch(tfmda, strict); } vespalib::string resolve_iterator_with_unpack() const { if (in_operator) { @@ -262,7 +279,7 @@ expect_hits(const Docids& exp_docids, SearchIterator& itr) { SimpleResult exp(exp_docids); SimpleResult act; - act.search(itr); + act.search(itr, doc_id_limit); EXPECT_EQ(exp, act); } @@ -294,8 +311,7 @@ INSTANTIATE_TEST_SUITE_P(DefaultInstantiation, TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_for_none_filter_field) { setup(false, true); - add_term(1); - add_term(3); + add_terms({1, 3}); auto itr = create_leaf_search(); EXPECT_THAT(itr->asString(), StartsWith(resolve_iterator_with_unpack())); expect_hits({10, 30, 31}, *itr); @@ -307,8 +323,7 @@ TEST_P(DirectMultiTermBlueprintTest, bitvectors_used_instead_of_btree_iterators_ if (!in_operator) { return; } - add_term(1); - add_term(100); + add_terms({1, 100}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 2); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); @@ -322,8 +337,7 @@ TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_instead_of_bitvectors_ if (in_operator) { return; } - add_term(1); - add_term(100); + add_terms({1, 100}); auto itr = create_leaf_search(); EXPECT_THAT(itr->asString(), StartsWith(iterator_unpack_docid_and_weights)); expect_hits(concat({10}, range(100, 128)), *itr); @@ -332,10 +346,7 @@ TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_instead_of_bitvectors_ TEST_P(DirectMultiTermBlueprintTest, bitvectors_and_btree_iterators_used_for_filter_field) { setup(true, true); - add_term(1); - add_term(3); - add_term(100); - add_term(300); + add_terms({1, 3, 100, 300}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 3); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); @@ -347,8 +358,7 @@ TEST_P(DirectMultiTermBlueprintTest, bitvectors_and_btree_iterators_used_for_fil TEST_P(DirectMultiTermBlueprintTest, only_bitvectors_used_for_filter_field) { setup(true, true); - add_term(100); - add_term(300); + add_terms({100, 300}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 2); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); @@ -359,8 +369,7 @@ TEST_P(DirectMultiTermBlueprintTest, only_bitvectors_used_for_filter_field) TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_for_filter_field_when_ranking_not_needed) { setup(true, false); - add_term(1); - add_term(3); + add_terms({1, 3}); auto itr = create_leaf_search(); EXPECT_THAT(itr->asString(), StartsWith(iterator_unpack_none)); expect_hits({10, 30, 31}, *itr); @@ -369,10 +378,7 @@ TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_for_filter_field_when_ TEST_P(DirectMultiTermBlueprintTest, bitvectors_and_btree_iterators_used_for_filter_field_when_ranking_not_needed) { setup(true, false); - add_term(1); - add_term(3); - add_term(100); - add_term(300); + add_terms({1, 3, 100, 300}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 3); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); @@ -384,8 +390,7 @@ TEST_P(DirectMultiTermBlueprintTest, bitvectors_and_btree_iterators_used_for_fil TEST_P(DirectMultiTermBlueprintTest, only_bitvectors_used_for_filter_field_when_ranking_not_needed) { setup(true, false); - add_term(100); - add_term(300); + add_terms({100, 300}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 2); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); @@ -393,4 +398,41 @@ TEST_P(DirectMultiTermBlueprintTest, only_bitvectors_used_for_filter_field_when_ expect_hits(concat(range(100, 128), range(300, 128)), *itr); } +TEST_P(DirectMultiTermBlueprintTest, hash_filter_used_for_non_strict_iterator_with_10_or_more_terms) +{ + setup(true, true); + if (!single_type) { + return; + } + add_terms({1, 3, 3, 3, 3, 3, 3, 3, 3, 3}); + auto itr = create_leaf_search(false); + EXPECT_THAT(itr->asString(), StartsWith("search::attribute::MultiTermHashFilter")); + expect_hits({10, 30, 31}, *itr); +} + +TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_for_non_strict_iterator_with_9_or_less_terms) +{ + setup(true, true); + if (!single_type) { + return; + } + add_terms({1, 3, 3, 3, 3, 3, 3, 3, 3}); + auto itr = create_leaf_search(false); + EXPECT_THAT(itr->asString(), StartsWith(iterator_unpack_docid)); + expect_hits({10, 30, 31}, *itr); +} + +TEST_P(DirectMultiTermBlueprintTest, hash_filter_with_string_folding_used_for_non_strict_iterator) +{ + setup(true, true); + if (!single_type || integer_type) { + return; + } + // "foo" matches documents with "foo" (40) and "Foo" (41). + add_terms({"foo", "3", "3", "3", "3", "3", "3", "3", "3", "3"}); + auto itr = create_leaf_search(false); + EXPECT_THAT(itr->asString(), StartsWith("search::attribute::MultiTermHashFilter")); + expect_hits({30, 31, 40, 41}, *itr); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp index fe6149e6fba..97b3d88c25e 100644 --- a/searchlib/src/tests/query/streaming_query_test.cpp +++ b/searchlib/src/tests/query/streaming_query_test.cpp @@ -23,9 +23,9 @@ using TermType = QueryTerm::Type; using search::fef::SimpleTermData; using search::fef::MatchData; -void assertHit(const Hit & h, size_t expWordpos, size_t expContext, int32_t weight) { +void assertHit(const Hit & h, size_t expWordpos, size_t exp_field_id, int32_t weight) { EXPECT_EQ(h.wordpos(), expWordpos); - EXPECT_EQ(h.context(), expContext); + EXPECT_EQ(h.field_id(), exp_field_id); EXPECT_EQ(h.weight(), weight); } @@ -479,11 +479,11 @@ TEST(StreamingQueryTest, test_phrase_evaluate) p->evaluateHits(hits); ASSERT_EQ(3u, hits.size()); EXPECT_EQ(hits[0].wordpos(), 2u); - EXPECT_EQ(hits[0].context(), 0u); + EXPECT_EQ(hits[0].field_id(), 0u); EXPECT_EQ(hits[1].wordpos(), 6u); - EXPECT_EQ(hits[1].context(), 1u); + EXPECT_EQ(hits[1].field_id(), 1u); EXPECT_EQ(hits[2].wordpos(), 2u); - EXPECT_EQ(hits[2].context(), 3u); + EXPECT_EQ(hits[2].field_id(), 3u); ASSERT_EQ(4u, p->getFieldInfoSize()); EXPECT_EQ(p->getFieldInfo(0).getHitOffset(), 0u); EXPECT_EQ(p->getFieldInfo(0).getHitCount(), 1u); @@ -847,22 +847,22 @@ TEST(StreamingQueryTest, test_same_element_evaluate) sameElem->evaluateHits(hits); EXPECT_EQ(4u, hits.size()); EXPECT_EQ(0u, hits[0].wordpos()); - EXPECT_EQ(2u, hits[0].context()); + EXPECT_EQ(2u, hits[0].field_id()); EXPECT_EQ(0u, hits[0].elemId()); EXPECT_EQ(130, hits[0].weight()); EXPECT_EQ(0u, hits[1].wordpos()); - EXPECT_EQ(2u, hits[1].context()); + EXPECT_EQ(2u, hits[1].field_id()); EXPECT_EQ(2u, hits[1].elemId()); EXPECT_EQ(140, hits[1].weight()); EXPECT_EQ(0u, hits[2].wordpos()); - EXPECT_EQ(2u, hits[2].context()); + EXPECT_EQ(2u, hits[2].field_id()); EXPECT_EQ(4u, hits[2].elemId()); EXPECT_EQ(150, hits[2].weight()); EXPECT_EQ(0u, hits[3].wordpos()); - EXPECT_EQ(2u, hits[3].context()); + EXPECT_EQ(2u, hits[3].field_id()); EXPECT_EQ(5u, hits[3].elemId()); EXPECT_EQ(160, hits[3].weight()); EXPECT_TRUE(sameElem->evaluate()); diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp index 01148c11c9c..d0353ab8947 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp @@ -6,12 +6,12 @@ #include <vespa/searchlib/common/bitvector.h> #include <vespa/searchlib/fef/matchdatalayout.h> #include <vespa/searchlib/query/query_term_ucs4.h> +#include <vespa/searchlib/queryeval/filter_wrapper.h> +#include <vespa/searchlib/queryeval/orsearch.h> #include <vespa/searchlib/queryeval/weighted_set_term_search.h> #include <vespa/vespalib/objects/visit.h> -#include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/stllike/hash_map.hpp> -#include <vespa/searchlib/queryeval/filter_wrapper.h> -#include <vespa/searchlib/queryeval/orsearch.h> +#include <vespa/vespalib/util/stringfmt.h> namespace search { @@ -38,6 +38,7 @@ class StringEnumWrapper : public AttrWrapper { public: using TokenT = uint32_t; + static constexpr bool unpack_weights = true; explicit StringEnumWrapper(const IAttributeVector & attr) : AttrWrapper(attr) {} auto mapToken(const ISearchContext &context) const { @@ -52,6 +53,7 @@ class IntegerWrapper : public AttrWrapper { public: using TokenT = uint64_t; + static constexpr bool unpack_weights = true; explicit IntegerWrapper(const IAttributeVector & attr) : AttrWrapper(attr) {} std::vector<int64_t> mapToken(const ISearchContext &context) const { std::vector<int64_t> result; diff --git a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h index 42651854599..485427391ad 100644 --- a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h +++ b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h @@ -35,6 +35,8 @@ private: using IteratorType = typename PostingStoreType::IteratorType; using IteratorWeights = std::variant<std::reference_wrapper<const std::vector<int32_t>>, std::vector<int32_t>>; + bool use_hash_filter(bool strict) const; + IteratorWeights create_iterators(std::vector<IteratorType>& btree_iterators, std::vector<std::unique_ptr<queryeval::SearchIterator>>& bitvectors, bool use_bitvector_when_available, diff --git a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp index 6fc8ac63026..0a3b24142a5 100644 --- a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp +++ b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp @@ -8,6 +8,7 @@ #include <vespa/searchlib/queryeval/emptysearch.h> #include <vespa/searchlib/queryeval/filter_wrapper.h> #include <vespa/searchlib/queryeval/orsearch.h> +#include <cmath> #include <memory> #include <type_traits> @@ -38,6 +39,43 @@ DirectMultiTermBlueprint<PostingStoreType, SearchType>::DirectMultiTermBlueprint template <typename PostingStoreType, typename SearchType> DirectMultiTermBlueprint<PostingStoreType, SearchType>::~DirectMultiTermBlueprint() = default; + +template <typename PostingStoreType, typename SearchType> +bool +DirectMultiTermBlueprint<PostingStoreType, SearchType>::use_hash_filter(bool strict) const +{ + if (strict || _iattr.hasMultiValue()) { + return false; + } + // The following very simplified formula was created after analysing performance of the IN operator + // on a 10M document corpus using a machine with an Intel Xeon 2.5 GHz CPU with 48 cores and 256 Gb of memory: + // https://github.com/vespa-engine/system-test/tree/master/tests/performance/in_operator + // + // The following 25 test cases were used to calculate the cost of using btree iterators (strict): + // op_hits_ratios = [5, 10, 50, 100, 200] * tokens_in_op = [1, 5, 10, 100, 1000] + // For each case we calculate the latency difference against the case with tokens_in_op=1 and the same op_hits_ratio. + // This indicates the extra time used to produce the same number of hits when having multiple tokens in the operator. + // The latency diff is divided with the number of hits produced and convert to nanoseconds: + // 10M * (op_hits_ratio / 1000) * 1000 * 1000 + // Based on the numbers we can approximate the cost per document (in nanoseconds) as: + // 8.0 (ns) * log2(tokens_in_op). + // NOTE: This is very simplified. Ideally we should also take into consideration the hit estimate of this blueprint, + // as the cost per document will be lower when producing few hits. + // + // In addition, the following 28 test cases were used to calculate the cost of using the hash filter (non-strict). + // filter_hits_ratios = [1, 5, 10, 50, 100, 150, 200] x op_hits_ratios = [200] x tokens_in_op = [5, 10, 100, 1000] + // The code was altered to always using the hash filter for non-strict iterators. + // For each case we calculate the latency difference against a case from above with tokens_in_op=1 that produce a similar number of hits. + // This indicates the extra time used to produce the same number of hits when using the hash filter. + // The latency diff is divided with the number of hits the test filter produces and convert to nanoseconds: + // 10M * (filter_hits_ratio / 1000) * 1000 * 1000 + // Based on the numbers we calculate the average cost per document (in nanoseconds) as 26.0 ns. + + float hash_filter_cost_per_doc_ns = 26.0; + float btree_iterator_cost_per_doc_ns = 8.0 * std::log2(_terms.size()); + return hash_filter_cost_per_doc_ns < btree_iterator_cost_per_doc_ns; +} + template <typename PostingStoreType, typename SearchType> typename DirectMultiTermBlueprint<PostingStoreType, SearchType>::IteratorWeights DirectMultiTermBlueprint<PostingStoreType, SearchType>::create_iterators(std::vector<IteratorType>& btree_iterators, @@ -96,14 +134,21 @@ DirectMultiTermBlueprint<PostingStoreType, SearchType>::create_search_helper(con if (_terms.empty()) { return std::make_unique<queryeval::EmptySearch>(); } + auto& tfmd = *tfmda[0]; + bool field_is_filter = getState().fields()[0].isFilter(); + if constexpr (SearchType::supports_hash_filter) { + if (use_hash_filter(strict)) { + return SearchType::create_hash_filter(tfmd, (filter_search || field_is_filter), + _weights, _terms, + _iattr, _attr, _dictionary_snapshot); + } + } std::vector<IteratorType> btree_iterators; std::vector<queryeval::SearchIterator::UP> bitvectors; const size_t num_children = _terms.size(); btree_iterators.reserve(num_children); - auto& tfmd = *tfmda[0]; bool use_bit_vector_when_available = filter_search || !_attr.has_always_btree_iterator(); auto weights = create_iterators(btree_iterators, bitvectors, use_bit_vector_when_available, tfmd, strict); - bool field_is_filter = getState().fields()[0].isFilter(); if constexpr (!SearchType::require_btree_iterators) { auto multi_term = !btree_iterators.empty() ? SearchType::create(tfmd, (filter_search || field_is_filter), std::move(weights), std::move(btree_iterators)) diff --git a/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h index 9c3ea258fdc..43500366f21 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h +++ b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h @@ -23,6 +23,7 @@ class MultiTermHashFilter final : public queryeval::SearchIterator public: using Key = typename WrapperType::TokenT; using TokenMap = vespalib::hash_map<Key, int32_t, vespalib::hash<Key>, std::equal_to<Key>, vespalib::hashtable_base::and_modulator>; + static constexpr bool unpack_weights = WrapperType::unpack_weights; private: fef::TermFieldMatchData& _tfmd; diff --git a/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp index 96d5b3ac1f3..e67bda147d7 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp @@ -42,10 +42,14 @@ template <typename WrapperType> void MultiTermHashFilter<WrapperType>::doUnpack(uint32_t docId) { - _tfmd.reset(docId); - fef::TermFieldMatchDataPosition pos; - pos.setElementWeight(_weight); - _tfmd.appendPosition(pos); + if constexpr (unpack_weights) { + _tfmd.reset(docId); + fef::TermFieldMatchDataPosition pos; + pos.setElementWeight(_weight); + _tfmd.appendPosition(pos); + } else { + _tfmd.resetOnlyDocId(docId); + } } } diff --git a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt index 05a75f4662e..76119a6d58f 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt @@ -2,6 +2,7 @@ vespa_add_library(searchlib_query_streaming OBJECT SOURCES dot_product_term.cpp + fuzzy_term.cpp in_term.cpp multi_term.cpp nearest_neighbor_query_node.cpp diff --git a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp index 1871bda564d..b3bfbb0e86b 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp @@ -24,7 +24,7 @@ DotProductTerm::build_scores(Scores& scores) const for (const auto& term : _terms) { auto& hl = term->evaluateHits(hl_store); for (auto& hit : hl) { - scores[hit.context()] += ((int64_t)term->weight().percent()) * hit.weight(); + scores[hit.field_id()] += ((int64_t)term->weight().percent()) * hit.weight(); } } } diff --git a/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.cpp new file mode 100644 index 00000000000..f33fe44369a --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.cpp @@ -0,0 +1,43 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "fuzzy_term.h" + +namespace search::streaming { + +namespace { + +constexpr bool normalizing_implies_cased(Normalizing norm) noexcept { + return (norm == Normalizing::NONE); +} + +} + +FuzzyTerm::FuzzyTerm(std::unique_ptr<QueryNodeResultBase> result_base, stringref term, + const string& index, Type type, Normalizing normalizing, + uint8_t max_edits, uint32_t prefix_size) + : QueryTerm(std::move(result_base), term, index, type, normalizing), + _dfa_matcher(), + _fallback_matcher() +{ + setFuzzyMaxEditDistance(max_edits); + setFuzzyPrefixLength(prefix_size); + + std::string_view term_view(term.data(), term.size()); + const bool cased = normalizing_implies_cased(normalizing); + if (attribute::DfaFuzzyMatcher::supports_max_edits(max_edits)) { + _dfa_matcher = std::make_unique<attribute::DfaFuzzyMatcher>(term_view, max_edits, prefix_size, cased); + } else { + _fallback_matcher = std::make_unique<vespalib::FuzzyMatcher>(term_view, max_edits, prefix_size, cased); + } +} + +FuzzyTerm::~FuzzyTerm() = default; + +bool FuzzyTerm::is_match(std::string_view term) const { + if (_dfa_matcher) { + return _dfa_matcher->is_match(term); + } else { + return _fallback_matcher->isMatch(term); + } +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.h b/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.h new file mode 100644 index 00000000000..c6c88b18969 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.h @@ -0,0 +1,34 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "queryterm.h" +#include <vespa/searchlib/attribute/dfa_fuzzy_matcher.h> +#include <vespa/vespalib/fuzzy/fuzzy_matcher.h> +#include <memory> +#include <string_view> + +namespace search::streaming { + +/** + * Query term that matches candidate field terms that are within a query-specified + * maximum number of edits (add, delete or substitute a character), with case + * sensitivity controlled by the provided Normalizing mode. + * + * Optionally, terms may be prefixed-locked, which enforces field terms to have a + * particular prefix and where edits are only counted for the remaining term suffix. + */ +class FuzzyTerm : public QueryTerm { + std::unique_ptr<attribute::DfaFuzzyMatcher> _dfa_matcher; + std::unique_ptr<vespalib::FuzzyMatcher> _fallback_matcher; +public: + FuzzyTerm(std::unique_ptr<QueryNodeResultBase> result_base, stringref term, + const string& index, Type type, Normalizing normalizing, + uint8_t max_edits, uint32_t prefix_size); + ~FuzzyTerm() override; + + [[nodiscard]] FuzzyTerm* as_fuzzy_term() noexcept override { return this; } + + [[nodiscard]] bool is_match(std::string_view term) const; +}; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/hit.h b/searchlib/src/vespa/searchlib/query/streaming/hit.h index a798d293491..1e467a895ac 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/hit.h +++ b/searchlib/src/vespa/searchlib/query/streaming/hit.h @@ -9,19 +9,17 @@ namespace search::streaming { class Hit { public: - Hit(uint32_t pos_, uint32_t context_, uint32_t elemId_, int32_t weight_) noexcept - : _position(pos_ | (context_<<24)), + Hit(uint32_t pos_, uint32_t field_id_, uint32_t elemId_, int32_t weight_) noexcept + : _position(pos_ | (field_id_<<24)), _elemId(elemId_), _weight(weight_) { } int32_t weight() const { return _weight; } uint32_t pos() const { return _position; } uint32_t wordpos() const { return _position & 0xffffff; } - uint32_t context() const { return _position >> 24; } + uint32_t field_id() const noexcept { return _position >> 24; } uint32_t elemId() const { return _elemId; } - bool operator < (const Hit & b) const { return cmp(b) < 0; } private: - int cmp(const Hit & b) const { return _position - b._position; } uint32_t _position; uint32_t _elemId; int32_t _weight; diff --git a/searchlib/src/vespa/searchlib/query/streaming/in_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/in_term.cpp index 3e75f4a5114..c164db69ba1 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/in_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/in_term.cpp @@ -29,9 +29,9 @@ InTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_ for (const auto& term : _terms) { auto& hl = term->evaluateHits(hl_store); for (auto& hit : hl) { - if (!prev_field_id.has_value() || prev_field_id.value() != hit.context()) { - prev_field_id = hit.context(); - matching_field_ids.insert(hit.context()); + if (!prev_field_id.has_value() || prev_field_id.value() != hit.field_id()) { + prev_field_id = hit.field_id(); + matching_field_ids.insert(hit.field_id()); } } } diff --git a/searchlib/src/vespa/searchlib/query/streaming/multi_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/multi_term.cpp index dd34b9b7e73..f5a09892551 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/multi_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/multi_term.cpp @@ -51,11 +51,4 @@ MultiTerm::evaluate() const return false; } -const HitList& -MultiTerm::evaluateHits(HitList& hl) const -{ - hl.clear(); - return hl; -} - } diff --git a/searchlib/src/vespa/searchlib/query/streaming/multi_term.h b/searchlib/src/vespa/searchlib/query/streaming/multi_term.h index 3bb69e29693..6f795c31356 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/multi_term.h +++ b/searchlib/src/vespa/searchlib/query/streaming/multi_term.h @@ -32,7 +32,6 @@ public: MultiTerm* as_multi_term() noexcept override { return this; } void reset() override; bool evaluate() const override; - const HitList& evaluateHits(HitList& hl) const override; virtual void unpack_match_data(uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data) = 0; const std::vector<std::unique_ptr<QueryTerm>>& get_terms() const noexcept { return _terms; } }; diff --git a/searchlib/src/vespa/searchlib/query/streaming/query.cpp b/searchlib/src/vespa/searchlib/query/streaming/query.cpp index ca742aabe26..618922eced9 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/query.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/query.cpp @@ -208,7 +208,7 @@ SameElementQueryNode::evaluateHits(HitList & hl) const currMatchCount++; if ((currMatchCount+1) == numFields) { Hit h = nextHL[indexVector[currMatchCount]]; - hl.emplace_back(0, h.context(), h.elemId(), h.weight()); + hl.emplace_back(0, h.field_id(), h.elemId(), h.weight()); currMatchCount = 0; indexVector[0]++; } @@ -260,26 +260,26 @@ PhraseQueryNode::evaluateHits(HitList & hl) const const auto & currHit = curr->evaluateHits(tmpHL)[currIndex]; size_t firstPosition = currHit.pos(); uint32_t currElemId = currHit.elemId(); - uint32_t currContext = currHit.context(); + uint32_t curr_field_id = currHit.field_id(); const HitList & nextHL = next->evaluateHits(tmpHL); int diff(0); size_t nextIndexMax = nextHL.size(); while ((nextIndex < nextIndexMax) && - ((nextHL[nextIndex].context() < currContext) || - ((nextHL[nextIndex].context() == currContext) && (nextHL[nextIndex].elemId() <= currElemId))) && + ((nextHL[nextIndex].field_id() < curr_field_id) || + ((nextHL[nextIndex].field_id() == curr_field_id) && (nextHL[nextIndex].elemId() <= currElemId))) && ((diff = nextHL[nextIndex].pos()-firstPosition) < 1)) { nextIndex++; } - if ((diff == 1) && (nextHL[nextIndex].context() == currContext) && (nextHL[nextIndex].elemId() == currElemId)) { + if ((diff == 1) && (nextHL[nextIndex].field_id() == curr_field_id) && (nextHL[nextIndex].elemId() == currElemId)) { currPhraseLen++; if ((currPhraseLen+1) == fullPhraseLen) { Hit h = nextHL[indexVector[currPhraseLen]]; hl.push_back(h); - const QueryTerm::FieldInfo & fi = next->getFieldInfo(h.context()); - updateFieldInfo(h.context(), hl.size() - 1, fi.getFieldLength()); + const QueryTerm::FieldInfo & fi = next->getFieldInfo(h.field_id()); + updateFieldInfo(h.field_id(), hl.size() - 1, fi.getFieldLength()); currPhraseLen = 0; indexVector[0]++; } diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp index 2ee515f062a..e71529a8aca 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp @@ -1,7 +1,8 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include "query.h" +#include "fuzzy_term.h" #include "nearest_neighbor_query_node.h" +#include "query.h" #include "regexp_term.h" #include <vespa/searchlib/parsequery/stackdumpiterator.h> #include <vespa/searchlib/query/streaming/dot_product_term.h> @@ -147,17 +148,16 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor } else { Normalizing normalize_mode = factory.normalizing_mode(ssIndex); std::unique_ptr<QueryTerm> qt; - if (sTerm != TermType::REGEXP) { - qt = std::make_unique<QueryTerm>(factory.create(), ssTerm, ssIndex, sTerm, normalize_mode); - } else { + if (sTerm == TermType::REGEXP) { qt = std::make_unique<RegexpTerm>(factory.create(), ssTerm, ssIndex, TermType::REGEXP, normalize_mode); + } else if (sTerm == TermType::FUZZYTERM) { + qt = std::make_unique<FuzzyTerm>(factory.create(), ssTerm, ssIndex, TermType::FUZZYTERM, normalize_mode, + queryRep.getFuzzyMaxEditDistance(), queryRep.getFuzzyPrefixLength()); + } else [[likely]] { + qt = std::make_unique<QueryTerm>(factory.create(), ssTerm, ssIndex, sTerm, normalize_mode); } qt->setWeight(queryRep.GetWeight()); qt->setUniqueId(queryRep.getUniqueId()); - if (qt->isFuzzy()) { - qt->setFuzzyMaxEditDistance(queryRep.getFuzzyMaxEditDistance()); - qt->setFuzzyPrefixLength(queryRep.getFuzzyPrefixLength()); - } if (allowRewrite && possibleFloat(*qt, ssTerm) && factory.allow_float_terms_rewrite(ssIndex)) { auto phrase = std::make_unique<PhraseQueryNode>(); auto dotPos = ssTerm.find('.'); diff --git a/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp b/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp index 3e05d381ee2..fb002ec1867 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp @@ -185,4 +185,10 @@ QueryTerm::as_regexp_term() noexcept return nullptr; } +FuzzyTerm* +QueryTerm::as_fuzzy_term() noexcept +{ + return nullptr; +} + } diff --git a/searchlib/src/vespa/searchlib/query/streaming/queryterm.h b/searchlib/src/vespa/searchlib/query/streaming/queryterm.h index cd2bdd7eaec..b4dfa98ebe5 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/queryterm.h +++ b/searchlib/src/vespa/searchlib/query/streaming/queryterm.h @@ -11,6 +11,7 @@ namespace search::streaming { +class FuzzyTerm; class NearestNeighborQueryNode; class MultiTerm; class RegexpTerm; @@ -64,7 +65,7 @@ public: QueryTerm & operator = (QueryTerm &&) = delete; ~QueryTerm() override; bool evaluate() const override; - const HitList & evaluateHits(HitList & hl) const override; + const HitList & evaluateHits(HitList & hl) const final override; void reset() override; void getLeaves(QueryTermList & tl) override; void getLeaves(ConstQueryTermList & tl) const override; @@ -95,6 +96,7 @@ public: virtual NearestNeighborQueryNode* as_nearest_neighbor_query_node() noexcept; virtual MultiTerm* as_multi_term() noexcept; virtual RegexpTerm* as_regexp_term() noexcept; + virtual FuzzyTerm* as_fuzzy_term() noexcept; protected: using QueryNodeResultBaseContainer = std::unique_ptr<QueryNodeResultBase>; string _index; diff --git a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp index 90d0be5d43c..d2d706eef3d 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp @@ -25,7 +25,7 @@ WeightedSetTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchDat for (const auto& term : _terms) { auto& hl = term->evaluateHits(hl_store); for (auto& hit : hl) { - scores[hit.context()].emplace_back(term->weight().percent()); + scores[hit.field_id()].emplace_back(term->weight().percent()); } } auto num_fields = td.numFields(); diff --git a/searchlib/src/vespa/searchlib/queryeval/dot_product_search.h b/searchlib/src/vespa/searchlib/queryeval/dot_product_search.h index e49fcbcb5bc..c74c2d4e9a7 100644 --- a/searchlib/src/vespa/searchlib/queryeval/dot_product_search.h +++ b/searchlib/src/vespa/searchlib/queryeval/dot_product_search.h @@ -27,6 +27,7 @@ protected: public: static constexpr bool filter_search = false; static constexpr bool require_btree_iterators = true; + static constexpr bool supports_hash_filter = false; // TODO: use MultiSearch::Children to pass ownership static SearchIterator::UP create(const std::vector<SearchIterator*> &children, diff --git a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp index 0c57b21aba6..0be014474d0 100644 --- a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp @@ -1,10 +1,13 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "weighted_set_term_search.h" +#include <vespa/searchcommon/attribute/i_search_context.h> +#include <vespa/searchcommon/attribute/iattributevector.h> +#include <vespa/searchlib/attribute/i_direct_posting_store.h> +#include <vespa/searchlib/attribute/multi_term_hash_filter.hpp> #include <vespa/searchlib/common/bitvector.h> -#include <vespa/searchlib/attribute/multi_term_or_filter_search.h> #include <vespa/vespalib/objects/visit.h> -#include <vespa/searchcommon/attribute/i_search_context.h> +#include <vespa/vespalib/stllike/hash_map.hpp> #include "iterator_pack.h" #include "blueprint.h" @@ -237,4 +240,98 @@ WeightedSetTermSearch::create(fef::TermFieldMatchData &tmd, return create_helper_resolve_pack<DocidWithWeightIterator, DocidWithWeightIteratorPack>(tmd, is_filter_search, std::move(weights), std::move(iterators)); } +namespace { + +class HashFilterWrapper { +protected: + const attribute::IAttributeVector& _attr; +public: + HashFilterWrapper(const attribute::IAttributeVector& attr) : _attr(attr) {} +}; + +template <bool unpack_weights_t> +class StringHashFilterWrapper : public HashFilterWrapper { +public: + using TokenT = attribute::IAttributeVector::EnumHandle; + static constexpr bool unpack_weights = unpack_weights_t; + StringHashFilterWrapper(const attribute::IAttributeVector& attr) + : HashFilterWrapper(attr) + {} + auto mapToken(const IDirectPostingStore::LookupResult& term, const IDirectPostingStore& store, vespalib::datastore::EntryRef dict_snapshot) const { + std::vector<TokenT> result; + store.collect_folded(term.enum_idx, dict_snapshot, [&](vespalib::datastore::EntryRef ref) { result.emplace_back(ref.ref()); }); + return result; + } + TokenT getToken(uint32_t docid) const { + return _attr.getEnum(docid); + } +}; + +template <bool unpack_weights_t> +class IntegerHashFilterWrapper : public HashFilterWrapper { +public: + using TokenT = attribute::IAttributeVector::largeint_t; + static constexpr bool unpack_weights = unpack_weights_t; + IntegerHashFilterWrapper(const attribute::IAttributeVector& attr) + : HashFilterWrapper(attr) + {} + auto mapToken(const IDirectPostingStore::LookupResult& term, + const IDirectPostingStore& store, + vespalib::datastore::EntryRef) const { + std::vector<TokenT> result; + result.emplace_back(store.get_integer_value(term.enum_idx)); + return result; + } + TokenT getToken(uint32_t docid) const { + return _attr.getInt(docid); + } +}; + +template <typename WrapperType> +SearchIterator::UP +create_hash_filter_helper(fef::TermFieldMatchData& tfmd, + const std::vector<int32_t>& weights, + const std::vector<IDirectPostingStore::LookupResult>& terms, + const attribute::IAttributeVector& attr, + const IDirectPostingStore& posting_store, + vespalib::datastore::EntryRef dict_snapshot) +{ + using FilterType = attribute::MultiTermHashFilter<WrapperType>; + typename FilterType::TokenMap tokens; + WrapperType wrapper(attr); + for (size_t i = 0; i < terms.size(); ++i) { + for (auto token : wrapper.mapToken(terms[i], posting_store, dict_snapshot)) { + tokens[token] = weights[i]; + } + } + return std::make_unique<FilterType>(tfmd, wrapper, std::move(tokens)); +} + +} + +SearchIterator::UP +WeightedSetTermSearch::create_hash_filter(search::fef::TermFieldMatchData& tmd, + bool is_filter_search, + const std::vector<int32_t>& weights, + const std::vector<IDirectPostingStore::LookupResult>& terms, + const attribute::IAttributeVector& attr, + const IDirectPostingStore& posting_store, + vespalib::datastore::EntryRef dict_snapshot) +{ + if (attr.isStringType()) { + if (is_filter_search) { + return create_hash_filter_helper<StringHashFilterWrapper<false>>(tmd, weights, terms, attr, posting_store, dict_snapshot); + } else { + return create_hash_filter_helper<StringHashFilterWrapper<true>>(tmd, weights, terms, attr, posting_store, dict_snapshot); + } + } else { + assert(attr.isIntegerType()); + if (is_filter_search) { + return create_hash_filter_helper<IntegerHashFilterWrapper<false>>(tmd, weights, terms, attr, posting_store, dict_snapshot); + } else { + return create_hash_filter_helper<IntegerHashFilterWrapper<true>>(tmd, weights, terms, attr, posting_store, dict_snapshot); + } + } +} + } diff --git a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h index 7e47928fb90..d078fd5babc 100644 --- a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h +++ b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h @@ -11,6 +11,8 @@ #include <variant> #include <vector> +namespace search::attribute { class IAttributeVector; } + namespace search::fef { class TermFieldMatchData; } namespace search::queryeval { @@ -30,6 +32,8 @@ public: static constexpr bool filter_search = false; // Whether this iterator requires btree iterators for all tokens/terms used by the operator. static constexpr bool require_btree_iterators = false; + // Whether this supports creating a hash filter iterator; + static constexpr bool supports_hash_filter = true; // TODO: pass ownership with unique_ptr static SearchIterator::UP create(const std::vector<SearchIterator *> &children, @@ -48,6 +52,14 @@ public: std::variant<std::reference_wrapper<const std::vector<int32_t>>, std::vector<int32_t>> weights, std::vector<DocidWithWeightIterator> &&iterators); + static SearchIterator::UP create_hash_filter(search::fef::TermFieldMatchData& tmd, + bool is_filter_search, + const std::vector<int32_t>& weights, + const std::vector<IDirectPostingStore::LookupResult>& terms, + const attribute::IAttributeVector& attr, + const IDirectPostingStore& posting_store, + vespalib::datastore::EntryRef dictionary_snapshot); + // used during docsum fetching to identify matching elements // initRange must be called before use. // doSeek/doUnpack must not be called. diff --git a/streamingvisitors/src/tests/searcher/searcher_test.cpp b/streamingvisitors/src/tests/searcher/searcher_test.cpp index 24877866c1b..ee2c5e2b5c7 100644 --- a/streamingvisitors/src/tests/searcher/searcher_test.cpp +++ b/streamingvisitors/src/tests/searcher/searcher_test.cpp @@ -3,6 +3,7 @@ #include <vespa/vespalib/testkit/testapp.h> #include <vespa/document/fieldvalue/fieldvalues.h> +#include <vespa/searchlib/query/streaming/fuzzy_term.h> #include <vespa/searchlib/query/streaming/regexp_term.h> #include <vespa/searchlib/query/streaming/queryterm.h> #include <vespa/vsm/searcher/boolfieldsearcher.h> @@ -18,10 +19,14 @@ #include <vespa/vsm/searcher/utf8suffixstringfieldsearcher.h> #include <vespa/vsm/searcher/tokenizereader.h> #include <vespa/vsm/vsm/snippetmodifier.h> +#include <concepts> +#include <charconv> +#include <stdexcept> using namespace document; using search::streaming::HitList; using search::streaming::QueryNodeResultFactory; +using search::streaming::FuzzyTerm; using search::streaming::RegexpTerm; using search::streaming::QueryTerm; using search::streaming::Normalizing; @@ -58,6 +63,46 @@ public: } }; +namespace { + +template <std::integral T> +std::string_view maybe_consume_into(std::string_view str, T& val_out) { + auto [ptr, ec] = std::from_chars(str.data(), str.data() + str.size(), val_out); + if (ec != std::errc()) { + return str; + } + return str.substr(ptr - str.data()); +} + +// Parse optional max edits and prefix lock length from term string. +// Syntax: +// "term" -> {2, 0, "term"} (default max edits & prefix length) +// "{1}term" -> {1, 0, "term"} +// "{1,3}term" -> {1, 3, "term"} +// +// Note: this is not a "proper" parser (it accepts empty numeric values); only for testing! +std::tuple<uint8_t, uint32_t, std::string_view> parse_fuzzy_params(std::string_view term) { + if (term.empty() || term[0] != '{') { + return {2, 0, term}; + } + uint8_t max_edits = 2; + uint32_t prefix_length = 0; + term = maybe_consume_into(term.substr(1), max_edits); + if (term.empty() || (term[0] != ',' && term[0] != '}')) { + throw std::invalid_argument("malformed fuzzy params at (or after) max_edits"); + } + if (term[0] == '}') { + return {max_edits, prefix_length, term.substr(1)}; + } + term = maybe_consume_into(term.substr(1), prefix_length); + if (term.empty() || term[0] != '}') { + throw std::invalid_argument("malformed fuzzy params at (or after) prefix_length"); + } + return {max_edits, prefix_length, term.substr(1)}; +} + +} + class Query { private: @@ -66,10 +111,14 @@ private: ParsedQueryTerm pqt = parseQueryTerm(term); ParsedTerm pt = parseTerm(pqt.second); std::string effective_index = pqt.first.empty() ? "index" : pqt.first; - if (pt.second != TermType::REGEXP) { - qtv.push_back(std::make_unique<QueryTerm>(eqnr.create(), pt.first, effective_index, pt.second, normalizing)); + if (pt.second == TermType::REGEXP) { + qtv.push_back(std::make_unique<RegexpTerm>(eqnr.create(), pt.first, effective_index, TermType::REGEXP, normalizing)); + } else if (pt.second == TermType::FUZZYTERM) { + auto [max_edits, prefix_length, actual_term] = parse_fuzzy_params(pt.first); + qtv.push_back(std::make_unique<FuzzyTerm>(eqnr.create(), vespalib::stringref(actual_term.data(), actual_term.size()), + effective_index, TermType::FUZZYTERM, normalizing, max_edits, prefix_length)); } else { - qtv.push_back(std::make_unique<RegexpTerm>(eqnr.create(), pt.first, effective_index, pt.second, normalizing)); + qtv.push_back(std::make_unique<QueryTerm>(eqnr.create(), pt.first, effective_index, pt.second, normalizing)); } } for (const auto & i : qtv) { @@ -100,6 +149,8 @@ public: return std::make_pair(term.substr(1, term.size() - 1), TermType::SUFFIXTERM); } else if (term[0] == '#') { // magic regex enabler return std::make_pair(term.substr(1), TermType::REGEXP); + } else if (term[0] == '%') { // equally magic fuzzy enabler + return std::make_pair(term.substr(1), TermType::FUZZYTERM); } else if (term[term.size() - 1] == '*') { return std::make_pair(term.substr(0, term.size() - 1), TermType::PREFIXTERM); } else { @@ -477,31 +528,54 @@ testStrChrFieldSearcher(StrChrFieldSearcher & fs) return true; } - TEST("verify correct term parsing") { - ASSERT_TRUE(Query::parseQueryTerm("index:term").first == "index"); - ASSERT_TRUE(Query::parseQueryTerm("index:term").second == "term"); - ASSERT_TRUE(Query::parseQueryTerm("term").first.empty()); - ASSERT_TRUE(Query::parseQueryTerm("term").second == "term"); - ASSERT_TRUE(Query::parseTerm("*substr*").first == "substr"); - ASSERT_TRUE(Query::parseTerm("*substr*").second == TermType::SUBSTRINGTERM); - ASSERT_TRUE(Query::parseTerm("*suffix").first == "suffix"); - ASSERT_TRUE(Query::parseTerm("*suffix").second == TermType::SUFFIXTERM); - ASSERT_TRUE(Query::parseTerm("prefix*").first == "prefix"); - ASSERT_TRUE(Query::parseTerm("prefix*").second == TermType::PREFIXTERM); - ASSERT_TRUE(Query::parseTerm("#regex").first == "regex"); - ASSERT_TRUE(Query::parseTerm("#regex").second == TermType::REGEXP); - ASSERT_TRUE(Query::parseTerm("term").first == "term"); - ASSERT_TRUE(Query::parseTerm("term").second == TermType::WORD); - } - - TEST("suffix matching") { - EXPECT_EQUAL(assertMatchTermSuffix("a", "vespa"), true); - EXPECT_EQUAL(assertMatchTermSuffix("spa", "vespa"), true); - EXPECT_EQUAL(assertMatchTermSuffix("vespa", "vespa"), true); - EXPECT_EQUAL(assertMatchTermSuffix("vvespa", "vespa"), false); - EXPECT_EQUAL(assertMatchTermSuffix("fspa", "vespa"), false); - EXPECT_EQUAL(assertMatchTermSuffix("v", "vespa"), false); - } +TEST("parsing of test-only fuzzy term params can extract numeric values") { + uint8_t max_edits = 0; + uint32_t prefix_length = 1234; + std::string_view out; + + std::tie(max_edits, prefix_length, out) = parse_fuzzy_params("myterm"); + EXPECT_EQUAL(max_edits, 2u); + EXPECT_EQUAL(prefix_length, 0u); + EXPECT_EQUAL(out, "myterm"); + + std::tie(max_edits, prefix_length, out) = parse_fuzzy_params("{3}myterm"); + EXPECT_EQUAL(max_edits, 3u); + EXPECT_EQUAL(prefix_length, 0u); + EXPECT_EQUAL(out, "myterm"); + + std::tie(max_edits, prefix_length, out) = parse_fuzzy_params("{2,70}myterm"); + EXPECT_EQUAL(max_edits, 2u); + EXPECT_EQUAL(prefix_length, 70u); + EXPECT_EQUAL(out, "myterm"); +} + +TEST("verify correct term parsing") { + ASSERT_TRUE(Query::parseQueryTerm("index:term").first == "index"); + ASSERT_TRUE(Query::parseQueryTerm("index:term").second == "term"); + ASSERT_TRUE(Query::parseQueryTerm("term").first.empty()); + ASSERT_TRUE(Query::parseQueryTerm("term").second == "term"); + ASSERT_TRUE(Query::parseTerm("*substr*").first == "substr"); + ASSERT_TRUE(Query::parseTerm("*substr*").second == TermType::SUBSTRINGTERM); + ASSERT_TRUE(Query::parseTerm("*suffix").first == "suffix"); + ASSERT_TRUE(Query::parseTerm("*suffix").second == TermType::SUFFIXTERM); + ASSERT_TRUE(Query::parseTerm("prefix*").first == "prefix"); + ASSERT_TRUE(Query::parseTerm("prefix*").second == TermType::PREFIXTERM); + ASSERT_TRUE(Query::parseTerm("#regex").first == "regex"); + ASSERT_TRUE(Query::parseTerm("#regex").second == TermType::REGEXP); + ASSERT_TRUE(Query::parseTerm("%fuzzy").first == "fuzzy"); + ASSERT_TRUE(Query::parseTerm("%fuzzy").second == TermType::FUZZYTERM); + ASSERT_TRUE(Query::parseTerm("term").first == "term"); + ASSERT_TRUE(Query::parseTerm("term").second == TermType::WORD); +} + +TEST("suffix matching") { + EXPECT_EQUAL(assertMatchTermSuffix("a", "vespa"), true); + EXPECT_EQUAL(assertMatchTermSuffix("spa", "vespa"), true); + EXPECT_EQUAL(assertMatchTermSuffix("vespa", "vespa"), true); + EXPECT_EQUAL(assertMatchTermSuffix("vvespa", "vespa"), false); + EXPECT_EQUAL(assertMatchTermSuffix("fspa", "vespa"), false); + EXPECT_EQUAL(assertMatchTermSuffix("v", "vespa"), false); +} TEST("Test basic strchrfield searchers") { { @@ -654,6 +728,96 @@ TEST("utf8 flexible searcher handles regexes with explicit anchoring") { TEST_DO(assertString(fs, "#^foo$", "oo", Hits())); } +TEST("utf8 flexible searcher handles fuzzy search in uncased mode") { + UTF8FlexibleStringFieldSearcher fs(0); + // Term syntax (only applies to these tests): + // %{k}term => fuzzy match "term" with max edits k + // %{k,p}term => fuzzy match "term" with max edits k, prefix lock length p + + // DFA is used for k in {1, 2} + TEST_DO(assertString(fs, "%{1}abc", "abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}ABC", "abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}abc", "ABC", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}Abc", "abd", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}abc", "ABCD", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}abc", "abcde", Hits())); + TEST_DO(assertString(fs, "%{2}abc", "abcde", Hits().add(0))); + TEST_DO(assertString(fs, "%{2}abc", "xabcde", Hits())); + // Fallback to non-DFA matcher when k not in {1, 2} + TEST_DO(assertString(fs, "%{3}abc", "abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{3}abc", "XYZ", Hits().add(0))); + TEST_DO(assertString(fs, "%{3}abc", "XYZ!", Hits())); +} + +TEST("utf8 flexible searcher handles fuzzy search in cased mode") { + UTF8FlexibleStringFieldSearcher fs(0); + fs.normalize_mode(Normalizing::NONE); + TEST_DO(assertString(fs, "%{1}abc", "abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}abc", "Abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}ABC", "abc", Hits())); + TEST_DO(assertString(fs, "%{2}Abc", "abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{2}abc", "AbC", Hits().add(0))); + TEST_DO(assertString(fs, "%{3}abc", "ABC", Hits().add(0))); + TEST_DO(assertString(fs, "%{3}abc", "ABCD", Hits())); +} + +TEST("utf8 flexible searcher handles fuzzy search with prefix locking") { + UTF8FlexibleStringFieldSearcher fs(0); + // DFA + TEST_DO(assertString(fs, "%{1,4}zoid", "zoi", Hits())); + TEST_DO(assertString(fs, "%{1,4}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,4}zoid", "ZOID", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoid", Hits())); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "ZoidBerg", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "ZoidBergg", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoidborg", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoidblergh", Hits())); + TEST_DO(assertString(fs, "%{2,4}zoidberg", "zoidblergh", Hits().add(0))); + // Fallback + TEST_DO(assertString(fs, "%{3,4}zoidberg", "zoidblergh", Hits().add(0))); + TEST_DO(assertString(fs, "%{3,4}zoidberg", "zoidbooorg", Hits().add(0))); + TEST_DO(assertString(fs, "%{3,4}zoidberg", "zoidzooorg", Hits())); + + fs.normalize_mode(Normalizing::NONE); + // DFA + TEST_DO(assertString(fs, "%{1,4}zoid", "ZOID", Hits())); + TEST_DO(assertString(fs, "%{1,4}ZOID", "zoid", Hits())); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoidBerg", Hits().add(0))); // 1 edit + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoidBblerg", Hits())); // 2 edits, 1 max + TEST_DO(assertString(fs, "%{2,4}zoidberg", "zoidBblerg", Hits().add(0))); // 2 edits, 2 max + // Fallback + TEST_DO(assertString(fs, "%{3,4}zoidberg", "zoidBERG", Hits())); // 4 edits, 3 max + TEST_DO(assertString(fs, "%{4,4}zoidberg", "zoidBERG", Hits().add(0))); // 4 edits, 4 max +} + +TEST("utf8 flexible searcher fuzzy match with max_edits=0 implies exact match") { + UTF8FlexibleStringFieldSearcher fs(0); + TEST_DO(assertString(fs, "%{0}zoid", "zoi", Hits())); + TEST_DO(assertString(fs, "%{0,4}zoid", "zoi", Hits())); + TEST_DO(assertString(fs, "%{0}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{0}zoid", "ZOID", Hits().add(0))); + TEST_DO(assertString(fs, "%{0,4}zoid", "ZOID", Hits().add(0))); + fs.normalize_mode(Normalizing::NONE); + TEST_DO(assertString(fs, "%{0}zoid", "ZOID", Hits())); + TEST_DO(assertString(fs, "%{0,4}zoid", "ZOID", Hits())); + TEST_DO(assertString(fs, "%{0}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{0,4}zoid", "zoid", Hits().add(0))); +} + +TEST("utf8 flexible searcher caps oversized fuzzy prefix length to term length") { + UTF8FlexibleStringFieldSearcher fs(0); + // DFA + TEST_DO(assertString(fs, "%{1,5}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,9001}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,9001}zoid", "boid", Hits())); + // Fallback + TEST_DO(assertString(fs, "%{0,5}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{5,5}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{0,9001}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{5,9001}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{5,9001}zoid", "boid", Hits())); +} + TEST("bool search") { BoolFieldSearcher fs(0); TEST_DO(assertBool(fs, "true", true, true)); diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp index 6b15b7cb88e..a350bfa7b21 100644 --- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp @@ -300,7 +300,7 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap // optimize for hitlist giving all hits for a single field in one chunk for (const Hit & hit : hitList) { - uint32_t fieldId = hit.context(); + uint32_t fieldId = hit.field_id(); if (fieldId != lastFieldId) { // reset to notfound/unknown values tmd = nullptr; diff --git a/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.h b/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.h index c5bca6f3899..bb3aa6fdd10 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.h +++ b/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.h @@ -46,7 +46,7 @@ public: explicit FieldSearcher(FieldIdT fId) noexcept : FieldSearcher(fId, false) {} FieldSearcher(FieldIdT fId, bool defaultPrefix) noexcept; ~FieldSearcher() override; - virtual std::unique_ptr<FieldSearcher> duplicate() const = 0; + [[nodiscard]] virtual std::unique_ptr<FieldSearcher> duplicate() const = 0; bool search(const StorageDocument & doc); virtual void prepare(search::streaming::QueryTermList& qtl, const SharedSearcherBuf& buf, const vsm::FieldPathMapT& field_paths, search::fef::IQueryEnvironment& query_env); diff --git a/streamingvisitors/src/vespa/vsm/searcher/strchrfieldsearcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/strchrfieldsearcher.cpp index c0a0249125f..98e88e45b3a 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/strchrfieldsearcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/strchrfieldsearcher.cpp @@ -34,7 +34,7 @@ bool StrChrFieldSearcher::matchDoc(const FieldRef & fieldRef) } } else { for (auto qt : _qtl) { - if (fieldRef.size() >= qt->termLen() || qt->isRegex()) { + if (fieldRef.size() >= qt->termLen() || qt->isRegex() || qt->isFuzzy()) { _words += matchTerm(fieldRef, *qt); } else { _words += countWords(fieldRef); @@ -49,8 +49,8 @@ size_t StrChrFieldSearcher::shortestTerm() const size_t mintsz(_qtl.front()->termLen()); for (auto it=_qtl.begin()+1, mt=_qtl.end(); it != mt; it++) { const QueryTerm & qt = **it; - if (qt.isRegex()) { - return 0; // Must avoid "too short query term" optimization when using regex + if (qt.isRegex() || qt.isFuzzy()) { + return 0; // Must avoid "too short query term" optimization when using regex or fuzzy } mintsz = std::min(mintsz, qt.termLen()); } diff --git a/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.cpp index c6deb6eacd1..d648d2e252e 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.cpp @@ -1,5 +1,6 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "utf8flexiblestringfieldsearcher.h" +#include <vespa/searchlib/query/streaming/fuzzy_term.h> #include <vespa/searchlib/query/streaming/regexp_term.h> #include <cassert> @@ -40,6 +41,19 @@ UTF8FlexibleStringFieldSearcher::match_regexp(const FieldRef & f, search::stream } size_t +UTF8FlexibleStringFieldSearcher::match_fuzzy(const FieldRef & f, search::streaming::QueryTerm & qt) +{ + auto* fuzzy_term = qt.as_fuzzy_term(); + assert(fuzzy_term != nullptr); + // TODO delegate to matchTermExact if max edits == 0? + // - needs to avoid folding to have consistent normalization semantics + if (fuzzy_term->is_match({f.data(), f.size()})) { + addHit(qt, 0); + } + return countWords(f); +} + +size_t UTF8FlexibleStringFieldSearcher::matchTerm(const FieldRef & f, QueryTerm & qt) { if (qt.isPrefix()) { @@ -57,6 +71,9 @@ UTF8FlexibleStringFieldSearcher::matchTerm(const FieldRef & f, QueryTerm & qt) } else if (qt.isRegex()) { LOG(debug, "Use regexp match for term '%s:%s'", qt.index().c_str(), qt.getTerm()); return match_regexp(f, qt); + } else if (qt.isFuzzy()) { + LOG(debug, "Use fuzzy match for term '%s:%s'", qt.index().c_str(), qt.getTerm()); + return match_fuzzy(f, qt); } else { if (substring()) { LOG(debug, "Use substring match for term '%s:%s'", qt.index().c_str(), qt.getTerm()); diff --git a/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.h b/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.h index cd1715ad158..a5f6ad46246 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.h +++ b/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.h @@ -25,6 +25,7 @@ private: size_t matchTerms(const FieldRef & f, size_t shortestTerm) override; size_t match_regexp(const FieldRef & f, search::streaming::QueryTerm & qt); + size_t match_fuzzy(const FieldRef & f, search::streaming::QueryTerm & qt); public: std::unique_ptr<FieldSearcher> duplicate() const override; diff --git a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp index 9c8bb2f185a..63d2007cecf 100644 --- a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp +++ b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp @@ -133,7 +133,7 @@ FieldSearchSpec::reconfig(const QueryTerm & term) (term.isSuffix() && _arg1 != "suffix") || (term.isExactstring() && _arg1 != "exact") || (term.isPrefix() && _arg1 == "suffix") || - term.isRegex()) + (term.isRegex() || term.isFuzzy())) { _searcher = std::make_unique<UTF8FlexibleStringFieldSearcher>(id()); propagate_settings_to_searcher(); diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 5d88b2d2829..174ce6332db 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -720,6 +720,20 @@ ], "fields" : [ ] }, + "com.yahoo.tensor.DirectIndexedAddress" : { + "superClass" : "java.lang.Object", + "interfaces" : [ ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public void setIndex(int, int)", + "public long getDirectIndex()", + "public long getStride(int)" + ], + "fields" : [ ] + }, "com.yahoo.tensor.IndexedDoubleTensor$BoundDoubleBuilder" : { "superClass" : "com.yahoo.tensor.IndexedTensor$BoundBuilder", "interfaces" : [ ], @@ -894,6 +908,8 @@ "public java.util.Iterator subspaceIterator(java.util.Set, com.yahoo.tensor.DimensionSizes)", "public java.util.Iterator subspaceIterator(java.util.Set)", "public varargs double get(long[])", + "public double get(com.yahoo.tensor.DirectIndexedAddress)", + "public com.yahoo.tensor.DirectIndexedAddress directAddress()", "public varargs float getFloat(long[])", "public double get(com.yahoo.tensor.TensorAddress)", "public boolean has(com.yahoo.tensor.TensorAddress)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java new file mode 100644 index 00000000000..37752361876 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java @@ -0,0 +1,38 @@ +package com.yahoo.tensor; + +/** + * Utility class for efficient access and iteration along dimensions in Indexed tensors. + * Usage: Use setIndex to lock the indexes of the dimensions that don't change in this iteration. + * long base = addr.getDirectIndex(); + * long stride = addr.getStride(dimension) + * i = 0...size_of_dimension + * double value = tensor.get(base + i * stride); + */ +public final class DirectIndexedAddress { + private final DimensionSizes sizes; + private final int [] indexes; + private long directIndex; + private DirectIndexedAddress(DimensionSizes sizes) { + this.sizes = sizes; + indexes = new int[sizes.dimensions()]; + directIndex = 0; + } + static DirectIndexedAddress of(DimensionSizes sizes) { + return new DirectIndexedAddress(sizes); + } + /** Sets the current index of a dimension */ + public void setIndex(int dimension, int index) { + if (index < 0 || index >= sizes.size(dimension)) { + throw new IndexOutOfBoundsException("Index " + index + " outside of [0," + sizes.size(dimension) + ">"); + } + int diff = index - indexes[dimension]; + directIndex += getStride(dimension) * diff; + indexes[dimension] = index; + } + /** Retrieve the index that can be used for direct lookup in an indexed tensor. */ + public long getDirectIndex() { return directIndex; } + /** returns the stride to be used for the given dimension */ + public long getStride(int dimension) { + return sizes.productOfDimensionsAfter(dimension); + } +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 1319675f5d4..93cdc3f630f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -93,6 +93,10 @@ public abstract class IndexedTensor implements Tensor { return get(toValueIndex(indexes, dimensionSizes)); } + public double get(DirectIndexedAddress address) { + return get(address.getDirectIndex()); + } + public DirectIndexedAddress directAddress() { return DirectIndexedAddress.of(dimensionSizes); } /** * Returns the value at the given indexes as a float * diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java index 0a6c821e64e..afc95d295f0 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java @@ -9,6 +9,7 @@ import java.util.Iterator; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -96,6 +97,38 @@ public class IndexedTensorTestCase { } @Test + public void testDirectIndexedAddress() { + TensorType type = new TensorType.Builder().indexed("v", 3) + .indexed("w", wSize) + .indexed("x", xSize) + .indexed("y", ySize) + .indexed("z", zSize) + .build(); + var directAddress = DirectIndexedAddress.of(DimensionSizes.of(type)); + assertThrows(ArrayIndexOutOfBoundsException.class, () -> directAddress.getStride(5)); + assertThrows(IndexOutOfBoundsException.class, () -> directAddress.setIndex(4, 7)); + assertEquals(wSize*xSize*ySize*zSize, directAddress.getStride(0)); + assertEquals(xSize*ySize*zSize, directAddress.getStride(1)); + assertEquals(ySize*zSize, directAddress.getStride(2)); + assertEquals(zSize, directAddress.getStride(3)); + assertEquals(1, directAddress.getStride(4)); + assertEquals(0, directAddress.getDirectIndex()); + directAddress.setIndex(0,1); + assertEquals(1 * directAddress.getStride(0), directAddress.getDirectIndex()); + directAddress.setIndex(1,1); + assertEquals(1 * directAddress.getStride(0) + 1 * directAddress.getStride(1), directAddress.getDirectIndex()); + directAddress.setIndex(2,2); + directAddress.setIndex(3,2); + directAddress.setIndex(4,2); + long expected = 1 * directAddress.getStride(0) + + 1 * directAddress.getStride(1) + + 2 * directAddress.getStride(2) + + 2 * directAddress.getStride(3) + + 2 * directAddress.getStride(4); + assertEquals(expected, directAddress.getDirectIndex()); + } + + @Test public void testUnboundBuilding() { TensorType type = new TensorType.Builder().indexed("w") .indexed("v") |