| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583 |
- using System;
- using System.Collections.Concurrent;
- using System.Collections.Generic;
- using System.IO;
- using System.Linq;
- using System.Net;
- using System.Net.Sockets;
- using System.Runtime.InteropServices;
- namespace ET
- {
- public static class KcpProtocalType
- {
- public const byte SYN = 1;
- public const byte ACK = 2;
- public const byte FIN = 3;
- public const byte MSG = 4;
- }
- public enum ServiceType
- {
- Outer,
- Inner,
- }
- public sealed class KService: AService
- {
- // KService创建的时间
- private readonly long startTime;
- // 当前时间 - KService创建的时间, 线程安全
- public uint TimeNow
- {
- get
- {
- return (uint) (TimeHelper.ClientNow() - this.startTime);
- }
- }
- private Socket socket;
- #region 回调方法
- static KService()
- {
- //Kcp.KcpSetLog(KcpLog);
- Kcp.KcpSetoutput(KcpOutput);
- }
- private static readonly byte[] logBuffer = new byte[1024];
- #if ENABLE_IL2CPP
- [AOT.MonoPInvokeCallback(typeof(KcpOutput))]
- #endif
- private static void KcpLog(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
- {
- try
- {
- Marshal.Copy(bytes, logBuffer, 0, len);
- Log.Info(logBuffer.ToStr(0, len));
- }
- catch (Exception e)
- {
- Log.Error(e);
- }
- }
- #if ENABLE_IL2CPP
- [AOT.MonoPInvokeCallback(typeof(KcpOutput))]
- #endif
- private static int KcpOutput(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
- {
- try
- {
- if (kcp == IntPtr.Zero)
- {
- return 0;
- }
- if (!KChannel.KcpPtrChannels.TryGetValue(kcp, out KChannel kChannel))
- {
- return 0;
- }
-
- kChannel.Output(bytes, len);
- }
- catch (Exception e)
- {
- Log.Error(e);
- return len;
- }
- return len;
- }
- #endregion
- public KService(ThreadSynchronizationContext threadSynchronizationContext, IPEndPoint ipEndPoint, ServiceType serviceType)
- {
- this.ServiceType = serviceType;
- this.ThreadSynchronizationContext = threadSynchronizationContext;
- this.startTime = TimeHelper.ClientNow();
- this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
- if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
- {
- this.socket.SendBufferSize = Kcp.OneM * 64;
- this.socket.ReceiveBufferSize = Kcp.OneM * 64;
- }
- this.socket.Bind(ipEndPoint);
- if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
- {
- const uint IOC_IN = 0x80000000;
- const uint IOC_VENDOR = 0x18000000;
- uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
- this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
- }
- }
- public KService(ThreadSynchronizationContext threadSynchronizationContext, ServiceType serviceType)
- {
- this.ServiceType = serviceType;
- this.ThreadSynchronizationContext = threadSynchronizationContext;
- this.startTime = TimeHelper.ClientNow();
- this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
- // 作为客户端不需要修改发送跟接收缓冲区大小
- this.socket.Bind(new IPEndPoint(IPAddress.Any, 0));
- if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
- {
- const uint IOC_IN = 0x80000000;
- const uint IOC_VENDOR = 0x18000000;
- uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
- this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
- }
- }
- public void ChangeAddress(long id, IPEndPoint address)
- {
- KChannel kChannel = this.Get(id);
- if (kChannel == null)
- {
- return;
- }
- Log.Info($"channel change address: {id} {address}");
- kChannel.RemoteAddress = address;
- }
- // 保存所有的channel
- private readonly Dictionary<long, KChannel> idChannels = new Dictionary<long, KChannel>();
- private readonly Dictionary<long, KChannel> localConnChannels = new Dictionary<long, KChannel>();
- private readonly Dictionary<long, KChannel> waitConnectChannels = new Dictionary<long, KChannel>();
- private readonly byte[] cache = new byte[8192];
- private EndPoint ipEndPoint = new IPEndPoint(IPAddress.Any, 0);
- // 下帧要更新的channel
- private readonly HashSet<long> updateChannels = new HashSet<long>();
- // 下次时间更新的channel
- private readonly MultiMap<long, long> timeId = new MultiMap<long, long>();
- private readonly List<long> timeOutTime = new List<long>();
- // 记录最小时间,不用每次都去MultiMap取第一个值
- private long minTime;
- private List<long> waitRemoveChannels = new List<long>();
- public override bool IsDispose()
- {
- return this.socket == null;
- }
- public override void Dispose()
- {
- foreach (long channelId in this.idChannels.Keys.ToArray())
- {
- this.Remove(channelId);
- }
- this.socket.Close();
- this.socket = null;
- }
-
- private IPEndPoint CloneAddress()
- {
- IPEndPoint ip = (IPEndPoint) this.ipEndPoint;
- return new IPEndPoint(ip.Address, ip.Port);
- }
- private void Recv()
- {
- if (this.socket == null)
- {
- return;
- }
- while (socket != null && this.socket.Available > 0)
- {
- int messageLength = this.socket.ReceiveFrom(this.cache, ref this.ipEndPoint);
- // 长度小于1,不是正常的消息
- if (messageLength < 1)
- {
- continue;
- }
- // accept
- byte flag = this.cache[0];
-
- // conn从100开始,如果为1,2,3则是特殊包
- uint remoteConn = 0;
- uint localConn = 0;
-
- try
- {
- KChannel kChannel = null;
- switch (flag)
- {
- #if NOT_UNITY
- case KcpProtocalType.SYN: // accept
- {
- // 长度!=5,不是SYN消息
- if (messageLength < 9)
- {
- break;
- }
- string realAddress = null;
- remoteConn = BitConverter.ToUInt32(this.cache, 1);
- if (messageLength > 9)
- {
- realAddress = this.cache.ToStr(9, messageLength - 9);
- }
- remoteConn = BitConverter.ToUInt32(this.cache, 1);
- localConn = BitConverter.ToUInt32(this.cache, 5);
- this.waitConnectChannels.TryGetValue(remoteConn, out kChannel);
- if (kChannel == null)
- {
- localConn = CreateRandomLocalConn();
- // 已存在同样的localConn,则不处理,等待下次sync
- if (this.localConnChannels.ContainsKey(localConn))
- {
- break;
- }
- long id = this.CreateAcceptChannelId(localConn);
- if (this.idChannels.ContainsKey(id))
- {
- break;
- }
- kChannel = new KChannel(id, localConn, remoteConn, this.socket, this.CloneAddress(), this);
- this.idChannels.Add(kChannel.Id, kChannel);
- this.waitConnectChannels.Add(kChannel.RemoteConn, kChannel); // 连接上了或者超时后会删除
- this.localConnChannels.Add(kChannel.LocalConn, kChannel);
- kChannel.RealAddress = realAddress;
- IPEndPoint realEndPoint = kChannel.RealAddress == null? kChannel.RemoteAddress : NetworkHelper.ToIPEndPoint(kChannel.RealAddress);
- this.OnAccept(kChannel.Id, realEndPoint);
- }
- if (kChannel.RemoteConn != remoteConn)
- {
- break;
- }
- // 地址跟上次的不一致则跳过
- if (kChannel.RealAddress != realAddress)
- {
- Log.Error($"kchannel syn address diff: {kChannel.Id} {kChannel.RealAddress} {realAddress}");
- break;
- }
- try
- {
- byte[] buffer = this.cache;
- buffer.WriteTo(0, KcpProtocalType.ACK);
- buffer.WriteTo(1, kChannel.LocalConn);
- buffer.WriteTo(5, kChannel.RemoteConn);
- Log.Info($"kservice syn: {kChannel.Id} {remoteConn} {localConn}");
- this.socket.SendTo(buffer, 0, 9, SocketFlags.None, kChannel.RemoteAddress);
- }
- catch (Exception e)
- {
- Log.Error(e);
- kChannel.OnError(ErrorCore.ERR_SocketCantSend);
- }
- break;
- }
- #endif
- case KcpProtocalType.ACK: // connect返回
- // 长度!=9,不是connect消息
- if (messageLength != 9)
- {
- break;
- }
- remoteConn = BitConverter.ToUInt32(this.cache, 1);
- localConn = BitConverter.ToUInt32(this.cache, 5);
- kChannel = this.GetByLocalConn(localConn);
- if (kChannel != null)
- {
- Log.Info($"kservice ack: {kChannel.Id} {remoteConn} {localConn}");
- kChannel.RemoteConn = remoteConn;
- kChannel.HandleConnnect();
- }
- break;
- case KcpProtocalType.FIN: // 断开
- // 长度!=13,不是DisConnect消息
- if (messageLength != 13)
- {
- break;
- }
- remoteConn = BitConverter.ToUInt32(this.cache, 1);
- localConn = BitConverter.ToUInt32(this.cache, 5);
- int error = BitConverter.ToInt32(this.cache, 9);
- // 处理chanel
- kChannel = this.GetByLocalConn(localConn);
- if (kChannel == null)
- {
- break;
- }
-
- // 校验remoteConn,防止第三方攻击
- if (kChannel.RemoteConn != remoteConn)
- {
- break;
- }
-
- Log.Info($"kservice recv fin: {kChannel.Id} {localConn} {remoteConn} {error}");
- kChannel.OnError(ErrorCore.ERR_PeerDisconnect);
- break;
- case KcpProtocalType.MSG: // 断开
- // 长度<9,不是Msg消息
- if (messageLength < 9)
- {
- break;
- }
- // 处理chanel
- remoteConn = BitConverter.ToUInt32(this.cache, 1);
- localConn = BitConverter.ToUInt32(this.cache, 5);
- kChannel = this.GetByLocalConn(localConn);
- if (kChannel == null)
- {
- // 通知对方断开
- this.Disconnect(localConn, remoteConn, ErrorCore.ERR_KcpNotFoundChannel, (IPEndPoint) this.ipEndPoint, 1);
- break;
- }
-
- // 校验remoteConn,防止第三方攻击
- if (kChannel.RemoteConn != remoteConn)
- {
- break;
- }
-
- kChannel.HandleRecv(this.cache, 5, messageLength - 5);
- break;
- }
- }
- catch (Exception e)
- {
- Log.Error($"kservice error: {flag} {remoteConn} {localConn}\n{e}");
- }
- }
- }
- public KChannel Get(long id)
- {
- KChannel channel;
- this.idChannels.TryGetValue(id, out channel);
- return channel;
- }
-
- private KChannel GetByLocalConn(uint localConn)
- {
- KChannel channel;
- this.localConnChannels.TryGetValue(localConn, out channel);
- return channel;
- }
- protected override void Get(long id, IPEndPoint address)
- {
- if (this.idChannels.TryGetValue(id, out KChannel kChannel))
- {
- return;
- }
- try
- {
- // 低32bit是localConn
- uint localConn = (uint) ((ulong) id & uint.MaxValue);
- kChannel = new KChannel(id, localConn, this.socket, address, this);
- this.idChannels.Add(id, kChannel);
- this.localConnChannels.Add(kChannel.LocalConn, kChannel);
- }
- catch (Exception e)
- {
- Log.Error($"kservice get error: {id}\n{e}");
- }
- }
- public override void Remove(long id)
- {
- if (!this.idChannels.TryGetValue(id, out KChannel kChannel))
- {
- return;
- }
- Log.Info($"kservice remove channel: {id} {kChannel.LocalConn} {kChannel.RemoteConn}");
- this.idChannels.Remove(id);
- this.localConnChannels.Remove(kChannel.LocalConn);
- if (this.waitConnectChannels.TryGetValue(kChannel.RemoteConn, out KChannel waitChannel))
- {
- if (waitChannel.LocalConn == kChannel.LocalConn)
- {
- this.waitConnectChannels.Remove(kChannel.RemoteConn);
- }
- }
- kChannel.Dispose();
- }
- private void Disconnect(uint localConn, uint remoteConn, int error, IPEndPoint address, int times)
- {
- try
- {
- if (this.socket == null)
- {
- return;
- }
- byte[] buffer = this.cache;
- buffer.WriteTo(0, KcpProtocalType.FIN);
- buffer.WriteTo(1, localConn);
- buffer.WriteTo(5, remoteConn);
- buffer.WriteTo(9, (uint) error);
- for (int i = 0; i < times; ++i)
- {
- this.socket.SendTo(buffer, 0, 13, SocketFlags.None, address);
- }
- }
- catch (Exception e)
- {
- Log.Error($"Disconnect error {localConn} {remoteConn} {error} {address} {e}");
- }
-
- Log.Info($"channel send fin: {localConn} {remoteConn} {address} {error}");
- }
-
- protected override void Send(long channelId, long actorId, MemoryStream stream)
- {
- KChannel channel = this.Get(channelId);
- if (channel == null)
- {
- return;
- }
- channel.Send(actorId, stream);
- }
- // 服务端需要看channel的update时间是否已到
- public void AddToUpdateNextTime(long time, long id)
- {
- if (time == 0)
- {
- this.updateChannels.Add(id);
- return;
- }
- if (time < this.minTime)
- {
- this.minTime = time;
- }
- this.timeId.Add(time, id);
- }
- public override void Update()
- {
- this.Recv();
-
- this.TimerOut();
- foreach (long id in updateChannels)
- {
- KChannel kChannel = this.Get(id);
- if (kChannel == null)
- {
- continue;
- }
- if (kChannel.Id == 0)
- {
- continue;
- }
- kChannel.Update();
- }
- this.updateChannels.Clear();
-
- this.RemoveConnectTimeoutChannels();
- }
- private void RemoveConnectTimeoutChannels()
- {
- waitRemoveChannels.Clear();
- foreach (long channelId in this.waitConnectChannels.Keys)
- {
- this.waitConnectChannels.TryGetValue(channelId, out KChannel kChannel);
- if (kChannel == null)
- {
- Log.Error($"RemoveConnectTimeoutChannels not found kchannel: {channelId}");
- continue;
- }
- // 连接上了要马上删除
- if (kChannel.IsConnected)
- {
- waitRemoveChannels.Add(channelId);
- }
- // 10秒连接超时
- if (this.TimeNow > kChannel.CreateTime + 10 * 1000)
- {
- waitRemoveChannels.Add(channelId);
- }
- }
- foreach (long channelId in waitRemoveChannels)
- {
- this.waitConnectChannels.Remove(channelId);
- }
- }
- // 计算到期需要update的channel
- private void TimerOut()
- {
- if (this.timeId.Count == 0)
- {
- return;
- }
- uint timeNow = this.TimeNow;
- if (timeNow < this.minTime)
- {
- return;
- }
- this.timeOutTime.Clear();
- foreach (KeyValuePair<long, List<long>> kv in this.timeId)
- {
- long k = kv.Key;
- if (k > timeNow)
- {
- minTime = k;
- break;
- }
- this.timeOutTime.Add(k);
- }
- foreach (long k in this.timeOutTime)
- {
- foreach (long v in this.timeId[k])
- {
- this.updateChannels.Add(v);
- }
- this.timeId.Remove(k);
- }
- }
- }
- }
|