434 lines
9.3 KiB
Go
434 lines
9.3 KiB
Go
|
//MySQL driver for Go database/sql package
|
||
|
package godrv
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"database/sql/driver"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"github.com/ziutek/mymysql/mysql"
|
||
|
"github.com/ziutek/mymysql/native"
|
||
|
"io"
|
||
|
"net"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
type conn struct {
|
||
|
my mysql.Conn
|
||
|
}
|
||
|
|
||
|
type rowsRes struct {
|
||
|
row mysql.Row
|
||
|
my mysql.Result
|
||
|
simpleQuery mysql.Stmt
|
||
|
}
|
||
|
|
||
|
func errFilter(err error) error {
|
||
|
if err == io.ErrUnexpectedEOF {
|
||
|
return driver.ErrBadConn
|
||
|
}
|
||
|
if _, ok := err.(net.Error); ok {
|
||
|
return driver.ErrBadConn
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func join(a []string) string {
|
||
|
n := 0
|
||
|
for _, s := range a {
|
||
|
n += len(s)
|
||
|
}
|
||
|
b := make([]byte, n)
|
||
|
n = 0
|
||
|
for _, s := range a {
|
||
|
n += copy(b[n:], s)
|
||
|
}
|
||
|
return string(b)
|
||
|
}
|
||
|
|
||
|
func (c conn) parseQuery(query string, args []driver.Value) (string, error) {
|
||
|
if len(args) == 0 {
|
||
|
return query, nil
|
||
|
}
|
||
|
if strings.ContainsAny(query, `'"`) {
|
||
|
return "", nil
|
||
|
}
|
||
|
q := make([]string, 2*len(args)+1)
|
||
|
n := 0
|
||
|
for _, a := range args {
|
||
|
i := strings.IndexRune(query, '?')
|
||
|
if i == -1 {
|
||
|
return "", errors.New("number of parameters doesn't match number of placeholders")
|
||
|
}
|
||
|
var s string
|
||
|
switch v := a.(type) {
|
||
|
case nil:
|
||
|
s = "NULL"
|
||
|
case string:
|
||
|
s = "'" + c.my.Escape(v) + "'"
|
||
|
case []byte:
|
||
|
s = "'" + c.my.Escape(string(v)) + "'"
|
||
|
case int64:
|
||
|
s = strconv.FormatInt(v, 10)
|
||
|
case time.Time:
|
||
|
s = "'" + v.Format(mysql.TimeFormat) + "'"
|
||
|
case bool:
|
||
|
if v {
|
||
|
s = "1"
|
||
|
} else {
|
||
|
s = "0"
|
||
|
}
|
||
|
case float64:
|
||
|
s = strconv.FormatFloat(v, 'e', 12, 64)
|
||
|
default:
|
||
|
panic(fmt.Sprintf("%v (%T) can't be handled by godrv"))
|
||
|
}
|
||
|
q[n] = query[:i]
|
||
|
q[n+1] = s
|
||
|
query = query[i+1:]
|
||
|
n += 2
|
||
|
}
|
||
|
q[n] = query
|
||
|
return join(q), nil
|
||
|
}
|
||
|
|
||
|
func (c conn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
||
|
q, err := c.parseQuery(query, args)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if len(q) == 0 {
|
||
|
return nil, driver.ErrSkip
|
||
|
}
|
||
|
res, err := c.my.Start(q)
|
||
|
if err != nil {
|
||
|
return nil, errFilter(err)
|
||
|
}
|
||
|
return &rowsRes{my: res}, nil
|
||
|
}
|
||
|
|
||
|
var textQuery = mysql.Stmt(new(native.Stmt))
|
||
|
|
||
|
func (c conn) Query(query string, args []driver.Value) (driver.Rows, error) {
|
||
|
q, err := c.parseQuery(query, args)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if len(q) == 0 {
|
||
|
return nil, driver.ErrSkip
|
||
|
}
|
||
|
res, err := c.my.Start(q)
|
||
|
if err != nil {
|
||
|
return nil, errFilter(err)
|
||
|
}
|
||
|
return &rowsRes{row: res.MakeRow(), my: res, simpleQuery: textQuery}, nil
|
||
|
}
|
||
|
|
||
|
type stmt struct {
|
||
|
my mysql.Stmt
|
||
|
args []interface{}
|
||
|
}
|
||
|
|
||
|
func (s *stmt) run(args []driver.Value) (*rowsRes, error) {
|
||
|
for i, v := range args {
|
||
|
s.args[i] = interface{}(v)
|
||
|
}
|
||
|
res, err := s.my.Run(s.args...)
|
||
|
if err != nil {
|
||
|
return nil, errFilter(err)
|
||
|
}
|
||
|
return &rowsRes{my: res}, nil
|
||
|
}
|
||
|
|
||
|
func (c conn) Prepare(query string) (driver.Stmt, error) {
|
||
|
st, err := c.my.Prepare(query)
|
||
|
if err != nil {
|
||
|
return nil, errFilter(err)
|
||
|
}
|
||
|
return &stmt{st, make([]interface{}, st.NumParam())}, nil
|
||
|
}
|
||
|
|
||
|
func (c conn) Close() (err error) {
|
||
|
err = c.my.Close()
|
||
|
c.my = nil
|
||
|
if err != nil {
|
||
|
err = errFilter(err)
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
type tx struct {
|
||
|
my mysql.Transaction
|
||
|
}
|
||
|
|
||
|
func (c conn) Begin() (driver.Tx, error) {
|
||
|
t, err := c.my.Begin()
|
||
|
if err != nil {
|
||
|
return nil, errFilter(err)
|
||
|
}
|
||
|
return tx{t}, nil
|
||
|
}
|
||
|
|
||
|
func (t tx) Commit() (err error) {
|
||
|
err = t.my.Commit()
|
||
|
if err != nil {
|
||
|
err = errFilter(err)
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (t tx) Rollback() (err error) {
|
||
|
err = t.my.Rollback()
|
||
|
if err != nil {
|
||
|
err = errFilter(err)
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (s *stmt) Close() (err error) {
|
||
|
err = s.my.Delete()
|
||
|
s.my = nil
|
||
|
if err != nil {
|
||
|
err = errFilter(err)
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (s *stmt) NumInput() int {
|
||
|
return s.my.NumParam()
|
||
|
}
|
||
|
|
||
|
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||
|
return s.run(args)
|
||
|
}
|
||
|
|
||
|
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||
|
r, err := s.run(args)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
r.row = r.my.MakeRow()
|
||
|
return r, nil
|
||
|
}
|
||
|
|
||
|
func (r *rowsRes) LastInsertId() (int64, error) {
|
||
|
return int64(r.my.InsertId()), nil
|
||
|
}
|
||
|
|
||
|
func (r *rowsRes) RowsAffected() (int64, error) {
|
||
|
return int64(r.my.AffectedRows()), nil
|
||
|
}
|
||
|
|
||
|
func (r *rowsRes) Columns() []string {
|
||
|
flds := r.my.Fields()
|
||
|
cls := make([]string, len(flds))
|
||
|
for i, f := range flds {
|
||
|
cls[i] = f.Name
|
||
|
}
|
||
|
return cls
|
||
|
}
|
||
|
|
||
|
func (r *rowsRes) Close() error {
|
||
|
if r.my == nil {
|
||
|
return nil // closed before
|
||
|
}
|
||
|
if err := r.my.End(); err != nil {
|
||
|
return errFilter(err)
|
||
|
}
|
||
|
if r.simpleQuery != nil && r.simpleQuery != textQuery {
|
||
|
if err := r.simpleQuery.Delete(); err != nil {
|
||
|
return errFilter(err)
|
||
|
}
|
||
|
}
|
||
|
r.my = nil
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
var location = time.Local
|
||
|
|
||
|
// DATE, DATETIME, TIMESTAMP are treated as they are in Local time zone (this
|
||
|
// can be changed globaly using SetLocation function).
|
||
|
func (r *rowsRes) Next(dest []driver.Value) error {
|
||
|
if r.my == nil {
|
||
|
return io.EOF // closed before
|
||
|
}
|
||
|
err := r.my.ScanRow(r.row)
|
||
|
if err == nil {
|
||
|
if r.simpleQuery == textQuery {
|
||
|
// workaround for time.Time from text queries
|
||
|
for i, f := range r.my.Fields() {
|
||
|
if r.row[i] != nil {
|
||
|
switch f.Type {
|
||
|
case native.MYSQL_TYPE_TIMESTAMP, native.MYSQL_TYPE_DATETIME,
|
||
|
native.MYSQL_TYPE_DATE, native.MYSQL_TYPE_NEWDATE:
|
||
|
r.row[i] = r.row.ForceTime(i, location)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
for i, d := range r.row {
|
||
|
dest[i] = driver.Value(d)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
if err != io.EOF {
|
||
|
return errFilter(err)
|
||
|
}
|
||
|
if r.simpleQuery != nil && r.simpleQuery != textQuery {
|
||
|
if err = r.simpleQuery.Delete(); err != nil {
|
||
|
return errFilter(err)
|
||
|
}
|
||
|
}
|
||
|
r.my = nil
|
||
|
return io.EOF
|
||
|
}
|
||
|
|
||
|
type Driver struct {
|
||
|
// Defaults
|
||
|
proto, laddr, raddr, user, passwd, db string
|
||
|
timeout time.Duration
|
||
|
dialer Dialer
|
||
|
|
||
|
initCmds []string
|
||
|
}
|
||
|
|
||
|
// Open new connection. The uri need to have the following syntax:
|
||
|
//
|
||
|
// [PROTOCOL_SPECFIIC*]DBNAME/USER/PASSWD
|
||
|
//
|
||
|
// where protocol spercific part may be empty (this means connection to
|
||
|
// local server using default protocol). Currently possible forms:
|
||
|
//
|
||
|
// DBNAME/USER/PASSWD
|
||
|
// unix:SOCKPATH*DBNAME/USER/PASSWD
|
||
|
// unix:SOCKPATH,OPTIONS*DBNAME/USER/PASSWD
|
||
|
// tcp:ADDR*DBNAME/USER/PASSWD
|
||
|
// tcp:ADDR,OPTIONS*DBNAME/USER/PASSWD
|
||
|
// cloudsql:INSTANCE*DBNAME/USER/PASSWD
|
||
|
//
|
||
|
// OPTIONS can contain comma separated list of options in form:
|
||
|
// opt1=VAL1,opt2=VAL2,boolopt3,boolopt4
|
||
|
// Currently implemented options, in addition to default MySQL variables:
|
||
|
// laddr - local address/port (eg. 1.2.3.4:0)
|
||
|
// timeout - connect timeout in format accepted by time.ParseDuration
|
||
|
func (d *Driver) Open(uri string) (driver.Conn, error) {
|
||
|
cfg := *d // copy default configuration
|
||
|
pd := strings.SplitN(uri, "*", 2)
|
||
|
connCommands := []string{}
|
||
|
if len(pd) == 2 {
|
||
|
// Parse protocol part of URI
|
||
|
p := strings.SplitN(pd[0], ":", 2)
|
||
|
if len(p) != 2 {
|
||
|
return nil, errors.New("Wrong protocol part of URI")
|
||
|
}
|
||
|
cfg.proto = p[0]
|
||
|
options := strings.Split(p[1], ",")
|
||
|
cfg.raddr = options[0]
|
||
|
for _, o := range options[1:] {
|
||
|
kv := strings.SplitN(o, "=", 2)
|
||
|
var k, v string
|
||
|
if len(kv) == 2 {
|
||
|
k, v = kv[0], kv[1]
|
||
|
} else {
|
||
|
k, v = o, "true"
|
||
|
}
|
||
|
switch k {
|
||
|
case "laddr":
|
||
|
cfg.laddr = v
|
||
|
case "timeout":
|
||
|
to, err := time.ParseDuration(v)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
cfg.timeout = to
|
||
|
default:
|
||
|
connCommands = append(connCommands, "SET "+k+"="+v)
|
||
|
}
|
||
|
}
|
||
|
// Remove protocol part
|
||
|
pd = pd[1:]
|
||
|
}
|
||
|
// Parse database part of URI
|
||
|
dup := strings.SplitN(pd[0], "/", 3)
|
||
|
if len(dup) != 3 {
|
||
|
return nil, errors.New("Wrong database part of URI")
|
||
|
}
|
||
|
cfg.db = dup[0]
|
||
|
cfg.user = dup[1]
|
||
|
cfg.passwd = dup[2]
|
||
|
|
||
|
c := conn{mysql.New(
|
||
|
cfg.proto, cfg.laddr, cfg.raddr, cfg.user, cfg.passwd, cfg.db,
|
||
|
)}
|
||
|
if d.dialer != nil {
|
||
|
dialer := func(proto, laddr, raddr string, timeout time.Duration) (
|
||
|
net.Conn, error) {
|
||
|
|
||
|
return d.dialer(proto, laddr, raddr, cfg.user, cfg.passwd, timeout)
|
||
|
}
|
||
|
c.my.SetDialer(dialer)
|
||
|
}
|
||
|
|
||
|
// Establish the connection
|
||
|
c.my.SetTimeout(cfg.timeout)
|
||
|
for _, q := range cfg.initCmds {
|
||
|
c.my.Register(q) // Register initialisation commands
|
||
|
}
|
||
|
for _, q := range connCommands {
|
||
|
c.my.Register(q)
|
||
|
}
|
||
|
if err := c.my.Connect(); err != nil {
|
||
|
return nil, errFilter(err)
|
||
|
}
|
||
|
c.my.NarrowTypeSet(true)
|
||
|
c.my.FullFieldInfo(false)
|
||
|
return &c, nil
|
||
|
}
|
||
|
|
||
|
// Register registers initialisation commands.
|
||
|
// This is workaround, see http://codereview.appspot.com/5706047
|
||
|
func (drv *Driver) Register(query string) {
|
||
|
drv.initCmds = append(d.initCmds, query)
|
||
|
}
|
||
|
|
||
|
// Dialer can be used to dial connections to MySQL. If Dialer returns (nil, nil)
|
||
|
// the hook is skipped and normal dialing proceeds. user and dbname are there
|
||
|
// only for logging.
|
||
|
type Dialer func(proto, laddr, raddr, user, dbname string, timeout time.Duration) (net.Conn, error)
|
||
|
|
||
|
// SetDialer sets custom Dialer used by Driver to make connections
|
||
|
func (drv *Driver) SetDialer(dialer Dialer) {
|
||
|
drv.dialer = dialer
|
||
|
}
|
||
|
|
||
|
// Driver automatically registered in database/sql
|
||
|
var d = Driver{proto: "tcp", raddr: "127.0.0.1:3306"}
|
||
|
|
||
|
// Register calls Register method on driver registered in database/sql
|
||
|
func Register(query string) {
|
||
|
d.Register(query)
|
||
|
}
|
||
|
|
||
|
// SetDialer calls SetDialer method on driver registered in database/sql
|
||
|
func SetDialer(dialer Dialer) {
|
||
|
d.SetDialer(dialer)
|
||
|
}
|
||
|
|
||
|
func init() {
|
||
|
Register("SET NAMES utf8")
|
||
|
sql.Register("mymysql", &d)
|
||
|
}
|
||
|
|
||
|
// Version returns mymysql version string
|
||
|
func Version() string {
|
||
|
return mysql.Version()
|
||
|
}
|
||
|
|
||
|
// SetLocation changes default location used to convert dates obtained from
|
||
|
// server to time.Time.
|
||
|
func SetLocation(loc *time.Location) {
|
||
|
location = loc
|
||
|
}
|