// This file is provided under The MIT License as part of RiptideNetworking.
// Copyright (c) Tom Weiland
// For additional information please see the included LICENSE.md file or view it on GitHub:
// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md
using Riptide.Transports;
using Riptide.Utils;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
namespace Riptide
{
/// A server that can accept connections from s.
public class Server : Peer
{
/// Invoked when a client connects.
public event EventHandler ClientConnected;
/// Invoked when a connection fails to be fully established.
public event EventHandler ConnectionFailed;
/// Invoked when a message is received.
public event EventHandler MessageReceived;
/// Invoked when a client disconnects.
public event EventHandler ClientDisconnected;
/// Whether or not the server is currently running.
public bool IsRunning { get; private set; }
/// The local port that the server is running on.
public ushort Port => transport.Port;
/// Sets the default timeout time for future connections and updates the of all connected clients.
public override int TimeoutTime
{
set
{
defaultTimeout = value;
foreach (Connection connection in clients.Values)
connection.TimeoutTime = defaultTimeout;
}
}
/// The maximum number of concurrent connections.
public ushort MaxClientCount { get; private set; }
/// The number of currently connected clients.
public int ClientCount => clients.Count;
/// An array of all the currently connected clients.
/// The position of each instance in the array does not correspond to that client's numeric ID (except by coincidence).
public Connection[] Clients => clients.Values.ToArray();
/// Encapsulates a method that handles a message from a client.
/// The numeric ID of the client from whom the message was received.
/// The message that was received.
public delegate void MessageHandler(ushort fromClientId, Message message);
/// Encapsulates a method that determines whether or not to accept a client's connection attempt.
public delegate void ConnectionAttemptHandler(Connection pendingConnection, Message connectMessage);
/// An optional method which determines whether or not to accept a client's connection attempt.
/// The parameter is the pending connection and the parameter is a message containing any additional data the
/// client included with the connection attempt. If you choose to subscribe a method to this delegate, you should use it to call either
/// or . Not doing so will result in the connection hanging until the client times out.
public ConnectionAttemptHandler HandleConnection;
/// Stores which message IDs have auto relaying enabled. Relaying is disabled entirely when this is .
public MessageRelayFilter RelayFilter;
/// Currently pending connections which are waiting to be accepted or rejected.
private readonly List pendingConnections;
/// Currently connected clients.
private Dictionary clients;
/// Clients that have timed out and need to be removed from .
private readonly List timedOutClients;
/// Methods used to handle messages, accessible by their corresponding message IDs.
private Dictionary messageHandlers;
/// The underlying transport's server that is used for sending and receiving data.
private IServer transport;
/// All currently unused client IDs.
private Queue availableClientIds;
/// Handles initial setup.
/// The transport to use for sending and receiving data.
/// The name to use when logging messages via .
public Server(IServer transport, string logName = "SERVER") : base(logName)
{
this.transport = transport;
pendingConnections = new List();
clients = new Dictionary();
timedOutClients = new List();
}
/// Handles initial setup using the built-in UDP transport.
/// The name to use when logging messages via .
public Server(string logName = "SERVER") : this(new Transports.Udp.UdpServer(), logName) { }
/// Stops the server if it's running and swaps out the transport it's using.
/// The new underlying transport server to use for sending and receiving data.
/// This method does not automatically restart the server. To continue accepting connections, must be called again.
public void ChangeTransport(IServer newTransport)
{
Stop();
transport = newTransport;
}
/// Starts the server.
/// The local port on which to start the server.
/// The maximum number of concurrent connections to allow.
/// The ID of the group of message handler methods to use when building .
/// Whether or not the server should use the built-in message handler system.
/// Setting to will disable the automatic detection and execution of methods with the , which is beneficial if you prefer to handle messages via the event.
public void Start(ushort port, ushort maxClientCount, byte messageHandlerGroupId = 0, bool useMessageHandlers = true)
{
Stop();
IncreaseActiveCount();
this.useMessageHandlers = useMessageHandlers;
if (useMessageHandlers)
CreateMessageHandlersDictionary(messageHandlerGroupId);
MaxClientCount = maxClientCount;
clients = new Dictionary(maxClientCount);
InitializeClientIds();
SubToTransportEvents();
transport.Start(port);
StartTime();
Heartbeat();
IsRunning = true;
RiptideLogger.Log(LogType.Info, LogName, $"Started on port {port}.");
}
/// Subscribes appropriate methods to the transport's events.
private void SubToTransportEvents()
{
transport.Connected += HandleConnectionAttempt;
transport.DataReceived += HandleData;
transport.Disconnected += TransportDisconnected;
}
/// Unsubscribes methods from all of the transport's events.
private void UnsubFromTransportEvents()
{
transport.Connected -= HandleConnectionAttempt;
transport.DataReceived -= HandleData;
transport.Disconnected -= TransportDisconnected;
}
///
protected override void CreateMessageHandlersDictionary(byte messageHandlerGroupId)
{
MethodInfo[] methods = FindMessageHandlers();
messageHandlers = new Dictionary(methods.Length);
foreach (MethodInfo method in methods)
{
MessageHandlerAttribute attribute = method.GetCustomAttribute();
if (attribute.GroupId != messageHandlerGroupId)
continue;
if (!method.IsStatic)
throw new NonStaticHandlerException(method.DeclaringType, method.Name);
Delegate serverMessageHandler = Delegate.CreateDelegate(typeof(MessageHandler), method, false);
if (serverMessageHandler != null)
{
// It's a message handler for Server instances
if (messageHandlers.ContainsKey(attribute.MessageId))
{
MethodInfo otherMethodWithId = messageHandlers[attribute.MessageId].GetMethodInfo();
throw new DuplicateHandlerException(attribute.MessageId, method, otherMethodWithId);
}
else
messageHandlers.Add(attribute.MessageId, (MessageHandler)serverMessageHandler);
}
else
{
// It's not a message handler for Server instances, but it might be one for Client instances
if (Delegate.CreateDelegate(typeof(Client.MessageHandler), method, false) == null)
throw new InvalidHandlerSignatureException(method.DeclaringType, method.Name);
}
}
}
/// Handles an incoming connection attempt.
private void HandleConnectionAttempt(object _, ConnectedEventArgs e)
{
e.Connection.Initialize(this, defaultTimeout);
}
/// Handles a connect message.
/// The client that sent the connect message.
/// The connect message.
private void HandleConnect(Connection connection, Message connectMessage)
{
connection.SetPending();
if (HandleConnection == null)
AcceptConnection(connection);
else if (ClientCount < MaxClientCount)
{
if (!clients.ContainsValue(connection) && !pendingConnections.Contains(connection))
{
pendingConnections.Add(connection);
Send(Message.Create(MessageHeader.Connect), connection); // Inform the client we've received the connection attempt
HandleConnection(connection, connectMessage); // Externally determines whether to accept
}
else
Reject(connection, RejectReason.AlreadyConnected);
}
else
Reject(connection, RejectReason.ServerFull);
}
/// Accepts the given pending connection.
/// The connection to accept.
public void Accept(Connection connection)
{
if (pendingConnections.Remove(connection))
AcceptConnection(connection);
else
RiptideLogger.Log(LogType.Warning, LogName, $"Couldn't accept connection from {connection} because no such connection was pending!");
}
/// Rejects the given pending connection.
/// The connection to reject.
/// Data that should be sent to the client being rejected. Use to get an empty message instance.
public void Reject(Connection connection, Message message = null)
{
if (message != null && message.ReadBits != 0)
RiptideLogger.Log(LogType.Error, LogName, $"Use the parameterless 'Message.Create()' overload when setting rejection data!");
if (pendingConnections.Remove(connection))
Reject(connection, message == null ? RejectReason.Rejected : RejectReason.Custom, message);
else
RiptideLogger.Log(LogType.Warning, LogName, $"Couldn't reject connection from {connection} because no such connection was pending!");
}
/// Accepts the given pending connection.
/// The connection to accept.
private void AcceptConnection(Connection connection)
{
if (ClientCount < MaxClientCount)
{
if (!clients.ContainsValue(connection))
{
ushort clientId = GetAvailableClientId();
connection.Id = clientId;
clients.Add(clientId, connection);
connection.ResetTimeout();
connection.SendWelcome();
return;
}
else
Reject(connection, RejectReason.AlreadyConnected);
}
else
Reject(connection, RejectReason.ServerFull);
}
/// Rejects the given pending connection.
/// The connection to reject.
/// The reason why the connection is being rejected.
/// Data that should be sent to the client being rejected.
private void Reject(Connection connection, RejectReason reason, Message rejectMessage = null)
{
if (reason != RejectReason.AlreadyConnected)
{
// Sending a reject message about the client already being connected could theoretically be exploited to obtain information
// on other connected clients, although in practice that seems very unlikely. However, under normal circumstances, clients
// should never actually encounter a scenario where they are "already connected".
Message message = Message.Create(MessageHeader.Reject);
message.AddByte((byte)reason);
if (reason == RejectReason.Custom)
message.AddMessage(rejectMessage);
for (int i = 0; i < 3; i++) // Send the rejection message a few times to increase the odds of it arriving
connection.Send(message, false);
message.Release();
}
connection.ResetTimeout(); // Keep the connection alive for a moment so the same client can't immediately attempt to connect again
connection.LocalDisconnect();
RiptideLogger.Log(LogType.Info, LogName, $"Rejected connection from {connection}: {Helper.GetReasonString(reason)}.");
}
/// Checks if clients have timed out.
internal override void Heartbeat()
{
foreach (Connection connection in clients.Values)
if (connection.HasTimedOut)
timedOutClients.Add(connection);
foreach (Connection connection in pendingConnections)
if (connection.HasConnectAttemptTimedOut)
timedOutClients.Add(connection);
foreach (Connection connection in timedOutClients)
LocalDisconnect(connection, DisconnectReason.TimedOut);
timedOutClients.Clear();
ExecuteLater(HeartbeatInterval, new HeartbeatEvent(this));
}
///
public override void Update()
{
base.Update();
transport.Poll();
HandleMessages();
}
///
protected override void Handle(Message message, MessageHeader header, Connection connection)
{
switch (header)
{
// User messages
case MessageHeader.Unreliable:
case MessageHeader.Reliable:
OnMessageReceived(message, connection);
break;
// Internal messages
case MessageHeader.Ack:
connection.HandleAck(message);
break;
case MessageHeader.Connect:
HandleConnect(connection, message);
break;
case MessageHeader.Heartbeat:
connection.HandleHeartbeat(message);
break;
case MessageHeader.Disconnect:
LocalDisconnect(connection, DisconnectReason.Disconnected);
break;
case MessageHeader.Welcome:
if (connection.HandleWelcomeResponse(message))
OnClientConnected(connection);
break;
default:
RiptideLogger.Log(LogType.Warning, LogName, $"Unexpected message header '{header}'! Discarding {message.BytesInUse} bytes received from {connection}.");
break;
}
message.Release();
}
/// Sends a message to a given client.
/// The message to send.
/// The numeric ID of the client to send the message to.
/// Whether or not to return the message to the pool after it is sent.
///
public void Send(Message message, ushort toClient, bool shouldRelease = true)
{
if (clients.TryGetValue(toClient, out Connection connection))
Send(message, connection, shouldRelease);
}
/// Sends a message to a given client.
/// The message to send.
/// The client to send the message to.
/// Whether or not to return the message to the pool after it is sent.
///
public ushort Send(Message message, Connection toClient, bool shouldRelease = true) => toClient.Send(message, shouldRelease);
/// Sends a message to all connected clients.
/// The message to send.
/// Whether or not to return the message to the pool after it is sent.
///
public void SendToAll(Message message, bool shouldRelease = true)
{
foreach (Connection client in clients.Values)
client.Send(message, false);
if (shouldRelease)
message.Release();
}
/// Sends a message to all connected clients except the given one.
/// The message to send.
/// The numeric ID of the client to not send the message to.
/// Whether or not to return the message to the pool after it is sent.
///
public void SendToAll(Message message, ushort exceptToClientId, bool shouldRelease = true)
{
foreach (Connection client in clients.Values)
if (client.Id != exceptToClientId)
client.Send(message, false);
if (shouldRelease)
message.Release();
}
/// Retrieves the client with the given ID, if a client with that ID is currently connected.
/// The ID of the client to retrieve.
/// The retrieved client.
/// if a client with the given ID was connected; otherwise .
public bool TryGetClient(ushort id, out Connection client) => clients.TryGetValue(id, out client);
/// Disconnects a specific client.
/// The numeric ID of the client to disconnect.
/// Data that should be sent to the client being disconnected. Use to get an empty message instance.
public void DisconnectClient(ushort id, Message message = null)
{
if (message != null && message.ReadBits != 0)
RiptideLogger.Log(LogType.Error, LogName, $"Use the parameterless 'Message.Create()' overload when setting disconnection data!");
if (clients.TryGetValue(id, out Connection client))
{
SendDisconnect(client, DisconnectReason.Kicked, message);
LocalDisconnect(client, DisconnectReason.Kicked);
}
else
RiptideLogger.Log(LogType.Warning, LogName, $"Couldn't disconnect client {id} because it wasn't connected!");
}
/// Disconnects the given client.
/// The client to disconnect.
/// Data that should be sent to the client being disconnected. Use to get an empty message instance.
public void DisconnectClient(Connection client, Message message = null)
{
if (message != null && message.ReadBits != 0)
RiptideLogger.Log(LogType.Error, LogName, $"Use the parameterless 'Message.Create()' overload when setting disconnection data!");
if (clients.ContainsKey(client.Id))
{
SendDisconnect(client, DisconnectReason.Kicked, message);
LocalDisconnect(client, DisconnectReason.Kicked);
}
else
RiptideLogger.Log(LogType.Warning, LogName, $"Couldn't disconnect client {client.Id} because it wasn't connected!");
}
///
internal override void Disconnect(Connection connection, DisconnectReason reason)
{
if (connection.IsConnected && connection.CanQualityDisconnect)
LocalDisconnect(connection, reason);
}
/// Cleans up the local side of the given connection.
/// The client to disconnect.
/// The reason why the client is being disconnected.
private void LocalDisconnect(Connection client, DisconnectReason reason)
{
if (client.Peer != this)
return; // Client does not belong to this Server instance
transport.Close(client);
if (clients.Remove(client.Id))
availableClientIds.Enqueue(client.Id);
if (client.IsConnected)
OnClientDisconnected(client, reason); // Only run if the client was ever actually connected
else if (client.IsPending)
OnConnectionFailed(client);
client.LocalDisconnect();
}
/// What to do when the transport disconnects a client.
private void TransportDisconnected(object sender, Transports.DisconnectedEventArgs e)
{
LocalDisconnect(e.Connection, e.Reason);
}
/// Stops the server.
public void Stop()
{
if (!IsRunning)
return;
pendingConnections.Clear();
byte[] disconnectBytes = { (byte)MessageHeader.Disconnect, (byte)DisconnectReason.ServerStopped };
foreach (Connection client in clients.Values)
client.Send(disconnectBytes, disconnectBytes.Length);
clients.Clear();
transport.Shutdown();
UnsubFromTransportEvents();
DecreaseActiveCount();
StopTime();
IsRunning = false;
RiptideLogger.Log(LogType.Info, LogName, "Server stopped.");
}
/// Initializes available client IDs.
private void InitializeClientIds()
{
if (MaxClientCount > ushort.MaxValue - 1)
throw new Exception($"A server's max client count may not exceed {ushort.MaxValue - 1}!");
availableClientIds = new Queue(MaxClientCount);
for (ushort i = 1; i <= MaxClientCount; i++)
availableClientIds.Enqueue(i);
}
/// Retrieves an available client ID.
/// The client ID. 0 if none were available.
private ushort GetAvailableClientId()
{
if (availableClientIds.Count > 0)
return availableClientIds.Dequeue();
RiptideLogger.Log(LogType.Error, LogName, "No available client IDs, assigned 0!");
return 0;
}
#region Messages
/// Sends a disconnect message.
/// The client to send the disconnect message to.
/// Why the client is being disconnected.
/// Optional custom data that should be sent to the client being disconnected.
private void SendDisconnect(Connection client, DisconnectReason reason, Message disconnectMessage)
{
Message message = Message.Create(MessageHeader.Disconnect);
message.AddByte((byte)reason);
if (reason == DisconnectReason.Kicked && disconnectMessage != null)
message.AddMessage(disconnectMessage);
Send(message, client);
}
/// Sends a client connected message.
/// The newly connected client.
private void SendClientConnected(Connection newClient)
{
Message message = Message.Create(MessageHeader.ClientConnected);
message.AddUShort(newClient.Id);
SendToAll(message, newClient.Id);
}
/// Sends a client disconnected message.
/// The numeric ID of the client that disconnected.
private void SendClientDisconnected(ushort id)
{
Message message = Message.Create(MessageHeader.ClientDisconnected);
message.AddUShort(id);
SendToAll(message);
}
#endregion
#region Events
/// Invokes the event.
/// The newly connected client.
protected virtual void OnClientConnected(Connection client)
{
RiptideLogger.Log(LogType.Info, LogName, $"Client {client.Id} ({client}) connected successfully!");
SendClientConnected(client);
ClientConnected?.Invoke(this, new ServerConnectedEventArgs(client));
}
/// Invokes the event.
/// The connection that failed to be fully established.
protected virtual void OnConnectionFailed(Connection connection)
{
RiptideLogger.Log(LogType.Info, LogName, $"Client {connection} stopped responding before the connection was fully established!");
ConnectionFailed?.Invoke(this, new ServerConnectionFailedEventArgs(connection));
}
/// Invokes the event and initiates handling of the received message.
/// The received message.
/// The client from which the message was received.
protected virtual void OnMessageReceived(Message message, Connection fromConnection)
{
ushort messageId = (ushort)message.GetVarULong();
if (RelayFilter != null && RelayFilter.ShouldRelay(messageId))
{
// The message should be automatically relayed to clients instead of being handled on the server
SendToAll(message, fromConnection.Id);
return;
}
MessageReceived?.Invoke(this, new MessageReceivedEventArgs(fromConnection, messageId, message));
if (useMessageHandlers)
{
if (messageHandlers.TryGetValue(messageId, out MessageHandler messageHandler))
messageHandler(fromConnection.Id, message);
else
RiptideLogger.Log(LogType.Warning, LogName, $"No message handler method found for message ID {messageId}!");
}
}
/// Invokes the event.
/// The client that disconnected.
/// The reason for the disconnection.
protected virtual void OnClientDisconnected(Connection connection, DisconnectReason reason)
{
RiptideLogger.Log(LogType.Info, LogName, $"Client {connection.Id} ({connection}) disconnected: {Helper.GetReasonString(reason)}.");
SendClientDisconnected(connection.Id);
ClientDisconnected?.Invoke(this, new ServerDisconnectedEventArgs(connection, reason));
}
#endregion
}
}