diff options
author | Håvard Pettersen <havardpe@oath.com> | 2019-10-23 12:47:30 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2019-10-23 12:47:30 +0000 |
commit | cb6ff89a2bad626b215b50779a4301e5619295b6 (patch) | |
tree | bcf40dc1a40d47d6c3803fe906fcd1e3519d23b7 /eval/src | |
parent | dfec83cec5f464b09b72fc48ff1a5e114523aded (diff) |
use float in vm forest implementation
Diffstat (limited to 'eval/src')
-rw-r--r-- | eval/src/vespa/eval/eval/vm_forest.cpp | 51 |
1 files changed, 34 insertions, 17 deletions
diff --git a/eval/src/vespa/eval/eval/vm_forest.cpp b/eval/src/vespa/eval/eval/vm_forest.cpp index 127114d0ca5..ff72c3f6521 100644 --- a/eval/src/vespa/eval/eval/vm_forest.cpp +++ b/eval/src/vespa/eval/eval/vm_forest.cpp @@ -25,18 +25,25 @@ constexpr uint32_t INVERTED = 3; // bits: 20 4 4 4 // // LEAF: [const] -// bits: 64 +// bits: 32 // // LESS: [<feature+types>][const][skip] -// bits 32 64 32 +// bits 32 32 32 // // IN: [<feature+types>][skip|set size](set size)X[const] // bits 32 24 8 64 +// Note: We need to use double for set membership checks (IN) due to +// string hashing. + const double *as_double_ptr(const uint32_t *pos) { return reinterpret_cast<const double*>(pos); } +const float *as_float_ptr(const uint32_t *pos) { + return reinterpret_cast<const float*>(pos); +} + bool find_in(double value, const double *set, const double *end) { for (; set < end; ++set) { if (value == *set) { @@ -48,15 +55,15 @@ bool find_in(double value, const double *set, const double *end) { double less_only_find_leaf(const double *input, const uint32_t *pos, uint32_t node_type) { for (;;) { - if (input[pos[0] >> 12] < *as_double_ptr(pos + 1)) { + if (input[pos[0] >> 12] < *as_float_ptr(pos + 1)) { node_type = (pos[0] & 0xf0) >> 4; - pos += 4; + pos += 3; } else { node_type = (pos[0] & 0xf); - pos += 4 + pos[3]; + pos += 3 + pos[2]; } if (node_type == LEAF) { - return *as_double_ptr(pos); + return *as_float_ptr(pos); } } } @@ -64,15 +71,15 @@ double less_only_find_leaf(const double *input, const uint32_t *pos, uint32_t no double general_find_leaf(const double *input, const uint32_t *pos, uint32_t node_type) { for (;;) { if (node_type == LESS) { - if (input[pos[0] >> 12] < *as_double_ptr(pos + 1)) { + if (input[pos[0] >> 12] < *as_float_ptr(pos + 1)) { node_type = (pos[0] & 0xf0) >> 4; - pos += 4; + pos += 3; } else { node_type = (pos[0] & 0xf); - pos += 4 + pos[3]; + pos += 3 + pos[2]; } if (node_type == LEAF) { - return *as_double_ptr(pos); + return *as_float_ptr(pos); } } else if (node_type == IN) { if (find_in(input[pos[0] >> 12], as_double_ptr(pos + 2), @@ -85,18 +92,18 @@ double general_find_leaf(const double *input, const uint32_t *pos, uint32_t node pos += (2 + (2 * (pos[1] & 0xff))) + (pos[1] >> 8); } if (node_type == LEAF) { - return *as_double_ptr(pos); + return *as_float_ptr(pos); } } else { - if (input[pos[0] >> 12] >= *as_double_ptr(pos + 1)) { + if (input[pos[0] >> 12] >= *as_float_ptr(pos + 1)) { node_type = (pos[0] & 0xf); - pos += 4 + pos[3]; + pos += 3 + pos[2]; } else { node_type = (pos[0] & 0xf0) >> 4; - pos += 4; + pos += 3; } if (node_type == LEAF) { - return *as_double_ptr(pos); + return *as_float_ptr(pos); } } } @@ -104,7 +111,7 @@ double general_find_leaf(const double *input, const uint32_t *pos, uint32_t node //----------------------------------------------------------------------------- -void encode_const(double value, std::vector<uint32_t> &model_out) { +void encode_large_const(double value, std::vector<uint32_t> &model_out) { union { double d[1]; uint32_t i[2]; @@ -115,6 +122,16 @@ void encode_const(double value, std::vector<uint32_t> &model_out) { model_out.push_back(buf.i[1]); } +void encode_const(float value, std::vector<uint32_t> &model_out) { + union { + float f[1]; + uint32_t i[1]; + } buf; + assert(sizeof(buf) == sizeof(float)); + buf.f[0] = value; + model_out.push_back(buf.i[0]); +} + uint32_t encode_node(const nodes::Node &node_in, std::vector<uint32_t> &model_out); void encode_less(const nodes::Less &less, @@ -146,7 +163,7 @@ void encode_in(const nodes::In &in, size_t set_size_idx = model_out.size(); model_out.push_back(in.num_entries()); for (size_t i = 0; i < in.num_entries(); ++i) { - encode_const(in.get_entry(i).get_const_value(), model_out); + encode_large_const(in.get_entry(i).get_const_value(), model_out); } size_t left_idx = model_out.size(); uint32_t left_type = encode_node(left_child, model_out); |