aboutsummaryrefslogtreecommitdiffstats
path: root/document/src/vespa/document/update/tensor_add_update.cpp
blob: 4110a94693f23ce0e262dd7ac3d925b9449c5450 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include "tensor_add_update.h"
#include "tensor_partial_update.h"
#include <vespa/document/base/exceptions.h>
#include <vespa/document/base/field.h>
#include <vespa/document/datatype/tensor_data_type.h>
#include <vespa/document/fieldvalue/document.h>
#include <vespa/document/fieldvalue/tensorfieldvalue.h>
#include <vespa/document/serialization/vespadocumentdeserializer.h>
#include <vespa/eval/eval/value.h>
#include <vespa/eval/eval/fast_value.h>
#include <vespa/vespalib/stllike/asciistream.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/util/xmlstream.h>
#include <ostream>

using vespalib::IllegalArgumentException;
using vespalib::IllegalStateException;
using vespalib::make_string;
using vespalib::eval::FastValueBuilderFactory;

namespace document {

TensorAddUpdate::TensorAddUpdate()
    : ValueUpdate(TensorAdd),
      TensorUpdate(),
      _tensor()
{
}

TensorAddUpdate::TensorAddUpdate(std::unique_ptr<TensorFieldValue> tensor)
    : ValueUpdate(TensorAdd),
      TensorUpdate(),
      _tensor(std::move(tensor))
{
}

TensorAddUpdate::~TensorAddUpdate() = default;

bool
TensorAddUpdate::operator==(const ValueUpdate &other) const
{
    if (other.getType() != TensorAdd) {
        return false;
    }
    const TensorAddUpdate& o(static_cast<const TensorAddUpdate&>(other));
    if (*_tensor != *o._tensor) {
        return false;
    }
    return true;
}


void
TensorAddUpdate::checkCompatibility(const Field& field) const
{
    if ( ! field.getDataType().isTensor()) {
        throw IllegalArgumentException(make_string("Cannot perform tensor add update on non-tensor field '%s'",
                                                   field.getName().data()), VESPA_STRLOC);
    }
}

std::unique_ptr<vespalib::eval::Value>
TensorAddUpdate::applyTo(const vespalib::eval::Value &tensor) const
{
    return apply_to(tensor, FastValueBuilderFactory::get());
}

std::unique_ptr<vespalib::eval::Value>
TensorAddUpdate::apply_to(const Value &old_tensor,
                          const ValueBuilderFactory &factory) const
{
    if (auto addTensor = _tensor->getAsTensorPtr()) {
        return TensorPartialUpdate::add(old_tensor, *addTensor, factory);
    }
    return {};
}

bool
TensorAddUpdate::applyTo(FieldValue& value) const
{
    if (value.isA(FieldValue::Type::TENSOR)) {
        TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value);
        tensorFieldValue.make_empty_if_not_existing();
        auto oldTensor = tensorFieldValue.getAsTensorPtr();
        assert(oldTensor);
        auto newTensor = applyTo(*oldTensor);
        if (newTensor) {
            tensorFieldValue = std::move(newTensor);
        }
    } else {
        vespalib::string err = make_string("Unable to perform a tensor add update on a '%s' field value",
                                           value.className());
        throw IllegalStateException(err, VESPA_STRLOC);
    }
    return true;
}

void
TensorAddUpdate::printXml(XmlOutputStream& xos) const
{
    xos << "{TensorAddUpdate::printXml not yet implemented}";
}

void
TensorAddUpdate::print(std::ostream& out, bool verbose, const std::string& indent) const
{
    out << indent << "TensorAddUpdate(";
    if (_tensor) {
        _tensor->print(out, verbose, indent);
    }
    out << ")";
}

void
TensorAddUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream & stream)
{
    auto tensor = type.createFieldValue();
    if (tensor->isA(FieldValue::Type::TENSOR)) {
        _tensor.reset(static_cast<TensorFieldValue *>(tensor.release()));
    } else {
        vespalib::string err = make_string("Expected tensor field value, got a '%s' field value",
                                           tensor->className());
        throw IllegalStateException(err, VESPA_STRLOC);
    }
    VespaDocumentDeserializer deserializer(repo, stream, Document::getNewestSerializationVersion());
    deserializer.read(*_tensor);
}

}