diff --git a/crates/hotfix-message/src/builder.rs b/crates/hotfix-message/src/builder.rs index fff0404b..28b30ac3 100644 --- a/crates/hotfix-message/src/builder.rs +++ b/crates/hotfix-message/src/builder.rs @@ -11,6 +11,8 @@ use std::collections::{HashMap, HashSet}; pub const SOH: u8 = 0x1; +const USER_DEFINED_TAG_MIN: u32 = 5000; + /// Length of the checksum field. /// /// It should always be 7 bytes: @@ -197,25 +199,43 @@ impl MessageBuilder { let mut body = Body::default(); let mut field = next_field; - while message_def.contains_tag(field.tag) { - let tag = field.tag; - body.store_field(field); - - // check if it's the start of a group and parse the group as needed - let field_def = self.get_dict_field_by_tag(tag.get())?; - match message_def.get_group(tag) { - Some(group_def) => { - let (groups, next) = Self::parse_groups(parser, group_def, field_def.tag())?; - #[allow(clippy::expect_used)] - body.set_groups(groups) - .expect("groups are guaranteed to be valid at this point"); - field = next; - } - None => { - field = parser.next_field().ok_or(ParserError::Malformed( - "message ended within the body".to_string(), - ))?; + loop { + if message_def.contains_tag(field.tag) { + let tag = field.tag; + body.store_field(field); + + // check if it's the start of a group and parse the group as needed + let field_def = self.get_dict_field_by_tag(tag.get())?; + match message_def.get_group(tag) { + Some(group_def) => { + let (groups, next) = self.parse_groups( + parser, + group_def, + field_def.tag(), + ParentScope::Message(message_def), + )?; + Self::check_required_fields_in_groups(&groups, group_def)?; + body.set_groups(groups).map_err(|err| { + ParserError::Malformed(format!( + "failed to set groups for tag {}: {err}", + tag.get() + )) + })?; + field = next; + } + None => { + field = parser.next_field().ok_or(ParserError::Malformed( + "message ended within the body".to_string(), + ))?; + } } + } else if self.should_accept_unknown_user_defined(field.tag) { + body.store_field(field); + field = parser.next_field().ok_or(ParserError::Malformed( + "message ended within the body".to_string(), + ))?; + } else { + break; } } @@ -226,6 +246,14 @@ impl MessageBuilder { Ok((body, field)) } + fn should_accept_unknown_user_defined(&self, tag: TagU32) -> bool { + !self.config.validate_user_defined_fields && Self::is_user_defined(tag) + } + + fn is_user_defined(tag: TagU32) -> bool { + tag.get() >= USER_DEFINED_TAG_MIN + } + fn build_trailer(&self, trailer: &mut Trailer, parser: &mut Parser, next_field: Field) { let mut field = Some(next_field); while let Some(f) = field { @@ -238,64 +266,151 @@ impl MessageBuilder { } fn parse_groups( + &self, parser: &mut Parser, group_def: &GroupSpecification, start_tag: TagU32, + parent: ParentScope<'_>, ) -> ParserResult<(Vec, Field)> { let mut groups = vec![]; + let delimiter_tag = group_def.delimiter_tag(); let mut field = parser.next_field().ok_or(ParserError::Malformed( "missing delimiter field".to_string(), ))?; + + if field.tag != delimiter_tag { + return Err(ParserError::InvalidGroupFieldOrder { + tag: field.tag.get(), + group_tag: group_def.number_of_entries_tag().get(), + }); + } + + let mut current_group: Option = None; + let mut next_declared_idx: usize = 0; + loop { - let mut group = RepeatingGroup::new_with_tags(start_tag, group_def.delimiter_tag()); - - // we skip the first field as we've already stored the delimiter - for field_def in group_def.fields().iter() { - let is_required = - field_def.is_required || field_def.tag == group_def.delimiter_tag(); - let current_tag = field.tag; - if field_def.tag == current_tag { - // the next tag is the next expected field's tag in the group, store it and move on - group.store_field(field); - field = if let Some(nested_group_def) = group_def.get_nested_group(current_tag) - { - let (groups, next) = - Self::parse_groups(parser, nested_group_def, current_tag)?; - #[allow(clippy::expect_used)] - group - .set_groups(groups) - .expect("groups are guaranteed to be valid at this point"); - next - } else { - parser - .next_field() - .ok_or(ParserError::Malformed("incomplete group".to_string()))? - } - } else if !is_required { - // this field isn't required in the group, so it's fine to skip it + let current_tag = field.tag; + + if current_tag == delimiter_tag { + // delimiter starts a new group instance + if let Some(g) = current_group.take() { + groups.push(g); + } + let mut new_group = RepeatingGroup::new_with_tags(start_tag, delimiter_tag); + new_group.store_field(field); + current_group = Some(new_group); + next_declared_idx = 1; + field = parser + .next_field() + .ok_or(ParserError::Malformed("incomplete group".to_string()))?; + } else if group_def.contains_tag(current_tag) { + // tag is declared in this group; enforce declaration order + let position = group_def + .fields() + .iter() + .position(|f| f.tag == current_tag) + .ok_or_else(|| { + ParserError::Malformed(format!( + "field {} reported as in group {} but missing from fields()", + current_tag.get(), + group_def.number_of_entries_tag().get() + )) + })?; + if position < next_declared_idx { + return Err(ParserError::InvalidGroupFieldOrder { + tag: current_tag.get(), + group_tag: group_def.number_of_entries_tag().get(), + }); + } + next_declared_idx = position + 1; + + let group = current_group.as_mut().ok_or_else(|| { + ParserError::Malformed(format!( + "no group started before field {} in group {}", + current_tag.get(), + group_def.number_of_entries_tag().get() + )) + })?; + group.store_field(field); + + field = if let Some(nested_group_def) = group_def.get_nested_group(current_tag) { + let (nested_groups, next) = self.parse_groups( + parser, + nested_group_def, + current_tag, + ParentScope::Group(group_def), + )?; + group.set_groups(nested_groups).map_err(|err| { + ParserError::Malformed(format!( + "failed to set nested groups for tag {}: {err}", + current_tag.get() + )) + })?; + next } else { - // the next field in the group is required but the next field in the message isn't it - let err = if group_def.contains_tag(field.tag) { - ParserError::InvalidGroupFieldOrder { - tag: field.tag.get(), - group_tag: group_def.number_of_entries_tag().get(), - } - } else { - ParserError::InvalidField(field.tag.get()) - }; - return Err(err); + parser + .next_field() + .ok_or(ParserError::Malformed("incomplete group".to_string()))? + }; + } else if Self::parent_contains_tag(parent, current_tag) + || self.is_trailer_tag(current_tag) + { + // tag belongs to the enclosing scope (or is a trailer field), so + // close the current group and hand the field back to the caller + if let Some(g) = current_group.take() { + groups.push(g); } + return Ok((groups, field)); + } else if self.should_accept_unknown_user_defined(current_tag) { + let group = current_group.as_mut().ok_or_else(|| { + ParserError::Malformed(format!( + "no group started before unknown user-defined field {} in group {}", + current_tag.get(), + group_def.number_of_entries_tag().get() + )) + })?; + group.store_field(field); + field = parser + .next_field() + .ok_or(ParserError::Malformed("incomplete group".to_string()))?; + } else { + return Err(ParserError::InvalidField(current_tag.get())); } + } + } - // we've checked all fields for this group, - // it's either another group in the repeating group or the end of the repeating group - groups.push(group); + fn parent_contains_tag(parent: ParentScope<'_>, tag: TagU32) -> bool { + match parent { + ParentScope::Message(message_def) => message_def.contains_tag(tag), + ParentScope::Group(group_def) => group_def.contains_tag(tag), + } + } - if !group_def.contains_tag(field.tag) { - return Ok((groups, field)); + fn check_required_fields_in_groups( + groups: &[RepeatingGroup], + group_def: &GroupSpecification, + ) -> ParserResult<()> { + let group_tag = group_def.number_of_entries_tag().get(); + for entry in groups { + let field_map = entry.get_fields(); + + for field_def in group_def.fields() { + if field_def.is_required && !field_map.fields.contains_key(&field_def.tag) { + return Err(ParserError::RequiredFieldMissing { + tag: field_def.tag.get(), + group_tag: Some(group_tag), + }); + } + } + + for (nested_tag, nested_spec) in &group_def.nested_groups { + if let Some(nested_entries) = field_map.groups.get(nested_tag) { + Self::check_required_fields_in_groups(nested_entries, nested_spec)?; + } } } + Ok(()) } fn get_dict_field_by_tag(&self, tag: u32) -> ParserResult> { @@ -407,6 +522,10 @@ fn parser_error_to_parsed_message(err: ParserError, header: Header) -> ParsedMes reason: InvalidReason::InvalidMsgType(msg_type), message: Message::with_header(header), }, + ParserError::RequiredFieldMissing { tag, group_tag } => ParsedMessage::Invalid { + reason: InvalidReason::RequiredFieldMissing { tag, group_tag }, + message: Message::with_header(header), + }, ParserError::Malformed(_) => ParsedMessage::Garbled(GarbledReason::Malformed), } } @@ -416,6 +535,12 @@ struct FieldSpecification { pub(crate) is_required: bool, } +#[derive(Clone, Copy)] +enum ParentScope<'a> { + Message(&'a MessageSpecification), + Group(&'a GroupSpecification), +} + struct GroupSpecification { number_of_entries_tag: TagU32, fields: Vec, @@ -872,4 +997,251 @@ mod tests { assert_eq!(alloc_2.get::(fix44::COMMISSION).unwrap(), 75.0); assert_eq!(alloc_2.get::<&str>(fix44::COMM_TYPE).unwrap(), "2"); } + + const CONFIG_NO_USER_VALIDATION: Config = + Config::with_separator(b'|').validate_user_defined_fields(false); + + fn build_pipe_separated_message(content: &str) -> Vec { + let body_length = content.len(); + let prefix = format!("8=FIX.4.4|9={body_length}|{content}"); + let checksum: u8 = prefix.bytes().fold(0u8, |acc, x| acc.wrapping_add(x)); + format!("{prefix}10={checksum:03}|").into_bytes() + } + + #[test] + fn test_unknown_user_defined_tag_at_body_rejected_by_default() { + let raw = + build_pipe_separated_message("35=D|49=SENDER|56=TARGET|55=AAPL|59=0|20000=custom|"); + let builder = MessageBuilder::new(Dictionary::fix44(), CONFIG).unwrap(); + let parsed = builder.build(&raw); + + assert!(matches!( + parsed, + ParsedMessage::Invalid { + reason: InvalidReason::InvalidField(20000), + .. + } + )); + } + + #[test] + fn test_unknown_user_defined_tag_at_body_accepted_when_validation_disabled() { + let raw = + build_pipe_separated_message("35=D|49=SENDER|56=TARGET|55=AAPL|59=0|20000=custom|"); + let builder = MessageBuilder::new(Dictionary::fix44(), CONFIG_NO_USER_VALIDATION).unwrap(); + let message = builder.build(&raw).into_message().unwrap(); + + let custom_tag = TagU32::new(20000).unwrap(); + let custom_value = message.get_field_map().get_raw(custom_tag).unwrap(); + assert_eq!(custom_value, b"custom"); + + // Known fields should still be reachable. + assert_eq!(message.get::<&str>(fix44::SYMBOL).unwrap(), "AAPL"); + } + + #[test] + fn test_unknown_user_defined_tag_inside_group_stays_in_group_when_validation_disabled() { + let raw = build_pipe_separated_message( + "35=D|49=SENDER|56=TARGET|453=1|448=PARTYA|447=D|452=1|20000=custom|55=AAPL|", + ); + let builder = MessageBuilder::new(Dictionary::fix44(), CONFIG_NO_USER_VALIDATION).unwrap(); + let message = builder.build(&raw).into_message().unwrap(); + + // 20000 doesn't belong to the parent message and isn't a + // trailer field, so with user-defined validation disabled it stays + // inside the current group rather than leaking to the body. + let party = message.get_group(fix44::NO_PARTY_I_DS, 0).unwrap(); + assert_eq!(party.get::<&str>(fix44::PARTY_ID).unwrap(), "PARTYA"); + assert!(message.get_group(fix44::NO_PARTY_I_DS, 1).is_none()); + + let custom_tag = TagU32::new(20000).unwrap(); + let in_group = party.get_field_map().get_raw(custom_tag).unwrap(); + assert_eq!(in_group, b"custom"); + + // The body must NOT have picked up the unknown tag. + assert!(message.get_field_map().get_raw(custom_tag).is_none()); + + // Tag 55 belongs to NewOrderSingle, so it ends the group and lands on the body. + assert_eq!(message.get::<&str>(fix44::SYMBOL).unwrap(), "AAPL"); + } + + #[test] + fn test_unknown_user_defined_tag_inside_group_rejected_when_validation_enabled() { + let raw = build_pipe_separated_message( + "35=D|49=SENDER|56=TARGET|453=1|448=PARTYA|447=D|452=1|20000=custom|55=AAPL|", + ); + let builder = MessageBuilder::new(Dictionary::fix44(), CONFIG).unwrap(); + let parsed = builder.build(&raw); + + assert!(matches!( + parsed, + ParsedMessage::Invalid { + reason: InvalidReason::InvalidField(20000), + .. + } + )); + } + + #[test] + fn test_unknown_user_defined_tag_inside_nested_group_stays_in_nested_group() { + // 20000 appears inside a NoPartySubIDs entry. NoPartySubIDs is the + // immediate parent (a Group) and doesn't contain 20000, so the field + // is kept inside the nested entry. The trailer (10) ends both the + // nested group and the outer NoPartyIDs group cleanly. + let raw = build_pipe_separated_message( + "35=D|49=SENDER|56=TARGET|453=1|448=PARTYA|447=D|452=1|802=1|523=SUB1|803=1|20000=custom|", + ); + let builder = MessageBuilder::new(Dictionary::fix44(), CONFIG_NO_USER_VALIDATION).unwrap(); + let message = builder.build(&raw).into_message().unwrap(); + + let party = message.get_group(fix44::NO_PARTY_I_DS, 0).unwrap(); + let sub = party.get_group(fix44::NO_PARTY_SUB_I_DS.tag(), 0).unwrap(); + assert_eq!(sub.get::<&str>(fix44::PARTY_SUB_ID).unwrap(), "SUB1"); + + let custom_tag = TagU32::new(20000).unwrap(); + // 20000 lives inside the nested NoPartySubIDs entry... + let in_sub = sub.get_field_map().get_raw(custom_tag).unwrap(); + assert_eq!(in_sub, b"custom"); + // ...not on the outer NoPartyIDs entry... + assert!(party.get_field_map().get_raw(custom_tag).is_none()); + // ...nor on the body. + assert!(message.get_field_map().get_raw(custom_tag).is_none()); + } + + #[test] + fn test_group_rejects_when_first_tag_is_not_delimiter() { + // 453 (NoPartyIDs) declares the count but the first tag after it is + // 20000 instead of the delimiter 448. This is a + // hard reject regardless of user-defined-field validation. + let raw = build_pipe_separated_message( + "35=D|49=SENDER|56=TARGET|453=1|20000=x|448=PARTYA|447=D|452=1|", + ); + let builder = MessageBuilder::new(Dictionary::fix44(), CONFIG_NO_USER_VALIDATION).unwrap(); + let parsed = builder.build(&raw); + + assert!(matches!( + parsed, + ParsedMessage::Invalid { + reason: InvalidReason::InvalidOrderInGroup { + tag: 20000, + group_tag: 453, + }, + .. + } + )); + } + + #[test] + fn test_known_outer_field_after_group_fields_ends_group() { + let raw = build_pipe_separated_message( + "35=D|49=SENDER|56=TARGET|453=1|448=PARTYA|447=D|452=1|55=AAPL|59=0|", + ); + let builder = MessageBuilder::new(Dictionary::fix44(), CONFIG).unwrap(); + let message = builder.build(&raw).into_message().unwrap(); + + let party = message.get_group(fix44::NO_PARTY_I_DS, 0).unwrap(); + assert_eq!(party.get::<&str>(fix44::PARTY_ID).unwrap(), "PARTYA"); + assert!(message.get_group(fix44::NO_PARTY_I_DS, 1).is_none()); + + // Symbol must not have been pulled into the group. + assert!(party.get_raw(fix44::SYMBOL).is_none()); + assert_eq!(message.get::<&str>(fix44::SYMBOL).unwrap(), "AAPL"); + assert_eq!(message.get::<&str>(fix44::TIME_IN_FORCE).unwrap(), "0"); + } + + #[test] + fn test_fields_declared_in_group_still_parsed_inside_when_validation_disabled() { + // Same nested-group payload as `nested_repeating_group_entries`, but with + // user-defined-field validation disabled. Tags that *are* declared inside + // the group dictionary (including nested NoPartySubIDs) must still be + // placed inside the group rather than leaking to the body. + let raw = b"8=FIX.4.4|9=247|35=8|34=2|49=Broker|52=20231103-09:30:00|56=Client|11=Order12345|17=Exec12345|150=0|39=0|55=APPL|54=1|38=100|32=50|31=150.00|151=50|14=50|6=150.00|453=2|448=PARTYA|447=D|452=1|802=2|523=SUBPARTYA1|803=1|523=SUBPARTYA2|803=2|448=PARTYB|447=D|452=2|10=129|"; + let builder = MessageBuilder::new(Dictionary::fix44(), CONFIG_NO_USER_VALIDATION).unwrap(); + let message = builder.build(raw).into_message().unwrap(); + + let party_a = message.get_group(fix44::NO_PARTY_I_DS, 0).unwrap(); + let sub = party_a + .get_group(fix44::NO_PARTY_SUB_I_DS.tag(), 0) + .unwrap(); + assert_eq!(sub.get::<&str>(fix44::PARTY_SUB_ID).unwrap(), "SUBPARTYA1"); + + let party_b = message.get_group(fix44::NO_PARTY_I_DS, 1).unwrap(); + assert_eq!(party_b.get::<&str>(fix44::PARTY_ID).unwrap(), "PARTYB"); + } + + #[test] + fn test_group_entry_missing_required_field_is_rejected() { + let dict = Dictionary::fix44(); + let builder = MessageBuilder::new(dict, CONFIG).unwrap(); + + // Find any (msg_type, group, required-non-delim-field) combination so + // we can construct a minimal message that's structurally well-formed + // except for the missing group field. + let candidate = builder + .message_specification + .iter() + .find_map(|(msg_type, message_def)| { + message_def + .groups + .iter() + .find_map(|(group_tag, group_spec)| { + let required_non_delim = group_spec + .fields() + .iter() + .find(|f| f.is_required && f.tag != group_spec.delimiter_tag())?; + Some(( + msg_type.clone(), + group_tag.get(), + group_spec.delimiter_tag().get(), + required_non_delim.tag.get(), + )) + }) + }); + + let Some((msg_type, group_tag, delimiter_tag, missing_tag)) = candidate else { + // The dictionary has no group with a required non-delimiter field. + return; + }; + + let body = format!( + "35={msg_type}|49=S|56=T|34=1|52=20231103-12:00:00|{group_tag}=1|{delimiter_tag}=X|" + ); + let raw = build_pipe_separated_message(&body); + let parsed = builder.build(&raw); + + match parsed { + ParsedMessage::Invalid { + reason: + InvalidReason::RequiredFieldMissing { + tag, + group_tag: Some(g), + }, + .. + } => { + assert_eq!(tag, missing_tag); + assert_eq!(g, group_tag); + } + other => panic!( + "expected RequiredFieldMissing(tag={missing_tag}, group_tag={group_tag}); \ + msg_type={msg_type}, delimiter={delimiter_tag}, got: {}", + match &other { + ParsedMessage::Valid(_) => "Valid".to_string(), + ParsedMessage::Invalid { reason, .. } => match reason { + InvalidReason::InvalidField(t) => format!("InvalidField({t})"), + InvalidReason::InvalidGroup(t) => format!("InvalidGroup({t})"), + InvalidReason::InvalidOrderInGroup { tag, group_tag } => { + format!("InvalidOrderInGroup(tag={tag}, group_tag={group_tag})") + } + InvalidReason::InvalidComponent(s) => format!("InvalidComponent({s})"), + InvalidReason::InvalidMsgType(s) => format!("InvalidMsgType({s})"), + InvalidReason::RequiredFieldMissing { tag, group_tag } => { + format!("RequiredFieldMissing(tag={tag}, group_tag={group_tag:?})") + } + }, + ParsedMessage::Garbled(_) => "Garbled".to_string(), + ParsedMessage::UnexpectedError(_) => "UnexpectedError".to_string(), + } + ), + } + } } diff --git a/crates/hotfix-message/src/encoder.rs b/crates/hotfix-message/src/encoder.rs index 31fd073b..ec48bd4f 100644 --- a/crates/hotfix-message/src/encoder.rs +++ b/crates/hotfix-message/src/encoder.rs @@ -75,7 +75,7 @@ mod tests { msg.set(fix44::PRICE, 150); msg.set(fix44::ORDER_QTY, 60); - let config = Config { separator: b'|' }; + let config = Config::with_separator(b'|'); let raw_message = msg.encode(&config)?; let builder = MessageBuilder::new(Dictionary::fix44(), config)?; @@ -146,7 +146,7 @@ mod tests { party_2.store_field(Field::new(fix44::PARTY_ROLE.tag(), b"2".to_vec())); msg.body.set_groups(vec![party_1, party_2])?; - let config = Config { separator: b'|' }; + let config = Config::with_separator(b'|'); let raw_message = msg.encode(&config)?; let builder = MessageBuilder::new(Dictionary::fix44(), config)?; diff --git a/crates/hotfix-message/src/error.rs b/crates/hotfix-message/src/error.rs index 360331d8..f88c6a83 100644 --- a/crates/hotfix-message/src/error.rs +++ b/crates/hotfix-message/src/error.rs @@ -27,6 +27,11 @@ pub enum ParserError { InvalidComponent(String), #[error("MsgType {0} is not a valid message type")] InvalidMsgType(String), + #[error( + "required field (tag = {tag}) is missing{}", + match group_tag { Some(g) => format!(" from repeating group (group tag = {g})"), None => String::new() } + )] + RequiredFieldMissing { tag: u32, group_tag: Option }, #[error("malformed message: {0}")] Malformed(String), } diff --git a/crates/hotfix-message/src/message.rs b/crates/hotfix-message/src/message.rs index 6fddf40e..fb18fe1f 100644 --- a/crates/hotfix-message/src/message.rs +++ b/crates/hotfix-message/src/message.rs @@ -119,16 +119,28 @@ impl Part for Message { #[derive(Clone, Copy)] pub struct Config { pub(crate) separator: u8, + pub(crate) validate_user_defined_fields: bool, } impl Config { pub const fn with_separator(separator: u8) -> Self { - Self { separator } + Self { + separator, + validate_user_defined_fields: true, + } + } + + pub const fn validate_user_defined_fields(mut self, value: bool) -> Self { + self.validate_user_defined_fields = value; + self } } impl Default for Config { fn default() -> Self { - Self { separator: SOH } + Self { + separator: SOH, + validate_user_defined_fields: true, + } } } diff --git a/crates/hotfix-message/src/parsed_message.rs b/crates/hotfix-message/src/parsed_message.rs index 3b0b0175..59f20bfc 100644 --- a/crates/hotfix-message/src/parsed_message.rs +++ b/crates/hotfix-message/src/parsed_message.rs @@ -27,6 +27,7 @@ pub enum InvalidReason { InvalidOrderInGroup { tag: u32, group_tag: u32 }, InvalidComponent(String), InvalidMsgType(String), + RequiredFieldMissing { tag: u32, group_tag: Option }, } #[derive(Debug)] diff --git a/crates/hotfix/src/config.rs b/crates/hotfix/src/config.rs index 858516b3..b2e1e230 100644 --- a/crates/hotfix/src/config.rs +++ b/crates/hotfix/src/config.rs @@ -114,6 +114,63 @@ pub struct SessionConfig { /// The schedule configuration for the session pub schedule: Option, + + /// The validation configuration for the session + #[serde(default)] + pub validation: ValidationConfig, +} + +#[derive(Clone, Debug, Deserialize)] +/// The configuration of validation rules. +pub struct ValidationConfig { + /// Specifies whether unknown user-defined tags (>= 5000) should cause the message to be rejected. + #[serde(default = "default_true")] + pub validate_user_defined_fields: bool, +} + +impl ValidationConfig { + pub fn builder() -> VerificationConfigBuilder { + VerificationConfigBuilder::default() + } +} + +impl Default for ValidationConfig { + fn default() -> Self { + VerificationConfigBuilder::default().build() + } +} + +pub struct VerificationConfigBuilder { + validate_user_defined_fields: bool, +} + +impl Default for VerificationConfigBuilder { + fn default() -> Self { + Self { + validate_user_defined_fields: true, + } + } +} + +impl VerificationConfigBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn validate_user_defined_fields(mut self, value: bool) -> Self { + self.validate_user_defined_fields = value; + self + } + + pub fn build(self) -> ValidationConfig { + ValidationConfig { + validate_user_defined_fields: self.validate_user_defined_fields, + } + } +} + +fn default_true() -> bool { + true } /// Errors that may occur when loading configuration. @@ -170,6 +227,7 @@ reset_on_logon = false assert_eq!(session_config.tls_config, Some(expected_tls_config)); assert_eq!(session_config.reconnect_interval, 30); assert_eq!(session_config.logon_timeout, 10); + assert!(session_config.validation.validate_user_defined_fields); } #[test] @@ -439,6 +497,45 @@ end_day = "Friday" assert_eq!(session_config.reconnect_interval, 15); } + #[test] + fn test_verification_config_defaults_when_omitted() { + let config_contents = r#" +[[sessions]] +begin_string = "FIX.4.4" +sender_comp_id = "send-comp-id" +target_comp_id = "target-comp-id" +connection_port = 443 +connection_host = "127.0.0.1" +heartbeat_interval = 30 + "#; + + let config: Config = toml::from_str(config_contents).unwrap(); + let session_config = config.sessions.first().unwrap(); + + assert!(session_config.validation.validate_user_defined_fields); + } + + #[test] + fn test_verification_config_can_disable_user_defined_field_validation() { + let config_contents = r#" +[[sessions]] +begin_string = "FIX.4.4" +sender_comp_id = "send-comp-id" +target_comp_id = "target-comp-id" +connection_port = 443 +connection_host = "127.0.0.1" +heartbeat_interval = 30 + +[sessions.validation] +validate_user_defined_fields = false + "#; + + let config: Config = toml::from_str(config_contents).unwrap(); + let session_config = config.sessions.first().unwrap(); + + assert!(!session_config.validation.validate_user_defined_fields); + } + #[test] fn test_load_from_path_success() { let config_contents = r#" diff --git a/crates/hotfix/src/initiator.rs b/crates/hotfix/src/initiator.rs index b86bdca0..782df5d4 100644 --- a/crates/hotfix/src/initiator.rs +++ b/crates/hotfix/src/initiator.rs @@ -309,6 +309,7 @@ mod tests { reconnect_interval: 1, // Short for tests reset_on_logon: false, schedule: None, + validation: Default::default(), } } diff --git a/crates/hotfix/src/message/verification.rs b/crates/hotfix/src/message/verification.rs index ae038168..aaa1b760 100644 --- a/crates/hotfix/src/message/verification.rs +++ b/crates/hotfix/src/message/verification.rs @@ -238,6 +238,7 @@ fn check_target_comp_id( #[cfg(test)] mod tests { use super::{Message, SessionConfig, VerificationFlags, verify_message}; + use crate::config::ValidationConfig; use crate::message::sequence_reset::SequenceReset; use crate::message::verification_issue::{CompIdType, MessageError, VerificationIssue}; use hotfix_message::field_types::Timestamp; @@ -258,6 +259,7 @@ mod tests { reconnect_interval: 0, reset_on_logon: false, schedule: None, + validation: ValidationConfig::default(), } } diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index fc28d81a..440b87cb 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -79,7 +79,8 @@ where let schedule_check_timer = sleep(Duration::from_secs(SCHEDULE_CHECK_INTERVAL)); let dictionary = Self::get_data_dictionary(&config)?; - let message_config = MessageConfig::default(); + let message_config = MessageConfig::default() + .validate_user_defined_fields(config.validation.validate_user_defined_fields); let message_builder = MessageBuilder::new(dictionary, message_config)?; let schedule = config.schedule.as_ref().try_into()?; let ctx = SessionCtx { @@ -184,6 +185,27 @@ where } } } + InvalidReason::RequiredFieldMissing { tag, group_tag } => { + match message.header().get(MSG_SEQ_NUM) { + Ok(msg_seq_num) => { + let text = match group_tag { + Some(group_tag) => { + format!("required tag missing: {tag} (in group {group_tag})") + } + None => format!("required tag missing: {tag}"), + }; + let reject = Reject::new(msg_seq_num) + .session_reject_reason(SessionRejectReason::RequiredTagMissing) + .text(&text); + self.send_message(reject) + .await + .with_send_context("reject for missing required field")?; + } + Err(err) => { + error!("failed to get message seq num: {:?}", err); + } + } + } }, ParsedMessage::UnexpectedError(err) => { error!("unexpected error: {:?}", err); @@ -771,6 +793,7 @@ async fn run_session( mod tests { use super::*; use crate::application::{InboundDecision, OutboundDecision}; + use crate::config::ValidationConfig; use crate::message::OutboundMessage; use crate::store::{Result as StoreResult, StoreError}; use chrono::{DateTime, Datelike, NaiveDate, NaiveTime, TimeDelta, Timelike}; @@ -907,6 +930,7 @@ mod tests { reconnect_interval: 30, reset_on_logon: false, schedule: None, + validation: ValidationConfig::default(), } } diff --git a/crates/hotfix/src/session/test_utils.rs b/crates/hotfix/src/session/test_utils.rs index a7e6cfae..6db46498 100644 --- a/crates/hotfix/src/session/test_utils.rs +++ b/crates/hotfix/src/session/test_utils.rs @@ -91,6 +91,7 @@ pub(crate) fn create_test_ctx(store: FakeMessageStore) -> SessionCtx<(), FakeMes reconnect_interval: 30, reset_on_logon: false, schedule: None, + validation: Default::default(), }, store, application: (), diff --git a/crates/hotfix/tests/connection_test_cases/connect_tests.rs b/crates/hotfix/tests/connection_test_cases/connect_tests.rs index 4e677101..ff1d683a 100644 --- a/crates/hotfix/tests/connection_test_cases/connect_tests.rs +++ b/crates/hotfix/tests/connection_test_cases/connect_tests.rs @@ -26,6 +26,7 @@ fn create_session_config(host: &str, port: u16, tls_config: Option) - reconnect_interval: 30, reset_on_logon: false, schedule: None, + validation: Default::default(), } } diff --git a/crates/hotfix/tests/session_test_cases/common/setup.rs b/crates/hotfix/tests/session_test_cases/common/setup.rs index b6faeea7..7ff6c03f 100644 --- a/crates/hotfix/tests/session_test_cases/common/setup.rs +++ b/crates/hotfix/tests/session_test_cases/common/setup.rs @@ -119,6 +119,7 @@ pub fn create_session_config() -> SessionConfig { reconnect_interval: 30, reset_on_logon: false, schedule: None, + validation: Default::default(), } } diff --git a/examples/load-testing/src/main.rs b/examples/load-testing/src/main.rs index 976f05b1..b4d25685 100644 --- a/examples/load-testing/src/main.rs +++ b/examples/load-testing/src/main.rs @@ -3,7 +3,7 @@ mod messages; use anyhow::Result; use clap::{Parser, ValueEnum}; -use hotfix::config::SessionConfig; +use hotfix::config::{SessionConfig, ValidationConfig}; use hotfix::field_types::{Date, Timestamp}; use hotfix::fix44; use hotfix::fix44::OrdType; @@ -168,5 +168,6 @@ fn get_config() -> SessionConfig { reconnect_interval: 30, reset_on_logon: true, schedule: None, + validation: ValidationConfig::default(), } }