KService.cs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. using System;
  2. using System.Collections.Concurrent;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Net;
  7. using System.Net.Sockets;
  8. using System.Runtime.InteropServices;
  9. namespace ET
  10. {
  11. public static class KcpProtocalType
  12. {
  13. public const byte SYN = 1;
  14. public const byte ACK = 2;
  15. public const byte FIN = 3;
  16. public const byte MSG = 4;
  17. }
  18. public enum ServiceType
  19. {
  20. Outer,
  21. Inner,
  22. }
  23. public sealed class KService: AService
  24. {
  25. // KService创建的时间
  26. private readonly long startTime;
  27. // 当前时间 - KService创建的时间, 线程安全
  28. public uint TimeNow
  29. {
  30. get
  31. {
  32. return (uint) (TimeHelper.ClientNow() - this.startTime);
  33. }
  34. }
  35. private Socket socket;
  36. #region 回调方法
  37. static KService()
  38. {
  39. //Kcp.KcpSetLog(KcpLog);
  40. Kcp.KcpSetoutput(KcpOutput);
  41. }
  42. private static readonly byte[] logBuffer = new byte[1024];
  43. #if ENABLE_IL2CPP
  44. [AOT.MonoPInvokeCallback(typeof(KcpOutput))]
  45. #endif
  46. private static void KcpLog(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
  47. {
  48. try
  49. {
  50. Marshal.Copy(bytes, logBuffer, 0, len);
  51. Log.Info(logBuffer.ToStr(0, len));
  52. }
  53. catch (Exception e)
  54. {
  55. Log.Error(e);
  56. }
  57. }
  58. #if ENABLE_IL2CPP
  59. [AOT.MonoPInvokeCallback(typeof(KcpOutput))]
  60. #endif
  61. private static int KcpOutput(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
  62. {
  63. try
  64. {
  65. if (kcp == IntPtr.Zero)
  66. {
  67. return 0;
  68. }
  69. if (!KChannel.KcpPtrChannels.TryGetValue(kcp, out KChannel kChannel))
  70. {
  71. return 0;
  72. }
  73. kChannel.Output(bytes, len);
  74. }
  75. catch (Exception e)
  76. {
  77. Log.Error(e);
  78. return len;
  79. }
  80. return len;
  81. }
  82. #endregion
  83. public KService(ThreadSynchronizationContext threadSynchronizationContext, IPEndPoint ipEndPoint, ServiceType serviceType)
  84. {
  85. this.ServiceType = serviceType;
  86. this.ThreadSynchronizationContext = threadSynchronizationContext;
  87. this.startTime = TimeHelper.ClientNow();
  88. this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
  89. if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
  90. {
  91. this.socket.SendBufferSize = Kcp.OneM * 64;
  92. this.socket.ReceiveBufferSize = Kcp.OneM * 64;
  93. }
  94. this.socket.Bind(ipEndPoint);
  95. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  96. {
  97. const uint IOC_IN = 0x80000000;
  98. const uint IOC_VENDOR = 0x18000000;
  99. uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
  100. this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
  101. }
  102. }
  103. public KService(ThreadSynchronizationContext threadSynchronizationContext, ServiceType serviceType)
  104. {
  105. this.ServiceType = serviceType;
  106. this.ThreadSynchronizationContext = threadSynchronizationContext;
  107. this.startTime = TimeHelper.ClientNow();
  108. this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
  109. // 作为客户端不需要修改发送跟接收缓冲区大小
  110. this.socket.Bind(new IPEndPoint(IPAddress.Any, 0));
  111. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
  112. {
  113. const uint IOC_IN = 0x80000000;
  114. const uint IOC_VENDOR = 0x18000000;
  115. uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
  116. this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
  117. }
  118. }
  119. public void ChangeAddress(long id, IPEndPoint address)
  120. {
  121. KChannel kChannel = this.Get(id);
  122. if (kChannel == null)
  123. {
  124. return;
  125. }
  126. Log.Info($"channel change address: {id} {address}");
  127. kChannel.RemoteAddress = address;
  128. }
  129. // 保存所有的channel
  130. private readonly Dictionary<long, KChannel> idChannels = new Dictionary<long, KChannel>();
  131. private readonly Dictionary<long, KChannel> localConnChannels = new Dictionary<long, KChannel>();
  132. private readonly Dictionary<long, KChannel> waitConnectChannels = new Dictionary<long, KChannel>();
  133. private readonly byte[] cache = new byte[8192];
  134. private EndPoint ipEndPoint = new IPEndPoint(IPAddress.Any, 0);
  135. // 下帧要更新的channel
  136. private readonly HashSet<long> updateChannels = new HashSet<long>();
  137. // 下次时间更新的channel
  138. private readonly MultiMap<long, long> timeId = new MultiMap<long, long>();
  139. private readonly List<long> timeOutTime = new List<long>();
  140. // 记录最小时间,不用每次都去MultiMap取第一个值
  141. private long minTime;
  142. private List<long> waitRemoveChannels = new List<long>();
  143. public override bool IsDispose()
  144. {
  145. return this.socket == null;
  146. }
  147. public override void Dispose()
  148. {
  149. foreach (long channelId in this.idChannels.Keys.ToArray())
  150. {
  151. this.Remove(channelId);
  152. }
  153. this.socket.Close();
  154. this.socket = null;
  155. }
  156. private IPEndPoint CloneAddress()
  157. {
  158. IPEndPoint ip = (IPEndPoint) this.ipEndPoint;
  159. return new IPEndPoint(ip.Address, ip.Port);
  160. }
  161. private void Recv()
  162. {
  163. if (this.socket == null)
  164. {
  165. return;
  166. }
  167. while (socket != null && this.socket.Available > 0)
  168. {
  169. int messageLength = this.socket.ReceiveFrom(this.cache, ref this.ipEndPoint);
  170. // 长度小于1,不是正常的消息
  171. if (messageLength < 1)
  172. {
  173. continue;
  174. }
  175. // accept
  176. byte flag = this.cache[0];
  177. // conn从100开始,如果为1,2,3则是特殊包
  178. uint remoteConn = 0;
  179. uint localConn = 0;
  180. try
  181. {
  182. KChannel kChannel = null;
  183. switch (flag)
  184. {
  185. #if NOT_UNITY
  186. case KcpProtocalType.SYN: // accept
  187. {
  188. // 长度!=5,不是SYN消息
  189. if (messageLength < 9)
  190. {
  191. break;
  192. }
  193. string realAddress = null;
  194. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  195. if (messageLength > 9)
  196. {
  197. realAddress = this.cache.ToStr(9, messageLength - 9);
  198. }
  199. else
  200. {
  201. realAddress = this.ipEndPoint.ToString();
  202. }
  203. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  204. localConn = BitConverter.ToUInt32(this.cache, 5);
  205. this.waitConnectChannels.TryGetValue(remoteConn, out kChannel);
  206. if (kChannel == null)
  207. {
  208. localConn = CreateRandomLocalConn();
  209. // 已存在同样的localConn,则不处理,等待下次sync
  210. if (this.localConnChannels.ContainsKey(localConn))
  211. {
  212. break;
  213. }
  214. long id = this.CreateAcceptChannelId(localConn);
  215. if (this.idChannels.ContainsKey(id))
  216. {
  217. break;
  218. }
  219. kChannel = new KChannel(id, localConn, remoteConn, this.socket, this.CloneAddress(), this);
  220. this.idChannels.Add(kChannel.Id, kChannel);
  221. this.waitConnectChannels.Add(kChannel.RemoteConn, kChannel); // 连接上了或者超时后会删除
  222. this.localConnChannels.Add(kChannel.LocalConn, kChannel);
  223. kChannel.RealAddress = realAddress;
  224. IPEndPoint realEndPoint = NetworkHelper.ToIPEndPoint(kChannel.RealAddress);
  225. this.OnAccept(kChannel.Id, realEndPoint);
  226. }
  227. if (kChannel.RemoteConn != remoteConn)
  228. {
  229. break;
  230. }
  231. // 地址跟上次的不一致则跳过
  232. if (kChannel.RealAddress != realAddress)
  233. {
  234. Log.Error($"kchannel syn address diff: {kChannel.Id} {kChannel.RealAddress} {realAddress}");
  235. break;
  236. }
  237. try
  238. {
  239. byte[] buffer = this.cache;
  240. buffer.WriteTo(0, KcpProtocalType.ACK);
  241. buffer.WriteTo(1, kChannel.LocalConn);
  242. buffer.WriteTo(5, kChannel.RemoteConn);
  243. Log.Info($"kservice syn: {kChannel.Id} {remoteConn} {localConn}");
  244. this.socket.SendTo(buffer, 0, 9, SocketFlags.None, kChannel.RemoteAddress);
  245. }
  246. catch (Exception e)
  247. {
  248. Log.Error(e);
  249. kChannel.OnError(ErrorCore.ERR_SocketCantSend);
  250. }
  251. break;
  252. }
  253. #endif
  254. case KcpProtocalType.ACK: // connect返回
  255. // 长度!=9,不是connect消息
  256. if (messageLength != 9)
  257. {
  258. break;
  259. }
  260. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  261. localConn = BitConverter.ToUInt32(this.cache, 5);
  262. kChannel = this.GetByLocalConn(localConn);
  263. if (kChannel != null)
  264. {
  265. Log.Info($"kservice ack: {kChannel.Id} {remoteConn} {localConn}");
  266. kChannel.RemoteConn = remoteConn;
  267. kChannel.HandleConnnect();
  268. }
  269. break;
  270. case KcpProtocalType.FIN: // 断开
  271. // 长度!=13,不是DisConnect消息
  272. if (messageLength != 13)
  273. {
  274. break;
  275. }
  276. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  277. localConn = BitConverter.ToUInt32(this.cache, 5);
  278. int error = BitConverter.ToInt32(this.cache, 9);
  279. // 处理chanel
  280. kChannel = this.GetByLocalConn(localConn);
  281. if (kChannel == null)
  282. {
  283. break;
  284. }
  285. // 校验remoteConn,防止第三方攻击
  286. if (kChannel.RemoteConn != remoteConn)
  287. {
  288. break;
  289. }
  290. Log.Info($"kservice recv fin: {kChannel.Id} {localConn} {remoteConn} {error}");
  291. kChannel.OnError(ErrorCore.ERR_PeerDisconnect);
  292. break;
  293. case KcpProtocalType.MSG: // 断开
  294. // 长度<9,不是Msg消息
  295. if (messageLength < 9)
  296. {
  297. break;
  298. }
  299. // 处理chanel
  300. remoteConn = BitConverter.ToUInt32(this.cache, 1);
  301. localConn = BitConverter.ToUInt32(this.cache, 5);
  302. kChannel = this.GetByLocalConn(localConn);
  303. if (kChannel == null)
  304. {
  305. // 通知对方断开
  306. this.Disconnect(localConn, remoteConn, ErrorCore.ERR_KcpNotFoundChannel, (IPEndPoint) this.ipEndPoint, 1);
  307. break;
  308. }
  309. // 校验remoteConn,防止第三方攻击
  310. if (kChannel.RemoteConn != remoteConn)
  311. {
  312. break;
  313. }
  314. kChannel.HandleRecv(this.cache, 5, messageLength - 5);
  315. break;
  316. }
  317. }
  318. catch (Exception e)
  319. {
  320. Log.Error($"kservice error: {flag} {remoteConn} {localConn}\n{e}");
  321. }
  322. }
  323. }
  324. public KChannel Get(long id)
  325. {
  326. KChannel channel;
  327. this.idChannels.TryGetValue(id, out channel);
  328. return channel;
  329. }
  330. private KChannel GetByLocalConn(uint localConn)
  331. {
  332. KChannel channel;
  333. this.localConnChannels.TryGetValue(localConn, out channel);
  334. return channel;
  335. }
  336. protected override void Get(long id, IPEndPoint address)
  337. {
  338. if (this.idChannels.TryGetValue(id, out KChannel kChannel))
  339. {
  340. return;
  341. }
  342. try
  343. {
  344. // 低32bit是localConn
  345. uint localConn = (uint) ((ulong) id & uint.MaxValue);
  346. kChannel = new KChannel(id, localConn, this.socket, address, this);
  347. this.idChannels.Add(id, kChannel);
  348. this.localConnChannels.Add(kChannel.LocalConn, kChannel);
  349. }
  350. catch (Exception e)
  351. {
  352. Log.Error($"kservice get error: {id}\n{e}");
  353. }
  354. }
  355. public override void Remove(long id)
  356. {
  357. if (!this.idChannels.TryGetValue(id, out KChannel kChannel))
  358. {
  359. return;
  360. }
  361. Log.Info($"kservice remove channel: {id} {kChannel.LocalConn} {kChannel.RemoteConn}");
  362. this.idChannels.Remove(id);
  363. this.localConnChannels.Remove(kChannel.LocalConn);
  364. if (this.waitConnectChannels.TryGetValue(kChannel.RemoteConn, out KChannel waitChannel))
  365. {
  366. if (waitChannel.LocalConn == kChannel.LocalConn)
  367. {
  368. this.waitConnectChannels.Remove(kChannel.RemoteConn);
  369. }
  370. }
  371. kChannel.Dispose();
  372. }
  373. private void Disconnect(uint localConn, uint remoteConn, int error, IPEndPoint address, int times)
  374. {
  375. try
  376. {
  377. if (this.socket == null)
  378. {
  379. return;
  380. }
  381. byte[] buffer = this.cache;
  382. buffer.WriteTo(0, KcpProtocalType.FIN);
  383. buffer.WriteTo(1, localConn);
  384. buffer.WriteTo(5, remoteConn);
  385. buffer.WriteTo(9, (uint) error);
  386. for (int i = 0; i < times; ++i)
  387. {
  388. this.socket.SendTo(buffer, 0, 13, SocketFlags.None, address);
  389. }
  390. }
  391. catch (Exception e)
  392. {
  393. Log.Error($"Disconnect error {localConn} {remoteConn} {error} {address} {e}");
  394. }
  395. Log.Info($"channel send fin: {localConn} {remoteConn} {address} {error}");
  396. }
  397. protected override void Send(long channelId, long actorId, MemoryStream stream)
  398. {
  399. KChannel channel = this.Get(channelId);
  400. if (channel == null)
  401. {
  402. return;
  403. }
  404. channel.Send(actorId, stream);
  405. }
  406. // 服务端需要看channel的update时间是否已到
  407. public void AddToUpdateNextTime(long time, long id)
  408. {
  409. if (time == 0)
  410. {
  411. this.updateChannels.Add(id);
  412. return;
  413. }
  414. if (time < this.minTime)
  415. {
  416. this.minTime = time;
  417. }
  418. this.timeId.Add(time, id);
  419. }
  420. public override void Update()
  421. {
  422. this.Recv();
  423. this.TimerOut();
  424. foreach (long id in updateChannels)
  425. {
  426. KChannel kChannel = this.Get(id);
  427. if (kChannel == null)
  428. {
  429. continue;
  430. }
  431. if (kChannel.Id == 0)
  432. {
  433. continue;
  434. }
  435. kChannel.Update();
  436. }
  437. this.updateChannels.Clear();
  438. this.RemoveConnectTimeoutChannels();
  439. }
  440. private void RemoveConnectTimeoutChannels()
  441. {
  442. waitRemoveChannels.Clear();
  443. foreach (long channelId in this.waitConnectChannels.Keys)
  444. {
  445. this.waitConnectChannels.TryGetValue(channelId, out KChannel kChannel);
  446. if (kChannel == null)
  447. {
  448. Log.Error($"RemoveConnectTimeoutChannels not found kchannel: {channelId}");
  449. continue;
  450. }
  451. // 连接上了要马上删除
  452. if (kChannel.IsConnected)
  453. {
  454. waitRemoveChannels.Add(channelId);
  455. }
  456. // 10秒连接超时
  457. if (this.TimeNow > kChannel.CreateTime + 10 * 1000)
  458. {
  459. waitRemoveChannels.Add(channelId);
  460. }
  461. }
  462. foreach (long channelId in waitRemoveChannels)
  463. {
  464. this.waitConnectChannels.Remove(channelId);
  465. }
  466. }
  467. // 计算到期需要update的channel
  468. private void TimerOut()
  469. {
  470. if (this.timeId.Count == 0)
  471. {
  472. return;
  473. }
  474. uint timeNow = this.TimeNow;
  475. if (timeNow < this.minTime)
  476. {
  477. return;
  478. }
  479. this.timeOutTime.Clear();
  480. foreach (KeyValuePair<long, List<long>> kv in this.timeId)
  481. {
  482. long k = kv.Key;
  483. if (k > timeNow)
  484. {
  485. minTime = k;
  486. break;
  487. }
  488. this.timeOutTime.Add(k);
  489. }
  490. foreach (long k in this.timeOutTime)
  491. {
  492. foreach (long v in this.timeId[k])
  493. {
  494. this.updateChannels.Add(v);
  495. }
  496. this.timeId.Remove(k);
  497. }
  498. }
  499. }
  500. }