diff options
author | Håvard Pettersen <havardpe@oath.com> | 2021-06-21 08:52:05 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2021-06-21 08:52:05 +0000 |
commit | 4bf43f52a153cc2dea91aae8e48cf0f782f511f6 (patch) | |
tree | be70c9b16895ed16bed92ed88ceae120c105ae10 /eval | |
parent | 13d53f151b8fef06e77c82aa380a0d11bf053a79 (diff) |
use common implementation for 'bit' function
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/eval/extract_bit.h | 13 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp | 8 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/operation.cpp | 8 |
3 files changed, 17 insertions, 12 deletions
diff --git a/eval/src/vespa/eval/eval/extract_bit.h b/eval/src/vespa/eval/eval/extract_bit.h new file mode 100644 index 00000000000..ecf56b33b02 --- /dev/null +++ b/eval/src/vespa/eval/eval/extract_bit.h @@ -0,0 +1,13 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +namespace vespalib::eval { + +inline double extract_bit(double a, double b) { + int8_t value = (int8_t) a; + uint32_t n = (uint32_t) b; + return ((n < 8) && bool(value & (1 << n))) ? 1.0 : 0.0; +} + +} diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp index 9a99c4fedd7..2a9b7815aa8 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp @@ -4,6 +4,7 @@ #include "llvm_wrapper.h" #include <vespa/eval/eval/node_visitor.h> #include <vespa/eval/eval/node_traverser.h> +#include <vespa/eval/eval/extract_bit.h> #include <llvm/IR/Verifier.h> #include <llvm/Support/TargetSelect.h> #include <llvm/IR/IRBuilder.h> @@ -29,12 +30,7 @@ double vespalib_eval_approx(double a, double b) { return (vespalib::approx_equal double vespalib_eval_relu(double a) { return std::max(a, 0.0); } double vespalib_eval_sigmoid(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); } double vespalib_eval_elu(double a) { return (a < 0) ? std::exp(a) - 1.0 : a; } -double vespalib_eval_bit(double a, double b) { - // must match Bit::f - int8_t value = (int8_t) a; - uint32_t n = (uint32_t) b; - return ((n < 8) && bool(value & (1 << n))) ? 1.0 : 0.0; -} +double vespalib_eval_bit(double a, double b) { return vespalib::eval::extract_bit(a, b); } using vespalib::eval::gbdt::Forest; using resolve_function = double (*)(void *ctx, size_t idx); diff --git a/eval/src/vespa/eval/eval/operation.cpp b/eval/src/vespa/eval/eval/operation.cpp index 36b922539f4..a82a79e6bc4 100644 --- a/eval/src/vespa/eval/eval/operation.cpp +++ b/eval/src/vespa/eval/eval/operation.cpp @@ -3,6 +3,7 @@ #include "operation.h" #include "function.h" #include "key_gen.h" +#include "extract_bit.h" #include <vespa/vespalib/util/approx.h> #include <algorithm> @@ -50,12 +51,7 @@ double Relu::f(double a) { return std::max(a, 0.0); } double Sigmoid::f(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); } double Elu::f(double a) { return (a < 0) ? std::exp(a) - 1 : a; } double Erf::f(double a) { return std::erf(a); } -double Bit::f(double a, double b) { - // must match vespalib_eval_bit - int8_t value = (int8_t) a; - uint32_t n = (uint32_t) b; - return ((n < 8) && bool(value & (1 << n))) ? 1.0 : 0.0; -} +double Bit::f(double a, double b) { return extract_bit(a, b); } //----------------------------------------------------------------------------- double Inv::f(double a) { return (1.0 / a); } double Square::f(double a) { return (a * a); } |