KService.cs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5. using System.Net;
  6. using System.Net.Sockets;
  7. using System.Runtime.InteropServices;
  8. namespace ET
  9. {
  10. public static class KcpProtocalType
  11. {
  12. public const byte SYN = 1;
  13. public const byte ACK = 2;
  14. public const byte FIN = 3;
  15. public const byte MSG = 4;
  16. public const byte RouterReconnect = 10;
  17. public const byte RouterAck = 11;
  18. public const byte RouterSYN = 12;
  19. }
  20. public enum ServiceType
  21. {
  22. Outer,
  23. Inner,
  24. }
  25. public sealed class KService: AService
  26. {
  27. // KService创建的时间
  28. public long StartTime;
  29. // 当前时间 - KService创建的时间, 线程安全
  30. public uint TimeNow
  31. {
  32. get
  33. {
  34. return (uint) (TimeHelper.ClientNow() - this.StartTime);
  35. }
  36. }
  37. private Socket socket;
  38. #region 回调方法
  39. static KService()
  40. {
  41. //Kcp.KcpSetLog(KcpLog);
  42. Kcp.KcpSetoutput(KcpOutput);
  43. }
  44. private static readonly byte[] logBuffer = new byte[1024];
  45. #if ENABLE_IL2CPP
  46. [AOT.MonoPInvokeCallback(typeof(KcpOutput))]
  47. #endif
  48. private static void KcpLog(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
  49. {
  50. try
  51. {
  52. Marshal.Copy(bytes, logBuffer, 0, len);
  53. Log.Info(logBuffer.ToStr(0, len));
  54. }
  55. catch (Exception e)
  56. {
  57. Log.Error(e);
  58. }
  59. }
  60. #if ENABLE_IL2CPP
  61. [AOT.MonoPInvokeCallback(typeof(KcpOutput))]
  62. #endif
  63. private static int KcpOutput(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
  64. {
  65. try
  66. {
  67. KChannel kChannel = KChannel.kChannels[(uint) user];
  68. kChannel.Output(bytes, len);
  69. }
  70. catch (Exception e)
  71. {
  72. Log.Error(e);
  73. return len;
  74. }
  75. return len;
  76. }
  77. #endregion
  78. #region 主线程
  79. public KService(ThreadSynchronizationContext threadSynchronizationContext, IPEndPoint ipEndPoint, ServiceType serviceType)
  80. {
  81. this.ServiceType = serviceType;
  82. this.ThreadSynchronizationContext = threadSynchronizationContext;
  83. this.StartTime = TimeHelper.ClientNow();
  84. this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
  85. if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
  86. {
  87. this.socket.SendBufferSize = Kcp.OneM * 64;
  88. this.socket.ReceiveBufferSize = Kcp.OneM * 64;
  89. }
  90. this.socket.Bind(ipEndPoint);
  91. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  92. {
  93. const uint IOC_IN = 0x80000000;
  94. const uint IOC_VENDOR = 0x18000000;
  95. uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
  96. this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
  97. }
  98. }
  99. public KService(ThreadSynchronizationContext threadSynchronizationContext, ServiceType serviceType)
  100. {
  101. this.ServiceType = serviceType;
  102. this.ThreadSynchronizationContext = threadSynchronizationContext;
  103. this.StartTime = TimeHelper.ClientNow();
  104. this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
  105. // 作为客户端不需要修改发送跟接收缓冲区大小
  106. this.socket.Bind(new IPEndPoint(IPAddress.Any, 0));
  107. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  108. {
  109. const uint IOC_IN = 0x80000000;
  110. const uint IOC_VENDOR = 0x18000000;
  111. uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
  112. this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
  113. }
  114. }
  115. public void ChangeAddress(long id, IPEndPoint address)
  116. {
  117. #if NET_THREAD
  118. this.ThreadSynchronizationContext.Post(() =>
  119. {
  120. #endif
  121. KChannel kChannel = this.Get(id);
  122. if (kChannel == null)
  123. {
  124. return;
  125. }
  126. Log.Info($"channel change address: {id} {address}");
  127. kChannel.RemoteAddress = address;
  128. #if NET_THREAD
  129. }
  130. );
  131. #endif
  132. }
  133. #endregion
  134. #region 网络线程
  135. private readonly Dictionary<long, KChannel> idChannels = new Dictionary<long, KChannel>();
  136. private readonly Dictionary<long, KChannel> localConnChannels = new Dictionary<long, KChannel>();
  137. private readonly Dictionary<long, KChannel> waitConnectChannels = new Dictionary<long, KChannel>();
  138. private readonly List<long> waitRemoveChannels = new List<long>();
  139. private readonly byte[] cache = new byte[8192];
  140. private EndPoint ipEndPoint = new IPEndPoint(IPAddress.Any, 0);
  141. // 网络线程
  142. private readonly Random random = new Random(Guid.NewGuid().GetHashCode());
  143. // 下帧要更新的channel
  144. private readonly HashSet<long> updateChannels = new HashSet<long>();
  145. // 下次时间更新的channel
  146. private readonly MultiMap<long, long> timeId = new MultiMap<long, long>();
  147. private readonly List<long> timeOutTime = new List<long>();
  148. // 记录最小时间,不用每次都去MultiMap取第一个值
  149. private long minTime;
  150. public override bool IsDispose()
  151. {
  152. return this.socket == null;
  153. }
  154. public override void Dispose()
  155. {
  156. foreach (long channelId in this.idChannels.Keys.ToArray())
  157. {
  158. this.Remove(channelId);
  159. }
  160. this.socket.Close();
  161. this.socket = null;
  162. }
  163. private IPEndPoint CloneAddress()
  164. {
  165. IPEndPoint ip = (IPEndPoint) this.ipEndPoint;
  166. return new IPEndPoint(ip.Address, ip.Port);
  167. }
  168. private void Recv()
  169. {
  170. if (this.socket == null)
  171. {
  172. return;
  173. }
  174. while (socket != null && this.socket.Available > 0)
  175. {
  176. int messageLength = this.socket.ReceiveFrom(this.cache, ref this.ipEndPoint);
  177. // 长度小于1,不是正常的消息
  178. if (messageLength < 1)
  179. {
  180. continue;
  181. }
  182. // accept
  183. byte flag = this.cache[0];
  184. // conn从100开始,如果为1,2,3则是特殊包
  185. uint remoteConn = 0;
  186. uint localConn = 0;
  187. try
  188. {
  189. KChannel kChannel = null;
  190. switch (flag)
  191. {
  192. #if NOT_CLIENT
  193. case KcpProtocalType.SYN: // accept
  194. {
  195. // 长度!=5,不是SYN消息
  196. if (messageLength < 9)
  197. {
  198. break;
  199. }
  200. string realAddress = null;
  201. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  202. if (messageLength > 9)
  203. {
  204. realAddress = this.cache.ToStr(9, messageLength - 9);
  205. }
  206. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  207. localConn = BitConverter.ToUInt32(this.cache, 5);
  208. this.waitConnectChannels.TryGetValue(remoteConn, out kChannel);
  209. if (kChannel == null)
  210. {
  211. localConn = CreateRandomLocalConn(this.random);
  212. // 已存在同样的localConn,则不处理,等待下次sync
  213. if (this.localConnChannels.ContainsKey(localConn))
  214. {
  215. break;
  216. }
  217. long id = this.CreateAcceptChannelId(localConn);
  218. if (this.idChannels.ContainsKey(id))
  219. {
  220. break;
  221. }
  222. kChannel = new KChannel(id, localConn, remoteConn, this.socket, this.CloneAddress(), this);
  223. this.idChannels.Add(kChannel.Id, kChannel);
  224. this.waitConnectChannels.Add(kChannel.RemoteConn, kChannel); // 连接上了或者超时后会删除
  225. this.localConnChannels.Add(kChannel.LocalConn, kChannel);
  226. kChannel.RealAddress = realAddress;
  227. IPEndPoint realEndPoint = kChannel.RealAddress == null? kChannel.RemoteAddress : NetworkHelper.ToIPEndPoint(kChannel.RealAddress);
  228. this.OnAccept(kChannel.Id, realEndPoint);
  229. }
  230. if (kChannel.RemoteConn != remoteConn)
  231. {
  232. break;
  233. }
  234. // 地址跟上次的不一致则跳过
  235. if (kChannel.RealAddress != realAddress)
  236. {
  237. Log.Error($"kchannel syn address diff: {kChannel.Id} {kChannel.RealAddress} {realAddress}");
  238. break;
  239. }
  240. try
  241. {
  242. byte[] buffer = this.cache;
  243. buffer.WriteTo(0, KcpProtocalType.ACK);
  244. buffer.WriteTo(1, kChannel.LocalConn);
  245. buffer.WriteTo(5, kChannel.RemoteConn);
  246. Log.Info($"kservice syn: {kChannel.Id} {remoteConn} {localConn}");
  247. this.socket.SendTo(buffer, 0, 9, SocketFlags.None, kChannel.RemoteAddress);
  248. }
  249. catch (Exception e)
  250. {
  251. Log.Error(e);
  252. kChannel.OnError(ErrorCode.ERR_SocketCantSend);
  253. }
  254. break;
  255. }
  256. #endif
  257. case KcpProtocalType.ACK: // connect返回
  258. // 长度!=9,不是connect消息
  259. if (messageLength != 9)
  260. {
  261. break;
  262. }
  263. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  264. localConn = BitConverter.ToUInt32(this.cache, 5);
  265. kChannel = this.GetByLocalConn(localConn);
  266. if (kChannel != null)
  267. {
  268. Log.Info($"kservice ack: {kChannel.Id} {remoteConn} {localConn}");
  269. kChannel.RemoteConn = remoteConn;
  270. kChannel.HandleConnnect();
  271. }
  272. break;
  273. case KcpProtocalType.FIN: // 断开
  274. // 长度!=13,不是DisConnect消息
  275. if (messageLength != 13)
  276. {
  277. break;
  278. }
  279. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  280. localConn = BitConverter.ToUInt32(this.cache, 5);
  281. int error = BitConverter.ToInt32(this.cache, 9);
  282. // 处理chanel
  283. kChannel = this.GetByLocalConn(localConn);
  284. if (kChannel == null)
  285. {
  286. break;
  287. }
  288. // 校验remoteConn,防止第三方攻击
  289. if (kChannel.RemoteConn != remoteConn)
  290. {
  291. break;
  292. }
  293. Log.Info($"kservice recv fin: {kChannel.Id} {localConn} {remoteConn} {error}");
  294. kChannel.OnError(ErrorCode.ERR_PeerDisconnect);
  295. break;
  296. case KcpProtocalType.MSG: // 断开
  297. // 长度<9,不是Msg消息
  298. if (messageLength < 9)
  299. {
  300. break;
  301. }
  302. // 处理chanel
  303. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  304. localConn = BitConverter.ToUInt32(this.cache, 5);
  305. kChannel = this.GetByLocalConn(localConn);
  306. if (kChannel == null)
  307. {
  308. // 通知对方断开
  309. this.Disconnect(localConn, remoteConn, ErrorCode.ERR_KcpNotFoundChannel, (IPEndPoint) this.ipEndPoint, 1);
  310. break;
  311. }
  312. // 校验remoteConn,防止第三方攻击
  313. if (kChannel.RemoteConn != remoteConn)
  314. {
  315. break;
  316. }
  317. kChannel.HandleRecv(this.cache, 5, messageLength - 5);
  318. break;
  319. }
  320. }
  321. catch (Exception e)
  322. {
  323. Log.Error($"kservice error: {flag} {remoteConn} {localConn}\n{e}");
  324. }
  325. }
  326. }
  327. private KChannel Get(long id)
  328. {
  329. KChannel channel;
  330. this.idChannels.TryGetValue(id, out channel);
  331. return channel;
  332. }
  333. private KChannel GetByLocalConn(uint localConn)
  334. {
  335. KChannel channel;
  336. this.localConnChannels.TryGetValue(localConn, out channel);
  337. return channel;
  338. }
  339. protected override void Get(long id, IPEndPoint address)
  340. {
  341. if (this.idChannels.TryGetValue(id, out KChannel channel))
  342. {
  343. return;
  344. }
  345. try
  346. {
  347. // 低32bit是localConn
  348. uint localConn = (uint)((ulong) id & uint.MaxValue);
  349. channel = new KChannel(id, localConn, this.socket, address, this);
  350. this.idChannels.Add(id, channel);
  351. this.localConnChannels.Add(channel.LocalConn, channel);
  352. }
  353. catch (Exception e)
  354. {
  355. Log.Error($"kservice get error: {id}\n{e}");
  356. }
  357. }
  358. public override void Remove(long id)
  359. {
  360. if (!this.idChannels.TryGetValue(id, out KChannel kChannel))
  361. {
  362. return;
  363. }
  364. Log.Info($"kservice remove channel: {id} {kChannel.LocalConn} {kChannel.RemoteConn}");
  365. this.idChannels.Remove(id);
  366. this.localConnChannels.Remove(kChannel.LocalConn);
  367. if (this.waitConnectChannels.TryGetValue(kChannel.RemoteConn, out KChannel waitChannel))
  368. {
  369. if (waitChannel.LocalConn == kChannel.LocalConn)
  370. {
  371. this.waitConnectChannels.Remove(kChannel.RemoteConn);
  372. }
  373. }
  374. kChannel.Dispose();
  375. }
  376. private void Disconnect(uint localConn, uint remoteConn, int error, IPEndPoint address, int times)
  377. {
  378. try
  379. {
  380. if (this.socket == null)
  381. {
  382. return;
  383. }
  384. byte[] buffer = this.cache;
  385. buffer.WriteTo(0, KcpProtocalType.FIN);
  386. buffer.WriteTo(1, localConn);
  387. buffer.WriteTo(5, remoteConn);
  388. buffer.WriteTo(9, (uint) error);
  389. for (int i = 0; i < times; ++i)
  390. {
  391. this.socket.SendTo(buffer, 0, 13, SocketFlags.None, address);
  392. }
  393. }
  394. catch (Exception e)
  395. {
  396. Log.Error($"Disconnect error {localConn} {remoteConn} {error} {address} {e}");
  397. }
  398. Log.Info($"channel send fin: {localConn} {remoteConn} {address} {error}");
  399. }
  400. protected override void Send(long channelId, long actorId, MemoryStream stream)
  401. {
  402. KChannel channel = this.Get(channelId);
  403. if (channel == null)
  404. {
  405. return;
  406. }
  407. channel.Send(actorId, stream);
  408. }
  409. // 服务端需要看channel的update时间是否已到
  410. public void AddToUpdateNextTime(long time, long id)
  411. {
  412. if (time == 0)
  413. {
  414. this.updateChannels.Add(id);
  415. return;
  416. }
  417. if (time < this.minTime)
  418. {
  419. this.minTime = time;
  420. }
  421. this.timeId.Add(time, id);
  422. }
  423. public override void Update()
  424. {
  425. this.Recv();
  426. this.TimerOut();
  427. foreach (long id in updateChannels)
  428. {
  429. KChannel kChannel = this.Get(id);
  430. if (kChannel == null)
  431. {
  432. continue;
  433. }
  434. if (kChannel.Id == 0)
  435. {
  436. continue;
  437. }
  438. kChannel.Update();
  439. }
  440. this.updateChannels.Clear();
  441. this.RemoveConnectTimeoutChannels();
  442. }
  443. private void RemoveConnectTimeoutChannels()
  444. {
  445. this.waitRemoveChannels.Clear();
  446. foreach (long channelId in this.waitConnectChannels.Keys)
  447. {
  448. this.waitConnectChannels.TryGetValue(channelId, out KChannel kChannel);
  449. if (kChannel == null)
  450. {
  451. Log.Error($"RemoveConnectTimeoutChannels not found kchannel: {channelId}");
  452. continue;
  453. }
  454. // 连接上了要马上删除
  455. if (kChannel.IsConnected)
  456. {
  457. this.waitRemoveChannels.Add(channelId);
  458. }
  459. // 10秒连接超时
  460. if (this.TimeNow > kChannel.CreateTime + 10 * 1000)
  461. {
  462. this.waitRemoveChannels.Add(channelId);
  463. }
  464. }
  465. foreach (long channelId in this.waitRemoveChannels)
  466. {
  467. this.waitConnectChannels.Remove(channelId);
  468. }
  469. }
  470. // 计算到期需要update的channel
  471. private void TimerOut()
  472. {
  473. if (this.timeId.Count == 0)
  474. {
  475. return;
  476. }
  477. uint timeNow = this.TimeNow;
  478. if (timeNow < this.minTime)
  479. {
  480. return;
  481. }
  482. this.timeOutTime.Clear();
  483. foreach (KeyValuePair<long, List<long>> kv in this.timeId)
  484. {
  485. long k = kv.Key;
  486. if (k > timeNow)
  487. {
  488. minTime = k;
  489. break;
  490. }
  491. this.timeOutTime.Add(k);
  492. }
  493. foreach (long k in this.timeOutTime)
  494. {
  495. foreach (long v in this.timeId[k])
  496. {
  497. this.updateChannels.Add(v);
  498. }
  499. this.timeId.Remove(k);
  500. }
  501. }
  502. #endregion
  503. }
  504. }