cetus/dumpbinlog-tool/transaction.py
2019-02-21 15:39:58 +08:00

267 lines
9.1 KiB
Python

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
from pymysqlreplication import BinLogStreamReader
from pymysqlreplication.row_event import (
RowsEvent,
DeleteRowsEvent,
UpdateRowsEvent,
WriteRowsEvent,
)
from pymysqlreplication.event import (
QueryEvent,
XidEvent,
RotateEvent,
GtidEvent,
XAPrepareEvent
)
from binascii import unhexlify
import re
class Transaction:
TRX_NORMAL = 'NORMAL'
TRX_XA_PREPARE = 'XA_PREPARE'
TRX_XA_COMMIT = 'XA_COMMIT'
def __init__(self, gtid:str=None, trx_type=TRX_NORMAL, xid:str=None,
log_file:str=None, log_pos:int=None):
self.gtid = gtid
self.last_committed = None
self.sequence_number = None
self.binlog_file = log_file
self.last_log_pos = log_pos
self.sql_list = []
self.XID = xid
self.type = trx_type
self.timestamp = 0
self.is_skip_schema = False
def isvalid(self) -> bool:
return self.gtid and self.last_committed is not None and \
self.sequence_number is not None and self.binlog_file
def dump(self):
print('GTID:', self.gtid,
'last_committed:', self.last_committed,
'sequence_number:', self.sequence_number,
'\nfile:', self.binlog_file, 'pos:', self.last_log_pos,
'\ntype:', self.type)
if self.is_skip_schema:
print('Some SQL in this trx is skipped')
print('SQL:')
for sql in self.sql_list:
print(sql)
print()
def brief(self) -> str:
return '..'.join(sql.split(' ')[0] for sql in self.sql_list)
def interleaved(self, other: 'Transaction') -> bool:
assert self.last_committed < self.sequence_number
assert other.last_committed < other.sequence_number
if self.last_committed < other.last_committed and \
self.sequence_number > other.last_committed:
return True
elif other.last_committed < self.sequence_number and \
other.sequence_number > self.last_committed:
return True
else:
return False
def __repr__(self) -> str:
return '{} {} {} {} {} {}'.format(self.gtid, self.type, self.XID,
self.binlog_file, self.last_log_pos,
self.timestamp)
def qstr(obj) -> str:
return "'{}'".format(str(obj))
def sql_delete(table, row) -> str:
sql = "delete from {} where ".format(table)
sql += ' and '.join([str(k)+'='+qstr(v) for k, v in row.items()])
return sql
def sql_update(table, before_row, after_row) -> str:
sql = 'update {} set '.format(table)
ct = 0
l = len(after_row.items())
for k, v in after_row.items():
ct += 1
if v is None:
sql += (str(k) + '=' + 'NULL')
else:
sql += (str(k) + '=' + qstr(v))
if ct != l:
sql += ','
sql += ' where '
sql += ' and '.join([str(k)+'='+qstr(v) for k, v in before_row.items()])
return sql
def sql_insert(table, row) -> str:
sql = 'insert into {}('.format(table)
keys = row.keys()
sql += ','.join([str(k) for k in keys])
sql += ') values('
ct = 0
l = len(keys)
for k in keys:
ct+= 1
if row[k] is None:
sql +="NULL"
else:
sql += qstr(row[k])
if ct != l:
sql += ','
sql += ')'
return sql
def is_ddl(sql: str) -> bool:
ddl_pattern = ['create table', 'drop table', 'create index', 'drop index',
'truncate table', 'alter table', 'alter index', 'create database', 'drop database', 'create user', 'drop user']
no_comment = re.sub('/\*.*?\*/', '', sql, flags=re.S)
formatted = ' '.join(no_comment.lower().split())
return any(formatted.startswith(x) for x in ddl_pattern)
def is_ddl_database(sql: str) -> bool:
ddl_pattern = ['create database', 'drop database']
no_comment = re.sub('/\*.*?\*/', '', sql, flags=re.S)
formatted = ' '.join(no_comment.lower().split())
return any(formatted.startswith(x) for x in ddl_pattern)
class BinlogTrxReader:
def __init__(self, config,
server_id,
blocking,
resume_stream,
log_file=None,
log_pos=None,
auto_position=None):
self.event_stream = BinLogStreamReader(
connection_settings=config.BINLOG_MYSQL,
server_id=server_id,
blocking=blocking,
resume_stream=resume_stream,
log_file=log_file,
log_pos=log_pos,
auto_position=auto_position
)
self._SKIP_SCHEMAS = config.SKIP_SCHEMAS
self._ALLOW_TABLES = config.ALLOW_TABLES
self._IGNORE_DDL = config.IGNORE_DDL
def __iter__(self):
return iter(self.fetch_one, None)
def _get_xid(self, event:QueryEvent) -> str:
sql = event.query
assert sql.lower().startswith('xa')
all_id = sql.split(' ')[2]
hex_id = all_id.split(',')[0]
return unhexlify(hex_id[2:-1]).decode()
def fetch_one(self) -> Transaction:
sql_events = [DeleteRowsEvent, WriteRowsEvent,
UpdateRowsEvent, QueryEvent, XidEvent,
XAPrepareEvent]
trx = Transaction()
for event in self.event_stream:
if isinstance(event, RotateEvent):
self.current_file = event.next_binlog
elif isinstance(event, GtidEvent):
trx.timestamp = event.timestamp
trx.gtid = event.gtid
trx.last_committed = event.last_committed
trx.sequence_number = event.sequence_number
trx.binlog_file = self.current_file
else:
finished = self._feed_event(trx, event)
if finished:
trx.last_log_pos = event.packet.log_pos
self._trim(trx)
return trx
def _process_rows_event(self, trx: Transaction, event: RowsEvent):
if self._SKIP_SCHEMAS and event.schema in self._SKIP_SCHEMAS:
trx.is_skip_schema = True
return
if self._ALLOW_TABLES and (event.schema, event.table) not in self._ALLOW_TABLES:
return
table = "%s.%s" % (event.schema, event.table)
if isinstance(event, DeleteRowsEvent):
trx.sql_list += [sql_delete(table, row['values']) for row in event.rows]
elif isinstance(event, UpdateRowsEvent):
trx.sql_list += [sql_update(table, row["before_values"], row["after_values"])\
for row in event.rows]
elif isinstance(event, WriteRowsEvent):
trx.sql_list += [sql_insert(table, row['values']) for row in event.rows]
def _feed_event(self, trx: Transaction, event) -> bool:
'''return: is this transaction finished
'''
if isinstance(event, RowsEvent):
self._process_rows_event(trx, event)
return False
elif isinstance(event, XidEvent):
trx.sql_list.append('commit')
assert trx.isvalid()
return True
elif isinstance(event, XAPrepareEvent):
if event.one_phase:
trx.sql_list.append('commit')
else:
trx.type = trx.TRX_XA_PREPARE
trx.XID = event.xid
assert trx.isvalid()
return True
elif isinstance(event, QueryEvent):
sql = event.query
if sql.startswith('XA START'):
trx.sql_list.append('START TRANSACTION')
elif sql.startswith('XA ROLLBACK'):
trx.sql_list.append('ROLLBACK')
elif sql.startswith('XA END'):
pass
elif sql.startswith('XA COMMIT'):
trx.sql_list.append('COMMIT')
trx.type = Transaction.TRX_XA_COMMIT
trx.XID = self._get_xid(event)
assert trx.isvalid()
return True
elif is_ddl(sql):
if self._IGNORE_DDL:
return True
if event.schema_length and not is_ddl_database(sql):
trx.sql_list.append('use '+event.schema.decode())
trx.sql_list.append(sql)
assert trx.isvalid()
return True
else:
trx.sql_list.append(sql)
return False
def _trim(self, trx: Transaction):
if trx.type == Transaction.TRX_NORMAL:
if len(trx.sql_list) == 3 \
and trx.sql_list[0].lower() == 'begin' \
and trx.sql_list[2].lower() == 'commit':
trx.sql_list = trx.sql_list[1:2]
if len(trx.sql_list) == 2 \
and trx.sql_list[0].lower() == 'begin' \
and trx.sql_list[1].lower() == 'commit':
trx.sql_list = []
elif trx.type == Transaction.TRX_XA_PREPARE:
if len(trx.sql_list) == 1 \
and trx.sql_list[0].lower() == 'start transaction':
trx.sql_list = []
else:
pass
def close(self):
self.event_stream.close()