diff options
86 files changed, 1579 insertions, 781 deletions
diff --git a/client/js/app/yarn.lock b/client/js/app/yarn.lock index e9dc5bf25fe..ebf1c99db13 100644 --- a/client/js/app/yarn.lock +++ b/client/js/app/yarn.lock @@ -1311,70 +1311,70 @@ resolved "https://registry.yarnpkg.com/@remix-run/router/-/router-1.14.2.tgz#4d58f59908d9197ba3179310077f25c88e49ed17" integrity sha512-ACXpdMM9hmKZww21yEqWwiLws/UPLhNKvimN8RrYSqPSvB3ov7sLvAcfvaxePeLvccTQKGdkDIhLYApZVDFuKg== -"@rollup/rollup-android-arm-eabi@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.9.4.tgz#b1094962742c1a0349587040bc06185e2a667c9b" - integrity sha512-ub/SN3yWqIv5CWiAZPHVS1DloyZsJbtXmX4HxUTIpS0BHm9pW5iYBo2mIZi+hE3AeiTzHz33blwSnhdUo+9NpA== - -"@rollup/rollup-android-arm64@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.9.4.tgz#96eb86fb549e05b187f2ad06f51d191a23cb385a" - integrity sha512-ehcBrOR5XTl0W0t2WxfTyHCR/3Cq2jfb+I4W+Ch8Y9b5G+vbAecVv0Fx/J1QKktOrgUYsIKxWAKgIpvw56IFNA== - -"@rollup/rollup-darwin-arm64@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.9.4.tgz#2456630c007cc5905cb368acb9ff9fc04b2d37be" - integrity sha512-1fzh1lWExwSTWy8vJPnNbNM02WZDS8AW3McEOb7wW+nPChLKf3WG2aG7fhaUmfX5FKw9zhsF5+MBwArGyNM7NA== - -"@rollup/rollup-darwin-x64@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.9.4.tgz#97742214fc7dfd47a0f74efba6f5ae264e29c70c" - integrity sha512-Gc6cukkF38RcYQ6uPdiXi70JB0f29CwcQ7+r4QpfNpQFVHXRd0DfWFidoGxjSx1DwOETM97JPz1RXL5ISSB0pA== - -"@rollup/rollup-linux-arm-gnueabihf@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.9.4.tgz#cd933e61d6f689c9cdefde424beafbd92cfe58e2" - integrity sha512-g21RTeFzoTl8GxosHbnQZ0/JkuFIB13C3T7Y0HtKzOXmoHhewLbVTFBQZu+z5m9STH6FZ7L/oPgU4Nm5ErN2fw== - -"@rollup/rollup-linux-arm64-gnu@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.9.4.tgz#33b09bf462f1837afc1e02a1b352af6b510c78a6" - integrity sha512-TVYVWD/SYwWzGGnbfTkrNpdE4HON46orgMNHCivlXmlsSGQOx/OHHYiQcMIOx38/GWgwr/po2LBn7wypkWw/Mg== - -"@rollup/rollup-linux-arm64-musl@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.9.4.tgz#50257fb248832c2308064e3764a16273b6ee4615" - integrity sha512-XcKvuendwizYYhFxpvQ3xVpzje2HHImzg33wL9zvxtj77HvPStbSGI9czrdbfrf8DGMcNNReH9pVZv8qejAQ5A== - -"@rollup/rollup-linux-riscv64-gnu@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.9.4.tgz#09589e4e1a073cf56f6249b77eb6c9a8e9b613a8" - integrity sha512-LFHS/8Q+I9YA0yVETyjonMJ3UA+DczeBd/MqNEzsGSTdNvSJa1OJZcSH8GiXLvcizgp9AlHs2walqRcqzjOi3A== - -"@rollup/rollup-linux-x64-gnu@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.9.4.tgz#bd312bb5b5f02e54d15488605d15cfd3f90dda7c" - integrity sha512-dIYgo+j1+yfy81i0YVU5KnQrIJZE8ERomx17ReU4GREjGtDW4X+nvkBak2xAUpyqLs4eleDSj3RrV72fQos7zw== - -"@rollup/rollup-linux-x64-musl@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.9.4.tgz#25b3bede85d86438ce28cc642842d10d867d40e9" - integrity sha512-RoaYxjdHQ5TPjaPrLsfKqR3pakMr3JGqZ+jZM0zP2IkDtsGa4CqYaWSfQmZVgFUCgLrTnzX+cnHS3nfl+kB6ZQ== - -"@rollup/rollup-win32-arm64-msvc@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.9.4.tgz#95957067eb107f571da1d81939f017d37b4958d3" - integrity sha512-T8Q3XHV+Jjf5e49B4EAaLKV74BbX7/qYBRQ8Wop/+TyyU0k+vSjiLVSHNWdVd1goMjZcbhDmYZUYW5RFqkBNHQ== - -"@rollup/rollup-win32-ia32-msvc@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.9.4.tgz#71b6facad976db527863f698692c6964c0b6e10e" - integrity sha512-z+JQ7JirDUHAsMecVydnBPWLwJjbppU+7LZjffGf+Jvrxq+dVjIE7By163Sc9DKc3ADSU50qPVw0KonBS+a+HQ== - -"@rollup/rollup-win32-x64-msvc@4.9.4": - version "4.9.4" - resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.9.4.tgz#16295ccae354707c9bc6842906bdeaad4f3ba7a5" - integrity sha512-LfdGXCV9rdEify1oxlN9eamvDSjv9md9ZVMAbNHA87xqIfFCxImxan9qZ8+Un54iK2nnqPlbnSi4R54ONtbWBw== +"@rollup/rollup-android-arm-eabi@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.9.5.tgz#b752b6c88a14ccfcbdf3f48c577ccc3a7f0e66b9" + integrity sha512-idWaG8xeSRCfRq9KpRysDHJ/rEHBEXcHuJ82XY0yYFIWnLMjZv9vF/7DOq8djQ2n3Lk6+3qfSH8AqlmHlmi1MA== + +"@rollup/rollup-android-arm64@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.9.5.tgz#33757c3a448b9ef77b6f6292d8b0ec45c87e9c1a" + integrity sha512-f14d7uhAMtsCGjAYwZGv6TwuS3IFaM4ZnGMUn3aCBgkcHAYErhV1Ad97WzBvS2o0aaDv4mVz+syiN0ElMyfBPg== + +"@rollup/rollup-darwin-arm64@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.9.5.tgz#5234ba62665a3f443143bc8bcea9df2cc58f55fb" + integrity sha512-ndoXeLx455FffL68OIUrVr89Xu1WLzAG4n65R8roDlCoYiQcGGg6MALvs2Ap9zs7AHg8mpHtMpwC8jBBjZrT/w== + +"@rollup/rollup-darwin-x64@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.9.5.tgz#981256c054d3247b83313724938d606798a919d1" + integrity sha512-UmElV1OY2m/1KEEqTlIjieKfVwRg0Zwg4PLgNf0s3glAHXBN99KLpw5A5lrSYCa1Kp63czTpVll2MAqbZYIHoA== + +"@rollup/rollup-linux-arm-gnueabihf@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.9.5.tgz#120678a5a2b3a283a548dbb4d337f9187a793560" + integrity sha512-Q0LcU61v92tQB6ae+udZvOyZ0wfpGojtAKrrpAaIqmJ7+psq4cMIhT/9lfV6UQIpeItnq/2QDROhNLo00lOD1g== + +"@rollup/rollup-linux-arm64-gnu@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.9.5.tgz#c99d857e2372ece544b6f60b85058ad259f64114" + integrity sha512-dkRscpM+RrR2Ee3eOQmRWFjmV/payHEOrjyq1VZegRUa5OrZJ2MAxBNs05bZuY0YCtpqETDy1Ix4i/hRqX98cA== + +"@rollup/rollup-linux-arm64-musl@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.9.5.tgz#3064060f568a5718c2a06858cd6e6d24f2ff8632" + integrity sha512-QaKFVOzzST2xzY4MAmiDmURagWLFh+zZtttuEnuNn19AiZ0T3fhPyjPPGwLNdiDT82ZE91hnfJsUiDwF9DClIQ== + +"@rollup/rollup-linux-riscv64-gnu@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.9.5.tgz#987d30b5d2b992fff07d055015991a57ff55fbad" + integrity sha512-HeGqmRJuyVg6/X6MpE2ur7GbymBPS8Np0S/vQFHDmocfORT+Zt76qu+69NUoxXzGqVP1pzaY6QIi0FJWLC3OPA== + +"@rollup/rollup-linux-x64-gnu@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.9.5.tgz#85946ee4d068bd12197aeeec2c6f679c94978a49" + integrity sha512-Dq1bqBdLaZ1Gb/l2e5/+o3B18+8TI9ANlA1SkejZqDgdU/jK/ThYaMPMJpVMMXy2uRHvGKbkz9vheVGdq3cJfA== + +"@rollup/rollup-linux-x64-musl@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.9.5.tgz#fe0b20f9749a60eb1df43d20effa96c756ddcbd4" + integrity sha512-ezyFUOwldYpj7AbkwyW9AJ203peub81CaAIVvckdkyH8EvhEIoKzaMFJj0G4qYJ5sw3BpqhFrsCc30t54HV8vg== + +"@rollup/rollup-win32-arm64-msvc@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.9.5.tgz#422661ef0e16699a234465d15b2c1089ef963b2a" + integrity sha512-aHSsMnUw+0UETB0Hlv7B/ZHOGY5bQdwMKJSzGfDfvyhnpmVxLMGnQPGNE9wgqkLUs3+gbG1Qx02S2LLfJ5GaRQ== + +"@rollup/rollup-win32-ia32-msvc@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.9.5.tgz#7b73a145891c202fbcc08759248983667a035d85" + integrity sha512-AiqiLkb9KSf7Lj/o1U3SEP9Zn+5NuVKgFdRIZkvd4N0+bYrTOovVd0+LmYCPQGbocT4kvFyK+LXCDiXPBF3fyA== + +"@rollup/rollup-win32-x64-msvc@4.9.5": + version "4.9.5" + resolved "https://registry.yarnpkg.com/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.9.5.tgz#10491ccf4f63c814d4149e0316541476ea603602" + integrity sha512-1q+mykKE3Vot1kaFJIDoUFv5TuW+QQVaf2FmTT9krg86pQrGStOSJJ0Zil7CFagyxDuouTepzt5Y5TVzyajOdQ== "@sinclair/typebox@^0.27.8": version "0.27.8" @@ -4670,17 +4670,17 @@ react-refresh@^0.14.0: integrity sha512-wViHqhAd8OHeLS/IRMJjTSDHF3U9eWi62F/MledQGPdJGDhodXJ9PBLNGr6WWL7qlH12Mt3TyTpbS+hGXMjCzQ== react-router-dom@^6: - version "6.21.2" - resolved "https://registry.yarnpkg.com/react-router-dom/-/react-router-dom-6.21.2.tgz#5fba851731a194fa32c31990c4829c5e247f650a" - integrity sha512-tE13UukgUOh2/sqYr6jPzZTzmzc70aGRP4pAjG2if0IP3aUT+sBtAKUJh0qMh0zylJHGLmzS+XWVaON4UklHeg== + version "6.21.3" + resolved "https://registry.yarnpkg.com/react-router-dom/-/react-router-dom-6.21.3.tgz#ef3a7956a3699c7b82c21fcb3dbc63c313ed8c5d" + integrity sha512-kNzubk7n4YHSrErzjLK72j0B5i969GsuCGazRl3G6j1zqZBLjuSlYBdVdkDOgzGdPIffUOc9nmgiadTEVoq91g== dependencies: "@remix-run/router" "1.14.2" - react-router "6.21.2" + react-router "6.21.3" -react-router@6.21.2: - version "6.21.2" - resolved "https://registry.yarnpkg.com/react-router/-/react-router-6.21.2.tgz#8820906c609ae7e4e8f926cc8eb5ce161428b956" - integrity sha512-jJcgiwDsnaHIeC+IN7atO0XiSRCrOsQAHHbChtJxmgqG2IaYQXSnhqGb5vk2CU/wBQA12Zt+TkbuJjIn65gzbA== +react-router@6.21.3: + version "6.21.3" + resolved "https://registry.yarnpkg.com/react-router/-/react-router-6.21.3.tgz#8086cea922c2bfebbb49c6594967418f1f167d70" + integrity sha512-a0H638ZXULv1OdkmiK6s6itNhoy33ywxmUFT/xtSoVyf9VnC7n7+VT4LjVzdIHSaF5TIh9ylUgxMXksHTgGrKg== dependencies: "@remix-run/router" "1.14.2" @@ -4845,25 +4845,25 @@ rimraf@^3.0.2: glob "^7.1.3" rollup@^4.2.0: - version "4.9.4" - resolved "https://registry.yarnpkg.com/rollup/-/rollup-4.9.4.tgz#37bc0c09ae6b4538a9c974f4d045bb64b2e7c27c" - integrity sha512-2ztU7pY/lrQyXSCnnoU4ICjT/tCG9cdH3/G25ERqE3Lst6vl2BCM5hL2Nw+sslAvAf+ccKsAq1SkKQALyqhR7g== + version "4.9.5" + resolved "https://registry.yarnpkg.com/rollup/-/rollup-4.9.5.tgz#62999462c90f4c8b5d7c38fc7161e63b29101b05" + integrity sha512-E4vQW0H/mbNMw2yLSqJyjtkHY9dslf/p0zuT1xehNRqUTBOFMqEjguDvqhXr7N7r/4ttb2jr4T41d3dncmIgbQ== dependencies: "@types/estree" "1.0.5" optionalDependencies: - "@rollup/rollup-android-arm-eabi" "4.9.4" - "@rollup/rollup-android-arm64" "4.9.4" - "@rollup/rollup-darwin-arm64" "4.9.4" - "@rollup/rollup-darwin-x64" "4.9.4" - "@rollup/rollup-linux-arm-gnueabihf" "4.9.4" - "@rollup/rollup-linux-arm64-gnu" "4.9.4" - "@rollup/rollup-linux-arm64-musl" "4.9.4" - "@rollup/rollup-linux-riscv64-gnu" "4.9.4" - "@rollup/rollup-linux-x64-gnu" "4.9.4" - "@rollup/rollup-linux-x64-musl" "4.9.4" - "@rollup/rollup-win32-arm64-msvc" "4.9.4" - "@rollup/rollup-win32-ia32-msvc" "4.9.4" - "@rollup/rollup-win32-x64-msvc" "4.9.4" + "@rollup/rollup-android-arm-eabi" "4.9.5" + "@rollup/rollup-android-arm64" "4.9.5" + "@rollup/rollup-darwin-arm64" "4.9.5" + "@rollup/rollup-darwin-x64" "4.9.5" + "@rollup/rollup-linux-arm-gnueabihf" "4.9.5" + "@rollup/rollup-linux-arm64-gnu" "4.9.5" + "@rollup/rollup-linux-arm64-musl" "4.9.5" + "@rollup/rollup-linux-riscv64-gnu" "4.9.5" + "@rollup/rollup-linux-x64-gnu" "4.9.5" + "@rollup/rollup-linux-x64-musl" "4.9.5" + "@rollup/rollup-win32-arm64-msvc" "4.9.5" + "@rollup/rollup-win32-ia32-msvc" "4.9.5" + "@rollup/rollup-win32-x64-msvc" "4.9.5" fsevents "~2.3.2" rsvp@^4.8.4: @@ -5474,9 +5474,9 @@ v8-to-istanbul@^9.0.1: convert-source-map "^1.6.0" vite@^5.0.5: - version "5.0.11" - resolved "https://registry.yarnpkg.com/vite/-/vite-5.0.11.tgz#31562e41e004cb68e1d51f5d2c641ab313b289e4" - integrity sha512-XBMnDjZcNAw/G1gEiskiM1v6yzM4GE5aMGvhWTlHAYYhxb7S3/V1s3m2LDHa8Vh6yIWYYB0iJwsEaS523c4oYA== + version "5.0.12" + resolved "https://registry.yarnpkg.com/vite/-/vite-5.0.12.tgz#8a2ffd4da36c132aec4adafe05d7adde38333c47" + integrity sha512-4hsnEkG3q0N4Tzf1+t6NdN9dg/L3BM+q8SWgbSPnJvrgH2kgdyzfVJwbR1ic69/4uMJJ/3dqDZZE5/WwqW8U1w== dependencies: esbuild "^0.19.3" postcss "^8.4.32" diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index 502b054f84e..9b3e236612a 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -615,18 +615,6 @@ public class RankProfile implements Cloneable { .orElse(Set.of()); } - private void addSummaryFeature(ReferenceNode feature) { - if (summaryFeatures == null) - summaryFeatures = new LinkedHashSet<>(); - summaryFeatures.add(feature); - } - - private void addMatchFeature(ReferenceNode feature) { - if (matchFeatures == null) - matchFeatures = new LinkedHashSet<>(); - matchFeatures.add(feature); - } - private void addImplicitMatchFeatures(List<FeatureList> list) { if (hiddenMatchFeatures == null) hiddenMatchFeatures = new LinkedHashSet<>(); @@ -642,15 +630,19 @@ public class RankProfile implements Cloneable { /** Adds the content of the given feature list to the internal list of summary features. */ public void addSummaryFeatures(FeatureList features) { + if (summaryFeatures == null) + summaryFeatures = new LinkedHashSet<>(); for (ReferenceNode feature : features) { - addSummaryFeature(feature); + summaryFeatures.add(feature); } } /** Adds the content of the given feature list to the internal list of match features. */ public void addMatchFeatures(FeatureList features) { + if (matchFeatures == null) + matchFeatures = new LinkedHashSet<>(); for (ReferenceNode feature : features) { - addMatchFeature(feature); + matchFeatures.add(feature); } } @@ -661,20 +653,16 @@ public class RankProfile implements Cloneable { .orElse(Set.of()); } - private void addRankFeature(ReferenceNode feature) { - if (rankFeatures == null) - rankFeatures = new LinkedHashSet<>(); - rankFeatures.add(feature); - } - /** * Adds the content of the given feature list to the internal list of rank features. * * @param features The features to add. */ public void addRankFeatures(FeatureList features) { + if (rankFeatures == null) + rankFeatures = new LinkedHashSet<>(); for (ReferenceNode feature : features) { - addRankFeature(feature); + rankFeatures.add(feature); } } diff --git a/config-model/src/test/derived/rankprofileinheritance/child.sd b/config-model/src/test/derived/rankprofileinheritance/child.sd index 2517d0731f5..8348a62838c 100644 --- a/config-model/src/test/derived/rankprofileinheritance/child.sd +++ b/config-model/src/test/derived/rankprofileinheritance/child.sd @@ -39,4 +39,15 @@ schema child { } + rank-profile profile5 inherits profile1 { + match-features { + attribute(field3) + } + } + + rank-profile profile6 inherits profile1 { + summary-features { } + match-features { } + } + } diff --git a/config-model/src/test/derived/rankprofileinheritance/rank-profiles.cfg b/config-model/src/test/derived/rankprofileinheritance/rank-profiles.cfg index a3bc6791412..ccf52da3b5e 100644 --- a/config-model/src/test/derived/rankprofileinheritance/rank-profiles.cfg +++ b/config-model/src/test/derived/rankprofileinheritance/rank-profiles.cfg @@ -52,3 +52,23 @@ rankprofile[].fef.property[].name "vespa.feature.rename" rankprofile[].fef.property[].value "rankingExpression(function4)" rankprofile[].fef.property[].name "vespa.feature.rename" rankprofile[].fef.property[].value "function4" +rankprofile[].name "profile5" +rankprofile[].fef.property[].name "rankingExpression(function1).rankingScript" +rankprofile[].fef.property[].value "attribute(field1) + 5" +rankprofile[].fef.property[].name "rankingExpression(function1b).rankingScript" +rankprofile[].fef.property[].value "attribute(field1) + 42" +rankprofile[].fef.property[].name "vespa.summary.feature" +rankprofile[].fef.property[].value "attribute(field1)" +rankprofile[].fef.property[].name "vespa.summary.feature" +rankprofile[].fef.property[].value "rankingExpression(function1)" +rankprofile[].fef.property[].name "vespa.match.feature" +rankprofile[].fef.property[].value "attribute(field3)" +rankprofile[].fef.property[].name "vespa.feature.rename" +rankprofile[].fef.property[].value "rankingExpression(function1)" +rankprofile[].fef.property[].name "vespa.feature.rename" +rankprofile[].fef.property[].value "function1" +rankprofile[].name "profile6" +rankprofile[].fef.property[].name "rankingExpression(function1).rankingScript" +rankprofile[].fef.property[].value "attribute(field1) + 5" +rankprofile[].fef.property[].name "rankingExpression(function1b).rankingScript" +rankprofile[].fef.property[].value "attribute(field1) + 42" diff --git a/config-model/src/test/java/com/yahoo/schema/SchemaTestCase.java b/config-model/src/test/java/com/yahoo/schema/SchemaTestCase.java index c959634019d..e920672646f 100644 --- a/config-model/src/test/java/com/yahoo/schema/SchemaTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/SchemaTestCase.java @@ -1,11 +1,17 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema; +import com.yahoo.document.Document; import com.yahoo.schema.document.Stemming; import com.yahoo.schema.parser.ParseException; import com.yahoo.schema.processing.ImportedFieldsResolver; import com.yahoo.schema.processing.OnnxModelTypeResolver; import com.yahoo.vespa.documentmodel.DocumentSummary; +import com.yahoo.vespa.indexinglanguage.expressions.AttributeExpression; +import com.yahoo.vespa.indexinglanguage.expressions.Expression; +import com.yahoo.vespa.indexinglanguage.expressions.InputExpression; +import com.yahoo.vespa.indexinglanguage.expressions.ScriptExpression; +import com.yahoo.vespa.indexinglanguage.expressions.StatementExpression; import com.yahoo.vespa.model.test.utils.DeployLoggerStub; import org.junit.jupiter.api.Test; diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java index cac297f061a..32f4d2b653c 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java @@ -43,6 +43,7 @@ import com.yahoo.slime.Slime; import com.yahoo.transaction.NestedTransaction; import com.yahoo.transaction.Transaction; import com.yahoo.vespa.applicationmodel.InfrastructureApplication; +import com.yahoo.vespa.config.server.application.ActiveTokenFingerprints; import com.yahoo.vespa.config.server.application.ActiveTokenFingerprints.Token; import com.yahoo.vespa.config.server.application.ActiveTokenFingerprintsClient; import com.yahoo.vespa.config.server.application.Application; @@ -54,7 +55,6 @@ import com.yahoo.vespa.config.server.application.ClusterReindexing; import com.yahoo.vespa.config.server.application.ClusterReindexingStatusClient; import com.yahoo.vespa.config.server.application.CompressedApplicationInputStream; import com.yahoo.vespa.config.server.application.ConfigConvergenceChecker; -import com.yahoo.vespa.config.server.application.ActiveTokenFingerprints; import com.yahoo.vespa.config.server.application.DefaultClusterReindexingStatusClient; import com.yahoo.vespa.config.server.application.FileDistributionStatus; import com.yahoo.vespa.config.server.application.HttpProxy; @@ -111,7 +111,6 @@ import java.time.Duration; import java.time.Instant; import java.util.Collection; import java.util.Comparator; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -135,7 +134,6 @@ import static com.yahoo.vespa.config.server.tenant.TenantRepository.HOSTED_VESPA import static com.yahoo.vespa.curator.Curator.CompletionWaiter; import static com.yahoo.yolean.Exceptions.uncheck; import static java.nio.file.Files.readAttributes; -import static java.util.stream.Collectors.toMap; /** * The API for managing applications. @@ -544,6 +542,7 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye static void checkIfActiveHasChanged(Session session, Session activeSession, boolean ignoreStaleSessionFailure) { long activeSessionAtCreate = session.getActiveSessionAtCreate(); log.log(Level.FINE, () -> activeSession.logPre() + "active session id at create time=" + activeSessionAtCreate); + if (activeSessionAtCreate == 0) return; // No active session at create time, or session created for indeterminate app. long sessionId = session.getSessionId(); long activeSessionSessionId = activeSession.getSessionId(); @@ -554,7 +553,7 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye String errMsg = activeSession.logPre() + "Cannot activate session " + sessionId + " because the currently active session (" + activeSessionSessionId + ") has changed since session " + sessionId + " was created (was " + - (activeSessionAtCreate == 0 ? "empty" : activeSessionAtCreate) + " at creation time)"; + activeSessionAtCreate + " at creation time)"; if (ignoreStaleSessionFailure) { log.warning(errMsg + " (Continuing because of force.)"); } else { @@ -956,30 +955,25 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye } public void deleteExpiredLocalSessions() { - Map<Tenant, Collection<LocalSession>> sessionsPerTenant = new HashMap<>(); - tenantRepository.getAllTenants() - .forEach(tenant -> sessionsPerTenant.put(tenant, tenant.getSessionRepository().getLocalSessions())); - - Set<ApplicationId> applicationIds = new HashSet<>(); - sessionsPerTenant.values() - .forEach(sessionList -> sessionList.stream() - .map(Session::getOptionalApplicationId) - .filter(Optional::isPresent) - .forEach(appId -> applicationIds.add(appId.get()))); - - Map<ApplicationId, Long> activeSessions = new HashMap<>(); - applicationIds.forEach(applicationId -> getActiveSession(applicationId).ifPresent(session -> activeSessions.put(applicationId, session.getSessionId()))); - sessionsPerTenant.keySet().forEach(tenant -> tenant.getSessionRepository().deleteExpiredSessions(activeSessions)); + for (Tenant tenant : tenantRepository.getAllTenants()) { + tenant.getSessionRepository().deleteExpiredSessions(session -> sessionIsActiveForItsApplication(tenant, session)); + } } public int deleteExpiredRemoteSessions(Clock clock) { return tenantRepository.getAllTenants() .stream() - .map(tenant -> tenant.getSessionRepository().deleteExpiredRemoteSessions(clock)) + .map(tenant -> tenant.getSessionRepository().deleteExpiredRemoteSessions(clock, session -> sessionIsActiveForItsApplication(tenant, session))) .mapToInt(i -> i) .sum(); } + private boolean sessionIsActiveForItsApplication(Tenant tenant, Session session) { + Optional<ApplicationId> owner = session.getOptionalApplicationId(); + if (owner.isEmpty()) return true; // Chicken out ~(˘▾˘)~ + return tenant.getApplicationRepo().activeSessionOf(owner.get()).equals(Optional.of(session.getSessionId())); + } + // ---------------- Tenant operations ---------------------------------------------------------------- diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java index dcc5d7caa0d..4f10b1215cf 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.config.server.maintenance; import com.yahoo.config.FileReference; import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.TenantName; import com.yahoo.config.subscription.ConfigSourceSet; import com.yahoo.jrt.Supervisor; import com.yahoo.jrt.Transport; @@ -19,8 +20,10 @@ import com.yahoo.vespa.filedistribution.FileReferenceDownload; import java.io.File; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.concurrent.Future; import java.util.logging.Logger; import static com.yahoo.vespa.config.server.filedistribution.FileDistributionUtil.fileReferenceExistsOnDisk; @@ -51,39 +54,60 @@ public class ApplicationPackageMaintainer extends ConfigServerMaintainer { @Override protected double maintain() { int attempts = 0; - int failures = 0; - - for (var applicationId : applicationRepository.listApplications()) { - if (shuttingDown()) - break; - - log.finest(() -> "Verifying application package for " + applicationId); - Optional<Session> session = applicationRepository.getActiveSession(applicationId); - if (session.isEmpty()) continue; // App might be deleted after call to listApplications() or not activated yet (bootstrap phase) - - Optional<FileReference> appFileReference = session.get().getApplicationPackageReference(); - if (appFileReference.isPresent()) { - long sessionId = session.get().getSessionId(); - attempts++; - if (!fileReferenceExistsOnDisk(downloadDirectory, appFileReference.get())) { - log.fine(() -> "Downloading application package with file reference " + appFileReference + - " for " + applicationId + " (session " + sessionId + ")"); - - FileReferenceDownload download = new FileReferenceDownload(appFileReference.get(), - this.getClass().getSimpleName(), - false); - if (fileDownloader.getFile(download).isEmpty()) { - failures++; - log.info("Downloading application package (" + appFileReference + ")" + - " for " + applicationId + " (session " + sessionId + ") unsuccessful. " + - "Can be ignored unless it happens many times over a long period of time, retries is expected"); + int[] failures = new int[1]; + + List<Runnable> futureDownloads = new ArrayList<>(); + for (TenantName tenantName : applicationRepository.tenantRepository().getAllTenantNames()) { + for (Session session : applicationRepository.tenantRepository().getTenant(tenantName).getSessionRepository().getRemoteSessions()) { + if (shuttingDown()) + break; + + switch (session.getStatus()) { + case PREPARE, ACTIVATE: + break; + default: continue; + } + + var applicationId = session.getApplicationId(); + log.finest(() -> "Verifying application package for " + applicationId); + + Optional<FileReference> appFileReference = session.getApplicationPackageReference(); + if (appFileReference.isPresent()) { + long sessionId = session.getSessionId(); + attempts++; + if (!fileReferenceExistsOnDisk(downloadDirectory, appFileReference.get())) { + log.fine(() -> "Downloading application package with file reference " + appFileReference + + " for " + applicationId + " (session " + sessionId + ")"); + + FileReferenceDownload download = new FileReferenceDownload(appFileReference.get(), + this.getClass().getSimpleName(), + false); + Future<Optional<File>> futureDownload = fileDownloader.getFutureFileOrTimeout(download); + futureDownloads.add(() -> { + try { + if (futureDownload.get().isPresent()) { + createLocalSessionIfMissing(applicationId, sessionId); + return; + } + } + catch (Exception ignored) { } + failures[0]++; + log.info("Downloading application package (" + appFileReference + ")" + + " for " + applicationId + " (session " + sessionId + ") unsuccessful. " + + "Can be ignored unless it happens many times over a long period of time, retries is expected"); + }); + } + else { + createLocalSessionIfMissing(applicationId, sessionId); } } - createLocalSessionIfMissing(applicationId, sessionId); } } - return asSuccessFactorDeviation(attempts, failures); + + futureDownloads.forEach(Runnable::run); + + return asSuccessFactorDeviation(attempts, failures[0]); } private static FileDownloader createFileDownloader(ApplicationRepository applicationRepository, @@ -92,7 +116,7 @@ public class ApplicationPackageMaintainer extends ConfigServerMaintainer { List<String> otherConfigServersInCluster = getOtherConfigServersInCluster(applicationRepository.configserverConfig()); ConfigSourceSet configSourceSet = new ConfigSourceSet(otherConfigServersInCluster); ConnectionPool connectionPool = new FileDistributionConnectionPool(configSourceSet, supervisor); - return new FileDownloader(connectionPool, supervisor, downloadDirectory, Duration.ofSeconds(300)); + return new FileDownloader(connectionPool, supervisor, downloadDirectory, Duration.ofSeconds(60)); } @Override diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java index 52c11ed0e93..2f0d8b4065d 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java @@ -79,6 +79,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; import java.util.logging.Level; import java.util.logging.Logger; @@ -369,7 +370,7 @@ public class SessionRepository { return session; } - public int deleteExpiredRemoteSessions(Clock clock) { + public int deleteExpiredRemoteSessions(Clock clock, Predicate<Session> sessionIsActiveForApplication) { Duration expiryTime = Duration.ofSeconds(expiryTimeFlag.value()); List<Long> remoteSessionsFromZooKeeper = getRemoteSessionsFromZooKeeper(); log.log(Level.FINE, () -> "Remote sessions for tenant " + tenantName + ": " + remoteSessionsFromZooKeeper); @@ -377,11 +378,11 @@ public class SessionRepository { int deleted = 0; // Avoid deleting too many in one run int deleteMax = (int) Math.min(1000, Math.max(50, remoteSessionsFromZooKeeper.size() * 0.05)); - for (long sessionId : remoteSessionsFromZooKeeper) { + for (Long sessionId : remoteSessionsFromZooKeeper) { Session session = remoteSessionCache.get(sessionId); if (session == null) session = new RemoteSession(tenantName, sessionId, createSessionZooKeeperClient(sessionId)); - if (session.getStatus() == Session.Status.ACTIVATE) continue; + if (session.getStatus() == Session.Status.ACTIVATE && sessionIsActiveForApplication.test(session)) continue; if (sessionHasExpired(session.getCreateTime(), expiryTime, clock)) { log.log(Level.FINE, () -> "Remote session " + sessionId + " for " + tenantName + " has expired, deleting it"); deleteRemoteSessionFromZooKeeper(session); @@ -616,7 +617,7 @@ public class SessionRepository { // ---------------- Common stuff ---------------------------------------------------------------- - public void deleteExpiredSessions(Map<ApplicationId, Long> activeSessions) { + public void deleteExpiredSessions(Predicate<Session> sessionIsActiveForApplication) { log.log(Level.FINE, () -> "Deleting expired local sessions for tenant '" + tenantName + "'"); Set<Long> sessionIdsToDelete = new HashSet<>(); Set<Long> newSessions = findNewSessionsInFileSystem(); @@ -650,8 +651,7 @@ public class SessionRepository { Optional<ApplicationId> applicationId = session.getOptionalApplicationId(); if (applicationId.isEmpty()) continue; - Long activeSession = activeSessions.get(applicationId.get()); - if (activeSession == null || activeSession != sessionId) { + if ( ! sessionIsActiveForApplication.test(session)) { sessionIdsToDelete.add(sessionId); log.log(Level.FINE, () -> "Will delete inactive session " + sessionId + " created " + createTime + " for '" + applicationId + "'"); diff --git a/configserver/src/main/resources/configserver-app/services.xml b/configserver/src/main/resources/configserver-app/services.xml index 2e666089152..069b7ffc496 100644 --- a/configserver/src/main/resources/configserver-app/services.xml +++ b/configserver/src/main/resources/configserver-app/services.xml @@ -111,6 +111,7 @@ </handler> <handler id='com.yahoo.vespa.config.server.http.v2.ApplicationApiHandler' bundle='configserver'> <binding>http://*/application/v2/tenant/*/prepareandactivate</binding> + <binding>http://*/application/v2/tenant/*/prepareandactivate/*</binding> </handler> <handler id='com.yahoo.vespa.config.server.http.v2.SessionContentHandler' bundle='configserver'> <binding>http://*/application/v2/tenant/*/session/*/content/*</binding> diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/filedistribution/FileServerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/filedistribution/FileServerTest.java index 891284a3a0e..e0a58888109 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/filedistribution/FileServerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/filedistribution/FileServerTest.java @@ -206,7 +206,7 @@ public class FileServerTest { super(FileDownloader.emptyConnectionPool(), new Supervisor(new Transport("mock")).setDropEmptyBuffers(true), downloadDirectory, - Duration.ofMillis(100), + Duration.ofMillis(1000), Duration.ofMillis(100)); } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandlerTest.java index 36892860295..d4aa0676c4f 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandlerTest.java @@ -212,15 +212,34 @@ class ApplicationApiHandlerTest { @Test void testActivationFailuresAndRetries() throws Exception { - // Prepare session 2, but fail on hosts; this session will be activated later. - provisioner.activationFailure(new ParentHostUnavailableException("host still booting")); + // Prepare session 2, and activate it successfully. verifyResponse(post(minimalPrepareParams, zip(appPackage), Map.of()), 200, """ { "log": [ ], - "message": "Session 2 for tenant 'test' prepared, but activation failed: host still booting", + "message": "Session 2 for tenant 'test' prepared and activated.", "session-id": "2", + "activated": true, + "tenant": "test", + "url": "http://host:123/application/v2/tenant/test/application/default/environment/prod/region/default/instance/default", + "configChangeActions": { + "restart": [ ], + "refeed": [ ], + "reindex": [ ] + } + } + """); + + // Prepare session 3, but fail on hosts; this session will be activated later. + provisioner.activationFailure(new ParentHostUnavailableException("host still booting")); + verifyResponse(post(minimalPrepareParams, zip(appPackage), Map.of()), + 200, + """ + { + "log": [ ], + "message": "Session 3 for tenant 'test' prepared, but activation failed: host still booting", + "session-id": "3", "activated": false, "tenant": "test", "url": "http://host:123/application/v2/tenant/test/application/default/environment/prod/region/default/instance/default", @@ -232,15 +251,15 @@ class ApplicationApiHandlerTest { } """); - // Prepare session 3, but fail on lock; this session will become outdated later. + // Prepare session 4, but fail on lock; this session will become outdated later. provisioner.activationFailure(new ApplicationLockException("lock timeout")); verifyResponse(post(minimalPrepareParams, zip(appPackage), Map.of()), 200, """ { "log": [ ], - "message": "Session 3 for tenant 'test' prepared, but activation failed: lock timeout", - "session-id": "3", + "message": "Session 4 for tenant 'test' prepared, but activation failed: lock timeout", + "session-id": "4", "activated": false, "tenant": "test", "url": "http://host:123/application/v2/tenant/test/application/default/environment/prod/region/default/instance/default", @@ -263,9 +282,9 @@ class ApplicationApiHandlerTest { } """); - // Retry only activation of session 2, but fail again with hosts. + // Retry only activation of session 3, but fail again with hosts. provisioner.activationFailure(new ParentHostUnavailableException("host still booting")); - verifyResponse(put(2, Map.of()), + verifyResponse(put(3, Map.of()), 409, """ { @@ -274,9 +293,9 @@ class ApplicationApiHandlerTest { } """); - // Retry only activation of session 2, but fail again with lock. + // Retry only activation of session 3, but fail again with lock. provisioner.activationFailure(new ApplicationLockException("lock timeout")); - verifyResponse(put(2, Map.of()), + verifyResponse(put(3, Map.of()), 500, """ { @@ -285,33 +304,33 @@ class ApplicationApiHandlerTest { } """); - // Retry only activation of session 2, and succeed! + // Retry only activation of session 3, and succeed! provisioner.activationFailure(null); - verifyResponse(put(2, Map.of()), + verifyResponse(put(3, Map.of()), 200, """ { - "message": "Session 2 for test.default.default activated" + "message": "Session 3 for test.default.default activated" } """); - // Retry only activation of session 3, but fail because it is now based on an outdated session. - verifyResponse(put(3, Map.of()), + // Retry only activation of session 4, but fail because it is now based on an outdated session. + verifyResponse(put(4, Map.of()), 409, """ { "error-code": "ACTIVATION_CONFLICT", - "message": "app:test.default.default Cannot activate session 3 because the currently active session (2) has changed since session 3 was created (was empty at creation time)" + "message": "app:test.default.default Cannot activate session 4 because the currently active session (3) has changed since session 4 was created (was 2 at creation time)" } """); - // Retry activation of session 2 again, and fail. - verifyResponse(put(2, Map.of()), + // Retry activation of session 3 again, and fail. + verifyResponse(put(3, Map.of()), 400, """ { "error-code": "BAD_REQUEST", - "message": "app:test.default.default Session 2 is already active" + "message": "app:test.default.default Session 3 is already active" } """); } diff --git a/container-search/src/main/java/com/yahoo/prelude/query/parser/AllParser.java b/container-search/src/main/java/com/yahoo/prelude/query/parser/AllParser.java index 0010291de66..43bd175b348 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/parser/AllParser.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/parser/AllParser.java @@ -11,6 +11,7 @@ import com.yahoo.prelude.query.OrItem; import com.yahoo.prelude.query.PhraseItem; import com.yahoo.prelude.query.QueryCanonicalizer; import com.yahoo.prelude.query.RankItem; +import com.yahoo.prelude.query.TrueItem; import com.yahoo.prelude.query.WeakAndItem; import com.yahoo.search.query.QueryTree; import com.yahoo.search.query.parser.ParserEnvironment; @@ -79,8 +80,8 @@ public class AllParser extends SimpleParser { // Combine the items Item topLevel = and; - if (not != null && topLevel != null) { - not.setPositiveItem(topLevel); + if (not != null) { + not.setPositiveItem(topLevel != null ? topLevel : new TrueItem()); topLevel = not; } @@ -130,6 +131,7 @@ public class AllParser extends SimpleParser { if ( ! tokens.skip(MINUS)) return null; if (tokens.currentIsNoIgnore(SPACE)) return null; var itemAndExplicitIndex = indexableItem(); + item = itemAndExplicitIndex.getFirst(); boolean explicitIndex = itemAndExplicitIndex.getSecond(); if (item == null) { diff --git a/container-search/src/main/java/com/yahoo/prelude/query/parser/AnyParser.java b/container-search/src/main/java/com/yahoo/prelude/query/parser/AnyParser.java index efc804fcf1f..1bbc21768b5 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/parser/AnyParser.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/parser/AnyParser.java @@ -12,6 +12,7 @@ import com.yahoo.prelude.query.OrItem; import com.yahoo.prelude.query.PhraseItem; import com.yahoo.prelude.query.RankItem; import com.yahoo.prelude.query.TermItem; +import com.yahoo.prelude.query.TrueItem; import com.yahoo.search.query.parser.ParserEnvironment; import java.util.Iterator; @@ -106,9 +107,8 @@ public class AnyParser extends SimpleParser { } return rank; } else if ((topLevelItem instanceof RankItem) - && (item instanceof RankItem) + && (item instanceof RankItem itemAsRank) && (((RankItem) item).getItem(0) instanceof OrItem)) { - RankItem itemAsRank = (RankItem) item; OrItem or = (OrItem) itemAsRank.getItem(0); ((RankItem) topLevelItem).addItem(0, or); @@ -139,8 +139,10 @@ public class AnyParser extends SimpleParser { if (root instanceof PhraseItem) { root.setFilter(true); } - for (Iterator<Item> i = ((CompositeItem) root).getItemIterator(); i.hasNext();) { - markAllTermsAsFilters(i.next()); + if (root instanceof CompositeItem composite) { + for (Iterator<Item> i = composite.getItemIterator(); i.hasNext(); ) { + markAllTermsAsFilters(i.next()); + } } } } @@ -206,8 +208,7 @@ public class AnyParser extends SimpleParser { return root; } - if (root instanceof RankItem) { - RankItem rootAsRank = (RankItem) root; + if (root instanceof RankItem rootAsRank) { Item firstChild = rootAsRank.getItem(0); if (firstChild instanceof NotItem) { @@ -228,7 +229,6 @@ public class AnyParser extends SimpleParser { } NotItem not = new NotItem(); - not.addPositiveItem(root); not.addNegativeItem(item); return not; diff --git a/container-search/src/main/java/com/yahoo/prelude/query/parser/SimpleParser.java b/container-search/src/main/java/com/yahoo/prelude/query/parser/SimpleParser.java index deab2be9d00..ea0cd2312a6 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/parser/SimpleParser.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/parser/SimpleParser.java @@ -130,14 +130,13 @@ abstract class SimpleParser extends StructuredParser { } } if (not != null && not.getPositiveItem() instanceof TrueItem) { - // Incomplete not, only negatives - - + // Incomplete not, only negatives - simplify when possible if (topLevelItem != null && topLevelItem != not) { // => neutral rank items becomes implicit positives not.addPositiveItem(getItemAsPositiveItem(topLevelItem, not)); return not; - } else { // Only negatives - ignore them - return null; + } else { + return not; } } if (topLevelItem != null) { diff --git a/container-search/src/main/java/com/yahoo/prelude/query/parser/WebParser.java b/container-search/src/main/java/com/yahoo/prelude/query/parser/WebParser.java index 06ea583c53f..75396a8714f 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/parser/WebParser.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/parser/WebParser.java @@ -6,6 +6,7 @@ import com.yahoo.prelude.query.CompositeItem; import com.yahoo.prelude.query.Item; import com.yahoo.prelude.query.NotItem; import com.yahoo.prelude.query.OrItem; +import com.yahoo.prelude.query.TrueItem; import com.yahoo.prelude.query.WordItem; import com.yahoo.search.query.parser.ParserEnvironment; @@ -69,8 +70,8 @@ public class WebParser extends AllParser { if (or != null) topLevel = or; - if (not != null && topLevel != null) { - not.setPositiveItem(topLevel); + if (not != null) { + not.setPositiveItem(topLevel != null ? topLevel : new TrueItem()); topLevel = not; } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java index bef766e7ef9..70f6e405a92 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java @@ -15,6 +15,7 @@ public class ConversionContext { private final String destination; private final CompiledQueryProfileRegistry registry; private final Map<String, Embedder> embedders; + private final Map<String, String> contextValues; private final Language language; public ConversionContext(String destination, CompiledQueryProfileRegistry registry, Embedder embedder, @@ -30,6 +31,7 @@ public class ConversionContext { this.embedders = embedders; this.language = context.containsKey("language") ? Language.fromLanguageTag(context.get("language")) : Language.UNKNOWN; + this.contextValues = context; } /** Returns the local name of the field which will receive the converted value (or null when this is empty) */ @@ -44,6 +46,9 @@ public class ConversionContext { /** Returns the language, which is never null but may be UNKNOWN */ Language language() { return language; } + /** Returns a read-only map of context key-values which can be looked up during conversion. */ + Map<String,String> contextValues() { return contextValues; } + /** Returns an empty context */ public static ConversionContext empty() { return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of()); diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index cfadd79de8f..e16f8e7b0cd 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -48,7 +48,8 @@ public class TensorFieldType extends FieldType { @Override public Object convertFrom(Object o, ConversionContext context) { if (o instanceof SubstituteString) return new SubstituteStringTensor((SubstituteString) o, type); - return new TensorConverter(context.embedders()).convertTo(type, context.destination(), o, context.language()); + return new TensorConverter(context.embedders()).convertTo(type, context.destination(), o, + context.language(), context.contextValues()); } public static TensorFieldType fromTypeString(String s) { diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java index c9f935e5f52..25a5c277dce 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java @@ -44,7 +44,8 @@ public class RankProfileInputProperties extends Properties { value = tensorConverter.convertTo(expectedType, name.last(), value, - query.getModel().getLanguage()); + query.getModel().getLanguage(), + context); } } catch (IllegalArgumentException e) { diff --git a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java index 6da53ae699c..94f92c7fd48 100644 --- a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java +++ b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java @@ -19,7 +19,8 @@ import java.util.regex.Pattern; */ public class TensorConverter { - private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])"); + private static final Pattern embedderArgumentAndQuotedTextRegexp = Pattern.compile("^([A-Za-z0-9_@\\-.]+),\\s*([\"'].*[\"'])"); + private static final Pattern embedderArgumentAndReferenceRegexp = Pattern.compile("^([A-Za-z0-9_@\\-.]+),\\s*(@.*)"); private final Map<String, Embedder> embedders; @@ -27,8 +28,9 @@ public class TensorConverter { this.embedders = embedders; } - public Tensor convertTo(TensorType type, String key, Object value, Language language) { - var context = new Embedder.Context(key).setLanguage(language); + public Tensor convertTo(TensorType type, String key, Object value, Language language, + Map<String, String> contextValues) { + var context = new Embedder.Context(key).setLanguage(language).setContextValues(contextValues); Tensor tensor = toTensor(type, value, context); if (tensor == null) return null; if (! tensor.type().isAssignableTo(type)) @@ -55,16 +57,16 @@ public class TensorConverter { String embedderId; // Check if arguments specifies an embedder with the format embed(embedder, "text to encode") - Matcher matcher = embedderArgumentRegexp.matcher(argument); - if (matcher.matches()) { + Matcher matcher; + if (( matcher = embedderArgumentAndQuotedTextRegexp.matcher(argument)).matches()) { embedderId = matcher.group(1); + embedder = requireEmbedder(embedderId); argument = matcher.group(2); - if ( ! embedders.containsKey(embedderId)) { - throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + - "Valid embedders are " + validEmbedders(embedders)); - } - embedder = embedders.get(embedderId); - } else if (embedders.size() == 0) { + } else if (( matcher = embedderArgumentAndReferenceRegexp.matcher(argument)).matches()) { + embedderId = matcher.group(1); + embedder = requireEmbedder(embedderId); + argument = matcher.group(2); + } else if (embedders.isEmpty()) { throw new IllegalStateException("No embedders provided"); // should never happen } else if (embedders.size() > 1) { throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " + @@ -74,19 +76,35 @@ public class TensorConverter { embedderId = entry.getKey(); embedder = entry.getValue(); } - return embedder.embed(removeQuotes(argument), embedderContext.copy().setEmbedderId(embedderId), type); + return embedder.embed(resolve(argument, embedderContext), embedderContext.copy().setEmbedderId(embedderId), type); } - private static String removeQuotes(String s) { - if (s.startsWith("'") && s.endsWith("'")) { + private Embedder requireEmbedder(String embedderId) { + if ( ! embedders.containsKey(embedderId)) + throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + + "Valid embedders are " + validEmbedders(embedders)); + return embedders.get(embedderId); + } + + private static String resolve(String s, Embedder.Context embedderContext) { + if (s.startsWith("'") && s.endsWith("'")) return s.substring(1, s.length() - 1); - } - if (s.startsWith("\"") && s.endsWith("\"")) { + if (s.startsWith("\"") && s.endsWith("\"")) return s.substring(1, s.length() - 1); - } + if (s.startsWith("@")) + return resolveReference(s, embedderContext); return s; } + private static String resolveReference(String s, Embedder.Context embedderContext) { + String referenceKey = s.substring(1); + String referencedValue = embedderContext.getContextValues().get(referenceKey); + if (referencedValue == null) + throw new IllegalArgumentException("Could not resolve query parameter reference '" + referenceKey + + "' used in an embed() argument"); + return referencedValue; + } + private static String validEmbedders(Map<String, Embedder> embedders) { List<String> embedderIds = new ArrayList<>(); embedders.forEach((key, value) -> embedderIds.add(key)); diff --git a/container-search/src/test/java/com/yahoo/prelude/query/parser/test/ParseTestCase.java b/container-search/src/test/java/com/yahoo/prelude/query/parser/test/ParseTestCase.java index 4383be184fa..156c34e5005 100644 --- a/container-search/src/test/java/com/yahoo/prelude/query/parser/test/ParseTestCase.java +++ b/container-search/src/test/java/com/yahoo/prelude/query/parser/test/ParseTestCase.java @@ -21,6 +21,7 @@ import com.yahoo.prelude.query.RankItem; import com.yahoo.prelude.query.SubstringItem; import com.yahoo.prelude.query.SuffixItem; import com.yahoo.prelude.query.TaggableItem; +import com.yahoo.prelude.query.TrueItem; import com.yahoo.prelude.query.WordItem; import com.yahoo.language.process.SpecialTokens; import com.yahoo.prelude.query.parser.TestLinguistics; @@ -262,17 +263,28 @@ public class ParseTestCase { @Test void testNotOnly() { - tester.assertParsed(null, "-foobar", Query.Type.ALL); + Item item = tester.assertParsed("-foobar", "-foobar", Query.Type.ALL); + assertTrue(item instanceof NotItem); + assertTrue(((NotItem) item).getPositiveItem() instanceof TrueItem); } @Test - void testMultipleNotsOnlt() { - tester.assertParsed(null, "-foo -bar -foobar", Query.Type.ALL); + void testNotOnlyAny() { + Item item = tester.assertParsed("-foobar", "-foobar", Query.Type.ANY); + assertTrue(item instanceof NotItem); + assertTrue(((NotItem) item).getPositiveItem() instanceof TrueItem); + } + + @Test + void testMultipleNotsOnly() { + Item item = tester.assertParsed("-foo -bar -foobar", "-foo -bar -foobar", Query.Type.ALL); + assertTrue(item instanceof NotItem); + assertTrue(((NotItem) item).getPositiveItem() instanceof TrueItem); } @Test void testOnlyNotComposite() { - tester.assertParsed(null, "-(foo bar baz)", Query.Type.ALL); + tester.assertParsed("-(AND foo bar baz)", "-(foo bar baz)", Query.Type.ALL); } @Test @@ -391,7 +403,7 @@ public class ParseTestCase { @Test void testMinusAndPluses() { - tester.assertParsed(null, "--test+-if", Query.Type.ANY); + tester.assertParsed("-(AND test if)", "--test+-if", Query.Type.ANY); } @Test @@ -1305,7 +1317,9 @@ public class ParseTestCase { @Test void testNotFilterEmptyQuery() { - tester.assertParsed(null, "", "-foo", Query.Type.ANY); + Item item = tester.assertParsed("-|foo", "", "-foo", Query.Type.ANY); + assertTrue(item instanceof NotItem); + assertTrue(((NotItem) item).getPositiveItem() instanceof TrueItem); } @Test @@ -1380,7 +1394,7 @@ public class ParseTestCase { @Test void testMultitermNotFilterEmptyQuery() { - tester.assertParsed(null, "", "-foo -foz", Query.Type.ANY); + tester.assertParsed("-|foo -|foz", "", "-foo -foz", Query.Type.ANY); } @Test @@ -2320,17 +2334,19 @@ public class ParseTestCase { @Test void testNotOnlyWeb() { - tester.assertParsed(null, "-foobar", Query.Type.WEB); + Item item = tester.assertParsed("-foobar", "-foobar", Query.Type.WEB); + assertTrue(item instanceof NotItem); + assertTrue(((NotItem)item).getPositiveItem() instanceof TrueItem); } @Test void testMultipleNotsOnltWeb() { - tester.assertParsed(null, "-foo -bar -foobar", Query.Type.WEB); + tester.assertParsed("-foo -bar -foobar", "-foo -bar -foobar", Query.Type.WEB); } @Test void testOnlyNotCompositeWeb() { - tester.assertParsed(null, "-(foo bar baz)", Query.Type.WEB); + tester.assertParsed("-(AND foo bar baz)", "-(foo bar baz)", Query.Type.WEB); } @Test diff --git a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java index 90e21e5f3b0..429b8d1c6cb 100644 --- a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java +++ b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java @@ -185,6 +185,21 @@ public class RankProfileInputTest { assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders, Language.UNKNOWN.languageCode()); } + @Test + void testUnembeddedTensorRankFeatureInRequestReferencedFromAParameter() { + String text = "text to embed into a tensor"; + Tensor embedding1 = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]"); + + Map<String, Embedder> embedders = Map.of( + "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1) + ); + assertEmbedQuery("embed(@param1)", embedding1, embedders, null, text); + assertEmbedQuery("embed(emb1, @param1)", embedding1, embedders, null, text); + assertEmbedQueryFails("embed(emb1, @noSuchParam)", embedding1, embedders, + "Could not resolve query parameter reference 'noSuchParam' " + + "used in an embed() argument"); + } + private Query createTensor1Query(String tensorString, String profile, String additionalParams) { return new Query.Builder() .setSchemaInfo(createSchemaInfo()) @@ -202,18 +217,24 @@ public class RankProfileInputTest { } private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders) { - assertEmbedQuery(embed, expected, embedders, null); + assertEmbedQuery(embed, expected, embedders, null, null); } private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language) { + assertEmbedQuery(embed, expected, embedders, language, null); + } + private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language, String param1Value) { String languageParam = language == null ? "" : "&language=" + language; + String param1 = param1Value == null ? "" : "¶m1=" + urlEncode(param1Value); + String destination = "query(myTensor4)"; Query query = new Query.Builder().setRequest(HttpRequest.createTestRequest( "?" + urlEncode("ranking.features." + destination) + "=" + urlEncode(embed) + "&ranking=commonProfile" + - languageParam, + languageParam + + param1, com.yahoo.jdisc.http.HttpRequest.Method.GET)) .setSchemaInfo(createSchemaInfo()) .setQueryProfile(createQueryProfile()) @@ -230,7 +251,7 @@ public class RankProfileInputTest { if (t.getMessage().equals(errMsg)) return; t = t.getCause(); } - fail("Error '" + errMsg + "' not thrown"); + fail("Exception with message '" + errMsg + "' not thrown"); } private CompiledQueryProfile createQueryProfile() { diff --git a/dependency-versions/pom.xml b/dependency-versions/pom.xml index cbeb4d01fe4..8128cdf7cd7 100644 --- a/dependency-versions/pom.xml +++ b/dependency-versions/pom.xml @@ -66,7 +66,7 @@ <!-- Athenz dependencies. Make sure these dependencies match those in Vespa's internal repositories --> <athenz.vespa.version>1.11.50</athenz.vespa.version> - <aws-sdk.vespa.version>1.12.639</aws-sdk.vespa.version> + <aws-sdk.vespa.version>1.12.641</aws-sdk.vespa.version> <!-- Athenz END --> <!-- WARNING: If you change curator version, you also need to update @@ -90,7 +90,7 @@ <commons-compress.vespa.version>1.25.0</commons-compress.vespa.version> <commons-cli.vespa.version>1.6.0</commons-cli.vespa.version> <curator.vespa.version>5.6.0</curator.vespa.version> - <dropwizard.metrics.vespa.version>4.2.23</dropwizard.metrics.vespa.version> <!-- ZK 3.9.1 requires this --> + <dropwizard.metrics.vespa.version>4.2.24</dropwizard.metrics.vespa.version> <!-- ZK 3.9.1 requires this --> <eclipse-collections.vespa.version>11.1.0</eclipse-collections.vespa.version> <eclipse-sisu.vespa.version>0.9.0.M2</eclipse-sisu.vespa.version> <failureaccess.vespa.version>1.0.2</failureaccess.vespa.version> @@ -120,7 +120,7 @@ <mimepull.vespa.version>1.10.0</mimepull.vespa.version> <mockito.vespa.version>5.9.0</mockito.vespa.version> <mojo-executor.vespa.version>2.4.0</mojo-executor.vespa.version> - <netty.vespa.version>4.1.105.Final</netty.vespa.version> + <netty.vespa.version>4.1.106.Final</netty.vespa.version> <netty-tcnative.vespa.version>2.0.62.Final</netty-tcnative.vespa.version> <onnxruntime.vespa.version>1.16.3</onnxruntime.vespa.version> <opennlp.vespa.version>2.3.1</opennlp.vespa.version> diff --git a/document/src/main/java/com/yahoo/document/json/TokenBuffer.java b/document/src/main/java/com/yahoo/document/json/TokenBuffer.java index a9cd3cc87a8..c1ac239d5f0 100644 --- a/document/src/main/java/com/yahoo/document/json/TokenBuffer.java +++ b/document/src/main/java/com/yahoo/document/json/TokenBuffer.java @@ -159,7 +159,7 @@ public class TokenBuffer { if (name.equals(currentName()) && current().isScalarValue()) { toReturn = tokens.get(position); } else { - i = tokens.iterator(); + i = rest().iterator(); i.next(); // just ignore the first value, as we know it's not what // we're looking for, and it's nesting effect is already // included diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java index 2854ef8836a..72f0fb977d5 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java @@ -74,7 +74,17 @@ public class FileDownloader implements AutoCloseable { } } - Future<Optional<File>> getFutureFile(FileReferenceDownload fileReferenceDownload) { + /** Returns a future that times out if download takes too long, and return empty on unsuccessful download. */ + public Future<Optional<File>> getFutureFileOrTimeout(FileReferenceDownload fileReferenceDownload) { + return getFutureFile(fileReferenceDownload) + .orTimeout(timeout.toMillis(), TimeUnit.MILLISECONDS) + .exceptionally(thrown -> { + fileReferenceDownloader.failedDownloading(fileReferenceDownload.fileReference()); + return Optional.empty(); + }); + } + + CompletableFuture<Optional<File>> getFutureFile(FileReferenceDownload fileReferenceDownload) { FileReference fileReference = fileReferenceDownload.fileReference(); Optional<File> file = getFileFromFileSystem(fileReference); @@ -135,7 +145,7 @@ public class FileDownloader implements AutoCloseable { } /** Start downloading, the future returned will be complete()d by receiving method in {@link FileReceiver} */ - private synchronized Future<Optional<File>> startDownload(FileReferenceDownload fileReferenceDownload) { + private synchronized CompletableFuture<Optional<File>> startDownload(FileReferenceDownload fileReferenceDownload) { return fileReferenceDownloader.startDownload(fileReferenceDownload); } diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java index 450801ce530..5ad197e8633 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java @@ -15,6 +15,7 @@ import java.time.Instant; import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -67,7 +68,7 @@ public class FileReferenceDownloader { int retryCount = 0; Connection connection = connectionPool.getCurrent(); do { - backoff(retryCount); + backoff(retryCount, end); if (FileDownloader.fileReferenceExists(fileReference, downloadDirectory)) return; @@ -79,24 +80,26 @@ public class FileReferenceDownloader { // exist on just one config server, and which one could be different for each file reference), so // switch to a new connection for every retry connection = connectionPool.switchConnection(connection); - } while (retryCount < 5 || Instant.now().isAfter(end)); + } while (Instant.now().isBefore(end)); fileReferenceDownload.future().completeExceptionally(new RuntimeException("Failed getting " + fileReference)); downloads.remove(fileReference); } - private void backoff(int retryCount) { + private void backoff(int retryCount, Instant end) { if (retryCount > 0) { try { - long sleepTime = Math.min(120_000, (long) (Math.pow(2, retryCount)) * sleepBetweenRetries.toMillis()); - Thread.sleep(sleepTime); + long sleepTime = Math.min(120_000, + Math.min((long) (Math.pow(2, retryCount)) * sleepBetweenRetries.toMillis(), + Duration.between(Instant.now(), end).toMillis())); + if (sleepTime > 0) Thread.sleep(sleepTime); } catch (InterruptedException e) { /* ignored */ } } } - Future<Optional<File>> startDownload(FileReferenceDownload fileReferenceDownload) { + CompletableFuture<Optional<File>> startDownload(FileReferenceDownload fileReferenceDownload) { FileReference fileReference = fileReferenceDownload.fileReference(); Optional<FileReferenceDownload> inProgress = downloads.get(fileReference); if (inProgress.isPresent()) return inProgress.get().future(); diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json index 1ffb879e57e..dc6a62cc463 100644 --- a/linguistics/abi-spec.json +++ b/linguistics/abi-spec.json @@ -344,7 +344,9 @@ "public java.lang.String getDestination()", "public com.yahoo.language.process.Embedder$Context setDestination(java.lang.String)", "public java.lang.String getEmbedderId()", - "public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)" + "public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)", + "public java.util.Map getContextValues()", + "public com.yahoo.language.process.Embedder$Context setContextValues(java.util.Map)" ], "fields" : [ ] }, diff --git a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java index fa141977d5d..f6fdb86d01f 100644 --- a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java +++ b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java @@ -88,6 +88,7 @@ public interface Embedder { private Language language = Language.UNKNOWN; private String destination; private String embedderId = "unknown"; + private Map<String, String> contextValues; public Context(String destination) { this.destination = destination; @@ -138,6 +139,15 @@ public interface Embedder { this.embedderId = embedderId; return this; } + + /** Returns a read-only map of context key-values which can be looked up during conversion. */ + public Map<String, String> getContextValues() { return contextValues; } + + public Context setContextValues(Map<String, String> contextValues) { + this.contextValues = contextValues; + return this; + } + } class FailingEmbedder implements Embedder { diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 853009873a1..28f8c4e252f 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -10,9 +10,9 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.embedding.SpladeEmbedderConfig; import com.yahoo.language.huggingface.HuggingFaceTokenizer; import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.DirectIndexedAddress; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import java.nio.file.Paths; import java.util.List; @@ -152,10 +152,15 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { String dimension = tensorType.dimensions().get(0).name(); //Iterate over the vocab dimension and find the max value for each sequence token long [] tokens = new long[1]; + DirectIndexedAddress directAddress = modelOutput.directAddress(); + directAddress.setIndex(0,0); for (int v = 0; v < vocabSize; v++) { double maxValue = 0.0d; + directAddress.setIndex(2, v); + long increment = directAddress.getStride(1); + long directIndex = directAddress.getDirectIndex(); for (int s = 0; s < sequenceLength; s++) { - double value = modelOutput.get(0, s, v); // batch, sequence, vocab + double value = modelOutput.get(directIndex + s * increment); if (value > maxValue) { maxValue = value; } diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java index 07f2aea4ab6..d1a06d8c7ff 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java @@ -28,6 +28,7 @@ import java.nio.ShortBuffer; import java.util.HashMap; import java.util.Map; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; @@ -101,70 +102,67 @@ class TensorConverter { throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensorInfo.type); } + interface Short2Float { + float convert(short value); + } + + private static void extractTensor(FloatBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(DoubleBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(ByteBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(ShortBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(ShortBuffer buffer, Short2Float converter, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, converter.convert(buffer.get(i))); + } + private static void extractTensor(IntBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(LongBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + static Tensor toVespaTensor(OnnxValue onnxValue) { if ( ! (onnxValue instanceof OnnxTensor onnxTensor)) { throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported"); } TensorInfo tensorInfo = onnxTensor.getInfo(); - TensorType type = toVespaType(onnxTensor.getInfo()); - DimensionSizes sizes = sizesFromType(type); - + DimensionSizes sizes = DimensionSizes.of(type); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type, sizes); - long totalSize = sizes.totalSize(); - if (tensorInfo.type == OnnxJavaType.FLOAT) { - FloatBuffer buffer = onnxTensor.getFloatBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.DOUBLE) { - DoubleBuffer buffer = onnxTensor.getDoubleBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT8) { - ByteBuffer buffer = onnxTensor.getByteBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT16) { - ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT32) { - IntBuffer buffer = onnxTensor.getIntBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT64) { - LongBuffer buffer = onnxTensor.getLongBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.FLOAT16) { - ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, Fp16Conversions.fp16ToFloat(buffer.get())); - } - else if (tensorInfo.type == OnnxJavaType.BFLOAT16) { - ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, Fp16Conversions.bf16ToFloat((buffer.get()))); - } - else { - throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type); + long totalSizeAsLong = sizes.totalSize(); + if (totalSizeAsLong > Integer.MAX_VALUE) { + throw new IllegalArgumentException("TotalSize=" + totalSizeAsLong + " currently limited at INTEGER.MAX_VALUE"); + } + + int totalSize = (int) totalSizeAsLong; + switch (tensorInfo.type) { + case FLOAT -> extractTensor(onnxTensor.getFloatBuffer(), builder, totalSize); + case DOUBLE -> extractTensor(onnxTensor.getDoubleBuffer(), builder, totalSize); + case INT8 -> extractTensor(onnxTensor.getByteBuffer(), builder, totalSize); + case INT16 -> extractTensor(onnxTensor.getShortBuffer(), builder, totalSize); + case INT32 -> extractTensor(onnxTensor.getIntBuffer(), builder, totalSize); + case INT64 -> extractTensor(onnxTensor.getLongBuffer(), builder, totalSize); + case FLOAT16 -> extractTensor(onnxTensor.getShortBuffer(), Fp16Conversions::fp16ToFloat, builder, totalSize); + case BFLOAT16 -> extractTensor(onnxTensor.getShortBuffer(), Fp16Conversions::bf16ToFloat, builder, totalSize); + default -> throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type); } return builder.build(); } - static private DimensionSizes sizesFromType(TensorType type) { - DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); - for (int i = 0; i < type.dimensions().size(); i++) - builder.set(i, type.dimensions().get(i).size().get()); - return builder.build(); - } - static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> infoMap) { return infoMap.entrySet().stream().collect(Collectors.toMap(e -> asValidName(e.getKey()), e -> toVespaType(e.getValue().getInfo()))); diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java index 9ecb0e3e162..b48051814ab 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java @@ -49,11 +49,11 @@ public class SpladeEmbedderTest { String text = "what was the manhattan project in this context it was a secret project to develop a nuclear weapon in world war" + " ii the project was led by the united states with the support of the united kingdom and canada"; Long now = System.currentTimeMillis(); - int n = 10; + int n = 1000; // Takes around 7s on Intel core i9 2.4Ghz (macbook pro, 2019) for (int i = 0; i < n; i++) { assertEmbed("tensor<float>(t{})", text, indexingContext); } - Long elapsed = (System.currentTimeMillis() - now)/1000; + Long elapsed = System.currentTimeMillis() - now; System.out.println("Elapsed time: " + elapsed + " ms"); } diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 591f0eb8b37..97aa42f79c9 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -177,7 +177,7 @@ List<ReferenceNode> featureList() : ReferenceNode exp; } { - ( ( exp = feature() { ret.add(exp); } )+ <EOF> ) + ( ( exp = feature() { ret.add(exp); } )* <EOF> ) { return ret; } } diff --git a/searchlib/src/tests/attribute/direct_multi_term_blueprint/direct_multi_term_blueprint_test.cpp b/searchlib/src/tests/attribute/direct_multi_term_blueprint/direct_multi_term_blueprint_test.cpp index a94ca9d890e..87b771af8e6 100644 --- a/searchlib/src/tests/attribute/direct_multi_term_blueprint/direct_multi_term_blueprint_test.cpp +++ b/searchlib/src/tests/attribute/direct_multi_term_blueprint/direct_multi_term_blueprint_test.cpp @@ -26,6 +26,7 @@ using LookupKey = IDirectPostingStore::LookupKey; struct IntegerKey : public LookupKey { int64_t _value; IntegerKey(int64_t value_in) : _value(value_in) {} + IntegerKey(const vespalib::string&) : _value() { abort(); } vespalib::stringref asString() const override { abort(); } bool asInteger(int64_t& value) const override { value = _value; return true; } }; @@ -33,6 +34,7 @@ struct IntegerKey : public LookupKey { struct StringKey : public LookupKey { vespalib::string _value; StringKey(int64_t value_in) : _value(std::to_string(value_in)) {} + StringKey(const vespalib::string& value_in) : _value(value_in) {} vespalib::stringref asString() const override { return _value; } bool asInteger(int64_t&) const override { abort(); } }; @@ -78,6 +80,10 @@ populate_attribute(AttributeType& attr, const std::vector<DataType>& values) for (auto docid : range(300, 128)) { attr.update(docid, values[3]); } + if (values.size() > 4) { + attr.update(40, values[4]); + attr.update(41, values[5]); + } attr.commit(true); } @@ -93,7 +99,7 @@ make_attribute(CollectionType col_type, BasicType type, bool field_is_filter) auto attr = test::AttributeBuilder(field_name, cfg).docs(num_docs).get(); if (type == BasicType::STRING) { populate_attribute<StringAttribute, vespalib::string>(dynamic_cast<StringAttribute&>(*attr), - {"1", "3", "100", "300"}); + {"1", "3", "100", "300", "foo", "Foo"}); } else { populate_attribute<IntegerAttribute, int64_t>(dynamic_cast<IntegerAttribute&>(*attr), {1, 3, 100, 300}); @@ -223,15 +229,16 @@ public: tfmd.tagAsNotNeeded(); } } - template <typename BlueprintType> - void add_term_helper(BlueprintType& b, int64_t term_value) { + template <typename BlueprintType, typename TermType> + void add_term_helper(BlueprintType& b, TermType term_value) { if (integer_type) { b.addTerm(IntegerKey(term_value), 1, estimate); } else { b.addTerm(StringKey(term_value), 1, estimate); } } - void add_term(int64_t term_value) { + template <typename TermType> + void add_term(TermType term_value) { if (single_type) { if (in_operator) { add_term_helper(dynamic_cast<SingleInBlueprintType&>(*blueprint), term_value); @@ -246,8 +253,18 @@ public: } } } - std::unique_ptr<SearchIterator> create_leaf_search() const { - return blueprint->createLeafSearch(tfmda, true); + void add_terms(const std::vector<int64_t>& term_values) { + for (auto value : term_values) { + add_term(value); + } + } + void add_terms(const std::vector<vespalib::string>& term_values) { + for (auto value : term_values) { + add_term(value); + } + } + std::unique_ptr<SearchIterator> create_leaf_search(bool strict = true) const { + return blueprint->createLeafSearch(tfmda, strict); } vespalib::string resolve_iterator_with_unpack() const { if (in_operator) { @@ -262,7 +279,7 @@ expect_hits(const Docids& exp_docids, SearchIterator& itr) { SimpleResult exp(exp_docids); SimpleResult act; - act.search(itr); + act.search(itr, doc_id_limit); EXPECT_EQ(exp, act); } @@ -294,8 +311,7 @@ INSTANTIATE_TEST_SUITE_P(DefaultInstantiation, TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_for_none_filter_field) { setup(false, true); - add_term(1); - add_term(3); + add_terms({1, 3}); auto itr = create_leaf_search(); EXPECT_THAT(itr->asString(), StartsWith(resolve_iterator_with_unpack())); expect_hits({10, 30, 31}, *itr); @@ -307,8 +323,7 @@ TEST_P(DirectMultiTermBlueprintTest, bitvectors_used_instead_of_btree_iterators_ if (!in_operator) { return; } - add_term(1); - add_term(100); + add_terms({1, 100}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 2); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); @@ -322,8 +337,7 @@ TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_instead_of_bitvectors_ if (in_operator) { return; } - add_term(1); - add_term(100); + add_terms({1, 100}); auto itr = create_leaf_search(); EXPECT_THAT(itr->asString(), StartsWith(iterator_unpack_docid_and_weights)); expect_hits(concat({10}, range(100, 128)), *itr); @@ -332,10 +346,7 @@ TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_instead_of_bitvectors_ TEST_P(DirectMultiTermBlueprintTest, bitvectors_and_btree_iterators_used_for_filter_field) { setup(true, true); - add_term(1); - add_term(3); - add_term(100); - add_term(300); + add_terms({1, 3, 100, 300}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 3); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); @@ -347,8 +358,7 @@ TEST_P(DirectMultiTermBlueprintTest, bitvectors_and_btree_iterators_used_for_fil TEST_P(DirectMultiTermBlueprintTest, only_bitvectors_used_for_filter_field) { setup(true, true); - add_term(100); - add_term(300); + add_terms({100, 300}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 2); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); @@ -359,8 +369,7 @@ TEST_P(DirectMultiTermBlueprintTest, only_bitvectors_used_for_filter_field) TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_for_filter_field_when_ranking_not_needed) { setup(true, false); - add_term(1); - add_term(3); + add_terms({1, 3}); auto itr = create_leaf_search(); EXPECT_THAT(itr->asString(), StartsWith(iterator_unpack_none)); expect_hits({10, 30, 31}, *itr); @@ -369,10 +378,7 @@ TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_for_filter_field_when_ TEST_P(DirectMultiTermBlueprintTest, bitvectors_and_btree_iterators_used_for_filter_field_when_ranking_not_needed) { setup(true, false); - add_term(1); - add_term(3); - add_term(100); - add_term(300); + add_terms({1, 3, 100, 300}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 3); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); @@ -384,8 +390,7 @@ TEST_P(DirectMultiTermBlueprintTest, bitvectors_and_btree_iterators_used_for_fil TEST_P(DirectMultiTermBlueprintTest, only_bitvectors_used_for_filter_field_when_ranking_not_needed) { setup(true, false); - add_term(100); - add_term(300); + add_terms({100, 300}); auto itr = create_leaf_search(); expect_or_iterator(*itr, 2); expect_or_child(*itr, 0, "search::BitVectorIteratorStrictT"); @@ -393,4 +398,41 @@ TEST_P(DirectMultiTermBlueprintTest, only_bitvectors_used_for_filter_field_when_ expect_hits(concat(range(100, 128), range(300, 128)), *itr); } +TEST_P(DirectMultiTermBlueprintTest, hash_filter_used_for_non_strict_iterator_with_10_or_more_terms) +{ + setup(true, true); + if (!single_type) { + return; + } + add_terms({1, 3, 3, 3, 3, 3, 3, 3, 3, 3}); + auto itr = create_leaf_search(false); + EXPECT_THAT(itr->asString(), StartsWith("search::attribute::MultiTermHashFilter")); + expect_hits({10, 30, 31}, *itr); +} + +TEST_P(DirectMultiTermBlueprintTest, btree_iterators_used_for_non_strict_iterator_with_9_or_less_terms) +{ + setup(true, true); + if (!single_type) { + return; + } + add_terms({1, 3, 3, 3, 3, 3, 3, 3, 3}); + auto itr = create_leaf_search(false); + EXPECT_THAT(itr->asString(), StartsWith(iterator_unpack_docid)); + expect_hits({10, 30, 31}, *itr); +} + +TEST_P(DirectMultiTermBlueprintTest, hash_filter_with_string_folding_used_for_non_strict_iterator) +{ + setup(true, true); + if (!single_type || integer_type) { + return; + } + // "foo" matches documents with "foo" (40) and "Foo" (41). + add_terms({"foo", "3", "3", "3", "3", "3", "3", "3", "3", "3"}); + auto itr = create_leaf_search(false); + EXPECT_THAT(itr->asString(), StartsWith("search::attribute::MultiTermHashFilter")); + expect_hits({30, 31, 40, 41}, *itr); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp index fe6149e6fba..35732eaf7e5 100644 --- a/searchlib/src/tests/query/streaming_query_test.cpp +++ b/searchlib/src/tests/query/streaming_query_test.cpp @@ -23,10 +23,11 @@ using TermType = QueryTerm::Type; using search::fef::SimpleTermData; using search::fef::MatchData; -void assertHit(const Hit & h, size_t expWordpos, size_t expContext, int32_t weight) { - EXPECT_EQ(h.wordpos(), expWordpos); - EXPECT_EQ(h.context(), expContext); - EXPECT_EQ(h.weight(), weight); +void assertHit(const Hit & h, uint32_t exp_field_id, uint32_t exp_element_id, int32_t exp_element_weight, size_t exp_position) { + EXPECT_EQ(h.field_id(), exp_field_id); + EXPECT_EQ(h.element_id(), exp_element_id); + EXPECT_EQ(h.element_weight(), exp_element_weight); + EXPECT_EQ(h.position(), exp_position); } @@ -449,65 +450,69 @@ TEST(StreamingQueryTest, test_phrase_evaluate) } // field 0 - terms[0]->add(0, 0, 0, 1); - terms[1]->add(1, 0, 0, 1); - terms[2]->add(2, 0, 0, 1); - terms[0]->add(7, 0, 0, 1); - terms[1]->add(8, 0, 1, 1); - terms[2]->add(9, 0, 0, 1); + terms[0]->add(0, 0, 1, 0); + terms[1]->add(0, 0, 1, 1); + terms[2]->add(0, 0, 1, 2); + terms[0]->add(0, 0, 1, 7); + terms[1]->add(0, 1, 1, 8); + terms[2]->add(0, 0, 1, 9); // field 1 - terms[0]->add(4, 1, 0, 1); - terms[1]->add(5, 1, 0, 1); - terms[2]->add(6, 1, 0, 1); + terms[0]->add(1, 0, 1, 4); + terms[1]->add(1, 0, 1, 5); + terms[2]->add(1, 0, 1, 6); // field 2 (not complete match) - terms[0]->add(1, 2, 0, 1); - terms[1]->add(2, 2, 0, 1); - terms[2]->add(4, 2, 0, 1); + terms[0]->add(2, 0, 1, 1); + terms[1]->add(2, 0, 1, 2); + terms[2]->add(2, 0, 1, 4); // field 3 - terms[0]->add(0, 3, 0, 1); - terms[1]->add(1, 3, 0, 1); - terms[2]->add(2, 3, 0, 1); + terms[0]->add(3, 0, 1, 0); + terms[1]->add(3, 0, 1, 1); + terms[2]->add(3, 0, 1, 2); // field 4 (not complete match) - terms[0]->add(1, 4, 0, 1); - terms[1]->add(2, 4, 0, 1); + terms[0]->add(4, 0, 1, 1); + terms[1]->add(4, 0, 1, 2); // field 5 (not complete match) - terms[0]->add(2, 5, 0, 1); - terms[1]->add(1, 5, 0, 1); - terms[2]->add(0, 5, 0, 1); + terms[0]->add(5, 0, 1, 2); + terms[1]->add(5, 0, 1, 1); + terms[2]->add(5, 0, 1, 0); HitList hits; auto * p = static_cast<PhraseQueryNode *>(phrases[0]); p->evaluateHits(hits); ASSERT_EQ(3u, hits.size()); - EXPECT_EQ(hits[0].wordpos(), 2u); - EXPECT_EQ(hits[0].context(), 0u); - EXPECT_EQ(hits[1].wordpos(), 6u); - EXPECT_EQ(hits[1].context(), 1u); - EXPECT_EQ(hits[2].wordpos(), 2u); - EXPECT_EQ(hits[2].context(), 3u); + EXPECT_EQ(0u, hits[0].field_id()); + EXPECT_EQ(0u, hits[0].element_id()); + EXPECT_EQ(2u, hits[0].position()); + EXPECT_EQ(1u, hits[1].field_id()); + EXPECT_EQ(0u, hits[1].element_id()); + EXPECT_EQ(6u, hits[1].position()); + EXPECT_EQ(3u, hits[2].field_id()); + EXPECT_EQ(0u, hits[2].element_id()); + EXPECT_EQ(2u, hits[2].position()); ASSERT_EQ(4u, p->getFieldInfoSize()); - EXPECT_EQ(p->getFieldInfo(0).getHitOffset(), 0u); - EXPECT_EQ(p->getFieldInfo(0).getHitCount(), 1u); - EXPECT_EQ(p->getFieldInfo(1).getHitOffset(), 1u); - EXPECT_EQ(p->getFieldInfo(1).getHitCount(), 1u); - EXPECT_EQ(p->getFieldInfo(2).getHitOffset(), 0u); // invalid, but will never be used - EXPECT_EQ(p->getFieldInfo(2).getHitCount(), 0u); - EXPECT_EQ(p->getFieldInfo(3).getHitOffset(), 2u); - EXPECT_EQ(p->getFieldInfo(3).getHitCount(), 1u); + EXPECT_EQ(0u, p->getFieldInfo(0).getHitOffset()); + EXPECT_EQ(1u, p->getFieldInfo(0).getHitCount()); + EXPECT_EQ(1u, p->getFieldInfo(1).getHitOffset()); + EXPECT_EQ(1u, p->getFieldInfo(1).getHitCount()); + EXPECT_EQ(0u, p->getFieldInfo(2).getHitOffset()); // invalid, but will never be used + EXPECT_EQ(0u, p->getFieldInfo(2).getHitCount()); + EXPECT_EQ(2u, p->getFieldInfo(3).getHitOffset()); + EXPECT_EQ(1u, p->getFieldInfo(3).getHitCount()); EXPECT_TRUE(p->evaluate()); } TEST(StreamingQueryTest, test_hit) { - // positions (0 - (2^24-1)) - assertHit(Hit(0, 0, 0, 0), 0, 0, 0); - assertHit(Hit(256, 0, 0, 1), 256, 0, 1); - assertHit(Hit(16777215, 0, 0, -1), 16777215, 0, -1); - assertHit(Hit(16777216, 0, 0, 1), 0, 1, 1); // overflow - - // contexts (0 - 255) - assertHit(Hit(0, 1, 0, 1), 0, 1, 1); - assertHit(Hit(0, 255, 0, 1), 0, 255, 1); - assertHit(Hit(0, 256, 0, 1), 0, 0, 1); // overflow + // field id + assertHit(Hit( 1, 0, 1, 0), 1, 0, 1, 0); + assertHit(Hit(255, 0, 1, 0), 255, 0, 1, 0); + assertHit(Hit(256, 0, 1, 0), 256, 0, 1, 0); + + // positions + assertHit(Hit(0, 0, 0, 0), 0, 0, 0, 0); + assertHit(Hit(0, 0, 1, 256), 0, 0, 1, 256); + assertHit(Hit(0, 0, -1, 16777215), 0, 0, -1, 16777215); + assertHit(Hit(0, 0, 1, 16777216), 0, 0, 1, 16777216); + } void assertInt8Range(const std::string &term, bool expAdjusted, int64_t expLow, int64_t expHigh) { @@ -824,47 +829,47 @@ TEST(StreamingQueryTest, test_same_element_evaluate) } // field 0 - terms[0]->add(1, 0, 0, 10); - terms[0]->add(2, 0, 1, 20); - terms[0]->add(3, 0, 2, 30); - terms[0]->add(4, 0, 3, 40); - terms[0]->add(5, 0, 4, 50); - terms[0]->add(6, 0, 5, 60); - - terms[1]->add(7, 1, 0, 70); - terms[1]->add(8, 1, 1, 80); - terms[1]->add(9, 1, 2, 90); - terms[1]->add(10, 1, 4, 100); - terms[1]->add(11, 1, 5, 110); - terms[1]->add(12, 1, 6, 120); - - terms[2]->add(13, 2, 0, 130); - terms[2]->add(14, 2, 2, 140); - terms[2]->add(15, 2, 4, 150); - terms[2]->add(16, 2, 5, 160); - terms[2]->add(17, 2, 6, 170); + terms[0]->add(0, 0, 10, 1); + terms[0]->add(0, 1, 20, 2); + terms[0]->add(0, 2, 30, 3); + terms[0]->add(0, 3, 40, 4); + terms[0]->add(0, 4, 50, 5); + terms[0]->add(0, 5, 60, 6); + + terms[1]->add(1, 0, 70, 7); + terms[1]->add(1, 1, 80, 8); + terms[1]->add(1, 2, 90, 9); + terms[1]->add(1, 4, 100, 10); + terms[1]->add(1, 5, 110, 11); + terms[1]->add(1, 6, 120, 12); + + terms[2]->add(2, 0, 130, 13); + terms[2]->add(2, 2, 140, 14); + terms[2]->add(2, 4, 150, 15); + terms[2]->add(2, 5, 160, 16); + terms[2]->add(2, 6, 170, 17); HitList hits; sameElem->evaluateHits(hits); EXPECT_EQ(4u, hits.size()); - EXPECT_EQ(0u, hits[0].wordpos()); - EXPECT_EQ(2u, hits[0].context()); - EXPECT_EQ(0u, hits[0].elemId()); - EXPECT_EQ(130, hits[0].weight()); - - EXPECT_EQ(0u, hits[1].wordpos()); - EXPECT_EQ(2u, hits[1].context()); - EXPECT_EQ(2u, hits[1].elemId()); - EXPECT_EQ(140, hits[1].weight()); - - EXPECT_EQ(0u, hits[2].wordpos()); - EXPECT_EQ(2u, hits[2].context()); - EXPECT_EQ(4u, hits[2].elemId()); - EXPECT_EQ(150, hits[2].weight()); - - EXPECT_EQ(0u, hits[3].wordpos()); - EXPECT_EQ(2u, hits[3].context()); - EXPECT_EQ(5u, hits[3].elemId()); - EXPECT_EQ(160, hits[3].weight()); + EXPECT_EQ(2u, hits[0].field_id()); + EXPECT_EQ(0u, hits[0].element_id()); + EXPECT_EQ(130, hits[0].element_weight()); + EXPECT_EQ(0u, hits[0].position()); + + EXPECT_EQ(2u, hits[1].field_id()); + EXPECT_EQ(2u, hits[1].element_id()); + EXPECT_EQ(140, hits[1].element_weight()); + EXPECT_EQ(0u, hits[1].position()); + + EXPECT_EQ(2u, hits[2].field_id()); + EXPECT_EQ(4u, hits[2].element_id()); + EXPECT_EQ(150, hits[2].element_weight()); + EXPECT_EQ(0u, hits[2].position()); + + EXPECT_EQ(2u, hits[3].field_id()); + EXPECT_EQ(5u, hits[3].element_id()); + EXPECT_EQ(160, hits[3].element_weight()); + EXPECT_EQ(0u, hits[3].position()); EXPECT_TRUE(sameElem->evaluate()); } @@ -917,8 +922,8 @@ TEST(StreamingQueryTest, test_in_term) td.lookupField(12)->setHandle(1); EXPECT_FALSE(term.evaluate()); auto& q = *term.get_terms().front(); - q.add(0, 11, 0, 1); - q.add(0, 12, 0, 1); + q.add(11, 0, 1, 0); + q.add(12, 0, 1, 0); EXPECT_TRUE(term.evaluate()); MatchData md(MatchData::params().numTermFields(2)); term.unpack_match_data(23, td, md); @@ -944,11 +949,11 @@ TEST(StreamingQueryTest, dot_product_term) td.lookupField(12)->setHandle(1); EXPECT_FALSE(term.evaluate()); auto& q0 = *term.get_terms()[0]; - q0.add(0, 11, 0, -13); - q0.add(0, 12, 0, -17); + q0.add(11, 0, -13, 0); + q0.add(12, 0, -17, 0); auto& q1 = *term.get_terms()[1]; - q1.add(0, 11, 0, 4); - q1.add(0, 12, 0, 9); + q1.add(11, 0, 4, 0); + q1.add(12, 0, 9, 0); EXPECT_TRUE(term.evaluate()); MatchData md(MatchData::params().numTermFields(2)); term.unpack_match_data(23, td, md); @@ -989,11 +994,11 @@ check_wand_term(double limit, const vespalib::string& label) td.lookupField(12)->setHandle(1); EXPECT_FALSE(term.evaluate()); auto& q0 = *term.get_terms()[0]; - q0.add(0, 11, 0, 17); - q0.add(0, 12, 0, 13); + q0.add(11, 0, 17, 0); + q0.add(12, 0, 13, 0); auto& q1 = *term.get_terms()[1]; - q1.add(0, 11, 0, 9); - q1.add(0, 12, 0, 4); + q1.add(11, 0, 9, 0); + q1.add(12, 0, 4, 0); EXPECT_EQ(limit < exp_wand_score_field_11, term.evaluate()); MatchData md(MatchData::params().numTermFields(2)); term.unpack_match_data(23, td, md); @@ -1043,11 +1048,11 @@ TEST(StreamingQueryTest, weighted_set_term) td.lookupField(12)->setHandle(1); EXPECT_FALSE(term.evaluate()); auto& q0 = *term.get_terms()[0]; - q0.add(0, 11, 0, 10); - q0.add(0, 12, 0, 10); + q0.add(11, 0, 10, 0); + q0.add(12, 0, 10, 0); auto& q1 = *term.get_terms()[1]; - q1.add(0, 11, 0, 10); - q1.add(0, 12, 0, 10); + q1.add(11, 0, 10, 0); + q1.add(12, 0, 10, 0); EXPECT_TRUE(term.evaluate()); MatchData md(MatchData::params().numTermFields(2)); term.unpack_match_data(23, td, md); diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp index 01148c11c9c..d0353ab8947 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp @@ -6,12 +6,12 @@ #include <vespa/searchlib/common/bitvector.h> #include <vespa/searchlib/fef/matchdatalayout.h> #include <vespa/searchlib/query/query_term_ucs4.h> +#include <vespa/searchlib/queryeval/filter_wrapper.h> +#include <vespa/searchlib/queryeval/orsearch.h> #include <vespa/searchlib/queryeval/weighted_set_term_search.h> #include <vespa/vespalib/objects/visit.h> -#include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/stllike/hash_map.hpp> -#include <vespa/searchlib/queryeval/filter_wrapper.h> -#include <vespa/searchlib/queryeval/orsearch.h> +#include <vespa/vespalib/util/stringfmt.h> namespace search { @@ -38,6 +38,7 @@ class StringEnumWrapper : public AttrWrapper { public: using TokenT = uint32_t; + static constexpr bool unpack_weights = true; explicit StringEnumWrapper(const IAttributeVector & attr) : AttrWrapper(attr) {} auto mapToken(const ISearchContext &context) const { @@ -52,6 +53,7 @@ class IntegerWrapper : public AttrWrapper { public: using TokenT = uint64_t; + static constexpr bool unpack_weights = true; explicit IntegerWrapper(const IAttributeVector & attr) : AttrWrapper(attr) {} std::vector<int64_t> mapToken(const ISearchContext &context) const { std::vector<int64_t> result; diff --git a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h index 42651854599..485427391ad 100644 --- a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h +++ b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h @@ -35,6 +35,8 @@ private: using IteratorType = typename PostingStoreType::IteratorType; using IteratorWeights = std::variant<std::reference_wrapper<const std::vector<int32_t>>, std::vector<int32_t>>; + bool use_hash_filter(bool strict) const; + IteratorWeights create_iterators(std::vector<IteratorType>& btree_iterators, std::vector<std::unique_ptr<queryeval::SearchIterator>>& bitvectors, bool use_bitvector_when_available, diff --git a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp index 6fc8ac63026..0a3b24142a5 100644 --- a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp +++ b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp @@ -8,6 +8,7 @@ #include <vespa/searchlib/queryeval/emptysearch.h> #include <vespa/searchlib/queryeval/filter_wrapper.h> #include <vespa/searchlib/queryeval/orsearch.h> +#include <cmath> #include <memory> #include <type_traits> @@ -38,6 +39,43 @@ DirectMultiTermBlueprint<PostingStoreType, SearchType>::DirectMultiTermBlueprint template <typename PostingStoreType, typename SearchType> DirectMultiTermBlueprint<PostingStoreType, SearchType>::~DirectMultiTermBlueprint() = default; + +template <typename PostingStoreType, typename SearchType> +bool +DirectMultiTermBlueprint<PostingStoreType, SearchType>::use_hash_filter(bool strict) const +{ + if (strict || _iattr.hasMultiValue()) { + return false; + } + // The following very simplified formula was created after analysing performance of the IN operator + // on a 10M document corpus using a machine with an Intel Xeon 2.5 GHz CPU with 48 cores and 256 Gb of memory: + // https://github.com/vespa-engine/system-test/tree/master/tests/performance/in_operator + // + // The following 25 test cases were used to calculate the cost of using btree iterators (strict): + // op_hits_ratios = [5, 10, 50, 100, 200] * tokens_in_op = [1, 5, 10, 100, 1000] + // For each case we calculate the latency difference against the case with tokens_in_op=1 and the same op_hits_ratio. + // This indicates the extra time used to produce the same number of hits when having multiple tokens in the operator. + // The latency diff is divided with the number of hits produced and convert to nanoseconds: + // 10M * (op_hits_ratio / 1000) * 1000 * 1000 + // Based on the numbers we can approximate the cost per document (in nanoseconds) as: + // 8.0 (ns) * log2(tokens_in_op). + // NOTE: This is very simplified. Ideally we should also take into consideration the hit estimate of this blueprint, + // as the cost per document will be lower when producing few hits. + // + // In addition, the following 28 test cases were used to calculate the cost of using the hash filter (non-strict). + // filter_hits_ratios = [1, 5, 10, 50, 100, 150, 200] x op_hits_ratios = [200] x tokens_in_op = [5, 10, 100, 1000] + // The code was altered to always using the hash filter for non-strict iterators. + // For each case we calculate the latency difference against a case from above with tokens_in_op=1 that produce a similar number of hits. + // This indicates the extra time used to produce the same number of hits when using the hash filter. + // The latency diff is divided with the number of hits the test filter produces and convert to nanoseconds: + // 10M * (filter_hits_ratio / 1000) * 1000 * 1000 + // Based on the numbers we calculate the average cost per document (in nanoseconds) as 26.0 ns. + + float hash_filter_cost_per_doc_ns = 26.0; + float btree_iterator_cost_per_doc_ns = 8.0 * std::log2(_terms.size()); + return hash_filter_cost_per_doc_ns < btree_iterator_cost_per_doc_ns; +} + template <typename PostingStoreType, typename SearchType> typename DirectMultiTermBlueprint<PostingStoreType, SearchType>::IteratorWeights DirectMultiTermBlueprint<PostingStoreType, SearchType>::create_iterators(std::vector<IteratorType>& btree_iterators, @@ -96,14 +134,21 @@ DirectMultiTermBlueprint<PostingStoreType, SearchType>::create_search_helper(con if (_terms.empty()) { return std::make_unique<queryeval::EmptySearch>(); } + auto& tfmd = *tfmda[0]; + bool field_is_filter = getState().fields()[0].isFilter(); + if constexpr (SearchType::supports_hash_filter) { + if (use_hash_filter(strict)) { + return SearchType::create_hash_filter(tfmd, (filter_search || field_is_filter), + _weights, _terms, + _iattr, _attr, _dictionary_snapshot); + } + } std::vector<IteratorType> btree_iterators; std::vector<queryeval::SearchIterator::UP> bitvectors; const size_t num_children = _terms.size(); btree_iterators.reserve(num_children); - auto& tfmd = *tfmda[0]; bool use_bit_vector_when_available = filter_search || !_attr.has_always_btree_iterator(); auto weights = create_iterators(btree_iterators, bitvectors, use_bit_vector_when_available, tfmd, strict); - bool field_is_filter = getState().fields()[0].isFilter(); if constexpr (!SearchType::require_btree_iterators) { auto multi_term = !btree_iterators.empty() ? SearchType::create(tfmd, (filter_search || field_is_filter), std::move(weights), std::move(btree_iterators)) diff --git a/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h index 9c3ea258fdc..43500366f21 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h +++ b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h @@ -23,6 +23,7 @@ class MultiTermHashFilter final : public queryeval::SearchIterator public: using Key = typename WrapperType::TokenT; using TokenMap = vespalib::hash_map<Key, int32_t, vespalib::hash<Key>, std::equal_to<Key>, vespalib::hashtable_base::and_modulator>; + static constexpr bool unpack_weights = WrapperType::unpack_weights; private: fef::TermFieldMatchData& _tfmd; diff --git a/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp index 96d5b3ac1f3..e67bda147d7 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp @@ -42,10 +42,14 @@ template <typename WrapperType> void MultiTermHashFilter<WrapperType>::doUnpack(uint32_t docId) { - _tfmd.reset(docId); - fef::TermFieldMatchDataPosition pos; - pos.setElementWeight(_weight); - _tfmd.appendPosition(pos); + if constexpr (unpack_weights) { + _tfmd.reset(docId); + fef::TermFieldMatchDataPosition pos; + pos.setElementWeight(_weight); + _tfmd.appendPosition(pos); + } else { + _tfmd.resetOnlyDocId(docId); + } } } diff --git a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt index 05a75f4662e..76119a6d58f 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt @@ -2,6 +2,7 @@ vespa_add_library(searchlib_query_streaming OBJECT SOURCES dot_product_term.cpp + fuzzy_term.cpp in_term.cpp multi_term.cpp nearest_neighbor_query_node.cpp diff --git a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp index 1871bda564d..09840d9a126 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/dot_product_term.cpp @@ -24,7 +24,7 @@ DotProductTerm::build_scores(Scores& scores) const for (const auto& term : _terms) { auto& hl = term->evaluateHits(hl_store); for (auto& hit : hl) { - scores[hit.context()] += ((int64_t)term->weight().percent()) * hit.weight(); + scores[hit.field_id()] += ((int64_t)term->weight().percent()) * hit.element_weight(); } } } diff --git a/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.cpp new file mode 100644 index 00000000000..f33fe44369a --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.cpp @@ -0,0 +1,43 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "fuzzy_term.h" + +namespace search::streaming { + +namespace { + +constexpr bool normalizing_implies_cased(Normalizing norm) noexcept { + return (norm == Normalizing::NONE); +} + +} + +FuzzyTerm::FuzzyTerm(std::unique_ptr<QueryNodeResultBase> result_base, stringref term, + const string& index, Type type, Normalizing normalizing, + uint8_t max_edits, uint32_t prefix_size) + : QueryTerm(std::move(result_base), term, index, type, normalizing), + _dfa_matcher(), + _fallback_matcher() +{ + setFuzzyMaxEditDistance(max_edits); + setFuzzyPrefixLength(prefix_size); + + std::string_view term_view(term.data(), term.size()); + const bool cased = normalizing_implies_cased(normalizing); + if (attribute::DfaFuzzyMatcher::supports_max_edits(max_edits)) { + _dfa_matcher = std::make_unique<attribute::DfaFuzzyMatcher>(term_view, max_edits, prefix_size, cased); + } else { + _fallback_matcher = std::make_unique<vespalib::FuzzyMatcher>(term_view, max_edits, prefix_size, cased); + } +} + +FuzzyTerm::~FuzzyTerm() = default; + +bool FuzzyTerm::is_match(std::string_view term) const { + if (_dfa_matcher) { + return _dfa_matcher->is_match(term); + } else { + return _fallback_matcher->isMatch(term); + } +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.h b/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.h new file mode 100644 index 00000000000..c6c88b18969 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/fuzzy_term.h @@ -0,0 +1,34 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "queryterm.h" +#include <vespa/searchlib/attribute/dfa_fuzzy_matcher.h> +#include <vespa/vespalib/fuzzy/fuzzy_matcher.h> +#include <memory> +#include <string_view> + +namespace search::streaming { + +/** + * Query term that matches candidate field terms that are within a query-specified + * maximum number of edits (add, delete or substitute a character), with case + * sensitivity controlled by the provided Normalizing mode. + * + * Optionally, terms may be prefixed-locked, which enforces field terms to have a + * particular prefix and where edits are only counted for the remaining term suffix. + */ +class FuzzyTerm : public QueryTerm { + std::unique_ptr<attribute::DfaFuzzyMatcher> _dfa_matcher; + std::unique_ptr<vespalib::FuzzyMatcher> _fallback_matcher; +public: + FuzzyTerm(std::unique_ptr<QueryNodeResultBase> result_base, stringref term, + const string& index, Type type, Normalizing normalizing, + uint8_t max_edits, uint32_t prefix_size); + ~FuzzyTerm() override; + + [[nodiscard]] FuzzyTerm* as_fuzzy_term() noexcept override { return this; } + + [[nodiscard]] bool is_match(std::string_view term) const; +}; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/hit.h b/searchlib/src/vespa/searchlib/query/streaming/hit.h index a798d293491..cd72555ea66 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/hit.h +++ b/searchlib/src/vespa/searchlib/query/streaming/hit.h @@ -8,23 +8,21 @@ namespace search::streaming { class Hit { + uint32_t _field_id; + uint32_t _element_id; + int32_t _element_weight; + uint32_t _position; public: - Hit(uint32_t pos_, uint32_t context_, uint32_t elemId_, int32_t weight_) noexcept - : _position(pos_ | (context_<<24)), - _elemId(elemId_), - _weight(weight_) + Hit(uint32_t field_id_, uint32_t element_id_, int32_t element_weight_, uint32_t position_) noexcept + : _field_id(field_id_), + _element_id(element_id_), + _element_weight(element_weight_), + _position(position_) { } - int32_t weight() const { return _weight; } - uint32_t pos() const { return _position; } - uint32_t wordpos() const { return _position & 0xffffff; } - uint32_t context() const { return _position >> 24; } - uint32_t elemId() const { return _elemId; } - bool operator < (const Hit & b) const { return cmp(b) < 0; } -private: - int cmp(const Hit & b) const { return _position - b._position; } - uint32_t _position; - uint32_t _elemId; - int32_t _weight; + uint32_t field_id() const noexcept { return _field_id; } + uint32_t element_id() const { return _element_id; } + int32_t element_weight() const { return _element_weight; } + uint32_t position() const { return _position; } }; using HitList = std::vector<Hit>; diff --git a/searchlib/src/vespa/searchlib/query/streaming/in_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/in_term.cpp index 3e75f4a5114..c164db69ba1 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/in_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/in_term.cpp @@ -29,9 +29,9 @@ InTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_ for (const auto& term : _terms) { auto& hl = term->evaluateHits(hl_store); for (auto& hit : hl) { - if (!prev_field_id.has_value() || prev_field_id.value() != hit.context()) { - prev_field_id = hit.context(); - matching_field_ids.insert(hit.context()); + if (!prev_field_id.has_value() || prev_field_id.value() != hit.field_id()) { + prev_field_id = hit.field_id(); + matching_field_ids.insert(hit.field_id()); } } } diff --git a/searchlib/src/vespa/searchlib/query/streaming/multi_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/multi_term.cpp index dd34b9b7e73..f5a09892551 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/multi_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/multi_term.cpp @@ -51,11 +51,4 @@ MultiTerm::evaluate() const return false; } -const HitList& -MultiTerm::evaluateHits(HitList& hl) const -{ - hl.clear(); - return hl; -} - } diff --git a/searchlib/src/vespa/searchlib/query/streaming/multi_term.h b/searchlib/src/vespa/searchlib/query/streaming/multi_term.h index 3bb69e29693..6f795c31356 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/multi_term.h +++ b/searchlib/src/vespa/searchlib/query/streaming/multi_term.h @@ -32,7 +32,6 @@ public: MultiTerm* as_multi_term() noexcept override { return this; } void reset() override; bool evaluate() const override; - const HitList& evaluateHits(HitList& hl) const override; virtual void unpack_match_data(uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data) = 0; const std::vector<std::unique_ptr<QueryTerm>>& get_terms() const noexcept { return _terms; } }; diff --git a/searchlib/src/vespa/searchlib/query/streaming/query.cpp b/searchlib/src/vespa/searchlib/query/streaming/query.cpp index ca742aabe26..196de23c236 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/query.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/query.cpp @@ -31,7 +31,7 @@ const HitList & QueryConnector::evaluateHits(HitList & hl) const { if (evaluate()) { - hl.emplace_back(1, 0, 0, 1); + hl.emplace_back(0, 0, 1, 1); } return hl; } @@ -196,19 +196,19 @@ SameElementQueryNode::evaluateHits(HitList & hl) const unsigned int & nextIndex = indexVector[currMatchCount+1]; const auto & currHit = curr->evaluateHits(tmpHL)[currIndex]; - uint32_t currElemId = currHit.elemId(); + uint32_t currElemId = currHit.element_id(); const HitList & nextHL = next->evaluateHits(tmpHL); size_t nextIndexMax = nextHL.size(); - while ((nextIndex < nextIndexMax) && (nextHL[nextIndex].elemId() < currElemId)) { + while ((nextIndex < nextIndexMax) && (nextHL[nextIndex].element_id() < currElemId)) { nextIndex++; } - if ((nextIndex < nextIndexMax) && (nextHL[nextIndex].elemId() == currElemId)) { + if ((nextIndex < nextIndexMax) && (nextHL[nextIndex].element_id() == currElemId)) { currMatchCount++; if ((currMatchCount+1) == numFields) { Hit h = nextHL[indexVector[currMatchCount]]; - hl.emplace_back(0, h.context(), h.elemId(), h.weight()); + hl.emplace_back(h.field_id(), h.element_id(), h.element_weight(), 0); currMatchCount = 0; indexVector[0]++; } @@ -238,6 +238,15 @@ PhraseQueryNode::addChild(QueryNode::UP child) { AndQueryNode::addChild(std::move(child)); } +namespace { + +// TODO: Remove when rewriting PhraseQueryNode::evaluateHits +uint32_t legacy_pos(const Hit& hit) { + return ((hit.position() & 0xffffff) | ((hit.field_id() & 0xff) << 24)); +} + +} + const HitList & PhraseQueryNode::evaluateHits(HitList & hl) const { @@ -258,28 +267,28 @@ PhraseQueryNode::evaluateHits(HitList & hl) const unsigned int & nextIndex = indexVector[currPhraseLen+1]; const auto & currHit = curr->evaluateHits(tmpHL)[currIndex]; - size_t firstPosition = currHit.pos(); - uint32_t currElemId = currHit.elemId(); - uint32_t currContext = currHit.context(); + size_t firstPosition = legacy_pos(currHit); + uint32_t currElemId = currHit.element_id(); + uint32_t curr_field_id = currHit.field_id(); const HitList & nextHL = next->evaluateHits(tmpHL); int diff(0); size_t nextIndexMax = nextHL.size(); while ((nextIndex < nextIndexMax) && - ((nextHL[nextIndex].context() < currContext) || - ((nextHL[nextIndex].context() == currContext) && (nextHL[nextIndex].elemId() <= currElemId))) && - ((diff = nextHL[nextIndex].pos()-firstPosition) < 1)) + ((nextHL[nextIndex].field_id() < curr_field_id) || + ((nextHL[nextIndex].field_id() == curr_field_id) && (nextHL[nextIndex].element_id() <= currElemId))) && + ((diff = legacy_pos(nextHL[nextIndex])-firstPosition) < 1)) { nextIndex++; } - if ((diff == 1) && (nextHL[nextIndex].context() == currContext) && (nextHL[nextIndex].elemId() == currElemId)) { + if ((diff == 1) && (nextHL[nextIndex].field_id() == curr_field_id) && (nextHL[nextIndex].element_id() == currElemId)) { currPhraseLen++; if ((currPhraseLen+1) == fullPhraseLen) { Hit h = nextHL[indexVector[currPhraseLen]]; hl.push_back(h); - const QueryTerm::FieldInfo & fi = next->getFieldInfo(h.context()); - updateFieldInfo(h.context(), hl.size() - 1, fi.getFieldLength()); + const QueryTerm::FieldInfo & fi = next->getFieldInfo(h.field_id()); + updateFieldInfo(h.field_id(), hl.size() - 1, fi.getFieldLength()); currPhraseLen = 0; indexVector[0]++; } diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp index 2ee515f062a..e71529a8aca 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp @@ -1,7 +1,8 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include "query.h" +#include "fuzzy_term.h" #include "nearest_neighbor_query_node.h" +#include "query.h" #include "regexp_term.h" #include <vespa/searchlib/parsequery/stackdumpiterator.h> #include <vespa/searchlib/query/streaming/dot_product_term.h> @@ -147,17 +148,16 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor } else { Normalizing normalize_mode = factory.normalizing_mode(ssIndex); std::unique_ptr<QueryTerm> qt; - if (sTerm != TermType::REGEXP) { - qt = std::make_unique<QueryTerm>(factory.create(), ssTerm, ssIndex, sTerm, normalize_mode); - } else { + if (sTerm == TermType::REGEXP) { qt = std::make_unique<RegexpTerm>(factory.create(), ssTerm, ssIndex, TermType::REGEXP, normalize_mode); + } else if (sTerm == TermType::FUZZYTERM) { + qt = std::make_unique<FuzzyTerm>(factory.create(), ssTerm, ssIndex, TermType::FUZZYTERM, normalize_mode, + queryRep.getFuzzyMaxEditDistance(), queryRep.getFuzzyPrefixLength()); + } else [[likely]] { + qt = std::make_unique<QueryTerm>(factory.create(), ssTerm, ssIndex, sTerm, normalize_mode); } qt->setWeight(queryRep.GetWeight()); qt->setUniqueId(queryRep.getUniqueId()); - if (qt->isFuzzy()) { - qt->setFuzzyMaxEditDistance(queryRep.getFuzzyMaxEditDistance()); - qt->setFuzzyPrefixLength(queryRep.getFuzzyPrefixLength()); - } if (allowRewrite && possibleFloat(*qt, ssTerm) && factory.allow_float_terms_rewrite(ssIndex)) { auto phrase = std::make_unique<PhraseQueryNode>(); auto dotPos = ssTerm.find('.'); diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp index d72a3371846..af8ce7c9994 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp @@ -1,5 +1,6 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "querynoderesultbase.h" +#include <ostream> namespace search::streaming { diff --git a/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp b/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp index 3e05d381ee2..e0b78633af3 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp @@ -162,9 +162,9 @@ void QueryTerm::resizeFieldId(size_t fieldNo) } } -void QueryTerm::add(unsigned pos, unsigned context, uint32_t elemId, int32_t weight_) +void QueryTerm::add(uint32_t field_id, uint32_t element_id, int32_t element_weight, uint32_t position) { - _hitList.emplace_back(pos, context, elemId, weight_); + _hitList.emplace_back(field_id, element_id, element_weight, position); } NearestNeighborQueryNode* @@ -185,4 +185,10 @@ QueryTerm::as_regexp_term() noexcept return nullptr; } +FuzzyTerm* +QueryTerm::as_fuzzy_term() noexcept +{ + return nullptr; +} + } diff --git a/searchlib/src/vespa/searchlib/query/streaming/queryterm.h b/searchlib/src/vespa/searchlib/query/streaming/queryterm.h index cd2bdd7eaec..627fae0532d 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/queryterm.h +++ b/searchlib/src/vespa/searchlib/query/streaming/queryterm.h @@ -11,6 +11,7 @@ namespace search::streaming { +class FuzzyTerm; class NearestNeighborQueryNode; class MultiTerm; class RegexpTerm; @@ -64,7 +65,7 @@ public: QueryTerm & operator = (QueryTerm &&) = delete; ~QueryTerm() override; bool evaluate() const override; - const HitList & evaluateHits(HitList & hl) const override; + const HitList & evaluateHits(HitList & hl) const final override; void reset() override; void getLeaves(QueryTermList & tl) override; void getLeaves(ConstQueryTermList & tl) const override; @@ -73,7 +74,7 @@ public: /// Gives you all phrases of this tree. Indicating that they are all const. void getPhrases(ConstQueryNodeRefList & tl) const override; - void add(unsigned pos, unsigned context, uint32_t elemId, int32_t weight); + void add(uint32_t field_id, uint32_t element_id, int32_t element_weight, uint32_t position); EncodingBitMap encoding() const { return _encoding; } size_t termLen() const { return getTermLen(); } const string & index() const { return _index; } @@ -95,6 +96,7 @@ public: virtual NearestNeighborQueryNode* as_nearest_neighbor_query_node() noexcept; virtual MultiTerm* as_multi_term() noexcept; virtual RegexpTerm* as_regexp_term() noexcept; + virtual FuzzyTerm* as_fuzzy_term() noexcept; protected: using QueryNodeResultBaseContainer = std::unique_ptr<QueryNodeResultBase>; string _index; diff --git a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp index 90d0be5d43c..d2d706eef3d 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp @@ -25,7 +25,7 @@ WeightedSetTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchDat for (const auto& term : _terms) { auto& hl = term->evaluateHits(hl_store); for (auto& hit : hl) { - scores[hit.context()].emplace_back(term->weight().percent()); + scores[hit.field_id()].emplace_back(term->weight().percent()); } } auto num_fields = td.numFields(); diff --git a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h index 4473e0fa45b..3d8a5fba843 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h +++ b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h @@ -10,7 +10,6 @@ namespace search::streaming { * A weighted set query term for streaming search. */ class WeightedSetTerm : public MultiTerm { - double _score_threshold; public: WeightedSetTerm(std::unique_ptr<QueryNodeResultBase> result_base, const string& index, uint32_t num_terms); ~WeightedSetTerm() override; diff --git a/searchlib/src/vespa/searchlib/queryeval/dot_product_search.h b/searchlib/src/vespa/searchlib/queryeval/dot_product_search.h index e49fcbcb5bc..c74c2d4e9a7 100644 --- a/searchlib/src/vespa/searchlib/queryeval/dot_product_search.h +++ b/searchlib/src/vespa/searchlib/queryeval/dot_product_search.h @@ -27,6 +27,7 @@ protected: public: static constexpr bool filter_search = false; static constexpr bool require_btree_iterators = true; + static constexpr bool supports_hash_filter = false; // TODO: use MultiSearch::Children to pass ownership static SearchIterator::UP create(const std::vector<SearchIterator*> &children, diff --git a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp index 0c57b21aba6..0be014474d0 100644 --- a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp @@ -1,10 +1,13 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "weighted_set_term_search.h" +#include <vespa/searchcommon/attribute/i_search_context.h> +#include <vespa/searchcommon/attribute/iattributevector.h> +#include <vespa/searchlib/attribute/i_direct_posting_store.h> +#include <vespa/searchlib/attribute/multi_term_hash_filter.hpp> #include <vespa/searchlib/common/bitvector.h> -#include <vespa/searchlib/attribute/multi_term_or_filter_search.h> #include <vespa/vespalib/objects/visit.h> -#include <vespa/searchcommon/attribute/i_search_context.h> +#include <vespa/vespalib/stllike/hash_map.hpp> #include "iterator_pack.h" #include "blueprint.h" @@ -237,4 +240,98 @@ WeightedSetTermSearch::create(fef::TermFieldMatchData &tmd, return create_helper_resolve_pack<DocidWithWeightIterator, DocidWithWeightIteratorPack>(tmd, is_filter_search, std::move(weights), std::move(iterators)); } +namespace { + +class HashFilterWrapper { +protected: + const attribute::IAttributeVector& _attr; +public: + HashFilterWrapper(const attribute::IAttributeVector& attr) : _attr(attr) {} +}; + +template <bool unpack_weights_t> +class StringHashFilterWrapper : public HashFilterWrapper { +public: + using TokenT = attribute::IAttributeVector::EnumHandle; + static constexpr bool unpack_weights = unpack_weights_t; + StringHashFilterWrapper(const attribute::IAttributeVector& attr) + : HashFilterWrapper(attr) + {} + auto mapToken(const IDirectPostingStore::LookupResult& term, const IDirectPostingStore& store, vespalib::datastore::EntryRef dict_snapshot) const { + std::vector<TokenT> result; + store.collect_folded(term.enum_idx, dict_snapshot, [&](vespalib::datastore::EntryRef ref) { result.emplace_back(ref.ref()); }); + return result; + } + TokenT getToken(uint32_t docid) const { + return _attr.getEnum(docid); + } +}; + +template <bool unpack_weights_t> +class IntegerHashFilterWrapper : public HashFilterWrapper { +public: + using TokenT = attribute::IAttributeVector::largeint_t; + static constexpr bool unpack_weights = unpack_weights_t; + IntegerHashFilterWrapper(const attribute::IAttributeVector& attr) + : HashFilterWrapper(attr) + {} + auto mapToken(const IDirectPostingStore::LookupResult& term, + const IDirectPostingStore& store, + vespalib::datastore::EntryRef) const { + std::vector<TokenT> result; + result.emplace_back(store.get_integer_value(term.enum_idx)); + return result; + } + TokenT getToken(uint32_t docid) const { + return _attr.getInt(docid); + } +}; + +template <typename WrapperType> +SearchIterator::UP +create_hash_filter_helper(fef::TermFieldMatchData& tfmd, + const std::vector<int32_t>& weights, + const std::vector<IDirectPostingStore::LookupResult>& terms, + const attribute::IAttributeVector& attr, + const IDirectPostingStore& posting_store, + vespalib::datastore::EntryRef dict_snapshot) +{ + using FilterType = attribute::MultiTermHashFilter<WrapperType>; + typename FilterType::TokenMap tokens; + WrapperType wrapper(attr); + for (size_t i = 0; i < terms.size(); ++i) { + for (auto token : wrapper.mapToken(terms[i], posting_store, dict_snapshot)) { + tokens[token] = weights[i]; + } + } + return std::make_unique<FilterType>(tfmd, wrapper, std::move(tokens)); +} + +} + +SearchIterator::UP +WeightedSetTermSearch::create_hash_filter(search::fef::TermFieldMatchData& tmd, + bool is_filter_search, + const std::vector<int32_t>& weights, + const std::vector<IDirectPostingStore::LookupResult>& terms, + const attribute::IAttributeVector& attr, + const IDirectPostingStore& posting_store, + vespalib::datastore::EntryRef dict_snapshot) +{ + if (attr.isStringType()) { + if (is_filter_search) { + return create_hash_filter_helper<StringHashFilterWrapper<false>>(tmd, weights, terms, attr, posting_store, dict_snapshot); + } else { + return create_hash_filter_helper<StringHashFilterWrapper<true>>(tmd, weights, terms, attr, posting_store, dict_snapshot); + } + } else { + assert(attr.isIntegerType()); + if (is_filter_search) { + return create_hash_filter_helper<IntegerHashFilterWrapper<false>>(tmd, weights, terms, attr, posting_store, dict_snapshot); + } else { + return create_hash_filter_helper<IntegerHashFilterWrapper<true>>(tmd, weights, terms, attr, posting_store, dict_snapshot); + } + } +} + } diff --git a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h index 7e47928fb90..d078fd5babc 100644 --- a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h +++ b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h @@ -11,6 +11,8 @@ #include <variant> #include <vector> +namespace search::attribute { class IAttributeVector; } + namespace search::fef { class TermFieldMatchData; } namespace search::queryeval { @@ -30,6 +32,8 @@ public: static constexpr bool filter_search = false; // Whether this iterator requires btree iterators for all tokens/terms used by the operator. static constexpr bool require_btree_iterators = false; + // Whether this supports creating a hash filter iterator; + static constexpr bool supports_hash_filter = true; // TODO: pass ownership with unique_ptr static SearchIterator::UP create(const std::vector<SearchIterator *> &children, @@ -48,6 +52,14 @@ public: std::variant<std::reference_wrapper<const std::vector<int32_t>>, std::vector<int32_t>> weights, std::vector<DocidWithWeightIterator> &&iterators); + static SearchIterator::UP create_hash_filter(search::fef::TermFieldMatchData& tmd, + bool is_filter_search, + const std::vector<int32_t>& weights, + const std::vector<IDirectPostingStore::LookupResult>& terms, + const attribute::IAttributeVector& attr, + const IDirectPostingStore& posting_store, + vespalib::datastore::EntryRef dictionary_snapshot); + // used during docsum fetching to identify matching elements // initRange must be called before use. // doSeek/doUnpack must not be called. diff --git a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp index 93e35e4c6d2..c9518b29884 100644 --- a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp +++ b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp @@ -84,7 +84,7 @@ RankProcessorTest::test_unpack_match_data_for_term_node(bool interleaved_feature EXPECT_EQ(invalid_id, tfmd->getDocId()); RankProcessor::unpack_match_data(1, *md, *_query_wrapper); EXPECT_EQ(invalid_id, tfmd->getDocId()); - node->add(0, field_id, 0, 1); + node->add(field_id, 0, 1, 0); auto& field_info = node->getFieldInfo(field_id); field_info.setHitCount(mock_num_occs); field_info.setFieldLength(mock_field_length); diff --git a/streamingvisitors/src/tests/searcher/searcher_test.cpp b/streamingvisitors/src/tests/searcher/searcher_test.cpp index 24877866c1b..705e14c11a5 100644 --- a/streamingvisitors/src/tests/searcher/searcher_test.cpp +++ b/streamingvisitors/src/tests/searcher/searcher_test.cpp @@ -3,6 +3,7 @@ #include <vespa/vespalib/testkit/testapp.h> #include <vespa/document/fieldvalue/fieldvalues.h> +#include <vespa/searchlib/query/streaming/fuzzy_term.h> #include <vespa/searchlib/query/streaming/regexp_term.h> #include <vespa/searchlib/query/streaming/queryterm.h> #include <vespa/vsm/searcher/boolfieldsearcher.h> @@ -18,10 +19,14 @@ #include <vespa/vsm/searcher/utf8suffixstringfieldsearcher.h> #include <vespa/vsm/searcher/tokenizereader.h> #include <vespa/vsm/vsm/snippetmodifier.h> +#include <concepts> +#include <charconv> +#include <stdexcept> using namespace document; using search::streaming::HitList; using search::streaming::QueryNodeResultFactory; +using search::streaming::FuzzyTerm; using search::streaming::RegexpTerm; using search::streaming::QueryTerm; using search::streaming::Normalizing; @@ -58,6 +63,46 @@ public: } }; +namespace { + +template <std::integral T> +std::string_view maybe_consume_into(std::string_view str, T& val_out) { + auto [ptr, ec] = std::from_chars(str.data(), str.data() + str.size(), val_out); + if (ec != std::errc()) { + return str; + } + return str.substr(ptr - str.data()); +} + +// Parse optional max edits and prefix lock length from term string. +// Syntax: +// "term" -> {2, 0, "term"} (default max edits & prefix length) +// "{1}term" -> {1, 0, "term"} +// "{1,3}term" -> {1, 3, "term"} +// +// Note: this is not a "proper" parser (it accepts empty numeric values); only for testing! +std::tuple<uint8_t, uint32_t, std::string_view> parse_fuzzy_params(std::string_view term) { + if (term.empty() || term[0] != '{') { + return {2, 0, term}; + } + uint8_t max_edits = 2; + uint32_t prefix_length = 0; + term = maybe_consume_into(term.substr(1), max_edits); + if (term.empty() || (term[0] != ',' && term[0] != '}')) { + throw std::invalid_argument("malformed fuzzy params at (or after) max_edits"); + } + if (term[0] == '}') { + return {max_edits, prefix_length, term.substr(1)}; + } + term = maybe_consume_into(term.substr(1), prefix_length); + if (term.empty() || term[0] != '}') { + throw std::invalid_argument("malformed fuzzy params at (or after) prefix_length"); + } + return {max_edits, prefix_length, term.substr(1)}; +} + +} + class Query { private: @@ -66,10 +111,14 @@ private: ParsedQueryTerm pqt = parseQueryTerm(term); ParsedTerm pt = parseTerm(pqt.second); std::string effective_index = pqt.first.empty() ? "index" : pqt.first; - if (pt.second != TermType::REGEXP) { - qtv.push_back(std::make_unique<QueryTerm>(eqnr.create(), pt.first, effective_index, pt.second, normalizing)); + if (pt.second == TermType::REGEXP) { + qtv.push_back(std::make_unique<RegexpTerm>(eqnr.create(), pt.first, effective_index, TermType::REGEXP, normalizing)); + } else if (pt.second == TermType::FUZZYTERM) { + auto [max_edits, prefix_length, actual_term] = parse_fuzzy_params(pt.first); + qtv.push_back(std::make_unique<FuzzyTerm>(eqnr.create(), vespalib::stringref(actual_term.data(), actual_term.size()), + effective_index, TermType::FUZZYTERM, normalizing, max_edits, prefix_length)); } else { - qtv.push_back(std::make_unique<RegexpTerm>(eqnr.create(), pt.first, effective_index, pt.second, normalizing)); + qtv.push_back(std::make_unique<QueryTerm>(eqnr.create(), pt.first, effective_index, pt.second, normalizing)); } } for (const auto & i : qtv) { @@ -100,6 +149,8 @@ public: return std::make_pair(term.substr(1, term.size() - 1), TermType::SUFFIXTERM); } else if (term[0] == '#') { // magic regex enabler return std::make_pair(term.substr(1), TermType::REGEXP); + } else if (term[0] == '%') { // equally magic fuzzy enabler + return std::make_pair(term.substr(1), TermType::FUZZYTERM); } else if (term[term.size() - 1] == '*') { return std::make_pair(term.substr(0, term.size() - 1), TermType::PREFIXTERM); } else { @@ -349,7 +400,8 @@ assertSearch(FieldSearcher & fs, const StringList & query, const FieldValue & fv EXPECT_EQUAL(hl.size(), exp[i].size()); ASSERT_TRUE(hl.size() == exp[i].size()); for (size_t j = 0; j < hl.size(); ++j) { - EXPECT_EQUAL((size_t)hl[j].pos(), exp[i][j]); + EXPECT_EQUAL(0u, hl[j].field_id()); + EXPECT_EQUAL((size_t)hl[j].position(), exp[i][j]); } } } @@ -477,31 +529,54 @@ testStrChrFieldSearcher(StrChrFieldSearcher & fs) return true; } - TEST("verify correct term parsing") { - ASSERT_TRUE(Query::parseQueryTerm("index:term").first == "index"); - ASSERT_TRUE(Query::parseQueryTerm("index:term").second == "term"); - ASSERT_TRUE(Query::parseQueryTerm("term").first.empty()); - ASSERT_TRUE(Query::parseQueryTerm("term").second == "term"); - ASSERT_TRUE(Query::parseTerm("*substr*").first == "substr"); - ASSERT_TRUE(Query::parseTerm("*substr*").second == TermType::SUBSTRINGTERM); - ASSERT_TRUE(Query::parseTerm("*suffix").first == "suffix"); - ASSERT_TRUE(Query::parseTerm("*suffix").second == TermType::SUFFIXTERM); - ASSERT_TRUE(Query::parseTerm("prefix*").first == "prefix"); - ASSERT_TRUE(Query::parseTerm("prefix*").second == TermType::PREFIXTERM); - ASSERT_TRUE(Query::parseTerm("#regex").first == "regex"); - ASSERT_TRUE(Query::parseTerm("#regex").second == TermType::REGEXP); - ASSERT_TRUE(Query::parseTerm("term").first == "term"); - ASSERT_TRUE(Query::parseTerm("term").second == TermType::WORD); - } - - TEST("suffix matching") { - EXPECT_EQUAL(assertMatchTermSuffix("a", "vespa"), true); - EXPECT_EQUAL(assertMatchTermSuffix("spa", "vespa"), true); - EXPECT_EQUAL(assertMatchTermSuffix("vespa", "vespa"), true); - EXPECT_EQUAL(assertMatchTermSuffix("vvespa", "vespa"), false); - EXPECT_EQUAL(assertMatchTermSuffix("fspa", "vespa"), false); - EXPECT_EQUAL(assertMatchTermSuffix("v", "vespa"), false); - } +TEST("parsing of test-only fuzzy term params can extract numeric values") { + uint8_t max_edits = 0; + uint32_t prefix_length = 1234; + std::string_view out; + + std::tie(max_edits, prefix_length, out) = parse_fuzzy_params("myterm"); + EXPECT_EQUAL(max_edits, 2u); + EXPECT_EQUAL(prefix_length, 0u); + EXPECT_EQUAL(out, "myterm"); + + std::tie(max_edits, prefix_length, out) = parse_fuzzy_params("{3}myterm"); + EXPECT_EQUAL(max_edits, 3u); + EXPECT_EQUAL(prefix_length, 0u); + EXPECT_EQUAL(out, "myterm"); + + std::tie(max_edits, prefix_length, out) = parse_fuzzy_params("{2,70}myterm"); + EXPECT_EQUAL(max_edits, 2u); + EXPECT_EQUAL(prefix_length, 70u); + EXPECT_EQUAL(out, "myterm"); +} + +TEST("verify correct term parsing") { + ASSERT_TRUE(Query::parseQueryTerm("index:term").first == "index"); + ASSERT_TRUE(Query::parseQueryTerm("index:term").second == "term"); + ASSERT_TRUE(Query::parseQueryTerm("term").first.empty()); + ASSERT_TRUE(Query::parseQueryTerm("term").second == "term"); + ASSERT_TRUE(Query::parseTerm("*substr*").first == "substr"); + ASSERT_TRUE(Query::parseTerm("*substr*").second == TermType::SUBSTRINGTERM); + ASSERT_TRUE(Query::parseTerm("*suffix").first == "suffix"); + ASSERT_TRUE(Query::parseTerm("*suffix").second == TermType::SUFFIXTERM); + ASSERT_TRUE(Query::parseTerm("prefix*").first == "prefix"); + ASSERT_TRUE(Query::parseTerm("prefix*").second == TermType::PREFIXTERM); + ASSERT_TRUE(Query::parseTerm("#regex").first == "regex"); + ASSERT_TRUE(Query::parseTerm("#regex").second == TermType::REGEXP); + ASSERT_TRUE(Query::parseTerm("%fuzzy").first == "fuzzy"); + ASSERT_TRUE(Query::parseTerm("%fuzzy").second == TermType::FUZZYTERM); + ASSERT_TRUE(Query::parseTerm("term").first == "term"); + ASSERT_TRUE(Query::parseTerm("term").second == TermType::WORD); +} + +TEST("suffix matching") { + EXPECT_EQUAL(assertMatchTermSuffix("a", "vespa"), true); + EXPECT_EQUAL(assertMatchTermSuffix("spa", "vespa"), true); + EXPECT_EQUAL(assertMatchTermSuffix("vespa", "vespa"), true); + EXPECT_EQUAL(assertMatchTermSuffix("vvespa", "vespa"), false); + EXPECT_EQUAL(assertMatchTermSuffix("fspa", "vespa"), false); + EXPECT_EQUAL(assertMatchTermSuffix("v", "vespa"), false); +} TEST("Test basic strchrfield searchers") { { @@ -654,6 +729,96 @@ TEST("utf8 flexible searcher handles regexes with explicit anchoring") { TEST_DO(assertString(fs, "#^foo$", "oo", Hits())); } +TEST("utf8 flexible searcher handles fuzzy search in uncased mode") { + UTF8FlexibleStringFieldSearcher fs(0); + // Term syntax (only applies to these tests): + // %{k}term => fuzzy match "term" with max edits k + // %{k,p}term => fuzzy match "term" with max edits k, prefix lock length p + + // DFA is used for k in {1, 2} + TEST_DO(assertString(fs, "%{1}abc", "abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}ABC", "abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}abc", "ABC", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}Abc", "abd", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}abc", "ABCD", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}abc", "abcde", Hits())); + TEST_DO(assertString(fs, "%{2}abc", "abcde", Hits().add(0))); + TEST_DO(assertString(fs, "%{2}abc", "xabcde", Hits())); + // Fallback to non-DFA matcher when k not in {1, 2} + TEST_DO(assertString(fs, "%{3}abc", "abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{3}abc", "XYZ", Hits().add(0))); + TEST_DO(assertString(fs, "%{3}abc", "XYZ!", Hits())); +} + +TEST("utf8 flexible searcher handles fuzzy search in cased mode") { + UTF8FlexibleStringFieldSearcher fs(0); + fs.normalize_mode(Normalizing::NONE); + TEST_DO(assertString(fs, "%{1}abc", "abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}abc", "Abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{1}ABC", "abc", Hits())); + TEST_DO(assertString(fs, "%{2}Abc", "abc", Hits().add(0))); + TEST_DO(assertString(fs, "%{2}abc", "AbC", Hits().add(0))); + TEST_DO(assertString(fs, "%{3}abc", "ABC", Hits().add(0))); + TEST_DO(assertString(fs, "%{3}abc", "ABCD", Hits())); +} + +TEST("utf8 flexible searcher handles fuzzy search with prefix locking") { + UTF8FlexibleStringFieldSearcher fs(0); + // DFA + TEST_DO(assertString(fs, "%{1,4}zoid", "zoi", Hits())); + TEST_DO(assertString(fs, "%{1,4}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,4}zoid", "ZOID", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoid", Hits())); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "ZoidBerg", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "ZoidBergg", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoidborg", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoidblergh", Hits())); + TEST_DO(assertString(fs, "%{2,4}zoidberg", "zoidblergh", Hits().add(0))); + // Fallback + TEST_DO(assertString(fs, "%{3,4}zoidberg", "zoidblergh", Hits().add(0))); + TEST_DO(assertString(fs, "%{3,4}zoidberg", "zoidbooorg", Hits().add(0))); + TEST_DO(assertString(fs, "%{3,4}zoidberg", "zoidzooorg", Hits())); + + fs.normalize_mode(Normalizing::NONE); + // DFA + TEST_DO(assertString(fs, "%{1,4}zoid", "ZOID", Hits())); + TEST_DO(assertString(fs, "%{1,4}ZOID", "zoid", Hits())); + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoidBerg", Hits().add(0))); // 1 edit + TEST_DO(assertString(fs, "%{1,4}zoidberg", "zoidBblerg", Hits())); // 2 edits, 1 max + TEST_DO(assertString(fs, "%{2,4}zoidberg", "zoidBblerg", Hits().add(0))); // 2 edits, 2 max + // Fallback + TEST_DO(assertString(fs, "%{3,4}zoidberg", "zoidBERG", Hits())); // 4 edits, 3 max + TEST_DO(assertString(fs, "%{4,4}zoidberg", "zoidBERG", Hits().add(0))); // 4 edits, 4 max +} + +TEST("utf8 flexible searcher fuzzy match with max_edits=0 implies exact match") { + UTF8FlexibleStringFieldSearcher fs(0); + TEST_DO(assertString(fs, "%{0}zoid", "zoi", Hits())); + TEST_DO(assertString(fs, "%{0,4}zoid", "zoi", Hits())); + TEST_DO(assertString(fs, "%{0}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{0}zoid", "ZOID", Hits().add(0))); + TEST_DO(assertString(fs, "%{0,4}zoid", "ZOID", Hits().add(0))); + fs.normalize_mode(Normalizing::NONE); + TEST_DO(assertString(fs, "%{0}zoid", "ZOID", Hits())); + TEST_DO(assertString(fs, "%{0,4}zoid", "ZOID", Hits())); + TEST_DO(assertString(fs, "%{0}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{0,4}zoid", "zoid", Hits().add(0))); +} + +TEST("utf8 flexible searcher caps oversized fuzzy prefix length to term length") { + UTF8FlexibleStringFieldSearcher fs(0); + // DFA + TEST_DO(assertString(fs, "%{1,5}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,9001}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{1,9001}zoid", "boid", Hits())); + // Fallback + TEST_DO(assertString(fs, "%{0,5}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{5,5}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{0,9001}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{5,9001}zoid", "zoid", Hits().add(0))); + TEST_DO(assertString(fs, "%{5,9001}zoid", "boid", Hits())); +} + TEST("bool search") { BoolFieldSearcher fs(0); TEST_DO(assertBool(fs, "true", true, true)); diff --git a/streamingvisitors/src/vespa/searchvisitor/matching_elements_filler.cpp b/streamingvisitors/src/vespa/searchvisitor/matching_elements_filler.cpp index 095141c0359..d574101cc89 100644 --- a/streamingvisitors/src/vespa/searchvisitor/matching_elements_filler.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/matching_elements_filler.cpp @@ -109,7 +109,7 @@ Matcher::add_matching_elements(const vespalib::string& field_name, uint32_t doc_ { _elements.clear(); for (auto& hit : hit_list) { - _elements.emplace_back(hit.elemId()); + _elements.emplace_back(hit.element_id()); } if (_elements.size() > 1) { std::sort(_elements.begin(), _elements.end()); diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp index 6b15b7cb88e..3fc7f351151 100644 --- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp @@ -300,7 +300,7 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap // optimize for hitlist giving all hits for a single field in one chunk for (const Hit & hit : hitList) { - uint32_t fieldId = hit.context(); + uint32_t fieldId = hit.field_id(); if (fieldId != lastFieldId) { // reset to notfound/unknown values tmd = nullptr; @@ -335,8 +335,8 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap } if (tmd != nullptr) { // adjust so that the position for phrase terms equals the match for the first term - TermFieldMatchDataPosition pos(hit.elemId(), hit.wordpos() - term.getPosAdjust(), - hit.weight(), fieldLen); + TermFieldMatchDataPosition pos(hit.element_id(), hit.position() - term.getPosAdjust(), + hit.element_weight(), fieldLen); tmd->appendPosition(pos); LOG(debug, "Append elemId(%u),position(%u), weight(%d), tfmd.weight(%d)", pos.getElementId(), pos.getPosition(), pos.getElementWeight(), tmd->getWeight()); diff --git a/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.h b/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.h index c5bca6f3899..e339e4bdf5a 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.h +++ b/streamingvisitors/src/vespa/vsm/searcher/fieldsearcher.h @@ -46,7 +46,7 @@ public: explicit FieldSearcher(FieldIdT fId) noexcept : FieldSearcher(fId, false) {} FieldSearcher(FieldIdT fId, bool defaultPrefix) noexcept; ~FieldSearcher() override; - virtual std::unique_ptr<FieldSearcher> duplicate() const = 0; + [[nodiscard]] virtual std::unique_ptr<FieldSearcher> duplicate() const = 0; bool search(const StorageDocument & doc); virtual void prepare(search::streaming::QueryTermList& qtl, const SharedSearcherBuf& buf, const vsm::FieldPathMapT& field_paths, search::fef::IQueryEnvironment& query_env); @@ -106,7 +106,7 @@ protected: * For each call to onValue() a batch of words are processed, and the position is local to this batch. **/ void addHit(search::streaming::QueryTerm & qt, uint32_t pos) const { - qt.add(_words + pos, field(), _currentElementId, _currentElementWeight); + qt.add(field(), _currentElementId, _currentElementWeight, _words + pos); } public: static search::byte _foldLowCase[256]; diff --git a/streamingvisitors/src/vespa/vsm/searcher/strchrfieldsearcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/strchrfieldsearcher.cpp index c0a0249125f..98e88e45b3a 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/strchrfieldsearcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/strchrfieldsearcher.cpp @@ -34,7 +34,7 @@ bool StrChrFieldSearcher::matchDoc(const FieldRef & fieldRef) } } else { for (auto qt : _qtl) { - if (fieldRef.size() >= qt->termLen() || qt->isRegex()) { + if (fieldRef.size() >= qt->termLen() || qt->isRegex() || qt->isFuzzy()) { _words += matchTerm(fieldRef, *qt); } else { _words += countWords(fieldRef); @@ -49,8 +49,8 @@ size_t StrChrFieldSearcher::shortestTerm() const size_t mintsz(_qtl.front()->termLen()); for (auto it=_qtl.begin()+1, mt=_qtl.end(); it != mt; it++) { const QueryTerm & qt = **it; - if (qt.isRegex()) { - return 0; // Must avoid "too short query term" optimization when using regex + if (qt.isRegex() || qt.isFuzzy()) { + return 0; // Must avoid "too short query term" optimization when using regex or fuzzy } mintsz = std::min(mintsz, qt.termLen()); } diff --git a/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.cpp index c6deb6eacd1..d648d2e252e 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.cpp @@ -1,5 +1,6 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "utf8flexiblestringfieldsearcher.h" +#include <vespa/searchlib/query/streaming/fuzzy_term.h> #include <vespa/searchlib/query/streaming/regexp_term.h> #include <cassert> @@ -40,6 +41,19 @@ UTF8FlexibleStringFieldSearcher::match_regexp(const FieldRef & f, search::stream } size_t +UTF8FlexibleStringFieldSearcher::match_fuzzy(const FieldRef & f, search::streaming::QueryTerm & qt) +{ + auto* fuzzy_term = qt.as_fuzzy_term(); + assert(fuzzy_term != nullptr); + // TODO delegate to matchTermExact if max edits == 0? + // - needs to avoid folding to have consistent normalization semantics + if (fuzzy_term->is_match({f.data(), f.size()})) { + addHit(qt, 0); + } + return countWords(f); +} + +size_t UTF8FlexibleStringFieldSearcher::matchTerm(const FieldRef & f, QueryTerm & qt) { if (qt.isPrefix()) { @@ -57,6 +71,9 @@ UTF8FlexibleStringFieldSearcher::matchTerm(const FieldRef & f, QueryTerm & qt) } else if (qt.isRegex()) { LOG(debug, "Use regexp match for term '%s:%s'", qt.index().c_str(), qt.getTerm()); return match_regexp(f, qt); + } else if (qt.isFuzzy()) { + LOG(debug, "Use fuzzy match for term '%s:%s'", qt.index().c_str(), qt.getTerm()); + return match_fuzzy(f, qt); } else { if (substring()) { LOG(debug, "Use substring match for term '%s:%s'", qt.index().c_str(), qt.getTerm()); diff --git a/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.h b/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.h index cd1715ad158..a5f6ad46246 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.h +++ b/streamingvisitors/src/vespa/vsm/searcher/utf8flexiblestringfieldsearcher.h @@ -25,6 +25,7 @@ private: size_t matchTerms(const FieldRef & f, size_t shortestTerm) override; size_t match_regexp(const FieldRef & f, search::streaming::QueryTerm & qt); + size_t match_fuzzy(const FieldRef & f, search::streaming::QueryTerm & qt); public: std::unique_ptr<FieldSearcher> duplicate() const override; diff --git a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp index 9c8bb2f185a..63d2007cecf 100644 --- a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp +++ b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp @@ -133,7 +133,7 @@ FieldSearchSpec::reconfig(const QueryTerm & term) (term.isSuffix() && _arg1 != "suffix") || (term.isExactstring() && _arg1 != "exact") || (term.isPrefix() && _arg1 == "suffix") || - term.isRegex()) + (term.isRegex() || term.isFuzzy())) { _searcher = std::make_unique<UTF8FlexibleStringFieldSearcher>(id()); propagate_settings_to_searcher(); diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 5d88b2d2829..df75a6f6d1f 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -720,6 +720,20 @@ ], "fields" : [ ] }, + "com.yahoo.tensor.DirectIndexedAddress" : { + "superClass" : "java.lang.Object", + "interfaces" : [ ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public void setIndex(int, int)", + "public long getDirectIndex()", + "public long getStride(int)" + ], + "fields" : [ ] + }, "com.yahoo.tensor.IndexedDoubleTensor$BoundDoubleBuilder" : { "superClass" : "com.yahoo.tensor.IndexedTensor$BoundBuilder", "interfaces" : [ ], @@ -894,8 +908,11 @@ "public java.util.Iterator subspaceIterator(java.util.Set, com.yahoo.tensor.DimensionSizes)", "public java.util.Iterator subspaceIterator(java.util.Set)", "public varargs double get(long[])", + "public double get(com.yahoo.tensor.DirectIndexedAddress)", + "public com.yahoo.tensor.DirectIndexedAddress directAddress()", "public varargs float getFloat(long[])", "public double get(com.yahoo.tensor.TensorAddress)", + "public java.lang.Double getAsDouble(com.yahoo.tensor.TensorAddress)", "public boolean has(com.yahoo.tensor.TensorAddress)", "public abstract double get(long)", "public abstract float getFloat(long)", @@ -952,6 +969,7 @@ "public int sizeAsInt()", "public double get(com.yahoo.tensor.TensorAddress)", "public boolean has(com.yahoo.tensor.TensorAddress)", + "public java.lang.Double getAsDouble(com.yahoo.tensor.TensorAddress)", "public java.util.Iterator cellIterator()", "public java.util.Iterator valueIterator()", "public java.util.Map cells()", @@ -1032,6 +1050,7 @@ "public com.yahoo.tensor.TensorType type()", "public long size()", "public double get(com.yahoo.tensor.TensorAddress)", + "public java.lang.Double getAsDouble(com.yahoo.tensor.TensorAddress)", "public boolean has(com.yahoo.tensor.TensorAddress)", "public java.util.Iterator cellIterator()", "public java.util.Iterator valueIterator()", @@ -1158,6 +1177,7 @@ "public int sizeAsInt()", "public abstract double get(com.yahoo.tensor.TensorAddress)", "public abstract boolean has(com.yahoo.tensor.TensorAddress)", + "public abstract java.lang.Double getAsDouble(com.yahoo.tensor.TensorAddress)", "public abstract java.util.Iterator cellIterator()", "public abstract java.util.Iterator valueIterator()", "public abstract java.util.Map cells()", @@ -1454,6 +1474,9 @@ ], "methods" : [ "public void <init>(com.yahoo.tensor.TensorType$Value, java.util.Collection)", + "public boolean hasIndexedDimensions()", + "public boolean hasMappedDimensions()", + "public boolean hasOnlyIndexedBoundDimensions()", "public static varargs com.yahoo.tensor.TensorType$Value combinedValueType(com.yahoo.tensor.TensorType[])", "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)", "public com.yahoo.tensor.TensorType$Value valueType()", @@ -1464,6 +1487,7 @@ "public java.util.Set dimensionNames()", "public java.util.Optional dimension(java.lang.String)", "public java.util.Optional indexOfDimension(java.lang.String)", + "public int indexOfDimensionAsInt(java.lang.String)", "public java.util.Optional sizeOfDimension(java.lang.String)", "public boolean isAssignableTo(com.yahoo.tensor.TensorType)", "public boolean isConvertibleTo(com.yahoo.tensor.TensorType)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java new file mode 100644 index 00000000000..37752361876 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java @@ -0,0 +1,38 @@ +package com.yahoo.tensor; + +/** + * Utility class for efficient access and iteration along dimensions in Indexed tensors. + * Usage: Use setIndex to lock the indexes of the dimensions that don't change in this iteration. + * long base = addr.getDirectIndex(); + * long stride = addr.getStride(dimension) + * i = 0...size_of_dimension + * double value = tensor.get(base + i * stride); + */ +public final class DirectIndexedAddress { + private final DimensionSizes sizes; + private final int [] indexes; + private long directIndex; + private DirectIndexedAddress(DimensionSizes sizes) { + this.sizes = sizes; + indexes = new int[sizes.dimensions()]; + directIndex = 0; + } + static DirectIndexedAddress of(DimensionSizes sizes) { + return new DirectIndexedAddress(sizes); + } + /** Sets the current index of a dimension */ + public void setIndex(int dimension, int index) { + if (index < 0 || index >= sizes.size(dimension)) { + throw new IndexOutOfBoundsException("Index " + index + " outside of [0," + sizes.size(dimension) + ">"); + } + int diff = index - indexes[dimension]; + directIndex += getStride(dimension) * diff; + indexes[dimension] = index; + } + /** Retrieve the index that can be used for direct lookup in an indexed tensor. */ + public long getDirectIndex() { return directIndex; } + /** returns the stride to be used for the given dimension */ + public long getStride(int dimension) { + return sizes.productOfDimensionsAfter(dimension); + } +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 1319675f5d4..f26174d9576 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -93,6 +93,10 @@ public abstract class IndexedTensor implements Tensor { return get(toValueIndex(indexes, dimensionSizes)); } + public double get(DirectIndexedAddress address) { + return get(address.getDirectIndex()); + } + public DirectIndexedAddress directAddress() { return DirectIndexedAddress.of(dimensionSizes); } /** * Returns the value at the given indexes as a float * @@ -116,6 +120,17 @@ public abstract class IndexedTensor implements Tensor { } @Override + public Double getAsDouble(TensorAddress address) { + try { + long index = toValueIndex(address, dimensionSizes, type); + if (index < 0 || size() <= index) return null; + return get(index); + } catch (IllegalArgumentException e) { + return null; + } + } + + @Override public boolean has(TensorAddress address) { try { long index = toValueIndex(address, dimensionSizes, type); @@ -160,9 +175,10 @@ public abstract class IndexedTensor implements Tensor { long valueIndex = 0; for (int i = 0; i < address.size(); i++) { - if (address.numericLabel(i) >= sizes.size(i)) + long label = address.numericLabel(i); + if (label >= sizes.size(i)) throw new IllegalArgumentException(address + " is not within the bounds of " + type); - valueIndex += sizes.productOfDimensionsAfter(i) * address.numericLabel(i); + valueIndex += sizes.productOfDimensionsAfter(i) * label; } return valueIndex; } @@ -277,7 +293,7 @@ public abstract class IndexedTensor implements Tensor { } public static Builder of(TensorType type) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type)); else return new UnboundBuilder(type); @@ -291,7 +307,7 @@ public abstract class IndexedTensor implements Tensor { * must not be further mutated by the caller */ public static Builder of(TensorType type, float[] values) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type), values); else return new UnboundBuilder(type); @@ -305,7 +321,7 @@ public abstract class IndexedTensor implements Tensor { * must not be further mutated by the caller */ public static Builder of(TensorType type, double[] values) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type), values); else return new UnboundBuilder(type); @@ -611,11 +627,11 @@ public abstract class IndexedTensor implements Tensor { private final class ValueIterator implements Iterator<Double> { - private long count = 0; + private int count = 0; @Override public boolean hasNext() { - return count < size(); + return count < sizeAsInt(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index 5471ea65b97..3e0df5f2261 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -41,6 +41,9 @@ public class MappedTensor implements Tensor { public boolean has(TensorAddress address) { return cells.containsKey(address); } @Override + public Double getAsDouble(TensorAddress address) { return cells.get(address); } + + @Override public Iterator<Cell> cellIterator() { return new CellIteratorAdaptor(cells.entrySet().iterator()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 74b338fb503..95d1d70118a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -2,13 +2,15 @@ package com.yahoo.tensor; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.yahoo.tensor.impl.NumericTensorAddress; import com.yahoo.tensor.impl.StringTensorAddress; import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -29,7 +31,6 @@ public class MixedTensor implements Tensor { /** The dimension specification for this tensor */ private final TensorType type; - private final int denseSubspaceSize; // XXX consider using "record" instead /** only exposed for internal use; subject to change without notice */ @@ -51,45 +52,15 @@ public class MixedTensor implements Tensor { } } - /** The cells in the tensor */ - private final List<DenseSubspace> denseSubspaces; - /** only exposed for internal use; subject to change without notice */ - public List<DenseSubspace> getInternalDenseSubspaces() { return denseSubspaces; } + public List<DenseSubspace> getInternalDenseSubspaces() { return index.denseSubspaces; } /** An index structure over the cell list */ private final Index index; - private MixedTensor(TensorType type, List<DenseSubspace> denseSubspaces, Index index) { + private MixedTensor(TensorType type, Index index) { this.type = type; - this.denseSubspaceSize = index.denseSubspaceSize(); - this.denseSubspaces = List.copyOf(denseSubspaces); this.index = index; - if (this.denseSubspaceSize < 1) { - throw new IllegalStateException("invalid dense subspace size: " + denseSubspaceSize); - } - long count = 0; - for (var block : this.denseSubspaces) { - if (index.sparseMap.get(block.sparseAddress) != count) { - throw new IllegalStateException("map vs list mismatch: block #" - + count - + " address maps to #" - + index.sparseMap.get(block.sparseAddress)); - } - if (block.cells.length != denseSubspaceSize) { - throw new IllegalStateException("dense subspace size mismatch, expected " - + denseSubspaceSize - + " cells, but got: " - + block.cells.length); - } - ++count; - } - if (count != index.sparseMap.size()) { - throw new IllegalStateException("mismatch: list size is " - + count - + " but map size is " - + index.sparseMap.size()); - } } /** Returns the tensor type */ @@ -98,32 +69,34 @@ public class MixedTensor implements Tensor { /** Returns the size of the tensor measured in number of cells */ @Override - public long size() { return denseSubspaces.size() * denseSubspaceSize; } + public long size() { return index.denseSubspaces.size() * index.denseSubspaceSize; } /** Returns the value at the given address */ @Override public double get(TensorAddress address) { - int blockNum = index.blockIndexOf(address); - if (blockNum < 0 || blockNum > denseSubspaces.size()) { + var block = index.blockOf(address); + int denseOffset = index.denseOffsetOf(address); + if (block == null || denseOffset < 0 || denseOffset >= block.cells.length) { return 0.0; } + return block.cells[denseOffset]; + } + + @Override + public Double getAsDouble(TensorAddress address) { + var block = index.blockOf(address); int denseOffset = index.denseOffsetOf(address); - var block = denseSubspaces.get(blockNum); - if (denseOffset < 0 || denseOffset >= block.cells.length) { - return 0.0; + if (block == null || denseOffset < 0 || denseOffset >= block.cells.length) { + return null; } return block.cells[denseOffset]; } @Override public boolean has(TensorAddress address) { - int blockNum = index.blockIndexOf(address); - if (blockNum < 0 || blockNum > denseSubspaces.size()) { - return false; - } + var block = index.blockOf(address); int denseOffset = index.denseOffsetOf(address); - var block = denseSubspaces.get(blockNum); - return (denseOffset >= 0 && denseOffset < block.cells.length); + return (block != null && denseOffset >= 0 && denseOffset < block.cells.length); } /** @@ -136,20 +109,26 @@ public class MixedTensor implements Tensor { @Override public Iterator<Cell> cellIterator() { return new Iterator<>() { - final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator(); + final Iterator<DenseSubspace> blockIterator = index.denseSubspaces.iterator(); DenseSubspace currBlock = null; - int currOffset = denseSubspaceSize; + final long[] labels = new long[index.indexedDimensions.size()]; + int currOffset = index.denseSubspaceSize; + int prevOffset = -1; @Override public boolean hasNext() { - return (currOffset < denseSubspaceSize || blockIterator.hasNext()); + return (currOffset < index.denseSubspaceSize || blockIterator.hasNext()); } @Override public Cell next() { - if (currOffset == denseSubspaceSize) { + if (currOffset == index.denseSubspaceSize) { currBlock = blockIterator.next(); currOffset = 0; } - TensorAddress fullAddr = index.fullAddressOf(currBlock.sparseAddress, currOffset); + if (currOffset != prevOffset) { // Optimization for index.denseSubspaceSize == 1 + index.denseOffsetToAddress(currOffset, labels); + } + TensorAddress fullAddr = index.fullAddressOf(currBlock.sparseAddress, labels); + prevOffset = currOffset; double value = currBlock.cells[currOffset++]; return new Cell(fullAddr, value); } @@ -163,16 +142,16 @@ public class MixedTensor implements Tensor { @Override public Iterator<Double> valueIterator() { return new Iterator<>() { - final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator(); + final Iterator<DenseSubspace> blockIterator = index.denseSubspaces.iterator(); double[] currBlock = null; - int currOffset = denseSubspaceSize; + int currOffset = index.denseSubspaceSize; @Override public boolean hasNext() { - return (currOffset < denseSubspaceSize || blockIterator.hasNext()); + return (currOffset < index.denseSubspaceSize || blockIterator.hasNext()); } @Override public Double next() { - if (currOffset == denseSubspaceSize) { + if (currOffset == index.denseSubspaceSize) { currBlock = blockIterator.next().cells; currOffset = 0; } @@ -198,24 +177,22 @@ public class MixedTensor implements Tensor { throw new IllegalArgumentException("MixedTensor.withType: types are not compatible. Current type: '" + this.type + "', requested type: '" + type + "'"); } - return new MixedTensor(other, denseSubspaces, index); + return new MixedTensor(other, index); } @Override public Tensor remove(Set<TensorAddress> addresses) { var indexBuilder = new Index.Builder(type); - List<DenseSubspace> list = new ArrayList<>(); - for (var block : denseSubspaces) { + for (var block : index.denseSubspaces) { if ( ! addresses.contains(block.sparseAddress)) { // assumption: addresses only contain the sparse part - indexBuilder.addBlock(block.sparseAddress, list.size()); - list.add(block); + indexBuilder.addBlock(block); } } - return new MixedTensor(type, list, indexBuilder.build()); + return new MixedTensor(type, indexBuilder.build()); } @Override - public int hashCode() { return Objects.hash(type, denseSubspaces); } + public int hashCode() { return Objects.hash(type, index.denseSubspaces); } @Override public String toString() { @@ -250,13 +227,14 @@ public class MixedTensor implements Tensor { /** Returns the size of dense subspaces */ public long denseSubspaceSize() { - return denseSubspaceSize; + return index.denseSubspaceSize; } /** * Base class for building mixed tensors. */ public abstract static class Builder implements Tensor.Builder { + static final int INITIAL_HASH_CAPACITY = 1000; final TensorType type; @@ -266,10 +244,11 @@ public class MixedTensor implements Tensor { * a temporary structure while finding dimension bounds. */ public static Builder of(TensorType type) { - if (type.dimensions().stream().anyMatch(d -> d instanceof TensorType.IndexedUnboundDimension)) { - return new UnboundBuilder(type); + //TODO Wire in expected map size to avoid expensive resize + if (type.hasIndexedUnboundDimensions()) { + return new UnboundBuilder(type, INITIAL_HASH_CAPACITY); } else { - return new BoundBuilder(type); + return new BoundBuilder(type, INITIAL_HASH_CAPACITY); } } @@ -307,13 +286,14 @@ public class MixedTensor implements Tensor { public static class BoundBuilder extends Builder { /** For each sparse partial address, hold a dense subspace */ - private final Map<TensorAddress, double[]> denseSubspaceMap = new HashMap<>(); + private final Map<TensorAddress, double[]> denseSubspaceMap; private final Index.Builder indexBuilder; private final Index index; private final TensorType denseSubtype; - private BoundBuilder(TensorType type) { + private BoundBuilder(TensorType type, int expectedSize) { super(type); + denseSubspaceMap = new LinkedHashMap<>(expectedSize, 0.5f); indexBuilder = new Index.Builder(type); index = indexBuilder.index(); denseSubtype = new TensorType(type.valueType(), @@ -325,10 +305,7 @@ public class MixedTensor implements Tensor { } private double[] denseSubspace(TensorAddress sparseAddress) { - if (!denseSubspaceMap.containsKey(sparseAddress)) { - denseSubspaceMap.put(sparseAddress, new double[(int)denseSubspaceSize()]); - } - return denseSubspaceMap.get(sparseAddress); + return denseSubspaceMap.computeIfAbsent(sparseAddress, (key) -> new double[(int)denseSubspaceSize()]); } public IndexedTensor.DirectIndexBuilder denseSubspaceBuilder(TensorAddress sparseAddress) { @@ -363,19 +340,20 @@ public class MixedTensor implements Tensor { @Override public MixedTensor build() { - List<DenseSubspace> list = new ArrayList<>(); - for (Map.Entry<TensorAddress, double[]> entry : denseSubspaceMap.entrySet()) { + //TODO This can be solved more efficiently with a single map. + Set<Map.Entry<TensorAddress, double[]>> entrySet = denseSubspaceMap.entrySet(); + for (Map.Entry<TensorAddress, double[]> entry : entrySet) { TensorAddress sparsePart = entry.getKey(); double[] denseSubspace = entry.getValue(); var block = new DenseSubspace(sparsePart, denseSubspace); - indexBuilder.addBlock(sparsePart, list.size()); - list.add(block); + indexBuilder.addBlock(block); } - return new MixedTensor(type, list, indexBuilder.build()); + return new MixedTensor(type, indexBuilder.build()); } public static BoundBuilder of(TensorType type) { - return new BoundBuilder(type); + //TODO Wire in expected map size to avoid expensive resize + return new BoundBuilder(type, INITIAL_HASH_CAPACITY); } } @@ -392,9 +370,9 @@ public class MixedTensor implements Tensor { private final Map<TensorAddress, Double> cells; private final long[] dimensionBounds; - private UnboundBuilder(TensorType type) { + private UnboundBuilder(TensorType type, int expectedSize) { super(type); - cells = new HashMap<>(); + cells = new LinkedHashMap<>(expectedSize, 0.5f); dimensionBounds = new long[type.dimensions().size()]; } @@ -413,7 +391,7 @@ public class MixedTensor implements Tensor { @Override public MixedTensor build() { TensorType boundType = createBoundType(); - BoundBuilder builder = new BoundBuilder(boundType); + BoundBuilder builder = new BoundBuilder(boundType, cells.size()); for (Map.Entry<TensorAddress, Double> cell : cells.entrySet()) { builder.cell(cell.getKey(), cell.getValue()); } @@ -444,7 +422,8 @@ public class MixedTensor implements Tensor { } public static UnboundBuilder of(TensorType type) { - return new UnboundBuilder(type); + //TODO Wire in expected map size to avoid expensive resize + return new UnboundBuilder(type, INITIAL_HASH_CAPACITY); } } @@ -461,8 +440,10 @@ public class MixedTensor implements Tensor { private final TensorType denseType; private final List<TensorType.Dimension> mappedDimensions; private final List<TensorType.Dimension> indexedDimensions; + private final int [] indexedDimensionsSize; private ImmutableMap<TensorAddress, Integer> sparseMap; + private List<DenseSubspace> denseSubspaces; private final int denseSubspaceSize; static private int computeDSS(List<TensorType.Dimension> dimensions) { @@ -478,17 +459,31 @@ public class MixedTensor implements Tensor { this.type = type; this.mappedDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).toList(); this.indexedDimensions = type.dimensions().stream().filter(TensorType.Dimension::isIndexed).toList(); + this.indexedDimensionsSize = new int[indexedDimensions.size()]; + for (int i = 0; i < indexedDimensions.size(); i++) { + long dimensionSize = indexedDimensions.get(i).size().orElseThrow(() -> + new IllegalArgumentException("Unknown size of indexed dimension.")); + indexedDimensionsSize[i] = (int)dimensionSize; + } + this.sparseType = createPartialType(type.valueType(), mappedDimensions); this.denseType = createPartialType(type.valueType(), indexedDimensions); this.denseSubspaceSize = computeDSS(this.indexedDimensions); + if (this.denseSubspaceSize < 1) { + throw new IllegalStateException("invalid dense subspace size: " + denseSubspaceSize); + } } - int blockIndexOf(TensorAddress address) { + private DenseSubspace blockOf(TensorAddress address) { TensorAddress sparsePart = sparsePartialAddress(address); - return sparseMap.getOrDefault(sparsePart, -1); + Integer blockNum = sparseMap.get(sparsePart); + if (blockNum == null || blockNum >= denseSubspaces.size()) { + return null; + } + return denseSubspaces.get(blockNum); } - int denseOffsetOf(TensorAddress address) { + private int denseOffsetOf(TensorAddress address) { long innerSize = 1; long offset = 0; for (int i = type.dimensions().size(); --i >= 0; ) { @@ -519,38 +514,32 @@ public class MixedTensor implements Tensor { return builder.build(); } - private TensorAddress denseOffsetToAddress(long denseOffset) { + private void denseOffsetToAddress(long denseOffset, long [] labels) { if (denseOffset < 0 || denseOffset > denseSubspaceSize) { throw new IllegalArgumentException("Offset out of bounds"); } long restSize = denseOffset; long innerSize = denseSubspaceSize; - long[] labels = new long[indexedDimensions.size()]; for (int i = 0; i < labels.length; ++i) { - TensorType.Dimension dimension = indexedDimensions.get(i); - long dimensionSize = dimension.size().orElseThrow(() -> - new IllegalArgumentException("Unknown size of indexed dimension.")); - - innerSize /= dimensionSize; + innerSize /= indexedDimensionsSize[i]; labels[i] = restSize / innerSize; restSize %= innerSize; } - return TensorAddress.of(labels); } - TensorAddress fullAddressOf(TensorAddress sparsePart, long denseOffset) { - TensorAddress densePart = denseOffsetToAddress(denseOffset); + private TensorAddress fullAddressOf(TensorAddress sparsePart, long [] densePart) { String[] labels = new String[type.dimensions().size()]; int mappedIndex = 0; int indexedIndex = 0; - for (TensorType.Dimension d : type.dimensions()) { + for (int i = 0; i < type.dimensions().size(); i++) { + TensorType.Dimension d = type.dimensions().get(i); if (d.isIndexed()) { - labels[mappedIndex + indexedIndex] = densePart.label(indexedIndex); + labels[i] = NumericTensorAddress.asString(densePart[indexedIndex]); indexedIndex++; } else { - labels[mappedIndex + indexedIndex] = sparsePart.label(mappedIndex); + labels[i] = sparsePart.label(mappedIndex); mappedIndex++; } } @@ -606,8 +595,7 @@ public class MixedTensor implements Tensor { b.append(", "); // start brackets - for (int i = 0; i < indexes.nextDimensionsAtStart(); i++) - b.append("["); + b.append("[".repeat(Math.max(0, indexes.nextDimensionsAtStart()))); // value switch (type.valueType()) { @@ -620,32 +608,38 @@ public class MixedTensor implements Tensor { } // end bracket - for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) - b.append("]"); + b.append("]".repeat(Math.max(0, indexes.nextDimensionsAtEnd()))); } return index; } private double getDouble(int subspaceIndex, int denseOffset, MixedTensor tensor) { - return tensor.denseSubspaces.get(subspaceIndex).cells[denseOffset]; + return tensor.index.denseSubspaces.get(subspaceIndex).cells[denseOffset]; } - static class Builder { + private static class Builder { private final Index index; - private final ImmutableMap.Builder<TensorAddress, Integer> builder; + private final ImmutableMap.Builder<TensorAddress, Integer> builder = new ImmutableMap.Builder<>(); + private final ImmutableList.Builder<DenseSubspace> listBuilder = new ImmutableList.Builder<>(); + private int count = 0; Builder(TensorType type) { index = new Index(type); - builder = new ImmutableMap.Builder<>(); } - void addBlock(TensorAddress address, int sz) { - builder.put(address, sz); + void addBlock(DenseSubspace block) { + if (block.cells.length != index.denseSubspaceSize) { + throw new IllegalStateException("dense subspace size mismatch, expected " + index.denseSubspaceSize + + " cells, but got: " + block.cells.length); + } + builder.put(block.sparseAddress, count++); + listBuilder.add(block); } Index build() { index.sparseMap = builder.build(); + index.denseSubspaces = listBuilder.build(); return index; } @@ -655,27 +649,16 @@ public class MixedTensor implements Tensor { } } - private static class DenseSubspaceBuilder implements IndexedTensor.DirectIndexBuilder { - - private final TensorType type; - private final double[] values; - - public DenseSubspaceBuilder(TensorType type, double[] values) { - this.type = type; - this.values = values; - } - - @Override - public TensorType type() { return type; } + private record DenseSubspaceBuilder(TensorType type, double[] values) implements IndexedTensor.DirectIndexBuilder { @Override public void cellByDirectIndex(long index, double value) { - values[(int)index] = value; + values[(int) index] = value; } @Override public void cellByDirectIndex(long index, float value) { - values[(int)index] = value; + values[(int) index] = value; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index cc8e1602adb..d034ac551f8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -90,6 +90,8 @@ public interface Tensor { /** Returns true if this cell exists */ boolean has(TensorAddress address); + /** null = no value present. More efficient that if (t.has(key)) t.get(key) */ + Double getAsDouble(TensorAddress address); /** * Returns the cell of this in some undefined order. @@ -113,7 +115,7 @@ public interface Tensor { * @throws IllegalStateException if this does not have zero dimensions and one value */ default double asDouble() { - if (type().dimensions().size() > 0) + if (!type().dimensions().isEmpty()) throw new IllegalStateException("Require a dimensionless tensor but has " + type()); if (size() == 0) return Double.NaN; return valueIterator().next(); @@ -553,8 +555,8 @@ public interface Tensor { /** Creates a suitable builder for the given type */ static Builder of(TensorType type) { - boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); - boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); + boolean containsIndexed = type.hasIndexedDimensions(); + boolean containsMapped = type.hasMappedDimensions(); if (containsIndexed && containsMapped) return MixedTensor.Builder.of(type); if (containsMapped) @@ -565,8 +567,8 @@ public interface Tensor { /** Creates a suitable builder for the given type */ static Builder of(TensorType type, DimensionSizes dimensionSizes) { - boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); - boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); + boolean containsIndexed = type.hasIndexedDimensions(); + boolean containsMapped = type.hasMappedDimensions(); if (containsIndexed && containsMapped) return MixedTensor.Builder.of(type); if (containsMapped) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index f841b7757fb..1b88a5d1b2f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -3,10 +3,12 @@ package com.yahoo.tensor; import com.yahoo.tensor.impl.NumericTensorAddress; import com.yahoo.tensor.impl.StringTensorAddress; +import net.jpountz.xxhash.XXHash32; +import net.jpountz.xxhash.XXHashFactory; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Objects; -import java.util.Optional; /** * An immutable address to a tensor cell. This simply supplies a value to each dimension @@ -16,6 +18,8 @@ import java.util.Optional; */ public abstract class TensorAddress implements Comparable<TensorAddress> { + private static final XXHash32 hasher = XXHashFactory.fastestJavaInstance().hash32(); + public static TensorAddress of(String[] labels) { return StringTensorAddress.of(labels); } @@ -28,6 +32,8 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return NumericTensorAddress.of(labels); } + private int cached_hash = 0; + /** Returns the number of labels in this */ public abstract int size(); @@ -62,12 +68,17 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { @Override public int hashCode() { - int result = 1; + if (cached_hash != 0) return cached_hash; + + int hash = 0; for (int i = 0; i < size(); i++) { - if (label(i) != null) - result = 31 * result + label(i).hashCode(); + String label = label(i); + if (label != null) { + byte [] buf = label.getBytes(StandardCharsets.UTF_8); + hash = hasher.hash(buf, 0, buf.length, hash); + } } - return result; + return cached_hash = hash; } @Override @@ -138,10 +149,10 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { public Builder add(String dimension, String label) { Objects.requireNonNull(dimension, "dimension cannot be null"); Objects.requireNonNull(label, "label cannot be null"); - Optional<Integer> labelIndex = type.indexOfDimension(dimension); - if ( labelIndex.isEmpty()) + int labelIndex = type.indexOfDimensionAsInt(dimension); + if ( labelIndex < 0) throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'"); - labels[labelIndex.get()] = label; + labels[labelIndex] = label; return this; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index b30b664a5f7..dcfee88d599 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -1,6 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import com.google.common.collect.ImmutableSet; import com.yahoo.text.Ascii7BitMatcher; import java.util.ArrayList; @@ -86,16 +87,20 @@ public class TensorType { /** Sorted list of the dimensions of this */ private final List<Dimension> dimensions; + private final Set<String> dimensionNames; private final TensorType mappedSubtype; private final TensorType indexedSubtype; + private final int indexedUnBoundCount; // only used to initialize the "empty" instance private TensorType() { this.valueType = Value.DOUBLE; this.dimensions = List.of(); + this.dimensionNames = Set.of(); this.mappedSubtype = this; this.indexedSubtype = this; + indexedUnBoundCount = 0; } public TensorType(Value valueType, Collection<Dimension> dimensions) { @@ -103,12 +108,25 @@ public class TensorType { List<Dimension> dimensionList = new ArrayList<>(dimensions); Collections.sort(dimensionList); this.dimensions = List.copyOf(dimensionList); + ImmutableSet.Builder<String> namesbuilder = new ImmutableSet.Builder<>(); + int indexedBoundCount = 0, indexedUnBoundCount = 0, mappedCount = 0; + for (Dimension dimension : dimensionList) { + namesbuilder.add(dimension.name()); + Dimension.Type type = dimension.type(); + switch (type) { + case indexedUnbound -> indexedUnBoundCount++; + case indexedBound -> indexedBoundCount++; + case mapped -> mappedCount++; + } + } + this.indexedUnBoundCount = indexedUnBoundCount; + dimensionNames = namesbuilder.build(); - if (dimensionList.stream().allMatch(Dimension::isIndexed)) { + if (mappedCount == 0) { mappedSubtype = empty; indexedSubtype = this; } - else if (dimensionList.stream().noneMatch(Dimension::isIndexed)) { + else if ((indexedBoundCount + indexedUnBoundCount) == 0) { mappedSubtype = this; indexedSubtype = empty; } @@ -118,6 +136,11 @@ public class TensorType { } } + public boolean hasIndexedDimensions() { return indexedSubtype != empty; } + public boolean hasMappedDimensions() { return mappedSubtype != empty; } + public boolean hasOnlyIndexedBoundDimensions() { return !hasMappedDimensions() && ! hasIndexedUnboundDimensions(); } + boolean hasIndexedUnboundDimensions() { return indexedUnBoundCount > 0; } + static public Value combinedValueType(TensorType ... types) { List<Value> valueTypes = new ArrayList<>(); for (TensorType type : types) { @@ -161,7 +184,7 @@ public class TensorType { /** Returns an immutable set of the names of the dimensions of this */ public Set<String> dimensionNames() { - return dimensions.stream().map(Dimension::name).collect(Collectors.toSet()); + return dimensionNames; } /** Returns the dimension with this name, or empty if not present */ @@ -176,6 +199,13 @@ public class TensorType { return Optional.of(i); return Optional.empty(); } + /** Returns the 0-base index of this dimension, or empty if it is not present */ + public int indexOfDimensionAsInt(String dimension) { + for (int i = 0; i < dimensions.size(); i++) + if (dimensions.get(i).name().equals(dimension)) + return i; + return -1; + } /* Returns the bound of this dimension if it is present and bound in this, empty otherwise */ public Optional<Long> sizeOfDimension(String dimension) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 8d8fe2b356f..866b710b72e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -134,7 +134,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET return tensor; } else { // extend tensor with this dimension - if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) + if (tensor.type().hasMappedDimensions()) throw new IllegalArgumentException("Concat requires an indexed tensor, " + "but got a tensor with type " + tensor.type()); Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index 3b6e03186a3..b595b1a40cd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -40,7 +40,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens @Override public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { - if (arguments.size() != 0) + if (!arguments.isEmpty()) throw new IllegalArgumentException("Dynamic tensors must have 0 arguments, got " + arguments.size()); return this; } @@ -79,7 +79,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() { var result = new ArrayList<TensorFunction<NAMETYPE>>(); for (var fun : cells.values()) { - fun.asTensorFunction().ifPresent(tf -> result.add(tf)); + fun.asTensorFunction().ifPresent(result::add); } return result; } @@ -133,7 +133,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens IndexedDynamicTensor(TensorType type, List<ScalarFunction<NAMETYPE>> cells) { super(type); - if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)) + if ( ! type.hasOnlyIndexedBoundDimensions()) throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " + "only indexed, bound dimensions, but this has " + type); this.cells = List.copyOf(cells); @@ -142,7 +142,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() { var result = new ArrayList<TensorFunction<NAMETYPE>>(); for (var fun : cells) { - fun.asTensorFunction().ifPresent(tf -> result.add(tf)); + fun.asTensorFunction().ifPresent(result::add); } return result; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 1ded16636d3..e0ac549651c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -114,7 +114,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } private static Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) { - long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); + int joinedRank = (int)Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); Iterator<Double> aIterator = a.valueIterator(); Iterator<Double> bIterator = b.valueIterator(); IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build()); @@ -129,8 +129,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> aCell = i.next(); var key = aCell.getKey(); - if (b.has(key)) { - builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key))); + Double bVal = b.getAsDouble(key); + if (bVal != null) { + builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal)); } } return builder.build(); @@ -170,7 +171,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP Iterator<Tensor.Cell> superspace, long superspaceSize, boolean reversedArgumentOrder, IndexedTensor.Builder builder, DoubleBinaryOperator combinator) { - long joinedLength = Math.min(subspaceSize, superspaceSize); + int joinedLength = (int)Math.min(subspaceSize, superspaceSize); if (reversedArgumentOrder) { for (int i = 0; i < joinedLength; i++) { Tensor.Cell supercell = superspace.next(); @@ -206,11 +207,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> supercell = i.next(); TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes); - if (subspace.has(subaddress)) { - double subspaceValue = subspace.get(subaddress); + Double subspaceValue = subspace.getAsDouble(subaddress); + if (subspaceValue != null) { builder.cell(supercell.getKey(), - reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) - : combinator.applyAsDouble(subspaceValue, supercell.getValue())); + reversedArgumentOrder + ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) + : combinator.applyAsDouble(subspaceValue, supercell.getValue())); } } return builder.build(); @@ -252,6 +254,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder, DoubleBinaryOperator combinator) { Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames()); + int sharedDimensionSize = sharedDimensions.size(); // Expensive to compute size after intersection Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames()); DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize); @@ -263,7 +266,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP // for each combination of dimensions in a which is also in b while (aSubspace.hasNext()) { Tensor.Cell aCell = aSubspace.next(); - PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions); + PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize); // for each matching combination of dimensions ony in b for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) { Tensor.Cell bCell = bSubspace.next(); @@ -275,8 +278,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } } - private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) { - PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size()); + private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, + Set<String> retainDimensions, int sharedDimensionSize) { + PartialAddress.Builder builder = new PartialAddress.Builder(sharedDimensionSize); for (int i = 0; i < addressType.dimensions().size(); i++) if (retainDimensions.contains(addressType.dimensions().get(i).name())) builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i)); @@ -331,12 +335,11 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP int[] bIndexesInJoined = mapIndexes(b.type(), joinedType); // Iterate once through the smaller tensor and construct a hash map for common dimensions - Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>(); + Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>(a.sizeAsInt()); for (Iterator<Tensor.Cell> cellIterator = a.cellIterator(); cellIterator.hasNext(); ) { Tensor.Cell aCell = cellIterator.next(); TensorAddress partialCommonAddress = partialCommonAddress(aCell, aIndexesInCommon); - aCellsByCommonAddress.putIfAbsent(partialCommonAddress, new ArrayList<>()); - aCellsByCommonAddress.get(partialCommonAddress).add(aCell); + aCellsByCommonAddress.computeIfAbsent(partialCommonAddress, (key) -> new ArrayList<>()).add(aCell); } // Iterate once through the larger tensor and use the hash map to find joinable cells @@ -359,7 +362,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } /** - * Returns the an array having one entry in order for each dimension of fromType + * Returns an array having one entry in order for each dimension of fromType * containing the index at which toType contains the same dimension name. * That is, if the returned array contains n at index i then * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) @@ -368,7 +371,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP static int[] mapIndexes(TensorType fromType, TensorType toType) { int[] toIndexes = new int[fromType.dimensions().size()]; for (int i = 0; i < fromType.dimensions().size(); i++) - toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1); + toIndexes[i] = toType.indexOfDimensionAsInt(fromType.dimensions().get(i).name()); return toIndexes; } @@ -390,8 +393,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP private static boolean mapContent(TensorAddress from, String[] to, int[] indexMap) { for (int i = 0; i < from.size(); i++) { int toIndex = indexMap[i]; - if (to[toIndex] != null && ! to[toIndex].equals(from.label(i))) return false; - to[toIndex] = from.label(i); + String label = from.label(i); + if (to[toIndex] != null && ! to[toIndex].equals(label)) return false; + to[toIndex] = label; } return true; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java index 59394785382..ddad91dc060 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java @@ -121,10 +121,11 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> aCell = i.next(); var key = aCell.getKey(); - if (! b.has(key)) { + Double bVal = b.getAsDouble(key); + if (bVal == null) { builder.cell(key, aCell.getValue()); } else if (combinator != null) { - builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key))); + builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal)); } } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index fe20c41174a..77e82b818a7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.impl.Convert; import com.yahoo.tensor.impl.StringTensorAddress; import java.util.ArrayList; @@ -136,9 +137,7 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); TensorAddress reducedAddress = reduceDimensions(indexesToKeep, cell.getKey()); - ValueAggregator aggr = aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); - if (aggr == null) - aggr = aggregatingCells.get(reducedAddress); + ValueAggregator aggr = aggregatingCells.computeIfAbsent(reducedAddress, (key) ->ValueAggregator.ofType(aggregator)); aggr.aggregate(cell.getValue()); } Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType); @@ -172,14 +171,15 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); ) valueAggregator.aggregate(i.next()); - return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); + return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue()).build(); } private static Tensor reduceIndexedVector(IndexedTensor argument, Aggregator aggregator) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); - for (int i = 0; i < argument.dimensionSizes().size(0); i++) + int dimensionSize = Convert.safe2Int(argument.dimensionSizes().size(0)); + for (int i = 0; i < dimensionSize ; i++) valueAggregator.aggregate(argument.get(i)); - return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); + return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue()).build(); } static abstract class ValueAggregator { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index aece782d296..2d5a0518747 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -92,11 +92,11 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N return false; if ( ! (a instanceof IndexedTensor)) return false; - if ( ! (a.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))) + if ( ! (a.type().hasOnlyIndexedBoundDimensions())) return false; if ( ! (b instanceof IndexedTensor)) return false; - if ( ! (b.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))) + if ( ! (b.type().hasOnlyIndexedBoundDimensions())) return false; TensorType commonDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java new file mode 100644 index 00000000000..e2cb64fdd1f --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java @@ -0,0 +1,16 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.impl; + +/** + * Utility to make common conversions safe + * + * @author baldersheim + */ +public class Convert { + public static int safe2Int(long value) { + if (value > Integer.MAX_VALUE || value < Integer.MIN_VALUE) { + throw new IndexOutOfBoundsException("value = " + value + ", which is too large to fit in an int"); + } + return (int) value; + } +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 444ce02b14a..771b74633d9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -21,10 +21,8 @@ import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Slice; import java.util.ArrayList; -import java.util.HashSet; import java.util.Iterator; import java.util.List; -import java.util.Set; /** * Writes tensors on the JSON format used in Vespa tensor document fields: @@ -60,8 +58,7 @@ public class JsonFormat { // Short form for a single mapped dimension Cursor parent = root == null ? slime.setObject() : root.setObject("cells"); encodeSingleDimensionCells((MappedTensor) tensor, parent); - } else if (tensor instanceof MixedTensor && - tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped)) { + } else if (tensor instanceof MixedTensor && tensor.type().hasMappedDimensions()) { // Short form for a mixed tensor boolean singleMapped = tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() == 1; Cursor parent = root == null ? ( singleMapped ? slime.setObject() : slime.setArray() ) @@ -204,7 +201,7 @@ public class JsonFormat { if (root.field("cells").valid() && ! primitiveContent(root.field("cells"))) decodeCells(root.field("cells"), builder); - else if (root.field("values").valid() && builder.type().dimensions().stream().allMatch(d -> d.isIndexed())) + else if (root.field("values").valid() && ! builder.type().hasMappedDimensions()) decodeValuesAtTop(root.field("values"), builder); else if (root.field("blocks").valid()) decodeBlocks(root.field("blocks"), builder); @@ -298,14 +295,14 @@ public class JsonFormat { /** Decodes a tensor value directly at the root, where the format is decided by the tensor type. */ private static void decodeDirectValue(Inspector root, Tensor.Builder builder) { - boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); - boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped); + boolean hasIndexed = builder.type().hasIndexedDimensions(); + boolean hasMapped = builder.type().hasMappedDimensions(); if (isArrayOfObjects(root)) decodeCells(root, builder); else if ( ! hasMapped) decodeValuesAtTop(root, builder); - else if (hasMapped && hasIndexed) + else if (hasIndexed) decodeBlocks(root, builder); else decodeCells(root, builder); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index d4b18c73f11..0a5c713f3e2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -55,8 +55,8 @@ public class TypedBinaryFormat { } private static BinaryFormat getFormatEncoder(GrowableByteBuffer buffer, Tensor tensor) { - boolean hasMappedDimensions = tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped); - boolean hasIndexedDimensions = tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); + boolean hasMappedDimensions = tensor.type().hasMappedDimensions(); + boolean hasIndexedDimensions = tensor.type().hasIndexedDimensions(); boolean isMixed = hasMappedDimensions && hasIndexedDimensions; // TODO: Encoding as indexed if the implementation is mixed is not yet supported so use mixed format instead diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java index 0a6c821e64e..afc95d295f0 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java @@ -9,6 +9,7 @@ import java.util.Iterator; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -96,6 +97,38 @@ public class IndexedTensorTestCase { } @Test + public void testDirectIndexedAddress() { + TensorType type = new TensorType.Builder().indexed("v", 3) + .indexed("w", wSize) + .indexed("x", xSize) + .indexed("y", ySize) + .indexed("z", zSize) + .build(); + var directAddress = DirectIndexedAddress.of(DimensionSizes.of(type)); + assertThrows(ArrayIndexOutOfBoundsException.class, () -> directAddress.getStride(5)); + assertThrows(IndexOutOfBoundsException.class, () -> directAddress.setIndex(4, 7)); + assertEquals(wSize*xSize*ySize*zSize, directAddress.getStride(0)); + assertEquals(xSize*ySize*zSize, directAddress.getStride(1)); + assertEquals(ySize*zSize, directAddress.getStride(2)); + assertEquals(zSize, directAddress.getStride(3)); + assertEquals(1, directAddress.getStride(4)); + assertEquals(0, directAddress.getDirectIndex()); + directAddress.setIndex(0,1); + assertEquals(1 * directAddress.getStride(0), directAddress.getDirectIndex()); + directAddress.setIndex(1,1); + assertEquals(1 * directAddress.getStride(0) + 1 * directAddress.getStride(1), directAddress.getDirectIndex()); + directAddress.setIndex(2,2); + directAddress.setIndex(3,2); + directAddress.setIndex(4,2); + long expected = 1 * directAddress.getStride(0) + + 1 * directAddress.getStride(1) + + 2 * directAddress.getStride(2) + + 2 * directAddress.getStride(3) + + 2 * directAddress.getStride(4); + assertEquals(expected, directAddress.getDirectIndex()); + } + + @Test public void testUnboundBuilding() { TensorType type = new TensorType.Builder().indexed("w") .indexed("v") diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java index 5c4d5f1ffcf..74237a218fb 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java @@ -9,8 +9,10 @@ import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.TensorFunction; -import java.util.*; -import java.util.stream.Collectors; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + /** * Microbenchmark of tensor operations. @@ -91,7 +93,7 @@ public class TensorFunctionBenchmark { .value(random.nextDouble()); } } - return Collections.singletonList(builder.build()); + return List.of(builder.build()); } private static TensorType vectorType(TensorType.Builder builder, String name, TensorType.Dimension.Type type, int size) { @@ -107,45 +109,51 @@ public class TensorFunctionBenchmark { public static void main(String[] args) { double time = 0; - // ---------------- Mapped with extra space (sidesteps current special-case optimizations): - // 7.8 ms - time = new TensorFunctionBenchmark().benchmark(1000, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); - System.out.printf("Mapped vectors, x space time per join: %1$8.3f ms\n", time); - // 7.7 ms - time = new TensorFunctionBenchmark().benchmark(1000, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); - System.out.printf("Mapped matrix, x space time per join: %1$8.3f ms\n", time); + // ---------------- Indexed unbound: + time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); + System.out.printf("Indexed unbound vectors, time per join: %1$8.3f ms\n", time); + time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); + System.out.printf("Indexed unbound matrix, time per join: %1$8.3f ms\n", time); + + // ---------------- Indexed bound: + time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); + System.out.printf("Indexed bound vectors, time per join: %1$8.3f ms\n", time); + + time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); + System.out.printf("Indexed bound matrix, time per join: %1$8.3f ms\n", time); // ---------------- Mapped: - // 2.1 ms time = new TensorFunctionBenchmark().benchmark(5000, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false); System.out.printf("Mapped vectors, time per join: %1$8.3f ms\n", time); - // 7.0 ms + time = new TensorFunctionBenchmark().benchmark(1000, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false); System.out.printf("Mapped matrix, time per join: %1$8.3f ms\n", time); // ---------------- Indexed (unbound) with extra space (sidesteps current special-case optimizations): - // 14.5 ms time = new TensorFunctionBenchmark().benchmark(500, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true); System.out.printf("Indexed vectors, x space time per join: %1$8.3f ms\n", time); - // 8.9 ms time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true); System.out.printf("Indexed matrix, x space time per join: %1$8.3f ms\n", time); - // ---------------- Indexed unbound: - // 0.14 ms - time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); - System.out.printf("Indexed unbound vectors, time per join: %1$8.3f ms\n", time); - // 0.44 ms - time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); - System.out.printf("Indexed unbound matrix, time per join: %1$8.3f ms\n", time); + // ---------------- Mapped with extra space (sidesteps current special-case optimizations): + time = new TensorFunctionBenchmark().benchmark(1000, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); + System.out.printf("Mapped vectors, x space time per join: %1$8.3f ms\n", time); - // ---------------- Indexed bound: - // 0.32 ms - time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); - System.out.printf("Indexed bound vectors, time per join: %1$8.3f ms\n", time); - // 0.44 ms - time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); - System.out.printf("Indexed bound matrix, time per join: %1$8.3f ms\n", time); + time = new TensorFunctionBenchmark().benchmark(1000, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); + System.out.printf("Mapped matrix, x space time per join: %1$8.3f ms\n", time); + + /* 2.4Ghz Intel Core i9, Macbook Pro 2019 + * Indexed unbound vectors, time per join: 0,067 ms + * Indexed unbound matrix, time per join: 0,107 ms + * Indexed bound vectors, time per join: 0,068 ms + * Indexed bound matrix, time per join: 0,105 ms + * Mapped vectors, time per join: 1,342 ms + * Mapped matrix, time per join: 3,448 ms + * Indexed vectors, x space time per join: 6,398 ms + * Indexed matrix, x space time per join: 3,220 ms + * Mapped vectors, x space time per join: 14,984 ms + * Mapped matrix, x space time per join: 19,873 ms + */ } } diff --git a/vespalib/src/vespa/fastlib/text/unicodeutil.cpp b/vespalib/src/vespa/fastlib/text/unicodeutil.cpp index e29b91d6522..bd4ff5d93a9 100644 --- a/vespalib/src/vespa/fastlib/text/unicodeutil.cpp +++ b/vespalib/src/vespa/fastlib/text/unicodeutil.cpp @@ -11,15 +11,15 @@ namespace { class Initialize { public: - Initialize() { Fast_UnicodeUtil::InitTables(); } + Initialize() noexcept { Fast_UnicodeUtil::InitTables(); } }; -Initialize _G_Initializer; +Initialize _g_initializer; } void -Fast_UnicodeUtil::InitTables() +Fast_UnicodeUtil::InitTables() noexcept { /** * Hack for Katakana accent marks (torgeir) @@ -29,8 +29,7 @@ Fast_UnicodeUtil::InitTables() } char * -Fast_UnicodeUtil::utf8ncopy(char *dst, const ucs4_t *src, - int maxdst, int maxsrc) +Fast_UnicodeUtil::utf8ncopy(char *dst, const ucs4_t *src, int maxdst, int maxsrc) noexcept { char * p = dst; char * edst = dst + maxdst; @@ -83,7 +82,7 @@ Fast_UnicodeUtil::utf8ncopy(char *dst, const ucs4_t *src, int -Fast_UnicodeUtil::utf8cmp(const char *s1, const ucs4_t *s2) +Fast_UnicodeUtil::utf8cmp(const char *s1, const ucs4_t *s2) noexcept { ucs4_t i1; ucs4_t i2; @@ -101,7 +100,7 @@ Fast_UnicodeUtil::utf8cmp(const char *s1, const ucs4_t *s2) } size_t -Fast_UnicodeUtil::ucs4strlen(const ucs4_t *str) +Fast_UnicodeUtil::ucs4strlen(const ucs4_t *str) noexcept { const ucs4_t *p = str; while (*p++ != 0) { @@ -111,7 +110,7 @@ Fast_UnicodeUtil::ucs4strlen(const ucs4_t *str) } ucs4_t * -Fast_UnicodeUtil::ucs4copy(ucs4_t *dst, const char *src) +Fast_UnicodeUtil::ucs4copy(ucs4_t *dst, const char *src) noexcept { ucs4_t i; ucs4_t *p; @@ -127,7 +126,7 @@ Fast_UnicodeUtil::ucs4copy(ucs4_t *dst, const char *src) } ucs4_t -Fast_UnicodeUtil::GetUTF8CharNonAscii(unsigned const char *&src) +Fast_UnicodeUtil::GetUTF8CharNonAscii(unsigned const char *&src) noexcept { ucs4_t retval; @@ -222,7 +221,7 @@ Fast_UnicodeUtil::GetUTF8CharNonAscii(unsigned const char *&src) } ucs4_t -Fast_UnicodeUtil::GetUTF8Char(unsigned const char *&src) +Fast_UnicodeUtil::GetUTF8Char(unsigned const char *&src) noexcept { return (*src >= 0x80) ? GetUTF8CharNonAscii(src) @@ -246,7 +245,7 @@ Fast_UnicodeUtil::GetUTF8Char(unsigned const char *&src) #define UTF8_STARTCHAR(c) (!((c) & 0x80) || ((c) & 0x40)) int Fast_UnicodeUtil::UTF8move(unsigned const char* start, size_t length, - unsigned const char*& pos, off_t offset) + unsigned const char*& pos, off_t offset) noexcept { int increment = offset > 0 ? 1 : -1; unsigned const char* p = pos; diff --git a/vespalib/src/vespa/fastlib/text/unicodeutil.h b/vespalib/src/vespa/fastlib/text/unicodeutil.h index 87c09826948..740cc9381b7 100644 --- a/vespalib/src/vespa/fastlib/text/unicodeutil.h +++ b/vespalib/src/vespa/fastlib/text/unicodeutil.h @@ -16,7 +16,7 @@ using ucs4_t = uint32_t; * Used to examine properties of unicode characters, and * provide fast conversion methods between often used encodings. */ -class Fast_UnicodeUtil { +class Fast_UnicodeUtil final { private: /** * Is true when the tables have been initialized. Is set by @@ -46,9 +46,8 @@ private: }; public: - virtual ~Fast_UnicodeUtil() { } /** Initialize the ISO 8859-1 static tables. */ - static void InitTables(); + static void InitTables() noexcept; /** Indicates an invalid UTF-8 character sequence. */ enum { _BadUTF8Char = 0xfffffffeu }; @@ -64,7 +63,7 @@ public: * one or more of the properties alphabetic, ideographic, * combining char, decimal digit char, private use, extender. */ - static bool IsWordChar(ucs4_t testchar) { + static bool IsWordChar(ucs4_t testchar) noexcept { return (testchar < 65536 && (_compCharProps[testchar >> 8][testchar & 255] & _wordcharProp) != 0); @@ -80,8 +79,8 @@ public: * @return The next UCS4 character, or _BadUTF8Char if the * next character is invalid. */ - static ucs4_t GetUTF8Char(const unsigned char *& src); - static ucs4_t GetUTF8Char(const char *& src) { + static ucs4_t GetUTF8Char(const unsigned char *& src) noexcept; + static ucs4_t GetUTF8Char(const char *& src) noexcept { const unsigned char *temp = reinterpret_cast<const unsigned char *>(src); ucs4_t res = GetUTF8Char(temp); src = reinterpret_cast<const char *>(temp); @@ -94,7 +93,7 @@ public: * @param i The UCS4 character. * @return Pointer to the next position in dst after the putted byte(s). */ - static char *utf8cput(char *dst, ucs4_t i) { + static char *utf8cput(char *dst, ucs4_t i) noexcept { if (i < 128) *dst++ = i; else if (i < 0x800) { @@ -132,14 +131,14 @@ public: * @param src The UTF-8 source buffer. * @return A pointer to the destination string. */ - static ucs4_t *ucs4copy(ucs4_t *dst, const char *src); + static ucs4_t *ucs4copy(ucs4_t *dst, const char *src) noexcept; /** * Get the length of the UTF-8 representation of an UCS4 character. * @param i The UCS4 character. * @return The number of bytes required for the UTF-8 representation. */ - static size_t utf8clen(ucs4_t i) { + static size_t utf8clen(ucs4_t i) noexcept { if (i < 128) return 1; else if (i < 0x800) @@ -159,7 +158,7 @@ public: * @param testchar The character to lowercase. * @return The lowercase of the input, if defined. Else the input character. */ - static ucs4_t ToLower(ucs4_t testchar) + static ucs4_t ToLower(ucs4_t testchar) noexcept { ucs4_t ret; if (testchar < 65536) { @@ -182,14 +181,14 @@ public: * @return Number of bytes moved, or -1 if out of range */ static int UTF8move(unsigned const char* start, size_t length, - unsigned const char*& pos, off_t offset); + unsigned const char*& pos, off_t offset) noexcept; /** * Find the number of characters in an UCS4 string. * @param str The UCS4 string. * @return The number of characters. */ - static size_t ucs4strlen(const ucs4_t *str); + static size_t ucs4strlen(const ucs4_t *str) noexcept; /** * Convert UCS4 to UTF-8, bounded by max lengths. @@ -199,7 +198,7 @@ public: * @param maxsrc The maximum number of characters to convert from src. * @return A pointer to the destination. */ - static char *utf8ncopy(char *dst, const ucs4_t *src, int maxdst, int maxsrc); + static char *utf8ncopy(char *dst, const ucs4_t *src, int maxdst, int maxsrc) noexcept; /** @@ -210,7 +209,7 @@ public: * if s1 is, respectively, less than, matching, or greater than s2. * NB Only used in local test */ - static int utf8cmp(const char *s1, const ucs4_t *s2); + static int utf8cmp(const char *s1, const ucs4_t *s2) noexcept; /** * Test for terminal punctuation. @@ -218,7 +217,7 @@ public: * @return true if testchar is a terminal punctuation character, * i.e. if it has the terminal punctuation char property. */ - static bool IsTerminalPunctuationChar(ucs4_t testchar) { + static bool IsTerminalPunctuationChar(ucs4_t testchar) noexcept { return (testchar < 65536 && (_compCharProps[testchar >> 8][testchar & 255] & _terminalPunctuationCharProp) != 0); @@ -233,10 +232,10 @@ public: * @return The next UCS4 character, or _BadUTF8Char if the * next character is invalid. */ - static ucs4_t GetUTF8CharNonAscii(unsigned const char *&src); + static ucs4_t GetUTF8CharNonAscii(unsigned const char *&src) noexcept; // this is really an alias of the above function - static ucs4_t GetUTF8CharNonAscii(const char *&src) { + static ucs4_t GetUTF8CharNonAscii(const char *&src) noexcept { unsigned const char *temp = reinterpret_cast<unsigned const char *>(src); ucs4_t res = GetUTF8CharNonAscii(temp); src = reinterpret_cast<const char *>(temp); |