From 8c0b3018b51d4da9c0b0a9ba06121439495fa64a Mon Sep 17 00:00:00 2001 From: jdomnitz <380352+jdomnitz@users.noreply.github.com> Date: Fri, 13 Sep 2024 17:30:54 -0400 Subject: [PATCH] Security improvements --- README.md | 4 +--- ZWaveDotNet/CommandClasses/Configuration.cs | 19 +++++++--------- ZWaveDotNet/CommandClasses/Security2.cs | 20 ++++++++++++++--- .../CommandClasses/TransportService.cs | 8 +++---- ZWaveDotNet/Entities/Controller.cs | 2 +- ZWaveDotNet/Security/AES.cs | 8 +++---- ZWaveDotNet/Security/SecurityManager.cs | 22 ++++++++++++++++++- ZWaveDotNet/Util/MemoryUtil.cs | 9 ++++++-- 8 files changed, 63 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index c6f21fe..dcffec4 100644 --- a/README.md +++ b/README.md @@ -17,9 +17,7 @@ An implementation of ZWave Plus using the 2024a public specification. * See our [Examples Page](Examples.md) #### Work in progress: -* Multicast is not yet exposed -* Security2 multicast is not implemented +* Multicast is not fully implemented (including secure multicast) * Transport CC is receive only and not fully implemented (Very few devices use this) -* Node interviews are only partially implemented according to the spec Testers, Tickets, Feedback and PRs are welcome. \ No newline at end of file diff --git a/ZWaveDotNet/CommandClasses/Configuration.cs b/ZWaveDotNet/CommandClasses/Configuration.cs index a856320..04c0e6e 100644 --- a/ZWaveDotNet/CommandClasses/Configuration.cs +++ b/ZWaveDotNet/CommandClasses/Configuration.cs @@ -49,31 +49,28 @@ public async Task Get(byte parameter, CancellationToken can public async Task SetDefault(byte parameter, CancellationToken cancellationToken = default) { - await Set(parameter, 0, 0, cancellationToken, true); + await Set(parameter, 0, cancellationToken, true); } public async Task Set(byte parameter, sbyte value, CancellationToken cancellationToken = default) { - await Set(parameter, value, 0, cancellationToken); + await Set(parameter, value, cancellationToken); } public async Task Set(byte parameter, short value, CancellationToken cancellationToken = default) { - await Set(parameter, value, 0, cancellationToken); + await Set(parameter, value, cancellationToken); } public async Task Set(byte parameter, int value, CancellationToken cancellationToken = default) { - await Set(parameter, value, 0, cancellationToken); + await Set(parameter, value, cancellationToken); } - private async Task Set(byte parameter, int value, byte size, CancellationToken cancellationToken = default, bool reset = false) + private async Task Set(byte parameter, int value, CancellationToken cancellationToken = default, bool resetToDefault = false) { - if (size == 0) - { - ReportMessage response = await SendReceive(ConfigurationCommand.Get, ConfigurationCommand.Report, cancellationToken); - size = response.Payload.Span[1]; - } + ReportMessage response = await SendReceive(ConfigurationCommand.Get, ConfigurationCommand.Report, cancellationToken); + byte size = response.Payload.Span[1]; var values = new byte[size]; switch (size) @@ -90,7 +87,7 @@ private async Task Set(byte parameter, int value, byte size, CancellationToken c default: throw new NotSupportedException($"Size:{size} is not supported"); } - if (reset) + if (resetToDefault) size |= 0x80; await SendCommand(ConfigurationCommand.Set, cancellationToken, new[] { parameter, size }.Concat(values).ToArray()); } diff --git a/ZWaveDotNet/CommandClasses/Security2.cs b/ZWaveDotNet/CommandClasses/Security2.cs index 6d67099..e31b203 100644 --- a/ZWaveDotNet/CommandClasses/Security2.cs +++ b/ZWaveDotNet/CommandClasses/Security2.cs @@ -262,6 +262,11 @@ public async Task Encapsulate(List payload, SecurityManager.RecordType? ty decoded = Decrypt(msg, controller, networkKey, ad, ref i); if (decoded != null) break; + else if (controller.SecurityManager.HasKey(msg.SourceNodeID, SecurityManager.RecordType.ECDH_TEMP)) + { + using (CancellationTokenSource cts = new CancellationTokenSource(3000)) + await controller.Nodes[msg.SourceNodeID].GetCommandClass()!.KexFail(KexFailType.KEX_FAIL_KEY_VERIFY).ConfigureAwait(false); + } else if (i == 2) { try @@ -364,13 +369,16 @@ protected override async Task Handle(ReportMessage message) Log.Verbose("Kex Set Received: " + kexReport.ToString()); if (kexReport.Echo) { - //kexReport is the granted keys - //TODO - Send KexFail if attempting to get more keys than we granted if (controller.SecurityManager == null) return SupervisionStatus.Fail; KeyExchangeReport? requestedKeys = controller.SecurityManager.GetRequestedKeys(node.ID); if (requestedKeys != null) { + if (requestedKeys.Keys != kexReport.Keys) + { + await KexFail(KexFailType.KEX_FAIL_AUTH); + return SupervisionStatus.Fail; + } requestedKeys.Echo = true; Log.Verbose("Responding: " + requestedKeys.ToString()); CommandMessage reportKex = new CommandMessage(controller, node.ID, endpoint, commandClass, (byte)Security2Command.KEXReport, false, requestedKeys.ToBytes()); @@ -393,7 +401,13 @@ protected override async Task Handle(ReportMessage message) Log.Verbose("Network Key Get Received"); byte[] resp = new byte[17]; SecurityKey key = (SecurityKey)message.Payload.Span[0]; - //TODO - Verify this was granted + KeyExchangeReport? grantedKeys = controller.SecurityManager.GetRequestedKeys(node.ID); + if (grantedKeys == null || (grantedKeys.Keys & key) != key) + { + await KexFail(KexFailType.KEX_FAIL_KEY_GET); + Log.Error("Network Key Get Received for an ungranted key"); + return SupervisionStatus.Fail; + } resp[0] = (byte)key; switch (key) { diff --git a/ZWaveDotNet/CommandClasses/TransportService.cs b/ZWaveDotNet/CommandClasses/TransportService.cs index e8b677a..1d54abb 100644 --- a/ZWaveDotNet/CommandClasses/TransportService.cs +++ b/ZWaveDotNet/CommandClasses/TransportService.cs @@ -74,7 +74,7 @@ public static void Transmit (List payload) sessionId = (byte)((msg.Payload.Span[1] & 0xF0) >> 4); if ((msg.Payload.Span[1] & 0x8) == 0x8) { - //We skip extensions for now + //No extensions are defined yet Log.Information("Transport Service skipped an extension"); msg.Payload = msg.Payload.Slice(msg.Payload.Span[2] + 3); } @@ -82,7 +82,7 @@ public static void Transmit (List payload) msg.Payload = msg.Payload.Slice(2); chk = crc.ComputeChecksum(msg.Payload.Slice(0, msg.Payload.Length - 2)); if (chk[0] == msg.Payload.Span[msg.Payload.Length - 2] && chk[1] == msg.Payload.Span[msg.Payload.Length - 1]) - Log.Verbose("Transport Checksum is OK"); + Log.Debug("Transport Checksum is OK"); buff = new byte[datagramLen]; msg.Payload.Slice(0, msg.Payload.Length - 2).CopyTo(buff); key = GetKey(msg.SourceNodeID, sessionId); @@ -101,7 +101,7 @@ public static void Transmit (List payload) ushort datagramOffset = (ushort)(((msg.Payload.Span[1] & 0x7) << 8) | msg.Payload.Span[2]); if ((msg.Payload.Span[1] & 0x8) == 0x8) { - //We skip extensions for now + //No extensions are defined yet Log.Information("Transport Service skipped an extension"); msg.Payload = msg.Payload.Slice(msg.Payload.Span[3] + 4); } @@ -109,7 +109,7 @@ public static void Transmit (List payload) msg.Payload = msg.Payload.Slice(3); chk = crc.ComputeChecksum(msg.Payload.Slice(0, msg.Payload.Length - 2)); if (chk[0] == msg.Payload.Span[msg.Payload.Length - 2] && chk[1] == msg.Payload.Span[msg.Payload.Length - 1]) - Log.Verbose("Transport Checksum is OK"); + Log.Debug("Transport Checksum is OK"); key = GetKey(msg.SourceNodeID, sessionId); if (!buffers.TryGetValue(key, out buff)) { diff --git a/ZWaveDotNet/Entities/Controller.cs b/ZWaveDotNet/Entities/Controller.cs index af817c5..afc18fe 100644 --- a/ZWaveDotNet/Entities/Controller.cs +++ b/ZWaveDotNet/Entities/Controller.cs @@ -265,7 +265,7 @@ public async Task> GetRandom(byte length, CancellationToken cancell if (random == null || random.Data.Span[0] == 0x0) //TODO - Status Enums { Memory planB = new byte[length]; - new Random().NextBytes(planB.Span); + RandomNumberGenerator.Fill(planB.Span); return planB; } return random!.Data.Slice(2); diff --git a/ZWaveDotNet/Security/AES.cs b/ZWaveDotNet/Security/AES.cs index c1dbaf6..f67d62d 100644 --- a/ZWaveDotNet/Security/AES.cs +++ b/ZWaveDotNet/Security/AES.cs @@ -23,13 +23,13 @@ public static class AES public struct KeyTuple { public byte[] KeyCCM; - public byte[] PString; - public byte[] MPAN; + public byte[] PersonalizationString; + public byte[] keyMPAN; public KeyTuple(byte[] keyCCM, byte[] pString, byte[] mPAN) { this.KeyCCM = keyCCM; - this.PString = pString; - this.MPAN = mPAN; + this.PersonalizationString = pString; + this.keyMPAN = mPAN; } } diff --git a/ZWaveDotNet/Security/SecurityManager.cs b/ZWaveDotNet/Security/SecurityManager.cs index 7da09d2..4ce62f4 100644 --- a/ZWaveDotNet/Security/SecurityManager.cs +++ b/ZWaveDotNet/Security/SecurityManager.cs @@ -84,7 +84,7 @@ public void GrantKey(ushort nodeId, RecordType type, byte[]? key = null, bool te if (key == null) throw new ArgumentNullException(nameof(key)); AES.KeyTuple keyTuple = AES.CKDFExpand(key, temp); - StoreKey(nodeId, type, keyTuple.KeyCCM, keyTuple.PString, keyTuple.MPAN); + StoreKey(nodeId, type, keyTuple.KeyCCM, keyTuple.PersonalizationString, keyTuple.keyMPAN); } private void StoreKey(ushort nodeId, RecordType type, byte[]? keyCCM, byte[]? pString, byte[]? mPAN) @@ -143,6 +143,11 @@ public RecordType[] GetKeys(ushort nodeId) return null; } + public bool HasKey(ushort nodeId, RecordType key) + { + return GetKey(nodeId, key) != null; + } + public void RevokeKey(ushort nodeId, RecordType type) { if (keys.TryGetValue(nodeId, out List? keyLst)) @@ -217,6 +222,21 @@ public bool IsSequenceNew(ushort nodeId, byte sequence) return null; } + public Memory? CurrentMpanNonce(byte groupID, byte[] keyMPAN) + { + if (mpanRecords.TryGetValue(groupID, out MpanRecord? record)) + { + Memory result = new byte[16]; + using (Aes aes = Aes.Create()) + { + aes.Key = keyMPAN; + aes.EncryptEcb(record.Bytes.Span, result.Span, PaddingMode.None); + } + return result; + } + return null; + } + public Memory? NextMpanNonce(byte groupID, byte[] keyMPAN) { if (mpanRecords.TryGetValue(groupID, out MpanRecord? record)) diff --git a/ZWaveDotNet/Util/MemoryUtil.cs b/ZWaveDotNet/Util/MemoryUtil.cs index 14e5947..53eb142 100644 --- a/ZWaveDotNet/Util/MemoryUtil.cs +++ b/ZWaveDotNet/Util/MemoryUtil.cs @@ -73,8 +73,13 @@ public static void Increment(Span mem) public static string Print(Memory mem) { - StringBuilder ret = new StringBuilder(mem.Length * 3); - foreach (byte b in mem.Span) + return Print(mem.Span); + } + + public static string Print(ReadOnlySpan span) + { + StringBuilder ret = new StringBuilder(span.Length * 3); + foreach (byte b in span) { if (ret.Length > 0) ret.Append(' ');