aboutsummaryrefslogtreecommitdiffstats
path: root/document/src
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-10-14 14:36:58 +0000
committerArne Juul <arnej@verizonmedia.com>2020-10-14 15:05:50 +0000
commit3d89b1eaf56c686a9e795f3a544050173ee503e3 (patch)
tree1bda3b37f47f2944d6f62262c55d8e8a3fa587f6 /document/src
parent3d2645d8593874be4da8e5f73cd5a7e2cecfd399 (diff)
fix TensorFieldValue::compare to be correct (but slow)
Diffstat (limited to 'document/src')
-rw-r--r--document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp23
1 files changed, 18 insertions, 5 deletions
diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp
index 56d7b6ab078..c3b593732b9 100644
--- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp
+++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp
@@ -4,6 +4,7 @@
#include <vespa/document/base/exceptions.h>
#include <vespa/document/datatype/tensor_data_type.h>
#include <vespa/vespalib/util/xmlstream.h>
+#include <vespa/eval/eval/engine_or_factory.h>
#include <vespa/eval/eval/tensor_spec.h>
#include <vespa/eval/tensor/tensor.h>
#include <vespa/eval/tensor/default_tensor_engine.h>
@@ -11,6 +12,7 @@
#include <cassert>
using vespalib::tensor::Tensor;
+using vespalib::eval::EngineOrFactory;
using vespalib::eval::TensorSpec;
using vespalib::eval::ValueType;
using Engine = vespalib::tensor::DefaultTensorEngine;
@@ -218,13 +220,24 @@ TensorFieldValue::compare(const FieldValue &other) const
if (!rhs._tensor) {
return 1;
}
- if (_tensor->equals(*rhs._tensor)) {
+ // equal pointers always means identical
+ if (_tensor.get() == rhs._tensor.get()) {
return 0;
}
- assert(_tensor.get() != rhs._tensor.get());
- // XXX: Wrong, compares identity of tensors instead of values
- // Note: sorting can be dangerous due to this.
- return ((_tensor.get() < rhs._tensor.get()) ? -1 : 1);
+ // compare just the type first:
+ auto lhs_type = _tensor->type().to_spec();
+ auto rhs_type = rhs._tensor->type().to_spec();
+ int type_cmp = lhs_type.compare(rhs_type);
+ if (type_cmp != 0) {
+ return type_cmp;
+ }
+ // Compare the actual tensors by converting to TensorSpec strings.
+ // TODO: this can be very slow, check if it might be used for anything
+ // performance-critical.
+ auto engine = EngineOrFactory::get();
+ auto lhs_spec = engine.to_spec(*_tensor).to_string();
+ auto rhs_spec = engine.to_spec(*rhs._tensor).to_string();
+ return lhs_spec.compare(rhs_spec);
}
IMPLEMENT_IDENTIFIABLE(TensorFieldValue, FieldValue);