Skip to content

Commit

Permalink
[feat]支持ProxyProtocol协议,MqttServer能够获取经过代理之前的真实客户端IP地址,支持nginx/haprox…
Browse files Browse the repository at this point in the history
…y等反向代理
  • Loading branch information
nnhy committed Nov 3, 2024
1 parent 727e10c commit 3701f6c
Show file tree
Hide file tree
Showing 9 changed files with 300 additions and 25 deletions.
6 changes: 3 additions & 3 deletions NewLife.MQTT/Clusters/ClusterExchange.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class ClusterExchange : DisposeBase, ITracerFeature
/// <summary>订阅主题</summary>
/// <param name="session"></param>
/// <param name="message"></param>
public virtual void Subscribe(INetSession session, SubscribeMessage message)
public virtual void Subscribe(MqttSession session, SubscribeMessage message)
{
var myEndpoint = Cluster.GetNodeInfo().EndPoint;
var infos = message.Requests.Select(e => new SubscriptionInfo
Expand Down Expand Up @@ -55,7 +55,7 @@ public virtual void Subscribe(INetSession session, SubscribeMessage message)
/// <summary>取消主题订阅</summary>
/// <param name="session"></param>
/// <param name="message"></param>
public virtual void Unsubscribe(INetSession session, UnsubscribeMessage message)
public virtual void Unsubscribe(MqttSession session, UnsubscribeMessage message)
{
var myEndpoint = Cluster.GetNodeInfo().EndPoint;
var infos = message.TopicFilters.Select(e => new SubscriptionInfo
Expand Down Expand Up @@ -117,7 +117,7 @@ public void RemoveSubscription(SubscriptionInfo info)
/// </remarks>
/// <param name="session"></param>
/// <param name="message"></param>
public virtual void Publish(INetSession session, PublishMessage message)
public virtual void Publish(MqttSession session, PublishMessage message)
{
PublishInfo? info = null;

Expand Down
4 changes: 1 addition & 3 deletions NewLife.MQTT/Handlers/IMqttHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
using NewLife.Log;
using NewLife.MQTT.Clusters;
using NewLife.MQTT.Messaging;
using NewLife.Net;
using NewLife.Serialization;

namespace NewLife.MQTT.Handlers;

Expand Down Expand Up @@ -48,7 +46,7 @@ public interface IMqttHandler
public class MqttHandler : IMqttHandler, ITracerFeature, ILogFeature
{
/// <summary>网络会话</summary>
public INetSession Session { get; set; } = null!;
public MqttSession Session { get; set; } = null!;

/// <summary>消息交换机</summary>
public IMqttExchange? Exchange { get; set; }
Expand Down
7 changes: 7 additions & 0 deletions NewLife.MQTT/MqttClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using NewLife.Data;
using NewLife.Log;
using NewLife.MQTT.Messaging;
using NewLife.MQTT.ProxyProtocol;
using NewLife.Net;
using NewLife.Serialization;
using NewLife.Threading;
Expand Down Expand Up @@ -65,6 +66,9 @@ public class MqttClient : DisposeBase
/// </summary>
public Boolean Reconnect { get; set; } = true;

/// <summary>启用ProxyProtocol。仿造MQTT报文通过nginx/haproxy时的封包,仅用于测试,实际应用没有意义,默认false</summary>
public Boolean EnableProxyProtocol { get; set; }

/// <summary>
/// 连接成功后赋值为true
/// </summary>
Expand Down Expand Up @@ -152,6 +156,8 @@ private void Init()

client.Log = Log;
client.Timeout = Timeout;

if (EnableProxyProtocol) client.Add(new ProxyCodec());
client.Add(new MqttCodec());

// 关闭Tcp延迟以合并小包的算法,降低延迟
Expand Down Expand Up @@ -440,6 +446,7 @@ private void Client_Closed(Object sender, EventArgs e)
Disconnected?.Invoke(this, e);

if (Disposed || !Reconnect) return;

WriteLog("尝试重新连接");
ConnectAsync().GetAwaiter();
}
Expand Down
27 changes: 20 additions & 7 deletions NewLife.MQTT/MqttServer.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using NewLife.Data;
using System.Net;
using NewLife.Data;
using NewLife.Log;
using NewLife.Model;
using NewLife.MQTT.Clusters;
using NewLife.MQTT.Handlers;
using NewLife.MQTT.Messaging;
using NewLife.MQTT.ProxyProtocol;
using NewLife.Net;
using NewLife.Serialization;

Expand Down Expand Up @@ -37,20 +39,20 @@ public class MqttServer : NetServer<MqttSession>
/// <summary>启动</summary>
protected override void OnStart()
{
if (ServiceProvider == null) throw new NotSupportedException("未配置服务提供者ServiceProvider");
var provider = (ServiceProvider ?? ObjectContainer.Provider) ?? throw new NotSupportedException("未配置服务提供者ServiceProvider");

Name = $"Mqtt{Port}";

//JsonHost ??= ServiceProvider.GetService<IJsonHost>() ?? JsonHelper.Default;
Encoder ??= ServiceProvider.GetService<IPacketEncoder>() ?? new DefaultPacketEncoder();
Encoder ??= provider.GetService<IPacketEncoder>() ?? new DefaultPacketEncoder();
if (Encoder is DefaultPacketEncoder encoder)
{
var jsonHost = ServiceProvider.GetService<IJsonHost>();
var jsonHost = provider.GetService<IJsonHost>();
if (jsonHost != null) encoder.JsonHost = jsonHost;
}

var exchange = Exchange;
exchange ??= ServiceProvider.GetService<IMqttExchange>();
exchange ??= provider.GetService<IMqttExchange>();
exchange ??= new MqttExchange();

if (exchange is ITracerFeature feature)
Expand All @@ -61,6 +63,8 @@ protected override void OnStart()
// 创建集群
CreateCluster();

// 解码ProxyProtocol
Add(new ProxyCodec());
Add(new MqttCodec());

base.OnStart();
Expand All @@ -76,10 +80,11 @@ protected virtual void CreateCluster()

if (cluster != null)
{
var provider = (ServiceProvider ?? ObjectContainer.Provider) ?? throw new NotSupportedException("未配置服务提供者ServiceProvider");
var exchange = Exchange ?? throw new NotSupportedException("未配置消息交换机Exchange");

// 启动集群服务
cluster.ServiceProvider = ServiceProvider;
cluster.ServiceProvider = provider;
cluster.Log = Log;
cluster.Start();

Expand All @@ -96,7 +101,7 @@ protected virtual void CreateCluster()
}

// 创建集群交换机
var exchange2 = ServiceProvider?.GetService<ClusterExchange>();
var exchange2 = provider?.GetService<ClusterExchange>();
exchange2 ??= new ClusterExchange();

exchange2.Cluster = cluster;
Expand All @@ -120,6 +125,9 @@ protected override void OnStop(String? reason)
/// <summary>会话</summary>
public class MqttSession : NetSession<MqttServer>
{
/// <summary>远程地址信息。经过代理之前的地址</summary>
public new NetUri Remote { get; set; } = null!;

/// <summary>指令处理器</summary>
public IMqttHandler MqttHandler { get; set; } = null!;

Expand All @@ -140,6 +148,7 @@ protected override void OnConnected()
mqttHandler.Encoder = Host.Encoder;
}

Remote = base.Remote;
MqttHandler = handler;

base.OnConnected();
Expand Down Expand Up @@ -168,6 +177,10 @@ protected override void OnReceive(ReceivedEventArgs e)
using var span = Host.Tracer?.NewSpan($"mqtt:{msg.Type}", msg);
try
{
// 在连接指令中,修正远程地址,替换为经过代理之前的地址
//todo: 暂时通过覆盖Remote属性实现,后续考虑在NetSession中直接支持设置Remote属性,或者在核心库支持ProxyProtocol协议
if (msg.Type is MqttType.Connect && e.Remote != null) Remote.EndPoint = e.Remote;

// 执行处理器
result = MqttHandler?.Process(msg);
}
Expand Down
71 changes: 71 additions & 0 deletions NewLife.MQTT/ProxyProtocol/ProxyCodec.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
using System.Net;
using NewLife.Data;
using NewLife.Model;
using NewLife.Net;

namespace NewLife.MQTT.ProxyProtocol;

/// <summary>ProxyProtocol编码器</summary>
public class ProxyCodec : Handler
{
/// <summary>解析数据包,如果是ProxyProtocol协议则解码后返回</summary>
/// <param name="context"></param>
/// <param name="message"></param>
/// <returns></returns>
public override Object? Read(IHandlerContext context, Object message)
{
if (message is IPacket pk)
{
var data = pk.GetSpan();
if (ProxyMessage.FastValidHeader(data))
{
var msg = new ProxyMessage();
var rs = msg.Read(data);
if (rs > 0)
{
if (context is IExtend ext)
{
ext["Proxy"] = msg;

if (context is NetHandlerContext ctx && ctx.Data is ReceivedEventArgs e)
{
// 修改远程地址
e.Remote = msg.GetClient().EndPoint;
}
}

message = pk.Slice(rs);
}
}
}

return base.Read(context, message);
}

/// <summary>编码数据包,加上ProxyProtocol头部</summary>
/// <param name="context"></param>
/// <param name="message"></param>
/// <returns></returns>
public override Object? Write(IHandlerContext context, Object message)
{
if (message is IPacket pk && context.Owner is ISocketRemote remote)
{
var msg = new ProxyMessage
{
Protocol = "TCP4",
ClientIP = remote.Local.Address + "",
ClientPort = remote.Local.Port,
ProxyIP = remote.Remote.Address + "",
ProxyPort = remote.Remote.Port,
};

var ap = new ArrayPacket(msg.ToPacket().GetBytes());
//ap.Append(pk);
ap.Next = pk;

return base.Write(context, ap);
}

return base.Write(context, message);
}
}
97 changes: 97 additions & 0 deletions NewLife.MQTT/ProxyProtocol/ProxyMessage.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
using System.Net;
using System.Text;
using NewLife.Collections;
using NewLife.Data;
using NewLife.Net;

namespace NewLife.MQTT.ProxyProtocol;

/// <summary>ProxyProtocol协议消息</summary>
public class ProxyMessage
{
#region 属性
/// <summary>内部协议。TCP4等</summary>
public String? Protocol { get; set; }

/// <summary>客户端IP地址</summary>
public String? ClientIP { get; set; }

/// <summary>代理IP地址</summary>
public String? ProxyIP { get; set; }

/// <summary>客户端端口</summary>
public Int32 ClientPort { get; set; }

/// <summary>代理端口</summary>
public Int32 ProxyPort { get; set; }
#endregion

#region 核心读写方法
private static readonly Byte[] _Magic = [(Byte)'P', (Byte)'R', (Byte)'O', (Byte)'X', (Byte)'Y', (Byte)' '];
private static readonly Byte[] _NewLine = [(Byte)'\r', (Byte)'\n'];

/// <summary>快速验证协议头</summary>
/// <param name="data"></param>
/// <returns></returns>
public static Boolean FastValidHeader(ReadOnlySpan<Byte> data) => data.StartsWith(_Magic);

/// <summary>解析协议</summary>
/// <param name="data"></param>
/// <returns></returns>
public Int32 Read(ReadOnlySpan<Byte> data)
{
if (!data.StartsWith(_Magic)) return -1;

var p = _Magic.Length;
var p2 = data[p..].IndexOf(_NewLine);
if (p2 <= 0) return -1;

data = data[..(p + p2)];
var ss = Encoding.ASCII.GetString(data).Split(' ');
if (ss == null || ss.Length < 6) return -1;

Protocol = ss[1];
ClientIP = ss[2];
ProxyIP = ss[3];
ClientPort = ss[4].ToInt();
ProxyPort = ss[5].ToInt();

return p + p2 + _NewLine.Length;
}

/// <summary>转为数据包</summary>
/// <returns></returns>
public String ToPacket()
{
if (Protocol.IsNullOrEmpty()) Protocol = "TCP4";

var sb = Pool.StringBuilder.Get();
sb.Append("PROXY ");
sb.Append(Protocol);
sb.Append(' ');
sb.Append(ClientIP);
sb.Append(' ');
sb.Append(ProxyIP);
sb.Append(' ');
sb.Append(ClientPort);
sb.Append(' ');
sb.Append(ProxyPort);
sb.Append("\r\n");

return sb.Return(true);
}

/// <summary>获取客户端结点</summary>
/// <returns></returns>
public NetUri GetClient()
{
var uri = new NetUri
{
Host = ClientIP,
Port = ClientPort
};

return uri;
}
#endregion
}
Loading

0 comments on commit 3701f6c

Please sign in to comment.