diff options
author | Håvard Pettersen <havardpe@oath.com> | 2017-11-03 14:16:33 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2017-11-03 15:15:02 +0000 |
commit | 21868a8bb71720ad277d706ed6f08400ebfeb497 (patch) | |
tree | 83ccb7f593829dadd3ed631c46671203c8aa7cf1 /eval/src | |
parent | 4d5f2f54588e1d9888e4bd491361d55f5b1ed9da (diff) |
remove 'equal' concept for tensors
Diffstat (limited to 'eval/src')
-rw-r--r-- | eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp | 49 | ||||
-rw-r--r-- | eval/src/tests/eval/tensor_function/tensor_function_test.cpp | 4 | ||||
-rw-r--r-- | eval/src/tests/eval/value_cache/tensor_loader_test.cpp | 4 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/simple_tensor.cpp | 27 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/simple_tensor.h | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/simple_tensor_engine.cpp | 6 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/simple_tensor_engine.h | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/tensor.cpp | 5 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/tensor_engine.h | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/test/tensor_conformance.cpp | 59 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/value.cpp | 6 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/value.h | 6 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/default_tensor_engine.cpp | 10 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/default_tensor_engine.h | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp | 8 |
15 files changed, 27 insertions, 161 deletions
diff --git a/eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp b/eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp index beec4ed928b..150b86f27ce 100644 --- a/eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp +++ b/eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp @@ -13,8 +13,7 @@ using Cells = SimpleTensor::Cells; using Address = SimpleTensor::Address; using Stash = vespalib::Stash; -// need to specify numbers explicitly as size_t to avoid ambiguous behavior for 0 -constexpr size_t operator "" _z (unsigned long long int n) { return n; } +TensorSpec to_spec(const Tensor &a) { return a.engine().to_spec(a); } const Tensor &unwrap(const Value &value) { ASSERT_TRUE(value.is_tensor()); @@ -56,28 +55,8 @@ TEST("require that simple tensors can be built using tensor spec") { .add({{"w", "yyy"}, {"x", 1}, {"y", "yyy"}, {"z", 0}}, 0.0) .add({{"w", "yyy"}, {"x", 1}, {"y", "yyy"}, {"z", 1}}, 4.0); auto full_tensor = SimpleTensorEngine::ref().create(full_spec); - SimpleTensor expect_tensor(ValueType::from_spec("tensor(w{},x[2],y{},z[2])"), - CellBuilder() - .add({{"xxx"}, {0_z}, {"xxx"}, {0_z}}, 1.0) - .add({{"xxx"}, {0_z}, {"xxx"}, {1_z}}, 0.0) - .add({{"xxx"}, {0_z}, {"yyy"}, {0_z}}, 0.0) - .add({{"xxx"}, {0_z}, {"yyy"}, {1_z}}, 2.0) - .add({{"xxx"}, {1_z}, {"xxx"}, {0_z}}, 0.0) - .add({{"xxx"}, {1_z}, {"xxx"}, {1_z}}, 0.0) - .add({{"xxx"}, {1_z}, {"yyy"}, {0_z}}, 0.0) - .add({{"xxx"}, {1_z}, {"yyy"}, {1_z}}, 0.0) - .add({{"yyy"}, {0_z}, {"xxx"}, {0_z}}, 0.0) - .add({{"yyy"}, {0_z}, {"xxx"}, {1_z}}, 0.0) - .add({{"yyy"}, {0_z}, {"yyy"}, {0_z}}, 0.0) - .add({{"yyy"}, {0_z}, {"yyy"}, {1_z}}, 0.0) - .add({{"yyy"}, {1_z}, {"xxx"}, {0_z}}, 3.0) - .add({{"yyy"}, {1_z}, {"xxx"}, {1_z}}, 0.0) - .add({{"yyy"}, {1_z}, {"yyy"}, {0_z}}, 0.0) - .add({{"yyy"}, {1_z}, {"yyy"}, {1_z}}, 4.0) - .build()); - EXPECT_EQUAL(expect_tensor, *tensor); - EXPECT_EQUAL(expect_tensor, *full_tensor); - EXPECT_EQUAL(full_spec, tensor->engine().to_spec(*tensor)); + EXPECT_EQUAL(full_spec, to_spec(*tensor)); + EXPECT_EQUAL(full_spec, to_spec(*full_tensor)); }; TEST("require that simple tensors can have their values negated") { @@ -92,10 +71,10 @@ TEST("require that simple tensors can have their values negated") { .add({{"x","2"},{"y","1"}}, 3) .add({{"x","1"},{"y","2"}}, -5)); auto result = tensor->map([](double a){ return -a; }); - EXPECT_EQUAL(*expect, *result); + EXPECT_EQUAL(to_spec(*expect), to_spec(*result)); Stash stash; const Value &result2 = SimpleTensorEngine::ref().map(TensorValue(*tensor), operation::Neg::f, stash); - EXPECT_EQUAL(*expect, unwrap(result2)); + EXPECT_EQUAL(to_spec(*expect), to_spec(unwrap(result2))); } TEST("require that simple tensors can be multiplied with each other") { @@ -117,10 +96,10 @@ TEST("require that simple tensors can be multiplied with each other") { .add({{"x","2"},{"y","1"},{"z","2"}}, 39) .add({{"x","1"},{"y","2"},{"z","1"}}, 55)); auto result = SimpleTensor::join(*lhs, *rhs, [](double a, double b){ return (a * b); }); - EXPECT_EQUAL(*expect, *result); + EXPECT_EQUAL(to_spec(*expect), to_spec(*result)); Stash stash; const Value &result2 = SimpleTensorEngine::ref().join(TensorValue(*lhs), TensorValue(*rhs), operation::Mul::f, stash); - EXPECT_EQUAL(*expect, unwrap(result2)); + EXPECT_EQUAL(to_spec(*expect), to_spec(unwrap(result2))); } TEST("require that simple tensors support dimension reduction") { @@ -147,21 +126,21 @@ TEST("require that simple tensors support dimension reduction") { auto result_sum_y = tensor->reduce(aggr_sum, {"y"}); auto result_sum_x = tensor->reduce(aggr_sum, {"x"}); auto result_sum_all = tensor->reduce(aggr_sum, {"x", "y"}); - EXPECT_EQUAL(*expect_sum_y, *result_sum_y); - EXPECT_EQUAL(*expect_sum_x, *result_sum_x); - EXPECT_EQUAL(*expect_sum_all, *result_sum_all); + EXPECT_EQUAL(to_spec(*expect_sum_y), to_spec(*result_sum_y)); + EXPECT_EQUAL(to_spec(*expect_sum_x), to_spec(*result_sum_x)); + EXPECT_EQUAL(to_spec(*expect_sum_all), to_spec(*result_sum_all)); const Value &result_sum_y_2 = SimpleTensorEngine::ref().reduce(TensorValue(*tensor), Aggr::SUM, {"y"}, stash); const Value &result_sum_x_2 = SimpleTensorEngine::ref().reduce(TensorValue(*tensor), Aggr::SUM, {"x"}, stash); const Value &result_sum_all_2 = SimpleTensorEngine::ref().reduce(TensorValue(*tensor), Aggr::SUM, {"x", "y"}, stash); const Value &result_sum_all_3 = SimpleTensorEngine::ref().reduce(TensorValue(*tensor), Aggr::SUM, {}, stash); - EXPECT_EQUAL(*expect_sum_y, unwrap(result_sum_y_2)); - EXPECT_EQUAL(*expect_sum_x, unwrap(result_sum_x_2)); + EXPECT_EQUAL(to_spec(*expect_sum_y), to_spec(unwrap(result_sum_y_2))); + EXPECT_EQUAL(to_spec(*expect_sum_x), to_spec(unwrap(result_sum_x_2))); EXPECT_TRUE(result_sum_all_2.is_double()); EXPECT_TRUE(result_sum_all_3.is_double()); EXPECT_EQUAL(21, result_sum_all_2.as_double()); EXPECT_EQUAL(21, result_sum_all_3.as_double()); - EXPECT_EQUAL(*result_sum_y, *result_sum_y); - EXPECT_NOT_EQUAL(*result_sum_y, *result_sum_x); + EXPECT_EQUAL(to_spec(*result_sum_y), to_spec(*result_sum_y)); + EXPECT_NOT_EQUAL(to_spec(*result_sum_y), to_spec(*result_sum_x)); } TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp index 5b2d0848f64..8bd86621bf6 100644 --- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp +++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp @@ -102,7 +102,9 @@ void verify_equal(const Tensor &expect, const Value &value) { const Tensor *tensor = value.as_tensor(); ASSERT_TRUE(tensor != nullptr); ASSERT_EQUAL(&expect.engine(), &tensor->engine()); - EXPECT_TRUE(expect.engine().equal(expect, *tensor)); + auto expect_spec = expect.engine().to_spec(expect); + auto value_spec = tensor->engine().to_spec(*tensor); + EXPECT_EQUAL(expect_spec, value_spec); } TEST("require that tensor injection works") { diff --git a/eval/src/tests/eval/value_cache/tensor_loader_test.cpp b/eval/src/tests/eval/value_cache/tensor_loader_test.cpp index 20a77eb9fe3..ee8e502815f 100644 --- a/eval/src/tests/eval/value_cache/tensor_loader_test.cpp +++ b/eval/src/tests/eval/value_cache/tensor_loader_test.cpp @@ -42,8 +42,8 @@ std::unique_ptr<Tensor> make_mixed_tensor() { void verify_tensor(std::unique_ptr<Tensor> expect, ConstantValue::UP actual) { const auto &engine = expect->engine(); ASSERT_EQUAL(engine.type_of(*expect), actual->type()); - EXPECT_TRUE(&engine == &actual->value().as_tensor()->engine()); - EXPECT_TRUE(engine.equal(*expect, *actual->value().as_tensor())); + ASSERT_TRUE(&engine == &actual->value().as_tensor()->engine()); + EXPECT_EQUAL(engine.to_spec(*expect), engine.to_spec(*actual->value().as_tensor())); } TEST_F("require that invalid types loads an empty double", ConstantTensorLoader(SimpleTensorEngine::ref())) { diff --git a/eval/src/vespa/eval/eval/simple_tensor.cpp b/eval/src/vespa/eval/eval/simple_tensor.cpp index 75c170d48ba..e39e926708d 100644 --- a/eval/src/vespa/eval/eval/simple_tensor.cpp +++ b/eval/src/vespa/eval/eval/simple_tensor.cpp @@ -611,33 +611,6 @@ SimpleTensor::create(const TensorSpec &spec) return builder.build(); } -bool -SimpleTensor::equal(const SimpleTensor &a, const SimpleTensor &b) -{ - if (a.type() != b.type()) { - return false; - } - TypeAnalyzer type_info(a.type(), b.type()); - View view_a(a, type_info.overlap_a); - View view_b(b, type_info.overlap_b); - const CellRef *pos_a = view_a.refs_begin(); - const CellRef *end_a = view_a.refs_end(); - const CellRef *pos_b = view_b.refs_begin(); - const CellRef *end_b = view_b.refs_end(); - ViewMatcher::CrossCompare cmp(view_a.selector(), view_b.selector()); - while ((pos_a != end_a) && (pos_b != end_b)) { - if (cmp.compare(pos_a->get(), pos_b->get()) != ViewMatcher::CrossCompare::Result::EQUAL) { - return false; - } - if (pos_a->get().value != pos_b->get().value) { - return false; - } - ++pos_a; - ++pos_b; - } - return ((pos_a == end_a) && (pos_b == end_b)); -} - std::unique_ptr<SimpleTensor> SimpleTensor::join(const SimpleTensor &a, const SimpleTensor &b, join_fun_t function) { diff --git a/eval/src/vespa/eval/eval/simple_tensor.h b/eval/src/vespa/eval/eval/simple_tensor.h index ec154ff969a..366796f00d8 100644 --- a/eval/src/vespa/eval/eval/simple_tensor.h +++ b/eval/src/vespa/eval/eval/simple_tensor.h @@ -88,7 +88,6 @@ public: std::unique_ptr<SimpleTensor> reduce(Aggregator &aggr, const std::vector<vespalib::string> &dimensions) const; std::unique_ptr<SimpleTensor> rename(const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to) const; static std::unique_ptr<SimpleTensor> create(const TensorSpec &spec); - static bool equal(const SimpleTensor &a, const SimpleTensor &b); static std::unique_ptr<SimpleTensor> join(const SimpleTensor &a, const SimpleTensor &b, join_fun_t function); static std::unique_ptr<SimpleTensor> concat(const SimpleTensor &a, const SimpleTensor &b, const vespalib::string &dimension); static void encode(const SimpleTensor &tensor, nbostream &output); diff --git a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp index d69715cab22..21498ca2ff1 100644 --- a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp +++ b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp @@ -47,12 +47,6 @@ SimpleTensorEngine::type_of(const Tensor &tensor) const return to_simple(tensor).type(); } -bool -SimpleTensorEngine::equal(const Tensor &a, const Tensor &b) const -{ - return SimpleTensor::equal(to_simple(a), to_simple(b)); -} - vespalib::string SimpleTensorEngine::to_string(const Tensor &tensor) const { diff --git a/eval/src/vespa/eval/eval/simple_tensor_engine.h b/eval/src/vespa/eval/eval/simple_tensor_engine.h index bc6d0166bd1..c751f2f6b49 100644 --- a/eval/src/vespa/eval/eval/simple_tensor_engine.h +++ b/eval/src/vespa/eval/eval/simple_tensor_engine.h @@ -20,7 +20,6 @@ public: static const TensorEngine &ref() { return _engine; }; ValueType type_of(const Tensor &tensor) const override; - bool equal(const Tensor &a, const Tensor &b) const override; vespalib::string to_string(const Tensor &tensor) const override; TensorSpec to_spec(const Tensor &tensor) const override; diff --git a/eval/src/vespa/eval/eval/tensor.cpp b/eval/src/vespa/eval/eval/tensor.cpp index ed50d33de9b..926606f8e26 100644 --- a/eval/src/vespa/eval/eval/tensor.cpp +++ b/eval/src/vespa/eval/eval/tensor.cpp @@ -2,6 +2,7 @@ #include "tensor.h" #include "tensor_engine.h" +#include "tensor_spec.h" namespace vespalib { namespace eval { @@ -9,7 +10,9 @@ namespace eval { bool operator==(const Tensor &lhs, const Tensor &rhs) { - return ((&lhs.engine() == &rhs.engine()) && lhs.engine().equal(lhs, rhs)); + auto lhs_spec = lhs.engine().to_spec(lhs); + auto rhs_spec = rhs.engine().to_spec(rhs); + return (lhs_spec == rhs_spec); } std::ostream & diff --git a/eval/src/vespa/eval/eval/tensor_engine.h b/eval/src/vespa/eval/eval/tensor_engine.h index d33c1ba0ed2..00927f0c1b1 100644 --- a/eval/src/vespa/eval/eval/tensor_engine.h +++ b/eval/src/vespa/eval/eval/tensor_engine.h @@ -41,7 +41,6 @@ struct TensorEngine using Aggr = eval::Aggr; virtual ValueType type_of(const Tensor &tensor) const = 0; - virtual bool equal(const Tensor &a, const Tensor &b) const = 0; virtual vespalib::string to_string(const Tensor &tensor) const = 0; virtual TensorSpec to_spec(const Tensor &tensor) const = 0; diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp index 617aa75c945..2a7253454ff 100644 --- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp +++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp @@ -395,63 +395,6 @@ struct TestContext { //------------------------------------------------------------------------- - void verify_equal(const TensorSpec &a, const TensorSpec &b) { - auto ta = tensor(a); - auto tb = tensor(b); - EXPECT_EQUAL(a, b); - EXPECT_EQUAL(*ta, *tb); - TensorSpec spec = engine.to_spec(*ta); - TensorSpec ref_spec = ref_engine.to_spec(*ref_engine.create(a)); - EXPECT_EQUAL(spec, ref_spec); - } - - void test_tensor_equality() { - TEST_DO(verify_equal(spec(), spec())); - TEST_DO(verify_equal(spec(10.0), spec(10.0))); - TEST_DO(verify_equal(spec(x()), spec(x()))); - TEST_DO(verify_equal(spec(x({"a"}), Seq({1})), spec(x({"a"}), Seq({1})))); - TEST_DO(verify_equal(spec({x({"a"}),y({"a"})}, Seq({1})), spec({y({"a"}),x({"a"})}, Seq({1})))); - TEST_DO(verify_equal(spec(x(3)), spec(x(3)))); - TEST_DO(verify_equal(spec({x(1),y(1)}, Seq({1})), spec({y(1),x(1)}, Seq({1})))); - TEST_DO(verify_equal(spec({x({"a"}),y(1)}, Seq({1})), spec({y(1),x({"a"})}, Seq({1})))); - TEST_DO(verify_equal(spec({y({"a"}),x(1)}, Seq({1})), spec({x(1),y({"a"})}, Seq({1})))); - } - - //------------------------------------------------------------------------- - - void verify_not_equal(const TensorSpec &a, const TensorSpec &b) { - auto ta = tensor(a); - auto tb = tensor(b); - EXPECT_NOT_EQUAL(a, b); - EXPECT_NOT_EQUAL(b, a); - EXPECT_NOT_EQUAL(*ta, *tb); - EXPECT_NOT_EQUAL(*tb, *ta); - } - - void test_tensor_inequality() { - TEST_DO(verify_not_equal(spec(1.0), spec(2.0))); - TEST_DO(verify_not_equal(spec(), spec(x()))); - TEST_DO(verify_not_equal(spec(), spec(x(1)))); - TEST_DO(verify_not_equal(spec(x()), spec(x(1)))); - TEST_DO(verify_not_equal(spec(x()), spec(y()))); - TEST_DO(verify_not_equal(spec(x(1)), spec(x(2)))); - TEST_DO(verify_not_equal(spec(x(1)), spec(y(1)))); - TEST_DO(verify_not_equal(spec(x({"a"}), Seq({1})), spec(x({"a"}), Seq({2})))); - TEST_DO(verify_not_equal(spec(x({"a"}), Seq({1})), spec(x({"b"}), Seq({1})))); - TEST_DO(verify_not_equal(spec(x({"a"}), Seq({1})), spec({x({"a"}),y({"a"})}, Seq({1})))); - TEST_DO(verify_not_equal(spec(x(1), Seq({1})), spec(x(1), Seq({2})))); - TEST_DO(verify_not_equal(spec(x(1), Seq({1})), spec(x(2), Seq({1}), Bits({1,0})))); - TEST_DO(verify_not_equal(spec(x(2), Seq({1,1}), Bits({1,0})), - spec(x(2), Seq({1,1}), Bits({0,1})))); - TEST_DO(verify_not_equal(spec(x(1), Seq({1})), spec({x(1),y(1)}, Seq({1})))); - TEST_DO(verify_not_equal(spec({x({"a"}),y(1)}, Seq({1})), spec({x({"a"}),y(1)}, Seq({2})))); - TEST_DO(verify_not_equal(spec({x({"a"}),y(1)}, Seq({1})), spec({x({"b"}),y(1)}, Seq({1})))); - TEST_DO(verify_not_equal(spec({x(2),y({"a"})}, Seq({1}), Bits({1,0})), - spec({x(2),y({"a"})}, Seq({X,1}), Bits({0,1})))); - } - - //------------------------------------------------------------------------- - void verify_reduce_result(const Eval &eval, const TensorSpec &a, const Eval::Result &expect) { TEST_DO(verify_result(eval.eval(engine, a), expect)); } @@ -989,8 +932,6 @@ struct TestContext { void run_tests() { TEST_DO(test_tensor_create_type()); - TEST_DO(test_tensor_equality()); - TEST_DO(test_tensor_inequality()); TEST_DO(test_tensor_reduce()); TEST_DO(test_tensor_map()); TEST_DO(test_tensor_apply()); diff --git a/eval/src/vespa/eval/eval/value.cpp b/eval/src/vespa/eval/eval/value.cpp index 0118d95e5cb..456d80c0ff0 100644 --- a/eval/src/vespa/eval/eval/value.cpp +++ b/eval/src/vespa/eval/eval/value.cpp @@ -14,12 +14,6 @@ TensorValue::as_double() const return _tensor->as_double(); } -bool -TensorValue::equal(const Value &rhs) const -{ - return (rhs.is_tensor() && _tensor->engine().equal(*_tensor, *rhs.as_tensor())); -} - ValueType TensorValue::type() const { diff --git a/eval/src/vespa/eval/eval/value.h b/eval/src/vespa/eval/eval/value.h index 0d727db6b91..8826faed140 100644 --- a/eval/src/vespa/eval/eval/value.h +++ b/eval/src/vespa/eval/eval/value.h @@ -27,7 +27,6 @@ struct Value { virtual double as_double() const { return 0.0; } virtual bool as_bool() const { return false; } virtual const Tensor *as_tensor() const { return nullptr; } - virtual bool equal(const Value &rhs) const = 0; virtual ValueType type() const = 0; virtual ~Value() {} }; @@ -36,7 +35,6 @@ struct ErrorValue : public Value { static ErrorValue instance; bool is_error() const override { return true; } double as_double() const override { return error_value; } - bool equal(const Value &) const override { return false; } ValueType type() const override { return ValueType::error_type(); } }; @@ -49,9 +47,6 @@ public: bool is_double() const override { return true; } double as_double() const override { return _value; } bool as_bool() const override { return (_value != 0.0); } - bool equal(const Value &rhs) const override { - return (rhs.is_double() && (_value == rhs.as_double())); - } ValueType type() const override { return ValueType::double_type(); } }; @@ -66,7 +61,6 @@ public: bool is_tensor() const override { return true; } double as_double() const override; const Tensor *as_tensor() const override { return _tensor; } - bool equal(const Value &rhs) const override; ValueType type() const override; }; diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp index 2082b7efd25..7adb95f69ca 100644 --- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp +++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp @@ -97,16 +97,6 @@ DefaultTensorEngine::type_of(const Tensor &tensor) const return my_tensor.getType(); } -bool -DefaultTensorEngine::equal(const Tensor &a, const Tensor &b) const -{ - assert(&a.engine() == this); - assert(&b.engine() == this); - const tensor::Tensor &my_a = static_cast<const tensor::Tensor &>(a); - const tensor::Tensor &my_b = static_cast<const tensor::Tensor &>(b); - return my_a.equals(my_b); -} - vespalib::string DefaultTensorEngine::to_string(const Tensor &tensor) const { diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.h b/eval/src/vespa/eval/tensor/default_tensor_engine.h index abdce6edb62..bbb03aceb1f 100644 --- a/eval/src/vespa/eval/tensor/default_tensor_engine.h +++ b/eval/src/vespa/eval/tensor/default_tensor_engine.h @@ -20,7 +20,6 @@ public: static const TensorEngine &ref() { return _engine; }; ValueType type_of(const Tensor &tensor) const override; - bool equal(const Tensor &a, const Tensor &b) const override; vespalib::string to_string(const Tensor &tensor) const override; TensorSpec to_spec(const Tensor &tensor) const override; diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp index a407a46610b..534854732c7 100644 --- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp +++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp @@ -4,6 +4,7 @@ #include "tensor_address_builder.h" #include "tensor_visitor.h" #include <vespa/eval/eval/simple_tensor_engine.h> +#include <vespa/eval/eval/tensor_spec.h> #include <vespa/vespalib/util/stringfmt.h> namespace vespalib::tensor { @@ -11,10 +12,9 @@ namespace vespalib::tensor { bool WrappedSimpleTensor::equals(const Tensor &arg) const { - if (auto other = dynamic_cast<const WrappedSimpleTensor *>(&arg)) { - return eval::SimpleTensor::equal(_tensor, other->_tensor); - } - return false; + auto lhs_spec = _tensor.engine().to_spec(_tensor); + auto rhs_spec = arg.engine().to_spec(arg); + return (lhs_spec == rhs_spec); } vespalib::string |