dumpbinlog

This commit is contained in:
tsthght 2019-02-21 15:39:58 +08:00
parent c84e535e7f
commit 6ff354b05a
35 changed files with 3502 additions and 0 deletions

View File

@ -0,0 +1,27 @@
# 注释只能写在新行,写在行后会被解析成值
# 值不要用引号
[BINLOG_MYSQL]
host=10.120.12.55
port=3406
user=cetus_app
password=cetus_app123
[OUTPUT_MYSQL]
host=10.120.12.55
port=7002
user=cetus_app
password=cetus_app123
[DEFAULT]
#log_file='bj-10-238-7-7-bin.000010',
#log_pos=4,
auto_position=6a765e25-318d-11e8-a762-246e9616dff4:1-3945180
skip_schemas=proxy_heart_beat
log_level=DEBUG
#ignore_ddl=true
#only_sharding_table=~/sharding.json

View File

@ -0,0 +1,77 @@
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import json
import configparser
import logging
class CetusConf:
def __init__(self, sharding_conf_file):
with open(os.path.expanduser(sharding_conf_file)) as f:
self.shard_config = json.load(f)
def sharded_and_single_tables(self):
return self.sharded_tables() + self.single_tables()
def sharded_tables(self):
return [(t['db'], t['table']) for t in self.shard_config['table']]
def single_tables(self):
return [(t['db'], t['table']) for t in self.shard_config['single_tables']]
class BinlogConfig:
BINLOG_MYSQL = {}
OUTPUT_MYSQL = {}
SKIP_SCHEMAS = None
BINLOG_POS = {}
IGNORE_DDL = False
LOG_LEVEL = logging.INFO
ONLY_SHARDING_TABLE = None
ALLOW_TABLES = None # get these table from file: ONLY_SHARDING_TABLE
def __init__(self, filename):
conf = configparser.ConfigParser()
conf.read(filename)
self.BINLOG_MYSQL['host'] = conf['BINLOG_MYSQL']['host']
self.BINLOG_MYSQL['port'] = int(conf['BINLOG_MYSQL']['port'])
self.BINLOG_MYSQL['user'] = conf['BINLOG_MYSQL']['user']
self.BINLOG_MYSQL['passwd'] = conf['BINLOG_MYSQL']['password']
self.OUTPUT_MYSQL['host'] = conf['OUTPUT_MYSQL']['host']
self.OUTPUT_MYSQL['port'] = int(conf['OUTPUT_MYSQL']['port'])
self.OUTPUT_MYSQL['user'] = conf['OUTPUT_MYSQL']['user']
self.OUTPUT_MYSQL['password'] = conf['OUTPUT_MYSQL']['password']
skip_schemas = conf['DEFAULT'].get('skip_schemas', None)
if skip_schemas:
self.SKIP_SCHEMAS = [s.strip() for s in skip_schemas.split(',')]
sharding_file = conf['DEFAULT'].get('only_sharding_table', None)
if sharding_file:
self.ONLY_SHARDING_TABLE = sharding_file
pos = conf['DEFAULT'].get('auto_position', None)
if pos:
self.BINLOG_POS['auto_position'] = pos
log_file = conf['DEFAULT'].get('log_file', None)
if log_file:
self.BINLOG_POS['log_file'] = log_file
log_pos = conf['DEFAULT'].get('log_pos', None)
if log_file:
self.BINLOG_POS['log_pos'] = int(log_pos)
level = conf['DEFAULT'].get('log_level', 'INFO')
if level == 'DEBUG':
self.LOG_LEVEL = logging.DEBUG
if conf['DEFAULT'].getboolean('ignore_ddl', False):
self.IGNORE_DDL = True
def dump(self):
print(' SKIP_SCHEMAS:', self.SKIP_SCHEMAS)
print(' BINLOG_POS:', self.BINLOG_POS)
print(' IGNORE_DDL:', self.IGNORE_DDL)
print(' LOG_LEVEL:', self.LOG_LEVEL)
print(' ONLY_SHARDING_TABLE:', self.ONLY_SHARDING_TABLE)
print('')

84
dumpbinlog-tool/dumpbinlog.py Executable file
View File

@ -0,0 +1,84 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from executor import TransactionDispatcher
from transaction import Transaction, BinlogTrxReader
import cetus_config
import logging
import queue
import sys
import os
import recovery
import logger
import argparse
transaction_queue = queue.Queue()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-d', dest='base_dir', type=str, default=os.getcwd(), help='Set working directory')
parser.add_argument('-n', dest='dry_run', action='store_true', help='Not output to MySQL, just print')
args = parser.parse_args()
os.chdir(args.base_dir)
print("Set working dir: ", args.base_dir)
config = cetus_config.BinlogConfig('binlog.conf')
config.dump()
logger.init_logger(config.LOG_LEVEL)
main_logger = logging.getLogger('main')
main_logger.info('APP START')
if config.ONLY_SHARDING_TABLE:
conf = cetus_config.CetusConf(config.ONLY_SHARDING_TABLE)
config.ALLOW_TABLES = conf.sharded_and_single_tables()
prev_execution = None
pos_config = config.BINLOG_POS
if os.path.exists('progress.log'):
prev_execution = recovery.read('progress.log')
if prev_execution is not None:
print('Found previous execution log, recover from it.')
main_logger.info('Found previous execution log, recover from it.')
#pos_config = {'auto_position': prev_execution.gtid_executed}
pos_config = {'log_file': prev_execution.start_log_file,
'log_pos': prev_execution.start_log_pos}
stream = BinlogTrxReader(
config=config,
server_id=100,
blocking=True,
resume_stream=True,
**pos_config
)
trx_queue = queue.Queue(500)
dispatcher = TransactionDispatcher(config.OUTPUT_MYSQL,
trx_queue=trx_queue,
prev_execution=prev_execution,
max_connections=20)
dispatcher.start()
try:
for trx in stream:
if args.dry_run:
print(trx)
else:
trx_queue.put(trx) # will block if full
except KeyboardInterrupt:
print('KeyboardInterrupt, Control-C or error in threads, exiting')
main_logger.info('KeyboardInterrupt, Control-C or error in threads, exiting')
except Exception as e:
main_logger.info(e)
raise
finally:
stream.close()
dispatcher.quit()
dispatcher.join()
main_logger.info('APP END')
if __name__ == "__main__":
main()

270
dumpbinlog-tool/executor.py Normal file
View File

@ -0,0 +1,270 @@
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
from transaction import Transaction
import queue
import threading
import _thread
import pymysql
import os
import time
import logging
import signal
import traceback
from recovery import PreviousExecution
g_jobs = queue.Queue()
g_jobs_done = queue.Queue()
g_shutdown = False
def force_exit():
os._exit(1)
#os.kill(os.getpid(), signal.SIGKILL)
class MysqlConnection:
'''wrapper class for pymysql.Connection
it has a status and it remembers prepared XID
'''
XA_PREPARED = 'XA_PREPARED'
IDLE = 'IDLE'
BUSY = 'BUSY'
def __init__(self, conn: pymysql.Connection):
self.connection = conn
self.status = self.IDLE
self._xid = None
self.xa_prepare_ignored = False
@property
def xid(self):
assert self.status == self.XA_PREPARED
return self._xid
@xid.setter
def xid(self, xid):
self._xid = xid
def __repr__(self):
return 'mysql: '+str(self.connection.thread_id())
class Job:
def __init__(self, trx: Transaction, conn: MysqlConnection):
'''run in Dispatcher thread
status change always happen in Dispatcher thread
'''
self._trx = trx
self._conn = conn
self._conn.status = MysqlConnection.BUSY
def execute(self):
'''run in worker threads
'''
global g_shutdown
print("Job::execute", self)
print(self._trx.brief())
main_logger = logging.getLogger('main')
trx_logger = logging.getLogger('trx')
with self._conn.connection.cursor() as cursor:
for sql in self._trx.sql_list:
try:
main_logger.debug(str(self) + ' executing: ' + sql)
cursor.execute(sql)
main_logger.debug(str(self) + ' OK')
except pymysql.err.Error as e:
thrd = threading.current_thread().name
main_logger.error(thrd + ' Failed SQL: ' + sql)
main_logger.error(thrd + str(e))
print(thrd + ' Failed SQL: ' + sql)
print(thrd + str(e))
if not g_shutdown:
g_shutdown = True
main_logger.warning(thrd + 'Sending KeyboardInterrupt to main thread')
print(thrd + 'Sending KeyboardInterrupt to main thread')
_thread.interrupt_main()
return
trx_logger.info(self._trx)
def after_execute(self):
'''run in Dispatcher thread
'''
#print("after_execute", self)
if self._trx.type == Transaction.TRX_XA_PREPARE:
self._conn.xid = self._trx.XID
self._conn.status = MysqlConnection.XA_PREPARED
self._conn.xa_prepare_ignored = (len(self._trx.sql_list) == 0)
main_logger = logging.getLogger('main')
main_logger.debug('Hold conn: {} for XA PREPARE: {}'.format(
self._conn, self._trx.XID))
else:
self._conn.status = MysqlConnection.IDLE
self._conn.xa_prepare_ignored = False
def __repr__(self):
return '{} {} {}'.format(threading.current_thread().name,
self._trx.gtid, self._conn)
def _worker():
while True:
job = g_jobs.get(block=True)
if job is not None:
job.execute()
g_jobs_done.put(job)
else:
# thread exit
main_logger = logging.getLogger('main')
main_logger.info(threading.current_thread().name + ' exit')
g_jobs.put(None)
break
class TransactionDispatcher(threading.Thread):
def __init__(self, mysql_config: dict,
trx_queue: queue.Queue=None,
prev_execution: PreviousExecution=None,
max_connections: int=20):
super().__init__(name='Dispatcher-Thread')
self._mysql_config = mysql_config
self._mysql_config['autocommit'] = True
self._connections = []
self._max_connections = max_connections
self._trx_queue = trx_queue
self._threads = set()
self._create_threads()
self._running_transactions = set()
self._prev_execution = prev_execution
def quit(self):
global g_shutdown
g_shutdown = True
if self._trx_queue.empty():
self._trx_queue.put(None)
def _create_threads(self):
for i in range(self._max_connections):
new_th = threading.Thread(target=_worker) # TODO: weak ref
new_th.start()
self._threads.add(new_th)
main_logger = logging.getLogger('main')
main_logger.info('New thread: ' + new_th.name)
def _can_do_parallel(self, trx: Transaction) -> bool:
if len(self._running_transactions) == 0:
return True
# interleaved with any running transaction
return any(trx.interleaved(rt) for rt in self._running_transactions)
def _get_appropriate_connection(self, trx: Transaction) -> MysqlConnection:
main_logger = logging.getLogger('main')
if len(self._connections) < self._max_connections:
try:
conn = pymysql.connect(**self._mysql_config)
except pymysql.err.OperationalError as e:
main_logger.error(e)
traceback.print_exc()
force_exit()
mycon = MysqlConnection(conn)
self._connections.append(mycon)
main_logger.info('New conn: ' + str(mycon))
if trx.type == Transaction.TRX_XA_COMMIT:
conn = next((c for c in self._connections \
if c.status==MysqlConnection.XA_PREPARED \
and c.xid==trx.XID), None)
if conn is None:
main_logger.warning('XA commit {} cannot pairing, wait...'.format(trx.XID))
else:
main_logger.debug('Got pair conn: {} for XA COMMIT: {}'.format(conn, trx.XID))
# if the corresponding XA_PREPARE transaction is ignored,
# also ignore the COMMIT. TODO: should not modify trx here
if conn and conn.xa_prepare_ignored:
main_logger.warning('XA prepared ignored, also ignoring XA commit')
trx.sql_list = []
return conn
else:
conn = next((c for c in self._connections \
if c.status==MysqlConnection.IDLE), None)
if conn is None:
main_logger.warning('No idle conn for {} {}'.format(trx.gtid, trx.type))
return conn
def _complete_jobs(self) -> int:
count = 0
while True:
try:
'''
If we got one vaccant thread (count>0), we get a chance to arrange
new jobs. And we need to clear the queue in non-blocking mode.
'''
job = g_jobs_done.get(block=(count==0), timeout=0.5)
job.after_execute()
self._running_transactions.remove(job._trx)
count += 1
except queue.Empty:
break
return count
def _drain_jobs(self, filename: str) -> int:
'''The `interval rule` only works inside one binlog file
When new file comes, all jobs in last file must complete
'''
main_logger = logging.getLogger('main')
count = 0
while len(self._running_transactions) > 0:
main_logger.info('Draining jobs in file: %s', filename)
try:
job = g_jobs_done.get(timeout=5)
job.after_execute()
self._running_transactions.remove(job._trx)
count += 1
except queue.Empty:
break
return count
def _join_all_threads(self):
g_jobs.put(None)
while any(t.is_alive() for t in self._threads):
time.sleep(0.5)
def run(self):
main_logger = logging.getLogger('main')
global g_shutdown
current_file = ''
while True:
if g_shutdown:
self._join_all_threads()
main_logger.info(threading.current_thread().name + ' exit')
break
trx = self._trx_queue.get()
if trx is None or \
(self._prev_execution and self._prev_execution.executed(trx.gtid)):
continue
if trx.binlog_file != current_file:
self._drain_jobs(current_file)
current_file = trx.binlog_file
if self._can_do_parallel(trx): #ensure parallel probing before get connection
conn = self._get_appropriate_connection(trx) # can return None
else:
conn = None
while conn is None or not self._can_do_parallel(trx):
self._complete_jobs() # will block if no available thread
if conn is None:
conn = self._get_appropriate_connection(trx)
global g_shutdown
if g_shutdown is True:
break
else:
self._running_transactions.add(trx)
g_jobs.put(Job(trx, conn))

26
dumpbinlog-tool/logger.py Normal file
View File

@ -0,0 +1,26 @@
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
import logging
main_formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
trx_formatter = logging.Formatter('%(asctime)s %(message)s')
def _setup_logger(name, log_file, formatter, level=logging.INFO):
"""Function setup as many loggers as you want"""
import os
print('logger work dir:', os.getcwd())
handler = logging.FileHandler(log_file)
handler.setFormatter(formatter)
logger = logging.getLogger(name)
logger.setLevel(level)
logger.addHandler(handler)
return logger
def init_logger(level):
_setup_logger('trx', 'progress.log', trx_formatter)
_setup_logger('main', 'sqldump.log', main_formatter, level=level)
# use with logging.getLogger(name)

View File

@ -0,0 +1,23 @@
"""
Python MySQL Replication:
Pure Python Implementation of MySQL replication protocol build on top of
PyMYSQL.
Licence
=======
Copyright 2012 Julien Duponchelle
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from .binlogstream import BinLogStreamReader

View File

@ -0,0 +1,7 @@
import sys
if sys.version_info > (3,):
text_type = str
else:
text_type = unicode

View File

@ -0,0 +1,559 @@
# -*- coding: utf-8 -*-
import pymysql
import struct
from pymysql.constants.COMMAND import COM_BINLOG_DUMP, COM_REGISTER_SLAVE
from pymysql.cursors import DictCursor
from pymysql.util import int2byte
from .packet import BinLogPacketWrapper
from .constants.BINLOG import TABLE_MAP_EVENT, ROTATE_EVENT, FORMAT_DESCRIPTION_EVENT
from .gtid import GtidSet
from .event import (
QueryEvent, RotateEvent, FormatDescriptionEvent,
XidEvent, GtidEvent, StopEvent, XAPrepareEvent,
BeginLoadQueryEvent, ExecuteLoadQueryEvent,
HeartbeatLogEvent, NotImplementedEvent)
from .exceptions import BinLogNotEnabled
from .row_event import (
UpdateRowsEvent, WriteRowsEvent, DeleteRowsEvent, TableMapEvent)
try:
from pymysql.constants.COMMAND import COM_BINLOG_DUMP_GTID
except ImportError:
# Handle old pymysql versions
# See: https://github.com/PyMySQL/PyMySQL/pull/261
COM_BINLOG_DUMP_GTID = 0x1e
# 2013 Connection Lost
# 2006 MySQL server has gone away
MYSQL_EXPECTED_ERROR_CODES = [2013, 2006]
class ReportSlave(object):
"""Represent the values that you may report when connecting as a slave
to a master. SHOW SLAVE HOSTS related"""
hostname = ''
username = ''
password = ''
port = 0
def __init__(self, value):
"""
Attributes:
value: string or tuple
if string, then it will be used hostname
if tuple it will be used as (hostname, user, password, port)
"""
if isinstance(value, (tuple, list)):
try:
self.hostname = value[0]
self.username = value[1]
self.password = value[2]
self.port = int(value[3])
except IndexError:
pass
elif isinstance(value, dict):
for key in ['hostname', 'username', 'password', 'port']:
try:
setattr(self, key, value[key])
except KeyError:
pass
else:
self.hostname = value
def __repr__(self):
return '<ReportSlave hostname=%s username=%s password=%s port=%d>' %\
(self.hostname, self.username, self.password, self.port)
def encoded(self, server_id, master_id=0):
"""
server_id: the slave server-id
master_id: usually 0. Appears as "master id" in SHOW SLAVE HOSTS
on the master. Unknown what else it impacts.
"""
# 1 [15] COM_REGISTER_SLAVE
# 4 server-id
# 1 slaves hostname length
# string[$len] slaves hostname
# 1 slaves user len
# string[$len] slaves user
# 1 slaves password len
# string[$len] slaves password
# 2 slaves mysql-port
# 4 replication rank
# 4 master-id
lhostname = len(self.hostname.encode())
lusername = len(self.username.encode())
lpassword = len(self.password.encode())
packet_len = (1 + # command
4 + # server-id
1 + # hostname length
lhostname +
1 + # username length
lusername +
1 + # password length
lpassword +
2 + # slave mysql port
4 + # replication rank
4) # master-id
MAX_STRING_LEN = 257 # one byte for length + 256 chars
return (struct.pack('<i', packet_len) +
int2byte(COM_REGISTER_SLAVE) +
struct.pack('<L', server_id) +
struct.pack('<%dp' % min(MAX_STRING_LEN, lhostname + 1),
self.hostname.encode()) +
struct.pack('<%dp' % min(MAX_STRING_LEN, lusername + 1),
self.username.encode()) +
struct.pack('<%dp' % min(MAX_STRING_LEN, lpassword + 1),
self.password.encode()) +
struct.pack('<H', self.port) +
struct.pack('<l', 0) +
struct.pack('<l', master_id))
class BinLogStreamReader(object):
"""Connect to replication stream and read event
"""
report_slave = None
def __init__(self, connection_settings, server_id, ctl_connection_settings=None, resume_stream=False,
blocking=False, only_events=None, log_file=None, log_pos=None,
filter_non_implemented_events=True,
ignored_events=None, auto_position=None,
only_tables=None, ignored_tables=None,
only_schemas=None, ignored_schemas=None,
freeze_schema=False, skip_to_timestamp=None,
report_slave=None, slave_uuid=None,
pymysql_wrapper=None,
fail_on_table_metadata_unavailable=False,
slave_heartbeat=None):
"""
Attributes:
ctl_connection_settings: Connection settings for cluster holding schema information
resume_stream: Start for event from position or the latest event of
binlog or from older available event
blocking: Read on stream is blocking
only_events: Array of allowed events
ignored_events: Array of ignored events
log_file: Set replication start log file
log_pos: Set replication start log pos (resume_stream should be true)
auto_position: Use master_auto_position gtid to set position
only_tables: An array with the tables you want to watch (only works
in binlog_format ROW)
ignored_tables: An array with the tables you want to skip
only_schemas: An array with the schemas you want to watch
ignored_schemas: An array with the schemas you want to skip
freeze_schema: If true do not support ALTER TABLE. It's faster.
skip_to_timestamp: Ignore all events until reaching specified timestamp.
report_slave: Report slave in SHOW SLAVE HOSTS.
slave_uuid: Report slave_uuid in SHOW SLAVE HOSTS.
fail_on_table_metadata_unavailable: Should raise exception if we can't get
table information on row_events
slave_heartbeat: (seconds) Should master actively send heartbeat on
connection. This also reduces traffic in GTID replication
on replication resumption (in case many event to skip in
binlog). See MASTER_HEARTBEAT_PERIOD in mysql documentation
for semantics
"""
self.__connection_settings = connection_settings
self.__connection_settings.setdefault("charset", "utf8")
self.__connected_stream = False
self.__connected_ctl = False
self.__resume_stream = resume_stream
self.__blocking = blocking
self._ctl_connection_settings = ctl_connection_settings
if ctl_connection_settings:
self._ctl_connection_settings.setdefault("charset", "utf8")
self.__only_tables = only_tables
self.__ignored_tables = ignored_tables
self.__only_schemas = only_schemas
self.__ignored_schemas = ignored_schemas
self.__freeze_schema = freeze_schema
self.__allowed_events = self._allowed_event_list(
only_events, ignored_events, filter_non_implemented_events)
self.__fail_on_table_metadata_unavailable = fail_on_table_metadata_unavailable
# We can't filter on packet level TABLE_MAP and rotate event because
# we need them for handling other operations
self.__allowed_events_in_packet = frozenset(
[TableMapEvent, RotateEvent]).union(self.__allowed_events)
self.__server_id = server_id
self.__use_checksum = False
# Store table meta information
self.table_map = {}
self.log_pos = log_pos
self.log_file = log_file
self.auto_position = auto_position
self.skip_to_timestamp = skip_to_timestamp
if report_slave:
self.report_slave = ReportSlave(report_slave)
self.slave_uuid = slave_uuid
self.slave_heartbeat = slave_heartbeat
if pymysql_wrapper:
self.pymysql_wrapper = pymysql_wrapper
else:
self.pymysql_wrapper = pymysql.connect
self.mysql_version = (0, 0, 0)
def close(self):
if self.__connected_stream:
self._stream_connection.close()
self.__connected_stream = False
if self.__connected_ctl:
# break reference cycle between stream reader and underlying
# mysql connection object
self._ctl_connection._get_table_information = None
self._ctl_connection.close()
self.__connected_ctl = False
def __connect_to_ctl(self):
if not self._ctl_connection_settings:
self._ctl_connection_settings = dict(self.__connection_settings)
self._ctl_connection_settings["db"] = "information_schema"
self._ctl_connection_settings["cursorclass"] = DictCursor
self._ctl_connection = self.pymysql_wrapper(**self._ctl_connection_settings)
self._ctl_connection._get_table_information = self.__get_table_information
self.__connected_ctl = True
def __checksum_enabled(self):
"""Return True if binlog-checksum = CRC32. Only for MySQL > 5.6"""
cur = self._stream_connection.cursor()
cur.execute("SHOW GLOBAL VARIABLES LIKE 'BINLOG_CHECKSUM'")
result = cur.fetchone()
cur.close()
if result is None:
return False
var, value = result[:2]
if value == 'NONE':
return False
return True
def _register_slave(self):
if not self.report_slave:
return
packet = self.report_slave.encoded(self.__server_id)
if pymysql.__version__ < "0.6":
self._stream_connection.wfile.write(packet)
self._stream_connection.wfile.flush()
self._stream_connection.read_packet()
else:
self._stream_connection._write_bytes(packet)
self._stream_connection._next_seq_id = 1
self._stream_connection._read_packet()
def __connect_to_stream(self):
# log_pos (4) -- position in the binlog-file to start the stream with
# flags (2) BINLOG_DUMP_NON_BLOCK (0 or 1)
# server_id (4) -- server id of this slave
# log_file (string.EOF) -- filename of the binlog on the master
self._stream_connection = self.pymysql_wrapper(**self.__connection_settings)
self.__use_checksum = self.__checksum_enabled()
# If checksum is enabled we need to inform the server about the that
# we support it
if self.__use_checksum:
cur = self._stream_connection.cursor()
cur.execute("set @master_binlog_checksum= @@global.binlog_checksum")
cur.close()
if self.slave_uuid:
cur = self._stream_connection.cursor()
cur.execute("set @slave_uuid= '%s'" % self.slave_uuid)
cur.close()
if self.slave_heartbeat:
# 4294967 is documented as the max value for heartbeats
net_timeout = float(self.__connection_settings.get('read_timeout',
4294967))
# If heartbeat is too low, the connection will disconnect before,
# this is also the behavior in mysql
heartbeat = float(min(net_timeout/2., self.slave_heartbeat))
if heartbeat > 4294967:
heartbeat = 4294967
# master_heartbeat_period is nanoseconds
heartbeat = int(heartbeat * 1000000000)
cur = self._stream_connection.cursor()
cur.execute("set @master_heartbeat_period= %d" % heartbeat)
cur.close()
self._register_slave()
if not self.auto_position:
# only when log_file and log_pos both provided, the position info is
# valid, if not, get the current position from master
if self.log_file is None or self.log_pos is None:
cur = self._stream_connection.cursor()
cur.execute("SHOW MASTER STATUS")
master_status = cur.fetchone()
if master_status is None:
raise BinLogNotEnabled()
self.log_file, self.log_pos = master_status[:2]
cur.close()
prelude = struct.pack('<i', len(self.log_file) + 11) \
+ int2byte(COM_BINLOG_DUMP)
if self.__resume_stream:
prelude += struct.pack('<I', self.log_pos)
else:
prelude += struct.pack('<I', 4)
if self.__blocking:
prelude += struct.pack('<h', 0)
else:
prelude += struct.pack('<h', 1)
prelude += struct.pack('<I', self.__server_id)
prelude += self.log_file.encode()
else:
# Format for mysql packet master_auto_position
#
# All fields are little endian
# All fields are unsigned
# Packet length uint 4bytes
# Packet type byte 1byte == 0x1e
# Binlog flags ushort 2bytes == 0 (for retrocompatibilty)
# Server id uint 4bytes
# binlognamesize uint 4bytes
# binlogname str Nbytes N = binlognamesize
# Zeroified
# binlog position uint 4bytes == 4
# payload_size uint 4bytes
# What come next, is the payload, where the slave gtid_executed
# is sent to the master
# n_sid ulong 8bytes == which size is the gtid_set
# | sid uuid 16bytes UUID as a binary
# | n_intervals ulong 8bytes == how many intervals are sent for this gtid
# | | start ulong 8bytes Start position of this interval
# | | stop ulong 8bytes Stop position of this interval
# A gtid set looks like:
# 19d69c1e-ae97-4b8c-a1ef-9e12ba966457:1-3:8-10,
# 1c2aad49-ae92-409a-b4df-d05a03e4702e:42-47:80-100:130-140
#
# In this particular gtid set, 19d69c1e-ae97-4b8c-a1ef-9e12ba966457:1-3:8-10
# is the first member of the set, it is called a gtid.
# In this gtid, 19d69c1e-ae97-4b8c-a1ef-9e12ba966457 is the sid
# and have two intervals, 1-3 and 8-10, 1 is the start position of the first interval
# 3 is the stop position of the first interval.
gtid_set = GtidSet(self.auto_position)
encoded_data_size = gtid_set.encoded_length
header_size = (2 + # binlog_flags
4 + # server_id
4 + # binlog_name_info_size
4 + # empty binlog name
8 + # binlog_pos_info_size
4) # encoded_data_size
prelude = b'' + struct.pack('<i', header_size + encoded_data_size) \
+ int2byte(COM_BINLOG_DUMP_GTID)
# binlog_flags = 0 (2 bytes)
prelude += struct.pack('<H', 0)
# server_id (4 bytes)
prelude += struct.pack('<I', self.__server_id)
# binlog_name_info_size (4 bytes)
prelude += struct.pack('<I', 3)
# empty_binlog_name (4 bytes)
prelude += b'\0\0\0'
# binlog_pos_info (8 bytes)
prelude += struct.pack('<Q', 4)
# encoded_data_size (4 bytes)
prelude += struct.pack('<I', gtid_set.encoded_length)
# encoded_data
prelude += gtid_set.encoded()
if pymysql.__version__ < "0.6":
self._stream_connection.wfile.write(prelude)
self._stream_connection.wfile.flush()
else:
self._stream_connection._write_bytes(prelude)
self._stream_connection._next_seq_id = 1
self.__connected_stream = True
def fetchone(self):
while True:
if not self.__connected_stream:
self.__connect_to_stream()
if not self.__connected_ctl:
self.__connect_to_ctl()
try:
if pymysql.__version__ < "0.6":
pkt = self._stream_connection.read_packet()
else:
pkt = self._stream_connection._read_packet()
except pymysql.OperationalError as error:
code, message = error.args
if code in MYSQL_EXPECTED_ERROR_CODES:
self._stream_connection.close()
self.__connected_stream = False
continue
raise
if pkt.is_eof_packet():
self.close()
return None
if not pkt.is_ok_packet():
continue
binlog_event = BinLogPacketWrapper(pkt, self.table_map,
self._ctl_connection,
self.mysql_version,
self.__use_checksum,
self.__allowed_events_in_packet,
self.__only_tables,
self.__ignored_tables,
self.__only_schemas,
self.__ignored_schemas,
self.__freeze_schema,
self.__fail_on_table_metadata_unavailable)
if binlog_event.event_type == ROTATE_EVENT:
self.log_pos = binlog_event.event.position
self.log_file = binlog_event.event.next_binlog
# Table Id in binlog are NOT persistent in MySQL - they are in-memory identifiers
# that means that when MySQL master restarts, it will reuse same table id for different tables
# which will cause errors for us since our in-memory map will try to decode row data with
# wrong table schema.
# The fix is to rely on the fact that MySQL will also rotate to a new binlog file every time it
# restarts. That means every rotation we see *could* be a sign of restart and so potentially
# invalidates all our cached table id to schema mappings. This means we have to load them all
# again for each logfile which is potentially wasted effort but we can't really do much better
# without being broken in restart case
self.table_map = {}
elif binlog_event.log_pos:
self.log_pos = binlog_event.log_pos
# This check must not occur before clearing the ``table_map`` as a
# result of a RotateEvent.
#
# The first RotateEvent in a binlog file has a timestamp of
# zero. If the server has moved to a new log and not written a
# timestamped RotateEvent at the end of the previous log, the
# RotateEvent at the beginning of the new log will be ignored
# if the caller provided a positive ``skip_to_timestamp``
# value. This will result in the ``table_map`` becoming
# corrupt.
#
# https://dev.mysql.com/doc/internals/en/event-data-for-specific-event-types.html
# From the MySQL Internals Manual:
#
# ROTATE_EVENT is generated locally and written to the binary
# log on the master. It is written to the relay log on the
# slave when FLUSH LOGS occurs, and when receiving a
# ROTATE_EVENT from the master. In the latter case, there
# will be two rotate events in total originating on different
# servers.
#
# There are conditions under which the terminating
# log-rotation event does not occur. For example, the server
# might crash.
if self.skip_to_timestamp and binlog_event.timestamp < self.skip_to_timestamp:
continue
if binlog_event.event_type == TABLE_MAP_EVENT and \
binlog_event.event is not None:
self.table_map[binlog_event.event.table_id] = \
binlog_event.event.get_table()
# event is none if we have filter it on packet level
# we filter also not allowed events
if binlog_event.event is None or (binlog_event.event.__class__ not in self.__allowed_events):
continue
if binlog_event.event_type == FORMAT_DESCRIPTION_EVENT:
self.mysql_version = binlog_event.event.mysql_version
return binlog_event.event
def _allowed_event_list(self, only_events, ignored_events,
filter_non_implemented_events):
if only_events is not None:
events = set(only_events)
else:
events = set((
QueryEvent,
RotateEvent,
StopEvent,
FormatDescriptionEvent,
XAPrepareEvent,
XidEvent,
GtidEvent,
BeginLoadQueryEvent,
ExecuteLoadQueryEvent,
UpdateRowsEvent,
WriteRowsEvent,
DeleteRowsEvent,
TableMapEvent,
HeartbeatLogEvent,
NotImplementedEvent,
))
if ignored_events is not None:
for e in ignored_events:
events.remove(e)
if filter_non_implemented_events:
try:
events.remove(NotImplementedEvent)
except KeyError:
pass
return frozenset(events)
def __get_table_information(self, schema, table):
for i in range(1, 3):
try:
if not self.__connected_ctl:
self.__connect_to_ctl()
cur = self._ctl_connection.cursor()
cur.execute("""
SELECT
COLUMN_NAME, COLLATION_NAME, CHARACTER_SET_NAME,
COLUMN_COMMENT, COLUMN_TYPE, COLUMN_KEY
FROM
information_schema.columns
WHERE
table_schema = %s AND table_name = %s
""", (schema, table))
return cur.fetchall()
except pymysql.OperationalError as error:
code, message = error.args
if code in MYSQL_EXPECTED_ERROR_CODES:
self.__connected_ctl = False
continue
else:
raise error
def __iter__(self):
return iter(self.fetchone, None)

View File

@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
bitCountInByte = [
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8,
]
# Calculate totol bit counts in a bitmap
def BitCount(bitmap):
n = 0
for i in range(0, len(bitmap)):
bit = bitmap[i]
if type(bit) is str:
bit = ord(bit)
n += bitCountInByte[bit]
return n
# Get the bit set at offset position in bitmap
def BitGet(bitmap, position):
bit = bitmap[int(position / 8)]
if type(bit) is str:
bit = ord(bit)
return bit & (1 << (position & 7))

View File

@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
import struct
from .constants import FIELD_TYPE
class Column(object):
"""Definition of a column
"""
def __init__(self, *args, **kwargs):
if len(args) == 3:
self.__parse_column_definition(*args)
else:
self.__dict__.update(kwargs)
def __parse_column_definition(self, column_type, column_schema, packet):
self.type = column_type
self.name = column_schema["COLUMN_NAME"]
self.collation_name = column_schema["COLLATION_NAME"]
self.character_set_name = column_schema["CHARACTER_SET_NAME"]
self.comment = column_schema["COLUMN_COMMENT"]
self.unsigned = column_schema["COLUMN_TYPE"].find("unsigned") != -1
self.type_is_bool = False
self.is_primary = column_schema["COLUMN_KEY"] == "PRI"
if self.type == FIELD_TYPE.VARCHAR:
self.max_length = struct.unpack('<H', packet.read(2))[0]
elif self.type == FIELD_TYPE.DOUBLE:
self.size = packet.read_uint8()
elif self.type == FIELD_TYPE.FLOAT:
self.size = packet.read_uint8()
elif self.type == FIELD_TYPE.TIMESTAMP2:
self.fsp = packet.read_uint8()
elif self.type == FIELD_TYPE.DATETIME2:
self.fsp = packet.read_uint8()
elif self.type == FIELD_TYPE.TIME2:
self.fsp = packet.read_uint8()
elif self.type == FIELD_TYPE.TINY and \
column_schema["COLUMN_TYPE"] == "tinyint(1)":
self.type_is_bool = True
elif self.type == FIELD_TYPE.VAR_STRING or \
self.type == FIELD_TYPE.STRING:
self.__read_string_metadata(packet, column_schema)
elif self.type == FIELD_TYPE.BLOB:
self.length_size = packet.read_uint8()
elif self.type == FIELD_TYPE.GEOMETRY:
self.length_size = packet.read_uint8()
elif self.type == FIELD_TYPE.JSON:
self.length_size = packet.read_uint8()
elif self.type == FIELD_TYPE.NEWDECIMAL:
self.precision = packet.read_uint8()
self.decimals = packet.read_uint8()
elif self.type == FIELD_TYPE.BIT:
bits = packet.read_uint8()
bytes = packet.read_uint8()
self.bits = (bytes * 8) + bits
self.bytes = int((self.bits + 7) / 8)
def __read_string_metadata(self, packet, column_schema):
metadata = (packet.read_uint8() << 8) + packet.read_uint8()
real_type = metadata >> 8
if real_type == FIELD_TYPE.SET or real_type == FIELD_TYPE.ENUM:
self.type = real_type
self.size = metadata & 0x00ff
self.__read_enum_metadata(column_schema)
else:
self.max_length = (((metadata >> 4) & 0x300) ^ 0x300) \
+ (metadata & 0x00ff)
def __read_enum_metadata(self, column_schema):
enums = column_schema["COLUMN_TYPE"]
if self.type == FIELD_TYPE.ENUM:
self.enum_values = enums.replace('enum(', '')\
.replace(')', '').replace('\'', '').split(',')
else:
self.set_values = enums.replace('set(', '')\
.replace(')', '').replace('\'', '').split(',')
def __eq__(self, other):
return self.data == other.data
def __ne__(self, other):
return not self.__eq__(other)
def serializable_data(self):
return self.data
@property
def data(self):
return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith('_'))

View File

@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
UNKNOWN_EVENT = 0x00
START_EVENT_V3 = 0x01
QUERY_EVENT = 0x02
STOP_EVENT = 0x03
ROTATE_EVENT = 0x04
INTVAR_EVENT = 0x05
LOAD_EVENT = 0x06
SLAVE_EVENT = 0x07
CREATE_FILE_EVENT = 0x08
APPEND_BLOCK_EVENT = 0x09
EXEC_LOAD_EVENT = 0x0a
DELETE_FILE_EVENT = 0x0b
NEW_LOAD_EVENT = 0x0c
RAND_EVENT = 0x0d
USER_VAR_EVENT = 0x0e
FORMAT_DESCRIPTION_EVENT = 0x0f
XID_EVENT = 0x10
BEGIN_LOAD_QUERY_EVENT = 0x11
EXECUTE_LOAD_QUERY_EVENT = 0x12
TABLE_MAP_EVENT = 0x13
PRE_GA_WRITE_ROWS_EVENT = 0x14
PRE_GA_UPDATE_ROWS_EVENT = 0x15
PRE_GA_DELETE_ROWS_EVENT = 0x16
WRITE_ROWS_EVENT_V1 = 0x17
UPDATE_ROWS_EVENT_V1 = 0x18
DELETE_ROWS_EVENT_V1 = 0x19
INCIDENT_EVENT = 0x1a
HEARTBEAT_LOG_EVENT = 0x1b
IGNORABLE_LOG_EVENT = 0x1c
ROWS_QUERY_LOG_EVENT = 0x1d
WRITE_ROWS_EVENT_V2 = 0x1e
UPDATE_ROWS_EVENT_V2 = 0x1f
DELETE_ROWS_EVENT_V2 = 0x20
GTID_LOG_EVENT = 0x21
ANONYMOUS_GTID_LOG_EVENT = 0x22
PREVIOUS_GTIDS_LOG_EVENT = 0x23
XA_PREPARE_EVENT = 0x26
# INTVAR types
INTVAR_INVALID_INT_EVENT = 0x00
INTVAR_LAST_INSERT_ID_EVENT = 0x01
INTVAR_INSERT_ID_EVENT = 0x02

View File

@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
# Original code from PyMySQL
# Copyright (c) 2010 PyMySQL contributors
#
#Permission is hereby granted, free of charge, to any person obtaining a copy
#of this software and associated documentation files (the "Software"), to deal
#in the Software without restriction, including without limitation the rights
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#copies of the Software, and to permit persons to whom the Software is
#furnished to do so, subject to the following conditions:
#
#The above copyright notice and this permission notice shall be included in
#all copies or substantial portions of the Software.
#
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
#THE SOFTWARE.
DECIMAL = 0
TINY = 1
SHORT = 2
LONG = 3
FLOAT = 4
DOUBLE = 5
NULL = 6
TIMESTAMP = 7
LONGLONG = 8
INT24 = 9
DATE = 10
TIME = 11
DATETIME = 12
YEAR = 13
NEWDATE = 14
VARCHAR = 15
BIT = 16
TIMESTAMP2 = 17
DATETIME2 = 18
TIME2 = 19
JSON = 245 # Introduced in 5.7.8
NEWDECIMAL = 246
ENUM = 247
SET = 248
TINY_BLOB = 249
MEDIUM_BLOB = 250
LONG_BLOB = 251
BLOB = 252
VAR_STRING = 253
STRING = 254
GEOMETRY = 255
CHAR = TINY
INTERVAL = ENUM

View File

@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
from .BINLOG import *
from .FIELD_TYPE import *

View File

@ -0,0 +1,329 @@
# -*- coding: utf-8 -*-
import binascii
import struct
import datetime
from pymysql.util import byte2int, int2byte
class BinLogEvent(object):
def __init__(self, from_packet, event_size, table_map, ctl_connection,
mysql_version=(0,0,0),
only_tables=None,
ignored_tables=None,
only_schemas=None,
ignored_schemas=None,
freeze_schema=False,
fail_on_table_metadata_unavailable=False):
self.packet = from_packet
self.table_map = table_map
self.event_type = self.packet.event_type
self.timestamp = self.packet.timestamp
self.event_size = event_size
self._ctl_connection = ctl_connection
self.mysql_version = mysql_version
self._fail_on_table_metadata_unavailable = fail_on_table_metadata_unavailable
# The event have been fully processed, if processed is false
# the event will be skipped
self._processed = True
self.complete = True
def _read_table_id(self):
# Table ID is 6 byte
# pad little-endian number
table_id = self.packet.read(6) + int2byte(0) + int2byte(0)
return struct.unpack('<Q', table_id)[0]
def dump(self):
print("=== %s ===" % (self.__class__.__name__))
print("Date: %s" % (datetime.datetime.fromtimestamp(self.timestamp)
.isoformat()))
print("Log position: %d" % self.packet.log_pos)
print("Event size: %d" % (self.event_size))
print("Read bytes: %d" % (self.packet.read_bytes))
self._dump()
print()
def _dump(self):
"""Core data dumped for the event"""
pass
class GtidEvent(BinLogEvent):
"""GTID change in binlog event
"""
def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super(GtidEvent, self).__init__(from_packet, event_size, table_map,
ctl_connection, **kwargs)
self.commit_flag = byte2int(self.packet.read(1)) == 1
self.sid = self.packet.read(16)
self.gno = struct.unpack('<Q', self.packet.read(8))[0]
self.lt_type = byte2int(self.packet.read(1))
if self.mysql_version >= (5, 7):
self.last_committed = struct.unpack('<Q', self.packet.read(8))[0]
self.sequence_number = struct.unpack('<Q', self.packet.read(8))[0]
@property
def gtid(self):
"""GTID = source_id:transaction_id
Eg: 3E11FA47-71CA-11E1-9E33-C80AA9429562:23
See: http://dev.mysql.com/doc/refman/5.6/en/replication-gtids-concepts.html"""
nibbles = binascii.hexlify(self.sid).decode('ascii')
gtid = '%s-%s-%s-%s-%s:%d' % (
nibbles[:8], nibbles[8:12], nibbles[12:16], nibbles[16:20], nibbles[20:], self.gno
)
return gtid
def _dump(self):
print("Commit: %s" % self.commit_flag)
print("GTID_NEXT: %s" % self.gtid)
if hasattr(self, "last_committed"):
print("last_committed: %d" % self.last_committed)
print("sequence_number: %d" % self.sequence_number)
def __repr__(self):
return '<GtidEvent "%s">' % self.gtid
class RotateEvent(BinLogEvent):
"""Change MySQL bin log file
Attributes:
position: Position inside next binlog
next_binlog: Name of next binlog file
"""
def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super(RotateEvent, self).__init__(from_packet, event_size, table_map,
ctl_connection, **kwargs)
self.position = struct.unpack('<Q', self.packet.read(8))[0]
self.next_binlog = self.packet.read(event_size - 8).decode()
def dump(self):
print("=== %s ===" % (self.__class__.__name__))
print("Position: %d" % self.position)
print("Next binlog file: %s" % self.next_binlog)
print()
class XAPrepareEvent(BinLogEvent):
"""An XA prepare event is generated for a XA prepared transaction.
Like Xid_event it contans XID of the *prepared* transaction
Attributes:
one_phase: current XA transaction commit method
xid: serialized XID representation of XA transaction
"""
def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super(XAPrepareEvent, self).__init__(from_packet, event_size, table_map,
ctl_connection, **kwargs)
# one_phase is True: XA COMMIT ... ONE PHASE
# one_phase is False: XA PREPARE
self.one_phase = (self.packet.read(1) != b'\x00')
self.xid_format_id = struct.unpack('<I', self.packet.read(4))[0]
gtrid_length = struct.unpack('<I', self.packet.read(4))[0]
bqual_length = struct.unpack('<I', self.packet.read(4))[0]
self.xid_gtrid = self.packet.read(gtrid_length)
self.xid_bqual = self.packet.read(bqual_length)
@property
def xid(self):
return self.xid_gtrid.decode() + self.xid_bqual.decode()
def _dump(self):
print("One phase: %s" % self.one_phase)
print("XID formatID: %d" % self.xid_format_id)
print("XID: %s" % self.xid)
class FormatDescriptionEvent(BinLogEvent):
def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super(FormatDescriptionEvent, self).__init__(from_packet, event_size, table_map,
ctl_connection, **kwargs)
self.binlog_version = struct.unpack('<H', self.packet.read(2))
self.mysql_version_str = self.packet.read(50).rstrip(b'\0').decode()
numbers = self.mysql_version_str.split('-')[0]
self.mysql_version = tuple(map(int, numbers.split('.')))
def _dump(self):
print("Binlog version: %s" % self.binlog_version)
print("MySQL version: %s" % self.mysql_version_str)
class StopEvent(BinLogEvent):
pass
class XidEvent(BinLogEvent):
"""A COMMIT event
Attributes:
xid: Transaction ID for 2PC
"""
def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super(XidEvent, self).__init__(from_packet, event_size, table_map,
ctl_connection, **kwargs)
self.xid = struct.unpack('<Q', self.packet.read(8))[0]
def _dump(self):
super(XidEvent, self)._dump()
print("Transaction ID: %d" % (self.xid))
class HeartbeatLogEvent(BinLogEvent):
"""A Heartbeat event
Heartbeats are sent by the master only if there are no unsent events in the
binary log file for a period longer than the interval defined by
MASTER_HEARTBEAT_PERIOD connection setting.
A mysql server will also play those to the slave for each skipped
events in the log. I (baloo) believe the intention is to make the slave
bump its position so that if a disconnection occurs, the slave only
reconnects from the last skipped position (see Binlog_sender::send_events
in sql/rpl_binlog_sender.cc). That makes 106 bytes of data for skipped
event in the binlog. *this is also the case with GTID replication*. To
mitigate such behavior, you are expected to keep the binlog small (see
max_binlog_size, defaults to 1G).
In any case, the timestamp is 0 (as in 1970-01-01T00:00:00).
Attributes:
ident: Name of the current binlog
"""
def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super(HeartbeatLogEvent, self).__init__(from_packet, event_size,
table_map, ctl_connection,
**kwargs)
self.ident = self.packet.read(event_size).decode()
def _dump(self):
super(HeartbeatLogEvent, self)._dump()
print("Current binlog: %s" % (self.ident))
class QueryEvent(BinLogEvent):
'''This evenement is trigger when a query is run of the database.
Only replicated queries are logged.'''
def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super(QueryEvent, self).__init__(from_packet, event_size, table_map,
ctl_connection, **kwargs)
# Post-header
self.slave_proxy_id = self.packet.read_uint32()
self.execution_time = self.packet.read_uint32()
self.schema_length = byte2int(self.packet.read(1))
self.error_code = self.packet.read_uint16()
self.status_vars_length = self.packet.read_uint16()
# Payload
self.status_vars = self.packet.read(self.status_vars_length)
self.schema = self.packet.read(self.schema_length)
self.packet.advance(1)
self.query = self.packet.read(event_size - 13 - self.status_vars_length
- self.schema_length - 1).decode("utf-8")
#string[EOF] query
def _dump(self):
super(QueryEvent, self)._dump()
print("Schema: %s" % (self.schema))
print("Execution time: %d" % (self.execution_time))
print("Query: %s" % (self.query))
class BeginLoadQueryEvent(BinLogEvent):
"""
Attributes:
file_id
block-data
"""
def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super(BeginLoadQueryEvent, self).__init__(from_packet, event_size, table_map,
ctl_connection, **kwargs)
# Payload
self.file_id = self.packet.read_uint32()
self.block_data = self.packet.read(event_size - 4)
def _dump(self):
super(BeginLoadQueryEvent, self)._dump()
print("File id: %d" % (self.file_id))
print("Block data: %s" % (self.block_data))
class ExecuteLoadQueryEvent(BinLogEvent):
"""
Attributes:
slave_proxy_id
execution_time
schema_length
error_code
status_vars_length
file_id
start_pos
end_pos
dup_handling_flags
"""
def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super(ExecuteLoadQueryEvent, self).__init__(from_packet, event_size, table_map,
ctl_connection, **kwargs)
# Post-header
self.slave_proxy_id = self.packet.read_uint32()
self.execution_time = self.packet.read_uint32()
self.schema_length = self.packet.read_uint8()
self.error_code = self.packet.read_uint16()
self.status_vars_length = self.packet.read_uint16()
# Payload
self.file_id = self.packet.read_uint32()
self.start_pos = self.packet.read_uint32()
self.end_pos = self.packet.read_uint32()
self.dup_handling_flags = self.packet.read_uint8()
def _dump(self):
super(ExecuteLoadQueryEvent, self)._dump()
print("Slave proxy id: %d" % (self.slave_proxy_id))
print("Execution time: %d" % (self.execution_time))
print("Schema length: %d" % (self.schema_length))
print("Error code: %d" % (self.error_code))
print("Status vars length: %d" % (self.status_vars_length))
print("File id: %d" % (self.file_id))
print("Start pos: %d" % (self.start_pos))
print("End pos: %d" % (self.end_pos))
print("Dup handling flags: %d" % (self.dup_handling_flags))
class IntvarEvent(BinLogEvent):
"""
Attributes:
type
value
"""
def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super(IntvarEvent, self).__init__(from_packet, event_size, table_map,
ctl_connection, **kwargs)
# Payload
self.type = self.packet.read_uint8()
self.value = self.packet.read_uint32()
def _dump(self):
super(IntvarEvent, self)._dump()
print("type: %d" % (self.type))
print("Value: %d" % (self.value))
class NotImplementedEvent(BinLogEvent):
def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super(NotImplementedEvent, self).__init__(
from_packet, event_size, table_map, ctl_connection, **kwargs)
self.packet.advance(event_size)

View File

@ -0,0 +1,8 @@
class TableMetadataUnavailableError(Exception):
def __init__(self, table):
Exception.__init__(self,"Unable to find metadata for table {0}".format(table))
class BinLogNotEnabled(Exception):
def __init__(self):
Exception.__init__(self, "MySQL binary logging is not enabled.")

View File

@ -0,0 +1,299 @@
# -*- coding: utf-8 -*-
import re
import struct
import binascii
from io import BytesIO
def overlap(i1, i2):
return i1[0] < i2[1] and i1[1] > i2[0]
def contains(i1, i2):
return i2[0] >= i1[0] and i2[1] <= i1[1]
class Gtid(object):
"""A mysql GTID is composed of a server-id and a set of right-open
intervals [a,b), and represent all transactions x that happened on
server SID such as
x <= a < b
The human representation of it, though, is either represented by a
single transaction number A=a (when only one transaction is covered,
ie b = a+1)
SID:A
Or a closed interval [A,B] for at least two transactions (note, in that
case, that b=B+1)
SID:A-B
We can also have a mix of ranges for a given SID:
SID:1-2:4:6-74
For convenience, a Gtid accepts adding Gtid's to it and will merge
the existing interval representation. Adding TXN 3 to the human
representation above would produce:
SID:1-4:6-74
and adding 5 to this new result:
SID:1-74
Adding an already present transaction number (one that overlaps) will
raise an exception.
Adding a Gtid with a different SID will raise an exception.
"""
@staticmethod
def parse_interval(interval):
"""
We parse a human-generated string here. So our end value b
is incremented to conform to the internal representation format.
"""
m = re.search('^([0-9]+)(?:-([0-9]+))?$', interval)
if not m:
raise ValueError('GTID format is incorrect: %r' % (interval, ))
a = int(m.group(1))
b = int(m.group(2) or a)
return (a, b+1)
@staticmethod
def parse(gtid):
m = re.search('^([0-9a-fA-F]{8}(?:-[0-9a-fA-F]{4}){3}-[0-9a-fA-F]{12})'
'((?::[0-9-]+)+)$', gtid)
if not m:
raise ValueError('GTID format is incorrect: %r' % (gtid, ))
sid = m.group(1)
intervals = m.group(2)
intervals_parsed = [Gtid.parse_interval(x)
for x in intervals.split(':')[1:]]
return (sid, intervals_parsed)
def __add_interval(self, itvl):
"""
Use the internal representation format and add it
to our intervals, merging if required.
"""
new = []
if itvl[0] > itvl[1]:
raise Exception('Malformed interval %s' % (itvl,))
if any(overlap(x, itvl) for x in self.intervals):
raise Exception('Overlapping interval %s' % (itvl,))
## Merge: arrange interval to fit existing set
for existing in sorted(self.intervals):
if itvl[0] == existing[1]:
itvl = (existing[0], itvl[1])
continue
if itvl[1] == existing[0]:
itvl = (itvl[0], existing[1])
continue
new.append(existing)
self.intervals = sorted(new + [itvl])
def __sub_interval(self, itvl):
"""Using the internal representation, remove an interval"""
new = []
if itvl[0] > itvl[1]:
raise Exception('Malformed interval %s' % (itvl,))
if not any(overlap(x, itvl) for x in self.intervals):
# No raise
return
## Merge: arrange existing set around interval
for existing in sorted(self.intervals):
if overlap(existing, itvl):
if existing[0] < itvl[0]:
new.append((existing[0], itvl[0]))
if existing[1] > itvl[1]:
new.append((itvl[1], existing[1]))
else:
new.append(existing)
self.intervals = new
def __contains__(self, other):
if other.sid != self.sid:
return False
return all(any(contains(me, them) for me in self.intervals)
for them in other.intervals)
def __init__(self, gtid, sid=None, intervals=[]):
if sid:
intervals = intervals
else:
sid, intervals = Gtid.parse(gtid)
self.sid = sid
self.intervals = []
for itvl in intervals:
self.__add_interval(itvl)
def __add__(self, other):
"""Include the transactions of this gtid. Raise if the
attempted merge has different SID"""
if self.sid != other.sid:
raise Exception('Attempt to merge different SID'
'%s != %s' % (self.sid, other.sid))
result = Gtid(str(self))
for itvl in other.intervals:
result.__add_interval(itvl)
return result
def __sub__(self, other):
"""Remove intervals. Do not raise, if different SID simply
ignore"""
result = Gtid(str(self))
if self.sid != other.sid:
return result
for itvl in other.intervals:
result.__sub_interval(itvl)
return result
def __cmp__(self, other):
if other.sid != self.sid:
return cmp(self.sid, other.sid)
return cmp(self.intervals, other.intervals)
def __str__(self):
"""We represent the human value here - a single number
for one transaction, or a closed interval (decrementing b)"""
return '%s:%s' % (self.sid,
':'.join(('%d-%d' % (x[0], x[1]-1)) if x[0] +1 != x[1]
else str(x[0])
for x in self.intervals))
def __repr__(self):
return '<Gtid "%s">' % self
@property
def encoded_length(self):
return (16 + # sid
8 + # n_intervals
2 * # stop/start
8 * # stop/start mark encoded as int64
len(self.intervals))
def encode(self):
buffer = b''
# sid
buffer += binascii.unhexlify(self.sid.replace('-', ''))
# n_intervals
buffer += struct.pack('<Q', len(self.intervals))
for interval in self.intervals:
# Start position
buffer += struct.pack('<Q', interval[0])
# Stop position
buffer += struct.pack('<Q', interval[1])
return buffer
@classmethod
def decode(cls, payload):
assert isinstance(payload, BytesIO), \
'payload is expected to be a BytesIO'
sid = b''
sid = sid + binascii.hexlify(payload.read(4))
sid = sid + b'-'
sid = sid + binascii.hexlify(payload.read(2))
sid = sid + b'-'
sid = sid + binascii.hexlify(payload.read(2))
sid = sid + b'-'
sid = sid + binascii.hexlify(payload.read(2))
sid = sid + b'-'
sid = sid + binascii.hexlify(payload.read(6))
(n_intervals,) = struct.unpack('<Q', payload.read(8))
intervals = []
for i in range(0, n_intervals):
start, end = struct.unpack('<QQ', payload.read(16))
intervals.append((start, end-1))
return cls('%s:%s' % (sid.decode('ascii'), ':'.join([
'%d-%d' % x
if isinstance(x, tuple)
else '%d' % x
for x in intervals])))
class GtidSet(object):
def __init__(self, gtid_set):
def _to_gtid(element):
if isinstance(element, Gtid):
return element
return Gtid(element.strip(' \n'))
if not gtid_set:
self.gtids = []
elif isinstance(gtid_set, (list, set)):
self.gtids = [_to_gtid(x) for x in gtid_set]
else:
self.gtids = [Gtid(x.strip(' \n')) for x in gtid_set.split(',')]
def merge_gtid(self, gtid):
new_gtids = []
for existing in self.gtids:
if existing.sid == gtid.sid:
new_gtids.append(existing + gtid)
else:
new_gtids.append(existing)
if gtid.sid not in (x.sid for x in new_gtids):
new_gtids.append(gtid)
self.gtids = new_gtids
def __contains__(self, other):
if isinstance(other, Gtid):
return any(other in x for x in self.gtids)
raise NotImplementedError
def __add__(self, other):
if isinstance(other, Gtid):
new = GtidSet(self.gtids)
new.merge_gtid(other)
return new
raise NotImplementedError
def __str__(self):
return ','.join(str(x) for x in self.gtids)
def __repr__(self):
return '<GtidSet %r>' % self.gtids
@property
def encoded_length(self):
return (8 + # n_sids
sum(x.encoded_length for x in self.gtids))
def encoded(self):
return b'' + (struct.pack('<Q', len(self.gtids)) +
b''.join(x.encode() for x in self.gtids))
encode = encoded
@classmethod
def decode(cls, payload):
assert isinstance(payload, BytesIO), \
'payload is expected to be a BytesIO'
(n_sid,) = struct.unpack('<Q', payload.read(8))
return cls([Gtid.decode(payload) for _ in range(0, n_sid)])

View File

@ -0,0 +1,470 @@
# -*- coding: utf-8 -*-
import struct
from pymysql.util import byte2int
from pymysqlreplication import constants, event, row_event
# Constants from PyMYSQL source code
NULL_COLUMN = 251
UNSIGNED_CHAR_COLUMN = 251
UNSIGNED_SHORT_COLUMN = 252
UNSIGNED_INT24_COLUMN = 253
UNSIGNED_INT64_COLUMN = 254
UNSIGNED_CHAR_LENGTH = 1
UNSIGNED_SHORT_LENGTH = 2
UNSIGNED_INT24_LENGTH = 3
UNSIGNED_INT64_LENGTH = 8
JSONB_TYPE_SMALL_OBJECT = 0x0
JSONB_TYPE_LARGE_OBJECT = 0x1
JSONB_TYPE_SMALL_ARRAY = 0x2
JSONB_TYPE_LARGE_ARRAY = 0x3
JSONB_TYPE_LITERAL = 0x4
JSONB_TYPE_INT16 = 0x5
JSONB_TYPE_UINT16 = 0x6
JSONB_TYPE_INT32 = 0x7
JSONB_TYPE_UINT32 = 0x8
JSONB_TYPE_INT64 = 0x9
JSONB_TYPE_UINT64 = 0xA
JSONB_TYPE_DOUBLE = 0xB
JSONB_TYPE_STRING = 0xC
JSONB_TYPE_OPAQUE = 0xF
JSONB_LITERAL_NULL = 0x0
JSONB_LITERAL_TRUE = 0x1
JSONB_LITERAL_FALSE = 0x2
def read_offset_or_inline(packet, large):
t = packet.read_uint8()
if t in (JSONB_TYPE_LITERAL,
JSONB_TYPE_INT16, JSONB_TYPE_UINT16):
return (t, None, packet.read_binary_json_type_inlined(t))
if large and t in (JSONB_TYPE_INT32, JSONB_TYPE_UINT32):
return (t, None, packet.read_binary_json_type_inlined(t))
if large:
return (t, packet.read_uint32(), None)
return (t, packet.read_uint16(), None)
class BinLogPacketWrapper(object):
"""
Bin Log Packet Wrapper. It uses an existing packet object, and wraps
around it, exposing useful variables while still providing access
to the original packet objects variables and methods.
"""
__event_map = {
# event
constants.QUERY_EVENT: event.QueryEvent,
constants.ROTATE_EVENT: event.RotateEvent,
constants.FORMAT_DESCRIPTION_EVENT: event.FormatDescriptionEvent,
constants.XID_EVENT: event.XidEvent,
constants.INTVAR_EVENT: event.IntvarEvent,
constants.GTID_LOG_EVENT: event.GtidEvent,
constants.STOP_EVENT: event.StopEvent,
constants.BEGIN_LOAD_QUERY_EVENT: event.BeginLoadQueryEvent,
constants.EXECUTE_LOAD_QUERY_EVENT: event.ExecuteLoadQueryEvent,
constants.HEARTBEAT_LOG_EVENT: event.HeartbeatLogEvent,
constants.XA_PREPARE_EVENT: event.XAPrepareEvent,
# row_event
constants.UPDATE_ROWS_EVENT_V1: row_event.UpdateRowsEvent,
constants.WRITE_ROWS_EVENT_V1: row_event.WriteRowsEvent,
constants.DELETE_ROWS_EVENT_V1: row_event.DeleteRowsEvent,
constants.UPDATE_ROWS_EVENT_V2: row_event.UpdateRowsEvent,
constants.WRITE_ROWS_EVENT_V2: row_event.WriteRowsEvent,
constants.DELETE_ROWS_EVENT_V2: row_event.DeleteRowsEvent,
constants.TABLE_MAP_EVENT: row_event.TableMapEvent,
#5.6 GTID enabled replication events
constants.ANONYMOUS_GTID_LOG_EVENT: event.NotImplementedEvent,
constants.PREVIOUS_GTIDS_LOG_EVENT: event.NotImplementedEvent
}
def __init__(self, from_packet, table_map,
ctl_connection,
mysql_version,
use_checksum,
allowed_events,
only_tables,
ignored_tables,
only_schemas,
ignored_schemas,
freeze_schema,
fail_on_table_metadata_unavailable):
# -1 because we ignore the ok byte
self.read_bytes = 0
# Used when we want to override a value in the data buffer
self.__data_buffer = b''
self.packet = from_packet
self.charset = ctl_connection.charset
# OK value
# timestamp
# event_type
# server_id
# log_pos
# flags
unpack = struct.unpack('<cIcIIIH', self.packet.read(20))
# Header
self.timestamp = unpack[1]
self.event_type = byte2int(unpack[2])
self.server_id = unpack[3]
self.event_size = unpack[4]
# position of the next event
self.log_pos = unpack[5]
self.flags = unpack[6]
# MySQL 5.6 and more if binlog-checksum = CRC32
if use_checksum:
event_size_without_header = self.event_size - 23
else:
event_size_without_header = self.event_size - 19
self.event = None
event_class = self.__event_map.get(self.event_type, event.NotImplementedEvent)
if event_class not in allowed_events:
return
self.event = event_class(self, event_size_without_header, table_map,
ctl_connection,
mysql_version=mysql_version,
only_tables=only_tables,
ignored_tables=ignored_tables,
only_schemas=only_schemas,
ignored_schemas=ignored_schemas,
freeze_schema=freeze_schema,
fail_on_table_metadata_unavailable=fail_on_table_metadata_unavailable)
if self.event._processed == False:
self.event = None
def read(self, size):
size = int(size)
self.read_bytes += size
if len(self.__data_buffer) > 0:
data = self.__data_buffer[:size]
self.__data_buffer = self.__data_buffer[size:]
if len(data) == size:
return data
else:
return data + self.packet.read(size - len(data))
return self.packet.read(size)
def unread(self, data):
'''Push again data in data buffer. It's use when you want
to extract a bit from a value a let the rest of the code normally
read the datas'''
self.read_bytes -= len(data)
self.__data_buffer += data
def advance(self, size):
size = int(size)
self.read_bytes += size
buffer_len = len(self.__data_buffer)
if buffer_len > 0:
self.__data_buffer = self.__data_buffer[size:]
if size > buffer_len:
self.packet.advance(size - buffer_len)
else:
self.packet.advance(size)
def read_length_coded_binary(self):
"""Read a 'Length Coded Binary' number from the data buffer.
Length coded numbers can be anywhere from 1 to 9 bytes depending
on the value of the first byte.
From PyMYSQL source code
"""
c = byte2int(self.read(1))
if c == NULL_COLUMN:
return None
if c < UNSIGNED_CHAR_COLUMN:
return c
elif c == UNSIGNED_SHORT_COLUMN:
return self.unpack_uint16(self.read(UNSIGNED_SHORT_LENGTH))
elif c == UNSIGNED_INT24_COLUMN:
return self.unpack_int24(self.read(UNSIGNED_INT24_LENGTH))
elif c == UNSIGNED_INT64_COLUMN:
return self.unpack_int64(self.read(UNSIGNED_INT64_LENGTH))
def read_length_coded_string(self):
"""Read a 'Length Coded String' from the data buffer.
A 'Length Coded String' consists first of a length coded
(unsigned, positive) integer represented in 1-9 bytes followed by
that many bytes of binary data. (For example "cat" would be "3cat".)
From PyMYSQL source code
"""
length = self.read_length_coded_binary()
if length is None:
return None
return self.read(length).decode()
def __getattr__(self, key):
if hasattr(self.packet, key):
return getattr(self.packet, key)
raise AttributeError("%s instance has no attribute '%s'" %
(self.__class__, key))
def read_int_be_by_size(self, size):
'''Read a big endian integer values based on byte number'''
if size == 1:
return struct.unpack('>b', self.read(size))[0]
elif size == 2:
return struct.unpack('>h', self.read(size))[0]
elif size == 3:
return self.read_int24_be()
elif size == 4:
return struct.unpack('>i', self.read(size))[0]
elif size == 5:
return self.read_int40_be()
elif size == 8:
return struct.unpack('>l', self.read(size))[0]
def read_uint_by_size(self, size):
'''Read a little endian integer values based on byte number'''
if size == 1:
return self.read_uint8()
elif size == 2:
return self.read_uint16()
elif size == 3:
return self.read_uint24()
elif size == 4:
return self.read_uint32()
elif size == 5:
return self.read_uint40()
elif size == 6:
return self.read_uint48()
elif size == 7:
return self.read_uint56()
elif size == 8:
return self.read_uint64()
def read_length_coded_pascal_string(self, size):
"""Read a string with length coded using pascal style.
The string start by the size of the string
"""
length = self.read_uint_by_size(size)
return self.read(length)
def read_variable_length_string(self):
"""Read a variable length string where the first 1-5 bytes stores the
length of the string.
For each byte, the first bit being high indicates another byte must be
read.
"""
byte = 0x80
length = 0
bits_read = 0
while byte & 0x80 != 0:
byte = byte2int(self.read(1))
length = length | ((byte & 0x7f) << bits_read)
bits_read = bits_read + 7
return self.read(length)
def read_int24(self):
a, b, c = struct.unpack("BBB", self.read(3))
res = a | (b << 8) | (c << 16)
if res >= 0x800000:
res -= 0x1000000
return res
def read_int24_be(self):
a, b, c = struct.unpack('BBB', self.read(3))
res = (a << 16) | (b << 8) | c
if res >= 0x800000:
res -= 0x1000000
return res
def read_uint8(self):
return struct.unpack('<B', self.read(1))[0]
def read_int16(self):
return struct.unpack('<h', self.read(2))[0]
def read_uint16(self):
return struct.unpack('<H', self.read(2))[0]
def read_uint24(self):
a, b, c = struct.unpack("<BBB", self.read(3))
return a + (b << 8) + (c << 16)
def read_uint32(self):
return struct.unpack('<I', self.read(4))[0]
def read_int32(self):
return struct.unpack('<i', self.read(4))[0]
def read_uint40(self):
a, b = struct.unpack("<BI", self.read(5))
return a + (b << 8)
def read_int40_be(self):
a, b = struct.unpack(">IB", self.read(5))
return b + (a << 8)
def read_uint48(self):
a, b, c = struct.unpack("<HHH", self.read(6))
return a + (b << 16) + (c << 32)
def read_uint56(self):
a, b, c = struct.unpack("<BHI", self.read(7))
return a + (b << 8) + (c << 24)
def read_uint64(self):
return struct.unpack('<Q', self.read(8))[0]
def read_int64(self):
return struct.unpack('<q', self.read(8))[0]
def unpack_uint16(self, n):
return struct.unpack('<H', n[0:2])[0]
def unpack_int24(self, n):
try:
return struct.unpack('B', n[0])[0] \
+ (struct.unpack('B', n[1])[0] << 8) \
+ (struct.unpack('B', n[2])[0] << 16)
except TypeError:
return n[0] + (n[1] << 8) + (n[2] << 16)
def unpack_int32(self, n):
try:
return struct.unpack('B', n[0])[0] \
+ (struct.unpack('B', n[1])[0] << 8) \
+ (struct.unpack('B', n[2])[0] << 16) \
+ (struct.unpack('B', n[3])[0] << 24)
except TypeError:
return n[0] + (n[1] << 8) + (n[2] << 16) + (n[3] << 24)
def read_binary_json(self, size):
length = self.read_uint_by_size(size)
payload = self.read(length)
self.unread(payload)
t = self.read_uint8()
return self.read_binary_json_type(t, length)
def read_binary_json_type(self, t, length):
large = (t in (JSONB_TYPE_LARGE_OBJECT, JSONB_TYPE_LARGE_ARRAY))
if t in (JSONB_TYPE_SMALL_OBJECT, JSONB_TYPE_LARGE_OBJECT):
return self.read_binary_json_object(length - 1, large)
elif t in (JSONB_TYPE_SMALL_ARRAY, JSONB_TYPE_LARGE_ARRAY):
return self.read_binary_json_array(length - 1, large)
elif t in (JSONB_TYPE_STRING,):
return self.read_variable_length_string()
elif t in (JSONB_TYPE_LITERAL,):
value = self.read_uint8()
if value == JSONB_LITERAL_NULL:
return None
elif value == JSONB_LITERAL_TRUE:
return True
elif value == JSONB_LITERAL_FALSE:
return False
elif t == JSONB_TYPE_INT16:
return self.read_int16()
elif t == JSONB_TYPE_UINT16:
return self.read_uint16()
elif t in (JSONB_TYPE_DOUBLE,):
return struct.unpack('<d', self.read(8))[0]
elif t == JSONB_TYPE_INT32:
return self.read_int32()
elif t == JSONB_TYPE_UINT32:
return self.read_uint32()
elif t == JSONB_TYPE_INT64:
return self.read_int64()
elif t == JSONB_TYPE_UINT64:
return self.read_uint64()
raise ValueError('Json type %d is not handled' % t)
def read_binary_json_type_inlined(self, t):
if t == JSONB_TYPE_LITERAL:
value = self.read_uint16()
if value == JSONB_LITERAL_NULL:
return None
elif value == JSONB_LITERAL_TRUE:
return True
elif value == JSONB_LITERAL_FALSE:
return False
elif t == JSONB_TYPE_INT16:
return self.read_int16()
elif t == JSONB_TYPE_UINT16:
return self.read_uint16()
elif t == JSONB_TYPE_INT32:
return self.read_int32()
elif t == JSONB_TYPE_UINT32:
return self.read_uint32()
raise ValueError('Json type %d is not handled' % t)
def read_binary_json_object(self, length, large):
if large:
elements = self.read_uint32()
size = self.read_uint32()
else:
elements = self.read_uint16()
size = self.read_uint16()
if size > length:
raise ValueError('Json length is larger than packet length')
if large:
key_offset_lengths = [(
self.read_uint32(), # offset (we don't actually need that)
self.read_uint16() # size of the key
) for _ in range(elements)]
else:
key_offset_lengths = [(
self.read_uint16(), # offset (we don't actually need that)
self.read_uint16() # size of key
) for _ in range(elements)]
value_type_inlined_lengths = [read_offset_or_inline(self, large)
for _ in range(elements)]
keys = [self.read(x[1]) for x in key_offset_lengths]
out = {}
for i in range(elements):
if value_type_inlined_lengths[i][1] is None:
data = value_type_inlined_lengths[i][2]
else:
t = value_type_inlined_lengths[i][0]
data = self.read_binary_json_type(t, length)
out[keys[i]] = data
return out
def read_binary_json_array(self, length, large):
if large:
elements = self.read_uint32()
size = self.read_uint32()
else:
elements = self.read_uint16()
size = self.read_uint16()
if size > length:
raise ValueError('Json length is larger than packet length')
values_type_offset_inline = [
read_offset_or_inline(self, large)
for _ in range(elements)]
def _read(x):
if x[1] is None:
return x[2]
return self.read_binary_json_type(x[0], length)
return [_read(x) for x in values_type_offset_inline]

View File

@ -0,0 +1,629 @@
# -*- coding: utf-8 -*-
import struct
import decimal
import datetime
import json
from pymysql.util import byte2int
from pymysql.charset import charset_to_encoding
from .event import BinLogEvent
from .exceptions import TableMetadataUnavailableError
from .constants import FIELD_TYPE
from .constants import BINLOG
from .column import Column
from .table import Table
from .bitmap import BitCount, BitGet
class RowsEvent(BinLogEvent):
def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super(RowsEvent, self).__init__(from_packet, event_size, table_map,
ctl_connection, **kwargs)
self.__rows = None
self.__only_tables = kwargs["only_tables"]
self.__ignored_tables = kwargs["ignored_tables"]
self.__only_schemas = kwargs["only_schemas"]
self.__ignored_schemas = kwargs["ignored_schemas"]
#Header
self.table_id = self._read_table_id()
# Additional information
try:
self.primary_key = table_map[self.table_id].data["primary_key"]
self.schema = self.table_map[self.table_id].schema
self.table = self.table_map[self.table_id].table
except KeyError: #If we have filter the corresponding TableMap Event
self._processed = False
return
if self.__only_tables is not None and self.table not in self.__only_tables:
self._processed = False
return
elif self.__ignored_tables is not None and self.table in self.__ignored_tables:
self._processed = False
return
if self.__only_schemas is not None and self.schema not in self.__only_schemas:
self._processed = False
return
elif self.__ignored_schemas is not None and self.schema in self.__ignored_schemas:
self._processed = False
return
#Event V2
if self.event_type == BINLOG.WRITE_ROWS_EVENT_V2 or \
self.event_type == BINLOG.DELETE_ROWS_EVENT_V2 or \
self.event_type == BINLOG.UPDATE_ROWS_EVENT_V2:
self.flags, self.extra_data_length = struct.unpack('<HH', self.packet.read(4))
self.extra_data = self.packet.read(self.extra_data_length / 8)
else:
self.flags = struct.unpack('<H', self.packet.read(2))[0]
#Body
self.number_of_columns = self.packet.read_length_coded_binary()
self.columns = self.table_map[self.table_id].columns
if len(self.columns) == 0: # could not read the table metadata, probably already dropped
self.complete = False
if self._fail_on_table_metadata_unavailable:
raise TableMetadataUnavailableError(self.table)
def __is_null(self, null_bitmap, position):
bit = null_bitmap[int(position / 8)]
if type(bit) is str:
bit = ord(bit)
return bit & (1 << (position % 8))
def _read_column_data(self, cols_bitmap):
"""Use for WRITE, UPDATE and DELETE events.
Return an array of column data
"""
values = {}
# null bitmap length = (bits set in 'columns-present-bitmap'+7)/8
# See http://dev.mysql.com/doc/internals/en/rows-event.html
null_bitmap = self.packet.read((BitCount(cols_bitmap) + 7) / 8)
nullBitmapIndex = 0
nb_columns = len(self.columns)
for i in range(0, nb_columns):
column = self.columns[i]
name = self.table_map[self.table_id].columns[i].name
unsigned = self.table_map[self.table_id].columns[i].unsigned
if BitGet(cols_bitmap, i) == 0:
values[name] = None
continue
if self.__is_null(null_bitmap, nullBitmapIndex):
values[name] = None
elif column.type == FIELD_TYPE.TINY:
if unsigned:
values[name] = struct.unpack("<B", self.packet.read(1))[0]
else:
values[name] = struct.unpack("<b", self.packet.read(1))[0]
elif column.type == FIELD_TYPE.SHORT:
if unsigned:
values[name] = struct.unpack("<H", self.packet.read(2))[0]
else:
values[name] = struct.unpack("<h", self.packet.read(2))[0]
elif column.type == FIELD_TYPE.LONG:
if unsigned:
values[name] = struct.unpack("<I", self.packet.read(4))[0]
else:
values[name] = struct.unpack("<i", self.packet.read(4))[0]
elif column.type == FIELD_TYPE.INT24:
if unsigned:
values[name] = self.packet.read_uint24()
else:
values[name] = self.packet.read_int24()
elif column.type == FIELD_TYPE.FLOAT:
values[name] = struct.unpack("<f", self.packet.read(4))[0]
elif column.type == FIELD_TYPE.DOUBLE:
values[name] = struct.unpack("<d", self.packet.read(8))[0]
elif column.type == FIELD_TYPE.VARCHAR or \
column.type == FIELD_TYPE.STRING:
if column.max_length > 255:
values[name] = self.__read_string(2, column)
else:
values[name] = self.__read_string(1, column)
elif column.type == FIELD_TYPE.NEWDECIMAL:
values[name] = self.__read_new_decimal(column)
elif column.type == FIELD_TYPE.BLOB:
values[name] = self.__read_string(column.length_size, column)
elif column.type == FIELD_TYPE.DATETIME:
values[name] = self.__read_datetime()
elif column.type == FIELD_TYPE.TIME:
values[name] = self.__read_time()
elif column.type == FIELD_TYPE.DATE:
values[name] = self.__read_date()
elif column.type == FIELD_TYPE.TIMESTAMP:
t_time = self.packet.read_uint32()
if t_time == 0:
values[name] = '0000-00-00 00:00:00'
else:
values[name] = datetime.datetime.fromtimestamp(t_time)
# For new date format:
elif column.type == FIELD_TYPE.DATETIME2:
values[name] = self.__read_datetime2(column)
elif column.type == FIELD_TYPE.TIME2:
values[name] = self.__read_time2(column)
elif column.type == FIELD_TYPE.TIMESTAMP2:
t_time = self.packet.read_int_be_by_size(4)
if t_time == 0:
values[name] = '0000-00-00 00:00:00'
else:
values[name] = self.__add_fsp_to_time(
datetime.datetime.fromtimestamp(
t_time), column)
elif column.type == FIELD_TYPE.LONGLONG:
if unsigned:
values[name] = self.packet.read_uint64()
else:
values[name] = self.packet.read_int64()
elif column.type == FIELD_TYPE.YEAR:
values[name] = self.packet.read_uint8() + 1900
elif column.type == FIELD_TYPE.ENUM:
values[name] = column.enum_values[
self.packet.read_uint_by_size(column.size) - 1]
elif column.type == FIELD_TYPE.SET:
# We read set columns as a bitmap telling us which options
# are enabled
bit_mask = self.packet.read_uint_by_size(column.size)
values[name] = set(
val for idx, val in enumerate(column.set_values)
if bit_mask & 2 ** idx
) or None
elif column.type == FIELD_TYPE.BIT:
values[name] = self.__read_bit(column)
elif column.type == FIELD_TYPE.GEOMETRY:
values[name] = self.packet.read_length_coded_pascal_string(
column.length_size)
elif column.type == FIELD_TYPE.JSON:
values[name] = self.packet.read_binary_json(column.length_size)
else:
raise NotImplementedError("Unknown MySQL column type: %d" %
(column.type))
nullBitmapIndex += 1
return values
def __add_fsp_to_time(self, time, column):
"""Read and add the fractional part of time
For more details about new date format:
http://dev.mysql.com/doc/internals/en/date-and-time-data-type-representation.html
"""
microsecond = self.__read_fsp(column)
if microsecond > 0:
time = time.replace(microsecond=microsecond)
return time
def __read_fsp(self, column):
read = 0
if column.fsp == 1 or column.fsp == 2:
read = 1
elif column.fsp == 3 or column.fsp == 4:
read = 2
elif column.fsp == 5 or column.fsp == 6:
read = 3
if read > 0:
microsecond = self.packet.read_int_be_by_size(read)
if column.fsp % 2:
return int(microsecond / 10)
else:
return microsecond
return 0
def __read_string(self, size, column):
string = self.packet.read_length_coded_pascal_string(size)
if column.character_set_name is not None:
string = string.decode(charset_to_encoding(column.character_set_name))
return string
def __read_bit(self, column):
"""Read MySQL BIT type"""
resp = ""
for byte in range(0, column.bytes):
current_byte = ""
data = self.packet.read_uint8()
if byte == 0:
if column.bytes == 1:
end = column.bits
else:
end = column.bits % 8
if end == 0:
end = 8
else:
end = 8
for bit in range(0, end):
if data & (1 << bit):
current_byte += "1"
else:
current_byte += "0"
resp += current_byte[::-1]
return resp
def __read_time(self):
time = self.packet.read_uint24()
date = datetime.timedelta(
hours=int(time / 10000),
minutes=int((time % 10000) / 100),
seconds=int(time % 100))
return date
def __read_time2(self, column):
"""TIME encoding for nonfractional part:
1 bit sign (1= non-negative, 0= negative)
1 bit unused (reserved for future extensions)
10 bits hour (0-838)
6 bits minute (0-59)
6 bits second (0-59)
---------------------
24 bits = 3 bytes
"""
data = self.packet.read_int_be_by_size(3)
sign = 1 if self.__read_binary_slice(data, 0, 1, 24) else -1
if sign == -1:
# negative integers are stored as 2's compliment
# hence take 2's compliment again to get the right value.
data = ~data + 1
t = datetime.timedelta(
hours=sign*self.__read_binary_slice(data, 2, 10, 24),
minutes=self.__read_binary_slice(data, 12, 6, 24),
seconds=self.__read_binary_slice(data, 18, 6, 24),
microseconds=self.__read_fsp(column)
)
return t
def __read_date(self):
time = self.packet.read_uint24()
if time == 0: # nasty mysql 0000-00-00 dates
return None
year = (time & ((1 << 15) - 1) << 9) >> 9
month = (time & ((1 << 4) - 1) << 5) >> 5
day = (time & ((1 << 5) - 1))
if year == 0 or month == 0 or day == 0:
return None
date = datetime.date(
year=year,
month=month,
day=day
)
return date
def __read_datetime(self):
value = self.packet.read_uint64()
if value == 0: # nasty mysql 0000-00-00 dates
return None
date = value / 1000000
time = int(value % 1000000)
year = int(date / 10000)
month = int((date % 10000) / 100)
day = int(date % 100)
if year == 0 or month == 0 or day == 0:
return None
date = datetime.datetime(
year=year,
month=month,
day=day,
hour=int(time / 10000),
minute=int((time % 10000) / 100),
second=int(time % 100))
return date
def __read_datetime2(self, column):
"""DATETIME
1 bit sign (1= non-negative, 0= negative)
17 bits year*13+month (year 0-9999, month 0-12)
5 bits day (0-31)
5 bits hour (0-23)
6 bits minute (0-59)
6 bits second (0-59)
---------------------------
40 bits = 5 bytes
"""
data = self.packet.read_int_be_by_size(5)
year_month = self.__read_binary_slice(data, 1, 17, 40)
try:
t = datetime.datetime(
year=int(year_month / 13),
month=year_month % 13,
day=self.__read_binary_slice(data, 18, 5, 40),
hour=self.__read_binary_slice(data, 23, 5, 40),
minute=self.__read_binary_slice(data, 28, 6, 40),
second=self.__read_binary_slice(data, 34, 6, 40))
except ValueError:
self.__read_fsp(column)
return None
return self.__add_fsp_to_time(t, column)
def __read_new_decimal(self, column):
"""Read MySQL's new decimal format introduced in MySQL 5"""
# This project was a great source of inspiration for
# understanding this storage format.
# https://github.com/jeremycole/mysql_binlog
digits_per_integer = 9
compressed_bytes = [0, 1, 1, 2, 2, 3, 3, 4, 4, 4]
integral = (column.precision - column.decimals)
uncomp_integral = int(integral / digits_per_integer)
uncomp_fractional = int(column.decimals / digits_per_integer)
comp_integral = integral - (uncomp_integral * digits_per_integer)
comp_fractional = column.decimals - (uncomp_fractional
* digits_per_integer)
# Support negative
# The sign is encoded in the high bit of the the byte
# But this bit can also be used in the value
value = self.packet.read_uint8()
if value & 0x80 != 0:
res = ""
mask = 0
else:
mask = -1
res = "-"
self.packet.unread(struct.pack('<B', value ^ 0x80))
size = compressed_bytes[comp_integral]
if size > 0:
value = self.packet.read_int_be_by_size(size) ^ mask
res += str(value)
for i in range(0, uncomp_integral):
value = struct.unpack('>i', self.packet.read(4))[0] ^ mask
res += '%09d' % value
res += "."
for i in range(0, uncomp_fractional):
value = struct.unpack('>i', self.packet.read(4))[0] ^ mask
res += '%09d' % value
size = compressed_bytes[comp_fractional]
if size > 0:
value = self.packet.read_int_be_by_size(size) ^ mask
res += '%0*d' % (comp_fractional, value)
return decimal.Decimal(res)
def __read_binary_slice(self, binary, start, size, data_length):
"""
Read a part of binary data and extract a number
binary: the data
start: From which bit (1 to X)
size: How many bits should be read
data_length: data size
"""
binary = binary >> data_length - (start + size)
mask = ((1 << size) - 1)
return binary & mask
def _dump(self):
super(RowsEvent, self)._dump()
print("Table: %s.%s" % (self.schema, self.table))
print("Affected columns: %d" % self.number_of_columns)
print("Changed rows: %d" % (len(self.rows)))
def _fetch_rows(self):
self.__rows = []
if not self.complete:
return
while self.packet.read_bytes + 1 < self.event_size:
self.__rows.append(self._fetch_one_row())
@property
def rows(self):
if self.__rows is None:
self._fetch_rows()
return self.__rows
class DeleteRowsEvent(RowsEvent):
"""This event is trigger when a row in the database is removed
For each row you have a hash with a single key: values which contain the data of the removed line.
"""
def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super(DeleteRowsEvent, self).__init__(from_packet, event_size,
table_map, ctl_connection, **kwargs)
if self._processed:
self.columns_present_bitmap = self.packet.read(
(self.number_of_columns + 7) / 8)
def _fetch_one_row(self):
row = {}
row["values"] = self._read_column_data(self.columns_present_bitmap)
return row
def _dump(self):
super(DeleteRowsEvent, self)._dump()
print("Values:")
for row in self.rows:
print("--")
for key in row["values"]:
print("*", key, ":", row["values"][key])
class WriteRowsEvent(RowsEvent):
"""This event is triggered when a row in database is added
For each row you have a hash with a single key: values which contain the data of the new line.
"""
def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super(WriteRowsEvent, self).__init__(from_packet, event_size,
table_map, ctl_connection, **kwargs)
if self._processed:
self.columns_present_bitmap = self.packet.read(
(self.number_of_columns + 7) / 8)
def _fetch_one_row(self):
row = {}
row["values"] = self._read_column_data(self.columns_present_bitmap)
return row
def _dump(self):
super(WriteRowsEvent, self)._dump()
print("Values:")
for row in self.rows:
print("--")
for key in row["values"]:
print("*", key, ":", row["values"][key])
class UpdateRowsEvent(RowsEvent):
"""This event is triggered when a row in the database is changed
For each row you got a hash with two keys:
* before_values
* after_values
Depending of your MySQL configuration the hash can contains the full row or only the changes:
http://dev.mysql.com/doc/refman/5.6/en/replication-options-binary-log.html#sysvar_binlog_row_image
"""
def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super(UpdateRowsEvent, self).__init__(from_packet, event_size,
table_map, ctl_connection, **kwargs)
if self._processed:
#Body
self.columns_present_bitmap = self.packet.read(
(self.number_of_columns + 7) / 8)
self.columns_present_bitmap2 = self.packet.read(
(self.number_of_columns + 7) / 8)
def _fetch_one_row(self):
row = {}
row["before_values"] = self._read_column_data(self.columns_present_bitmap)
row["after_values"] = self._read_column_data(self.columns_present_bitmap2)
return row
def _dump(self):
super(UpdateRowsEvent, self)._dump()
print("Affected columns: %d" % self.number_of_columns)
print("Values:")
for row in self.rows:
print("--")
for key in row["before_values"]:
print("*%s:%s=>%s" % (key,
row["before_values"][key],
row["after_values"][key]))
class TableMapEvent(BinLogEvent):
"""This evenement describe the structure of a table.
It's sent before a change happens on a table.
An end user of the lib should have no usage of this
"""
def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super(TableMapEvent, self).__init__(from_packet, event_size,
table_map, ctl_connection, **kwargs)
self.__only_tables = kwargs["only_tables"]
self.__ignored_tables = kwargs["ignored_tables"]
self.__only_schemas = kwargs["only_schemas"]
self.__ignored_schemas = kwargs["ignored_schemas"]
self.__freeze_schema = kwargs["freeze_schema"]
# Post-Header
self.table_id = self._read_table_id()
if self.table_id in table_map and self.__freeze_schema:
self._processed = False
return
self.flags = struct.unpack('<H', self.packet.read(2))[0]
# Payload
self.schema_length = byte2int(self.packet.read(1))
self.schema = self.packet.read(self.schema_length).decode()
self.packet.advance(1)
self.table_length = byte2int(self.packet.read(1))
self.table = self.packet.read(self.table_length).decode()
if self.__only_tables is not None and self.table not in self.__only_tables:
self._processed = False
return
elif self.__ignored_tables is not None and self.table in self.__ignored_tables:
self._processed = False
return
if self.__only_schemas is not None and self.schema not in self.__only_schemas:
self._processed = False
return
elif self.__ignored_schemas is not None and self.schema in self.__ignored_schemas:
self._processed = False
return
self.packet.advance(1)
self.column_count = self.packet.read_length_coded_binary()
self.columns = []
if self.table_id in table_map:
self.column_schemas = table_map[self.table_id].column_schemas
else:
self.column_schemas = self._ctl_connection._get_table_information(self.schema, self.table)
if len(self.column_schemas) != 0:
# Read columns meta data
column_types = list(self.packet.read(self.column_count))
self.packet.read_length_coded_binary()
for i in range(0, len(column_types)):
column_type = column_types[i]
try:
column_schema = self.column_schemas[i]
except IndexError:
# this a dirty hack to prevent row events containing columns which have been dropped prior
# to pymysqlreplication start, but replayed from binlog from blowing up the service.
# TODO: this does not address the issue if the column other than the last one is dropped
column_schema = {
'COLUMN_NAME': '__dropped_col_{i}__'.format(i=i),
'COLLATION_NAME': None,
'CHARACTER_SET_NAME': None,
'COLUMN_COMMENT': None,
'COLUMN_TYPE': 'BLOB', # we don't know what it is, so let's not do anything with it.
'COLUMN_KEY': '',
}
col = Column(byte2int(column_type), column_schema, from_packet)
self.columns.append(col)
self.table_obj = Table(self.column_schemas, self.table_id, self.schema,
self.table, self.columns)
# TODO: get this information instead of trashing data
# n NULL-bitmask, length: (column-length * 8) / 7
def get_table(self):
return self.table_obj
def _dump(self):
super(TableMapEvent, self)._dump()
print("Table id: %d" % (self.table_id))
print("Schema: %s" % (self.schema))
print("Table: %s" % (self.table))
print("Columns: %s" % (self.column_count))

View File

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
class Table(object):
def __init__(self, column_schemas, table_id, schema, table, columns, primary_key=None):
if primary_key is None:
primary_key = [c.data["name"] for c in columns if c.data["is_primary"]]
if len(primary_key) == 0:
primary_key = ''
elif len(primary_key) == 1:
primary_key, = primary_key
else:
primary_key = tuple(primary_key)
self.__dict__.update({
"column_schemas": column_schemas,
"table_id": table_id,
"schema": schema,
"table": table,
"columns": columns,
"primary_key": primary_key
})
@property
def data(self):
return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith('_'))
def __eq__(self, other):
return self.data == other.data
def __ne__(self, other):
return not self.__eq__(other)
def serializable_data(self):
return self.data

68
dumpbinlog-tool/readme.md Normal file
View File

@ -0,0 +1,68 @@
## Cetus数据迁移追数工具使用手册
### 1 工具介绍
主要途是将binlog转换成SQL用于Cetus数据迁移。
### 2 工具用法
#### 2.1 配置文件
配置文件`binlog.conf`中分为三个段,分别是`[BINLOG_MYSQL]``[OUTPUT_MYSQL]`和`[DEFAULT]`。
`[BINLOG_MYSQL]`用来配置产生Binlog的MySQL的账号信息`[OUTPUT_MYSQL]`用来配置解析得到的SQL发往的MySQL的账号信息`[DEFAULT]`则是用来配置该工具的一些选项。
#### 2.2 参数介绍
基本的参数说明如下所示:
```
# 产生Binlog的MySQL账号信息
[BINLOG_MYSQL]
host=172.17.0.4
port=3306
user=ght
password=123456
# 解析后得到的SQL发往的MySQL账号信息
# 扩容时可以配置成新搭建的Cetus的账号信息
[OUTPUT_MYSQL]
host=172.17.0.2
port=6002
user=ght
password=123456
[DEFAULT]
# 解析Binlog的开始位置
log_file=binlog.000001
log_pos=351
# 需要跳过的schema即解析到该schema中的SQL全部忽略
skip_schemas=proxy_heart_beat
# 设置日志级别
log_level=DEBUG
# 是否忽略DDL操作
ignore_ddl=true
# 配置只解析的分库表名
# 只有这些表的操作输出,其他的(如全局表)的操作会被丢弃
# 兼容Cetus的配置文件
only_sharding_table=/data/sharding.json
```
#### 2.3 断点续传介绍
进度日志记录在`workdir/progress.log`文件中。下次启动会自动从这里继续,如果不想续传,可以**启动前将该文件删除**。
#### 2.4 启动及选项
启动时,可以指定 `-d`参数,用以指定工作目录,即`workdir`。
启动命令类似如下:
```
chmod +x ./dumpbinlog.py
./dumpbinlog.py
```

View File

@ -0,0 +1,91 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import logging
from transaction import Transaction
class PreviousExecution:
def __init__(self, from_trx: str, trxs: list):
gid, sid = from_trx.gtid.split(':')
self.gtid_executed = gid+':1-'+sid #TODO: this is wrong, find better solution
self.start_log_file = from_trx.binlog_file
self.start_log_pos = from_trx.last_log_pos
self.overlapped_gtids = {t.gtid:False for t in trxs}
self._all_tested = False
self._transactions = trxs
def executed(self, gtid: str) -> bool:
if self._all_tested:
return False
if self.overlapped_gtids.get(gtid) is not None: # True or False, but not None
self.overlapped_gtids[gtid] = True # mark used
self._all_tested = all(self.overlapped_gtids.values())
trx = self._get_trx(gtid)
# if XA_PREPARE doesn't have paired XA_COMMIT, we execute it again
if trx.type == Transaction.TRX_XA_PREPARE:
has_pair = self._has_paired_trx(trx)
main_logger = logging.getLogger('main')
main_logger.warning("gtid:{} does't have paired XA_COMMIT, will be executed again".format(gtid))
return has_pair
else:
return True
else:
return False
def _get_trx(self, gtid: str) -> Transaction:
trx = next((t for t in self._transactions if t.gtid==gtid), None)
assert trx is not None
return trx
def _has_paired_trx(self, preparetrx: Transaction) -> bool:
return any(preparetrx.XID==trx.XID for trx in self._transactions\
if trx.type==Transaction.TRX_XA_COMMIT)
def reverse_readline(filename, buf_size=8192):
"""a generator that returns the lines of a file in reverse order"""
with open(filename) as fh:
segment = None
offset = 0
fh.seek(0, os.SEEK_END)
file_size = remaining_size = fh.tell()
while remaining_size > 0:
offset = min(file_size, offset + buf_size)
fh.seek(file_size - offset)
buffer = fh.read(min(remaining_size, buf_size))
remaining_size -= buf_size
lines = buffer.split('\n')
# the first line of the buffer is probably not a complete line so
# we'll save it and append it to the last line of the next buffer
# we read
if segment is not None:
# if the previous chunk starts right from the beginning of line
# do not concact the segment to the last line of new chunk
# instead, yield the segment first
if buffer[-1] is not '\n':
lines[-1] += segment
else:
yield segment
segment = lines[0]
for index in range(len(lines) - 1, 0, -1):
if len(lines[index]):
yield lines[index]
# Don't yield None if the file was empty
if segment is not None:
yield segment
def read(filename) -> PreviousExecution:
transactions = []
for i, line in enumerate(reverse_readline(filename)):
if i >= 50:
break
s = line.split()
transactions.append(Transaction(s[2], s[3], s[4], s[5], int(s[6])))
if len(transactions) < 1:
return None
transactions = sorted(transactions, key=lambda trx: trx.gtid)
return PreviousExecution(transactions[0], transactions[1:])

View File

@ -0,0 +1,266 @@
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
from pymysqlreplication import BinLogStreamReader
from pymysqlreplication.row_event import (
RowsEvent,
DeleteRowsEvent,
UpdateRowsEvent,
WriteRowsEvent,
)
from pymysqlreplication.event import (
QueryEvent,
XidEvent,
RotateEvent,
GtidEvent,
XAPrepareEvent
)
from binascii import unhexlify
import re
class Transaction:
TRX_NORMAL = 'NORMAL'
TRX_XA_PREPARE = 'XA_PREPARE'
TRX_XA_COMMIT = 'XA_COMMIT'
def __init__(self, gtid:str=None, trx_type=TRX_NORMAL, xid:str=None,
log_file:str=None, log_pos:int=None):
self.gtid = gtid
self.last_committed = None
self.sequence_number = None
self.binlog_file = log_file
self.last_log_pos = log_pos
self.sql_list = []
self.XID = xid
self.type = trx_type
self.timestamp = 0
self.is_skip_schema = False
def isvalid(self) -> bool:
return self.gtid and self.last_committed is not None and \
self.sequence_number is not None and self.binlog_file
def dump(self):
print('GTID:', self.gtid,
'last_committed:', self.last_committed,
'sequence_number:', self.sequence_number,
'\nfile:', self.binlog_file, 'pos:', self.last_log_pos,
'\ntype:', self.type)
if self.is_skip_schema:
print('Some SQL in this trx is skipped')
print('SQL:')
for sql in self.sql_list:
print(sql)
print()
def brief(self) -> str:
return '..'.join(sql.split(' ')[0] for sql in self.sql_list)
def interleaved(self, other: 'Transaction') -> bool:
assert self.last_committed < self.sequence_number
assert other.last_committed < other.sequence_number
if self.last_committed < other.last_committed and \
self.sequence_number > other.last_committed:
return True
elif other.last_committed < self.sequence_number and \
other.sequence_number > self.last_committed:
return True
else:
return False
def __repr__(self) -> str:
return '{} {} {} {} {} {}'.format(self.gtid, self.type, self.XID,
self.binlog_file, self.last_log_pos,
self.timestamp)
def qstr(obj) -> str:
return "'{}'".format(str(obj))
def sql_delete(table, row) -> str:
sql = "delete from {} where ".format(table)
sql += ' and '.join([str(k)+'='+qstr(v) for k, v in row.items()])
return sql
def sql_update(table, before_row, after_row) -> str:
sql = 'update {} set '.format(table)
ct = 0
l = len(after_row.items())
for k, v in after_row.items():
ct += 1
if v is None:
sql += (str(k) + '=' + 'NULL')
else:
sql += (str(k) + '=' + qstr(v))
if ct != l:
sql += ','
sql += ' where '
sql += ' and '.join([str(k)+'='+qstr(v) for k, v in before_row.items()])
return sql
def sql_insert(table, row) -> str:
sql = 'insert into {}('.format(table)
keys = row.keys()
sql += ','.join([str(k) for k in keys])
sql += ') values('
ct = 0
l = len(keys)
for k in keys:
ct+= 1
if row[k] is None:
sql +="NULL"
else:
sql += qstr(row[k])
if ct != l:
sql += ','
sql += ')'
return sql
def is_ddl(sql: str) -> bool:
ddl_pattern = ['create table', 'drop table', 'create index', 'drop index',
'truncate table', 'alter table', 'alter index', 'create database', 'drop database', 'create user', 'drop user']
no_comment = re.sub('/\*.*?\*/', '', sql, flags=re.S)
formatted = ' '.join(no_comment.lower().split())
return any(formatted.startswith(x) for x in ddl_pattern)
def is_ddl_database(sql: str) -> bool:
ddl_pattern = ['create database', 'drop database']
no_comment = re.sub('/\*.*?\*/', '', sql, flags=re.S)
formatted = ' '.join(no_comment.lower().split())
return any(formatted.startswith(x) for x in ddl_pattern)
class BinlogTrxReader:
def __init__(self, config,
server_id,
blocking,
resume_stream,
log_file=None,
log_pos=None,
auto_position=None):
self.event_stream = BinLogStreamReader(
connection_settings=config.BINLOG_MYSQL,
server_id=server_id,
blocking=blocking,
resume_stream=resume_stream,
log_file=log_file,
log_pos=log_pos,
auto_position=auto_position
)
self._SKIP_SCHEMAS = config.SKIP_SCHEMAS
self._ALLOW_TABLES = config.ALLOW_TABLES
self._IGNORE_DDL = config.IGNORE_DDL
def __iter__(self):
return iter(self.fetch_one, None)
def _get_xid(self, event:QueryEvent) -> str:
sql = event.query
assert sql.lower().startswith('xa')
all_id = sql.split(' ')[2]
hex_id = all_id.split(',')[0]
return unhexlify(hex_id[2:-1]).decode()
def fetch_one(self) -> Transaction:
sql_events = [DeleteRowsEvent, WriteRowsEvent,
UpdateRowsEvent, QueryEvent, XidEvent,
XAPrepareEvent]
trx = Transaction()
for event in self.event_stream:
if isinstance(event, RotateEvent):
self.current_file = event.next_binlog
elif isinstance(event, GtidEvent):
trx.timestamp = event.timestamp
trx.gtid = event.gtid
trx.last_committed = event.last_committed
trx.sequence_number = event.sequence_number
trx.binlog_file = self.current_file
else:
finished = self._feed_event(trx, event)
if finished:
trx.last_log_pos = event.packet.log_pos
self._trim(trx)
return trx
def _process_rows_event(self, trx: Transaction, event: RowsEvent):
if self._SKIP_SCHEMAS and event.schema in self._SKIP_SCHEMAS:
trx.is_skip_schema = True
return
if self._ALLOW_TABLES and (event.schema, event.table) not in self._ALLOW_TABLES:
return
table = "%s.%s" % (event.schema, event.table)
if isinstance(event, DeleteRowsEvent):
trx.sql_list += [sql_delete(table, row['values']) for row in event.rows]
elif isinstance(event, UpdateRowsEvent):
trx.sql_list += [sql_update(table, row["before_values"], row["after_values"])\
for row in event.rows]
elif isinstance(event, WriteRowsEvent):
trx.sql_list += [sql_insert(table, row['values']) for row in event.rows]
def _feed_event(self, trx: Transaction, event) -> bool:
'''return: is this transaction finished
'''
if isinstance(event, RowsEvent):
self._process_rows_event(trx, event)
return False
elif isinstance(event, XidEvent):
trx.sql_list.append('commit')
assert trx.isvalid()
return True
elif isinstance(event, XAPrepareEvent):
if event.one_phase:
trx.sql_list.append('commit')
else:
trx.type = trx.TRX_XA_PREPARE
trx.XID = event.xid
assert trx.isvalid()
return True
elif isinstance(event, QueryEvent):
sql = event.query
if sql.startswith('XA START'):
trx.sql_list.append('START TRANSACTION')
elif sql.startswith('XA ROLLBACK'):
trx.sql_list.append('ROLLBACK')
elif sql.startswith('XA END'):
pass
elif sql.startswith('XA COMMIT'):
trx.sql_list.append('COMMIT')
trx.type = Transaction.TRX_XA_COMMIT
trx.XID = self._get_xid(event)
assert trx.isvalid()
return True
elif is_ddl(sql):
if self._IGNORE_DDL:
return True
if event.schema_length and not is_ddl_database(sql):
trx.sql_list.append('use '+event.schema.decode())
trx.sql_list.append(sql)
assert trx.isvalid()
return True
else:
trx.sql_list.append(sql)
return False
def _trim(self, trx: Transaction):
if trx.type == Transaction.TRX_NORMAL:
if len(trx.sql_list) == 3 \
and trx.sql_list[0].lower() == 'begin' \
and trx.sql_list[2].lower() == 'commit':
trx.sql_list = trx.sql_list[1:2]
if len(trx.sql_list) == 2 \
and trx.sql_list[0].lower() == 'begin' \
and trx.sql_list[1].lower() == 'commit':
trx.sql_list = []
elif trx.type == Transaction.TRX_XA_PREPARE:
if len(trx.sql_list) == 1 \
and trx.sql_list[0].lower() == 'start transaction':
trx.sql_list = []
else:
pass
def close(self):
self.event_stream.close()