summaryrefslogtreecommitdiffstats
path: root/document
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@broadpark.no>2019-02-13 12:31:58 +0100
committerTor Egge <Tor.Egge@broadpark.no>2019-02-13 12:34:18 +0100
commite6935ceb834b912a0b63b4e7fb9e1c5bceadaab0 (patch)
tree6bb285c89a183e5aefa0fa911d3de6eba213fb5a /document
parentb9869d95dd4d80e23f15d610756924aaa12ea28b (diff)
Prepare for tracking tensor type in document module (C++), aligning
with java implementation.
Diffstat (limited to 'document')
-rw-r--r--document/src/tests/documentupdatetestcase.cpp10
-rw-r--r--document/src/vespa/document/datatype/CMakeLists.txt1
-rw-r--r--document/src/vespa/document/datatype/datatype.cpp3
-rw-r--r--document/src/vespa/document/datatype/datatype.h1
-rw-r--r--document/src/vespa/document/datatype/primitivedatatype.cpp1
-rw-r--r--document/src/vespa/document/datatype/tensor_data_type.cpp42
-rw-r--r--document/src/vespa/document/datatype/tensor_data_type.h23
-rw-r--r--document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp19
-rw-r--r--document/src/vespa/document/fieldvalue/tensorfieldvalue.h4
-rw-r--r--document/src/vespa/document/repo/documenttyperepo.cpp30
-rw-r--r--document/src/vespa/document/update/tensor_add_update.cpp3
-rw-r--r--document/src/vespa/document/update/tensor_modify_update.cpp3
12 files changed, 126 insertions, 14 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp
index 5100e3683e5..e93738f2afb 100644
--- a/document/src/tests/documentupdatetestcase.cpp
+++ b/document/src/tests/documentupdatetestcase.cpp
@@ -3,6 +3,7 @@
#include <vespa/document/fieldvalue/fieldvalues.h>
#include <vespa/document/update/documentupdate.h>
#include <vespa/document/base/testdocman.h>
+#include <vespa/document/datatype/tensor_data_type.h>
#include <vespa/document/update/addvalueupdate.h>
#include <vespa/document/update/arithmeticvalueupdate.h>
@@ -944,7 +945,8 @@ DocumentUpdateTest::testTensorAssignUpdate()
CPPUNIT_ASSERT(!doc->getValue("tensor"));
Document updated(*doc);
FieldValue::UP new_value(createTensorFieldValue());
- testValueUpdate(AssignValueUpdate(*new_value), *DataType::TENSOR);
+ auto tensorDataType = std::make_unique<TensorDataType>();
+ testValueUpdate(AssignValueUpdate(*new_value), *tensorDataType);
DocumentUpdate upd(docMan.getTypeRepo(), *doc->getDataType(), doc->getId());
upd.addUpdate(FieldUpdate(upd.getType().getField("tensor")).addUpdate(AssignValueUpdate(*new_value)));
upd.applyTo(updated);
@@ -978,7 +980,8 @@ DocumentUpdateTest::testTensorAddUpdate()
auto oldTensor = createTensorFieldValueWith2Cells();
updated.setValue(updated.getField("tensor"), *oldTensor);
CPPUNIT_ASSERT(*doc != updated);
- testValueUpdate(*createTensorAddUpdate(), *DataType::TENSOR);
+ auto tensorDataType = std::make_unique<TensorDataType>();
+ testValueUpdate(*createTensorAddUpdate(), *tensorDataType);
std::string expTensorAddUpdateString("TensorAddUpdate("
"{TensorFieldValue: "
"{\"dimensions\":[\"x\",\"y\"],"
@@ -1006,7 +1009,8 @@ DocumentUpdateTest::testTensorModifyUpdate()
auto oldTensor = createTensorFieldValueWith2Cells();
updated.setValue(updated.getField("tensor"), *oldTensor);
CPPUNIT_ASSERT(*doc != updated);
- testValueUpdate(*createTensorModifyUpdate(), *DataType::TENSOR);
+ auto tensorDataType = std::make_unique<TensorDataType>();
+ testValueUpdate(*createTensorModifyUpdate(), *tensorDataType);
std::string expTensorModifyUpdateString("TensorModifyUpdate(replace,"
"{TensorFieldValue: "
"{\"dimensions\":[\"x\",\"y\"],"
diff --git a/document/src/vespa/document/datatype/CMakeLists.txt b/document/src/vespa/document/datatype/CMakeLists.txt
index d6b432b9b04..8e527677b50 100644
--- a/document/src/vespa/document/datatype/CMakeLists.txt
+++ b/document/src/vespa/document/datatype/CMakeLists.txt
@@ -13,6 +13,7 @@ vespa_add_library(document_datatypes OBJECT
primitivedatatype.cpp
structdatatype.cpp
structureddatatype.cpp
+ tensor_data_type.cpp
urldatatype.cpp
weightedsetdatatype.cpp
referencedatatype.cpp
diff --git a/document/src/vespa/document/datatype/datatype.cpp b/document/src/vespa/document/datatype/datatype.cpp
index 8c17ca4e383..fae606839a7 100644
--- a/document/src/vespa/document/datatype/datatype.cpp
+++ b/document/src/vespa/document/datatype/datatype.cpp
@@ -28,7 +28,6 @@ DocumentType DOCUMENT_OBJ("document");
WeightedSetDataType TAG_OBJ(*DataType::STRING, true, true);
PrimitiveDataType URI_OBJ(DataType::T_URI);
PrimitiveDataType PREDICATE_OBJ(DataType::T_PREDICATE);
-PrimitiveDataType TENSOR_OBJ(DataType::T_TENSOR);
} // namespace
@@ -45,7 +44,6 @@ const DocumentType *const DataType::DOCUMENT(&DOCUMENT_OBJ);
const DataType *const DataType::TAG(&TAG_OBJ);
const DataType *const DataType::URI(&URI_OBJ);
const DataType *const DataType::PREDICATE(&PREDICATE_OBJ);
-const DataType *const DataType::TENSOR(&TENSOR_OBJ);
namespace {
@@ -113,7 +111,6 @@ DataType::getDefaultDataTypes()
types.push_back(TAG);
types.push_back(URI);
types.push_back(PREDICATE);
- types.push_back(TENSOR);
return types;
}
diff --git a/document/src/vespa/document/datatype/datatype.h b/document/src/vespa/document/datatype/datatype.h
index 723e7c69ed6..95f0e8b9a64 100644
--- a/document/src/vespa/document/datatype/datatype.h
+++ b/document/src/vespa/document/datatype/datatype.h
@@ -96,7 +96,6 @@ public:
static const DataType *const TAG;
static const DataType *const URI;
static const DataType *const PREDICATE;
- static const DataType *const TENSOR;
/** Used by type manager to fetch default types to register. */
static std::vector<const DataType *> getDefaultDataTypes();
diff --git a/document/src/vespa/document/datatype/primitivedatatype.cpp b/document/src/vespa/document/datatype/primitivedatatype.cpp
index 7ec47f52d9c..e831cd4e8f4 100644
--- a/document/src/vespa/document/datatype/primitivedatatype.cpp
+++ b/document/src/vespa/document/datatype/primitivedatatype.cpp
@@ -69,7 +69,6 @@ PrimitiveDataType::createFieldValue() const
case T_BOOL: return std::make_unique<BoolFieldValue>();
case T_BYTE: return std::make_unique<ByteFieldValue>();
case T_PREDICATE: return std::make_unique<PredicateFieldValue>();
- case T_TENSOR: return std::make_unique<TensorFieldValue>();
}
LOG_ABORT("getId() returned value out of range");
}
diff --git a/document/src/vespa/document/datatype/tensor_data_type.cpp b/document/src/vespa/document/datatype/tensor_data_type.cpp
new file mode 100644
index 00000000000..df799509ab1
--- /dev/null
+++ b/document/src/vespa/document/datatype/tensor_data_type.cpp
@@ -0,0 +1,42 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "tensor_data_type.h"
+#include <vespa/document/fieldvalue/tensorfieldvalue.h>
+#include <vespa/vespalib/util/exceptions.h>
+#include <sstream>
+
+namespace document {
+
+IMPLEMENT_IDENTIFIABLE_ABSTRACT(TensorDataType, DataType);
+
+TensorDataType::TensorDataType()
+ : PrimitiveDataType(DataType::T_TENSOR)
+{
+}
+
+FieldValue::UP
+TensorDataType::createFieldValue() const
+{
+ return std::make_unique<TensorFieldValue>();
+}
+
+TensorDataType*
+TensorDataType::clone() const
+{
+ return new TensorDataType(*this);
+}
+
+void
+TensorDataType::print(std::ostream& out, bool verbose, const std::string& indent) const
+{
+ (void) verbose; (void) indent;
+ out << "TensorDataType()";
+}
+
+std::unique_ptr<const TensorDataType>
+TensorDataType::fromSpec([[maybe_unused]] const vespalib::string &spec)
+{
+ return std::make_unique<const TensorDataType>();
+}
+
+} // document
diff --git a/document/src/vespa/document/datatype/tensor_data_type.h b/document/src/vespa/document/datatype/tensor_data_type.h
new file mode 100644
index 00000000000..aafc72eabfb
--- /dev/null
+++ b/document/src/vespa/document/datatype/tensor_data_type.h
@@ -0,0 +1,23 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+#pragma once
+
+#include "primitivedatatype.h"
+
+namespace document {
+
+/*
+ * This class describes a tensor type.
+ */
+class TensorDataType : public PrimitiveDataType {
+public:
+ TensorDataType();
+
+ std::unique_ptr<FieldValue> createFieldValue() const override;
+ TensorDataType* clone() const override;
+ void print(std::ostream&, bool verbose, const std::string& indent) const override;
+ static std::unique_ptr<const TensorDataType> fromSpec(const vespalib::string &spec);
+
+ DECLARE_IDENTIFIABLE_ABSTRACT(TensorDataType);
+};
+
+}
diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp
index 82a87575fbe..588cc7dbaaf 100644
--- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp
+++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp
@@ -1,7 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "tensorfieldvalue.h"
-#include <vespa/document/datatype/datatype.h>
+#include <vespa/document/datatype/tensor_data_type.h>
#include <vespa/vespalib/util/xmlstream.h>
#include <vespa/eval/tensor/tensor.h>
#include <vespa/eval/tensor/serialization/slime_binary_format.h>
@@ -16,8 +16,20 @@ using namespace vespalib::xml;
namespace document {
+namespace {
+
+TensorDataType emptyTensorDataType;
+
+}
+
TensorFieldValue::TensorFieldValue()
+ : TensorFieldValue(&emptyTensorDataType)
+{
+}
+
+TensorFieldValue::TensorFieldValue(const TensorDataType *dataType)
: FieldValue(),
+ _dataType(dataType),
_tensor(),
_altered(true)
{
@@ -25,6 +37,7 @@ TensorFieldValue::TensorFieldValue()
TensorFieldValue::TensorFieldValue(const TensorFieldValue &rhs)
: FieldValue(),
+ _dataType(rhs._dataType),
_tensor(),
_altered(true)
{
@@ -36,6 +49,7 @@ TensorFieldValue::TensorFieldValue(const TensorFieldValue &rhs)
TensorFieldValue::TensorFieldValue(TensorFieldValue &&rhs)
: FieldValue(),
+ _dataType(rhs._dataType),
_tensor(),
_altered(true)
{
@@ -52,6 +66,7 @@ TensorFieldValue &
TensorFieldValue::operator=(const TensorFieldValue &rhs)
{
if (this != &rhs) {
+ _dataType = rhs._dataType;
if (rhs._tensor) {
_tensor = rhs._tensor->clone();
} else {
@@ -90,7 +105,7 @@ TensorFieldValue::accept(ConstFieldValueVisitor &visitor) const
const DataType *
TensorFieldValue::getDataType() const
{
- return DataType::TENSOR;
+ return _dataType;
}
diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h
index 5c831601d1a..ff3fceb980a 100644
--- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h
+++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h
@@ -8,15 +8,19 @@ namespace vespalib { namespace tensor { class Tensor; } }
namespace document {
+class TensorDataType;
+
/**
* Field value representing a tensor.
*/
class TensorFieldValue : public FieldValue {
private:
+ const TensorDataType *_dataType;
std::unique_ptr<vespalib::tensor::Tensor> _tensor;
bool _altered;
public:
TensorFieldValue();
+ TensorFieldValue(const TensorDataType *dataType);
TensorFieldValue(const TensorFieldValue &rhs);
TensorFieldValue(TensorFieldValue &&rhs);
~TensorFieldValue();
diff --git a/document/src/vespa/document/repo/documenttyperepo.cpp b/document/src/vespa/document/repo/documenttyperepo.cpp
index a320750e0d5..bdecd521f44 100644
--- a/document/src/vespa/document/repo/documenttyperepo.cpp
+++ b/document/src/vespa/document/repo/documenttyperepo.cpp
@@ -10,6 +10,7 @@
#include <vespa/document/datatype/urldatatype.h>
#include <vespa/document/datatype/weightedsetdatatype.h>
#include <vespa/document/datatype/referencedatatype.h>
+#include <vespa/document/datatype/tensor_data_type.h>
#include <vespa/vespalib/stllike/hash_map.hpp>
#include <vespa/vespalib/util/exceptions.h>
#include <vespa/document/config/config-documenttypes.h>
@@ -66,6 +67,7 @@ void DeleteMapContent(Map &m) {
class Repo {
vector<const DataType *> _owned_types;
hash_map<int32_t, const DataType *> _types;
+ hash_map<string, const DataType *> _tensorTypes;
hash_map<string, const DataType *> _name_map;
public:
@@ -75,13 +77,16 @@ public:
bool addDataType(const DataType &type);
template <typename T> void addDataType(unique_ptr<T> type);
+ const DataType &addTensorType(const string &spec);
const DataType *lookup(int32_t id) const;
const DataType *lookup(stringref name) const;
const DataType &findOrThrow(int32_t id) const;
+ const DataType &findOrThrowOrCreate(int32_t id, const string &detailedType);
};
void Repo::inherit(const Repo &parent) {
_types.insert(parent._types.begin(), parent._types.end());
+ _tensorTypes.insert(parent._tensorTypes.begin(), parent._tensorTypes.end());
_name_map.insert(parent._name_map.begin(), parent._name_map.end());
}
@@ -114,6 +119,18 @@ void Repo::addDataType(unique_ptr<T> type) {
}
}
+
+const DataType &
+Repo::addTensorType(const string &spec)
+{
+ auto type = TensorDataType::fromSpec(spec);
+ auto insres = _tensorTypes.insert(std::make_pair(spec, type.get()));
+ if (insres.second) {
+ _owned_types.push_back(type.release());
+ }
+ return *insres.first->second;
+}
+
template <typename Map>
typename Map::mapped_type FindPtr(const Map &m, typename Map::key_type key) {
typename Map::const_iterator it = m.find(key);
@@ -139,6 +156,15 @@ const DataType &Repo::findOrThrow(int32_t id) const {
throw IllegalArgumentException(make_string("Unknown datatype %d", id));
}
+const DataType &
+Repo::findOrThrowOrCreate(int32_t id, const string &detailedType)
+{
+ if (id != DataType::T_TENSOR) {
+ return findOrThrow(id);
+ }
+ return addTensorType(detailedType);
+}
+
class AnnotationTypeRepo {
vector<const AnnotationType *> _owned_types;
hash_map<int32_t, AnnotationType *> _annotation_types;
@@ -231,11 +257,11 @@ void setAnnotationDataTypes(const vector<DocumenttypesConfig::Documenttype::Anno
typedef DocumenttypesConfig::Documenttype::Datatype Datatype;
-void addField(const Datatype::Sstruct::Field &field, const Repo &repo, StructDataType &struct_type, bool isHeaderField)
+void addField(const Datatype::Sstruct::Field &field, Repo &repo, StructDataType &struct_type, bool isHeaderField)
{
LOG(spam, "Adding field %s to %s (header: %s)",
field.name.c_str(), struct_type.getName().c_str(), isHeaderField ? "yes" : "no");
- const DataType &field_type = repo.findOrThrow(field.datatype);
+ const DataType &field_type = repo.findOrThrowOrCreate(field.datatype, field.detailedtype);
struct_type.addField(Field(field.name, field.id, field_type, isHeaderField));
}
diff --git a/document/src/vespa/document/update/tensor_add_update.cpp b/document/src/vespa/document/update/tensor_add_update.cpp
index 6d251b04937..96dc4d5e215 100644
--- a/document/src/vespa/document/update/tensor_add_update.cpp
+++ b/document/src/vespa/document/update/tensor_add_update.cpp
@@ -3,6 +3,7 @@
#include "tensor_add_update.h"
#include <vespa/document/base/exceptions.h>
#include <vespa/document/base/field.h>
+#include <vespa/document/datatype/tensor_data_type.h>
#include <vespa/document/fieldvalue/document.h>
#include <vespa/document/fieldvalue/tensorfieldvalue.h>
#include <vespa/document/serialization/vespadocumentdeserializer.h>
@@ -71,7 +72,7 @@ TensorAddUpdate::operator==(const ValueUpdate &other) const
void
TensorAddUpdate::checkCompatibility(const Field& field) const
{
- if (field.getDataType() != *DataType::TENSOR) {
+ if (field.getDataType().getClass().id() != TensorDataType::classId) {
throw IllegalArgumentException(make_string(
"Can not perform tensor add update on non-tensor field '%s'.",
field.getName().data()), VESPA_STRLOC);
diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp
index bb846581697..962543984d2 100644
--- a/document/src/vespa/document/update/tensor_modify_update.cpp
+++ b/document/src/vespa/document/update/tensor_modify_update.cpp
@@ -3,6 +3,7 @@
#include "tensor_modify_update.h"
#include <vespa/document/base/exceptions.h>
#include <vespa/document/base/field.h>
+#include <vespa/document/datatype/tensor_data_type.h>
#include <vespa/document/fieldvalue/document.h>
#include <vespa/document/fieldvalue/tensorfieldvalue.h>
#include <vespa/document/serialization/vespadocumentdeserializer.h>
@@ -127,7 +128,7 @@ TensorModifyUpdate::operator==(const ValueUpdate &other) const
void
TensorModifyUpdate::checkCompatibility(const Field& field) const
{
- if (field.getDataType() != *DataType::TENSOR) {
+ if (field.getDataType().getClass().id() != TensorDataType::classId) {
throw IllegalArgumentException(make_string(
"Can not perform tensor modify update on non-tensor field '%s'.",
field.getName().data()), VESPA_STRLOC);