Skip to content

Commit c7a57c8

Browse files
committed
Drop wire::write and replace encode_msg! macro
Now that we consistently use `wire::Message` everywhere, it's easier to simply use `Message::write`/`Type::write` instead of heaving yet another `wire::write` around. Here we drop `wire::write`, replace the `encode_msg` macro with a method that takes `wire::Message`, and convert a bunch of additional places to move semantics.
1 parent d0e22fa commit c7a57c8

File tree

3 files changed

+39
-72
lines changed

3 files changed

+39
-72
lines changed

lightning/src/ln/peer_channel_encryptor.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ use crate::prelude::*;
1212
use crate::ln::msgs;
1313
use crate::ln::msgs::LightningError;
1414
use crate::ln::wire;
15+
use crate::ln::wire::Type;
1516
use crate::sign::{NodeSigner, Recipient};
17+
use crate::util::ser::Writeable;
1618

1719
use bitcoin::hashes::sha256::Hash as Sha256;
1820
use bitcoin::hashes::{Hash, HashEngine};
@@ -570,7 +572,9 @@ impl PeerChannelEncryptor {
570572
// for the 2-byte message type prefix and its MAC.
571573
let mut res = VecWriter(Vec::with_capacity(MSG_BUF_ALLOC_SIZE));
572574
res.0.resize(16 + 2, 0);
573-
wire::write(&message, &mut res).expect("In-memory messages must never fail to serialize");
575+
576+
message.type_id().write(&mut res).expect("In-memory messages must never fail to serialize");
577+
message.write(&mut res).expect("In-memory messages must never fail to serialize");
574578

575579
self.encrypt_message_with_header_0s(&mut res.0);
576580
res.0

lightning/src/ln/peer_handler.rs

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,12 +1144,11 @@ impl From<LightningError> for MessageHandlingError {
11441144
}
11451145
}
11461146

1147-
macro_rules! encode_msg {
1148-
($msg: expr) => {{
1149-
let mut buffer = VecWriter(Vec::with_capacity(MSG_BUF_ALLOC_SIZE));
1150-
wire::write($msg, &mut buffer).unwrap();
1151-
buffer.0
1152-
}};
1147+
fn encode_message<T: wire::Type>(message: wire::Message<T>) -> Vec<u8> {
1148+
let mut buffer = VecWriter(Vec::with_capacity(MSG_BUF_ALLOC_SIZE));
1149+
message.type_id().write(&mut buffer).expect("In-memory messages must never fail to serialize");
1150+
message.write(&mut buffer).expect("In-memory messages must never fail to serialize");
1151+
buffer.0
11531152
}
11541153

11551154
impl<Descriptor: SocketDescriptor, CM: Deref, OM: Deref, L: Deref, NS: Deref, SM: Deref>
@@ -2071,7 +2070,7 @@ where
20712070
for msg in msgs_to_forward.drain(..) {
20722071
self.forward_broadcast_msg(
20732072
&*peers,
2074-
&msg,
2073+
msg,
20752074
peer_node_id.as_ref().map(|(pk, _)| pk),
20762075
false,
20772076
);
@@ -2665,22 +2664,25 @@ where
26652664
/// unless `allow_large_buffer` is set, in which case the message will be treated as critical
26662665
/// and delivered no matter the available buffer space.
26672666
fn forward_broadcast_msg(
2668-
&self, peers: &HashMap<Descriptor, Mutex<Peer>>, msg: &BroadcastGossipMessage,
2667+
&self, peers: &HashMap<Descriptor, Mutex<Peer>>, msg: BroadcastGossipMessage,
26692668
except_node: Option<&PublicKey>, allow_large_buffer: bool,
26702669
) {
26712670
match msg {
2672-
BroadcastGossipMessage::ChannelAnnouncement(ref msg) => {
2671+
BroadcastGossipMessage::ChannelAnnouncement(msg) => {
26732672
log_gossip!(self.logger, "Sending message to all peers except {:?} or the announced channel's counterparties: {:?}", except_node, msg);
2674-
let encoded_msg = encode_msg!(msg);
26752673
let our_channel = self.our_node_id == msg.contents.node_id_1
26762674
|| self.our_node_id == msg.contents.node_id_2;
2677-
2675+
let scid = msg.contents.short_channel_id;
2676+
let node_id_1 = msg.contents.node_id_1;
2677+
let node_id_2 = msg.contents.node_id_2;
2678+
let msg: Message<<CMH::Target as CustomMessageReader>::CustomMessage> =
2679+
Message::ChannelAnnouncement(msg);
2680+
let encoded_msg = encode_message(msg);
26782681
for (_, peer_mutex) in peers.iter() {
26792682
let mut peer = peer_mutex.lock().unwrap();
26802683
if !peer.handshake_complete() {
26812684
continue;
26822685
}
2683-
let scid = msg.contents.short_channel_id;
26842686
if !our_channel && !peer.should_forward_channel_announcement(scid) {
26852687
continue;
26862688
}
@@ -2697,9 +2699,7 @@ where
26972699
continue;
26982700
}
26992701
if let Some((_, their_node_id)) = peer.their_node_id {
2700-
if their_node_id == msg.contents.node_id_1
2701-
|| their_node_id == msg.contents.node_id_2
2702-
{
2702+
if their_node_id == node_id_1 || their_node_id == node_id_2 {
27032703
continue;
27042704
}
27052705
}
@@ -2712,23 +2712,25 @@ where
27122712
peer.gossip_broadcast_buffer.push_back(encoded_message);
27132713
}
27142714
},
2715-
BroadcastGossipMessage::NodeAnnouncement(ref msg) => {
2715+
BroadcastGossipMessage::NodeAnnouncement(msg) => {
27162716
log_gossip!(
27172717
self.logger,
27182718
"Sending message to all peers except {:?} or the announced node: {:?}",
27192719
except_node,
27202720
msg
27212721
);
2722-
let encoded_msg = encode_msg!(msg);
27232722
let our_announcement = self.our_node_id == msg.contents.node_id;
2723+
let msg_node_id = msg.contents.node_id;
27242724

2725+
let msg: Message<<CMH::Target as CustomMessageReader>::CustomMessage> =
2726+
Message::NodeAnnouncement(msg);
2727+
let encoded_msg = encode_message(msg);
27252728
for (_, peer_mutex) in peers.iter() {
27262729
let mut peer = peer_mutex.lock().unwrap();
27272730
if !peer.handshake_complete() {
27282731
continue;
27292732
}
2730-
let node_id = msg.contents.node_id;
2731-
if !our_announcement && !peer.should_forward_node_announcement(node_id) {
2733+
if !our_announcement && !peer.should_forward_node_announcement(msg_node_id) {
27322734
continue;
27332735
}
27342736
debug_assert!(peer.their_node_id.is_some());
@@ -2744,7 +2746,7 @@ where
27442746
continue;
27452747
}
27462748
if let Some((_, their_node_id)) = peer.their_node_id {
2747-
if their_node_id == msg.contents.node_id {
2749+
if their_node_id == msg_node_id {
27482750
continue;
27492751
}
27502752
}
@@ -2764,15 +2766,16 @@ where
27642766
except_node,
27652767
msg
27662768
);
2767-
let encoded_msg = encode_msg!(msg);
2768-
let our_channel = self.our_node_id == *node_id_1 || self.our_node_id == *node_id_2;
2769-
2769+
let our_channel = self.our_node_id == node_id_1 || self.our_node_id == node_id_2;
2770+
let scid = msg.contents.short_channel_id;
2771+
let msg: Message<<CMH::Target as CustomMessageReader>::CustomMessage> =
2772+
Message::ChannelUpdate(msg);
2773+
let encoded_msg = encode_message(msg);
27702774
for (_, peer_mutex) in peers.iter() {
27712775
let mut peer = peer_mutex.lock().unwrap();
27722776
if !peer.handshake_complete() {
27732777
continue;
27742778
}
2775-
let scid = msg.contents.short_channel_id;
27762779
if !our_channel && !peer.should_forward_channel_announcement(scid) {
27772780
continue;
27782781
}
@@ -3251,7 +3254,7 @@ where
32513254
let forward = BroadcastGossipMessage::ChannelAnnouncement(msg);
32523255
self.forward_broadcast_msg(
32533256
peers,
3254-
&forward,
3257+
forward,
32553258
None,
32563259
from_chan_handler,
32573260
);
@@ -3272,7 +3275,7 @@ where
32723275
};
32733276
self.forward_broadcast_msg(
32743277
peers,
3275-
&forward,
3278+
forward,
32763279
None,
32773280
from_chan_handler,
32783281
);
@@ -3296,7 +3299,7 @@ where
32963299
};
32973300
self.forward_broadcast_msg(
32983301
peers,
3299-
&forward,
3302+
forward,
33003303
None,
33013304
from_chan_handler,
33023305
);
@@ -3315,7 +3318,7 @@ where
33153318
let forward = BroadcastGossipMessage::NodeAnnouncement(msg);
33163319
self.forward_broadcast_msg(
33173320
peers,
3318-
&forward,
3321+
forward,
33193322
None,
33203323
from_chan_handler,
33213324
);
@@ -3803,7 +3806,7 @@ where
38033806
let _ = self.message_handler.route_handler.handle_node_announcement(None, &msg);
38043807
self.forward_broadcast_msg(
38053808
&*self.peers.read().unwrap(),
3806-
&BroadcastGossipMessage::NodeAnnouncement(msg),
3809+
BroadcastGossipMessage::NodeAnnouncement(msg),
38073810
None,
38083811
true,
38093812
);
@@ -4618,7 +4621,8 @@ mod tests {
46184621
assert_eq!(peer.gossip_broadcast_buffer.len(), 1);
46194622

46204623
let pending_msg = &peer.gossip_broadcast_buffer[0];
4621-
let expected = encode_msg!(&msg_100);
4624+
let msg: Message<()> = Message::ChannelUpdate(msg_100);
4625+
let expected = encode_message(msg);
46224626
assert_eq!(expected, pending_msg.fetch_encoded_msg_with_type_pfx());
46234627
}
46244628
}

lightning/src/ln/wire.rs

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -425,19 +425,6 @@ where
425425
}
426426
}
427427

428-
/// Writes a message to the data buffer encoded as a 2-byte big-endian type and a variable-length
429-
/// payload.
430-
///
431-
/// # Errors
432-
///
433-
/// Returns an I/O error if the write could not be completed.
434-
pub(crate) fn write<M: Type + Writeable, W: Writer>(
435-
message: &M, buffer: &mut W,
436-
) -> Result<(), io::Error> {
437-
message.type_id().write(buffer)?;
438-
message.write(buffer)
439-
}
440-
441428
mod encode {
442429
/// Defines a constant type identifier for reading messages from the wire.
443430
pub trait Encode {
@@ -737,34 +724,6 @@ mod tests {
737724
}
738725
}
739726

740-
#[test]
741-
fn write_message_with_type() {
742-
let message = msgs::Pong { byteslen: 2u16 };
743-
let mut buffer = Vec::new();
744-
assert!(write(&message, &mut buffer).is_ok());
745-
746-
let type_length = ::core::mem::size_of::<u16>();
747-
let (type_bytes, payload_bytes) = buffer.split_at(type_length);
748-
assert_eq!(u16::from_be_bytes(type_bytes.try_into().unwrap()), msgs::Pong::TYPE);
749-
assert_eq!(payload_bytes, &ENCODED_PONG[type_length..]);
750-
}
751-
752-
#[test]
753-
fn read_message_encoded_with_write() {
754-
let message = msgs::Pong { byteslen: 2u16 };
755-
let mut buffer = Vec::new();
756-
assert!(write(&message, &mut buffer).is_ok());
757-
758-
let decoded_message = read(&mut &buffer[..], &IgnoringMessageHandler {}).unwrap();
759-
match decoded_message {
760-
Message::Pong(msgs::Pong { byteslen: 2u16 }) => (),
761-
Message::Pong(msgs::Pong { byteslen }) => {
762-
panic!("Expected byteslen {}; found: {}", message.byteslen, byteslen);
763-
},
764-
_ => panic!("Expected pong message; found message type: {}", decoded_message.type_id()),
765-
}
766-
}
767-
768727
#[test]
769728
fn is_even_message_type() {
770729
let message = Message::<()>::Unknown(42);

0 commit comments

Comments
 (0)