KService.cs 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641
  1. using System;
  2. using System.Collections.Concurrent;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Net;
  7. using System.Net.Sockets;
  8. using System.Runtime.InteropServices;
  9. namespace ET
  10. {
  11. public static class KcpProtocalType
  12. {
  13. public const byte SYN = 1;
  14. public const byte ACK = 2;
  15. public const byte FIN = 3;
  16. public const byte MSG = 4;
  17. public const byte RouterReconnectSYN = 5;
  18. public const byte RouterReconnectACK = 6;
  19. public const byte RouterSYN = 7;
  20. public const byte RouterACK = 8;
  21. }
  22. public enum ServiceType
  23. {
  24. Outer,
  25. Inner,
  26. }
  27. public sealed class KService: AService
  28. {
  29. // KService创建的时间
  30. private readonly long startTime;
  31. // 当前时间 - KService创建的时间, 线程安全
  32. public uint TimeNow
  33. {
  34. get
  35. {
  36. return (uint) (TimeHelper.ClientNow() - this.startTime);
  37. }
  38. }
  39. private Socket socket;
  40. #region 回调方法
  41. static KService()
  42. {
  43. //Kcp.KcpSetLog(KcpLog);
  44. Kcp.KcpSetoutput(KcpOutput);
  45. }
  46. private static readonly byte[] logBuffer = new byte[1024];
  47. #if ENABLE_IL2CPP
  48. [AOT.MonoPInvokeCallback(typeof(KcpOutput))]
  49. #endif
  50. private static void KcpLog(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
  51. {
  52. try
  53. {
  54. Marshal.Copy(bytes, logBuffer, 0, len);
  55. Log.Info(logBuffer.ToStr(0, len));
  56. }
  57. catch (Exception e)
  58. {
  59. Log.Error(e);
  60. }
  61. }
  62. #if ENABLE_IL2CPP
  63. [AOT.MonoPInvokeCallback(typeof(KcpOutput))]
  64. #endif
  65. private static int KcpOutput(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
  66. {
  67. try
  68. {
  69. if (kcp == IntPtr.Zero)
  70. {
  71. return 0;
  72. }
  73. if (!KChannel.KcpPtrChannels.TryGetValue(kcp, out KChannel kChannel))
  74. {
  75. return 0;
  76. }
  77. kChannel.Output(bytes, len);
  78. }
  79. catch (Exception e)
  80. {
  81. Log.Error(e);
  82. return len;
  83. }
  84. return len;
  85. }
  86. #endregion
  87. public KService(ThreadSynchronizationContext threadSynchronizationContext, IPEndPoint ipEndPoint, ServiceType serviceType)
  88. {
  89. this.ServiceType = serviceType;
  90. this.ThreadSynchronizationContext = threadSynchronizationContext;
  91. this.startTime = TimeHelper.ClientNow();
  92. this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
  93. if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
  94. {
  95. this.socket.SendBufferSize = Kcp.OneM * 64;
  96. this.socket.ReceiveBufferSize = Kcp.OneM * 64;
  97. }
  98. this.socket.Bind(ipEndPoint);
  99. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  100. {
  101. const uint IOC_IN = 0x80000000;
  102. const uint IOC_VENDOR = 0x18000000;
  103. uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
  104. this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
  105. }
  106. }
  107. public KService(ThreadSynchronizationContext threadSynchronizationContext, ServiceType serviceType)
  108. {
  109. this.ServiceType = serviceType;
  110. this.ThreadSynchronizationContext = threadSynchronizationContext;
  111. this.startTime = TimeHelper.ClientNow();
  112. this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
  113. // 作为客户端不需要修改发送跟接收缓冲区大小
  114. this.socket.Bind(new IPEndPoint(IPAddress.Any, 0));
  115. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  116. {
  117. const uint IOC_IN = 0x80000000;
  118. const uint IOC_VENDOR = 0x18000000;
  119. uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
  120. this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
  121. }
  122. }
  123. public void ChangeAddress(long id, IPEndPoint address)
  124. {
  125. KChannel kChannel = this.Get(id);
  126. if (kChannel == null)
  127. {
  128. return;
  129. }
  130. Log.Info($"channel change address: {id} {address}");
  131. kChannel.RemoteAddress = address;
  132. }
  133. // 保存所有的channel
  134. private readonly Dictionary<long, KChannel> idChannels = new Dictionary<long, KChannel>();
  135. private readonly Dictionary<long, KChannel> localConnChannels = new Dictionary<long, KChannel>();
  136. private readonly Dictionary<long, KChannel> waitConnectChannels = new Dictionary<long, KChannel>();
  137. private readonly byte[] cache = new byte[8192];
  138. private EndPoint ipEndPoint = new IPEndPoint(IPAddress.Any, 0);
  139. // 下帧要更新的channel
  140. private readonly HashSet<long> updateChannels = new HashSet<long>();
  141. // 下次时间更新的channel
  142. private readonly MultiMap<long, long> timeId = new MultiMap<long, long>();
  143. private readonly List<long> timeOutTime = new List<long>();
  144. // 记录最小时间,不用每次都去MultiMap取第一个值
  145. private long minTime;
  146. private List<long> waitRemoveChannels = new List<long>();
  147. public override bool IsDispose()
  148. {
  149. return this.socket == null;
  150. }
  151. public override void Dispose()
  152. {
  153. foreach (long channelId in this.idChannels.Keys.ToArray())
  154. {
  155. this.Remove(channelId);
  156. }
  157. this.socket.Close();
  158. this.socket = null;
  159. }
  160. private IPEndPoint CloneAddress()
  161. {
  162. IPEndPoint ip = (IPEndPoint) this.ipEndPoint;
  163. return new IPEndPoint(ip.Address, ip.Port);
  164. }
  165. private void Recv()
  166. {
  167. if (this.socket == null)
  168. {
  169. return;
  170. }
  171. while (socket != null && this.socket.Available > 0)
  172. {
  173. int messageLength = this.socket.ReceiveFrom(this.cache, ref this.ipEndPoint);
  174. // 长度小于1,不是正常的消息
  175. if (messageLength < 1)
  176. {
  177. continue;
  178. }
  179. // accept
  180. byte flag = this.cache[0];
  181. // conn从100开始,如果为1,2,3则是特殊包
  182. uint remoteConn = 0;
  183. uint localConn = 0;
  184. try
  185. {
  186. KChannel kChannel = null;
  187. switch (flag)
  188. {
  189. case KcpProtocalType.RouterReconnectSYN:
  190. {
  191. // 长度!=5,不是RouterReconnectSYN消息
  192. if (messageLength != 13)
  193. {
  194. break;
  195. }
  196. string realAddress = null;
  197. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  198. localConn = BitConverter.ToUInt32(this.cache, 5);
  199. uint connectId = BitConverter.ToUInt32(this.cache, 9);
  200. this.localConnChannels.TryGetValue(localConn, out kChannel);
  201. if (kChannel == null)
  202. {
  203. Log.Warning($"kchannel reconnect not found channel: {localConn} {remoteConn} {realAddress}");
  204. break;
  205. }
  206. // 这里必须校验localConn,客户端重连,localConn一定是一样的
  207. if (localConn != kChannel.LocalConn)
  208. {
  209. Log.Warning($"kchannel reconnect localconn error: {kChannel.Id} {localConn} {remoteConn} {realAddress} {kChannel.LocalConn}");
  210. break;
  211. }
  212. if (remoteConn != kChannel.RemoteConn)
  213. {
  214. Log.Warning($"kchannel reconnect remoteconn error: {kChannel.Id} {localConn} {remoteConn} {realAddress} {kChannel.RemoteConn}");
  215. break;
  216. }
  217. // 重连的时候router地址变化, 这个不能放到msg中,必须经过严格的验证才能切换
  218. if (!Equals(kChannel.RemoteAddress, this.ipEndPoint))
  219. {
  220. kChannel.RemoteAddress = this.CloneAddress();
  221. }
  222. try
  223. {
  224. byte[] buffer = this.cache;
  225. buffer.WriteTo(0, KcpProtocalType.RouterReconnectACK);
  226. buffer.WriteTo(1, kChannel.LocalConn);
  227. buffer.WriteTo(5, kChannel.RemoteConn);
  228. buffer.WriteTo(9, connectId);
  229. this.socket.SendTo(buffer, 0, 13, SocketFlags.None, this.ipEndPoint);
  230. }
  231. catch (Exception e)
  232. {
  233. Log.Error(e);
  234. kChannel.OnError(ErrorCore.ERR_SocketCantSend);
  235. }
  236. break;
  237. }
  238. case KcpProtocalType.SYN: // accept
  239. {
  240. // 长度!=5,不是SYN消息
  241. if (messageLength < 9)
  242. {
  243. break;
  244. }
  245. string realAddress = null;
  246. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  247. if (messageLength > 9)
  248. {
  249. realAddress = this.cache.ToStr(9, messageLength - 9);
  250. }
  251. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  252. localConn = BitConverter.ToUInt32(this.cache, 5);
  253. this.waitConnectChannels.TryGetValue(remoteConn, out kChannel);
  254. if (kChannel == null)
  255. {
  256. localConn = CreateRandomLocalConn();
  257. // 已存在同样的localConn,则不处理,等待下次sync
  258. if (this.localConnChannels.ContainsKey(localConn))
  259. {
  260. break;
  261. }
  262. long id = this.CreateAcceptChannelId(localConn);
  263. if (this.idChannels.ContainsKey(id))
  264. {
  265. break;
  266. }
  267. kChannel = new KChannel(id, localConn, remoteConn, this.socket, this.CloneAddress(), this);
  268. this.idChannels.Add(kChannel.Id, kChannel);
  269. this.waitConnectChannels.Add(kChannel.RemoteConn, kChannel); // 连接上了或者超时后会删除
  270. this.localConnChannels.Add(kChannel.LocalConn, kChannel);
  271. kChannel.RealAddress = realAddress;
  272. IPEndPoint realEndPoint = kChannel.RealAddress == null? kChannel.RemoteAddress : NetworkHelper.ToIPEndPoint(kChannel.RealAddress);
  273. this.OnAccept(kChannel.Id, realEndPoint);
  274. }
  275. if (kChannel.RemoteConn != remoteConn)
  276. {
  277. break;
  278. }
  279. // 地址跟上次的不一致则跳过
  280. if (kChannel.RealAddress != realAddress)
  281. {
  282. Log.Error($"kchannel syn address diff: {kChannel.Id} {kChannel.RealAddress} {realAddress}");
  283. break;
  284. }
  285. try
  286. {
  287. byte[] buffer = this.cache;
  288. buffer.WriteTo(0, KcpProtocalType.ACK);
  289. buffer.WriteTo(1, kChannel.LocalConn);
  290. buffer.WriteTo(5, kChannel.RemoteConn);
  291. Log.Info($"kservice syn: {kChannel.Id} {remoteConn} {localConn}");
  292. this.socket.SendTo(buffer, 0, 9, SocketFlags.None, kChannel.RemoteAddress);
  293. }
  294. catch (Exception e)
  295. {
  296. Log.Error(e);
  297. kChannel.OnError(ErrorCore.ERR_SocketCantSend);
  298. }
  299. break;
  300. }
  301. case KcpProtocalType.ACK: // connect返回
  302. // 长度!=9,不是connect消息
  303. if (messageLength != 9)
  304. {
  305. break;
  306. }
  307. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  308. localConn = BitConverter.ToUInt32(this.cache, 5);
  309. kChannel = this.GetByLocalConn(localConn);
  310. if (kChannel != null)
  311. {
  312. Log.Info($"kservice ack: {kChannel.Id} {remoteConn} {localConn}");
  313. kChannel.RemoteConn = remoteConn;
  314. kChannel.HandleConnnect();
  315. }
  316. break;
  317. case KcpProtocalType.FIN: // 断开
  318. // 长度!=13,不是DisConnect消息
  319. if (messageLength != 13)
  320. {
  321. break;
  322. }
  323. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  324. localConn = BitConverter.ToUInt32(this.cache, 5);
  325. int error = BitConverter.ToInt32(this.cache, 9);
  326. // 处理chanel
  327. kChannel = this.GetByLocalConn(localConn);
  328. if (kChannel == null)
  329. {
  330. break;
  331. }
  332. // 校验remoteConn,防止第三方攻击
  333. if (kChannel.RemoteConn != remoteConn)
  334. {
  335. break;
  336. }
  337. Log.Info($"kservice recv fin: {kChannel.Id} {localConn} {remoteConn} {error}");
  338. kChannel.OnError(ErrorCore.ERR_PeerDisconnect);
  339. break;
  340. case KcpProtocalType.MSG: // 断开
  341. // 长度<9,不是Msg消息
  342. if (messageLength < 9)
  343. {
  344. break;
  345. }
  346. // 处理chanel
  347. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  348. localConn = BitConverter.ToUInt32(this.cache, 5);
  349. kChannel = this.GetByLocalConn(localConn);
  350. if (kChannel == null)
  351. {
  352. // 通知对方断开
  353. this.Disconnect(localConn, remoteConn, ErrorCore.ERR_KcpNotFoundChannel, (IPEndPoint) this.ipEndPoint, 1);
  354. break;
  355. }
  356. // 校验remoteConn,防止第三方攻击
  357. if (kChannel.RemoteConn != remoteConn)
  358. {
  359. break;
  360. }
  361. kChannel.HandleRecv(this.cache, 5, messageLength - 5);
  362. break;
  363. }
  364. }
  365. catch (Exception e)
  366. {
  367. Log.Error($"kservice error: {flag} {remoteConn} {localConn}\n{e}");
  368. }
  369. }
  370. }
  371. public KChannel Get(long id)
  372. {
  373. KChannel channel;
  374. this.idChannels.TryGetValue(id, out channel);
  375. return channel;
  376. }
  377. private KChannel GetByLocalConn(uint localConn)
  378. {
  379. KChannel channel;
  380. this.localConnChannels.TryGetValue(localConn, out channel);
  381. return channel;
  382. }
  383. protected override void Get(long id, IPEndPoint address)
  384. {
  385. if (this.idChannels.TryGetValue(id, out KChannel kChannel))
  386. {
  387. return;
  388. }
  389. try
  390. {
  391. // 低32bit是localConn
  392. uint localConn = (uint) ((ulong) id & uint.MaxValue);
  393. kChannel = new KChannel(id, localConn, this.socket, address, this);
  394. this.idChannels.Add(id, kChannel);
  395. this.localConnChannels.Add(kChannel.LocalConn, kChannel);
  396. }
  397. catch (Exception e)
  398. {
  399. Log.Error($"kservice get error: {id}\n{e}");
  400. }
  401. }
  402. public override void Remove(long id)
  403. {
  404. if (!this.idChannels.TryGetValue(id, out KChannel kChannel))
  405. {
  406. return;
  407. }
  408. Log.Info($"kservice remove channel: {id} {kChannel.LocalConn} {kChannel.RemoteConn}");
  409. this.idChannels.Remove(id);
  410. this.localConnChannels.Remove(kChannel.LocalConn);
  411. if (this.waitConnectChannels.TryGetValue(kChannel.RemoteConn, out KChannel waitChannel))
  412. {
  413. if (waitChannel.LocalConn == kChannel.LocalConn)
  414. {
  415. this.waitConnectChannels.Remove(kChannel.RemoteConn);
  416. }
  417. }
  418. kChannel.Dispose();
  419. }
  420. private void Disconnect(uint localConn, uint remoteConn, int error, IPEndPoint address, int times)
  421. {
  422. try
  423. {
  424. if (this.socket == null)
  425. {
  426. return;
  427. }
  428. byte[] buffer = this.cache;
  429. buffer.WriteTo(0, KcpProtocalType.FIN);
  430. buffer.WriteTo(1, localConn);
  431. buffer.WriteTo(5, remoteConn);
  432. buffer.WriteTo(9, (uint) error);
  433. for (int i = 0; i < times; ++i)
  434. {
  435. this.socket.SendTo(buffer, 0, 13, SocketFlags.None, address);
  436. }
  437. }
  438. catch (Exception e)
  439. {
  440. Log.Error($"Disconnect error {localConn} {remoteConn} {error} {address} {e}");
  441. }
  442. Log.Info($"channel send fin: {localConn} {remoteConn} {address} {error}");
  443. }
  444. protected override void Send(long channelId, long actorId, MemoryStream stream)
  445. {
  446. KChannel channel = this.Get(channelId);
  447. if (channel == null)
  448. {
  449. return;
  450. }
  451. channel.Send(actorId, stream);
  452. }
  453. // 服务端需要看channel的update时间是否已到
  454. public void AddToUpdateNextTime(long time, long id)
  455. {
  456. if (time == 0)
  457. {
  458. this.updateChannels.Add(id);
  459. return;
  460. }
  461. if (time < this.minTime)
  462. {
  463. this.minTime = time;
  464. }
  465. this.timeId.Add(time, id);
  466. }
  467. public override void Update()
  468. {
  469. this.Recv();
  470. this.TimerOut();
  471. foreach (long id in updateChannels)
  472. {
  473. KChannel kChannel = this.Get(id);
  474. if (kChannel == null)
  475. {
  476. continue;
  477. }
  478. if (kChannel.Id == 0)
  479. {
  480. continue;
  481. }
  482. kChannel.Update();
  483. }
  484. this.updateChannels.Clear();
  485. this.RemoveConnectTimeoutChannels();
  486. }
  487. private void RemoveConnectTimeoutChannels()
  488. {
  489. waitRemoveChannels.Clear();
  490. foreach (long channelId in this.waitConnectChannels.Keys)
  491. {
  492. this.waitConnectChannels.TryGetValue(channelId, out KChannel kChannel);
  493. if (kChannel == null)
  494. {
  495. Log.Error($"RemoveConnectTimeoutChannels not found kchannel: {channelId}");
  496. continue;
  497. }
  498. // 连接上了要马上删除
  499. if (kChannel.IsConnected)
  500. {
  501. waitRemoveChannels.Add(channelId);
  502. }
  503. // 10秒连接超时
  504. if (this.TimeNow > kChannel.CreateTime + 10 * 1000)
  505. {
  506. waitRemoveChannels.Add(channelId);
  507. }
  508. }
  509. foreach (long channelId in waitRemoveChannels)
  510. {
  511. this.waitConnectChannels.Remove(channelId);
  512. }
  513. }
  514. // 计算到期需要update的channel
  515. private void TimerOut()
  516. {
  517. if (this.timeId.Count == 0)
  518. {
  519. return;
  520. }
  521. uint timeNow = this.TimeNow;
  522. if (timeNow < this.minTime)
  523. {
  524. return;
  525. }
  526. this.timeOutTime.Clear();
  527. foreach (KeyValuePair<long, List<long>> kv in this.timeId)
  528. {
  529. long k = kv.Key;
  530. if (k > timeNow)
  531. {
  532. minTime = k;
  533. break;
  534. }
  535. this.timeOutTime.Add(k);
  536. }
  537. foreach (long k in this.timeOutTime)
  538. {
  539. foreach (long v in this.timeId[k])
  540. {
  541. this.updateChannels.Add(v);
  542. }
  543. this.timeId.Remove(k);
  544. }
  545. }
  546. }
  547. }