KService.cs 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. using System;
  2. using System.Collections.Concurrent;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Net;
  7. using System.Net.Sockets;
  8. using System.Runtime.InteropServices;
  9. namespace ET
  10. {
  11. public static class KcpProtocalType
  12. {
  13. public const byte SYN = 1;
  14. public const byte ACK = 2;
  15. public const byte FIN = 3;
  16. public const byte MSG = 4;
  17. public const byte RouterReconnectSYN = 5;
  18. public const byte RouterReconnectACK = 6;
  19. public const byte RouterSYN = 7;
  20. public const byte RouterACK = 8;
  21. }
  22. public enum ServiceType
  23. {
  24. Outer,
  25. Inner,
  26. }
  27. public sealed class KService: AService
  28. {
  29. // KService创建的时间
  30. private readonly long startTime;
  31. // 当前时间 - KService创建的时间, 线程安全
  32. public uint TimeNow
  33. {
  34. get
  35. {
  36. return (uint) (TimeHelper.ClientNow() - this.startTime);
  37. }
  38. }
  39. private Socket socket;
  40. #region 回调方法
  41. static KService()
  42. {
  43. //Kcp.KcpSetLog(KcpLog);
  44. Kcp.KcpSetoutput(KcpOutput);
  45. }
  46. private static readonly byte[] logBuffer = new byte[1024];
  47. #if ENABLE_IL2CPP
  48. [AOT.MonoPInvokeCallback(typeof(KcpOutput))]
  49. #endif
  50. private static void KcpLog(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
  51. {
  52. try
  53. {
  54. Marshal.Copy(bytes, logBuffer, 0, len);
  55. Log.Info(logBuffer.ToStr(0, len));
  56. }
  57. catch (Exception e)
  58. {
  59. Log.Error(e);
  60. }
  61. }
  62. #if ENABLE_IL2CPP
  63. [AOT.MonoPInvokeCallback(typeof(KcpOutput))]
  64. #endif
  65. private static int KcpOutput(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
  66. {
  67. try
  68. {
  69. if (kcp == IntPtr.Zero)
  70. {
  71. return 0;
  72. }
  73. if (!KChannel.KcpPtrChannels.TryGetValue(kcp, out KChannel kChannel))
  74. {
  75. return 0;
  76. }
  77. kChannel.Output(bytes, len);
  78. }
  79. catch (Exception e)
  80. {
  81. Log.Error(e);
  82. return len;
  83. }
  84. return len;
  85. }
  86. #endregion
  87. public KService(ThreadSynchronizationContext threadSynchronizationContext, IPEndPoint ipEndPoint, ServiceType serviceType)
  88. {
  89. this.ServiceType = serviceType;
  90. this.ThreadSynchronizationContext = threadSynchronizationContext;
  91. this.startTime = TimeHelper.ClientNow();
  92. this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
  93. if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
  94. {
  95. this.socket.SendBufferSize = Kcp.OneM * 64;
  96. this.socket.ReceiveBufferSize = Kcp.OneM * 64;
  97. }
  98. this.socket.Bind(ipEndPoint);
  99. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  100. {
  101. const uint IOC_IN = 0x80000000;
  102. const uint IOC_VENDOR = 0x18000000;
  103. uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
  104. this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
  105. }
  106. }
  107. public KService(ThreadSynchronizationContext threadSynchronizationContext, ServiceType serviceType)
  108. {
  109. this.ServiceType = serviceType;
  110. this.ThreadSynchronizationContext = threadSynchronizationContext;
  111. this.startTime = TimeHelper.ClientNow();
  112. this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
  113. // 作为客户端不需要修改发送跟接收缓冲区大小
  114. this.socket.Bind(new IPEndPoint(IPAddress.Any, 0));
  115. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  116. {
  117. const uint IOC_IN = 0x80000000;
  118. const uint IOC_VENDOR = 0x18000000;
  119. uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
  120. this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
  121. }
  122. }
  123. public void ChangeAddress(long id, IPEndPoint address)
  124. {
  125. KChannel kChannel = this.Get(id);
  126. if (kChannel == null)
  127. {
  128. return;
  129. }
  130. Log.Info($"channel change address: {id} {address}");
  131. kChannel.RemoteAddress = address;
  132. }
  133. // 保存所有的channel
  134. private readonly Dictionary<long, KChannel> idChannels = new Dictionary<long, KChannel>();
  135. private readonly Dictionary<long, KChannel> localConnChannels = new Dictionary<long, KChannel>();
  136. private readonly Dictionary<long, KChannel> waitConnectChannels = new Dictionary<long, KChannel>();
  137. private readonly byte[] cache = new byte[8192];
  138. private EndPoint ipEndPoint = new IPEndPoint(IPAddress.Any, 0);
  139. // 下帧要更新的channel
  140. private readonly HashSet<long> updateChannels = new HashSet<long>();
  141. // 下次时间更新的channel
  142. private readonly MultiMap<long, long> timeId = new MultiMap<long, long>();
  143. private readonly List<long> timeOutTime = new List<long>();
  144. // 记录最小时间,不用每次都去MultiMap取第一个值
  145. private long minTime;
  146. private List<long> waitRemoveChannels = new List<long>();
  147. public override bool IsDispose()
  148. {
  149. return this.socket == null;
  150. }
  151. public override void Dispose()
  152. {
  153. foreach (long channelId in this.idChannels.Keys.ToArray())
  154. {
  155. this.Remove(channelId);
  156. }
  157. this.socket.Close();
  158. this.socket = null;
  159. }
  160. private IPEndPoint CloneAddress()
  161. {
  162. IPEndPoint ip = (IPEndPoint) this.ipEndPoint;
  163. return new IPEndPoint(ip.Address, ip.Port);
  164. }
  165. private void Recv()
  166. {
  167. if (this.socket == null)
  168. {
  169. return;
  170. }
  171. while (socket != null && this.socket.Available > 0)
  172. {
  173. int messageLength = this.socket.ReceiveFrom(this.cache, ref this.ipEndPoint);
  174. // 长度小于1,不是正常的消息
  175. if (messageLength < 1)
  176. {
  177. continue;
  178. }
  179. // accept
  180. byte flag = this.cache[0];
  181. // conn从100开始,如果为1,2,3则是特殊包
  182. uint remoteConn = 0;
  183. uint localConn = 0;
  184. try
  185. {
  186. KChannel kChannel = null;
  187. switch (flag)
  188. {
  189. #if NOT_UNITY
  190. case KcpProtocalType.RouterReconnectSYN:
  191. {
  192. // 长度!=5,不是RouterReconnectSYN消息
  193. if (messageLength != 13)
  194. {
  195. break;
  196. }
  197. string realAddress = null;
  198. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  199. localConn = BitConverter.ToUInt32(this.cache, 5);
  200. uint connectId = BitConverter.ToUInt32(this.cache, 9);
  201. this.localConnChannels.TryGetValue(localConn, out kChannel);
  202. if (kChannel == null)
  203. {
  204. Log.Warning($"kchannel reconnect not found channel: {localConn} {remoteConn} {realAddress}");
  205. break;
  206. }
  207. // 这里必须校验localConn,客户端重连,localConn一定是一样的
  208. if (localConn != kChannel.LocalConn)
  209. {
  210. Log.Warning($"kchannel reconnect localconn error: {kChannel.Id} {localConn} {remoteConn} {realAddress} {kChannel.LocalConn}");
  211. break;
  212. }
  213. if (remoteConn != kChannel.RemoteConn)
  214. {
  215. Log.Warning($"kchannel reconnect remoteconn error: {kChannel.Id} {localConn} {remoteConn} {realAddress} {kChannel.RemoteConn}");
  216. break;
  217. }
  218. // 重连的时候router地址变化, 这个不能放到msg中,必须经过严格的验证才能切换
  219. if (!Equals(kChannel.RemoteAddress, this.ipEndPoint))
  220. {
  221. kChannel.RemoteAddress = this.CloneAddress();
  222. }
  223. try
  224. {
  225. byte[] buffer = this.cache;
  226. buffer.WriteTo(0, KcpProtocalType.RouterReconnectACK);
  227. buffer.WriteTo(1, kChannel.LocalConn);
  228. buffer.WriteTo(5, kChannel.RemoteConn);
  229. buffer.WriteTo(9, connectId);
  230. this.socket.SendTo(buffer, 0, 13, SocketFlags.None, this.ipEndPoint);
  231. }
  232. catch (Exception e)
  233. {
  234. Log.Error(e);
  235. kChannel.OnError(ErrorCore.ERR_SocketCantSend);
  236. }
  237. break;
  238. }
  239. case KcpProtocalType.SYN: // accept
  240. {
  241. // 长度!=5,不是SYN消息
  242. if (messageLength < 9)
  243. {
  244. break;
  245. }
  246. string realAddress = null;
  247. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  248. if (messageLength > 9)
  249. {
  250. realAddress = this.cache.ToStr(9, messageLength - 9);
  251. }
  252. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  253. localConn = BitConverter.ToUInt32(this.cache, 5);
  254. this.waitConnectChannels.TryGetValue(remoteConn, out kChannel);
  255. if (kChannel == null)
  256. {
  257. localConn = CreateRandomLocalConn();
  258. // 已存在同样的localConn,则不处理,等待下次sync
  259. if (this.localConnChannels.ContainsKey(localConn))
  260. {
  261. break;
  262. }
  263. long id = this.CreateAcceptChannelId(localConn);
  264. if (this.idChannels.ContainsKey(id))
  265. {
  266. break;
  267. }
  268. kChannel = new KChannel(id, localConn, remoteConn, this.socket, this.CloneAddress(), this);
  269. this.idChannels.Add(kChannel.Id, kChannel);
  270. this.waitConnectChannels.Add(kChannel.RemoteConn, kChannel); // 连接上了或者超时后会删除
  271. this.localConnChannels.Add(kChannel.LocalConn, kChannel);
  272. kChannel.RealAddress = realAddress;
  273. IPEndPoint realEndPoint = kChannel.RealAddress == null? kChannel.RemoteAddress : NetworkHelper.ToIPEndPoint(kChannel.RealAddress);
  274. this.OnAccept(kChannel.Id, realEndPoint);
  275. }
  276. if (kChannel.RemoteConn != remoteConn)
  277. {
  278. break;
  279. }
  280. // 地址跟上次的不一致则跳过
  281. if (kChannel.RealAddress != realAddress)
  282. {
  283. Log.Error($"kchannel syn address diff: {kChannel.Id} {kChannel.RealAddress} {realAddress}");
  284. break;
  285. }
  286. try
  287. {
  288. byte[] buffer = this.cache;
  289. buffer.WriteTo(0, KcpProtocalType.ACK);
  290. buffer.WriteTo(1, kChannel.LocalConn);
  291. buffer.WriteTo(5, kChannel.RemoteConn);
  292. Log.Info($"kservice syn: {kChannel.Id} {remoteConn} {localConn}");
  293. this.socket.SendTo(buffer, 0, 9, SocketFlags.None, kChannel.RemoteAddress);
  294. }
  295. catch (Exception e)
  296. {
  297. Log.Error(e);
  298. kChannel.OnError(ErrorCore.ERR_SocketCantSend);
  299. }
  300. break;
  301. }
  302. #endif
  303. case KcpProtocalType.ACK: // connect返回
  304. // 长度!=9,不是connect消息
  305. if (messageLength != 9)
  306. {
  307. break;
  308. }
  309. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  310. localConn = BitConverter.ToUInt32(this.cache, 5);
  311. kChannel = this.GetByLocalConn(localConn);
  312. if (kChannel != null)
  313. {
  314. Log.Info($"kservice ack: {kChannel.Id} {remoteConn} {localConn}");
  315. kChannel.RemoteConn = remoteConn;
  316. kChannel.HandleConnnect();
  317. }
  318. break;
  319. case KcpProtocalType.FIN: // 断开
  320. // 长度!=13,不是DisConnect消息
  321. if (messageLength != 13)
  322. {
  323. break;
  324. }
  325. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  326. localConn = BitConverter.ToUInt32(this.cache, 5);
  327. int error = BitConverter.ToInt32(this.cache, 9);
  328. // 处理chanel
  329. kChannel = this.GetByLocalConn(localConn);
  330. if (kChannel == null)
  331. {
  332. break;
  333. }
  334. // 校验remoteConn,防止第三方攻击
  335. if (kChannel.RemoteConn != remoteConn)
  336. {
  337. break;
  338. }
  339. Log.Info($"kservice recv fin: {kChannel.Id} {localConn} {remoteConn} {error}");
  340. kChannel.OnError(ErrorCore.ERR_PeerDisconnect);
  341. break;
  342. case KcpProtocalType.MSG: // 断开
  343. // 长度<9,不是Msg消息
  344. if (messageLength < 9)
  345. {
  346. break;
  347. }
  348. // 处理chanel
  349. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  350. localConn = BitConverter.ToUInt32(this.cache, 5);
  351. kChannel = this.GetByLocalConn(localConn);
  352. if (kChannel == null)
  353. {
  354. // 通知对方断开
  355. this.Disconnect(localConn, remoteConn, ErrorCore.ERR_KcpNotFoundChannel, (IPEndPoint) this.ipEndPoint, 1);
  356. break;
  357. }
  358. // 校验remoteConn,防止第三方攻击
  359. if (kChannel.RemoteConn != remoteConn)
  360. {
  361. break;
  362. }
  363. kChannel.HandleRecv(this.cache, 5, messageLength - 5);
  364. break;
  365. }
  366. }
  367. catch (Exception e)
  368. {
  369. Log.Error($"kservice error: {flag} {remoteConn} {localConn}\n{e}");
  370. }
  371. }
  372. }
  373. public KChannel Get(long id)
  374. {
  375. KChannel channel;
  376. this.idChannels.TryGetValue(id, out channel);
  377. return channel;
  378. }
  379. private KChannel GetByLocalConn(uint localConn)
  380. {
  381. KChannel channel;
  382. this.localConnChannels.TryGetValue(localConn, out channel);
  383. return channel;
  384. }
  385. protected override void Get(long id, IPEndPoint address)
  386. {
  387. if (this.idChannels.TryGetValue(id, out KChannel kChannel))
  388. {
  389. return;
  390. }
  391. try
  392. {
  393. // 低32bit是localConn
  394. uint localConn = (uint) ((ulong) id & uint.MaxValue);
  395. kChannel = new KChannel(id, localConn, this.socket, address, this);
  396. this.idChannels.Add(id, kChannel);
  397. this.localConnChannels.Add(kChannel.LocalConn, kChannel);
  398. }
  399. catch (Exception e)
  400. {
  401. Log.Error($"kservice get error: {id}\n{e}");
  402. }
  403. }
  404. public override void Remove(long id)
  405. {
  406. if (!this.idChannels.TryGetValue(id, out KChannel kChannel))
  407. {
  408. return;
  409. }
  410. Log.Info($"kservice remove channel: {id} {kChannel.LocalConn} {kChannel.RemoteConn}");
  411. this.idChannels.Remove(id);
  412. this.localConnChannels.Remove(kChannel.LocalConn);
  413. if (this.waitConnectChannels.TryGetValue(kChannel.RemoteConn, out KChannel waitChannel))
  414. {
  415. if (waitChannel.LocalConn == kChannel.LocalConn)
  416. {
  417. this.waitConnectChannels.Remove(kChannel.RemoteConn);
  418. }
  419. }
  420. kChannel.Dispose();
  421. }
  422. private void Disconnect(uint localConn, uint remoteConn, int error, IPEndPoint address, int times)
  423. {
  424. try
  425. {
  426. if (this.socket == null)
  427. {
  428. return;
  429. }
  430. byte[] buffer = this.cache;
  431. buffer.WriteTo(0, KcpProtocalType.FIN);
  432. buffer.WriteTo(1, localConn);
  433. buffer.WriteTo(5, remoteConn);
  434. buffer.WriteTo(9, (uint) error);
  435. for (int i = 0; i < times; ++i)
  436. {
  437. this.socket.SendTo(buffer, 0, 13, SocketFlags.None, address);
  438. }
  439. }
  440. catch (Exception e)
  441. {
  442. Log.Error($"Disconnect error {localConn} {remoteConn} {error} {address} {e}");
  443. }
  444. Log.Info($"channel send fin: {localConn} {remoteConn} {address} {error}");
  445. }
  446. protected override void Send(long channelId, long actorId, MemoryStream stream)
  447. {
  448. KChannel channel = this.Get(channelId);
  449. if (channel == null)
  450. {
  451. return;
  452. }
  453. channel.Send(actorId, stream);
  454. }
  455. // 服务端需要看channel的update时间是否已到
  456. public void AddToUpdateNextTime(long time, long id)
  457. {
  458. if (time == 0)
  459. {
  460. this.updateChannels.Add(id);
  461. return;
  462. }
  463. if (time < this.minTime)
  464. {
  465. this.minTime = time;
  466. }
  467. this.timeId.Add(time, id);
  468. }
  469. public override void Update()
  470. {
  471. this.Recv();
  472. this.TimerOut();
  473. foreach (long id in updateChannels)
  474. {
  475. KChannel kChannel = this.Get(id);
  476. if (kChannel == null)
  477. {
  478. continue;
  479. }
  480. if (kChannel.Id == 0)
  481. {
  482. continue;
  483. }
  484. kChannel.Update();
  485. }
  486. this.updateChannels.Clear();
  487. this.RemoveConnectTimeoutChannels();
  488. }
  489. private void RemoveConnectTimeoutChannels()
  490. {
  491. waitRemoveChannels.Clear();
  492. foreach (long channelId in this.waitConnectChannels.Keys)
  493. {
  494. this.waitConnectChannels.TryGetValue(channelId, out KChannel kChannel);
  495. if (kChannel == null)
  496. {
  497. Log.Error($"RemoveConnectTimeoutChannels not found kchannel: {channelId}");
  498. continue;
  499. }
  500. // 连接上了要马上删除
  501. if (kChannel.IsConnected)
  502. {
  503. waitRemoveChannels.Add(channelId);
  504. }
  505. // 10秒连接超时
  506. if (this.TimeNow > kChannel.CreateTime + 10 * 1000)
  507. {
  508. waitRemoveChannels.Add(channelId);
  509. }
  510. }
  511. foreach (long channelId in waitRemoveChannels)
  512. {
  513. this.waitConnectChannels.Remove(channelId);
  514. }
  515. }
  516. // 计算到期需要update的channel
  517. private void TimerOut()
  518. {
  519. if (this.timeId.Count == 0)
  520. {
  521. return;
  522. }
  523. uint timeNow = this.TimeNow;
  524. if (timeNow < this.minTime)
  525. {
  526. return;
  527. }
  528. this.timeOutTime.Clear();
  529. foreach (KeyValuePair<long, List<long>> kv in this.timeId)
  530. {
  531. long k = kv.Key;
  532. if (k > timeNow)
  533. {
  534. minTime = k;
  535. break;
  536. }
  537. this.timeOutTime.Add(k);
  538. }
  539. foreach (long k in this.timeOutTime)
  540. {
  541. foreach (long v in this.timeId[k])
  542. {
  543. this.updateChannels.Add(v);
  544. }
  545. this.timeId.Remove(k);
  546. }
  547. }
  548. }
  549. }