文章目录
- 前言
- 注册驱动
- 连接器
- 创建连接
- 交互协议
- 读写数据
- 读数据
- 写数据
- mysqlConn
- context超时控制
- 查询
- 发送查询请求
- 读取查询响应
- Exec
- 发送exec请求
- 读取响应
- 预编译
- 客户端预编译
- 服务端预编译
- 生成prepareStmt
- 执行查询操作
- 执行Exec操作
- 事务
- 读取响应
- query响应
- exec响应
- 总结
前言
go-sql-driver/mysql
的核心功能是,实现 database/sql/driver 中定义的接口,提供mysql版本的驱动实现,主要完成以下功能:
- 根据交互协议,完成往mysql server发送请求和解析响应的具体操作
- 执行客户端预编译操作
驱动需要实现的接口,和接口之间的关系如下:
本文阅读源码:https://github.com/go-sql-driver/mysql,版本:v1.8.1
注册驱动
要使用mysql驱动,首先需要注册驱动,注册方式如下:
import (
// 注册驱动
_ "github.com/go-sql-driver/mysql"
)
这个包会用通过init函数注册驱动:
var driverName = "mysql"
func init() {
if driverName != "" {
sql.Register(driverName, &MySQLDriver{})
}
}
往database/sql包的map全局变量注册驱动:
func Register(name string, driver driver.Driver) {
driversMu.Lock()
defer driversMu.Unlock()
if driver == nil {
panic("sql: Register driver is nil")
}
if _, dup := drivers[name]; dup {
panic("sql: Register called twice for driver " + name)
}
// 注册
drivers[name] = driver
}
这样sql.Open("mysql", XXX)
时就能根据mysql的driver生成connecter,然后创建sql.DB实例
连接器
继续看sql.Open方法:
func Open(driverName, dataSourceName string) (*DB, error) {
// 校验驱动是否注册
driversMu.RLock()
driveri, ok := drivers[driverName]
driversMu.RUnlock()
if !ok {
return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
}
// mysql的驱动会命中这个分支
if driverCtx, ok := driveri.(driver.DriverContext); ok {
connector, err := driverCtx.OpenConnector(dataSourceName)
if err != nil {
return nil, err
}
return OpenDB(connector), nil
}
return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil
}
MySQLDriver实现了driver.DriverContext
接口,根据dsn解析出配置,生成driver.Connector实例,让sql.DB持有
func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
cfg, err := ParseDSN(dsn)
if err != nil {
return nil, err
}
return newConnector(cfg), nil
}
newConnector:让connector持有cfg配置
func newConnector(cfg *Config) *connector {
encodedAttributes := encodeConnectionAttributes(cfg)
return &connector{
cfg: cfg,
encodedAttributes: encodedAttributes,
}
}
config里有各种配置,例如:
- 账号密码,连接地址
- 超时配置:连接超时时间
Timeout
,读超时时间ReadTimeout
,写超时时间WriteTimeout
- 是否开启客户端预编译:
InterpolateParams
完整的配置如下:
type Config struct {
// non boolean fields
User string // Username
Passwd string // Password (requires User)
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix")
DBName string // Database name
Params map[string]string // Connection parameters
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
Collation string // Connection collation
Loc *time.Location // Location for time.Time values MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
TLSConfig string // TLS configuration name TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Logger Logger // Logger
// boolean fields
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
AllowFallbackToPlaintext bool // Allows fallback to unencrypted connection if server does not support TLS
AllowNativePasswords bool // Allows the native password authentication method
AllowOldPasswords bool // Allows the old insecure password method
CheckConnLiveness bool // Check connections for liveness before using them
ClientFoundRows bool // Return number of matching rows instead of rows changed
ColumnsWithAlias bool // Prepend table alias to column names
InterpolateParams bool // Interpolate placeholders into query string
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections
// unexported fields. new options should be come here
beforeConnect func(context.Context, *Config) error // Invoked before a connection is established
pubKey *rsa.PublicKey // Server public key
timeTruncate time.Duration // Truncate time.Time values to the specified duration}
例如客户端预编译参数interpolateParams
:
- 如果dsn中有
interpolateParams=true
参数,就会被解析到cfg.InterpolateParams中,值为true - cfg被connector持有
- 用connector新建连接时,将cfg传给
mysqlConn
- 用该连接查询时,如果
cfg.InterpolateParams=true
,就会启用客户端预编译
创建连接
根据上一篇文章的分析,sql包要新建连接时,都调Connector.Connect
方法
- 和mysql服务器建立连接,类型为
net.TCPConn
- 创建读缓冲区
- 设置读写超时时间
- 处理握手,鉴权,不是本文的重点,这里省略
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
var err error
cfg := c.cfg
// ...
// New mysqlConn
mc := &mysqlConn{
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
// 使用cfg
cfg: cfg,
connector: c,
}
mc.parseTime = mc.cfg.ParseTime
// Connect to Server
dialsLock.RLock()
dial, ok := dials[mc.cfg.Net]
dialsLock.RUnlock()
if ok {
// ...
} else {
// 设置建立连接的timeout = cfg.Timeout
nd := net.Dialer{Timeout: mc.cfg.Timeout}
// 创建net.Conn类型的连接
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
}
if err != nil {
return nil, err
}
mc.rawConn = mc.netConn
// Enable TCP Keepalives on TCP connections
if tc, ok := mc.netConn.(*net.TCPConn); ok {
// 设为长连接
if err := tc.SetKeepAlive(true); err != nil {
c.cfg.Logger.Print(err)
}
}
// Call startWatcher for context support (From Go 1.8)
mc.startWatcher()
if err := mc.watchCancel(ctx); err != nil {
mc.cleanup()
return nil, err
}
defer mc.finish()
// 创建缓冲区
mc.buf = newBuffer(mc.netConn)
// 设置读写超时时间
mc.buf.timeout = mc.cfg.ReadTimeout
mc.writeTimeout = mc.cfg.WriteTimeout
// 处理握手,鉴权,这里忽略
// Handle DSN Params
err = mc.handleParams()
if err != nil {
mc.Close()
return nil, err
}
return mc, nil
}
交互协议
mysql client和mysql server采用特殊的交互协议:每个报文的结构如下:
- 前3字节为数据包长度,也就是每个报文最多传输
2^24 = 16MB
的数据。如果一次请求或响应需要传输的数据量超过这个大小,需要拆分成多个报文传输 - 第4个字节为序列号sequence,每次读写之前都会对
sequence+1
,接收到mysql server的响应后,检查报文里的sequence是否和本地相同 - 第5个字节:
- 读:表示mysql server返回的响应状态,0为成功,255为出错
- 写:表示client本次请求的类型,例如Query,Ping,Prepare
- 剩下的字节:报文数据
读写数据
接下来看mysql client如何根据通信协议和mysql server交互,也就是如何读写数据
读数据
readPacket用于读一个完整的数据包:
由于每个报文有3个字节表示数据长度,也就是最多表示2^24个字节=16MB
,如果数据包超过这个字节数,就要多次读取报文。最后一个报文的长度为0表示数据读完了
func (mc *mysqlConn) readPacket() ([]byte, error) {
var prevData []byte
for {
// 先读头部的4个字节
data, err := mc.buf.readNext(4)
if err != nil {
// ...
return nil, ErrInvalidConn
}
// 前3个字节为包的长度
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
// 第4个字节为序列号,必须要和本地的序列号一致
if data[3] != mc.sequence {
mc.Close()
if data[3] > mc.sequence {
return nil, ErrPktSyncMul
}
return nil, ErrPktSync
}
// 本地序列号自增
mc.sequence++
// 包长度为0,代表终结上一次的读取,上一次的读取很长
if pktLen == 0 {
// there was no previous packet
if prevData == nil {
mc.log(ErrMalformPkt)
mc.Close()
return nil, ErrInvalidConn
}
// 返回之前读取的数据
return prevData, nil
}
// 读pktLen长度的数据
data, err = mc.buf.readNext(pktLen)
if err != nil {
if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr
}
mc.log(err)
mc.Close()
return nil, ErrInvalidConn
}
// 包长度小于最大长度,说明读一次就行
if pktLen < maxPacketSize {
// zero allocationsfor non-split packets
if prevData == nil {
return data, nil
}
return append(prevData, data...), nil
}
// 否则需要读多次,先把这次读到的暂存起来
prevData = append(prevData, data...)
}
}
缓冲区buffer定义如下:
type buffer struct {
// 都缓冲区
buf []byte
// 缓冲区数据不够时,从什么连接读数据
nc net.Conn
// 下次从哪开始读
idx int
// 还剩多少个字节未读
length int
// 读超时时间
timeout time.Duration
// ...
}
readNext:从buffer中读need个字节
func (b *buffer) readNext(need int) ([]byte, error) {
// 剩余字节数不够本次要读的
if b.length < need {
// 调fill从连接中读到缓冲区
if err := b.fill(need); err != nil {
return nil, err
}
}
// 从offset开始读
offset := b.idx
// 更新idx
b.idx += need
// 还能读的字节数减少
b.length -= need
return b.buf[offset:b.idx], nil
}
当缓冲区中字节数不够时,调fill填充:
- 如果buf不够need,扩容
- 将buf中还能读的部分,拷贝到开头
- 每次读之前,设置读超时时间
- 调net.Conn的Read方法,把数据读到缓冲区中
func (b *buffer) fill(need int) error {
n := b.length
// 如果flipcnt一直是0,那就一直用第一个buf
dest := b.dbuf[b.flipcnt&1]
// 如果buf不够need,扩容
if need > len(dest) {
// 容量按照4096的整数倍向上对齐
dest = make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
// 如果新分配的buf不是太大,就放到下一个buf中
if len(dest) <= maxCachedBufSize {
b.dbuf[b.flipcnt&1] = dest
}
}
if n > 0 {
// 将buf中还能读的部分,拷贝到开头
copy(dest[:n], b.buf[b.idx:])
}
b.buf = dest
b.idx = 0
for {
if b.timeout > 0 {
// 每次读之前,设置读超时时间
if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil {
return err
}
}
// 从连接中读数据,写到n开始的位置
nn, err := b.nc.Read(b.buf[n:])
n += nn
switch err {
case nil:
// 如果还读到需要的字节数,继续读
if n < need {
continue
}
b.length = n
return nil
case io.EOF:
if n >= need {
b.length = n
return nil
}
return io.ErrUnexpectedEOF
default:
return err
}
}
}
写数据
往mysql server写数据没用缓冲区,直接往连接写。如果一次写不完,拆分成多个报文写
func (mc *mysqlConn) writePacket(data []byte) error {
// 包长度
pktLen := len(data) - 4
// mc.maxAllowedPacket为配置的最大包长度
if pktLen > mc.maxAllowedPacket {
return ErrPktTooLarge
}
for {
var size int
// 前3个字节为包长度
if pktLen >= maxPacketSize {
data[0] = 0xff
data[1] = 0xff
data[2] = 0xff
size = maxPacketSize
} else {
data[0] = byte(pktLen)
data[1] = byte(pktLen >> 8)
data[2] = byte(pktLen >> 16)
size = pktLen
}
// 第4个字节为序列号
data[3] = mc.sequence
// 设置本次的写超时
if mc.writeTimeout > 0 {
if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil {
return err
}
}
// 执行写操作
n, err := mc.netConn.Write(data[:4+size])
if err == nil && n == 4+size {
// 写完了,序列号++
mc.sequence++
if size != maxPacketSize {
return nil
}
// 接着写下一个包
pktLen -= size
data = data[size:]
continue
}
// 处理错误
return ErrInvalidConn
}
}
mysqlConn
mysqlConn实现了driver.Conn接口:
type Conn interface {
// 预编译
Prepare(query string) (Stmt, error)
Close() error
// 开启事务
Begin() (Tx, error)
}
mysqlConn字段如下:
type mysqlConn struct {
// 缓冲区
buf buffer
// TCP连接
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
// sql执行结果
result mysqlResult // managed by clearResult() and handleOkPacket().
// 连接相关配置
cfg *Config
// 由哪个连接器创建的
connector *connector
maxAllowedPacket int
maxWriteSize int
// 一次写报文的超市时间
writeTimeout time.Duration
// 客户端状态标识
flags clientFlag
// 服务端状态标识
status statusFlag
// 序列号
sequence uint8
parseTime bool
// for context support (Go 1.8+)
watching bool
watcher chan<- context.Context
closech chan struct{}
finished chan<- struct{}
canceled atomicError // set non-nil if conn is canceled
closed atomicBool // set when conn is closed, before closech is closed
}
context超时控制
每个mysqlConn会启动一个Watcher,用于监听一旦ctx的Done后,关闭tcp连接
func (mc *mysqlConn) startWatcher() {
watcher := make(chan context.Context, 1)
mc.watcher = watcher
finished := make(chan struct{})
mc.finished = finished
go func() {
for {
var ctx context.Context
select {
// 从watcher接受一个ctx
case ctx = <-watcher:
case <-mc.closech:
return
}
select {
case <-ctx.Done():
// ctx.Done后,关闭连接
mc.cancel(ctx.Err())
case <-finished:
case <-mc.closech:
return
}
}
}()
}
mysqlConn.cancal
func (mc *mysqlConn) cancel(err error) {
mc.canceled.Set(err)
mc.cleanup()
}
func (mc *mysqlConn) cleanup() {
if mc.closed.Swap(true) {
return
}
// Makes cleanup idempotent
close(mc.closech)
conn := mc.rawConn
if conn == nil {
return
}
// 关闭底层tcp连接
if err := conn.Close(); err != nil {
mc.log(err)
}
}
那啥时候往mc.watcher推ctx呢?每次调QueryContext,ExecContext
时,将参数传入的ctx推给后台watcher,后台监听这个ctx的Done信号。如果sql执行超时,会关闭tcp连接
通过这种方式来实现ctx的超时控制
查询
当调用sql.DB的QueryContext方法时,如果驱动的连接实现了driver.QueryerContext
接口,会转化为调该接口的QueryContext方法,而mysql驱动 mysqlConn实现了该接口:
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
// 将ctx通过channel推给后台go,监控ctx的退出
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
// 执行查询
rows, err := mc.query(query, dargs)
if err != nil {
mc.finish()
return nil, err
}
rows.finish = mc.finish
return rows, err
}
发送查询请求
query:执行查询
- 如果有参数:
- 如果不允许客户端预编译,返回
driver.ErrSkip
,让sql.DB走服务端预编译 - 否则执行客户端预编译
- 如果不允许客户端预编译,返回
- 将命令封装成符合交互协议的报文格式,发送到mysql server
- 读取响应结果,也就是列的信息
// query:sql模板,args:占位符
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
handleOk := mc.clearResult()
if mc.closed.Load() {
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// 有参数
if len(args) != 0 {
// 如果不允许客户端预编译,返回driver.ErrSkip,让sql.DB走服务端预编译
if !mc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// 尝试客户端的预编译,以减少网络消耗
prepared, err := mc.interpolateParams(query, args)
if err != nil {
return nil, err
}
query = prepared
}
// 将sql发到服务端
err := mc.writeCommandPacketStr(comQuery, query)
if err == nil {
// Read Result
var resLen int
// 读取列的个数
resLen, err = handleOk.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
if resLen == 0 {
rows.rs.done = true
switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
}
}
// 读resLen个列信息到rows.rs.columns中
rows.rs.columns, err = mc.readColumns(resLen)
return rows, err
}
}
return nil, mc.markBadConn(err)
}
将命令write到mysql server:
- 计算数据包长度,加上报文头部的长度,申请空间data
- 将数据拷贝到data中
- 一次性或分批写到mysql server
func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
// Reset Packet Sequence
mc.sequence = 0
// 数据包的长度=1 + len(arg)
pktLen := 1 + len(arg)
// 申请一块buf,尝试复用已经申请好的,长度为pktLen再加4,也就是加上报文头部的长度
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
mc.log(err)
return errBadConnNoWrite
}
// 第4个字节为请求类型
data[4] = command
// 追加上arg
copy(data[5:], arg)
// 一次性或分批写到mysql server
return mc.writePacket(data)
}
读取查询响应
读取响应结果:
对query的响应来说,返回列的个数
func (mc *okHandler) readResultSetHeaderPacket() (int, error) {
// handleOkPacket replaces both values; other cases leave the values unchanged.
mc.result.affectedRows = append(mc.result.affectedRows, 0)
mc.result.insertIds = append(mc.result.insertIds, 0)
// 读一个响应包
data, err := mc.conn().readPacket()
if err == nil {
switch data[0] {
// data[0]=0代表响应成功,Exec函数会走到这,下文再分析
case iOK:
return 0, mc.handleOkPacket(data)
case iERR:
return 0, mc.conn().handleErrorPacket(data)
case iLocalInFile:
return 0, mc.handleInFileRequest(string(data[1:]))
}
// 查询走这
// data[0]是其他值的话,从data里读取列的个数
num, _, _ := readLengthEncodedInteger(data)
return int(num), nil
}
return 0, err
}
先看两个小方法:如何从mysql server响应的字节流中读取一个数字,读一个字符串
readLengthEncodedInteger
:从响应结果b里解析一个数字,返回数字本身,是否为null,数字占用多少个字节
根据b[0]的值不同,分3种情况处理:
- b[0]是251,说明结果是null
- b[0]是为0~250,那么b[0] 就是接下来的数字
- 否则数字由多个字节组成,根据b[0]的值使用的不同的解析方式
func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
// See issue #349
if len(b) == 0 {
return 0, true, 1
}
switch b[0] {
// 251: NULL
case 0xfb:
return 0, true, 1
// 252: value of following 2
case 0xfc:
return uint64(b[1]) | uint64(b[2])<<8, false, 3
// 253: value of following 3
case 0xfd:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4
// 254: value of following 8
case 0xfe:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
uint64(b[7])<<48 | uint64(b[8])<<56,
false, 9
}
// 0-250: value of first byte
return uint64(b[0]), false, 1
}
readLengthEncodedString:从字节流b中读取一个字符串。先读一个数字num代表字符串的长度,然后读num个字节,就是字符串本身
func readLengthEncodedString(b []byte) ([]byte, bool, int, error) {
// Get length
num, isNull, n := readLengthEncodedInteger(b)
if num < 1 {
return b[n:n], isNull, n, nil
}
n += int(num)
// Check data length
if len(b) >= n {
return b[n-int(num) : n : n], false, n, nil
}
return nil, false, n, io.EOF
}
现在有了列的个数,再从连接中读取mysql server返回的列的详情:
每个列有这些信息:
type mysqlField struct {
tableName string
name string
length uint32
flags fieldFlag
fieldType fieldType
decimals byte
charSet uint8
}
mysqlConn.readColumns
要干的事就是从连接中的字节流中读取并解析数据,为mysqlField
的每个字段赋值
整体来说就是对每个字段都读一个报文,解析报文中的字节流到mysqlField中
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
columns := make([]mysqlField, count)
for i := 0; ; i++ {
// 读一个报文
data, err := mc.readPacket()
if err != nil {
return nil, err
}
// 读完了,返回
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
if i == count {
return columns, nil
}
return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns))
}
// Catalog
pos, err := skipLengthEncodedString(data)
if err != nil {
return nil, err
}
// ... 跳过一些字节
// 读列名
name, _, n, err := readLengthEncodedString(data[pos:])
if err != nil {
return nil, err
}
columns[i].name = string(name)
pos += n
// ... 跳过一些字节
// Filler [uint8]
pos++
// 读charset
columns[i].charSet = data[pos]
pos += 2
// 读length
columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4])
pos += 4
// 读字段类型
columns[i].fieldType = fieldType(data[pos])
pos++
// Flags [uint16]
columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
pos += 2
// Decimals [uint8]
columns[i].decimals = data[pos]
}
Exec
当调用sql.DB的ExecContext
方法时,如果驱动的连接实现了driver.ExecerContext
接口,会转化为调该接口的ExecContext方法,而mysql驱动 mysqlConn实现了该接口:
func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
// 监听ctx的退出信号
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()
return mc.Exec(query, dargs)
}
调mysqlConn.Exec方法
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.closed.Load() {
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// 客户端预编译相关,和query流程类似
if len(args) != 0 {
if !mc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
prepared, err := mc.interpolateParams(query, args)
if err != nil {
return nil, err
}
query = prepared
}
err := mc.exec(query)
if err == nil {
copied := mc.result
return &copied, err
}
return nil, mc.markBadConn(err)
}
发送exec请求
func (mc *mysqlConn) exec(query string) error {
handleOk := mc.clearResult()
// 将query发往mysql server
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
return mc.markBadConn(err)
}
// 处理响应结果
resLen, err := handleOk.readResultSetHeaderPacket()
if err != nil {
return err
}
// ...
}
读取响应
readResultSetHeaderPacket:读取响应
func (mc *okHandler) readResultSetHeaderPacket() (int, error) {
// handleOkPacket replaces both values; other cases leave the values unchanged.
mc.result.affectedRows = append(mc.result.affectedRows, 0)
mc.result.insertIds = append(mc.result.insertIds, 0)
data, err := mc.conn().readPacket()
if err == nil {
switch data[0] {
// 进入这个分支
case iOK:
return 0, mc.handleOkPacket(data)
// ...
}
return 0, err
}
handleOkPacket:处理OK的响应结果:
就是从字节流中读取本次exec影响的行数affectedRows
,上次插入id insertId
,保存到连接中,供应用层获取
func (mc *okHandler) handleOkPacket(data []byte) error {
var n, m int
var affectedRows, insertId uint64
// 影响行数,n代表了几个字节
affectedRows, _, n = readLengthEncodedInteger(data[1:])
// insertId,m代表读了几个字节
insertId, _, m = readLengthEncodedInteger(data[1+n:])
if len(mc.result.affectedRows) > 0 {
// 保存affectedRows
mc.result.affectedRows[len(mc.result.affectedRows)-1] = int64(affectedRows)
}
if len(mc.result.insertIds) > 0 {
// 保存insertId
mc.result.insertIds[len(mc.result.insertIds)-1] = int64(insertId)
}
// server_status [2 bytes]
mc.status = readStatus(data[1+n+m : 1+n+m+2])
if mc.status&statusMoreResultsExists != 0 {
return nil
}
// warning count [2 bytes]
return nil
}
预编译
预编译操作是将一个sql模板提前发往mysql server. 后续在该 sql 模板下的多笔操作,只需要将对应的参数发往服务端,即可实现对模板的复用,有以下优点:
- 模板复用:sql 模板一次编译,多次复用,可以提高性能
- 语法安全:模板不变的部分和参数可变的部分隔离,防止sql注入
客户端预编译
客户端预编译只有语法安全的效果,没有模板服用的优点,客户端预编译只有语法安全的效果,没有模板服用的优点,但是能减少和mysql server交互的次数
当配置mysql驱动的参数interpolateParams为true
,且没有用预编译时(直接调db.Query,db.Exec),会使用客户端预编译
mysql驱动层面会将query和args拼接成一个完整的sql,怎么防止sql注入?拼接字符串时,在其前后加上\'
客户端预编译函数interpolateParams的代码如下:
- 校验sql中
?
的个数要个args的长度相同 - 将每个
?
替换成实际的参数。如果参数是字符串,在前后拼接\'
// 客户端预编译
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
// ?的个数要个args的长度相同
if strings.Count(query, "?") != len(args) {
return "", driver.ErrSkip
}
buf, err := mc.buf.takeCompleteBuffer()
if err != nil {
// can not take the buffer. Something must be wrong with the connection
mc.log(err)
return "", ErrInvalidConn
}
buf = buf[:0]
argPos := 0
for i := 0; i < len(query); i++ {
// 下一个?的位置
q := strings.IndexByte(query[i:], '?')
// 后面没有?了,跳出循环
if q == -1 {
buf = append(buf, query[i:]...)
break
}
// 拼接上?前面的部分
buf = append(buf, query[i:i+q]...)
i += q
arg := args[argPos]
argPos++
// 对应arg如果是nil,拼接NULL
if arg == nil {
buf = append(buf, "NULL"...)
continue
}
switch v := arg.(type) {
// 参数为其他类型
// 重点在这,如果对应arg是字符串,在其前后拼接上 '
case string:
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeStringBackslash(buf, v)
} else {
// 将v写进去
buf = escapeStringQuotes(buf, v)
}
buf = append(buf, '\'')
default:
return "", driver.ErrSkip
}
if len(buf)+4 > mc.maxAllowedPacket {
return "", driver.ErrSkip
}
}
if argPos != len(args) {
return "", driver.ErrSkip
}
return string(buf), nil
}
服务端预编译
生成prepareStmt
通过mysqlConn.Prepare得到
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.closed.Load() {
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
// 将sql模板发往mysql server
err := mc.writeCommandPacketStr(comStmtPrepare, query)
if err != nil {
// STMT_PREPARE is safe to retry. So we can return ErrBadConn here.
mc.log(err)
return nil, driver.ErrBadConn
}
// 构造mysqlStmt实例
stmt := &mysqlStmt{
mc: mc,
}
// 读取prepare响应,主要获得statmentId,由mysql server生成返回
columnCount, err := stmt.readPrepareResultPacket()
if err == nil {
if stmt.paramCount > 0 {
if err = mc.readUntilEOF(); err != nil {
return nil, err
}
}
if columnCount > 0 {
err = mc.readUntilEOF()
}
}
return stmt, err
}
mysqlStmt如下:
type mysqlStmt struct {
// 关联的mysqlConn
mc *mysqlConn
// mysql server返回的stmtId
id uint32
// 参数个数
paramCount int
}
读prepare的响应,获得statmentId
func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
// 都一个报文
data, err := stmt.mc.readPacket()
if err == nil {
// packet indicator [1 byte]
if data[0] != iOK {
return 0, stmt.mc.handleErrorPacket(data)
}
// 前4个字节为 statement id,保存到stmt实例里
stmt.id = binary.LittleEndian.Uint32(data[1:5])
// Column count [16 bit uint]
columnCount := binary.LittleEndian.Uint16(data[5:7])
// 保存sql模板中的参数个数,用于driverStmt.NumInput方法
stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))
return columnCount, nil
}
return 0, err
}
拿到stmt后,看看怎么基于stmt执行query和exec操作
执行查询操作
mysqlStmt实现driver.StmtQueryContext接口方法,只用传参数即可
func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
// 将ctx通过channel推给后台go,监听ctx的Done
if err := stmt.mc.watchCancel(ctx); err != nil {
return nil, err
}
rows, err := stmt.query(dargs)
if err != nil {
stmt.mc.finish()
return nil, err
}
rows.finish = stmt.mc.finish
return rows, err
}
mysqlStmt.query
- 将参数发往mysql server
- 读取响应,列的信息,这部分同普通的query流程
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
if stmt.mc.closed.Load() {
stmt.mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// 将参数args发往mysql server
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, stmt.mc.markBadConn(err)
}
mc := stmt.mc
// Read Result
handleOk := stmt.mc.clearResult()
// 读取列的个数
resLen, err := handleOk.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
rows := new(binaryRows)
if resLen > 0 {
rows.mc = mc
// 读取列的元数据信息
rows.rs.columns, err = mc.readColumns(resLen)
} else {
rows.rs.done = true
switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
}
}
return rows, err
}
writeExecutePacket发送参数:
除了报文头部4个字节加操作类型1个字节外,接下来写4个字节的stmtId
然后往字节数组里写参数,对于每个参数来说,先写数据类型,再写数据的值
func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// 参数个数必须要和预编译返回的参数个数一致
if len(args) != stmt.paramCount {
return fmt.Errorf(
"argument count mismatch (got: %d; has: %d)",
len(args),
stmt.paramCount,
)
}
const minPktLen = 4 + 1 + 4 + 1 + 4
mc := stmt.mc
// ...
// Reset packet-sequence
mc.sequence = 0
var data []byte
var err error
// 复用buffer
if len(args) == 0 {
data, err = mc.buf.takeBuffer(minPktLen)
} else {
data, err = mc.buf.takeCompleteBuffer()
// In this case the len(data) == cap(data) which is used to optimise the flow below.
}
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.log(err)
return errBadConnNoWrite
}
// 指令类型为stmtExec
data[4] = comStmtExecute
// 接下来4个字节设为statmentId
data[5] = byte(stmt.id)
data[6] = byte(stmt.id >> 8)
data[7] = byte(stmt.id >> 16)
data[8] = byte(stmt.id >> 24)
// 接下来设一些固定值
// flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
data[9] = 0x00
// iteration_count (uint32(1)) [4 bytes]
data[10] = 0x01
data[11] = 0x00
data[12] = 0x00
data[13] = 0x00
// 往data里填充args
if len(args) > 0 {
// ...
for i, arg := range args {
// 先填充每个arg的类型,再填充每个arg的值
}
// ...
}
return mc.writePacket(data)
}
解析响应结果和普通query流程相同,这里不再分析
执行Exec操作
exec操作和query类型,类似只用传stmtId
和参数,解析响应的流程和普通exec类似,这里不再分析
事务
sql标准库定义了driver.Tx,供各个驱动实现:
type Tx interface {
Commit() error
Rollback() error
}
mysql驱动的实现为:
type mysqlTx struct {
mc *mysqlConn
}
mysqlConn调Begin生成一个driver.Tx
:就是往mysql server发送START TRANSACTION
指令
func (mc *mysqlConn) Begin() (driver.Tx, error) {
return mc.begin(false)
}
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
if mc.closed.Load() {
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
var q string
if readOnly {
q = "START TRANSACTION READ ONLY"
} else {
q = "START TRANSACTION"
}
// 将开始事务命令发给mysql server
err := mc.exec(q)
if err == nil {
// 封装成mysqlTx
return &mysqlTx{mc}, err
}
return nil, mc.markBadConn(err)
}
Commit和Rollback方法如下:
就是发送COMMIT
和ROLLBACK
命令,之后持有的mysqlConn置为空
func (tx *mysqlTx) Commit() (err error) {
if tx.mc == nil || tx.mc.closed.Load() {
return ErrInvalidConn
}
err = tx.mc.exec("COMMIT")
tx.mc = nil
return
}
func (tx *mysqlTx) Rollback() (err error) {
if tx.mc == nil || tx.mc.closed.Load() {
return ErrInvalidConn
}
err = tx.mc.exec("ROLLBACK")
tx.mc = nil
return
}
下面串一下驱动实现的driver.Tx,怎么和sql.Tx配合,实现增删改查操作
我们知道,执行事务上的各种操作一定要在某个固定的连接上,也就是调Begin的那个连接,那sql包是怎么实现的呢?答案就在sql.Tx
上
- 开启事务时,调驱动用连接dc.ci创建driver.Tx
- 将连接,driver.Tx绑定到sql.Tx返回
- 之后在Tx上执行exec,query操作,都用Tx上绑定的连接
在db.BeginTx时:
func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
var tx *Tx
var err error
err = db.retry(func(strategy connReuseStrategy) error {
tx, err = db.begin(ctx, opts, strategy)
return err
})
return tx, err
}
DB.begin
func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStrategy) (tx *Tx, err error) {
// 从连接池获取一个连接·
dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
return db.beginDC(ctx, dc, dc.releaseConn, opts)
}
DB.beginDC
func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), opts *TxOptions) (tx *Tx, err error) {
var txi driver.Tx
keepConnOnRollback := false
withLock(dc, func() {
_, hasSessionResetter := dc.ci.(driver.SessionResetter)
_, hasConnectionValidator := dc.ci.(driver.Validator)
keepConnOnRollback = hasSessionResetter && hasConnectionValidator
// 调驱动用连接dc.ci创建driver.Tx
txi, err = ctxDriverBegin(ctx, opts, dc.ci)
})
if err != nil {
release(err)
return nil, err
}
// 将连接dc,驱动的事务txi绑定到sql.Tx上返回
tx = &Tx{
db: db,
dc: dc,
releaseConn: release,
txi: txi,
cancel: cancel,
keepConnOnRollback: keepConnOnRollback,
ctx: ctx,
}
go tx.awaitDone()
return tx, nil
}
上层sql.Tx
在执行Exec,Query操作时,使用调Begin的那个连接,也就是和事务绑定的连接
例如:在事务Tx上执行ExecContext
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
// 获得一个连接
dc, release, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
// 用连接dc执行exec
return tx.db.execDC(ctx, dc, release, query, args)
}
grabConn:返回Tx上绑定的那个连接
func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) {
// ...
// 返回tx.dc
return tx.dc, tx.closemuRUnlockRelease, nil
}
读取响应
query响应
标准库sql/driver中定义了返回结果的接口Rows:
type Rows interface {
// 查询结果的列名
Columns() []string
Close() error
// 将下一行的数据读到dest中,返回io.EOF代表没数据了
Next(dest []Value) error
}
mysql实现的Rows如下:
type mysqlRows struct {
mc *mysqlConn
rs resultSet
finish func()
}
type binaryRows struct {
mysqlRows
}
type textRows struct {
mysqlRows
}
resultSet包含了列的元数据信息,列名信息
type resultSet struct {
columns []mysqlField
columnNames []string
done bool
}
Colunms的实现:
之前读取query响应时,已经将columns信息解析到rows.rs.columns
中了,这里只用提取columnNames
func (rows *mysqlRows) Columns() []string {
// 已经解析过columnNames
if rows.rs.columnNames != nil {
return rows.rs.columnNames
}
columns := make([]string, len(rows.rs.columns))
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
// ...
} else {
// 将columns中的name提取到columnNames中
for i := range columns {
columns[i] = rows.rs.columns[i].name
}
}
rows.rs.columnNames = columns
return columns
}
Close的实现:
上层sql.Rows在close方法中会释放连接,驱动这一层的close主要将缓冲区的数据丢弃
func (rows *mysqlRows) Close() (err error) {
if f := rows.finish; f != nil {
f()
rows.finish = nil
}
mc := rows.mc
if mc == nil {
return nil
}
if err := mc.error(); err != nil {
return err
}
// Remove unread packets from stream
if !rows.rs.done {
err = mc.readUntilEOF()
}
if err == nil {
handleOk := mc.clearResult()
// 将缓冲区剩余的数据读完
if err = handleOk.discardResults(); err != nil {
return err
}
}
// 将关联的连接置空
rows.mc = nil
return err
}
Next的实现:
分为textRows和binaryRows
- binaryRows:用于在预编译模式下接受查询结果
- textRows:用于非预编译模式下接受查询结果
我们看textRows:将数据从缓冲区读到dest中
func (rows *textRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
if err := mc.error(); err != nil {
return err
}
// Fetch next row from stream
return rows.readRow(dest)
}
return io.EOF
}
textRows.readRows:
- 从连接中读一个完整的包
- 根据dest的顺序,依次读取每个字符串,根据
rows.rs.columns
的类型,将字符串转换为对应的类型,放到dest中
func (rows *textRows) readRow(dest []driver.Value) error {
mc := rows.mc
if rows.rs.done {
return io.EOF
}
// 读一个包,一次性把这个包的数据都读完
data, err := mc.readPacket()
if err != nil {
return err
}
// 读到EOF了,返回
if data[0] == iEOF && len(data) == 5 {
// server_status [2 bytes]
rows.mc.status = readStatus(data[3:])
rows.rs.done = true
if !rows.HasNextResultSet() {
rows.mc = nil
}
return io.EOF
}
if data[0] == iERR {
rows.mc = nil
return mc.handleErrorPacket(data)
}
// RowSet Packet
var (
n int
isNull bool
pos int = 0
)
// 依次读各个列
for i := range dest {
// Read bytes and convert to string
var buf []byte
// 读一个字符串
buf, isNull, n, err = readLengthEncodedString(data[pos:])
pos += n
if err != nil {
return err
}
if isNull {
dest[i] = nil
continue
}
// 根据不同的类型,解析成不同的列
switch rows.rs.columns[i].fieldType {
case fieldTypeTimestamp,
fieldTypeDateTime,
fieldTypeDate,
fieldTypeNewDate:
if mc.parseTime {
dest[i], err = parseDateTime(buf, mc.cfg.Loc)
} else {
dest[i] = buf
}
case fieldTypeTiny, fieldTypeShort, fieldTypeInt24, fieldTypeYear, fieldTypeLong:
dest[i], err = strconv.ParseInt(string(buf), 10, 64)
case fieldTypeLongLong:
if rows.rs.columns[i].flags&flagUnsigned != 0 {
dest[i], err = strconv.ParseUint(string(buf), 10, 64)
} else {
dest[i], err = strconv.ParseInt(string(buf), 10, 64)
}
case fieldTypeFloat:
var d float64
d, err = strconv.ParseFloat(string(buf), 32)
dest[i] = float32(d)
case fieldTypeDouble:
dest[i], err = strconv.ParseFloat(string(buf), 64)
default:
dest[i] = buf
}
if err != nil {
return err
}
}
return nil
}
binaryRows.readRows原理类似,这里不再做分析
exec响应
在driver.Result定义了驱动应该实现的接口:
type Result interface {
LastInsertId() (int64, error)
RowsAffected() (int64, error)
}
mysql驱动的实现为:
type mysqlResult struct {
// One entry in both slices is created for every executed statement result.
affectedRows []int64
insertIds []int64
}
被mysqlConn持有:
type mysqlConn struct {
buf buffer
netConn net.Conn
rawConn net.Conn
// 持有mysqlResult
result mysqlResult
}
实现接口的两个方法:
func (res *mysqlResult) LastInsertId() (int64, error) {
return res.insertIds[len(res.insertIds)-1], nil
}
func (res *mysqlResult) RowsAffected() (int64, error) {
return res.affectedRows[len(res.affectedRows)-1], nil
}
那这两个值啥时候塞到mysqlConn.result里呢?
执行完exec,读取响应时:
func (mc *okHandler) handleOkPacket(data []byte) error {
var n, m int
var affectedRows, insertId uint64
// 影响行数,n代表了几个字节
affectedRows, _, n = readLengthEncodedInteger(data[1:])
// insertId,m代表读了几个字节
insertId, _, m = readLengthEncodedInteger(data[1+n:])
if len(mc.result.affectedRows) > 0 {
// 保存affectedRows
mc.result.affectedRows[len(mc.result.affectedRows)-1] = int64(affectedRows)
}
if len(mc.result.insertIds) > 0 {
// 保存insertId
mc.result.insertIds[len(mc.result.insertIds)-1] = int64(insertId)
}
// server_status [2 bytes]
mc.status = readStatus(data[1+n+m : 1+n+m+2])
if mc.status&statusMoreResultsExists != 0 {
return nil
}
// warning count [2 bytes]
return nil
}
总结
至此,关于mysql驱动的源码本文已经分析完毕。下一篇文章分析gorm的orm架构设计以及实现原理