Files
qfixpt/quickfix/quickfix_test.go
2026-03-12 12:14:13 -03:00

307 lines
7.5 KiB
Go

// Copyright (c) quickfixengine.org All rights reserved.
//
// This file may be distributed under the terms of the quickfixengine.org
// license as defined by quickfixengine.org and appearing in the file
// LICENSE included in the packaging of this file.
//
// This file is provided AS IS with NO WARRANTY OF ANY KIND, INCLUDING
// THE WARRANTY OF DESIGN, MERCHANTABILITY AND FITNESS FOR A
// PARTICULAR PURPOSE.
//
// See http://www.quickfixengine.org/LICENSE for licensing information.
//
// Contact ask@quickfixengine.org if any conditions of this licensing
// are not clear to you.
package quickfix
import (
"time"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"quantex.com/qfixpt/quickfix/internal"
)
type QuickFIXSuite struct {
suite.Suite
}
type KnowsFieldMap interface {
Has(Tag) bool
GetString(Tag) (string, MessageRejectError)
GetInt(Tag) (int, MessageRejectError)
GetField(Tag, FieldValueReader) MessageRejectError
}
func (s *QuickFIXSuite) MessageType(msgType string, msg *Message) {
s.FieldEquals(tagMsgType, msgType, msg.Header)
}
func (s *QuickFIXSuite) FieldEquals(tag Tag, expectedValue interface{}, fieldMap KnowsFieldMap) {
s.Require().True(fieldMap.Has(tag), "Tag %v not set", tag)
switch expected := expectedValue.(type) {
case string:
val, err := fieldMap.GetString(tag)
s.Nil(err)
s.Equal(expected, val)
case int:
val, err := fieldMap.GetInt(tag)
s.Nil(err)
s.Equal(expected, val)
case bool:
var val FIXBoolean
err := fieldMap.GetField(tag, &val)
s.Nil(err)
s.Equal(expected, val.Bool())
default:
s.FailNow("Field type not handled")
}
}
func (s *QuickFIXSuite) MessageEqualsBytes(expectedBytes []byte, msg *Message) {
actualBytes := msg.build()
s.Equal(string(actualBytes), string(expectedBytes))
}
// MockStore wraps a memory store and mocks Refresh for convenience.
type MockStore struct {
mock.Mock
memoryStore
}
func (s *MockStore) Refresh() error {
return s.Called().Error(0)
}
type MockApp struct {
mock.Mock
decorateToAdmin func(*Message)
lastToAdmin *Message
lastToApp *Message
}
func (e *MockApp) OnCreate(_ SessionID) {
}
func (e *MockApp) OnLogon(_ SessionID) {
e.Called()
}
func (e *MockApp) OnLogout(_ SessionID) {
e.Called()
}
func (e *MockApp) FromAdmin(_ *Message, _ SessionID) (reject MessageRejectError) {
if err, ok := e.Called().Get(0).(MessageRejectError); ok {
return err
}
return nil
}
func (e *MockApp) ToAdmin(msg *Message, _ SessionID) {
e.Called()
if e.decorateToAdmin != nil {
e.decorateToAdmin(msg)
}
e.lastToAdmin = msg
}
func (e *MockApp) ToApp(msg *Message, _ SessionID) (err error) {
e.lastToApp = msg
return e.Called().Error(0)
}
func (e *MockApp) FromApp(_ *Message, _ SessionID) (reject MessageRejectError) {
if err, ok := e.Called().Get(0).(MessageRejectError); ok {
return err
}
return nil
}
type MessageFactory struct {
seqNum int
}
func (m *MessageFactory) SetNextSeqNum(next int) {
m.seqNum = next - 1
}
func (m *MessageFactory) buildMessage(msgType string) *Message {
m.seqNum++
msg := NewMessage()
msg.Header.
SetField(tagBeginString, FIXString(string(BeginStringFIX42))).
SetField(tagSenderCompID, FIXString("TW")).
SetField(tagTargetCompID, FIXString("ISLD")).
SetField(tagSendingTime, FIXUTCTimestamp{Time: time.Now()}).
SetField(tagMsgSeqNum, FIXInt(m.seqNum)).
SetField(tagMsgType, FIXString(msgType))
return msg
}
func (m *MessageFactory) Logout() *Message {
return m.buildMessage(string(msgTypeLogout))
}
func (m *MessageFactory) NewOrderSingle() *Message {
return m.buildMessage("D")
}
func (m *MessageFactory) Heartbeat() *Message {
return m.buildMessage(string(msgTypeHeartbeat))
}
func (m *MessageFactory) Logon() *Message {
return m.buildMessage(string(msgTypeLogon))
}
func (m *MessageFactory) ResendRequest(beginSeqNo int) *Message {
msg := m.buildMessage(string(msgTypeResendRequest))
msg.Body.SetField(tagBeginSeqNo, FIXInt(beginSeqNo))
msg.Body.SetField(tagEndSeqNo, FIXInt(0))
return msg
}
func (m *MessageFactory) SequenceReset(seqNo int) *Message {
msg := m.buildMessage(string(msgTypeSequenceReset))
msg.Body.SetField(tagNewSeqNo, FIXInt(seqNo))
return msg
}
type MockSessionReceiver struct {
sendChannel chan []byte
}
func newMockSessionReceiver() MockSessionReceiver {
return MockSessionReceiver{
sendChannel: make(chan []byte, 10),
}
}
func (p *MockSessionReceiver) LastMessage() (msg []byte, ok bool) {
select {
case msg, ok = <-p.sendChannel:
default:
ok = true
}
return
}
type SessionSuiteRig struct {
QuickFIXSuite
MessageFactory
MockApp MockApp
MockStore MockStore
*session
Receiver MockSessionReceiver
}
func (s *SessionSuiteRig) Init() {
s.MockApp = MockApp{}
s.MockStore = MockStore{}
s.MessageFactory = MessageFactory{}
s.Receiver = newMockSessionReceiver()
s.session = &session{
sessionID: SessionID{BeginString: "FIX.4.2", TargetCompID: "TW", SenderCompID: "ISLD"},
store: &s.MockStore,
application: &s.MockApp,
log: nullLog{},
messageOut: s.Receiver.sendChannel,
sessionEvent: make(chan internal.Event),
}
s.MaxLatency = 120 * time.Second
}
func (s *SessionSuiteRig) State(state sessionState) {
s.IsType(state, s.session.State, "session state should be %v", state)
}
func (s *SessionSuiteRig) MessageSentEquals(msg *Message) {
msgBytes, ok := s.Receiver.LastMessage()
s.True(ok, "Should be connected")
s.NotNil(msgBytes, "Message should have been sent")
s.MessageEqualsBytes(msgBytes, msg)
}
func (s *SessionSuiteRig) LastToAppMessageSent() {
s.MessageSentEquals(s.MockApp.lastToApp)
}
func (s *SessionSuiteRig) LastToAdminMessageSent() {
require.NotNil(s.T(), s.MockApp.lastToAdmin, "No ToAdmin received")
s.MessageSentEquals(s.MockApp.lastToAdmin)
}
func (s *SessionSuiteRig) NotStopped() {
s.False(s.session.Stopped(), "session should not be stopped")
}
func (s *SessionSuiteRig) Stopped() {
s.True(s.session.Stopped(), "session should be stopped")
}
func (s *SessionSuiteRig) Disconnected() {
msg, ok := s.Receiver.LastMessage()
s.Nil(msg, "Expect disconnect, not message")
s.False(ok, "Expect disconnect")
}
func (s *SessionSuiteRig) NoMessageSent() {
msg, _ := s.Receiver.LastMessage()
s.Nil(msg, "no message should be sent but got %s", msg)
}
func (s *SessionSuiteRig) NoMessageQueued() {
s.Empty(s.session.toSend, "no messages should be queueud")
}
func (s *SessionSuiteRig) ExpectStoreReset() {
s.NextSenderMsgSeqNum(1)
s.NextTargetMsgSeqNum(1)
}
func (s *SessionSuiteRig) NextTargetMsgSeqNum(expected int) {
s.Equal(expected, s.session.store.NextTargetMsgSeqNum(), "NextTargetMsgSeqNum should be %v ", expected)
}
func (s *SessionSuiteRig) NextSenderMsgSeqNum(expected int) {
s.Equal(expected, s.session.store.NextSenderMsgSeqNum(), "NextSenderMsgSeqNum should be %v", expected)
}
func (s *SessionSuiteRig) IncrNextSenderMsgSeqNum() {
s.Require().Nil(s.session.store.IncrNextSenderMsgSeqNum())
}
func (s *SessionSuiteRig) IncrNextTargetMsgSeqNum() {
s.Require().Nil(s.session.store.IncrNextTargetMsgSeqNum())
}
func (s *SessionSuiteRig) NoMessagePersisted(seqNum int) {
persistedMessages, err := s.session.store.GetMessages(seqNum, seqNum)
s.Nil(err)
s.Empty(persistedMessages, "The message should not be persisted")
}
func (s *SessionSuiteRig) MessagePersisted(msg *Message) {
var err error
seqNum, err := msg.Header.GetInt(tagMsgSeqNum)
s.Nil(err, "message should have seq num")
persistedMessages, err := s.session.store.GetMessages(seqNum, seqNum)
s.Nil(err)
s.Len(persistedMessages, 1, "a message should be stored at %v", seqNum)
s.MessageEqualsBytes(persistedMessages[0], msg)
}