diff --git a/Attributes/PacketHandlerAttribute.cs b/Attributes/PacketHandlerAttribute.cs new file mode 100644 index 0000000..d7e4e6c --- /dev/null +++ b/Attributes/PacketHandlerAttribute.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Attributes +{ + [AttributeUsage(AttributeTargets.Method, AllowMultiple = false)] + public class PacketHandlerAttribute : Attribute + { + public ushort packetId; + + public PacketHandlerAttribute(ushort packetId) + { + this.packetId = packetId; + } + } +} diff --git a/Attributes/TransmittableAttribute.cs b/Attributes/TransmittableAttribute.cs new file mode 100644 index 0000000..6a01ef8 --- /dev/null +++ b/Attributes/TransmittableAttribute.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Attributes +{ + [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field, AllowMultiple = false)] + public class TransmittableAttribute : Attribute + { + + } +} diff --git a/Constants.cs b/Constants.cs new file mode 100644 index 0000000..17b54d4 --- /dev/null +++ b/Constants.cs @@ -0,0 +1,34 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; + +namespace KCM +{ + /// + /// Kingdoms and Castles "Contants" (World instances, etc) + /// + /// _T is a Trasnform + /// _O is a GameObject + /// + public static class Constants + { + public static readonly MainMenuMode MainMenuMode = GameState.inst.mainMenuMode; + public static readonly PlayingMode PlayingMode = GameState.inst.playingMode; + public static readonly World World = GameState.inst.world; + + #region "UI" + public static readonly Transform MainMenuUI_T = MainMenuMode.mainMenuUI.transform; + public static readonly GameObject MainMenuUI_O = MainMenuMode.mainMenuUI; + + /* public static readonly Transform TopLevelUI_T = MainMenuUI_T.parent; + public static readonly GameObject TopLevelUI_O = MainMenuUI_T.parent.gameObject;*/ + + public static readonly Transform ChooseModeUI_T = MainMenuMode.chooseModeUI.transform; + public static readonly GameObject ChooseModeUI_O = MainMenuMode.chooseModeUI; + #endregion + + } +} diff --git a/Enums/Difficulty.cs b/Enums/Difficulty.cs new file mode 100644 index 0000000..4296212 --- /dev/null +++ b/Enums/Difficulty.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Enums +{ + public enum Difficulty + { + Paxlon, + Sommern, + Vintar, + Falle + } +} diff --git a/Enums/MenuState.cs b/Enums/MenuState.cs new file mode 100644 index 0000000..a272292 --- /dev/null +++ b/Enums/MenuState.cs @@ -0,0 +1,59 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Enums +{ + public enum MenuState + { + Uninitialized, + // Token: 0x040022F1 RID: 8945 + Menu, + // Token: 0x040022F2 RID: 8946 + ChooseMode, + // Token: 0x040022F3 RID: 8947 + ChooseDifficulty, + // Token: 0x040022F4 RID: 8948 + NewMap, + // Token: 0x040022F5 RID: 8949 + NameAndBanner, + // Token: 0x040022F6 RID: 8950 + PauseMenu, + // Token: 0x040022F7 RID: 8951 + SettingsMenu, + // Token: 0x040022F8 RID: 8952 + Save, + // Token: 0x040022F9 RID: 8953 + Load, + // Token: 0x040022FA RID: 8954 + QuitConfirm, + // Token: 0x040022FB RID: 8955 + ExitConfirm, + // Token: 0x040022FC RID: 8956 + LoadError, + // Token: 0x040022FD RID: 8957 + SendSave, + // Token: 0x040022FE RID: 8958 + Credits, + // Token: 0x040022FF RID: 8959 + Failure, + // Token: 0x04002300 RID: 8960 + KeepDestroyed, + // Token: 0x04002301 RID: 8961 + BannerSelect, + // Token: 0x04002302 RID: 8962 + GameWorkshopUI, + // Token: 0x04002303 RID: 8963 + RivalChoiceUI, + // Token: 0x04002304 RID: 8964 + KingdomShareFromMenu, + // Token: 0x04002305 RID: 8965 + KingdomShareFromGame, + + ServerBrowser, + + ServerLobby + } +} diff --git a/Enums/Packets.cs b/Enums/Packets.cs new file mode 100644 index 0000000..f7da401 --- /dev/null +++ b/Enums/Packets.cs @@ -0,0 +1,49 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Enums +{ + public enum Packets + { + ClientConnected = 25, + PlayerList = 26, + ChatSystemMessage = 27, + ChatMessage = 28, + ServerSettings = 29, + PlayerReady = 30, + PlayerBanner = 31, + KingdomName = 32, + StartGame = 33, + WorldSeed = 34, + + + Building = 50, + BuildingOnPlacement = 51, + + World = 70, + WorldPlace = 71, + FellTree = 72, + ShakeTree = 73, + GrowTree = 74, + UpdateConstruction = 75, + SetSpeed = 76, + CompleteBuild = 77, + WorldPlaceBatch = 78, + ChangeWeather = 79, + ShowModal = 80, + ServerHandshake = 81, + SpawnSiegeDragon = 82, + SpawnMamaDragon = 83, + SpawnBabyDragon = 84, + SaveTransferPacket = 85, + UpdateState = 86, + BuildingStatePacket = 87, + AddVillager = 88, + SetupInitialWorkers = 89, + VillagerTeleportTo = 90, + PlaceKeepRandomly = 91 + } +} diff --git a/ErrorCodeMessages.cs b/ErrorCodeMessages.cs new file mode 100644 index 0000000..d19e6e3 --- /dev/null +++ b/ErrorCodeMessages.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM +{ + public static class ErrorCodeMessages + { + private static readonly Dictionary errorMessages = new Dictionary + { + { "TimedOut", "Your connectioned has timed out." }, + { "Disconnected", "Lost connection to server." }, + { "NoConnection", "Failed to connect to server." }, + }; + + public static string GetMessage(Enum errorCode) + { + if (errorMessages.TryGetValue(errorCode.ToString(), out string message)) + { + return message; + } + return errorCode.ToString(); // Fallback message + } + } +} diff --git a/KCClient.cs b/KCClient.cs new file mode 100644 index 0000000..fa52f44 --- /dev/null +++ b/KCClient.cs @@ -0,0 +1,109 @@ +using Harmony; +using KCM.Enums; +using KCM.Packets; +using KCM.Packets.Handlers; +using KCM.Packets.Lobby; +using KCM.Packets.Network; +using Riptide; +using Steamworks; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; +using static KCM.KCServer; + +namespace KCM +{ + public class KCClient : MonoBehaviour + { + public static Client client = new Client(Main.steamClient); + + public string Name { get; set; } + + public static KCClient inst { get; set; } + + + static KCClient() + { + client.Connected += Client_Connected; + client.ConnectionFailed += Client_ConnectionFailed; + client.Disconnected += Client_Disconnected; + client.MessageReceived += PacketHandler.HandlePacket; + } + + private static void Client_Disconnected(object sender, DisconnectedEventArgs e) + { + Main.helper.Log("Client disconnected event start"); + try + { + if (e.Message != null) + { + Main.helper.Log(e.Message.ToString()); + MessageReceivedEventArgs eargs = new MessageReceivedEventArgs(null, (ushort)Enums.Packets.ShowModal, e.Message); + + if (eargs.MessageId == (ushort)Enums.Packets.ShowModal) + { + ShowModal modalPacket = (ShowModal)PacketHandler.DeserialisePacket(eargs); + + modalPacket.HandlePacketClient(); + } + } + else + { + + GameState.inst.SetNewMode(GameState.inst.mainMenuMode); + ModalManager.ShowModal("Disconnected from Server", ErrorCodeMessages.GetMessage(e.Reason), "Okay", true, () => { Main.TransitionTo(MenuState.ServerBrowser); }); + } + + } + catch (Exception ex) + { + Main.helper.Log("Error handling disconnection message"); + Main.helper.Log(ex.ToString()); + } + Main.helper.Log("Client disconnected event end"); + } + + private static void Client_ConnectionFailed(object sender, ConnectionFailedEventArgs e) + { + Main.helper.Log($"Connection failed: {e.Reason}"); + + ModalManager.ShowModal("Failed to connect", ErrorCodeMessages.GetMessage(e.Reason)); + } + + private static void Client_Connected(object sender, EventArgs e) + { + + + } + + public KCClient(string name) + { + Name = name; + } + + public static void Connect(string ip) + { + Main.helper.Log("Trying to connect to: " + ip); + client.Connect(ip, useMessageHandlers: false); + } + + private void Update() + { + client.Update(); + } + + private void Preload(KCModHelper helper) + { + + helper.Log("Preload run in client"); + } + + private void SceneLoaded(KCModHelper helper) + { + } + } +} diff --git a/KCPlayer.cs b/KCPlayer.cs new file mode 100644 index 0000000..75d7793 --- /dev/null +++ b/KCPlayer.cs @@ -0,0 +1,91 @@ +using KCM.Attributes; +using Riptide; +using Steamworks; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Numerics; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; + +namespace KCM +{ + public class KCPlayer + { + [Transmittable] + public ushort id; + [Transmittable] + public string name; + + public string steamId; + + + [Transmittable] + public string kingdomName; + [Transmittable] + public int banner = 0; + [Transmittable] + public bool ready = false; + + public Player inst; + public GameObject gameObject; + + + public KCPlayer(string name, ushort id, string steamId) + { + if (id != KCClient.client.Id) + { + gameObject = new GameObject($"Client Player ({id} {name})"); + + inst = gameObject.AddComponent(); + var irrigation = gameObject.AddComponent(); + var lmo = gameObject.AddComponent(); + + inst.irrigation = irrigation; + + inst.PlayerLandmassOwner = lmo; + inst.PlayerLandmassOwner.teamId = id * 10 + 2; + + inst.hazardPayWarmup = new Timer(5f); + inst.hazardPayWarmup.Enabled = false; + + bool[] flagsArr = new bool[38]; + for (int i = 0; i < flagsArr.Length; i++) + flagsArr[i] = true; + + var field = typeof(Player).GetField("defaultEnabledFlags", BindingFlags.NonPublic | BindingFlags.Instance); + field.SetValue(inst, flagsArr); + + + + Player oldPlayer = Player.inst; + Player.inst = inst; + + inst.Reset(); + + Player.inst = oldPlayer; + } + else + { + + gameObject = Player.inst.gameObject; + inst = Player.inst; + } + + this.name = name; + this.id = id; + this.steamId = steamId; + this.kingdomName = " "; + } + + public KCPlayer(ushort id, Player player) + { + gameObject = player.gameObject; + inst = player; + + this.id = id; + } + } +} diff --git a/KCServer.cs b/KCServer.cs new file mode 100644 index 0000000..0648689 --- /dev/null +++ b/KCServer.cs @@ -0,0 +1,121 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; +using Riptide; +using Harmony; +using System.Reflection; +using KCM.Packets.Handlers; +using KCM.Packets.Lobby; +using KCM.ServerLobby; +using KCM.Packets; +using KCM.Packets.Network; +using Riptide.Demos.Steam.PlayerHosted; + +namespace KCM +{ + public class KCServer : MonoBehaviour + { + public static Server server = new Server(Main.steamServer); + public static bool started = false; + + static KCServer() + { + //server.registerMessageHandler(typeof(KCServer).GetMethod("ClientJoined")); + + server.MessageReceived += PacketHandler.HandlePacketServer; + } + + public static void StartServer() + { + server = new Server(Main.steamServer); + server.MessageReceived += PacketHandler.HandlePacketServer; + + server.Start(0, 25, useMessageHandlers: false); + + server.ClientConnected += (obj, ev) => + { + Main.helper.Log("Client connected"); + + if (server.ClientCount > LobbyHandler.ServerSettings.MaxPlayers) + { + ShowModal showModal = new ShowModal() { title = "Failed to connect", message = "Server is full." }; + + showModal.Send(ev.Client.Id); + + server.DisconnectClient(ev.Client.Id); //, PacketHandler.SerialisePacket(showModal) + return; + } + + ev.Client.CanQualityDisconnect = false; + + Main.helper.Log("Client ID is: " + ev.Client.Id); + + new ServerHandshake() { clientId = ev.Client.Id, loadingSave = LobbyManager.loadingSave }.Send(ev.Client.Id); + }; + + server.ClientDisconnected += (obj, ev) => + { + new ChatSystemMessage() + { + Message = $"{Main.GetPlayerByClientID(ev.Client.Id).name} has left the server.", + }.SendToAll(); + + Main.kCPlayers.Remove(Main.GetPlayerByClientID(ev.Client.Id).steamId); + Destroy(LobbyHandler.playerEntries.Select(x => x.GetComponent()).Where(x => x.Client == ev.Client.Id).FirstOrDefault().gameObject); + + Main.helper.Log($"Client disconnected. {ev.Reason}"); + }; + + Main.helper.Log($"Listening on port 7777. Max {LobbyHandler.ServerSettings.MaxPlayers} clients."); + + + //Main.kCPlayers.Add(1, new KCPlayer(1, Player.inst)); + + //Player.inst = Main.GetPlayer(); + } + + /*[MessageHandler(25)] + public static void ClientJoined(ushort id, Message message) + { + var name = message.GetString(); + + Main.helper.Log(id.ToString()); + Main.helper.Log($"User connected: {name}"); + + if (id == 1) + { + players.Add(id, new KCPlayer(name, id, Player.inst)); + } + else + { + players.Add(id, new KCPlayer(name, id)); + } + }*/ + + public static bool IsRunning { get { return server.IsRunning; } } + + private void Update() + { + server.Update(); + } + + private void OnApplicationQuit() + { + server.Stop(); + } + + private void Preload(KCModHelper helper) + { + helper.Log("server?"); + + helper.Log("Preload run in server"); + } + + private void SceneLoaded(KCModHelper helper) + { + } + } +} diff --git a/LoadSaveOverrides/MultiplayerSaveContainer.cs b/LoadSaveOverrides/MultiplayerSaveContainer.cs new file mode 100644 index 0000000..918f003 --- /dev/null +++ b/LoadSaveOverrides/MultiplayerSaveContainer.cs @@ -0,0 +1,261 @@ +using Assets.Code; +using Riptide; +using Riptide.Transports; +using Steamworks; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.LoadSaveOverrides +{ + [Serializable] + public class MultiplayerSaveContainer : LoadSaveContainer + { + public Dictionary players = new Dictionary(); + public Dictionary kingdomNames = new Dictionary(); + + public new MultiplayerSaveContainer Pack(object obj) + { + this.CameraSaveData = new Cam.CamSaveData().Pack(Cam.inst); + this.TownNameSaveData = new TownNameUI.TownNameSaveData().Pack(TownNameUI.inst); + + Main.helper.Log($"Saving data for {Main.kCPlayers.Count} ({KCServer.server.ClientCount}) players."); + + //this.PlayerSaveData = new PlayerSaveDataOverride().Pack(Player.inst); + foreach (var player in Main.kCPlayers.Values) + { + Main.helper.Log($"Attempting to pack data for: " + player.name + $"({player.steamId})"); + Main.helper.Log($"{player.inst.ToString()} {player.inst?.gameObject.name}"); + this.players.Add(player.steamId, new Player.PlayerSaveData().Pack(player.inst)); + kingdomNames.Add(player.steamId, player.kingdomName); + + Main.helper.Log($"{players[player.steamId] == null}"); + } + + this.WorldSaveData = new World.WorldSaveData().Pack(World.inst); + this.FishSystemSaveData = new FishSystem.FishSystemSaveData().Pack(FishSystem.inst); + this.JobSystemSaveData = new JobSystem.JobSystemSaveData().Pack(JobSystem.inst); + this.FreeResourceManagerSaveData = new FreeResourceManager.FreeResourceManagerSaveData().Pack(FreeResourceManager.inst); + this.WeatherSaveData = new Weather.WeatherSaveData().Pack(Weather.inst); + this.FireManagerSaveData = new FireManager.FireManagerSaveData().Pack(FireManager.inst); + this.DragonSpawnSaveData = new DragonSpawn.DragonSpawnSaveData().Pack(DragonSpawn.inst); + this.UnitSystemSaveData = new UnitSystem.UnitSystemSaveData().Pack(UnitSystem.inst); + this.RaidSystemSaveData2 = new RaiderSystem.RaiderSystemSaveData2().Pack(RaiderSystem.inst); + this.ShipSystemSaveData = new ShipSystem.ShipSystemSaveData().Pack(ShipSystem.inst); + this.AIBrainsSaveData = new AIBrainsContainer.SaveData().Pack(AIBrainsContainer.inst); + this.SiegeMonsterSaveData = new SiegeMonster.SiegeMonsterSaveData().Pack(null); + this.CartSystemSaveData = new CartSystem.CartSystemSaveData().Pack(CartSystem.inst); + this.SiegeCatapultSystemSaveData = new SiegeCatapultSystem.SiegeCatapultSystemSaveData().Pack(SiegeCatapultSystem.inst); + this.OrdersManagerSaveData = new OrdersManager.OrdersManagerSaveData().Pack(OrdersManager.inst); + this.CustomSaveData = LoadSave.CustomSaveData_DontAccessDirectly; + + return this; + } + + public override object Unpack(object obj) + { + //original Player reset was up here + foreach (var kvp in players) + { + + KCPlayer player; + + if (!Main.kCPlayers.TryGetValue(kvp.Key, out player)) + { + player = new KCPlayer("", 50, kvp.Key); + player.kingdomName = kingdomNames[kvp.Key]; + + Main.kCPlayers.Add(kvp.Key, player); + } + } + + foreach (var player in Main.kCPlayers.Values) + player.inst.Reset(); + + + AIBrainsContainer.inst.ClearAIs(); + this.CameraSaveData.Unpack(Cam.inst); + this.WorldSaveData.Unpack(World.inst); + + bool flag = this.FishSystemSaveData != null; + if (flag) + { + this.FishSystemSaveData.Unpack(FishSystem.inst); + } + this.TownNameSaveData.Unpack(TownNameUI.inst); + + + //TownNameUI.inst.townName = kingdomNames[Main.PlayerSteamID]; + TownNameUI.inst.SetTownName(kingdomNames[Main.PlayerSteamID]); + + Main.helper.Log("Unpacking player data"); + + Player.PlayerSaveData clientPlayerData = null; + + foreach (var kvp in players) + { + if (kvp.Key == SteamUser.GetSteamID().ToString()) + { + Main.helper.Log("Found current client player data. ID: " + SteamUser.GetSteamID().ToString()); + + clientPlayerData = kvp.Value; + } + else + { // Maybe ?? + Main.helper.Log("Loading player data: " + kvp.Key); + + + KCPlayer player; + + if (!Main.kCPlayers.TryGetValue(kvp.Key, out player)) + { + player = new KCPlayer("", 50, kvp.Key); + Main.kCPlayers.Add(kvp.Key, player); + } + + Player oldPlayer = Player.inst; + Player.inst = player.inst; + Main.helper.Log($"Number of landmasses: {World.inst.NumLandMasses}"); + + //Reset was here before unpack + kvp.Value.Unpack(player.inst); + + Player.inst = oldPlayer; + + + player.banner = player.inst.PlayerLandmassOwner.bannerIdx; + player.kingdomName = TownNameUI.inst.townName; + } + } + + clientPlayerData.Unpack(Player.inst); // Unpack the current client player last so that loading of villagers works correctly. + + Main.helper.Log("unpacked player data"); + Main.helper.Log("Setting banner and name"); + + var client = Main.kCPlayers[SteamUser.GetSteamID().ToString()]; + + + client.banner = Player.inst.PlayerLandmassOwner.bannerIdx; + client.kingdomName = TownNameUI.inst.townName; + + Main.helper.Log("Finished unpacking player data"); + + /* + * Not even going to bother fixing AI brains save data yet, not in short-term roadmap + */ + + /*bool flag2 = this.AIBrainsSaveData != null; + if (flag2) + { + this.AIBrainsSaveData.UnpackPrePlayer(AIBrainsContainer.inst); + }*/ + + Main.helper.Log("Unpacking free resource manager"); + this.FreeResourceManagerSaveData.Unpack(FreeResourceManager.inst); + Main.helper.Log("Unpacking job system"); + this.JobSystemSaveData.Unpack(JobSystem.inst); + Main.helper.Log("Unpacking weather"); + this.WeatherSaveData.Unpack(Weather.inst); + Main.helper.Log("Unpacking fire manager"); + this.FireManagerSaveData.Unpack(FireManager.inst); + Main.helper.Log("Unpacking dragon spawn"); + this.DragonSpawnSaveData.Unpack(DragonSpawn.inst); + Main.helper.Log("Unpacking unit system"); + bool flag3 = this.UnitSystemSaveData != null; + if (flag3) + { + this.UnitSystemSaveData.Unpack(UnitSystem.inst); + } + Main.helper.Log("Unpacking siege monster"); + bool flag4 = this.SiegeMonsterSaveData != null; + if (flag4) + { + this.SiegeMonsterSaveData.Unpack(null); + } + Main.helper.Log("Unpacking siege catapult system"); + bool flag5 = this.SiegeCatapultSystemSaveData != null; + if (flag5) + { + this.SiegeCatapultSystemSaveData.Unpack(SiegeCatapultSystem.inst); + } + Main.helper.Log("Unpacking ship system"); + bool flag6 = this.ShipSystemSaveData != null; + if (flag6) + { + this.ShipSystemSaveData.Unpack(ShipSystem.inst); + } + Main.helper.Log("Unpacking cart system"); + bool flag7 = this.CartSystemSaveData != null; + if (flag7) + { + this.CartSystemSaveData.Unpack(CartSystem.inst); + } + Main.helper.Log("Unpacking raid system"); + bool flag8 = this.RaidSystemSaveData2 != null; + if (flag8) + { + this.RaidSystemSaveData2.Unpack(RaiderSystem.inst); + } + Main.helper.Log("Unpacking orders manager"); + bool flag9 = this.OrdersManagerSaveData != null; + if (flag9) + { + this.OrdersManagerSaveData.Unpack(OrdersManager.inst); + } + Main.helper.Log("Unpacking AI brains"); + bool flag10 = this.AIBrainsSaveData != null; + if (flag10) + { + this.AIBrainsSaveData.Unpack(AIBrainsContainer.inst); + } + Main.helper.Log("Unpacking custom save data"); + bool flag11 = this.CustomSaveData != null; + if (flag11) + { + LoadSave.CustomSaveData_DontAccessDirectly = this.CustomSaveData; + } + Main.helper.Log("Unpacking done"); + + + World.inst.UpscaleFeatures(); + Player.inst.RefreshVisibility(true); + for (int i = 0; i < Player.inst.Buildings.Count; i++) + { + Player.inst.Buildings.data[i].UpdateMaterialSelection(); + } + + // Player.inst.loadTickDelay = 1; + Type playerType = typeof(Player); + FieldInfo loadTickDelayField = playerType.GetField("loadTickDelay", BindingFlags.Instance | BindingFlags.NonPublic); + if (loadTickDelayField != null) + { + loadTickDelayField.SetValue(Player.inst, 1); + } + + // UnitSystem.inst.loadTickDelay = 1; + Type unitSystemType = typeof(UnitSystem); + loadTickDelayField = unitSystemType.GetField("loadTickDelay", BindingFlags.Instance | BindingFlags.NonPublic); + if (loadTickDelayField != null) + { + loadTickDelayField.SetValue(UnitSystem.inst, 1); + } + + // JobSystem.inst.loadTickDelay = 1; + Type jobSystemType = typeof(JobSystem); + loadTickDelayField = jobSystemType.GetField("loadTickDelay", BindingFlags.Instance | BindingFlags.NonPublic); + if (loadTickDelayField != null) + { + loadTickDelayField.SetValue(JobSystem.inst, 1); + } + + Main.helper.Log($"Setting kingdom name to: {kingdomNames[Main.PlayerSteamID]}"); + TownNameUI.inst.SetTownName(kingdomNames[Main.PlayerSteamID]); + + return obj; + } + } +} diff --git a/LoadSaveOverrides/MultiplayerSaveDeserializationBinder.cs b/LoadSaveOverrides/MultiplayerSaveDeserializationBinder.cs new file mode 100644 index 0000000..4e36786 --- /dev/null +++ b/LoadSaveOverrides/MultiplayerSaveDeserializationBinder.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Runtime.Serialization; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.LoadSaveOverrides +{ + sealed class MultiplayerSaveDeserializationBinder : SerializationBinder + { + public override Type BindToType(string assemblyName, string typeName) + { + Type typeToDeserialize = null; + + // For each assemblyName/typeName that you want to deserialize to + // a different type, set typeToDeserialize to the desired type. + String exeAssembly = Assembly.GetExecutingAssembly().FullName; + + + // The following line of code returns the type. + typeToDeserialize = Type.GetType(String.Format("{0}, {1}", + typeName, exeAssembly)); + + return typeToDeserialize; + } + } +} diff --git a/Main.cs b/Main.cs new file mode 100644 index 0000000..f574754 --- /dev/null +++ b/Main.cs @@ -0,0 +1,2445 @@ +using Assets.Code; +using Assets.Code.UI; +using Assets.Interface; +using Harmony; +using KCM.Enums; +using KCM.LoadSaveOverrides; +using KCM.Packets.Game; +using KCM.Packets.Game.Dragon; +using KCM.Packets.Game.GameBuilding; +using KCM.Packets.Game.GamePlayer; +using KCM.Packets.Game.GameTrees; +using KCM.Packets.Game.GameVillager; +using KCM.Packets.Game.GameWeather; +using KCM.Packets.Game.GameWorld; +using KCM.Packets.Handlers; +using KCM.Packets.Lobby; +using KCM.StateManagement.BuildingState; +using KCM.StateManagement.Observers; +using KCM.UI; +using Newtonsoft.Json; +using Riptide; +using Riptide.Demos.Steam.PlayerHosted; +using Riptide.Transports.Steam; +using Riptide.Utils; +using Steamworks; +using System; +using System.Collections.Generic; +using System.Configuration.Assemblies; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Reflection.Emit; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.ComTypes; +using System.Runtime.Serialization.Formatters.Binary; +using System.Security.AccessControl; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using TMPro; +using UnityEngine; +using UnityEngine.Events; +using UnityEngine.UI; +using static ModCompiler; +using static World; + +namespace KCM +{ + public class Main : MonoBehaviour + { + public static KCModHelper helper; + public static MenuState menuState = (MenuState)MainMenuMode.State.Uninitialized; + + public static Dictionary kCPlayers = new Dictionary(); + public static Dictionary clientSteamIds = new Dictionary(); + + public static KCPlayer GetPlayerByClientID(ushort clientId) + { + return kCPlayers[clientSteamIds[clientId]]; + } + + public static Player GetPlayerByTeamID(int teamId) // Need to replace building / production types so that the correct player is used. IResourceStorage and IResourceProvider, and jobs + { + try + { + var player = kCPlayers.Values.FirstOrDefault(p => p.inst.PlayerLandmassOwner.teamId == teamId).inst; + + return player; + } + catch (Exception e) + { + if (KCServer.IsRunning || KCClient.client.IsConnected) + { + Main.helper.Log("Failed finding player by teamID: " + teamId + " My teamID is: " + Player.inst.PlayerLandmassOwner.teamId); + Main.helper.Log(kCPlayers.Count.ToString()); + Main.helper.Log(string.Join(", ", kCPlayers.Values.Select(p => p.inst.PlayerLandmassOwner.teamId.ToString()))); + Main.helper.Log(e.Message); + Main.helper.Log(e.StackTrace); + } + } + return Player.inst; + } + + public static Player GetPlayerByBuilding(Building building) + { + try + { + var lmo = World.GetLandmassOwner(building.LandMass()); + + if (lmo == null) // Return the actual player for the client if the landmass owner is null + return Player.inst; + + // Return the player by teamId so that the correct player instance is updated/used on the server + return GetPlayerByTeamID(building.TeamID()); + } + catch (Exception e) + { + Main.helper.Log("Failed finding player by building: " + building.UniqueName); + Main.helper.Log(e.Message); + Main.helper.Log(e.StackTrace); + } + return Player.inst; + } + + public static string PlayerSteamID = SteamUser.GetSteamID().ToString(); + + public static KCMSteamManager KCMSteamManager = null; + public static SteamServer steamServer = new SteamServer(); + public static Riptide.Transports.Steam.SteamClient steamClient = new Riptide.Transports.Steam.SteamClient(steamServer); + + public static ushort currentClient = 0; + + #region "SceneLoaded" + private void SceneLoaded(KCModHelper helper) + { + helper.Log("SceneLoaded run in main"); + RiptideLogger.Initialize(helper.Log, helper.Log, helper.Log, helper.Log, false); + + helper.Log($"{SteamFriends.GetPersonaName()}"); + + + KCMSteamManager = new GameObject("KCMSteamManager").AddComponent(); + DontDestroyOnLoad(KCMSteamManager); + + var lobbyManager = new GameObject("LobbyManager").AddComponent(); + DontDestroyOnLoad(lobbyManager); + + //SteamFriends.InviteUserToGame(new CSteamID(76561198036307537), "test"); + //SteamMatchmaking.lobby + + //Main.helper.Log($"Timer duration for hazardpay {Player.inst.hazardPayWarmup.Duration}"); + + try + { + + SteamFriends.SetRichPresence("status", "Playing Multiplayer"); + + PacketHandler.Initialise(); + + Main.helper.Log(JsonConvert.SerializeObject(World.inst.mapSizeDefs, Formatting.Indented)); + + KaC_Button serverBrowser = new KaC_Button(Constants.MainMenuUI_T.Find("TopLevelUICanvas/TopLevel/Body/ButtonContainer/New").parent) + { + Name = "Multiplayer", + Text = "Multiplayer", + FirstSibling = true, + OnClick = () => + { + //Constants.MainMenuUI_T.Find("TopLevelUICanvas/TopLevel").gameObject.SetActive(false); + SfxSystem.PlayUiSelect(); + + //ServerBrowser.serverBrowserRef.SetActive(true); + TransitionTo(MenuState.ServerBrowser); + } + }; + serverBrowser.Transform.SetSiblingIndex(2); + + + Destroy(Constants.MainMenuUI_T.Find("TopLevelUICanvas/TopLevel/Body/ButtonContainer/Kingdom Share").gameObject); + } + catch (Exception ex) + { + Main.helper.Log("----------------------- Main exception -----------------------"); + Main.helper.Log(ex.ToString()); + Main.helper.Log("----------------------- Main message -----------------------"); + Main.helper.Log(ex.Message); + Main.helper.Log("----------------------- Main stacktrace -----------------------"); + Main.helper.Log(ex.StackTrace); + if (ex.InnerException != null) + { + Main.helper.Log("----------------------- Inner exception -----------------------"); + Main.helper.Log(ex.InnerException.ToString()); + Main.helper.Log("----------------------- Inner message -----------------------"); + Main.helper.Log(ex.InnerException.Message); + Main.helper.Log("----------------------- Inner stacktrace -----------------------"); + Main.helper.Log(ex.InnerException.StackTrace); + } + } + + } + #endregion + + public static int FixedUpdateInterval = 0; + + private void FixedUpdate() + { + // send batched building placement info + /*if (PlaceHook.QueuedBuildings.Count > 0 && (FixedUpdateInterval % 25 == 0)) + { + foreach (Building building in PlaceHook.QueuedBuildings) + { + new WorldPlace() + { + uniqueName = building.UniqueName, + customName = building.customName, + guid = building.guid, + rotation = building.transform.GetChild(0).rotation, + globalPosition = building.transform.position, + localPosition = building.transform.GetChild(0).localPosition, + built = building.IsBuilt(), + placed = building.IsPlaced(), + open = building.Open, + doBuildAnimation = building.doBuildAnimation, + constructionPaused = building.constructionPaused, + constructionProgress = building.constructionProgress, + life = building.Life, + ModifiedMaxLife = building.ModifiedMaxLife, + //CollectForBuild = CollectForBuild, + yearBuilt = building.YearBuilt, + decayProtection = building.decayProtection, + seenByPlayer = building.seenByPlayer + }.Send(); + } + + PlaceHook.QueuedBuildings.Clear(); + }*/ + + FixedUpdateInterval++; + } + + #region "TransitionTo" + public static void TransitionTo(MenuState state) + { + try + { + ServerBrowser.serverBrowserRef.SetActive(state == MenuState.ServerBrowser); + ServerBrowser.serverLobbyRef.SetActive(state == MenuState.ServerLobby); + + ServerBrowser.KCMUICanvas.gameObject.SetActive((int)state > 21); + helper.Log(((int)state > 21).ToString()); + + GameState.inst.mainMenuMode.TransitionTo((MainMenuMode.State)state); + } + catch (Exception ex) + { + Main.helper.Log("----------------------- Main exception -----------------------"); + Main.helper.Log(ex.ToString()); + Main.helper.Log("----------------------- Main message -----------------------"); + Main.helper.Log(ex.Message); + Main.helper.Log("----------------------- Main stacktrace -----------------------"); + Main.helper.Log(ex.StackTrace); + if (ex.InnerException != null) + { + Main.helper.Log("----------------------- Inner exception -----------------------"); + Main.helper.Log(ex.InnerException.ToString()); + Main.helper.Log("----------------------- Inner message -----------------------"); + Main.helper.Log(ex.InnerException.Message); + Main.helper.Log("----------------------- Inner stacktrace -----------------------"); + Main.helper.Log(ex.InnerException.StackTrace); + } + } + } + #endregion + + private void Preload(KCModHelper helper) + { + helper.Log("Preload start in main"); + try + { + + + //MainMenuPatches.Patch(); + Main.helper = helper; + helper.Log(helper.modPath); + + var harmony = HarmonyInstance.Create("harmony"); + harmony.PatchAll(Assembly.GetExecutingAssembly()); + + + helper.Log("Preload run in main"); + } + catch (Exception ex) + { + Main.helper.Log("----------------------- Main exception -----------------------"); + Main.helper.Log(ex.ToString()); + Main.helper.Log("----------------------- Main message -----------------------"); + Main.helper.Log(ex.Message); + Main.helper.Log("----------------------- Main stacktrace -----------------------"); + Main.helper.Log(ex.StackTrace); + } + helper.Log("Preload end in main"); + } + + #region "MainMenu Hooks" + + public static MenuState prevMenuState = MenuState.Uninitialized; + + [HarmonyPatch(typeof(MainMenuMode))] + [HarmonyPatch("TransitionTo")] + public class TransitionToHook + { + private static void Prefix(MainMenuMode.State newState) + { + Main.helper.Log($"Menu set to: {(MenuState)newState}"); + + Main.prevMenuState = Main.menuState; + + if (newState != MainMenuMode.State.Uninitialized) + Main.menuState = (MenuState)newState; + } + } + + [HarmonyPatch(typeof(MainMenuMode))] + [HarmonyPatch("OnClickedClose")] + public class OnClickedCloseHook + { + private static bool Prefix() + { + helper.Log("Transition back"); + + TransitionTo(prevMenuState); + + return false; + } + } + + [HarmonyPatch(typeof(MainMenuMode))] + [HarmonyPatch("OnClickedBackToModeSelect")] + public class OnClickedBackToModeSelectPatch + { + private static bool Prefix() + { + if (KCClient.client.IsConnected) + { + Main.TransitionTo(MenuState.ServerLobby); + SfxSystem.PlayUiCancel(); + + return false; + } + else return true; + } + } + + [HarmonyPatch(typeof(MainMenuMode))] + [HarmonyPatch("OnClickedAcceptNameBanner")] + public class OnClickedAcceptNameBannerPatch + { + private static bool Prefix() + { + if (KCClient.client.IsConnected) + { + Main.TransitionTo(MenuState.ServerLobby); + SfxSystem.PlayUiCancel(); + + return false; + } + else return true; + } + } + #endregion + + #region "TownName Hooks" + [HarmonyPatch(typeof(TownNameUI))] + [HarmonyPatch("SetTownNameQuiet")] + public static class TownNameHook + { + //A function to run after target function invocation + private static void Postfix(TownNameUI __instance) + { + helper.Log($"name set: {__instance.townName}"); + + new KingdomName() { kingdomName = __instance.townName }.Send(); + } + } + #endregion + + #region "ChooseBanner Hooks" + [HarmonyPatch(typeof(ChooseBannerUI))] + [HarmonyPatch("OnAccept")] + public class ChooseBannerUIOnAcceptHook + { + private static void Postfix() + { + if (KCClient.client.IsConnected) + { + var banner = Player.inst.PlayerLandmassOwner.bannerIdx; + + + new PlayerBanner() { banner = banner }.Send(); + //return true; + } + //else return true; + } + } + #endregion + + + [HarmonyPatch(typeof(Keep))] + [HarmonyPatch("OnPlayerPlacement")] + public class KeepHook + { + public static void Postfix() + { + // Your code here + + // Get the name of the last method that called OnPlayerPlacement + string callTree = ""; + List strings = new List(); + + for (int i = 1; i < 10; i++) + { + try + { + string callingMethodName = new StackFrame(i).GetMethod().Name; + strings.Add(callingMethodName); + } + catch + { + strings.Add("Start"); + break; + } + } + + strings.Reverse(); + + Main.helper.Log($"Last {strings.Count} methods in call tree: {string.Join(" -> ", strings)}"); + } + } + + #region "GameUI Hooks" + //GameUI hook for acceptcursorobjplacement + /*[HarmonyPatch(typeof(GameUI), "AcceptCursorObjPlacement")] + public class AcceptCursorObjPlacementHook + { + }*/ + #endregion + + #region "World Hooks" + [HarmonyPatch(typeof(World))] + [HarmonyPatch("Place")] + public class PlaceHook + { + /*public static bool Prefix() + { + if (KCClient.client.IsConnected && !KCServer.IsRunning) + { + if (!new StackFrame(3).GetMethod().kingdomName.Contains("HandlePacket")) + return false; + } + + return true; + }*/ + + public static void Postfix(Building PendingObj) + { + try + { + if (KCClient.client.IsConnected) + { + /*string callTree = ""; + List strings = new List(); + + for (int i = 1; i < 10; i++) + { + try + { + string callingMethodName = new StackFrame(i).GetMethod().Name; + strings.Add($"{callingMethodName} ({i})"); + } + catch + { + strings.Add("Start"); + break; + } + } + + strings.Reverse(); + + Main.helper.Log($"WORLDPLACE Last {strings.Count} methods in call tree: {string.Join(" -> ", strings)}");*/ + + if (new StackFrame(3).GetMethod().Name.Contains("HandlePacket") && !new StackFrame(2).GetMethod().Name.Equals("RandomPlacement")) + return; + + Main.helper.Log($"Called by: {new StackFrame(3).GetMethod().Name}"); + Main.helper.Log($"{KCClient.client.Id} {Main.kCPlayers[PlayerSteamID].name} - Sending building place packet for " + PendingObj.UniqueName); + + // Need to batch building placements to prevent network spam + new WorldPlace() + { + uniqueName = PendingObj.UniqueName, + customName = PendingObj.customName, + guid = PendingObj.guid, + rotation = PendingObj.transform.GetChild(0).rotation, + globalPosition = PendingObj.transform.position, + localPosition = PendingObj.transform.GetChild(0).localPosition, + built = PendingObj.IsBuilt(), + placed = PendingObj.IsPlaced(), + open = PendingObj.Open, + doBuildAnimation = PendingObj.doBuildAnimation, + constructionPaused = PendingObj.constructionPaused, + constructionProgress = PendingObj.constructionProgress, + life = PendingObj.Life, + ModifiedMaxLife = PendingObj.ModifiedMaxLife, + //CollectForBuild = CollectForBuild, + yearBuilt = PendingObj.YearBuilt, + decayProtection = PendingObj.decayProtection, + seenByPlayer = PendingObj.seenByPlayer + }.Send(); + //return true; + } + //else return true; + } + catch (Exception e) + { + Main.helper.Log("World Place error"); + Main.helper.Log(e.Message); + Main.helper.Log(e.StackTrace); + } + } + } + + [HarmonyPatch(typeof(World), "RelationBetween")] + public class WorldRelationBetweenHook + { + public static void Prefix(ref int teamIDA, ref int teamIDB) + { + + //Main.helper.Log($"RelationBetween {teamIDA} and {teamIDB}"); + + if (KCClient.client.IsConnected) + { + if (teamIDA == 0 || teamIDB == 0) + { + if (teamIDA == 0) + teamIDA = Player.inst.PlayerLandmassOwner.teamId; + + if (teamIDB == 0) + teamIDB = Player.inst.PlayerLandmassOwner.teamId; + } + } + } + } + #endregion + + + #region "Player Hooks" + + [HarmonyPatch(typeof(Player), "Reset")] + public class PlayerResetHook + { + public static bool Prefix(Player __instance) + { + if (KCClient.client.IsConnected && __instance.gameObject.name.Contains("Client Player") && !LobbyManager.loadingSave) + { + try + { + var bindingFlags = BindingFlags.Instance | BindingFlags.NonPublic; + + __instance.GetType().GetField("resetting", bindingFlags).SetValue(__instance, true); + //__instance.resetting = true; + __instance.GetType().GetField("poorHealthGracePeriod", bindingFlags).SetValue(__instance, 0f); + //__instance.poorHealthGracePeriod = 0f; + __instance.PlayerLandmassOwner.Gold = 0; + __instance.CurrYear = 0; + + __instance.buildingDamageAnimator.Reset(); + __instance.fruitSystem.Reset(); + __instance.fieldSystem.Reset(); + + bool flag = __instance.DamagedList != null; + if (flag) + { + for (int i = 0; i < __instance.DamagedList.Length; i++) + { + __instance.DamagedList[i].Clear(); + } + __instance.DamagedList = null; + } + + __instance.irrigation.Reset(); + __instance.ClearRegistry(); + bool flag2 = __instance.buildingContainer; + if (flag2) + { + Building[] buildings = __instance.buildingContainer.transform.GetComponentsInChildren(); + for (int j = 0; j < buildings.Length; j++) + { + buildings[j].destroyedWhileInPlay = false; + UnityEngine.Object.Destroy(buildings[j].gameObject); + } + } + UnityEngine.Object.Destroy(__instance.buildingContainer); + __instance.buildingContainer = new GameObject(); + __instance.buildingContainer.name = "Buildings"; + + for (int k = 0; k < __instance.Workers.Count; k++) + { + __instance.Workers.data[k].Shutdown(); + } + __instance.Workers.Clear(); + /*int r = 0; + for (int l = 0; l < __instance.Homeless.Count; l++) + { + bool flag3 = !__instance.Homeless.data[l].shutdown; + if (flag3) + { + r++; + } + }*/ + + __instance.Homeless.Clear(); + + __instance.Residentials.Clear(); + __instance.Buildings.Clear(); + __instance.RadiusBonuses.Clear(); + __instance.WagePayers.Clear(); + + __instance.timeAtFailHappiness = 0f; + __instance.MaxGoldStorage = 0; + __instance.KingdomHappiness = 100; + ReflectionHelper.ClearPrivateListField(__instance, "landMassHappiness"); + //__instance.landMassHappiness.Clear(); + ReflectionHelper.ClearPrivateListField(__instance, "landMassHealth"); + //__instance.landMassHealth.Clear(); + ReflectionHelper.ClearPrivateListField(__instance, "landMassIntegrity"); + //__instance.landMassIntegrity.Clear(); + //__instance.HealthTimer.ForceExpire(); // TO-DO implement timer + __instance.happinessMods.Clear(); + /*for (int m = 0; m < __instance.plagueDeathInfo.Count; m++) + { + __instance.plagueDeathInfo[m].deathQueue.Clear(); + __instance.plagueDeathInfo[m].deaths = 0; + __instance.plagueDeathInfo[m].deathTime = 0f; + }*/ + ReflectionHelper.ClearPrivateListField(__instance, "OldAgeDeathQueue"); + //__instance.OldAgeDeathQueue.Clear(); + __instance.GetType().GetField("deathsThisYear", bindingFlags).SetValue(__instance, 0); + //__instance.deathsThisYear = 0; + __instance.ResetPerLandMassData(); + __instance.ResetTaxRates(); + __instance.ResetCreativeModeOptions(); + __instance.PlayerLandmassOwner.ReleaseOwnership(); + + ReflectionHelper.ClearPrivateListField(__instance, "dockOpenings"); + //__instance.dockOpenings.Clear(); + __instance.GetType().GetField("resetting", bindingFlags).SetValue(__instance, false); + //__instance.resetting = false; + } + catch (Exception e) + { + Main.helper.Log("Error in reset player hook"); + Main.helper.Log(e.Message); + Main.helper.Log(e.StackTrace); + } + + return false; + } + + return true; + } + } + + [HarmonyPatch(typeof(Player), "AddBuilding")] + public class PlayerAddBuildingHook + { + static int step = 1; + static void LogStep(bool reset = false) + { + if (reset) + step = 1; + + Main.helper.Log(step.ToString()); + step++; + } + + public static bool Prefix(Player __instance, Building b) + { + try + { + if (KCClient.client.IsConnected) + { + LogStep(true); + __instance.Buildings.Add(b); + IResourceStorage[] storages = b.GetComponents(); + LogStep(); + for (int i = 0; i < storages.Length; i++) + { + bool flag = !storages[i].IsPrivate(); + if (flag) + { + FreeResourceManager.inst.AddResourceStorage(storages[i]); + } + } + LogStep(); + int landMass = b.LandMass(); + Home res = b.GetComponent(); + bool flag2 = res != null; + LogStep(); + if (flag2) + { + __instance.Residentials.Add(res); + __instance.ResidentialsPerLandmass[landMass].Add(res); + } + WagePayer wagePayer = b.GetComponent(); + LogStep(); + bool flag3 = wagePayer != null; + if (flag3) + { + __instance.WagePayers.Add(wagePayer); + } + RadiusBonus radiusBonus = b.GetComponent(); + LogStep(); + bool flag4 = radiusBonus != null; + if (flag4) + { + __instance.RadiusBonuses.Add(radiusBonus); + } + LogStep(); + var globalBuildingRegistry = __instance.GetType().GetField("globalBuildingRegistry", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(__instance) as ArrayExt; + LogStep(); + var landMassBuildingRegistry = __instance.GetType().GetField("landMassBuildingRegistry", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(__instance) as ArrayExt; + LogStep(); + var unbuiltBuildingsPerLandmass = __instance.GetType().GetField("unbuiltBuildingsPerLandmass", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(__instance) as ArrayExt>; + LogStep(); + + __instance.AddToRegistry(globalBuildingRegistry, b); + LogStep(); + __instance.AddToRegistry(landMassBuildingRegistry.data[landMass].registry, b); + LogStep(); + landMassBuildingRegistry.data[landMass].buildings.Add(b); + LogStep(); + bool flag5 = !b.IsBuilt(); + if (flag5) + { + unbuiltBuildingsPerLandmass.data[landMass].Add(b); + } + LogStep(); + + + return false; + } + } + catch (Exception e) + { + Main.helper.Log("Error in add building hook"); + Main.helper.Log(e.Message); + Main.helper.Log(e.StackTrace); + } + + return true; + } + } + + [HarmonyPatch(typeof(Player), "SetupInitialWorkers")] + public class PlayerSetupInitialWorkersHook + { + public static void Postfix(Keep keep) + { + if (KCClient.client.IsConnected) + { + + + if (new StackFrame(3).GetMethod().Name.Contains("HandlePacket")) + return; + + new SetupInitialWorkersPacket() + { + keepGuid = keep.gameObject.GetComponent().guid + }.Send(); + } + } + } + + [HarmonyPatch(typeof(VillagerSystem), "AddVillager")] + public class PlayerAddVillagerHook + { + public static void Postfix(Villager __result, Vector3 pos) + { + if (KCClient.client.IsConnected) + { + try + { + if (new StackFrame(3).GetMethod().Name.Contains("HandlePacket")) + return; + + if (Enumerable.Range(0, 4).Select(i => new StackFrame(i).GetMethod()?.Name).Any(name => name?.Contains("unpack") == true)) // If called by unpack in the tree, do not run, since clients already unpacked villager data + return; + + new AddVillagerPacket() + { + guid = __result.guid, + }.Send(); + } + catch (Exception e) + { + Main.helper.Log("Error in add villager hook"); + + Main.helper.Log(e.Message); + Main.helper.Log(e.StackTrace); + } + } + } + } + + #endregion + + + #region "Tree Hooks" + + + [HarmonyPatch(typeof(TreeSystem), "FellTree")] + public class TreeSystemFellTreeHook + { + /*static IEnumerable TargetMethods() + { + var tmeth = typeof(TreeSystem).GetMethods(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance | BindingFlags.DeclaredOnly | BindingFlags.Static); + + return tmeth.Cast(); + }*/ + + public static void Postfix(MethodBase __originalMethod, Cell cell, int idx) + { + if (KCClient.client.IsConnected) + { + //Main.helper.Log($"Called by: {new StackFrame(3).GetMethod().kingdomName}"); + + if (new StackFrame(3).GetMethod().Name.Contains("HandlePacket")) + return; + + new FellTree() + { + idx = idx, + x = cell.x, + z = cell.z + }.Send(); + } + } + } + + [HarmonyPatch(typeof(TreeSystem), "ShakeTree")] + public class TreeSystemShakeTreeHook + { + public static void Postfix(MethodBase __originalMethod, int idx) + { + if (KCClient.client.IsConnected) + { + //Main.helper.Log($"Called by: {new StackFrame(3).GetMethod().kingdomName}"); + + if (new StackFrame(3).GetMethod().Name.Contains("HandlePacket")) + return; + + new ShakeTree() + { + idx = idx + }.Send(); + } + } + } + + [HarmonyPatch(typeof(TreeSystem), "GrowTree")] + public class TreeSystemGrowTreeHook + { + /*public static bool Prefix() + { + if (KCClient.client.IsConnected && !KCServer.IsRunning) + { + //Main.helper.Log($"Called by: {new StackFrame(3).GetMethod().kingdomName}"); + + if (!new StackFrame(3).GetMethod().Name.Contains("HandlePacket")) + return false; + + + } + + return true; + }*/ + + // Only server should send this information + public static void Postfix(MethodBase __originalMethod, Cell cell) + { + if (KCServer.IsRunning) + { + //Main.helper.Log($"Called by: {new StackFrame(3).GetMethod().kingdomName}"); + + if (new StackFrame(3).GetMethod().Name.Contains("HandlePacket")) + return; + + new GrowTree() + { + X = cell.x, + Z = cell.z + }.SendToAll(KCClient.client.Id); + } + } + } + + #endregion + + #region "Weather Hooks" + [HarmonyPatch(typeof(Weather), "ChangeWeather")] + public class WeatherChangeWeatherHook + { + public static void Postfix(MethodBase __originalMethod, Weather.WeatherType type) + { + if (KCServer.IsRunning && KCClient.client.IsConnected) + { + //Main.helper.Log($"Called by: {new StackFrame(3).GetMethod().kingdomName}"); + + if (new StackFrame(3).GetMethod().Name.Contains("HandlePacket")) + return; + + if (type != Weather.inst.currentWeather) + new ChangeWeather() + { + weatherType = (int)type + }.Send(); + } + } + } + #endregion + + #region "Building Hooks" + + [HarmonyPatch(typeof(Building), "CompleteBuild")] + public class BuildingCompleteBuildHook + { + public static bool Prefix(MethodBase __originalMethod, Building __instance) + { + if (KCClient.client.IsConnected) + { + Main.helper.Log("Overridden complete build"); + Player player = Main.GetPlayerByTeamID(__instance.TeamID()); + + //Main.helper.Log($"Called by: {new StackFrame(3).GetMethod().kingdomName}"); + typeof(Building).GetField("built", BindingFlags.NonPublic | BindingFlags.Instance).SetValue(__instance, true); + + __instance.UpdateMaterialSelection(); + __instance.SendMessage("OnBuilt", SendMessageOptions.DontRequireReceiver); + + + typeof(Building).GetField("yearBuilt", BindingFlags.NonPublic | BindingFlags.Instance).SetValue(__instance, player.CurrYear); + + typeof(Building).GetMethod("AddAllResourceProviders", BindingFlags.NonPublic | BindingFlags.Instance).Invoke(__instance, null); + + + player.BuildingNowBuilt(__instance); + + typeof(Building).GetMethod("TryAddJobs", BindingFlags.NonPublic | BindingFlags.Instance).Invoke(__instance, null); + __instance.BakePathing(); + + return false; + } + + return true; + } + } + + [HarmonyPatch(typeof(Building), "UpdateConstruction")] + public class BuildingUpdateHook + { + public static void Prefix(Building __instance) + { + try + { + if (KCClient.client.IsConnected) + { + if (__instance.TeamID() == Player.inst.PlayerLandmassOwner.teamId) + StateObserver.RegisterObserver(__instance, new string[] { + "customName", "guid", "UniqueName", "built", "placed", "open", "doBuildAnimation", "constructionPaused", "constructionProgress", "resourceProgress", + "Life", "ModifiedMaxLife", "CollectForBuild", "yearBuilt", "decayProtection", "seenByPlayer", + }, BuildingStateManager.BuildingStateChanged, BuildingStateManager.SendBuildingUpdate); + + //StateObserver.Update(__instance); + } + } + catch (Exception e) + { + helper.Log(e.ToString()); + helper.Log(e.Message); + helper.Log(e.StackTrace); + } + } + } + + #endregion + + #region "Time Hooks" + // TimeManager TrySetSpeed hook + [HarmonyPatch(typeof(SpeedControlUI), "SetSpeed")] + public class SpeedControlUISetSpeedHook + { + private static long lastTime = 0; + + public static bool Prefix() + { + if (KCClient.client.IsConnected) + { + if ((DateTimeOffset.Now.ToUnixTimeMilliseconds() - lastTime) < 250) // Set speed spam fix / hack + return false; + + if (!new StackFrame(3).GetMethod().Name.Contains("HandlePacket")) + return false; + } + + return true; + } + + public static void Postfix(int idx, bool skipNextSfx) + { + if (KCClient.client.IsConnected) + { + /*Main.helper.Log($"set speed Called by 0: {new StackFrame(0).GetMethod()} {new StackFrame(0).GetMethod().Name.Contains("HandlePacket")}"); + Main.helper.Log($"set speed Called by 1: {new StackFrame(1).GetMethod()} {new StackFrame(1).GetMethod().Name.Contains("HandlePacket")}"); + Main.helper.Log($"set speed Called by 2: {new StackFrame(2).GetMethod()} {new StackFrame(2).GetMethod().Name.Contains("HandlePacket")}"); + Main.helper.Log($"set speed Called by 3: {new StackFrame(3).GetMethod()} {new StackFrame(3).GetMethod().Name.Contains("HandlePacket")}");*/ + + if (new StackFrame(3).GetMethod().Name.Contains("HandlePacket")) + return; + + new SetSpeed() + { + speed = idx + }.Send(); + + lastTime = DateTimeOffset.Now.ToUnixTimeMilliseconds(); + } + } + } + #endregion + + #region "SteamManager Hook" + [HarmonyPatch] + public class SteamManagerAwakeHook + { + static IEnumerable TargetMethods() + { + var meth = typeof(SteamManager).GetMethods(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance | BindingFlags.DeclaredOnly); + return meth.Cast(); + } + + public static bool Prefix(MethodBase __originalMethod) + { + return false; + } + } + #endregion + + #region "Dragon Hooks" + + #region "Dragon Spawn Hooks" + [HarmonyPatch(typeof(DragonSpawn), "SpawnSiegeDragon")] + public class DragonSpawnSpawnSiegeDragonHook + { + public static bool Prefix() + { + if (KCClient.client.IsConnected && !KCServer.IsRunning && !new StackFrame(3).GetMethod().Name.Contains("HandlePacket")) + return false; + + return true; + } + public static void Postfix(MethodBase __originalMethod, Vector3 start) + { + if (KCClient.client.IsConnected) + { + //Main.helper.Log($"Called by: {new StackFrame(3).GetMethod().kingdomName}"); + + if (new StackFrame(3).GetMethod().Name.Contains("HandlePacket")) + return; + + new SpawnSiegeDragonPacket() { start = start }.Send(); + } + } + } + + [HarmonyPatch(typeof(DragonSpawn), "SpawnMamaDragon", new Type[] { typeof(Vector3) })] + public class DragonSpawnSpawnMamaDragonHook + { + public static bool Prefix() + { + if (KCClient.client.IsConnected && !KCServer.IsRunning && !new StackFrame(3).GetMethod().Name.Contains("HandlePacket")) + return false; + + return true; + } + public static void Postfix(MethodBase __originalMethod, Vector3 start) + { + if (KCClient.client.IsConnected) + { + //Main.helper.Log($"Called by: {new StackFrame(3).GetMethod().kingdomName}"); + + if (new StackFrame(3).GetMethod().Name.Contains("HandlePacket")) + return; + + new SpawnMamaDragonPacket() { start = start }.Send(); + } + } + } + + [HarmonyPatch(typeof(DragonSpawn), "SpawnBabyDragon", new Type[] { typeof(Vector3) })] + public class DragonSpawnSpawnBabyDragonHook + { + public static bool Prefix() + { + if (KCClient.client.IsConnected && !KCServer.IsRunning && !new StackFrame(3).GetMethod().Name.Contains("HandlePacket")) + return false; + + return true; + } + public static void Postfix(MethodBase __originalMethod, Vector3 start) + { + if (KCClient.client.IsConnected) + { + //Main.helper.Log($"Called by: {new StackFrame(3).GetMethod().kingdomName}"); + + if (new StackFrame(3).GetMethod().Name.Contains("HandlePacket")) + return; + + new SpawnBabyDragonPacket() { start = start }.Send(); + } + } + } + #endregion + + #endregion + + #region "Villager Hooks" + [HarmonyPatch(typeof(Villager), "TeleportTo")] + public class VillagerTeleportToHook + { + public static void Postfix(Villager __instance, Vector3 newPos) + { + if (KCClient.client.IsConnected) + { + if (new StackFrame(3).GetMethod().Name.Contains("HandlePacket")) + return; + + new VillagerTeleportTo() + { + guid = __instance.guid, + pos = newPos + }.Send(); + } + } + } + #endregion + + #region "Job Hooks" + + /*[HarmonyPatch(typeof(Job), "OnEmployeeQuit")] + public class JobOnEmployeeQuitHook + { + public static Player oldPlayer; + + public static void Prefix(Job __instance) + { + if (KCClient.client.IsConnected) + { + oldPlayer = Player.inst; + + Player.inst = Main.GetPlayerByTeamID(World.GetLandmassOwner(__instance.employer.LandMass()).teamId); + } + } + + public static void Postfix(Job __instance) + { + if (KCClient.client.IsConnected) + { + Player.inst = oldPlayer; + } + } + }*/ + + #endregion + + #region "LoadSave Hooks" + [HarmonyPatch(typeof(LoadSave), "GetSaveDir")] + public class LoadSaveGetSaveDirHook + { + public static bool Prefix(ref string __result) + { + Main.helper.Log("Get save dir"); + if (KCClient.client.IsConnected) + { + if (KCServer.IsRunning) + { + + } + __result = Application.persistentDataPath + "/Saves/Multiplayer"; + + return false; + } + + __result = Application.persistentDataPath + "/Saves"; ; + return true; + } + } + + [HarmonyPatch(typeof(LoadSave), "LoadAtPath")] + public class LoadSaveLoadAtPathHook + { + //public static string saveFile = ""; + public static byte[] saveData = new byte[0]; + + public static bool Prefix(string path, string filename, bool visitedWorld) + { + if (KCServer.IsRunning) + { + Main.helper.Log("Trying to load multiplayer save"); + LoadSave.LastLoadDirectory = path; + path = path + "/" + filename; + + + bool flag = !File.Exists(path); + if (!flag) + { + BinaryFormatter bf = new BinaryFormatter(); + bf.Binder = new MultiplayerSaveDeserializationBinder(); + saveData = File.ReadAllBytes(path); + Stream file = new FileStream(path, FileMode.Open); + try + { + MultiplayerSaveContainer loadData = (MultiplayerSaveContainer)bf.Deserialize(file); + loadData.Unpack(null); + Broadcast.OnLoadedEvent.Broadcast(new OnLoadedEvent()); + } + catch (Exception e) + { + GameState.inst.mainMenuMode.TransitionTo(MainMenuMode.State.LoadError); + Main.helper.Log("Error loading save"); + Main.helper.Log(e.Message); + Main.helper.Log(e.StackTrace); + throw; + } + finally + { + bool flag2 = file != null; + if (flag2) + { + file.Close(); + file.Dispose(); + } + } + } + + return false; + } + + return true; + } + } + + [HarmonyPatch(typeof(LoadSave), "Load")] + public class LoadSaveLoadHook + { + public static bool memoryStreamHook = false; + + public static byte[] saveBytes = new byte[0]; + + public static MultiplayerSaveContainer saveContainer; + + public static bool Prefix() + { + if (memoryStreamHook) + { + Main.helper.Log("Attempting to load save from server"); + + using (MemoryStream ms = new MemoryStream(saveBytes)) + { + BinaryFormatter bf = new BinaryFormatter(); + bf.Binder = new MultiplayerSaveDeserializationBinder(); + saveContainer = (MultiplayerSaveContainer)bf.Deserialize(ms); + } + + memoryStreamHook = false; + return false; + } + + return true; + } + } + + [HarmonyPatch(typeof(LoadSave), "Save")] + public class LoadSaveSaveHook + { + private class OutData + { + // Token: 0x04002176 RID: 8566 + public string Path; + + // Token: 0x04002177 RID: 8567 + public MultiplayerSaveContainer LoadSaveContainer; + } + + private static void OutToFile(object data) + { + OutData outData = (OutData)data; + BinaryFormatter bf = new BinaryFormatter(); + Stream file = null; + try + { + file = new FileStream(outData.Path, FileMode.Create, FileAccess.Write); + bf.Serialize(file, outData.LoadSaveContainer); + } + catch (Exception e) + { + LoadSave.AppendToLocalErrorLog(string.Concat(new string[] + { + "Problem during save", + Environment.NewLine, + e.Message, + Environment.NewLine, + e.StackTrace + })); + } + finally + { + bool flag = file != null; + if (flag) + { + file.Close(); + file.Dispose(); + } + } + } + + public static bool Prefix(string pathOverride, UnityAction onCompleteCallback, ref Thread __result) + { + if (KCServer.IsRunning) + { + Directory.CreateDirectory(LoadSave.GetSaveDir()); + Guid guid = Guid.NewGuid(); + string path = (pathOverride != "") ? pathOverride : (LoadSave.GetSaveDir() + "/" + guid); + Directory.CreateDirectory(path); + Thread thread; + try + { + thread = new Thread(new ParameterizedThreadStart(OutToFile)); + + MultiplayerSaveContainer packedData = new MultiplayerSaveContainer().Pack(null); + Broadcast.OnSaveEvent.Broadcast(new OnSaveEvent()); + thread.Start(new OutData + { + LoadSaveContainer = packedData, + Path = path + "/world" + }); + } + catch (Exception e) + { + //LoadSave.ErrorToKingdomLog(e); + Main.helper.Log(e.Message); + Main.helper.Log(e.StackTrace); + throw; + } + finally + { + LoadSave.SaveWorldSummaryData(path); + } + + + // Custom banners not implemented yet + /*try + { + bool usingCustomBanner = Player.inst.usingCustomBanner; + if (usingCustomBanner) + { + File.WriteAllBytes(path + "/custombanner.png", Player.inst.customBannerTexture2D.EncodeToPNG()); + } + } + catch (Exception e2) + { + Main.helper.Log(e2.Message); + }*/ + + try + { + World.inst.TakeScreenshot(path + "/cover", new Func(World.inst.Func_CaptureWorldShot), onCompleteCallback); + } + catch (Exception e3) + { + Main.helper.Log(e3.Message); + } + bool flag = onCompleteCallback != null; + if (flag) + { + bool flag2 = thread != null && thread.ThreadState == System.Threading.ThreadState.Running; + if (flag2) + { + thread.Join(); + } + } + GC.Collect(); + + __result = thread; + + return false; + } + + return true; + } + } + + [HarmonyPatch(typeof(SaveLoadUI), "ClickLoadItem")] + public class SaveLoadUIClickedLoadItemHook + { + public static bool Prefix(SaveLoadUI __instance, string id) + { + if (KCServer.IsRunning) + { + + LoadSave.Load(id); + TransitionTo(MenuState.ServerLobby); + //GameState.inst.SetNewMode(GameState.inst.playingMode); + + + return false; + } + + return true; + } + } + + [HarmonyPatch(typeof(Player.PlayerSaveData), "ProcessBuilding")] + public class PlayerProcessBuildingHook + { + public static bool Prefix(Building.BuildingSaveData structureData, Player p, ref Building __result) + { + if (KCClient.client.IsConnected) + { + + Building Building = GameState.inst.GetPlaceableByUniqueName(structureData.uniqueName); + bool flag = Building; + if (flag) + { + Building building = UnityEngine.Object.Instantiate(Building); + building.transform.position = structureData.globalPosition; + building.Init(); + building.transform.SetParent(p.buildingContainer.transform, true); + structureData.Unpack(building); + p.AddBuilding(building); + + Main.helper.Log($"Loading player id: {p.PlayerLandmassOwner.teamId}"); + Main.helper.Log($"loading building: {building.FriendlyName}"); + Main.helper.Log($" (teamid: {building.TeamID()})"); + Main.helper.Log(p.ToString()); + bool flag2 = building.GetComponent() != null && building.TeamID() == p.PlayerLandmassOwner.teamId; + Main.helper.Log("Set keep? " + flag2); + if (flag2) + { + p.keep = building.GetComponent(); + Main.helper.Log(p.keep.ToString()); + } + __result = building; + } + else + { + Main.helper.Log(structureData.uniqueName + " failed to load correctly"); + __result = null; + } + + return false; + } + + return true; + } + } + + [HarmonyPatch(typeof(Player.PlayerSaveData), "Pack")] + public class PlayerSaveDataPackgHook + { + public static bool Prefix(Player.PlayerSaveData __instance, Player p, ref Player.PlayerSaveData __result) + { + if (KCClient.client.IsConnected) + { + var bindingFlags = BindingFlags.Instance | BindingFlags.NonPublic; + + Main.helper.Log("Running patched player pack method"); + Main.helper.Log("Saving banner system"); + __instance.newBannerSystem = true; + Main.helper.Log("Saving player creativeMode"); + __instance.creativeMode = p.creativeMode; + + //cmo options not used for saving or loading in multiplayer + /**for (int i = 0; i < p.cmoOptionsOn.Length; i++) + { + bool flag = p.cmoOptionsOn[i]; + if (flag) + { + __instance.cmoOptions.Add((Player.CreativeOptions)i); + } + }**/ + + Main.helper.Log("Saving player upgrades"); + __instance.GetType().GetField("upgrades", bindingFlags).SetValue(__instance, new List()); + + + Main.helper.Log("Saving player bannerIdx"); + __instance.bannerIdx = p.PlayerLandmassOwner.bannerIdx; + + Main.helper.Log("Saving player WorkersArray"); + __instance.WorkersArray = new Villager.VillagerSaveData[p.Workers.Count]; + for (int j = 0; j < p.Workers.Count; j++) + { + bool flag2 = p.Workers.data[j] != null; + if (flag2) + { + __instance.WorkersArray[j] = new Villager.VillagerSaveData().Pack(p.Workers.data[j]); + } + } + + Main.helper.Log("Saving player HomelessData"); + __instance.HomelessData = new List(); + for (int k = 0; k < p.Homeless.Count; k++) + { + __instance.HomelessData.Add(p.Homeless.data[k].guid); + } + __instance.structures = new List(); + __instance.subStructures = new List(); + + Main.helper.Log("Saving player structures"); + World.inst.ForEachTile(0, 0, World.inst.GridWidth, World.inst.GridHeight, delegate (int x, int z, Cell cell) + { + bool flag4 = cell.OccupyingStructure.Count > 0; + if (flag4) + { + List occupyingStructureData = new List(); + for (int i3 = 0; i3 < cell.OccupyingStructure.Count; i3++) + { + var building = cell.OccupyingStructure[i3]; + bool flag5 = Vector3.Distance(cell.OccupyingStructure[i3].transform.position.xz(), cell.Position.xz()) <= 1E-05f; + if (flag5 && building.TeamID() == p.PlayerLandmassOwner.teamId) + { + occupyingStructureData.Add(new Building.BuildingSaveData().Pack(cell.OccupyingStructure[i3])); + } + } + bool flag6 = occupyingStructureData.Count > 0; + if (flag6) + { + __instance.structures.Add(occupyingStructureData.ToArray()); + } + } + bool flag7 = cell.SubStructure.Count > 0; + if (flag7) + { + List subStructureData = new List(); + for (int i4 = 0; i4 < cell.SubStructure.Count; i4++) + { + var building = cell.SubStructure[i4]; + bool flag8 = Vector3.Distance(cell.SubStructure[i4].transform.position.xz(), cell.Position.xz()) <= 1E-05f; + if (flag8 && building.TeamID() == p.PlayerLandmassOwner.teamId) + { + subStructureData.Add(new Building.BuildingSaveData().Pack(cell.SubStructure[i4])); + } + } + bool flag9 = subStructureData.Count > 0; + if (flag9) + { + __instance.subStructures.Add(subStructureData.ToArray()); + } + } + }); + + Main.helper.Log($"Saving town happiness"); + __instance.TownHappiness = p.KingdomHappiness; + + Main.helper.Log($"Saving town happiness infos"); + __instance.happinessInfos = p.GetType().GetField("landMassHappiness", bindingFlags).GetValue(p) as List; + + Main.helper.Log($"Saving town integrity infos"); + __instance.integrityInfos = p.GetType().GetField("landMassIntegrity", bindingFlags).GetValue(p) as List; + + Main.helper.Log($"Saving town landmass owner"); + __instance.playerLandmassOwnerSaveData = new LandmassOwner.LandmassOwnerSaveData().Pack(p.PlayerLandmassOwner); + + Main.helper.Log($"Saving town bDidFirstFire"); + __instance.bDidFirstFire = (bool)p.GetType().GetField("bDidFirstFire", bindingFlags).GetValue(p); + + bool flag3 = p.taxRates != null; + if (flag3) + { + + Main.helper.Log($"Saving town tax rates"); + __instance.TaxRates = new float[p.taxRates.Length]; + Array.Copy(p.taxRates, __instance.TaxRates, p.taxRates.Length); + } + + Main.helper.Log($"Saving difficulty"); + __instance.Difficulty = p.difficulty; + + Main.helper.Log($"Saving CurrYear"); + __instance.CurrYear = p.CurrYear; + + Main.helper.Log($"Saving timeAtFailHappiness"); + __instance.timeAtFailHappiness = p.timeAtFailHappiness; + + Main.helper.Log($"Saving happinessMods"); + __instance.happinessMods = p.happinessMods; + + Main.helper.Log($"Saving currConsumption"); + __instance.currConsumptionList = p.currConsumption; + + Main.helper.Log($"Saving lastConsumption"); + __instance.lastConsumptionList = p.lastConsumption; + + Main.helper.Log($"Saving currProduction"); + __instance.currProductionList = p.currProduction; + + Main.helper.Log($"Saving lastProduction"); + __instance.lastProductionList = p.lastProduction; + + Main.helper.Log($"Saving landMassNames"); + __instance.landMassNames = new List(); + for (int l = 0; l < p.LandMassNames.Count; l++) + { + __instance.landMassNames.Add(p.LandMassNames[l]); + } + + Main.helper.Log($"Saving JobPriorityOrder"); + __instance.JobPriorityOrder = new int[p.JobPriorityOrder.Length][]; + __instance.JobEnabledFlag = new bool[p.JobEnabledFlag.Length][]; + for (int m = 0; m < p.JobPriorityOrder.Length; m++) + { + __instance.JobPriorityOrder[m] = new int[p.JobPriorityOrder[m].Length]; + __instance.JobEnabledFlag[m] = new bool[p.JobEnabledFlag[m].Length]; + Array.Copy(p.JobPriorityOrder[m], __instance.JobPriorityOrder[m], __instance.JobPriorityOrder[m].Length); + Array.Copy(p.JobEnabledFlag[m], __instance.JobEnabledFlag[m], __instance.JobEnabledFlag[m].Length); + } + + Main.helper.Log($"Saving JobFilledAvailable"); + __instance.JobFilledAvailable = new int[World.inst.NumLandMasses][]; + __instance.JobCustomMaxEnabledFlag = new bool[World.inst.NumLandMasses][]; + for (int lm = 0; lm < World.inst.NumLandMasses; lm++) + { + __instance.JobFilledAvailable[lm] = new int[38]; + __instance.JobCustomMaxEnabledFlag[lm] = new bool[38]; + for (int n = 0; n < 38; n++) + { + __instance.JobFilledAvailable[lm][n] = p.JobFilledAvailable.data[lm][n, 1]; + } + Array.Copy(p.JobCustomMaxEnabledFlag[lm], __instance.JobCustomMaxEnabledFlag[lm], __instance.JobCustomMaxEnabledFlag[lm].Length); + } + + Main.helper.Log($"Saving CanUseTools"); + __instance.CanUseTools = new bool[p.CanUseTools.Length][]; + for (int i2 = 0; i2 < p.CanUseTools.Length; i2++) + { + __instance.CanUseTools[i2] = new bool[p.CanUseTools[i2].Length]; + Array.Copy(p.CanUseTools[i2], __instance.CanUseTools[i2], __instance.CanUseTools[i2].Length); + } + + Main.helper.Log($"Saving usedCheats"); + __instance.usedCheats = p.hasUsedCheats; + + Main.helper.Log($"Saving nameForOldAgeDeath"); + __instance.nameForOldAgeDeath = (string)p.GetType().GetField("nameForOldAgeDeath", bindingFlags).GetValue(p); + + Main.helper.Log($"Saving deathsThisYear"); + __instance.deathsThisYear = (int)p.GetType().GetField("deathsThisYear", bindingFlags).GetValue(p); + + Main.helper.Log($"Saving poorHealthGracePeriod"); + __instance.poorHealthGracePeriod = (float)p.GetType().GetField("poorHealthGracePeriod", bindingFlags).GetValue(p); + + Main.helper.Log($"Saving dockOpenings"); + __instance.dockOpenings = p.GetType().GetField("dockOpenings", bindingFlags).GetValue(p) as List; + + Main.helper.Log($"Saving tourism"); + __instance.tourism = p.tourism; + + + __result = __instance; + + return false; + } + + return true; + } + } + + /*[HarmonyPatch(typeof(Player.PlayerSaveData), "Unpack")] + public class PlayerSaveDataUnpackHook + { + public static bool Prefix(Player.PlayerSaveData __instance, Player p, ref Player __result) + { + Main.helper.Log("Running patched player unpack method"); + if (KCClient.client.IsConnected) + { + Main.helper.Log("Running patched unpack method"); + + var bindingFlags = BindingFlags.Instance | BindingFlags.NonPublic; + Main.helper.Log("1"); + Weather.inst.weatherTimeScale = 0f; + p.creativeMode = __instance.creativeMode; + p.ResetPerLandMassData(); + Main.helper.Log("2"); + bool flag = __instance.JobPriorityOrder != null && __instance.JobPriorityOrder[0].Length == ((int[])p.GetType().GetField("defaultPriorityOrder", bindingFlags).GetValue(p)).Length; + if (flag) + { + Main.helper.Log(__instance.JobPriorityOrder.Length.ToString()); + Main.helper.Log(__instance.JobEnabledFlag.Length.ToString()); + Main.helper.Log(p.JobEnabledFlag.Length.ToString()); + + for (int i = 0; i < __instance.JobPriorityOrder.Length; i++) + { + Array.Copy(__instance.JobPriorityOrder[i], p.JobPriorityOrder[i], __instance.JobPriorityOrder[i].Length); + Main.helper.Log("2.1"); + Array.Copy(__instance.JobEnabledFlag[i], p.JobEnabledFlag[i], __instance.JobPriorityOrder[i].Length); + Main.helper.Log("2.2"); + } + } + Main.helper.Log("3"); + bool flag2 = __instance.JobFilledAvailable != null && __instance.JobFilledAvailable.Length == p.JobFilledAvailable.data.Length; + if (flag2) + { + for (int lm = 0; lm < World.inst.NumLandMasses; lm++) + { + bool flag3 = __instance.JobFilledAvailable[lm].Length != p.JobFilledAvailable.data[lm].Length / 2; + if (flag3) + { + break; + } + for (int j = 0; j < 38; j++) + { + p.JobFilledAvailable.data[lm][j, 0] = 0; + p.JobFilledAvailable.data[lm][j, 1] = __instance.JobFilledAvailable[lm][j]; + } + Array.Copy(__instance.JobCustomMaxEnabledFlag[lm], p.JobCustomMaxEnabledFlag[lm], __instance.JobCustomMaxEnabledFlag[lm].Length); + } + } + Main.helper.Log("4"); + // not saving creative info + p.ResetCreativeModeOptions(); + + p.KingdomHappiness = __instance.TownHappiness; + Main.helper.Log("5"); + p.GetType().GetField("landMassHappiness", bindingFlags).SetValue(p, __instance.happinessInfos); + var landMassHappiness = p.GetType().GetField("landMassHappiness", bindingFlags).GetValue(p) as List; + + bool flag5 = landMassHappiness == null; + if (flag5) + { + landMassHappiness = new List(); + } + while (landMassHappiness.Count < World.inst.NumLandMasses) + { + landMassHappiness.Add(new Player.HappinessInfo()); + } + + Main.helper.Log("6"); + p.GetType().GetField("landMassHealth", bindingFlags).SetValue(p, __instance.healthInfos); + var landMassHealth = p.GetType().GetField("landMassHealth", bindingFlags).GetValue(p) as List; + + bool flag6 = landMassHealth == null; + if (flag6) + { + landMassHealth = new List(); + } + while (landMassHealth.Count < World.inst.NumLandMasses) + { + landMassHealth.Add(new Player.HealthInfo()); + } + + Main.helper.Log("7"); + + p.GetType().GetField("landMassIntegrity", bindingFlags).SetValue(p, __instance.integrityInfos); + var landMassIntegrity = p.GetType().GetField("landMassIntegrity", bindingFlags).GetValue(p) as List; + + bool flag7 = landMassIntegrity == null; + if (flag7) + { + landMassIntegrity = new List(); + } + while (landMassIntegrity.Count < World.inst.NumLandMasses) + { + landMassIntegrity.Add(new Player.IntegrityInfo()); + } + Main.helper.Log("8"); + p.GetType().GetField("bDidFirstFire", bindingFlags).SetValue(p, __instance.bDidFirstFire); + + bool flag8 = __instance.TaxRates == null; + if (flag8) + { + p.taxRates = new float[World.inst.NumLandMasses]; + int l = 0; + int m = World.inst.NumLandMasses; + while (l < m) + { + p.taxRates[l] = (float)__instance.TaxRate; + l++; + } + } + else + { + p.taxRates = new float[__instance.TaxRates.Length]; + Array.Copy(__instance.TaxRates, p.taxRates, __instance.TaxRates.Length); + } + p.difficulty = __instance.Difficulty; + p.CurrYear = __instance.CurrYear; + p.timeAtFailHappiness = __instance.timeAtFailHappiness; + p.currConsumption = __instance.currConsumptionList; + Main.helper.Log("9"); + while (p.currConsumption.Count < World.inst.NumLandMasses) + { + p.currConsumption.Add(new Player.Consumption()); + } + p.lastConsumption = __instance.lastConsumptionList; + while (p.lastConsumption.Count < World.inst.NumLandMasses) + { + p.lastConsumption.Add(new Player.Consumption()); + } + p.currProduction = __instance.currProductionList; + while (p.currProduction.Count < World.inst.NumLandMasses) + { + p.currProduction.Add(new Player.Production()); + } + p.lastProduction = __instance.lastProductionList; + while (p.lastProduction.Count < World.inst.NumLandMasses) + { + p.lastProduction.Add(new Player.Production()); + } + p.happinessMods = __instance.happinessMods; + Main.helper.Log("10"); + bool flag9 = p.happinessMods == null; + if (flag9) + { + p.happinessMods = new List(); + } + Main.helper.Log("11"); + bool flag10 = __instance.landMassNames == null; + if (flag10) + { + p.ResetLandMassNames(); + } + else + { + p.LandMassNames.Clear(); + for (int n = 0; n < __instance.landMassNames.Count; n++) + { + p.LandMassNames.Add(__instance.landMassNames[n]); + } + } + Main.helper.Log("12"); + bool flag11 = __instance.playerLandmassOwnerSaveData != null; + if (flag11) + { + __instance.playerLandmassOwnerSaveData.Unpack(p.PlayerLandmassOwner); + } + else + { + + var Resources = (ResourceAmount)__instance.GetType().GetField("Resources").GetValue(__instance); + + p.PlayerLandmassOwner.Gold = Resources.Get(FreeResourceType.Gold); + for (int i2 = 0; i2 < __instance.structures.Count; i2++) + { + Building.BuildingSaveData[] occupyingStructureData = __instance.structures[i2]; + for (int h = 0; h < occupyingStructureData.Length; h++) + { + int lIdx = World.inst.GetCellData(occupyingStructureData[h].globalPosition).landMassIdx; + bool flag12 = !p.PlayerLandmassOwner.OwnsLandMass(lIdx); + if (flag12) + { + p.PlayerLandmassOwner.TakeOwnership(lIdx); + } + } + } + } + Main.helper.Log("13"); + bool flag13 = __instance.bannerIdx != -1; + if (flag13) + { + p.SetIndexedBanner(__instance.bannerIdx); + } + bool flag14 = !__instance.newBannerSystem; + Main.helper.Log("14"); + if (flag14) + { + p.SetIndexedBanner(5); + p.SetCustomBannerTexture(World.inst.liverySets[__instance.bannerIdx].bannerMaterial.mainTexture as Texture2D); + } + bool flag15 = p.PlayerLandmassOwner.Gold < 0; + if (flag15) + { + p.PlayerLandmassOwner.Gold = 0; + } + + Main.helper.Log("15"); + var upgrades = __instance.GetType().GetField("upgrades", bindingFlags).GetValue(__instance) as List; + for (int i3 = 0; i3 < upgrades.Count; i3++) + { + p.PlayerLandmassOwner.AddUpgrade(upgrades[i3]); + } + p.Workers.Clear(); + p.Homeless.Clear(); + + Main.helper.Log("16"); + Main.helper.Log($"Loading {__instance.WorkersArray.Length} workers from workers array for {p.PlayerLandmassOwner.teamId}"); + bool flag16 = __instance.WorkersArray != null; + if (flag16) + { + + Main.helper.Log("17"); + for (int i4 = 0; i4 < __instance.WorkersArray.Length; i4++) + { + Villager person = Villager.CreateVillager(); + person.Pos = __instance.WorkersArray[i4].pos; + __instance.WorkersArray[i4].Unpack(person); + p.Workers.Add(person); + } + } + else + { + + Main.helper.Log("18"); + Main.helper.Log($"Loading {__instance.Workers.Count} workers for {p.PlayerLandmassOwner.teamId}"); + for (int i5 = 0; i5 < __instance.Workers.Count; i5++) + { + Villager person2 = Villager.CreateVillager(); + person2.Pos = __instance.Workers[i5].pos; + __instance.Workers[i5].Unpack(person2); + p.Workers.Add(person2); + } + } + + Main.helper.Log("19"); + for (int i6 = 0; i6 < __instance.HomelessData.Count; i6++) + { + Villager worker = p.GetWorker(__instance.HomelessData[i6]); + bool flag17 = worker != null; + if (flag17) + { + p.Homeless.Add(worker); + } + } + Main.helper.Log("20"); + + /*List buildingsToPlace = new List(); + for (int i7 = 0; i7 < __instance.structures.Count; i7++) + { + Building.BuildingSaveData[] occupyingStructureData2 = __instance.structures[i7]; + for (int h2 = 0; h2 < occupyingStructureData2.Length; h2++) + { + Building building = __instance.ProcessBuilding(occupyingStructureData2[h2], p); + bool flag18 = building != null; + if (flag18) + { + buildingsToPlace.Add(new Player.PlayerSaveData.BuildingLoadHelper(building, occupyingStructureData2[h2], h2)); + } + } + } + for (int i8 = 0; i8 < __instance.subStructures.Count; i8++) + { + Building.BuildingSaveData[] occupyingStructureData3 = __instance.subStructures[i8]; + for (int h3 = 0; h3 < occupyingStructureData3.Length; h3++) + { + Building building2 = __instance.ProcessBuilding(occupyingStructureData3[h3], p); + bool flag19 = building2 != null; + if (flag19) + { + buildingsToPlace.Add(new Player.PlayerSaveData.BuildingLoadHelper(building2, occupyingStructureData3[h3], h3 - 1000)); + } + } + } + foreach (Player.PlayerSaveData.BuildingLoadHelper buildingHelper in from x in buildingsToPlace + orderby x.priority + select x) + { + World.inst.PlaceFromLoad(buildingHelper.building); + buildingHelper.buildingSaveData.UnpackStage2(buildingHelper.building); + } + for (int i9 = 0; i9 < p.Workers.Count; i9++) + { + bool flag20 = p.Workers.data[i9].Residence == null; + if (flag20) + { + bool flag21 = !p.Homeless.Contains(p.Workers.data[i9]); + if (flag21) + { + p.Homeless.Add(p.Workers.data[i9]); + Debug.LogError("worker with null residence not in homeless saved data..."); + } + } + } + for (int i10 = 0; i10 < buildingsToPlace.Count; i10++) + { + Player.PlayerSaveData.BuildingLoadHelper buildingHelper2 = buildingsToPlace[i10]; + Road roadComp = buildingHelper2.building.GetComponent(); + bool flag22 = roadComp != null; + if (flag22) + { + roadComp.UpdateRotation(); + } + bool flag23 = buildingHelper2.building.uniqueNameHash == Player.PlayerSaveData.aqueductHash; + if (flag23) + { + buildingHelper2.building.GetComponent().UpdateRotation(); + } + } + Cell[] cellData = World.inst.GetCellsData(); + for (int i11 = 0; i11 < cellData.Length; i11++) + { + List structures = cellData[i11].OccupyingStructure; + for (int j2 = 0; j2 < structures.Count; j2++) + { + bool flag24 = structures[j2].categoryHash == Player.PlayerSaveData.castleblockHash; + if (flag24) + { + structures[j2].GetComponent().UpdateStackPostLoad(); + break; + } + } + } + bool flag25 = Player.inst.keep != null; + if (flag25) + { + p.buildingContainer.BroadcastMessage("OnAnyBuildingAdded", Player.inst.keep.GetComponent(), SendMessageOptions.DontRequireReceiver); + } + World.inst.SetupInitialPathCosts(); + Player.inst.irrigation.UpdateIrrigation(); + Player.inst.CalcMaxResources(null, -1); + bool flag26 = __instance.CanUseTools != null; + if (flag26) + { + int i12 = 0; + while (i12 < __instance.CanUseTools.Length && i12 < p.CanUseTools.Length) + { + Array.Copy(__instance.CanUseTools[i12], p.CanUseTools[i12], __instance.CanUseTools[i12].Length); + i12++; + } + } + ToolInfo.inst.SetTogglesFromPlayerData(); + p.hasUsedCheats = __instance.usedCheats; + p.checkedForceCreative = false; + p.nameForOldAgeDeath = __instance.nameForOldAgeDeath; + p.deathsThisYear = __instance.deathsThisYear; + World.inst.RebuildVillagerGrid(); + ArrayExt keepers = Player.inst.GetBuildingList(World.cemeteryKeeperHash); + for (int i13 = 0; i13 < keepers.Count; i13++) + { + keepers.data[i13].GetComponent().RebuildPathData(); + } + p.ChangeHazardPayActive(p.PlayerLandmassOwner.hazardPay, false); + p.poorHealthGracePeriod = __instance.poorHealthGracePeriod; + p.UpdateFocusedLandmass(); + PopulationUI.inst.regionName.RefreshRegionName(); + p.dockOpenings = __instance.dockOpenings; + bool flag27 = p.dockOpenings == null; + if (flag27) + { + p.dockOpenings = new List(); + } + p.IntegrityTimer.ForceExpire(); + bool flag28 = __instance.tourism != null; + if (flag28) + { + p.tourism = __instance.tourism; + } + else + { + p.tourism = new Player.TourismInfo(); + } + + + __result = p; + + return false; + } + + return true; + } + }*/ + + #endregion + + /** + * + * Find all Player.inst references and reconstruct method with references to client planyer + * + * Instantiating main player object and setting landmass teamid in KCPLayer + * + * E.G instead of Player.inst, it should be Main.kCPlayers[Client].player for example, and the rest of the code is the same + * + * Prefix that sets Player.inst to the right client instance and then calls that instances method? + * + */ + + [HarmonyPatch] + public class PlayerReferencePatch + { + static IEnumerable TargetMethods() + { + Assembly assembly = typeof(Player).Assembly; + + Type[] types = new Type[] { typeof(Player)/*, typeof(World), typeof(LandmassOwner), typeof(Keep), typeof(Villager), typeof(DragonSpawn), typeof(DragonController), typeof(Dragon)*/ }; + + var methodsInNamespace = types + .SelectMany(t => t.GetMethods(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance | BindingFlags.DeclaredOnly).Where(m => !m.IsAbstract)) + .ToList(); + + helper.Log("Methods in namespace: " + methodsInNamespace.Count); + + return methodsInNamespace.ToArray().Cast(); + } + + static IEnumerable Transpiler(MethodBase method, IEnumerable instructions) + { + int PlayerInstCount = 0; + + var codes = new List(instructions); + for (var i = 0; i < codes.Count; i++) + { + if (codes[i].opcode == OpCodes.Ldsfld && codes[i].operand.ToString() == "Player inst") + { + PlayerInstCount++; + + codes[i].opcode = (OpCodes.Ldarg_0); // Replace all instance methods static ref with "this" instead of Player.inst + + // Replace ldsfld Player::inst with the sequence to load from Main.kCPlayers + // Step 1: Load Main.kCPlayers onto the evaluation stack. + //codes[i] = new CodeInstruction(OpCodes.Ldsfld, typeof(Main).GetField("kCPlayers")); + + // Step 2: Load the value of Main.PlayerSteamID onto the evaluation stack as the key + //codes.Insert(++i, new CodeInstruction(OpCodes.Ldsfld, typeof(Main).GetField("PlayerSteamID"))); + + // Step 3: Call Dictionary.get_Item(TKey key) to get the Player instance. + //codes.Insert(++i, new CodeInstruction(OpCodes.Callvirt, typeof(Dictionary).GetMethod("get_Item"))); + + // Now, access the 'inst' field of the fetched Player instance, if necessary. + //codes.Insert(++i, new CodeInstruction(OpCodes.Ldfld, typeof(KCPlayer).GetField("inst"))); + } + } + + if (PlayerInstCount > 0) + Main.helper.Log($"Found {PlayerInstCount} static Player.inst references in {method.Name}"); + + return codes.AsEnumerable(); + } + } + + [HarmonyPatch] + public class BuildingPlayerReferencePatch + { + static IEnumerable TargetMethods() + { + Assembly assembly = typeof(Building).Assembly; + + Type[] types = new Type[] { typeof(Building) }; + + var methodsInNamespace = types + .SelectMany(t => t.GetMethods(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance | BindingFlags.DeclaredOnly).Where(m => !m.IsAbstract)) + .ToList(); + + helper.Log("Methods in namespace: " + methodsInNamespace.Count); + + return methodsInNamespace.ToArray().Cast(); + } + + static IEnumerable Transpiler(MethodBase method, IEnumerable instructions) + { + int PlayerInstCount = 0; + + var codes = new List(instructions); + MethodInfo getPlayerByBuildingMethodInfo = typeof(Main).GetMethod("GetPlayerByBuilding", BindingFlags.Static | BindingFlags.Public); + + for (var i = 0; i < codes.Count; i++) + { + if (codes[i].opcode == OpCodes.Ldsfld && codes[i].operand.ToString() == "Player inst") + { + PlayerInstCount++; + + // Check if the current instruction is ldsfld Player.inst + if (codes[i].opcode == OpCodes.Ldsfld && codes[i].operand.ToString().Contains("Player inst")) + { + // Replace the instruction sequence + // Step 1: Load 'this' for the Building instance + codes[i].opcode = OpCodes.Ldarg_0; + + // Step 2: Call GetPlayerByBuilding(Building instance) static method in Main + var callTeamID = new CodeInstruction(OpCodes.Call, getPlayerByBuildingMethodInfo); + codes.Insert(++i, callTeamID); + } + } + } + + if (PlayerInstCount > 0) + Main.helper.Log($"Found {PlayerInstCount} static building Player.inst references in {method.Name}"); + + return codes.AsEnumerable(); + } + } + + + [HarmonyPatch] + public class PlayerPatch + { + static IEnumerable TargetMethods() + { + var meth = typeof(Player).GetMethods(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance | BindingFlags.DeclaredOnly); + return meth.Cast(); + } + + public static bool Prefix(MethodBase __originalMethod, Player __instance) + { + if (__originalMethod.Name.Equals("Awake") && (KCServer.IsRunning || KCClient.client.IsConnected)) + { + helper.Log("Awake run on player instance while server is running"); + + return false; + } + + if (__originalMethod.Name.Equals("Awake") && __instance.gameObject.name.Contains("Client Player")) + { + helper.Log("Awake run on client instance"); + try + { + //___defaultEnabledFlags = new bool[38]; + //for (int i = 0; i < ___defaultEnabledFlags.Length; i++) + //{ + // ___defaultEnabledFlags[i] = true; + //} + //__instance.PlayerLandmassOwner = __instance.gameObject.AddComponent(); + + + + //helper.Log(__instance.PlayerLandmassOwner.ToString()); + } + catch (Exception e) + { + helper.Log(e.ToString()); + helper.Log(e.Message); + helper.Log(e.StackTrace); + } + return false; + } + + if (__originalMethod.Name.Equals("Update") && __instance.gameObject.name.Contains("Client Player")) + { + //helper.Log("Update run on client instance"); + try + { + //___defaultEnabledFlags = new bool[38]; + //for (int i = 0; i < ___defaultEnabledFlags.Length; i++) + //{ + // ___defaultEnabledFlags[i] = true; + //} + //__instance.PlayerLandmassOwner = __instance.gameObject.AddComponent(); + + + + //helper.Log(__instance.PlayerLandmassOwner.ToString()); + } + catch (Exception e) + { + helper.Log(e.ToString()); + helper.Log(e.Message); + helper.Log(e.StackTrace); + } + return false; + } + + if (__originalMethod.Name.Equals("Update")) + { + //helper.Log($"Update called for: {__instance.gameObject.name}"); + + try + { + if (KCClient.client.IsConnected && !__instance.gameObject.name.Contains("Client Player")) + { + StateObserver.RegisterObserver(__instance, new string[] { + "bannerIdx", "kingdomHappiness", "landMassHappiness", "landMassIntegrity", "bDidFirstFire", "CurrYear", + "timeAtFailHappiness", "hasUsedCheats", "nameForOldAgeDeath", "deathsThisYear", /*"poorHealthGracePeriod",*/ + }); + + //StateObserver.Update(__instance); + } + } + catch (Exception e) + { + helper.Log(e.ToString()); + helper.Log(e.Message); + helper.Log(e.StackTrace); + } + return true; + } + + return true; + } + + public static void Postfix(MethodBase __originalMethod, Player __instance) + { + if (__originalMethod.Name.Equals("Update")) + { + //helper.Log($"Update called for: {__instance.gameObject.name} in POSTFIX"); + + + //helper.Log("CHECKING ALL COMPONENTS IN UPDATE: "); + //Component[] components = __instance.gameObject.GetComponents(); + + //foreach (Component component in components) + //{ + // helper.Log("--- " + component.GetType().kingdomName); + //} + } + } + } + + #region "Unity Log Hooks" + + [HarmonyPatch(typeof(UnityEngine.Debug), "Log", new Type[] { typeof(object) })] + public class DebugLogPatch + { + public static void Prefix(object message) + { + if (Main.kCPlayers.Values.Any((p) => message.ToString().StartsWith(p.inst.PlayerLandmassOwner.teamId.ToString()))) + return; + + //Main.helper.Log($"UNITY 3D DEBUG LOG"); + //Main.helper.Log(message.ToString()); + } + } + + [HarmonyPatch(typeof(UnityEngine.Debug), "Log", new Type[] { typeof(object), typeof(UnityEngine.Object) })] + public class DebugLogCPatch + { + public static void Prefix(object message, UnityEngine.Object context) + { + if (Main.kCPlayers.Values.Any((p) => message.ToString().StartsWith(p.inst.PlayerLandmassOwner.teamId.ToString()))) + return; + + //Main.helper.Log($"UNITY 3D DEBUG LOG"); + //Main.helper.Log(context.ToString()); + //Main.helper.Log(message.ToString()); + } + } + + + [HarmonyPatch(typeof(UnityEngine.Debug), "LogError", new Type[] { typeof(object) })] + public class DebugLogErrorPatch + { + public static void Prefix(object message) + { + if (message.ToString().StartsWith(Player.inst?.PlayerLandmassOwner?.teamId.ToString() ?? "")) + return; + + //Main.helper.Log($"UNITY 3D DEBUG LOG ERROR"); + //Main.helper.Log(message.ToString()); + } + } + + [HarmonyPatch(typeof(UnityEngine.Debug), "LogError", new Type[] { typeof(object), typeof(UnityEngine.Object) })] + public class DebugLogErrorCPatch + { + public static void Prefix(object message, UnityEngine.Object context) + { + if (message.ToString().StartsWith(Player.inst?.PlayerLandmassOwner?.teamId.ToString() ?? "")) + return; + + //Main.helper.Log($"UNITY 3D DEBUG LOG ERROR"); + //Main.helper.Log(context.ToString()); + //Main.helper.Log(message.ToString()); + } + } + + [HarmonyPatch(typeof(UnityEngine.Debug), "LogException", new Type[] { typeof(Exception) })] + public class DebugLogExceptionPatch + { + public static void Prefix(Exception exception) + { + //Main.helper.Log($"UNITY 3D DEBUG LOG EXCEPTION"); + //Main.helper.Log(exception.Message); + //Main.helper.Log(exception.StackTrace); + } + } + + [HarmonyPatch(typeof(UnityEngine.Debug), "LogException", new Type[] { typeof(Exception), typeof(UnityEngine.Object) })] + public class DebugLogExceptionCPatch + { + public static void Prefix(Exception exception, UnityEngine.Object context) + { + //Main.helper.Log($"UNITY 3D DEBUG LOG EXCEPTION"); + //Main.helper.Log(exception.Message); + //Main.helper.Log(exception.StackTrace); + } + } + + + [HarmonyPatch(typeof(UnityEngine.Debug), "LogWarning", new Type[] { typeof(object) })] + public class DebugLogWarningPatch + { + public static void Prefix(object message) + { + if (message.ToString().Contains("Failed to send 928 byte")) + return; + + + //Main.helper.Log($"UNITY 3D DEBUG LOG WARNING"); + //Main.helper.Log(message.ToString()); + } + } + + [HarmonyPatch(typeof(UnityEngine.Debug), "LogWarning", new Type[] { typeof(object), typeof(UnityEngine.Object) })] + public class DebugLogWarningCPatch + { + public static void Prefix(object message, UnityEngine.Object context) + { + if (Main.kCPlayers.Values.Any((p) => message.ToString().StartsWith(p.inst.PlayerLandmassOwner.teamId.ToString()))) + return; + + //Main.helper.Log($"UNITY 3D DEBUG LOG WARNING"); + //Main.helper.Log(context.ToString()); + //Main.helper.Log(message.ToString()); + } + } + + #endregion + } + + public class DetailedTransformData + { + public string name; + public string path; + public Vector3 position; + public Vector3 localPosition; + public Quaternion rotation; + public Quaternion localRotation; + public Vector3 eulerAngles; + public Vector3 localEulerAngles; + public Vector3 localScale; + public Vector3 lossyScale; + public List children; + public List components; + + public DetailedTransformData(Transform transform, string parentPath = "") + { + name = transform.name; + path = parentPath == "" ? name : parentPath + "/" + name; + position = transform.position; + localPosition = transform.localPosition; + rotation = transform.rotation; + localRotation = transform.localRotation; + eulerAngles = transform.eulerAngles; + localEulerAngles = transform.localEulerAngles; + localScale = transform.localScale; + lossyScale = transform.lossyScale; + children = new List(); + + components = transform.GetComponents().Select(c => c.GetType().ToString()).ToList(); + + foreach (Transform child in transform) + { + children.Add(new DetailedTransformData(child, path)); + } + } + } + +} diff --git a/ModalManager.cs b/ModalManager.cs new file mode 100644 index 0000000..0361910 --- /dev/null +++ b/ModalManager.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using TMPro; +using UnityEngine; +using UnityEngine.UI; + +namespace KCM +{ + public class ModalManager + { + static GameObject modalInst; + static bool instantiated = false; + + static TMPro.TextMeshProUGUI tmpTitle; + static TMPro.TextMeshProUGUI tmpDescription; + static Button acceptButton; + + static ModalManager() + { + if (!instantiated) + { + modalInst = GameObject.Instantiate(PrefabManager.modalUIPrefab, Constants.MainMenuUI_T); + modalInst.SetActive(false); + + acceptButton = modalInst.transform.Find("Modal/Container/Button").GetComponent(); + + + + tmpTitle = modalInst.transform.Find("Modal/Container/Title").GetComponent(); + tmpDescription = modalInst.transform.Find("Modal/Container/Description").GetComponent(); + + instantiated = true; + } + else + { + throw new Exception("ModalManager is a singleton and may only be instantiated once"); + } + } + + public static void ShowModal(string title, string message, string buttonText = "Okay", bool withButton = true, Action action = null) + { + tmpTitle.text = title; + tmpDescription.text = message; + + acceptButton.gameObject.SetActive(withButton); + + acceptButton.gameObject.GetComponentInChildren().text = buttonText; + + acceptButton.onClick.RemoveAllListeners(); + + acceptButton.onClick.AddListener(() => + { + modalInst.SetActive(false); // Clicked okay + action?.Invoke(); + }); + + modalInst.SetActive(true); + } + + public static void HideModal() + { + modalInst.SetActive(false); + } + } +} diff --git a/Packets/Game/Dragon/SpawnBabyDragonPacket.cs b/Packets/Game/Dragon/SpawnBabyDragonPacket.cs new file mode 100644 index 0000000..35773b5 --- /dev/null +++ b/Packets/Game/Dragon/SpawnBabyDragonPacket.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; + +namespace KCM.Packets.Game.Dragon +{ + public class SpawnBabyDragonPacket : Packet + { + public override ushort packetId => (int)Enums.Packets.SpawnBabyDragon; + + public Vector3 start { get; set; } + + public override void HandlePacketClient() + { + DragonSpawn.inst.SpawnBabyDragon(start); + } + + public override void HandlePacketServer() + { + //throw new NotImplementedException(); + } + } +} diff --git a/Packets/Game/Dragon/SpawnMamaDragonPacket.cs b/Packets/Game/Dragon/SpawnMamaDragonPacket.cs new file mode 100644 index 0000000..9d3a2e8 --- /dev/null +++ b/Packets/Game/Dragon/SpawnMamaDragonPacket.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; + +namespace KCM.Packets.Game.Dragon +{ + public class SpawnMamaDragonPacket : Packet + { + public override ushort packetId => (int)Enums.Packets.SpawnMamaDragon; + + public Vector3 start { get; set; } + + public override void HandlePacketClient() + { + DragonSpawn.inst.SpawnMamaDragon(start); + } + + public override void HandlePacketServer() + { + //throw new NotImplementedException(); + } + } +} diff --git a/Packets/Game/Dragon/SpawnSiegeDragonPacket.cs b/Packets/Game/Dragon/SpawnSiegeDragonPacket.cs new file mode 100644 index 0000000..7aac5fb --- /dev/null +++ b/Packets/Game/Dragon/SpawnSiegeDragonPacket.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; + +namespace KCM.Packets.Game.Dragon +{ + public class SpawnSiegeDragonPacket : Packet + { + public override ushort packetId => (int)Enums.Packets.SpawnSiegeDragon; + + public Vector3 start { get; set; } + + public override void HandlePacketClient() + { + DragonSpawn.inst.SpawnSiegeDragon(start); + } + + public override void HandlePacketServer() + { + //throw new NotImplementedException(); + } + } +} diff --git a/Packets/Game/GameBuilding/CompleteBuild.cs b/Packets/Game/GameBuilding/CompleteBuild.cs new file mode 100644 index 0000000..726257c --- /dev/null +++ b/Packets/Game/GameBuilding/CompleteBuild.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets.Game.GameBuilding +{ + public class CompleteBuild : Packet + { + public override ushort packetId => (int)Enums.Packets.CompleteBuild; + + public Guid buildingId { get; set; } + + public override void HandlePacketClient() + { + if (KCClient.client.Id == clientId) return; + + Player.inst.GetBuilding(buildingId).CompleteBuild(); + } + + public override void HandlePacketServer() + { + //throw new NotImplementedException(); + } + } +} diff --git a/Packets/Game/GameBuilding/UpdateConstruction.cs b/Packets/Game/GameBuilding/UpdateConstruction.cs new file mode 100644 index 0000000..978db16 --- /dev/null +++ b/Packets/Game/GameBuilding/UpdateConstruction.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets.Game.GameBuilding +{ + public class UpdateConstruction : Packet + { + public override ushort packetId => (int)Enums.Packets.UpdateConstruction; + + public Guid buildingId { get; set; } + public float constructionProgress { get; set; } + + public override void HandlePacketClient() + { + if (KCClient.client.Id == clientId) return; + + //Main.helper.Log($"Received packet from: {clientId} receiving client is {KCClient.client.Id}"); + Player.inst.GetBuilding(buildingId).constructionProgress = constructionProgress; + } + + public override void HandlePacketServer() + { + //throw new NotImplementedException(); + } + } +} diff --git a/Packets/Game/GamePlayer/AddVillagerPacket.cs b/Packets/Game/GamePlayer/AddVillagerPacket.cs new file mode 100644 index 0000000..ae6d450 --- /dev/null +++ b/Packets/Game/GamePlayer/AddVillagerPacket.cs @@ -0,0 +1,42 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; + +namespace KCM.Packets.Game.GamePlayer +{ + public class AddVillagerPacket : Packet + { + public override ushort packetId => (ushort)Enums.Packets.AddVillager; + + public Guid guid { get; set; } + + public override void HandlePacketClient() + { + try + { + if (KCClient.client.Id == clientId) return; + + Main.helper.Log("Received add villager packet from " + player.name + $"({player.id})"); + + Villager v = Villager.CreateVillager(); + v.guid = guid; + + player.inst.Workers.Add(v); + player.inst.Homeless.Add(v); + + } + catch (Exception e) + { + Main.helper.Log("Error handling add villager packet: " + e.Message); + } + } + + public override void HandlePacketServer() + { + //throw new NotImplementedException(); + } + } +} diff --git a/Packets/Game/GamePlayer/SetupInitialWorkersPacket.cs b/Packets/Game/GamePlayer/SetupInitialWorkersPacket.cs new file mode 100644 index 0000000..69926c4 --- /dev/null +++ b/Packets/Game/GamePlayer/SetupInitialWorkersPacket.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets.Game.GamePlayer +{ + public class SetupInitialWorkersPacket : Packet + { + public override ushort packetId => (ushort)Enums.Packets.SetupInitialWorkers; + + public Guid keepGuid { get; set; } + + + public override void HandlePacketClient() + { + if (KCClient.client.Id == clientId) return; + + /*Keep keep = player.inst.GetBuilding(keepGuid).GetComponent(); + if (keep == null) + { + Main.helper.Log("Keep not found."); + return; + } + + player.inst.SetupInitialWorkers(keep);*/ + } + + public override void HandlePacketServer() + { + + } + } +} diff --git a/Packets/Game/GameTrees/FellTree.cs b/Packets/Game/GameTrees/FellTree.cs new file mode 100644 index 0000000..b1e7df0 --- /dev/null +++ b/Packets/Game/GameTrees/FellTree.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets.Game.GameTrees +{ + public class FellTree : Packet + { + public override ushort packetId => (int)Enums.Packets.FellTree; + + public int idx { get; set; } + + public int x { get; set; } + public int z { get; set; } + + public override void HandlePacketClient() + { + Cell cell = World.inst.GetCellData(x, z); + + TreeSystem.inst.FellTree(cell, idx); + } + + public override void HandlePacketServer() + { + //throw new NotImplementedException(); + } + } +} diff --git a/Packets/Game/GameTrees/GrowTree.cs b/Packets/Game/GameTrees/GrowTree.cs new file mode 100644 index 0000000..cfd16e2 --- /dev/null +++ b/Packets/Game/GameTrees/GrowTree.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets.Game.GameTrees +{ + public class GrowTree : Packet + { + public override ushort packetId => (int)Enums.Packets.GrowTree; + + public int X { get; set; } + public int Z { get; set; } + + public override void HandlePacketClient() + { + Cell cell = World.inst.GetCellData(X, Z); + + TreeSystem.inst.GrowTree(cell); + } + + public override void HandlePacketServer() + { + //throw new NotImplementedException(); + } + } +} diff --git a/Packets/Game/GameTrees/ShakeTree.cs b/Packets/Game/GameTrees/ShakeTree.cs new file mode 100644 index 0000000..7fcea07 --- /dev/null +++ b/Packets/Game/GameTrees/ShakeTree.cs @@ -0,0 +1,25 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets.Game.GameTrees +{ + public class ShakeTree : Packet + { + public override ushort packetId => (int)Enums.Packets.ShakeTree; + + public int idx { get; set; } + + public override void HandlePacketClient() + { + TreeSystem.inst.ShakeTree(idx); + } + + public override void HandlePacketServer() + { + //throw new NotImplementedException(); + } + } +} diff --git a/Packets/Game/GameVillager/VillagerTeleportTo.cs b/Packets/Game/GameVillager/VillagerTeleportTo.cs new file mode 100644 index 0000000..0128fda --- /dev/null +++ b/Packets/Game/GameVillager/VillagerTeleportTo.cs @@ -0,0 +1,38 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; + +namespace KCM.Packets.Game.GameVillager +{ + public class VillagerTeleportTo : Packet + { + public override ushort packetId => (ushort)Enums.Packets.VillagerTeleportTo; + + public Guid guid { get; set; } + public Vector3 pos { get; set; } + + public override void HandlePacketClient() + { + if (KCClient.client.Id == clientId) return; + + try + { + Villager.villagers.data.Where(x => x.guid == guid).FirstOrDefault().TeleportTo(pos); + + Main.helper.Log($"Teleporting villager to {pos.ToString()}"); + } + catch (Exception e) + { + Main.helper.Log("Error handling villager teleport packet: " + e.Message); + } + } + + public override void HandlePacketServer() + { + + } + } +} diff --git a/Packets/Game/GameWeather/ChangeWeather.cs b/Packets/Game/GameWeather/ChangeWeather.cs new file mode 100644 index 0000000..3ddbf1d --- /dev/null +++ b/Packets/Game/GameWeather/ChangeWeather.cs @@ -0,0 +1,25 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets.Game.GameWeather +{ + public class ChangeWeather : Packet + { + public override ushort packetId => (int)Enums.Packets.ChangeWeather; + + public int weatherType { get; set; } + + public override void HandlePacketClient() + { + Weather.CurrentWeather = ((Weather.WeatherType)weatherType); + } + + public override void HandlePacketServer() + { + //throw new NotImplementedException(); + } + } +} diff --git a/Packets/Game/GameWorld/WorldPlace.cs b/Packets/Game/GameWorld/WorldPlace.cs new file mode 100644 index 0000000..34b8249 --- /dev/null +++ b/Packets/Game/GameWorld/WorldPlace.cs @@ -0,0 +1,145 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.Serialization.Formatters.Binary; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; + +namespace KCM.Packets.Game.GameWorld +{ + public class WorldPlace : Packet + { + public override ushort packetId => (int)Enums.Packets.WorldPlace; + + public string customName { get; set; } + public Guid guid { get; set; } + public string uniqueName { get; set; } + public Quaternion rotation { get; set; } + public Vector3 globalPosition { get; set; } + public Vector3 localPosition { get; set; } + public bool built { get; set; } + public bool placed { get; set; } + public bool open { get; set; } + public bool doBuildAnimation { get; set; } + public bool constructionPaused { get; set; } + public float constructionProgress { get; set; } + public float life { get; set; } + public float ModifiedMaxLife { get; set; } + public int yearBuilt { get; set; } + public float decayProtection { get; set; } + public bool seenByPlayer { get; set; } + + public override void HandlePacketClient() + { + if (clientId == KCClient.client.Id) return; //prevent double placing on same client + + PlaceBuilding(); + } + + public override void HandlePacketServer() + { + //PlaceBuilding(); + + //SendToAll(clientId); + } + + public void PlaceBuilding() + { + Main.helper.Log("Received place building packet for " + uniqueName + " from " + player.name + $"({player.id})"); + + //var originalPlayer = Player.inst; + //Player.inst = player.inst; + + Building.BuildingSaveData structureData = new Building.BuildingSaveData() + { + uniqueName = uniqueName, + customName = customName, + guid = guid, + rotation = rotation, + globalPosition = globalPosition, + localPosition = localPosition, + built = built, + placed = placed, + open = open, + doBuildAnimation = doBuildAnimation, + constructionPaused = constructionPaused, + constructionProgress = constructionProgress, + life = life, + ModifiedMaxLife = ModifiedMaxLife, + //CollectForBuild = CollectForBuild, + yearBuilt = yearBuilt, + decayProtection = decayProtection, + seenByPlayer = seenByPlayer + }; + + + //Player originalInst = Player.inst; + //Player.inst = player.inst; + + Building Building = GameState.inst.GetPlaceableByUniqueName(structureData.uniqueName); + bool flag = Building; + if (flag) + { + Building building = UnityEngine.Object.Instantiate(Building); + building.transform.position = structureData.globalPosition; + Main.helper.Log("Building init"); + building.Init(); + building.transform.SetParent(player.inst.buildingContainer.transform, true); + Main.helper.Log("Building unpack"); + structureData.Unpack(building); + + Main.helper.Log(player.inst.ToString()); + Main.helper.Log((player.inst.PlayerLandmassOwner == null).ToString()); + Main.helper.Log(building.LandMass().ToString()); + Main.helper.Log("Player add Building unpacked"); + player.inst.AddBuilding(building); + + try + { + + player.inst.PlayerLandmassOwner.TakeOwnership(building.LandMass()); + bool flag2 = building.GetComponent() != null && building.TeamID() == player.inst.PlayerLandmassOwner.teamId; + Main.helper.Log("Set keep " + flag2); + if (flag2) + { + player.inst.keep = building.GetComponent(); + } + } + catch (Exception e) + { + Main.helper.Log(e.Message); + } + + Main.helper.Log("Place from load"); + Cell cell = World.inst.PlaceFromLoad(building); + Main.helper.Log("unpack stage 2"); + structureData.UnpackStage2(building); + + building.SetVisibleForFog(false); + + Main.helper.Log("Landmass owner take ownership"); + + Main.helper.Log($"{player.id} (team {player.inst.PlayerLandmassOwner.teamId}) banner: {player.inst.PlayerLandmassOwner.bannerIdx} Placed building {building.name} at {building.transform.position}"); + + + //Player.inst = originalInst; // Reset player back to normal // Might not be needed anymore with player ref patches? + + + Main.helper.Log($"Host player Landmass Names Count: {Player.inst.LandMassNames.Count}, Contents: {string.Join(", ", Player.inst.LandMassNames)}"); + Main.helper.Log($"Client player ({player.name}) Landmass Names Count: {player.inst.LandMassNames.Count}, Contents: {string.Join(", ", player.inst.LandMassNames)}"); + + player.inst.LandMassNames[building.LandMass()] = player.kingdomName; + Player.inst.LandMassNames[building.LandMass()] = player.kingdomName; + + //Player.inst = originalPlayer; + } + else + { + Main.helper.Log(structureData.uniqueName + " failed to load correctly"); + } + //building.Init(); + } + + } +} diff --git a/Packets/Game/PlaceKeepRandomly.cs b/Packets/Game/PlaceKeepRandomly.cs new file mode 100644 index 0000000..6916e4d --- /dev/null +++ b/Packets/Game/PlaceKeepRandomly.cs @@ -0,0 +1,102 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; +using UnityEngine.Analytics; + +namespace KCM.Packets.Game +{ + public class PlaceKeepRandomly : Packet + { + public override ushort packetId => (ushort)Enums.Packets.PlaceKeepRandomly; + + public int landmassIdx { get; set; } + + public override void HandlePacketClient() + { + try + { + Building keep = UnityEngine.Object.Instantiate(GameState.inst.GetPlaceableByUniqueName(World.keepName)); + + keep.Init(); + + + Cell[] cells = World.inst.GetCellsData().Where(x => x.landMassIdx == landmassIdx).ToArray(); + Cell keepCell = null; + + + foreach (Cell cell in cells) + { + Cell nearbyStoneCell = FindNearbyStoneCell(cells, cell.x, cell.z, landmassIdx, 15); // Place keep within 15 tiles of stone + Cell nearbyWaterCell = FindNearbyWaterCell(cells, cell.x, cell.z, landmassIdx, 6); // Do not place keep within 6 tiles of water + + + Cell clearCell = FindClearCell(cells, cell.x, cell.z, landmassIdx, 4); // cells in 4 by 4 radius are clear? + + if (clearCell != null & nearbyStoneCell != null && nearbyWaterCell == null && cell.Type == ResourceType.None) + { + Console.WriteLine($"Nearby stone cell found at ({nearbyStoneCell.x}, {nearbyStoneCell.z})"); + keepCell = cell; + + break; + } + else + continue; + + } + + keep.transform.position = keepCell.Position; + + keep.SendMessage("OnPlayerPlacement", SendMessageOptions.DontRequireReceiver); + + + Player.inst.PlayerLandmassOwner.TakeOwnership(keep.LandMass()); + Player.inst.keep = keep.GetComponent(); + Player.inst.RefreshVisibility(true); + RandomPlacement(keep); + + } catch (Exception e) + { + Main.helper.Log($"Error placing keep randomly: {e.Message}"); + } + } + + private void RandomPlacement(Building keep) // This is a hack so I can detect when its being called by this packet + { + World.inst.Place(keep); + + Cam.inst.SetTrackingPos(keep.GetPosition()); + } + + private static Cell FindNearbyStoneCell(Cell[] cells, int x, int z, int landmassIdx, int radius) + { + return cells.FirstOrDefault(cell => IsResourceInRadius(cell, x, z, radius, ResourceType.Stone)); + } + + private static Cell FindNearbyWaterCell(Cell[] cells, int x, int z, int landmassIdx, int radius) + { + return cells.FirstOrDefault(cell => IsResourceInRadius(cell, x, z, radius, ResourceType.Water)); + } + private static Cell FindClearCell(Cell[] cells, int x, int z, int landmassIdx, int radius) + { + return cells.FirstOrDefault(cell => IsResourceInRadius(cell, x, z, radius, ResourceType.None)); + } + + private static bool IsResourceInRadius(Cell cell, int x, int z, int radius, ResourceType desiredResource) + { + bool isWithinRadius = Math.Sqrt((cell.x - x) * (cell.x - x) + (cell.z - z) * (cell.z - z)) <= radius; + bool isNotCentralCell = cell.x != x || cell.z != z; + bool isStoneType = cell.Type == desiredResource; + + bool isWater = desiredResource == ResourceType.Water ? false : cell.deepWater || cell.Type == ResourceType.Water; + + return isWithinRadius && isNotCentralCell && isStoneType && !isWater; + } + + public override void HandlePacketServer() + { + } + } +} diff --git a/Packets/Game/SetSpeed.cs b/Packets/Game/SetSpeed.cs new file mode 100644 index 0000000..319eb71 --- /dev/null +++ b/Packets/Game/SetSpeed.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets.Game +{ + public class SetSpeed : Packet + { + public override ushort packetId => (int)Enums.Packets.SetSpeed; + + public int speed { get; set; } + + public override void HandlePacketClient() + { + if (clientId == KCClient.client.Id) // Prevent speed softlock + return; + + SpeedControlUI.inst.SetSpeed(speed); + } + + public override void HandlePacketServer() + { + //throw new NotImplementedException(); + } + } +} diff --git a/Packets/Handlers/LobbyHandler.cs b/Packets/Handlers/LobbyHandler.cs new file mode 100644 index 0000000..34f8a74 --- /dev/null +++ b/Packets/Handlers/LobbyHandler.cs @@ -0,0 +1,203 @@ +using Assets.Code; +using Assets; +using KCM.Attributes; +using KCM.Packets.Lobby; +using KCM.Packets.Network; +using KCM.ServerLobby; +using KCM.ServerLobby.LobbyChat; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; +using UnityEngine.UI; +using System.Reflection; + +namespace KCM.Packets.Handlers +{ + public class LobbyHandler + { + public static ServerSettings ServerSettings = new ServerSettings(); + + public static List playerEntries = new List(); + + + public static void ClearPlayerList() + { + try + { + foreach (GameObject entry in playerEntries) + GameObject.Destroy(entry); + + playerEntries.Clear(); + + if (!KCServer.IsRunning) + { + Main.kCPlayers.Clear(); + } + } + catch (Exception ex) + { + Main.helper.Log("----------------------- Main exception -----------------------"); + Main.helper.Log(ex.ToString()); + Main.helper.Log("----------------------- Main message -----------------------"); + Main.helper.Log(ex.Message); + Main.helper.Log("----------------------- Main stacktrace -----------------------"); + Main.helper.Log(ex.StackTrace); + if (ex.InnerException != null) + { + Main.helper.Log("----------------------- Inner exception -----------------------"); + Main.helper.Log(ex.InnerException.ToString()); + Main.helper.Log("----------------------- Inner message -----------------------"); + Main.helper.Log(ex.InnerException.Message); + Main.helper.Log("----------------------- Inner stacktrace -----------------------"); + Main.helper.Log(ex.InnerException.StackTrace); + } + } + } + + public static void AddPlayerEntry(ushort client) + { + try + { + GameObject entry = GameObject.Instantiate(PrefabManager.serverLobbyPlayerEntryPrefab, ServerLobbyScript.PlayerListContent); + entry.SetActive(true); + Main.helper.Log(entry.ToString()); + var s = entry.AddComponent(); + + s.Client = client; + + playerEntries.Add(entry); + } + catch (Exception ex) + { + Main.helper.Log("----------------------- Main exception -----------------------"); + Main.helper.Log(ex.ToString()); + Main.helper.Log("----------------------- Main message -----------------------"); + Main.helper.Log(ex.Message); + Main.helper.Log("----------------------- Main stacktrace -----------------------"); + Main.helper.Log(ex.StackTrace); + if (ex.InnerException != null) + { + Main.helper.Log("----------------------- Inner exception -----------------------"); + Main.helper.Log(ex.InnerException.ToString()); + Main.helper.Log("----------------------- Inner message -----------------------"); + Main.helper.Log(ex.InnerException.Message); + Main.helper.Log("----------------------- Inner stacktrace -----------------------"); + Main.helper.Log(ex.InnerException.StackTrace); + } + } + } + + public static void AddSystemMessage(string message) + { + try + { + GameObject entry = GameObject.Instantiate(PrefabManager.serverChatSystemEntryPrefab, ServerLobbyScript.PlayerChatContent); + entry.SetActive(true); + chatEntries.Add(entry); + var s = entry.AddComponent(); + + + SnapTo(entry.GetComponent()); + + s.Message = message; + } + catch (Exception ex) + { + Main.helper.Log("----------------------- Main exception -----------------------"); + Main.helper.Log(ex.ToString()); + Main.helper.Log("----------------------- Main message -----------------------"); + Main.helper.Log(ex.Message); + Main.helper.Log("----------------------- Main stacktrace -----------------------"); + Main.helper.Log(ex.StackTrace); + if (ex.InnerException != null) + { + Main.helper.Log("----------------------- Inner exception -----------------------"); + Main.helper.Log(ex.InnerException.ToString()); + Main.helper.Log("----------------------- Inner message -----------------------"); + Main.helper.Log(ex.InnerException.Message); + Main.helper.Log("----------------------- Inner stacktrace -----------------------"); + Main.helper.Log(ex.InnerException.StackTrace); + } + } + } + + public static List chatEntries = new List(); + + public static void AddChatMessage(ushort client, string player, string message) + { + try + { + GameObject entry = GameObject.Instantiate(PrefabManager.serverChatEntryPrefab, ServerLobbyScript.PlayerChatContent); + entry.SetActive(true); + + chatEntries.Add(entry); + + var s = entry.AddComponent(); + + SnapTo(entry.GetComponent()); + + s.Client = client; + s.PlayerName = player; + s.Message = message; + } + catch (Exception ex) + { + Main.helper.Log("----------------------- Main exception -----------------------"); + Main.helper.Log(ex.ToString()); + Main.helper.Log("----------------------- Main message -----------------------"); + Main.helper.Log(ex.Message); + Main.helper.Log("----------------------- Main stacktrace -----------------------"); + Main.helper.Log(ex.StackTrace); + if (ex.InnerException != null) + { + Main.helper.Log("----------------------- Inner exception -----------------------"); + Main.helper.Log(ex.InnerException.ToString()); + Main.helper.Log("----------------------- Inner message -----------------------"); + Main.helper.Log(ex.InnerException.Message); + Main.helper.Log("----------------------- Inner stacktrace -----------------------"); + Main.helper.Log(ex.InnerException.StackTrace); + } + } + } + + public static void ClearChatEntries() + { + try + { + foreach (GameObject entry in chatEntries) + GameObject.Destroy(entry); + + chatEntries.Clear(); + } + catch (Exception ex) + { + Main.helper.Log("----------------------- Main exception -----------------------"); + Main.helper.Log(ex.ToString()); + Main.helper.Log("----------------------- Main message -----------------------"); + Main.helper.Log(ex.Message); + Main.helper.Log("----------------------- Main stacktrace -----------------------"); + Main.helper.Log(ex.StackTrace); + if (ex.InnerException != null) + { + Main.helper.Log("----------------------- Inner exception -----------------------"); + Main.helper.Log(ex.InnerException.ToString()); + Main.helper.Log("----------------------- Inner message -----------------------"); + Main.helper.Log(ex.InnerException.Message); + Main.helper.Log("----------------------- Inner stacktrace -----------------------"); + Main.helper.Log(ex.InnerException.StackTrace); + } + } + } + + + public static void SnapTo(RectTransform target) + { + Canvas.ForceUpdateCanvases(); + + target.parent.parent.parent.GetComponent().normalizedPosition = new Vector2(0, 0); + } + } +} diff --git a/Packets/Handlers/PacketHandler.cs b/Packets/Handlers/PacketHandler.cs new file mode 100644 index 0000000..bbafed6 --- /dev/null +++ b/Packets/Handlers/PacketHandler.cs @@ -0,0 +1,604 @@ +using KCM.Attributes; +using KCM.Packets.Lobby; +using KCM.Packets.Network; +using Riptide; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; + +namespace KCM.Packets.Handlers +{ + public class PacketHandler + { + public static Dictionary Packets = new Dictionary(); + public class PacketRef + { + public IPacket packet; + public PropertyInfo[] properties; + + public PacketRef(IPacket packet, PropertyInfo[] properties) + { + this.packet = packet; + this.properties = properties; + } + } + + + public static Dictionary PacketHandlers = new Dictionary(); + public delegate void PacketHandlerDelegate(IPacket packet); + + public static void Initialise() + { + try + { + Main.helper.Log("Loading Packet Handlers..."); + + //TO-DO Remove this. Packets now have "handle packet" method + #region "Register server packet handlers" + + var serverPacketHandlers = Assembly.GetExecutingAssembly().GetTypes().SelectMany(t => t.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static)) + .Where(m => m.GetCustomAttributes(typeof(PacketHandlerAttribute), false).Length > 0) + .ToArray(); + + foreach (MethodInfo method in serverPacketHandlers) + { + PacketHandlerAttribute attribute = method.GetCustomAttribute(); + + + if (!method.IsStatic) + throw new NonStaticHandlerException(method.DeclaringType, method.Name); + + Delegate packetHandler = Delegate.CreateDelegate(typeof(PacketHandlerDelegate), method, false); + if (packetHandler != null) + { + // It's a message handler for Client instances + if (PacketHandlers.ContainsKey(attribute.packetId)) + { + MethodInfo otherMethodWithId = PacketHandlers[attribute.packetId].GetMethodInfo(); + throw new DuplicateHandlerException(attribute.packetId, method, otherMethodWithId); + } + else + PacketHandlers.Add(attribute.packetId, (PacketHandlerDelegate)packetHandler); + } + else + { + Main.helper.Log($"Failed to register handler: {method.Name}"); + } + } + + Main.helper.Log($"Loaded {PacketHandlers.Count} server handlers"); + + #endregion + + + Main.helper.Log("Loading packets..."); + + var packets = Assembly.GetExecutingAssembly().GetTypes().Where(t => t != null && t.Namespace != null && t.Namespace.StartsWith("KCM.Packets") && !t.IsAbstract && !t.IsInterface && typeof(IPacket).IsAssignableFrom(t)).ToList(); + + foreach (var packet in packets) + { + + IPacket p = (IPacket)Activator.CreateInstance(packet); + var properties = packet.GetProperties(BindingFlags.Instance | BindingFlags.Public).Where(prop => prop.Name != "packetId").ToArray(); + Array.Sort(properties, (x, y) => String.Compare(x.Name, y.Name)); + ushort id = (ushort)p.GetType().GetProperty("packetId").GetValue(p, null); + + if (p.GetType() == typeof(SaveTransferPacket)) + { + Main.helper.Log("SaveTransferPacket"); + Main.helper.Log(string.Join("\n", properties.Select(x => x.Name).ToArray())); + } + + Packets.Add(id, new PacketRef(p, properties)); + Main.helper.Log($"- Registered Packet: {id} {packet.FullName}"); + + } + + Main.helper.Log($"Loaded {Packets.Count} packets"); + } + catch (Exception ex) + { + Main.helper.Log("----------------------- Main exception -----------------------"); + Main.helper.Log(ex.ToString()); + Main.helper.Log("----------------------- Main message -----------------------"); + Main.helper.Log(ex.Message); + Main.helper.Log("----------------------- Main stacktrace -----------------------"); + Main.helper.Log(ex.StackTrace); + if (ex.InnerException != null) + { + Main.helper.Log("----------------------- Inner exception -----------------------"); + Main.helper.Log(ex.InnerException.ToString()); + Main.helper.Log("----------------------- Inner message -----------------------"); + Main.helper.Log(ex.InnerException.Message); + Main.helper.Log("----------------------- Inner stacktrace -----------------------"); + Main.helper.Log(ex.InnerException.StackTrace); + } + } + } + + public static void HandlePacketServer(object sender, MessageReceivedEventArgs messageReceived) + { + var id = messageReceived.MessageId; + + + IPacket packet = DeserialisePacket(messageReceived); + + //Main.helper.Log($"Server Received packet {Packets[id].packet.GetType().Name} from {messageReceived.FromConnection.Id}"); + + + if (KCServer.IsRunning) + { + try + { + packet.HandlePacketServer(); + + ((Packet)packet).SendToAll(); + } + catch (Exception ex) + { + Main.helper.Log($"Error handling packet {id} {packet.GetType().Name} from {packet.clientId}"); + + Main.helper.Log("----------------------- Main exception -----------------------"); + Main.helper.Log(ex.ToString()); + Main.helper.Log("----------------------- Main message -----------------------"); + Main.helper.Log(ex.Message); + Main.helper.Log("----------------------- Main stacktrace -----------------------"); + Main.helper.Log(ex.StackTrace); + if (ex.InnerException != null) + { + Main.helper.Log("----------------------- Inner exception -----------------------"); + Main.helper.Log(ex.InnerException.ToString()); + Main.helper.Log("----------------------- Inner message -----------------------"); + Main.helper.Log(ex.InnerException.Message); + Main.helper.Log("----------------------- Inner stacktrace -----------------------"); + Main.helper.Log(ex.InnerException.StackTrace); + } + } + } + } + + public static void HandlePacket(object sender, MessageReceivedEventArgs messageReceived) + { + try + { + var id = messageReceived.MessageId; + + + //Main.helper.Log($"Client Received packet {Packets[id].packet.GetType().Name} from {messageReceived.FromConnection.Id}"); + + IPacket packet = DeserialisePacket(messageReceived); + + //Main.helper.Log($"Client Received packet {Packets[id].packet.GetType().Name} from {packet.clientId}"); + + if (KCClient.client.IsConnected) + { + try + { + packet.HandlePacketClient(); + } + catch (Exception ex) + { + Main.helper.Log($"Error handling packet {id} {packet.GetType().Name} from {packet.clientId}"); + + Main.helper.Log("----------------------- Main exception -----------------------"); + Main.helper.Log(ex.ToString()); + Main.helper.Log("----------------------- Main message -----------------------"); + Main.helper.Log(ex.Message); + Main.helper.Log("----------------------- Main stacktrace -----------------------"); + Main.helper.Log(ex.StackTrace); + if (ex.InnerException != null) + { + Main.helper.Log("----------------------- Inner exception -----------------------"); + Main.helper.Log(ex.InnerException.ToString()); + Main.helper.Log("----------------------- Inner message -----------------------"); + Main.helper.Log(ex.InnerException.Message); + Main.helper.Log("----------------------- Inner stacktrace -----------------------"); + Main.helper.Log(ex.InnerException.StackTrace); + } + } + } + + /* if (PacketHandlers.TryGetValue(id, out PacketHandlerDelegate handler)) + handler(packet);*/ + + // Main.helper.Log($"{(KCServer.IsRunning ? "Server" : "Client")} Received packet {id} {packet.GetType().kingdomName}"); + //Main.helper.Log($"Found handler: {(handler != null).ToString()}"); + } + catch + { + + } + } + + public static Message SerialisePacket(IPacket packet) + { + + var currentPropName = ""; + try + { + var packetRef = Packets[packet.packetId]; + Message message = Message.Create(MessageSendMode.Reliable, packet.packetId); + + foreach (var prop in packetRef.properties) + { + if (prop.PropertyType.IsEnum) + { + currentPropName = prop.Name; + message.AddInt((int)prop.GetValue(packet, null)); + } + else if (prop.PropertyType == typeof(ushort)) + { + currentPropName = prop.Name; + message.AddUShort((ushort)prop.GetValue(packet, null)); + } + else if (prop.PropertyType == typeof(bool)) + { + currentPropName = prop.Name; + message.AddBool((bool)prop.GetValue(packet, null)); + } + else if (prop.PropertyType == typeof(int)) + { + currentPropName = prop.Name; + message.AddInt((int)prop.GetValue(packet, null)); + } + else if (prop.PropertyType == typeof(string)) + { + currentPropName = prop.Name; + message.AddString((string)prop.GetValue(packet, null)); + } + else if (prop.PropertyType == typeof(float)) + { + currentPropName = prop.Name; + message.AddFloat((float)prop.GetValue(packet, null)); + } + else if (prop.PropertyType == typeof(double)) + { + currentPropName = prop.Name; + message.AddDouble((double)prop.GetValue(packet, null)); + } + else if (prop.PropertyType == typeof(byte[])) + { + currentPropName = prop.Name; + byte[] bytes = (byte[])prop.GetValue(packet, null); + message.AddBytes(bytes, true); + } + else if (prop.PropertyType == typeof(List)) + { + currentPropName = prop.Name; + List list = (List)prop.GetValue(packet, null); + message.AddInt(list.Count); + foreach (var item in list) + message.AddString(item); + } + else if (prop.PropertyType == typeof(List)) + { + currentPropName = prop.Name; + List list = (List)prop.GetValue(packet, null); + message.AddInt(list.Count); + foreach (var item in list) + message.AddBool(item); + } + else if (prop.PropertyType == typeof(List)) + { + currentPropName = prop.Name; + List list = (List)prop.GetValue(packet, null); + message.AddInt(list.Count); + foreach (var item in list) + message.AddUShort(item); + } + else if (prop.PropertyType == typeof(List)) + { + currentPropName = prop.Name; + List list = (List)prop.GetValue(packet, null); + message.AddInt(list.Count); + foreach (var item in list) + message.AddInt(item); + } + + else if (prop.PropertyType.IsGenericType && prop.PropertyType.GetGenericTypeDefinition() == typeof(Dictionary<,>)) + { + currentPropName = prop.Name; + Type[] argumentTypes = prop.PropertyType.GetGenericArguments(); + Type keyType = argumentTypes[0]; + Type valueType = argumentTypes[1]; + + object dictionary = prop.GetValue(packet, null); + + int count = (int)dictionary.GetType().GetProperty("Count").GetValue(dictionary, null); + + var enumerator = ((IEnumerable)dictionary).GetEnumerator(); + while (enumerator.MoveNext()) + { + object key = enumerator.Current.GetType().GetProperty("Key").GetValue(enumerator.Current, null); + object value = enumerator.Current.GetType().GetProperty("Value").GetValue(enumerator.Current, null); + + Main.helper.Log($"Key: {key.GetType()}, Value: {value.GetType()}"); + } + } + else if (prop.PropertyType == typeof(Vector3)) + { + currentPropName = prop.Name; + Vector3 vector = (Vector3)prop.GetValue(packet, null); + message.AddFloat(vector.x); + message.AddFloat(vector.y); + message.AddFloat(vector.z); + } + else if (prop.PropertyType == typeof(Quaternion)) + { + currentPropName = prop.Name; + Quaternion quaternion = (Quaternion)prop.GetValue(packet, null); + message.AddFloat(quaternion.x); + message.AddFloat(quaternion.y); + message.AddFloat(quaternion.z); + message.AddFloat(quaternion.w); + } + else if (prop.PropertyType == typeof(Guid)) + { + currentPropName = prop.Name; + Guid guid = (Guid)prop.GetValue(packet, null); + message.AddBytes(guid.ToByteArray()); + } + else if (prop.PropertyType.IsGenericType && prop.PropertyType.GetGenericTypeDefinition() == typeof(List<>)) + { + currentPropName = prop.Name; + + Type itemType = prop.PropertyType.GetGenericArguments()[0]; + + var list = prop.GetValue(packet, null) as System.Collections.IList; + if (list != null) + { + message.AddInt(list.Count); + + foreach (var item in list) + { + if (itemType.IsClass && itemType != typeof(string) || itemType.IsValueType && !itemType.IsPrimitive) + { + var fields = itemType.GetFields(); // Get fields + Array.Sort(fields, (x, y) => String.Compare(x.Name, y.Name)); + var properties = itemType.GetProperties(); // Get properties + Array.Sort(properties, (x, y) => String.Compare(x.Name, y.Name)); + + + // Serialize fields + foreach (var field in fields) + { + var fieldValue = field.GetValue(item); + AddDynamic(message, fieldValue); + } + + // Serialize properties + foreach (var property in properties) + { + var propertyValue = property.GetValue(item); + AddDynamic(message, propertyValue); + } + } + else + { + AddDynamic(message, item); + } + } + } + } + // You can add more types as needed + } + + + return message; + } + catch (Exception ex) + { + Main.helper.Log($"Failed to serialise packet {packet.packetId} {packet.GetType().Name} at {currentPropName}"); + + Main.helper.Log("----------------------- Main exception -----------------------"); + Main.helper.Log(ex.ToString()); + Main.helper.Log("----------------------- Main message -----------------------"); + Main.helper.Log(ex.Message); + Main.helper.Log("----------------------- Main stacktrace -----------------------"); + Main.helper.Log(ex.StackTrace); + if (ex.InnerException != null) + { + Main.helper.Log("----------------------- Inner exception -----------------------"); + Main.helper.Log(ex.InnerException.ToString()); + Main.helper.Log("----------------------- Inner message -----------------------"); + Main.helper.Log(ex.InnerException.Message); + Main.helper.Log("----------------------- Inner stacktrace -----------------------"); + Main.helper.Log(ex.InnerException.StackTrace); + } + } + return null; + } + + static void AddDynamic(Message message, object value) + { + if (value is int intValue) + message.AddInt(intValue); + else if (value is string stringValue) + message.AddString(stringValue); + else if (value is bool boolValue) + message.AddBool(boolValue); + else if (value is float floatValue) + message.AddFloat(floatValue); + else if (value is double doubleValue) + message.AddDouble(doubleValue); + else if (value is Vector3 vector) + { + message.AddFloat(vector.x); + message.AddFloat(vector.y); + message.AddFloat(vector.z); + } + else if (value is Quaternion quaternion) + { + message.AddFloat(quaternion.x); + message.AddFloat(quaternion.y); + message.AddFloat(quaternion.z); + message.AddFloat(quaternion.w); + } + else if (value is Guid guid) + message.AddBytes(guid.ToByteArray()); + // Add more type checks as necessary + else + throw new NotImplementedException($"Type {value.GetType()} serialization not implemented."); + } + + + public static IPacket DeserialisePacket(MessageReceivedEventArgs messageReceived) + { + try + { + var message = messageReceived.Message; + var packetRef = Packets[messageReceived.MessageId]; + IPacket p = (IPacket)Activator.CreateInstance(packetRef.packet.GetType()); + + + foreach (var prop in packetRef.properties) + { + if (prop.PropertyType.IsEnum) + { + int enumValue = message.GetInt(); + string enumName = Enum.GetName(prop.PropertyType, enumValue); + + prop.SetValue(p, Enum.Parse(prop.PropertyType, enumName)); + } + else if (prop.PropertyType == typeof(ushort)) + { + prop.SetValue(p, message.GetUShort()); + } + else if (prop.PropertyType == typeof(bool)) + { + prop.SetValue(p, message.GetBool()); + } + else if (prop.PropertyType == typeof(int)) + { + prop.SetValue(p, message.GetInt()); + } + else if (prop.PropertyType == typeof(string)) + { + prop.SetValue(p, message.GetString()); + } + else if (prop.PropertyType == typeof(float)) + { + prop.SetValue(p, message.GetFloat()); + } + else if (prop.PropertyType == typeof(double)) + { + prop.SetValue(p, message.GetDouble()); + } + else if (prop.PropertyType == typeof(byte[])) + { + byte[] bytes = message.GetBytes(); + + prop.SetValue(p, bytes); + } + else if (prop.PropertyType == typeof(List)) + { + int count = message.GetInt(); + List list = new List(); + + for (int i = 0; i < count; i++) + list.Add(message.GetString()); + + prop.SetValue(p, list); + } + else if (prop.PropertyType == typeof(List)) + { + int count = message.GetInt(); + List list = new List(); + + for (int i = 0; i < count; i++) + list.Add(message.GetBool()); + + prop.SetValue(p, list); + } + else if (prop.PropertyType == typeof(List)) + { + int count = message.GetInt(); + List list = new List(); + + for (int i = 0; i < count; i++) + list.Add(message.GetUShort()); + + prop.SetValue(p, list); + } + else if (prop.PropertyType == typeof(List)) + { + int count = message.GetInt(); + List list = new List(); + + for (int i = 0; i < count; i++) + list.Add(message.GetInt()); + + prop.SetValue(p, list); + } + else if (prop.PropertyType.IsGenericType && prop.PropertyType.GetGenericTypeDefinition() == typeof(Dictionary<,>)) + { + IDictionary dictionary = (IDictionary)prop.GetValue(p, null); + Type[] argumentTypes = prop.PropertyType.GetGenericArguments(); + Type keyType = argumentTypes[0]; + Type valueType = argumentTypes[1]; + + + message.AddInt(dictionary.Count); + + foreach (DictionaryEntry entry in dictionary) + { + + //Serialize(entry.Key, message); // Implement this method based on the type of 'Key' + //Serialize(entry.Value, message); // Implement this method based on the type of 'Value' + } + } + else if (prop.PropertyType == typeof(Vector3)) + { + Vector3 vector = new Vector3(message.GetFloat(), message.GetFloat(), message.GetFloat()); + prop.SetValue(p, vector); + } + else if (prop.PropertyType == typeof(Quaternion)) + { + Quaternion quaternion = new Quaternion(message.GetFloat(), message.GetFloat(), message.GetFloat(), message.GetFloat()); + prop.SetValue(p, quaternion); + } + else if (prop.PropertyType == typeof(Guid)) + { + Guid guid = new Guid(message.GetBytes()); + prop.SetValue(p, guid); + } + // You can add more types as needed + } + + if (KCServer.IsRunning) + { + //if (!p.GetType().Name.Contains("Update")) + //Main.helper.Log($"Received packet {messageReceived.MessageId} {p.GetType().Name} from {messageReceived.FromConnection.Id}"); + //Main.helper.Log("Setting packet client id to: " + messageReceived.FromConnection.Id + " for packet: " + p.GetType().Name); + //p.clientId = messageReceived.FromConnection.Id; + } + + return p; + } + catch (Exception ex) + { + Main.helper.Log("----------------------- Main exception -----------------------"); + Main.helper.Log(ex.ToString()); + Main.helper.Log("----------------------- Main message -----------------------"); + Main.helper.Log(ex.Message); + Main.helper.Log("----------------------- Main stacktrace -----------------------"); + Main.helper.Log(ex.StackTrace); + if (ex.InnerException != null) + { + Main.helper.Log("----------------------- Inner exception -----------------------"); + Main.helper.Log(ex.InnerException.ToString()); + Main.helper.Log("----------------------- Inner message -----------------------"); + Main.helper.Log(ex.InnerException.Message); + Main.helper.Log("----------------------- Inner stacktrace -----------------------"); + Main.helper.Log(ex.InnerException.StackTrace); + } + } + return null; + } + } +} diff --git a/Packets/IPacket.cs b/Packets/IPacket.cs new file mode 100644 index 0000000..efd47e3 --- /dev/null +++ b/Packets/IPacket.cs @@ -0,0 +1,17 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets +{ + public interface IPacket + { + ushort packetId { get; } + ushort clientId { get; set; } + + void HandlePacketServer(); + void HandlePacketClient(); + } +} diff --git a/Packets/Lobby/ChatMessage.cs b/Packets/Lobby/ChatMessage.cs new file mode 100644 index 0000000..300c0cb --- /dev/null +++ b/Packets/Lobby/ChatMessage.cs @@ -0,0 +1,32 @@ +using KCM.Packets.Handlers; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets.Lobby +{ + public class ChatMessage : Packet + { + public override ushort packetId => (int)Enums.Packets.ChatMessage; + + public string PlayerName { get; set; } + public string Message { get; set; } + + public override void HandlePacketServer() + { + //Main.helper.Log("Received chat packet: " + Message); + + //SendToAll(KCClient.client.Id); + //LobbyHandler.AddChatMessage(clientId, PlayerName, Message); + } + + public override void HandlePacketClient() + { + Main.helper.Log("Received chat packet: " + Message); + + LobbyHandler.AddChatMessage(clientId, PlayerName, Message); + } + } +} diff --git a/Packets/Lobby/ChatSystemMessage.cs b/Packets/Lobby/ChatSystemMessage.cs new file mode 100644 index 0000000..702af7e --- /dev/null +++ b/Packets/Lobby/ChatSystemMessage.cs @@ -0,0 +1,26 @@ +using KCM.Packets.Handlers; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets.Lobby +{ + public class ChatSystemMessage : Packet + { + public override ushort packetId => (int)Enums.Packets.ChatSystemMessage; + + public string Message { get; set; } + + public override void HandlePacketServer() + { + //LobbyHandler.AddSystemMessage(Message); + } + + public override void HandlePacketClient() + { + LobbyHandler.AddSystemMessage(Message); + } + } +} diff --git a/Packets/Lobby/KingdomName.cs b/Packets/Lobby/KingdomName.cs new file mode 100644 index 0000000..57f4ac2 --- /dev/null +++ b/Packets/Lobby/KingdomName.cs @@ -0,0 +1,38 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets.Lobby +{ + public class KingdomName : Packet + { + public override ushort packetId => (int)Enums.Packets.KingdomName; + + public string kingdomName { get; set; } + + public override void HandlePacketServer() + { + if (player == null) + return; + Main.helper.Log("Received kingdom name packet"); + + //SendToAll(KCClient.client.Id); + + player.kingdomName = kingdomName; + } + + public override void HandlePacketClient() + { + if (player == null) + return; + + Main.helper.Log("Received kingdom name packet"); + + player.kingdomName = kingdomName; + + Main.helper.Log($"Player {player.name} has joined with their kingdom {player.kingdomName}"); + } + } +} diff --git a/Packets/Lobby/PlayerBanner.cs b/Packets/Lobby/PlayerBanner.cs new file mode 100644 index 0000000..398a1c5 --- /dev/null +++ b/Packets/Lobby/PlayerBanner.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets.Lobby +{ + public class PlayerBanner : Packet + { + public override ushort packetId => (int)Enums.Packets.PlayerBanner; + + public int banner { get; set; } + + public override void HandlePacketServer() + { + //SendToAll(KCClient.client.Id); + + //player.banner = banner; + //player.inst.PlayerLandmassOwner.SetBannerIdx(banner); + } + + public override void HandlePacketClient() + { + player.banner = banner; + player.inst.PlayerLandmassOwner.SetBannerIdx(banner); + + Main.helper.Log($"Player {clientId} ({player.id}) has set banner to {player.banner}"); + } + } +} diff --git a/Packets/Lobby/PlayerList.cs b/Packets/Lobby/PlayerList.cs new file mode 100644 index 0000000..3597da8 --- /dev/null +++ b/Packets/Lobby/PlayerList.cs @@ -0,0 +1,58 @@ +using KCM.Packets.Handlers; +using Steamworks; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets.Lobby +{ + public class PlayerList : Packet + { + public override ushort packetId => (int)Enums.Packets.PlayerList; + + public List playersReady { get; set; } + public List playersName { get; set; } + public List playersBanner { get; set; } + public List playersId { get; set; } + public List playersKingdomName { get; set; } + + public List steamIds { get; set; } + + public override void HandlePacketServer() + { + + } + + public override void HandlePacketClient() + { + LobbyHandler.ClearPlayerList(); + + for (int i = 0; i < playersId.Count; i++) + { + + Main.helper.Log("PlayerList: " + playersName[i] + " " + playersId[i] + " " + steamIds[i]); + + Main.kCPlayers.Add(steamIds[i], new KCPlayer(playersName[i], playersId[i], steamIds[i]) + { + name = playersName[i], + ready = playersReady[i], + banner = playersBanner[i], + kingdomName = playersKingdomName[i] + }); + + + if (Main.clientSteamIds.ContainsKey(playersId[i])) + Main.clientSteamIds[playersId[i]] = steamIds[i]; + else + Main.clientSteamIds.Add(playersId[i], steamIds[i]); + + Main.kCPlayers[steamIds[i]].inst.PlayerLandmassOwner.SetBannerIdx(playersBanner[i]); + + LobbyHandler.AddPlayerEntry(playersId[i]); + } + } + } +} diff --git a/Packets/Lobby/PlayerReady.cs b/Packets/Lobby/PlayerReady.cs new file mode 100644 index 0000000..e8957f5 --- /dev/null +++ b/Packets/Lobby/PlayerReady.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets.Lobby +{ + public class PlayerReady : Packet + { + public override ushort packetId => (int)Enums.Packets.PlayerReady; + + public bool IsReady { get; set; } + + public override void HandlePacketServer() + { + IsReady = !player.ready; + //SendToAll(KCClient.client.Id); + + player.ready = IsReady; + } + + public override void HandlePacketClient() + { + player.ready = IsReady; + } + } +} diff --git a/Packets/Lobby/SaveTransferPacket.cs b/Packets/Lobby/SaveTransferPacket.cs new file mode 100644 index 0000000..18282e6 --- /dev/null +++ b/Packets/Lobby/SaveTransferPacket.cs @@ -0,0 +1,114 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using static KCM.Main; + +namespace KCM.Packets.Lobby +{ + public class SaveTransferPacket : Packet + { + public override ushort packetId => (ushort)Enums.Packets.SaveTransferPacket; + + public static byte[] saveData = new byte[1]; + public static bool[] chunksReceived = new bool[1]; + public static bool loadingSave = false; + public static int received = 0; + + + public int chunkId { get; set; } + public int chunkSize { get; set; } + + public int saveSize { get; set; } + public int saveDataIndex { get; set; } + public int totalChunks { get; set; } + + public byte[] saveDataChunk { get; set; } + + public override void HandlePacketClient() + { + float savePercent = (float)received / (float)saveSize; + + // Initialize saveData and chunksReceived on the first packet received + if (saveData.Length == 1) + { + + Main.helper.Log("Save Transfer started!"); + loadingSave = true; + + ServerLobbyScript.LoadingSave.SetActive(true); + + // save percentage + + + saveData = new byte[saveSize]; + chunksReceived = new bool[totalChunks]; + } + + + // Copy the chunk data into the correct position in saveData + Array.Copy(saveDataChunk, 0, saveData, saveDataIndex, saveDataChunk.Length); + + // Mark this chunk as received + chunksReceived[chunkId] = true; + + // Seek to the next position to write to + received += chunkSize; + + + ServerLobbyScript.ProgressBar.fillAmount = savePercent; + ServerLobbyScript.ProgressBarText.text = (savePercent * 100).ToString("0.00") + "%"; + ServerLobbyScript.ProgressText.text = $"{((float)(received / 1000)).ToString("0.00")} KB / {((float)(saveSize / 1000)).ToString("0.00")} KB"; + + + if (chunkId + 1 == totalChunks) + { + Main.helper.Log($"Received last save transfer packet."); + + Main.helper.Log(WhichIsNotComplete()); + } + + // Check if all chunks have been received + if (IsTransferComplete()) + { + // Handle completed transfer here + Main.helper.Log("Save Transfer complete!"); + + LoadSaveLoadHook.saveBytes = saveData; + LoadSaveLoadHook.memoryStreamHook = true; + + LoadSave.Load(); + + + LoadSaveLoadHook.saveContainer.Unpack(null); + Broadcast.OnLoadedEvent.Broadcast(new OnLoadedEvent()); + + ServerLobbyScript.LoadingSave.SetActive(false); + } + } + + public static bool IsTransferComplete() + { + return chunksReceived.All(x => x == true); + } + + public static string WhichIsNotComplete() + { + string notComplete = ""; + for (int i = 0; i < chunksReceived.Length; i++) + { + if (!chunksReceived[i]) + { + notComplete += i + ", "; + } + } + return notComplete; + } + + public override void HandlePacketServer() + { + } + } +} diff --git a/Packets/Lobby/ServerSettings.cs b/Packets/Lobby/ServerSettings.cs new file mode 100644 index 0000000..3caacd8 --- /dev/null +++ b/Packets/Lobby/ServerSettings.cs @@ -0,0 +1,48 @@ +using KCM.Packets.Handlers; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets.Lobby +{ + public class ServerSettings : Packet + { + public override ushort packetId => (int)Enums.Packets.ServerSettings; + + public string ServerName { get; set; } + public int MaxPlayers { get; set; } + public bool Locked { get; set; } + public string Password { get; set; } + public int Difficulty { get; set; } + public string WorldSeed { get; set; } + public World.MapSize WorldSize { get; set; } + public World.MapBias WorldType { get; set; } + public World.MapRiverLakes WorldRivers { get; set; } + public int PlacementType { get; set; } + public bool FogOfWar { get; set; } + + public ServerSettings() { this.MaxPlayers = 2; this.Password = " "; this.WorldRivers = World.MapRiverLakes.Some; } + + public override void HandlePacketServer() + { + //SetServerSettings(); + } + + public override void HandlePacketClient() + { + SetServerSettings(); + } + + public void SetServerSettings() + { + + LobbyHandler.ServerSettings = this; + + World.inst.mapSize = WorldSize; + World.inst.mapBias = WorldType; + World.inst.mapRiverLakes = WorldRivers; + } + } +} diff --git a/Packets/Lobby/StartGame.cs b/Packets/Lobby/StartGame.cs new file mode 100644 index 0000000..eefcfcf --- /dev/null +++ b/Packets/Lobby/StartGame.cs @@ -0,0 +1,103 @@ +using KCM.Enums; +using Riptide.Demos.Steam.PlayerHosted; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; + +namespace KCM.Packets.Lobby +{ + public class StartGame : Packet + { + public override ushort packetId => (int)Enums.Packets.StartGame; + + public void Start() + { + Main.helper.Log(GameState.inst.mainMenuMode.ToString()); + + // Hide server lobby + Main.TransitionTo((MenuState)200); + + // This is run when user clicks "accept" on choose your map screeen + + try + { + if (!LobbyManager.loadingSave) + { + SpeedControlUI.inst.SetSpeed(0); + + try + { + typeof(MainMenuMode).GetMethod("StartGame", BindingFlags.NonPublic | BindingFlags.Instance).Invoke(GameState.inst.mainMenuMode, null); + } + catch (Exception ex) + { + Main.helper.Log(ex.Message.ToString()); + Main.helper.Log(ex.ToString()); + } + + SpeedControlUI.inst.SetSpeed(0); + } + else + { + LobbyManager.loadingSave = false; + GameState.inst.SetNewMode(GameState.inst.playingMode); + } + } + catch (Exception ex) + { + // Handle exception here + Main.helper.Log(ex.Message.ToString()); + Main.helper.Log(ex.ToString()); + } + } + + public override void HandlePacketClient() + { + Start(); + } + + public override void HandlePacketServer() + { + //Start(); + + + /*AIBrainsContainer.PreStartAIConfig aiConfig = new AIBrainsContainer.PreStartAIConfig(); + int count = 0; + for (int i = 0; i < RivalKingdomSettingsUI.inst.rivalItems.Length; i++) + { + RivalItemUI r = RivalKingdomSettingsUI.inst.rivalItems[i]; + bool flag = r.Enabled && !r.Locked; + if (flag) + { + count++; + } + } + int idx = 0; + aiConfig.startData = new AIBrainsContainer.PreStartAIConfig.AIStartData[count]; + for (int j = 0; j < RivalKingdomSettingsUI.inst.rivalItems.Length; j++) + { + RivalItemUI item = RivalKingdomSettingsUI.inst.rivalItems[j]; + bool flag2 = item.Enabled && !item.Locked; + if (flag2) + { + aiConfig.startData[idx] = new AIBrainsContainer.PreStartAIConfig.AIStartData(); + aiConfig.startData[idx].landmass = item.flag.landmass; + aiConfig.startData[idx].bioCode = item.bannerIdx; + aiConfig.startData[idx].personalityKey = PersonalityCollection.aiPersonalityKeys[0]; + aiConfig.startData[idx].skillLevel = item.GetSkillLevel(); + idx++; + } + } + AIBrainsContainer.inst.aiStartInfo = aiConfig; + bool isControllerActive = GamepadControl.inst.isControllerActive; + if (isControllerActive) + { + ConsoleCursorMenu.inst.PrepForGamepad(); + }*/ + } + } +} diff --git a/Packets/Lobby/WorldSeed.cs b/Packets/Lobby/WorldSeed.cs new file mode 100644 index 0000000..38655c9 --- /dev/null +++ b/Packets/Lobby/WorldSeed.cs @@ -0,0 +1,46 @@ +using KCM.Packets.Handlers; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; + +namespace KCM.Packets.Lobby +{ + public class WorldSeed : Packet + { + public override ushort packetId => (int)Enums.Packets.WorldSeed; + public int Seed { get; set; } + + public override void HandlePacketServer() + { + //SetWorldSeed(); + } + + public override void HandlePacketClient() + { + SetWorldSeed(); + } + + public void SetWorldSeed() + { + try + { + foreach (var player in Main.kCPlayers.Values) + player.inst.Reset(); + + World.inst.Generate(Seed); + Vector3 center = World.inst.GetCellData(World.inst.GridWidth / 2, World.inst.GridHeight / 2).Center; + Cam.inst.SetTrackingPos(center); + } + catch (Exception e) + { + Main.helper.Log("Set world seed packet error"); + Main.helper.Log(e.Message); + Main.helper.Log(e.StackTrace); + Main.helper.Log(e.ToString()); + } + } + } +} diff --git a/Packets/Network/ClientConnected.cs b/Packets/Network/ClientConnected.cs new file mode 100644 index 0000000..6c55639 --- /dev/null +++ b/Packets/Network/ClientConnected.cs @@ -0,0 +1,145 @@ +using KCM.Packets.Handlers; +using KCM.Packets.Lobby; +using Riptide.Demos.Steam.PlayerHosted; +using Steamworks; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using static KCM.Main; + +namespace KCM.Packets.Network +{ + public class ClientConnected : Packet + { + public override ushort packetId => (int)Enums.Packets.ClientConnected; + + public string Name { get; set; } + + public string SteamId { get; set; } + + public override void HandlePacketClient() + { + + Main.helper.Log("Client Player Connected: " + Name + " Id: " + clientId + " SteamID: " + SteamId); + + KCPlayer player; + if (Main.kCPlayers.TryGetValue(SteamId, out player)) + { + player.id = clientId; + player.name = Name; + player.steamId = SteamId; + } + else + Main.kCPlayers.Add(SteamId, new KCPlayer(Name, clientId, SteamId)); + + + if (Main.clientSteamIds.ContainsKey(clientId)) + Main.clientSteamIds[clientId] = SteamId; + else + Main.clientSteamIds.Add(clientId, SteamId); + + + if (!SaveTransferPacket.loadingSave) + LobbyHandler.AddPlayerEntry(clientId); + } + + public override void HandlePacketServer() + { + Main.helper.Log("Server Player Connected: " + Name + " Id: " + clientId + " SteamID: " + SteamId); + + List list = Main.kCPlayers.Select(x => x.Value).OrderBy(x => x.id).ToList(); + + if (list.Count > 0) + new PlayerList() + { + playersBanner = list.Select(x => x.banner).ToList(), + playersReady = list.Select(x => x.ready).ToList(), + playersName = list.Select(x => x.name).ToList(), + playersId = list.Select(x => x.id).ToList(), + playersKingdomName = list.Select(x => x.kingdomName).ToList(), + steamIds = list.Select(x => x.steamId).ToList() + }.SendToAll(KCClient.client.Id); + + new ChatSystemMessage() + { + Message = $"{Name} has joined the server." + }.SendToAll(); + + LobbyHandler.ServerSettings.SendToAll(KCClient.client.Id); + + + if (LobbyManager.loadingSave) + { + if (clientId == KCClient.client.Id) + return; + + byte[] bytes = LoadSaveLoadAtPathHook.saveData; + int chunkSize = 900; // 900 bytes per chunk to fit within packet size limit + + List chunks = SplitByteArrayIntoChunks(bytes, chunkSize); + Main.helper.Log("Save Transfer started!"); + + int sent = 0; + int packetsSent = 0; + + for (int i = 0; i < chunks.Count; i++) + { + var chunk = chunks[i]; + + + new SaveTransferPacket() + { + saveSize = bytes.Length, + saveDataChunk = chunk, + chunkId = i, + chunkSize = chunk.Length, + saveDataIndex = sent, + totalChunks = chunks.Count + }.Send(clientId); + + Main.helper.Log(" "); + + packetsSent++; + sent += chunk.Length; + } + + Main.helper.Log($"Sent {packetsSent} save data chunks to client"); + } + else + { + + new WorldSeed() + { + Seed = World.inst.seed + }.SendToAll(KCClient.client.Id); + } + } + + public static List SplitByteArrayIntoChunks(byte[] source, int chunkSize) + { + var chunks = new List(); + int sourceLength = source.Length; + + for (int i = 0; i < sourceLength; i += chunkSize) + { + // Calculate the length of the current chunk, as the last chunk may be smaller than chunkSize + int currentChunkSize = Math.Min(chunkSize, sourceLength - i); + + // Create a chunk array of the correct size + byte[] chunk = new byte[currentChunkSize]; + + // Copy a segment of the source array into the chunk array + Array.Copy(source, i, chunk, 0, currentChunkSize); + + // Add the chunk to the list of chunks + chunks.Add(chunk); + } + + return chunks; + } + } +} diff --git a/Packets/Network/ServerHandshake.cs b/Packets/Network/ServerHandshake.cs new file mode 100644 index 0000000..a605357 --- /dev/null +++ b/Packets/Network/ServerHandshake.cs @@ -0,0 +1,69 @@ +using KCM.Enums; +using KCM.Packets.Lobby; +using Riptide; +using Riptide.Demos.Steam.PlayerHosted; +using Steamworks; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using System.Xml.Linq; + +namespace KCM.Packets.Network +{ + public class ServerHandshake : Packet + { + public override ushort packetId => (int)Enums.Packets.ServerHandshake; + + public bool loadingSave { get; set; } + + public override void HandlePacketClient() + { + ModalManager.HideModal(); + + Main.TransitionTo(Enums.MenuState.ServerLobby); + + SfxSystem.PlayUiSelect(); + + Cam.inst.desiredDist = 80f; + Cam.inst.desiredPhi = 45f; + CloudSystem.inst.threshold1 = 0.6f; + CloudSystem.inst.threshold2 = 0.8f; + CloudSystem.inst.BaseFreq = 4.5f; + Weather.inst.SetSeason(Weather.Season.Summer); + + //inst = new KCClient(KCServer.IsRunning ? "Ryan" : "Orion"); + KCClient.inst = new KCClient(SteamFriends.GetPersonaName()); + + Main.helper.Log("Sending client connected. Client ID is: " + clientId); + + Main.kCPlayers.Add(Main.PlayerSteamID, new KCPlayer(KCClient.inst.Name, clientId, Main.PlayerSteamID)); + + Player.inst.PlayerLandmassOwner.teamId = clientId * 10 + 2; + + if (loadingSave && KCServer.IsRunning) + Main.TransitionTo(MenuState.Load); + else if (!loadingSave) + { + Main.TransitionTo(MenuState.NameAndBanner); + + } + + + new KingdomName() { kingdomName = TownNameUI.inst.townName, clientId = clientId }.Send(); + + new ClientConnected() + { + clientId = clientId, + Name = KCClient.inst.Name, + SteamId = Main.PlayerSteamID + }.Send(); + } + + public override void HandlePacketServer() + { + //throw new NotImplementedException(); + } + } +} diff --git a/Packets/Packet.cs b/Packets/Packet.cs new file mode 100644 index 0000000..0764cff --- /dev/null +++ b/Packets/Packet.cs @@ -0,0 +1,113 @@ +using KCM.Packets.Handlers; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets +{ + public abstract class Packet : IPacket + { + public abstract ushort packetId { get; } + public ushort clientId { get; set; } + + public KCPlayer player + { + get + { + KCPlayer p = null; + + if (!Main.clientSteamIds.ContainsKey(clientId)) + return null; + + //Main.helper.Log($"SteamID: {Main.GetPlayerByClientID(clientId).steamId} for {clientId} ({Main.GetPlayerByClientID(clientId).id})"); + + if (Main.kCPlayers.TryGetValue(Main.GetPlayerByClientID(clientId).steamId, out p)) + return p; + else + { + Main.helper.Log($"Error getting player from packet {packetId} {this.GetType().Name} from {clientId}"); + } + + return null; + } + } + + public void SendToAll(ushort exceptToClient = 0) + { + try + { + if (exceptToClient == 0) + { + if (KCServer.IsRunning) + KCServer.server.SendToAll(PacketHandler.SerialisePacket(this)); + } + else + { + if (KCServer.IsRunning && exceptToClient != 0) + KCServer.server.SendToAll(PacketHandler.SerialisePacket(this), exceptToClient); + } + } + catch (Exception ex) + { + Main.helper.Log($"Error sending packet to all {packetId} {this.GetType().Name} from {clientId}"); + + Main.helper.Log("----------------------- Main exception -----------------------"); + Main.helper.Log(ex.ToString()); + Main.helper.Log("----------------------- Main message -----------------------"); + Main.helper.Log(ex.Message); + Main.helper.Log("----------------------- Main stacktrace -----------------------"); + Main.helper.Log(ex.StackTrace); + if (ex.InnerException != null) + { + Main.helper.Log("----------------------- Inner exception -----------------------"); + Main.helper.Log(ex.InnerException.ToString()); + Main.helper.Log("----------------------- Inner message -----------------------"); + Main.helper.Log(ex.InnerException.Message); + Main.helper.Log("----------------------- Inner stacktrace -----------------------"); + Main.helper.Log(ex.InnerException.StackTrace); + } + } + } + + public void Send(ushort toClient = 0) + { + try + { + if (KCClient.client.IsConnected && toClient == 0) + { + this.clientId = KCClient.client.Id; + KCClient.client.Send(PacketHandler.SerialisePacket(this)); + } + else if (KCServer.IsRunning && toClient != 0) + { + KCServer.server.Send(PacketHandler.SerialisePacket(this), toClient); + } + } + catch (Exception ex) + { + Main.helper.Log($"Error sending packet {packetId} {this.GetType().Name} from {clientId}"); + + Main.helper.Log("----------------------- Main exception -----------------------"); + Main.helper.Log(ex.ToString()); + Main.helper.Log("----------------------- Main message -----------------------"); + Main.helper.Log(ex.Message); + Main.helper.Log("----------------------- Main stacktrace -----------------------"); + Main.helper.Log(ex.StackTrace); + if (ex.InnerException != null) + { + Main.helper.Log("----------------------- Inner exception -----------------------"); + Main.helper.Log(ex.InnerException.ToString()); + Main.helper.Log("----------------------- Inner message -----------------------"); + Main.helper.Log(ex.InnerException.Message); + Main.helper.Log("----------------------- Inner stacktrace -----------------------"); + Main.helper.Log(ex.InnerException.StackTrace); + } + } + } + + public abstract void HandlePacketServer(); + public abstract void HandlePacketClient(); + } +} diff --git a/Packets/ShowModal.cs b/Packets/ShowModal.cs new file mode 100644 index 0000000..b78d247 --- /dev/null +++ b/Packets/ShowModal.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace KCM.Packets +{ + public class ShowModal : Packet + { + public override ushort packetId => (int)Enums.Packets.ShowModal; + + public string title { get; set; } + public string message { get; set; } + + public override void HandlePacketClient() + { + Main.helper.Log("Opening Modal"); + Main.helper.Log("Title: " + title); + Main.helper.Log("Message: " + message); + + ModalManager.ShowModal(title, message); + } + + public override void HandlePacketServer() + { + //throw new NotImplementedException(); + } + } +} diff --git a/Packets/State/BuildingStatePacket.cs b/Packets/State/BuildingStatePacket.cs new file mode 100644 index 0000000..70ea74b --- /dev/null +++ b/Packets/State/BuildingStatePacket.cs @@ -0,0 +1,115 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; + +namespace KCM.Packets.State +{ + public class BuildingStatePacket : Packet + { + public override ushort packetId => (ushort)Enums.Packets.BuildingStatePacket; + + public string customName { get; set; } + public Guid guid { get; set; } + public string uniqueName { get; set; } + public Quaternion rotation { get; set; } + public Vector3 globalPosition { get; set; } + public Vector3 localPosition { get; set; } + public bool built { get; set; } + public bool placed { get; set; } + public bool open { get; set; } + public bool doBuildAnimation { get; set; } + public bool constructionPaused { get; set; } + public float constructionProgress { get; set; } + public float resourceProgress { get; set; } + public float life { get; set; } + public float ModifiedMaxLife { get; set; } + public int yearBuilt { get; set; } + public float decayProtection { get; set; } + public bool seenByPlayer { get; set; } + + + public override void HandlePacketClient() + { + if (clientId == KCClient.client.Id) return; //prevent double placing on same client + + //Main.helper.Log("Received building state packet for: " + uniqueName + " from " + Main.kCPlayers[Main.GetPlayerByClientID(clientId).steamId].name + $"({clientId})"); + + + Building building = player.inst.GetBuilding(guid); + + if (building == null) + { + Main.helper.Log("Building not found."); + return; + } + + try + { + //PrintProperties(); + + building.UniqueName = uniqueName; + building.customName = customName; + + + building.transform.position = this.globalPosition; + building.transform.GetChild(0).rotation = this.rotation; + building.transform.GetChild(0).localPosition = this.localPosition; + + SetPrivateFieldValue(building, "built", built); + SetPrivateFieldValue(building, "placed", placed); + SetPrivateFieldValue(building, "resourceProgress", resourceProgress); + + + building.Open = open; + building.doBuildAnimation = doBuildAnimation; + building.constructionPaused = constructionPaused; + building.constructionProgress = constructionProgress; + building.Life = life; + building.ModifiedMaxLife = ModifiedMaxLife; + + + //building.yearBuilt = yearBuilt; + SetPrivateFieldValue(building, "yearBuilt", yearBuilt); + + building.decayProtection = decayProtection; + //building.seenByPlayer = seenByPlayer; + } + catch (Exception e) + { + Main.helper.Log("Error setting building state"); + Main.helper.Log(e.Message); + Main.helper.Log(e.StackTrace); + } + } + + public override void HandlePacketServer() + { + //throw new NotImplementedException(); + } + + private void SetPrivateFieldValue(object obj, string fieldName, object value) + { + Type type = obj.GetType(); + FieldInfo field = type.GetField(fieldName, BindingFlags.NonPublic | BindingFlags.Instance); + field.SetValue(obj, value); + } + + public void PrintProperties() + { + Type type = typeof(BuildingStatePacket); + + foreach (PropertyInfo property in type.GetProperties()) + { + object value = property.GetValue(this); + string propertyName = property.Name; + + Main.helper.Log($"{propertyName}: {value}"); + } + } + + } +} diff --git a/PrefabManager.cs b/PrefabManager.cs new file mode 100644 index 0000000..083a765 --- /dev/null +++ b/PrefabManager.cs @@ -0,0 +1,54 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; + +namespace KCM +{ + public class PrefabManager + { + public static AssetBundle assetBundle; + public static GameObject serverBrowserPrefab; + public static GameObject serverEntryItemPrefab; + + public static GameObject serverLobbyPrefab; + public static GameObject serverLobbyPlayerEntryPrefab; + public static GameObject serverChatEntryPrefab; + public static GameObject serverChatSystemEntryPrefab; + + public static GameObject modalUIPrefab; + + public void PreScriptLoad(KCModHelper _helper) + { + try + { + //Main.helper = _helper; + + assetBundle = KCModHelper.LoadAssetBundle(_helper.modPath, "serverbrowserpkg"); + + Main.helper.Log(String.Join(", ", assetBundle.GetAllAssetNames())); + + serverBrowserPrefab = assetBundle.LoadAsset("assets/workspace/serverbrowser.prefab") as GameObject; + serverEntryItemPrefab = assetBundle.LoadAsset("assets/workspace/serverentryitem.prefab") as GameObject; + + + serverLobbyPrefab = assetBundle.LoadAsset("assets/workspace/serverlobby.prefab") as GameObject; + serverLobbyPlayerEntryPrefab = assetBundle.LoadAsset("assets/workspace/serverlobbyplayerentry.prefab") as GameObject; + serverChatEntryPrefab = assetBundle.LoadAsset("assets/workspace/serverchatentry.prefab") as GameObject; + serverChatSystemEntryPrefab = assetBundle.LoadAsset("assets/workspace/serverchatsystementry.prefab") as GameObject; + + modalUIPrefab = assetBundle.LoadAsset("assets/workspace/modalui.prefab") as GameObject; + + Main.helper.Log("Loaded assets"); + } + catch (Exception ex) + { + Main.helper.Log(ex.ToString()); + Main.helper.Log(ex.Message); + Main.helper.Log(ex.StackTrace); + } + } + } +} diff --git a/ReflectionHelper/ReflectionHelper.cs b/ReflectionHelper/ReflectionHelper.cs new file mode 100644 index 0000000..56764d4 --- /dev/null +++ b/ReflectionHelper/ReflectionHelper.cs @@ -0,0 +1,39 @@ +using System; +using System.Collections.Generic; +using System.Reflection; + +public static class ReflectionHelper +{ + public static void ClearPrivateListField(object classInstance, string fieldName) + { + // Get the Type object representing the class of the instance + Type classType = classInstance.GetType(); + + // Get the FieldInfo for the specified field name, assuming it's private and an instance field + FieldInfo fieldInfo = classType.GetField(fieldName, BindingFlags.NonPublic | BindingFlags.Instance); + + if (fieldInfo != null) + { + // Get the value of the field (the instance of the list) from the class instance + object fieldValue = fieldInfo.GetValue(classInstance); + + // Check if the field is actually a List + if (fieldValue is List listInstance) + { + // Get the MethodInfo for the Clear method + MethodInfo clearMethodInfo = typeof(List).GetMethod("Clear", BindingFlags.Public | BindingFlags.Instance); + + // Invoke the Clear method on the list instance + clearMethodInfo?.Invoke(listInstance, null); + } + else + { + throw new InvalidOperationException("The specified field is not a List."); + } + } + else + { + throw new ArgumentException("The specified field was not found in the class instance.", nameof(fieldName)); + } + } +} diff --git a/Riptide/Client.cs b/Riptide/Client.cs new file mode 100644 index 0000000..c958e3e --- /dev/null +++ b/Riptide/Client.cs @@ -0,0 +1,418 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using Riptide.Transports; +using Riptide.Utils; +using System; +using System.Collections.Generic; +using System.Reflection; + +namespace Riptide +{ + /// A client that can connect to a . + public class Client : Peer + { + /// Invoked when a connection to the server is established. + public event EventHandler Connected; + /// Invoked when a connection to the server fails to be established. + public event EventHandler ConnectionFailed; + /// Invoked when a message is received. + public event EventHandler MessageReceived; + /// Invoked when disconnected from the server. + public event EventHandler Disconnected; + /// Invoked when another non-local client connects. + public event EventHandler ClientConnected; + /// Invoked when another non-local client disconnects. + public event EventHandler ClientDisconnected; + + /// The client's numeric ID. + public ushort Id => connection.Id; + /// + public short RTT => connection.RTT; + /// + /// This value is slower to accurately represent lasting changes in latency than , but it is less susceptible to changing drastically due to significant—but temporary—jumps in latency. + public short SmoothRTT => connection.SmoothRTT; + /// Sets the client's . + public override int TimeoutTime + { + set + { + defaultTimeout = value; + connection.TimeoutTime = defaultTimeout; + } + } + /// Whether or not the client is currently not trying to connect, pending, nor actively connected. + public bool IsNotConnected => connection is null || connection.IsNotConnected; + /// Whether or not the client is currently in the process of connecting. + public bool IsConnecting => !(connection is null) && connection.IsConnecting; + /// Whether or not the client's connection is currently pending (waiting to be accepted/rejected by the server). + public bool IsPending => !(connection is null) && connection.IsPending; + /// Whether or not the client is currently connected. + public bool IsConnected => !(connection is null) && connection.IsConnected; + /// The client's connection to a server. + // Not an auto property because properties can't be passed as ref/out parameters. Could + // use a local variable in the Connect method, but that's arguably not any cleaner. This + // property will also probably only be used rarely from outside the class/library. + public Connection Connection => connection; + /// Encapsulates a method that handles a message from a server. + /// The message that was received. + public delegate void MessageHandler(Message message); + + /// + private Connection connection; + /// How many connection attempts have been made so far. + private int connectionAttempts; + /// How many connection attempts to make before giving up. + private int maxConnectionAttempts; + /// + private Dictionary messageHandlers; + /// The underlying transport's client that is used for sending and receiving data. + private IClient transport; + /// The message sent when connecting. May include custom data. + private Message connectMessage; + + /// Handles initial setup. + /// The transport to use for sending and receiving data. + /// The name to use when logging messages via . + public Client(IClient transport, string logName = "CLIENT") : base(logName) + { + this.transport = transport; + } + /// Handles initial setup using the built-in UDP transport. + /// The name to use when logging messages via . + public Client(string logName = "CLIENT") : this(new Transports.Udp.UdpClient(), logName) { } + + /// Disconnects the client if it's connected and swaps out the transport it's using. + /// The new transport to use for sending and receiving data. + /// This method does not automatically reconnect to the server. To continue communicating with the server, must be called again. + public void ChangeTransport(IClient newTransport) + { + Disconnect(); + transport = newTransport; + } + + /// Attempts to connect to a server at the given host address. + /// The host address to connect to. + /// How many connection attempts to make before giving up. + /// The ID of the group of message handler methods to use when building . + /// Data that should be sent to the server with the connection attempt. Use to get an empty message instance. + /// Whether or not the client should use the built-in message handler system. + /// + /// Riptide's default transport expects the host address to consist of an IP and port, separated by a colon. For example: 127.0.0.1:7777. If you are using a different transport, check the relevant documentation for what information it requires in the host address. + /// Setting to will disable the automatic detection and execution of methods with the , which is beneficial if you prefer to handle messages via the event. + /// + /// if a connection attempt will be made. if an issue occurred (such as being in an invalid format) and a connection attempt will not be made. + public bool Connect(string hostAddress, int maxConnectionAttempts = 5, byte messageHandlerGroupId = 0, Message message = null, bool useMessageHandlers = true) + { + Disconnect(); + + SubToTransportEvents(); + + if (!transport.Connect(hostAddress, out connection, out string connectError)) + { + RiptideLogger.Log(LogType.Error, LogName, connectError); + UnsubFromTransportEvents(); + return false; + } + + this.maxConnectionAttempts = maxConnectionAttempts; + connectionAttempts = 0; + connection.Initialize(this, defaultTimeout); + IncreaseActiveCount(); + this.useMessageHandlers = useMessageHandlers; + if (useMessageHandlers) + CreateMessageHandlersDictionary(messageHandlerGroupId); + + connectMessage = Message.Create(MessageHeader.Connect); + if (message != null) + { + if (message.ReadBits != 0) + RiptideLogger.Log(LogType.Error, LogName, $"Use the parameterless 'Message.Create()' overload when setting connection attempt data!"); + + connectMessage.AddMessage(message); + message.Release(); + } + + StartTime(); + Heartbeat(); + RiptideLogger.Log(LogType.Info, LogName, $"Connecting to {connection}..."); + return true; + } + + /// Subscribes appropriate methods to the transport's events. + private void SubToTransportEvents() + { + transport.Connected += TransportConnected; + transport.ConnectionFailed += TransportConnectionFailed; + transport.DataReceived += HandleData; + transport.Disconnected += TransportDisconnected; + } + + /// Unsubscribes methods from all of the transport's events. + private void UnsubFromTransportEvents() + { + transport.Connected -= TransportConnected; + transport.ConnectionFailed -= TransportConnectionFailed; + transport.DataReceived -= HandleData; + transport.Disconnected -= TransportDisconnected; + } + + /// + protected override void CreateMessageHandlersDictionary(byte messageHandlerGroupId) + { + MethodInfo[] methods = FindMessageHandlers(); + + messageHandlers = new Dictionary(methods.Length); + foreach (MethodInfo method in methods) + { + MessageHandlerAttribute attribute = method.GetCustomAttribute(); + if (attribute.GroupId != messageHandlerGroupId) + continue; + + if (!method.IsStatic) + throw new NonStaticHandlerException(method.DeclaringType, method.Name); + + Delegate clientMessageHandler = Delegate.CreateDelegate(typeof(MessageHandler), method, false); + if (clientMessageHandler != null) + { + // It's a message handler for Client instances + if (messageHandlers.ContainsKey(attribute.MessageId)) + { + MethodInfo otherMethodWithId = messageHandlers[attribute.MessageId].GetMethodInfo(); + throw new DuplicateHandlerException(attribute.MessageId, method, otherMethodWithId); + } + else + messageHandlers.Add(attribute.MessageId, (MessageHandler)clientMessageHandler); + } + else + { + // It's not a message handler for Client instances, but it might be one for Server instances + if (Delegate.CreateDelegate(typeof(Server.MessageHandler), method, false) == null) + throw new InvalidHandlerSignatureException(method.DeclaringType, method.Name); + } + } + } + + /// + internal override void Heartbeat() + { + if (IsConnecting) + { + // If still trying to connect, send connect messages instead of heartbeats + if (connectionAttempts < maxConnectionAttempts) + { + Send(connectMessage, false); + connectionAttempts++; + } + else + LocalDisconnect(DisconnectReason.NeverConnected); + } + else if (IsPending) + { + // If waiting for the server to accept/reject the connection attempt + if (connection.HasConnectAttemptTimedOut) + { + LocalDisconnect(DisconnectReason.TimedOut); + return; + } + } + else if (IsConnected) + { + // If connected and not timed out, send heartbeats + if (connection.HasTimedOut) + { + LocalDisconnect(DisconnectReason.TimedOut); + return; + } + + connection.SendHeartbeat(); + } + + ExecuteLater(HeartbeatInterval, new HeartbeatEvent(this)); + } + + /// + public override void Update() + { + base.Update(); + transport.Poll(); + HandleMessages(); + } + + /// + protected override void Handle(Message message, MessageHeader header, Connection connection) + { + switch (header) + { + // User messages + case MessageHeader.Unreliable: + case MessageHeader.Reliable: + OnMessageReceived(message); + break; + + // Internal messages + case MessageHeader.Ack: + connection.HandleAck(message); + break; + case MessageHeader.Connect: + connection.SetPending(); + break; + case MessageHeader.Reject: + if (!IsConnected) // Don't disconnect if we are connected + LocalDisconnect(DisconnectReason.ConnectionRejected, message, (RejectReason)message.GetByte()); + break; + case MessageHeader.Heartbeat: + connection.HandleHeartbeatResponse(message); + break; + case MessageHeader.Disconnect: + LocalDisconnect((DisconnectReason)message.GetByte(), message); + break; + case MessageHeader.Welcome: + if (IsConnecting || IsPending) + { + connection.HandleWelcome(message); + OnConnected(); + } + break; + case MessageHeader.ClientConnected: + OnClientConnected(message.GetUShort()); + break; + case MessageHeader.ClientDisconnected: + OnClientDisconnected(message.GetUShort()); + break; + default: + RiptideLogger.Log(LogType.Warning, LogName, $"Unexpected message header '{header}'! Discarding {message.BytesInUse} bytes."); + break; + } + + message.Release(); + } + + /// Sends a message to the server. + /// + public ushort Send(Message message, bool shouldRelease = true) => connection.Send(message, shouldRelease); + + /// Disconnects from the server. + public void Disconnect() + { + if (connection == null || IsNotConnected) + return; + + Send(Message.Create(MessageHeader.Disconnect)); + LocalDisconnect(DisconnectReason.Disconnected); + } + + /// + internal override void Disconnect(Connection connection, DisconnectReason reason) + { + if (connection.IsConnected && connection.CanQualityDisconnect) + LocalDisconnect(reason); + } + + /// Cleans up the local side of the connection. + /// The reason why the client has disconnected. + /// The disconnection or rejection message, potentially containing extra data to be handled externally. + /// The reason why the connection was rejected (if it was rejected). + private void LocalDisconnect(DisconnectReason reason, Message message = null, RejectReason rejectReason = RejectReason.NoConnection) + { + if (IsNotConnected) + return; + + UnsubFromTransportEvents(); + DecreaseActiveCount(); + + StopTime(); + transport.Disconnect(); + + connection.LocalDisconnect(); + + if (reason == DisconnectReason.NeverConnected) + OnConnectionFailed(RejectReason.NoConnection); + else if (reason == DisconnectReason.ConnectionRejected) + OnConnectionFailed(rejectReason, message); + else + OnDisconnected(reason, message); + } + + /// What to do when the transport establishes a connection. + private void TransportConnected(object sender, EventArgs e) { } + + /// What to do when the transport fails to connect. + private void TransportConnectionFailed(object sender, EventArgs e) + { + LocalDisconnect(DisconnectReason.NeverConnected); + } + + /// What to do when the transport disconnects. + private void TransportDisconnected(object sender, Transports.DisconnectedEventArgs e) + { + if (connection == e.Connection) + LocalDisconnect(e.Reason); + } + + #region Events + /// Invokes the event. + protected virtual void OnConnected() + { + connectMessage.Release(); + connectMessage = null; + RiptideLogger.Log(LogType.Info, LogName, "Connected successfully!"); + Connected?.Invoke(this, EventArgs.Empty); + } + + /// Invokes the event. + /// The reason for the connection failure. + /// Additional data related to the failed connection attempt. + protected virtual void OnConnectionFailed(RejectReason reason, Message message = null) + { + connectMessage.Release(); + connectMessage = null; + RiptideLogger.Log(LogType.Info, LogName, $"Connection to server failed: {Helper.GetReasonString(reason)}."); + ConnectionFailed?.Invoke(this, new ConnectionFailedEventArgs(reason, message)); + } + + /// Invokes the event and initiates handling of the received message. + /// The received message. + protected virtual void OnMessageReceived(Message message) + { + ushort messageId = (ushort)message.GetVarULong(); + MessageReceived?.Invoke(this, new MessageReceivedEventArgs(connection, messageId, message)); + + if (useMessageHandlers) + { + if (messageHandlers.TryGetValue(messageId, out MessageHandler messageHandler)) + messageHandler(message); + else + RiptideLogger.Log(LogType.Warning, LogName, $"No message handler method found for message ID {messageId}!"); + } + } + + /// Invokes the event. + /// The reason for the disconnection. + /// Additional data related to the disconnection. + protected virtual void OnDisconnected(DisconnectReason reason, Message message) + { + RiptideLogger.Log(LogType.Info, LogName, $"Disconnected from server: {Helper.GetReasonString(reason)}."); + Disconnected?.Invoke(this, new DisconnectedEventArgs(reason, message)); + } + + /// Invokes the event. + /// The numeric ID of the client that connected. + protected virtual void OnClientConnected(ushort clientId) + { + RiptideLogger.Log(LogType.Info, LogName, $"Client {clientId} connected."); + ClientConnected?.Invoke(this, new ClientConnectedEventArgs(clientId)); + } + + /// Invokes the event. + /// The numeric ID of the client that disconnected. + protected virtual void OnClientDisconnected(ushort clientId) + { + RiptideLogger.Log(LogType.Info, LogName, $"Client {clientId} disconnected."); + ClientDisconnected?.Invoke(this, new ClientDisconnectedEventArgs(clientId)); + } + #endregion + } +} diff --git a/Riptide/Connection.cs b/Riptide/Connection.cs new file mode 100644 index 0000000..f34ac90 --- /dev/null +++ b/Riptide/Connection.cs @@ -0,0 +1,648 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using Riptide.Transports; +using Riptide.Utils; +using System; +using System.Collections.Generic; + +namespace Riptide +{ + /// The state of a connection. + internal enum ConnectionState : byte + { + /// Not connected. No connection has been established or the connection has been closed. + NotConnected, + /// Connecting. Still trying to establish a connection. + Connecting, + /// Connection is pending. The server is still determining whether or not the connection should be allowed. + Pending, + /// Connected. A connection has been established successfully. + Connected, + } + + /// Represents a connection to a or . + public abstract class Connection + { + /// Invoked when the notify message with the given sequence ID is successfully delivered. + public Action NotifyDelivered; + /// Invoked when the notify message with the given sequence ID is lost. + public Action NotifyLost; + /// Invoked when a notify message is received. + public Action NotifyReceived; + /// Invoked when the reliable message with the given sequence ID is successfully delivered. + public Action ReliableDelivered; + + /// The connection's numeric ID. + public ushort Id { get; internal set; } + /// Whether or not the connection is currently not trying to connect, pending, nor actively connected. + public bool IsNotConnected => state == ConnectionState.NotConnected; + /// Whether or not the connection is currently in the process of connecting. + public bool IsConnecting => state == ConnectionState.Connecting; + /// Whether or not the connection is currently pending (waiting to be accepted/rejected by the server). + public bool IsPending => state == ConnectionState.Pending; + /// Whether or not the connection is currently connected. + public bool IsConnected => state == ConnectionState.Connected; + /// The round trip time (ping) of the connection, in milliseconds. -1 if not calculated yet. + public short RTT + { + get => _rtt; + private set + { + SmoothRTT = _rtt == -1 ? value : (short)Math.Max(1f, SmoothRTT * 0.7f + value * 0.3f); + _rtt = value; + } + } + private short _rtt; + /// The smoothed round trip time (ping) of the connection, in milliseconds. -1 if not calculated yet. + /// This value is slower to accurately represent lasting changes in latency than , but it is less susceptible to changing drastically due to significant—but temporary—jumps in latency. + public short SmoothRTT { get; private set; } + /// The time (in milliseconds) after which to disconnect if no heartbeats are received. + public int TimeoutTime { get; set; } + /// Whether or not the connection can time out. + public bool CanTimeout + { + get => _canTimeout; + set + { + if (value) + ResetTimeout(); + + _canTimeout = value; + } + } + private bool _canTimeout; + /// Whether or not the connection can disconnect due to poor connection quality. + /// When this is set to , , , + /// and are ignored and exceeding their values will not trigger a disconnection. + public bool CanQualityDisconnect; + /// The connection's metrics. + public readonly ConnectionMetrics Metrics; + /// The maximum acceptable average number of send attempts it takes to deliver a reliable message. The connection + /// will be closed if this is exceeded more than times in a row. + public int MaxAvgSendAttempts; + /// How many consecutive times can be exceeded before triggering a disconnect. + public int AvgSendAttemptsResilience; + /// The absolute maximum number of times a reliable message may be sent. A single message reaching this threshold will cause a disconnection. + public int MaxSendAttempts; + /// The maximum acceptable loss rate of notify messages. The connection will be closed if this is exceeded more than times in a row. + public float MaxNotifyLoss; + /// How many consecutive times can be exceeded before triggering a disconnect. + public int NotifyLossResilience; + + /// The local peer this connection is associated with. + internal Peer Peer { get; private set; } + /// Whether or not the connection has timed out. + internal bool HasTimedOut => _canTimeout && Peer.CurrentTime - lastHeartbeat > TimeoutTime; + /// Whether or not the connection attempt has timed out. + internal bool HasConnectAttemptTimedOut => _canTimeout && Peer.CurrentTime - lastHeartbeat > Peer.ConnectTimeoutTime; + + /// The sequencer for notify messages. + private readonly NotifySequencer notify; + /// The sequencer for reliable messages. + private readonly ReliableSequencer reliable; + /// The currently pending reliably sent messages whose delivery has not been acknowledged yet. Stored by sequence ID. + private readonly Dictionary pendingMessages; + /// The connection's current state. + private ConnectionState state; + /// The number of consecutive times that the threshold was exceeded. + private int sendAttemptsViolations; + /// The number of consecutive times that the threshold was exceeded. + private int lossRateViolations; + /// The time at which the last heartbeat was received from the other end. + private long lastHeartbeat; + /// The ID of the last ping that was sent. + private byte lastPingId; + /// The ID of the currently pending ping. + private byte pendingPingId; + /// The time at which the currently pending ping was sent. + private long pendingPingSendTime; + + /// Initializes the connection. + protected Connection() + { + Metrics = new ConnectionMetrics(); + notify = new NotifySequencer(this); + reliable = new ReliableSequencer(this); + state = ConnectionState.Connecting; + _rtt = -1; + SmoothRTT = -1; + _canTimeout = true; + CanQualityDisconnect = true; + MaxAvgSendAttempts = 5; + AvgSendAttemptsResilience = 64; + MaxSendAttempts = 15; + MaxNotifyLoss = 0.05f; // 5% + NotifyLossResilience = 64; + pendingMessages = new Dictionary(); + } + + /// Initializes connection data. + /// The which this connection belongs to. + /// The timeout time. + internal void Initialize(Peer peer, int timeoutTime) + { + Peer = peer; + TimeoutTime = timeoutTime; + } + + /// Resets the connection's timeout time. + public void ResetTimeout() + { + lastHeartbeat = Peer.CurrentTime; + } + + /// Sends a message. + /// The message to send. + /// Whether or not to return the message to the pool after it is sent. + /// For reliable and notify messages, the sequence ID that the message was sent with. 0 for unreliable messages. + /// + /// If you intend to continue using the message instance after calling this method, you must set + /// to . can be used to manually return the message to the pool at a later time. + /// + public ushort Send(Message message, bool shouldRelease = true) + { + ushort sequenceId = 0; + if (message.SendMode == MessageSendMode.Notify) + { + sequenceId = notify.InsertHeader(message); + int byteAmount = message.BytesInUse; + Buffer.BlockCopy(message.Data, 0, Message.ByteBuffer, 0, byteAmount); + Send(Message.ByteBuffer, byteAmount); + Metrics.SentNotify(byteAmount); + } + else if (message.SendMode == MessageSendMode.Unreliable) + { + int byteAmount = message.BytesInUse; + Buffer.BlockCopy(message.Data, 0, Message.ByteBuffer, 0, byteAmount); + Send(Message.ByteBuffer, byteAmount); + Metrics.SentUnreliable(byteAmount); + } + else + { + sequenceId = reliable.NextSequenceId; + PendingMessage pendingMessage = PendingMessage.Create(sequenceId, message, this); + pendingMessages.Add(sequenceId, pendingMessage); + pendingMessage.TrySend(); + Metrics.ReliableUniques++; + } + + if (shouldRelease) + message.Release(); + + return sequenceId; + } + + /// Sends data. + /// The array containing the data. + /// The number of bytes in the array which should be sent. + protected internal abstract void Send(byte[] dataBuffer, int amount); + + /// Processes a notify message. + /// The received data. + /// The number of bytes that were received. + /// The message instance to use. + internal void ProcessNotify(byte[] dataBuffer, int amount, Message message) + { + notify.UpdateReceivedAcks(Converter.UShortFromBits(dataBuffer, Message.HeaderBits), Converter.ByteFromBits(dataBuffer, Message.HeaderBits + 16)); + + Metrics.ReceivedNotify(amount); + if (notify.ShouldHandle(Converter.UShortFromBits(dataBuffer, Message.HeaderBits + 24))) + { + Buffer.BlockCopy(dataBuffer, 1, message.Data, 1, amount - 1); // Copy payload + NotifyReceived?.Invoke(message); + } + else + Metrics.NotifyDiscarded++; + } + + /// Determines if the message with the given sequence ID should be handled. + /// The message's sequence ID. + /// Whether or not the message should be handled. + internal bool ShouldHandle(ushort sequenceId) + { + return reliable.ShouldHandle(sequenceId); + } + + /// Cleans up the local side of the connection. + internal void LocalDisconnect() + { + state = ConnectionState.NotConnected; + + foreach (PendingMessage pendingMessage in pendingMessages.Values) + pendingMessage.Clear(); + + pendingMessages.Clear(); + } + + /// Resends the with the given sequence ID. + /// The sequence ID of the message to resend. + private void ResendMessage(ushort sequenceId) + { + if (pendingMessages.TryGetValue(sequenceId, out PendingMessage pendingMessage)) + pendingMessage.RetrySend(); + } + + /// Clears the with the given sequence ID. + /// The sequence ID that was acknowledged. + internal void ClearMessage(ushort sequenceId) + { + if (pendingMessages.TryGetValue(sequenceId, out PendingMessage pendingMessage)) + { + ReliableDelivered?.Invoke(sequenceId); + pendingMessage.Clear(); + pendingMessages.Remove(sequenceId); + UpdateSendAttemptsViolations(); + } + } + + /// Puts the connection in the pending state. + internal void SetPending() + { + if (IsConnecting) + { + state = ConnectionState.Pending; + ResetTimeout(); + } + } + + /// Checks the average send attempts (of reliable messages) and updates accordingly. + private void UpdateSendAttemptsViolations() + { + if (Metrics.RollingReliableSends.Mean > MaxAvgSendAttempts) + { + sendAttemptsViolations++; + if (sendAttemptsViolations >= AvgSendAttemptsResilience) + Peer.Disconnect(this, DisconnectReason.PoorConnection); + } + else + sendAttemptsViolations = 0; + } + + /// Checks the loss rate (of notify messages) and updates accordingly. + private void UpdateLossViolations() + { + if (Metrics.RollingNotifyLossRate > MaxNotifyLoss) + { + lossRateViolations++; + if (lossRateViolations >= NotifyLossResilience) + Peer.Disconnect(this, DisconnectReason.PoorConnection); + } + else + lossRateViolations = 0; + } + + #region Messages + /// Sends an ack message for the given sequence ID. + /// The sequence ID to acknowledge. + /// The sequence ID of the latest message we've received. + /// Sequence IDs of previous messages that we have (or have not received). + private void SendAck(ushort forSeqId, ushort lastReceivedSeqId, Bitfield receivedSeqIds) + { + Message message = Message.Create(MessageHeader.Ack); + message.AddUShort(lastReceivedSeqId); + message.AddUShort(receivedSeqIds.First16); + + if (forSeqId == lastReceivedSeqId) + message.AddBool(false); + else + message.AddBool(true); + message.AddUShort(forSeqId); + + Send(message); + } + + /// Handles an ack message. + /// The ack message to handle. + internal void HandleAck(Message message) + { + ushort remoteLastReceivedSeqId = message.GetUShort(); + ushort remoteAcksBitField = message.GetUShort(); + ushort ackedSeqId = message.GetBool() ? message.GetUShort() : remoteLastReceivedSeqId; + + ClearMessage(ackedSeqId); + reliable.UpdateReceivedAcks(remoteLastReceivedSeqId, remoteAcksBitField); + } + + #region Server + /// Sends a welcome message. + internal void SendWelcome() + { + Message message = Message.Create(MessageHeader.Welcome); + message.AddUShort(Id); + + Send(message); + } + + /// Handles a welcome message on the server. + /// The welcome message to handle. + /// Whether or not the connection is now connected. + internal bool HandleWelcomeResponse(Message message) + { + if (!IsPending) + return false; + + ushort id = message.GetUShort(); + if (Id != id) + RiptideLogger.Log(LogType.Error, Peer.LogName, $"Client has assumed ID {id} instead of {Id}!"); + + state = ConnectionState.Connected; + ResetTimeout(); + return true; + } + + /// Handles a heartbeat message. + /// The heartbeat message to handle. + internal void HandleHeartbeat(Message message) + { + if (!IsConnected) + return; // A client that is not yet fully connected should not be sending heartbeats + + RespondHeartbeat(message.GetByte()); + RTT = message.GetShort(); + + ResetTimeout(); + } + + /// Sends a heartbeat message. + private void RespondHeartbeat(byte pingId) + { + Message message = Message.Create(MessageHeader.Heartbeat); + message.AddByte(pingId); + + Send(message); + } + #endregion + + #region Client + /// Handles a welcome message on the client. + /// The welcome message to handle. + internal void HandleWelcome(Message message) + { + Id = message.GetUShort(); + state = ConnectionState.Connected; + ResetTimeout(); + + RespondWelcome(); + } + + /// Sends a welcome response message. + private void RespondWelcome() + { + Message message = Message.Create(MessageHeader.Welcome); + message.AddUShort(Id); + + Send(message); + } + + /// Sends a heartbeat message. + internal void SendHeartbeat() + { + pendingPingId = lastPingId++; + pendingPingSendTime = Peer.CurrentTime; + + Message message = Message.Create(MessageHeader.Heartbeat); + message.AddByte(pendingPingId); + message.AddShort(RTT); + + Send(message); + } + + /// Handles a heartbeat message. + /// The heartbeat message to handle. + internal void HandleHeartbeatResponse(Message message) + { + byte pingId = message.GetByte(); + + if (pendingPingId == pingId) + RTT = (short)Math.Max(1, Peer.CurrentTime - pendingPingSendTime); + + ResetTimeout(); + } + #endregion + #endregion + + #region Events + /// Invokes the event. + /// The sequence ID of the delivered message. + protected virtual void OnNotifyDelivered(ushort sequenceId) + { + Metrics.DeliveredNotify(); + NotifyDelivered?.Invoke(sequenceId); + UpdateLossViolations(); + } + + /// Invokes the event. + /// The sequence ID of the lost message. + protected virtual void OnNotifyLost(ushort sequenceId) + { + Metrics.LostNotify(); + NotifyLost?.Invoke(sequenceId); + UpdateLossViolations(); + } + #endregion + + #region Message Sequencing + /// Provides functionality for filtering out duplicate messages and determining delivery/loss status. + private abstract class Sequencer + { + /// The next sequence ID to use. + internal ushort NextSequenceId => _nextSequenceId++; + private ushort _nextSequenceId = 1; + + /// The connection this sequencer belongs to. + protected readonly Connection connection; + /// The sequence ID of the latest message that we want to acknowledge. + protected ushort lastReceivedSeqId; + /// Sequence IDs of messages which we have (or have not) received and want to acknowledge. + protected readonly Bitfield receivedSeqIds = new Bitfield(); + /// The sequence ID of the latest message that we've received an ack for. + protected ushort lastAckedSeqId; + /// Sequence IDs of messages we sent and which we have (or have not) received acks for. + protected readonly Bitfield ackedSeqIds = new Bitfield(false); + + /// Initializes the sequencer. + /// The connection this sequencer belongs to. + protected Sequencer(Connection connection) + { + this.connection = connection; + } + + /// Determines whether or not to handle a message with the given sequence ID. + /// The sequence ID in question. + /// Whether or not to handle the message. + internal abstract bool ShouldHandle(ushort sequenceId); + + /// Updates which messages we've received acks for. + /// The latest sequence ID that the other end has received. + /// Sequence IDs which the other end has (or has not) received. + internal abstract void UpdateReceivedAcks(ushort remoteLastReceivedSeqId, ushort remoteReceivedSeqIds); + } + + /// + private class NotifySequencer : Sequencer + { + /// + internal NotifySequencer(Connection connection) : base(connection) { } + + /// Inserts the notify header into the given message. + /// The message to insert the header into. + /// The sequence ID of the message. + internal ushort InsertHeader(Message message) + { + ushort sequenceId = NextSequenceId; + ulong notifyBits = lastReceivedSeqId | ((ulong)receivedSeqIds.First8 << (2 * Converter.BitsPerByte)) | ((ulong)sequenceId << (3 * Converter.BitsPerByte)); + message.SetBits(notifyBits, 5 * Converter.BitsPerByte, Message.HeaderBits); + return sequenceId; + } + + /// + /// Duplicate and out of order messages are filtered out and not handled. + internal override bool ShouldHandle(ushort sequenceId) + { + int sequenceGap = Helper.GetSequenceGap(sequenceId, lastReceivedSeqId); + + if (sequenceGap > 0) + { + // The received sequence ID is newer than the previous one + receivedSeqIds.ShiftBy(sequenceGap); + lastReceivedSeqId = sequenceId; + + if (receivedSeqIds.IsSet(sequenceGap)) + return false; + + receivedSeqIds.Set(sequenceGap); + return true; + } + else + { + // The received sequence ID is older than or the same as the previous one (out of order or duplicate message) + return false; + } + } + + /// + internal override void UpdateReceivedAcks(ushort remoteLastReceivedSeqId, ushort remoteReceivedSeqIds) + { + int sequenceGap = Helper.GetSequenceGap(remoteLastReceivedSeqId, lastAckedSeqId); + + if (sequenceGap > 0) + { + if (sequenceGap > 1) + { + // Deal with messages in the gap + while (sequenceGap > 9) // 9 because a gap of 1 means sequence IDs are consecutive, and notify uses 8 bits for the bitfield. 9 means all 8 bits are in use + { + lastAckedSeqId++; + sequenceGap--; + connection.NotifyLost?.Invoke(lastAckedSeqId); + } + + int bitCount = sequenceGap - 1; + int bit = 1 << bitCount; + for (int i = 0; i < bitCount; i++) + { + lastAckedSeqId++; + bit >>= 1; + if ((remoteReceivedSeqIds & bit) == 0) + connection.OnNotifyLost(lastAckedSeqId); + else + connection.OnNotifyDelivered(lastAckedSeqId); + } + } + + lastAckedSeqId = remoteLastReceivedSeqId; + connection.OnNotifyDelivered(lastAckedSeqId); + } + } + } + + /// + private class ReliableSequencer : Sequencer + { + /// + internal ReliableSequencer(Connection connection) : base(connection) { } + + /// + /// Duplicate messages are filtered out while out of order messages are handled. + internal override bool ShouldHandle(ushort sequenceId) + { + bool doHandle = false; + int sequenceGap = Helper.GetSequenceGap(sequenceId, lastReceivedSeqId); + + if (sequenceGap != 0) + { + // The received sequence ID is different from the previous one + if (sequenceGap > 0) + { + // The received sequence ID is newer than the previous one + if (sequenceGap > 64) + RiptideLogger.Log(LogType.Warning, connection.Peer.LogName, $"The gap between received sequence IDs was very large ({sequenceGap})!"); + + receivedSeqIds.ShiftBy(sequenceGap); + lastReceivedSeqId = sequenceId; + } + else // The received sequence ID is older than the previous one (out of order message) + sequenceGap = -sequenceGap; + + doHandle = !receivedSeqIds.IsSet(sequenceGap); + receivedSeqIds.Set(sequenceGap); + } + + connection.SendAck(sequenceId, lastReceivedSeqId, receivedSeqIds); + return doHandle; + } + + /// Updates which messages we've received acks for. + /// The latest sequence ID that the other end has received. + /// Sequence IDs which the other end has (or has not) received. + internal override void UpdateReceivedAcks(ushort remoteLastReceivedSeqId, ushort remoteReceivedSeqIds) + { + int sequenceGap = Helper.GetSequenceGap(remoteLastReceivedSeqId, lastAckedSeqId); + + if (sequenceGap > 0) + { + // The latest sequence ID that the other end has received is newer than the previous one + if (!ackedSeqIds.HasCapacityFor(sequenceGap, out int overflow)) + { + for (int i = 0; i < overflow; i++) + { + // Resend those messages which haven't been acked and whose sequence IDs are about to be pushed out of the bitfield + if (!ackedSeqIds.CheckAndTrimLast(out int checkedPosition)) + connection.ResendMessage((ushort)(lastAckedSeqId - checkedPosition)); + else + connection.ClearMessage((ushort)(lastAckedSeqId - checkedPosition)); + } + } + + ackedSeqIds.ShiftBy(sequenceGap); + lastAckedSeqId = remoteLastReceivedSeqId; + + for (int i = 0; i < 16; i++) + { + // Clear any messages that have been newly acknowledged + if (!ackedSeqIds.IsSet(i + 1) && (remoteReceivedSeqIds & (1 << i)) != 0) + connection.ClearMessage((ushort)(lastAckedSeqId - (i + 1))); + } + + ackedSeqIds.Combine(remoteReceivedSeqIds); + ackedSeqIds.Set(sequenceGap); // Ensure that the bit corresponding to the previous acked sequence ID is set + connection.ClearMessage(remoteLastReceivedSeqId); + } + else if (sequenceGap < 0) + { + // The latest sequence ID that the other end has received is older than the previous one (out of order ack) + ackedSeqIds.Set(-sequenceGap); + } + else + { + // The latest sequence ID that the other end has received is the same as the previous one (duplicate ack) + ackedSeqIds.Combine(remoteReceivedSeqIds); + } + } + } + #endregion + } +} diff --git a/Riptide/EventArgs.cs b/Riptide/EventArgs.cs new file mode 100644 index 0000000..9f4e216 --- /dev/null +++ b/Riptide/EventArgs.cs @@ -0,0 +1,135 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; + +namespace Riptide +{ + /// Contains event data for when a client connects to the server. + public class ServerConnectedEventArgs : EventArgs + { + /// The newly connected client. + public readonly Connection Client; + + /// Initializes event data. + /// The newly connected client. + public ServerConnectedEventArgs(Connection client) + { + Client = client; + } + } + + /// Contains event data for when a connection fails to be fully established. + public class ServerConnectionFailedEventArgs : EventArgs + { + /// The connection that failed to be established. + public readonly Connection Client; + + /// Initializes event data. + /// The connection that failed to be established. + public ServerConnectionFailedEventArgs(Connection client) + { + Client = client; + } + } + + /// Contains event data for when a client disconnects from the server. + public class ServerDisconnectedEventArgs : EventArgs + { + /// The client that disconnected. + public readonly Connection Client; + /// The reason for the disconnection. + public readonly DisconnectReason Reason; + + /// Initializes event data. + /// The client that disconnected. + /// The reason for the disconnection. + public ServerDisconnectedEventArgs(Connection client, DisconnectReason reason) + { + Client = client; + Reason = reason; + } + } + + /// Contains event data for when a message is received. + public class MessageReceivedEventArgs : EventArgs + { + /// The connection from which the message was received. + public readonly Connection FromConnection; + /// The ID of the message. + public readonly ushort MessageId; + /// The received message. + public readonly Message Message; + + /// Initializes event data. + /// The connection from which the message was received. + /// The ID of the message. + /// The received message. + public MessageReceivedEventArgs(Connection fromConnection, ushort messageId, Message message) + { + FromConnection = fromConnection; + MessageId = messageId; + Message = message; + } + } + + /// Contains event data for when a connection attempt to a server fails. + public class ConnectionFailedEventArgs : EventArgs + { + /// The reason for the connection failure. + public readonly RejectReason Reason; + /// Additional data related to the failed connection attempt (if any). + public readonly Message Message; + + /// Initializes event data. + /// The reason for the connection failure. + /// Additional data related to the failed connection attempt (if any). + public ConnectionFailedEventArgs(RejectReason reason, Message message) + { + Reason = reason; + Message = message; + } + } + + /// Contains event data for when the client disconnects from a server. + public class DisconnectedEventArgs : EventArgs + { + /// The reason for the disconnection. + public readonly DisconnectReason Reason; + /// Additional data related to the disconnection (if any). + public readonly Message Message; + + /// Initializes event data. + /// The reason for the disconnection. + /// Additional data related to the disconnection (if any). + public DisconnectedEventArgs(DisconnectReason reason, Message message) + { + Reason = reason; + Message = message; + } + } + + /// Contains event data for when a non-local client connects to the server. + public class ClientConnectedEventArgs : EventArgs + { + /// The numeric ID of the client that connected. + public readonly ushort Id; + + /// Initializes event data. + /// The numeric ID of the client that connected. + public ClientConnectedEventArgs(ushort id) => Id = id; + } + + /// Contains event data for when a non-local client disconnects from the server. + public class ClientDisconnectedEventArgs : EventArgs + { + /// The numeric ID of the client that disconnected. + public readonly ushort Id; + + /// Initializes event data. + /// The numeric ID of the client that disconnected. + public ClientDisconnectedEventArgs(ushort id) => Id = id; + } +} diff --git a/Riptide/Exceptions.cs b/Riptide/Exceptions.cs new file mode 100644 index 0000000..2d01007 --- /dev/null +++ b/Riptide/Exceptions.cs @@ -0,0 +1,197 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using Riptide.Utils; +using System; +using System.Reflection; + +namespace Riptide +{ + /// The exception that is thrown when a does not contain enough unwritten bits to perform an operation. + public class InsufficientCapacityException : Exception + { + /// The message with insufficient remaining capacity. + public readonly Message RiptideMessage; + /// The name of the type which could not be added to the message. + public readonly string TypeName; + /// The number of available bits the type requires in order to be added successfully. + public readonly int RequiredBits; + + /// Initializes a new instance. + public InsufficientCapacityException() { } + /// Initializes a new instance with a specified error message. + /// The error message that explains the reason for the exception. + public InsufficientCapacityException(string message) : base(message) { } + /// Initializes a new instance with a specified error message and a reference to the inner exception that is the cause of this exception. + /// The error message that explains the reason for the exception. + /// The exception that is the cause of the current exception. If is not a null reference, the current exception is raised in a catch block that handles the inner exception. + public InsufficientCapacityException(string message, Exception inner) : base(message, inner) { } + /// Initializes a new instance and constructs an error message from the given information. + /// The message with insufficient remaining capacity. + /// The number of bits which were attempted to be reserved. + public InsufficientCapacityException(Message message, int reserveBits) : base(GetErrorMessage(message, reserveBits)) + { + RiptideMessage = message; + TypeName = "reservation"; + RequiredBits = reserveBits; + } + /// Initializes a new instance and constructs an error message from the given information. + /// The message with insufficient remaining capacity. + /// The name of the type which could not be added to the message. + /// The number of available bits required for the type to be added successfully. + public InsufficientCapacityException(Message message, string typeName, int requiredBits) : base(GetErrorMessage(message, typeName, requiredBits)) + { + RiptideMessage = message; + TypeName = typeName; + RequiredBits = requiredBits; + } + /// Initializes a new instance and constructs an error message from the given information. + /// The message with insufficient remaining capacity. + /// The length of the array which could not be added to the message. + /// The name of the array's type. + /// The number of available bits required for a single element of the array to be added successfully. + public InsufficientCapacityException(Message message, int arrayLength, string typeName, int requiredBits) : base(GetErrorMessage(message, arrayLength, typeName, requiredBits)) + { + RiptideMessage = message; + TypeName = $"{typeName}[]"; + RequiredBits = requiredBits * arrayLength; + } + + /// Constructs the error message from the given information. + /// The error message. + private static string GetErrorMessage(Message message, int reserveBits) + { + return $"Cannot reserve {reserveBits} {Helper.CorrectForm(reserveBits, "bit")} in a message with {message.UnwrittenBits} " + + $"{Helper.CorrectForm(message.UnwrittenBits, "bit")} of remaining capacity!"; + } + /// Constructs the error message from the given information. + /// The error message. + private static string GetErrorMessage(Message message, string typeName, int requiredBits) + { + return $"Cannot add a value of type '{typeName}' (requires {requiredBits} {Helper.CorrectForm(requiredBits, "bit")}) to " + + $"a message with {message.UnwrittenBits} {Helper.CorrectForm(message.UnwrittenBits, "bit")} of remaining capacity!"; + } + /// Constructs the error message from the given information. + /// The error message. + private static string GetErrorMessage(Message message, int arrayLength, string typeName, int requiredBits) + { + requiredBits *= arrayLength; + return $"Cannot add an array of type '{typeName}[]' with {arrayLength} {Helper.CorrectForm(arrayLength, "element")} (requires {requiredBits} {Helper.CorrectForm(requiredBits, "bit")}) " + + $"to a message with {message.UnwrittenBits} {Helper.CorrectForm(message.UnwrittenBits, "bit")} of remaining capacity!"; + } + } + + /// The exception that is thrown when a method with a is not marked as . + public class NonStaticHandlerException : Exception + { + /// The type containing the handler method. + public readonly Type DeclaringType; + /// The name of the handler method. + public readonly string HandlerMethodName; + + /// Initializes a new instance. + public NonStaticHandlerException() { } + /// Initializes a new instance with a specified error message. + /// The error message that explains the reason for the exception. + public NonStaticHandlerException(string message) : base(message) { } + /// Initializes a new instance with a specified error message and a reference to the inner exception that is the cause of this exception. + /// The error message that explains the reason for the exception. + /// The exception that is the cause of the current exception. If is not a null reference, the current exception is raised in a catch block that handles the inner exception. + public NonStaticHandlerException(string message, Exception inner) : base(message, inner) { } + /// Initializes a new instance and constructs an error message from the given information. + /// The type containing the handler method. + /// The name of the handler method. + public NonStaticHandlerException(Type declaringType, string handlerMethodName) : base(GetErrorMessage(declaringType, handlerMethodName)) + { + DeclaringType = declaringType; + HandlerMethodName = handlerMethodName; + } + + /// Constructs the error message from the given information. + /// The error message. + private static string GetErrorMessage(Type declaringType, string handlerMethodName) + { + return $"'{declaringType.Name}.{handlerMethodName}' is an instance method, but message handler methods must be static!"; + } + } + + /// The exception that is thrown when a method with a does not have an acceptable message handler method signature (either or ). + public class InvalidHandlerSignatureException : Exception + { + /// The type containing the handler method. + public readonly Type DeclaringType; + /// The name of the handler method. + public readonly string HandlerMethodName; + + /// Initializes a new instance. + public InvalidHandlerSignatureException() { } + /// Initializes a new instance with a specified error message. + /// The error message that explains the reason for the exception. + public InvalidHandlerSignatureException(string message) : base(message) { } + /// Initializes a new instance with a specified error message and a reference to the inner exception that is the cause of this exception. + /// The error message that explains the reason for the exception. + /// The exception that is the cause of the current exception. If is not a null reference, the current exception is raised in a catch block that handles the inner exception. + public InvalidHandlerSignatureException(string message, Exception inner) : base(message, inner) { } + /// Initializes a new instance and constructs an error message from the given information. + /// The type containing the handler method. + /// The name of the handler method. + public InvalidHandlerSignatureException(Type declaringType, string handlerMethodName) : base(GetErrorMessage(declaringType, handlerMethodName)) + { + DeclaringType = declaringType; + HandlerMethodName = handlerMethodName; + } + + /// Constructs the error message from the given information. + /// The error message. + private static string GetErrorMessage(Type declaringType, string handlerMethodName) + { + return $"'{declaringType.Name}.{handlerMethodName}' doesn't match any acceptable message handler method signatures! Server message handler methods should have a 'ushort' and a '{nameof(Riptide.Message)}' parameter, while client message handler methods should only have a '{nameof(Riptide.Message)}' parameter."; + } + } + + /// The exception that is thrown when multiple methods with s are set to handle messages with the same ID and have the same method signature. + public class DuplicateHandlerException : Exception + { + /// The message ID with multiple handler methods. + public readonly ushort Id; + /// The type containing the first handler method. + public readonly Type DeclaringType1; + /// The name of the first handler method. + public readonly string HandlerMethodName1; + /// The type containing the second handler method. + public readonly Type DeclaringType2; + /// The name of the second handler method. + public readonly string HandlerMethodName2; + + /// Initializes a new instance with a specified error message. + public DuplicateHandlerException() { } + /// Initializes a new instance with a specified error message. + /// The error message that explains the reason for the exception. + public DuplicateHandlerException(string message) : base(message) { } + /// Initializes a new instance with a specified error message and a reference to the inner exception that is the cause of this exception. + /// The error message that explains the reason for the exception. + /// The exception that is the cause of the current exception. If is not a null reference, the current exception is raised in a catch block that handles the inner exception. + public DuplicateHandlerException(string message, Exception inner) : base(message, inner) { } + /// Initializes a new instance and constructs an error message from the given information. + /// The message ID with multiple handler methods. + /// The first handler method's info. + /// The second handler method's info. + public DuplicateHandlerException(ushort id, MethodInfo method1, MethodInfo method2) : base(GetErrorMessage(id, method1, method2)) + { + Id = id; + DeclaringType1 = method1.DeclaringType; + HandlerMethodName1 = method1.Name; + DeclaringType2 = method2.DeclaringType; + HandlerMethodName2 = method2.Name; + } + + /// Constructs the error message from the given information. + /// The error message. + private static string GetErrorMessage(ushort id, MethodInfo method1, MethodInfo method2) + { + return $"Message handler methods '{method1.DeclaringType.Name}.{method1.Name}' and '{method2.DeclaringType.Name}.{method2.Name}' are both set to handle messages with ID {id}! Only one handler method is allowed per message ID!"; + } + } +} diff --git a/Riptide/IMessageSerializable.cs b/Riptide/IMessageSerializable.cs new file mode 100644 index 0000000..635367c --- /dev/null +++ b/Riptide/IMessageSerializable.cs @@ -0,0 +1,18 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +namespace Riptide +{ + /// Represents a type that can be added to and retrieved from messages using the and methods. + public interface IMessageSerializable + { + /// Adds the type to the message. + /// The message to add the type to. + void Serialize(Message message); + /// Retrieves the type from the message. + /// The message to retrieve the type from. + void Deserialize(Message message); + } +} diff --git a/Riptide/Message.cs b/Riptide/Message.cs new file mode 100644 index 0000000..09b83fe --- /dev/null +++ b/Riptide/Message.cs @@ -0,0 +1,1922 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using Riptide.Transports; +using Riptide.Utils; +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Riptide +{ + /// The send mode of a . + public enum MessageSendMode : byte + { + /// Guarantees order but not delivery. Notifies the sender of what happened via the and + /// events. The receiver must handle notify messages via the event, which is different from the other two send modes. + Notify = MessageHeader.Notify, + /// Guarantees neither delivery nor order. + Unreliable = MessageHeader.Unreliable, + /// Guarantees delivery but not order. + Reliable = MessageHeader.Reliable, + } + + /// Provides functionality for converting data to bytes and vice versa. + public class Message + { + /// The maximum number of bits required for a message's header. + public const int MaxHeaderSize = NotifyHeaderBits; + /// The number of bits used by the . + internal const int HeaderBits = 4; + /// A bitmask that, when applied, only keeps the bits corresponding to the value. + internal const byte HeaderBitmask = (1 << HeaderBits) - 1; + /// The header size for unreliable messages. Does not count the 2 bytes used for the message ID. + /// 4 bits - header. + internal const int UnreliableHeaderBits = HeaderBits; + /// The header size for reliable messages. Does not count the 2 bytes used for the message ID. + /// 4 bits - header, 16 bits - sequence ID. + internal const int ReliableHeaderBits = HeaderBits + 2 * BitsPerByte; + /// The header size for notify messages. + /// 4 bits - header, 24 bits - ack, 16 bits - sequence ID. + internal const int NotifyHeaderBits = HeaderBits + 5 * BitsPerByte; + /// The minimum number of bytes contained in an unreliable message. + internal const int MinUnreliableBytes = UnreliableHeaderBits / BitsPerByte + (UnreliableHeaderBits % BitsPerByte == 0 ? 0 : 1); + /// The minimum number of bytes contained in a reliable message. + internal const int MinReliableBytes = ReliableHeaderBits / BitsPerByte + (ReliableHeaderBits % BitsPerByte == 0 ? 0 : 1); + /// The minimum number of bytes contained in a notify message. + internal const int MinNotifyBytes = NotifyHeaderBits / BitsPerByte + (NotifyHeaderBits % BitsPerByte == 0 ? 0 : 1); + /// The number of bits in a byte. + private const int BitsPerByte = Converter.BitsPerByte; + /// The number of bits in each data segment. + private const int BitsPerSegment = Converter.BitsPerULong; + + /// The maximum number of bytes that a message can contain, including the . + public static int MaxSize { get; private set; } + /// The maximum number of bytes of payload data that a message can contain. This value represents how many bytes can be added to a message on top of the . + public static int MaxPayloadSize + { + get => MaxSize - (MaxHeaderSize / BitsPerByte + (MaxHeaderSize % BitsPerByte == 0 ? 0 : 1)); + set + { + if (Peer.ActiveCount > 0) + throw new InvalidOperationException($"Changing the '{nameof(MaxPayloadSize)}' is not allowed while a {nameof(Server)} or {nameof(Client)} is running!"); + + if (value < 0) + throw new ArgumentOutOfRangeException(nameof(value), $"'{nameof(MaxPayloadSize)}' cannot be negative!"); + + MaxSize = MaxHeaderSize / BitsPerByte + (MaxHeaderSize % BitsPerByte == 0 ? 0 : 1) + value; + maxBitCount = MaxSize * BitsPerByte; + maxArraySize = MaxSize / sizeof(ulong) + (MaxSize % sizeof(ulong) == 0 ? 0 : 1); + ByteBuffer = new byte[MaxSize]; + TrimPool(); // When ActiveSocketCount is 0, this clears the pool + PendingMessage.ClearPool(); + } + } + /// An intermediary buffer to help convert to a byte array when sending. + internal static byte[] ByteBuffer; + /// The maximum number of bits a message can contain. + private static int maxBitCount; + /// The maximum size of the array. + private static int maxArraySize; + + /// How many messages to add to the pool for each or instance that is started. + /// Changes will not affect and instances which are already running until they are restarted. + public static byte InstancesPerPeer { get; set; } = 4; + /// A pool of reusable message instances. + private static readonly List pool = new List(InstancesPerPeer * 2); + + static Message() + { + MaxSize = MaxHeaderSize / BitsPerByte + (MaxHeaderSize % BitsPerByte == 0 ? 0 : 1) + 1225; + maxBitCount = MaxSize * BitsPerByte; + maxArraySize = MaxSize / sizeof(ulong) + (MaxSize % sizeof(ulong) == 0 ? 0 : 1); + ByteBuffer = new byte[MaxSize]; + } + + /// The message's send mode. + public MessageSendMode SendMode { get; private set; } + /// How many bits have been retrieved from the message. + public int ReadBits => readBit; + /// How many unretrieved bits remain in the message. + public int UnreadBits => writeBit - readBit; + /// How many bits have been added to the message. + public int WrittenBits => writeBit; + /// How many more bits can be added to the message. + public int UnwrittenBits => maxBitCount - writeBit; + /// How many of this message's bytes are in use. Rounds up to the next byte because only whole bytes can be sent. + public int BytesInUse => writeBit / BitsPerByte + (writeBit % BitsPerByte == 0 ? 0 : 1); + /// How many bytes have been retrieved from the message. + [Obsolete("Use ReadBits instead.")] public int ReadLength => ReadBits / BitsPerByte + (ReadBits % BitsPerByte == 0 ? 0 : 1); + /// How many more bytes can be retrieved from the message. + [Obsolete("Use UnreadBits instead.")] public int UnreadLength => UnreadBits / BitsPerByte + (UnreadBits % BitsPerByte == 0 ? 0 : 1); + /// How many bytes have been added to the message. + [Obsolete("Use WrittenBits instead.")] public int WrittenLength => WrittenBits / BitsPerByte + (WrittenBits % BitsPerByte == 0 ? 0 : 1); + /// + internal ulong[] Data => data; + + /// The message's data. + private readonly ulong[] data; + /// The next bit to be read. + private int readBit; + /// The next bit to be written. + private int writeBit; + + /// Initializes a reusable instance. + private Message() => data = new ulong[maxArraySize]; + + /// Gets a completely empty message instance with no header. + /// An empty message instance. + public static Message Create() + { + Message message = RetrieveFromPool(); + message.readBit = 0; + message.writeBit = 0; + return message; + } + /// Gets a message instance that can be used for sending. + /// The mode in which the message should be sent. + /// A message instance ready to be sent. + /// This method is primarily intended for use with as notify messages don't have a built-in message ID, and unlike + /// and , this overload does not add a message ID to the message. + public static Message Create(MessageSendMode sendMode) + { + return RetrieveFromPool().Init((MessageHeader)sendMode); + } + /// Gets a message instance that can be used for sending. + /// The mode in which the message should be sent. + /// The message ID. + /// A message instance ready to be sent. + public static Message Create(MessageSendMode sendMode, ushort id) + { + return RetrieveFromPool().Init((MessageHeader)sendMode).AddVarULong(id); + } + /// + /// NOTE: will be cast to a . You should ensure that its value never exceeds that of , otherwise you'll encounter unexpected behaviour when handling messages. + public static Message Create(MessageSendMode sendMode, Enum id) + { + return Create(sendMode, (ushort)(object)id); + } + /// Gets a message instance that can be used for sending. + /// The message's header type. + /// A message instance ready to be sent. + internal static Message Create(MessageHeader header) + { + return RetrieveFromPool().Init(header); + } + + #region Pooling + /// Trims the message pool to a more appropriate size for how many and/or instances are currently running. + public static void TrimPool() + { + if (Peer.ActiveCount == 0) + { + // No Servers or Clients are running, empty the list and reset the capacity + pool.Clear(); + pool.Capacity = InstancesPerPeer * 2; // x2 so there's some buffer room for extra Message instances in the event that more are needed + } + else + { + // Reset the pool capacity and number of Message instances in the pool to what is appropriate for how many Servers & Clients are active + int idealInstanceAmount = Peer.ActiveCount * InstancesPerPeer; + if (pool.Count > idealInstanceAmount) + { + pool.RemoveRange(Peer.ActiveCount * InstancesPerPeer, pool.Count - idealInstanceAmount); + pool.Capacity = idealInstanceAmount * 2; + } + } + } + + /// Retrieves a message instance from the pool. If none is available, a new instance is created. + /// A message instance ready to be used for sending or handling. + private static Message RetrieveFromPool() + { + Message message; + if (pool.Count > 0) + { + message = pool[0]; + pool.RemoveAt(0); + } + else + message = new Message(); + + return message; + } + + /// Returns the message instance to the internal pool so it can be reused. + public void Release() + { + if (pool.Count < pool.Capacity) + { + // Pool exists and there's room + if (!pool.Contains(this)) + pool.Add(this); // Only add it if it's not already in the list, otherwise this method being called twice in a row for whatever reason could cause *serious* issues + } + } + #endregion + + #region Functions + /// Initializes the message so that it can be used for sending. + /// The message's header type. + /// The message, ready to be used for sending. + private Message Init(MessageHeader header) + { + data[0] = (byte)header; + SetHeader(header); + return this; + } + /// Initializes the message so that it can be used for receiving/handling. + /// The first byte of the received data. + /// The message's header type. + /// The number of bytes which this message will contain. + /// The message, ready to be used for handling. + internal Message Init(byte firstByte, int contentLength, out MessageHeader header) + { + data[0] = firstByte; + header = (MessageHeader)(firstByte & HeaderBitmask); + SetHeader(header); + writeBit = contentLength * BitsPerByte; + return this; + } + + /// Sets the message's header bits to the given and determines the appropriate and read/write positions. + /// The header to use for this message. + private void SetHeader(MessageHeader header) + { + if (header == MessageHeader.Notify) + { + readBit = NotifyHeaderBits; + writeBit = NotifyHeaderBits; + SendMode = MessageSendMode.Notify; + } + else if (header >= MessageHeader.Reliable) + { + readBit = ReliableHeaderBits; + writeBit = ReliableHeaderBits; + SendMode = MessageSendMode.Reliable; + } + else + { + readBit = UnreliableHeaderBits; + writeBit = UnreliableHeaderBits; + SendMode = MessageSendMode.Unreliable; + } + } + #endregion + + #region Add & Retrieve Data + #region Message + /// Adds 's unread bits to the message. + /// The message whose unread bits to add. + /// The message that the bits were added to. + /// This method does not move 's internal read position! + public Message AddMessage(Message message) => AddMessage(message, message.UnreadBits, message.readBit); + /// Adds a range of bits from to the message. + /// The message whose bits to add. + /// The number of bits to add. + /// The position in from which to add the bits. + /// The message that the bits were added to. + /// This method does not move 's internal read position! + public Message AddMessage(Message message, int amount, int startBit) + { + if (UnwrittenBits < amount) + throw new InsufficientCapacityException(this, nameof(Message), amount); + + int sourcePos = startBit / BitsPerSegment; + int sourceBit = startBit % BitsPerSegment; + int destPos = writeBit / BitsPerSegment; + int destBit = writeBit % BitsPerSegment; + int bitOffset = destBit - sourceBit; + int destSegments = (writeBit + amount) / BitsPerSegment - destPos + 1; + + if (bitOffset == 0) + { + // Source doesn't need to be shifted, source and dest bits span the same number of segments + ulong firstSegment = message.data[sourcePos]; + if (destBit == 0) + data[destPos] = firstSegment; + else + data[destPos] |= firstSegment & ~((1ul << sourceBit) - 1); + + for (int i = 1; i < destSegments; i++) + data[destPos + i] = message.data[sourcePos + i]; + } + else if (bitOffset > 0) + { + // Source needs to be shifted left, dest bits may span more segments than source bits + ulong firstSegment = message.data[sourcePos] & ~((1ul << sourceBit) - 1); + firstSegment <<= bitOffset; + if (destBit == 0) + data[destPos] = firstSegment; + else + data[destPos] |= firstSegment; + + for (int i = 1; i < destSegments; i++) + data[destPos + i] = (message.data[sourcePos + i - 1] >> (BitsPerSegment - bitOffset)) | (message.data[sourcePos + i] << bitOffset); + } + else + { + // Source needs to be shifted right, source bits may span more segments than dest bits + bitOffset = -bitOffset; + ulong firstSegment = message.data[sourcePos] & ~((1ul << sourceBit) - 1); + firstSegment >>= bitOffset; + if (destBit == 0) + data[destPos] = firstSegment; + else + data[destPos] |= firstSegment; + + int sourceSegments = (startBit + amount) / BitsPerSegment - sourcePos + 1; + for (int i = 1; i < sourceSegments; i++) + { + data[destPos + i - 1] |= message.data[sourcePos + i] << (BitsPerSegment - bitOffset); + data[destPos + i ] = message.data[sourcePos + i] >> bitOffset; + } + } + + writeBit += amount; + data[destPos + destSegments - 1] &= (1ul << (writeBit % BitsPerSegment)) - 1; + return this; + } + #endregion + + #region Bits + /// Moves the message's internal write position by the given of bits, reserving them so they can be set at a later time. + /// The number of bits to reserve. + /// The message instance. + public Message ReserveBits(int amount) + { + if (UnwrittenBits < amount) + throw new InsufficientCapacityException(this, amount); + + int bit = writeBit % BitsPerSegment; + writeBit += amount; + + // Reset the last segment that the reserved range touches, unless it's also the first one, in which case it may already contain data which we don't want to overwrite + if (bit + amount >= BitsPerSegment) + data[writeBit / BitsPerSegment] = 0; + + return this; + } + + /// Moves the message's internal read position by the given of bits, skipping over them. + /// The number of bits to skip. + /// The message instance. + public Message SkipBits(int amount) + { + if (UnreadBits < amount) + RiptideLogger.Log(LogType.Error, $"Message only contains {UnreadBits} unread {Helper.CorrectForm(UnreadBits, "bit")}, which is not enough to skip {amount}!"); + + readBit += amount; + return this; + } + + /// Sets up to 64 bits at the specified position in the message. + /// The bits to write into the message. + /// The number of bits to set. + /// The bit position in the message at which to start writing. + /// The message instance. + /// This method can be used to directly set a range of bits anywhere in the message without moving its internal write position. Data which was previously added to + /// the message and which falls within the range of bits being set will be overwritten, meaning that improper use of this method will likely corrupt the message! + public Message SetBits(ulong bitfield, int amount, int startBit) + { + if (amount > sizeof(ulong) * BitsPerByte) + throw new ArgumentOutOfRangeException(nameof(amount), $"Cannot set more than {sizeof(ulong) * BitsPerByte} bits at a time!"); + + Converter.SetBits(bitfield, amount, data, startBit); + return this; + } + + /// Retrieves up to 8 bits from the specified position in the message. + /// The number of bits to peek. + /// The bit position in the message at which to start peeking. + /// The bits that were retrieved. + /// The message instance. + /// This method can be used to retrieve a range of bits from anywhere in the message without moving its internal read position. + public Message PeekBits(int amount, int startBit, out byte bitfield) + { + if (amount > BitsPerByte) + throw new ArgumentOutOfRangeException(nameof(amount), $"This '{nameof(PeekBits)}' overload cannot be used to peek more than {BitsPerByte} bits at a time!"); + + Converter.GetBits(amount, data, startBit, out bitfield); + return this; + } + /// Retrieves up to 16 bits from the specified position in the message. + /// + public Message PeekBits(int amount, int startBit, out ushort bitfield) + { + if (amount > sizeof(ushort) * BitsPerByte) + throw new ArgumentOutOfRangeException(nameof(amount), $"This '{nameof(PeekBits)}' overload cannot be used to peek more than {sizeof(ushort) * BitsPerByte} bits at a time!"); + + Converter.GetBits(amount, data, startBit, out bitfield); + return this; + } + /// Retrieves up to 32 bits from the specified position in the message. + /// + public Message PeekBits(int amount, int startBit, out uint bitfield) + { + if (amount > sizeof(uint) * BitsPerByte) + throw new ArgumentOutOfRangeException(nameof(amount), $"This '{nameof(PeekBits)}' overload cannot be used to peek more than {sizeof(uint) * BitsPerByte} bits at a time!"); + + Converter.GetBits(amount, data, startBit, out bitfield); + return this; + } + /// Retrieves up to 64 bits from the specified position in the message. + /// + public Message PeekBits(int amount, int startBit, out ulong bitfield) + { + if (amount > sizeof(ulong) * BitsPerByte) + throw new ArgumentOutOfRangeException(nameof(amount), $"This '{nameof(PeekBits)}' overload cannot be used to peek more than {sizeof(ulong) * BitsPerByte} bits at a time!"); + + Converter.GetBits(amount, data, startBit, out bitfield); + return this; + } + + /// Adds up to 8 of the given bits to the message. + /// The bits to add. + /// The number of bits to add. + /// The message that the bits were added to. + public Message AddBits(byte bitfield, int amount) + { + if (amount > BitsPerByte) + throw new ArgumentOutOfRangeException(nameof(amount), $"This '{nameof(AddBits)}' overload cannot be used to add more than {BitsPerByte} bits at a time!"); + + bitfield &= (byte)((1 << amount) - 1); // Discard any bits that are set beyond the ones we're setting + Converter.ByteToBits(bitfield, data, writeBit); + writeBit += amount; + return this; + } + /// Adds up to 16 of the given bits to the message. + /// + public Message AddBits(ushort bitfield, int amount) + { + if (amount > sizeof(ushort) * BitsPerByte) + throw new ArgumentOutOfRangeException(nameof(amount), $"This '{nameof(AddBits)}' overload cannot be used to add more than {sizeof(ushort) * BitsPerByte} bits at a time!"); + + bitfield &= (ushort)((1 << amount) - 1); // Discard any bits that are set beyond the ones we're adding + Converter.UShortToBits(bitfield, data, writeBit); + writeBit += amount; + return this; + } + /// Adds up to 32 of the given bits to the message. + /// + public Message AddBits(uint bitfield, int amount) + { + if (amount > sizeof(uint) * BitsPerByte) + throw new ArgumentOutOfRangeException(nameof(amount), $"This '{nameof(AddBits)}' overload cannot be used to add more than {sizeof(uint) * BitsPerByte} bits at a time!"); + + bitfield &= (1u << (amount - 1) << 1) - 1; // Discard any bits that are set beyond the ones we're adding + Converter.UIntToBits(bitfield, data, writeBit); + writeBit += amount; + return this; + } + /// Adds up to 64 of the given bits to the message. + /// + public Message AddBits(ulong bitfield, int amount) + { + if (amount > sizeof(ulong) * BitsPerByte) + throw new ArgumentOutOfRangeException(nameof(amount), $"This '{nameof(AddBits)}' overload cannot be used to add more than {sizeof(ulong) * BitsPerByte} bits at a time!"); + + bitfield &= (1ul << (amount - 1) << 1) - 1; // Discard any bits that are set beyond the ones we're adding + Converter.ULongToBits(bitfield, data, writeBit); + writeBit += amount; + return this; + } + + /// Retrieves the next bits (up to 8) from the message. + /// The number of bits to retrieve. + /// The bits that were retrieved. + /// The messages that the bits were retrieved from. + public Message GetBits(int amount, out byte bitfield) + { + PeekBits(amount, readBit, out bitfield); + readBit += amount; + return this; + } + /// Retrieves the next bits (up to 16) from the message. + /// + public Message GetBits(int amount, out ushort bitfield) + { + PeekBits(amount, readBit, out bitfield); + readBit += amount; + return this; + } + /// Retrieves the next bits (up to 32) from the message. + /// + public Message GetBits(int amount, out uint bitfield) + { + PeekBits(amount, readBit, out bitfield); + readBit += amount; + return this; + } + /// Retrieves the next bits (up to 64) from the message. + /// + public Message GetBits(int amount, out ulong bitfield) + { + PeekBits(amount, readBit, out bitfield); + readBit += amount; + return this; + } + #endregion + + #region Varint + /// Adds a positive or negative number to the message, using fewer bits for smaller values. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message AddVarLong(long value) => AddVarULong((ulong)Converter.ZigZagEncode(value)); + /// Adds a positive number to the message, using fewer bits for smaller values. + /// The value to add. + /// The message that the value was added to. + /// The value is added in segments of 8 bits, 1 of which is used to indicate whether or not another segment follows. As a result, small values are + /// added to the message using fewer bits, while large values will require a few more bits than they would if they were added via , + /// , , or (or their signed counterparts). + public Message AddVarULong(ulong value) + { + do + { + byte byteValue = (byte)(value & 0b01111111); + value >>= 7; + if (value != 0) // There's more to write + byteValue |= 0b10000000; + + AddByte(byteValue); + } + while (value != 0); + + return this; + } + + /// Retrieves a positive or negative number from the message, using fewer bits for smaller values. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public long GetVarLong() => Converter.ZigZagDecode((long)GetVarULong()); + /// Retrieves a positive number from the message, using fewer bits for smaller values. + /// The value that was retrieved. + /// The value is retrieved in segments of 8 bits, 1 of which is used to indicate whether or not another segment follows. As a result, small values are + /// retrieved from the message using fewer bits, while large values will require a few more bits than they would if they were retrieved via , + /// , , or (or their signed counterparts). + public ulong GetVarULong() + { + ulong byteValue; + ulong value = 0; + int shift = 0; + + do + { + byteValue = GetByte(); + value |= (byteValue & 0b01111111) << shift; + shift += 7; + } + while ((byteValue & 0b10000000) != 0); + + return value; + } + #endregion + + #region Byte & SByte + /// Adds a to the message. + /// The to add. + /// The message that the was added to. + public Message AddByte(byte value) + { + if (UnwrittenBits < BitsPerByte) + throw new InsufficientCapacityException(this, ByteName, BitsPerByte); + + Converter.ByteToBits(value, data, writeBit); + writeBit += BitsPerByte; + return this; + } + + /// Adds an to the message. + /// The to add. + /// The message that the was added to. + public Message AddSByte(sbyte value) + { + if (UnwrittenBits < BitsPerByte) + throw new InsufficientCapacityException(this, SByteName, BitsPerByte); + + Converter.SByteToBits(value, data, writeBit); + writeBit += BitsPerByte; + return this; + } + + /// Retrieves a from the message. + /// The that was retrieved. + public byte GetByte() + { + if (UnreadBits < BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(ByteName, $"{default(byte)}")); + return default(byte); + } + + byte value = Converter.ByteFromBits(data, readBit); + readBit += BitsPerByte; + return value; + } + + /// Retrieves an from the message. + /// The that was retrieved. + public sbyte GetSByte() + { + if (UnreadBits < BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(SByteName, $"{default(sbyte)}")); + return default(sbyte); + } + + sbyte value = Converter.SByteFromBits(data, readBit); + readBit += BitsPerByte; + return value; + } + + /// Adds a array to the message. + /// The array to add. + /// Whether or not to include the length of the array in the message. + /// The message that the array was added to. + public Message AddBytes(byte[] array, bool includeLength = true) + { + if (includeLength) + AddVarULong((uint)array.Length); + + if (UnwrittenBits < array.Length * BitsPerByte) + throw new InsufficientCapacityException(this, array.Length, ByteName, BitsPerByte); + + if (writeBit % BitsPerByte == 0) + { + Buffer.BlockCopy(array, 0, data, writeBit / BitsPerByte, array.Length); + writeBit += array.Length * BitsPerByte; + } + else + { + for (int i = 0; i < array.Length; i++) + { + Converter.ByteToBits(array[i], data, writeBit); + writeBit += BitsPerByte; + } + } + + return this; + } + + /// Adds an array to the message. + /// The array to add. + /// Whether or not to include the length of the array in the message. + /// The message that the array was added to. + public Message AddSBytes(sbyte[] array, bool includeLength = true) + { + if (includeLength) + AddVarULong((uint)array.Length); + + if (UnwrittenBits < array.Length * BitsPerByte) + throw new InsufficientCapacityException(this, array.Length, SByteName, BitsPerByte); + + for (int i = 0; i < array.Length; i++) + { + Converter.SByteToBits(array[i], data, writeBit); + writeBit += BitsPerByte; + } + + return this; + } + + /// Retrieves a array from the message. + /// The array that was retrieved. + public byte[] GetBytes() => GetBytes((int)GetVarULong()); + /// Retrieves a array from the message. + /// The amount of bytes to retrieve. + /// The array that was retrieved. + public byte[] GetBytes(int amount) + { + byte[] array = new byte[amount]; + ReadBytes(amount, array); + return array; + } + /// Populates a array with bytes retrieved from the message. + /// The amount of bytes to retrieve. + /// The array to populate. + /// The position at which to start populating the array. + public void GetBytes(int amount, byte[] intoArray, int startIndex = 0) + { + if (startIndex + amount > intoArray.Length) + throw new ArgumentException(nameof(amount), ArrayNotLongEnoughError(amount, intoArray.Length, startIndex, ByteName)); + + ReadBytes(amount, intoArray, startIndex); + } + + /// Retrieves an array from the message. + /// The array that was retrieved. + public sbyte[] GetSBytes() => GetSBytes((int)GetVarULong()); + /// Retrieves an array from the message. + /// The amount of sbytes to retrieve. + /// The array that was retrieved. + public sbyte[] GetSBytes(int amount) + { + sbyte[] array = new sbyte[amount]; + ReadSBytes(amount, array); + return array; + } + /// Populates a array with bytes retrieved from the message. + /// The amount of sbytes to retrieve. + /// The array to populate. + /// The position at which to start populating . + public void GetSBytes(int amount, sbyte[] intArray, int startIndex = 0) + { + if (startIndex + amount > intArray.Length) + throw new ArgumentException(nameof(amount), ArrayNotLongEnoughError(amount, intArray.Length, startIndex, SByteName)); + + ReadSBytes(amount, intArray, startIndex); + } + + /// Reads a number of bytes from the message and writes them into the given array. + /// The amount of bytes to read. + /// The array to write the bytes into. + /// The position at which to start writing into the array. + private void ReadBytes(int amount, byte[] intoArray, int startIndex = 0) + { + if (UnreadBits < amount * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(amount, ByteName)); + amount = UnreadBits / BitsPerByte; + } + + if (readBit % BitsPerByte == 0) + { + Buffer.BlockCopy(data, readBit / BitsPerByte, intoArray, startIndex, amount); + readBit += amount * BitsPerByte; + } + else + { + for (int i = 0; i < amount; i++) + { + intoArray[startIndex + i] = Converter.ByteFromBits(data, readBit); + readBit += BitsPerByte; + } + } + } + + /// Reads a number of sbytes from the message and writes them into the given array. + /// The amount of sbytes to read. + /// The array to write the sbytes into. + /// The position at which to start writing into the array. + private void ReadSBytes(int amount, sbyte[] intoArray, int startIndex = 0) + { + if (UnreadBits < amount * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(amount, SByteName)); + amount = UnreadBits / BitsPerByte; + } + + for (int i = 0; i < amount; i++) + { + intoArray[startIndex + i] = Converter.SByteFromBits(data, readBit); + readBit += BitsPerByte; + } + } + #endregion + + #region Bool + /// Adds a to the message. + /// The to add. + /// The message that the was added to. + public Message AddBool(bool value) + { + if (UnwrittenBits < 1) + throw new InsufficientCapacityException(this, BoolName, 1); + + Converter.BoolToBit(value, data, writeBit++); + return this; + } + + /// Retrieves a from the message. + /// The that was retrieved. + public bool GetBool() + { + if (UnreadBits < 1) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(BoolName, $"{default(bool)}")); + return default(bool); + } + + return Converter.BoolFromBit(data, readBit++); + } + + /// Adds a array to the message. + /// The array to add. + /// Whether or not to include the length of the array in the message. + /// The message that the array was added to. + public Message AddBools(bool[] array, bool includeLength = true) + { + if (includeLength) + AddVarULong((uint)array.Length); + + if (UnwrittenBits < array.Length) + throw new InsufficientCapacityException(this, array.Length, BoolName, 1); + + for (int i = 0; i < array.Length; i++) + Converter.BoolToBit(array[i], data, writeBit++); + + return this; + } + + /// Retrieves a array from the message. + /// The array that was retrieved. + public bool[] GetBools() => GetBools((int)GetVarULong()); + /// Retrieves a array from the message. + /// The amount of bools to retrieve. + /// The array that was retrieved. + public bool[] GetBools(int amount) + { + bool[] array = new bool[amount]; + ReadBools(amount, array); + return array; + } + /// Populates a array with bools retrieved from the message. + /// The amount of bools to retrieve. + /// The array to populate. + /// The position at which to start populating the array. + public void GetBools(int amount, bool[] intoArray, int startIndex = 0) + { + if (startIndex + amount > intoArray.Length) + throw new ArgumentException(nameof(amount), ArrayNotLongEnoughError(amount, intoArray.Length, startIndex, BoolName)); + + ReadBools(amount, intoArray, startIndex); + } + + /// Reads a number of bools from the message and writes them into the given array. + /// The amount of bools to read. + /// The array to write the bools into. + /// The position at which to start writing into the array. + private void ReadBools(int amount, bool[] intoArray, int startIndex = 0) + { + if (UnreadBits < amount) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(amount, BoolName)); + amount = UnreadBits; + } + + for (int i = 0; i < amount; i++) + intoArray[startIndex + i] = Converter.BoolFromBit(data, readBit++); + } + #endregion + + #region Short & UShort + /// Adds a to the message. + /// The to add. + /// The message that the was added to. + public Message AddShort(short value) + { + if (UnwrittenBits < sizeof(short) * BitsPerByte) + throw new InsufficientCapacityException(this, ShortName, sizeof(short) * BitsPerByte); + + Converter.ShortToBits(value, data, writeBit); + writeBit += sizeof(short) * BitsPerByte; + return this; + } + + /// Adds a to the message. + /// The to add. + /// The message that the was added to. + public Message AddUShort(ushort value) + { + if (UnwrittenBits < sizeof(ushort) * BitsPerByte) + throw new InsufficientCapacityException(this, UShortName, sizeof(ushort) * BitsPerByte); + + Converter.UShortToBits(value, data, writeBit); + writeBit += sizeof(ushort) * BitsPerByte; + return this; + } + + /// Retrieves a from the message. + /// The that was retrieved. + public short GetShort() + { + if (UnreadBits < sizeof(short) * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(ShortName, $"{default(short)}")); + return default(short); + } + + short value = Converter.ShortFromBits(data, readBit); + readBit += sizeof(short) * BitsPerByte; + return value; + } + + /// Retrieves a from the message. + /// The that was retrieved. + public ushort GetUShort() + { + if (UnreadBits < sizeof(ushort) * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(UShortName, $"{default(ushort)}")); + return default(ushort); + } + + ushort value = Converter.UShortFromBits(data, readBit); + readBit += sizeof(ushort) * BitsPerByte; + return value; + } + + /// Adds a array to the message. + /// The array to add. + /// Whether or not to include the length of the array in the message. + /// The message that the array was added to. + public Message AddShorts(short[] array, bool includeLength = true) + { + if (includeLength) + AddVarULong((uint)array.Length); + + if (UnwrittenBits < array.Length * sizeof(short) * BitsPerByte) + throw new InsufficientCapacityException(this, array.Length, ShortName, sizeof(short) * BitsPerByte); + + for (int i = 0; i < array.Length; i++) + { + array[i] = Converter.ShortFromBits(data, readBit); + readBit += sizeof(short) * BitsPerByte; + } + + return this; + } + + /// Adds a array to the message. + /// The array to add. + /// Whether or not to include the length of the array in the message. + /// The message that the array was added to. + public Message AddUShorts(ushort[] array, bool includeLength = true) + { + if (includeLength) + AddVarULong((uint)array.Length); + + if (UnwrittenBits < array.Length * sizeof(ushort) * BitsPerByte) + throw new InsufficientCapacityException(this, array.Length, UShortName, sizeof(ushort) * BitsPerByte); + + for (int i = 0; i < array.Length; i++) + { + array[i] = Converter.UShortFromBits(data, readBit); + readBit += sizeof(ushort) * BitsPerByte; + } + + return this; + } + + /// Retrieves a array from the message. + /// The array that was retrieved. + public short[] GetShorts() => GetShorts((int)GetVarULong()); + /// Retrieves a array from the message. + /// The amount of shorts to retrieve. + /// The array that was retrieved. + public short[] GetShorts(int amount) + { + short[] array = new short[amount]; + ReadShorts(amount, array); + return array; + } + /// Populates a array with shorts retrieved from the message. + /// The amount of shorts to retrieve. + /// The array to populate. + /// The position at which to start populating the array. + public void GetShorts(int amount, short[] intoArray, int startIndex = 0) + { + if (startIndex + amount > intoArray.Length) + throw new ArgumentException(nameof(amount), ArrayNotLongEnoughError(amount, intoArray.Length, startIndex, ShortName)); + + ReadShorts(amount, intoArray, startIndex); + } + + /// Retrieves a array from the message. + /// The array that was retrieved. + public ushort[] GetUShorts() => GetUShorts((int)GetVarULong()); + /// Retrieves a array from the message. + /// The amount of ushorts to retrieve. + /// The array that was retrieved. + public ushort[] GetUShorts(int amount) + { + ushort[] array = new ushort[amount]; + ReadUShorts(amount, array); + return array; + } + /// Populates a array with ushorts retrieved from the message. + /// The amount of ushorts to retrieve. + /// The array to populate. + /// The position at which to start populating the array. + public void GetUShorts(int amount, ushort[] intoArray, int startIndex = 0) + { + if (startIndex + amount > intoArray.Length) + throw new ArgumentException(nameof(amount), ArrayNotLongEnoughError(amount, intoArray.Length, startIndex, UShortName)); + + ReadUShorts(amount, intoArray, startIndex); + } + + /// Reads a number of shorts from the message and writes them into the given array. + /// The amount of shorts to read. + /// The array to write the shorts into. + /// The position at which to start writing into the array. + private void ReadShorts(int amount, short[] intoArray, int startIndex = 0) + { + if (UnreadBits < amount * sizeof(short) * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(amount, ShortName)); + amount = UnreadBits / (sizeof(short) * BitsPerByte); + } + + for (int i = 0; i < amount; i++) + { + intoArray[startIndex + i] = Converter.ShortFromBits(data, readBit); + readBit += sizeof(short) * BitsPerByte; + } + } + + /// Reads a number of ushorts from the message and writes them into the given array. + /// The amount of ushorts to read. + /// The array to write the ushorts into. + /// The position at which to start writing into the array. + private void ReadUShorts(int amount, ushort[] intoArray, int startIndex = 0) + { + if (UnreadBits < amount * sizeof(ushort) * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(amount, UShortName)); + amount = UnreadBits / (sizeof(ushort) * BitsPerByte); + } + + for (int i = 0; i < amount; i++) + { + intoArray[startIndex + i] = Converter.UShortFromBits(data, readBit); + readBit += sizeof(ushort) * BitsPerByte; + } + } + #endregion + + #region Int & UInt + /// Adds an to the message. + /// The to add. + /// The message that the was added to. + public Message AddInt(int value) + { + if (UnwrittenBits < sizeof(int) * BitsPerByte) + throw new InsufficientCapacityException(this, IntName, sizeof(int) * BitsPerByte); + + Converter.IntToBits(value, data, writeBit); + writeBit += sizeof(int) * BitsPerByte; + return this; + } + + /// Adds a to the message. + /// The to add. + /// The message that the was added to. + public Message AddUInt(uint value) + { + if (UnwrittenBits < sizeof(uint) * BitsPerByte) + throw new InsufficientCapacityException(this, UIntName, sizeof(uint) * BitsPerByte); + + Converter.UIntToBits(value, data, writeBit); + writeBit += sizeof(uint) * BitsPerByte; + return this; + } + + /// Retrieves an from the message. + /// The that was retrieved. + public int GetInt() + { + if (UnreadBits < sizeof(int) * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(IntName, $"{default(int)}")); + return default(int); + } + + int value = Converter.IntFromBits(data, readBit); + readBit += sizeof(int) * BitsPerByte; + return value; + } + + /// Retrieves a from the message. + /// The that was retrieved. + public uint GetUInt() + { + if (UnreadBits < sizeof(uint) * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(UIntName, $"{default(uint)}")); + return default(uint); + } + + uint value = Converter.UIntFromBits(data, readBit); + readBit += sizeof(uint) * BitsPerByte; + return value; + } + + /// Adds an array message. + /// The array to add. + /// Whether or not to include the length of the array in the message. + /// The message that the array was added to. + public Message AddInts(int[] array, bool includeLength = true) + { + if (includeLength) + AddVarULong((uint)array.Length); + + if (UnwrittenBits < array.Length * sizeof(int) * BitsPerByte) + throw new InsufficientCapacityException(this, array.Length, IntName, sizeof(int) * BitsPerByte); + + for (int i = 0; i < array.Length; i++) + { + Converter.IntToBits(array[i], data, writeBit); + writeBit += sizeof(int) * BitsPerByte; + } + + return this; + } + + /// Adds a array to the message. + /// The array to add. + /// Whether or not to include the length of the array in the message. + /// The message that the array was added to. + public Message AddUInts(uint[] array, bool includeLength = true) + { + if (includeLength) + AddVarULong((uint)array.Length); + + if (UnwrittenBits < array.Length * sizeof(uint) * BitsPerByte) + throw new InsufficientCapacityException(this, array.Length, UIntName, sizeof(uint) * BitsPerByte); + + for (int i = 0; i < array.Length; i++) + { + Converter.UIntToBits(array[i], data, writeBit); + writeBit += sizeof(uint) * BitsPerByte; + } + + return this; + } + + /// Retrieves an array from the message. + /// The array that was retrieved. + public int[] GetInts() => GetInts((int)GetVarULong()); + /// Retrieves an array from the message. + /// The amount of ints to retrieve. + /// The array that was retrieved. + public int[] GetInts(int amount) + { + int[] array = new int[amount]; + ReadInts(amount, array); + return array; + } + /// Populates an array with ints retrieved from the message. + /// The amount of ints to retrieve. + /// The array to populate. + /// The position at which to start populating the array. + public void GetInts(int amount, int[] intoArray, int startIndex = 0) + { + if (startIndex + amount > intoArray.Length) + throw new ArgumentException(nameof(amount), ArrayNotLongEnoughError(amount, intoArray.Length, startIndex, IntName)); + + ReadInts(amount, intoArray, startIndex); + } + + /// Retrieves a array from the message. + /// The array that was retrieved. + public uint[] GetUInts() => GetUInts((int)GetVarULong()); + /// Retrieves a array from the message. + /// The amount of uints to retrieve. + /// The array that was retrieved. + public uint[] GetUInts(int amount) + { + uint[] array = new uint[amount]; + ReadUInts(amount, array); + return array; + } + /// Populates a array with uints retrieved from the message. + /// The amount of uints to retrieve. + /// The array to populate. + /// The position at which to start populating the array. + public void GetUInts(int amount, uint[] intoArray, int startIndex = 0) + { + if (startIndex + amount > intoArray.Length) + throw new ArgumentException(nameof(amount), ArrayNotLongEnoughError(amount, intoArray.Length, startIndex, UIntName)); + + ReadUInts(amount, intoArray, startIndex); + } + + /// Reads a number of ints from the message and writes them into the given array. + /// The amount of ints to read. + /// The array to write the ints into. + /// The position at which to start writing into the array. + private void ReadInts(int amount, int[] intoArray, int startIndex = 0) + { + if (UnreadBits < amount * sizeof(int) * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(amount, IntName)); + amount = UnreadBits / (sizeof(int) * BitsPerByte); + } + + for (int i = 0; i < amount; i++) + { + intoArray[startIndex + i] = Converter.IntFromBits(data, readBit); + readBit += sizeof(int) * BitsPerByte; + } + } + + /// Reads a number of uints from the message and writes them into the given array. + /// The amount of uints to read. + /// The array to write the uints into. + /// The position at which to start writing into the array. + private void ReadUInts(int amount, uint[] intoArray, int startIndex = 0) + { + if (UnreadBits < amount * sizeof(uint) * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(amount, UIntName)); + amount = UnreadBits / (sizeof(uint) * BitsPerByte); + } + + for (int i = 0; i < amount; i++) + { + intoArray[startIndex + i] = Converter.UIntFromBits(data, readBit); + readBit += sizeof(uint) * BitsPerByte; + } + } + #endregion + + #region Long & ULong + /// Adds a to the message. + /// The to add. + /// The message that the was added to. + public Message AddLong(long value) + { + if (UnwrittenBits < sizeof(long) * BitsPerByte) + throw new InsufficientCapacityException(this, LongName, sizeof(long) * BitsPerByte); + + Converter.LongToBits(value, data, writeBit); + writeBit += sizeof(long) * BitsPerByte; + return this; + } + + /// Adds a to the message. + /// The to add. + /// The message that the was added to. + public Message AddULong(ulong value) + { + if (UnwrittenBits < sizeof(ulong) * BitsPerByte) + throw new InsufficientCapacityException(this, ULongName, sizeof(ulong) * BitsPerByte); + + Converter.ULongToBits(value, data, writeBit); + writeBit += sizeof(ulong) * BitsPerByte; + return this; + } + + /// Retrieves a from the message. + /// The that was retrieved. + public long GetLong() + { + if (UnreadBits < sizeof(long) * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(LongName, $"{default(long)}")); + return default(long); + } + + long value = Converter.LongFromBits(data, readBit); + readBit += sizeof(long) * BitsPerByte; + return value; + } + + /// Retrieves a from the message. + /// The that was retrieved. + public ulong GetULong() + { + if (UnreadBits < sizeof(ulong) * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(ULongName, $"{default(ulong)}")); + return default(ulong); + } + + ulong value = Converter.ULongFromBits(data, readBit); + readBit += sizeof(ulong) * BitsPerByte; + return value; + } + + /// Adds a array to the message. + /// The array to add. + /// Whether or not to include the length of the array in the message. + /// The message that the array was added to. + public Message AddLongs(long[] array, bool includeLength = true) + { + if (includeLength) + AddVarULong((uint)array.Length); + + if (UnwrittenBits < array.Length * sizeof(long) * BitsPerByte) + throw new InsufficientCapacityException(this, array.Length, LongName, sizeof(long) * BitsPerByte); + + for (int i = 0; i < array.Length; i++) + { + Converter.LongToBits(array[i], data, writeBit); + writeBit += sizeof(long) * BitsPerByte; + } + + return this; + } + + /// Adds a array to the message. + /// The array to add. + /// Whether or not to include the length of the array in the message. + /// The message that the array was added to. + public Message AddULongs(ulong[] array, bool includeLength = true) + { + if (includeLength) + AddVarULong((uint)array.Length); + + if (UnwrittenBits < array.Length * sizeof(ulong) * BitsPerByte) + throw new InsufficientCapacityException(this, array.Length, ULongName, sizeof(ulong) * BitsPerByte); + + for (int i = 0; i < array.Length; i++) + { + Converter.ULongToBits(array[i], data, writeBit); + writeBit += sizeof(ulong) * BitsPerByte; + } + + return this; + } + + /// Retrieves a array from the message. + /// The array that was retrieved. + public long[] GetLongs() => GetLongs((int)GetVarULong()); + /// Retrieves a array from the message. + /// The amount of longs to retrieve. + /// The array that was retrieved. + public long[] GetLongs(int amount) + { + long[] array = new long[amount]; + ReadLongs(amount, array); + return array; + } + /// Populates a array with longs retrieved from the message. + /// The amount of longs to retrieve. + /// The array to populate. + /// The position at which to start populating the array. + public void GetLongs(int amount, long[] intoArray, int startIndex = 0) + { + if (startIndex + amount > intoArray.Length) + throw new ArgumentException(nameof(amount), ArrayNotLongEnoughError(amount, intoArray.Length, startIndex, LongName)); + + ReadLongs(amount, intoArray, startIndex); + } + + /// Retrieves a array from the message. + /// The array that was retrieved. + public ulong[] GetULongs() => GetULongs((int)GetVarULong()); + /// Retrieves a array from the message. + /// The amount of ulongs to retrieve. + /// The array that was retrieved. + public ulong[] GetULongs(int amount) + { + ulong[] array = new ulong[amount]; + ReadULongs(amount, array); + return array; + } + /// Populates a array with ulongs retrieved from the message. + /// The amount of ulongs to retrieve. + /// The array to populate. + /// The position at which to start populating the array. + public void GetULongs(int amount, ulong[] intoArray, int startIndex = 0) + { + if (startIndex + amount > intoArray.Length) + throw new ArgumentException(nameof(amount), ArrayNotLongEnoughError(amount, intoArray.Length, startIndex, ULongName)); + + ReadULongs(amount, intoArray, startIndex); + } + + /// Reads a number of longs from the message and writes them into the given array. + /// The amount of longs to read. + /// The array to write the longs into. + /// The position at which to start writing into the array. + private void ReadLongs(int amount, long[] intoArray, int startIndex = 0) + { + if (UnreadBits < amount * sizeof(long) * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(amount, LongName)); + amount = UnreadBits / (sizeof(long) * BitsPerByte); + } + + for (int i = 0; i < amount; i++) + { + intoArray[startIndex + i] = Converter.LongFromBits(data, readBit); + readBit += sizeof(long) * BitsPerByte; + } + } + + /// Reads a number of ulongs from the message and writes them into the given array. + /// The amount of ulongs to read. + /// The array to write the ulongs into. + /// The position at which to start writing into the array. + private void ReadULongs(int amount, ulong[] intoArray, int startIndex = 0) + { + if (UnreadBits < amount * sizeof(ulong) * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(amount, ULongName)); + amount = UnreadBits / (sizeof(ulong) * BitsPerByte); + } + + for (int i = 0; i < amount; i++) + { + intoArray[startIndex + i] = Converter.ULongFromBits(data, readBit); + readBit += sizeof(ulong) * BitsPerByte; + } + } + #endregion + + #region Float + /// Adds a to the message. + /// The to add. + /// The message that the was added to. + public Message AddFloat(float value) + { + if (UnwrittenBits < sizeof(float) * BitsPerByte) + throw new InsufficientCapacityException(this, FloatName, sizeof(float) * BitsPerByte); + + Converter.FloatToBits(value, data, writeBit); + writeBit += sizeof(float) * BitsPerByte; + return this; + } + + /// Retrieves a from the message. + /// The that was retrieved. + public float GetFloat() + { + if (UnreadBits < sizeof(float) * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(FloatName, $"{default(float)}")); + return default(float); + } + + float value = Converter.FloatFromBits(data, readBit); + readBit += sizeof(float) * BitsPerByte; + return value; + } + + /// Adds a array to the message. + /// The array to add. + /// Whether or not to include the length of the array in the message. + /// The message that the array was added to. + public Message AddFloats(float[] array, bool includeLength = true) + { + if (includeLength) + AddVarULong((uint)array.Length); + + if (UnwrittenBits < array.Length * sizeof(float) * BitsPerByte) + throw new InsufficientCapacityException(this, array.Length, FloatName, sizeof(float) * BitsPerByte); + + for (int i = 0; i < array.Length; i++) + { + Converter.FloatToBits(array[i], data, writeBit); + writeBit += sizeof(float) * BitsPerByte; + } + + return this; + } + + /// Retrieves a array from the message. + /// The array that was retrieved. + public float[] GetFloats() => GetFloats((int)GetVarULong()); + /// Retrieves a array from the message. + /// The amount of floats to retrieve. + /// The array that was retrieved. + public float[] GetFloats(int amount) + { + float[] array = new float[amount]; + ReadFloats(amount, array); + return array; + } + /// Populates a array with floats retrieved from the message. + /// The amount of floats to retrieve. + /// The array to populate. + /// The position at which to start populating the array. + public void GetFloats(int amount, float[] intoArray, int startIndex = 0) + { + if (startIndex + amount > intoArray.Length) + throw new ArgumentException(nameof(amount), ArrayNotLongEnoughError(amount, intoArray.Length, startIndex, FloatName)); + + ReadFloats(amount, intoArray, startIndex); + } + + /// Reads a number of floats from the message and writes them into the given array. + /// The amount of floats to read. + /// The array to write the floats into. + /// The position at which to start writing into the array. + private void ReadFloats(int amount, float[] intoArray, int startIndex = 0) + { + if (UnreadBits < amount * sizeof(float) * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(amount, FloatName)); + amount = UnreadBits / (sizeof(float) * BitsPerByte); + } + + for (int i = 0; i < amount; i++) + { + intoArray[startIndex + i] = Converter.FloatFromBits(data, readBit); + readBit += sizeof(float) * BitsPerByte; + } + } + #endregion + + #region Double + /// Adds a to the message. + /// The to add. + /// The message that the was added to. + public Message AddDouble(double value) + { + if (UnwrittenBits < sizeof(double) * BitsPerByte) + throw new InsufficientCapacityException(this, DoubleName, sizeof(double) * BitsPerByte); + + Converter.DoubleToBits(value, data, writeBit); + writeBit += sizeof(double) * BitsPerByte; + return this; + } + + /// Retrieves a from the message. + /// The that was retrieved. + public double GetDouble() + { + if (UnreadBits < sizeof(double) * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(DoubleName, $"{default(double)}")); + return default(double); + } + + double value = Converter.DoubleFromBits(data, readBit); + readBit += sizeof(double) * BitsPerByte; + return value; + } + + /// Adds a array to the message. + /// The array to add. + /// Whether or not to include the length of the array in the message. + /// The message that the array was added to. + public Message AddDoubles(double[] array, bool includeLength = true) + { + if (includeLength) + AddVarULong((uint)array.Length); + + if (UnwrittenBits < array.Length * sizeof(double) * BitsPerByte) + throw new InsufficientCapacityException(this, array.Length, DoubleName, sizeof(double) * BitsPerByte); + + for (int i = 0; i < array.Length; i++) + { + Converter.DoubleToBits(array[i], data, writeBit); + writeBit += sizeof(double) * BitsPerByte; + } + + return this; + } + + /// Retrieves a array from the message. + /// The array that was retrieved. + public double[] GetDoubles() => GetDoubles((int)GetVarULong()); + /// Retrieves a array from the message. + /// The amount of doubles to retrieve. + /// The array that was retrieved. + public double[] GetDoubles(int amount) + { + double[] array = new double[amount]; + ReadDoubles(amount, array); + return array; + } + /// Populates a array with doubles retrieved from the message. + /// The amount of doubles to retrieve. + /// The array to populate. + /// The position at which to start populating the array. + public void GetDoubles(int amount, double[] intoArray, int startIndex = 0) + { + if (startIndex + amount > intoArray.Length) + throw new ArgumentException(nameof(amount), ArrayNotLongEnoughError(amount, intoArray.Length, startIndex, DoubleName)); + + ReadDoubles(amount, intoArray, startIndex); + } + + /// Reads a number of doubles from the message and writes them into the given array. + /// The amount of doubles to read. + /// The array to write the doubles into. + /// The position at which to start writing into the array. + private void ReadDoubles(int amount, double[] intoArray, int startIndex = 0) + { + if (UnreadBits < amount * sizeof(double) * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(amount, DoubleName)); + amount = UnreadBits / (sizeof(double) * BitsPerByte); + } + + for (int i = 0; i < amount; i++) + { + intoArray[startIndex + i] = Converter.DoubleFromBits(data, readBit); + readBit += sizeof(double) * BitsPerByte; + } + } + #endregion + + #region String + /// Adds a to the message. + /// The to add. + /// The message that the was added to. + public Message AddString(string value) + { + AddBytes(Encoding.UTF8.GetBytes(value)); + return this; + } + + /// Retrieves a from the message. + /// The that was retrieved. + public string GetString() + { + int length = (int)GetVarULong(); // Get the length of the string (in bytes, NOT characters) + if (UnreadBits < length * BitsPerByte) + { + RiptideLogger.Log(LogType.Error, NotEnoughBitsError(StringName, "shortened string")); + length = UnreadBits / BitsPerByte; + } + + string value = Encoding.UTF8.GetString(GetBytes(length), 0, length); + return value; + } + + /// Adds a array to the message. + /// The array to add. + /// Whether or not to include the length of the array in the message. + /// The message that the array was added to. + public Message AddStrings(string[] array, bool includeLength = true) + { + if (includeLength) + AddVarULong((uint)array.Length); + + // It'd be ideal to throw an exception here (instead of in AddString) if the entire array isn't going to fit, but since each string could + // be (and most likely is) a different length and some characters use more than a single byte, the only way of doing that would be to loop + // through the whole array here and convert each string to bytes ahead of time, just to get the required byte count. Then if they all fit + // into the message, they would all be converted again when actually being written into the byte array, which is obviously inefficient. + + for (int i = 0; i < array.Length; i++) + AddString(array[i]); + + return this; + } + + /// Retrieves a array from the message. + /// The array that was retrieved. + public string[] GetStrings() => GetStrings((int)GetVarULong()); + /// Retrieves a array from the message. + /// The amount of strings to retrieve. + /// The array that was retrieved. + public string[] GetStrings(int amount) + { + string[] array = new string[amount]; + for (int i = 0; i < array.Length; i++) + array[i] = GetString(); + + return array; + } + /// Populates a array with strings retrieved from the message. + /// The amount of strings to retrieve. + /// The array to populate. + /// The position at which to start populating the array. + public void GetStrings(int amount, string[] intoArray, int startIndex = 0) + { + if (startIndex + amount > intoArray.Length) + throw new ArgumentException(nameof(amount), ArrayNotLongEnoughError(amount, intoArray.Length, startIndex, StringName)); + + for (int i = 0; i < amount; i++) + intoArray[startIndex + i] = GetString(); + } + #endregion + + #region IMessageSerializable Types + /// Adds a serializable to the message. + /// The serializable to add. + /// The message that the serializable was added to. + public Message AddSerializable(T value) where T : IMessageSerializable + { + value.Serialize(this); + return this; + } + + /// Retrieves a serializable from the message. + /// The serializable that was retrieved. + public T GetSerializable() where T : IMessageSerializable, new() + { + T t = new T(); + t.Deserialize(this); + return t; + } + + /// Adds an array of serializables to the message. + /// The array to add. + /// Whether or not to include the length of the array in the message. + /// The message that the array was added to. + public Message AddSerializables(T[] array, bool includeLength = true) where T : IMessageSerializable + { + if (includeLength) + AddVarULong((uint)array.Length); + + for (int i = 0; i < array.Length; i++) + AddSerializable(array[i]); + + return this; + } + + /// Retrieves an array of serializables from the message. + /// The array that was retrieved. + public T[] GetSerializables() where T : IMessageSerializable, new() => GetSerializables((int)GetVarULong()); + /// Retrieves an array of serializables from the message. + /// The amount of serializables to retrieve. + /// The array that was retrieved. + public T[] GetSerializables(int amount) where T : IMessageSerializable, new() + { + T[] array = new T[amount]; + ReadSerializables(amount, array); + return array; + } + /// Populates an array of serializables retrieved from the message. + /// The amount of serializables to retrieve. + /// The array to populate. + /// The position at which to start populating the array. + public void GetSerializables(int amount, T[] intoArray, int startIndex = 0) where T : IMessageSerializable, new() + { + if (startIndex + amount > intoArray.Length) + throw new ArgumentException(nameof(amount), ArrayNotLongEnoughError(amount, intoArray.Length, startIndex, typeof(T).Name)); + + ReadSerializables(amount, intoArray, startIndex); + } + + /// Reads a number of serializables from the message and writes them into the given array. + /// The amount of serializables to read. + /// The array to write the serializables into. + /// The position at which to start writing into . + private void ReadSerializables(int amount, T[] intArray, int startIndex = 0) where T : IMessageSerializable, new() + { + for (int i = 0; i < amount; i++) + intArray[startIndex + i] = GetSerializable(); + } + #endregion + + #region Overload Versions + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(byte value) => AddByte(value); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(sbyte value) => AddSByte(value); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(bool value) => AddBool(value); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(short value) => AddShort(value); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(ushort value) => AddUShort(value); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(int value) => AddInt(value); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(uint value) => AddUInt(value); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(long value) => AddLong(value); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(ulong value) => AddULong(value); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(float value) => AddFloat(value); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(double value) => AddDouble(value); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(string value) => AddString(value); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(T value) where T : IMessageSerializable => AddSerializable(value); + + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(byte[] array, bool includeLength = true) => AddBytes(array, includeLength); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(sbyte[] array, bool includeLength = true) => AddSBytes(array, includeLength); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(bool[] array, bool includeLength = true) => AddBools(array, includeLength); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(short[] array, bool includeLength = true) => AddShorts(array, includeLength); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(ushort[] array, bool includeLength = true) => AddUShorts(array, includeLength); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(int[] array, bool includeLength = true) => AddInts(array, includeLength); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(uint[] array, bool includeLength = true) => AddUInts(array, includeLength); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(long[] array, bool includeLength = true) => AddLongs(array, includeLength); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(ulong[] array, bool includeLength = true) => AddULongs(array, includeLength); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(float[] array, bool includeLength = true) => AddFloats(array, includeLength); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(double[] array, bool includeLength = true) => AddDoubles(array, includeLength); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(string[] array, bool includeLength = true) => AddStrings(array, includeLength); + /// + /// This method is simply an alternative way of calling . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Message Add(T[] array, bool includeLength = true) where T : IMessageSerializable, new() => AddSerializables(array, includeLength); + #endregion + #endregion + + #region Error Messaging + /// The name of a value. + private const string ByteName = "byte"; + /// The name of a value. + private const string SByteName = "sbyte"; + /// The name of a value. + private const string BoolName = "bool"; + /// The name of a value. + private const string ShortName = "short"; + /// The name of a value. + private const string UShortName = "ushort"; + /// The name of an value. + private const string IntName = "int"; + /// The name of a value. + private const string UIntName = "uint"; + /// The name of a value. + private const string LongName = "long"; + /// The name of a value. + private const string ULongName = "ulong"; + /// The name of a value. + private const string FloatName = "float"; + /// The name of a value. + private const string DoubleName = "double"; + /// The name of a value. + private const string StringName = "string"; + /// The name of an array length value. + private const string ArrayLengthName = "array length"; + + /// Constructs an error message for when a message contains insufficient unread bits to retrieve a certain value. + /// The name of the value type for which the retrieval attempt failed. + /// Text describing the value which will be returned. + /// The error message. + private string NotEnoughBitsError(string valueName, string defaultReturn) + { + return $"Message only contains {UnreadBits} unread {Helper.CorrectForm(UnreadBits, "bit")}, which is not enough to retrieve a value of type '{valueName}'! Returning {defaultReturn}."; + } + /// Constructs an error message for when a message contains insufficient unread bits to retrieve an array of values. + /// The expected length of the array. + /// The name of the value type for which the retrieval attempt failed. + /// The error message. + private string NotEnoughBitsError(int arrayLength, string valueName) + { + return $"Message only contains {UnreadBits} unread {Helper.CorrectForm(UnreadBits, "bit")}, which is not enough to retrieve {arrayLength} {Helper.CorrectForm(arrayLength, valueName)}! Returned array will contain default elements."; + } + + /// Constructs an error message for when a number of retrieved values do not fit inside the bounds of the provided array. + /// The number of values being retrieved. + /// The length of the provided array. + /// The position in the array at which to begin writing values. + /// The name of the value type which is being retrieved. + /// The name of the value type in plural form. If left empty, this will be set to with an s appended to it. + /// The error message. + private string ArrayNotLongEnoughError(int amount, int arrayLength, int startIndex, string valueName, string pluralValueName = "") + { + if (string.IsNullOrEmpty(pluralValueName)) + pluralValueName = $"{valueName}s"; + + return $"The amount of {pluralValueName} to retrieve ({amount}) is greater than the number of elements from the start index ({startIndex}) to the end of the given array (length: {arrayLength})!"; + } + #endregion + } +} diff --git a/Riptide/MessageHandlerAttribute.cs b/Riptide/MessageHandlerAttribute.cs new file mode 100644 index 0000000..7df9e87 --- /dev/null +++ b/Riptide/MessageHandlerAttribute.cs @@ -0,0 +1,50 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; + +namespace Riptide +{ + /// Specifies a method as the message handler for messages with the given ID. + /// + /// + /// In order for a method to qualify as a message handler, it must match a valid message handler method signature. s + /// will only use methods marked with this attribute if they match the signature, and s + /// will only use methods marked with this attribute if they match the signature. + /// + /// + /// Methods marked with this attribute which match neither of the valid message handler signatures will not be used by s + /// or s and will cause warnings at runtime. + /// + /// + /// If you want a or to only use a subset of all message handler methods, you can do so by setting up + /// custom message handler groups. Simply set the group ID in the constructor and pass the + /// same value to the or method. This + /// will make that or only use message handlers which have the same group ID. + /// + /// + [AttributeUsage(AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + public sealed class MessageHandlerAttribute : Attribute + { + /// The ID of the message type which this method is meant to handle. + public readonly ushort MessageId; + /// The ID of the group of message handlers which this method belongs to. + public readonly byte GroupId; + + /// Initializes a new instance of the class with the and values. + /// The ID of the message type which this method is meant to handle. + /// The ID of the group of message handlers which this method belongs to. + /// + /// s will only use this method if its signature matches the signature. + /// s will only use this method if its signature matches the signature. + /// This method will be ignored if its signature matches neither of the valid message handler signatures. + /// + public MessageHandlerAttribute(ushort messageId, byte groupId = 0) + { + MessageId = messageId; + GroupId = groupId; + } + } +} diff --git a/Riptide/MessageRelayFilter.cs b/Riptide/MessageRelayFilter.cs new file mode 100644 index 0000000..b312162 --- /dev/null +++ b/Riptide/MessageRelayFilter.cs @@ -0,0 +1,106 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; +using System.Linq; + +namespace Riptide +{ + /// Provides functionality for enabling/disabling automatic message relaying by message type. + public class MessageRelayFilter + { + /// The number of bits an int consists of. + private const int BitsPerInt = sizeof(int) * 8; + + /// An array storing all the bits which represent whether messages of a given ID should be relayed or not. + private int[] filter; + + /// Creates a filter of a given size. + /// How big to make the filter. + /// + /// should be set to the value of the largest message ID, plus 1. For example, if a server will + /// handle messages with IDs 1, 2, 3, 7, and 8, should be set to 9 (8 is the largest possible value, + /// and 8 + 1 = 9) despite the fact that there are only 5 unique message IDs the server will ever handle. + /// + public MessageRelayFilter(int size) => Set(size); + /// Creates a filter based on an enum of message IDs. + /// The enum type. + public MessageRelayFilter(Type idEnum) => Set(GetSizeFromEnum(idEnum)); + /// Creates a filter of a given size and enables relaying for the given message IDs. + /// How big to make the filter. + /// Message IDs to enable auto relaying for. + /// + /// should be set to the value of the largest message ID, plus 1. For example, if a server will + /// handle messages with IDs 1, 2, 3, 7, and 8, should be set to 9 (8 is the largest possible value, + /// and 8 + 1 = 9) despite the fact that there are only 5 unique message IDs the server will ever handle. + /// + public MessageRelayFilter(int size, params ushort[] idsToEnable) + { + Set(size); + EnableIds(idsToEnable); + } + /// Creates a filter based on an enum of message IDs and enables relaying for the given message IDs. + /// The enum type. + /// Message IDs to enable relaying for. + public MessageRelayFilter(Type idEnum, params Enum[] idsToEnable) + { + Set(GetSizeFromEnum(idEnum)); + EnableIds(idsToEnable.Cast().ToArray()); + } + + /// Enables auto relaying for the given message IDs. + /// Message IDs to enable relaying for. + private void EnableIds(ushort[] idsToEnable) + { + for (int i = 0; i < idsToEnable.Length; i++) + EnableRelay(idsToEnable[i]); + } + + /// Calculate the filter size necessary to manage all message IDs in the given enum. + /// The enum type. + /// The appropriate filter size. + /// is not an . + private int GetSizeFromEnum(Type idEnum) + { + if (!idEnum.IsEnum) + throw new ArgumentException($"Parameter '{nameof(idEnum)}' must be an enum type!", nameof(idEnum)); + + return Enum.GetValues(idEnum).Cast().Max() + 1; + } + + /// Sets the filter size. + /// How big to make the filter. + private void Set(int size) + { + filter = new int[size / BitsPerInt + (size % BitsPerInt > 0 ? 1 : 0)]; + } + + /// Enables auto relaying for the given message ID. + /// The message ID to enable relaying for. + public void EnableRelay(ushort forMessageId) + { + filter[forMessageId / BitsPerInt] |= 1 << (forMessageId % BitsPerInt); + } + /// + public void EnableRelay(Enum forMessageId) => EnableRelay((ushort)(object)forMessageId); + + /// Disables auto relaying for the given message ID. + /// The message ID to enable relaying for. + public void DisableRelay(ushort forMessageId) + { + filter[forMessageId / BitsPerInt] &= ~(1 << (forMessageId % BitsPerInt)); + } + /// + public void DisableRelay(Enum forMessageId) => DisableRelay((ushort)(object)forMessageId); + + /// Checks whether or not messages with the given ID should be relayed. + /// The message ID to check. + /// Whether or not messages with the given ID should be relayed. + internal bool ShouldRelay(ushort forMessageId) + { + return (filter[forMessageId / BitsPerInt] & (1 << (forMessageId % BitsPerInt))) != 0; + } + } +} diff --git a/Riptide/Peer.cs b/Riptide/Peer.cs new file mode 100644 index 0000000..deaea40 --- /dev/null +++ b/Riptide/Peer.cs @@ -0,0 +1,242 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using Riptide.Transports; +using Riptide.Utils; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; + +namespace Riptide +{ + /// The reason the connection attempt was rejected. + public enum RejectReason : byte + { + /// No response was received from the server (because the client has no internet connection, the server is offline, no server is listening on the target endpoint, etc.). + NoConnection, + /// The client is already connected. + AlreadyConnected, + /// The server is full. + ServerFull, + /// The connection attempt was rejected. + Rejected, + /// The connection attempt was rejected and custom data may have been included with the rejection message. + Custom + } + + /// The reason for a disconnection. + public enum DisconnectReason : byte + { + /// No connection was ever established. + NeverConnected, + /// The connection attempt was rejected by the server. + ConnectionRejected, + /// The active transport detected a problem with the connection. + TransportError, + /// The connection timed out. + /// + /// This also acts as the fallback reason—if a client disconnects and the message containing the real reason is lost + /// in transmission, it can't be resent as the connection will have already been closed. As a result, the other end will time + /// out the connection after a short period of time and this will be used as the reason. + /// + TimedOut, + /// The client was forcibly disconnected by the server. + Kicked, + /// The server shut down. + ServerStopped, + /// The disconnection was initiated by the client. + Disconnected, + /// The connection's loss and/or resend rates exceeded the maximum acceptable thresholds, or a reliably sent message could not be delivered. + PoorConnection + } + + /// Provides base functionality for and . + public abstract class Peer + { + /// The name to use when logging messages via . + public readonly string LogName; + /// Sets the relevant connections' s. + public abstract int TimeoutTime { set; } + /// The interval (in milliseconds) at which to send and expect heartbeats to be received. + /// Changes to this value will only take effect after the next heartbeat is executed. + public int HeartbeatInterval { get; set; } = 1000; + + /// The number of currently active and instances. + internal static int ActiveCount { get; private set; } + + /// The time (in milliseconds) for which to wait before giving up on a connection attempt. + internal int ConnectTimeoutTime { get; set; } = 10000; + /// The current time. + internal long CurrentTime { get; private set; } + + /// Whether or not the peer should use the built-in message handler system. + protected bool useMessageHandlers; + /// The default time (in milliseconds) after which to disconnect if no heartbeats are received. + protected int defaultTimeout = 5000; + + /// A stopwatch used to track how much time has passed. + private readonly System.Diagnostics.Stopwatch time = new System.Diagnostics.Stopwatch(); + /// Received messages which need to be handled. + private readonly Queue messagesToHandle = new Queue(); + /// A queue of events to execute, ordered by how soon they need to be executed. + private readonly PriorityQueue eventQueue = new PriorityQueue(); + + /// Initializes the peer. + /// The name to use when logging messages via . + public Peer(string logName) + { + LogName = logName; + } + + /// Retrieves methods marked with . + /// An array containing message handler methods. + protected MethodInfo[] FindMessageHandlers() + { + return new MethodInfo[0]; + /*string thisAssemblyName = Assembly.GetExecutingAssembly().GetName().FullName; + return AppDomain.CurrentDomain.GetAssemblies() + .Where(a => a + .GetReferencedAssemblies() + .Any(n => n.FullName == thisAssemblyName)) // Get only assemblies that reference this assembly + .SelectMany(a => a.GetTypes()) + .SelectMany(t => t.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) // Include instance methods in the search so we can show the developer an error instead of silently not adding instance methods to the dictionary + .Where(m => m.GetCustomAttributes(typeof(MessageHandlerAttribute), false).Length > 0) + .ToArray();*/ + } + + /// Builds a dictionary of message IDs and their corresponding message handler methods. + /// The ID of the group of message handler methods to include in the dictionary. + protected abstract void CreateMessageHandlersDictionary(byte messageHandlerGroupId); + + /// Starts tracking how much time has passed. + protected void StartTime() + { + CurrentTime = 0; + time.Restart(); + } + + /// Stops tracking how much time has passed. + protected void StopTime() + { + CurrentTime = 0; + time.Reset(); + eventQueue.Clear(); + } + + /// Beats the heart. + internal abstract void Heartbeat(); + + /// Handles any received messages and invokes any delayed events which need to be invoked. + public virtual void Update() + { + CurrentTime = time.ElapsedMilliseconds; + + while (eventQueue.Count > 0 && eventQueue.PeekPriority() <= CurrentTime) + eventQueue.Dequeue().Invoke(); + } + + /// Sets up a delayed event to be executed after the given time has passed. + /// How long from now to execute the delayed event, in milliseconds. + /// The delayed event to execute later. + internal void ExecuteLater(long inMS, DelayedEvent delayedEvent) + { + eventQueue.Enqueue(delayedEvent, CurrentTime + inMS); + } + + /// Handles all queued messages. + protected void HandleMessages() + { + while (messagesToHandle.Count > 0) + { + MessageToHandle handle = messagesToHandle.Dequeue(); + Handle(handle.Message, handle.Header, handle.FromConnection); + } + } + + /// Handles data received by the transport. + protected void HandleData(object _, DataReceivedEventArgs e) + { + Message message = Message.Create().Init(e.DataBuffer[0], e.Amount, out MessageHeader header); + + if (message.SendMode == MessageSendMode.Notify) + { + if (e.Amount < Message.MinNotifyBytes) + return; + + e.FromConnection.ProcessNotify(e.DataBuffer, e.Amount, message); + } + else if (message.SendMode == MessageSendMode.Unreliable) + { + if (e.Amount > Message.MinUnreliableBytes) + Buffer.BlockCopy(e.DataBuffer, 1, message.Data, 1, e.Amount - 1); + + messagesToHandle.Enqueue(new MessageToHandle(message, header, e.FromConnection)); + e.FromConnection.Metrics.ReceivedUnreliable(e.Amount); + } + else + { + if (e.Amount < Message.MinReliableBytes) + return; + + e.FromConnection.Metrics.ReceivedReliable(e.Amount); + if (e.FromConnection.ShouldHandle(Converter.UShortFromBits(e.DataBuffer, Message.HeaderBits))) + { + Buffer.BlockCopy(e.DataBuffer, 1, message.Data, 1, e.Amount - 1); + messagesToHandle.Enqueue(new MessageToHandle(message, header, e.FromConnection)); + } + else + e.FromConnection.Metrics.ReliableDiscarded++; + } + } + + /// Handles a message. + /// The message to handle. + /// The message's header type. + /// The connection which the message was received on. + protected abstract void Handle(Message message, MessageHeader header, Connection connection); + + /// Disconnects the connection in question. Necessary for connections to be able to initiate disconnections (like in the case of poor connection quality). + /// The connection to disconnect. + /// The reason why the connection is being disconnected. + internal abstract void Disconnect(Connection connection, DisconnectReason reason); + + /// Increases . For use when a new or is started. + protected static void IncreaseActiveCount() + { + ActiveCount++; + } + + /// Decreases . For use when a or is stopped. + protected static void DecreaseActiveCount() + { + ActiveCount--; + if (ActiveCount < 0) + ActiveCount = 0; + } + } + + /// Stores information about a message that needs to be handled. + internal struct MessageToHandle + { + /// The message that needs to be handled. + internal readonly Message Message; + /// The message's header type. + internal readonly MessageHeader Header; + /// The connection on which the message was received. + internal readonly Connection FromConnection; + + /// Handles initialization. + /// The message that needs to be handled. + /// The message's header type. + /// The connection on which the message was received. + public MessageToHandle(Message message, MessageHeader header, Connection fromConnection) + { + Message = message; + Header = header; + FromConnection = fromConnection; + } + } +} diff --git a/Riptide/PendingMessage.cs b/Riptide/PendingMessage.cs new file mode 100644 index 0000000..4828250 --- /dev/null +++ b/Riptide/PendingMessage.cs @@ -0,0 +1,136 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using Riptide.Transports; +using Riptide.Utils; +using System; +using System.Collections.Generic; + +namespace Riptide +{ + /// Represents a currently pending reliably sent message whose delivery has not been acknowledged yet. + internal class PendingMessage + { + /// The time of the latest send attempt. + internal long LastSendTime { get; private set; } + + /// The multiplier used to determine how long to wait before resending a pending message. + private const float RetryTimeMultiplier = 1.2f; + + /// A pool of reusable instances. + private static readonly List pool = new List(); + + /// The to use to send (and resend) the pending message. + private Connection connection; + /// The contents of the message. + private readonly byte[] data; + /// The length in bytes of the message. + private int size; + /// How many send attempts have been made so far. + private byte sendAttempts; + /// Whether the pending message has been cleared or not. + private bool wasCleared; + + /// Handles initial setup. + internal PendingMessage() + { + data = new byte[Message.MaxSize]; + } + + #region Pooling + /// Retrieves a instance and initializes it. + /// The sequence ID of the message. + /// The message that is being sent reliably. + /// The to use to send (and resend) the pending message. + /// An intialized instance. + internal static PendingMessage Create(ushort sequenceId, Message message, Connection connection) + { + PendingMessage pendingMessage = RetrieveFromPool(); + pendingMessage.connection = connection; + + message.SetBits(sequenceId, sizeof(ushort) * Converter.BitsPerByte, Message.HeaderBits); + pendingMessage.size = message.BytesInUse; + Buffer.BlockCopy(message.Data, 0, pendingMessage.data, 0, pendingMessage.size); + + pendingMessage.sendAttempts = 0; + pendingMessage.wasCleared = false; + return pendingMessage; + } + + /// Retrieves a instance from the pool. If none is available, a new instance is created. + /// A instance. + private static PendingMessage RetrieveFromPool() + { + PendingMessage message; + if (pool.Count > 0) + { + message = pool[0]; + pool.RemoveAt(0); + } + else + message = new PendingMessage(); + + return message; + } + + /// Empties the pool. Does not affect instances which are actively pending and therefore not in the pool. + public static void ClearPool() + { + pool.Clear(); + } + + /// Returns the instance to the pool so it can be reused. + private void Release() + { + if (!pool.Contains(this)) + pool.Add(this); // Only add it if it's not already in the list, otherwise this method being called twice in a row for whatever reason could cause *serious* issues + + // TODO: consider doing something to decrease pool capacity if there are far more + // available instance than are needed, which could occur if a large burst of + // messages has to be sent for some reason + } + #endregion + + /// Resends the message. + internal void RetrySend() + { + if (!wasCleared) + { + long time = connection.Peer.CurrentTime; + if (LastSendTime + (connection.SmoothRTT < 0 ? 25 : connection.SmoothRTT / 2) <= time) // Avoid triggering a resend if the latest resend was less than half a RTT ago + TrySend(); + else + connection.Peer.ExecuteLater(connection.SmoothRTT < 0 ? 50 : (long)Math.Max(10, connection.SmoothRTT * RetryTimeMultiplier), new ResendEvent(this, time)); + } + } + + /// Attempts to send the message. + internal void TrySend() + { + if (sendAttempts >= connection.MaxSendAttempts && connection.CanQualityDisconnect) + { + RiptideLogger.Log(LogType.Info, connection.Peer.LogName, $"Could not guarantee delivery of a {(MessageHeader)(data[0] & Message.HeaderBitmask)} message after {sendAttempts} attempts! Disconnecting..."); + connection.Peer.Disconnect(connection, DisconnectReason.PoorConnection); + return; + } + + connection.Send(data, size); + connection.Metrics.SentReliable(size); + + LastSendTime = connection.Peer.CurrentTime; + sendAttempts++; + + connection.Peer.ExecuteLater(connection.SmoothRTT < 0 ? 50 : (long)Math.Max(10, connection.SmoothRTT * RetryTimeMultiplier), new ResendEvent(this, connection.Peer.CurrentTime)); + } + + /// Clears the message. + internal void Clear() + { + connection.Metrics.RollingReliableSends.Add(sendAttempts); + wasCleared = true; + Release(); + } + } +} diff --git a/Riptide/Server.cs b/Riptide/Server.cs new file mode 100644 index 0000000..d8a9fdf --- /dev/null +++ b/Riptide/Server.cs @@ -0,0 +1,598 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using Riptide.Transports; +using Riptide.Utils; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; + +namespace Riptide +{ + /// A server that can accept connections from s. + public class Server : Peer + { + /// Invoked when a client connects. + public event EventHandler ClientConnected; + /// Invoked when a connection fails to be fully established. + public event EventHandler ConnectionFailed; + /// Invoked when a message is received. + public event EventHandler MessageReceived; + /// Invoked when a client disconnects. + public event EventHandler ClientDisconnected; + + /// Whether or not the server is currently running. + public bool IsRunning { get; private set; } + /// The local port that the server is running on. + public ushort Port => transport.Port; + /// Sets the default timeout time for future connections and updates the of all connected clients. + public override int TimeoutTime + { + set + { + defaultTimeout = value; + foreach (Connection connection in clients.Values) + connection.TimeoutTime = defaultTimeout; + } + } + /// The maximum number of concurrent connections. + public ushort MaxClientCount { get; private set; } + /// The number of currently connected clients. + public int ClientCount => clients.Count; + /// An array of all the currently connected clients. + /// The position of each instance in the array does not correspond to that client's numeric ID (except by coincidence). + public Connection[] Clients => clients.Values.ToArray(); + /// Encapsulates a method that handles a message from a client. + /// The numeric ID of the client from whom the message was received. + /// The message that was received. + public delegate void MessageHandler(ushort fromClientId, Message message); + /// Encapsulates a method that determines whether or not to accept a client's connection attempt. + public delegate void ConnectionAttemptHandler(Connection pendingConnection, Message connectMessage); + /// An optional method which determines whether or not to accept a client's connection attempt. + /// The parameter is the pending connection and the parameter is a message containing any additional data the + /// client included with the connection attempt. If you choose to subscribe a method to this delegate, you should use it to call either + /// or . Not doing so will result in the connection hanging until the client times out. + public ConnectionAttemptHandler HandleConnection; + /// Stores which message IDs have auto relaying enabled. Relaying is disabled entirely when this is . + public MessageRelayFilter RelayFilter; + + /// Currently pending connections which are waiting to be accepted or rejected. + private readonly List pendingConnections; + /// Currently connected clients. + private Dictionary clients; + /// Clients that have timed out and need to be removed from . + private readonly List timedOutClients; + /// Methods used to handle messages, accessible by their corresponding message IDs. + private Dictionary messageHandlers; + /// The underlying transport's server that is used for sending and receiving data. + private IServer transport; + /// All currently unused client IDs. + private Queue availableClientIds; + + /// Handles initial setup. + /// The transport to use for sending and receiving data. + /// The name to use when logging messages via . + public Server(IServer transport, string logName = "SERVER") : base(logName) + { + this.transport = transport; + pendingConnections = new List(); + clients = new Dictionary(); + timedOutClients = new List(); + } + /// Handles initial setup using the built-in UDP transport. + /// The name to use when logging messages via . + public Server(string logName = "SERVER") : this(new Transports.Udp.UdpServer(), logName) { } + + /// Stops the server if it's running and swaps out the transport it's using. + /// The new underlying transport server to use for sending and receiving data. + /// This method does not automatically restart the server. To continue accepting connections, must be called again. + public void ChangeTransport(IServer newTransport) + { + Stop(); + transport = newTransport; + } + + /// Starts the server. + /// The local port on which to start the server. + /// The maximum number of concurrent connections to allow. + /// The ID of the group of message handler methods to use when building . + /// Whether or not the server should use the built-in message handler system. + /// Setting to will disable the automatic detection and execution of methods with the , which is beneficial if you prefer to handle messages via the event. + public void Start(ushort port, ushort maxClientCount, byte messageHandlerGroupId = 0, bool useMessageHandlers = true) + { + Stop(); + + IncreaseActiveCount(); + this.useMessageHandlers = useMessageHandlers; + if (useMessageHandlers) + CreateMessageHandlersDictionary(messageHandlerGroupId); + + MaxClientCount = maxClientCount; + clients = new Dictionary(maxClientCount); + InitializeClientIds(); + + SubToTransportEvents(); + transport.Start(port); + + StartTime(); + Heartbeat(); + IsRunning = true; + RiptideLogger.Log(LogType.Info, LogName, $"Started on port {port}."); + } + + /// Subscribes appropriate methods to the transport's events. + private void SubToTransportEvents() + { + transport.Connected += HandleConnectionAttempt; + transport.DataReceived += HandleData; + transport.Disconnected += TransportDisconnected; + } + + /// Unsubscribes methods from all of the transport's events. + private void UnsubFromTransportEvents() + { + transport.Connected -= HandleConnectionAttempt; + transport.DataReceived -= HandleData; + transport.Disconnected -= TransportDisconnected; + } + + /// + protected override void CreateMessageHandlersDictionary(byte messageHandlerGroupId) + { + MethodInfo[] methods = FindMessageHandlers(); + + messageHandlers = new Dictionary(methods.Length); + foreach (MethodInfo method in methods) + { + MessageHandlerAttribute attribute = method.GetCustomAttribute(); + if (attribute.GroupId != messageHandlerGroupId) + continue; + + if (!method.IsStatic) + throw new NonStaticHandlerException(method.DeclaringType, method.Name); + + Delegate serverMessageHandler = Delegate.CreateDelegate(typeof(MessageHandler), method, false); + if (serverMessageHandler != null) + { + // It's a message handler for Server instances + if (messageHandlers.ContainsKey(attribute.MessageId)) + { + MethodInfo otherMethodWithId = messageHandlers[attribute.MessageId].GetMethodInfo(); + throw new DuplicateHandlerException(attribute.MessageId, method, otherMethodWithId); + } + else + messageHandlers.Add(attribute.MessageId, (MessageHandler)serverMessageHandler); + } + else + { + // It's not a message handler for Server instances, but it might be one for Client instances + if (Delegate.CreateDelegate(typeof(Client.MessageHandler), method, false) == null) + throw new InvalidHandlerSignatureException(method.DeclaringType, method.Name); + } + } + } + + /// Handles an incoming connection attempt. + private void HandleConnectionAttempt(object _, ConnectedEventArgs e) + { + e.Connection.Initialize(this, defaultTimeout); + } + + /// Handles a connect message. + /// The client that sent the connect message. + /// The connect message. + private void HandleConnect(Connection connection, Message connectMessage) + { + connection.SetPending(); + + if (HandleConnection == null) + AcceptConnection(connection); + else if (ClientCount < MaxClientCount) + { + if (!clients.ContainsValue(connection) && !pendingConnections.Contains(connection)) + { + pendingConnections.Add(connection); + Send(Message.Create(MessageHeader.Connect), connection); // Inform the client we've received the connection attempt + HandleConnection(connection, connectMessage); // Externally determines whether to accept + } + else + Reject(connection, RejectReason.AlreadyConnected); + } + else + Reject(connection, RejectReason.ServerFull); + } + + /// Accepts the given pending connection. + /// The connection to accept. + public void Accept(Connection connection) + { + if (pendingConnections.Remove(connection)) + AcceptConnection(connection); + else + RiptideLogger.Log(LogType.Warning, LogName, $"Couldn't accept connection from {connection} because no such connection was pending!"); + } + + /// Rejects the given pending connection. + /// The connection to reject. + /// Data that should be sent to the client being rejected. Use to get an empty message instance. + public void Reject(Connection connection, Message message = null) + { + if (message != null && message.ReadBits != 0) + RiptideLogger.Log(LogType.Error, LogName, $"Use the parameterless 'Message.Create()' overload when setting rejection data!"); + + if (pendingConnections.Remove(connection)) + Reject(connection, message == null ? RejectReason.Rejected : RejectReason.Custom, message); + else + RiptideLogger.Log(LogType.Warning, LogName, $"Couldn't reject connection from {connection} because no such connection was pending!"); + } + + /// Accepts the given pending connection. + /// The connection to accept. + private void AcceptConnection(Connection connection) + { + if (ClientCount < MaxClientCount) + { + if (!clients.ContainsValue(connection)) + { + ushort clientId = GetAvailableClientId(); + connection.Id = clientId; + clients.Add(clientId, connection); + connection.ResetTimeout(); + connection.SendWelcome(); + return; + } + else + Reject(connection, RejectReason.AlreadyConnected); + } + else + Reject(connection, RejectReason.ServerFull); + } + + /// Rejects the given pending connection. + /// The connection to reject. + /// The reason why the connection is being rejected. + /// Data that should be sent to the client being rejected. + private void Reject(Connection connection, RejectReason reason, Message rejectMessage = null) + { + if (reason != RejectReason.AlreadyConnected) + { + // Sending a reject message about the client already being connected could theoretically be exploited to obtain information + // on other connected clients, although in practice that seems very unlikely. However, under normal circumstances, clients + // should never actually encounter a scenario where they are "already connected". + + Message message = Message.Create(MessageHeader.Reject); + message.AddByte((byte)reason); + if (reason == RejectReason.Custom) + message.AddMessage(rejectMessage); + + for (int i = 0; i < 3; i++) // Send the rejection message a few times to increase the odds of it arriving + connection.Send(message, false); + + message.Release(); + } + + connection.ResetTimeout(); // Keep the connection alive for a moment so the same client can't immediately attempt to connect again + connection.LocalDisconnect(); + + RiptideLogger.Log(LogType.Info, LogName, $"Rejected connection from {connection}: {Helper.GetReasonString(reason)}."); + } + + /// Checks if clients have timed out. + internal override void Heartbeat() + { + foreach (Connection connection in clients.Values) + if (connection.HasTimedOut) + timedOutClients.Add(connection); + + foreach (Connection connection in pendingConnections) + if (connection.HasConnectAttemptTimedOut) + timedOutClients.Add(connection); + + foreach (Connection connection in timedOutClients) + LocalDisconnect(connection, DisconnectReason.TimedOut); + + timedOutClients.Clear(); + + ExecuteLater(HeartbeatInterval, new HeartbeatEvent(this)); + } + + /// + public override void Update() + { + base.Update(); + transport.Poll(); + HandleMessages(); + } + + /// + protected override void Handle(Message message, MessageHeader header, Connection connection) + { + switch (header) + { + // User messages + case MessageHeader.Unreliable: + case MessageHeader.Reliable: + OnMessageReceived(message, connection); + break; + + // Internal messages + case MessageHeader.Ack: + connection.HandleAck(message); + break; + case MessageHeader.Connect: + HandleConnect(connection, message); + break; + case MessageHeader.Heartbeat: + connection.HandleHeartbeat(message); + break; + case MessageHeader.Disconnect: + LocalDisconnect(connection, DisconnectReason.Disconnected); + break; + case MessageHeader.Welcome: + if (connection.HandleWelcomeResponse(message)) + OnClientConnected(connection); + break; + default: + RiptideLogger.Log(LogType.Warning, LogName, $"Unexpected message header '{header}'! Discarding {message.BytesInUse} bytes received from {connection}."); + break; + } + + message.Release(); + } + + /// Sends a message to a given client. + /// The message to send. + /// The numeric ID of the client to send the message to. + /// Whether or not to return the message to the pool after it is sent. + /// + public void Send(Message message, ushort toClient, bool shouldRelease = true) + { + if (clients.TryGetValue(toClient, out Connection connection)) + Send(message, connection, shouldRelease); + } + /// Sends a message to a given client. + /// The message to send. + /// The client to send the message to. + /// Whether or not to return the message to the pool after it is sent. + /// + public ushort Send(Message message, Connection toClient, bool shouldRelease = true) => toClient.Send(message, shouldRelease); + + /// Sends a message to all connected clients. + /// The message to send. + /// Whether or not to return the message to the pool after it is sent. + /// + public void SendToAll(Message message, bool shouldRelease = true) + { + foreach (Connection client in clients.Values) + client.Send(message, false); + + if (shouldRelease) + message.Release(); + } + /// Sends a message to all connected clients except the given one. + /// The message to send. + /// The numeric ID of the client to not send the message to. + /// Whether or not to return the message to the pool after it is sent. + /// + public void SendToAll(Message message, ushort exceptToClientId, bool shouldRelease = true) + { + foreach (Connection client in clients.Values) + if (client.Id != exceptToClientId) + client.Send(message, false); + + if (shouldRelease) + message.Release(); + } + + /// Retrieves the client with the given ID, if a client with that ID is currently connected. + /// The ID of the client to retrieve. + /// The retrieved client. + /// if a client with the given ID was connected; otherwise . + public bool TryGetClient(ushort id, out Connection client) => clients.TryGetValue(id, out client); + + /// Disconnects a specific client. + /// The numeric ID of the client to disconnect. + /// Data that should be sent to the client being disconnected. Use to get an empty message instance. + public void DisconnectClient(ushort id, Message message = null) + { + if (message != null && message.ReadBits != 0) + RiptideLogger.Log(LogType.Error, LogName, $"Use the parameterless 'Message.Create()' overload when setting disconnection data!"); + + if (clients.TryGetValue(id, out Connection client)) + { + SendDisconnect(client, DisconnectReason.Kicked, message); + LocalDisconnect(client, DisconnectReason.Kicked); + } + else + RiptideLogger.Log(LogType.Warning, LogName, $"Couldn't disconnect client {id} because it wasn't connected!"); + } + + /// Disconnects the given client. + /// The client to disconnect. + /// Data that should be sent to the client being disconnected. Use to get an empty message instance. + public void DisconnectClient(Connection client, Message message = null) + { + if (message != null && message.ReadBits != 0) + RiptideLogger.Log(LogType.Error, LogName, $"Use the parameterless 'Message.Create()' overload when setting disconnection data!"); + + if (clients.ContainsKey(client.Id)) + { + SendDisconnect(client, DisconnectReason.Kicked, message); + LocalDisconnect(client, DisconnectReason.Kicked); + } + else + RiptideLogger.Log(LogType.Warning, LogName, $"Couldn't disconnect client {client.Id} because it wasn't connected!"); + } + + /// + internal override void Disconnect(Connection connection, DisconnectReason reason) + { + if (connection.IsConnected && connection.CanQualityDisconnect) + LocalDisconnect(connection, reason); + } + + /// Cleans up the local side of the given connection. + /// The client to disconnect. + /// The reason why the client is being disconnected. + private void LocalDisconnect(Connection client, DisconnectReason reason) + { + if (client.Peer != this) + return; // Client does not belong to this Server instance + + transport.Close(client); + + if (clients.Remove(client.Id)) + availableClientIds.Enqueue(client.Id); + + if (client.IsConnected) + OnClientDisconnected(client, reason); // Only run if the client was ever actually connected + else if (client.IsPending) + OnConnectionFailed(client); + + client.LocalDisconnect(); + } + + /// What to do when the transport disconnects a client. + private void TransportDisconnected(object sender, Transports.DisconnectedEventArgs e) + { + LocalDisconnect(e.Connection, e.Reason); + } + + /// Stops the server. + public void Stop() + { + if (!IsRunning) + return; + + pendingConnections.Clear(); + byte[] disconnectBytes = { (byte)MessageHeader.Disconnect, (byte)DisconnectReason.ServerStopped }; + foreach (Connection client in clients.Values) + client.Send(disconnectBytes, disconnectBytes.Length); + clients.Clear(); + + transport.Shutdown(); + UnsubFromTransportEvents(); + + DecreaseActiveCount(); + + StopTime(); + IsRunning = false; + RiptideLogger.Log(LogType.Info, LogName, "Server stopped."); + } + + /// Initializes available client IDs. + private void InitializeClientIds() + { + if (MaxClientCount > ushort.MaxValue - 1) + throw new Exception($"A server's max client count may not exceed {ushort.MaxValue - 1}!"); + + availableClientIds = new Queue(MaxClientCount); + for (ushort i = 1; i <= MaxClientCount; i++) + availableClientIds.Enqueue(i); + } + + /// Retrieves an available client ID. + /// The client ID. 0 if none were available. + private ushort GetAvailableClientId() + { + if (availableClientIds.Count > 0) + return availableClientIds.Dequeue(); + + RiptideLogger.Log(LogType.Error, LogName, "No available client IDs, assigned 0!"); + return 0; + } + + #region Messages + /// Sends a disconnect message. + /// The client to send the disconnect message to. + /// Why the client is being disconnected. + /// Optional custom data that should be sent to the client being disconnected. + private void SendDisconnect(Connection client, DisconnectReason reason, Message disconnectMessage) + { + Message message = Message.Create(MessageHeader.Disconnect); + message.AddByte((byte)reason); + + if (reason == DisconnectReason.Kicked && disconnectMessage != null) + message.AddMessage(disconnectMessage); + + Send(message, client); + } + + /// Sends a client connected message. + /// The newly connected client. + private void SendClientConnected(Connection newClient) + { + Message message = Message.Create(MessageHeader.ClientConnected); + message.AddUShort(newClient.Id); + + SendToAll(message, newClient.Id); + } + + /// Sends a client disconnected message. + /// The numeric ID of the client that disconnected. + private void SendClientDisconnected(ushort id) + { + Message message = Message.Create(MessageHeader.ClientDisconnected); + message.AddUShort(id); + + SendToAll(message); + } + #endregion + + #region Events + /// Invokes the event. + /// The newly connected client. + protected virtual void OnClientConnected(Connection client) + { + RiptideLogger.Log(LogType.Info, LogName, $"Client {client.Id} ({client}) connected successfully!"); + SendClientConnected(client); + ClientConnected?.Invoke(this, new ServerConnectedEventArgs(client)); + } + + /// Invokes the event. + /// The connection that failed to be fully established. + protected virtual void OnConnectionFailed(Connection connection) + { + RiptideLogger.Log(LogType.Info, LogName, $"Client {connection} stopped responding before the connection was fully established!"); + ConnectionFailed?.Invoke(this, new ServerConnectionFailedEventArgs(connection)); + } + + /// Invokes the event and initiates handling of the received message. + /// The received message. + /// The client from which the message was received. + protected virtual void OnMessageReceived(Message message, Connection fromConnection) + { + ushort messageId = (ushort)message.GetVarULong(); + if (RelayFilter != null && RelayFilter.ShouldRelay(messageId)) + { + // The message should be automatically relayed to clients instead of being handled on the server + SendToAll(message, fromConnection.Id); + return; + } + + MessageReceived?.Invoke(this, new MessageReceivedEventArgs(fromConnection, messageId, message)); + + if (useMessageHandlers) + { + if (messageHandlers.TryGetValue(messageId, out MessageHandler messageHandler)) + messageHandler(fromConnection.Id, message); + else + RiptideLogger.Log(LogType.Warning, LogName, $"No message handler method found for message ID {messageId}!"); + } + } + + /// Invokes the event. + /// The client that disconnected. + /// The reason for the disconnection. + protected virtual void OnClientDisconnected(Connection connection, DisconnectReason reason) + { + RiptideLogger.Log(LogType.Info, LogName, $"Client {connection.Id} ({connection}) disconnected: {Helper.GetReasonString(reason)}."); + SendClientDisconnected(connection.Id); + ClientDisconnected?.Invoke(this, new ServerDisconnectedEventArgs(connection, reason)); + } + #endregion + } +} diff --git a/Riptide/Transports/EventArgs.cs b/Riptide/Transports/EventArgs.cs new file mode 100644 index 0000000..3088ff5 --- /dev/null +++ b/Riptide/Transports/EventArgs.cs @@ -0,0 +1,61 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +namespace Riptide.Transports +{ + /// Contains event data for when a server's transport successfully establishes a connection to a client. + public class ConnectedEventArgs + { + /// The newly established connection. + public readonly Connection Connection; + + /// Initializes event data. + /// The newly established connection. + public ConnectedEventArgs(Connection connection) + { + Connection = connection; + } + } + + /// Contains event data for when a server's or client's transport receives data. + public class DataReceivedEventArgs + { + /// An array containing the received data. + public readonly byte[] DataBuffer; + /// The number of bytes that were received. + public readonly int Amount; + /// The connection which the data was received from. + public readonly Connection FromConnection; + + /// Initializes event data. + /// An array containing the received data. + /// The number of bytes that were received. + /// The connection which the data was received from. + public DataReceivedEventArgs(byte[] dataBuffer, int amount, Connection fromConnection) + { + DataBuffer = dataBuffer; + Amount = amount; + FromConnection = fromConnection; + } + } + + /// Contains event data for when a server's or client's transport initiates or detects a disconnection. + public class DisconnectedEventArgs + { + /// The closed connection. + public readonly Connection Connection; + /// The reason for the disconnection. + public readonly DisconnectReason Reason; + + /// Initializes event data. + /// The closed connection. + /// The reason for the disconnection. + public DisconnectedEventArgs(Connection connection, DisconnectReason reason) + { + Connection = connection; + Reason = reason; + } + } +} diff --git a/Riptide/Transports/IClient.cs b/Riptide/Transports/IClient.cs new file mode 100644 index 0000000..9d001f6 --- /dev/null +++ b/Riptide/Transports/IClient.cs @@ -0,0 +1,28 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; + +namespace Riptide.Transports +{ + /// Defines methods, properties, and events which every transport's client must implement. + public interface IClient : IPeer + { + /// Invoked when a connection is established at the transport level. + event EventHandler Connected; + /// Invoked when a connection attempt fails at the transport level. + event EventHandler ConnectionFailed; + + /// Starts the transport and attempts to connect to the given host address. + /// The host address to connect to. + /// The pending connection. if an issue occurred. + /// The error message associated with the issue that occurred, if any. + /// if a connection attempt will be made. if an issue occurred (such as being in an invalid format) and a connection attempt will not be made. + bool Connect(string hostAddress, out Connection connection, out string connectError); + + /// Closes the connection to the server. + void Disconnect(); + } +} diff --git a/Riptide/Transports/IPeer.cs b/Riptide/Transports/IPeer.cs new file mode 100644 index 0000000..f6201f0 --- /dev/null +++ b/Riptide/Transports/IPeer.cs @@ -0,0 +1,50 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; + +namespace Riptide.Transports +{ + /// The header type of a . + public enum MessageHeader : byte + { + /// An unreliable user message. + Unreliable, + /// An internal unreliable ack message. + Ack, + /// An internal unreliable connect message. + Connect, + /// An internal unreliable connection rejection message. + Reject, + /// An internal unreliable heartbeat message. + Heartbeat, + /// An internal unreliable disconnect message. + Disconnect, + + /// A notify message. + Notify, + + /// A reliable user message. + Reliable, + /// An internal reliable welcome message. + Welcome, + /// An internal reliable client connected message. + ClientConnected, + /// An internal reliable client disconnected message. + ClientDisconnected, + } + + /// Defines methods, properties, and events which every transport's server and client must implement. + public interface IPeer + { + /// Invoked when data is received by the transport. + event EventHandler DataReceived; + /// Invoked when a disconnection is initiated or detected by the transport. + event EventHandler Disconnected; + + /// Initiates handling of any received messages. + void Poll(); + } +} diff --git a/Riptide/Transports/IServer.cs b/Riptide/Transports/IServer.cs new file mode 100644 index 0000000..d2664d7 --- /dev/null +++ b/Riptide/Transports/IServer.cs @@ -0,0 +1,30 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; + +namespace Riptide.Transports +{ + /// Defines methods, properties, and events which every transport's server must implement. + public interface IServer : IPeer + { + /// Invoked when a connection is established at the transport level. + event EventHandler Connected; + + /// + ushort Port { get; } + + /// Starts the transport and begins listening for incoming connections. + /// The local port on which to listen for connections. + void Start(ushort port); + + /// Closes an active connection. + /// The connection to close. + void Close(Connection connection); + + /// Closes all existing connections and stops listening for new connections. + void Shutdown(); + } +} diff --git a/Riptide/Transports/Tcp/TcpClient.cs b/Riptide/Transports/Tcp/TcpClient.cs new file mode 100644 index 0000000..4d3d587 --- /dev/null +++ b/Riptide/Transports/Tcp/TcpClient.cs @@ -0,0 +1,120 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; +using System.Linq; +using System.Net; +using System.Net.Sockets; + +namespace Riptide.Transports.Tcp +{ + /// A client which can connect to a . + public class TcpClient : TcpPeer, IClient + { + /// + public event EventHandler Connected; + /// + public event EventHandler ConnectionFailed; + /// + public event EventHandler DataReceived; + + /// The connection to the server. + private TcpConnection tcpConnection; + + /// + /// Expects the host address to consist of an IP and port, separated by a colon. For example: 127.0.0.1:7777. + public bool Connect(string hostAddress, out Connection connection, out string connectError) + { + connectError = $"Invalid host address '{hostAddress}'! IP and port should be separated by a colon, for example: '127.0.0.1:7777'."; + if (!ParseHostAddress(hostAddress, out IPAddress ip, out ushort port)) + { + connection = null; + return false; + } + + IPEndPoint remoteEndPoint = new IPEndPoint(ip, port); + socket = new Socket(SocketType.Stream, ProtocolType.Tcp) + { + SendBufferSize = socketBufferSize, + ReceiveBufferSize = socketBufferSize, + NoDelay = true, + }; + + try + { + socket.Connect(remoteEndPoint); // TODO: do something about the fact that this is a blocking call + } + catch (SocketException) + { + // The connection failed, but invoking the transports ConnectionFailed event from + // inside this method will cause problems, so we're just goint to eat the exception, + // call OnConnected(), and let Riptide detect that no connection was established. + } + + connection = tcpConnection = new TcpConnection(socket, remoteEndPoint, this); + OnConnected(); + return true; + } + + /// Parses into and , if possible. + /// The host address to parse. + /// The retrieved IP. + /// The retrieved port. + /// Whether or not was in a valid format. + private bool ParseHostAddress(string hostAddress, out IPAddress ip, out ushort port) + { + string[] ipAndPort = hostAddress.Split(':'); + string ipString = ""; + string portString = ""; + if (ipAndPort.Length > 2) + { + // There was more than one ':' in the host address, might be IPv6 + ipString = string.Join(":", ipAndPort.Take(ipAndPort.Length - 1)); + portString = ipAndPort[ipAndPort.Length - 1]; + } + else if (ipAndPort.Length == 2) + { + // IPv4 + ipString = ipAndPort[0]; + portString = ipAndPort[1]; + } + + port = 0; // Need to make sure a value is assigned in case IP parsing fails + return IPAddress.TryParse(ipString, out ip) && ushort.TryParse(portString, out port); + } + + /// + public void Poll() + { + if (tcpConnection != null) + tcpConnection.Receive(); + } + + /// + public void Disconnect() + { + socket.Close(); + tcpConnection = null; + } + + /// Invokes the event. + protected virtual void OnConnected() + { + Connected?.Invoke(this, EventArgs.Empty); + } + + /// Invokes the event. + protected virtual void OnConnectionFailed() + { + ConnectionFailed?.Invoke(this, EventArgs.Empty); + } + + /// + protected internal override void OnDataReceived(int amount, TcpConnection fromConnection) + { + DataReceived?.Invoke(this, new DataReceivedEventArgs(ReceiveBuffer, amount, fromConnection)); + } + } +} diff --git a/Riptide/Transports/Tcp/TcpConnection.cs b/Riptide/Transports/Tcp/TcpConnection.cs new file mode 100644 index 0000000..77a3d5d --- /dev/null +++ b/Riptide/Transports/Tcp/TcpConnection.cs @@ -0,0 +1,195 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using Riptide.Utils; +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Sockets; + +namespace Riptide.Transports.Tcp +{ + /// Represents a connection to a or . + public class TcpConnection : Connection, IEquatable + { + /// The endpoint representing the other end of the connection. + public readonly IPEndPoint RemoteEndPoint; + + /// Whether or not the server has received a connection attempt from this connection. + internal bool DidReceiveConnect; + + /// The socket to use for sending and receiving. + private readonly Socket socket; + /// The local peer this connection is associated with. + private readonly TcpPeer peer; + /// An array to receive message size values into. + private readonly byte[] sizeBytes = new byte[sizeof(int)]; + /// The size of the next message to be received. + private int nextMessageSize; + + /// Initializes the connection. + /// The socket to use for sending and receiving. + /// The endpoint representing the other end of the connection. + /// The local peer this connection is associated with. + internal TcpConnection(Socket socket, IPEndPoint remoteEndPoint, TcpPeer peer) + { + RemoteEndPoint = remoteEndPoint; + this.socket = socket; + this.peer = peer; + } + + /// + protected internal override void Send(byte[] dataBuffer, int amount) + { + if (amount == 0) + throw new ArgumentOutOfRangeException(nameof(amount), "Sending 0 bytes is not allowed!"); + + try + { + if (socket.Connected) + { + Converter.FromInt(amount, peer.SendBuffer, 0); + Array.Copy(dataBuffer, 0, peer.SendBuffer, sizeof(int), amount); // TODO: consider sending length separately with an extra socket.Send call instead of copying the data an extra time + socket.Send(peer.SendBuffer, amount + sizeof(int), SocketFlags.None); + } + } + catch (SocketException) + { + // May want to consider triggering a disconnect here (perhaps depending on the type + // of SocketException)? Timeout should catch disconnections, but disconnecting + // explicitly might be better... + } + } + + /// Polls the socket and checks if any data was received. + internal void Receive() + { + bool tryReceiveMore = true; + while (tryReceiveMore) + { + int byteCount = 0; + try + { + if (nextMessageSize > 0) + { + // We already have a size value + tryReceiveMore = TryReceiveMessage(out byteCount); + } + else if (socket.Available >= sizeof(int)) + { + // We have enough bytes for a complete size value + socket.Receive(sizeBytes, sizeof(int), SocketFlags.None); + nextMessageSize = Converter.ToInt(sizeBytes, 0); + + if (nextMessageSize > 0) + tryReceiveMore = TryReceiveMessage(out byteCount); + } + else + tryReceiveMore = false; + } + catch (SocketException ex) + { + tryReceiveMore = false; + switch (ex.SocketErrorCode) + { + case SocketError.Interrupted: + case SocketError.NotSocket: + peer.OnDisconnected(this, DisconnectReason.TransportError); + break; + case SocketError.ConnectionReset: + peer.OnDisconnected(this, DisconnectReason.Disconnected); + break; + case SocketError.TimedOut: + peer.OnDisconnected(this, DisconnectReason.TimedOut); + break; + case SocketError.MessageSize: + break; + default: + break; + } + } + catch (ObjectDisposedException) + { + tryReceiveMore = false; + peer.OnDisconnected(this, DisconnectReason.TransportError); + } + catch (NullReferenceException) + { + tryReceiveMore = false; + peer.OnDisconnected(this, DisconnectReason.TransportError); + } + + if (byteCount > 0) + peer.OnDataReceived(byteCount, this); + } + } + + /// Receives a message, if all of its data is ready to be received. + /// How many bytes were received. + /// Whether or not all of the message's data was ready to be received. + private bool TryReceiveMessage(out int receivedByteCount) + { + if (socket.Available >= nextMessageSize) + { + // We have enough bytes to read the complete message + receivedByteCount = socket.Receive(peer.ReceiveBuffer, nextMessageSize, SocketFlags.None); + nextMessageSize = 0; + return true; + } + + receivedByteCount = 0; + return false; + } + + /// Closes the connection. + internal void Close() + { + socket.Close(); + } + + /// + public override string ToString() => RemoteEndPoint.ToStringBasedOnIPFormat(); + + /// + public override bool Equals(object obj) => Equals(obj as TcpConnection); + /// + public bool Equals(TcpConnection other) + { + if (other is null) + return false; + + if (ReferenceEquals(this, other)) + return true; + + return RemoteEndPoint.Equals(other.RemoteEndPoint); + } + + /// + public override int GetHashCode() + { + return -288961498 + EqualityComparer.Default.GetHashCode(RemoteEndPoint); + } + +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member + public static bool operator ==(TcpConnection left, TcpConnection right) +#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member + { + if (left is null) + { + if (right is null) + return true; + + return false; // Only the left side is null + } + + // Equals handles case of null on right side + return left.Equals(right); + } + +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member + public static bool operator !=(TcpConnection left, TcpConnection right) => !(left == right); +#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member + } +} diff --git a/Riptide/Transports/Tcp/TcpPeer.cs b/Riptide/Transports/Tcp/TcpPeer.cs new file mode 100644 index 0000000..87ef391 --- /dev/null +++ b/Riptide/Transports/Tcp/TcpPeer.cs @@ -0,0 +1,56 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; +using System.Net.Sockets; + +namespace Riptide.Transports.Tcp +{ + /// Provides base send & receive functionality for and . + public abstract class TcpPeer + { + /// + public event EventHandler Disconnected; + + /// An array that incoming data is received into. + internal readonly byte[] ReceiveBuffer; + /// An array that outgoing data is sent out of. + internal readonly byte[] SendBuffer; + + /// The default size used for the socket's send and receive buffers. + protected const int DefaultSocketBufferSize = 1024 * 1024; // 1MB + /// The size to use for the socket's send and receive buffers. + protected readonly int socketBufferSize; + /// The main socket, either used for listening for connections or for sending and receiving data. + protected Socket socket; + /// The minimum size that may be used for the socket's send and receive buffers. + private const int MinSocketBufferSize = 256 * 1024; // 256KB + + /// Initializes the transport. + /// How big the socket's send and receive buffers should be. + protected TcpPeer(int socketBufferSize = DefaultSocketBufferSize) + { + if (socketBufferSize < MinSocketBufferSize) + throw new ArgumentOutOfRangeException(nameof(socketBufferSize), $"The minimum socket buffer size is {MinSocketBufferSize}!"); + + this.socketBufferSize = socketBufferSize; + ReceiveBuffer = new byte[Message.MaxSize]; + SendBuffer = new byte[Message.MaxSize + sizeof(int)]; // Need room for the entire message plus the message length (since this is TCP) + } + + /// Handles received data. + /// The number of bytes that were received. + /// The connection from which the data was received. + protected internal abstract void OnDataReceived(int amount, TcpConnection fromConnection); + + /// Invokes the event. + /// The closed connection. + /// The reason for the disconnection. + protected internal virtual void OnDisconnected(Connection connection, DisconnectReason reason) + { + Disconnected?.Invoke(this, new DisconnectedEventArgs(connection, reason)); + } + } +} diff --git a/Riptide/Transports/Tcp/TcpServer.cs b/Riptide/Transports/Tcp/TcpServer.cs new file mode 100644 index 0000000..44d712e --- /dev/null +++ b/Riptide/Transports/Tcp/TcpServer.cs @@ -0,0 +1,157 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Sockets; + +namespace Riptide.Transports.Tcp +{ + /// A server which can accept connections from s. + public class TcpServer : TcpPeer, IServer + { + /// + public event EventHandler Connected; + /// + public event EventHandler DataReceived; + + /// + public ushort Port { get; private set; } + /// The maximum number of pending connections to allow at any given time. + public int MaxPendingConnections { get; private set; } = 5; + + /// Whether or not the server is running. + private bool isRunning = false; + /// The currently open connections, accessible by their endpoints. + private Dictionary connections; + /// Connections that have been closed and need to be removed from . + private readonly List closedConnections = new List(); + /// The IP address to bind the socket to. + private readonly IPAddress listenAddress; + + /// + public TcpServer(int socketBufferSize = DefaultSocketBufferSize) : this(IPAddress.IPv6Any, socketBufferSize) { } + + /// Initializes the transport, binding the socket to a specific IP address. + /// The IP address to bind the socket to. + /// How big the socket's send and receive buffers should be. + public TcpServer(IPAddress listenAddress, int socketBufferSize = DefaultSocketBufferSize) : base(socketBufferSize) + { + this.listenAddress = listenAddress; + } + + /// + public void Start(ushort port) + { + Port = port; + connections = new Dictionary(); + + StartListening(port); + } + + /// Starts listening for connections on the given port. + /// The port to listen on. + private void StartListening(ushort port) + { + if (isRunning) + StopListening(); + + IPEndPoint localEndPoint = new IPEndPoint(listenAddress, port); + socket = new Socket(SocketType.Stream, ProtocolType.Tcp) + { + SendBufferSize = socketBufferSize, + ReceiveBufferSize = socketBufferSize, + NoDelay = true, + }; + socket.Bind(localEndPoint); + socket.Listen(MaxPendingConnections); + + isRunning = true; + } + + /// + public void Poll() + { + if (!isRunning) + return; + + Accept(); + foreach (TcpConnection connection in connections.Values) + connection.Receive(); + + foreach (IPEndPoint endPoint in closedConnections) + connections.Remove(endPoint); + + closedConnections.Clear(); + } + + /// Accepts any pending connections. + private void Accept() + { + if (socket.Poll(0, SelectMode.SelectRead)) + { + Socket acceptedSocket = socket.Accept(); + IPEndPoint fromEndPoint = (IPEndPoint)acceptedSocket.RemoteEndPoint; + if (!connections.ContainsKey(fromEndPoint)) + { + TcpConnection newConnection = new TcpConnection(acceptedSocket, fromEndPoint, this); + connections.Add(fromEndPoint, newConnection); + OnConnected(newConnection); + } + else + acceptedSocket.Close(); + } + } + + /// Stops listening for connections. + private void StopListening() + { + if (!isRunning) + return; + + isRunning = false; + socket.Close(); + } + + /// + public void Close(Connection connection) + { + if (connection is TcpConnection tcpConnection) + { + closedConnections.Add(tcpConnection.RemoteEndPoint); + tcpConnection.Close(); + } + } + + /// + public void Shutdown() + { + StopListening(); + connections.Clear(); + } + + /// Invokes the event. + /// The successfully established connection. + protected virtual void OnConnected(Connection connection) + { + Connected?.Invoke(this, new ConnectedEventArgs(connection)); + } + + /// + protected internal override void OnDataReceived(int amount, TcpConnection fromConnection) + { + if ((MessageHeader)(ReceiveBuffer[0] & Message.HeaderBitmask) == MessageHeader.Connect) + { + if (fromConnection.DidReceiveConnect) + return; + + fromConnection.DidReceiveConnect = true; + } + + DataReceived?.Invoke(this, new DataReceivedEventArgs(ReceiveBuffer, amount, fromConnection)); + } + } +} diff --git a/Riptide/Transports/Udp/UdpClient.cs b/Riptide/Transports/Udp/UdpClient.cs new file mode 100644 index 0000000..09f877e --- /dev/null +++ b/Riptide/Transports/Udp/UdpClient.cs @@ -0,0 +1,111 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; +using System.Linq; +using System.Net; +using System.Net.Sockets; + +namespace Riptide.Transports.Udp +{ + /// A client which can connect to a . + public class UdpClient : UdpPeer, IClient + { + /// + public event EventHandler Connected; + /// + public event EventHandler ConnectionFailed; + /// + public event EventHandler DataReceived; + + /// The connection to the server. + private UdpConnection udpConnection; + + /// + public UdpClient(SocketMode mode = SocketMode.Both, int socketBufferSize = DefaultSocketBufferSize) : base(mode, socketBufferSize) { } + + /// + /// Expects the host address to consist of an IP and port, separated by a colon. For example: 127.0.0.1:7777. + public bool Connect(string hostAddress, out Connection connection, out string connectError) + { + connectError = $"Invalid host address '{hostAddress}'! IP and port should be separated by a colon, for example: '127.0.0.1:7777'."; + if (!ParseHostAddress(hostAddress, out IPAddress ip, out ushort port)) + { + connection = null; + return false; + } + + if ((mode == SocketMode.IPv4Only && ip.AddressFamily == AddressFamily.InterNetworkV6) || (mode == SocketMode.IPv6Only && ip.AddressFamily == AddressFamily.InterNetwork)) + { + // The IP address isn't in an acceptable format for the current socket mode + if (mode == SocketMode.IPv4Only) + connectError = "Connecting to IPv6 addresses is not allowed when running in IPv4 only mode!"; + else + connectError = "Connecting to IPv4 addresses is not allowed when running in IPv6 only mode!"; + + connection = null; + return false; + } + + OpenSocket(); + + connection = udpConnection = new UdpConnection(new IPEndPoint(mode == SocketMode.IPv4Only ? ip : ip.MapToIPv6(), port), this); + OnConnected(); // UDP is connectionless, so from the transport POV everything is immediately ready to send/receive data + return true; + } + + /// Parses into and , if possible. + /// The host address to parse. + /// The retrieved IP. + /// The retrieved port. + /// Whether or not was in a valid format. + private bool ParseHostAddress(string hostAddress, out IPAddress ip, out ushort port) + { + string[] ipAndPort = hostAddress.Split(':'); + string ipString = ""; + string portString = ""; + if (ipAndPort.Length > 2) + { + // There was more than one ':' in the host address, might be IPv6 + ipString = string.Join(":", ipAndPort.Take(ipAndPort.Length - 1)); + portString = ipAndPort[ipAndPort.Length - 1]; + } + else if (ipAndPort.Length == 2) + { + // IPv4 + ipString = ipAndPort[0]; + portString = ipAndPort[1]; + } + + port = 0; // Need to make sure a value is assigned in case IP parsing fails + return IPAddress.TryParse(ipString, out ip) && ushort.TryParse(portString, out port); + } + + /// + public void Disconnect() + { + CloseSocket(); + } + + /// Invokes the event. + protected virtual void OnConnected() + { + Connected?.Invoke(this, EventArgs.Empty); + } + + /// Invokes the event. + protected virtual void OnConnectionFailed() + { + ConnectionFailed?.Invoke(this, EventArgs.Empty); + } + + /// + protected override void OnDataReceived(byte[] dataBuffer, int amount, IPEndPoint fromEndPoint) + { + if (udpConnection.RemoteEndPoint.Equals(fromEndPoint) && !udpConnection.IsNotConnected) + DataReceived?.Invoke(this, new DataReceivedEventArgs(dataBuffer, amount, udpConnection)); + } + } +} diff --git a/Riptide/Transports/Udp/UdpConnection.cs b/Riptide/Transports/Udp/UdpConnection.cs new file mode 100644 index 0000000..39b6758 --- /dev/null +++ b/Riptide/Transports/Udp/UdpConnection.cs @@ -0,0 +1,80 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using Riptide.Utils; +using System; +using System.Collections.Generic; +using System.Net; + +namespace Riptide.Transports.Udp +{ + /// Represents a connection to a or . + public class UdpConnection : Connection, IEquatable + { + /// The endpoint representing the other end of the connection. + public readonly IPEndPoint RemoteEndPoint; + + /// The local peer this connection is associated with. + private readonly UdpPeer peer; + + /// Initializes the connection. + /// The endpoint representing the other end of the connection. + /// The local peer this connection is associated with. + internal UdpConnection(IPEndPoint remoteEndPoint, UdpPeer peer) + { + RemoteEndPoint = remoteEndPoint; + this.peer = peer; + } + + /// + protected internal override void Send(byte[] dataBuffer, int amount) + { + peer.Send(dataBuffer, amount, RemoteEndPoint); + } + + /// + public override string ToString() => RemoteEndPoint.ToStringBasedOnIPFormat(); + + /// + public override bool Equals(object obj) => Equals(obj as UdpConnection); + /// + public bool Equals(UdpConnection other) + { + if (other is null) + return false; + + if (ReferenceEquals(this, other)) + return true; + + return RemoteEndPoint.Equals(other.RemoteEndPoint); + } + + /// + public override int GetHashCode() + { + return -288961498 + EqualityComparer.Default.GetHashCode(RemoteEndPoint); + } + +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member + public static bool operator ==(UdpConnection left, UdpConnection right) +#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member + { + if (left is null) + { + if (right is null) + return true; + + return false; // Only the left side is null + } + + // Equals handles case of null on right side + return left.Equals(right); + } + +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member + public static bool operator !=(UdpConnection left, UdpConnection right) => !(left == right); +#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member + } +} diff --git a/Riptide/Transports/Udp/UdpPeer.cs b/Riptide/Transports/Udp/UdpPeer.cs new file mode 100644 index 0000000..378bab5 --- /dev/null +++ b/Riptide/Transports/Udp/UdpPeer.cs @@ -0,0 +1,185 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; +using System.Net; +using System.Net.Sockets; + +namespace Riptide.Transports.Udp +{ + /// The kind of socket to create. + public enum SocketMode + { + /// Dual-mode. Works with both IPv4 and IPv6. + Both, + /// IPv4 only mode. + IPv4Only, + /// IPv6 only mode. + IPv6Only + } + + /// Provides base send & receive functionality for and . + public abstract class UdpPeer + { + /// + public event EventHandler Disconnected; + + /// The default size used for the socket's send and receive buffers. + protected const int DefaultSocketBufferSize = 1024 * 1024; // 1MB + /// The minimum size that may be used for the socket's send and receive buffers. + private const int MinSocketBufferSize = 256 * 1024; // 256KB + /// How long to wait for a packet, in microseconds. + private const int ReceivePollingTime = 500000; // 0.5 seconds + + /// Whether to create an IPv4 only, IPv6 only, or dual-mode socket. + protected readonly SocketMode mode; + /// The size to use for the socket's send and receive buffers. + private readonly int socketBufferSize; + /// The array that incoming data is received into. + private readonly byte[] receivedData; + /// The socket to use for sending and receiving. + private Socket socket; + /// Whether or not the transport is running. + private bool isRunning; + /// A reusable endpoint. + private EndPoint remoteEndPoint; + + /// Initializes the transport. + /// Whether to create an IPv4 only, IPv6 only, or dual-mode socket. + /// How big the socket's send and receive buffers should be. + protected UdpPeer(SocketMode mode, int socketBufferSize) + { + if (socketBufferSize < MinSocketBufferSize) + throw new ArgumentOutOfRangeException(nameof(socketBufferSize), $"The minimum socket buffer size is {MinSocketBufferSize}!"); + + this.mode = mode; + this.socketBufferSize = socketBufferSize; + receivedData = new byte[Message.MaxSize]; + } + + /// + public void Poll() + { + Receive(); + } + + /// Opens the socket and starts the transport. + /// The IP address to bind the socket to, if any. + /// The port to bind the socket to. + protected void OpenSocket(IPAddress listenAddress = null, ushort port = 0) + { + if (isRunning) + CloseSocket(); + + if (mode == SocketMode.IPv4Only) + socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + else if (mode == SocketMode.IPv6Only) + socket = new Socket(AddressFamily.InterNetworkV6, SocketType.Dgram, ProtocolType.Udp) { DualMode = false }; + else + socket = new Socket(SocketType.Dgram, ProtocolType.Udp); + + IPAddress any = socket.AddressFamily == AddressFamily.InterNetworkV6 ? IPAddress.IPv6Any : IPAddress.Any; + socket.SendBufferSize = socketBufferSize; + socket.ReceiveBufferSize = socketBufferSize; + socket.Bind(new IPEndPoint(listenAddress == null ? any : listenAddress, port)); + remoteEndPoint = new IPEndPoint(any, 0); + + isRunning = true; + } + + /// Closes the socket and stops the transport. + protected void CloseSocket() + { + if (!isRunning) + return; + + isRunning = false; + socket.Close(); + } + + /// Polls the socket and checks if any data was received. + private void Receive() + { + if (!isRunning) + return; + + bool tryReceiveMore = true; + while (tryReceiveMore) + { + int byteCount = 0; + try + { + if (socket.Available > 0 && socket.Poll(ReceivePollingTime, SelectMode.SelectRead)) + byteCount = socket.ReceiveFrom(receivedData, SocketFlags.None, ref remoteEndPoint); + else + tryReceiveMore = false; + } + catch (SocketException ex) + { + tryReceiveMore = false; + switch (ex.SocketErrorCode) + { + case SocketError.Interrupted: + case SocketError.NotSocket: + isRunning = false; + break; + case SocketError.ConnectionReset: + case SocketError.MessageSize: + case SocketError.TimedOut: + break; + default: + break; + } + } + catch (ObjectDisposedException) + { + tryReceiveMore = false; + isRunning = false; + } + catch (NullReferenceException) + { + tryReceiveMore = false; + isRunning = false; + } + + if (byteCount > 0) + OnDataReceived(receivedData, byteCount, (IPEndPoint)remoteEndPoint); + } + } + + /// Sends data to a given endpoint. + /// The array containing the data. + /// The number of bytes in the array which should be sent. + /// The endpoint to send the data to. + internal void Send(byte[] dataBuffer, int numBytes, IPEndPoint toEndPoint) + { + try + { + if (isRunning) + socket.SendTo(dataBuffer, numBytes, SocketFlags.None, toEndPoint); + } + catch(SocketException) + { + // May want to consider triggering a disconnect here (perhaps depending on the type + // of SocketException)? Timeout should catch disconnections, but disconnecting + // explicitly might be better... + } + } + + /// Handles received data. + /// A byte array containing the received data. + /// The number of bytes in used by the received data. + /// The endpoint from which the data was received. + protected abstract void OnDataReceived(byte[] dataBuffer, int amount, IPEndPoint fromEndPoint); + + /// Invokes the event. + /// The closed connection. + /// The reason for the disconnection. + protected virtual void OnDisconnected(Connection connection, DisconnectReason reason) + { + Disconnected?.Invoke(this, new DisconnectedEventArgs(connection, reason)); + } + } +} diff --git a/Riptide/Transports/Udp/UdpServer.cs b/Riptide/Transports/Udp/UdpServer.cs new file mode 100644 index 0000000..cca2c53 --- /dev/null +++ b/Riptide/Transports/Udp/UdpServer.cs @@ -0,0 +1,93 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; +using System.Collections.Generic; +using System.Net; + +namespace Riptide.Transports.Udp +{ + /// A server which can accept connections from s. + public class UdpServer : UdpPeer, IServer + { + /// + public event EventHandler Connected; + /// + public event EventHandler DataReceived; + + /// + public ushort Port { get; private set; } + + /// The currently open connections, accessible by their endpoints. + private Dictionary connections; + /// The IP address to bind the socket to, if any. + private readonly IPAddress listenAddress; + + /// + public UdpServer(SocketMode mode = SocketMode.Both, int socketBufferSize = DefaultSocketBufferSize) : base(mode, socketBufferSize) { } + + /// Initializes the transport, binding the socket to a specific IP address. + /// The IP address to bind the socket to. + /// How big the socket's send and receive buffers should be. + public UdpServer(IPAddress listenAddress, int socketBufferSize = DefaultSocketBufferSize) : base(SocketMode.Both, socketBufferSize) + { + this.listenAddress = listenAddress; + } + + /// + public void Start(ushort port) + { + Port = port; + connections = new Dictionary(); + + OpenSocket(listenAddress, port); + } + + /// Decides what to do with a connection attempt. + /// The endpoint the connection attempt is coming from. + /// Whether or not the connection attempt was from a new connection. + private bool HandleConnectionAttempt(IPEndPoint fromEndPoint) + { + if (connections.ContainsKey(fromEndPoint)) + return false; + + UdpConnection connection = new UdpConnection(fromEndPoint, this); + connections.Add(fromEndPoint, connection); + OnConnected(connection); + return true; + } + + /// + public void Close(Connection connection) + { + if (connection is UdpConnection udpConnection) + connections.Remove(udpConnection.RemoteEndPoint); + } + + /// + public void Shutdown() + { + CloseSocket(); + connections.Clear(); + } + + /// Invokes the event. + /// The successfully established connection. + protected virtual void OnConnected(Connection connection) + { + Connected?.Invoke(this, new ConnectedEventArgs(connection)); + } + + /// + protected override void OnDataReceived(byte[] dataBuffer, int amount, IPEndPoint fromEndPoint) + { + if ((MessageHeader)(dataBuffer[0] & Message.HeaderBitmask) == MessageHeader.Connect && !HandleConnectionAttempt(fromEndPoint)) + return; + + if (connections.TryGetValue(fromEndPoint, out Connection connection) && !connection.IsNotConnected) + DataReceived?.Invoke(this, new DataReceivedEventArgs(dataBuffer, amount, connection)); + } + } +} diff --git a/Riptide/Utils/Bitfield.cs b/Riptide/Utils/Bitfield.cs new file mode 100644 index 0000000..c991f19 --- /dev/null +++ b/Riptide/Utils/Bitfield.cs @@ -0,0 +1,150 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; +using System.Collections.Generic; + +namespace Riptide.Utils +{ + /// Provides functionality for managing and manipulating a collection of bits. + internal class Bitfield + { + /// The first 8 bits stored in the bitfield. + internal byte First8 => (byte)segments[0]; + /// The first 16 bits stored in the bitfield. + internal ushort First16 => (ushort)segments[0]; + + /// The number of bits which fit into a single segment. + private const int SegmentSize = sizeof(uint) * 8; + /// The segments of the bitfield. + private readonly List segments; + /// Whether or not the bitfield's capacity should dynamically adjust when shifting. + private readonly bool isDynamicCapacity; + /// The current number of bits being stored. + private int count; + /// The current capacity. + private int capacity; + + /// Creates a bitfield. + /// Whether or not the bitfield's capacity should dynamically adjust when shifting. + internal Bitfield(bool isDynamicCapacity = true) + { + segments = new List(4) { 0 }; + capacity = segments.Count * SegmentSize; + this.isDynamicCapacity = isDynamicCapacity; + } + + /// Checks if the bitfield has capacity for the given number of bits. + /// The number of bits for which to check if there is capacity. + /// The number of bits from which there is no capacity for. + /// Whether or not there is sufficient capacity. + internal bool HasCapacityFor(int amount, out int overflow) + { + overflow = count + amount - capacity; + return overflow < 0; + } + + /// Shifts the bitfield by the given amount. + /// How much to shift by. + internal void ShiftBy(int amount) + { + int segmentShift = amount / SegmentSize; // How many WHOLE segments we have to shift by + int bitShift = amount % SegmentSize; // How many bits we have to shift by + + if (!isDynamicCapacity) + count = Math.Min(count + amount, SegmentSize); + else if (!HasCapacityFor(amount, out int _)) + { + Trim(); + count += amount; + + if (count > capacity) + { + int increaseBy = segmentShift + 1; + for (int i = 0; i < increaseBy; i++) + segments.Add(0); + + capacity = segments.Count * SegmentSize; + } + } + else + count += amount; + + int s = segments.Count - 1; + segments[s] <<= bitShift; + s -= 1 + segmentShift; + while (s > -1) + { + ulong shiftedBits = (ulong)segments[s] << bitShift; + segments[s] = (uint)shiftedBits; + + segments[s + 1 + segmentShift] |= (uint)(shiftedBits >> SegmentSize); + s--; + } + } + + /// Checks the last bit in the bitfield, and trims it if it is set to 1. + /// The checked bit's position in the bitfield. + /// Whether or not the checked bit was set. + internal bool CheckAndTrimLast(out int checkedPosition) + { + checkedPosition = count; + uint bitToCheck = (uint)(1 << ((count - 1) % SegmentSize)); + bool isSet = (segments[segments.Count - 1] & bitToCheck) != 0; + count--; + return isSet; + } + + /// Trims all bits from the end of the bitfield until an unset bit is encountered. + private void Trim() + { + while (count > 0 && IsSet(count)) + count--; + } + + /// Sets the given bit to 1. + /// The bit to set. + /// is less than 1. + internal void Set(int bit) + { + if (bit < 1) + throw new ArgumentOutOfRangeException(nameof(bit), $"'{nameof(bit)}' must be greater than zero!"); + + bit--; + int s = bit / SegmentSize; + uint bitToSet = (uint)(1 << (bit % SegmentSize)); + if (s < segments.Count) + segments[s] |= bitToSet; + } + + /// Checks if the given bit is set to 1. + /// The bit to check. + /// Whether or not the bit is set. + /// is less than 1. + internal bool IsSet(int bit) + { + if (bit > count) + return true; + + if (bit < 1) + throw new ArgumentOutOfRangeException(nameof(bit), $"'{nameof(bit)}' must be greater than zero!"); + + bit--; + int s = bit / SegmentSize; + uint bitToCheck = (uint)(1 << (bit % SegmentSize)); + if (s < segments.Count) + return (segments[s] & bitToCheck) != 0; + + return true; + } + + /// Combines this bitfield with the given bits. + /// The bits to OR into the bitfield. + internal void Combine(ushort other) + { + segments[0] |= other; + } + } +} diff --git a/Riptide/Utils/ConnectionMetrics.cs b/Riptide/Utils/ConnectionMetrics.cs new file mode 100644 index 0000000..2986de0 --- /dev/null +++ b/Riptide/Utils/ConnectionMetrics.cs @@ -0,0 +1,202 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +namespace Riptide.Utils +{ + /// Tracks and manages various metrics of a . + public class ConnectionMetrics + { + /// The total number of bytes received across all send modes since the last call, including those in duplicate and, in + /// the case of notify messages, out-of-order packets. Does not include packet header bytes, which may vary by transport. + public int BytesIn => UnreliableBytesIn + NotifyBytesIn + ReliableBytesIn; + /// The total number of bytes sent across all send modes since the last call, including those in automatic resends. + /// Does not include packet header bytes, which may vary by transport. + public int BytesOut => UnreliableBytesOut + NotifyBytesOut + ReliableBytesOut; + /// The total number of messages received across all send modes since the last call, including duplicate and out-of-order notify messages. + public int MessagesIn => UnreliableIn + NotifyIn + ReliableIn; + /// The total number of messages sent across all send modes since the last call, including automatic resends. + public int MessagesOut => UnreliableOut + NotifyOut + ReliableOut; + + /// The total number of bytes received in unreliable messages since the last call. Does not include packet header bytes, which may vary by transport. + public int UnreliableBytesIn { get; private set; } + /// The total number of bytes sent in unreliable messages since the last call. Does not include packet header bytes, which may vary by transport. + public int UnreliableBytesOut { get; internal set; } + /// The number of unreliable messages received since the last call. + public int UnreliableIn { get; private set; } + /// The number of unreliable messages sent since the last call. + public int UnreliableOut { get; internal set; } + + /// The total number of bytes received in notify messages since the last call, including those in duplicate and out-of-order packets. + /// Does not include packet header bytes, which may vary by transport. + public int NotifyBytesIn { get; private set; } + /// The total number of bytes sent in notify messages since the last call. Does not include packet header bytes, which may vary by transport. + public int NotifyBytesOut { get; internal set; } + /// The number of notify messages received since the last call, including duplicate and out-of-order ones. + public int NotifyIn { get; private set; } + /// The number of notify messages sent since the last call. + public int NotifyOut { get; internal set; } + /// The number of duplicate or out-of-order notify messages which were received, but discarded (not handled) since the last call. + public int NotifyDiscarded { get; internal set; } + /// The number of notify messages lost since the last call. + public int NotifyLost { get; private set; } + /// The number of notify messages delivered since the last call. + public int NotifyDelivered { get; private set; } + /// The number of notify messages lost of the last 64 notify messages to be lost or delivered. + public int RollingNotifyLost { get; private set; } + /// The number of notify messages delivered of the last 64 notify messages to be lost or delivered. + public int RollingNotifyDelivered { get; private set; } + /// The loss rate (0-1) among the last 64 notify messages. + public float RollingNotifyLossRate => RollingNotifyLost / 64f; + + /// The total number of bytes received in reliable messages since the last call, including those in duplicate packets. + /// Does not include packet header bytes, which may vary by transport. + public int ReliableBytesIn { get; private set; } + /// The total number of bytes sent in reliable messages since the last call, including those in automatic resends. + /// Does not include packet header bytes, which may vary by transport. + public int ReliableBytesOut { get; internal set; } + /// The number of reliable messages received since the last call, including duplicates. + public int ReliableIn { get; private set; } + /// The number of reliable messages sent since the last call, including automatic resends (each resend adds to this value). + public int ReliableOut { get; internal set; } + /// The number of duplicate reliable messages which were received, but discarded (and not handled) since the last call. + public int ReliableDiscarded { get; internal set; } + /// The number of unique reliable messages sent since the last call. + /// A message only counts towards this the first time it is sent—subsequent resends are not counted. + public int ReliableUniques { get; internal set; } + /// The number of send attempts that were required to deliver recent reliable messages. + public readonly RollingStat RollingReliableSends; + + /// The left-most bit of a , used to store the oldest value in the . + private const ulong ULongLeftBit = 1ul << 63; + /// Which recent notify messages were lost. Each bit corresponds to a message. + private ulong notifyLossTracker; + /// How many of the 's bits are in use. + private int notifyBufferCount; + + /// Initializes metrics. + public ConnectionMetrics() + { + Reset(); + RollingNotifyDelivered = 0; + RollingNotifyLost = 0; + notifyLossTracker = 0; + notifyBufferCount = 0; + RollingReliableSends = new RollingStat(64); + } + + /// Resets all non-rolling metrics to 0. + public void Reset() + { + UnreliableBytesIn = 0; + UnreliableBytesOut = 0; + UnreliableIn = 0; + UnreliableOut = 0; + + NotifyBytesIn = 0; + NotifyBytesOut = 0; + NotifyIn = 0; + NotifyOut = 0; + NotifyDiscarded = 0; + NotifyLost = 0; + NotifyDelivered = 0; + + ReliableBytesIn = 0; + ReliableBytesOut = 0; + ReliableIn = 0; + ReliableOut = 0; + ReliableDiscarded = 0; + ReliableUniques = 0; + } + + /// Updates the metrics associated with receiving an unreliable message. + /// The number of bytes that were received. + internal void ReceivedUnreliable(int byteCount) + { + UnreliableBytesIn += byteCount; + UnreliableIn++; + } + + /// Updates the metrics associated with sending an unreliable message. + /// The number of bytes that were sent. + internal void SentUnreliable(int byteCount) + { + UnreliableBytesOut += byteCount; + UnreliableOut++; + } + + /// Updates the metrics associated with receiving a notify message. + /// The number of bytes that were received. + internal void ReceivedNotify(int byteCount) + { + NotifyBytesIn += byteCount; + NotifyIn++; + } + + /// Updates the metrics associated with sending a notify message. + /// The number of bytes that were sent. + internal void SentNotify(int byteCount) + { + NotifyBytesOut += byteCount; + NotifyOut++; + } + + /// Updates the metrics associated with delivering a notify message. + internal void DeliveredNotify() + { + NotifyDelivered++; + + if (notifyBufferCount < 64) + { + RollingNotifyDelivered++; + notifyBufferCount++; + } + else if ((notifyLossTracker & ULongLeftBit) == 0) + { + // The one being removed from the buffer was not delivered + RollingNotifyDelivered++; + RollingNotifyLost--; + } + + notifyLossTracker <<= 1; + notifyLossTracker |= 1; + } + + /// Updates the metrics associated with losing a notify message. + internal void LostNotify() + { + NotifyLost++; + + if (notifyBufferCount < 64) + { + RollingNotifyLost++; + notifyBufferCount++; + } + else if ((notifyLossTracker & ULongLeftBit) != 0) + { + // The one being removed from the buffer was delivered + RollingNotifyDelivered--; + RollingNotifyLost++; + } + + notifyLossTracker <<= 1; + } + + /// Updates the metrics associated with receiving a reliable message. + /// The number of bytes that were received. + internal void ReceivedReliable(int byteCount) + { + ReliableBytesIn += byteCount; + ReliableIn++; + } + + /// Updates the metrics associated with sending a reliable message. + /// The number of bytes that were sent. + internal void SentReliable(int byteCount) + { + ReliableBytesOut += byteCount; + ReliableOut++; + } + } +} diff --git a/Riptide/Utils/Converter.cs b/Riptide/Utils/Converter.cs new file mode 100644 index 0000000..4bfab74 --- /dev/null +++ b/Riptide/Utils/Converter.cs @@ -0,0 +1,954 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Riptide.Utils +{ + /// Provides functionality for converting bits and bytes to various value types and vice versa. + public class Converter + { + /// The number of bits in a byte. + public const int BitsPerByte = 8; + /// The number of bits in a ulong. + public const int BitsPerULong = sizeof(ulong) * BitsPerByte; + + #region Zig Zag Encoding + /// Zig zag encodes . + /// The value to encode. + /// The zig zag-encoded value. + /// Zig zag encoding allows small negative numbers to be represented as small positive numbers. All positive numbers are doubled and become even numbers, + /// while all negative numbers become positive odd numbers. In contrast, simply casting a negative value to its unsigned counterpart would result in a large positive + /// number which uses the high bit, rendering compression via and ineffective. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int ZigZagEncode(int value) + { + return (value >> 31) ^ (value << 1); + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long ZigZagEncode(long value) + { + return (value >> 63) ^ (value << 1); + } + + /// Zig zag decodes . + /// The value to decode. + /// The zig zag-decoded value. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int ZigZagDecode(int value) + { + return (value >> 1) ^ -(value & 1); + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long ZigZagDecode(long value) + { + return (value >> 1) ^ -(value & 1); + } + #endregion + + #region Bits + /// Takes bits from and writes them into , starting at . + /// The bitfield from which to write the bits into the array. + /// The number of bits to write. + /// The array to write the bits into. + /// The bit position in the array at which to start writing. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void SetBits(byte bitfield, int amount, byte[] array, int startBit) + { + byte mask = (byte)((1 << amount) - 1); + bitfield &= mask; // Discard any bits that are set beyond the ones we're setting + int inverseMask = ~mask; + int pos = startBit / BitsPerByte; + int bit = startBit % BitsPerByte; + if (bit == 0) + array[pos] = (byte)(bitfield | (array[pos] & inverseMask)); + else + { + array[pos ] = (byte)((bitfield << bit) | (array[pos] & ~(mask << bit))); + array[pos + 1] = (byte)((bitfield >> (8 - bit)) | (array[pos + 1] & (inverseMask >> (8 - bit)))); + } + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void SetBits(ushort bitfield, int amount, byte[] array, int startBit) + { + ushort mask = (ushort)((1 << amount) - 1); + bitfield &= mask; // Discard any bits that are set beyond the ones we're setting + int inverseMask = ~mask; + int pos = startBit / BitsPerByte; + int bit = startBit % BitsPerByte; + if (bit == 0) + { + array[pos ] = (byte)(bitfield | (array[pos] & inverseMask)); + array[pos + 1] = (byte)((bitfield >> 8) | (array[pos + 1] & (inverseMask >> 8))); + } + else + { + array[pos ] = (byte)((bitfield << bit) | (array[pos] & ~(mask << bit))); + bitfield >>= 8 - bit; + inverseMask >>= 8 - bit; + array[pos + 1] = (byte)(bitfield | (array[pos + 1] & inverseMask)); + array[pos + 2] = (byte)((bitfield >> 8) | (array[pos + 2] & (inverseMask >> 8))); + } + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void SetBits(uint bitfield, int amount, byte[] array, int startBit) + { + uint mask = (1u << (amount - 1) << 1) - 1; // Perform 2 shifts, doing it in 1 doesn't cause the value to wrap properly + bitfield &= mask; // Discard any bits that are set beyond the ones we're setting + uint inverseMask = ~mask; + int pos = startBit / BitsPerByte; + int bit = startBit % BitsPerByte; + if (bit == 0) + { + array[pos ] = (byte)(bitfield | (array[pos] & inverseMask)); + array[pos + 1] = (byte)((bitfield >> 8) | (array[pos + 1] & (inverseMask >> 8))); + array[pos + 2] = (byte)((bitfield >> 16) | (array[pos + 2] & (inverseMask >> 16))); + array[pos + 3] = (byte)((bitfield >> 24) | (array[pos + 3] & (inverseMask >> 24))); + } + else + { + array[pos ] = (byte)((bitfield << bit) | (array[pos] & ~(mask << bit))); + bitfield >>= 8 - bit; + inverseMask >>= 8 - bit; + array[pos + 1] = (byte)(bitfield | (array[pos + 1] & inverseMask)); + array[pos + 2] = (byte)((bitfield >> 8) | (array[pos + 2] & (inverseMask >> 8))); + array[pos + 3] = (byte)((bitfield >> 16) | (array[pos + 3] & (inverseMask >> 16))); + array[pos + 4] = (byte)((bitfield >> 24) | (array[pos + 4] & ~(mask >> (32 - bit)))); // This one can't use inverseMask because it would have incorrectly zeroed bits + } + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void SetBits(ulong bitfield, int amount, byte[] array, int startBit) + { + ulong mask = (1ul << (amount - 1) << 1) - 1; // Perform 2 shifts, doing it in 1 doesn't cause the value to wrap properly + bitfield &= mask; // Discard any bits that are set beyond the ones we're setting + ulong inverseMask = ~mask; + int pos = startBit / BitsPerByte; + int bit = startBit % BitsPerByte; + if (bit == 0) + { + array[pos ] = (byte)(bitfield | (array[pos] & inverseMask)); + array[pos + 1] = (byte)((bitfield >> 8) | (array[pos + 1] & (inverseMask >> 8))); + array[pos + 2] = (byte)((bitfield >> 16) | (array[pos + 2] & (inverseMask >> 16))); + array[pos + 3] = (byte)((bitfield >> 24) | (array[pos + 3] & (inverseMask >> 24))); + array[pos + 4] = (byte)((bitfield >> 32) | (array[pos + 4] & (inverseMask >> 32))); + array[pos + 5] = (byte)((bitfield >> 40) | (array[pos + 5] & (inverseMask >> 40))); + array[pos + 6] = (byte)((bitfield >> 48) | (array[pos + 6] & (inverseMask >> 48))); + array[pos + 7] = (byte)((bitfield >> 56) | (array[pos + 7] & (inverseMask >> 56))); + } + else + { + array[pos ] = (byte)((bitfield << bit) | (array[pos] & ~(mask << bit))); + bitfield >>= 8 - bit; + inverseMask >>= 8 - bit; + array[pos + 1] = (byte)(bitfield | (array[pos + 1] & inverseMask)); + array[pos + 2] = (byte)((bitfield >> 8) | (array[pos + 2] & (inverseMask >> 8))); + array[pos + 3] = (byte)((bitfield >> 16) | (array[pos + 3] & (inverseMask >> 16))); + array[pos + 4] = (byte)((bitfield >> 24) | (array[pos + 4] & (inverseMask >> 24))); + array[pos + 5] = (byte)((bitfield >> 32) | (array[pos + 5] & (inverseMask >> 32))); + array[pos + 6] = (byte)((bitfield >> 40) | (array[pos + 6] & (inverseMask >> 40))); + array[pos + 7] = (byte)((bitfield >> 48) | (array[pos + 7] & (inverseMask >> 48))); + array[pos + 8] = (byte)((bitfield >> 56) | (array[pos + 8] & ~(mask >> (64 - bit)))); // This one can't use inverseMask because it would have incorrectly zeroed bits + } + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void SetBits(ulong bitfield, int amount, ulong[] array, int startBit) + { + ulong mask = (1ul << (amount - 1) << 1) - 1; // Perform 2 shifts, doing it in 1 doesn't cause the value to wrap properly + bitfield &= mask; // Discard any bits that are set beyond the ones we're setting + int pos = startBit / BitsPerULong; + int bit = startBit % BitsPerULong; + if (bit == 0) + array[pos] = bitfield | array[pos] & ~mask; + else + { + array[pos] = (bitfield << bit) | (array[pos] & ~(mask << bit)); + if (bit + amount >= BitsPerULong) + array[pos + 1] = (bitfield >> (64 - bit)) | (array[pos + 1] & ~(mask >> (64 - bit))); + } + } + + /// Starting at , reads bits from into . + /// The number of bits to read. + /// The array to read the bits from. + /// The bit position in the array at which to start reading. + /// The bitfield into which to write the bits from the array. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void GetBits(int amount, byte[] array, int startBit, out byte bitfield) + { + bitfield = ByteFromBits(array, startBit); + bitfield &= (byte)((1 << amount) - 1); // Discard any bits that are set beyond the ones we're reading + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void GetBits(int amount, byte[] array, int startBit, out ushort bitfield) + { + bitfield = UShortFromBits(array, startBit); + bitfield &= (ushort)((1 << amount) - 1); // Discard any bits that are set beyond the ones we're reading + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void GetBits(int amount, byte[] array, int startBit, out uint bitfield) + { + bitfield = UIntFromBits(array, startBit); + bitfield &= (1u << (amount - 1) << 1) - 1; // Discard any bits that are set beyond the ones we're reading + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void GetBits(int amount, byte[] array, int startBit, out ulong bitfield) + { + bitfield = ULongFromBits(array, startBit); + bitfield &= (1ul << (amount - 1) << 1) - 1; // Discard any bits that are set beyond the ones we're reading + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void GetBits(int amount, ulong[] array, int startBit, out byte bitfield) + { + bitfield = ByteFromBits(array, startBit); + bitfield &= (byte)((1 << amount) - 1); // Discard any bits that are set beyond the ones we're reading + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void GetBits(int amount, ulong[] array, int startBit, out ushort bitfield) + { + bitfield = UShortFromBits(array, startBit); + bitfield &= (ushort)((1 << amount) - 1); // Discard any bits that are set beyond the ones we're reading + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void GetBits(int amount, ulong[] array, int startBit, out uint bitfield) + { + bitfield = UIntFromBits(array, startBit); + bitfield &= (1u << (amount - 1) << 1) - 1; // Discard any bits that are set beyond the ones we're reading + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void GetBits(int amount, ulong[] array, int startBit, out ulong bitfield) + { + bitfield = ULongFromBits(array, startBit); + bitfield &= (1ul << (amount - 1) << 1) - 1; // Discard any bits that are set beyond the ones we're reading + } + #endregion + + #region Byte/SByte + /// Converts to 8 bits and writes them into at . + /// The to convert. + /// The array to write the bits into. + /// The position in the array at which to write the bits. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void SByteToBits(sbyte value, byte[] array, int startBit) => ByteToBits((byte)value, array, startBit); + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void SByteToBits(sbyte value, ulong[] array, int startBit) => ByteToBits((byte)value, array, startBit); + /// Converts to 8 bits and writes them into at . + /// The to convert. + /// The array to write the bits into. + /// The position in the array at which to write the bits. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void ByteToBits(byte value, byte[] array, int startBit) + { + int pos = startBit / BitsPerByte; + int bit = startBit % BitsPerByte; + if (bit == 0) + array[pos] = value; + else + { + array[pos ] |= (byte)(value << bit); + array[pos + 1] = (byte)(value >> (8 - bit)); + } + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void ByteToBits(byte value, ulong[] array, int startBit) => ToBits(value, BitsPerByte, array, startBit); + + /// Converts the 8 bits at in to an . + /// The array to convert the bits from. + /// The position in the array from which to read the bits. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static sbyte SByteFromBits(byte[] array, int startBit) => (sbyte)ByteFromBits(array, startBit); + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static sbyte SByteFromBits(ulong[] array, int startBit) => (sbyte)ByteFromBits(array, startBit); + /// Converts the 8 bits at in to a . + /// The array to convert the bits from. + /// The position in the array from which to read the bits. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static byte ByteFromBits(byte[] array, int startBit) + { + int pos = startBit / BitsPerByte; + int bit = startBit % BitsPerByte; + byte value = array[pos]; + if (bit == 0) + return value; + + value >>= bit; + return (byte)(value | (array[pos + 1] << (8 - bit))); + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static byte ByteFromBits(ulong[] array, int startBit) => (byte)FromBits(BitsPerByte, array, startBit); + #endregion + + #region Bool + /// Converts to a bit and writes it into at . + /// The to convert. + /// The array to write the bit into. + /// The position in the array at which to write the bit. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void BoolToBit(bool value, byte[] array, int startBit) + { + int pos = startBit / BitsPerByte; + int bit = startBit % BitsPerByte; + if (bit == 0) + array[pos] = 0; + + if (value) + array[pos] |= (byte)(1 << bit); + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void BoolToBit(bool value, ulong[] array, int startBit) + { + int pos = startBit / BitsPerULong; + int bit = startBit % BitsPerULong; + if (bit == 0) + array[pos] = 0; + + if (value) + array[pos] |= 1ul << bit; + } + + /// Converts the bit at in to a . + /// The array to convert the bit from. + /// The position in the array from which to read the bit. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool BoolFromBit(byte[] array, int startBit) + { + int pos = startBit / BitsPerByte; + int bit = startBit % BitsPerByte; + return (array[pos] & (1 << bit)) != 0; + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool BoolFromBit(ulong[] array, int startBit) + { + int pos = startBit / BitsPerULong; + int bit = startBit % BitsPerULong; + return (array[pos] & (1ul << bit)) != 0; + } + #endregion + + #region Short/UShort + /// Converts a given to bytes and writes them into the given array. + /// The to convert. + /// The array to write the bytes into. + /// The position in the array at which to write the bytes. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void FromShort(short value, byte[] array, int startIndex) => FromUShort((ushort)value, array, startIndex); + /// Converts a given to bytes and writes them into the given array. + /// The to convert. + /// The array to write the bytes into. + /// The position in the array at which to write the bytes. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void FromUShort(ushort value, byte[] array, int startIndex) + { +#if BIG_ENDIAN + array[startIndex + 1] = (byte)value; + array[startIndex ] = (byte)(value >> 8); +#else + array[startIndex ] = (byte)value; + array[startIndex + 1] = (byte)(value >> 8); +#endif + } + + /// Converts the 2 bytes in the array at to a . + /// The array to read the bytes from. + /// The position in the array at which to read the bytes. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static short ToShort(byte[] array, int startIndex) => (short)ToUShort(array, startIndex); + /// Converts the 2 bytes in the array at to a . + /// The array to read the bytes from. + /// The position in the array at which to read the bytes. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ushort ToUShort(byte[] array, int startIndex) + { +#if BIG_ENDIAN + return (ushort)(array[startIndex + 1] | (array[startIndex ] << 8)); +#else + return (ushort)(array[startIndex ] | (array[startIndex + 1] << 8)); +#endif + } + + /// Converts to 16 bits and writes them into at . + /// The to convert. + /// The array to write the bits into. + /// The position in the array at which to write the bits. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void ShortToBits(short value, byte[] array, int startBit) => UShortToBits((ushort)value, array, startBit); + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void ShortToBits(short value, ulong[] array, int startBit) => UShortToBits((ushort)value, array, startBit); + /// Converts to 16 bits and writes them into at . + /// The to convert. + /// The array to write the bits into. + /// The position in the array at which to write the bits. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void UShortToBits(ushort value, byte[] array, int startBit) + { + int pos = startBit / BitsPerByte; + int bit = startBit % BitsPerByte; + if (bit == 0) + { + array[pos] = (byte)value; + array[pos + 1] = (byte)(value >> 8); + } + else + { + array[pos ] |= (byte)(value << bit); + value >>= 8 - bit; + array[pos + 1] = (byte)value; + array[pos + 2] = (byte)(value >> 8); + } + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void UShortToBits(ushort value, ulong[] array, int startBit) => ToBits(value, sizeof(ushort) * BitsPerByte, array, startBit); + + /// Converts the 16 bits at in to a . + /// The array to convert the bits from. + /// The position in the array from which to read the bits. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static short ShortFromBits(byte[] array, int startBit) => (short)UShortFromBits(array, startBit); + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static short ShortFromBits(ulong[] array, int startBit) => (short)UShortFromBits(array, startBit); + /// Converts the 16 bits at in to a . + /// The array to convert the bits from. + /// The position in the array from which to read the bits. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ushort UShortFromBits(byte[] array, int startBit) + { + int pos = startBit / BitsPerByte; + int bit = startBit % BitsPerByte; + ushort value = (ushort)(array[pos] | (array[pos + 1] << 8)); + if (bit == 0) + return value; + + value >>= bit; + return (ushort)(value | (array[pos + 2] << (16 - bit))); + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ushort UShortFromBits(ulong[] array, int startBit) => (ushort)FromBits(sizeof(ushort) * BitsPerByte, array, startBit); + #endregion + + #region Int/UInt + /// Converts a given to bytes and writes them into the given array. + /// The to convert. + /// The array to write the bytes into. + /// The position in the array at which to write the bytes. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void FromInt(int value, byte[] array, int startIndex) => FromUInt((uint)value, array, startIndex); + /// Converts a given to bytes and writes them into the given array. + /// The to convert. + /// The array to write the bytes into. + /// The position in the array at which to write the bytes. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void FromUInt(uint value, byte[] array, int startIndex) + { +#if BIG_ENDIAN + array[startIndex + 3] = (byte)value; + array[startIndex + 2] = (byte)(value >> 8); + array[startIndex + 1] = (byte)(value >> 16); + array[startIndex ] = (byte)(value >> 24); +#else + array[startIndex ] = (byte)value; + array[startIndex + 1] = (byte)(value >> 8); + array[startIndex + 2] = (byte)(value >> 16); + array[startIndex + 3] = (byte)(value >> 24); +#endif + } + + /// Converts the 4 bytes in the array at to a . + /// The array to read the bytes from. + /// The position in the array at which to read the bytes. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int ToInt(byte[] array, int startIndex) => (int)ToUInt(array, startIndex); + /// Converts the 4 bytes in the array at to a . + /// The array to read the bytes from. + /// The position in the array at which to read the bytes. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint ToUInt(byte[] array, int startIndex) + { +#if BIG_ENDIAN + return (uint)(array[startIndex + 3] | (array[startIndex + 2] << 8) | (array[startIndex + 1] << 16) | (array[startIndex ] << 24)); +#else + return (uint)(array[startIndex ] | (array[startIndex + 1] << 8) | (array[startIndex + 2] << 16) | (array[startIndex + 3] << 24)); +#endif + } + + /// Converts to 32 bits and writes them into at . + /// The to convert. + /// The array to write the bits into. + /// The position in the array at which to write the bits. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void IntToBits(int value, byte[] array, int startBit) => UIntToBits((uint)value, array, startBit); + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void IntToBits(int value, ulong[] array, int startBit) => UIntToBits((uint)value, array, startBit); + /// Converts to 32 bits and writes them into at . + /// The to convert. + /// The array to write the bits into. + /// The position in the array at which to write the bits. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void UIntToBits(uint value, byte[] array, int startBit) + { + int pos = startBit / BitsPerByte; + int bit = startBit % BitsPerByte; + if (bit == 0) + { + array[pos ] = (byte)value; + array[pos + 1] = (byte)(value >> 8); + array[pos + 2] = (byte)(value >> 16); + array[pos + 3] = (byte)(value >> 24); + } + else + { + array[pos ] |= (byte)(value << bit); + value >>= 8 - bit; + array[pos + 1] = (byte)value; + array[pos + 2] = (byte)(value >> 8); + array[pos + 3] = (byte)(value >> 16); + array[pos + 4] = (byte)(value >> 24); + } + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void UIntToBits(uint value, ulong[] array, int startBit) => ToBits(value, sizeof(uint) * BitsPerByte, array, startBit); + + /// Converts the 32 bits at in to an . + /// The array to convert the bits from. + /// The position in the array from which to read the bits. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IntFromBits(byte[] array, int startBit) => (int)UIntFromBits(array, startBit); + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IntFromBits(ulong[] array, int startBit) => (int)UIntFromBits(array, startBit); + /// Converts the 32 bits at in to a . + /// The array to convert the bits from. + /// The position in the array from which to read the bits. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint UIntFromBits(byte[] array, int startBit) + { + int pos = startBit / BitsPerByte; + int bit = startBit % BitsPerByte; + uint value = (uint)(array[pos] | (array[pos + 1] << 8) | (array[pos + 2] << 16) | (array[pos + 3] << 24)); + if (bit == 0) + return value; + + value >>= bit; + return value | (uint)(array[pos + 4] << (32 - bit)); + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint UIntFromBits(ulong[] array, int startBit) => (uint)FromBits(sizeof(uint) * BitsPerByte, array, startBit); + #endregion + + #region Long/ULong + /// Converts a given to bytes and writes them into the given array. + /// The to convert. + /// The array to write the bytes into. + /// The position in the array at which to write the bytes. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void FromLong(long value, byte[] array, int startIndex) => FromULong((ulong)value, array, startIndex); + /// Converts a given to bytes and writes them into the given array. + /// The to convert. + /// The array to write the bytes into. + /// The position in the array at which to write the bytes. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void FromULong(ulong value, byte[] array, int startIndex) + { +#if BIG_ENDIAN + array[startIndex + 7] = (byte)value; + array[startIndex + 6] = (byte)(value >> 8); + array[startIndex + 5] = (byte)(value >> 16); + array[startIndex + 4] = (byte)(value >> 24); + array[startIndex + 3] = (byte)(value >> 32); + array[startIndex + 2] = (byte)(value >> 40); + array[startIndex + 1] = (byte)(value >> 48); + array[startIndex ] = (byte)(value >> 56); +#else + array[startIndex ] = (byte)value; + array[startIndex + 1] = (byte)(value >> 8); + array[startIndex + 2] = (byte)(value >> 16); + array[startIndex + 3] = (byte)(value >> 24); + array[startIndex + 4] = (byte)(value >> 32); + array[startIndex + 5] = (byte)(value >> 40); + array[startIndex + 6] = (byte)(value >> 48); + array[startIndex + 7] = (byte)(value >> 56); +#endif + } + + /// Converts the 8 bytes in the array at to a . + /// The array to read the bytes from. + /// The position in the array at which to read the bytes. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long ToLong(byte[] array, int startIndex) + { +#if BIG_ENDIAN + Array.Reverse(array, startIndex, longLength); +#endif + return BitConverter.ToInt64(array, startIndex); + } + /// Converts the 8 bytes in the array at to a . + /// The array to read the bytes from. + /// The position in the array at which to read the bytes. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong ToULong(byte[] array, int startIndex) + { +#if BIG_ENDIAN + Array.Reverse(array, startIndex, ulongLength); +#endif + return BitConverter.ToUInt64(array, startIndex); + } + + /// Converts to 64 bits and writes them into at . + /// The to convert. + /// The array to write the bits into. + /// The position in the array at which to write the bits. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void LongToBits(long value, byte[] array, int startBit) => ULongToBits((ulong)value, array, startBit); + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void LongToBits(long value, ulong[] array, int startBit) => ULongToBits((ulong)value, array, startBit); + /// Converts to 64 bits and writes them into at . + /// The to convert. + /// The array to write the bits into. + /// The position in the array at which to write the bits. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void ULongToBits(ulong value, byte[] array, int startBit) + { + int pos = startBit / BitsPerByte; + int bit = startBit % BitsPerByte; + if (bit == 0) + { + array[pos ] = (byte)value; + array[pos + 1] = (byte)(value >> 8); + array[pos + 2] = (byte)(value >> 16); + array[pos + 3] = (byte)(value >> 24); + array[pos + 4] = (byte)(value >> 32); + array[pos + 5] = (byte)(value >> 40); + array[pos + 6] = (byte)(value >> 48); + array[pos + 7] = (byte)(value >> 56); + } + else + { + array[pos ] |= (byte)(value << bit); + value >>= 8 - bit; + array[pos + 1] = (byte)value; + array[pos + 2] = (byte)(value >> 8); + array[pos + 3] = (byte)(value >> 16); + array[pos + 4] = (byte)(value >> 24); + array[pos + 5] = (byte)(value >> 32); + array[pos + 6] = (byte)(value >> 40); + array[pos + 7] = (byte)(value >> 48); + array[pos + 8] = (byte)(value >> 56); + } + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void ULongToBits(ulong value, ulong[] array, int startBit) + { + int pos = startBit / BitsPerULong; + int bit = startBit % BitsPerULong; + if (bit == 0) + array[pos] = value; + else + { + array[pos ] |= value << bit; + array[pos + 1] = value >> (BitsPerULong - bit); + } + } + + /// Converts the 64 bits at in to a . + /// The array to convert the bits from. + /// The position in the array from which to read the bits. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long LongFromBits(byte[] array, int startBit) => (long)ULongFromBits(array, startBit); + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long LongFromBits(ulong[] array, int startBit) => (long)ULongFromBits(array, startBit); + /// Converts the 64 bits at in to a . + /// The array to convert the bits from. + /// The position in the array from which to read the bits. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong ULongFromBits(byte[] array, int startBit) + { + int pos = startBit / BitsPerByte; + int bit = startBit % BitsPerByte; + ulong value = BitConverter.ToUInt64(array, pos); + if (bit == 0) + return value; + + value >>= bit; + return value | ((ulong)array[pos + 8] << (64 - bit)); + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong ULongFromBits(ulong[] array, int startBit) + { + int pos = startBit / BitsPerULong; + int bit = startBit % BitsPerULong; + ulong value = array[pos]; + if (bit == 0) + return value; + + value >>= bit; + return value | (array[pos + 1] << (BitsPerULong - bit)); + } + + /// Converts to bits and writes them into at . + /// Meant for values which fit into a , not for s themselves. + /// The value to convert. + /// The size in bits of the value being converted. + /// The array to write the bits into. + /// The position in the array at which to write the bits. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void ToBits(ulong value, int valueSize, ulong[] array, int startBit) + { + int pos = startBit / BitsPerULong; + int bit = startBit % BitsPerULong; + if (bit == 0) + array[pos] = value; + else if (bit + valueSize < BitsPerULong) + array[pos] |= value << bit; + else + { + array[pos] |= value << bit; + array[pos + 1] = value >> (BitsPerULong - bit); + } + } + /// Converts the bits at in to a . + /// Meant for values which fit into a , not for s themselves. + /// The size in bits of the value being converted. + /// The array to convert the bits from. + /// The position in the array from which to read the bits. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ulong FromBits(int valueSize, ulong[] array, int startBit) + { + int pos = startBit / BitsPerULong; + int bit = startBit % BitsPerULong; + ulong value = array[pos]; + if (bit == 0) + return value; + + value >>= bit; + if (bit + valueSize < BitsPerULong) + return value; + + return value | (array[pos + 1] << (BitsPerULong - bit)); + } + #endregion + + #region Float + /// Converts a given to bytes and writes them into the given array. + /// The to convert. + /// The array to write the bytes into. + /// The position in the array at which to write the bytes. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void FromFloat(float value, byte[] array, int startIndex) + { + FloatConverter converter = new FloatConverter { FloatValue = value }; +#if BIG_ENDIAN + array[startIndex + 3] = converter.Byte0; + array[startIndex + 2] = converter.Byte1; + array[startIndex + 1] = converter.Byte2; + array[startIndex ] = converter.Byte3; +#else + array[startIndex ] = converter.Byte0; + array[startIndex + 1] = converter.Byte1; + array[startIndex + 2] = converter.Byte2; + array[startIndex + 3] = converter.Byte3; +#endif + } + + /// Converts the 4 bytes in the array at to a . + /// The array to read the bytes from. + /// The position in the array at which to read the bytes. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float ToFloat(byte[] array, int startIndex) + { +#if BIG_ENDIAN + return new FloatConverter { Byte3 = array[startIndex], Byte2 = array[startIndex + 1], Byte1 = array[startIndex + 2], Byte0 = array[startIndex + 3] }.FloatValue; +#else + return new FloatConverter { Byte0 = array[startIndex], Byte1 = array[startIndex + 1], Byte2 = array[startIndex + 2], Byte3 = array[startIndex + 3] }.FloatValue; +#endif + } + + /// Converts to 32 bits and writes them into at . + /// The to convert. + /// The array to write the bits into. + /// The position in the array at which to write the bits. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void FloatToBits(float value, byte[] array, int startBit) + { + UIntToBits(new FloatConverter { FloatValue = value }.UIntValue, array, startBit); + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void FloatToBits(float value, ulong[] array, int startBit) + { + UIntToBits(new FloatConverter { FloatValue = value }.UIntValue, array, startBit); + } + + /// Converts the 32 bits at in to a . + /// The array to convert the bits from. + /// The position in the array from which to read the bits. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float FloatFromBits(byte[] array, int startBit) + { + return new FloatConverter { UIntValue = UIntFromBits(array, startBit) }.FloatValue; + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float FloatFromBits(ulong[] array, int startBit) + { + return new FloatConverter { UIntValue = UIntFromBits(array, startBit) }.FloatValue; + } + #endregion + + #region Double + /// Converts a given to bytes and writes them into the given array. + /// The to convert. + /// The array to write the bytes into. + /// The position in the array at which to write the bytes. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void FromDouble(double value, byte[] array, int startIndex) + { + DoubleConverter converter = new DoubleConverter { DoubleValue = value }; +#if BIG_ENDIAN + array[startIndex + 7] = converter.Byte0; + array[startIndex + 6] = converter.Byte1; + array[startIndex + 5] = converter.Byte2; + array[startIndex + 4] = converter.Byte3; + array[startIndex + 3] = converter.Byte4; + array[startIndex + 2] = converter.Byte5; + array[startIndex + 1] = converter.Byte6; + array[startIndex ] = converter.Byte7; +#else + array[startIndex ] = converter.Byte0; + array[startIndex + 1] = converter.Byte1; + array[startIndex + 2] = converter.Byte2; + array[startIndex + 3] = converter.Byte3; + array[startIndex + 4] = converter.Byte4; + array[startIndex + 5] = converter.Byte5; + array[startIndex + 6] = converter.Byte6; + array[startIndex + 7] = converter.Byte7; +#endif + } + + /// Converts the 8 bytes in the array at to a . + /// The array to read the bytes from. + /// The position in the array at which to read the bytes. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static double ToDouble(byte[] array, int startIndex) + { +#if BIG_ENDIAN + Array.Reverse(array, startIndex, doubleLength); +#endif + return BitConverter.ToDouble(array, startIndex); + } + + /// Converts to 64 bits and writes them into at . + /// The to convert. + /// The array to write the bits into. + /// The position in the array at which to write the bits. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void DoubleToBits(double value, byte[] array, int startBit) + { + ULongToBits(new DoubleConverter { DoubleValue = value }.ULongValue, array, startBit); + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void DoubleToBits(double value, ulong[] array, int startBit) + { + ULongToBits(new DoubleConverter { DoubleValue = value }.ULongValue, array, startBit); + } + + /// Converts the 64 bits at in to a . + /// The array to convert the bits from. + /// The position in the array from which to read the bits. + /// The converted . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static double DoubleFromBits(byte[] array, int startBit) + { + return new DoubleConverter { ULongValue = ULongFromBits(array, startBit) }.DoubleValue; + } + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static double DoubleFromBits(ulong[] array, int startBit) + { + return new DoubleConverter { ULongValue = ULongFromBits(array, startBit) }.DoubleValue; + } + #endregion + } + + [StructLayout(LayoutKind.Explicit)] + internal struct FloatConverter + { + [FieldOffset(0)] public byte Byte0; + [FieldOffset(1)] public byte Byte1; + [FieldOffset(2)] public byte Byte2; + [FieldOffset(3)] public byte Byte3; + + [FieldOffset(0)] public float FloatValue; + + [FieldOffset(0)] public uint UIntValue; + } + + [StructLayout(LayoutKind.Explicit)] + internal struct DoubleConverter + { + [FieldOffset(0)] public byte Byte0; + [FieldOffset(1)] public byte Byte1; + [FieldOffset(2)] public byte Byte2; + [FieldOffset(3)] public byte Byte3; + [FieldOffset(4)] public byte Byte4; + [FieldOffset(5)] public byte Byte5; + [FieldOffset(6)] public byte Byte6; + [FieldOffset(7)] public byte Byte7; + + [FieldOffset(0)] public double DoubleValue; + + [FieldOffset(0)] public ulong ULongValue; + } +} diff --git a/Riptide/Utils/DelayedEvents.cs b/Riptide/Utils/DelayedEvents.cs new file mode 100644 index 0000000..481d93f --- /dev/null +++ b/Riptide/Utils/DelayedEvents.cs @@ -0,0 +1,61 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using Riptide.Transports; + +namespace Riptide.Utils +{ + /// Executes an action when invoked. + internal abstract class DelayedEvent + { + /// Executes the action. + public abstract void Invoke(); + } + + /// Resends a when invoked. + internal class ResendEvent : DelayedEvent + { + /// The message to resend. + private readonly PendingMessage message; + /// The time at which the resend event was queued. + private readonly long initiatedAtTime; + + /// Initializes the event. + /// The message to resend. + /// The time at which the resend event was queued. + public ResendEvent(PendingMessage message, long initiatedAtTime) + { + this.message = message; + this.initiatedAtTime = initiatedAtTime; + } + + /// + public override void Invoke() + { + if (initiatedAtTime == message.LastSendTime) // If this isn't the case then the message has been resent already + message.RetrySend(); + } + } + + /// Executes a heartbeat when invoked. + internal class HeartbeatEvent : DelayedEvent + { + /// The peer whose heart to beat. + private readonly Peer peer; + + /// Initializes the event. + /// The peer whose heart to beat. + public HeartbeatEvent(Peer peer) + { + this.peer = peer; + } + + /// + public override void Invoke() + { + peer.Heartbeat(); + } + } +} diff --git a/Riptide/Utils/Extensions.cs b/Riptide/Utils/Extensions.cs new file mode 100644 index 0000000..7b8b712 --- /dev/null +++ b/Riptide/Utils/Extensions.cs @@ -0,0 +1,23 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System.Net; + +namespace Riptide.Utils +{ + /// Contains extension methods for various classes. + public static class Extensions + { + /// Takes the 's IP address and port number and converts it to a string, accounting for whether the address is an IPv4 or IPv6 address. + /// A string containing the IP address and port number of the endpoint. + public static string ToStringBasedOnIPFormat(this IPEndPoint endPoint) + { + if (endPoint.Address.IsIPv4MappedToIPv6) + return $"{endPoint.Address.MapToIPv4()}:{endPoint.Port}"; + + return endPoint.ToString(); + } + } +} diff --git a/Riptide/Utils/Helper.cs b/Riptide/Utils/Helper.cs new file mode 100644 index 0000000..99ddc47 --- /dev/null +++ b/Riptide/Utils/Helper.cs @@ -0,0 +1,113 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; + +namespace Riptide.Utils +{ + /// Contains miscellaneous helper methods. + internal class Helper + { + /// The text to log when disconnected due to . + private const string DCNeverConnected = "Never connected"; + /// The text to log when disconnected due to . + private const string DCTransportError = "Transport error"; + /// The text to log when disconnected due to . + private const string DCTimedOut = "Timed out"; + /// The text to log when disconnected due to . + private const string DCKicked = "Kicked"; + /// The text to log when disconnected due to . + private const string DCServerStopped = "Server stopped"; + /// The text to log when disconnected due to . + private const string DCDisconnected = "Disconnected"; + /// The text to log when disconnected due to . + private const string DCPoorConnection = "Poor connection"; + /// The text to log when disconnected or rejected due to an unknown reason. + private const string UnknownReason = "Unknown reason"; + /// The text to log when the connection failed due to . + private const string CRNoConnection = "No connection"; + /// The text to log when the connection failed due to . + private const string CRAlreadyConnected = "This client is already connected"; + /// The text to log when the connection failed due to . + private const string CRServerFull = "Server is full"; + /// The text to log when the connection failed due to . + private const string CRRejected = "Rejected"; + /// The text to log when the connection failed due to . + private const string CRCustom = "Rejected (with custom data)"; + + /// Determines whether or form should be used based on the . + /// The amount that and refer to. + /// The singular form. + /// The plural form. + /// if is 1; otherwise . + internal static string CorrectForm(int amount, string singular, string plural = "") + { + if (string.IsNullOrEmpty(plural)) + plural = $"{singular}s"; + + return amount == 1 ? singular : plural; + } + + /// Calculates the signed gap between sequence IDs, accounting for wrapping. + /// The new sequence ID. + /// The previous sequence ID. + /// The signed gap between the two given sequence IDs. A positive gap means is newer than . A negative gap means is older than . + internal static int GetSequenceGap(ushort seqId1, ushort seqId2) + { + int gap = seqId1 - seqId2; + if (Math.Abs(gap) <= 32768) // Difference is small, meaning sequence IDs are close together + return gap; + else // Difference is big, meaning sequence IDs are far apart + return (seqId1 <= 32768 ? ushort.MaxValue + 1 + seqId1 : seqId1) - (seqId2 <= 32768 ? ushort.MaxValue + 1 + seqId2 : seqId2); + } + + /// Retrieves the appropriate reason string for the given . + /// The to retrieve the string for. + /// The appropriate reason string. + internal static string GetReasonString(DisconnectReason forReason) + { + switch (forReason) + { + case DisconnectReason.NeverConnected: + return DCNeverConnected; + case DisconnectReason.TransportError: + return DCTransportError; + case DisconnectReason.TimedOut: + return DCTimedOut; + case DisconnectReason.Kicked: + return DCKicked; + case DisconnectReason.ServerStopped: + return DCServerStopped; + case DisconnectReason.Disconnected: + return DCDisconnected; + case DisconnectReason.PoorConnection: + return DCPoorConnection; + default: + return $"{UnknownReason} '{forReason}'"; + } + } + /// Retrieves the appropriate reason string for the given . + /// The to retrieve the string for. + /// The appropriate reason string. + internal static string GetReasonString(RejectReason forReason) + { + switch (forReason) + { + case RejectReason.NoConnection: + return CRNoConnection; + case RejectReason.AlreadyConnected: + return CRAlreadyConnected; + case RejectReason.ServerFull: + return CRServerFull; + case RejectReason.Rejected: + return CRRejected; + case RejectReason.Custom: + return CRCustom; + default: + return $"{UnknownReason} '{forReason}'"; + } + } + } +} diff --git a/Riptide/Utils/PriorityQueue.cs b/Riptide/Utils/PriorityQueue.cs new file mode 100644 index 0000000..9c44c40 --- /dev/null +++ b/Riptide/Utils/PriorityQueue.cs @@ -0,0 +1,158 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; +using System.Collections.Generic; + +namespace Riptide.Utils +{ + // PriorityQueue unfortunately doesn't exist in .NET Standard 2.1 + /// Represents a collection of items that have a value and a priority. On dequeue, the item with the lowest priority value is removed. + /// Specifies the type of elements in the queue. + /// Specifies the type of priority associated with enqueued elements. + public class PriorityQueue + { + /// Gets the number of elements contained in the . + public int Count { get; private set; } + + private const int DefaultCapacity = 8; + private Entry[] heap; + private readonly IComparer comparer; + + /// Initializes a new instance of the class. + /// Initial capacity to allocate for the underlying heap array. + public PriorityQueue(int capacity = DefaultCapacity) + { + heap = new Entry[capacity]; + comparer = Comparer.Default; + } + + /// Initializes a new instance of the class with the specified custom priority comparer. + /// Custom comparer dictating the ordering of elements. + /// Initial capacity to allocate for the underlying heap array. + public PriorityQueue(IComparer comparer, int capacity = DefaultCapacity) + { + heap = new Entry[capacity]; + this.comparer = comparer; + } + + /// Adds the specified element and associated priority to the . + /// The element to add. + /// The priority with which to associate the new element. + public void Enqueue(TElement element, TPriority priority) + { + if (Count == heap.Length) + { + // Resizing is necessary + Entry[] temp = new Entry[Count * 2]; + Array.Copy(heap, temp, heap.Length); + heap = temp; + } + + int index = Count; + while (index > 0) + { + int parentIndex = GetParentIndex(index); + if (comparer.Compare(priority, heap[parentIndex].Priority) < 0) + { + heap[index] = heap[parentIndex]; + index = parentIndex; + } + else + break; + } + + heap[index] = new Entry(element, priority); + Count++; + } + + /// Removes and returns the lowest priority element. + public TElement Dequeue() + { + TElement returnValue = heap[0].Element; + + if (Count > 1) + { + int parent = 0; + int leftChild = GetLeftChildIndex(parent); + + while (leftChild < Count) + { + int rightChild = leftChild + 1; + int bestChild = (rightChild < Count && comparer.Compare(heap[rightChild].Priority, heap[leftChild].Priority) < 0) ? rightChild : leftChild; + + heap[parent] = heap[bestChild]; + parent = bestChild; + leftChild = GetLeftChildIndex(parent); + } + + heap[parent] = heap[Count - 1]; + } + + Count--; + return returnValue; + } + + /// Removes the lowest priority element from the and copies it and its associated priority to the and arguments. + /// When this method returns, contains the removed element. + /// When this method returns, contains the priority associated with the removed element. + /// true if the element is successfully removed; false if the is empty. + public bool TryDequeue(out TElement element, out TPriority priority) + { + if (Count > 0) + { + priority = heap[0].Priority; + element = Dequeue(); + return true; + } + { + element = default(TElement); + priority = default(TPriority); + return false; + } + } + + /// Returns the lowest priority element. + public TElement Peek() + { + return heap[0].Element; + } + + /// Returns the priority of the lowest priority element. + public TPriority PeekPriority() + { + return heap[0].Priority; + } + + /// Removes all elements from the . + public void Clear() + { + Array.Clear(heap, 0, heap.Length); + Count = 0; + } + + private static int GetParentIndex(int index) + { + return (index - 1) / 2; + } + + private static int GetLeftChildIndex(int index) + { + return (index * 2) + 1; + } + + private struct Entry + { + internal readonly TEle Element; + internal readonly TPrio Priority; + + public Entry(TEle element, TPrio priority) + { + Element = element; + Priority = priority; + } + } + } +} diff --git a/Riptide/Utils/RiptideLogger.cs b/Riptide/Utils/RiptideLogger.cs new file mode 100644 index 0000000..c3d51ef --- /dev/null +++ b/Riptide/Utils/RiptideLogger.cs @@ -0,0 +1,126 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; +using System.Collections.Generic; + +namespace Riptide.Utils +{ + /// Defines log message types. + public enum LogType + { + /// Logs that are used for investigation during development. + Debug, + /// Logs that provide general information about application flow. + Info, + /// Logs that highlight abnormal or unexpected events in the application flow. + Warning, + /// Logs that highlight problematic events in the application flow which will cause unexpected behavior if not planned for. + Error + } + + /// Provides functionality for logging messages. + public class RiptideLogger + { + /// Whether or not messages will be logged. + public static bool IsDebugLoggingEnabled => logMethods.ContainsKey(LogType.Debug); + /// Whether or not messages will be logged. + public static bool IsInfoLoggingEnabled => logMethods.ContainsKey(LogType.Info); + /// Whether or not messages will be logged. + public static bool IsWarningLoggingEnabled => logMethods.ContainsKey(LogType.Warning); + /// Whether or not messages will be logged. + public static bool IsErrorLoggingEnabled => logMethods.ContainsKey(LogType.Error); + /// Encapsulates a method used to log messages. + /// The message to log. + public delegate void LogMethod(string log); + + /// Log methods, accessible by their + private static readonly Dictionary logMethods = new Dictionary(4); + /// Whether or not to include timestamps when logging messages. + private static bool includeTimestamps; + /// The format to use for timestamps. + private static string timestampFormat; + + /// Initializes with all log types enabled. + /// The method to use when logging all types of messages. + /// Whether or not to include timestamps when logging messages. + /// The format to use for timestamps. + public static void Initialize(LogMethod logMethod, bool includeTimestamps, string timestampFormat = "HH:mm:ss") => Initialize(logMethod, logMethod, logMethod, logMethod, includeTimestamps, timestampFormat); + /// Initializes with the supplied log methods. + /// The method to use when logging debug messages. Set to to disable debug logs. + /// The method to use when logging info messages. Set to to disable info logs. + /// The method to use when logging warning messages. Set to to disable warning logs. + /// The method to use when logging error messages. Set to to disable error logs. + /// Whether or not to include timestamps when logging messages. + /// The format to use for timestamps. + public static void Initialize(LogMethod debugMethod, LogMethod infoMethod, LogMethod warningMethod, LogMethod errorMethod, bool includeTimestamps, string timestampFormat = "HH:mm:ss") + { + logMethods.Clear(); + + if (debugMethod != null) + logMethods.Add(LogType.Debug, debugMethod); + if (infoMethod != null) + logMethods.Add(LogType.Info, infoMethod); + if (warningMethod != null) + logMethods.Add(LogType.Warning, warningMethod); + if (errorMethod != null) + logMethods.Add(LogType.Error, errorMethod); + + RiptideLogger.includeTimestamps = includeTimestamps; + RiptideLogger.timestampFormat = timestampFormat; + } + + /// Enables logging for messages of the given . + /// The type of message to enable logging for. + /// The method to use when logging this type of message. + public static void EnableLoggingFor(LogType logType, LogMethod logMethod) + { + if (logMethods.ContainsKey(logType)) + logMethods[logType] = logMethod; + else + logMethods.Add(logType, logMethod); + } + + /// Disables logging for messages of the given . + /// The type of message to enable logging for. + public static void DisableLoggingFor(LogType logType) => logMethods.Remove(logType); + + /// Logs a message. + /// The type of log message that is being logged. + /// The message to log. + public static void Log(LogType logType, string message) + { + if (logMethods.TryGetValue(logType, out LogMethod logMethod)) + { + if (includeTimestamps) + logMethod($"[{GetTimestamp(DateTime.Now)}]: {message}"); + else + logMethod(message); + } + } + /// Logs a message. + /// The type of log message that is being logged. + /// Who is logging this message. + /// The message to log. + public static void Log(LogType logType, string logName, string message) + { + if (logMethods.TryGetValue(logType, out LogMethod logMethod)) + { + if (includeTimestamps) + logMethod($"[{GetTimestamp(DateTime.Now)}] ({logName}): {message}"); + else + logMethod($"({logName}): {message}"); + } + } + + /// Converts a object to a formatted timestamp string. + /// The time to format. + /// The formatted timestamp. + private static string GetTimestamp(DateTime time) + { + return time.ToString(timestampFormat); + } + } +} diff --git a/Riptide/Utils/RollingStat.cs b/Riptide/Utils/RollingStat.cs new file mode 100644 index 0000000..a3ba524 --- /dev/null +++ b/Riptide/Utils/RollingStat.cs @@ -0,0 +1,93 @@ +// This file is provided under The MIT License as part of RiptideNetworking. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/RiptideNetworking/Riptide/blob/main/LICENSE.md + +using System; +using System.Linq; + +namespace Riptide.Utils +{ + /// Represents a rolling series of numbers. + public class RollingStat + { + /// The position in the array of the latest item. + private int index; + /// How many of the array's slots are in use. + private int slotsFilled; + /// + private double mean; + /// The sum of the mean subtracted from each value in the array. + private double sumOfSquares; + /// The array used to store the values. + private readonly double[] array; + + /// The mean of the stat's values. + public double Mean => mean; + /// The variance of the stat's values. + public double Variance => slotsFilled > 1 ? sumOfSquares / (slotsFilled - 1) : 0; + /// The standard deviation of the stat's values. + public double StandardDev + { + get + { + double variance = Variance; + if (variance >= double.Epsilon) + { + double root = Math.Sqrt(variance); + return double.IsNaN(root) ? 0 : root; + } + + return 0; + } + } + + /// Initializes the stat. + /// The number of values to store. + public RollingStat(int sampleSize) + { + index = 0; + slotsFilled = 0; + mean = 0; + sumOfSquares = 0; + array = new double[sampleSize]; + } + + /// Adds a new value to the stat. + /// The value to add. + public void Add(double value) + { + if (double.IsNaN(value) || double.IsInfinity(value)) + return; + + index %= array.Length; + double oldMean = mean; + double oldValue = array[index]; + array[index] = value; + index++; + + if (slotsFilled == array.Length) + { + double delta = value - oldValue; + mean += delta / slotsFilled; + sumOfSquares += delta * (value - mean + (oldValue - oldMean)); + } + else + { + slotsFilled++; + double delta = value - oldMean; + mean += delta / slotsFilled; + sumOfSquares += delta * (value - mean); + } + } + + /// + public override string ToString() + { + if (slotsFilled == array.Length) + return string.Join(",", array); + + return string.Join(",", array.Take(slotsFilled)); + } + } +} diff --git a/RiptideSteamTransport/KCMSteamManager.cs b/RiptideSteamTransport/KCMSteamManager.cs new file mode 100644 index 0000000..b7e0f70 --- /dev/null +++ b/RiptideSteamTransport/KCMSteamManager.cs @@ -0,0 +1,184 @@ +// The SteamManager is designed to work with Steamworks.NET +// This file is released into the public domain. +// Where that dedication is not recognized you are granted a perpetual, +// irrevocable license to copy and modify this file as you see fit. +// +// Version: 1.0.12 + +#if !(UNITY_STANDALONE_WIN || UNITY_STANDALONE_LINUX || UNITY_STANDALONE_OSX || STEAMWORKS_WIN || STEAMWORKS_LIN_OSX) + +#endif + +using UnityEngine; +#if !DISABLESTEAMWORKS +using System.Collections; +using Steamworks; +using KCM; +using Riptide.Demos.Steam.PlayerHosted; +#endif + +// +// The SteamManager provides a base implementation of Steamworks.NET on which you can build upon. +// It handles the basics of starting up and shutting down the SteamAPI for use. +// +[DisallowMultipleComponent] +public class KCMSteamManager : MonoBehaviour { +#if !DISABLESTEAMWORKS + protected static bool s_EverInitialized = false; + + protected static KCMSteamManager s_instance; + public static KCMSteamManager Instance { + get { + if (s_instance == null) { + return new GameObject("KCMSteamManager").AddComponent(); + } + else { + return s_instance; + } + } + } + + protected bool m_bInitialized = false; + public static bool Initialized { + get { + return Instance.m_bInitialized; + } + } + + protected SteamAPIWarningMessageHook_t m_SteamAPIWarningMessageHook; + + [AOT.MonoPInvokeCallback(typeof(SteamAPIWarningMessageHook_t))] + protected static void SteamAPIDebugTextHook(int nSeverity, System.Text.StringBuilder pchDebugText) { + Main.helper.Log(pchDebugText.ToString()); + } + +#if UNITY_2019_3_OR_NEWER + // In case of disabled Domain Reload, reset static members before entering Play Mode. + [RuntimeInitializeOnLoadMethod(RuntimeInitializeLoadType.SubsystemRegistration)] + private static void InitOnPlayMode() + { + s_EverInitialized = false; + s_instance = null; + } +#endif + + protected virtual void Awake() { + // Only one instance of SteamManager at a time! + Main.helper.Log("Steam awake"); + if (s_instance != null) { + Destroy(gameObject); + return; + } + s_instance = this; + + if(s_EverInitialized) { + // This is almost always an error. + // The most common case where this happens is when SteamManager gets destroyed because of Application.Quit(), + // and then some Steamworks code in some other OnDestroy gets called afterwards, creating a new SteamManager. + // You should never call Steamworks functions in OnDestroy, always prefer OnDisable if possible. + Main.helper.Log("Tried to Initialize the SteamAPI twice in one session!"); + + return; + } + + // We want our SteamManager Instance to persist across scenes. + DontDestroyOnLoad(gameObject); + + if (!Packsize.Test()) { + Main.helper.Log("[Steamworks.NET] Packsize Test returned false, the wrong version of Steamworks.NET is being run in this platform."); + } + + if (!DllCheck.Test()) { + Main.helper.Log("[Steamworks.NET] DllCheck Test returned false, One or more of the Steamworks binaries seems to be the wrong version."); + } + + try { + // If Steam is not running or the game wasn't started through Steam, SteamAPI_RestartAppIfNecessary starts the + // Steam client and also launches this game again if the User owns it. This can act as a rudimentary form of DRM. + + // Once you get a Steam AppID assigned by Valve, you need to replace AppId_t.Invalid with it and + // remove steam_appid.txt from the game depot. eg: "(AppId_t)480" or "new AppId_t(480)". + // See the Valve documentation for more information: https://partner.steamgames.com/doc/sdk/api#initialization_and_shutdown + if (SteamAPI.RestartAppIfNecessary((AppId_t)569480)) { + //Application.Quit(); + Main.helper.Log("Attempted to restart app"); + return; + } + } + catch (System.DllNotFoundException e) { // We catch this exception here, as it will be the first occurrence of it. + Main.helper.Log("[Steamworks.NET] Could not load [lib]steam_api.dll/so/dylib. It's likely not in the correct location. Refer to the README for more details.\n" + e); + + //Application.Quit(); + return; + } + + // Initializes the Steamworks API. + // If this returns false then this indicates one of the following conditions: + // [*] The Steam client isn't running. A running Steam client is required to provide implementations of the various Steamworks interfaces. + // [*] The Steam client couldn't determine the App ID of game. If you're running your application from the executable or debugger directly then you must have a [code-inline]steam_appid.txt[/code-inline] in your game directory next to the executable, with your app ID in it and nothing else. Steam will look for this file in the current working directory. If you are running your executable from a different directory you may need to relocate the [code-inline]steam_appid.txt[/code-inline] file. + // [*] Your application is not running under the same OS user context as the Steam client, such as a different user or administration access level. + // [*] Ensure that you own a license for the App ID on the currently active Steam account. Your game must show up in your Steam library. + // [*] Your App ID is not completely set up, i.e. in Release State: Unavailable, or it's missing default packages. + // Valve's documentation for this is located here: + // https://partner.steamgames.com/doc/sdk/api#initialization_and_shutdown + m_bInitialized = SteamAPI.Init(); + if (!m_bInitialized) { + Main.helper.Log("[Steamworks.NET] SteamAPI_Init() failed. Refer to Valve's documentation or the comment above this line for more information."); + + return; + } + + s_EverInitialized = true; + } + + // This should only ever get called on first load and after an Assembly reload, You should never Disable the Steamworks Manager yourself. + protected virtual void OnEnable() { + if (s_instance == null) { + s_instance = this; + } + + if (!m_bInitialized) { + return; + } + + if (m_SteamAPIWarningMessageHook == null) { + // Set up our callback to receive warning messages from Steam. + // You must launch with "-debug_steamapi" in the launch args to receive warnings. + m_SteamAPIWarningMessageHook = new SteamAPIWarningMessageHook_t(SteamAPIDebugTextHook); + SteamClient.SetWarningMessageHook(m_SteamAPIWarningMessageHook); + } + } + + // OnApplicationQuit gets called too early to shutdown the SteamAPI. + // Because the SteamManager should be persistent and never disabled or destroyed we can shutdown the SteamAPI here. + // Thus it is not recommended to perform any Steamworks work in other OnDestroy functions as the order of execution can not be garenteed upon Shutdown. Prefer OnDisable(). + protected virtual void OnDestroy() { + if (s_instance != this) { + return; + } + + s_instance = null; + + if (!m_bInitialized) { + return; + } + + SteamAPI.Shutdown(); + } + + protected virtual void Update() { + if (!m_bInitialized) { + return; + } + + // Run Steam client callbacks + SteamAPI.RunCallbacks(); + } +#else + public static bool Initialized { + get { + return false; + } + } +#endif //!DISABLESTEAMWORKS +} diff --git a/RiptideSteamTransport/LobbyManager.cs b/RiptideSteamTransport/LobbyManager.cs new file mode 100644 index 0000000..f159086 --- /dev/null +++ b/RiptideSteamTransport/LobbyManager.cs @@ -0,0 +1,175 @@ +using KCM; +using KCM.Enums; +using KCM.Packets.Handlers; +using Steamworks; +using UnityEngine; + +namespace Riptide.Demos.Steam.PlayerHosted +{ + public class LobbyManager : MonoBehaviour + { + private static LobbyManager _singleton; + internal static LobbyManager Singleton + { + get => _singleton; + private set + { + if (_singleton == null) + _singleton = value; + else if (_singleton != value) + { + Debug.Log($"{nameof(LobbyManager)} instance already exists, destroying object!"); + Destroy(value); + } + } + } + + protected Callback lobbyCreated; + protected Callback gameLobbyJoinRequested; + protected Callback lobbyEnter; + + private const string HostAddressKey = "HostAddress"; + private CSteamID lobbyId; + + private void Awake() + { + Singleton = this; + } + + private void Start() + { + + if (!KCMSteamManager.Initialized) + { + Main.helper.Log("Steam is not initialized!"); + return; + } + + lobbyCreated = Callback.Create(OnLobbyCreated); + gameLobbyJoinRequested = Callback.Create(OnGameLobbyJoinRequested); + lobbyEnter = Callback.Create(OnLobbyEnter); + + } + + public static bool loadingSave = false; + + internal void CreateLobby(bool loadingSave = false) + { + var result = SteamMatchmaking.CreateLobby(ELobbyType.k_ELobbyTypePublic, 25); + + LobbyManager.loadingSave = loadingSave; + + } + + private void OnLobbyCreated(LobbyCreated_t callback) + { + + if (callback.m_eResult != EResult.k_EResultOK) + { + //UIManager.Singleton.LobbyCreationFailed(); + Main.helper.Log("Create lobby failed"); + return; + } + + lobbyId = new CSteamID(callback.m_ulSteamIDLobby); + //UIManager.Singleton.LobbyCreationSucceeded(callback.m_ulSteamIDLobby); + + //NetworkManager.Singleton.Server.Start(0, 5, NetworkManager.PlayerHostedDemoMessageHandlerGroupId); + + + KCServer.StartServer(); + + Main.TransitionTo(MenuState.ServerLobby); + + + try + { + Main.helper.Log("About to call connect"); + KCClient.Connect("127.0.0.1"); + + World.inst.Generate(); + ServerLobbyScript.WorldSeed.text = World.inst.GetTextSeed(); + + LobbyHandler.ClearPlayerList(); + + /*Cam.inst.desiredDist = 80f; + Cam.inst.desiredPhi = 45f; + CloudSystem.inst.threshold1 = 0.6f; + CloudSystem.inst.threshold2 = 0.8f; + CloudSystem.inst.BaseFreq = 4.5f; + Weather.inst.SetSeason(Weather.Season.Summer); + + + Main.TransitionTo(MenuState.NameAndBanner);*/ + + ServerBrowser.registerServer = true; + } + catch (System.Exception ex) + { + Main.helper.Log("----------------------- Main exception -----------------------"); + Main.helper.Log(ex.ToString()); + Main.helper.Log("----------------------- Main message -----------------------"); + Main.helper.Log(ex.Message); + Main.helper.Log("----------------------- Main stacktrace -----------------------"); + Main.helper.Log(ex.StackTrace); + if (ex.InnerException != null) + { + Main.helper.Log("----------------------- Inner exception -----------------------"); + Main.helper.Log(ex.InnerException.ToString()); + Main.helper.Log("----------------------- Inner message -----------------------"); + Main.helper.Log(ex.InnerException.Message); + Main.helper.Log("----------------------- Inner stacktrace -----------------------"); + Main.helper.Log(ex.InnerException.StackTrace); + } + } + + //NetworkManager.Singleton.Client.Connect("127.0.0.1", messageHandlerGroupId: NetworkManager.PlayerHostedDemoMessageHandlerGroupId); + } + + internal void JoinLobby(ulong lobbyId) + { + SteamMatchmaking.JoinLobby(new CSteamID(lobbyId)); + } + + private void OnGameLobbyJoinRequested(GameLobbyJoinRequested_t callback) + { + SteamMatchmaking.JoinLobby(callback.m_steamIDLobby); + } + + private void OnLobbyEnter(LobbyEnter_t callback) + { + if (KCServer.IsRunning) + return; + + lobbyId = new CSteamID(callback.m_ulSteamIDLobby); + CSteamID hostId = SteamMatchmaking.GetLobbyOwner(lobbyId); + + KCClient.Connect(hostId.ToString()); + //UIManager.Singleton.LobbyEntered(); + } + + public void LeaveLobby() + { + //NetworkManager.Singleton.StopServer(); + //NetworkManager.Singleton.DisconnectClient(); + SteamMatchmaking.LeaveLobby(lobbyId); + + if (KCClient.client.IsConnected) + KCClient.client.Disconnect(); + + Main.helper.Log("clear players"); + Main.kCPlayers.Clear(); + LobbyHandler.ClearPlayerList(); + LobbyHandler.ClearChatEntries(); + Main.helper.Log("end clear players"); + + if (KCServer.IsRunning) + KCServer.server.Stop(); + + + + Main.TransitionTo(MenuState.ServerBrowser); + ServerBrowser.registerServer = false; + } + } +} diff --git a/RiptideSteamTransport/Transport/SteamClient.cs b/RiptideSteamTransport/Transport/SteamClient.cs new file mode 100644 index 0000000..9e1698b --- /dev/null +++ b/RiptideSteamTransport/Transport/SteamClient.cs @@ -0,0 +1,204 @@ +// This file is provided under The MIT License as part of RiptideSteamTransport. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/tom-weiland/RiptideSteamTransport/blob/main/LICENSE.md + +using Steamworks; +using System; +using System.Threading.Tasks; +using UnityEngine; + +namespace Riptide.Transports.Steam +{ + public class SteamClient : SteamPeer, IClient + { + public event EventHandler Connected; + public event EventHandler ConnectionFailed; + public event EventHandler DataReceived; + public event EventHandler Disconnected; + + private const string LocalHostName = "localhost"; + private const string LocalHostIP = "127.0.0.1"; + + private SteamConnection steamConnection; + private SteamServer localServer; + private Callback connectionStatusChanged; + + public SteamClient(SteamServer localServer = null) + { + this.localServer = localServer; + } + + public void ChangeLocalServer(SteamServer newLocalServer) + { + localServer = newLocalServer; + } + + public bool Connect(string hostAddress, out Connection connection, out string connectError) + { + connection = null; + + try + { + //SteamGameServerNetworkingUtils.InitRelayNetworkAccess(); + SteamNetworkingUtils.InitRelayNetworkAccess(); + } + catch (Exception ex) + { + connectError = $"Couldn't connect: {ex}"; + return false; + } + + connectError = $"Invalid host address '{hostAddress}'! Expected '{LocalHostIP}' or '{LocalHostName}' for local connections, or a valid Steam ID."; + if (hostAddress == LocalHostIP || hostAddress == LocalHostName) + { + if (localServer == null) + { + connectError = $"No locally running server was specified to connect to! Either pass a {nameof(SteamServer)} instance to your {nameof(SteamClient)}'s constructor or call its {nameof(SteamClient.ChangeLocalServer)} method before attempting to connect locally."; + connection = null; + return false; + } + + connection = steamConnection = ConnectLocal(); + return true; + } + else if (ulong.TryParse(hostAddress, out ulong hostId)) + { + connection = steamConnection = TryConnect(new CSteamID(hostId)); + return connection != null; + } + + return false; + } + + private SteamConnection ConnectLocal() + { + Debug.Log($"{LogName}: Connecting to locally running server..."); + + connectionStatusChanged = Callback.Create(OnConnectionStatusChanged); + CSteamID playerSteamId = SteamUser.GetSteamID(); + + SteamNetworkingIdentity clientIdentity = new SteamNetworkingIdentity(); + clientIdentity.SetSteamID(playerSteamId); + SteamNetworkingIdentity serverIdentity = new SteamNetworkingIdentity(); + serverIdentity.SetSteamID(playerSteamId); + + SteamNetworkingSockets.CreateSocketPair(out HSteamNetConnection connectionToClient, out HSteamNetConnection connectionToServer, false, ref clientIdentity, ref serverIdentity); + + localServer.Add(new SteamConnection(playerSteamId, connectionToClient, this)); + OnConnected(); + return new SteamConnection(playerSteamId, connectionToServer, this); + } + + private SteamConnection TryConnect(CSteamID hostId) + { + try + { + connectionStatusChanged = Callback.Create(OnConnectionStatusChanged); + + SteamNetworkingIdentity serverIdentity = new SteamNetworkingIdentity(); + serverIdentity.SetSteamID(hostId); + + SteamNetworkingConfigValue_t[] options = new SteamNetworkingConfigValue_t[] { }; + HSteamNetConnection connectionToServer = SteamNetworkingSockets.ConnectP2P(ref serverIdentity, 0, options.Length, options); + + ConnectTimeout(); + return new SteamConnection(hostId, connectionToServer, this); + } + catch (Exception ex) + { + Debug.LogException(ex); + OnConnectionFailed(); + return null; + } + } + + private async void ConnectTimeout() // TODO: confirm if this is needed, Riptide *should* take care of timing out the connection + { + Task timeOutTask = Task.Delay(6000); // TODO: use Riptide Client's TimeoutTime + await Task.WhenAny(timeOutTask); + + if (!steamConnection.IsConnected) + OnConnectionFailed(); + } + + private void OnConnectionStatusChanged(SteamNetConnectionStatusChangedCallback_t callback) + { + if (!callback.m_hConn.Equals(steamConnection.SteamNetConnection)) + { + // When connecting via local loopback connection to a locally running SteamServer (aka + // this player is also the host), other external clients that attempt to connect seem + // to trigger ConnectionStatusChanged callbacks for the locally connected client. Not + // 100% sure why this is the case, but returning out of the callback here when the + // connection doesn't match that between local client & server avoids the problem. + return; + } + + switch (callback.m_info.m_eState) + { + case ESteamNetworkingConnectionState.k_ESteamNetworkingConnectionState_Connected: + OnConnected(); + break; + + case ESteamNetworkingConnectionState.k_ESteamNetworkingConnectionState_ClosedByPeer: + SteamNetworkingSockets.CloseConnection(callback.m_hConn, 0, "Closed by peer", false); + OnDisconnected(DisconnectReason.Disconnected); + break; + + case ESteamNetworkingConnectionState.k_ESteamNetworkingConnectionState_ProblemDetectedLocally: + SteamNetworkingSockets.CloseConnection(callback.m_hConn, 0, "Problem detected", false); + OnDisconnected(DisconnectReason.TransportError); + break; + + default: + Debug.Log($"{LogName}: Connection state changed - {callback.m_info.m_eState} | {callback.m_info.m_szEndDebug}"); + break; + } + } + + public void Poll() + { + if (steamConnection != null) + Receive(steamConnection); + } + + // TODO: disable nagle so this isn't needed + //public void Flush() + //{ + // foreach (SteamConnection connection in connections.Values) + // SteamNetworkingSockets.FlushMessagesOnConnection(connection.SteamNetConnection); + //} + + public void Disconnect() + { + if (connectionStatusChanged != null) + { + connectionStatusChanged.Dispose(); + connectionStatusChanged = null; + } + + SteamNetworkingSockets.CloseConnection(steamConnection.SteamNetConnection, 0, "Disconnected", false); + steamConnection = null; + } + + protected virtual void OnConnected() + { + Connected?.Invoke(this, EventArgs.Empty); + } + + protected virtual void OnConnectionFailed() + { + ConnectionFailed?.Invoke(this, EventArgs.Empty); + } + + protected override void OnDataReceived(byte[] dataBuffer, int amount, SteamConnection fromConnection) + { + DataReceived?.Invoke(this, new DataReceivedEventArgs(dataBuffer, amount, fromConnection)); + } + + protected virtual void OnDisconnected(DisconnectReason reason) + { + Disconnected?.Invoke(this, new DisconnectedEventArgs(steamConnection, reason)); + } + } +} diff --git a/RiptideSteamTransport/Transport/SteamConnection.cs b/RiptideSteamTransport/Transport/SteamConnection.cs new file mode 100644 index 0000000..6ff1ce8 --- /dev/null +++ b/RiptideSteamTransport/Transport/SteamConnection.cs @@ -0,0 +1,72 @@ +// This file is provided under The MIT License as part of RiptideSteamTransport. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/tom-weiland/RiptideSteamTransport/blob/main/LICENSE.md + +using Steamworks; +using System; +using System.Collections.Generic; + +namespace Riptide.Transports.Steam +{ + public class SteamConnection : Connection, IEquatable + { + public readonly CSteamID SteamId; + public readonly HSteamNetConnection SteamNetConnection; + + internal bool DidReceiveConnect; + + private readonly SteamPeer peer; + + internal SteamConnection(CSteamID steamId, HSteamNetConnection steamNetConnection, SteamPeer peer) + { + SteamId = steamId; + SteamNetConnection = steamNetConnection; + this.peer = peer; + } + + protected internal override void Send(byte[] dataBuffer, int amount) + { + peer.Send(dataBuffer, amount, SteamNetConnection); + } + + /// + public override string ToString() => SteamNetConnection.ToString(); + + /// + public override bool Equals(object obj) => Equals(obj as SteamConnection); + /// + public bool Equals(SteamConnection other) + { + if (other is null) + return false; + + if (ReferenceEquals(this, other)) + return true; + + return SteamNetConnection.Equals(other.SteamNetConnection); + } + + /// + public override int GetHashCode() + { + return -721414014 + EqualityComparer.Default.GetHashCode(SteamNetConnection); + } + + public static bool operator ==(SteamConnection left, SteamConnection right) + { + if (left is null) + { + if (right is null) + return true; + + return false; // Only the left side is null + } + + // Equals handles case of null on right side + return left.Equals(right); + } + + public static bool operator !=(SteamConnection left, SteamConnection right) => !(left == right); + } +} diff --git a/RiptideSteamTransport/Transport/SteamPeer.cs b/RiptideSteamTransport/Transport/SteamPeer.cs new file mode 100644 index 0000000..49501a1 --- /dev/null +++ b/RiptideSteamTransport/Transport/SteamPeer.cs @@ -0,0 +1,69 @@ +// This file is provided under The MIT License as part of RiptideSteamTransport. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/tom-weiland/RiptideSteamTransport/blob/main/LICENSE.md + +using Steamworks; +using System; +using System.Runtime.InteropServices; +using UnityEngine; + +namespace Riptide.Transports.Steam +{ + public abstract class SteamPeer + { + /// The name to use when logging messages via . + public const string LogName = "STEAM"; + + protected const int MaxMessages = 256; + + private readonly byte[] receiveBuffer; + + protected SteamPeer() + { + receiveBuffer = new byte[Message.MaxSize + sizeof(ushort)]; + } + + protected void Receive(SteamConnection fromConnection) + { + IntPtr[] ptrs = new IntPtr[MaxMessages]; // TODO: remove allocation? + + // TODO: consider using poll groups -> https://partner.steamgames.com/doc/api/ISteamNetworkingSockets#functions_poll_groups + int messageCount = SteamNetworkingSockets.ReceiveMessagesOnConnection(fromConnection.SteamNetConnection, ptrs, MaxMessages); + if (messageCount > 0) + { + for (int i = 0; i < messageCount; i++) + { + SteamNetworkingMessage_t data = Marshal.PtrToStructure(ptrs[i]); + + if (data.m_cbSize > 0) + { + int byteCount = data.m_cbSize; + if (data.m_cbSize > receiveBuffer.Length) + { + Debug.LogWarning($"{LogName}: Can't fully handle {data.m_cbSize} bytes because it exceeds the maximum of {receiveBuffer.Length}. Data will be incomplete!"); + byteCount = receiveBuffer.Length; + } + + Marshal.Copy(data.m_pData, receiveBuffer, 0, data.m_cbSize); + OnDataReceived(receiveBuffer, byteCount, fromConnection); + } + } + } + } + + internal void Send(byte[] dataBuffer, int numBytes, HSteamNetConnection toConnection) + { + GCHandle handle = GCHandle.Alloc(dataBuffer, GCHandleType.Pinned); + IntPtr pDataBuffer = handle.AddrOfPinnedObject(); + + EResult result = SteamNetworkingSockets.SendMessageToConnection(toConnection, pDataBuffer, (uint)numBytes, Constants.k_nSteamNetworkingSend_Unreliable, out long _); + if (result != EResult.k_EResultOK) + Debug.LogWarning($"{LogName}: Failed to send {numBytes} bytes - {result}"); + + handle.Free(); + } + + protected abstract void OnDataReceived(byte[] dataBuffer, int amount, SteamConnection fromConnection); + } +} diff --git a/RiptideSteamTransport/Transport/SteamServer.cs b/RiptideSteamTransport/Transport/SteamServer.cs new file mode 100644 index 0000000..f5989c7 --- /dev/null +++ b/RiptideSteamTransport/Transport/SteamServer.cs @@ -0,0 +1,160 @@ +// This file is provided under The MIT License as part of RiptideSteamTransport. +// Copyright (c) Tom Weiland +// For additional information please see the included LICENSE.md file or view it on GitHub: +// https://github.com/tom-weiland/RiptideSteamTransport/blob/main/LICENSE.md + +using Steamworks; +using System; +using System.Collections.Generic; +using UnityEngine; + +namespace Riptide.Transports.Steam +{ + public class SteamServer : SteamPeer, IServer + { + public event EventHandler Connected; + public event EventHandler DataReceived; + public event EventHandler Disconnected; + + public ushort Port { get; private set; } + + private Dictionary connections; + private HSteamListenSocket listenSocket; + private Callback connectionStatusChanged; + + public void Start(ushort port) + { + Port = port; + connections = new Dictionary(); + + connectionStatusChanged = Callback.Create(OnConnectionStatusChanged); + + try + { +#if UNITY_SERVER + SteamGameServerNetworkingUtils.InitRelayNetworkAccess(); +#else + SteamNetworkingUtils.InitRelayNetworkAccess(); +#endif + } + catch (Exception ex) + { + Debug.LogException(ex); + } + + SteamNetworkingConfigValue_t[] options = new SteamNetworkingConfigValue_t[] { }; + listenSocket = SteamNetworkingSockets.CreateListenSocketP2P(port, options.Length, options); + } + + private void OnConnectionStatusChanged(SteamNetConnectionStatusChangedCallback_t callback) + { + CSteamID clientSteamId = callback.m_info.m_identityRemote.GetSteamID(); + switch (callback.m_info.m_eState) + { + case ESteamNetworkingConnectionState.k_ESteamNetworkingConnectionState_Connecting: + Accept(callback.m_hConn); + break; + + case ESteamNetworkingConnectionState.k_ESteamNetworkingConnectionState_Connected: + Add(new SteamConnection(clientSteamId, callback.m_hConn, this)); + break; + + case ESteamNetworkingConnectionState.k_ESteamNetworkingConnectionState_ClosedByPeer: + SteamNetworkingSockets.CloseConnection(callback.m_hConn, 0, "Closed by peer", false); + OnDisconnected(clientSteamId, DisconnectReason.Disconnected); + break; + + case ESteamNetworkingConnectionState.k_ESteamNetworkingConnectionState_ProblemDetectedLocally: + SteamNetworkingSockets.CloseConnection(callback.m_hConn, 0, "Problem detected", false); + OnDisconnected(clientSteamId, DisconnectReason.TransportError); + break; + + default: + Debug.Log($"{LogName}: {clientSteamId}'s connection state changed - {callback.m_info.m_eState}"); + break; + } + } + + internal void Add(SteamConnection connection) + { + if (!connections.ContainsKey(connection.SteamId)) + { + connections.Add(connection.SteamId, connection); + OnConnected(connection); + } + else + Debug.Log($"{LogName}: Connection from {connection.SteamId} could not be accepted: Already connected"); + } + + private void Accept(HSteamNetConnection connection) + { + EResult result = SteamNetworkingSockets.AcceptConnection(connection); + if (result != EResult.k_EResultOK) + Debug.LogWarning($"{LogName}: Connection could not be accepted: {result}"); + } + + public void Close(Connection connection) + { + if (connection is SteamConnection steamConnection) + { + SteamNetworkingSockets.CloseConnection(steamConnection.SteamNetConnection, 0, "Disconnected by server", false); + connections.Remove(steamConnection.SteamId); + } + } + + public void Poll() + { + foreach (SteamConnection connection in connections.Values) + Receive(connection); + } + + // TODO: disable nagle so this isn't needed + //public void Flush() + //{ + // foreach (SteamConnection connection in connections.Values) + // SteamNetworkingSockets.FlushMessagesOnConnection(connection.SteamNetConnection); + //} + + public void Shutdown() + { + if (connectionStatusChanged != null) + { + connectionStatusChanged.Dispose(); + connectionStatusChanged = null; + } + + foreach (SteamConnection connection in connections.Values) + SteamNetworkingSockets.CloseConnection(connection.SteamNetConnection, 0, "Server stopped", false); + + connections.Clear(); + SteamNetworkingSockets.CloseListenSocket(listenSocket); + } + + protected internal virtual void OnConnected(Connection connection) + { + Connected?.Invoke(this, new ConnectedEventArgs(connection)); + } + + protected override void OnDataReceived(byte[] dataBuffer, int amount, SteamConnection fromConnection) + { + if ((MessageHeader)dataBuffer[0] == MessageHeader.Connect) + { + if (fromConnection.DidReceiveConnect) + return; + + fromConnection.DidReceiveConnect = true; + } + + DataReceived?.Invoke(this, new DataReceivedEventArgs(dataBuffer, amount, fromConnection)); + } + + protected virtual void OnDisconnected(CSteamID steamId, DisconnectReason reason) + { + if (connections.TryGetValue(steamId, out SteamConnection connection)) + { + Disconnected?.Invoke(this, new DisconnectedEventArgs(connection, reason)); + connections.Remove(steamId); + } + } + } +} diff --git a/ServerBrowser/ServerBrowser.cs b/ServerBrowser/ServerBrowser.cs new file mode 100644 index 0000000..f37ebd8 --- /dev/null +++ b/ServerBrowser/ServerBrowser.cs @@ -0,0 +1,493 @@ +using Harmony; +using KCM.Enums; +using KCM.Packets.Handlers; +using Newtonsoft.Json; +using Riptide.Demos.Steam.PlayerHosted; +using Steamworks; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Linq; +using System.Net; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using UnityEngine; +using UnityEngine.Networking; +using UnityEngine.UI; + +namespace KCM +{ + public class ServerBrowser : MonoBehaviour + { + public static GameObject serverBrowserRef = null; + public static Transform serverBrowserContentRef = null; + + public static GameObject serverLobbyRef = null; + public static Transform serverLobbyPlayerRef = null; + public static Transform serverLobbyChatRef = null; + + public static List ServerEntries = new List(); + public static ServerResponse ServerResponse = new ServerResponse(); + + private string databaseId = "6563181402855cfc8b87"; // Replace with your database ID + private string collectionId = "servers"; // Replace with your collection ID + private string projectId = "kcmmasterserver"; // Replace with your project ID + string apiKey = "f80c8f7f5c07a4d4600a7d9954529a8a7897de58c08d9c2b24eaf638dd66e7007917840cfeea5d2673ad397336b9d68ca48375ca6e918c41ddfbdb84a96fa009e9976dacfbaa0a3a8effd79f862f1ea249822e17d26e111c5da48e20ceb0065421fc7fca7e630172a003cc89dd00c5a636b443bc7c8d85149384db9d6d5f6df6"; // Replace with your API key + + private string serverID = string.Empty; + + public static GameObject inst { get; private set; } + public void Awake() + { + inst = serverBrowserRef; + } + + void Start() + { + inst = serverBrowserRef; + StartCoroutine(LobbyHeartbeat()); + } + + public static bool registerServer = false; + int interval = 0; + + IEnumerator LobbyHeartbeat() + { + while (true) + { + string url = $"https://base.ryanpalmer.tech/v1/databases/{databaseId}/collections/{collectionId}/documents"; + + #region "Get Servers (for browser)" + if (serverBrowserRef != null) + { + WebRequest request = WebRequest.Create(url); + request.Method = "GET"; + request.Headers["X-Appwrite-Project"] = projectId; + request.Headers["X-Appwrite-Key"] = apiKey; + + Task task = Task.Run(async () => + { + using (WebResponse response = await request.GetResponseAsync()) + { + using (Stream stream = response.GetResponseStream()) + { + try + { + StreamReader reader = new StreamReader(stream); + string responseText = reader.ReadToEnd(); + + ServerResponse serverResponse = JsonConvert.DeserializeObject(responseText); + + ServerResponse = serverResponse; + + } + catch (Exception ex) + { + Main.helper.Log("----------------------- Main exception -----------------------"); + Main.helper.Log(ex.ToString()); + Main.helper.Log("----------------------- Main message -----------------------"); + Main.helper.Log(ex.Message); + Main.helper.Log("----------------------- Main stacktrace -----------------------"); + Main.helper.Log(ex.StackTrace); + if (ex.InnerException != null) + { + Main.helper.Log("----------------------- Inner exception -----------------------"); + Main.helper.Log(ex.InnerException.ToString()); + Main.helper.Log("----------------------- Inner message -----------------------"); + Main.helper.Log(ex.InnerException.Message); + Main.helper.Log("----------------------- Inner stacktrace -----------------------"); + Main.helper.Log(ex.InnerException.StackTrace); + } + } + } + } + }); + + + yield return new WaitUntil(() => task.IsCompleted); + + DestroyServerEntries(); + + foreach (ServerEntry serverEntry in ServerResponse.Documents) + { + GameObject entry = Instantiate(PrefabManager.serverEntryItemPrefab, serverBrowserContentRef); + var s = entry.AddComponent(); + + s.Name = serverEntry.Name; + s.Host = serverEntry.Host; + s.MaxPlayers = serverEntry.MaxPlayers; + s.Locked = serverEntry.Locked; + s.PlayerCount = serverEntry.PlayerCount; + s.Difficulty = serverEntry.Difficulty; + s.Port = serverEntry.Port; + s.IPAddress = serverEntry.IPAddress; + s.PlayerId = serverEntry.PlayerId; + + ServerEntries.Add(entry); + } + } + #endregion + + #region "Register Server" + if (registerServer) + { + //Main.helper.Log("Register server"); + registerServer = false; + + Task registerTask = Task.Run(() => + { + WebRequest request = WebRequest.Create(url); + request.Method = "POST"; + request.ContentType = "application/json"; + request.Headers["X-Appwrite-Project"] = projectId; + request.Headers["X-Appwrite-Key"] = apiKey; + + serverID = SteamUser.GetSteamID().ToString(); + + string postData = JsonConvert.SerializeObject(new + { + documentId = serverID, + data = new + { + Name = LobbyHandler.ServerSettings.ServerName, + PlayerId = serverID, + Host = KCClient.inst.Name, + PlayerCount = KCServer.server.ClientCount, + MaxPlayers = LobbyHandler.ServerSettings.MaxPlayers, + Difficulty = Enum.GetName(typeof(Difficulty), LobbyHandler.ServerSettings.Difficulty), + Heartbeat = DateTime.UtcNow.ToString("yyyy-MM-ddTHH:mm:ssZ", CultureInfo.InvariantCulture), + //IPAddress = "127.0.0.1", + Port = 7777, + Locked = false + } + }); + + Main.helper.Log(postData); + + using (var streamWriter = new StreamWriter(request.GetRequestStream())) + { + streamWriter.Write(postData); + } + + return request.GetResponse(); + }); + + // Wait until the task is completed + yield return new WaitUntil(() => registerTask.IsCompleted); + + if (registerTask.Exception != null) + { + Main.helper.Log("Register error"); + Main.helper.Log($"Task Exception: {registerTask.Exception}"); + Main.helper.Log($"Task InnerException: {registerTask.Exception.InnerException}"); + using (WebResponse response = registerTask.Result) + { + using (Stream dataStream = response.GetResponseStream()) + { + using (StreamReader reader = new StreamReader(dataStream)) + { + string responseFromServer = reader.ReadToEnd(); + //Main.helper.Log(responseFromServer); + } + } + } + } + else + { + using (WebResponse response = registerTask.Result) + { + using (Stream dataStream = response.GetResponseStream()) + { + using (StreamReader reader = new StreamReader(dataStream)) + { + string responseFromServer = reader.ReadToEnd(); + //Main.helper.Log(responseFromServer); + } + } + } + } + } + #endregion + + #region "Heartbeat" + if (interval >= 8 && KCServer.IsRunning) + { + //Main.helper.Log("Commence heartbeat"); + Task heartbeatTask = Task.Run(() => + { + WebRequest request = WebRequest.Create(url + "/" + serverID); + request.Method = "PATCH"; + request.ContentType = "application/json"; + request.Headers["X-Appwrite-Project"] = projectId; + request.Headers["X-Appwrite-Key"] = apiKey; + + // Create the request body + string postData = JsonConvert.SerializeObject(new + { + data = new + { + Name = LobbyHandler.ServerSettings.ServerName, + Host = KCClient.inst.Name, + PlayerCount = KCServer.server.ClientCount, + MaxPlayers = LobbyHandler.ServerSettings.MaxPlayers, + Difficulty = Enum.GetName(typeof(Difficulty), LobbyHandler.ServerSettings.Difficulty), + Heartbeat = DateTime.UtcNow.ToString("yyyy-MM-ddTHH:mm:ssZ", CultureInfo.InvariantCulture), + //IPAddress = "127.0.0.1", + Locked = LobbyHandler.ServerSettings.Locked + } + }); + + Main.helper.Log(postData); + + using (var streamWriter = new StreamWriter(request.GetRequestStream())) + { + streamWriter.Write(postData); + } + + return request.GetResponse(); + }); + + // Wait until the task is completed + yield return new WaitUntil(() => heartbeatTask.IsCompleted); + + if (heartbeatTask.Exception != null) + { + Main.helper.Log("Heartbeat error"); + Main.helper.Log($"Task Exception: {heartbeatTask.Exception.InnerException}"); + } + else + { + using (WebResponse response = heartbeatTask.Result) + { + using (Stream dataStream = response.GetResponseStream()) + { + using (StreamReader reader = new StreamReader(dataStream)) + { + string responseFromServer = reader.ReadToEnd(); + //Main.helper.Log(responseFromServer); + } + } + } + } + //Main.helper.Log("Master server heartbeat"); + interval = 0; + } + interval += 1; + #endregion + + yield return new WaitForSecondsRealtime(2.0f); + } + } + + public static void DestroyServerEntries() + { + foreach (GameObject entry in ServerEntries) + Destroy(entry); + + ServerEntries.Clear(); + } + + public static Transform KCMUICanvas { get; set; } + + private void SceneLoaded(KCModHelper helper) + { + Main.helper.Log("Serverbrowser scene loaded"); + + + try + { + GameObject kcmUICanvas = Instantiate(Constants.MainMenuUI_T.Find("TopLevelUICanvas").gameObject); + + for (int i = 0; i < kcmUICanvas.transform.childCount; i++) + Destroy(kcmUICanvas.transform.GetChild(i).gameObject); + + kcmUICanvas.name = "KCMUICanvas"; + kcmUICanvas.transform.SetParent(Constants.MainMenuUI_T); + + KCMUICanvas = kcmUICanvas.transform; + + serverBrowserRef = GameObject.Instantiate(PrefabManager.serverBrowserPrefab, KCMUICanvas.transform); + serverBrowserRef.SetActive(false); + serverBrowserContentRef = serverBrowserRef.transform.Find("Container/Scroll View/Viewport/Content"); + + //hides player name prompt + serverBrowserRef.transform.Find("Container/PlayerName").gameObject.SetActive(false); + + + + serverLobbyRef = GameObject.Instantiate(PrefabManager.serverLobbyPrefab, KCMUICanvas.transform); + serverLobbyPlayerRef = serverLobbyRef.transform.Find("Container/PlayerList/Viewport/Content"); + serverLobbyChatRef = serverLobbyRef.transform.Find("Container/PlayerChat/Viewport/Content"); + serverLobbyRef.SetActive(false); + //browser.transform.position = new Vector3(0, 0, 0); + + + var lobbyScript = serverLobbyRef.GetComponent(); + if (lobbyScript == null) + lobbyScript = serverLobbyRef.AddComponent(); + + + Main.helper.Log($"{lobbyScript == null}"); + + + //Create Server + serverBrowserRef.transform.Find("Container/Create").GetComponent