// This file is provided under The MIT License as part of RiptideNetworking.
// Copyright (c) Tom Weiland
// For additional information please see the included LICENSE.md file or view it on GitHub:
// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md
using Riptide.Utils;
using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;
namespace Riptide.Transports.Tcp
{
/// Represents a connection to a or .
public class TcpConnection : Connection, IEquatable
{
/// The endpoint representing the other end of the connection.
public readonly IPEndPoint RemoteEndPoint;
/// Whether or not the server has received a connection attempt from this connection.
internal bool DidReceiveConnect;
/// The socket to use for sending and receiving.
private readonly Socket socket;
/// The local peer this connection is associated with.
private readonly TcpPeer peer;
/// An array to receive message size values into.
private readonly byte[] sizeBytes = new byte[sizeof(int)];
/// The size of the next message to be received.
private int nextMessageSize;
/// Initializes the connection.
/// The socket to use for sending and receiving.
/// The endpoint representing the other end of the connection.
/// The local peer this connection is associated with.
internal TcpConnection(Socket socket, IPEndPoint remoteEndPoint, TcpPeer peer)
{
RemoteEndPoint = remoteEndPoint;
this.socket = socket;
this.peer = peer;
}
///
protected internal override void Send(byte[] dataBuffer, int amount)
{
if (amount == 0)
throw new ArgumentOutOfRangeException(nameof(amount), "Sending 0 bytes is not allowed!");
try
{
if (socket.Connected)
{
Converter.FromInt(amount, peer.SendBuffer, 0);
Array.Copy(dataBuffer, 0, peer.SendBuffer, sizeof(int), amount); // TODO: consider sending length separately with an extra socket.Send call instead of copying the data an extra time
socket.Send(peer.SendBuffer, amount + sizeof(int), SocketFlags.None);
}
}
catch (SocketException)
{
// May want to consider triggering a disconnect here (perhaps depending on the type
// of SocketException)? Timeout should catch disconnections, but disconnecting
// explicitly might be better...
}
}
/// Polls the socket and checks if any data was received.
internal void Receive()
{
bool tryReceiveMore = true;
while (tryReceiveMore)
{
int byteCount = 0;
try
{
if (nextMessageSize > 0)
{
// We already have a size value
tryReceiveMore = TryReceiveMessage(out byteCount);
}
else if (socket.Available >= sizeof(int))
{
// We have enough bytes for a complete size value
socket.Receive(sizeBytes, sizeof(int), SocketFlags.None);
nextMessageSize = Converter.ToInt(sizeBytes, 0);
if (nextMessageSize > 0)
tryReceiveMore = TryReceiveMessage(out byteCount);
}
else
tryReceiveMore = false;
}
catch (SocketException ex)
{
tryReceiveMore = false;
switch (ex.SocketErrorCode)
{
case SocketError.Interrupted:
case SocketError.NotSocket:
peer.OnDisconnected(this, DisconnectReason.TransportError);
break;
case SocketError.ConnectionReset:
peer.OnDisconnected(this, DisconnectReason.Disconnected);
break;
case SocketError.TimedOut:
peer.OnDisconnected(this, DisconnectReason.TimedOut);
break;
case SocketError.MessageSize:
break;
default:
break;
}
}
catch (ObjectDisposedException)
{
tryReceiveMore = false;
peer.OnDisconnected(this, DisconnectReason.TransportError);
}
catch (NullReferenceException)
{
tryReceiveMore = false;
peer.OnDisconnected(this, DisconnectReason.TransportError);
}
if (byteCount > 0)
peer.OnDataReceived(byteCount, this);
}
}
/// Receives a message, if all of its data is ready to be received.
/// How many bytes were received.
/// Whether or not all of the message's data was ready to be received.
private bool TryReceiveMessage(out int receivedByteCount)
{
if (socket.Available >= nextMessageSize)
{
// We have enough bytes to read the complete message
receivedByteCount = socket.Receive(peer.ReceiveBuffer, nextMessageSize, SocketFlags.None);
nextMessageSize = 0;
return true;
}
receivedByteCount = 0;
return false;
}
/// Closes the connection.
internal void Close()
{
socket.Close();
}
///
public override string ToString() => RemoteEndPoint.ToStringBasedOnIPFormat();
///
public override bool Equals(object obj) => Equals(obj as TcpConnection);
///
public bool Equals(TcpConnection other)
{
if (other is null)
return false;
if (ReferenceEquals(this, other))
return true;
return RemoteEndPoint.Equals(other.RemoteEndPoint);
}
///
public override int GetHashCode()
{
return -288961498 + EqualityComparer.Default.GetHashCode(RemoteEndPoint);
}
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static bool operator ==(TcpConnection left, TcpConnection right)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
if (left is null)
{
if (right is null)
return true;
return false; // Only the left side is null
}
// Equals handles case of null on right side
return left.Equals(right);
}
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static bool operator !=(TcpConnection left, TcpConnection right) => !(left == right);
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
}
}