From a7f49d893e252c4457705f6bcde436cf89f3692a Mon Sep 17 00:00:00 2001 From: ihmc3jn09hk Date: Sun, 26 Apr 2020 23:23:03 +0800 Subject: [PATCH] Support url safe base64 codec (#417) Co-authored-by: antao --- lib/inc/drogon/utils/Utilities.h | 3 +- lib/src/Utilities.cc | 79 ++++++++++++++++++++++++++------ unittest/Base64Unittest.cpp | 57 +++++++++++++++++++++++ unittest/CMakeLists.txt | 4 +- 4 files changed, 127 insertions(+), 16 deletions(-) create mode 100644 unittest/Base64Unittest.cpp diff --git a/lib/inc/drogon/utils/Utilities.h b/lib/inc/drogon/utils/Utilities.h index a9c070fc..53f89ddb 100644 --- a/lib/inc/drogon/utils/Utilities.h +++ b/lib/inc/drogon/utils/Utilities.h @@ -66,7 +66,8 @@ std::string getUuid(); /// Encode the string to base64 format. std::string base64Encode(const unsigned char *bytes_to_encode, - unsigned int in_len); + unsigned int in_len, + bool url_safe = false); /// Decode the base64 format string. std::string base64Decode(const std::string &encoded_string); diff --git a/lib/src/Utilities.cc b/lib/src/Utilities.cc index 95b726d3..e523c201 100644 --- a/lib/src/Utilities.cc +++ b/lib/src/Utilities.cc @@ -53,9 +53,55 @@ static const std::string base64Chars = "abcdefghijklmnopqrstuvwxyz" "0123456789+/"; +static const std::string urlBase64Chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789-_"; +class Base64CharMap +{ + public: + Base64CharMap() + { + char index = 0; + for (char c = 'A'; c <= 'Z'; ++c) + { + charMap_[c] = index++; + } + for (char c = 'a'; c <= 'z'; ++c) + { + charMap_[c] = index++; + } + for (char c = '0'; c <= '9'; ++c) + { + charMap_[c] = index++; + } + charMap_['+'] = charMap_['-'] = index++; + charMap_['/'] = charMap_['_'] = index; + charMap_[0] = 0xff; + } + char getIndex(const char c) const noexcept + { + return charMap_[c]; + } + + private: + char charMap_[256]{0}; +}; +const static Base64CharMap base64CharMap; + static inline bool isBase64(unsigned char c) { - return (isalnum(c) || (c == '+') || (c == '/')); + if (isalnum(c)) + return true; + switch (c) + { + case '+': + case '/': + case '-': + case '_': + return true; + } + return false; } bool isInteger(const std::string &str) @@ -314,13 +360,16 @@ std::string getUuid() } std::string base64Encode(const unsigned char *bytes_to_encode, - unsigned int in_len) + unsigned int in_len, + bool url_safe) { std::string ret; int i = 0; unsigned char char_array_3[3]; unsigned char char_array_4[4]; + const std::string &charSet = url_safe ? urlBase64Chars : base64Chars; + while (in_len--) { char_array_3[i++] = *(bytes_to_encode++); @@ -334,7 +383,7 @@ std::string base64Encode(const unsigned char *bytes_to_encode, char_array_4[3] = char_array_3[2] & 0x3f; for (i = 0; (i < 4); ++i) - ret += base64Chars[char_array_4[i]]; + ret += charSet[char_array_4[i]]; i = 0; } } @@ -352,12 +401,11 @@ std::string base64Encode(const unsigned char *bytes_to_encode, char_array_4[3] = char_array_3[2] & 0x3f; for (int j = 0; (j < i + 1); ++j) - ret += base64Chars[char_array_4[j]]; + ret += charSet[char_array_4[j]]; while ((i++ < 3)) ret += '='; } - return ret; } @@ -378,8 +426,9 @@ std::vector base64DecodeToVector(const std::string &encoded_string) if (i == 4) { for (i = 0; i < 4; ++i) - char_array_4[i] = - static_cast(base64Chars.find(char_array_4[i])); + { + char_array_4[i] = base64CharMap.getIndex(char_array_4[i]); + } char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); @@ -399,8 +448,9 @@ std::vector base64DecodeToVector(const std::string &encoded_string) char_array_4[j] = 0; for (int j = 0; j < 4; ++j) - char_array_4[j] = - static_cast(base64Chars.find(char_array_4[j])); + { + char_array_4[j] = base64CharMap.getIndex(char_array_4[j]); + } char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); @@ -431,9 +481,9 @@ std::string base64Decode(const std::string &encoded_string) if (i == 4) { for (i = 0; i < 4; ++i) - char_array_4[i] = static_cast( - base64Chars.find(char_array_4[i])); - + { + char_array_4[i] = base64CharMap.getIndex(char_array_4[i]); + } char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + @@ -452,8 +502,9 @@ std::string base64Decode(const std::string &encoded_string) char_array_4[j] = 0; for (int j = 0; j < 4; ++j) - char_array_4[j] = - static_cast(base64Chars.find(char_array_4[j])); + { + char_array_4[j] = base64CharMap.getIndex(char_array_4[j]); + } char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); diff --git a/unittest/Base64Unittest.cpp b/unittest/Base64Unittest.cpp new file mode 100644 index 00000000..5dbd9c45 --- /dev/null +++ b/unittest/Base64Unittest.cpp @@ -0,0 +1,57 @@ +#include +#include +#include + +TEST(Base64, base64) +{ + std::string in{"drogon framework"}; + auto out = drogon::utils::base64Encode((const unsigned char *)in.data(), + in.length()); + auto out2 = drogon::utils::base64Decode(out); + EXPECT_EQ(out, "ZHJvZ29uIGZyYW1ld29yaw=="); + EXPECT_EQ(out2, in); +} + +TEST(Base64, base64_long_string) +{ + std::string in; + for (int i = 0; i < 100000; ++i) + { + in.append(1, char(i)); + } + auto out = drogon::utils::base64Encode((const unsigned char *)in.data(), + in.length()); + auto out2 = drogon::utils::base64Decode(out); + EXPECT_EQ(out2, in); +} + +TEST(Base64, base64_url) +{ + std::string in{"drogon framework"}; + auto out = drogon::utils::base64Encode((const unsigned char *)in.data(), + in.length(), + true); + auto out2 = drogon::utils::base64Decode(out); + EXPECT_EQ(out, "ZHJvZ29uIGZyYW1ld29yaw=="); + EXPECT_EQ(out2, in); +} + +TEST(Base64, base64_long_string_url) +{ + std::string in; + for (int i = 0; i < 100000; ++i) + { + in.append(1, char(i)); + } + auto out = drogon::utils::base64Encode((const unsigned char *)in.data(), + in.length(), + true); + auto out2 = drogon::utils::base64Decode(out); + EXPECT_EQ(out2, in); +} + +int main(int argc, char **argv) +{ + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/unittest/CMakeLists.txt b/unittest/CMakeLists.txt index c3fa9677..e8ce0659 100644 --- a/unittest/CMakeLists.txt +++ b/unittest/CMakeLists.txt @@ -4,6 +4,7 @@ add_executable(gzip_unittest GzipUnittest.cpp) add_executable(md5_unittest MD5Unittest.cpp ../lib/src/ssl_funcs/Md5.cc) add_executable(sha1_unittest SHA1Unittest.cpp ../lib/src/ssl_funcs/Sha1.cc) add_executable(ostringstream_unittest OStringStreamUnitttest.cpp) +add_executable(base64_unittest Base64Unittest.cpp) if(Brotli_FOUND) add_executable(brotli_unittest BrotliUnittest.cpp) endif() @@ -14,7 +15,8 @@ set(UNITTEST_TARGETS gzip_unittest md5_unittest sha1_unittest - ostringstream_unittest) + ostringstream_unittest + base64_unittest) if(Brotli_FOUND) set(UNITTEST_TARGETS ${UNITTEST_TARGETS} brotli_unittest) endif()