diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 284cc42212..e7ea2ac77b 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -11,7 +11,9 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightninglabs/neutrino/cache" @@ -166,14 +168,9 @@ type PinnedSyncers map[route.Vertex]struct{} // Config defines the configuration for the service. ALL elements within the // configuration MUST be non-nil for the service to carry out its duties. type Config struct { - // ChainHash is a hash that indicates which resident chain of the - // AuthenticatedGossiper. Any announcements that don't match this - // chain hash will be ignored. - // - // TODO(roasbeef): eventually make into map so can de-multiplex - // incoming announcements - // * also need to do same for Notifier - ChainHash chainhash.Hash + // ChainParams holds the chain parameters for the active network this + // node is participating on. + ChainParams *chaincfg.Params // Graph is the subsystem which is responsible for managing the // topology of lightning network. After incoming channel, node, channel @@ -359,6 +356,12 @@ type Config struct { // updates for a channel and returns true if the channel should be // considered a zombie based on these timestamps. IsStillZombieChannel func(time.Time, time.Time) bool + + // chainHash is a hash that indicates which resident chain of the + // AuthenticatedGossiper. Any announcements that don't match this + // chain hash will be ignored. This is an internal config value obtained + // from ChainParams. + chainHash *chainhash.Hash } // processedNetworkMsg is a wrapper around networkMsg and a boolean. It is @@ -518,6 +521,8 @@ type AuthenticatedGossiper struct { // New creates a new AuthenticatedGossiper instance, initialized with the // passed configuration parameters. func New(cfg Config, selfKeyDesc *keychain.KeyDescriptor) *AuthenticatedGossiper { + cfg.chainHash = cfg.ChainParams.GenesisHash + gossiper := &AuthenticatedGossiper{ selfKey: selfKeyDesc.PubKey, selfKeyLoc: selfKeyDesc.KeyLocator, @@ -538,7 +543,7 @@ func New(cfg Config, selfKeyDesc *keychain.KeyDescriptor) *AuthenticatedGossiper } gossiper.syncMgr = newSyncManager(&SyncManagerCfg{ - ChainHash: cfg.ChainHash, + ChainHash: *cfg.chainHash, ChanSeries: cfg.ChanSeries, RotateTicker: cfg.RotateTicker, HistoricalSyncTicker: cfg.HistoricalSyncTicker, @@ -1946,9 +1951,28 @@ func (d *AuthenticatedGossiper) processRejectedEdge( // fetchPKScript fetches the output script for the given SCID. func (d *AuthenticatedGossiper) fetchPKScript(chanID *lnwire.ShortChannelID) ( - []byte, error) { + txscript.ScriptClass, btcutil.Address, error) { + + pkScript, err := lnwallet.FetchPKScriptWithQuit( + d.cfg.ChainIO, chanID, d.quit, + ) + if err != nil { + return txscript.WitnessUnknownTy, nil, err + } + + scriptClass, addrs, _, err := txscript.ExtractPkScriptAddrs( + pkScript, d.cfg.ChainParams, + ) + if err != nil { + return txscript.WitnessUnknownTy, nil, err + } + + if len(addrs) != 1 { + return txscript.WitnessUnknownTy, nil, fmt.Errorf("expected "+ + "1 address, got: %d", len(addrs)) + } - return lnwallet.FetchPKScriptWithQuit(d.cfg.ChainIO, chanID, d.quit) + return scriptClass, addrs[0], nil } // addNode processes the given node announcement, and adds it to our channel @@ -2448,10 +2472,10 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // We'll ignore any channel announcements that target any chain other // than the set of chains we know of. - if !bytes.Equal(ann.ChainHash[:], d.cfg.ChainHash[:]) { + if !bytes.Equal(ann.ChainHash[:], d.cfg.chainHash[:]) { err := fmt.Errorf("ignoring ChannelAnnouncement1 from chain=%v"+ ", gossiper on chain=%v", ann.ChainHash, - d.cfg.ChainHash) + d.cfg.chainHash) log.Errorf(err.Error()) key := newRejectCacheKey( @@ -2837,9 +2861,9 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // We'll ignore any channel updates that target any chain other than // the set of chains we know of. - if !bytes.Equal(upd.ChainHash[:], d.cfg.ChainHash[:]) { + if !bytes.Equal(upd.ChainHash[:], d.cfg.chainHash[:]) { err := fmt.Errorf("ignoring ChannelUpdate from chain=%v, "+ - "gossiper on chain=%v", upd.ChainHash, d.cfg.ChainHash) + "gossiper on chain=%v", upd.ChainHash, d.cfg.chainHash) log.Errorf(err.Error()) key := newRejectCacheKey( diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index db632cdafe..ed209102f9 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -16,6 +16,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" @@ -766,7 +767,8 @@ func createTestCtx(t *testing.T, startHeight uint32, isChanPeer bool) ( } gossiper := New(Config{ - Notifier: notifier, + ChainParams: &chaincfg.MainNetParams, + Notifier: notifier, Broadcast: func(senders map[route.Vertex]struct{}, msgs ...lnwire.Message) error { @@ -1480,6 +1482,7 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { //nolint:lll gossiper := New(Config{ + ChainParams: &chaincfg.MainNetParams, Notifier: ctx.gossiper.cfg.Notifier, Broadcast: ctx.gossiper.cfg.Broadcast, NotifyWhenOnline: ctx.gossiper.reliableSender.cfg.NotifyWhenOnline, diff --git a/go.mod b/go.mod index 3a559abb68..56e64e7a3b 100644 --- a/go.mod +++ b/go.mod @@ -210,3 +210,5 @@ replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-d go 1.22.6 retract v0.0.2 + +replace github.com/lightningnetwork/lnd/tlv => github.com/ellemouton/lnd/tlv v0.0.0-20241012094556-298ff9eaed58 diff --git a/go.sum b/go.sum index ea51f3e958..9db5b97bea 100644 --- a/go.sum +++ b/go.sum @@ -185,6 +185,8 @@ github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4 github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/ellemouton/lnd/tlv v0.0.0-20241012094556-298ff9eaed58 h1:j+Sr9J/exZ3CAx6GaToJKUUkll385erxNiwTwZqegec= +github.com/ellemouton/lnd/tlv v0.0.0-20241012094556-298ff9eaed58/go.mod h1:/CmY4VbItpOldksocmGT4lxiJqRP9oLxwSZOda2kzNQ= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= @@ -465,8 +467,6 @@ github.com/lightningnetwork/lnd/sqldb v1.0.4 h1:9cMwPxcrLQG8UmyZO4q8SpR7NmxSwBMb github.com/lightningnetwork/lnd/sqldb v1.0.4/go.mod h1:4cQOkdymlZ1znnjuRNvMoatQGJkRneTj2CoPSPaQhWo= github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM= github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA= -github.com/lightningnetwork/lnd/tlv v1.2.6 h1:icvQG2yDr6k3ZuZzfRdG3EJp6pHurcuh3R6dg0gv/Mw= -github.com/lightningnetwork/lnd/tlv v1.2.6/go.mod h1:/CmY4VbItpOldksocmGT4lxiJqRP9oLxwSZOda2kzNQ= github.com/lightningnetwork/lnd/tor v1.1.2 h1:3zv9z/EivNFaMF89v3ciBjCS7kvCj4ZFG7XvD2Qq0/k= github.com/lightningnetwork/lnd/tor v1.1.2/go.mod h1:j7T9uJ2NLMaHwE7GiBGnpYLn4f7NRoTM6qj+ul6/ycA= github.com/ltcsuite/ltcd v0.0.0-20190101042124-f37f8bf35796 h1:sjOGyegMIhvgfq5oaue6Td+hxZuf3tDC8lAPrFldqFw= diff --git a/lnwire/announcement_signatures_2.go b/lnwire/announcement_signatures_2.go index a104470321..526b995485 100644 --- a/lnwire/announcement_signatures_2.go +++ b/lnwire/announcement_signatures_2.go @@ -3,6 +3,8 @@ package lnwire import ( "bytes" "io" + + "github.com/lightningnetwork/lnd/tlv" ) // AnnounceSignatures2 is a direct message between two endpoints of a @@ -14,27 +16,40 @@ type AnnounceSignatures2 struct { // Channel id is better for users and debugging and short channel id is // used for quick test on existence of the particular utxo inside the // blockchain, because it contains information about block. - ChannelID ChannelID + ChannelID tlv.RecordT[tlv.TlvType0, ChannelID] // ShortChannelID is the unique description of the funding transaction. // It is constructed with the most significant 3 bytes as the block // height, the next 3 bytes indicating the transaction index within the // block, and the least significant two bytes indicating the output // index which pays to the channel. - ShortChannelID ShortChannelID + ShortChannelID tlv.RecordT[tlv.TlvType2, ShortChannelID] // PartialSignature is the combination of the partial Schnorr signature // created for the node's bitcoin key with the partial signature created // for the node's node ID key. - PartialSignature PartialSig - - // ExtraOpaqueData is the set of data that was appended to this - // message, some of which we may not actually know how to iterate or - // parse. By holding onto this data, we ensure that we're able to - // properly validate the set of signatures that cover these new fields, - // and ensure we're able to make upgrades to the network in a forwards - // compatible manner. - ExtraOpaqueData ExtraOpaqueData + PartialSignature tlv.RecordT[tlv.TlvType4, PartialSig] + + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraFieldsInSignedRange map[uint64][]byte +} + +// NewAnnSigs2 is a constructor for AnnounceSignatures2. +func NewAnnSigs2(chanID ChannelID, scid ShortChannelID, + partialSig PartialSig) *AnnounceSignatures2 { + + return &AnnounceSignatures2{ + ChannelID: tlv.NewRecordT[tlv.TlvType0, ChannelID](chanID), + ShortChannelID: tlv.NewRecordT[tlv.TlvType2, ShortChannelID]( + scid, + ), + PartialSignature: tlv.NewRecordT[tlv.TlvType4, PartialSig]( + partialSig, + ), + ExtraFieldsInSignedRange: make(map[uint64][]byte, 0), + } } // A compile time check to ensure AnnounceSignatures2 implements the @@ -46,32 +61,29 @@ var _ Message = (*AnnounceSignatures2)(nil) // // This is part of the lnwire.Message interface. func (a *AnnounceSignatures2) Decode(r io.Reader, _ uint32) error { - return ReadElements(r, - &a.ChannelID, - &a.ShortChannelID, - &a.PartialSignature, - &a.ExtraOpaqueData, - ) -} - -// Encode serializes the target AnnounceSignatures2 into the passed io.Writer -// observing the protocol version specified. -// -// This is part of the lnwire.Message interface. -func (a *AnnounceSignatures2) Encode(w *bytes.Buffer, _ uint32) error { - if err := WriteChannelID(w, a.ChannelID); err != nil { + stream, err := tlv.NewStream(ProduceRecordsSorted( + &a.ChannelID, &a.ShortChannelID, &a.PartialSignature, + )...) + if err != nil { return err } - if err := WriteShortChannelID(w, a.ShortChannelID); err != nil { + typeMap, err := stream.DecodeWithParsedTypesP2P(r) + if err != nil { return err } - if err := WriteElement(w, a.PartialSignature); err != nil { - return err - } + a.ExtraFieldsInSignedRange = ExtraSignedFieldsFromTypeMap(typeMap) + + return nil +} - return WriteBytes(w, a.ExtraOpaqueData) +// Encode serializes the target AnnounceSignatures2 into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (a *AnnounceSignatures2) Encode(w *bytes.Buffer, _ uint32) error { + return EncodePureTLVMessage(a, w) } // MsgType returns the integer uniquely identifying this message type on the @@ -82,16 +94,34 @@ func (a *AnnounceSignatures2) MsgType() MessageType { return MsgAnnounceSignatures2 } +// AllRecords returns all the TLV records for the message. This will include all +// the records we know about along with any that we don't know about but that +// fall in the signed TLV range. +// +// NOTE: this is part of the PureTLVMessage interface. +func (a *AnnounceSignatures2) AllRecords() []tlv.Record { + recordProducers := []tlv.RecordProducer{ + &a.ChannelID, &a.ShortChannelID, + &a.PartialSignature, + } + + recordProducers = append(recordProducers, RecordsAsProducers( + tlv.MapToRecords(a.ExtraFieldsInSignedRange), + )...) + + return ProduceRecordsSorted(recordProducers...) +} + // SCID returns the ShortChannelID of the channel. // // NOTE: this is part of the AnnounceSignatures interface. func (a *AnnounceSignatures2) SCID() ShortChannelID { - return a.ShortChannelID + return a.ShortChannelID.Val } // ChanID returns the ChannelID identifying the channel. // // NOTE: this is part of the AnnounceSignatures interface. func (a *AnnounceSignatures2) ChanID() ChannelID { - return a.ChannelID + return a.ChannelID.Val } diff --git a/lnwire/channel_announcement_2.go b/lnwire/channel_announcement_2.go index 074e7d0842..a236966689 100644 --- a/lnwire/channel_announcement_2.go +++ b/lnwire/channel_announcement_2.go @@ -12,9 +12,6 @@ import ( // ChannelAnnouncement2 message is used to announce the existence of a taproot // channel between two peers in the network. type ChannelAnnouncement2 struct { - // Signature is a Schnorr signature over the TLV stream of the message. - Signature Sig - // ChainHash denotes the target chain that this channel was opened // within. This value should be the genesis hash of the target chain. ChainHash tlv.RecordT[tlv.TlvType0, chainhash.Hash] @@ -59,47 +56,103 @@ type ChannelAnnouncement2 struct { // the funding output is a pure 2-of-2 MuSig aggregate public key. MerkleRootHash tlv.OptionalRecordT[tlv.TlvType16, [32]byte] - // ExtraOpaqueData is the set of data that was appended to this - // message, some of which we may not actually know how to iterate or - // parse. By holding onto this data, we ensure that we're able to - // properly validate the set of signatures that cover these new fields, - // and ensure we're able to make upgrades to the network in a forwards - // compatible manner. - ExtraOpaqueData ExtraOpaqueData + // Signature is a Schnorr signature over serialised signed-range TLV + // stream of the message. + Signature tlv.RecordT[tlv.TlvType160, Sig] + + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraFieldsInSignedRange map[uint64][]byte } -// Decode deserializes a serialized AnnounceSignatures1 stored in the passed -// io.Reader observing the specified protocol version. +// Encode serializes the target AnnounceSignatures1 into the passed io.Writer +// observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (c *ChannelAnnouncement2) Decode(r io.Reader, _ uint32) error { - err := ReadElement(r, &c.Signature) - if err != nil { - return err - } - c.Signature.ForceSchnorr() +func (c *ChannelAnnouncement2) Encode(w *bytes.Buffer, _ uint32) error { + return EncodePureTLVMessage(c, w) +} + +// AllRecords returns all the TLV records for the message. This will include all +// the records we know about along with any that we don't know about but that +// fall in the signed TLV range. +// +// NOTE: this is part of the PureTLVMessage interface. +func (c *ChannelAnnouncement2) AllRecords() []tlv.Record { + recordProducers := append( + c.allNonSignatureRecordProducers(), &c.Signature, + ) - return c.DecodeTLVRecords(r) + return ProduceRecordsSorted(recordProducers...) } -// DecodeTLVRecords decodes only the TLV section of the message. -func (c *ChannelAnnouncement2) DecodeTLVRecords(r io.Reader) error { - // First extract into extra opaque data. - var tlvRecords ExtraOpaqueData - if err := ReadElements(r, &tlvRecords); err != nil { - return err +func (c *ChannelAnnouncement2) allNonSignatureRecordProducers() []tlv.RecordProducer { + // The chain-hash record is only included if it is _not_ equal to the + // bitcoin mainnet genisis block hash. + var recordProducers []tlv.RecordProducer + if !c.ChainHash.Val.IsEqual(chaincfg.MainNetParams.GenesisHash) { + hash := tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() + hash.Val = c.ChainHash.Val + + recordProducers = append(recordProducers, &hash) } + recordProducers = append(recordProducers, + &c.Features, &c.ShortChannelID, &c.Capacity, &c.NodeID1, + &c.NodeID2, + ) + + c.BitcoinKey1.WhenSome(func(key tlv.RecordT[tlv.TlvType12, [33]byte]) { + recordProducers = append(recordProducers, &key) + }) + + c.BitcoinKey2.WhenSome(func(key tlv.RecordT[tlv.TlvType14, [33]byte]) { + recordProducers = append(recordProducers, &key) + }) + + c.MerkleRootHash.WhenSome( + func(hash tlv.RecordT[tlv.TlvType16, [32]byte]) { + recordProducers = append(recordProducers, &hash) + }, + ) + + recordProducers = append(recordProducers, RecordsAsProducers( + tlv.MapToRecords(c.ExtraFieldsInSignedRange), + )...) + + return recordProducers +} + +// Decode deserializes a serialized AnnounceSignatures1 stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *ChannelAnnouncement2) Decode(r io.Reader, _ uint32) error { var ( chainHash = tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() btcKey1 = tlv.ZeroRecordT[tlv.TlvType12, [33]byte]() btcKey2 = tlv.ZeroRecordT[tlv.TlvType14, [33]byte]() merkleRootHash = tlv.ZeroRecordT[tlv.TlvType16, [32]byte]() ) - typeMap, err := tlvRecords.ExtractRecords( - &chainHash, &c.Features, &c.ShortChannelID, &c.Capacity, - &c.NodeID1, &c.NodeID2, &btcKey1, &btcKey2, &merkleRootHash, - ) + stream, err := tlv.NewStream(ProduceRecordsSorted( + &chainHash, + &c.Features, + &c.ShortChannelID, + &c.Capacity, + &c.NodeID1, + &c.NodeID2, + &btcKey1, + &btcKey2, + &merkleRootHash, + &c.Signature, + )...) + if err != nil { + return err + } + c.Signature.Val.ForceSchnorr() + + typeMap, err := stream.DecodeWithParsedTypesP2P(r) if err != nil { return err } @@ -122,68 +175,68 @@ func (c *ChannelAnnouncement2) DecodeTLVRecords(r io.Reader) error { c.MerkleRootHash = tlv.SomeRecordT(merkleRootHash) } - if len(tlvRecords) != 0 { - c.ExtraOpaqueData = tlvRecords - } + c.ExtraFieldsInSignedRange = ExtraSignedFieldsFromTypeMap(typeMap) return nil } -// Encode serializes the target AnnounceSignatures1 into the passed io.Writer -// observing the protocol version specified. -// -// This is part of the lnwire.Message interface. -func (c *ChannelAnnouncement2) Encode(w *bytes.Buffer, _ uint32) error { - _, err := w.Write(c.Signature.RawBytes()) +// DecodeNonSigTLVRecords decodes only the TLV section of the message. +func (c *ChannelAnnouncement2) DecodeNonSigTLVRecords(r io.Reader) error { + var ( + chainHash = tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() + btcKey1 = tlv.ZeroRecordT[tlv.TlvType12, [33]byte]() + btcKey2 = tlv.ZeroRecordT[tlv.TlvType14, [33]byte]() + merkleRootHash = tlv.ZeroRecordT[tlv.TlvType16, [32]byte]() + ) + stream, err := tlv.NewStream(ProduceRecordsSorted( + &chainHash, + &c.Features, + &c.ShortChannelID, + &c.Capacity, + &c.NodeID1, + &c.NodeID2, + &btcKey1, + &btcKey2, + &merkleRootHash, + )...) if err != nil { return err } - _, err = c.DataToSign() + + typeMap, err := stream.DecodeWithParsedTypesP2P(r) if err != nil { return err } - return WriteBytes(w, c.ExtraOpaqueData) -} + // By default, the chain-hash is the bitcoin mainnet genesis block hash. + c.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash + if _, ok := typeMap[c.ChainHash.TlvType()]; ok { + c.ChainHash.Val = chainHash.Val + } -// DataToSign encodes the data to be signed into the ExtraOpaqueData member and -// returns it. -func (c *ChannelAnnouncement2) DataToSign() ([]byte, error) { - // The chain-hash record is only included if it is _not_ equal to the - // bitcoin mainnet genisis block hash. - var recordProducers []tlv.RecordProducer - if !c.ChainHash.Val.IsEqual(chaincfg.MainNetParams.GenesisHash) { - hash := tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() - hash.Val = c.ChainHash.Val + if _, ok := typeMap[c.BitcoinKey1.TlvType()]; ok { + c.BitcoinKey1 = tlv.SomeRecordT(btcKey1) + } - recordProducers = append(recordProducers, &hash) + if _, ok := typeMap[c.BitcoinKey2.TlvType()]; ok { + c.BitcoinKey2 = tlv.SomeRecordT(btcKey2) } - recordProducers = append(recordProducers, - &c.Features, &c.ShortChannelID, &c.Capacity, &c.NodeID1, - &c.NodeID2, - ) + if _, ok := typeMap[c.MerkleRootHash.TlvType()]; ok { + c.MerkleRootHash = tlv.SomeRecordT(merkleRootHash) + } - c.BitcoinKey1.WhenSome(func(key tlv.RecordT[tlv.TlvType12, [33]byte]) { - recordProducers = append(recordProducers, &key) - }) + c.ExtraFieldsInSignedRange = ExtraSignedFieldsFromTypeMap(typeMap) - c.BitcoinKey2.WhenSome(func(key tlv.RecordT[tlv.TlvType14, [33]byte]) { - recordProducers = append(recordProducers, &key) - }) + return nil +} - c.MerkleRootHash.WhenSome( - func(hash tlv.RecordT[tlv.TlvType16, [32]byte]) { - recordProducers = append(recordProducers, &hash) - }, +// EncodeAllNonSigFields encodes the entire message to the given writer but +// excludes the signature field. +func (c *ChannelAnnouncement2) EncodeAllNonSigFields(w io.Writer) error { + return EncodeRecordsTo( + w, ProduceRecordsSorted(c.allNonSignatureRecordProducers()...), ) - - err := EncodeMessageExtraData(&c.ExtraOpaqueData, recordProducers...) - if err != nil { - return nil, err - } - - return c.ExtraOpaqueData, nil } // MsgType returns the integer uniquely identifying this message type on the @@ -198,6 +251,10 @@ func (c *ChannelAnnouncement2) MsgType() MessageType { // lnwire.Message interface. var _ Message = (*ChannelAnnouncement2)(nil) +// A compile time check to ensure ChannelAnnouncement2 implements the +// lnwire.PureTLVMessage interface. +var _ PureTLVMessage = (*ChannelAnnouncement2)(nil) + // Node1KeyBytes returns the bytes representing the public key of node 1 in the // channel. // diff --git a/lnwire/channel_id.go b/lnwire/channel_id.go index 1615eb7471..5c9eca34fb 100644 --- a/lnwire/channel_id.go +++ b/lnwire/channel_id.go @@ -3,10 +3,12 @@ package lnwire import ( "encoding/binary" "encoding/hex" + "io" "math" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/tlv" ) const ( @@ -36,6 +38,40 @@ func (c ChannelID) String() string { return hex.EncodeToString(c[:]) } +// Record returns a TLV record that can be used to encode/decode a ChannelID +// to/from a TLV stream. +func (c *ChannelID) Record() tlv.Record { + return tlv.MakeStaticRecord(0, c, 32, encodeChannelID, decodeChannelID) +} + +func encodeChannelID(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*ChannelID); ok { + bigSize := [32]byte(*v) + + return tlv.EBytes32(w, &bigSize, buf) + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.ChannelID") +} + +func decodeChannelID(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if v, ok := val.(*ChannelID); ok { + var id [32]byte + err := tlv.DBytes32(r, &id, buf, l) + if err != nil { + return err + } + + *v = id + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "lnwire.ChannelID", l, l) +} + // NewChanIDFromOutPoint converts a target OutPoint into a ChannelID that is // usable within the network. In order to convert the OutPoint into a ChannelID, // we XOR the lower 2-bytes of the txid within the OutPoint with the big-endian diff --git a/lnwire/channel_update_2.go b/lnwire/channel_update_2.go index 79a76aad61..08be40a7c5 100644 --- a/lnwire/channel_update_2.go +++ b/lnwire/channel_update_2.go @@ -22,10 +22,6 @@ const ( // HTLCs and other parameters. This message is also used to redeclare initially // set channel parameters. type ChannelUpdate2 struct { - // Signature is used to validate the announced data and prove the - // ownership of node id. - Signature Sig - // ChainHash denotes the target chain that this channel was opened // within. This value should be the genesis hash of the target chain. // Along with the short channel ID, this uniquely identifies the @@ -74,10 +70,22 @@ type ChannelUpdate2 struct { // millionth of a satoshi. FeeProportionalMillionths tlv.RecordT[tlv.TlvType18, uint32] - // ExtraOpaqueData is the set of data that was appended to this message - // to fill out the full maximum transport message size. These fields can - // be used to specify optional data such as custom TLV fields. - ExtraOpaqueData ExtraOpaqueData + // Signature is used to validate the announced data and prove the + // ownership of node id. + Signature tlv.RecordT[tlv.TlvType160, Sig] + + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraFieldsInSignedRange map[uint64][]byte +} + +// Encode serializes the target ChannelUpdate2 into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *ChannelUpdate2) Encode(w *bytes.Buffer, _ uint32) error { + return EncodePureTLVMessage(c, w) } // Decode deserializes a serialized ChannelUpdate2 stored in the passed @@ -85,17 +93,6 @@ type ChannelUpdate2 struct { // // This is part of the lnwire.Message interface. func (c *ChannelUpdate2) Decode(r io.Reader, _ uint32) error { - err := ReadElement(r, &c.Signature) - if err != nil { - return err - } - c.Signature.ForceSchnorr() - - return c.DecodeTLVRecords(r) -} - -// DecodeTLVRecords decodes only the TLV section of the message. -func (c *ChannelUpdate2) DecodeTLVRecords(r io.Reader) error { // First extract into extra opaque data. var tlvRecords ExtraOpaqueData if err := ReadElements(r, &tlvRecords); err != nil { @@ -111,10 +108,12 @@ func (c *ChannelUpdate2) DecodeTLVRecords(r io.Reader) error { &secondPeer, &c.CLTVExpiryDelta, &c.HTLCMinimumMsat, &c.HTLCMaximumMsat, &c.FeeBaseMsat, &c.FeeProportionalMillionths, + &c.Signature, ) if err != nil { return err } + c.Signature.Val.ForceSchnorr() // By default, the chain-hash is the bitcoin mainnet genesis block hash. c.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash @@ -150,38 +149,21 @@ func (c *ChannelUpdate2) DecodeTLVRecords(r io.Reader) error { c.FeeProportionalMillionths.Val = defaultFeeProportionalMillionths //nolint:lll } - if len(tlvRecords) != 0 { - c.ExtraOpaqueData = tlvRecords - } + c.ExtraFieldsInSignedRange = ExtraSignedFieldsFromTypeMap(typeMap) return nil } -// Encode serializes the target ChannelUpdate2 into the passed io.Writer -// observing the protocol version specified. +// AllRecords returns all the TLV records for the message. This will include all +// the records we know about along with any that we don't know about but that +// fall in the signed TLV range. // -// This is part of the lnwire.Message interface. -func (c *ChannelUpdate2) Encode(w *bytes.Buffer, _ uint32) error { - _, err := w.Write(c.Signature.RawBytes()) - if err != nil { - return err - } - - _, err = c.DataToSign() - if err != nil { - return err - } - - return WriteBytes(w, c.ExtraOpaqueData) -} +// NOTE: this is part of the PureTLVMessage interface. +func (c *ChannelUpdate2) AllRecords() []tlv.Record { + var recordProducers []tlv.RecordProducer -// DataToSign is used to retrieve part of the announcement message which should -// be signed. For the ChannelUpdate2 message, this includes the serialised TLV -// records. -func (c *ChannelUpdate2) DataToSign() ([]byte, error) { // The chain-hash record is only included if it is _not_ equal to the // bitcoin mainnet genisis block hash. - var recordProducers []tlv.RecordProducer if !c.ChainHash.Val.IsEqual(chaincfg.MainNetParams.GenesisHash) { hash := tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() hash.Val = c.ChainHash.Val @@ -190,7 +172,7 @@ func (c *ChannelUpdate2) DataToSign() ([]byte, error) { } recordProducers = append(recordProducers, - &c.ShortChannelID, &c.BlockHeight, + &c.ShortChannelID, &c.BlockHeight, &c.Signature, ) // Only include the disable flags if any bit is set. @@ -225,12 +207,11 @@ func (c *ChannelUpdate2) DataToSign() ([]byte, error) { ) } - err := EncodeMessageExtraData(&c.ExtraOpaqueData, recordProducers...) - if err != nil { - return nil, err - } + recordProducers = append(recordProducers, RecordsAsProducers( + tlv.MapToRecords(c.ExtraFieldsInSignedRange), + )...) - return c.ExtraOpaqueData, nil + return ProduceRecordsSorted(recordProducers...) } // MsgType returns the integer uniquely identifying this message type on the @@ -241,8 +222,14 @@ func (c *ChannelUpdate2) MsgType() MessageType { return MsgChannelUpdate2 } -func (c *ChannelUpdate2) ExtraData() ExtraOpaqueData { - return c.ExtraOpaqueData +func (c *ChannelUpdate2) ExtraData() (ExtraOpaqueData, error) { + var buf *bytes.Buffer + err := EncodeRecordsTo(buf, tlv.MapToRecords(c.ExtraFieldsInSignedRange)) + if err != nil { + return nil, err + } + + return buf.Bytes(), nil } // A compile time check to ensure ChannelUpdate2 implements the diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 6b9630f58a..0dbe672579 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -420,8 +420,26 @@ func TestEmptyMessageUnknownType(t *testing.T) { // randCustomRecords generates a random set of custom records for testing. func randCustomRecords(t *testing.T, r *rand.Rand) CustomRecords { + customRecords := randTLVMap(t, r, MinCustomRecordsTlvType) + + // Validate the custom records as a sanity check. + err := CustomRecords(customRecords).Validate() + require.NoError(t, err) + + return customRecords +} + +// randSignedRangeRecords generates a random set of signed records in the +// second "signed" tlv range for pure TLV messages. +func randSignedRangeRecords(t *testing.T, r *rand.Rand) CustomRecords { + return randTLVMap(t, r, pureTLVSignedSecondRangeStart) +} + +func randTLVMap(t *testing.T, r *rand.Rand, + rangeStart uint64) map[uint64][]byte { + var ( - customRecords = CustomRecords{} + m = make(map[uint64][]byte) // We'll generate a random number of records, between 1 and 10. numRecords = r.Intn(9) + 1 @@ -432,21 +450,17 @@ func randCustomRecords(t *testing.T, r *rand.Rand) CustomRecords { // Keys must be equal to or greater than // MinCustomRecordsTlvType. keyOffset := uint64(r.Intn(100)) - key := MinCustomRecordsTlvType + keyOffset + key := rangeStart + keyOffset // Values are byte slices of any length. value := make([]byte, r.Intn(10)) _, err := r.Read(value) require.NoError(t, err) - customRecords[key] = value + m[key] = value } - // Validate the custom records as a sanity check. - err := customRecords.Validate() - require.NoError(t, err) - - return customRecords + return m } // TestLightningWireProtocol uses the testing/quick package to create a series @@ -1505,37 +1519,29 @@ func TestLightningWireProtocol(t *testing.T) { MsgAnnounceSignatures2: func(v []reflect.Value, r *rand.Rand) { - req := AnnounceSignatures2{ - ShortChannelID: NewShortChanIDFromInt( - uint64(r.Int63()), - ), - ExtraOpaqueData: make([]byte, 0), - } + var req AnnounceSignatures2 - _, err := r.Read(req.ChannelID[:]) + req.ExtraFieldsInSignedRange = randSignedRangeRecords( + t, r, + ) + + _, err := r.Read(req.ChannelID.Val[:]) require.NoError(t, err) partialSig, err := randPartialSig(r) require.NoError(t, err) - req.PartialSignature = *partialSig - - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make( - []byte, numExtraBytes, - ) - _, err := r.Read(req.ExtraOpaqueData[:]) - require.NoError(t, err) - } + req.PartialSignature.Val = *partialSig v[0] = reflect.ValueOf(req) }, MsgChannelAnnouncement2: func(v []reflect.Value, r *rand.Rand) { - req := ChannelAnnouncement2{ - Signature: testSchnorrSig, - ExtraOpaqueData: make([]byte, 0), - } + var req ChannelAnnouncement2 + + req.Signature.Val = testSchnorrSig + req.ExtraFieldsInSignedRange = randSignedRangeRecords( + t, r, + ) req.ShortChannelID.Val = NewShortChanIDFromInt( uint64(r.Int63()), @@ -1584,23 +1590,16 @@ func TestLightningWireProtocol(t *testing.T) { } } - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make( - []byte, numExtraBytes, - ) - _, err := r.Read(req.ExtraOpaqueData[:]) - require.NoError(t, err) - } - v[0] = reflect.ValueOf(req) }, MsgChannelUpdate2: func(v []reflect.Value, r *rand.Rand) { - req := ChannelUpdate2{ - Signature: testSchnorrSig, - ExtraOpaqueData: make([]byte, 0), - } + var req ChannelUpdate2 + req.ExtraFieldsInSignedRange = randSignedRangeRecords( + t, r, + ) + + req.Signature.Val = testSchnorrSig req.ShortChannelID.Val = NewShortChanIDFromInt( uint64(r.Int63()), ) @@ -1661,15 +1660,6 @@ func TestLightningWireProtocol(t *testing.T) { ChanUpdateDisableOutgoing } - numExtraBytes := r.Int31n(1000) - if numExtraBytes > 0 { - req.ExtraOpaqueData = make( - []byte, numExtraBytes, - ) - _, err := r.Read(req.ExtraOpaqueData[:]) - require.NoError(t, err) - } - v[0] = reflect.ValueOf(req) }, } diff --git a/lnwire/pure_tlv.go b/lnwire/pure_tlv.go new file mode 100644 index 0000000000..7336f1cd94 --- /dev/null +++ b/lnwire/pure_tlv.go @@ -0,0 +1,89 @@ +package lnwire + +import ( + "bytes" + + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + pureTLVUnsignedRangeOneStart = 160 + pureTLVSignedSecondRangeStart = 1000000000 + pureTLVUnsignedRangeTwoStart = 3000000000 +) + +// PureTLVMessage describes an LN message that is a pure TLV stream. If the +// message includes a signature, it will sign all the TLV records in the +// inclusive ranges: 0 to 159 and 1000000000 to 2999999999. +type PureTLVMessage interface { + // AllRecords returns all the TLV records for the message. This will + // include all the records we know about along with any that we don't + // know about but that fall in the signed TLV range. + AllRecords() []tlv.Record +} + +// EncodePureTLVMessage encodes the given PureTLVMessage to the given buffer. +func EncodePureTLVMessage(msg PureTLVMessage, buf *bytes.Buffer) error { + return EncodeRecordsTo(buf, msg.AllRecords()) +} + +// SerialiseFieldsToSign serialises all the records from the given +// PureTLVMessage that fall within the signed TLV range. +func SerialiseFieldsToSign(msg PureTLVMessage) ([]byte, error) { + // Filter out all the fields not in the signed ranges. + var signedRecords []tlv.Record + for _, record := range msg.AllRecords() { + if InUnsignedRange(record.Type()) { + continue + } + + signedRecords = append(signedRecords, record) + } + + var buf bytes.Buffer + if err := EncodeRecordsTo(&buf, signedRecords); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// InUnsignedRange returns true if the given TLV type falls outside the TLV +// ranges that the signature of a pure TLV message will cover. +func InUnsignedRange(t tlv.Type) bool { + return (t >= pureTLVUnsignedRangeOneStart && + t < pureTLVSignedSecondRangeStart) || + t >= pureTLVUnsignedRangeTwoStart +} + +// ExtraSignedFieldsFromTypeMap is a helper that can be used alongside calls to +// the tlv.Stream DecodeWithParsedTypesP2P or DecodeWithParsedTypes methods to +// extract the tlv type and value pairs in the defined PureTLVMessage signed +// range which we have not handled with any of our defined Records. These +// methods will return a tlv.TypeMap containing the records that were extracted +// from an io.Reader. If the record was know and handled by a defined record, +// then the value accompanying the record's type in the map will be nil. +// Otherwise, if the record was unhandled, it will be non-nil. +func ExtraSignedFieldsFromTypeMap(m tlv.TypeMap) map[uint64][]byte { + extraFields := make(map[uint64][]byte) + for t, v := range m { + // If the value in the type map is nil, then it indicates that + // we know this type, and it was handled by one of the records + // we passed to the decode function vai the TLV stream. + if v == nil { + continue + } + + // No need to keep this field if it is unknown to us and is not + // in the sign range. + if InUnsignedRange(t) { + continue + } + + // Otherwise, this is an un-handled type, so we keep track of + // it for signature validation and re-encoding later on. + extraFields[uint64(t)] = v + } + + return extraFields +} diff --git a/lnwire/pure_tlv_test.go b/lnwire/pure_tlv_test.go new file mode 100644 index 0000000000..eb91d63a83 --- /dev/null +++ b/lnwire/pure_tlv_test.go @@ -0,0 +1,389 @@ +package lnwire + +import ( + "bytes" + "io" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// TestPureTLVMessages tests the forwards compatibility of two versions of the +// same Lightning Network message that uses the Pure TLV format. This in essence +// tests that and older client is able to verify the signature over relevant +// data in a newer client's message. +func TestPureTLVMessage(t *testing.T) { + t.Parallel() + + var ( + _, pkA = btcec.PrivKeyFromBytes([]byte{1}) + _, pkB = btcec.PrivKeyFromBytes([]byte{2}) + capacity = MilliSatoshi(100) + ) + + // Test encode and decode of MsgV1 as is. + t.Run("Encode and Decode of MsgV1", func(t *testing.T) { + t.Parallel() + + msgOld := newMsgV1(pkA, &capacity) + + buf := bytes.NewBuffer(nil) + require.NoError(t, msgOld.Encode(buf, 0)) + + var msgOld2 MsgV1 + require.NoError(t, msgOld2.Decode(buf, 0)) + + require.Equal(t, msgOld, &msgOld2) + }) + + // Test encode and decode of MsgV2 as is. + t.Run("Encode and Decode of MsgV2", func(t *testing.T) { + t.Parallel() + + msgNew := newMsgV2( + pkA, &capacity, pkB, []byte{1, 2, 3, 4}, 90, 100, true, + ) + + buf := bytes.NewBuffer(nil) + require.NoError(t, msgNew.Encode(buf, 0)) + + var msgNew2 MsgV2 + require.NoError(t, msgNew2.Decode(buf, 0)) + + require.Equal(t, msgNew, &msgNew2) + }) + + // Create a MsgV2 and decode it into a MsgV1. Both the new client + // (MsgV2) and old client (MsgV1) should be able to generate the same + // digest that will be used to create and validate the signture. + t.Run("Encode MsgV2 and decode via MsgV1", func(t *testing.T) { + t.Parallel() + + var ( + buf = bytes.NewBuffer(nil) + msgV2 = newMsgV2( + pkA, &capacity, pkB, []byte{1, 2, 3, 4}, 100, + 90, true, + ) + ) + require.NoError(t, msgV2.Encode(buf, 0)) + + // Get the serialised bytes that would be signed for msgV2. + signData1, err := SerialiseFieldsToSign(msgV2) + require.NoError(t, err) + + // Decoding via the old message should store some of the extra + // fields. + var msgV1 MsgV1 + require.NoError(t, msgV1.Decode(buf, 0)) + require.NotEmpty(t, msgV1.ExtraFieldsInSignedRange) + + // Show that the extra fields map contains unknown fields in the + // signed range but not unknown fields in the unsigned range. + _, ok := msgV1.ExtraFieldsInSignedRange[uint64(msgV2.Num.TlvType())] //nolint:lll + require.True(t, ok) + _, ok = msgV1.ExtraFieldsInSignedRange[uint64(msgV2.Other.TlvType())] //nolint:lll + require.False(t, ok) + + // The serialised bytes to verify the signature against should + // be the same though. + signData2, err := SerialiseFieldsToSign(&msgV1) + require.NoError(t, err) + + require.Equal(t, signData1, signData2) + + // Re-encoding via the old message should keep the extra fields. + buf = bytes.NewBuffer(nil) + require.NoError(t, msgV1.Encode(buf, 0)) + + var msgV1ReEncoded MsgV1 + require.NoError(t, msgV1ReEncoded.Decode(buf, 0)) + + require.Equal(t, &msgV1, &msgV1ReEncoded) + }) +} + +// MsgV1 represents a more minimal, first version of a Lightning Network +// message. +type MsgV1 struct { + // Two known fields in the signed range. + NodeKey tlv.RecordT[tlv.TlvType0, *btcec.PublicKey] + Capacity tlv.OptionalRecordT[tlv.TlvType1, MilliSatoshi] + + // Signature in the unsigned range. + Signature tlv.RecordT[tlv.TlvType160, Sig] + + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraFieldsInSignedRange map[uint64][]byte +} + +var _ Message = (*MsgV1)(nil) +var _ PureTLVMessage = (*MsgV1)(nil) + +// newMsgV1 is a constructor for MsgV1. +func newMsgV1(nodeKey *btcec.PublicKey, capacity *MilliSatoshi) *MsgV1 { + newMsg := &MsgV1{ + NodeKey: tlv.NewPrimitiveRecord[tlv.TlvType0]( + nodeKey, + ), + Signature: tlv.NewRecordT[tlv.TlvType160]( + testSchnorrSig, + ), + ExtraFieldsInSignedRange: make(map[uint64][]byte), + } + + if capacity != nil { + newMsg.Capacity = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType1](*capacity), + ) + } + + return newMsg +} + +// Decode deserializes a serialized MsgV1 in the passed io.Reader. +// +// This is part of the lnwire.Message interface. +func (g *MsgV1) Decode(r io.Reader, _ uint32) error { + var capacity = tlv.ZeroRecordT[tlv.TlvType1, MilliSatoshi]() + stream, err := tlv.NewStream( + ProduceRecordsSorted( + &g.NodeKey, + &capacity, + &g.Signature, + )..., + ) + if err != nil { + return err + } + g.Signature.Val.ForceSchnorr() + + typeMap, err := stream.DecodeWithParsedTypesP2P(r) + if err != nil { + return err + } + + if _, ok := typeMap[g.Capacity.TlvType()]; ok { + g.Capacity = tlv.SomeRecordT(capacity) + } + + g.ExtraFieldsInSignedRange = ExtraSignedFieldsFromTypeMap(typeMap) + + return nil +} + +// Encode serializes the target MsgV1 into the passed buffer. +// +// This is part of the lnwire.Message interface. +func (g *MsgV1) Encode(buf *bytes.Buffer, _ uint32) error { + return EncodePureTLVMessage(g, buf) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (g *MsgV1) MsgType() MessageType { + return 7777 +} + +// AllRecords returns all the TLV records for the message. This will +// include all the records we know about along with any that we don't +// know about but that fall in the signed TLV range. +// +// This is part of the PureTLVMessage interface. +func (g *MsgV1) AllRecords() []tlv.Record { + recordProducers := []tlv.RecordProducer{ + &g.NodeKey, + &g.Signature, + } + recordProducers = append( + recordProducers, + RecordsAsProducers( + tlv.MapToRecords(g.ExtraFieldsInSignedRange), + )..., + ) + + g.Capacity.WhenSome( + func(capacity tlv.RecordT[tlv.TlvType1, MilliSatoshi]) { + recordProducers = append(recordProducers, &capacity) + }, + ) + + return ProduceRecordsSorted(recordProducers...) +} + +// MsgV2 represents a newer version of MsgV1 which contains more fields both in +// the unsigned and signed TLV ranges. +type MsgV2 struct { + NodeKey tlv.RecordT[tlv.TlvType0, *btcec.PublicKey] + Capacity tlv.OptionalRecordT[tlv.TlvType1, MilliSatoshi] + + // An additional fields (optional) in the signed range. + BitcoinKey tlv.OptionalRecordT[tlv.TlvType3, *btcec.PublicKey] + + // A zero length TLV in the signed range. + SecondPeer tlv.OptionalRecordT[tlv.TlvType5, TrueBoolean] + + // Signature in the unsigned range. + Signature tlv.RecordT[tlv.TlvType160, Sig] + + // Another field in the unsigned range. An older node can throw this + // away. + SPVProof tlv.RecordT[tlv.TlvType161, []byte] + + // A new field in the second signed range. An older node should keep + // this since it is part of the serialised message that is signed. + Num tlv.RecordT[tlv.TlvType1000000000, uint8] + + // Another field in the second unsigned-range. Older nodes may throw + // this away and it won't affect the digest used for signature creation + // and validation. + Other tlv.RecordT[tlv.TlvType3000000000, uint8] + + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraFieldsInSignedRange map[uint64][]byte +} + +// newMsgV2 is a constructor for MsgV2. +func newMsgV2(nodeKey *btcec.PublicKey, capacity *MilliSatoshi, + btcKey *btcec.PublicKey, spvProof []byte, num, other uint8, + secondPeer bool) *MsgV2 { + + newMsg := &MsgV2{ + NodeKey: tlv.NewPrimitiveRecord[tlv.TlvType0](nodeKey), + SPVProof: tlv.NewPrimitiveRecord[tlv.TlvType161](spvProof), + Num: tlv.NewPrimitiveRecord[tlv.TlvType1000000000](num), + Other: tlv.NewPrimitiveRecord[tlv.TlvType3000000000](num), + Signature: tlv.NewRecordT[tlv.TlvType160]( + testSchnorrSig, + ), + ExtraFieldsInSignedRange: make(map[uint64][]byte), + } + + if secondPeer { + newMsg.SecondPeer = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType5](TrueBoolean{}), + ) + } + + if capacity != nil { + newMsg.Capacity = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType1](*capacity), + ) + } + + if btcKey != nil { + newMsg.BitcoinKey = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType3](btcKey), + ) + } + + return newMsg +} + +// Decode deserializes a serialized MsgV2 in the passed io.Reader. +// +// This is part of the lnwire.Message interface. +func (g *MsgV2) Decode(r io.Reader, _ uint32) error { + var ( + capacity = tlv.ZeroRecordT[tlv.TlvType1, MilliSatoshi]() + btcKey = tlv.ZeroRecordT[tlv.TlvType3, *btcec.PublicKey]() + secondPeer = tlv.ZeroRecordT[tlv.TlvType5, TrueBoolean]() + ) + + stream, err := tlv.NewStream( + ProduceRecordsSorted( + &g.NodeKey, + &capacity, + &btcKey, + &secondPeer, + &g.Signature, + &g.SPVProof, + &g.Num, + &g.Other, + )..., + ) + if err != nil { + return err + } + g.Signature.Val.ForceSchnorr() + + typeMap, err := stream.DecodeWithParsedTypesP2P(r) + if err != nil { + return err + } + + if _, ok := typeMap[g.Capacity.TlvType()]; ok { + g.Capacity = tlv.SomeRecordT(capacity) + } + + if _, ok := typeMap[g.SecondPeer.TlvType()]; ok { + g.SecondPeer = tlv.SomeRecordT(secondPeer) + } + + if _, ok := typeMap[g.BitcoinKey.TlvType()]; ok { + g.BitcoinKey = tlv.SomeRecordT(btcKey) + } + + g.ExtraFieldsInSignedRange = ExtraSignedFieldsFromTypeMap(typeMap) + + return nil +} + +// Encode serializes the target MsgV2 into the passed buffer. +// +// This is part of the lnwire.Message interface. +func (g *MsgV2) Encode(buf *bytes.Buffer, _ uint32) error { + return EncodePureTLVMessage(g, buf) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (g *MsgV2) MsgType() MessageType { + return 7779 +} + +// AllRecords returns all the TLV records for the message. This will +// include all the records we know about along with any that we don't +// know about but that fall in the signed TLV range. +// +// This is part of the PureTLVMessage interface. +func (g *MsgV2) AllRecords() []tlv.Record { + recordProducers := []tlv.RecordProducer{ + &g.NodeKey, + &g.Signature, + &g.SPVProof, + &g.Num, + &g.Other, + } + recordProducers = append(recordProducers, RecordsAsProducers( + tlv.MapToRecords(g.ExtraFieldsInSignedRange), + )...) + + g.Capacity.WhenSome( + func(cap tlv.RecordT[tlv.TlvType1, MilliSatoshi]) { + recordProducers = append(recordProducers, &cap) + }, + ) + g.BitcoinKey.WhenSome( + func(key tlv.RecordT[tlv.TlvType3, *btcec.PublicKey]) { + recordProducers = append(recordProducers, &key) + }, + ) + g.SecondPeer.WhenSome( + func(second tlv.RecordT[tlv.TlvType5, TrueBoolean]) { + recordProducers = append(recordProducers, &second) + }, + ) + + return ProduceRecordsSorted(recordProducers...) +} diff --git a/netann/channel_announcement.go b/netann/channel_announcement.go index 9644a523ff..d5cf967f8e 100644 --- a/netann/channel_announcement.go +++ b/netann/channel_announcement.go @@ -8,7 +8,9 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" + "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" @@ -108,7 +110,8 @@ func CreateChanAnnouncement(chanProof *models.ChannelAuthProof, // FetchPkScript defines a function that can be used to fetch the output script // for the transaction with the given SCID. -type FetchPkScript func(*lnwire.ShortChannelID) ([]byte, error) +type FetchPkScript func(*lnwire.ShortChannelID) (txscript.ScriptClass, + btcutil.Address, error) // ValidateChannelAnn validates the channel announcement. func ValidateChannelAnn(a lnwire.ChannelAnnouncement, @@ -202,24 +205,124 @@ func validateChannelAnn1(a *lnwire.ChannelAnnouncement1) error { func validateChannelAnn2(a *lnwire.ChannelAnnouncement2, fetchPkScript FetchPkScript) error { + // Next, we fetch the funding transaction's PK script. We need this so + // that we know what type of channel we will be validating: P2WSH or + // P2TR. + scriptClass, scriptAddr, err := fetchPkScript(&a.ShortChannelID.Val) + if err != nil { + return err + } + + var keys []*btcec.PublicKey + + switch scriptClass { + case txscript.WitnessV0ScriptHashTy: + keys, err = chanAnn2P2WSHMuSig2Keys(a) + if err != nil { + return err + } + case txscript.WitnessV1TaprootTy: + keys, err = chanAnn2P2TRMuSig2Keys(a, scriptAddr) + if err != nil { + return err + } + default: + return fmt.Errorf("invalid on-chain pk script type for "+ + "channel_announcement_2: %s", scriptClass) + } + + // Do a MuSig2 aggregation of the keys to obtain the aggregate key that + // the signature will be validated against. + aggKey, _, _, err := musig2.AggregateKeys(keys, true) + if err != nil { + return err + } + + // Get the message that the signature should have signed. dataHash, err := ChanAnn2DigestToSign(a) if err != nil { return err } - sig, err := a.Signature.ToSignature() + // Obtain the signature. + sig, err := a.Signature.Val.ToSignature() if err != nil { return err } + // Check that the signature is valid for the aggregate key given the + // message digest. + if !sig.Verify(dataHash.CloneBytes(), aggKey.FinalKey) { + return fmt.Errorf("invalid sig") + } + + return nil +} + +// chanAnn2P2WSHMuSig2Keys returns the set of keys that should be used to +// construct the aggregate key that the signature in an +// lnwire.ChannelAnnouncement2 message should be verified against in the case +// where the channel being announced is a P2WSH channel. +func chanAnn2P2WSHMuSig2Keys(a *lnwire.ChannelAnnouncement2) ( + []*btcec.PublicKey, error) { + nodeKey1, err := btcec.ParsePubKey(a.NodeID1.Val[:]) if err != nil { - return err + return nil, err } nodeKey2, err := btcec.ParsePubKey(a.NodeID2.Val[:]) if err != nil { - return err + return nil, err + } + + btcKeyMissingErrString := "bitcoin key %d missing for announcement " + + "of a P2WSH channel" + + btcKey1Bytes, err := a.BitcoinKey1.UnwrapOrErr( + fmt.Errorf(btcKeyMissingErrString, 1), + ) + if err != nil { + return nil, err + } + + btcKey1, err := btcec.ParsePubKey(btcKey1Bytes.Val[:]) + if err != nil { + return nil, err + } + + btcKey2Bytes, err := a.BitcoinKey2.UnwrapOrErr( + fmt.Errorf(btcKeyMissingErrString, 2), + ) + if err != nil { + return nil, err + } + + btcKey2, err := btcec.ParsePubKey(btcKey2Bytes.Val[:]) + if err != nil { + return nil, err + } + + return []*btcec.PublicKey{ + nodeKey1, nodeKey2, btcKey1, btcKey2, + }, nil +} + +// chanAnn2P2TRMuSig2Keys returns the set of keys that should be used to +// construct the aggregate key that the signature in an +// lnwire.ChannelAnnouncement2 message should be verified against in the case +// where the channel being announced is a P2TR channel. +func chanAnn2P2TRMuSig2Keys(a *lnwire.ChannelAnnouncement2, + scriptAddr btcutil.Address) ([]*btcec.PublicKey, error) { + + nodeKey1, err := btcec.ParsePubKey(a.NodeID1.Val[:]) + if err != nil { + return nil, err + } + + nodeKey2, err := btcec.ParsePubKey(a.NodeID2.Val[:]) + if err != nil { + return nil, err } keys := []*btcec.PublicKey{ @@ -240,49 +343,36 @@ func validateChannelAnn2(a *lnwire.ChannelAnnouncement2, bitcoinKey1, err := btcec.ParsePubKey(btcKey1.Val[:]) if err != nil { - return err + return nil, err } bitcoinKey2, err := btcec.ParsePubKey(btcKey2.Val[:]) if err != nil { - return err + return nil, err } keys = append(keys, bitcoinKey1, bitcoinKey2) } else { - // If bitcoin keys are not provided, then we need to get the - // on-chain output key since this will be the 3rd key in the - // 3-of-3 MuSig2 signature. - pkScript, err := fetchPkScript(&a.ShortChannelID.Val) - if err != nil { - return err - } - - outputKey, err := schnorr.ParsePubKey(pkScript[2:]) + // If bitcoin keys are not provided, then the on-chain output + // key is considered the 3rd key in the 3-of-3 MuSig2 signature. + outputKey, err := schnorr.ParsePubKey( + scriptAddr.ScriptAddress(), + ) if err != nil { - return err + return nil, err } keys = append(keys, outputKey) } - aggKey, _, _, err := musig2.AggregateKeys(keys, true) - if err != nil { - return err - } - - if !sig.Verify(dataHash.CloneBytes(), aggKey.FinalKey) { - return fmt.Errorf("invalid sig") - } - - return nil + return keys, nil } // ChanAnn2DigestToSign computes the digest of the message to be signed. func ChanAnn2DigestToSign(a *lnwire.ChannelAnnouncement2) (*chainhash.Hash, error) { - data, err := a.DataToSign() + data, err := lnwire.SerialiseFieldsToSign(a) if err != nil { return nil, err } diff --git a/netann/channel_announcement_test.go b/netann/channel_announcement_test.go index 61db16b16e..f439f59172 100644 --- a/netann/channel_announcement_test.go +++ b/netann/channel_announcement_test.go @@ -9,6 +9,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/input" @@ -76,20 +77,25 @@ func TestChanAnnounce2Validation(t *testing.T) { t.Parallel() t.Run( - "test 4-of-4 MuSig2 channel announcement", - test4of4MuSig2ChanAnnouncement, + "test 4-of-4 MuSig2 P2TR channel announcement", + test4of4MuSig2P2TRChanAnnouncement, ) t.Run( - "test 3-of-3 MuSig2 channel announcement", + "test 3-of-3 MuSig2 P2TR channel announcement", test3of3MuSig2ChanAnnouncement, ) + + t.Run( + "test 4-of-4 MuSig2 P2WSH channel announcement", + test4of4MuSig2P2WSHChanAnnouncement, + ) } -// test4of4MuSig2ChanAnnouncement covers the case where both bitcoin keys are -// present in the channel announcement. In this case, the signature should be -// a 4-of-4 MuSig2. -func test4of4MuSig2ChanAnnouncement(t *testing.T) { +// test4of4MuSig2P2TRChanAnnouncement covers the case where the funding +// transaction PK script is a P2WSH. In this case, the signature should be valid +// for the MuSig2 4-of-4 aggregation of the node keys and the bitcoin keys. +func test4of4MuSig2P2WSHChanAnnouncement(t *testing.T) { t.Parallel() // Generate the keys for node 1 and node2. @@ -162,10 +168,138 @@ func test4of4MuSig2ChanAnnouncement(t *testing.T) { sig, err := lnwire.NewSigFromSignature(s) require.NoError(t, err) - ann.Signature = sig + ann.Signature.Val = sig + + // Create an accurate representation of what the on-chain pk script will + // look like. For this case, it is only important that we get the + // correct script class. + multiSigScript, err := input.GenMultiSigScript( + node1.btcPub.SerializeCompressed(), + node2.btcPub.SerializeCompressed(), + ) + require.NoError(t, err) + + scriptHash, err := input.WitnessScriptHash(multiSigScript) + require.NoError(t, err) + pkAddr, err := btcutil.NewAddressScriptHash( + scriptHash, &chaincfg.MainNetParams, + ) + require.NoError(t, err) + + // Create a mock tx fetcher that returns the expected script class and + // pk address. + fetchTx := func(*lnwire.ShortChannelID) (txscript.ScriptClass, + btcutil.Address, error) { + + return txscript.WitnessV0ScriptHashTy, pkAddr, nil + } // Validate the announcement. - require.NoError(t, ValidateChannelAnn(ann, nil)) + require.NoError(t, ValidateChannelAnn(ann, fetchTx)) +} + +// test4of4MuSig2P2TRChanAnnouncement covers the case where both bitcoin keys +// are present in the channel announcement 2 and the funding transaction PK +// script is a P2TR. In this case, the signature should be a 4-of-4 MuSig2. +func test4of4MuSig2P2TRChanAnnouncement(t *testing.T) { + t.Parallel() + + // Generate the keys for node 1 and node2. + node1, node2 := genChanAnnKeys(t) + + // Build the unsigned channel announcement. + ann := buildUnsignedChanAnnouncement(node1, node2, true) + + // Serialise the bytes that need to be signed. + msg, err := ChanAnn2DigestToSign(ann) + require.NoError(t, err) + + var msgBytes [32]byte + copy(msgBytes[:], msg.CloneBytes()) + + // Generate the 4 nonces required for producing the signature. + var ( + node1NodeNonce = genNonceForPubKey(t, node1.nodePub) + node1BtcNonce = genNonceForPubKey(t, node1.btcPub) + node2NodeNonce = genNonceForPubKey(t, node2.nodePub) + node2BtcNonce = genNonceForPubKey(t, node2.btcPub) + ) + + nonceAgg, err := musig2.AggregateNonces([][66]byte{ + node1NodeNonce.PubNonce, + node1BtcNonce.PubNonce, + node2NodeNonce.PubNonce, + node2BtcNonce.PubNonce, + }) + require.NoError(t, err) + + pubKeys := []*btcec.PublicKey{ + node1.nodePub, node2.nodePub, node1.btcPub, node2.btcPub, + } + + // Let Node1 sign the announcement message with its node key. + psA1, err := musig2.Sign( + node1NodeNonce.SecNonce, node1.nodePriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Let Node1 sign the announcement message with its bitcoin key. + psA2, err := musig2.Sign( + node1BtcNonce.SecNonce, node1.btcPriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Let Node2 sign the announcement message with its node key. + psB1, err := musig2.Sign( + node2NodeNonce.SecNonce, node2.nodePriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Let Node2 sign the announcement message with its bitcoin key. + psB2, err := musig2.Sign( + node2BtcNonce.SecNonce, node2.btcPriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Finally, combine the partial signatures from Node1 and Node2 and add + // the signature to the announcement message. + s := musig2.CombineSigs(psA1.R, []*musig2.PartialSignature{ + psA1, psA2, psB1, psB2, + }) + + sig, err := lnwire.NewSigFromSignature(s) + require.NoError(t, err) + + ann.Signature.Val = sig + + // Create an accurate representation of what the on-chain pk script will + // look like. For this case, it is only important that we get the + // correct script class. + combinedKey, _, _, err := musig2.AggregateKeys( + []*btcec.PublicKey{node1.btcPub, node2.btcPub}, true, + ) + require.NoError(t, err) + + pkAddr, err := btcutil.NewAddressTaproot( + combinedKey.FinalKey.SerializeCompressed()[1:], + &chaincfg.MainNetParams, + ) + require.NoError(t, err) + + // Create a mock tx fetcher that returns the expected script class and + // pk address. + fetchTx := func(*lnwire.ShortChannelID) (txscript.ScriptClass, + btcutil.Address, error) { + + return txscript.WitnessV1TaprootTy, pkAddr, nil + } + + // Validate the announcement. + require.NoError(t, ValidateChannelAnn(ann, fetchTx)) } // test3of3MuSig2ChanAnnouncement covers the case where no bitcoin keys are @@ -220,14 +354,17 @@ func test3of3MuSig2ChanAnnouncement(t *testing.T) { }) require.NoError(t, err) - pkScript, err := input.PayToTaprootScript(outputKey) + pkAddr, err := btcutil.NewAddressTaproot( + outputKey.SerializeCompressed()[1:], &chaincfg.MainNetParams, + ) require.NoError(t, err) - // We'll pass in a mock tx fetcher that will return the funding output - // containing this key. This is needed since the output key can not be - // determined from the channel announcement itself. - fetchTx := func(chanID *lnwire.ShortChannelID) ([]byte, error) { - return pkScript, nil + // Create a mock tx fetcher that returns the expected script class + // and pk address. + fetchTx := func(*lnwire.ShortChannelID) (txscript.ScriptClass, + btcutil.Address, error) { + + return txscript.WitnessV1TaprootTy, pkAddr, nil } pubKeys := []*btcec.PublicKey{node1.nodePub, node2.nodePub, outputKey} @@ -262,7 +399,7 @@ func test3of3MuSig2ChanAnnouncement(t *testing.T) { sig, err := lnwire.NewSigFromSignature(s) require.NoError(t, err) - ann.Signature = sig + ann.Signature.Val = sig // Validate the announcement. require.NoError(t, ValidateChannelAnn(ann, fetchTx)) diff --git a/netann/channel_update.go b/netann/channel_update.go index af91abdd24..bf64f10f36 100644 --- a/netann/channel_update.go +++ b/netann/channel_update.go @@ -235,7 +235,7 @@ func verifyChannelUpdate2Signature(c *lnwire.ChannelUpdate2, return fmt.Errorf("unable to reconstruct message data: %w", err) } - nodeSig, err := c.Signature.ToSignature() + nodeSig, err := c.Signature.Val.ToSignature() if err != nil { return err } @@ -323,7 +323,7 @@ func ChanUpdate2DigestTag() []byte { // chanUpdate2DigestToSign computes the digest of the ChannelUpdate2 message to // be signed. func chanUpdate2DigestToSign(c *lnwire.ChannelUpdate2) ([]byte, error) { - data, err := c.DataToSign() + data, err := lnwire.SerialiseFieldsToSign(c) if err != nil { return nil, err } diff --git a/server.go b/server.go index 0fc95dc756..eb7d5813d9 100644 --- a/server.go +++ b/server.go @@ -1073,7 +1073,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, Graph: s.graphBuilder, ChainIO: s.cc.ChainIO, Notifier: s.cc.ChainNotifier, - ChainHash: *s.cfg.ActiveNetParams.GenesisHash, + ChainParams: s.cfg.ActiveNetParams.Params, Broadcast: s.BroadcastMessage, ChanSeries: chanSeries, NotifyWhenOnline: s.NotifyWhenOnline,