diff --git a/libs/server/Custom/CustomCommandManager.cs b/libs/server/Custom/CustomCommandManager.cs index 580748544b..4ca58d7b9a 100644 --- a/libs/server/Custom/CustomCommandManager.cs +++ b/libs/server/Custom/CustomCommandManager.cs @@ -2,8 +2,8 @@ // Licensed under the MIT license. using System; -using System.Collections.Generic; -using System.Threading; +using System.Collections.Concurrent; +using System.Diagnostics; namespace Garnet.server { @@ -12,147 +12,104 @@ namespace Garnet.server /// public class CustomCommandManager { - internal static readonly ushort StartOffset = (ushort)(RespCommandExtensions.LastValidCommand + 1); - internal static readonly int MaxRegistrations = 512 - StartOffset; // Temporary fix to reduce map sizes - internal static readonly byte TypeIdStartOffset = (byte)(GarnetObjectTypeExtensions.LastObjectType + 1); - internal static readonly int MaxTypeRegistrations = (byte)(GarnetObjectTypeExtensions.FirstSpecialObjectType) - TypeIdStartOffset; - - internal readonly CustomRawStringCommand[] rawStringCommandMap; - internal readonly CustomObjectCommandWrapper[] objectCommandMap; - internal readonly CustomTransaction[] transactionProcMap; - internal readonly CustomProcedureWrapper[] customProcedureMap; - internal int RawStringCommandId = 0; - internal int ObjectTypeId = 0; - internal int TransactionProcId = 0; - internal int CustomProcedureId = 0; - - internal int CustomCommandsInfoCount => CustomCommandsInfo.Count; - internal readonly Dictionary CustomCommandsInfo = new(StringComparer.OrdinalIgnoreCase); - internal readonly Dictionary CustomCommandsDocs = new(StringComparer.OrdinalIgnoreCase); + internal static readonly int MinMapSize = 8; + internal static readonly byte TypeIdStartOffset = byte.MaxValue - (byte)GarnetObjectTypeExtensions.FirstSpecialObjectType; + + private ConcurrentExpandableMap rawStringCommandMap; + private ConcurrentExpandableMap objectCommandMap; + private ConcurrentExpandableMap transactionProcMap; + private ConcurrentExpandableMap customProcedureMap; + + internal int CustomCommandsInfoCount => customCommandsInfo.Count; + internal readonly ConcurrentDictionary customCommandsInfo = new(StringComparer.OrdinalIgnoreCase); + internal readonly ConcurrentDictionary customCommandsDocs = new(StringComparer.OrdinalIgnoreCase); /// /// Create new custom command manager /// public CustomCommandManager() { - rawStringCommandMap = new CustomRawStringCommand[MaxRegistrations]; - objectCommandMap = new CustomObjectCommandWrapper[MaxTypeRegistrations]; - transactionProcMap = new CustomTransaction[MaxRegistrations]; // can increase up to byte.MaxValue - customProcedureMap = new CustomProcedureWrapper[MaxRegistrations]; + rawStringCommandMap = new ConcurrentExpandableMap(MinMapSize, + (ushort)RespCommand.INVALID - 1, + (ushort)RespCommandExtensions.LastValidCommand + 1); + objectCommandMap = new ConcurrentExpandableMap(MinMapSize, + (byte)GarnetObjectTypeExtensions.FirstSpecialObjectType - 1, + (byte)GarnetObjectTypeExtensions.LastObjectType + 1); + transactionProcMap = new ConcurrentExpandableMap(MinMapSize, 0, byte.MaxValue); + customProcedureMap = new ConcurrentExpandableMap(MinMapSize, 0, byte.MaxValue); } internal int Register(string name, CommandType type, CustomRawStringFunctions customFunctions, RespCommandsInfo commandInfo, RespCommandDocs commandDocs, long expirationTicks) { - int id = Interlocked.Increment(ref RawStringCommandId) - 1; - if (id >= MaxRegistrations) + if (!rawStringCommandMap.TryGetNextId(out var cmdId)) throw new Exception("Out of registration space"); - - rawStringCommandMap[id] = new CustomRawStringCommand(name, (ushort)id, type, customFunctions, expirationTicks); - if (commandInfo != null) CustomCommandsInfo.Add(name, commandInfo); - if (commandDocs != null) CustomCommandsDocs.Add(name, commandDocs); - return id; + Debug.Assert(cmdId <= ushort.MaxValue); + var newCmd = new CustomRawStringCommand(name, (ushort)cmdId, type, customFunctions, expirationTicks); + var setSuccessful = rawStringCommandMap.TrySetValue(cmdId, ref newCmd); + Debug.Assert(setSuccessful); + if (commandInfo != null) customCommandsInfo.AddOrUpdate(name, commandInfo, (_, _) => commandInfo); + if (commandDocs != null) customCommandsDocs.AddOrUpdate(name, commandDocs, (_, _) => commandDocs); + return cmdId; } internal int Register(string name, Func proc, RespCommandsInfo commandInfo = null, RespCommandDocs commandDocs = null) { - int id = Interlocked.Increment(ref TransactionProcId) - 1; - if (id >= MaxRegistrations) + if (!transactionProcMap.TryGetNextId(out var cmdId)) throw new Exception("Out of registration space"); - - transactionProcMap[id] = new CustomTransaction(name, (byte)id, proc); - if (commandInfo != null) CustomCommandsInfo.Add(name, commandInfo); - if (commandDocs != null) CustomCommandsDocs.Add(name, commandDocs); - return id; + Debug.Assert(cmdId <= byte.MaxValue); + + var newCmd = new CustomTransaction(name, (byte)cmdId, proc); + var setSuccessful = transactionProcMap.TrySetValue(cmdId, ref newCmd); + Debug.Assert(setSuccessful); + if (commandInfo != null) customCommandsInfo.AddOrUpdate(name, commandInfo, (_, _) => commandInfo); + if (commandDocs != null) customCommandsDocs.AddOrUpdate(name, commandDocs, (_, _) => commandDocs); + return cmdId; } internal int RegisterType(CustomObjectFactory factory) { - for (int i = 0; i < ObjectTypeId; i++) - if (objectCommandMap[i].factory == factory) - throw new Exception($"Type already registered with ID {i}"); - - int type; - do - { - type = Interlocked.Increment(ref ObjectTypeId) - 1; - if (type >= MaxTypeRegistrations) - throw new Exception("Out of registration space"); - } while (objectCommandMap[type] != null); + if (objectCommandMap.TryGetFirstId(c => c.factory == factory, out var dupRegistrationId)) + throw new Exception($"Type already registered with ID {dupRegistrationId}"); - objectCommandMap[type] = new CustomObjectCommandWrapper((byte)type, factory); - - return type; - } - - internal void RegisterType(int objectTypeId, CustomObjectFactory factory) - { - if (objectTypeId >= MaxTypeRegistrations) - throw new Exception("Type is outside registration space"); + if (!objectCommandMap.TryGetNextId(out var cmdId)) + throw new Exception("Out of registration space"); + Debug.Assert(cmdId <= byte.MaxValue); - if (ObjectTypeId <= objectTypeId) ObjectTypeId = objectTypeId + 1; - for (int i = 0; i < ObjectTypeId; i++) - if (objectCommandMap[i].factory == factory) - throw new Exception($"Type already registered with ID {i}"); + var newCmd = new CustomObjectCommandWrapper((byte)cmdId, factory); + var setSuccessful = objectCommandMap.TrySetValue(cmdId, ref newCmd); + Debug.Assert(setSuccessful); - objectCommandMap[objectTypeId] = new CustomObjectCommandWrapper((byte)objectTypeId, factory); + return cmdId; } - internal (int objectTypeId, int subCommand) Register(string name, CommandType commandType, CustomObjectFactory factory, RespCommandsInfo commandInfo, RespCommandDocs commandDocs) + internal (int objectTypeId, int subCommand) Register(string name, CommandType commandType, CustomObjectFactory factory, RespCommandsInfo commandInfo, RespCommandDocs commandDocs, CustomObjectFunctions customObjectFunctions = null) { - int objectTypeId = -1; - for (int i = 0; i < ObjectTypeId; i++) - { - if (objectCommandMap[i].factory == factory) { objectTypeId = i; break; } - } - - if (objectTypeId == -1) + if (!objectCommandMap.TryGetFirstId(c => c.factory == factory, out var typeId)) { - objectTypeId = Interlocked.Increment(ref ObjectTypeId) - 1; - if (objectTypeId >= MaxTypeRegistrations) + if (!objectCommandMap.TryGetNextId(out typeId)) throw new Exception("Out of registration space"); - objectCommandMap[objectTypeId] = new CustomObjectCommandWrapper((byte)objectTypeId, factory); - } - var wrapper = objectCommandMap[objectTypeId]; + Debug.Assert(typeId <= byte.MaxValue); - int subCommand = Interlocked.Increment(ref wrapper.CommandId) - 1; - if (subCommand >= byte.MaxValue) - throw new Exception("Out of registration space"); - wrapper.commandMap[subCommand] = new CustomObjectCommand(name, (byte)objectTypeId, (byte)subCommand, commandType, wrapper.factory); - - if (commandInfo != null) CustomCommandsInfo.Add(name, commandInfo); - if (commandDocs != null) CustomCommandsDocs.Add(name, commandDocs); - - return (objectTypeId, subCommand); - } - - internal (int objectTypeId, int subCommand) Register(string name, CommandType commandType, CustomObjectFactory factory, CustomObjectFunctions customObjectFunctions, RespCommandsInfo commandInfo, RespCommandDocs commandDocs) - { - var objectTypeId = -1; - for (var i = 0; i < ObjectTypeId; i++) - { - if (objectCommandMap[i].factory == factory) { objectTypeId = i; break; } - } - - if (objectTypeId == -1) - { - objectTypeId = Interlocked.Increment(ref ObjectTypeId) - 1; - if (objectTypeId >= MaxTypeRegistrations) - throw new Exception("Out of registration space"); - objectCommandMap[objectTypeId] = new CustomObjectCommandWrapper((byte)objectTypeId, factory); + var newCmd = new CustomObjectCommandWrapper((byte)typeId, factory); + var setSuccessful = objectCommandMap.TrySetValue(typeId, ref newCmd); + Debug.Assert(setSuccessful); } - var wrapper = objectCommandMap[objectTypeId]; - - int subCommand = Interlocked.Increment(ref wrapper.CommandId) - 1; - if (subCommand >= byte.MaxValue) + objectCommandMap.TryGetValue(typeId, out var wrapper); + if (!wrapper.commandMap.TryGetNextId(out var scId)) throw new Exception("Out of registration space"); - wrapper.commandMap[subCommand] = new CustomObjectCommand(name, (byte)objectTypeId, (byte)subCommand, commandType, wrapper.factory, customObjectFunctions); - if (commandInfo != null) CustomCommandsInfo.Add(name, commandInfo); - if (commandDocs != null) CustomCommandsDocs.Add(name, commandDocs); + Debug.Assert(scId <= byte.MaxValue); + var newSubCmd = new CustomObjectCommand(name, (byte)typeId, (byte)scId, commandType, wrapper.factory, + customObjectFunctions); + var scSetSuccessful = wrapper.commandMap.TrySetValue(scId, ref newSubCmd); + Debug.Assert(scSetSuccessful); - return (objectTypeId, subCommand); + if (commandInfo != null) customCommandsInfo.AddOrUpdate(name, commandInfo, (_, _) => commandInfo); + if (commandDocs != null) customCommandsDocs.AddOrUpdate(name, commandDocs, (_, _) => commandDocs); + + return (typeId, scId); } /// @@ -166,80 +123,59 @@ internal void RegisterType(int objectTypeId, CustomObjectFactory factory) /// internal int Register(string name, Func customProcedure, RespCommandsInfo commandInfo = null, RespCommandDocs commandDocs = null) { - int id = Interlocked.Increment(ref CustomProcedureId) - 1; - if (id >= MaxRegistrations) + if (!customProcedureMap.TryGetNextId(out var cmdId)) throw new Exception("Out of registration space"); - customProcedureMap[id] = new CustomProcedureWrapper(name, (byte)id, customProcedure, this); - if (commandInfo != null) CustomCommandsInfo.Add(name, commandInfo); - if (commandDocs != null) CustomCommandsDocs.Add(name, commandDocs); - return id; + Debug.Assert(cmdId <= byte.MaxValue); + + var newCmd = new CustomProcedureWrapper(name, (byte)cmdId, customProcedure, this); + var setSuccessful = customProcedureMap.TrySetValue(cmdId, ref newCmd); + Debug.Assert(setSuccessful); + + if (commandInfo != null) customCommandsInfo.AddOrUpdate(name, commandInfo, (_, _) => commandInfo); + if (commandDocs != null) customCommandsDocs.AddOrUpdate(name, commandDocs, (_, _) => commandDocs); + return cmdId; } - internal bool Match(ReadOnlySpan command, out CustomRawStringCommand cmd) + internal bool TryGetCustomProcedure(int id, out CustomProcedureWrapper value) + => customProcedureMap.TryGetValue(id, out value); + + internal bool TryGetCustomTransactionProcedure(int id, out CustomTransaction value) + => transactionProcMap.TryGetValue(id, out value); + + internal bool TryGetCustomCommand(int id, out CustomRawStringCommand value) + => rawStringCommandMap.TryGetValue(id, out value); + + internal bool TryGetCustomObjectCommand(int id, out CustomObjectCommandWrapper value) + => objectCommandMap.TryGetValue(id, out value); + + internal bool TryGetCustomObjectSubCommand(int id, int subId, out CustomObjectCommand value) { - for (int i = 0; i < RawStringCommandId; i++) - { - cmd = rawStringCommandMap[i]; - if (cmd != null && command.SequenceEqual(new ReadOnlySpan(cmd.name))) - return true; - } - cmd = null; - return false; + value = default; + return objectCommandMap.TryGetValue(id, out var wrapper) && + wrapper.commandMap.TryGetValue(subId, out value); } + internal bool Match(ReadOnlySpan command, out CustomRawStringCommand cmd) + => rawStringCommandMap.MatchCommandSafe(command, out cmd); + internal bool Match(ReadOnlySpan command, out CustomTransaction cmd) - { - for (int i = 0; i < TransactionProcId; i++) - { - cmd = transactionProcMap[i]; - if (cmd != null && command.SequenceEqual(new ReadOnlySpan(cmd.name))) - return true; - } - cmd = null; - return false; - } + => transactionProcMap.MatchCommandSafe(command, out cmd); internal bool Match(ReadOnlySpan command, out CustomObjectCommand cmd) - { - for (int i = 0; i < ObjectTypeId; i++) - { - var wrapper = objectCommandMap[i]; - if (wrapper != null) - { - for (int j = 0; j < wrapper.CommandId; j++) - { - cmd = wrapper.commandMap[j]; - if (cmd != null && command.SequenceEqual(new ReadOnlySpan(cmd.name))) - return true; - } - } - else break; - } - cmd = null; - return false; - } + => objectCommandMap.MatchSubCommandSafe(command, out cmd); internal bool Match(ReadOnlySpan command, out CustomProcedureWrapper cmd) - { - for (int i = 0; i < CustomProcedureId; i++) - { - cmd = customProcedureMap[i]; - if (cmd != null && command.SequenceEqual(new ReadOnlySpan(cmd.Name))) - return true; - } - cmd = null; - return false; - } + => customProcedureMap.MatchCommandSafe(command, out cmd); internal bool TryGetCustomCommandInfo(string cmdName, out RespCommandsInfo respCommandsInfo) { - return this.CustomCommandsInfo.TryGetValue(cmdName, out respCommandsInfo); + return this.customCommandsInfo.TryGetValue(cmdName, out respCommandsInfo); } internal bool TryGetCustomCommandDocs(string cmdName, out RespCommandDocs respCommandsDocs) { - return this.CustomCommandsDocs.TryGetValue(cmdName, out respCommandsDocs); + return this.customCommandsDocs.TryGetValue(cmdName, out respCommandsDocs); } } } \ No newline at end of file diff --git a/libs/server/Custom/CustomCommandManagerSession.cs b/libs/server/Custom/CustomCommandManagerSession.cs index 8cf7e4ba1f..5664a561ac 100644 --- a/libs/server/Custom/CustomCommandManagerSession.cs +++ b/libs/server/Custom/CustomCommandManagerSession.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +using System.Diagnostics; using Garnet.common; namespace Garnet.server @@ -13,53 +14,81 @@ internal sealed class CustomCommandManagerSession readonly CustomCommandManager customCommandManager; // These session specific arrays are indexed by the same ID as the arrays in CustomCommandManager - readonly (CustomTransactionProcedure, int)[] sessionTransactionProcMap; - readonly CustomProcedure[] sessionCustomProcMap; - + ExpandableMap sessionTransactionProcMap; + ExpandableMap sessionCustomProcMap; public CustomCommandManagerSession(CustomCommandManager customCommandManager) { this.customCommandManager = customCommandManager; - sessionTransactionProcMap = new (CustomTransactionProcedure, int)[CustomCommandManager.MaxRegistrations]; - sessionCustomProcMap = new CustomProcedure[CustomCommandManager.MaxRegistrations]; + sessionTransactionProcMap = new ExpandableMap(CustomCommandManager.MinMapSize, 0, byte.MaxValue); + sessionCustomProcMap = new ExpandableMap(CustomCommandManager.MinMapSize, 0, byte.MaxValue); } public CustomProcedure GetCustomProcedure(int id, RespServerSession respServerSession) { - if (sessionCustomProcMap[id] == null) + if (!sessionCustomProcMap.TryGetValue(id, out var customProc)) { - var entry = customCommandManager.customProcedureMap[id] ?? throw new GarnetException($"Custom procedure {id} not found"); - sessionCustomProcMap[id] = entry.CustomProcedureFactory(); - sessionCustomProcMap[id].respServerSession = respServerSession; + if (!customCommandManager.TryGetCustomProcedure(id, out var entry)) + throw new GarnetException($"Custom procedure {id} not found"); + + customProc = entry.CustomProcedureFactory(); + customProc.respServerSession = respServerSession; + var setSuccessful = sessionCustomProcMap.TrySetValue(id, ref customProc); + Debug.Assert(setSuccessful); } - return sessionCustomProcMap[id]; + return customProc; } - public (CustomTransactionProcedure, int) GetCustomTransactionProcedure(int id, RespServerSession respServerSession, TransactionManager txnManager, ScratchBufferManager scratchBufferManager) + public CustomTransactionProcedure GetCustomTransactionProcedure(int id, RespServerSession respServerSession, TransactionManager txnManager, ScratchBufferManager scratchBufferManager, out int arity) { - if (sessionTransactionProcMap[id].Item1 == null) + if (sessionTransactionProcMap.Exists(id)) { - var entry = customCommandManager.transactionProcMap[id] ?? throw new GarnetException($"Transaction procedure {id} not found"); - _ = customCommandManager.CustomCommandsInfo.TryGetValue(entry.NameStr, out var cmdInfo); - return GetCustomTransactionProcedure(entry, respServerSession, txnManager, scratchBufferManager, cmdInfo?.Arity ?? 0); + ref var customTranProc = ref sessionTransactionProcMap.GetValueByRef(id); + if (customTranProc.Procedure != null) + { + arity = customTranProc.Arity; + return customTranProc.Procedure; + } } - return sessionTransactionProcMap[id]; + + if (!customCommandManager.TryGetCustomTransactionProcedure(id, out var entry)) + throw new GarnetException($"Transaction procedure {id} not found"); + _ = customCommandManager.customCommandsInfo.TryGetValue(entry.NameStr, out var cmdInfo); + arity = cmdInfo?.Arity ?? 0; + return GetCustomTransactionProcedureAndSetArity(entry, respServerSession, txnManager, scratchBufferManager, cmdInfo?.Arity ?? 0); } - public (CustomTransactionProcedure, int) GetCustomTransactionProcedure(CustomTransaction entry, RespServerSession respServerSession, TransactionManager txnManager, ScratchBufferManager scratchBufferManager, int arity) + private CustomTransactionProcedure GetCustomTransactionProcedureAndSetArity(CustomTransaction entry, RespServerSession respServerSession, TransactionManager txnManager, ScratchBufferManager scratchBufferManager, int arity) { int id = entry.id; - if (sessionTransactionProcMap[id].Item1 == null) + + var customTranProc = new CustomTransactionProcedureWithArity(entry.proc(), arity) { - sessionTransactionProcMap[id].Item1 = entry.proc(); - sessionTransactionProcMap[id].Item2 = arity; + Procedure = + { + txnManager = txnManager, + scratchBufferManager = scratchBufferManager, + respServerSession = respServerSession + } + }; + var setSuccessful = sessionTransactionProcMap.TrySetValue(id, ref customTranProc); + Debug.Assert(setSuccessful); + + return customTranProc.Procedure; + } + + private struct CustomTransactionProcedureWithArity + { + public CustomTransactionProcedure Procedure { get; } - sessionTransactionProcMap[id].Item1.txnManager = txnManager; - sessionTransactionProcMap[id].Item1.scratchBufferManager = scratchBufferManager; - sessionTransactionProcMap[id].Item1.respServerSession = respServerSession; + public int Arity { get; } + + public CustomTransactionProcedureWithArity(CustomTransactionProcedure procedure, int arity) + { + this.Procedure = procedure; + this.Arity = arity; } - return sessionTransactionProcMap[id]; } } } \ No newline at end of file diff --git a/libs/server/Custom/CustomCommandRegistration.cs b/libs/server/Custom/CustomCommandRegistration.cs index ddf7f830a7..a03f4421df 100644 --- a/libs/server/Custom/CustomCommandRegistration.cs +++ b/libs/server/Custom/CustomCommandRegistration.cs @@ -233,9 +233,9 @@ public override void Register(CustomCommandManager customCommandManager) RegisterArgs.Name, RegisterArgs.CommandType, factory, - RegisterArgs.ObjectCommand, RegisterArgs.CommandInfo, - RegisterArgs.CommandDocs); + RegisterArgs.CommandDocs, + RegisterArgs.ObjectCommand); } } diff --git a/libs/server/Custom/CustomObjectCommand.cs b/libs/server/Custom/CustomObjectCommand.cs index 96e7d168da..0f7ec804a5 100644 --- a/libs/server/Custom/CustomObjectCommand.cs +++ b/libs/server/Custom/CustomObjectCommand.cs @@ -3,10 +3,11 @@ namespace Garnet.server { - public class CustomObjectCommand + public class CustomObjectCommand : ICustomCommand { + public byte[] Name { get; } + public readonly string NameStr; - public readonly byte[] name; public readonly byte id; public readonly byte subid; public readonly CommandType type; @@ -16,14 +17,12 @@ public class CustomObjectCommand internal CustomObjectCommand(string name, byte id, byte subid, CommandType type, CustomObjectFactory factory, CustomObjectFunctions functions = null) { NameStr = name.ToUpperInvariant(); - this.name = System.Text.Encoding.ASCII.GetBytes(NameStr); + this.Name = System.Text.Encoding.ASCII.GetBytes(NameStr); this.id = id; this.subid = subid; this.type = type; this.factory = factory; this.functions = functions; } - - internal GarnetObjectType GetObjectType() => (GarnetObjectType)(id + CustomCommandManager.TypeIdStartOffset); } } \ No newline at end of file diff --git a/libs/server/Custom/CustomObjectCommandWrapper.cs b/libs/server/Custom/CustomObjectCommandWrapper.cs index 57b8ce4194..5c5a0d5ce8 100644 --- a/libs/server/Custom/CustomObjectCommandWrapper.cs +++ b/libs/server/Custom/CustomObjectCommandWrapper.cs @@ -8,16 +8,18 @@ namespace Garnet.server /// class CustomObjectCommandWrapper { + static readonly int MinMapSize = 8; + static readonly byte MaxSubId = 31; // RespInputHeader uses the 3 MSBs of SubId, so SubId must fit in the 5 LSBs + public readonly byte id; public readonly CustomObjectFactory factory; - public int CommandId = 0; - public readonly CustomObjectCommand[] commandMap; + public ConcurrentExpandableMap commandMap; public CustomObjectCommandWrapper(byte id, CustomObjectFactory functions) { this.id = id; this.factory = functions; - this.commandMap = new CustomObjectCommand[byte.MaxValue]; + this.commandMap = new ConcurrentExpandableMap(MinMapSize, 0, MaxSubId); } } } \ No newline at end of file diff --git a/libs/server/Custom/CustomProcedureWrapper.cs b/libs/server/Custom/CustomProcedureWrapper.cs index fa7b1e2349..aac96b0b93 100644 --- a/libs/server/Custom/CustomProcedureWrapper.cs +++ b/libs/server/Custom/CustomProcedureWrapper.cs @@ -22,10 +22,11 @@ public abstract bool Execute(TGarnetApi garnetApi, ref CustomProcedu where TGarnetApi : IGarnetApi; } - class CustomProcedureWrapper + class CustomProcedureWrapper : ICustomCommand { + public byte[] Name { get; } + public readonly string NameStr; - public readonly byte[] Name; public readonly byte Id; public readonly Func CustomProcedureFactory; diff --git a/libs/server/Custom/CustomRawStringCommand.cs b/libs/server/Custom/CustomRawStringCommand.cs index 0959cab9f1..1dec27cf9d 100644 --- a/libs/server/Custom/CustomRawStringCommand.cs +++ b/libs/server/Custom/CustomRawStringCommand.cs @@ -3,10 +3,11 @@ namespace Garnet.server { - public class CustomRawStringCommand + public class CustomRawStringCommand : ICustomCommand { + public byte[] Name { get; } + public readonly string NameStr; - public readonly byte[] name; public readonly ushort id; public readonly CommandType type; public readonly CustomRawStringFunctions functions; @@ -15,13 +16,11 @@ public class CustomRawStringCommand internal CustomRawStringCommand(string name, ushort id, CommandType type, CustomRawStringFunctions functions, long expirationTicks) { NameStr = name.ToUpperInvariant(); - this.name = System.Text.Encoding.ASCII.GetBytes(NameStr); + this.Name = System.Text.Encoding.ASCII.GetBytes(NameStr); this.id = id; this.type = type; this.functions = functions; this.expirationTicks = expirationTicks; } - - internal RespCommand GetRespCommand() => (RespCommand)(id + CustomCommandManager.StartOffset); } } \ No newline at end of file diff --git a/libs/server/Custom/CustomRespCommands.cs b/libs/server/Custom/CustomRespCommands.cs index f8e7e5f4b1..1c1362923f 100644 --- a/libs/server/Custom/CustomRespCommands.cs +++ b/libs/server/Custom/CustomRespCommands.cs @@ -52,7 +52,7 @@ private bool TryTransactionProc(byte id, CustomTransactionProcedure proc, int st public bool RunTransactionProc(byte id, ref CustomProcedureInput procInput, ref MemoryResult output) { var proc = customCommandManagerSession - .GetCustomTransactionProcedure(id, this, txnManager, scratchBufferManager).Item1; + .GetCustomTransactionProcedure(id, this, txnManager, scratchBufferManager, out _); return txnManager.RunTransactionProc(id, ref procInput, proc, ref output); } @@ -226,7 +226,7 @@ public bool InvokeCustomRawStringCommand(ref TGarnetApi storageApi, var sbKey = key.SpanByte; var inputArg = customCommand.expirationTicks > 0 ? DateTimeOffset.UtcNow.Ticks + customCommand.expirationTicks : customCommand.expirationTicks; customCommandParseState.InitializeWithArguments(args); - var rawStringInput = new RawStringInput(customCommand.GetRespCommand(), ref customCommandParseState, arg1: inputArg); + var rawStringInput = new RawStringInput((RespCommand)customCommand.id, ref customCommandParseState, arg1: inputArg); var _output = new SpanByteAndMemory(null); GarnetStatus status; @@ -290,7 +290,7 @@ public bool InvokeCustomObjectCommand(ref TGarnetApi storageApi, Cus var keyBytes = key.ToArray(); // Prepare input - var header = new RespInputHeader(customObjCommand.GetObjectType()) { SubId = customObjCommand.subid }; + var header = new RespInputHeader((GarnetObjectType)customObjCommand.id) { SubId = customObjCommand.subid }; customCommandParseState.InitializeWithArguments(args); var input = new ObjectInput(header, ref customCommandParseState); diff --git a/libs/server/Custom/CustomTransaction.cs b/libs/server/Custom/CustomTransaction.cs index 7e42170444..0a7a851a23 100644 --- a/libs/server/Custom/CustomTransaction.cs +++ b/libs/server/Custom/CustomTransaction.cs @@ -6,10 +6,11 @@ namespace Garnet.server { - class CustomTransaction + class CustomTransaction : ICustomCommand { + public byte[] Name { get; } + public readonly string NameStr; - public readonly byte[] name; public readonly byte id; public readonly Func proc; @@ -18,7 +19,7 @@ internal CustomTransaction(string name, byte id, Func + /// This interface describes an API for a map of items of type T whose keys are a specified range of IDs (can be descending / ascending) + /// The size of the underlying array containing the items doubles in size as needed. + /// + /// + internal interface IExpandableMap + { + /// + /// Checks if ID is mapped to a value in underlying array + /// + /// Item ID + /// True if ID exists + bool Exists(int id); + + /// + /// Try to get item by ID + /// + /// Item ID + /// Item value + /// True if item found + bool TryGetValue(int id, out T value); + + /// + /// Try to get item by ref by ID + /// + /// Item ID + /// Item value + ref T GetValueByRef(int id); + + /// + /// Try to set item by ID + /// + /// Item ID + /// Item value + /// True if actual size of map should be updated (true by default) + /// True if assignment succeeded + bool TrySetValue(int id, ref T value, bool updateSize = true); + + /// + /// Get next item ID for assignment + /// + /// Item ID + /// True if item ID available + bool TryGetNextId(out int id); + + /// + /// Find first ID in map of item that fulfills specified predicate + /// + /// Predicate + /// ID if found, otherwise -1 + /// True if ID found + bool TryGetFirstId(Func predicate, out int id); + } + + /// + /// This struct defines a map of items of type T whose keys are a specified range of IDs (can be descending / ascending) + /// The size of the underlying array containing the items doubles in size as needed. + /// This struct is not thread-safe, for a thread-safe option see ConcurrentExpandableMap. + /// + /// Type of item to store + internal struct ExpandableMap : IExpandableMap + { + /// + /// The underlying array containing the items + /// + internal T[] Map { get; private set; } + + /// + /// The actual size of the map + /// i.e. the max index of an inserted item + 1 (not the size of the underlying array) + /// + internal int ActualSize { get; private set; } + + // The last requested index for assignment + int currIndex = -1; + // Initial array size + readonly int minSize; + // Value of min item ID + readonly int minId; + // Value of max item ID + readonly int maxSize; + // True if item IDs are in descending order + readonly bool descIds; + + /// + /// Creates a new instance of ExpandableMap + /// + /// Initial size of underlying array + /// The minimal item ID value + /// The maximal item ID value (can be smaller than minId for descending order of IDs) + public ExpandableMap(int minSize, int minId, int maxId) + { + this.Map = null; + this.minSize = minSize; + this.minId = minId; + this.maxSize = Math.Abs(maxId - minId) + 1; + this.descIds = minId > maxId; + } + + /// + public bool TryGetValue(int id, out T value) + { + value = default; + var idx = GetIndexFromId(id); + if (idx < 0 || idx >= ActualSize) + return false; + + value = Map[idx]; + return true; + } + + /// + public bool Exists(int id) + { + var idx = GetIndexFromId(id); + return idx >= 0 && idx < ActualSize; + } + + /// + public ref T GetValueByRef(int id) + { + var idx = GetIndexFromId(id); + if (idx < 0 || idx >= ActualSize) + throw new ArgumentOutOfRangeException(nameof(idx)); + + return ref Map[idx]; + } + + /// + public bool TrySetValue(int id, ref T value, bool updateSize = true) => + TrySetValue(id, ref value, false, updateSize); + + /// + public bool TryGetNextId(out int id) + { + id = -1; + var nextIdx = ++currIndex; + + if (nextIdx >= maxSize) + return false; + id = GetIdFromIndex(nextIdx); + + return true; + } + + /// + public bool TryGetFirstId(Func predicate, out int id) + { + id = -1; + for (var i = 0; i < ActualSize; i++) + { + if (predicate(Map[i])) + { + id = GetIdFromIndex(i); + return true; + } + } + + return false; + } + + /// + /// Get next item ID for assignment with atomic incrementation of underlying index + /// + /// Item ID + /// True if item ID available + public bool TryGetNextIdSafe(out int id) + { + id = -1; + var nextIdx = Interlocked.Increment(ref currIndex); + + if (nextIdx >= maxSize) + return false; + id = GetIdFromIndex(nextIdx); + + return true; + } + + /// + /// Try to update the actual size of the map based on the inserted item ID + /// + /// The inserted item ID + /// True if should not do actual update + /// True if actual size should be updated (or was updated if noUpdate is false) + internal bool TryUpdateSize(int id, bool noUpdate = false) + { + var idx = GetIndexFromId(id); + + // Should not update the size if the index is out of bounds + // or if index is smaller than the current actual size + if (idx < 0 || idx < ActualSize || idx >= maxSize) return false; + + if (!noUpdate) + ActualSize = idx + 1; + + return true; + } + + /// + /// Try to set item by ID + /// + /// Item ID + /// Item value + /// True if should not attempt to expand the underlying array + /// True if should update actual size of the map + /// True if assignment succeeded + internal bool TrySetValue(int id, ref T value, bool noExpansion, bool updateSize) + { + var idx = GetIndexFromId(id); + if (idx < 0 || idx >= maxSize) return false; + + // If index within array bounds, set item + if (Map != null && idx < Map.Length) + { + Map[idx] = value; + if (updateSize) TryUpdateSize(id); + return true; + } + + if (noExpansion) return false; + + // Double new array size until item can fit + var newSize = Map != null ? Math.Max(Map.Length, minSize) : minSize; + while (idx >= newSize) + { + newSize = Math.Min(maxSize, newSize * 2); + } + + // Create new array, copy existing items and set new item + var newMap = new T[newSize]; + if (Map != null) + { + Array.Copy(Map, newMap, Map.Length); + } + + Map = newMap; + Map[idx] = value; + if (updateSize) TryUpdateSize(id); + return true; + } + + /// + /// Maps map index to item ID + /// + /// Map index + /// Item ID + private int GetIdFromIndex(int index) => descIds ? minId - index : index; + + /// + /// Maps an item ID to a map index + /// + /// Item ID + /// Map index + private int GetIndexFromId(int id) => descIds ? minId - id : id; + } + + /// + /// This struct defines a map of items of type T whose keys are a specified range of IDs (can be descending / ascending) + /// The size of the underlying array containing the items doubles in size as needed + /// This struct is thread-safe with regard to the underlying array pointer. + /// + /// Type of item to store + internal struct ConcurrentExpandableMap : IExpandableMap + { + /// + /// Reader-writer lock for the underlying item array + /// + internal SingleWriterMultiReaderLock eMapLock = new(); + + /// + /// The underlying non-concurrent ExpandableMap (should be accessed using the eMapLock) + /// + internal ExpandableMap eMapUnsafe; + + /// + /// Creates a new instance of ConcurrentExpandableMap + /// + /// Initial size of underlying array + /// The minimal item ID value + /// The maximal item ID value (can be smaller than minId for descending order of IDs) + public ConcurrentExpandableMap(int minSize, int minId, int maxId) + { + this.eMapUnsafe = new ExpandableMap(minSize, minId, maxId); + } + + /// + public bool TryGetValue(int id, out T value) + { + value = default; + eMapLock.ReadLock(); + try + { + return eMapUnsafe.TryGetValue(id, out value); + } + finally + { + eMapLock.ReadUnlock(); + } + } + + /// + public bool Exists(int id) + { + eMapLock.ReadLock(); + try + { + return eMapUnsafe.Exists(id); + } + finally + { + eMapLock.ReadUnlock(); + } + } + + /// + public ref T GetValueByRef(int id) + { + eMapLock.ReadLock(); + try + { + return ref eMapUnsafe.GetValueByRef(id); + } + finally + { + eMapLock.ReadUnlock(); + } + } + + /// + public bool TrySetValue(int id, ref T value, bool updateSize = true) + { + var shouldUpdateSize = false; + + // Try to perform set without taking a write lock first + eMapLock.ReadLock(); + try + { + // Try to set value without expanding map + if (eMapUnsafe.TrySetValue(id, ref value, true, false)) + { + // Check if map size should be updated + if (!updateSize || !eMapUnsafe.TryUpdateSize(id, true)) + return true; + shouldUpdateSize = true; + } + } + finally + { + eMapLock.ReadUnlock(); + } + + eMapLock.WriteLock(); + try + { + // Value already set, just update map size + if (shouldUpdateSize) + { + eMapUnsafe.TryUpdateSize(id); + return true; + } + + // Try to set value with expanding the map, if needed + return eMapUnsafe.TrySetValue(id, ref value, false, true); + } + finally + { + eMapLock.WriteUnlock(); + } + } + + /// + public bool TryGetNextId(out int id) + { + return eMapUnsafe.TryGetNextIdSafe(out id); + } + + /// + public bool TryGetFirstId(Func predicate, out int id) + { + id = -1; + eMapLock.ReadLock(); + try + { + return eMapUnsafe.TryGetFirstId(predicate, out id); + } + finally + { + eMapLock.ReadUnlock(); + } + } + } + + /// + /// Extension methods for ConcurrentExpandableMap + /// + internal static class ConcurrentExpandableMapExtensions + { + /// + /// Match command name with existing commands in map and return first matching instance + /// + /// Type of command + /// Current instance of ConcurrentExpandableMap + /// Command name to match + /// Value of command found + /// True if command found + internal static bool MatchCommandSafe(this ConcurrentExpandableMap eMap, ReadOnlySpan cmd, out T value) + where T : ICustomCommand + { + value = default; + eMap.eMapLock.ReadLock(); + try + { + for (var i = 0; i < eMap.eMapUnsafe.ActualSize; i++) + { + var currCmd = eMap.eMapUnsafe.Map[i]; + if (currCmd != null && cmd.SequenceEqual(new ReadOnlySpan(currCmd.Name))) + { + value = currCmd; + return true; + } + } + } + finally + { + eMap.eMapLock.ReadUnlock(); + } + + return false; + } + + /// + /// Match sub-command name with existing sub-commands in map and return first matching instance + /// + /// Type of command + /// Current instance of ConcurrentExpandableMap + /// Sub-command name to match + /// Value of sub-command found + /// + internal static bool MatchSubCommandSafe(this ConcurrentExpandableMap eMap, ReadOnlySpan cmd, out CustomObjectCommand value) + where T : CustomObjectCommandWrapper + { + value = default; + eMap.eMapLock.ReadLock(); + try + { + for (var i = 0; i < eMap.eMapUnsafe.ActualSize; i++) + { + if (eMap.eMapUnsafe.Map[i] != null && eMap.eMapUnsafe.Map[i].commandMap.MatchCommandSafe(cmd, out value)) + return true; + } + } + finally + { + eMap.eMapLock.ReadUnlock(); + } + + return false; + } + } +} \ No newline at end of file diff --git a/libs/server/Custom/ICustomCommand.cs b/libs/server/Custom/ICustomCommand.cs new file mode 100644 index 0000000000..7a0a21a4f5 --- /dev/null +++ b/libs/server/Custom/ICustomCommand.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +namespace Garnet.server +{ + /// + /// Interface for custom commands + /// + interface ICustomCommand + { + /// + /// Name of command + /// + byte[] Name { get; } + } +} \ No newline at end of file diff --git a/libs/server/Module/ModuleRegistrar.cs b/libs/server/Module/ModuleRegistrar.cs index 891b2df33d..623384b96c 100644 --- a/libs/server/Module/ModuleRegistrar.cs +++ b/libs/server/Module/ModuleRegistrar.cs @@ -152,7 +152,7 @@ public ModuleActionStatus RegisterCommand(string name, CustomObjectFactory facto if (string.IsNullOrEmpty(name) || factory == null || command == null) return ModuleActionStatus.InvalidRegistrationInfo; - customCommandManager.Register(name, type, factory, command, commandInfo, commandDocs); + customCommandManager.Register(name, type, factory, commandInfo, commandDocs, command); return ModuleActionStatus.Success; } diff --git a/libs/server/Objects/Types/GarnetObjectSerializer.cs b/libs/server/Objects/Types/GarnetObjectSerializer.cs index 2563371269..1ebdc7b20d 100644 --- a/libs/server/Objects/Types/GarnetObjectSerializer.cs +++ b/libs/server/Objects/Types/GarnetObjectSerializer.cs @@ -13,14 +13,14 @@ namespace Garnet.server /// public sealed class GarnetObjectSerializer : BinaryObjectSerializer { - readonly CustomObjectCommandWrapper[] customCommands; + readonly CustomCommandManager customCommandManager; /// /// Constructor /// public GarnetObjectSerializer(CustomCommandManager customCommandManager) { - this.customCommands = customCommandManager.objectCommandMap; + this.customCommandManager = customCommandManager; } /// @@ -58,8 +58,9 @@ private IGarnetObject DeserializeInternal(BinaryReader binaryReader) private IGarnetObject CustomDeserialize(byte type, BinaryReader binaryReader) { - if (type < CustomCommandManager.TypeIdStartOffset) return null; - return customCommands[type - CustomCommandManager.TypeIdStartOffset].factory.Deserialize(type, binaryReader); + if (type < CustomCommandManager.TypeIdStartOffset || + !customCommandManager.TryGetCustomObjectCommand(type, out var cmd)) return null; + return cmd.factory.Deserialize(type, binaryReader); } /// diff --git a/libs/server/Objects/Types/GarnetObjectType.cs b/libs/server/Objects/Types/GarnetObjectType.cs index 69ad2e793b..ddbc40f8f4 100644 --- a/libs/server/Objects/Types/GarnetObjectType.cs +++ b/libs/server/Objects/Types/GarnetObjectType.cs @@ -33,6 +33,11 @@ public enum GarnetObjectType : byte // Any new special type inserted here should update GarnetObjectTypeExtensions.FirstSpecialObjectType + /// + /// Special type indicating PEXPIRE command + /// + PExpire = 0xf8, + /// /// Special type indicating EXPIRETIME command /// @@ -44,40 +49,35 @@ public enum GarnetObjectType : byte PExpireTime = 0xfa, /// - /// Special type indicating PERSIST command - /// - Persist = 0xfd, - - /// - /// Special type indicating TTL command + /// Indicating a Custom Object command /// - Ttl = 0xfe, + All = 0xfb, /// - /// Special type indicating EXPIRE command + /// Special type indicating PTTL command /// - Expire = 0xff, + PTtl = 0xfc, /// - /// Special type indicating PEXPIRE command + /// Special type indicating PERSIST command /// - PExpire = 0xf8, + Persist = 0xfd, /// - /// Special type indicating PTTL command + /// Special type indicating TTL command /// - PTtl = 0xfc, + Ttl = 0xfe, /// - /// Indicating a Custom Object command + /// Special type indicating EXPIRE command /// - All = 0xfb + Expire = 0xff, } public static class GarnetObjectTypeExtensions { internal const GarnetObjectType LastObjectType = GarnetObjectType.Set; - internal const GarnetObjectType FirstSpecialObjectType = GarnetObjectType.ExpireTime; + internal const GarnetObjectType FirstSpecialObjectType = GarnetObjectType.PExpire; } } \ No newline at end of file diff --git a/libs/server/Resp/BasicCommands.cs b/libs/server/Resp/BasicCommands.cs index cbb2008504..860dd691b2 100644 --- a/libs/server/Resp/BasicCommands.cs +++ b/libs/server/Resp/BasicCommands.cs @@ -988,7 +988,7 @@ private void WriteCOMMANDResponse() var resultSb = new StringBuilder(); var cmdCount = 0; - foreach (var customCmd in storeWrapper.customCommandManager.CustomCommandsInfo.Values) + foreach (var customCmd in storeWrapper.customCommandManager.customCommandsInfo.Values) { cmdCount++; resultSb.Append(customCmd.RespFormat); @@ -1082,7 +1082,7 @@ private bool NetworkCOMMAND_DOCS() resultSb.Append(cmdDocs.RespFormat); } - foreach (var customCmd in storeWrapper.customCommandManager.CustomCommandsDocs.Values) + foreach (var customCmd in storeWrapper.customCommandManager.customCommandsDocs.Values) { docsCount++; resultSb.Append(customCmd.RespFormat); diff --git a/libs/server/Resp/RespServerSession.cs b/libs/server/Resp/RespServerSession.cs index 9ccd49a9f9..f924e21fd7 100644 --- a/libs/server/Resp/RespServerSession.cs +++ b/libs/server/Resp/RespServerSession.cs @@ -776,8 +776,8 @@ bool NetworkCustomTxn() // Perform the operation TryTransactionProc(currentCustomTransaction.id, customCommandManagerSession - .GetCustomTransactionProcedure(currentCustomTransaction.id, this, txnManager, scratchBufferManager) - .Item1); + .GetCustomTransactionProcedure(currentCustomTransaction.id, this, txnManager, + scratchBufferManager, out _)); currentCustomTransaction = null; return true; } @@ -816,7 +816,7 @@ private bool NetworkCustomRawStringCmd(ref TGarnetApi storageApi) } // Perform the operation - TryCustomRawStringCommand(currentCustomRawStringCommand.GetRespCommand(), + TryCustomRawStringCommand((RespCommand)currentCustomRawStringCommand.id, currentCustomRawStringCommand.expirationTicks, currentCustomRawStringCommand.type, ref storageApi); currentCustomRawStringCommand = null; return true; @@ -832,7 +832,7 @@ bool NetworkCustomObjCmd(ref TGarnetApi storageApi) } // Perform the operation - TryCustomObjectCommand(currentCustomObjectCommand.GetObjectType(), currentCustomObjectCommand.subid, + TryCustomObjectCommand((GarnetObjectType)currentCustomObjectCommand.id, currentCustomObjectCommand.subid, currentCustomObjectCommand.type, ref storageApi); currentCustomObjectCommand = null; return true; @@ -840,7 +840,7 @@ bool NetworkCustomObjCmd(ref TGarnetApi storageApi) private bool IsCommandArityValid(string cmdName, int count) { - if (storeWrapper.customCommandManager.CustomCommandsInfo.TryGetValue(cmdName, out var cmdInfo)) + if (storeWrapper.customCommandManager.customCommandsInfo.TryGetValue(cmdName, out var cmdInfo)) { Debug.Assert(cmdInfo != null, "Custom command info should not be null"); if ((cmdInfo.Arity > 0 && count != cmdInfo.Arity - 1) || diff --git a/libs/server/Servers/RegisterApi.cs b/libs/server/Servers/RegisterApi.cs index ed2995280b..fd90cf9073 100644 --- a/libs/server/Servers/RegisterApi.cs +++ b/libs/server/Servers/RegisterApi.cs @@ -57,14 +57,6 @@ public int NewTransactionProc(string name, Func proc public int NewType(CustomObjectFactory factory) => provider.StoreWrapper.customCommandManager.RegisterType(factory); - /// - /// Register object type with server, with specific type ID [0-55] - /// - /// Type ID for factory - /// Factory for object type - public void NewType(int type, CustomObjectFactory factory) - => provider.StoreWrapper.customCommandManager.RegisterType(type, factory); - /// /// Register custom command with Garnet /// @@ -76,7 +68,7 @@ public void NewType(int type, CustomObjectFactory factory) /// RESP command docs /// ID of the registered command public (int objectTypeId, int subCommandId) NewCommand(string name, CommandType commandType, CustomObjectFactory factory, CustomObjectFunctions customObjectFunctions, RespCommandsInfo commandInfo = null, RespCommandDocs commandDocs = null) - => provider.StoreWrapper.customCommandManager.Register(name, commandType, factory, customObjectFunctions, commandInfo, commandDocs); + => provider.StoreWrapper.customCommandManager.Register(name, commandType, factory, commandInfo, commandDocs, customObjectFunctions); /// /// Register custom procedure with Garnet diff --git a/libs/server/Storage/Functions/FunctionsState.cs b/libs/server/Storage/Functions/FunctionsState.cs index bb2aa8e16e..055ad9f675 100644 --- a/libs/server/Storage/Functions/FunctionsState.cs +++ b/libs/server/Storage/Functions/FunctionsState.cs @@ -11,25 +11,33 @@ namespace Garnet.server /// internal sealed class FunctionsState { + private readonly CustomCommandManager customCommandManager; + public readonly TsavoriteLog appendOnlyFile; - public readonly CustomRawStringCommand[] customCommands; - public readonly CustomObjectCommandWrapper[] customObjectCommands; public readonly WatchVersionMap watchVersionMap; public readonly MemoryPool memoryPool; public readonly CacheSizeTracker objectStoreSizeTracker; public readonly GarnetObjectSerializer garnetObjectSerializer; public bool StoredProcMode; - public FunctionsState(TsavoriteLog appendOnlyFile, WatchVersionMap watchVersionMap, CustomRawStringCommand[] customCommands, CustomObjectCommandWrapper[] customObjectCommands, + public FunctionsState(TsavoriteLog appendOnlyFile, WatchVersionMap watchVersionMap, CustomCommandManager customCommandManager, MemoryPool memoryPool, CacheSizeTracker objectStoreSizeTracker, GarnetObjectSerializer garnetObjectSerializer) { this.appendOnlyFile = appendOnlyFile; this.watchVersionMap = watchVersionMap; - this.customCommands = customCommands; - this.customObjectCommands = customObjectCommands; + this.customCommandManager = customCommandManager; this.memoryPool = memoryPool ?? MemoryPool.Shared; this.objectStoreSizeTracker = objectStoreSizeTracker; this.garnetObjectSerializer = garnetObjectSerializer; } + + public CustomRawStringFunctions GetCustomCommandFunctions(int id) + => customCommandManager.TryGetCustomCommand(id, out var cmd) ? cmd.functions : null; + + public CustomObjectFactory GetCustomObjectFactory(int id) + => customCommandManager.TryGetCustomObjectCommand(id, out var cmd) ? cmd.factory : null; + + public CustomObjectFunctions GetCustomObjectSubCommandFunctions(int id, int subId) + => customCommandManager.TryGetCustomObjectSubCommand(id, subId, out var cmd) ? cmd.functions : null; } } \ No newline at end of file diff --git a/libs/server/Storage/Functions/MainStore/RMWMethods.cs b/libs/server/Storage/Functions/MainStore/RMWMethods.cs index ec5b8b1462..9d875ac357 100644 --- a/libs/server/Storage/Functions/MainStore/RMWMethods.cs +++ b/libs/server/Storage/Functions/MainStore/RMWMethods.cs @@ -30,11 +30,10 @@ public bool NeedInitialUpdate(ref SpanByte key, ref RawStringInput input, ref Sp case RespCommand.GETEX: return false; default: - if ((ushort)input.header.cmd >= CustomCommandManager.StartOffset) + if (input.header.cmd > RespCommandExtensions.LastValidCommand) { (IMemoryOwner Memory, int Length) outp = (output.Memory, 0); - var ret = functionsState - .customCommands[(ushort)input.header.cmd - CustomCommandManager.StartOffset].functions + var ret = functionsState.GetCustomCommandFunctions((ushort)input.header.cmd) .NeedInitialUpdate(key.AsReadOnlySpan(), ref input, ref outp); output.Memory = outp.Memory; output.Length = outp.Length; @@ -178,9 +177,9 @@ public bool InitialUpdater(ref SpanByte key, ref RawStringInput input, ref SpanB default: value.UnmarkExtraMetadata(); - if ((ushort)input.header.cmd >= CustomCommandManager.StartOffset) + if (input.header.cmd > RespCommandExtensions.LastValidCommand) { - var functions = functionsState.customCommands[(ushort)input.header.cmd - CustomCommandManager.StartOffset].functions; + var functions = functionsState.GetCustomCommandFunctions((ushort)input.header.cmd); // compute metadata size for result var expiration = input.arg1; metadataSize = expiration switch @@ -505,10 +504,10 @@ private bool InPlaceUpdaterWorker(ref SpanByte key, ref RawStringInput input, re return false; default: - var cmd = (ushort)input.header.cmd; - if (cmd >= CustomCommandManager.StartOffset) + var cmd = input.header.cmd; + if (cmd > RespCommandExtensions.LastValidCommand) { - var functions = functionsState.customCommands[cmd - CustomCommandManager.StartOffset].functions; + var functions = functionsState.GetCustomCommandFunctions((ushort)cmd); var expiration = input.arg1; if (expiration == -1) { @@ -583,10 +582,10 @@ public bool NeedCopyUpdate(ref SpanByte key, ref RawStringInput input, ref SpanB } return true; default: - if ((ushort)input.header.cmd >= CustomCommandManager.StartOffset) + if (input.header.cmd > RespCommandExtensions.LastValidCommand) { (IMemoryOwner Memory, int Length) outp = (output.Memory, 0); - var ret = functionsState.customCommands[(ushort)input.header.cmd - CustomCommandManager.StartOffset].functions + var ret = functionsState.GetCustomCommandFunctions((ushort)input.header.cmd) .NeedCopyUpdate(key.AsReadOnlySpan(), ref input, oldValue.AsReadOnlySpan(), ref outp); output.Memory = outp.Memory; output.Length = outp.Length; @@ -818,9 +817,9 @@ public bool CopyUpdater(ref SpanByte key, ref RawStringInput input, ref SpanByte break; default: - if ((ushort)input.header.cmd >= CustomCommandManager.StartOffset) + if (input.header.cmd > RespCommandExtensions.LastValidCommand) { - var functions = functionsState.customCommands[(ushort)input.header.cmd - CustomCommandManager.StartOffset].functions; + var functions = functionsState.GetCustomCommandFunctions((ushort)input.header.cmd); var expiration = input.arg1; if (expiration == 0) { diff --git a/libs/server/Storage/Functions/MainStore/ReadMethods.cs b/libs/server/Storage/Functions/MainStore/ReadMethods.cs index cd0a0be785..5447708a09 100644 --- a/libs/server/Storage/Functions/MainStore/ReadMethods.cs +++ b/libs/server/Storage/Functions/MainStore/ReadMethods.cs @@ -19,11 +19,11 @@ public bool SingleReader(ref SpanByte key, ref RawStringInput input, ref SpanByt return false; var cmd = input.header.cmd; - if ((ushort)cmd >= CustomCommandManager.StartOffset) + if (cmd > RespCommandExtensions.LastValidCommand) { var valueLength = value.LengthWithoutMetadata; (IMemoryOwner Memory, int Length) output = (dst.Memory, 0); - var ret = functionsState.customCommands[(ushort)cmd - CustomCommandManager.StartOffset].functions + var ret = functionsState.GetCustomCommandFunctions((ushort)cmd) .Reader(key.AsReadOnlySpan(), ref input, value.AsReadOnlySpan(), ref output, ref readInfo); Debug.Assert(valueLength <= value.LengthWithoutMetadata); dst.Memory = output.Memory; @@ -50,11 +50,11 @@ public bool ConcurrentReader(ref SpanByte key, ref RawStringInput input, ref Spa } var cmd = input.header.cmd; - if ((ushort)cmd >= CustomCommandManager.StartOffset) + if (cmd > RespCommandExtensions.LastValidCommand) { var valueLength = value.LengthWithoutMetadata; (IMemoryOwner Memory, int Length) output = (dst.Memory, 0); - var ret = functionsState.customCommands[(ushort)cmd - CustomCommandManager.StartOffset].functions + var ret = functionsState.GetCustomCommandFunctions((ushort)cmd) .Reader(key.AsReadOnlySpan(), ref input, value.AsReadOnlySpan(), ref output, ref readInfo); Debug.Assert(valueLength <= value.LengthWithoutMetadata); dst.Memory = output.Memory; diff --git a/libs/server/Storage/Functions/MainStore/VarLenInputMethods.cs b/libs/server/Storage/Functions/MainStore/VarLenInputMethods.cs index b0b3803465..442cf7a769 100644 --- a/libs/server/Storage/Functions/MainStore/VarLenInputMethods.cs +++ b/libs/server/Storage/Functions/MainStore/VarLenInputMethods.cs @@ -119,9 +119,9 @@ public int GetRMWInitialValueLength(ref RawStringInput input) return sizeof(int) + ndigits; default: - if ((ushort)cmd >= CustomCommandManager.StartOffset) + if (cmd > RespCommandExtensions.LastValidCommand) { - var functions = functionsState.customCommands[(ushort)cmd - CustomCommandManager.StartOffset].functions; + var functions = functionsState.GetCustomCommandFunctions((ushort)cmd); // Compute metadata size for result int metadataSize = input.arg1 switch { @@ -236,9 +236,9 @@ public int GetRMWModifiedValueLength(ref SpanByte t, ref RawStringInput input) return sizeof(int) + t.Length + valueLength; default: - if ((ushort)cmd >= CustomCommandManager.StartOffset) + if (cmd > RespCommandExtensions.LastValidCommand) { - var functions = functionsState.customCommands[(ushort)cmd - CustomCommandManager.StartOffset].functions; + var functions = functionsState.GetCustomCommandFunctions((ushort)cmd); // compute metadata for result var metadataSize = input.arg1 switch { diff --git a/libs/server/Storage/Functions/ObjectStore/PrivateMethods.cs b/libs/server/Storage/Functions/ObjectStore/PrivateMethods.cs index 91698a8f00..88f08e9d53 100644 --- a/libs/server/Storage/Functions/ObjectStore/PrivateMethods.cs +++ b/libs/server/Storage/Functions/ObjectStore/PrivateMethods.cs @@ -184,9 +184,8 @@ static bool EvaluateObjectExpireInPlace(ExpireOption optionType, bool expiryExis [MethodImpl(MethodImplOptions.AggressiveInlining)] private CustomObjectFunctions GetCustomObjectCommand(ref ObjectInput input, GarnetObjectType type) { - var objectId = (byte)((byte)type - CustomCommandManager.TypeIdStartOffset); var cmdId = input.header.SubId; - var customObjectCommand = functionsState.customObjectCommands[objectId].commandMap[cmdId].functions; + var customObjectCommand = functionsState.GetCustomObjectSubCommandFunctions((byte)type, cmdId); return customObjectCommand; } diff --git a/libs/server/Storage/Functions/ObjectStore/RMWMethods.cs b/libs/server/Storage/Functions/ObjectStore/RMWMethods.cs index 8a28bc1e1e..01d8c562bb 100644 --- a/libs/server/Storage/Functions/ObjectStore/RMWMethods.cs +++ b/libs/server/Storage/Functions/ObjectStore/RMWMethods.cs @@ -55,8 +55,7 @@ public bool InitialUpdater(ref byte[] key, ref ObjectInput input, ref IGarnetObj Debug.Assert(type != GarnetObjectType.Expire && type != GarnetObjectType.PExpire && type != GarnetObjectType.Persist, "Expire and Persist commands should have been handled already by NeedInitialUpdate."); var customObjectCommand = GetCustomObjectCommand(ref input, type); - var objectId = (byte)((byte)type - CustomCommandManager.TypeIdStartOffset); - value = functionsState.customObjectCommands[objectId].factory.Create((byte)type); + value = functionsState.GetCustomObjectFactory((byte)type).Create((byte)type); (IMemoryOwner Memory, int Length) outp = (output.spanByteAndMemory.Memory, 0); var result = customObjectCommand.InitialUpdater(key, ref input, value, ref outp, ref rmwInfo); diff --git a/libs/server/StoreWrapper.cs b/libs/server/StoreWrapper.cs index b94833916f..4c3ef3caf3 100644 --- a/libs/server/StoreWrapper.cs +++ b/libs/server/StoreWrapper.cs @@ -217,7 +217,7 @@ public string GetIp() } internal FunctionsState CreateFunctionsState() - => new(appendOnlyFile, versionMap, customCommandManager.rawStringCommandMap, customCommandManager.objectCommandMap, null, objectStoreSizeTracker, GarnetObjectSerializer); + => new(appendOnlyFile, versionMap, customCommandManager, null, objectStoreSizeTracker, GarnetObjectSerializer); internal void Recover() { diff --git a/libs/server/Transaction/TxnRespCommands.cs b/libs/server/Transaction/TxnRespCommands.cs index 14186a0537..e18d97a7c7 100644 --- a/libs/server/Transaction/TxnRespCommands.cs +++ b/libs/server/Transaction/TxnRespCommands.cs @@ -266,7 +266,7 @@ private bool NetworkRUNTXP() try { - (proc, arity) = customCommandManagerSession.GetCustomTransactionProcedure(txId, this, txnManager, scratchBufferManager); + proc = customCommandManagerSession.GetCustomTransactionProcedure(txId, this, txnManager, scratchBufferManager, out arity); } catch (Exception e) { diff --git a/test/Garnet.test/RespCustomCommandTests.cs b/test/Garnet.test/RespCustomCommandTests.cs index 5d238af79d..cef9353517 100644 --- a/test/Garnet.test/RespCustomCommandTests.cs +++ b/test/Garnet.test/RespCustomCommandTests.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Linq; @@ -1088,5 +1089,195 @@ public void CustomProcedureInvokingInvalidCommandTest() var result = db.Execute("PROCINVALIDCMD", "key"); ClassicAssert.AreEqual("OK", (string)result); } + + [Test] + [TestCase(true)] + [TestCase(false)] + public void MultiRegisterCommandTest(bool sync) + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + var regCount = 24; + var regCmdTasks = new Task[regCount]; + for (var i = 0; i < regCount; i++) + { + var idx = i; + regCmdTasks[i] = new Task(() => server.Register.NewCommand($"SETIFPM{idx + 1}", CommandType.ReadModifyWrite, new SetIfPMCustomCommand(), + new RespCommandsInfo { Arity = 4 })); + } + + for (var i = 0; i < regCount; i++) + { + if (sync) + { + regCmdTasks[i].RunSynchronously(); + } + else + { + regCmdTasks[i].Start(); + } + } + + if (!sync) Task.WhenAll(regCmdTasks); + + for (var i = 0; i < regCount; i++) + { + var key = $"mykey{i + 1}"; + var origValue = "foovalue0"; + db.StringSet(key, origValue); + + var newValue1 = "foovalue1"; + db.Execute($"SETIFPM{i + 1}", key, newValue1, "foo"); + + // This conditional set should pass (prefix matches) + string retValue = db.StringGet(key); + ClassicAssert.AreEqual(newValue1, retValue); + } + } + + [Test] + [TestCase(true)] + [TestCase(false)] + public void MultiRegisterSubCommandTest(bool sync) + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + var factory = new MyDictFactory(); + server.Register.NewCommand("MYDICTGET", CommandType.Read, factory, new MyDictGet(), new RespCommandsInfo { Arity = 3 }); + + // Only able to register 31 sub-commands, try to register 32 + var regCount = 32; + var failedTaskIdAndMessage = new ConcurrentBag<(int, string)>(); + var regCmdTasks = new Task[regCount]; + for (var i = 0; i < regCount; i++) + { + var idx = i; + regCmdTasks[i] = new Task(() => + { + try + { + server.Register.NewCommand($"MYDICTSET{idx + 1}", + CommandType.ReadModifyWrite, factory, new MyDictSet(), new RespCommandsInfo { Arity = 4 }); + } + catch (Exception e) + { + failedTaskIdAndMessage.Add((idx, e.Message)); + } + }); + } + + for (var i = 0; i < regCount; i++) + { + if (sync) + { + regCmdTasks[i].RunSynchronously(); + } + else + { + regCmdTasks[i].Start(); + } + } + + if (!sync) Task.WaitAll(regCmdTasks); + + // Exactly one registration should fail + ClassicAssert.AreEqual(1, failedTaskIdAndMessage.Count); + failedTaskIdAndMessage.TryTake(out var failedTaskResult); + + var failedTaskId = failedTaskResult.Item1; + var failedTaskMessage = failedTaskResult.Item2; + ClassicAssert.AreEqual("Out of registration space", failedTaskMessage); + + var mainkey = "key"; + + // Check that all registrations worked except the failed one + for (var i = 0; i < regCount; i++) + { + if (i == failedTaskId) continue; + var key1 = $"mykey{i + 1}"; + var value1 = $"foovalue{i + 1}"; + db.Execute($"MYDICTSET{i + 1}", mainkey, key1, value1); + + var retValue = db.Execute("MYDICTGET", mainkey, key1); + ClassicAssert.AreEqual(value1, (string)retValue); + } + } + + [Test] + public void MultiRegisterTxnTest() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + var regCount = byte.MaxValue + 1; + for (var i = 0; i < regCount; i++) + { + server.Register.NewTransactionProc($"GETTWOKEYSNOTXN{i + 1}", () => new GetTwoKeysNoTxn(), new RespCommandsInfo { Arity = 3 }); + } + + try + { + // This register should fail as there could only be byte.MaxValue + 1 transactions registered + server.Register.NewTransactionProc($"GETTWOKEYSNOTXN{byte.MaxValue + 3}", () => new GetTwoKeysNoTxn(), new RespCommandsInfo { Arity = 3 }); + Assert.Fail(); + } + catch (Exception e) + { + ClassicAssert.AreEqual("Out of registration space", e.Message); + } + + for (var i = 0; i < regCount; i++) + { + var readkey1 = $"readkey{i + 1}.1"; + var value1 = $"foovalue{i + 1}.1"; + db.StringSet(readkey1, value1); + + var readkey2 = $"readkey{i + 1}.2"; + var value2 = $"foovalue{i + 1}.2"; + db.StringSet(readkey2, value2); + + var result = db.Execute($"GETTWOKEYSNOTXN{i + 1}", readkey1, readkey2); + + ClassicAssert.AreEqual(value1, ((string[])result)?[0]); + ClassicAssert.AreEqual(value2, ((string[])result)?[1]); + } + } + + [Test] + public void MultiRegisterProcTest() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + var regCount = byte.MaxValue + 1; + for (var i = 0; i < regCount; i++) + { + server.Register.NewProcedure($"SUM{i + 1}", () => new Sum()); + } + + try + { + // This register should fail as there could only be byte.MaxValue + 1 procedures registered + server.Register.NewProcedure($"SUM{byte.MaxValue + 3}", () => new Sum()); + Assert.Fail(); + } + catch (Exception e) + { + ClassicAssert.AreEqual("Out of registration space", e.Message); + } + + db.StringSet("key1", "10"); + db.StringSet("key2", "35"); + db.StringSet("key3", "20"); + + for (var i = 0; i < regCount; i++) + { + // Include non-existent and string keys as well + var retValue = db.Execute($"SUM{i + 1}", "key1", "key2", "key3", "key4"); + ClassicAssert.AreEqual("65", retValue.ToString()); + } + } } } \ No newline at end of file