diff options
13 files changed, 130 insertions, 15 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index 5100e3683e5..e93738f2afb 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -3,6 +3,7 @@ #include <vespa/document/fieldvalue/fieldvalues.h> #include <vespa/document/update/documentupdate.h> #include <vespa/document/base/testdocman.h> +#include <vespa/document/datatype/tensor_data_type.h> #include <vespa/document/update/addvalueupdate.h> #include <vespa/document/update/arithmeticvalueupdate.h> @@ -944,7 +945,8 @@ DocumentUpdateTest::testTensorAssignUpdate() CPPUNIT_ASSERT(!doc->getValue("tensor")); Document updated(*doc); FieldValue::UP new_value(createTensorFieldValue()); - testValueUpdate(AssignValueUpdate(*new_value), *DataType::TENSOR); + auto tensorDataType = std::make_unique<TensorDataType>(); + testValueUpdate(AssignValueUpdate(*new_value), *tensorDataType); DocumentUpdate upd(docMan.getTypeRepo(), *doc->getDataType(), doc->getId()); upd.addUpdate(FieldUpdate(upd.getType().getField("tensor")).addUpdate(AssignValueUpdate(*new_value))); upd.applyTo(updated); @@ -978,7 +980,8 @@ DocumentUpdateTest::testTensorAddUpdate() auto oldTensor = createTensorFieldValueWith2Cells(); updated.setValue(updated.getField("tensor"), *oldTensor); CPPUNIT_ASSERT(*doc != updated); - testValueUpdate(*createTensorAddUpdate(), *DataType::TENSOR); + auto tensorDataType = std::make_unique<TensorDataType>(); + testValueUpdate(*createTensorAddUpdate(), *tensorDataType); std::string expTensorAddUpdateString("TensorAddUpdate(" "{TensorFieldValue: " "{\"dimensions\":[\"x\",\"y\"]," @@ -1006,7 +1009,8 @@ DocumentUpdateTest::testTensorModifyUpdate() auto oldTensor = createTensorFieldValueWith2Cells(); updated.setValue(updated.getField("tensor"), *oldTensor); CPPUNIT_ASSERT(*doc != updated); - testValueUpdate(*createTensorModifyUpdate(), *DataType::TENSOR); + auto tensorDataType = std::make_unique<TensorDataType>(); + testValueUpdate(*createTensorModifyUpdate(), *tensorDataType); std::string expTensorModifyUpdateString("TensorModifyUpdate(replace," "{TensorFieldValue: " "{\"dimensions\":[\"x\",\"y\"]," diff --git a/document/src/vespa/document/datatype/CMakeLists.txt b/document/src/vespa/document/datatype/CMakeLists.txt index d6b432b9b04..8e527677b50 100644 --- a/document/src/vespa/document/datatype/CMakeLists.txt +++ b/document/src/vespa/document/datatype/CMakeLists.txt @@ -13,6 +13,7 @@ vespa_add_library(document_datatypes OBJECT primitivedatatype.cpp structdatatype.cpp structureddatatype.cpp + tensor_data_type.cpp urldatatype.cpp weightedsetdatatype.cpp referencedatatype.cpp diff --git a/document/src/vespa/document/datatype/datatype.cpp b/document/src/vespa/document/datatype/datatype.cpp index 8c17ca4e383..fae606839a7 100644 --- a/document/src/vespa/document/datatype/datatype.cpp +++ b/document/src/vespa/document/datatype/datatype.cpp @@ -28,7 +28,6 @@ DocumentType DOCUMENT_OBJ("document"); WeightedSetDataType TAG_OBJ(*DataType::STRING, true, true); PrimitiveDataType URI_OBJ(DataType::T_URI); PrimitiveDataType PREDICATE_OBJ(DataType::T_PREDICATE); -PrimitiveDataType TENSOR_OBJ(DataType::T_TENSOR); } // namespace @@ -45,7 +44,6 @@ const DocumentType *const DataType::DOCUMENT(&DOCUMENT_OBJ); const DataType *const DataType::TAG(&TAG_OBJ); const DataType *const DataType::URI(&URI_OBJ); const DataType *const DataType::PREDICATE(&PREDICATE_OBJ); -const DataType *const DataType::TENSOR(&TENSOR_OBJ); namespace { @@ -113,7 +111,6 @@ DataType::getDefaultDataTypes() types.push_back(TAG); types.push_back(URI); types.push_back(PREDICATE); - types.push_back(TENSOR); return types; } diff --git a/document/src/vespa/document/datatype/datatype.h b/document/src/vespa/document/datatype/datatype.h index 723e7c69ed6..95f0e8b9a64 100644 --- a/document/src/vespa/document/datatype/datatype.h +++ b/document/src/vespa/document/datatype/datatype.h @@ -96,7 +96,6 @@ public: static const DataType *const TAG; static const DataType *const URI; static const DataType *const PREDICATE; - static const DataType *const TENSOR; /** Used by type manager to fetch default types to register. */ static std::vector<const DataType *> getDefaultDataTypes(); diff --git a/document/src/vespa/document/datatype/primitivedatatype.cpp b/document/src/vespa/document/datatype/primitivedatatype.cpp index 7ec47f52d9c..e831cd4e8f4 100644 --- a/document/src/vespa/document/datatype/primitivedatatype.cpp +++ b/document/src/vespa/document/datatype/primitivedatatype.cpp @@ -69,7 +69,6 @@ PrimitiveDataType::createFieldValue() const case T_BOOL: return std::make_unique<BoolFieldValue>(); case T_BYTE: return std::make_unique<ByteFieldValue>(); case T_PREDICATE: return std::make_unique<PredicateFieldValue>(); - case T_TENSOR: return std::make_unique<TensorFieldValue>(); } LOG_ABORT("getId() returned value out of range"); } diff --git a/document/src/vespa/document/datatype/tensor_data_type.cpp b/document/src/vespa/document/datatype/tensor_data_type.cpp new file mode 100644 index 00000000000..df799509ab1 --- /dev/null +++ b/document/src/vespa/document/datatype/tensor_data_type.cpp @@ -0,0 +1,42 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "tensor_data_type.h" +#include <vespa/document/fieldvalue/tensorfieldvalue.h> +#include <vespa/vespalib/util/exceptions.h> +#include <sstream> + +namespace document { + +IMPLEMENT_IDENTIFIABLE_ABSTRACT(TensorDataType, DataType); + +TensorDataType::TensorDataType() + : PrimitiveDataType(DataType::T_TENSOR) +{ +} + +FieldValue::UP +TensorDataType::createFieldValue() const +{ + return std::make_unique<TensorFieldValue>(); +} + +TensorDataType* +TensorDataType::clone() const +{ + return new TensorDataType(*this); +} + +void +TensorDataType::print(std::ostream& out, bool verbose, const std::string& indent) const +{ + (void) verbose; (void) indent; + out << "TensorDataType()"; +} + +std::unique_ptr<const TensorDataType> +TensorDataType::fromSpec([[maybe_unused]] const vespalib::string &spec) +{ + return std::make_unique<const TensorDataType>(); +} + +} // document diff --git a/document/src/vespa/document/datatype/tensor_data_type.h b/document/src/vespa/document/datatype/tensor_data_type.h new file mode 100644 index 00000000000..aafc72eabfb --- /dev/null +++ b/document/src/vespa/document/datatype/tensor_data_type.h @@ -0,0 +1,23 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "primitivedatatype.h" + +namespace document { + +/* + * This class describes a tensor type. + */ +class TensorDataType : public PrimitiveDataType { +public: + TensorDataType(); + + std::unique_ptr<FieldValue> createFieldValue() const override; + TensorDataType* clone() const override; + void print(std::ostream&, bool verbose, const std::string& indent) const override; + static std::unique_ptr<const TensorDataType> fromSpec(const vespalib::string &spec); + + DECLARE_IDENTIFIABLE_ABSTRACT(TensorDataType); +}; + +} diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp index 82a87575fbe..588cc7dbaaf 100644 --- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp +++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp @@ -1,7 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "tensorfieldvalue.h" -#include <vespa/document/datatype/datatype.h> +#include <vespa/document/datatype/tensor_data_type.h> #include <vespa/vespalib/util/xmlstream.h> #include <vespa/eval/tensor/tensor.h> #include <vespa/eval/tensor/serialization/slime_binary_format.h> @@ -16,8 +16,20 @@ using namespace vespalib::xml; namespace document { +namespace { + +TensorDataType emptyTensorDataType; + +} + TensorFieldValue::TensorFieldValue() + : TensorFieldValue(&emptyTensorDataType) +{ +} + +TensorFieldValue::TensorFieldValue(const TensorDataType *dataType) : FieldValue(), + _dataType(dataType), _tensor(), _altered(true) { @@ -25,6 +37,7 @@ TensorFieldValue::TensorFieldValue() TensorFieldValue::TensorFieldValue(const TensorFieldValue &rhs) : FieldValue(), + _dataType(rhs._dataType), _tensor(), _altered(true) { @@ -36,6 +49,7 @@ TensorFieldValue::TensorFieldValue(const TensorFieldValue &rhs) TensorFieldValue::TensorFieldValue(TensorFieldValue &&rhs) : FieldValue(), + _dataType(rhs._dataType), _tensor(), _altered(true) { @@ -52,6 +66,7 @@ TensorFieldValue & TensorFieldValue::operator=(const TensorFieldValue &rhs) { if (this != &rhs) { + _dataType = rhs._dataType; if (rhs._tensor) { _tensor = rhs._tensor->clone(); } else { @@ -90,7 +105,7 @@ TensorFieldValue::accept(ConstFieldValueVisitor &visitor) const const DataType * TensorFieldValue::getDataType() const { - return DataType::TENSOR; + return _dataType; } diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h index 5c831601d1a..ff3fceb980a 100644 --- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h +++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h @@ -8,15 +8,19 @@ namespace vespalib { namespace tensor { class Tensor; } } namespace document { +class TensorDataType; + /** * Field value representing a tensor. */ class TensorFieldValue : public FieldValue { private: + const TensorDataType *_dataType; std::unique_ptr<vespalib::tensor::Tensor> _tensor; bool _altered; public: TensorFieldValue(); + TensorFieldValue(const TensorDataType *dataType); TensorFieldValue(const TensorFieldValue &rhs); TensorFieldValue(TensorFieldValue &&rhs); ~TensorFieldValue(); diff --git a/document/src/vespa/document/repo/documenttyperepo.cpp b/document/src/vespa/document/repo/documenttyperepo.cpp index a320750e0d5..bdecd521f44 100644 --- a/document/src/vespa/document/repo/documenttyperepo.cpp +++ b/document/src/vespa/document/repo/documenttyperepo.cpp @@ -10,6 +10,7 @@ #include <vespa/document/datatype/urldatatype.h> #include <vespa/document/datatype/weightedsetdatatype.h> #include <vespa/document/datatype/referencedatatype.h> +#include <vespa/document/datatype/tensor_data_type.h> #include <vespa/vespalib/stllike/hash_map.hpp> #include <vespa/vespalib/util/exceptions.h> #include <vespa/document/config/config-documenttypes.h> @@ -66,6 +67,7 @@ void DeleteMapContent(Map &m) { class Repo { vector<const DataType *> _owned_types; hash_map<int32_t, const DataType *> _types; + hash_map<string, const DataType *> _tensorTypes; hash_map<string, const DataType *> _name_map; public: @@ -75,13 +77,16 @@ public: bool addDataType(const DataType &type); template <typename T> void addDataType(unique_ptr<T> type); + const DataType &addTensorType(const string &spec); const DataType *lookup(int32_t id) const; const DataType *lookup(stringref name) const; const DataType &findOrThrow(int32_t id) const; + const DataType &findOrThrowOrCreate(int32_t id, const string &detailedType); }; void Repo::inherit(const Repo &parent) { _types.insert(parent._types.begin(), parent._types.end()); + _tensorTypes.insert(parent._tensorTypes.begin(), parent._tensorTypes.end()); _name_map.insert(parent._name_map.begin(), parent._name_map.end()); } @@ -114,6 +119,18 @@ void Repo::addDataType(unique_ptr<T> type) { } } + +const DataType & +Repo::addTensorType(const string &spec) +{ + auto type = TensorDataType::fromSpec(spec); + auto insres = _tensorTypes.insert(std::make_pair(spec, type.get())); + if (insres.second) { + _owned_types.push_back(type.release()); + } + return *insres.first->second; +} + template <typename Map> typename Map::mapped_type FindPtr(const Map &m, typename Map::key_type key) { typename Map::const_iterator it = m.find(key); @@ -139,6 +156,15 @@ const DataType &Repo::findOrThrow(int32_t id) const { throw IllegalArgumentException(make_string("Unknown datatype %d", id)); } +const DataType & +Repo::findOrThrowOrCreate(int32_t id, const string &detailedType) +{ + if (id != DataType::T_TENSOR) { + return findOrThrow(id); + } + return addTensorType(detailedType); +} + class AnnotationTypeRepo { vector<const AnnotationType *> _owned_types; hash_map<int32_t, AnnotationType *> _annotation_types; @@ -231,11 +257,11 @@ void setAnnotationDataTypes(const vector<DocumenttypesConfig::Documenttype::Anno typedef DocumenttypesConfig::Documenttype::Datatype Datatype; -void addField(const Datatype::Sstruct::Field &field, const Repo &repo, StructDataType &struct_type, bool isHeaderField) +void addField(const Datatype::Sstruct::Field &field, Repo &repo, StructDataType &struct_type, bool isHeaderField) { LOG(spam, "Adding field %s to %s (header: %s)", field.name.c_str(), struct_type.getName().c_str(), isHeaderField ? "yes" : "no"); - const DataType &field_type = repo.findOrThrow(field.datatype); + const DataType &field_type = repo.findOrThrowOrCreate(field.datatype, field.detailedtype); struct_type.addField(Field(field.name, field.id, field_type, isHeaderField)); } diff --git a/document/src/vespa/document/update/tensor_add_update.cpp b/document/src/vespa/document/update/tensor_add_update.cpp index 6d251b04937..96dc4d5e215 100644 --- a/document/src/vespa/document/update/tensor_add_update.cpp +++ b/document/src/vespa/document/update/tensor_add_update.cpp @@ -3,6 +3,7 @@ #include "tensor_add_update.h" #include <vespa/document/base/exceptions.h> #include <vespa/document/base/field.h> +#include <vespa/document/datatype/tensor_data_type.h> #include <vespa/document/fieldvalue/document.h> #include <vespa/document/fieldvalue/tensorfieldvalue.h> #include <vespa/document/serialization/vespadocumentdeserializer.h> @@ -71,7 +72,7 @@ TensorAddUpdate::operator==(const ValueUpdate &other) const void TensorAddUpdate::checkCompatibility(const Field& field) const { - if (field.getDataType() != *DataType::TENSOR) { + if (field.getDataType().getClass().id() != TensorDataType::classId) { throw IllegalArgumentException(make_string( "Can not perform tensor add update on non-tensor field '%s'.", field.getName().data()), VESPA_STRLOC); diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp index bb846581697..962543984d2 100644 --- a/document/src/vespa/document/update/tensor_modify_update.cpp +++ b/document/src/vespa/document/update/tensor_modify_update.cpp @@ -3,6 +3,7 @@ #include "tensor_modify_update.h" #include <vespa/document/base/exceptions.h> #include <vespa/document/base/field.h> +#include <vespa/document/datatype/tensor_data_type.h> #include <vespa/document/fieldvalue/document.h> #include <vespa/document/fieldvalue/tensorfieldvalue.h> #include <vespa/document/serialization/vespadocumentdeserializer.h> @@ -127,7 +128,7 @@ TensorModifyUpdate::operator==(const ValueUpdate &other) const void TensorModifyUpdate::checkCompatibility(const Field& field) const { - if (field.getDataType() != *DataType::TENSOR) { + if (field.getDataType().getClass().id() != TensorDataType::classId) { throw IllegalArgumentException(make_string( "Can not perform tensor modify update on non-tensor field '%s'.", field.getName().data()), VESPA_STRLOC); diff --git a/searchlib/src/vespa/searchlib/index/doctypebuilder.cpp b/searchlib/src/vespa/searchlib/index/doctypebuilder.cpp index 836d16825f5..aabb2f3c296 100644 --- a/searchlib/src/vespa/searchlib/index/doctypebuilder.cpp +++ b/searchlib/src/vespa/searchlib/index/doctypebuilder.cpp @@ -2,6 +2,7 @@ #include "doctypebuilder.h" #include <vespa/document/datatype/urldatatype.h> +#include <vespa/document/datatype/tensor_data_type.h> #include <vespa/document/repo/configbuilder.h> using namespace document; @@ -9,6 +10,8 @@ using namespace document; namespace search::index { namespace { +TensorDataType tensorDataType; + const DataType *convert(Schema::DataType type) { switch (type) { case schema::DataType::BOOL: @@ -33,7 +36,7 @@ const DataType *convert(Schema::DataType type) { case schema::DataType::BOOLEANTREE: return DataType::PREDICATE; case schema::DataType::TENSOR: - return DataType::TENSOR; + return &tensorDataType; default: break; } |