434 lines
13 KiB
Go
434 lines
13 KiB
Go
// Copyright (c) quickfixengine.org All rights reserved.
|
|
//
|
|
// This file may be distributed under the terms of the quickfixengine.org
|
|
// license as defined by quickfixengine.org and appearing in the file
|
|
// LICENSE included in the packaging of this file.
|
|
//
|
|
// This file is provided AS IS with NO WARRANTY OF ANY KIND, INCLUDING
|
|
// THE WARRANTY OF DESIGN, MERCHANTABILITY AND FITNESS FOR A
|
|
// PARTICULAR PURPOSE.
|
|
//
|
|
// See http://www.quickfixengine.org/LICENSE for licensing information.
|
|
//
|
|
// Contact ask@quickfixengine.org if any conditions of this licensing
|
|
// are not clear to you.
|
|
|
|
package sql
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"regexp"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
|
|
"quantex.com/qfixpt/quickfix"
|
|
"quantex.com/qfixpt/quickfix/config"
|
|
)
|
|
|
|
const (
|
|
defaultMessagesTable = "messages"
|
|
defaultSessionsTable = "sessions"
|
|
)
|
|
|
|
type sqlStoreFactory struct {
|
|
settings *quickfix.Settings
|
|
}
|
|
|
|
type sqlStore struct {
|
|
sessionID quickfix.SessionID
|
|
cache quickfix.MessageStore
|
|
sqlDriver string
|
|
sqlDataSourceName string
|
|
sqlConnMaxLifetime time.Duration
|
|
db *sql.DB
|
|
placeholder placeholderFunc
|
|
messagesTable string
|
|
sessionsTable string
|
|
|
|
sqlUpdateSeqNums string
|
|
sqlInsertSession string
|
|
sqlGetSeqNums string
|
|
sqlUpdateMessage string
|
|
sqlInsertMessage string
|
|
sqlGetMessages string
|
|
sqlUpdateSession string
|
|
sqlUpdateSenderSeqNum string
|
|
sqlUpdateTargetSeqNum string
|
|
sqlDeleteMessages string
|
|
}
|
|
|
|
type placeholderFunc func(int) string
|
|
|
|
var rePlaceholder = regexp.MustCompile(`\?`)
|
|
|
|
func sqlString(raw string, placeholder placeholderFunc) string {
|
|
if placeholder == nil {
|
|
return raw
|
|
}
|
|
idx := 0
|
|
return rePlaceholder.ReplaceAllStringFunc(raw, func(_ string) string {
|
|
p := placeholder(idx)
|
|
idx++
|
|
return p
|
|
})
|
|
}
|
|
|
|
func postgresPlaceholder(i int) string {
|
|
return fmt.Sprintf("$%d", i+1)
|
|
}
|
|
|
|
// NewStoreFactory returns a sql-based implementation of MessageStoreFactory.
|
|
func NewStoreFactory(settings *quickfix.Settings) quickfix.MessageStoreFactory {
|
|
return sqlStoreFactory{settings: settings}
|
|
}
|
|
|
|
// Create creates a new SQLStore implementation of the MessageStore interface.
|
|
func (f sqlStoreFactory) Create(sessionID quickfix.SessionID) (msgStore quickfix.MessageStore, err error) {
|
|
globalSettings := f.settings.GlobalSettings()
|
|
dynamicSessions, _ := globalSettings.BoolSetting(config.DynamicSessions)
|
|
|
|
sessionSettings, ok := f.settings.SessionSettings()[sessionID]
|
|
if !ok {
|
|
if dynamicSessions {
|
|
sessionSettings = globalSettings
|
|
} else {
|
|
return nil, fmt.Errorf("unknown session: %v", sessionID)
|
|
}
|
|
}
|
|
|
|
sqlDriver, err := sessionSettings.Setting(config.SQLStoreDriver)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sqlDataSourceName, err := sessionSettings.Setting(config.SQLStoreDataSourceName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
messagesTableName := defaultMessagesTable
|
|
if name, err := sessionSettings.Setting(config.SQLStoreMessagesTableName); err == nil {
|
|
messagesTableName = name
|
|
}
|
|
|
|
sessionsTableName := defaultSessionsTable
|
|
if name, err := sessionSettings.Setting(config.SQLStoreSessionsTableName); err == nil {
|
|
sessionsTableName = name
|
|
}
|
|
|
|
sqlConnMaxLifetime := 0 * time.Second
|
|
if sessionSettings.HasSetting(config.SQLStoreConnMaxLifetime) {
|
|
sqlConnMaxLifetime, err = sessionSettings.DurationSetting(config.SQLStoreConnMaxLifetime)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return newSQLStore(sessionID, sqlDriver, sqlDataSourceName, messagesTableName, sessionsTableName, sqlConnMaxLifetime)
|
|
}
|
|
|
|
func newSQLStore(sessionID quickfix.SessionID, driver, dataSourceName, messagesTableName, sessionsTableName string, connMaxLifetime time.Duration) (store *sqlStore, err error) {
|
|
|
|
memStore, memErr := quickfix.NewMemoryStoreFactory().Create(sessionID)
|
|
if memErr != nil {
|
|
err = errors.Wrap(memErr, "cache creation")
|
|
return
|
|
}
|
|
|
|
store = &sqlStore{
|
|
sessionID: sessionID,
|
|
cache: memStore,
|
|
sqlDriver: driver,
|
|
sqlDataSourceName: dataSourceName,
|
|
sqlConnMaxLifetime: connMaxLifetime,
|
|
messagesTable: messagesTableName,
|
|
sessionsTable: sessionsTableName,
|
|
}
|
|
if err = store.cache.Reset(); err != nil {
|
|
err = errors.Wrap(err, "cache reset")
|
|
return
|
|
}
|
|
|
|
if store.sqlDriver == "postgres" || store.sqlDriver == "pgx" {
|
|
store.placeholder = postgresPlaceholder
|
|
}
|
|
|
|
if store.db, err = sql.Open(store.sqlDriver, store.sqlDataSourceName); err != nil {
|
|
return nil, err
|
|
}
|
|
store.db.SetConnMaxLifetime(store.sqlConnMaxLifetime)
|
|
|
|
if err = store.db.Ping(); err != nil { // ensure immediate connection
|
|
return nil, err
|
|
}
|
|
|
|
store.setSQLStatements()
|
|
|
|
if err = store.populateCache(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return store, nil
|
|
}
|
|
|
|
func (store *sqlStore) setSQLStatements() {
|
|
idColumns := `beginstring, session_qualifier, sendercompid, sendersubid, senderlocid, targetcompid, targetsubid, targetlocid`
|
|
idPlaceholders := `?,?,?,?,?,?,?,?`
|
|
idWhereClause := `beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? AND targetcompid=? AND targetsubid=? AND targetlocid=?`
|
|
|
|
store.sqlInsertMessage = fmt.Sprintf(`INSERT INTO %s (
|
|
msgseqnum, message, %s) VALUES (?, ?, %s)`,
|
|
store.messagesTable, idColumns, idPlaceholders)
|
|
|
|
store.sqlUpdateMessage = fmt.Sprintf(`UPDATE %s SET message=? WHERE %s AND msgseqnum=?`,
|
|
store.messagesTable, idWhereClause)
|
|
|
|
store.sqlGetMessages = fmt.Sprintf(`SELECT message FROM %s WHERE %s AND msgseqnum>=? AND msgseqnum<=? ORDER BY msgseqnum`,
|
|
store.messagesTable, idWhereClause)
|
|
|
|
store.sqlDeleteMessages = fmt.Sprintf(`DELETE FROM %s WHERE %s`,
|
|
store.messagesTable, idWhereClause)
|
|
|
|
store.sqlInsertSession = fmt.Sprintf(`INSERT INTO %s (
|
|
creation_time, incoming_seqnum, outgoing_seqnum, %s) VALUES (?, ?, ?, %s)`,
|
|
store.sessionsTable, idColumns, idPlaceholders)
|
|
|
|
store.sqlGetSeqNums = fmt.Sprintf(`SELECT creation_time, incoming_seqnum, outgoing_seqnum FROM %s WHERE %s`,
|
|
store.sessionsTable, idWhereClause)
|
|
|
|
store.sqlUpdateSession = fmt.Sprintf(`UPDATE %s SET creation_time=?, incoming_seqnum=?, outgoing_seqnum=? WHERE %s`,
|
|
store.sessionsTable, idWhereClause)
|
|
|
|
store.sqlUpdateSenderSeqNum = fmt.Sprintf(`UPDATE %s SET outgoing_seqnum=? WHERE %s`,
|
|
store.sessionsTable, idWhereClause)
|
|
|
|
store.sqlUpdateTargetSeqNum = fmt.Sprintf(`UPDATE %s SET incoming_seqnum=? WHERE %s`,
|
|
store.sessionsTable, idWhereClause)
|
|
|
|
store.sqlUpdateSeqNums = fmt.Sprintf(`UPDATE %s SET incoming_seqnum=?, outgoing_seqnum=? WHERE %s`,
|
|
store.sessionsTable, idWhereClause)
|
|
}
|
|
|
|
// Reset deletes the store records and sets the seqnums back to 1.
|
|
func (store *sqlStore) Reset() error {
|
|
s := store.sessionID
|
|
_, err := store.db.Exec(sqlString(store.sqlDeleteMessages, store.placeholder),
|
|
s.BeginString, s.Qualifier,
|
|
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
|
|
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = store.cache.Reset(); err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = store.db.Exec(sqlString(store.sqlUpdateSession, store.placeholder),
|
|
store.cache.CreationTime(), store.cache.NextTargetMsgSeqNum(), store.cache.NextSenderMsgSeqNum(),
|
|
s.BeginString, s.Qualifier,
|
|
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
|
|
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
|
|
|
|
return err
|
|
}
|
|
|
|
// Refresh reloads the store from the database.
|
|
func (store *sqlStore) Refresh() error {
|
|
if err := store.cache.Reset(); err != nil {
|
|
return err
|
|
}
|
|
return store.populateCache()
|
|
}
|
|
|
|
func (store *sqlStore) populateCache() error {
|
|
s := store.sessionID
|
|
var creationTime time.Time
|
|
var incomingSeqNum, outgoingSeqNum int
|
|
row := store.db.QueryRow(sqlString(store.sqlGetSeqNums, store.placeholder),
|
|
s.BeginString, s.Qualifier,
|
|
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
|
|
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
|
|
|
|
err := row.Scan(&creationTime, &incomingSeqNum, &outgoingSeqNum)
|
|
|
|
// session record found, load it
|
|
if err == nil {
|
|
store.cache.SetCreationTime(creationTime)
|
|
if err = store.cache.SetNextTargetMsgSeqNum(incomingSeqNum); err != nil {
|
|
return errors.Wrap(err, "cache set next target")
|
|
}
|
|
if err = store.cache.SetNextSenderMsgSeqNum(outgoingSeqNum); err != nil {
|
|
return errors.Wrap(err, "cache set next sender")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// fatal error, give up
|
|
if err != sql.ErrNoRows {
|
|
return err
|
|
}
|
|
|
|
// session record not found, create it
|
|
_, err = store.db.Exec(sqlString(store.sqlInsertSession, store.placeholder),
|
|
store.cache.CreationTime(),
|
|
store.cache.NextTargetMsgSeqNum(),
|
|
store.cache.NextSenderMsgSeqNum(),
|
|
s.BeginString, s.Qualifier,
|
|
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
|
|
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
|
|
|
|
return err
|
|
}
|
|
|
|
// NextSenderMsgSeqNum returns the next MsgSeqNum that will be sent.
|
|
func (store *sqlStore) NextSenderMsgSeqNum() int {
|
|
return store.cache.NextSenderMsgSeqNum()
|
|
}
|
|
|
|
// NextTargetMsgSeqNum returns the next MsgSeqNum that should be received.
|
|
func (store *sqlStore) NextTargetMsgSeqNum() int {
|
|
return store.cache.NextTargetMsgSeqNum()
|
|
}
|
|
|
|
// SetNextSenderMsgSeqNum sets the next MsgSeqNum that will be sent.
|
|
func (store *sqlStore) SetNextSenderMsgSeqNum(next int) error {
|
|
s := store.sessionID
|
|
_, err := store.db.Exec(sqlString(store.sqlUpdateSenderSeqNum, store.placeholder),
|
|
next, s.BeginString, s.Qualifier,
|
|
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
|
|
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return store.cache.SetNextSenderMsgSeqNum(next)
|
|
}
|
|
|
|
// SetNextTargetMsgSeqNum sets the next MsgSeqNum that should be received.
|
|
func (store *sqlStore) SetNextTargetMsgSeqNum(next int) error {
|
|
s := store.sessionID
|
|
_, err := store.db.Exec(sqlString(store.sqlUpdateTargetSeqNum, store.placeholder),
|
|
next, s.BeginString, s.Qualifier,
|
|
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
|
|
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return store.cache.SetNextTargetMsgSeqNum(next)
|
|
}
|
|
|
|
// IncrNextSenderMsgSeqNum increments the next MsgSeqNum that will be sent.
|
|
func (store *sqlStore) IncrNextSenderMsgSeqNum() error {
|
|
if err := store.SetNextSenderMsgSeqNum(store.cache.NextSenderMsgSeqNum() + 1); err != nil {
|
|
return errors.Wrap(err, "store next")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// IncrNextTargetMsgSeqNum increments the next MsgSeqNum that should be received.
|
|
func (store *sqlStore) IncrNextTargetMsgSeqNum() error {
|
|
if err := store.SetNextTargetMsgSeqNum(store.cache.NextTargetMsgSeqNum() + 1); err != nil {
|
|
return errors.Wrap(err, "store next")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CreationTime returns the creation time of the store.
|
|
func (store *sqlStore) CreationTime() time.Time {
|
|
return store.cache.CreationTime()
|
|
}
|
|
|
|
// SetCreationTime is a no-op for SQLStore.
|
|
func (store *sqlStore) SetCreationTime(_ time.Time) {
|
|
}
|
|
|
|
func (store *sqlStore) SaveMessage(seqNum int, msg []byte) error {
|
|
s := store.sessionID
|
|
|
|
_, err := store.db.Exec(sqlString(store.sqlInsertMessage, store.placeholder),
|
|
seqNum, string(msg),
|
|
s.BeginString, s.Qualifier,
|
|
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
|
|
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
|
|
|
|
return err
|
|
}
|
|
|
|
func (store *sqlStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []byte) error {
|
|
s := store.sessionID
|
|
|
|
tx, err := store.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
_, err = tx.Exec(sqlString(store.sqlInsertMessage, store.placeholder),
|
|
seqNum, string(msg),
|
|
s.BeginString, s.Qualifier,
|
|
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
|
|
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
next := store.cache.NextSenderMsgSeqNum() + 1
|
|
_, err = tx.Exec(sqlString(store.sqlUpdateSenderSeqNum, store.placeholder),
|
|
next, s.BeginString, s.Qualifier,
|
|
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
|
|
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return store.cache.SetNextSenderMsgSeqNum(next)
|
|
}
|
|
|
|
func (store *sqlStore) IterateMessages(beginSeqNum, endSeqNum int, cb func([]byte) error) error {
|
|
s := store.sessionID
|
|
rows, err := store.db.Query(sqlString(store.sqlGetMessages, store.placeholder),
|
|
s.BeginString, s.Qualifier,
|
|
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
|
|
s.TargetCompID, s.TargetSubID, s.TargetLocationID,
|
|
beginSeqNum, endSeqNum)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
for rows.Next() {
|
|
var message string
|
|
if err = rows.Scan(&message); err != nil {
|
|
return err
|
|
} else if err = cb([]byte(message)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return rows.Err()
|
|
}
|
|
|
|
func (store *sqlStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) {
|
|
var msgs [][]byte
|
|
err := store.IterateMessages(beginSeqNum, endSeqNum, func(msg []byte) error {
|
|
msgs = append(msgs, msg)
|
|
return nil
|
|
})
|
|
return msgs, err
|
|
}
|
|
|
|
// Close closes the store's database connection.
|
|
func (store *sqlStore) Close() error {
|
|
if store.db != nil {
|
|
store.db.Close()
|
|
store.db = nil
|
|
}
|
|
return nil
|
|
}
|