KService.cs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5. using System.Net;
  6. using System.Net.Sockets;
  7. using System.Runtime.InteropServices;
  8. namespace ET
  9. {
  10. public static class KcpProtocalType
  11. {
  12. public const byte SYN = 1;
  13. public const byte ACK = 2;
  14. public const byte FIN = 3;
  15. public const byte MSG = 4;
  16. public const byte RouterReconnect = 10;
  17. public const byte RouterAck = 11;
  18. public const byte RouterSYN = 12;
  19. }
  20. public enum ServiceType
  21. {
  22. Outer,
  23. Inner,
  24. }
  25. public sealed class KService: AService
  26. {
  27. // KService创建的时间
  28. public long StartTime;
  29. // 当前时间 - KService创建的时间, 线程安全
  30. public uint TimeNow
  31. {
  32. get
  33. {
  34. return (uint) (TimeHelper.ClientNow() - this.StartTime);
  35. }
  36. }
  37. private Socket socket;
  38. #region 回调方法
  39. static KService()
  40. {
  41. //Kcp.KcpSetLog(KcpLog);
  42. Kcp.KcpSetoutput(KcpOutput);
  43. }
  44. private static readonly byte[] logBuffer = new byte[1024];
  45. #if ENABLE_IL2CPP
  46. [AOT.MonoPInvokeCallback(typeof(KcpOutput))]
  47. #endif
  48. private static void KcpLog(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
  49. {
  50. try
  51. {
  52. Marshal.Copy(bytes, logBuffer, 0, len);
  53. Log.Info(logBuffer.ToStr(0, len));
  54. }
  55. catch (Exception e)
  56. {
  57. Log.Error(e);
  58. }
  59. }
  60. #if ENABLE_IL2CPP
  61. [AOT.MonoPInvokeCallback(typeof(KcpOutput))]
  62. #endif
  63. private static int KcpOutput(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
  64. {
  65. try
  66. {
  67. KChannel kChannel = KChannel.kChannels[(uint) user];
  68. kChannel.Output(bytes, len);
  69. }
  70. catch (Exception e)
  71. {
  72. Log.Error(e);
  73. return len;
  74. }
  75. return len;
  76. }
  77. #endregion
  78. #region 主线程
  79. public KService(ThreadSynchronizationContext threadSynchronizationContext, IPEndPoint ipEndPoint, ServiceType serviceType)
  80. {
  81. this.ServiceType = serviceType;
  82. this.ThreadSynchronizationContext = threadSynchronizationContext;
  83. this.StartTime = TimeHelper.ClientNow();
  84. this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
  85. this.socket.SendBufferSize = Kcp.OneM * 64;
  86. this.socket.ReceiveBufferSize = Kcp.OneM * 64;
  87. this.socket.Bind(ipEndPoint);
  88. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  89. {
  90. const uint IOC_IN = 0x80000000;
  91. const uint IOC_VENDOR = 0x18000000;
  92. uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
  93. this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
  94. }
  95. }
  96. public KService(ThreadSynchronizationContext threadSynchronizationContext, ServiceType serviceType)
  97. {
  98. this.ServiceType = serviceType;
  99. this.ThreadSynchronizationContext = threadSynchronizationContext;
  100. this.StartTime = TimeHelper.ClientNow();
  101. this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
  102. // 作为客户端不需要修改发送跟接收缓冲区大小
  103. this.socket.Bind(new IPEndPoint(IPAddress.Any, 0));
  104. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  105. {
  106. const uint IOC_IN = 0x80000000;
  107. const uint IOC_VENDOR = 0x18000000;
  108. uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
  109. this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
  110. }
  111. }
  112. public void ChangeAddress(long id, IPEndPoint address)
  113. {
  114. #if NET_THREAD
  115. this.ThreadSynchronizationContext.Post(() =>
  116. {
  117. #endif
  118. KChannel kChannel = this.Get(id);
  119. if (kChannel == null)
  120. {
  121. return;
  122. }
  123. Log.Info($"channel change address: {id} {address}");
  124. kChannel.RemoteAddress = address;
  125. #if NET_THREAD
  126. }
  127. );
  128. #endif
  129. }
  130. #endregion
  131. #region 网络线程
  132. private readonly Dictionary<long, KChannel> idChannels = new Dictionary<long, KChannel>();
  133. private readonly Dictionary<long, KChannel> localConnChannels = new Dictionary<long, KChannel>();
  134. private readonly Dictionary<long, KChannel> waitConnectChannels = new Dictionary<long, KChannel>();
  135. private readonly List<long> waitRemoveChannels = new List<long>();
  136. private readonly byte[] cache = new byte[8192];
  137. private EndPoint ipEndPoint = new IPEndPoint(IPAddress.Any, 0);
  138. // 网络线程
  139. private readonly Random random = new Random(Guid.NewGuid().GetHashCode());
  140. // 下帧要更新的channel
  141. private readonly HashSet<long> updateChannels = new HashSet<long>();
  142. // 下次时间更新的channel
  143. private readonly MultiMap<long, long> timeId = new MultiMap<long, long>();
  144. private readonly List<long> timeOutTime = new List<long>();
  145. // 记录最小时间,不用每次都去MultiMap取第一个值
  146. private long minTime;
  147. public override bool IsDispose()
  148. {
  149. return this.socket == null;
  150. }
  151. public override void Dispose()
  152. {
  153. foreach (long channelId in this.idChannels.Keys.ToArray())
  154. {
  155. this.Remove(channelId);
  156. }
  157. this.socket.Close();
  158. this.socket = null;
  159. }
  160. private IPEndPoint CloneAddress()
  161. {
  162. IPEndPoint ip = (IPEndPoint) this.ipEndPoint;
  163. return new IPEndPoint(ip.Address, ip.Port);
  164. }
  165. private void Recv()
  166. {
  167. if (this.socket == null)
  168. {
  169. return;
  170. }
  171. while (socket != null && this.socket.Available > 0)
  172. {
  173. int messageLength = this.socket.ReceiveFrom(this.cache, ref this.ipEndPoint);
  174. // 长度小于1,不是正常的消息
  175. if (messageLength < 1)
  176. {
  177. continue;
  178. }
  179. // accept
  180. byte flag = this.cache[0];
  181. // conn从100开始,如果为1,2,3则是特殊包
  182. uint remoteConn = 0;
  183. uint localConn = 0;
  184. try
  185. {
  186. KChannel kChannel = null;
  187. switch (flag)
  188. {
  189. #if NOT_CLIENT
  190. case KcpProtocalType.SYN: // accept
  191. {
  192. // 长度!=5,不是SYN消息
  193. if (messageLength < 9)
  194. {
  195. break;
  196. }
  197. string realAddress = null;
  198. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  199. if (messageLength > 9)
  200. {
  201. realAddress = this.cache.ToStr(9, messageLength - 9);
  202. }
  203. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  204. localConn = BitConverter.ToUInt32(this.cache, 5);
  205. this.waitConnectChannels.TryGetValue(remoteConn, out kChannel);
  206. if (kChannel == null)
  207. {
  208. localConn = CreateRandomLocalConn(this.random);
  209. // 已存在同样的localConn,则不处理,等待下次sync
  210. if (this.localConnChannels.ContainsKey(localConn))
  211. {
  212. break;
  213. }
  214. long id = this.CreateAcceptChannelId(localConn);
  215. if (this.idChannels.ContainsKey(id))
  216. {
  217. break;
  218. }
  219. kChannel = new KChannel(id, localConn, remoteConn, this.socket, this.CloneAddress(), this);
  220. this.idChannels.Add(kChannel.Id, kChannel);
  221. this.waitConnectChannels.Add(kChannel.RemoteConn, kChannel); // 连接上了或者超时后会删除
  222. this.localConnChannels.Add(kChannel.LocalConn, kChannel);
  223. kChannel.RealAddress = realAddress;
  224. IPEndPoint realEndPoint = kChannel.RealAddress == null? kChannel.RemoteAddress : NetworkHelper.ToIPEndPoint(kChannel.RealAddress);
  225. this.OnAccept(kChannel.Id, realEndPoint);
  226. }
  227. if (kChannel.RemoteConn != remoteConn)
  228. {
  229. break;
  230. }
  231. // 地址跟上次的不一致则跳过
  232. if (kChannel.RealAddress != realAddress)
  233. {
  234. Log.Error($"kchannel syn address diff: {kChannel.Id} {kChannel.RealAddress} {realAddress}");
  235. break;
  236. }
  237. try
  238. {
  239. byte[] buffer = this.cache;
  240. buffer.WriteTo(0, KcpProtocalType.ACK);
  241. buffer.WriteTo(1, kChannel.LocalConn);
  242. buffer.WriteTo(5, kChannel.RemoteConn);
  243. Log.Info($"kservice syn: {kChannel.Id} {remoteConn} {localConn}");
  244. this.socket.SendTo(buffer, 0, 9, SocketFlags.None, kChannel.RemoteAddress);
  245. }
  246. catch (Exception e)
  247. {
  248. Log.Error(e);
  249. kChannel.OnError(ErrorCode.ERR_SocketCantSend);
  250. }
  251. break;
  252. }
  253. #endif
  254. case KcpProtocalType.ACK: // connect返回
  255. // 长度!=9,不是connect消息
  256. if (messageLength != 9)
  257. {
  258. break;
  259. }
  260. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  261. localConn = BitConverter.ToUInt32(this.cache, 5);
  262. kChannel = this.GetByLocalConn(localConn);
  263. if (kChannel != null)
  264. {
  265. Log.Info($"kservice ack: {kChannel.Id} {remoteConn} {localConn}");
  266. kChannel.RemoteConn = remoteConn;
  267. kChannel.HandleConnnect();
  268. }
  269. break;
  270. case KcpProtocalType.FIN: // 断开
  271. // 长度!=13,不是DisConnect消息
  272. if (messageLength != 13)
  273. {
  274. break;
  275. }
  276. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  277. localConn = BitConverter.ToUInt32(this.cache, 5);
  278. int error = BitConverter.ToInt32(this.cache, 9);
  279. // 处理chanel
  280. kChannel = this.GetByLocalConn(localConn);
  281. if (kChannel == null)
  282. {
  283. break;
  284. }
  285. // 校验remoteConn,防止第三方攻击
  286. if (kChannel.RemoteConn != remoteConn)
  287. {
  288. break;
  289. }
  290. Log.Info($"kservice recv fin: {kChannel.Id} {localConn} {remoteConn} {error}");
  291. kChannel.OnError(ErrorCode.ERR_PeerDisconnect);
  292. break;
  293. case KcpProtocalType.MSG: // 断开
  294. // 长度<9,不是Msg消息
  295. if (messageLength < 9)
  296. {
  297. break;
  298. }
  299. // 处理chanel
  300. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  301. localConn = BitConverter.ToUInt32(this.cache, 5);
  302. kChannel = this.GetByLocalConn(localConn);
  303. if (kChannel == null)
  304. {
  305. // 通知对方断开
  306. this.Disconnect(localConn, remoteConn, ErrorCode.ERR_KcpNotFoundChannel, (IPEndPoint) this.ipEndPoint, 1);
  307. break;
  308. }
  309. // 校验remoteConn,防止第三方攻击
  310. if (kChannel.RemoteConn != remoteConn)
  311. {
  312. break;
  313. }
  314. kChannel.HandleRecv(this.cache, 5, messageLength - 5);
  315. break;
  316. }
  317. }
  318. catch (Exception e)
  319. {
  320. Log.Error($"kservice error: {flag} {remoteConn} {localConn}\n{e}");
  321. }
  322. }
  323. }
  324. private KChannel Get(long id)
  325. {
  326. KChannel channel;
  327. this.idChannels.TryGetValue(id, out channel);
  328. return channel;
  329. }
  330. private KChannel GetByLocalConn(uint localConn)
  331. {
  332. KChannel channel;
  333. this.localConnChannels.TryGetValue(localConn, out channel);
  334. return channel;
  335. }
  336. protected override void Get(long id, IPEndPoint address)
  337. {
  338. if (this.idChannels.TryGetValue(id, out KChannel channel))
  339. {
  340. return;
  341. }
  342. try
  343. {
  344. // 低32bit是localConn
  345. uint localConn = (uint)((ulong) id & uint.MaxValue);
  346. channel = new KChannel(id, localConn, this.socket, address, this);
  347. this.idChannels.Add(id, channel);
  348. this.localConnChannels.Add(channel.LocalConn, channel);
  349. }
  350. catch (Exception e)
  351. {
  352. Log.Error($"kservice get error: {id}\n{e}");
  353. }
  354. }
  355. public override void Remove(long id)
  356. {
  357. if (!this.idChannels.TryGetValue(id, out KChannel kChannel))
  358. {
  359. return;
  360. }
  361. Log.Info($"kservice remove channel: {id} {kChannel.LocalConn} {kChannel.RemoteConn}");
  362. this.idChannels.Remove(id);
  363. this.localConnChannels.Remove(kChannel.LocalConn);
  364. if (this.waitConnectChannels.TryGetValue(kChannel.RemoteConn, out KChannel waitChannel))
  365. {
  366. if (waitChannel.LocalConn == kChannel.LocalConn)
  367. {
  368. this.waitConnectChannels.Remove(kChannel.RemoteConn);
  369. }
  370. }
  371. kChannel.Dispose();
  372. }
  373. private void Disconnect(uint localConn, uint remoteConn, int error, IPEndPoint address, int times)
  374. {
  375. try
  376. {
  377. if (this.socket == null)
  378. {
  379. return;
  380. }
  381. byte[] buffer = this.cache;
  382. buffer.WriteTo(0, KcpProtocalType.FIN);
  383. buffer.WriteTo(1, localConn);
  384. buffer.WriteTo(5, remoteConn);
  385. buffer.WriteTo(9, (uint) error);
  386. for (int i = 0; i < times; ++i)
  387. {
  388. this.socket.SendTo(buffer, 0, 13, SocketFlags.None, address);
  389. }
  390. }
  391. catch (Exception e)
  392. {
  393. Log.Error($"Disconnect error {localConn} {remoteConn} {error} {address} {e}");
  394. }
  395. Log.Info($"channel send fin: {localConn} {remoteConn} {address} {error}");
  396. }
  397. protected override void Send(long channelId, long actorId, MemoryStream stream)
  398. {
  399. KChannel channel = this.Get(channelId);
  400. if (channel == null)
  401. {
  402. return;
  403. }
  404. channel.Send(actorId, stream);
  405. }
  406. // 服务端需要看channel的update时间是否已到
  407. public void AddToUpdateNextTime(long time, long id)
  408. {
  409. if (time == 0)
  410. {
  411. this.updateChannels.Add(id);
  412. return;
  413. }
  414. if (time < this.minTime)
  415. {
  416. this.minTime = time;
  417. }
  418. this.timeId.Add(time, id);
  419. }
  420. public override void Update()
  421. {
  422. this.Recv();
  423. this.TimerOut();
  424. foreach (long id in updateChannels)
  425. {
  426. KChannel kChannel = this.Get(id);
  427. if (kChannel == null)
  428. {
  429. continue;
  430. }
  431. if (kChannel.Id == 0)
  432. {
  433. continue;
  434. }
  435. kChannel.Update();
  436. }
  437. this.updateChannels.Clear();
  438. this.RemoveConnectTimeoutChannels();
  439. }
  440. private void RemoveConnectTimeoutChannels()
  441. {
  442. this.waitRemoveChannels.Clear();
  443. foreach (long channelId in this.waitConnectChannels.Keys)
  444. {
  445. this.waitConnectChannels.TryGetValue(channelId, out KChannel kChannel);
  446. if (kChannel == null)
  447. {
  448. Log.Error($"RemoveConnectTimeoutChannels not found kchannel: {channelId}");
  449. continue;
  450. }
  451. // 连接上了要马上删除
  452. if (kChannel.IsConnected)
  453. {
  454. this.waitRemoveChannels.Add(channelId);
  455. }
  456. // 10秒连接超时
  457. if (this.TimeNow > kChannel.CreateTime + 10 * 1000)
  458. {
  459. this.waitRemoveChannels.Add(channelId);
  460. }
  461. }
  462. foreach (long channelId in this.waitRemoveChannels)
  463. {
  464. this.waitConnectChannels.Remove(channelId);
  465. }
  466. }
  467. // 计算到期需要update的channel
  468. private void TimerOut()
  469. {
  470. if (this.timeId.Count == 0)
  471. {
  472. return;
  473. }
  474. uint timeNow = this.TimeNow;
  475. if (timeNow < this.minTime)
  476. {
  477. return;
  478. }
  479. this.timeOutTime.Clear();
  480. foreach (KeyValuePair<long, List<long>> kv in this.timeId)
  481. {
  482. long k = kv.Key;
  483. if (k > timeNow)
  484. {
  485. minTime = k;
  486. break;
  487. }
  488. this.timeOutTime.Add(k);
  489. }
  490. foreach (long k in this.timeOutTime)
  491. {
  492. foreach (long v in this.timeId[k])
  493. {
  494. this.updateChannels.Add(v);
  495. }
  496. this.timeId.Remove(k);
  497. }
  498. }
  499. #endregion
  500. }
  501. }