aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2019-10-23 12:47:30 +0000
committerHåvard Pettersen <havardpe@oath.com>2019-10-23 12:47:30 +0000
commitcb6ff89a2bad626b215b50779a4301e5619295b6 (patch)
treebcf40dc1a40d47d6c3803fe906fcd1e3519d23b7 /eval
parentdfec83cec5f464b09b72fc48ff1a5e114523aded (diff)
use float in vm forest implementation
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/eval/vm_forest.cpp51
1 files changed, 34 insertions, 17 deletions
diff --git a/eval/src/vespa/eval/eval/vm_forest.cpp b/eval/src/vespa/eval/eval/vm_forest.cpp
index 127114d0ca5..ff72c3f6521 100644
--- a/eval/src/vespa/eval/eval/vm_forest.cpp
+++ b/eval/src/vespa/eval/eval/vm_forest.cpp
@@ -25,18 +25,25 @@ constexpr uint32_t INVERTED = 3;
// bits: 20 4 4 4
//
// LEAF: [const]
-// bits: 64
+// bits: 32
//
// LESS: [<feature+types>][const][skip]
-// bits 32 64 32
+// bits 32 32 32
//
// IN: [<feature+types>][skip|set size](set size)X[const]
// bits 32 24 8 64
+// Note: We need to use double for set membership checks (IN) due to
+// string hashing.
+
const double *as_double_ptr(const uint32_t *pos) {
return reinterpret_cast<const double*>(pos);
}
+const float *as_float_ptr(const uint32_t *pos) {
+ return reinterpret_cast<const float*>(pos);
+}
+
bool find_in(double value, const double *set, const double *end) {
for (; set < end; ++set) {
if (value == *set) {
@@ -48,15 +55,15 @@ bool find_in(double value, const double *set, const double *end) {
double less_only_find_leaf(const double *input, const uint32_t *pos, uint32_t node_type) {
for (;;) {
- if (input[pos[0] >> 12] < *as_double_ptr(pos + 1)) {
+ if (input[pos[0] >> 12] < *as_float_ptr(pos + 1)) {
node_type = (pos[0] & 0xf0) >> 4;
- pos += 4;
+ pos += 3;
} else {
node_type = (pos[0] & 0xf);
- pos += 4 + pos[3];
+ pos += 3 + pos[2];
}
if (node_type == LEAF) {
- return *as_double_ptr(pos);
+ return *as_float_ptr(pos);
}
}
}
@@ -64,15 +71,15 @@ double less_only_find_leaf(const double *input, const uint32_t *pos, uint32_t no
double general_find_leaf(const double *input, const uint32_t *pos, uint32_t node_type) {
for (;;) {
if (node_type == LESS) {
- if (input[pos[0] >> 12] < *as_double_ptr(pos + 1)) {
+ if (input[pos[0] >> 12] < *as_float_ptr(pos + 1)) {
node_type = (pos[0] & 0xf0) >> 4;
- pos += 4;
+ pos += 3;
} else {
node_type = (pos[0] & 0xf);
- pos += 4 + pos[3];
+ pos += 3 + pos[2];
}
if (node_type == LEAF) {
- return *as_double_ptr(pos);
+ return *as_float_ptr(pos);
}
} else if (node_type == IN) {
if (find_in(input[pos[0] >> 12], as_double_ptr(pos + 2),
@@ -85,18 +92,18 @@ double general_find_leaf(const double *input, const uint32_t *pos, uint32_t node
pos += (2 + (2 * (pos[1] & 0xff))) + (pos[1] >> 8);
}
if (node_type == LEAF) {
- return *as_double_ptr(pos);
+ return *as_float_ptr(pos);
}
} else {
- if (input[pos[0] >> 12] >= *as_double_ptr(pos + 1)) {
+ if (input[pos[0] >> 12] >= *as_float_ptr(pos + 1)) {
node_type = (pos[0] & 0xf);
- pos += 4 + pos[3];
+ pos += 3 + pos[2];
} else {
node_type = (pos[0] & 0xf0) >> 4;
- pos += 4;
+ pos += 3;
}
if (node_type == LEAF) {
- return *as_double_ptr(pos);
+ return *as_float_ptr(pos);
}
}
}
@@ -104,7 +111,7 @@ double general_find_leaf(const double *input, const uint32_t *pos, uint32_t node
//-----------------------------------------------------------------------------
-void encode_const(double value, std::vector<uint32_t> &model_out) {
+void encode_large_const(double value, std::vector<uint32_t> &model_out) {
union {
double d[1];
uint32_t i[2];
@@ -115,6 +122,16 @@ void encode_const(double value, std::vector<uint32_t> &model_out) {
model_out.push_back(buf.i[1]);
}
+void encode_const(float value, std::vector<uint32_t> &model_out) {
+ union {
+ float f[1];
+ uint32_t i[1];
+ } buf;
+ assert(sizeof(buf) == sizeof(float));
+ buf.f[0] = value;
+ model_out.push_back(buf.i[0]);
+}
+
uint32_t encode_node(const nodes::Node &node_in, std::vector<uint32_t> &model_out);
void encode_less(const nodes::Less &less,
@@ -146,7 +163,7 @@ void encode_in(const nodes::In &in,
size_t set_size_idx = model_out.size();
model_out.push_back(in.num_entries());
for (size_t i = 0; i < in.num_entries(); ++i) {
- encode_const(in.get_entry(i).get_const_value(), model_out);
+ encode_large_const(in.get_entry(i).get_const_value(), model_out);
}
size_t left_idx = model_out.size();
uint32_t left_type = encode_node(left_child, model_out);