KService.cs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. using System;
  2. using System.Collections.Concurrent;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Net;
  7. using System.Net.Sockets;
  8. using System.Runtime.InteropServices;
  9. namespace ET
  10. {
  11. public static class KcpProtocalType
  12. {
  13. public const byte SYN = 1;
  14. public const byte ACK = 2;
  15. public const byte FIN = 3;
  16. public const byte MSG = 4;
  17. }
  18. public enum ServiceType
  19. {
  20. Outer,
  21. Inner,
  22. }
  23. public sealed class KService: AService
  24. {
  25. // KService创建的时间
  26. private readonly long startTime;
  27. // 当前时间 - KService创建的时间, 线程安全
  28. public uint TimeNow
  29. {
  30. get
  31. {
  32. return (uint) (TimeHelper.ClientNow() - this.startTime);
  33. }
  34. }
  35. private Socket socket;
  36. #region 回调方法
  37. static KService()
  38. {
  39. //Kcp.KcpSetLog(KcpLog);
  40. Kcp.KcpSetoutput(KcpOutput);
  41. }
  42. private static readonly byte[] logBuffer = new byte[1024];
  43. #if ENABLE_IL2CPP
  44. [AOT.MonoPInvokeCallback(typeof(KcpOutput))]
  45. #endif
  46. private static void KcpLog(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
  47. {
  48. try
  49. {
  50. Marshal.Copy(bytes, logBuffer, 0, len);
  51. Log.Info(logBuffer.ToStr(0, len));
  52. }
  53. catch (Exception e)
  54. {
  55. Log.Error(e);
  56. }
  57. }
  58. #if ENABLE_IL2CPP
  59. [AOT.MonoPInvokeCallback(typeof(KcpOutput))]
  60. #endif
  61. private static int KcpOutput(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
  62. {
  63. try
  64. {
  65. if (kcp == IntPtr.Zero)
  66. {
  67. return 0;
  68. }
  69. if (!KChannel.KcpPtrChannels.TryGetValue(kcp, out KChannel kChannel))
  70. {
  71. return 0;
  72. }
  73. kChannel.Output(bytes, len);
  74. }
  75. catch (Exception e)
  76. {
  77. Log.Error(e);
  78. return len;
  79. }
  80. return len;
  81. }
  82. #endregion
  83. public KService(ThreadSynchronizationContext threadSynchronizationContext, IPEndPoint ipEndPoint, ServiceType serviceType)
  84. {
  85. this.ServiceType = serviceType;
  86. this.ThreadSynchronizationContext = threadSynchronizationContext;
  87. this.startTime = TimeHelper.ClientNow();
  88. this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
  89. if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
  90. {
  91. this.socket.SendBufferSize = Kcp.OneM * 64;
  92. this.socket.ReceiveBufferSize = Kcp.OneM * 64;
  93. }
  94. this.socket.Bind(ipEndPoint);
  95. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  96. {
  97. const uint IOC_IN = 0x80000000;
  98. const uint IOC_VENDOR = 0x18000000;
  99. uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
  100. this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
  101. }
  102. }
  103. public KService(ThreadSynchronizationContext threadSynchronizationContext, ServiceType serviceType)
  104. {
  105. this.ServiceType = serviceType;
  106. this.ThreadSynchronizationContext = threadSynchronizationContext;
  107. this.startTime = TimeHelper.ClientNow();
  108. this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
  109. // 作为客户端不需要修改发送跟接收缓冲区大小
  110. this.socket.Bind(new IPEndPoint(IPAddress.Any, 0));
  111. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  112. {
  113. const uint IOC_IN = 0x80000000;
  114. const uint IOC_VENDOR = 0x18000000;
  115. uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
  116. this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
  117. }
  118. }
  119. public void ChangeAddress(long id, IPEndPoint address)
  120. {
  121. KChannel kChannel = this.Get(id);
  122. if (kChannel == null)
  123. {
  124. return;
  125. }
  126. Log.Info($"channel change address: {id} {address}");
  127. kChannel.RemoteAddress = address;
  128. }
  129. // 保存所有的channel
  130. private readonly Dictionary<long, KChannel> idChannels = new Dictionary<long, KChannel>();
  131. private readonly Dictionary<long, KChannel> localConnChannels = new Dictionary<long, KChannel>();
  132. private readonly Dictionary<long, KChannel> waitConnectChannels = new Dictionary<long, KChannel>();
  133. private readonly byte[] cache = new byte[8192];
  134. private EndPoint ipEndPoint = new IPEndPoint(IPAddress.Any, 0);
  135. // 下帧要更新的channel
  136. private readonly HashSet<long> updateChannels = new HashSet<long>();
  137. // 下次时间更新的channel
  138. private readonly MultiMap<long, long> timeId = new MultiMap<long, long>();
  139. private readonly List<long> timeOutTime = new List<long>();
  140. // 记录最小时间,不用每次都去MultiMap取第一个值
  141. private long minTime;
  142. public override bool IsDispose()
  143. {
  144. return this.socket == null;
  145. }
  146. public override void Dispose()
  147. {
  148. foreach (long channelId in this.idChannels.Keys.ToArray())
  149. {
  150. this.Remove(channelId);
  151. }
  152. this.socket.Close();
  153. this.socket = null;
  154. }
  155. private IPEndPoint CloneAddress()
  156. {
  157. IPEndPoint ip = (IPEndPoint) this.ipEndPoint;
  158. return new IPEndPoint(ip.Address, ip.Port);
  159. }
  160. private void Recv()
  161. {
  162. if (this.socket == null)
  163. {
  164. return;
  165. }
  166. while (socket != null && this.socket.Available > 0)
  167. {
  168. int messageLength = this.socket.ReceiveFrom(this.cache, ref this.ipEndPoint);
  169. // 长度小于1,不是正常的消息
  170. if (messageLength < 1)
  171. {
  172. continue;
  173. }
  174. // accept
  175. byte flag = this.cache[0];
  176. // conn从100开始,如果为1,2,3则是特殊包
  177. uint remoteConn = 0;
  178. uint localConn = 0;
  179. try
  180. {
  181. KChannel kChannel = null;
  182. switch (flag)
  183. {
  184. #if NOT_UNITY
  185. case KcpProtocalType.SYN: // accept
  186. {
  187. // 长度!=5,不是SYN消息
  188. if (messageLength < 9)
  189. {
  190. break;
  191. }
  192. string realAddress = null;
  193. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  194. if (messageLength > 9)
  195. {
  196. realAddress = this.cache.ToStr(9, messageLength - 9);
  197. }
  198. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  199. localConn = BitConverter.ToUInt32(this.cache, 5);
  200. this.waitConnectChannels.TryGetValue(remoteConn, out kChannel);
  201. if (kChannel == null)
  202. {
  203. localConn = CreateRandomLocalConn();
  204. // 已存在同样的localConn,则不处理,等待下次sync
  205. if (this.localConnChannels.ContainsKey(localConn))
  206. {
  207. break;
  208. }
  209. long id = this.CreateAcceptChannelId(localConn);
  210. if (this.idChannels.ContainsKey(id))
  211. {
  212. break;
  213. }
  214. kChannel = new KChannel(id, localConn, remoteConn, this.socket, this.CloneAddress(), this);
  215. this.idChannels.Add(kChannel.Id, kChannel);
  216. this.waitConnectChannels.Add(kChannel.RemoteConn, kChannel); // 连接上了或者超时后会删除
  217. this.localConnChannels.Add(kChannel.LocalConn, kChannel);
  218. kChannel.RealAddress = realAddress;
  219. IPEndPoint realEndPoint = kChannel.RealAddress == null? kChannel.RemoteAddress : NetworkHelper.ToIPEndPoint(kChannel.RealAddress);
  220. this.OnAccept(kChannel.Id, realEndPoint);
  221. }
  222. if (kChannel.RemoteConn != remoteConn)
  223. {
  224. break;
  225. }
  226. // 地址跟上次的不一致则跳过
  227. if (kChannel.RealAddress != realAddress)
  228. {
  229. Log.Error($"kchannel syn address diff: {kChannel.Id} {kChannel.RealAddress} {realAddress}");
  230. break;
  231. }
  232. try
  233. {
  234. byte[] buffer = this.cache;
  235. buffer.WriteTo(0, KcpProtocalType.ACK);
  236. buffer.WriteTo(1, kChannel.LocalConn);
  237. buffer.WriteTo(5, kChannel.RemoteConn);
  238. Log.Info($"kservice syn: {kChannel.Id} {remoteConn} {localConn}");
  239. this.socket.SendTo(buffer, 0, 9, SocketFlags.None, kChannel.RemoteAddress);
  240. }
  241. catch (Exception e)
  242. {
  243. Log.Error(e);
  244. kChannel.OnError(ErrorCode.ERR_SocketCantSend);
  245. }
  246. break;
  247. }
  248. #endif
  249. case KcpProtocalType.ACK: // connect返回
  250. // 长度!=9,不是connect消息
  251. if (messageLength != 9)
  252. {
  253. break;
  254. }
  255. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  256. localConn = BitConverter.ToUInt32(this.cache, 5);
  257. kChannel = this.GetByLocalConn(localConn);
  258. if (kChannel != null)
  259. {
  260. Log.Info($"kservice ack: {kChannel.Id} {remoteConn} {localConn}");
  261. kChannel.RemoteConn = remoteConn;
  262. kChannel.HandleConnnect();
  263. }
  264. break;
  265. case KcpProtocalType.FIN: // 断开
  266. // 长度!=13,不是DisConnect消息
  267. if (messageLength != 13)
  268. {
  269. break;
  270. }
  271. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  272. localConn = BitConverter.ToUInt32(this.cache, 5);
  273. int error = BitConverter.ToInt32(this.cache, 9);
  274. // 处理chanel
  275. kChannel = this.GetByLocalConn(localConn);
  276. if (kChannel == null)
  277. {
  278. break;
  279. }
  280. // 校验remoteConn,防止第三方攻击
  281. if (kChannel.RemoteConn != remoteConn)
  282. {
  283. break;
  284. }
  285. Log.Info($"kservice recv fin: {kChannel.Id} {localConn} {remoteConn} {error}");
  286. kChannel.OnError(ErrorCode.ERR_PeerDisconnect);
  287. break;
  288. case KcpProtocalType.MSG: // 断开
  289. // 长度<9,不是Msg消息
  290. if (messageLength < 9)
  291. {
  292. break;
  293. }
  294. // 处理chanel
  295. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  296. localConn = BitConverter.ToUInt32(this.cache, 5);
  297. kChannel = this.GetByLocalConn(localConn);
  298. if (kChannel == null)
  299. {
  300. // 通知对方断开
  301. this.Disconnect(localConn, remoteConn, ErrorCode.ERR_KcpNotFoundChannel, (IPEndPoint) this.ipEndPoint, 1);
  302. break;
  303. }
  304. // 校验remoteConn,防止第三方攻击
  305. if (kChannel.RemoteConn != remoteConn)
  306. {
  307. break;
  308. }
  309. kChannel.HandleRecv(this.cache, 5, messageLength - 5);
  310. break;
  311. }
  312. }
  313. catch (Exception e)
  314. {
  315. Log.Error($"kservice error: {flag} {remoteConn} {localConn}\n{e}");
  316. }
  317. }
  318. }
  319. public KChannel Get(long id)
  320. {
  321. KChannel channel;
  322. this.idChannels.TryGetValue(id, out channel);
  323. return channel;
  324. }
  325. private KChannel GetByLocalConn(uint localConn)
  326. {
  327. KChannel channel;
  328. this.localConnChannels.TryGetValue(localConn, out channel);
  329. return channel;
  330. }
  331. protected override void Get(long id, IPEndPoint address)
  332. {
  333. if (this.idChannels.TryGetValue(id, out KChannel kChannel))
  334. {
  335. return;
  336. }
  337. try
  338. {
  339. // 低32bit是localConn
  340. uint localConn = (uint) ((ulong) id & uint.MaxValue);
  341. kChannel = new KChannel(id, localConn, this.socket, address, this);
  342. this.idChannels.Add(id, kChannel);
  343. this.localConnChannels.Add(kChannel.LocalConn, kChannel);
  344. }
  345. catch (Exception e)
  346. {
  347. Log.Error($"kservice get error: {id}\n{e}");
  348. }
  349. }
  350. public override void Remove(long id)
  351. {
  352. if (!this.idChannels.TryGetValue(id, out KChannel kChannel))
  353. {
  354. return;
  355. }
  356. Log.Info($"kservice remove channel: {id} {kChannel.LocalConn} {kChannel.RemoteConn}");
  357. this.idChannels.Remove(id);
  358. this.localConnChannels.Remove(kChannel.LocalConn);
  359. if (this.waitConnectChannels.TryGetValue(kChannel.RemoteConn, out KChannel waitChannel))
  360. {
  361. if (waitChannel.LocalConn == kChannel.LocalConn)
  362. {
  363. this.waitConnectChannels.Remove(kChannel.RemoteConn);
  364. }
  365. }
  366. kChannel.Dispose();
  367. }
  368. private void Disconnect(uint localConn, uint remoteConn, int error, IPEndPoint address, int times)
  369. {
  370. try
  371. {
  372. if (this.socket == null)
  373. {
  374. return;
  375. }
  376. byte[] buffer = this.cache;
  377. buffer.WriteTo(0, KcpProtocalType.FIN);
  378. buffer.WriteTo(1, localConn);
  379. buffer.WriteTo(5, remoteConn);
  380. buffer.WriteTo(9, (uint) error);
  381. for (int i = 0; i < times; ++i)
  382. {
  383. this.socket.SendTo(buffer, 0, 13, SocketFlags.None, address);
  384. }
  385. }
  386. catch (Exception e)
  387. {
  388. Log.Error($"Disconnect error {localConn} {remoteConn} {error} {address} {e}");
  389. }
  390. Log.Info($"channel send fin: {localConn} {remoteConn} {address} {error}");
  391. }
  392. protected override void Send(long channelId, long actorId, MemoryStream stream)
  393. {
  394. KChannel channel = this.Get(channelId);
  395. if (channel == null)
  396. {
  397. return;
  398. }
  399. channel.Send(actorId, stream);
  400. }
  401. // 服务端需要看channel的update时间是否已到
  402. public void AddToUpdateNextTime(long time, long id)
  403. {
  404. if (time == 0)
  405. {
  406. this.updateChannels.Add(id);
  407. return;
  408. }
  409. if (time < this.minTime)
  410. {
  411. this.minTime = time;
  412. }
  413. this.timeId.Add(time, id);
  414. }
  415. public override void Update()
  416. {
  417. this.Recv();
  418. this.TimerOut();
  419. foreach (long id in updateChannels)
  420. {
  421. KChannel kChannel = this.Get(id);
  422. if (kChannel == null)
  423. {
  424. continue;
  425. }
  426. if (kChannel.Id == 0)
  427. {
  428. continue;
  429. }
  430. kChannel.Update();
  431. }
  432. this.updateChannels.Clear();
  433. this.RemoveConnectTimeoutChannels();
  434. }
  435. private void RemoveConnectTimeoutChannels()
  436. {
  437. using (ListComponent<long> waitRemoveChannels = ListComponent<long>.Create())
  438. {
  439. foreach (long channelId in this.waitConnectChannels.Keys)
  440. {
  441. this.waitConnectChannels.TryGetValue(channelId, out KChannel kChannel);
  442. if (kChannel == null)
  443. {
  444. Log.Error($"RemoveConnectTimeoutChannels not found kchannel: {channelId}");
  445. continue;
  446. }
  447. // 连接上了要马上删除
  448. if (kChannel.IsConnected)
  449. {
  450. waitRemoveChannels.List.Add(channelId);
  451. }
  452. // 10秒连接超时
  453. if (this.TimeNow > kChannel.CreateTime + 10 * 1000)
  454. {
  455. waitRemoveChannels.List.Add(channelId);
  456. }
  457. }
  458. foreach (long channelId in waitRemoveChannels.List)
  459. {
  460. this.waitConnectChannels.Remove(channelId);
  461. }
  462. }
  463. }
  464. // 计算到期需要update的channel
  465. private void TimerOut()
  466. {
  467. if (this.timeId.Count == 0)
  468. {
  469. return;
  470. }
  471. uint timeNow = this.TimeNow;
  472. if (timeNow < this.minTime)
  473. {
  474. return;
  475. }
  476. this.timeOutTime.Clear();
  477. foreach (KeyValuePair<long, List<long>> kv in this.timeId)
  478. {
  479. long k = kv.Key;
  480. if (k > timeNow)
  481. {
  482. minTime = k;
  483. break;
  484. }
  485. this.timeOutTime.Add(k);
  486. }
  487. foreach (long k in this.timeOutTime)
  488. {
  489. foreach (long v in this.timeId[k])
  490. {
  491. this.updateChannels.Add(v);
  492. }
  493. this.timeId.Remove(k);
  494. }
  495. }
  496. }
  497. }