diff options
22 files changed, 418 insertions, 55 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index e0c6f8572e0..e7283849178 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -170,6 +170,11 @@ std::unique_ptr<Tensor> createTensorWith2Cells() { {{{"x", "9"}, {"y", "9"}}, 11} }, {"x", "y"}); } +std::unique_ptr<Tensor> createExpectedUpdatedTensorWith2Cells() { + return createTensor({ {{{"x", "8"}, {"y", "9"}}, 2}, + {{{"x", "9"}, {"y", "9"}}, 11} }, {"x", "y"}); +} + FieldValue::UP createTensorFieldValueWith2Cells() { auto fv(std::make_unique<TensorFieldValue>()); *fv = createTensorWith2Cells(); @@ -953,7 +958,8 @@ DocumentUpdateTest::testTensorModifyUpdate() TestDocMan docMan; Document::UP doc(docMan.createDocument()); Document updated(*doc); - updated.setValue(updated.getField("tensor"), *createTensorFieldValueWith2Cells()); + auto oldTensor = createTensorFieldValueWith2Cells(); + updated.setValue(updated.getField("tensor"), *oldTensor); CPPUNIT_ASSERT(*doc != updated); testValueUpdate(*createTensorModifyUpdate(), *DataType::TENSOR); DocumentUpdate upd(docMan.getTypeRepo(), *doc->getDataType(), doc->getId()); @@ -962,9 +968,8 @@ DocumentUpdateTest::testTensorModifyUpdate() FieldValue::UP fval(updated.getValue("tensor")); CPPUNIT_ASSERT(fval); auto &tensor = asTensor(*fval); - // TODO: Check that tensor is correctly modified. - // For now, value is unchanged. - CPPUNIT_ASSERT(tensor.equals(*createTensorWith2Cells())); + auto expectedUpdatedTensor = createExpectedUpdatedTensorWith2Cells(); + CPPUNIT_ASSERT(tensor.equals(*expectedUpdatedTensor)); } void diff --git a/document/src/vespa/document/update/tensormodifyupdate.cpp b/document/src/vespa/document/update/tensormodifyupdate.cpp index 87da385a57a..5529adaf5ce 100644 --- a/document/src/vespa/document/update/tensormodifyupdate.cpp +++ b/document/src/vespa/document/update/tensormodifyupdate.cpp @@ -7,7 +7,9 @@ #include <vespa/document/fieldvalue/tensorfieldvalue.h> #include <vespa/document/util/serializableexceptions.h> #include <vespa/document/serialization/vespadocumentdeserializer.h> +#include <vespa/eval/eval/operation.h> #include <vespa/eval/tensor/tensor.h> +#include <vespa/eval/tensor/cell_values.h> #include <vespa/vespalib/objects/nbostream.h> #include <vespa/vespalib/stllike/asciistream.h> #include <vespa/vespalib/util/stringfmt.h> @@ -19,8 +21,37 @@ using vespalib::IllegalStateException; using vespalib::tensor::Tensor; using vespalib::make_string; +using join_fun_t = double (*)(double, double); + namespace document { +namespace { + +double +replace(double, double b) +{ + return b; +} + +join_fun_t +getJoinFunction(TensorModifyUpdate::Operation operation) +{ + using Operation = TensorModifyUpdate::Operation; + + switch (operation) { + case Operation::REPLACE: + return replace; + case Operation::ADD: + return vespalib::eval::operation::Add::f; + case Operation::MUL: + return vespalib::eval::operation::Mul::f; + default: + throw IllegalArgumentException("Bad operation", VESPA_STRLOC); + } +} + +} + IMPLEMENT_IDENTIFIABLE(TensorModifyUpdate, ValueUpdate); TensorModifyUpdate::TensorModifyUpdate() @@ -92,9 +123,12 @@ TensorModifyUpdate::applyTo(FieldValue& value) const if (value.inherits(TensorFieldValue::classId)) { TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value); auto &oldTensor = tensorFieldValue.getAsTensorPtr(); - // TODO: Apply operation with tensor - auto newTensor = oldTensor->clone(); - tensorFieldValue = std::move(newTensor); + auto &cellTensor = _tensor->getAsTensorPtr(); + if (cellTensor) { + vespalib::tensor::CellValues cellValues(static_cast<const vespalib::tensor::SparseTensor &>(*cellTensor)); + auto newTensor = oldTensor->modify(getJoinFunction(_operation), cellValues); + tensorFieldValue = std::move(newTensor); + } } else { std::string err = make_string( "Unable to perform a tensor modify update on a \"%s\" field " diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt index 22479952270..a5bb6fec31e 100644 --- a/eval/CMakeLists.txt +++ b/eval/CMakeLists.txt @@ -39,6 +39,7 @@ vespa_define_module( src/tests/tensor/tensor_address src/tests/tensor/tensor_conformance src/tests/tensor/tensor_mapper + src/tests/tensor/tensor_modify src/tests/tensor/tensor_performance src/tests/tensor/tensor_serialization src/tests/tensor/tensor_slime_serialization diff --git a/eval/src/tests/tensor/tensor_modify/CMakeLists.txt b/eval/src/tests/tensor/tensor_modify/CMakeLists.txt new file mode 100644 index 00000000000..2d4055db7e2 --- /dev/null +++ b/eval/src/tests/tensor/tensor_modify/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_tensor_modify_test_app TEST + SOURCES + tensor_modify_test.cpp + DEPENDS + vespaeval +) +vespa_add_test(NAME eval_tensor_modify_test_app COMMAND eval_tensor_modify_test_app) diff --git a/eval/src/tests/tensor/tensor_modify/tensor_modify_test.cpp b/eval/src/tests/tensor/tensor_modify/tensor_modify_test.cpp new file mode 100644 index 00000000000..67dc033fb80 --- /dev/null +++ b/eval/src/tests/tensor/tensor_modify/tensor_modify_test.cpp @@ -0,0 +1,68 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/testkit/test_kit.h> +#include <vespa/vespalib/util/stringfmt.h> +#include <vespa/eval/tensor/cell_values.h> +#include <vespa/eval/tensor/sparse/sparse_tensor.h> +#include <vespa/eval/tensor/default_tensor_engine.h> +#include <vespa/eval/eval/tensor_spec.h> + +using vespalib::eval::Value; +using vespalib::eval::TensorSpec; +using namespace vespalib::tensor; + +namespace { + +double +replace(double, double b) +{ + return b; +} + +template <typename Tensor> +const Tensor *asTensor(Value &value) +{ + auto *tensor = dynamic_cast<const Tensor *>(value.as_tensor()); + ASSERT_TRUE(tensor); + return tensor; +} + +} + +void checkUpdate(const TensorSpec &source, const TensorSpec &update, const TensorSpec &expect) { + auto sourceValue = DefaultTensorEngine::ref().from_spec(source); + auto sourceTensor = asTensor<Tensor>(*sourceValue); + auto updateValue = DefaultTensorEngine::ref().from_spec(update); + auto updateTensor = asTensor<SparseTensor>(*updateValue); + const CellValues cellValues(*updateTensor); + auto actualTensor = sourceTensor->modify(replace, cellValues); + auto actual = actualTensor->toSpec(); + auto expectValue = DefaultTensorEngine::ref().from_spec(expect); + auto expectTensor = asTensor<Tensor>(*expectValue); + auto expectPadded = expectTensor->toSpec(); + EXPECT_EQUAL(actual, expectPadded); +} + +TEST("require that sparse tensors can be modified") { + checkUpdate(TensorSpec("tensor(x{},y{})") + .add({{"x","8"},{"y","9"}}, 11) + .add({{"x","9"},{"y","9"}}, 11), + TensorSpec("tensor(x{},y{})") + .add({{"x","8"},{"y","9"}}, 2), + TensorSpec("tensor(x{},y{})") + .add({{"x","8"},{"y","9"}}, 2) + .add({{"x","9"},{"y","9"}}, 11)); +} + +TEST("require that dense tensors can be modified") { + checkUpdate(TensorSpec("tensor(x[10],y[10])") + .add({{"x",8},{"y",9}}, 11) + .add({{"x",9},{"y",9}}, 11), + TensorSpec("tensor(x{},y{})") + .add({{"x","8"},{"y","9"}}, 2), + TensorSpec("tensor(x[10],y[10])") + .add({{"x",8},{"y",9}}, 2) + .add({{"x",9},{"y",9}}, 11)); +} + +TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/vespa/eval/tensor/cell_values.h b/eval/src/vespa/eval/tensor/cell_values.h new file mode 100644 index 00000000000..dabdcde3294 --- /dev/null +++ b/eval/src/vespa/eval/tensor/cell_values.h @@ -0,0 +1,28 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/eval/tensor/tensor_visitor.h> +#include <vespa/eval/tensor/sparse/sparse_tensor.h> + +namespace vespalib::tensor { + +/* + * A collection of tensor cells, used as argument for modifying a subset + * of cells in a tensor. + */ +class CellValues { + const SparseTensor &_tensor; + +public: + CellValues(const SparseTensor &tensor) + : _tensor(tensor) + { + } + + void accept(TensorVisitor &visitor) const { + _tensor.accept(visitor); + } +}; + +} diff --git a/eval/src/vespa/eval/tensor/dense/CMakeLists.txt b/eval/src/vespa/eval/tensor/dense/CMakeLists.txt index 2cca7f9b6d0..9d54dd26763 100644 --- a/eval/src/vespa/eval/tensor/dense/CMakeLists.txt +++ b/eval/src/vespa/eval/tensor/dense/CMakeLists.txt @@ -10,8 +10,10 @@ vespa_add_library(eval_tensor_dense OBJECT dense_replace_type_function.cpp dense_tensor.cpp dense_tensor_address_combiner.cpp + dense_tensor_address_mapper.cpp dense_tensor_builder.cpp dense_tensor_cells_iterator.cpp + dense_tensor_modify.cpp dense_tensor_reduce.cpp dense_tensor_view.cpp dense_xw_product_function.cpp diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_mapper.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_mapper.cpp new file mode 100644 index 00000000000..09bc546c982 --- /dev/null +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_mapper.cpp @@ -0,0 +1,48 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "dense_tensor_address_mapper.h" +#include <vespa/eval/eval/value_type.h> +#include <vespa/eval/tensor/types.h> +#include <vespa/eval/tensor/tensor_address.h> +#include <vespa/eval/tensor/tensor_address_element_iterator.h> + +namespace vespalib::tensor { + +uint32_t +DenseTensorAddressMapper::mapLabelToNumber(stringref label) +{ + uint32_t result = 0; + for (char c : label) { + if (c < '0' || c > '9') { + return BAD_LABEL; // bad char + } + result = result * 10 + (c - '0'); + if (result > 100000000) { + return BAD_LABEL; // overflow + } + } + return result; +} + +uint32_t +DenseTensorAddressMapper::mapAddressToIndex(const TensorAddress &address, const eval::ValueType type) +{ + uint32_t idx = 0; + TensorAddressElementIterator<TensorAddress> addressIterator(address); + for (const auto &dimension : type.dimensions()) { + if (addressIterator.skipToDimension(dimension.name)) { + uint32_t label = mapLabelToNumber(addressIterator.label()); + if (label == BAD_LABEL || label >= dimension.size) { + return BAD_ADDRESS; + } + idx = idx * dimension.size + label; + addressIterator.next(); + } else { + // output dimension not in input + idx = idx * dimension.size; + } + } + return idx; +} + +} diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_mapper.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_mapper.h new file mode 100644 index 00000000000..cf20fc6ad3c --- /dev/null +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_mapper.h @@ -0,0 +1,27 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <limits> +#include <cstdint> + +namespace vespalib::eval { class ValueType; } +namespace vespalib { class stringref; } + +namespace vespalib::tensor { + +class TensorAddress; + +/** + * Utility class for mapping of tensor adress to index + */ +class DenseTensorAddressMapper +{ +public: + static constexpr uint32_t BAD_LABEL = std::numeric_limits<uint32_t>::max(); + static constexpr uint32_t BAD_ADDRESS = std::numeric_limits<uint32_t>::max(); + static uint32_t mapLabelToNumber(stringref label); + static uint32_t mapAddressToIndex(const TensorAddress &address, const eval::ValueType type); +}; + +} diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_modify.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_modify.cpp new file mode 100644 index 00000000000..4e2940f2516 --- /dev/null +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_modify.cpp @@ -0,0 +1,33 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "dense_tensor_modify.h" +#include "dense_tensor_address_mapper.h" +#include "dense_tensor.h" + +namespace vespalib::tensor { + +DenseTensorModify::DenseTensorModify(join_fun_t op, const eval::ValueType &type, Cells cells) + : _op(op), + _type(type), + _cells(std::move(cells)) +{ +} + +DenseTensorModify::~DenseTensorModify() = default; + +void +DenseTensorModify::visit(const TensorAddress &address, double value) +{ + uint32_t idx = DenseTensorAddressMapper::mapAddressToIndex(address, _type); + if (idx != DenseTensorAddressMapper::BAD_ADDRESS) { + _cells[idx] = _op(_cells[idx], value); + } +} + +std::unique_ptr<Tensor> +DenseTensorModify::build() +{ + return std::make_unique<DenseTensor>(std::move(_type), std::move(_cells)); +} + +} diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_modify.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_modify.h new file mode 100644 index 00000000000..848e6e559c2 --- /dev/null +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_modify.h @@ -0,0 +1,30 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/eval/tensor/tensor_visitor.h> +#include "dense_tensor_view.h" + +namespace vespalib::tensor { + +/* + * This class handles tensor modify update on a dense tensor. + * For all cells visited, a join function is applied to determine + * the new cell value. + */ +class DenseTensorModify : public TensorVisitor +{ + using join_fun_t = Tensor::join_fun_t; + using Cells = DenseTensorView::Cells; + join_fun_t _op; + eval::ValueType _type; + Cells _cells; + +public: + DenseTensorModify(join_fun_t op, const eval::ValueType &type, Cells cells); + ~DenseTensorModify(); + void visit(const TensorAddress &address, double value) override; + std::unique_ptr<Tensor> build(); +}; + +} diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp index f2ddcb38698..f758b0ec915 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp @@ -3,9 +3,11 @@ #include "dense_tensor_view.h" #include "dense_tensor_apply.hpp" #include "dense_tensor_reduce.hpp" +#include "dense_tensor_modify.h" #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/util/exceptions.h> #include <vespa/vespalib/stllike/asciistream.h> +#include <vespa/eval/tensor/cell_values.h> #include <vespa/eval/tensor/tensor_address_builder.h> #include <vespa/eval/tensor/tensor_visitor.h> #include <vespa/eval/eval/operation.h> @@ -280,4 +282,12 @@ DenseTensorView::reduce(join_fun_t op, const std::vector<vespalib::string> &dime : reduce_all(op, dimensions); } +std::unique_ptr<Tensor> +DenseTensorView::modify(join_fun_t op, const CellValues &cellValues) const +{ + DenseTensorModify modifier(op, _typeRef, Cells(_cellsRef.cbegin(), _cellsRef.cend())); + cellValues.accept(modifier); + return modifier.build(); +} + } diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h index 7104a1f86f0..5aedcf6fb8d 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h @@ -53,6 +53,7 @@ public: Tensor::UP apply(const CellFunction &func) const override; Tensor::UP join(join_fun_t function, const Tensor &arg) const override; Tensor::UP reduce(join_fun_t op, const std::vector<vespalib::string> &dimensions) const override; + std::unique_ptr<Tensor> modify(join_fun_t op, const CellValues &cellValues) const override; bool equals(const Tensor &arg) const override; Tensor::UP clone() const override; eval::TensorSpec toSpec() const override; diff --git a/eval/src/vespa/eval/tensor/sparse/CMakeLists.txt b/eval/src/vespa/eval/tensor/sparse/CMakeLists.txt index 4cd682d8953..db31b08e3e4 100644 --- a/eval/src/vespa/eval/tensor/sparse/CMakeLists.txt +++ b/eval/src/vespa/eval/tensor/sparse/CMakeLists.txt @@ -6,6 +6,7 @@ vespa_add_library(eval_tensor_sparse OBJECT sparse_tensor_address_padder.cpp sparse_tensor_address_reducer.cpp sparse_tensor_match.cpp + sparse_tensor_modify.cpp sparse_tensor_builder.cpp sparse_tensor_unsorted_address_builder.cpp ) diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp index 1f57e297ea1..dd2befd4df8 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp @@ -5,6 +5,8 @@ #include "sparse_tensor_match.h" #include "sparse_tensor_apply.hpp" #include "sparse_tensor_reduce.hpp" +#include "sparse_tensor_modify.h" +#include <vespa/eval/tensor/cell_values.h> #include <vespa/eval/tensor/tensor_address_builder.h> #include <vespa/eval/tensor/tensor_apply.h> #include <vespa/eval/tensor/tensor_visitor.h> @@ -186,6 +188,17 @@ SparseTensor::reduce(join_fun_t op, return sparse::reduce(*this, dimensions, op); } +std::unique_ptr<Tensor> +SparseTensor::modify(join_fun_t op, const CellValues &cellValues) const +{ + Stash stash; + Cells cells; + copyCells(cells, _cells, stash); + SparseTensorModify modifier(op, _type, std::move(stash), std::move(cells)); + cellValues.accept(modifier); + return modifier.build(); +} + } VESPALIB_HASH_MAP_INSTANTIATE_H_E_M(vespalib::tensor::SparseTensorAddressRef, double, vespalib::hash<vespalib::tensor::SparseTensorAddressRef>, diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h index 3eeb122f48c..36a0c246d25 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h @@ -45,6 +45,7 @@ public: Tensor::UP apply(const CellFunction &func) const override; Tensor::UP join(join_fun_t function, const Tensor &arg) const override; Tensor::UP reduce(join_fun_t op, const std::vector<vespalib::string> &dimensions) const override; + std::unique_ptr<Tensor> modify(join_fun_t op, const CellValues &cellValues) const override; bool equals(const Tensor &arg) const override; Tensor::UP clone() const override; eval::TensorSpec toSpec() const override; diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_modify.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_modify.cpp new file mode 100644 index 00000000000..91d0fa3fcdd --- /dev/null +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_modify.cpp @@ -0,0 +1,45 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "sparse_tensor_modify.h" +#include <vespa/eval/tensor/tensor_address_element_iterator.h> + +namespace vespalib::tensor { + +SparseTensorModify::SparseTensorModify(join_fun_t op, const eval::ValueType &type, Stash &&stash, Cells &&cells) + : _op(op), + _type(type), + _stash(std::move(stash)), + _cells(std::move(cells)), + _addressBuilder() +{ +} + +SparseTensorModify::~SparseTensorModify() = default; + +void +SparseTensorModify::visit(const TensorAddress &address, double value) +{ + TensorAddressElementIterator addressElementIterator(address); + + _addressBuilder.clear(); + for (const auto &dimension : _type.dimensions()) { + if (addressElementIterator.skipToDimension(dimension.name)) { + _addressBuilder.add(addressElementIterator.label()); + } else { + _addressBuilder.addUndefined(); + } + } + auto addressRef = _addressBuilder.getAddressRef(); + auto cellItr = _cells.find(addressRef); + if (cellItr != _cells.end()) { + cellItr->second = _op(cellItr->second, value); + } +} + +std::unique_ptr<Tensor> +SparseTensorModify::build() +{ + return std::make_unique<SparseTensor>(std::move(_type), std::move(_cells), std::move(_stash)); +} + +} diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_modify.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_modify.h new file mode 100644 index 00000000000..17a2ad3a2c1 --- /dev/null +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_modify.h @@ -0,0 +1,33 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/eval/tensor/tensor_visitor.h> +#include "sparse_tensor.h" +#include "sparse_tensor_address_builder.h" + +namespace vespalib::tensor { + +/* + * This class handles tensor modify update on a sparse tensor. + * For all cells visited, a join function is applied to determine + * the new cell value. + */ +class SparseTensorModify : public TensorVisitor +{ + using join_fun_t = Tensor::join_fun_t; + using Cells = SparseTensor::Cells; + join_fun_t _op; + eval::ValueType _type; + Stash _stash; + Cells _cells; + SparseTensorAddressBuilder _addressBuilder; + +public: + SparseTensorModify(join_fun_t op, const eval::ValueType &type, Stash &&stash, Cells &&cells); + ~SparseTensorModify(); + void visit(const TensorAddress &address, double value) override; + std::unique_ptr<Tensor> build(); +}; + +} diff --git a/eval/src/vespa/eval/tensor/tensor.h b/eval/src/vespa/eval/tensor/tensor.h index 8e31448e026..842fed436cd 100644 --- a/eval/src/vespa/eval/tensor/tensor.h +++ b/eval/src/vespa/eval/tensor/tensor.h @@ -14,6 +14,7 @@ namespace eval { class BinaryOperation; } namespace tensor { class TensorVisitor; +class CellValues; /** * Interface for operations on a tensor (sparse multi-dimensional array). @@ -33,6 +34,12 @@ struct Tensor : public eval::Tensor virtual Tensor::UP apply(const CellFunction &func) const = 0; virtual Tensor::UP join(join_fun_t function, const Tensor &arg) const = 0; virtual Tensor::UP reduce(join_fun_t op, const std::vector<vespalib::string> &dimensions) const = 0; + /* + * Creates a new tensor by modifying the underlying cells matching + * the given cells applying a join function to determine the new + * cell value. + */ + virtual std::unique_ptr<Tensor> modify(join_fun_t op, const CellValues &cellValues) const = 0; virtual bool equals(const Tensor &arg) const = 0; // want to remove, but needed by document virtual Tensor::UP clone() const = 0; // want to remove, but needed by document virtual eval::TensorSpec toSpec() const = 0; diff --git a/eval/src/vespa/eval/tensor/tensor_mapper.cpp b/eval/src/vespa/eval/tensor/tensor_mapper.cpp index f1039b08816..c91237e4994 100644 --- a/eval/src/vespa/eval/tensor/tensor_mapper.cpp +++ b/eval/src/vespa/eval/tensor/tensor_mapper.cpp @@ -8,6 +8,7 @@ #include "wrapped_simple_tensor.h" #include <vespa/eval/tensor/sparse/direct_sparse_tensor_builder.h> #include <vespa/eval/tensor/dense/dense_tensor.h> +#include <vespa/eval/tensor/dense/dense_tensor_address_mapper.h> #include <vespa/vespalib/stllike/hash_map.hpp> #include <limits> @@ -103,25 +104,6 @@ SparseTensorMapper<TensorT>::map(const Tensor &tensor, //----------------------------------------------------------------------------- -static constexpr uint32_t BAD_LABEL = std::numeric_limits<uint32_t>::max(); -static constexpr uint32_t BAD_ADDRESS = std::numeric_limits<uint32_t>::max(); - -uint32_t mapLabelToNumber(vespalib::stringref label) { - uint32_t result = 0; - for (char c : label) { - if (c < '0' || c > '9') { - return BAD_LABEL; // bad char - } - result = result * 10 + (c - '0'); - if (result > 100000000) { - return BAD_LABEL; // overflow - } - } - return result; -} - -//----------------------------------------------------------------------------- - class TensorTypeMapper : public TensorVisitor { ValueType _type; @@ -148,8 +130,8 @@ TensorTypeMapper::addressOK(const TensorAddress &address) for (const auto &dimension : _type.dimensions()) { if (addressIterator.skipToDimension(dimension.name)) { if (dimension.is_indexed()) { - uint32_t label = mapLabelToNumber(addressIterator.label()); - if (label == BAD_LABEL || + uint32_t label = DenseTensorAddressMapper::mapLabelToNumber(addressIterator.label()); + if (label == DenseTensorAddressMapper::BAD_LABEL || (dimension.is_bound() && label >= dimIterator->size)) { return false; } @@ -171,8 +153,8 @@ TensorTypeMapper::expandUnboundDimensions(const TensorAddress &address) for (const auto &dimension : _type.dimensions()) { if (addressIterator.skipToDimension(dimension.name)) { if (dimension.is_indexed()) { - uint32_t label = mapLabelToNumber(addressIterator.label()); - if (label != BAD_LABEL && + uint32_t label = DenseTensorAddressMapper::mapLabelToNumber(addressIterator.label()); + if (label != DenseTensorAddressMapper::BAD_LABEL && !dimension.is_bound() && label >= dimIterator->size) { dimIterator->size = label + 1; @@ -266,32 +248,11 @@ DenseTensorMapper::build() std::move(_cells)); } -uint32_t -DenseTensorMapper::mapAddressToIndex(const TensorAddress &address) -{ - uint32_t idx = 0; - TensorAddressElementIterator<TensorAddress> addressIterator(address); - for (const auto &dimension : _type.dimensions()) { - if (addressIterator.skipToDimension(dimension.name)) { - uint32_t label = mapLabelToNumber(addressIterator.label()); - if (label == BAD_LABEL || label >= dimension.size) { - return BAD_ADDRESS; - } - idx = idx * dimension.size + label; - addressIterator.next(); - } else { - // output dimension not in input - idx = idx * dimension.size; - } - } - return idx; -} - void DenseTensorMapper::visit(const TensorAddress &address, double value) { - uint32_t idx = mapAddressToIndex(address); - if (idx != BAD_ADDRESS) { + uint32_t idx = DenseTensorAddressMapper::mapAddressToIndex(address, _type); + if (idx != DenseTensorAddressMapper::BAD_ADDRESS) { assert(idx < _cells.size()); _cells[idx] += value; } @@ -340,8 +301,8 @@ WrappedTensorMapper::visit(const TensorAddress &address, double value) for (const auto &dimension: _type.dimensions()) { if (addressIterator.skipToDimension(dimension.name)) { if (dimension.is_indexed()) { - uint32_t label = mapLabelToNumber(addressIterator.label()); - if ((label == BAD_LABEL) || (label >= dimension.size)) { + uint32_t label = DenseTensorAddressMapper::mapLabelToNumber(addressIterator.label()); + if ((label == DenseTensorAddressMapper::BAD_LABEL) || (label >= dimension.size)) { return; // bad address; ignore cell } addr.emplace(dimension.name, label); diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp index 5c6c70099ad..394335d9b67 100644 --- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp +++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp @@ -77,4 +77,10 @@ WrappedSimpleTensor::reduce(join_fun_t, const std::vector<vespalib::string> &) c LOG_ABORT("should not be reached"); } +std::unique_ptr<Tensor> +WrappedSimpleTensor::modify(join_fun_t, const CellValues &) const +{ + LOG_ABORT("should not be reached"); +} + } // namespace vespalib::tensor diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h index ae7907845e1..511fbc3c795 100644 --- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h +++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h @@ -38,6 +38,7 @@ public: Tensor::UP apply(const CellFunction &) const override; Tensor::UP join(join_fun_t, const Tensor &) const override; Tensor::UP reduce(join_fun_t, const std::vector<vespalib::string> &) const override; + std::unique_ptr<Tensor> modify(join_fun_t, const CellValues &) const override; }; } // namespace vespalib::tensor |