summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2021-06-21 08:52:05 +0000
committerHåvard Pettersen <havardpe@oath.com>2021-06-21 08:52:05 +0000
commit4bf43f52a153cc2dea91aae8e48cf0f782f511f6 (patch)
treebe70c9b16895ed16bed92ed88ceae120c105ae10 /eval
parent13d53f151b8fef06e77c82aa380a0d11bf053a79 (diff)
use common implementation for 'bit' function
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/eval/extract_bit.h13
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp8
-rw-r--r--eval/src/vespa/eval/eval/operation.cpp8
3 files changed, 17 insertions, 12 deletions
diff --git a/eval/src/vespa/eval/eval/extract_bit.h b/eval/src/vespa/eval/eval/extract_bit.h
new file mode 100644
index 00000000000..ecf56b33b02
--- /dev/null
+++ b/eval/src/vespa/eval/eval/extract_bit.h
@@ -0,0 +1,13 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+namespace vespalib::eval {
+
+inline double extract_bit(double a, double b) {
+ int8_t value = (int8_t) a;
+ uint32_t n = (uint32_t) b;
+ return ((n < 8) && bool(value & (1 << n))) ? 1.0 : 0.0;
+}
+
+}
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
index 9a99c4fedd7..2a9b7815aa8 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
@@ -4,6 +4,7 @@
#include "llvm_wrapper.h"
#include <vespa/eval/eval/node_visitor.h>
#include <vespa/eval/eval/node_traverser.h>
+#include <vespa/eval/eval/extract_bit.h>
#include <llvm/IR/Verifier.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/IR/IRBuilder.h>
@@ -29,12 +30,7 @@ double vespalib_eval_approx(double a, double b) { return (vespalib::approx_equal
double vespalib_eval_relu(double a) { return std::max(a, 0.0); }
double vespalib_eval_sigmoid(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); }
double vespalib_eval_elu(double a) { return (a < 0) ? std::exp(a) - 1.0 : a; }
-double vespalib_eval_bit(double a, double b) {
- // must match Bit::f
- int8_t value = (int8_t) a;
- uint32_t n = (uint32_t) b;
- return ((n < 8) && bool(value & (1 << n))) ? 1.0 : 0.0;
-}
+double vespalib_eval_bit(double a, double b) { return vespalib::eval::extract_bit(a, b); }
using vespalib::eval::gbdt::Forest;
using resolve_function = double (*)(void *ctx, size_t idx);
diff --git a/eval/src/vespa/eval/eval/operation.cpp b/eval/src/vespa/eval/eval/operation.cpp
index 36b922539f4..a82a79e6bc4 100644
--- a/eval/src/vespa/eval/eval/operation.cpp
+++ b/eval/src/vespa/eval/eval/operation.cpp
@@ -3,6 +3,7 @@
#include "operation.h"
#include "function.h"
#include "key_gen.h"
+#include "extract_bit.h"
#include <vespa/vespalib/util/approx.h>
#include <algorithm>
@@ -50,12 +51,7 @@ double Relu::f(double a) { return std::max(a, 0.0); }
double Sigmoid::f(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); }
double Elu::f(double a) { return (a < 0) ? std::exp(a) - 1 : a; }
double Erf::f(double a) { return std::erf(a); }
-double Bit::f(double a, double b) {
- // must match vespalib_eval_bit
- int8_t value = (int8_t) a;
- uint32_t n = (uint32_t) b;
- return ((n < 8) && bool(value & (1 << n))) ? 1.0 : 0.0;
-}
+double Bit::f(double a, double b) { return extract_bit(a, b); }
//-----------------------------------------------------------------------------
double Inv::f(double a) { return (1.0 / a); }
double Square::f(double a) { return (a * a); }