// Copyright (c) quickfixengine.org All rights reserved. // // This file may be distributed under the terms of the quickfixengine.org // license as defined by quickfixengine.org and appearing in the file // LICENSE included in the packaging of this file. // // This file is provided AS IS with NO WARRANTY OF ANY KIND, INCLUDING // THE WARRANTY OF DESIGN, MERCHANTABILITY AND FITNESS FOR A // PARTICULAR PURPOSE. // // See http://www.quickfixengine.org/LICENSE for licensing information. // // Contact ask@quickfixengine.org if any conditions of this licensing // are not clear to you. package quickfix import ( "bufio" "bytes" "crypto/tls" "io" "net" "runtime/debug" "strconv" "sync" proxyproto "github.com/pires/go-proxyproto" "quantex.com/qfixdpl/quickfix/config" ) // Acceptor accepts connections from FIX clients and manages the associated sessions. type Acceptor struct { app Application settings *Settings logFactory LogFactory storeFactory MessageStoreFactory globalLog Log sessions map[SessionID]*session sessionGroup sync.WaitGroup listenerShutdown sync.WaitGroup dynamicSessions bool dynamicQualifier bool dynamicQualifierCount int dynamicSessionChan chan *session sessionAddr sync.Map sessionHostPort map[SessionID]int listeners map[string]net.Listener connectionValidator ConnectionValidator tlsConfig *tls.Config newListenerCallback NewListenerCallback sessionFactory } // ConnectionValidator is an interface allowing to implement a custom authentication logic. type ConnectionValidator interface { // Validate the connection for validity. This can be a part of authentication process. // For example, you may tie up a SenderCompID to an IP range, or to a specific TLS certificate as a part of mTLS. Validate(netConn net.Conn, session SessionID) error } // NewListenerCallback is a function that returns a net.Listener for the given address and tls.Config struct. type NewListenerCallback func(address string, tlsConfig *tls.Config) (net.Listener, error) // Start accepting connections. func (a *Acceptor) Start() (err error) { socketAcceptHost := "" if a.settings.GlobalSettings().HasSetting(config.SocketAcceptHost) { if socketAcceptHost, err = a.settings.GlobalSettings().Setting(config.SocketAcceptHost); err != nil { return } } a.sessionHostPort = make(map[SessionID]int) a.listeners = make(map[string]net.Listener) for sessionID, sessionSettings := range a.settings.SessionSettings() { if sessionSettings.HasSetting(config.SocketAcceptPort) { if a.sessionHostPort[sessionID], err = sessionSettings.IntSetting(config.SocketAcceptPort); err != nil { return } } else if a.sessionHostPort[sessionID], err = a.settings.GlobalSettings().IntSetting(config.SocketAcceptPort); err != nil { return } address := net.JoinHostPort(socketAcceptHost, strconv.Itoa(a.sessionHostPort[sessionID])) a.listeners[address] = nil } if a.tlsConfig == nil { var tlsConfig *tls.Config if tlsConfig, err = loadTLSConfig(a.settings.GlobalSettings()); err != nil { return } a.tlsConfig = tlsConfig } if a.newListenerCallback == nil { a.newListenerCallback = func(address string, tlsConfig *tls.Config) (net.Listener, error) { if tlsConfig != nil { return tls.Listen("tcp", address, a.tlsConfig) } return net.Listen("tcp", address) } } var useTCPProxy bool if a.settings.GlobalSettings().HasSetting(config.UseTCPProxy) { if useTCPProxy, err = a.settings.GlobalSettings().BoolSetting(config.UseTCPProxy); err != nil { return } } for address := range a.listeners { if a.listeners[address], err = a.newListenerCallback(address, a.tlsConfig); err != nil { return } else if useTCPProxy { a.listeners[address] = &proxyproto.Listener{Listener: a.listeners[address]} } } for _, s := range a.sessions { a.sessionGroup.Add(1) go func(s *session) { s.run() a.sessionGroup.Done() }(s) } if a.dynamicSessions { a.dynamicSessionChan = make(chan *session) a.sessionGroup.Add(1) go func() { a.dynamicSessionsLoop() a.sessionGroup.Done() }() } a.listenerShutdown.Add(len(a.listeners)) for _, listener := range a.listeners { go a.listenForConnections(listener) } return } // Stop logs out existing sessions, close their connections, and stop accepting new connections. func (a *Acceptor) Stop() { defer func() { _ = recover() // suppress sending on closed channel error }() for _, listener := range a.listeners { listener.Close() } a.listenerShutdown.Wait() if a.dynamicSessions { close(a.dynamicSessionChan) } for _, session := range a.sessions { session.stop() } a.sessionGroup.Wait() for sessionID := range a.sessions { err := UnregisterSession(sessionID) if err != nil { return } } } // RemoteAddr gets remote IP address for a given session. func (a *Acceptor) RemoteAddr(sessionID SessionID) (net.Addr, bool) { addr, ok := a.sessionAddr.Load(sessionID) if !ok || addr == nil { return nil, false } val, ok := addr.(net.Addr) return val, ok } // NewAcceptor creates and initializes a new Acceptor. func NewAcceptor(app Application, storeFactory MessageStoreFactory, settings *Settings, logFactory LogFactory) (a *Acceptor, err error) { a = &Acceptor{ app: app, storeFactory: storeFactory, settings: settings, logFactory: logFactory, sessions: make(map[SessionID]*session), sessionHostPort: make(map[SessionID]int), listeners: make(map[string]net.Listener), } if a.settings.GlobalSettings().HasSetting(config.DynamicSessions) { if a.dynamicSessions, err = settings.globalSettings.BoolSetting(config.DynamicSessions); err != nil { return } if a.settings.GlobalSettings().HasSetting(config.DynamicQualifier) { if a.dynamicQualifier, err = settings.globalSettings.BoolSetting(config.DynamicQualifier); err != nil { return } } } if a.globalLog, err = logFactory.Create(); err != nil { return } for sessionID, sessionSettings := range settings.SessionSettings() { sessID := sessionID sessID.Qualifier = "" if _, dup := a.sessions[sessID]; dup { return a, errDuplicateSessionID } if a.sessions[sessID], err = a.createSession(sessionID, storeFactory, sessionSettings, logFactory, app); err != nil { return } } return } func (a *Acceptor) listenForConnections(listener net.Listener) { defer a.listenerShutdown.Done() for { netConn, err := listener.Accept() if err != nil { return } go func() { a.handleConnection(netConn) }() } } func (a *Acceptor) invalidMessage(msg *bytes.Buffer, err error) { a.globalLog.OnEventf("Invalid Message: %s, %v", msg.Bytes(), err.Error()) } func (a *Acceptor) handleConnection(netConn net.Conn) { defer func() { if err := recover(); err != nil { a.globalLog.OnEventf("Connection Terminated with Panic: %s", debug.Stack()) } if err := netConn.Close(); err != nil { a.globalLog.OnEvent(err.Error()) } }() reader := bufio.NewReader(netConn) parser := newParser(reader) msgBytes, err := parser.ReadMessage() if err != nil { if err == io.EOF { a.globalLog.OnEvent("Connection Terminated") } else { a.globalLog.OnEvent(err.Error()) } return } msg := NewMessage() err = ParseMessage(msg, msgBytes) if err != nil { a.invalidMessage(msgBytes, err) return } var beginString FIXString if err := msg.Header.GetField(tagBeginString, &beginString); err != nil { a.invalidMessage(msgBytes, err) return } var senderCompID FIXString if err := msg.Header.GetField(tagSenderCompID, &senderCompID); err != nil { a.invalidMessage(msgBytes, err) return } var senderSubID FIXString if msg.Header.Has(tagSenderSubID) { if err := msg.Header.GetField(tagSenderSubID, &senderSubID); err != nil { a.invalidMessage(msgBytes, err) return } } var senderLocationID FIXString if msg.Header.Has(tagSenderLocationID) { if err := msg.Header.GetField(tagSenderLocationID, &senderLocationID); err != nil { a.invalidMessage(msgBytes, err) return } } var targetCompID FIXString if err := msg.Header.GetField(tagTargetCompID, &targetCompID); err != nil { a.invalidMessage(msgBytes, err) return } var targetSubID FIXString if msg.Header.Has(tagTargetSubID) { if err := msg.Header.GetField(tagTargetSubID, &targetSubID); err != nil { a.invalidMessage(msgBytes, err) return } } var targetLocationID FIXString if msg.Header.Has(tagTargetLocationID) { if err := msg.Header.GetField(tagTargetLocationID, &targetLocationID); err != nil { a.invalidMessage(msgBytes, err) return } } sessID := SessionID{BeginString: string(beginString), SenderCompID: string(targetCompID), SenderSubID: string(targetSubID), SenderLocationID: string(targetLocationID), TargetCompID: string(senderCompID), TargetSubID: string(senderSubID), TargetLocationID: string(senderLocationID), } localConnectionPort := netConn.LocalAddr().(*net.TCPAddr).Port if expectedPort, ok := a.sessionHostPort[sessID]; ok && expectedPort != localConnectionPort { a.globalLog.OnEventf("Session %v not found for incoming message: %s", sessID, msgBytes) return } // We have a session ID and a network connection. This seems to be a good place for any custom authentication logic. if a.connectionValidator != nil { if err := a.connectionValidator.Validate(netConn, sessID); err != nil { a.globalLog.OnEventf("Unable to validate a connection for session %v: %v", sessID, err.Error()) return } } if a.dynamicQualifier { a.dynamicQualifierCount++ sessID.Qualifier = strconv.Itoa(a.dynamicQualifierCount) } session, ok := a.sessions[sessID] if !ok { if !a.dynamicSessions { a.globalLog.OnEventf("Session %v not found for incoming message: %s", sessID, msgBytes) return } dynamicSession, err := a.sessionFactory.createSession(sessID, a.storeFactory, a.settings.globalSettings.clone(), a.logFactory, a.app) if err != nil { a.globalLog.OnEventf("Dynamic session %v failed to create: %v", sessID, err) return } a.dynamicSessionChan <- dynamicSession session = dynamicSession defer session.stop() } a.sessionAddr.Store(sessID, netConn.RemoteAddr()) msgIn := make(chan fixIn, session.InChanCapacity) msgOut := make(chan []byte) if err := session.connect(msgIn, msgOut); err != nil { a.globalLog.OnEventf("Unable to accept session %v connection: %v", sessID, err.Error()) return } go func() { msgIn <- fixIn{msgBytes, parser.lastRead} readLoop(parser, msgIn, a.globalLog) }() writeLoop(netConn, msgOut, a.globalLog) } func (a *Acceptor) dynamicSessionsLoop() { var id int var sessions = map[int]*session{} var complete = make(chan int) defer close(complete) LOOP: for { select { case session, ok := <-a.dynamicSessionChan: if !ok { for _, oldSession := range sessions { oldSession.stop() } break LOOP } id++ sessionID := id sessions[sessionID] = session go func() { session.run() err := UnregisterSession(session.sessionID) if err != nil { a.globalLog.OnEventf("Unregister dynamic session %v failed: %v", session.sessionID, err) return } complete <- sessionID }() case id := <-complete: session, ok := sessions[id] if ok { a.sessionAddr.Delete(session.sessionID) delete(sessions, id) } else { a.globalLog.OnEventf("Missing dynamic session %v!", id) } } } if len(sessions) == 0 { return } for id := range complete { delete(sessions, id) if len(sessions) == 0 { return } } } // SetConnectionValidator sets an optional connection validator. // Use it when you need a custom authentication logic that includes lower level interactions, // like mTLS auth or IP whitelistening. // To remove a previously set validator call it with a nil value: // // a.SetConnectionValidator(nil) func (a *Acceptor) SetConnectionValidator(validator ConnectionValidator) { a.connectionValidator = validator } // SetTLSConfig allows the creator of the Acceptor to specify a fully customizable tls.Config of their choice, // which will be used in the Start() method. // // Note: when the caller explicitly provides a tls.Config with this function, // it takes precendent over TLS settings specified in the acceptor's settings.GlobalSettings(), // meaning that the `settings.GlobalSettings()` object is not inspected or used for the creation of the tls.Config. func (a *Acceptor) SetTLSConfig(tlsConfig *tls.Config) { a.tlsConfig = tlsConfig } // SetNewListenerCallback allows the creator of the Acceptor to specify the callback used to create each net.Listener // which will be used in the Start() method. func (a *Acceptor) SetNewListenerCallback(cb NewListenerCallback) { a.newListenerCallback = cb }