tanghai 11 лет назад
Родитель
Сommit
7f36a36e4c
2 измененных файлов с 137 добавлено и 64 удалено
  1. 18 27
      CSharp/Platform/TNet/TPoller.cs
  2. 119 37
      CSharp/Platform/TNet/TSocket.cs

+ 18 - 27
CSharp/Platform/TNet/TPoller.cs

@@ -1,17 +1,17 @@
-using System.Collections.Concurrent;
+using System;
+using System.Collections.Concurrent;
 using System.Collections.Generic;
 
 namespace TNet
 {
 	public class TPoller
 	{
-		private readonly BlockingCollection<TSocketState> blockingCollection = new BlockingCollection<TSocketState>();
-
-		public HashSet<TSocket> CanWriteSocket = new HashSet<TSocket>();
-
-		public void Add(TSocketState tSocketState)
+		// 线程同步队列,发送接收socket回调都放到该队列,由poll线程统一执行
+		private readonly BlockingCollection<Action> blockingCollection = new BlockingCollection<Action>();
+		
+		public void Add(Action action)
 		{
-			this.blockingCollection.Add(tSocketState);
+			this.blockingCollection.Add(action);
 		}
 
 		public void Dispose()
@@ -20,38 +20,29 @@ namespace TNet
 
 		public void RunOnce(int timeout)
 		{
-			foreach (TSocket socket in CanWriteSocket)
-			{
-				if (socket.IsSending)
-				{
-					continue;
-				}
-				socket.BeginSend();
-			}
-			this.CanWriteSocket.Clear();
-
-			TSocketState socketState;
-			if (!this.blockingCollection.TryTake(out socketState, timeout))
+			// 处理读写线程的回调
+			Action action;
+			if (!this.blockingCollection.TryTake(out action, timeout))
 			{
 				return;
 			}
 
-			var stateQueue = new Queue<TSocketState>();
-			stateQueue.Enqueue(socketState);
+			var queue = new Queue<Action>();
+			queue.Enqueue(action);
 
 			while (true)
 			{
-				if (!this.blockingCollection.TryTake(out socketState, 0))
+				if (!this.blockingCollection.TryTake(out action, 0))
 				{
 					break;
 				}
-				stateQueue.Enqueue(socketState);
+				queue.Enqueue(action);
 			}
 
-			while (stateQueue.Count > 0)
+			while (queue.Count > 0)
 			{
-				TSocketState state = stateQueue.Dequeue();
-				state.Run();
+				Action a = queue.Dequeue();
+				a();
 			}
 		}
 
@@ -59,7 +50,7 @@ namespace TNet
 		{
 			while (true)
 			{
-				this.RunOnce(1);
+				this.RunOnce(10);
 			}
 		}
 	}

+ 119 - 37
CSharp/Platform/TNet/TSocket.cs

@@ -1,18 +1,9 @@
 using System;
+using System.Net;
 using System.Net.Sockets;
 
 namespace TNet
 {
-	public class TSocketState
-	{
-		public Action Action { get; set; }
-
-		public void Run()
-		{
-			this.Action();
-		}
-	}
-
 	public class TSocket: IDisposable
 	{
 		private Socket socket;
@@ -21,7 +12,8 @@ namespace TNet
 		private readonly SocketAsyncEventArgs outSocketAsyncEventArgs = new SocketAsyncEventArgs();
 		private readonly TBuffer recvBuffer = new TBuffer();
 		private readonly TBuffer sendBuffer = new TBuffer();
-		public bool IsSending { get; private set; }
+		public Action RecvAction { get; set; }
+		public Action<TSocket> AcceptAction { get; set; }
 
 		public TSocket(TPoller poller)
 		{
@@ -29,7 +21,14 @@ namespace TNet
 			this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
 			this.outSocketAsyncEventArgs.Completed += this.OnComplete;
 			this.innSocketAsyncEventArgs.Completed += this.OnComplete;
-			this.IsSending = false;
+		}
+
+		public TSocket(TPoller poller, Socket socket)
+		{
+			this.poller = poller;
+			this.socket = socket;
+			this.outSocketAsyncEventArgs.Completed += this.OnComplete;
+			this.innSocketAsyncEventArgs.Completed += this.OnComplete;
 		}
 
 		public void Dispose()
@@ -40,53 +39,67 @@ namespace TNet
 			}
 			socket.Dispose();
 			this.socket = null;
-			poller.CanWriteSocket.Remove(this);
 		}
 
 		public void Connect(string host, int port)
 		{
-			socket.ConnectAsync(this.innSocketAsyncEventArgs);
+			if (socket.ConnectAsync(this.innSocketAsyncEventArgs))
+			{
+				return;
+			}
+
+			this.poller.Add(this.OnConnComplete);
 		}
 
-		public int CanRecvSize
+		public void Accept(int port)
 		{
-			get
-			{
-				return this.recvBuffer.Count;
-			}
+			this.socket.Bind(new IPEndPoint(IPAddress.Any, port));
+			this.socket.Listen(100);
+			this.BeginAccept();
 		}
 
-		public void Recv(byte[] buffer)
+		public bool Recv(byte[] buffer)
 		{
+			if (buffer.Length > this.RecvSize)
+			{
+				return false;
+			}
 			this.recvBuffer.RecvFrom(buffer);
+			return true;
 		}
 
 		public void Send(byte[] buffer)
 		{
+			bool needBeginSend = this.sendBuffer.Count == 0;
 			this.sendBuffer.SendTo(buffer);
-			// 如果正在发送,则不做可发送标记
-			if (this.IsSending)
+			if (needBeginSend)
 			{
-				return;
+				this.BeginSend();
 			}
-			if (this.poller.CanWriteSocket.Contains(this))
+		}
+
+		public int RecvSize
+		{
+			get
 			{
-				return;
+				return this.recvBuffer.Count;
 			}
-			this.poller.CanWriteSocket.Add(this);
 		}
 
 		private void OnComplete(object sender, SocketAsyncEventArgs e)
 		{
-			Action action = () => { };
+			Action action;
 			switch (e.LastOperation)
 			{
 				case SocketAsyncOperation.Accept:
+					action = () => this.OnAcceptComplete(e.AcceptSocket);
+					e.AcceptSocket = null;
 					break;
 				case SocketAsyncOperation.Connect:
 					action = this.OnConnComplete;
 					break;
 				case SocketAsyncOperation.Disconnect:
+					action = this.OnDisconnect;
 					break;
 				case SocketAsyncOperation.Receive:
 					action = () => this.OnRecvComplete(e.BytesTransferred);
@@ -97,32 +110,67 @@ namespace TNet
 				default:
 					throw new ArgumentOutOfRangeException();
 			}
-			TSocketState socketState = new TSocketState
-			{
-				Action = action,
-			};
 			
-			this.poller.Add(socketState);
+			this.poller.Add(action);
+		}
+
+		private void OnDisconnect()
+		{
+			this.Dispose();
+		}
+
+		private void OnAcceptComplete(Socket sock)
+		{
+			if (this.socket == null)
+			{
+				return;
+			}
+
+			TSocket newSocket = new TSocket(poller, sock);
+			if (this.AcceptAction != null)
+			{
+				this.AcceptAction(newSocket);
+			}
+			this.BeginAccept();
 		}
 
 		private void OnConnComplete()
 		{
+			if (this.socket == null)
+			{
+				return;
+			}
 			this.BeginRecv();
 		}
 
 		private void OnRecvComplete(int bytesTransferred)
 		{
+			if (this.socket == null)
+			{
+				return;
+			}
 			this.recvBuffer.LastIndex += bytesTransferred;
 			if (this.recvBuffer.LastIndex == TBuffer.ChunkSize)
 			{
 				this.recvBuffer.LastIndex = 0;
 				this.recvBuffer.AddLast();
 			}
+
 			this.BeginRecv();
+
+			if (this.RecvAction != null)
+			{
+				this.RecvAction();
+			}
 		}
 
 		private void OnSendComplete(int bytesTransferred)
 		{
+			if (this.socket == null)
+			{
+				return;
+			}
+
 			this.sendBuffer.FirstIndex += bytesTransferred;
 			if (this.sendBuffer.FirstIndex == TBuffer.ChunkSize)
 			{
@@ -133,7 +181,6 @@ namespace TNet
 			// 如果没有数据可以发送,则返回
 			if (this.sendBuffer.Count == 0)
 			{
-				this.IsSending = false;
 				return;
 			}
 
@@ -141,15 +188,45 @@ namespace TNet
 			this.BeginSend();
 		}
 
+		private void BeginAccept()
+		{
+			if (this.socket == null)
+			{
+				return;
+			}
+
+			if (this.socket.AcceptAsync(this.innSocketAsyncEventArgs))
+			{
+				return;
+			}
+			Action action = () => this.OnAcceptComplete(this.innSocketAsyncEventArgs.AcceptSocket);
+			this.poller.Add(action);
+		}
+
 		private void BeginRecv()
 		{
+			if (this.socket == null)
+			{
+				return;
+			}
+
 			this.innSocketAsyncEventArgs.SetBuffer(this.recvBuffer.Last, this.recvBuffer.LastIndex, TBuffer.ChunkSize - this.recvBuffer.LastIndex);
-			this.socket.ReceiveAsync(this.innSocketAsyncEventArgs);
+			if (this.socket.ReceiveAsync(this.innSocketAsyncEventArgs))
+			{
+				return;
+			}
+
+			Action action = () => this.OnRecvComplete(this.innSocketAsyncEventArgs.BytesTransferred);
+			this.poller.Add(action);
 		}
 
-		public void BeginSend()
+		private void BeginSend()
 		{
-			this.IsSending = true;
+			if (this.socket == null)
+			{
+				return;
+			}
+
 			int count = 0;
 			if (TBuffer.ChunkSize - this.sendBuffer.FirstIndex < this.sendBuffer.Count)
 			{
@@ -160,7 +237,12 @@ namespace TNet
 				count = this.sendBuffer.Count;
 			}
 			this.outSocketAsyncEventArgs.SetBuffer(this.sendBuffer.First, this.sendBuffer.FirstIndex, count);
-			this.socket.SendAsync(outSocketAsyncEventArgs);
+			if (this.socket.SendAsync(outSocketAsyncEventArgs))
+			{
+				return;
+			}
+			Action action = () => this.OnSendComplete(this.outSocketAsyncEventArgs.BytesTransferred);
+			this.poller.Add(action);
 		}
 	}
 }