KService.cs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583
  1. using System;
  2. using System.Collections.Concurrent;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Net;
  7. using System.Net.Sockets;
  8. using System.Runtime.InteropServices;
  9. namespace ET
  10. {
  11. public static class KcpProtocalType
  12. {
  13. public const byte SYN = 1;
  14. public const byte ACK = 2;
  15. public const byte FIN = 3;
  16. public const byte MSG = 4;
  17. }
  18. public enum ServiceType
  19. {
  20. Outer,
  21. Inner,
  22. }
  23. public sealed class KService: AService
  24. {
  25. // KService创建的时间
  26. private readonly long startTime;
  27. // 当前时间 - KService创建的时间, 线程安全
  28. public uint TimeNow
  29. {
  30. get
  31. {
  32. return (uint) (TimeHelper.ClientNow() - this.startTime);
  33. }
  34. }
  35. private Socket socket;
  36. #region 回调方法
  37. static KService()
  38. {
  39. //Kcp.KcpSetLog(KcpLog);
  40. Kcp.KcpSetoutput(KcpOutput);
  41. }
  42. private static readonly byte[] logBuffer = new byte[1024];
  43. #if ENABLE_IL2CPP
  44. [AOT.MonoPInvokeCallback(typeof(KcpOutput))]
  45. #endif
  46. private static void KcpLog(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
  47. {
  48. try
  49. {
  50. Marshal.Copy(bytes, logBuffer, 0, len);
  51. Log.Info(logBuffer.ToStr(0, len));
  52. }
  53. catch (Exception e)
  54. {
  55. Log.Error(e);
  56. }
  57. }
  58. #if ENABLE_IL2CPP
  59. [AOT.MonoPInvokeCallback(typeof(KcpOutput))]
  60. #endif
  61. private static int KcpOutput(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
  62. {
  63. try
  64. {
  65. if (kcp == IntPtr.Zero)
  66. {
  67. return 0;
  68. }
  69. if (!KChannel.KcpPtrChannels.TryGetValue(kcp, out KChannel kChannel))
  70. {
  71. return 0;
  72. }
  73. kChannel.Output(bytes, len);
  74. }
  75. catch (Exception e)
  76. {
  77. Log.Error(e);
  78. return len;
  79. }
  80. return len;
  81. }
  82. #endregion
  83. public KService(ThreadSynchronizationContext threadSynchronizationContext, IPEndPoint ipEndPoint, ServiceType serviceType)
  84. {
  85. this.ServiceType = serviceType;
  86. this.ThreadSynchronizationContext = threadSynchronizationContext;
  87. this.startTime = TimeHelper.ClientNow();
  88. this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
  89. if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
  90. {
  91. this.socket.SendBufferSize = Kcp.OneM * 64;
  92. this.socket.ReceiveBufferSize = Kcp.OneM * 64;
  93. }
  94. this.socket.Bind(ipEndPoint);
  95. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  96. {
  97. const uint IOC_IN = 0x80000000;
  98. const uint IOC_VENDOR = 0x18000000;
  99. uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
  100. this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
  101. }
  102. }
  103. public KService(ThreadSynchronizationContext threadSynchronizationContext, ServiceType serviceType)
  104. {
  105. this.ServiceType = serviceType;
  106. this.ThreadSynchronizationContext = threadSynchronizationContext;
  107. this.startTime = TimeHelper.ClientNow();
  108. this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
  109. // 作为客户端不需要修改发送跟接收缓冲区大小
  110. this.socket.Bind(new IPEndPoint(IPAddress.Any, 0));
  111. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  112. {
  113. const uint IOC_IN = 0x80000000;
  114. const uint IOC_VENDOR = 0x18000000;
  115. uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
  116. this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
  117. }
  118. }
  119. public void ChangeAddress(long id, IPEndPoint address)
  120. {
  121. KChannel kChannel = this.Get(id);
  122. if (kChannel == null)
  123. {
  124. return;
  125. }
  126. Log.Info($"channel change address: {id} {address}");
  127. kChannel.RemoteAddress = address;
  128. }
  129. // 保存所有的channel
  130. private readonly Dictionary<long, KChannel> idChannels = new Dictionary<long, KChannel>();
  131. private readonly Dictionary<long, KChannel> localConnChannels = new Dictionary<long, KChannel>();
  132. private readonly Dictionary<long, KChannel> waitConnectChannels = new Dictionary<long, KChannel>();
  133. private readonly byte[] cache = new byte[8192];
  134. private EndPoint ipEndPoint = new IPEndPoint(IPAddress.Any, 0);
  135. // 下帧要更新的channel
  136. private readonly HashSet<long> updateChannels = new HashSet<long>();
  137. // 下次时间更新的channel
  138. private readonly MultiMap<long, long> timeId = new MultiMap<long, long>();
  139. private readonly List<long> timeOutTime = new List<long>();
  140. // 记录最小时间,不用每次都去MultiMap取第一个值
  141. private long minTime;
  142. private List<long> waitRemoveChannels = new List<long>();
  143. public override bool IsDispose()
  144. {
  145. return this.socket == null;
  146. }
  147. public override void Dispose()
  148. {
  149. foreach (long channelId in this.idChannels.Keys.ToArray())
  150. {
  151. this.Remove(channelId);
  152. }
  153. this.socket.Close();
  154. this.socket = null;
  155. }
  156. private IPEndPoint CloneAddress()
  157. {
  158. IPEndPoint ip = (IPEndPoint) this.ipEndPoint;
  159. return new IPEndPoint(ip.Address, ip.Port);
  160. }
  161. private void Recv()
  162. {
  163. if (this.socket == null)
  164. {
  165. return;
  166. }
  167. while (socket != null && this.socket.Available > 0)
  168. {
  169. int messageLength = this.socket.ReceiveFrom(this.cache, ref this.ipEndPoint);
  170. // 长度小于1,不是正常的消息
  171. if (messageLength < 1)
  172. {
  173. continue;
  174. }
  175. // accept
  176. byte flag = this.cache[0];
  177. // conn从100开始,如果为1,2,3则是特殊包
  178. uint remoteConn = 0;
  179. uint localConn = 0;
  180. try
  181. {
  182. KChannel kChannel = null;
  183. switch (flag)
  184. {
  185. #if NOT_UNITY
  186. case KcpProtocalType.SYN: // accept
  187. {
  188. // 长度!=5,不是SYN消息
  189. if (messageLength < 9)
  190. {
  191. break;
  192. }
  193. string realAddress = null;
  194. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  195. if (messageLength > 9)
  196. {
  197. realAddress = this.cache.ToStr(9, messageLength - 9);
  198. }
  199. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  200. localConn = BitConverter.ToUInt32(this.cache, 5);
  201. this.waitConnectChannels.TryGetValue(remoteConn, out kChannel);
  202. if (kChannel == null)
  203. {
  204. localConn = CreateRandomLocalConn();
  205. // 已存在同样的localConn,则不处理,等待下次sync
  206. if (this.localConnChannels.ContainsKey(localConn))
  207. {
  208. break;
  209. }
  210. long id = this.CreateAcceptChannelId(localConn);
  211. if (this.idChannels.ContainsKey(id))
  212. {
  213. break;
  214. }
  215. kChannel = new KChannel(id, localConn, remoteConn, this.socket, this.CloneAddress(), this);
  216. this.idChannels.Add(kChannel.Id, kChannel);
  217. this.waitConnectChannels.Add(kChannel.RemoteConn, kChannel); // 连接上了或者超时后会删除
  218. this.localConnChannels.Add(kChannel.LocalConn, kChannel);
  219. kChannel.RealAddress = realAddress;
  220. IPEndPoint realEndPoint = kChannel.RealAddress == null? kChannel.RemoteAddress : NetworkHelper.ToIPEndPoint(kChannel.RealAddress);
  221. this.OnAccept(kChannel.Id, realEndPoint);
  222. }
  223. if (kChannel.RemoteConn != remoteConn)
  224. {
  225. break;
  226. }
  227. // 地址跟上次的不一致则跳过
  228. if (kChannel.RealAddress != realAddress)
  229. {
  230. Log.Error($"kchannel syn address diff: {kChannel.Id} {kChannel.RealAddress} {realAddress}");
  231. break;
  232. }
  233. try
  234. {
  235. byte[] buffer = this.cache;
  236. buffer.WriteTo(0, KcpProtocalType.ACK);
  237. buffer.WriteTo(1, kChannel.LocalConn);
  238. buffer.WriteTo(5, kChannel.RemoteConn);
  239. Log.Info($"kservice syn: {kChannel.Id} {remoteConn} {localConn}");
  240. this.socket.SendTo(buffer, 0, 9, SocketFlags.None, kChannel.RemoteAddress);
  241. }
  242. catch (Exception e)
  243. {
  244. Log.Error(e);
  245. kChannel.OnError(ErrorCore.ERR_SocketCantSend);
  246. }
  247. break;
  248. }
  249. #endif
  250. case KcpProtocalType.ACK: // connect返回
  251. // 长度!=9,不是connect消息
  252. if (messageLength != 9)
  253. {
  254. break;
  255. }
  256. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  257. localConn = BitConverter.ToUInt32(this.cache, 5);
  258. kChannel = this.GetByLocalConn(localConn);
  259. if (kChannel != null)
  260. {
  261. Log.Info($"kservice ack: {kChannel.Id} {remoteConn} {localConn}");
  262. kChannel.RemoteConn = remoteConn;
  263. kChannel.HandleConnnect();
  264. }
  265. break;
  266. case KcpProtocalType.FIN: // 断开
  267. // 长度!=13,不是DisConnect消息
  268. if (messageLength != 13)
  269. {
  270. break;
  271. }
  272. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  273. localConn = BitConverter.ToUInt32(this.cache, 5);
  274. int error = BitConverter.ToInt32(this.cache, 9);
  275. // 处理chanel
  276. kChannel = this.GetByLocalConn(localConn);
  277. if (kChannel == null)
  278. {
  279. break;
  280. }
  281. // 校验remoteConn,防止第三方攻击
  282. if (kChannel.RemoteConn != remoteConn)
  283. {
  284. break;
  285. }
  286. Log.Info($"kservice recv fin: {kChannel.Id} {localConn} {remoteConn} {error}");
  287. kChannel.OnError(ErrorCore.ERR_PeerDisconnect);
  288. break;
  289. case KcpProtocalType.MSG: // 断开
  290. // 长度<9,不是Msg消息
  291. if (messageLength < 9)
  292. {
  293. break;
  294. }
  295. // 处理chanel
  296. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  297. localConn = BitConverter.ToUInt32(this.cache, 5);
  298. kChannel = this.GetByLocalConn(localConn);
  299. if (kChannel == null)
  300. {
  301. // 通知对方断开
  302. this.Disconnect(localConn, remoteConn, ErrorCore.ERR_KcpNotFoundChannel, (IPEndPoint) this.ipEndPoint, 1);
  303. break;
  304. }
  305. // 校验remoteConn,防止第三方攻击
  306. if (kChannel.RemoteConn != remoteConn)
  307. {
  308. break;
  309. }
  310. kChannel.HandleRecv(this.cache, 5, messageLength - 5);
  311. break;
  312. }
  313. }
  314. catch (Exception e)
  315. {
  316. Log.Error($"kservice error: {flag} {remoteConn} {localConn}\n{e}");
  317. }
  318. }
  319. }
  320. public KChannel Get(long id)
  321. {
  322. KChannel channel;
  323. this.idChannels.TryGetValue(id, out channel);
  324. return channel;
  325. }
  326. private KChannel GetByLocalConn(uint localConn)
  327. {
  328. KChannel channel;
  329. this.localConnChannels.TryGetValue(localConn, out channel);
  330. return channel;
  331. }
  332. protected override void Get(long id, IPEndPoint address)
  333. {
  334. if (this.idChannels.TryGetValue(id, out KChannel kChannel))
  335. {
  336. return;
  337. }
  338. try
  339. {
  340. // 低32bit是localConn
  341. uint localConn = (uint) ((ulong) id & uint.MaxValue);
  342. kChannel = new KChannel(id, localConn, this.socket, address, this);
  343. this.idChannels.Add(id, kChannel);
  344. this.localConnChannels.Add(kChannel.LocalConn, kChannel);
  345. }
  346. catch (Exception e)
  347. {
  348. Log.Error($"kservice get error: {id}\n{e}");
  349. }
  350. }
  351. public override void Remove(long id)
  352. {
  353. if (!this.idChannels.TryGetValue(id, out KChannel kChannel))
  354. {
  355. return;
  356. }
  357. Log.Info($"kservice remove channel: {id} {kChannel.LocalConn} {kChannel.RemoteConn}");
  358. this.idChannels.Remove(id);
  359. this.localConnChannels.Remove(kChannel.LocalConn);
  360. if (this.waitConnectChannels.TryGetValue(kChannel.RemoteConn, out KChannel waitChannel))
  361. {
  362. if (waitChannel.LocalConn == kChannel.LocalConn)
  363. {
  364. this.waitConnectChannels.Remove(kChannel.RemoteConn);
  365. }
  366. }
  367. kChannel.Dispose();
  368. }
  369. private void Disconnect(uint localConn, uint remoteConn, int error, IPEndPoint address, int times)
  370. {
  371. try
  372. {
  373. if (this.socket == null)
  374. {
  375. return;
  376. }
  377. byte[] buffer = this.cache;
  378. buffer.WriteTo(0, KcpProtocalType.FIN);
  379. buffer.WriteTo(1, localConn);
  380. buffer.WriteTo(5, remoteConn);
  381. buffer.WriteTo(9, (uint) error);
  382. for (int i = 0; i < times; ++i)
  383. {
  384. this.socket.SendTo(buffer, 0, 13, SocketFlags.None, address);
  385. }
  386. }
  387. catch (Exception e)
  388. {
  389. Log.Error($"Disconnect error {localConn} {remoteConn} {error} {address} {e}");
  390. }
  391. Log.Info($"channel send fin: {localConn} {remoteConn} {address} {error}");
  392. }
  393. protected override void Send(long channelId, long actorId, MemoryStream stream)
  394. {
  395. KChannel channel = this.Get(channelId);
  396. if (channel == null)
  397. {
  398. return;
  399. }
  400. channel.Send(actorId, stream);
  401. }
  402. // 服务端需要看channel的update时间是否已到
  403. public void AddToUpdateNextTime(long time, long id)
  404. {
  405. if (time == 0)
  406. {
  407. this.updateChannels.Add(id);
  408. return;
  409. }
  410. if (time < this.minTime)
  411. {
  412. this.minTime = time;
  413. }
  414. this.timeId.Add(time, id);
  415. }
  416. public override void Update()
  417. {
  418. this.Recv();
  419. this.TimerOut();
  420. foreach (long id in updateChannels)
  421. {
  422. KChannel kChannel = this.Get(id);
  423. if (kChannel == null)
  424. {
  425. continue;
  426. }
  427. if (kChannel.Id == 0)
  428. {
  429. continue;
  430. }
  431. kChannel.Update();
  432. }
  433. this.updateChannels.Clear();
  434. this.RemoveConnectTimeoutChannels();
  435. }
  436. private void RemoveConnectTimeoutChannels()
  437. {
  438. waitRemoveChannels.Clear();
  439. foreach (long channelId in this.waitConnectChannels.Keys)
  440. {
  441. this.waitConnectChannels.TryGetValue(channelId, out KChannel kChannel);
  442. if (kChannel == null)
  443. {
  444. Log.Error($"RemoveConnectTimeoutChannels not found kchannel: {channelId}");
  445. continue;
  446. }
  447. // 连接上了要马上删除
  448. if (kChannel.IsConnected)
  449. {
  450. waitRemoveChannels.Add(channelId);
  451. }
  452. // 10秒连接超时
  453. if (this.TimeNow > kChannel.CreateTime + 10 * 1000)
  454. {
  455. waitRemoveChannels.Add(channelId);
  456. }
  457. }
  458. foreach (long channelId in waitRemoveChannels)
  459. {
  460. this.waitConnectChannels.Remove(channelId);
  461. }
  462. }
  463. // 计算到期需要update的channel
  464. private void TimerOut()
  465. {
  466. if (this.timeId.Count == 0)
  467. {
  468. return;
  469. }
  470. uint timeNow = this.TimeNow;
  471. if (timeNow < this.minTime)
  472. {
  473. return;
  474. }
  475. this.timeOutTime.Clear();
  476. foreach (KeyValuePair<long, List<long>> kv in this.timeId)
  477. {
  478. long k = kv.Key;
  479. if (k > timeNow)
  480. {
  481. minTime = k;
  482. break;
  483. }
  484. this.timeOutTime.Add(k);
  485. }
  486. foreach (long k in this.timeOutTime)
  487. {
  488. foreach (long v in this.timeId[k])
  489. {
  490. this.updateChannels.Add(v);
  491. }
  492. this.timeId.Remove(k);
  493. }
  494. }
  495. }
  496. }