Files
qfixpt/quickfix/acceptor.go
2026-03-12 12:14:13 -03:00

453 lines
13 KiB
Go

// Copyright (c) quickfixengine.org All rights reserved.
//
// This file may be distributed under the terms of the quickfixengine.org
// license as defined by quickfixengine.org and appearing in the file
// LICENSE included in the packaging of this file.
//
// This file is provided AS IS with NO WARRANTY OF ANY KIND, INCLUDING
// THE WARRANTY OF DESIGN, MERCHANTABILITY AND FITNESS FOR A
// PARTICULAR PURPOSE.
//
// See http://www.quickfixengine.org/LICENSE for licensing information.
//
// Contact ask@quickfixengine.org if any conditions of this licensing
// are not clear to you.
package quickfix
import (
"bufio"
"bytes"
"crypto/tls"
"io"
"net"
"runtime/debug"
"strconv"
"sync"
proxyproto "github.com/pires/go-proxyproto"
"quantex.com/qfixpt/quickfix/config"
)
// Acceptor accepts connections from FIX clients and manages the associated sessions.
type Acceptor struct {
app Application
settings *Settings
logFactory LogFactory
storeFactory MessageStoreFactory
globalLog Log
sessions map[SessionID]*session
sessionGroup sync.WaitGroup
listenerShutdown sync.WaitGroup
dynamicSessions bool
dynamicQualifier bool
dynamicQualifierCount int
dynamicSessionChan chan *session
sessionAddr sync.Map
sessionHostPort map[SessionID]int
listeners map[string]net.Listener
connectionValidator ConnectionValidator
tlsConfig *tls.Config
newListenerCallback NewListenerCallback
sessionFactory
}
// ConnectionValidator is an interface allowing to implement a custom authentication logic.
type ConnectionValidator interface {
// Validate the connection for validity. This can be a part of authentication process.
// For example, you may tie up a SenderCompID to an IP range, or to a specific TLS certificate as a part of mTLS.
Validate(netConn net.Conn, session SessionID) error
}
// NewListenerCallback is a function that returns a net.Listener for the given address and tls.Config struct.
type NewListenerCallback func(address string, tlsConfig *tls.Config) (net.Listener, error)
// Start accepting connections.
func (a *Acceptor) Start() (err error) {
socketAcceptHost := ""
if a.settings.GlobalSettings().HasSetting(config.SocketAcceptHost) {
if socketAcceptHost, err = a.settings.GlobalSettings().Setting(config.SocketAcceptHost); err != nil {
return
}
}
a.sessionHostPort = make(map[SessionID]int)
a.listeners = make(map[string]net.Listener)
for sessionID, sessionSettings := range a.settings.SessionSettings() {
if sessionSettings.HasSetting(config.SocketAcceptPort) {
if a.sessionHostPort[sessionID], err = sessionSettings.IntSetting(config.SocketAcceptPort); err != nil {
return
}
} else if a.sessionHostPort[sessionID], err = a.settings.GlobalSettings().IntSetting(config.SocketAcceptPort); err != nil {
return
}
address := net.JoinHostPort(socketAcceptHost, strconv.Itoa(a.sessionHostPort[sessionID]))
a.listeners[address] = nil
}
if a.tlsConfig == nil {
var tlsConfig *tls.Config
if tlsConfig, err = loadTLSConfig(a.settings.GlobalSettings()); err != nil {
return
}
a.tlsConfig = tlsConfig
}
if a.newListenerCallback == nil {
a.newListenerCallback = func(address string, tlsConfig *tls.Config) (net.Listener, error) {
if tlsConfig != nil {
return tls.Listen("tcp", address, a.tlsConfig)
}
return net.Listen("tcp", address)
}
}
var useTCPProxy bool
if a.settings.GlobalSettings().HasSetting(config.UseTCPProxy) {
if useTCPProxy, err = a.settings.GlobalSettings().BoolSetting(config.UseTCPProxy); err != nil {
return
}
}
for address := range a.listeners {
if a.listeners[address], err = a.newListenerCallback(address, a.tlsConfig); err != nil {
return
} else if useTCPProxy {
a.listeners[address] = &proxyproto.Listener{Listener: a.listeners[address]}
}
}
for _, s := range a.sessions {
a.sessionGroup.Add(1)
go func(s *session) {
s.run()
a.sessionGroup.Done()
}(s)
}
if a.dynamicSessions {
a.dynamicSessionChan = make(chan *session)
a.sessionGroup.Add(1)
go func() {
a.dynamicSessionsLoop()
a.sessionGroup.Done()
}()
}
a.listenerShutdown.Add(len(a.listeners))
for _, listener := range a.listeners {
go a.listenForConnections(listener)
}
return
}
// Stop logs out existing sessions, close their connections, and stop accepting new connections.
func (a *Acceptor) Stop() {
defer func() {
_ = recover() // suppress sending on closed channel error
}()
for _, listener := range a.listeners {
listener.Close()
}
a.listenerShutdown.Wait()
if a.dynamicSessions {
close(a.dynamicSessionChan)
}
for _, session := range a.sessions {
session.stop()
}
a.sessionGroup.Wait()
for sessionID := range a.sessions {
err := UnregisterSession(sessionID)
if err != nil {
return
}
}
}
// RemoteAddr gets remote IP address for a given session.
func (a *Acceptor) RemoteAddr(sessionID SessionID) (net.Addr, bool) {
addr, ok := a.sessionAddr.Load(sessionID)
if !ok || addr == nil {
return nil, false
}
val, ok := addr.(net.Addr)
return val, ok
}
// NewAcceptor creates and initializes a new Acceptor.
func NewAcceptor(app Application, storeFactory MessageStoreFactory, settings *Settings, logFactory LogFactory) (a *Acceptor, err error) {
a = &Acceptor{
app: app,
storeFactory: storeFactory,
settings: settings,
logFactory: logFactory,
sessions: make(map[SessionID]*session),
sessionHostPort: make(map[SessionID]int),
listeners: make(map[string]net.Listener),
}
if a.settings.GlobalSettings().HasSetting(config.DynamicSessions) {
if a.dynamicSessions, err = settings.globalSettings.BoolSetting(config.DynamicSessions); err != nil {
return
}
if a.settings.GlobalSettings().HasSetting(config.DynamicQualifier) {
if a.dynamicQualifier, err = settings.globalSettings.BoolSetting(config.DynamicQualifier); err != nil {
return
}
}
}
if a.globalLog, err = logFactory.Create(); err != nil {
return
}
for sessionID, sessionSettings := range settings.SessionSettings() {
sessID := sessionID
sessID.Qualifier = ""
if _, dup := a.sessions[sessID]; dup {
return a, errDuplicateSessionID
}
if a.sessions[sessID], err = a.createSession(sessionID, storeFactory, sessionSettings, logFactory, app); err != nil {
return
}
}
return
}
func (a *Acceptor) listenForConnections(listener net.Listener) {
defer a.listenerShutdown.Done()
for {
netConn, err := listener.Accept()
if err != nil {
return
}
go func() {
a.handleConnection(netConn)
}()
}
}
func (a *Acceptor) invalidMessage(msg *bytes.Buffer, err error) {
a.globalLog.OnEventf("Invalid Message: %s, %v", msg.Bytes(), err.Error())
}
func (a *Acceptor) handleConnection(netConn net.Conn) {
defer func() {
if err := recover(); err != nil {
a.globalLog.OnEventf("Connection Terminated with Panic: %s", debug.Stack())
}
if err := netConn.Close(); err != nil {
a.globalLog.OnEvent(err.Error())
}
}()
reader := bufio.NewReader(netConn)
parser := newParser(reader)
msgBytes, err := parser.ReadMessage()
if err != nil {
if err == io.EOF {
a.globalLog.OnEvent("Connection Terminated")
} else {
a.globalLog.OnEvent(err.Error())
}
return
}
msg := NewMessage()
err = ParseMessage(msg, msgBytes)
if err != nil {
a.invalidMessage(msgBytes, err)
return
}
var beginString FIXString
if err := msg.Header.GetField(tagBeginString, &beginString); err != nil {
a.invalidMessage(msgBytes, err)
return
}
var senderCompID FIXString
if err := msg.Header.GetField(tagSenderCompID, &senderCompID); err != nil {
a.invalidMessage(msgBytes, err)
return
}
var senderSubID FIXString
if msg.Header.Has(tagSenderSubID) {
if err := msg.Header.GetField(tagSenderSubID, &senderSubID); err != nil {
a.invalidMessage(msgBytes, err)
return
}
}
var senderLocationID FIXString
if msg.Header.Has(tagSenderLocationID) {
if err := msg.Header.GetField(tagSenderLocationID, &senderLocationID); err != nil {
a.invalidMessage(msgBytes, err)
return
}
}
var targetCompID FIXString
if err := msg.Header.GetField(tagTargetCompID, &targetCompID); err != nil {
a.invalidMessage(msgBytes, err)
return
}
var targetSubID FIXString
if msg.Header.Has(tagTargetSubID) {
if err := msg.Header.GetField(tagTargetSubID, &targetSubID); err != nil {
a.invalidMessage(msgBytes, err)
return
}
}
var targetLocationID FIXString
if msg.Header.Has(tagTargetLocationID) {
if err := msg.Header.GetField(tagTargetLocationID, &targetLocationID); err != nil {
a.invalidMessage(msgBytes, err)
return
}
}
sessID := SessionID{BeginString: string(beginString),
SenderCompID: string(targetCompID), SenderSubID: string(targetSubID), SenderLocationID: string(targetLocationID),
TargetCompID: string(senderCompID), TargetSubID: string(senderSubID), TargetLocationID: string(senderLocationID),
}
localConnectionPort := netConn.LocalAddr().(*net.TCPAddr).Port
if expectedPort, ok := a.sessionHostPort[sessID]; ok && expectedPort != localConnectionPort {
a.globalLog.OnEventf("Session %v not found for incoming message: %s", sessID, msgBytes)
return
}
// We have a session ID and a network connection. This seems to be a good place for any custom authentication logic.
if a.connectionValidator != nil {
if err := a.connectionValidator.Validate(netConn, sessID); err != nil {
a.globalLog.OnEventf("Unable to validate a connection for session %v: %v", sessID, err.Error())
return
}
}
if a.dynamicQualifier {
a.dynamicQualifierCount++
sessID.Qualifier = strconv.Itoa(a.dynamicQualifierCount)
}
session, ok := a.sessions[sessID]
if !ok {
if !a.dynamicSessions {
a.globalLog.OnEventf("Session %v not found for incoming message: %s", sessID, msgBytes)
return
}
dynamicSession, err := a.sessionFactory.createSession(sessID, a.storeFactory, a.settings.globalSettings.clone(), a.logFactory, a.app)
if err != nil {
a.globalLog.OnEventf("Dynamic session %v failed to create: %v", sessID, err)
return
}
a.dynamicSessionChan <- dynamicSession
session = dynamicSession
defer session.stop()
}
a.sessionAddr.Store(sessID, netConn.RemoteAddr())
msgIn := make(chan fixIn, session.InChanCapacity)
msgOut := make(chan []byte)
if err := session.connect(msgIn, msgOut); err != nil {
a.globalLog.OnEventf("Unable to accept session %v connection: %v", sessID, err.Error())
return
}
go func() {
msgIn <- fixIn{msgBytes, parser.lastRead}
readLoop(parser, msgIn, a.globalLog)
}()
writeLoop(netConn, msgOut, a.globalLog)
}
func (a *Acceptor) dynamicSessionsLoop() {
var id int
var sessions = map[int]*session{}
var complete = make(chan int)
defer close(complete)
LOOP:
for {
select {
case session, ok := <-a.dynamicSessionChan:
if !ok {
for _, oldSession := range sessions {
oldSession.stop()
}
break LOOP
}
id++
sessionID := id
sessions[sessionID] = session
go func() {
session.run()
err := UnregisterSession(session.sessionID)
if err != nil {
a.globalLog.OnEventf("Unregister dynamic session %v failed: %v", session.sessionID, err)
return
}
complete <- sessionID
}()
case id := <-complete:
session, ok := sessions[id]
if ok {
a.sessionAddr.Delete(session.sessionID)
delete(sessions, id)
} else {
a.globalLog.OnEventf("Missing dynamic session %v!", id)
}
}
}
if len(sessions) == 0 {
return
}
for id := range complete {
delete(sessions, id)
if len(sessions) == 0 {
return
}
}
}
// SetConnectionValidator sets an optional connection validator.
// Use it when you need a custom authentication logic that includes lower level interactions,
// like mTLS auth or IP whitelistening.
// To remove a previously set validator call it with a nil value:
//
// a.SetConnectionValidator(nil)
func (a *Acceptor) SetConnectionValidator(validator ConnectionValidator) {
a.connectionValidator = validator
}
// SetTLSConfig allows the creator of the Acceptor to specify a fully customizable tls.Config of their choice,
// which will be used in the Start() method.
//
// Note: when the caller explicitly provides a tls.Config with this function,
// it takes precendent over TLS settings specified in the acceptor's settings.GlobalSettings(),
// meaning that the `settings.GlobalSettings()` object is not inspected or used for the creation of the tls.Config.
func (a *Acceptor) SetTLSConfig(tlsConfig *tls.Config) {
a.tlsConfig = tlsConfig
}
// SetNewListenerCallback allows the creator of the Acceptor to specify the callback used to create each net.Listener
// which will be used in the Start() method.
func (a *Acceptor) SetNewListenerCallback(cb NewListenerCallback) {
a.newListenerCallback = cb
}