// This file is provided under The MIT License as part of RiptideNetworking. // Copyright (c) Tom Weiland // For additional information please see the included LICENSE.md file or view it on GitHub: // https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md using Riptide.Utils; using System; using System.Collections.Generic; using System.Net; using System.Net.Sockets; namespace Riptide.Transports.Tcp { /// Represents a connection to a or . public class TcpConnection : Connection, IEquatable { /// The endpoint representing the other end of the connection. public readonly IPEndPoint RemoteEndPoint; /// Whether or not the server has received a connection attempt from this connection. internal bool DidReceiveConnect; /// The socket to use for sending and receiving. private readonly Socket socket; /// The local peer this connection is associated with. private readonly TcpPeer peer; /// An array to receive message size values into. private readonly byte[] sizeBytes = new byte[sizeof(int)]; /// The size of the next message to be received. private int nextMessageSize; /// Initializes the connection. /// The socket to use for sending and receiving. /// The endpoint representing the other end of the connection. /// The local peer this connection is associated with. internal TcpConnection(Socket socket, IPEndPoint remoteEndPoint, TcpPeer peer) { RemoteEndPoint = remoteEndPoint; this.socket = socket; this.peer = peer; } /// protected internal override void Send(byte[] dataBuffer, int amount) { if (amount == 0) throw new ArgumentOutOfRangeException(nameof(amount), "Sending 0 bytes is not allowed!"); try { if (socket.Connected) { Converter.FromInt(amount, peer.SendBuffer, 0); Array.Copy(dataBuffer, 0, peer.SendBuffer, sizeof(int), amount); // TODO: consider sending length separately with an extra socket.Send call instead of copying the data an extra time socket.Send(peer.SendBuffer, amount + sizeof(int), SocketFlags.None); } } catch (SocketException) { // May want to consider triggering a disconnect here (perhaps depending on the type // of SocketException)? Timeout should catch disconnections, but disconnecting // explicitly might be better... } } /// Polls the socket and checks if any data was received. internal void Receive() { bool tryReceiveMore = true; while (tryReceiveMore) { int byteCount = 0; try { if (nextMessageSize > 0) { // We already have a size value tryReceiveMore = TryReceiveMessage(out byteCount); } else if (socket.Available >= sizeof(int)) { // We have enough bytes for a complete size value socket.Receive(sizeBytes, sizeof(int), SocketFlags.None); nextMessageSize = Converter.ToInt(sizeBytes, 0); if (nextMessageSize > 0) tryReceiveMore = TryReceiveMessage(out byteCount); } else tryReceiveMore = false; } catch (SocketException ex) { tryReceiveMore = false; switch (ex.SocketErrorCode) { case SocketError.Interrupted: case SocketError.NotSocket: peer.OnDisconnected(this, DisconnectReason.TransportError); break; case SocketError.ConnectionReset: peer.OnDisconnected(this, DisconnectReason.Disconnected); break; case SocketError.TimedOut: peer.OnDisconnected(this, DisconnectReason.TimedOut); break; case SocketError.MessageSize: break; default: break; } } catch (ObjectDisposedException) { tryReceiveMore = false; peer.OnDisconnected(this, DisconnectReason.TransportError); } catch (NullReferenceException) { tryReceiveMore = false; peer.OnDisconnected(this, DisconnectReason.TransportError); } if (byteCount > 0) peer.OnDataReceived(byteCount, this); } } /// Receives a message, if all of its data is ready to be received. /// How many bytes were received. /// Whether or not all of the message's data was ready to be received. private bool TryReceiveMessage(out int receivedByteCount) { if (socket.Available >= nextMessageSize) { // We have enough bytes to read the complete message receivedByteCount = socket.Receive(peer.ReceiveBuffer, nextMessageSize, SocketFlags.None); nextMessageSize = 0; return true; } receivedByteCount = 0; return false; } /// Closes the connection. internal void Close() { socket.Close(); } /// public override string ToString() => RemoteEndPoint.ToStringBasedOnIPFormat(); /// public override bool Equals(object obj) => Equals(obj as TcpConnection); /// public bool Equals(TcpConnection other) { if (other is null) return false; if (ReferenceEquals(this, other)) return true; return RemoteEndPoint.Equals(other.RemoteEndPoint); } /// public override int GetHashCode() { return -288961498 + EqualityComparer.Default.GetHashCode(RemoteEndPoint); } #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member public static bool operator ==(TcpConnection left, TcpConnection right) #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member { if (left is null) { if (right is null) return true; return false; // Only the left side is null } // Equals handles case of null on right side return left.Equals(right); } #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member public static bool operator !=(TcpConnection left, TcpConnection right) => !(left == right); #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member } }