Просмотр исходного кода

kcp增加检验,防止第三方消息攻击

tanghai 7 лет назад
Родитель
Сommit
8d167f44a6

BIN
Server/ThirdParty/Libs/libkcp.so


+ 1 - 3
Tools/cwRsync/Config/exclude.txt

@@ -12,7 +12,5 @@ Unity/Assets/Res/
 Unity/Assets/Resources/
 Unity/Assets/Bundles/
 Unity/Assets/Scenes/
-Unity/.vs/
-Unity/.idea/
-Server/.vs/
 .vs/
+.idea/

BIN
Unity/Assets/Plugins/Android/libs/armeabi-v7a/libkcp.so


BIN
Unity/Assets/Plugins/Android/libs/x86/libkcp.so


BIN
Unity/Assets/Plugins/x86/kcp.dll


BIN
Unity/Assets/Plugins/x86_64/kcp.dll


+ 5 - 0
Unity/Assets/Scripts/Base/Helper/ByteHelper.cs

@@ -76,6 +76,11 @@ namespace ETModel
 			bytes[offset + 3] = (byte)((num & 0xff000000) >> 24);
 		}
 		
+		public static void WriteTo(this byte[] bytes, int offset, byte num)
+		{
+			bytes[offset] = num;
+		}
+		
 		public static void WriteTo(this byte[] bytes, int offset, short num)
 		{
 			bytes[offset] = (byte)(num & 0xff);

+ 24 - 19
Unity/Assets/Scripts/Module/Message/Network/KCP/KChannel.cs

@@ -28,7 +28,7 @@ namespace ETModel
 		private readonly Queue<WaitSendBuffer> sendBuffer = new Queue<WaitSendBuffer>();
 
 		private bool isConnected;
-		public bool IsRecvFirstKcpMessage { get; set; }
+		public bool isRecvFirstKcpMessage;
 		private readonly IPEndPoint remoteEndPoint;
 
 		private uint lastRecvTime;
@@ -57,8 +57,9 @@ namespace ETModel
 			);
 			Kcp.KcpNodelay(this.kcp, 1, 10, 1, 1);
 			Kcp.KcpWndsize(this.kcp, 256, 256);
+			Kcp.KcpSetmtu(this.kcp, 470);
 			this.isConnected = true;
-			this.IsRecvFirstKcpMessage = false;
+			this.isRecvFirstKcpMessage = false;
 			this.lastRecvTime = kService.TimeNow;
 		}
 
@@ -70,7 +71,7 @@ namespace ETModel
 			this.LocalConn = localConn;
 			this.socket = socket;
 			this.remoteEndPoint = remoteEndPoint;
-			this.IsRecvFirstKcpMessage = false;
+			this.isRecvFirstKcpMessage = false;
 			this.lastRecvTime = kService.TimeNow;
 			this.Connect();
 		}
@@ -156,6 +157,7 @@ namespace ETModel
 			);
 			Kcp.KcpNodelay(this.kcp, 1, 10, 1, 1);
 			Kcp.KcpWndsize(this.kcp, 256, 256);
+			Kcp.KcpSetmtu(this.kcp, 470);
 
 			this.isConnected = true;
 			this.lastRecvTime = this.GetService().TimeNow;
@@ -171,16 +173,16 @@ namespace ETModel
 			}
 
 			// 如果channel已经收到过消息,则不再响应连接请求
-			if (this.IsRecvFirstKcpMessage)
+			if (this.isRecvFirstKcpMessage)
 			{
 				return;
 			}
 			try
 			{
 				this.packet.Bytes.WriteTo(0, KcpProtocalType.ACK);
-				this.packet.Bytes.WriteTo(4, LocalConn);
-				this.packet.Bytes.WriteTo(8, RemoteConn);
-				this.socket.SendTo(this.packet.Bytes, 0, 12, SocketFlags.None, remoteEndPoint);
+				this.packet.Bytes.WriteTo(1, LocalConn);
+				this.packet.Bytes.WriteTo(5, RemoteConn);
+				this.socket.SendTo(this.packet.Bytes, 0, 9, SocketFlags.None, remoteEndPoint);
 			}
 			catch (Exception e)
 			{
@@ -198,8 +200,8 @@ namespace ETModel
 			{
 				uint timeNow = this.GetService().TimeNow;
 				this.packet.Bytes.WriteTo(0, KcpProtocalType.SYN);
-				this.packet.Bytes.WriteTo(4, this.LocalConn);
-				this.socket.SendTo(this.packet.Bytes, 0, 8, SocketFlags.None, remoteEndPoint);
+				this.packet.Bytes.WriteTo(1, this.LocalConn);
+				this.socket.SendTo(this.packet.Bytes, 0, 5, SocketFlags.None, remoteEndPoint);
 
 				// 200毫秒后再次update发送connect请求
 				this.GetService().AddToUpdateNextTime(timeNow + 200, this.Id);
@@ -220,10 +222,10 @@ namespace ETModel
 			try
 			{
 				this.packet.Bytes.WriteTo(0, KcpProtocalType.FIN);
-				this.packet.Bytes.WriteTo(4, this.LocalConn);
-				this.packet.Bytes.WriteTo(8, this.RemoteConn);
-				this.packet.Bytes.WriteTo(12, (uint)this.Error);
-				this.socket.SendTo(this.packet.Bytes, 0, 16, SocketFlags.None, remoteEndPoint);
+				this.packet.Bytes.WriteTo(1, this.LocalConn);
+				this.packet.Bytes.WriteTo(5, this.RemoteConn);
+				this.packet.Bytes.WriteTo(9, (uint)this.Error);
+				this.socket.SendTo(this.packet.Bytes, 0, 13, SocketFlags.None, remoteEndPoint);
 			}
 			catch (Exception e)
 			{
@@ -294,7 +296,7 @@ namespace ETModel
 			}
 		}
 
-		public void HandleRecv(byte[] date, int length)
+		public void HandleRecv(byte[] date, int offset, int length)
 		{
 			if (this.IsDisposed)
 			{
@@ -302,13 +304,13 @@ namespace ETModel
 			}
 
 			// 收到了kcp消息则将自己从连接状态移除
-			if (!this.IsRecvFirstKcpMessage)
+			if (!this.isRecvFirstKcpMessage)
 			{
 				this.GetService().RemoveFromWaitConnectChannels(this.RemoteConn);
-				this.IsRecvFirstKcpMessage = true;
+				this.isRecvFirstKcpMessage = true;
 			}
 
-			Kcp.KcpInput(this.kcp, date, length);
+			Kcp.KcpInput(this.kcp, date, offset, length);
 			this.GetService().AddToUpdateNextTime(0, this.Id);
 
 			while (true)
@@ -361,8 +363,11 @@ namespace ETModel
 					return;
 				}
 
-				Marshal.Copy(bytes, this.packet.Bytes, 0, count);
-				this.socket.SendTo(this.packet.Bytes, 0, count, SocketFlags.None, this.remoteEndPoint);
+				this.packet.Bytes.WriteTo(0, KcpProtocalType.MSG);
+				// 每个消息头部写下该channel的id;
+				this.packet.Bytes.WriteTo(1, this.LocalConn);
+				Marshal.Copy(bytes, this.packet.Bytes, 5, count);
+				this.socket.SendTo(this.packet.Bytes, 0, count + 5, SocketFlags.None, this.remoteEndPoint);
 			}
 			catch (Exception e)
 			{

+ 35 - 23
Unity/Assets/Scripts/Module/Message/Network/KCP/KService.cs

@@ -8,9 +8,10 @@ namespace ETModel
 {
 	public static class KcpProtocalType
 	{
-		public const uint SYN = 1;
-		public const uint ACK = 2;
-		public const uint FIN = 3;
+		public const byte SYN = 1;
+		public const byte ACK = 2;
+		public const byte FIN = 3;
+		public const byte MSG = 4;
 	}
 
 	public sealed class KService : AService
@@ -123,23 +124,23 @@ namespace ETModel
 					continue;
 				}
 
-				// 长度小于4,不是正常的消息
-				if (messageLength < 4)
+				// 长度小于1,不是正常的消息
+				if (messageLength < 1)
 				{
 					continue;
 				}
 				// accept
-				uint conn = BitConverter.ToUInt32(this.cache, 0);
-
+				byte flag = this.cache[0];
+				
 				// conn从1000开始,如果为1,2,3则是特殊包
 				uint remoteConn = 0;
 				uint localConn = 0;
 				KChannel kChannel = null;
-				switch (conn)
+				switch (flag)
 				{
 					case KcpProtocalType.SYN:  // accept
-											   // 长度!=8,不是accpet消息
-						if (messageLength != 8)
+						// 长度!=5,不是accpet消息
+						if (messageLength != 5)
 						{
 							break;
 						}
@@ -147,7 +148,7 @@ namespace ETModel
 						IPEndPoint acceptIpEndPoint = (IPEndPoint)this.ipEndPoint;
 						this.ipEndPoint = new IPEndPoint(0, 0);
 
-						remoteConn = BitConverter.ToUInt32(this.cache, 4);
+						remoteConn = BitConverter.ToUInt32(this.cache, 1);
 
 						// 如果等待连接状态,则重新响应请求
 						if (this.waitConnectChannels.TryGetValue(remoteConn, out kChannel))
@@ -167,13 +168,13 @@ namespace ETModel
 
 						break;
 					case KcpProtocalType.ACK:  // connect返回
-											   // 长度!=12,不是connect消息
-						if (messageLength != 12)
+						// 长度!=9,不是connect消息
+						if (messageLength != 9)
 						{
 							break;
 						}
-						remoteConn = BitConverter.ToUInt32(this.cache, 4);
-						localConn = BitConverter.ToUInt32(this.cache, 8);
+						remoteConn = BitConverter.ToUInt32(this.cache, 1);
+						localConn = BitConverter.ToUInt32(this.cache, 5);
 
 						kChannel = this.GetKChannel(localConn);
 						if (kChannel != null)
@@ -182,32 +183,43 @@ namespace ETModel
 						}
 						break;
 					case KcpProtocalType.FIN:  // 断开
-											   // 长度!=12,不是DisConnect消息
-						if (messageLength != 16)
+						// 长度!=13,不是DisConnect消息
+						if (messageLength != 13)
 						{
 							break;
 						}
 
-						remoteConn = BitConverter.ToUInt32(this.cache, 4);
-						localConn = BitConverter.ToUInt32(this.cache, 8);
+						remoteConn = BitConverter.ToUInt32(this.cache, 1);
+						localConn = BitConverter.ToUInt32(this.cache, 5);
 
 						// 处理chanel
 						kChannel = this.GetKChannel(localConn);
 						if (kChannel != null)
 						{
+							// 校验remoteConn,防止第三方攻击
 							if (kChannel.RemoteConn == remoteConn)
 							{
 								kChannel.Disconnect(ErrorCode.ERR_PeerDisconnect);
 							}
 						}
 						break;
-					default:  // 接收
-							  // 处理chanel
-						localConn = conn;
+					case KcpProtocalType.MSG:  // 断开
+						// 长度<9,不是Msg消息
+						if (messageLength < 9)
+						{
+							break;
+						}
+						// 处理chanel
+						remoteConn = BitConverter.ToUInt32(this.cache, 1);
+						localConn = BitConverter.ToUInt32(this.cache, 5);
 						kChannel = this.GetKChannel(localConn);
 						if (kChannel != null)
 						{
-							kChannel.HandleRecv(this.cache, messageLength);
+							// 校验remoteConn,防止第三方攻击
+							if (kChannel.RemoteConn == remoteConn)
+							{
+								kChannel.HandleRecv(this.cache, 5, messageLength - 5);
+							}
 						}
 						break;
 				}

+ 3 - 3
Unity/Assets/Scripts/Module/Message/Network/KCP/Kcp.cs

@@ -24,7 +24,7 @@ namespace ETModel
         [DllImport(KcpDLL, CallingConvention=CallingConvention.Cdecl)]
         public static extern uint ikcp_getconv(IntPtr ptr);
         [DllImport(KcpDLL, CallingConvention=CallingConvention.Cdecl)]
-        public static extern int ikcp_input(IntPtr kcp, byte[] data, long size);
+        public static extern int ikcp_input(IntPtr kcp, byte[] data, long offset, long size);
         [DllImport(KcpDLL, CallingConvention=CallingConvention.Cdecl)]
         public static extern int ikcp_nodelay(IntPtr kcp, int nodelay, int interval, int resend, int nc);
         [DllImport(KcpDLL, CallingConvention=CallingConvention.Cdecl)]
@@ -68,9 +68,9 @@ namespace ETModel
             return ikcp_getconv(ptr);
         }
 
-        public static int KcpInput(IntPtr kcp, byte[] data, long size)
+        public static int KcpInput(IntPtr kcp, byte[] data, long offset, long size)
         {
-            return ikcp_input(kcp, data, size);
+            return ikcp_input(kcp, data, offset, size);
         }
 
         public static int KcpNodelay(IntPtr kcp, int nodelay, int interval, int resend, int nc)