Files
qfixpt/quickfix/store/sql/sql_store.go
2026-03-12 12:14:13 -03:00

434 lines
13 KiB
Go

// Copyright (c) quickfixengine.org All rights reserved.
//
// This file may be distributed under the terms of the quickfixengine.org
// license as defined by quickfixengine.org and appearing in the file
// LICENSE included in the packaging of this file.
//
// This file is provided AS IS with NO WARRANTY OF ANY KIND, INCLUDING
// THE WARRANTY OF DESIGN, MERCHANTABILITY AND FITNESS FOR A
// PARTICULAR PURPOSE.
//
// See http://www.quickfixengine.org/LICENSE for licensing information.
//
// Contact ask@quickfixengine.org if any conditions of this licensing
// are not clear to you.
package sql
import (
"database/sql"
"fmt"
"regexp"
"time"
"github.com/pkg/errors"
"quantex.com/qfixpt/quickfix"
"quantex.com/qfixpt/quickfix/config"
)
const (
defaultMessagesTable = "messages"
defaultSessionsTable = "sessions"
)
type sqlStoreFactory struct {
settings *quickfix.Settings
}
type sqlStore struct {
sessionID quickfix.SessionID
cache quickfix.MessageStore
sqlDriver string
sqlDataSourceName string
sqlConnMaxLifetime time.Duration
db *sql.DB
placeholder placeholderFunc
messagesTable string
sessionsTable string
sqlUpdateSeqNums string
sqlInsertSession string
sqlGetSeqNums string
sqlUpdateMessage string
sqlInsertMessage string
sqlGetMessages string
sqlUpdateSession string
sqlUpdateSenderSeqNum string
sqlUpdateTargetSeqNum string
sqlDeleteMessages string
}
type placeholderFunc func(int) string
var rePlaceholder = regexp.MustCompile(`\?`)
func sqlString(raw string, placeholder placeholderFunc) string {
if placeholder == nil {
return raw
}
idx := 0
return rePlaceholder.ReplaceAllStringFunc(raw, func(_ string) string {
p := placeholder(idx)
idx++
return p
})
}
func postgresPlaceholder(i int) string {
return fmt.Sprintf("$%d", i+1)
}
// NewStoreFactory returns a sql-based implementation of MessageStoreFactory.
func NewStoreFactory(settings *quickfix.Settings) quickfix.MessageStoreFactory {
return sqlStoreFactory{settings: settings}
}
// Create creates a new SQLStore implementation of the MessageStore interface.
func (f sqlStoreFactory) Create(sessionID quickfix.SessionID) (msgStore quickfix.MessageStore, err error) {
globalSettings := f.settings.GlobalSettings()
dynamicSessions, _ := globalSettings.BoolSetting(config.DynamicSessions)
sessionSettings, ok := f.settings.SessionSettings()[sessionID]
if !ok {
if dynamicSessions {
sessionSettings = globalSettings
} else {
return nil, fmt.Errorf("unknown session: %v", sessionID)
}
}
sqlDriver, err := sessionSettings.Setting(config.SQLStoreDriver)
if err != nil {
return nil, err
}
sqlDataSourceName, err := sessionSettings.Setting(config.SQLStoreDataSourceName)
if err != nil {
return nil, err
}
messagesTableName := defaultMessagesTable
if name, err := sessionSettings.Setting(config.SQLStoreMessagesTableName); err == nil {
messagesTableName = name
}
sessionsTableName := defaultSessionsTable
if name, err := sessionSettings.Setting(config.SQLStoreSessionsTableName); err == nil {
sessionsTableName = name
}
sqlConnMaxLifetime := 0 * time.Second
if sessionSettings.HasSetting(config.SQLStoreConnMaxLifetime) {
sqlConnMaxLifetime, err = sessionSettings.DurationSetting(config.SQLStoreConnMaxLifetime)
if err != nil {
return nil, err
}
}
return newSQLStore(sessionID, sqlDriver, sqlDataSourceName, messagesTableName, sessionsTableName, sqlConnMaxLifetime)
}
func newSQLStore(sessionID quickfix.SessionID, driver, dataSourceName, messagesTableName, sessionsTableName string, connMaxLifetime time.Duration) (store *sqlStore, err error) {
memStore, memErr := quickfix.NewMemoryStoreFactory().Create(sessionID)
if memErr != nil {
err = errors.Wrap(memErr, "cache creation")
return
}
store = &sqlStore{
sessionID: sessionID,
cache: memStore,
sqlDriver: driver,
sqlDataSourceName: dataSourceName,
sqlConnMaxLifetime: connMaxLifetime,
messagesTable: messagesTableName,
sessionsTable: sessionsTableName,
}
if err = store.cache.Reset(); err != nil {
err = errors.Wrap(err, "cache reset")
return
}
if store.sqlDriver == "postgres" || store.sqlDriver == "pgx" {
store.placeholder = postgresPlaceholder
}
if store.db, err = sql.Open(store.sqlDriver, store.sqlDataSourceName); err != nil {
return nil, err
}
store.db.SetConnMaxLifetime(store.sqlConnMaxLifetime)
if err = store.db.Ping(); err != nil { // ensure immediate connection
return nil, err
}
store.setSQLStatements()
if err = store.populateCache(); err != nil {
return nil, err
}
return store, nil
}
func (store *sqlStore) setSQLStatements() {
idColumns := `beginstring, session_qualifier, sendercompid, sendersubid, senderlocid, targetcompid, targetsubid, targetlocid`
idPlaceholders := `?,?,?,?,?,?,?,?`
idWhereClause := `beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? AND targetcompid=? AND targetsubid=? AND targetlocid=?`
store.sqlInsertMessage = fmt.Sprintf(`INSERT INTO %s (
msgseqnum, message, %s) VALUES (?, ?, %s)`,
store.messagesTable, idColumns, idPlaceholders)
store.sqlUpdateMessage = fmt.Sprintf(`UPDATE %s SET message=? WHERE %s AND msgseqnum=?`,
store.messagesTable, idWhereClause)
store.sqlGetMessages = fmt.Sprintf(`SELECT message FROM %s WHERE %s AND msgseqnum>=? AND msgseqnum<=? ORDER BY msgseqnum`,
store.messagesTable, idWhereClause)
store.sqlDeleteMessages = fmt.Sprintf(`DELETE FROM %s WHERE %s`,
store.messagesTable, idWhereClause)
store.sqlInsertSession = fmt.Sprintf(`INSERT INTO %s (
creation_time, incoming_seqnum, outgoing_seqnum, %s) VALUES (?, ?, ?, %s)`,
store.sessionsTable, idColumns, idPlaceholders)
store.sqlGetSeqNums = fmt.Sprintf(`SELECT creation_time, incoming_seqnum, outgoing_seqnum FROM %s WHERE %s`,
store.sessionsTable, idWhereClause)
store.sqlUpdateSession = fmt.Sprintf(`UPDATE %s SET creation_time=?, incoming_seqnum=?, outgoing_seqnum=? WHERE %s`,
store.sessionsTable, idWhereClause)
store.sqlUpdateSenderSeqNum = fmt.Sprintf(`UPDATE %s SET outgoing_seqnum=? WHERE %s`,
store.sessionsTable, idWhereClause)
store.sqlUpdateTargetSeqNum = fmt.Sprintf(`UPDATE %s SET incoming_seqnum=? WHERE %s`,
store.sessionsTable, idWhereClause)
store.sqlUpdateSeqNums = fmt.Sprintf(`UPDATE %s SET incoming_seqnum=?, outgoing_seqnum=? WHERE %s`,
store.sessionsTable, idWhereClause)
}
// Reset deletes the store records and sets the seqnums back to 1.
func (store *sqlStore) Reset() error {
s := store.sessionID
_, err := store.db.Exec(sqlString(store.sqlDeleteMessages, store.placeholder),
s.BeginString, s.Qualifier,
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
if err != nil {
return err
}
if err = store.cache.Reset(); err != nil {
return err
}
_, err = store.db.Exec(sqlString(store.sqlUpdateSession, store.placeholder),
store.cache.CreationTime(), store.cache.NextTargetMsgSeqNum(), store.cache.NextSenderMsgSeqNum(),
s.BeginString, s.Qualifier,
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
return err
}
// Refresh reloads the store from the database.
func (store *sqlStore) Refresh() error {
if err := store.cache.Reset(); err != nil {
return err
}
return store.populateCache()
}
func (store *sqlStore) populateCache() error {
s := store.sessionID
var creationTime time.Time
var incomingSeqNum, outgoingSeqNum int
row := store.db.QueryRow(sqlString(store.sqlGetSeqNums, store.placeholder),
s.BeginString, s.Qualifier,
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
err := row.Scan(&creationTime, &incomingSeqNum, &outgoingSeqNum)
// session record found, load it
if err == nil {
store.cache.SetCreationTime(creationTime)
if err = store.cache.SetNextTargetMsgSeqNum(incomingSeqNum); err != nil {
return errors.Wrap(err, "cache set next target")
}
if err = store.cache.SetNextSenderMsgSeqNum(outgoingSeqNum); err != nil {
return errors.Wrap(err, "cache set next sender")
}
return nil
}
// fatal error, give up
if err != sql.ErrNoRows {
return err
}
// session record not found, create it
_, err = store.db.Exec(sqlString(store.sqlInsertSession, store.placeholder),
store.cache.CreationTime(),
store.cache.NextTargetMsgSeqNum(),
store.cache.NextSenderMsgSeqNum(),
s.BeginString, s.Qualifier,
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
return err
}
// NextSenderMsgSeqNum returns the next MsgSeqNum that will be sent.
func (store *sqlStore) NextSenderMsgSeqNum() int {
return store.cache.NextSenderMsgSeqNum()
}
// NextTargetMsgSeqNum returns the next MsgSeqNum that should be received.
func (store *sqlStore) NextTargetMsgSeqNum() int {
return store.cache.NextTargetMsgSeqNum()
}
// SetNextSenderMsgSeqNum sets the next MsgSeqNum that will be sent.
func (store *sqlStore) SetNextSenderMsgSeqNum(next int) error {
s := store.sessionID
_, err := store.db.Exec(sqlString(store.sqlUpdateSenderSeqNum, store.placeholder),
next, s.BeginString, s.Qualifier,
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
if err != nil {
return err
}
return store.cache.SetNextSenderMsgSeqNum(next)
}
// SetNextTargetMsgSeqNum sets the next MsgSeqNum that should be received.
func (store *sqlStore) SetNextTargetMsgSeqNum(next int) error {
s := store.sessionID
_, err := store.db.Exec(sqlString(store.sqlUpdateTargetSeqNum, store.placeholder),
next, s.BeginString, s.Qualifier,
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
if err != nil {
return err
}
return store.cache.SetNextTargetMsgSeqNum(next)
}
// IncrNextSenderMsgSeqNum increments the next MsgSeqNum that will be sent.
func (store *sqlStore) IncrNextSenderMsgSeqNum() error {
if err := store.SetNextSenderMsgSeqNum(store.cache.NextSenderMsgSeqNum() + 1); err != nil {
return errors.Wrap(err, "store next")
}
return nil
}
// IncrNextTargetMsgSeqNum increments the next MsgSeqNum that should be received.
func (store *sqlStore) IncrNextTargetMsgSeqNum() error {
if err := store.SetNextTargetMsgSeqNum(store.cache.NextTargetMsgSeqNum() + 1); err != nil {
return errors.Wrap(err, "store next")
}
return nil
}
// CreationTime returns the creation time of the store.
func (store *sqlStore) CreationTime() time.Time {
return store.cache.CreationTime()
}
// SetCreationTime is a no-op for SQLStore.
func (store *sqlStore) SetCreationTime(_ time.Time) {
}
func (store *sqlStore) SaveMessage(seqNum int, msg []byte) error {
s := store.sessionID
_, err := store.db.Exec(sqlString(store.sqlInsertMessage, store.placeholder),
seqNum, string(msg),
s.BeginString, s.Qualifier,
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
return err
}
func (store *sqlStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []byte) error {
s := store.sessionID
tx, err := store.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
_, err = tx.Exec(sqlString(store.sqlInsertMessage, store.placeholder),
seqNum, string(msg),
s.BeginString, s.Qualifier,
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
if err != nil {
return err
}
next := store.cache.NextSenderMsgSeqNum() + 1
_, err = tx.Exec(sqlString(store.sqlUpdateSenderSeqNum, store.placeholder),
next, s.BeginString, s.Qualifier,
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
}
return store.cache.SetNextSenderMsgSeqNum(next)
}
func (store *sqlStore) IterateMessages(beginSeqNum, endSeqNum int, cb func([]byte) error) error {
s := store.sessionID
rows, err := store.db.Query(sqlString(store.sqlGetMessages, store.placeholder),
s.BeginString, s.Qualifier,
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
s.TargetCompID, s.TargetSubID, s.TargetLocationID,
beginSeqNum, endSeqNum)
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var message string
if err = rows.Scan(&message); err != nil {
return err
} else if err = cb([]byte(message)); err != nil {
return err
}
}
return rows.Err()
}
func (store *sqlStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) {
var msgs [][]byte
err := store.IterateMessages(beginSeqNum, endSeqNum, func(msg []byte) error {
msgs = append(msgs, msg)
return nil
})
return msgs, err
}
// Close closes the store's database connection.
func (store *sqlStore) Close() error {
if store.db != nil {
store.db.Close()
store.db = nil
}
return nil
}