From ff9885cd0296700eaa2f35096e9abf3226c6d5ec Mon Sep 17 00:00:00 2001 From: fasiondog Date: Sun, 4 Apr 2021 22:27:54 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=20db=20connect?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../utilities/db_connect/DBConnectBase.h | 45 ++++-- .../utilities/db_connect/SQLStatementBase.h | 16 +- .../hikyuu/utilities/db_connect/TableMacro.h | 140 +++++++++++++----- .../db_connect/mysql/MySQLStatement.h | 5 + .../db_connect/sqlite/SQLiteStatement.h | 5 + .../hikyuu/utilities/test_sqlite.cpp | 6 +- hikyuu_cpp/unit_test/xmake.lua | 8 +- 7 files changed, 158 insertions(+), 67 deletions(-) diff --git a/hikyuu_cpp/hikyuu/utilities/db_connect/DBConnectBase.h b/hikyuu_cpp/hikyuu/utilities/db_connect/DBConnectBase.h index a7e029ec..97fa1166 100644 --- a/hikyuu_cpp/hikyuu/utilities/db_connect/DBConnectBase.h +++ b/hikyuu_cpp/hikyuu/utilities/db_connect/DBConnectBase.h @@ -86,9 +86,11 @@ public: * driver->save(a); * } * @endcode + * @param item 待保持的记录 + * @param autotrans 启动事务 */ template - void save(T& item); + void save(T& item, bool autotrans = true); /** * 批量保存 @@ -194,15 +196,37 @@ inline int DBConnectBase::queryInt(const string& query) { //------------------------------------------------------------------------- template -void DBConnectBase::save(T& item) { - if (item.id() == 0) { - SQLStatementPtr st = getStatement(T::getInsertSQL()); - item.save(st); - st->exec(); - } else { - SQLStatementPtr st = getStatement(T::getUpdateSQL()); - item.update(st); - st->exec(); +void DBConnectBase::save(T& item, bool autotrans) { + SQLStatementPtr st = + item.id() == 0 ? getStatement(T::getInsertSQL()) : getStatement(T::getUpdateSQL()); + if (autotrans) { + transaction(); + } + + try { + if (item.id() == 0) { + item.save(st); + st->exec(); + item.id(st->getLastRowid()); + } else { + item.update(st); + st->exec(); + } + + if (autotrans) { + commit(); + } + } catch (std::exception& e) { + if (autotrans) { + rollback(); + } + HKU_THROW("failed save! sql: {}! {}", st->getSqlString(), e.what()); + + } catch (...) { + if (autotrans) { + rollback(); + } + HKU_THROW("failed save! sql: {}! Unknown error!", st->getSqlString()); } } @@ -222,6 +246,7 @@ void DBConnectBase::batchSave(InputIterator first, InputIterator last, bool auto for (InputIterator iter = first; iter != last; ++iter) { iter->save(st); st->exec(); + iter->id(st->getLastRowid()); } if (autotrans) { diff --git a/hikyuu_cpp/hikyuu/utilities/db_connect/SQLStatementBase.h b/hikyuu_cpp/hikyuu/utilities/db_connect/SQLStatementBase.h index 28ad0731..7b025b2a 100644 --- a/hikyuu_cpp/hikyuu/utilities/db_connect/SQLStatementBase.h +++ b/hikyuu_cpp/hikyuu/utilities/db_connect/SQLStatementBase.h @@ -87,7 +87,7 @@ public: void bind(int idx, const T&, const Args&... rest); /** 获取执行INSERT时最后插入记录的 rowid,非线程安全 */ - uint64_t getLastRowid() const; + uint64_t getLastRowid(); /** 获取表格列数 */ int getNumColumns() const; @@ -116,9 +116,10 @@ public: //------------------------------------------------------------------------- // 子类接口 //------------------------------------------------------------------------- - virtual bool sub_isValid() const = 0; ///< 子类接口 @see isValid - virtual void sub_exec() = 0; ///< 子类接口 @see exec - virtual bool sub_moveNext() = 0; ///< 子类接口 @see moveNext + virtual bool sub_isValid() const = 0; ///< 子类接口 @see isValid + virtual void sub_exec() = 0; ///< 子类接口 @see exec + virtual bool sub_moveNext() = 0; ///< 子类接口 @see moveNext + virtual uint64_t sub_getLastRowid() = 0; ///< 子类接口 @see getLastRowid(); virtual void sub_bindNull(int idx) = 0; ///< 子类接口 @see bind virtual void sub_bindInt(int idx, int64_t value) = 0; ///< 子类接口 @see bind @@ -138,14 +139,13 @@ private: protected: DBConnectBase* m_driver; ///< 数据库连接 string m_sql_string; ///< 原始 SQL 语句 - uint64_t m_last_rowid; ///< INSERT时获取最后插入记录的rowid }; /** @ingroup DBConnect */ typedef shared_ptr SQLStatementPtr; inline SQLStatementBase ::SQLStatementBase(DBConnectBase* driver, const string& sql_statement) -: m_driver(driver), m_sql_string(sql_statement), m_last_rowid(0) { +: m_driver(driver), m_sql_string(sql_statement) { HKU_CHECK(driver, "driver is null!"); } @@ -195,8 +195,8 @@ inline void SQLStatementBase::bindBlob(int idx, const string& item) { sub_bindBlob(idx, item); } -inline uint64_t SQLStatementBase::getLastRowid() const { - return m_last_rowid; +inline uint64_t SQLStatementBase::getLastRowid() { + return sub_getLastRowid(); } inline int SQLStatementBase::getNumColumns() const { diff --git a/hikyuu_cpp/hikyuu/utilities/db_connect/TableMacro.h b/hikyuu_cpp/hikyuu/utilities/db_connect/TableMacro.h index f7f06751..7a73aa0d 100644 --- a/hikyuu_cpp/hikyuu/utilities/db_connect/TableMacro.h +++ b/hikyuu_cpp/hikyuu/utilities/db_connect/TableMacro.h @@ -17,12 +17,15 @@ namespace hku { #define TABLE_BIND1(table, f1) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`) values (?)"; \ } \ @@ -44,12 +47,15 @@ public: \ #define TABLE_BIND2(table, f1, f2) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`) values (?,?)"; \ } \ @@ -71,12 +77,15 @@ public: \ #define TABLE_BIND3(table, f1, f2, f3) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`) values (?,?,?)"; \ } \ @@ -98,12 +107,15 @@ public: #define TABLE_BIND4(table, f1, f2, f3, f4) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`,`" #f4 \ "`) values (?,?,?,?)"; \ @@ -127,12 +139,15 @@ public: #define TABLE_BIND5(table, f1, f2, f3, f4, f5) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`,`" #f4 "`,`" #f5 \ "`) values (?,?,?,?,?)"; \ @@ -156,12 +171,15 @@ public: #define TABLE_BIND6(table, f1, f2, f3, f4, f5, f6) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`,`" #f4 "`,`" #f5 "`,`" #f6 \ "`) values (?,?,?,?,?,?)"; \ @@ -186,12 +204,15 @@ public: #define TABLE_BIND7(table, f1, f2, f3, f4, f5, f6, f7) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`,`" #f4 "`,`" #f5 "`,`" #f6 \ "`,`" #f7 "`) values (?,?,?,?,?,?,?)"; \ @@ -216,12 +237,15 @@ public: #define TABLE_BIND8(table, f1, f2, f3, f4, f5, f6, f7, f8) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`,`" #f4 "`,`" #f5 "`,`" #f6 \ "`,`" #f7 "`,`" #f8 "`) values (?,?,?,?,?,?,?,?)"; \ @@ -246,12 +270,15 @@ public: #define TABLE_BIND9(table, f1, f2, f3, f4, f5, f6, f7, f8, f9) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`,`" #f4 "`,`" #f5 "`,`" #f6 \ "`,`" #f7 "`,`" #f8 "`,`" #f9 "`) values (?,?,?,?,?,?,?,?,?)"; \ @@ -276,12 +303,15 @@ public: #define TABLE_BIND10(table, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`,`" #f4 "`,`" #f5 "`,`" #f6 \ "`,`" #f7 "`,`" #f8 "`,`" #f9 "`,`" #f10 "`) values (?,?,?,?,?,?,?,?,?,?)"; \ @@ -306,12 +336,15 @@ public: #define TABLE_BIND11(table, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`,`" #f4 "`,`" #f5 "`,`" #f6 \ "`,`" #f7 "`,`" #f8 "`,`" #f9 "`,`" #f10 "`,`" #f11 \ @@ -338,12 +371,15 @@ public: #define TABLE_BIND12(table, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`,`" #f4 "`,`" #f5 "`,`" #f6 \ "`,`" #f7 "`,`" #f8 "`,`" #f9 "`,`" #f10 "`,`" #f11 "`,`" #f12 \ @@ -370,12 +406,15 @@ public: #define TABLE_BIND13(table, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`,`" #f4 "`,`" #f5 "`,`" #f6 \ "`,`" #f7 "`,`" #f8 "`,`" #f9 "`,`" #f10 "`,`" #f11 "`,`" #f12 "`,`" #f13 \ @@ -403,12 +442,15 @@ public: #define TABLE_BIND14(table, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`,`" #f4 "`,`" #f5 "`,`" #f6 \ "`,`" #f7 "`,`" #f8 "`,`" #f9 "`,`" #f10 "`,`" #f11 "`,`" #f12 "`,`" #f13 \ @@ -436,12 +478,15 @@ public: #define TABLE_BIND15(table, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`,`" #f4 "`,`" #f5 "`,`" #f6 \ "`,`" #f7 "`,`" #f8 "`,`" #f9 "`,`" #f10 "`,`" #f11 "`,`" #f12 "`,`" #f13 \ @@ -469,12 +514,15 @@ public: #define TABLE_BIND16(table, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`,`" #f4 "`,`" #f5 "`,`" #f6 \ "`,`" #f7 "`,`" #f8 "`,`" #f9 "`,`" #f10 "`,`" #f11 "`,`" #f12 "`,`" #f13 \ @@ -505,12 +553,15 @@ public: #define TABLE_BIND17(table, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, \ f17) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`,`" #f4 "`,`" #f5 "`,`" #f6 \ "`,`" #f7 "`,`" #f8 "`,`" #f9 "`,`" #f10 "`,`" #f11 "`,`" #f12 "`,`" #f13 \ @@ -543,12 +594,15 @@ public: #define TABLE_BIND18(table, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, \ f17, f18) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`,`" #f4 "`,`" #f5 "`,`" #f6 \ "`,`" #f7 "`,`" #f8 "`,`" #f9 "`,`" #f10 "`,`" #f11 "`,`" #f12 "`,`" #f13 \ @@ -582,12 +636,15 @@ public: #define TABLE_BIND19(table, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, \ f17, f18, f19) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`,`" #f4 "`,`" #f5 "`,`" #f6 \ "`,`" #f7 "`,`" #f8 "`,`" #f9 "`,`" #f10 "`,`" #f11 "`,`" #f12 "`,`" #f13 \ @@ -621,12 +678,15 @@ public: #define TABLE_BIND20(table, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, \ f17, f18, f19, f20) \ private: \ - int64_t m_id = 0; \ + uint64_t m_id = 0; \ \ public: \ - int64_t id() const { \ + uint64_t id() const { \ return m_id; \ } \ + void id(uint64_t val) { \ + m_id = val; \ + } \ static const char* getInsertSQL() { \ return "insert into `" #table "` (`" #f1 "`,`" #f2 "`,`" #f3 "`,`" #f4 "`,`" #f5 "`,`" #f6 \ "`,`" #f7 "`,`" #f8 "`,`" #f9 "`,`" #f10 "`,`" #f11 "`,`" #f12 "`,`" #f13 \ diff --git a/hikyuu_cpp/hikyuu/utilities/db_connect/mysql/MySQLStatement.h b/hikyuu_cpp/hikyuu/utilities/db_connect/mysql/MySQLStatement.h index 1205d9c2..fbd3b035 100755 --- a/hikyuu_cpp/hikyuu/utilities/db_connect/mysql/MySQLStatement.h +++ b/hikyuu_cpp/hikyuu/utilities/db_connect/mysql/MySQLStatement.h @@ -35,6 +35,7 @@ public: virtual bool sub_isValid() const override; virtual void sub_exec() override; virtual bool sub_moveNext() override; + virtual uint64_t sub_getLastRowid() override; virtual void sub_bindNull(int idx) override; virtual void sub_bindInt(int idx, int64_t value) override; @@ -67,6 +68,10 @@ private: vector m_result_error; }; +inline uint64_t MySQLStatement::sub_getLastRowid() { + return mysql_stmt_insert_id(m_stmt); +} + } // namespace hku #endif /* HIYUU_DB_CONNECT_MYSQL_MYSQLSTATEMENT_H */ \ No newline at end of file diff --git a/hikyuu_cpp/hikyuu/utilities/db_connect/sqlite/SQLiteStatement.h b/hikyuu_cpp/hikyuu/utilities/db_connect/sqlite/SQLiteStatement.h index 5c3510fe..042b4cc1 100644 --- a/hikyuu_cpp/hikyuu/utilities/db_connect/sqlite/SQLiteStatement.h +++ b/hikyuu_cpp/hikyuu/utilities/db_connect/sqlite/SQLiteStatement.h @@ -36,6 +36,7 @@ public: virtual bool sub_isValid() const override; virtual void sub_exec() override; virtual bool sub_moveNext() override; + virtual uint64_t sub_getLastRowid() override; virtual void sub_bindNull(int idx) override; virtual void sub_bindInt(int idx, int64_t value) override; @@ -65,6 +66,10 @@ inline bool SQLiteStatement::sub_isValid() const { return m_stmt ? true : false; } +inline uint64_t SQLiteStatement::sub_getLastRowid() { + return sqlite3_last_insert_rowid(m_db); +} + } /* namespace hku */ #endif /* HIKYUU_DB_CONNECT_SQLITE_SQLITESTATEMENT_H */ \ No newline at end of file diff --git a/hikyuu_cpp/unit_test/hikyuu/utilities/test_sqlite.cpp b/hikyuu_cpp/unit_test/hikyuu/utilities/test_sqlite.cpp index 08343f86..858d7b14 100644 --- a/hikyuu_cpp/unit_test/hikyuu/utilities/test_sqlite.cpp +++ b/hikyuu_cpp/unit_test/hikyuu/utilities/test_sqlite.cpp @@ -154,7 +154,7 @@ TEST_CASE("test_sqlite") { con->exec("drop table ttt"); } - { + /*{ con->exec( R"(CREATE TABLE "perf_test" ( "id" INTEGER UNIQUE, @@ -180,9 +180,9 @@ TEST_CASE("test_sqlite") { t_list.push_back(PerformancTest(std::to_string(i), i)); } { - SPEND_TIME_MSG(batch, "insert mysql, total records: {}", total); + SPEND_TIME_MSG(batch, "insert sqlite, total records: {}", total); con->batchSave(t_list.begin(), t_list.end()); } con->exec("drop table perf_test"); - } + }*/ } diff --git a/hikyuu_cpp/unit_test/xmake.lua b/hikyuu_cpp/unit_test/xmake.lua index be2cb769..110fe929 100644 --- a/hikyuu_cpp/unit_test/xmake.lua +++ b/hikyuu_cpp/unit_test/xmake.lua @@ -14,7 +14,7 @@ target("unit-test") set_default(false) end - add_packages("fmt", "spdlog", "doctest", "mysql") + add_packages("fmt", "spdlog", "doctest", "mysql", "sqlite3") add_includedirs("..") @@ -59,7 +59,7 @@ target("small-test") if get_config("test") == "all" then set_default(false) end - add_packages("fmt", "spdlog", "doctest", "mysql") + add_packages("fmt", "spdlog", "doctest", "mysql", "sqlite3") add_includedirs("..") --add_defines("BOOST_TEST_DYN_LINK") @@ -92,8 +92,4 @@ target("small-test") -- add files add_files("./hikyuu/hikyuu/**.cpp"); add_files("./hikyuu/test_main.cpp") - - add_packages("sqlite3") - add_files("./hikyuu/utilities/test_sqlite.cpp") - target_end() \ No newline at end of file