summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--client/go/internal/cli/cmd/feed.go25
-rw-r--r--client/go/internal/vespa/document/dispatcher.go21
-rw-r--r--client/go/internal/vespa/document/dispatcher_test.go20
-rw-r--r--client/go/internal/vespa/document/http.go15
-rw-r--r--client/go/internal/vespa/document/http_test.go23
-rw-r--r--searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp82
-rw-r--r--searchlib/src/vespa/searchlib/tensor/CMakeLists.txt1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/angular_distance.cpp1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp55
-rw-r--r--searchlib/src/vespa/searchlib/tensor/bound_distance_function.h48
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp7
-rw-r--r--searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp70
-rw-r--r--searchlib/src/vespa/searchlib/tensor/euclidean_distance.h12
-rw-r--r--searchlib/src/vespa/searchlib/tensor/temporary_vector_store.cpp57
-rw-r--r--searchlib/src/vespa/searchlib/tensor/temporary_vector_store.h32
15 files changed, 341 insertions, 128 deletions
diff --git a/client/go/internal/cli/cmd/feed.go b/client/go/internal/cli/cmd/feed.go
index 9d90233e62a..a6447ef8d2e 100644
--- a/client/go/internal/cli/cmd/feed.go
+++ b/client/go/internal/cli/cmd/feed.go
@@ -18,10 +18,11 @@ import (
func addFeedFlags(cmd *cobra.Command, options *feedOptions) {
cmd.PersistentFlags().IntVar(&options.connections, "connections", 8, "The number of connections to use")
cmd.PersistentFlags().StringVar(&options.compression, "compression", "auto", `Compression mode to use. Default is "auto" which compresses large documents. Must be "auto", "gzip" or "none"`)
+ cmd.PersistentFlags().IntVar(&options.timeoutSecs, "timeout", 0, "Invididual feed operation timeout in seconds. 0 to disable")
+ cmd.PersistentFlags().IntVar(&options.doomSecs, "max-failure-seconds", 0, "Exit if given number of seconds elapse without any successful operations. 0 to disable")
+ cmd.PersistentFlags().BoolVar(&options.verbose, "verbose", false, "Verbose mode. Print successful operations in addition to errors")
cmd.PersistentFlags().StringVar(&options.route, "route", "", "Target Vespa route for feed operations")
cmd.PersistentFlags().IntVar(&options.traceLevel, "trace", 0, "The trace level of network traffic. 0 to disable")
- cmd.PersistentFlags().IntVar(&options.timeoutSecs, "timeout", 0, "Feed operation timeout in seconds. 0 to disable")
- cmd.PersistentFlags().BoolVar(&options.verbose, "verbose", false, "Verbose mode. Print successful operations in addition to errors")
memprofile := "memprofile"
cpuprofile := "cpuprofile"
cmd.PersistentFlags().StringVar(&options.memprofile, memprofile, "", "Write a heap profile to given file")
@@ -38,28 +39,29 @@ type feedOptions struct {
verbose bool
traceLevel int
timeoutSecs int
- memprofile string
- cpuprofile string
+ doomSecs int
+
+ memprofile string
+ cpuprofile string
}
func newFeedCmd(cli *CLI) *cobra.Command {
var options feedOptions
cmd := &cobra.Command{
- Use: "feed FILE",
+ Use: "feed FILE [FILE]...",
Short: "Feed documents to a Vespa cluster",
Long: `Feed documents to a Vespa cluster.
-A high performance feeding client. This can be used to feed large amounts of
-documents to a Vespa cluster efficiently.
+This command can be used to feed large amounts of documents to a Vespa cluster
+efficiently.
The contents of FILE must be either a JSON array or JSON objects separated by
newline (JSONL).
If FILE is a single dash ('-'), documents will be read from standard input.
`,
- Example: `$ vespa feed documents.jsonl
-$ cat documents.jsonl | vespa feed -
-`,
+ Example: `$ vespa feed docs.jsonl moredocs.json
+$ cat docs.jsonl | vespa feed -`,
Args: cobra.MinimumNArgs(1),
DisableAutoGenTag: true,
SilenceUsage: true,
@@ -131,8 +133,7 @@ func feed(files []string, options feedOptions, cli *CLI) error {
NowFunc: cli.now,
}, clients)
throttler := document.NewThrottler(options.connections)
- // TODO(mpolden): Make doom duration configurable
- circuitBreaker := document.NewCircuitBreaker(10*time.Second, 0)
+ circuitBreaker := document.NewCircuitBreaker(10*time.Second, time.Duration(options.doomSecs)*time.Second)
dispatcher := document.NewDispatcher(client, throttler, circuitBreaker, cli.Stderr, options.verbose)
start := cli.now()
for _, name := range files {
diff --git a/client/go/internal/vespa/document/dispatcher.go b/client/go/internal/vespa/document/dispatcher.go
index 0f3d39d5a78..51e60e4e131 100644
--- a/client/go/internal/vespa/document/dispatcher.go
+++ b/client/go/internal/vespa/document/dispatcher.go
@@ -163,11 +163,15 @@ func (d *Dispatcher) enqueue(op documentOp) error {
}
d.mu.Unlock()
group.add(op, op.attempts > 0)
- d.enqueueWithSlot(group)
+ d.dispatch(op.document.Id, group)
return nil
}
-func (d *Dispatcher) enqueueWithSlot(group *documentGroup) {
+func (d *Dispatcher) dispatch(id Id, group *documentGroup) {
+ if !d.canDispatch() {
+ d.msgs <- fmt.Sprintf("refusing to dispatch document %s: too many errors", id)
+ return
+ }
d.acquireSlot()
d.workerWg.Add(1)
go func() {
@@ -177,6 +181,19 @@ func (d *Dispatcher) enqueueWithSlot(group *documentGroup) {
d.throttler.Sent()
}
+func (d *Dispatcher) canDispatch() bool {
+ switch d.circuitBreaker.State() {
+ case CircuitClosed:
+ return true
+ case CircuitHalfOpen:
+ time.Sleep(time.Second)
+ return true
+ case CircuitOpen:
+ return false
+ }
+ panic("invalid circuit state")
+}
+
func (d *Dispatcher) acquireSlot() {
for atomic.LoadInt64(&d.inflightCount) >= d.throttler.TargetInflight() {
time.Sleep(time.Millisecond)
diff --git a/client/go/internal/vespa/document/dispatcher_test.go b/client/go/internal/vespa/document/dispatcher_test.go
index c8f8e550ba4..2e2e9a5abbd 100644
--- a/client/go/internal/vespa/document/dispatcher_test.go
+++ b/client/go/internal/vespa/document/dispatcher_test.go
@@ -36,6 +36,12 @@ func (f *mockFeeder) Send(doc Document) Result {
return result
}
+type mockCircuitBreaker struct{ state CircuitState }
+
+func (c *mockCircuitBreaker) Success() {}
+func (c *mockCircuitBreaker) Error(err error) {}
+func (c *mockCircuitBreaker) State() CircuitState { return c.state }
+
func TestDispatcher(t *testing.T) {
feeder := &mockFeeder{}
clock := &manualClock{tick: time.Second}
@@ -131,6 +137,20 @@ func TestDispatcherOrderingWithFailures(t *testing.T) {
assert.Equal(t, 6, len(feeder.documents))
}
+func TestDispatcherOpenCircuit(t *testing.T) {
+ feeder := &mockFeeder{}
+ doc := Document{Id: mustParseId("id:ns:type::doc1"), Operation: OperationPut}
+ clock := &manualClock{tick: time.Second}
+ throttler := newThrottler(8, clock.now)
+ breaker := &mockCircuitBreaker{}
+ dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard, false)
+ dispatcher.Enqueue(doc)
+ breaker.state = CircuitOpen
+ dispatcher.Enqueue(doc)
+ dispatcher.Close()
+ assert.Equal(t, 1, len(feeder.documents))
+}
+
func BenchmarkDocumentDispatching(b *testing.B) {
feeder := &mockFeeder{}
clock := &manualClock{tick: time.Second}
diff --git a/client/go/internal/vespa/document/http.go b/client/go/internal/vespa/document/http.go
index 0530144747a..877bcc5edce 100644
--- a/client/go/internal/vespa/document/http.go
+++ b/client/go/internal/vespa/document/http.go
@@ -11,6 +11,7 @@ import (
"net/url"
"strconv"
"strings"
+ "sync"
"sync/atomic"
"time"
@@ -31,6 +32,7 @@ type Client struct {
httpClients []countingHTTPClient
now func() time.Time
sendCount int32
+ gzippers sync.Pool
}
// ClientOptions specifices the configuration options of a feed client.
@@ -78,11 +80,13 @@ func NewClient(options ClientOptions, httpClients []util.HTTPClient) *Client {
if nowFunc == nil {
nowFunc = time.Now
}
- return &Client{
+ c := &Client{
options: options,
httpClients: countingClients,
now: nowFunc,
}
+ c.gzippers.New = func() any { return gzip.NewWriter(io.Discard) }
+ return c
}
func (c *Client) queryParams() url.Values {
@@ -167,18 +171,25 @@ func (c *Client) leastBusyClient() *countingHTTPClient {
return &leastBusy
}
+func (c *Client) gzipWriter(w io.Writer) *gzip.Writer {
+ gzipWriter := c.gzippers.Get().(*gzip.Writer)
+ gzipWriter.Reset(w)
+ return gzipWriter
+}
+
func (c *Client) createRequest(method, url string, body []byte) (*http.Request, error) {
var r io.Reader
useGzip := c.options.Compression == CompressionGzip || (c.options.Compression == CompressionAuto && len(body) > 512)
if useGzip {
var buf bytes.Buffer
- w := gzip.NewWriter(&buf)
+ w := c.gzipWriter(&buf)
if _, err := w.Write(body); err != nil {
return nil, err
}
if err := w.Close(); err != nil {
return nil, err
}
+ c.gzippers.Put(w)
r = &buf
} else {
r = bytes.NewReader(body)
diff --git a/client/go/internal/vespa/document/http_test.go b/client/go/internal/vespa/document/http_test.go
index 314113c53be..f67368b5128 100644
--- a/client/go/internal/vespa/document/http_test.go
+++ b/client/go/internal/vespa/document/http_test.go
@@ -293,3 +293,26 @@ func TestClientFeedURL(t *testing.T) {
}
}
}
+
+func benchmarkClientSend(b *testing.B, compression Compression, document Document) {
+ httpClient := mock.HTTPClient{}
+ client := NewClient(ClientOptions{
+ Compression: compression,
+ BaseURL: "https://example.com:1337",
+ Timeout: time.Duration(5 * time.Second),
+ }, []util.HTTPClient{&httpClient})
+ b.ResetTimer() // ignore setup
+ for n := 0; n < b.N; n++ {
+ client.Send(document)
+ }
+}
+
+func BenchmarkClientSend(b *testing.B) {
+ doc := Document{Create: true, Id: mustParseId("id:ns:type::doc1"), Operation: OperationUpdate, Body: []byte(`{"fields":{"foo": "my document"}}`)}
+ benchmarkClientSend(b, CompressionNone, doc)
+}
+
+func BenchmarkClientSendCompressed(b *testing.B) {
+ doc := Document{Create: true, Id: mustParseId("id:ns:type::doc1"), Operation: OperationUpdate, Body: []byte(`{"fields":{"foo": "my document"}}`)}
+ benchmarkClientSend(b, CompressionGzip, doc)
+}
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
index 86b83b2c651..ae283f3f2b2 100644
--- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
+++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
@@ -44,6 +44,30 @@ void verify_geo_miles(const DistanceFunction *dist_fun,
}
}
+double computeEuclideanChecked(TypedCells a, TypedCells b) {
+ static EuclideanDistanceFunctionFactory<Int8Float> i8f_dff;
+ static EuclideanDistanceFunctionFactory<float> flt_dff;
+ static EuclideanDistanceFunctionFactory<double> dbl_dff;
+ auto d_n = dbl_dff.for_query_vector(a);
+ auto d_f = flt_dff.for_query_vector(a);
+ auto d_r = dbl_dff.for_query_vector(b);
+ auto d_i = dbl_dff.for_insertion_vector(a);
+ // normal:
+ double result = d_n->calc(b);
+ // insert is exactly same:
+ EXPECT_EQ(d_i->calc(b), result);
+ // reverse:
+ EXPECT_DOUBLE_EQ(d_r->calc(a), result);
+ // float factory:
+ EXPECT_FLOAT_EQ(d_f->calc(b), result);
+ if (a.type == vespalib::eval::CellType::INT8 ||
+ b.type == vespalib::eval::CellType::INT8)
+ {
+ auto d_8 = i8f_dff.for_query_vector(a);
+ EXPECT_DOUBLE_EQ(d_8->calc(b), result);
+ }
+ return result;
+}
TEST(DistanceFunctionsTest, euclidean_gives_expected_score)
{
@@ -59,15 +83,56 @@ TEST(DistanceFunctionsTest, euclidean_gives_expected_score)
std::vector<double> p5{0.0,-1.0, 0.0};
std::vector<double> p6{1.0, 2.0, 2.0};
- double n4 = euclid->calc(t(p0), t(p4));
+ double n4 = computeEuclideanChecked(t(p0), t(p4));
EXPECT_FLOAT_EQ(n4, 1.0);
- double d12 = euclid->calc(t(p1), t(p2));
+ double d12 = computeEuclideanChecked(t(p1), t(p2));
EXPECT_EQ(d12, 2.0);
EXPECT_DOUBLE_EQ(euclid->to_rawscore(d12), 1.0/(1.0 + sqrt(2.0)));
double threshold = euclid->convert_threshold(8.0);
EXPECT_EQ(threshold, 64.0);
threshold = euclid->convert_threshold(0.5);
EXPECT_EQ(threshold, 0.25);
+
+ // simple hand-checked distances:
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p0)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p1)), 1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p2)), 1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p3)), 1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p5)), 1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p6)), 9.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p1)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p2)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p3)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p5)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p6)), 8.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p2), t(p2)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p2), t(p3)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p2), t(p5)), 4.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p2), t(p6)), 6.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p3), t(p3)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p3), t(p5)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p3), t(p6)), 6.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p5), t(p5)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p5), t(p6)), 14.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p6), t(p6)), 0.0);
+
+ // smoke test for bfloat16:
+ std::vector<vespalib::BFloat16> bf16v;
+ bf16v.emplace_back(1.0);
+ bf16v.emplace_back(1.0);
+ bf16v.emplace_back(1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p0)), 3.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p1)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p2)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p3)), 2.0);
+ EXPECT_FLOAT_EQ(computeEuclideanChecked(t(bf16v), t(p4)), 0.5857863);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p5)), 6.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p6)), 2.0);
}
TEST(DistanceFunctionsTest, euclidean_int8_smoketest)
@@ -81,14 +146,13 @@ TEST(DistanceFunctionsTest, euclidean_int8_smoketest)
std::vector<Int8Float> p5{0.0,-1.0, 0.0};
std::vector<Int8Float> p7{-1.0, 2.0, -2.0};
- EXPECT_DOUBLE_EQ(1.0, euclid->calc(t(p0), t(p1)));
- EXPECT_DOUBLE_EQ(1.0, euclid->calc(t(p0), t(p5)));
- EXPECT_DOUBLE_EQ(9.0, euclid->calc(t(p0), t(p7)));
-
- EXPECT_DOUBLE_EQ(2.0, euclid->calc(t(p1), t(p5)));
- EXPECT_DOUBLE_EQ(12.0, euclid->calc(t(p1), t(p7)));
- EXPECT_DOUBLE_EQ(14.0, euclid->calc(t(p5), t(p7)));
+ EXPECT_DOUBLE_EQ(1.0, computeEuclideanChecked(t(p0), t(p1)));
+ EXPECT_DOUBLE_EQ(1.0, computeEuclideanChecked(t(p0), t(p5)));
+ EXPECT_DOUBLE_EQ(9.0, computeEuclideanChecked(t(p0), t(p7)));
+ EXPECT_DOUBLE_EQ(2.0, computeEuclideanChecked(t(p1), t(p5)));
+ EXPECT_DOUBLE_EQ(12.0, computeEuclideanChecked(t(p1), t(p7)));
+ EXPECT_DOUBLE_EQ(14.0, computeEuclideanChecked(t(p5), t(p7)));
}
double computeAngularChecked(TypedCells a, TypedCells b) {
diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
index 090042e5b83..1783e0da1dd 100644
--- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
@@ -34,6 +34,7 @@ vespa_add_library(searchlib_tensor OBJECT
serialized_tensor_ref.cpp
small_subspaces_buffer_type.cpp
subspace_type.cpp
+ temporary_vector_store.cpp
tensor_attribute.cpp
tensor_attribute_loader.cpp
tensor_attribute_saver.cpp
diff --git a/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp b/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp
index 5101373c047..85eac76728c 100644
--- a/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "angular_distance.h"
+#include "temporary_vector_store.h"
using vespalib::typify_invoke;
using vespalib::eval::TypifyCellType;
diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp
index 56edbf9fede..33b94e5218c 100644
--- a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.cpp
@@ -1,58 +1,3 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "bound_distance_function.h"
-#include <vespa/log/log.h>
-
-LOG_SETUP(".searchlib.tensor.bound_distance_function");
-
-using vespalib::ConstArrayRef;
-using vespalib::ArrayRef;
-using vespalib::eval::CellType;
-using vespalib::eval::TypedCells;
-
-namespace search::tensor {
-
-namespace {
-
-template<typename FromType, typename ToType>
-ConstArrayRef<ToType>
-convert_cells(ArrayRef<ToType> space, TypedCells cells)
-{
- assert(cells.size == space.size());
- auto old_cells = cells.typify<FromType>();
- ToType *p = space.data();
- for (FromType value : old_cells) {
- ToType conv(value);
- *p++ = conv;
- }
- return space;
-}
-
-template <typename ToType>
-struct ConvertCellsSelector
-{
- template <typename FromType> static auto invoke(ArrayRef<ToType> dst, TypedCells src) {
- return convert_cells<FromType, ToType>(dst, src);
- }
-};
-
-} // namespace
-
-template <typename FloatType>
-ConstArrayRef<FloatType>
-TemporaryVectorStore<FloatType>::internal_convert(TypedCells cells, size_t offset) {
- LOG_ASSERT(cells.size * 2 == _tmpSpace.size());
- ArrayRef<FloatType> where(_tmpSpace.data() + offset, cells.size);
- using MyTypify = vespalib::eval::TypifyCellType;
- using MySelector = ConvertCellsSelector<FloatType>;
- ConstArrayRef<FloatType> result = vespalib::typify_invoke<1,MyTypify,MySelector>(cells.type, where, cells);
- return result;
-}
-
-template class TemporaryVectorStore<float>;
-template class TemporaryVectorStore<double>;
-
-template class ConvertingBoundDistance<float>;
-template class ConvertingBoundDistance<double>;
-
-}
diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
index 0f51e8a33ef..5d602a52227 100644
--- a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
+++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
@@ -5,7 +5,6 @@
#include <memory>
#include <vespa/eval/eval/cell_type.h>
#include <vespa/eval/eval/typed_cells.h>
-#include <vespa/vespalib/util/array.h>
#include <vespa/vespalib/util/arrayref.h>
#include "distance_function.h"
@@ -43,51 +42,4 @@ public:
double limit) const = 0;
};
-
-/** helper class - temporary storage of possibly-converted vector cells */
-template <typename FloatType>
-class TemporaryVectorStore {
-private:
- vespalib::Array<FloatType> _tmpSpace;
- vespalib::ConstArrayRef<FloatType> internal_convert(vespalib::eval::TypedCells cells, size_t offset);
-public:
- TemporaryVectorStore(size_t vectorSize) : _tmpSpace(vectorSize * 2) {}
- vespalib::ConstArrayRef<FloatType> storeLhs(vespalib::eval::TypedCells cells) {
- return internal_convert(cells, 0);
- }
- vespalib::ConstArrayRef<FloatType> convertRhs(vespalib::eval::TypedCells cells) {
- if (vespalib::eval::get_cell_type<FloatType>() == cells.type) [[likely]] {
- return cells.unsafe_typify<FloatType>();
- } else {
- return internal_convert(cells, cells.size);
- }
- }
-};
-
-template<typename FloatType>
-class ConvertingBoundDistance : public BoundDistanceFunction {
- mutable TemporaryVectorStore<FloatType> _tmpSpace;
- const vespalib::eval::TypedCells _lhs;
- const DistanceFunction &_df;
-public:
- ConvertingBoundDistance(const vespalib::eval::TypedCells& lhs, const DistanceFunction &df)
- : BoundDistanceFunction(vespalib::eval::get_cell_type<FloatType>()),
- _tmpSpace(lhs.size),
- _lhs(_tmpSpace.storeLhs(lhs)),
- _df(df)
- {}
- double calc(const vespalib::eval::TypedCells& rhs) const override {
- return _df.calc(_lhs, vespalib::eval::TypedCells(_tmpSpace.convertRhs(rhs)));
- }
- double convert_threshold(double threshold) const override {
- return _df.convert_threshold(threshold);
- }
- double to_rawscore(double distance) const override {
- return _df.to_rawscore(distance);
- }
- double calc_with_limit(const vespalib::eval::TypedCells& rhs, double limit) const override {
- return _df.calc_with_limit(_lhs, vespalib::eval::TypedCells(_tmpSpace.convertRhs(rhs)), limit);
- }
-};
-
}
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
index 7ccca655943..cca492ef212 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
@@ -100,6 +100,13 @@ make_distance_function_factory(search::attribute::DistanceMetric variant,
}
return std::make_unique<AngularDistanceFunctionFactory<float>>();
}
+ if (variant == DistanceMetric::Euclidean) {
+ switch (cell_type) {
+ case CellType::DOUBLE: return std::make_unique<EuclideanDistanceFunctionFactory<double>>();
+ case CellType::INT8: return std::make_unique<EuclideanDistanceFunctionFactory<vespalib::eval::Int8Float>>();
+ default: return std::make_unique<EuclideanDistanceFunctionFactory<float>>();
+ }
+ }
auto df = make_distance_function(variant, cell_type);
return std::make_unique<SimpleDistanceFunctionFactory>(std::move(df));
}
diff --git a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
index c83f1821321..92d4e7af406 100644
--- a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "euclidean_distance.h"
+#include "temporary_vector_store.h"
using vespalib::typify_invoke;
using vespalib::eval::TypifyCellType;
@@ -48,4 +49,73 @@ SquaredEuclideanDistance::calc_with_limit(const vespalib::eval::TypedCells& lhs,
template class SquaredEuclideanDistanceHW<float>;
template class SquaredEuclideanDistanceHW<double>;
+using vespalib::eval::Int8Float;
+
+template<typename FloatType>
+class BoundEuclideanDistance : public BoundDistanceFunction {
+private:
+ const vespalib::hwaccelrated::IAccelrated & _computer;
+ mutable TemporaryVectorStore<FloatType> _tmpSpace;
+ const vespalib::ConstArrayRef<FloatType> _lhs_vector;
+ static const double *cast(const double * p) { return p; }
+ static const float *cast(const float * p) { return p; }
+ static const int8_t *cast(const Int8Float * p) { return reinterpret_cast<const int8_t *>(p); }
+public:
+ BoundEuclideanDistance(const vespalib::eval::TypedCells& lhs)
+ : BoundDistanceFunction(vespalib::eval::get_cell_type<FloatType>()),
+ _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()),
+ _tmpSpace(lhs.size),
+ _lhs_vector(_tmpSpace.storeLhs(lhs))
+ {}
+ double calc(const vespalib::eval::TypedCells& rhs) const override {
+ size_t sz = _lhs_vector.size();
+ vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs);
+ assert(sz == rhs_vector.size());
+ auto a = &_lhs_vector[0];
+ auto b = &rhs_vector[0];
+ return _computer.squaredEuclideanDistance(cast(a), cast(b), sz);
+ }
+ double convert_threshold(double threshold) const override {
+ return threshold*threshold;
+ }
+ double to_rawscore(double distance) const override {
+ double d = sqrt(distance);
+ double score = 1.0 / (1.0 + d);
+ return score;
+ }
+ double calc_with_limit(const vespalib::eval::TypedCells& rhs, double limit) const override {
+ vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs);
+ double sum = 0.0;
+ size_t sz = _lhs_vector.size();
+ assert(sz == rhs_vector.size());
+ for (size_t i = 0; i < sz && sum <= limit; ++i) {
+ double diff = _lhs_vector[i] - rhs_vector[i];
+ sum += diff*diff;
+ }
+ return sum;
+ }
+};
+
+template class BoundEuclideanDistance<Int8Float>;
+template class BoundEuclideanDistance<float>;
+template class BoundEuclideanDistance<double>;
+
+template <typename FloatType>
+BoundDistanceFunction::UP
+EuclideanDistanceFunctionFactory<FloatType>::for_query_vector(const vespalib::eval::TypedCells& lhs) {
+ using DFT = BoundEuclideanDistance<FloatType>;
+ return std::make_unique<DFT>(lhs);
+}
+
+template <typename FloatType>
+BoundDistanceFunction::UP
+EuclideanDistanceFunctionFactory<FloatType>::for_insertion_vector(const vespalib::eval::TypedCells& lhs) {
+ using DFT = BoundEuclideanDistance<FloatType>;
+ return std::make_unique<DFT>(lhs);
+}
+
+template class EuclideanDistanceFunctionFactory<Int8Float>;
+template class EuclideanDistanceFunctionFactory<float>;
+template class EuclideanDistanceFunctionFactory<double>;
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h
index 6505ea119ea..b406f0d3d1a 100644
--- a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h
+++ b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h
@@ -3,6 +3,7 @@
#pragma once
#include "distance_function.h"
+#include "distance_function_factory.h"
#include <vespa/eval/eval/typed_cells.h>
#include <vespa/vespalib/hwaccelrated/iaccelrated.h>
#include <cmath>
@@ -78,4 +79,15 @@ private:
const vespalib::hwaccelrated::IAccelrated & _computer;
};
+
+template <typename FloatType>
+class EuclideanDistanceFunctionFactory : public DistanceFunctionFactory {
+public:
+ EuclideanDistanceFunctionFactory()
+ : DistanceFunctionFactory(vespalib::eval::get_cell_type<FloatType>())
+ {}
+ BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override;
+ BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override;
+};
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.cpp b/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.cpp
new file mode 100644
index 00000000000..cc45f857d9f
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.cpp
@@ -0,0 +1,57 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "temporary_vector_store.h"
+
+#include <vespa/log/log.h>
+
+LOG_SETUP(".searchlib.tensor.temporary_vector_store");
+
+using vespalib::ConstArrayRef;
+using vespalib::ArrayRef;
+using vespalib::eval::CellType;
+using vespalib::eval::TypedCells;
+
+namespace search::tensor {
+
+namespace {
+
+template<typename FromType, typename ToType>
+ConstArrayRef<ToType>
+convert_cells(ArrayRef<ToType> space, TypedCells cells)
+{
+ assert(cells.size == space.size());
+ auto old_cells = cells.typify<FromType>();
+ ToType *p = space.data();
+ for (FromType value : old_cells) {
+ ToType conv(value);
+ *p++ = conv;
+ }
+ return space;
+}
+
+template <typename ToType>
+struct ConvertCellsSelector
+{
+ template <typename FromType> static auto invoke(ArrayRef<ToType> dst, TypedCells src) {
+ return convert_cells<FromType, ToType>(dst, src);
+ }
+};
+
+} // namespace
+
+template <typename FloatType>
+ConstArrayRef<FloatType>
+TemporaryVectorStore<FloatType>::internal_convert(TypedCells cells, size_t offset) {
+ LOG_ASSERT(cells.size * 2 == _tmpSpace.size());
+ ArrayRef<FloatType> where(_tmpSpace.data() + offset, cells.size);
+ using MyTypify = vespalib::eval::TypifyCellType;
+ using MySelector = ConvertCellsSelector<FloatType>;
+ ConstArrayRef<FloatType> result = vespalib::typify_invoke<1,MyTypify,MySelector>(cells.type, where, cells);
+ return result;
+}
+
+template class TemporaryVectorStore<vespalib::eval::Int8Float>;
+template class TemporaryVectorStore<float>;
+template class TemporaryVectorStore<double>;
+
+}
diff --git a/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.h b/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.h
new file mode 100644
index 00000000000..cd816621f91
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.h
@@ -0,0 +1,32 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <memory>
+#include <vespa/eval/eval/cell_type.h>
+#include <vespa/eval/eval/typed_cells.h>
+#include <vespa/vespalib/util/arrayref.h>
+
+namespace search::tensor {
+
+/** helper class - temporary storage of possibly-converted vector cells */
+template <typename FloatType>
+class TemporaryVectorStore {
+private:
+ std::vector<FloatType> _tmpSpace;
+ vespalib::ConstArrayRef<FloatType> internal_convert(vespalib::eval::TypedCells cells, size_t offset);
+public:
+ TemporaryVectorStore(size_t vectorSize) : _tmpSpace(vectorSize * 2) {}
+ vespalib::ConstArrayRef<FloatType> storeLhs(vespalib::eval::TypedCells cells) {
+ return internal_convert(cells, 0);
+ }
+ vespalib::ConstArrayRef<FloatType> convertRhs(vespalib::eval::TypedCells cells) {
+ if (vespalib::eval::get_cell_type<FloatType>() == cells.type) [[likely]] {
+ return cells.unsafe_typify<FloatType>();
+ } else {
+ return internal_convert(cells, cells.size);
+ }
+ }
+};
+
+}