summaryrefslogtreecommitdiffstats
path: root/eval/src
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2017-11-03 14:16:33 +0000
committerHåvard Pettersen <havardpe@oath.com>2017-11-03 15:15:02 +0000
commit21868a8bb71720ad277d706ed6f08400ebfeb497 (patch)
tree83ccb7f593829dadd3ed631c46671203c8aa7cf1 /eval/src
parent4d5f2f54588e1d9888e4bd491361d55f5b1ed9da (diff)
remove 'equal' concept for tensors
Diffstat (limited to 'eval/src')
-rw-r--r--eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp49
-rw-r--r--eval/src/tests/eval/tensor_function/tensor_function_test.cpp4
-rw-r--r--eval/src/tests/eval/value_cache/tensor_loader_test.cpp4
-rw-r--r--eval/src/vespa/eval/eval/simple_tensor.cpp27
-rw-r--r--eval/src/vespa/eval/eval/simple_tensor.h1
-rw-r--r--eval/src/vespa/eval/eval/simple_tensor_engine.cpp6
-rw-r--r--eval/src/vespa/eval/eval/simple_tensor_engine.h1
-rw-r--r--eval/src/vespa/eval/eval/tensor.cpp5
-rw-r--r--eval/src/vespa/eval/eval/tensor_engine.h1
-rw-r--r--eval/src/vespa/eval/eval/test/tensor_conformance.cpp59
-rw-r--r--eval/src/vespa/eval/eval/value.cpp6
-rw-r--r--eval/src/vespa/eval/eval/value.h6
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.cpp10
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.h1
-rw-r--r--eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp8
15 files changed, 27 insertions, 161 deletions
diff --git a/eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp b/eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp
index beec4ed928b..150b86f27ce 100644
--- a/eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp
+++ b/eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp
@@ -13,8 +13,7 @@ using Cells = SimpleTensor::Cells;
using Address = SimpleTensor::Address;
using Stash = vespalib::Stash;
-// need to specify numbers explicitly as size_t to avoid ambiguous behavior for 0
-constexpr size_t operator "" _z (unsigned long long int n) { return n; }
+TensorSpec to_spec(const Tensor &a) { return a.engine().to_spec(a); }
const Tensor &unwrap(const Value &value) {
ASSERT_TRUE(value.is_tensor());
@@ -56,28 +55,8 @@ TEST("require that simple tensors can be built using tensor spec") {
.add({{"w", "yyy"}, {"x", 1}, {"y", "yyy"}, {"z", 0}}, 0.0)
.add({{"w", "yyy"}, {"x", 1}, {"y", "yyy"}, {"z", 1}}, 4.0);
auto full_tensor = SimpleTensorEngine::ref().create(full_spec);
- SimpleTensor expect_tensor(ValueType::from_spec("tensor(w{},x[2],y{},z[2])"),
- CellBuilder()
- .add({{"xxx"}, {0_z}, {"xxx"}, {0_z}}, 1.0)
- .add({{"xxx"}, {0_z}, {"xxx"}, {1_z}}, 0.0)
- .add({{"xxx"}, {0_z}, {"yyy"}, {0_z}}, 0.0)
- .add({{"xxx"}, {0_z}, {"yyy"}, {1_z}}, 2.0)
- .add({{"xxx"}, {1_z}, {"xxx"}, {0_z}}, 0.0)
- .add({{"xxx"}, {1_z}, {"xxx"}, {1_z}}, 0.0)
- .add({{"xxx"}, {1_z}, {"yyy"}, {0_z}}, 0.0)
- .add({{"xxx"}, {1_z}, {"yyy"}, {1_z}}, 0.0)
- .add({{"yyy"}, {0_z}, {"xxx"}, {0_z}}, 0.0)
- .add({{"yyy"}, {0_z}, {"xxx"}, {1_z}}, 0.0)
- .add({{"yyy"}, {0_z}, {"yyy"}, {0_z}}, 0.0)
- .add({{"yyy"}, {0_z}, {"yyy"}, {1_z}}, 0.0)
- .add({{"yyy"}, {1_z}, {"xxx"}, {0_z}}, 3.0)
- .add({{"yyy"}, {1_z}, {"xxx"}, {1_z}}, 0.0)
- .add({{"yyy"}, {1_z}, {"yyy"}, {0_z}}, 0.0)
- .add({{"yyy"}, {1_z}, {"yyy"}, {1_z}}, 4.0)
- .build());
- EXPECT_EQUAL(expect_tensor, *tensor);
- EXPECT_EQUAL(expect_tensor, *full_tensor);
- EXPECT_EQUAL(full_spec, tensor->engine().to_spec(*tensor));
+ EXPECT_EQUAL(full_spec, to_spec(*tensor));
+ EXPECT_EQUAL(full_spec, to_spec(*full_tensor));
};
TEST("require that simple tensors can have their values negated") {
@@ -92,10 +71,10 @@ TEST("require that simple tensors can have their values negated") {
.add({{"x","2"},{"y","1"}}, 3)
.add({{"x","1"},{"y","2"}}, -5));
auto result = tensor->map([](double a){ return -a; });
- EXPECT_EQUAL(*expect, *result);
+ EXPECT_EQUAL(to_spec(*expect), to_spec(*result));
Stash stash;
const Value &result2 = SimpleTensorEngine::ref().map(TensorValue(*tensor), operation::Neg::f, stash);
- EXPECT_EQUAL(*expect, unwrap(result2));
+ EXPECT_EQUAL(to_spec(*expect), to_spec(unwrap(result2)));
}
TEST("require that simple tensors can be multiplied with each other") {
@@ -117,10 +96,10 @@ TEST("require that simple tensors can be multiplied with each other") {
.add({{"x","2"},{"y","1"},{"z","2"}}, 39)
.add({{"x","1"},{"y","2"},{"z","1"}}, 55));
auto result = SimpleTensor::join(*lhs, *rhs, [](double a, double b){ return (a * b); });
- EXPECT_EQUAL(*expect, *result);
+ EXPECT_EQUAL(to_spec(*expect), to_spec(*result));
Stash stash;
const Value &result2 = SimpleTensorEngine::ref().join(TensorValue(*lhs), TensorValue(*rhs), operation::Mul::f, stash);
- EXPECT_EQUAL(*expect, unwrap(result2));
+ EXPECT_EQUAL(to_spec(*expect), to_spec(unwrap(result2)));
}
TEST("require that simple tensors support dimension reduction") {
@@ -147,21 +126,21 @@ TEST("require that simple tensors support dimension reduction") {
auto result_sum_y = tensor->reduce(aggr_sum, {"y"});
auto result_sum_x = tensor->reduce(aggr_sum, {"x"});
auto result_sum_all = tensor->reduce(aggr_sum, {"x", "y"});
- EXPECT_EQUAL(*expect_sum_y, *result_sum_y);
- EXPECT_EQUAL(*expect_sum_x, *result_sum_x);
- EXPECT_EQUAL(*expect_sum_all, *result_sum_all);
+ EXPECT_EQUAL(to_spec(*expect_sum_y), to_spec(*result_sum_y));
+ EXPECT_EQUAL(to_spec(*expect_sum_x), to_spec(*result_sum_x));
+ EXPECT_EQUAL(to_spec(*expect_sum_all), to_spec(*result_sum_all));
const Value &result_sum_y_2 = SimpleTensorEngine::ref().reduce(TensorValue(*tensor), Aggr::SUM, {"y"}, stash);
const Value &result_sum_x_2 = SimpleTensorEngine::ref().reduce(TensorValue(*tensor), Aggr::SUM, {"x"}, stash);
const Value &result_sum_all_2 = SimpleTensorEngine::ref().reduce(TensorValue(*tensor), Aggr::SUM, {"x", "y"}, stash);
const Value &result_sum_all_3 = SimpleTensorEngine::ref().reduce(TensorValue(*tensor), Aggr::SUM, {}, stash);
- EXPECT_EQUAL(*expect_sum_y, unwrap(result_sum_y_2));
- EXPECT_EQUAL(*expect_sum_x, unwrap(result_sum_x_2));
+ EXPECT_EQUAL(to_spec(*expect_sum_y), to_spec(unwrap(result_sum_y_2)));
+ EXPECT_EQUAL(to_spec(*expect_sum_x), to_spec(unwrap(result_sum_x_2)));
EXPECT_TRUE(result_sum_all_2.is_double());
EXPECT_TRUE(result_sum_all_3.is_double());
EXPECT_EQUAL(21, result_sum_all_2.as_double());
EXPECT_EQUAL(21, result_sum_all_3.as_double());
- EXPECT_EQUAL(*result_sum_y, *result_sum_y);
- EXPECT_NOT_EQUAL(*result_sum_y, *result_sum_x);
+ EXPECT_EQUAL(to_spec(*result_sum_y), to_spec(*result_sum_y));
+ EXPECT_NOT_EQUAL(to_spec(*result_sum_y), to_spec(*result_sum_x));
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
index 5b2d0848f64..8bd86621bf6 100644
--- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
+++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
@@ -102,7 +102,9 @@ void verify_equal(const Tensor &expect, const Value &value) {
const Tensor *tensor = value.as_tensor();
ASSERT_TRUE(tensor != nullptr);
ASSERT_EQUAL(&expect.engine(), &tensor->engine());
- EXPECT_TRUE(expect.engine().equal(expect, *tensor));
+ auto expect_spec = expect.engine().to_spec(expect);
+ auto value_spec = tensor->engine().to_spec(*tensor);
+ EXPECT_EQUAL(expect_spec, value_spec);
}
TEST("require that tensor injection works") {
diff --git a/eval/src/tests/eval/value_cache/tensor_loader_test.cpp b/eval/src/tests/eval/value_cache/tensor_loader_test.cpp
index 20a77eb9fe3..ee8e502815f 100644
--- a/eval/src/tests/eval/value_cache/tensor_loader_test.cpp
+++ b/eval/src/tests/eval/value_cache/tensor_loader_test.cpp
@@ -42,8 +42,8 @@ std::unique_ptr<Tensor> make_mixed_tensor() {
void verify_tensor(std::unique_ptr<Tensor> expect, ConstantValue::UP actual) {
const auto &engine = expect->engine();
ASSERT_EQUAL(engine.type_of(*expect), actual->type());
- EXPECT_TRUE(&engine == &actual->value().as_tensor()->engine());
- EXPECT_TRUE(engine.equal(*expect, *actual->value().as_tensor()));
+ ASSERT_TRUE(&engine == &actual->value().as_tensor()->engine());
+ EXPECT_EQUAL(engine.to_spec(*expect), engine.to_spec(*actual->value().as_tensor()));
}
TEST_F("require that invalid types loads an empty double", ConstantTensorLoader(SimpleTensorEngine::ref())) {
diff --git a/eval/src/vespa/eval/eval/simple_tensor.cpp b/eval/src/vespa/eval/eval/simple_tensor.cpp
index 75c170d48ba..e39e926708d 100644
--- a/eval/src/vespa/eval/eval/simple_tensor.cpp
+++ b/eval/src/vespa/eval/eval/simple_tensor.cpp
@@ -611,33 +611,6 @@ SimpleTensor::create(const TensorSpec &spec)
return builder.build();
}
-bool
-SimpleTensor::equal(const SimpleTensor &a, const SimpleTensor &b)
-{
- if (a.type() != b.type()) {
- return false;
- }
- TypeAnalyzer type_info(a.type(), b.type());
- View view_a(a, type_info.overlap_a);
- View view_b(b, type_info.overlap_b);
- const CellRef *pos_a = view_a.refs_begin();
- const CellRef *end_a = view_a.refs_end();
- const CellRef *pos_b = view_b.refs_begin();
- const CellRef *end_b = view_b.refs_end();
- ViewMatcher::CrossCompare cmp(view_a.selector(), view_b.selector());
- while ((pos_a != end_a) && (pos_b != end_b)) {
- if (cmp.compare(pos_a->get(), pos_b->get()) != ViewMatcher::CrossCompare::Result::EQUAL) {
- return false;
- }
- if (pos_a->get().value != pos_b->get().value) {
- return false;
- }
- ++pos_a;
- ++pos_b;
- }
- return ((pos_a == end_a) && (pos_b == end_b));
-}
-
std::unique_ptr<SimpleTensor>
SimpleTensor::join(const SimpleTensor &a, const SimpleTensor &b, join_fun_t function)
{
diff --git a/eval/src/vespa/eval/eval/simple_tensor.h b/eval/src/vespa/eval/eval/simple_tensor.h
index ec154ff969a..366796f00d8 100644
--- a/eval/src/vespa/eval/eval/simple_tensor.h
+++ b/eval/src/vespa/eval/eval/simple_tensor.h
@@ -88,7 +88,6 @@ public:
std::unique_ptr<SimpleTensor> reduce(Aggregator &aggr, const std::vector<vespalib::string> &dimensions) const;
std::unique_ptr<SimpleTensor> rename(const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to) const;
static std::unique_ptr<SimpleTensor> create(const TensorSpec &spec);
- static bool equal(const SimpleTensor &a, const SimpleTensor &b);
static std::unique_ptr<SimpleTensor> join(const SimpleTensor &a, const SimpleTensor &b, join_fun_t function);
static std::unique_ptr<SimpleTensor> concat(const SimpleTensor &a, const SimpleTensor &b, const vespalib::string &dimension);
static void encode(const SimpleTensor &tensor, nbostream &output);
diff --git a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp
index d69715cab22..21498ca2ff1 100644
--- a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp
+++ b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp
@@ -47,12 +47,6 @@ SimpleTensorEngine::type_of(const Tensor &tensor) const
return to_simple(tensor).type();
}
-bool
-SimpleTensorEngine::equal(const Tensor &a, const Tensor &b) const
-{
- return SimpleTensor::equal(to_simple(a), to_simple(b));
-}
-
vespalib::string
SimpleTensorEngine::to_string(const Tensor &tensor) const
{
diff --git a/eval/src/vespa/eval/eval/simple_tensor_engine.h b/eval/src/vespa/eval/eval/simple_tensor_engine.h
index bc6d0166bd1..c751f2f6b49 100644
--- a/eval/src/vespa/eval/eval/simple_tensor_engine.h
+++ b/eval/src/vespa/eval/eval/simple_tensor_engine.h
@@ -20,7 +20,6 @@ public:
static const TensorEngine &ref() { return _engine; };
ValueType type_of(const Tensor &tensor) const override;
- bool equal(const Tensor &a, const Tensor &b) const override;
vespalib::string to_string(const Tensor &tensor) const override;
TensorSpec to_spec(const Tensor &tensor) const override;
diff --git a/eval/src/vespa/eval/eval/tensor.cpp b/eval/src/vespa/eval/eval/tensor.cpp
index ed50d33de9b..926606f8e26 100644
--- a/eval/src/vespa/eval/eval/tensor.cpp
+++ b/eval/src/vespa/eval/eval/tensor.cpp
@@ -2,6 +2,7 @@
#include "tensor.h"
#include "tensor_engine.h"
+#include "tensor_spec.h"
namespace vespalib {
namespace eval {
@@ -9,7 +10,9 @@ namespace eval {
bool
operator==(const Tensor &lhs, const Tensor &rhs)
{
- return ((&lhs.engine() == &rhs.engine()) && lhs.engine().equal(lhs, rhs));
+ auto lhs_spec = lhs.engine().to_spec(lhs);
+ auto rhs_spec = rhs.engine().to_spec(rhs);
+ return (lhs_spec == rhs_spec);
}
std::ostream &
diff --git a/eval/src/vespa/eval/eval/tensor_engine.h b/eval/src/vespa/eval/eval/tensor_engine.h
index d33c1ba0ed2..00927f0c1b1 100644
--- a/eval/src/vespa/eval/eval/tensor_engine.h
+++ b/eval/src/vespa/eval/eval/tensor_engine.h
@@ -41,7 +41,6 @@ struct TensorEngine
using Aggr = eval::Aggr;
virtual ValueType type_of(const Tensor &tensor) const = 0;
- virtual bool equal(const Tensor &a, const Tensor &b) const = 0;
virtual vespalib::string to_string(const Tensor &tensor) const = 0;
virtual TensorSpec to_spec(const Tensor &tensor) const = 0;
diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
index 617aa75c945..2a7253454ff 100644
--- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
+++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
@@ -395,63 +395,6 @@ struct TestContext {
//-------------------------------------------------------------------------
- void verify_equal(const TensorSpec &a, const TensorSpec &b) {
- auto ta = tensor(a);
- auto tb = tensor(b);
- EXPECT_EQUAL(a, b);
- EXPECT_EQUAL(*ta, *tb);
- TensorSpec spec = engine.to_spec(*ta);
- TensorSpec ref_spec = ref_engine.to_spec(*ref_engine.create(a));
- EXPECT_EQUAL(spec, ref_spec);
- }
-
- void test_tensor_equality() {
- TEST_DO(verify_equal(spec(), spec()));
- TEST_DO(verify_equal(spec(10.0), spec(10.0)));
- TEST_DO(verify_equal(spec(x()), spec(x())));
- TEST_DO(verify_equal(spec(x({"a"}), Seq({1})), spec(x({"a"}), Seq({1}))));
- TEST_DO(verify_equal(spec({x({"a"}),y({"a"})}, Seq({1})), spec({y({"a"}),x({"a"})}, Seq({1}))));
- TEST_DO(verify_equal(spec(x(3)), spec(x(3))));
- TEST_DO(verify_equal(spec({x(1),y(1)}, Seq({1})), spec({y(1),x(1)}, Seq({1}))));
- TEST_DO(verify_equal(spec({x({"a"}),y(1)}, Seq({1})), spec({y(1),x({"a"})}, Seq({1}))));
- TEST_DO(verify_equal(spec({y({"a"}),x(1)}, Seq({1})), spec({x(1),y({"a"})}, Seq({1}))));
- }
-
- //-------------------------------------------------------------------------
-
- void verify_not_equal(const TensorSpec &a, const TensorSpec &b) {
- auto ta = tensor(a);
- auto tb = tensor(b);
- EXPECT_NOT_EQUAL(a, b);
- EXPECT_NOT_EQUAL(b, a);
- EXPECT_NOT_EQUAL(*ta, *tb);
- EXPECT_NOT_EQUAL(*tb, *ta);
- }
-
- void test_tensor_inequality() {
- TEST_DO(verify_not_equal(spec(1.0), spec(2.0)));
- TEST_DO(verify_not_equal(spec(), spec(x())));
- TEST_DO(verify_not_equal(spec(), spec(x(1))));
- TEST_DO(verify_not_equal(spec(x()), spec(x(1))));
- TEST_DO(verify_not_equal(spec(x()), spec(y())));
- TEST_DO(verify_not_equal(spec(x(1)), spec(x(2))));
- TEST_DO(verify_not_equal(spec(x(1)), spec(y(1))));
- TEST_DO(verify_not_equal(spec(x({"a"}), Seq({1})), spec(x({"a"}), Seq({2}))));
- TEST_DO(verify_not_equal(spec(x({"a"}), Seq({1})), spec(x({"b"}), Seq({1}))));
- TEST_DO(verify_not_equal(spec(x({"a"}), Seq({1})), spec({x({"a"}),y({"a"})}, Seq({1}))));
- TEST_DO(verify_not_equal(spec(x(1), Seq({1})), spec(x(1), Seq({2}))));
- TEST_DO(verify_not_equal(spec(x(1), Seq({1})), spec(x(2), Seq({1}), Bits({1,0}))));
- TEST_DO(verify_not_equal(spec(x(2), Seq({1,1}), Bits({1,0})),
- spec(x(2), Seq({1,1}), Bits({0,1}))));
- TEST_DO(verify_not_equal(spec(x(1), Seq({1})), spec({x(1),y(1)}, Seq({1}))));
- TEST_DO(verify_not_equal(spec({x({"a"}),y(1)}, Seq({1})), spec({x({"a"}),y(1)}, Seq({2}))));
- TEST_DO(verify_not_equal(spec({x({"a"}),y(1)}, Seq({1})), spec({x({"b"}),y(1)}, Seq({1}))));
- TEST_DO(verify_not_equal(spec({x(2),y({"a"})}, Seq({1}), Bits({1,0})),
- spec({x(2),y({"a"})}, Seq({X,1}), Bits({0,1}))));
- }
-
- //-------------------------------------------------------------------------
-
void verify_reduce_result(const Eval &eval, const TensorSpec &a, const Eval::Result &expect) {
TEST_DO(verify_result(eval.eval(engine, a), expect));
}
@@ -989,8 +932,6 @@ struct TestContext {
void run_tests() {
TEST_DO(test_tensor_create_type());
- TEST_DO(test_tensor_equality());
- TEST_DO(test_tensor_inequality());
TEST_DO(test_tensor_reduce());
TEST_DO(test_tensor_map());
TEST_DO(test_tensor_apply());
diff --git a/eval/src/vespa/eval/eval/value.cpp b/eval/src/vespa/eval/eval/value.cpp
index 0118d95e5cb..456d80c0ff0 100644
--- a/eval/src/vespa/eval/eval/value.cpp
+++ b/eval/src/vespa/eval/eval/value.cpp
@@ -14,12 +14,6 @@ TensorValue::as_double() const
return _tensor->as_double();
}
-bool
-TensorValue::equal(const Value &rhs) const
-{
- return (rhs.is_tensor() && _tensor->engine().equal(*_tensor, *rhs.as_tensor()));
-}
-
ValueType
TensorValue::type() const
{
diff --git a/eval/src/vespa/eval/eval/value.h b/eval/src/vespa/eval/eval/value.h
index 0d727db6b91..8826faed140 100644
--- a/eval/src/vespa/eval/eval/value.h
+++ b/eval/src/vespa/eval/eval/value.h
@@ -27,7 +27,6 @@ struct Value {
virtual double as_double() const { return 0.0; }
virtual bool as_bool() const { return false; }
virtual const Tensor *as_tensor() const { return nullptr; }
- virtual bool equal(const Value &rhs) const = 0;
virtual ValueType type() const = 0;
virtual ~Value() {}
};
@@ -36,7 +35,6 @@ struct ErrorValue : public Value {
static ErrorValue instance;
bool is_error() const override { return true; }
double as_double() const override { return error_value; }
- bool equal(const Value &) const override { return false; }
ValueType type() const override { return ValueType::error_type(); }
};
@@ -49,9 +47,6 @@ public:
bool is_double() const override { return true; }
double as_double() const override { return _value; }
bool as_bool() const override { return (_value != 0.0); }
- bool equal(const Value &rhs) const override {
- return (rhs.is_double() && (_value == rhs.as_double()));
- }
ValueType type() const override { return ValueType::double_type(); }
};
@@ -66,7 +61,6 @@ public:
bool is_tensor() const override { return true; }
double as_double() const override;
const Tensor *as_tensor() const override { return _tensor; }
- bool equal(const Value &rhs) const override;
ValueType type() const override;
};
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
index 2082b7efd25..7adb95f69ca 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
@@ -97,16 +97,6 @@ DefaultTensorEngine::type_of(const Tensor &tensor) const
return my_tensor.getType();
}
-bool
-DefaultTensorEngine::equal(const Tensor &a, const Tensor &b) const
-{
- assert(&a.engine() == this);
- assert(&b.engine() == this);
- const tensor::Tensor &my_a = static_cast<const tensor::Tensor &>(a);
- const tensor::Tensor &my_b = static_cast<const tensor::Tensor &>(b);
- return my_a.equals(my_b);
-}
-
vespalib::string
DefaultTensorEngine::to_string(const Tensor &tensor) const
{
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.h b/eval/src/vespa/eval/tensor/default_tensor_engine.h
index abdce6edb62..bbb03aceb1f 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.h
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.h
@@ -20,7 +20,6 @@ public:
static const TensorEngine &ref() { return _engine; };
ValueType type_of(const Tensor &tensor) const override;
- bool equal(const Tensor &a, const Tensor &b) const override;
vespalib::string to_string(const Tensor &tensor) const override;
TensorSpec to_spec(const Tensor &tensor) const override;
diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
index a407a46610b..534854732c7 100644
--- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
@@ -4,6 +4,7 @@
#include "tensor_address_builder.h"
#include "tensor_visitor.h"
#include <vespa/eval/eval/simple_tensor_engine.h>
+#include <vespa/eval/eval/tensor_spec.h>
#include <vespa/vespalib/util/stringfmt.h>
namespace vespalib::tensor {
@@ -11,10 +12,9 @@ namespace vespalib::tensor {
bool
WrappedSimpleTensor::equals(const Tensor &arg) const
{
- if (auto other = dynamic_cast<const WrappedSimpleTensor *>(&arg)) {
- return eval::SimpleTensor::equal(_tensor, other->_tensor);
- }
- return false;
+ auto lhs_spec = _tensor.engine().to_spec(_tensor);
+ auto rhs_spec = arg.engine().to_spec(arg);
+ return (lhs_spec == rhs_spec);
}
vespalib::string