diff --git a/.github/workflows/ci-bdnbenchmark.yml b/.github/workflows/ci-bdnbenchmark.yml index 7ca620bc73..5f3ee17074 100644 --- a/.github/workflows/ci-bdnbenchmark.yml +++ b/.github/workflows/ci-bdnbenchmark.yml @@ -42,7 +42,7 @@ jobs: os: [ ubuntu-latest, windows-latest ] framework: [ 'net8.0' ] configuration: [ 'Release' ] - test: [ 'Operations.BasicOperations', 'Operations.ObjectOperations', 'Operations.HashObjectOperations', 'Cluster.ClusterMigrate', 'Cluster.ClusterOperations', 'Lua.LuaScripts', 'Operations.CustomOperations', 'Operations.RawStringOperations', 'Operations.ScriptOperations','Network.BasicOperations', 'Network.RawStringOperations' ] + test: [ 'Operations.BasicOperations', 'Operations.ObjectOperations', 'Operations.HashObjectOperations', 'Cluster.ClusterMigrate', 'Cluster.ClusterOperations', 'Lua.LuaScripts', 'Lua.LuaScriptCacheOperations','Lua.LuaRunnerOperations','Operations.CustomOperations', 'Operations.RawStringOperations', 'Operations.ScriptOperations','Network.BasicOperations', 'Network.RawStringOperations' ] steps: - name: Check out code uses: actions/checkout@v4 diff --git a/benchmark/BDN.benchmark/Lua/LuaParams.cs b/benchmark/BDN.benchmark/Lua/LuaParams.cs index 85c23fa9d1..3426a74786 100644 --- a/benchmark/BDN.benchmark/Lua/LuaParams.cs +++ b/benchmark/BDN.benchmark/Lua/LuaParams.cs @@ -1,26 +1,38 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +using BenchmarkDotNet.Code; +using Garnet.server; + namespace BDN.benchmark.Lua { /// - /// Cluster parameters + /// Lua parameters /// - public struct LuaParams + public readonly struct LuaParams { + public readonly LuaMemoryManagementMode Mode { get; } + public readonly bool MemoryLimit { get; } + /// /// Constructor /// - public LuaParams() + public LuaParams(LuaMemoryManagementMode mode, bool memoryLimit) { + Mode = mode; + MemoryLimit = memoryLimit; } + /// + /// Get the equivalent . + /// + public LuaOptions CreateOptions() + => new(Mode, MemoryLimit ? "2m" : ""); + /// /// String representation /// public override string ToString() - { - return "None"; - } + => $"{Mode},{(MemoryLimit ? "Limit" : "None")}"; } } \ No newline at end of file diff --git a/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs b/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs index 07d28e641b..a100249efe 100644 --- a/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs +++ b/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs @@ -139,9 +139,13 @@ public unsafe class LuaRunnerOperations /// Lua parameters provider /// public IEnumerable LuaParamsProvider() - { - yield return new(); - } + => [ + new(LuaMemoryManagementMode.Native, false), + new(LuaMemoryManagementMode.Tracked, false), + new(LuaMemoryManagementMode.Tracked, true), + new(LuaMemoryManagementMode.Managed, false), + new(LuaMemoryManagementMode.Managed, true), + ]; private EmbeddedRespServer server; private RespServerSession session; @@ -151,16 +155,21 @@ public IEnumerable LuaParamsProvider() private LuaRunner smallCompileRunner; private LuaRunner largeCompileRunner; + private LuaOptions opts; + [GlobalSetup] public void GlobalSetup() { - server = new EmbeddedRespServer(new GarnetServerOptions() { EnableLua = true, QuietMode = true }); + opts = Params.CreateOptions(); + + server = new EmbeddedRespServer(new GarnetServerOptions() { EnableLua = true, QuietMode = true, LuaOptions = opts }); session = server.GetRespSession(); - paramsRunner = new LuaRunner("return nil"); - smallCompileRunner = new LuaRunner(SmallScript); - largeCompileRunner = new LuaRunner(LargeScript); + paramsRunner = new LuaRunner(opts, "return nil"); + + smallCompileRunner = new LuaRunner(opts, SmallScript); + largeCompileRunner = new LuaRunner(opts, LargeScript); } [GlobalCleanup] @@ -171,6 +180,18 @@ public void GlobalCleanup() paramsRunner.Dispose(); } + [IterationSetup] + public void IterationSetup() + { + session.EnterAndGetResponseObject(); + } + + [IterationCleanup] + public void IterationCleanup() + { + session.ExitAndReturnResponseObject(); + } + [Benchmark] public void ResetParametersSmall() { @@ -194,13 +215,13 @@ public void ResetParametersLarge() [Benchmark] public void ConstructSmall() { - using var runner = new LuaRunner(SmallScript); + using var runner = new LuaRunner(opts, SmallScript); } [Benchmark] public void ConstructLarge() { - using var runner = new LuaRunner(LargeScript); + using var runner = new LuaRunner(opts, LargeScript); } [Benchmark] diff --git a/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs b/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs index a996e45469..430a92ca59 100644 --- a/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs +++ b/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs @@ -22,9 +22,13 @@ public class LuaScriptCacheOperations /// Lua parameters provider /// public IEnumerable LuaParamsProvider() - { - yield return new(); - } + => [ + new(LuaMemoryManagementMode.Native, false), + new(LuaMemoryManagementMode.Tracked, false), + new(LuaMemoryManagementMode.Tracked, true), + new(LuaMemoryManagementMode.Managed, false), + new(LuaMemoryManagementMode.Managed, true), + ]; private EmbeddedRespServer server; private StoreWrapper storeWrapper; @@ -38,7 +42,9 @@ public IEnumerable LuaParamsProvider() [GlobalSetup] public void GlobalSetup() { - server = new EmbeddedRespServer(new GarnetServerOptions() { EnableLua = true, QuietMode = true }); + var options = Params.CreateOptions(); + + server = new EmbeddedRespServer(new GarnetServerOptions() { EnableLua = true, QuietMode = true, LuaOptions = options }); storeWrapper = server.StoreWrapper; sessionScriptCache = new SessionScriptCache(storeWrapper, new GarnetNoAuthAuthenticator()); session = server.GetRespSession(); diff --git a/benchmark/BDN.benchmark/Lua/LuaScripts.cs b/benchmark/BDN.benchmark/Lua/LuaScripts.cs index 3cff4b30e3..fc2a2c7269 100644 --- a/benchmark/BDN.benchmark/Lua/LuaScripts.cs +++ b/benchmark/BDN.benchmark/Lua/LuaScripts.cs @@ -22,9 +22,13 @@ public unsafe class LuaScripts /// Lua parameters provider /// public IEnumerable LuaParamsProvider() - { - yield return new(); - } + => [ + new(LuaMemoryManagementMode.Native, false), + new(LuaMemoryManagementMode.Tracked, false), + new(LuaMemoryManagementMode.Tracked, true), + new(LuaMemoryManagementMode.Managed, false), + new(LuaMemoryManagementMode.Managed, true), + ]; LuaRunner r1, r2, r3, r4; readonly string[] keys = ["key1"]; @@ -32,13 +36,15 @@ public IEnumerable LuaParamsProvider() [GlobalSetup] public void GlobalSetup() { - r1 = new LuaRunner("return"); + var options = Params.CreateOptions(); + + r1 = new LuaRunner(options, "return"); r1.CompileForRunner(); - r2 = new LuaRunner("return 1 + 1"); + r2 = new LuaRunner(options, "return 1 + 1"); r2.CompileForRunner(); - r3 = new LuaRunner("return KEYS[1]"); + r3 = new LuaRunner(options, "return KEYS[1]"); r3.CompileForRunner(); - r4 = new LuaRunner("return redis.call(KEYS[1])"); + r4 = new LuaRunner(options, "return redis.call(KEYS[1])"); r4.CompileForRunner(); } diff --git a/benchmark/BDN.benchmark/Operations/OperationsBase.cs b/benchmark/BDN.benchmark/Operations/OperationsBase.cs index 8677331de3..c47a550a4f 100644 --- a/benchmark/BDN.benchmark/Operations/OperationsBase.cs +++ b/benchmark/BDN.benchmark/Operations/OperationsBase.cs @@ -55,6 +55,7 @@ public virtual void GlobalSetup() QuietMode = true, EnableLua = true, DisablePubSub = true, + LuaOptions = new(LuaMemoryManagementMode.Native, ""), }; if (Params.useAof) { diff --git a/benchmark/BDN.benchmark/Operations/ScriptOperations.cs b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs index 8a6cd34f08..b8298912eb 100644 --- a/benchmark/BDN.benchmark/Operations/ScriptOperations.cs +++ b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs @@ -4,8 +4,10 @@ using System.Runtime.CompilerServices; using System.Security.Cryptography; using System.Text; +using BDN.benchmark.Lua; using BenchmarkDotNet.Attributes; using Embedded.server; +using Garnet.server; namespace BDN.benchmark.Operations { @@ -13,7 +15,7 @@ namespace BDN.benchmark.Operations /// Benchmark for SCRIPT LOAD, SCRIPT EXISTS, EVAL, and EVALSHA /// [MemoryDiagnoser] - public unsafe class ScriptOperations : OperationsBase + public unsafe class ScriptOperations { // Small script that does 1 operation and no logic const string SmallScriptText = @"return redis.call('GET', KEYS[1]);"; @@ -156,9 +158,54 @@ public unsafe class ScriptOperations : OperationsBase static ReadOnlySpan ARRAY_RETURN => "*3\r\n$4\r\nEVAL\r\n$22\r\nreturn {1, 2, 3, 4, 5}\r\n$1\r\n0\r\n"u8; Request arrayReturn; - public override void GlobalSetup() + + /// + /// Lua parameters + /// + [ParamsSource(nameof(LuaParamsProvider))] + public LuaParams Params { get; set; } + + /// + /// Lua parameters provider + /// + public static IEnumerable LuaParamsProvider() + => [ + new(LuaMemoryManagementMode.Native, false), + new(LuaMemoryManagementMode.Tracked, false), + new(LuaMemoryManagementMode.Tracked, true), + new(LuaMemoryManagementMode.Managed, false), + new(LuaMemoryManagementMode.Managed, true), + ]; + + /// + /// Batch size per method invocation + /// With a batchSize of 100, we have a convenient conversion of latency to throughput: + /// 5 us = 20 Mops/sec + /// 10 us = 10 Mops/sec + /// 20 us = 5 Mops/sec + /// 25 us = 4 Mops/sec + /// 100 us = 1 Mops/sec + /// + internal const int batchSize = 100; + internal EmbeddedRespServer server; + internal RespServerSession session; + + /// + /// Setup + /// + [GlobalSetup] + public void GlobalSetup() { - base.GlobalSetup(); + var opts = new GarnetServerOptions + { + QuietMode = true, + EnableLua = true, + LuaOptions = Params.CreateOptions(), + }; + + server = new EmbeddedRespServer(opts); + + session = server.GetRespSession(); SetupOperation(ref scriptLoad, SCRIPT_LOAD); @@ -216,6 +263,16 @@ public override void GlobalSetup() SetupOperation(ref evalShaLargeScript, largeScriptEvals); } + /// + /// Cleanup + /// + [GlobalCleanup] + public virtual void GlobalCleanup() + { + session.Dispose(); + server.Dispose(); + } + [Benchmark] public void ScriptLoad() { @@ -263,5 +320,36 @@ public void ArrayReturn() { Send(arrayReturn); } + + private void Send(Request request) + { + _ = session.TryConsumeMessages(request.bufferPtr, request.buffer.Length); + } + + private unsafe void SetupOperation(ref Request request, ReadOnlySpan operation, int batchSize = batchSize) + { + request.buffer = GC.AllocateArray(operation.Length * batchSize, pinned: true); + request.bufferPtr = (byte*)Unsafe.AsPointer(ref request.buffer[0]); + for (int i = 0; i < batchSize; i++) + operation.CopyTo(new Span(request.buffer).Slice(i * operation.Length)); + } + + private unsafe void SetupOperation(ref Request request, string operation, int batchSize = batchSize) + { + request.buffer = GC.AllocateUninitializedArray(operation.Length * batchSize, pinned: true); + for (var i = 0; i < batchSize; i++) + { + var start = i * operation.Length; + Encoding.UTF8.GetBytes(operation, request.buffer.AsSpan().Slice(start, operation.Length)); + } + request.bufferPtr = (byte*)Unsafe.AsPointer(ref request.buffer[0]); + } + + private unsafe void SetupOperation(ref Request request, List operationBytes) + { + request.buffer = GC.AllocateUninitializedArray(operationBytes.Count, pinned: true); + operationBytes.CopyTo(request.buffer); + request.bufferPtr = (byte*)Unsafe.AsPointer(ref request.buffer[0]); + } } } \ No newline at end of file diff --git a/libs/host/Configuration/Options.cs b/libs/host/Configuration/Options.cs index b32cf175a3..8922292bf9 100644 --- a/libs/host/Configuration/Options.cs +++ b/libs/host/Configuration/Options.cs @@ -8,6 +8,7 @@ using System.Linq; using System.Net; using System.Reflection; +using System.Runtime.InteropServices; using System.Security.Cryptography.X509Certificates; using CommandLine; using Garnet.server; @@ -508,6 +509,14 @@ internal sealed class Options [Option("fail-on-recovery-error", Default = false, Required = false, HelpText = "Server bootup should fail if errors happen during bootup of AOF and checkpointing")] public bool? FailOnRecoveryError { get; set; } + [Option("lua-memory-management-mode", Default = LuaMemoryManagementMode.Native, Required = false, HelpText = "Memory management mode for Lua scripts, must be set to LimittedNative or Managed to impose script limits")] + public LuaMemoryManagementMode LuaMemoryManagementMode { get; set; } + + [MemorySizeValidation(false)] + [ForbiddenWithOption(nameof(LuaMemoryManagementMode), nameof(LuaMemoryManagementMode.Native))] + [Option("lua-script-memory-limit", Default = null, HelpText = "Memory limit for a Lua instances while running a script, lua-memory-management-mode must be set to something other than Native to use this flag")] + public string LuaScriptMemoryLimit { get; set; } + /// /// This property contains all arguments that were not parsed by the command line argument parser /// @@ -718,7 +727,8 @@ public GarnetServerOptions GetServerOptions(ILogger logger = null) IndexResizeFrequencySecs = IndexResizeFrequencySecs, IndexResizeThreshold = IndexResizeThreshold, LoadModuleCS = LoadModuleCS, - FailOnRecoveryError = FailOnRecoveryError.GetValueOrDefault() + FailOnRecoveryError = FailOnRecoveryError.GetValueOrDefault(), + LuaOptions = EnableLua.GetValueOrDefault() ? new LuaOptions(LuaMemoryManagementMode, LuaScriptMemoryLimit, logger) : null, }; } diff --git a/libs/host/Configuration/OptionsValidators.cs b/libs/host/Configuration/OptionsValidators.cs index ea7f64682e..94d97cced8 100644 --- a/libs/host/Configuration/OptionsValidators.cs +++ b/libs/host/Configuration/OptionsValidators.cs @@ -563,4 +563,42 @@ protected override ValidationResult IsValid(object value, ValidationContext vali return base.IsValid(value, validationContext); } } + + /// + /// Forbids a config option from being set if the another option has particular values. + /// + [AttributeUsage(AttributeTargets.Property)] + internal sealed class ForbiddenWithOptionAttribute : ValidationAttribute + { + private readonly string otherOptionName; + private readonly string[] forbiddenValues; + + internal ForbiddenWithOptionAttribute(string otherOptionName, string forbiddenValue, params string[] otherForbiddenValues) + { + this.otherOptionName = otherOptionName; + forbiddenValues = [forbiddenValue, .. otherForbiddenValues]; + } + + /// + protected override ValidationResult IsValid(object value, ValidationContext validationContext) + { + var optionIsSet = value != null && !(value is string valueStr && string.IsNullOrEmpty(valueStr)); + if (optionIsSet) + { + var propAccessor = validationContext.ObjectInstance?.GetType()?.GetProperty(otherOptionName, BindingFlags.Instance | BindingFlags.Public); + if (propAccessor != null) + { + var otherOptionValue = propAccessor.GetValue(validationContext.ObjectInstance); + var otherOptionValueAsString = otherOptionValue is string strVal ? strVal : otherOptionValue?.ToString(); + + if (forbiddenValues.Contains(otherOptionValueAsString, StringComparer.OrdinalIgnoreCase)) + { + return new ValidationResult($"{nameof(validationContext.DisplayName)} cannot be set with {otherOptionName} has value '{otherOptionValueAsString}'"); + } + } + } + + return ValidationResult.Success; + } + } } \ No newline at end of file diff --git a/libs/host/defaults.conf b/libs/host/defaults.conf index 5a19569653..4c08ebec06 100644 --- a/libs/host/defaults.conf +++ b/libs/host/defaults.conf @@ -342,5 +342,11 @@ "LoadModuleCS": null, /* Fails if encounters error during AOF replay or checkpointing */ - "FailOnRecoveryError": false + "FailOnRecoveryError": false, + + /* Lua uses the default, unmanaged and untracked, allocator */ + "LuaMemoryManagementMode": "Native", + + /* Lua limits are ignored for Native, but can be set for other modes */ + "LuaScriptMemoryLimit": null } \ No newline at end of file diff --git a/libs/server/ArgSlice/ScratchBufferManager.cs b/libs/server/ArgSlice/ScratchBufferManager.cs index e2106bfaa7..42fa904690 100644 --- a/libs/server/ArgSlice/ScratchBufferManager.cs +++ b/libs/server/ArgSlice/ScratchBufferManager.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using System.Numerics; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Text; using Garnet.common; @@ -297,15 +298,19 @@ void ExpandScratchBufferIfNeeded(int newLength) ExpandScratchBuffer(scratchBufferOffset + newLength); } - void ExpandScratchBuffer(int newLength) + void ExpandScratchBuffer(int newLength, int? copyLengthOverride = null) { if (newLength < 64) newLength = 64; else newLength = (int)BitOperations.RoundUpToPowerOf2((uint)newLength + 1); var _scratchBuffer = GC.AllocateArray(newLength, true); - var _scratchBufferHead = (byte*)Unsafe.AsPointer(ref _scratchBuffer[0]); - if (scratchBufferOffset > 0) - new ReadOnlySpan(scratchBufferHead, scratchBufferOffset).CopyTo(new Span(_scratchBufferHead, scratchBufferOffset)); + var _scratchBufferHead = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetArrayDataReference(_scratchBuffer)); + + var copyLength = copyLengthOverride ?? scratchBufferOffset; + if (copyLength > 0) + { + new ReadOnlySpan(scratchBufferHead, copyLength).CopyTo(new Span(_scratchBufferHead, copyLength)); + } scratchBuffer = _scratchBuffer; scratchBufferHead = _scratchBufferHead; } @@ -326,8 +331,12 @@ public ArgSlice GetSliceFromTail(int length) /// /// Force backing buffer to grow. + /// + /// provides a way to force a chunk at the start of the + /// previous buffer be copied into the new buffer, even if this + /// doesn't consider that chunk in use. /// - public void GrowBuffer() + public void GrowBuffer(int? copyLengthOverride = null) { if (scratchBuffer == null) { @@ -335,7 +344,7 @@ public void GrowBuffer() } else { - ExpandScratchBuffer(scratchBuffer.Length + 1); + ExpandScratchBuffer(scratchBuffer.Length + 1, copyLengthOverride); } } } diff --git a/libs/server/Lua/ILuaAllocator.cs b/libs/server/Lua/ILuaAllocator.cs new file mode 100644 index 0000000000..78be36a9ea --- /dev/null +++ b/libs/server/Lua/ILuaAllocator.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +namespace Garnet.server +{ + /// + /// Common interface for Lua allocators. + /// + /// Lua itself has a somewhat esoteric alloc interface, + /// this maps to something akin to malloc/free/realloc though + /// with some C# niceties. + /// + /// Note that all returned references must be pinned, as Lua is not aware + /// of the .NET GC. + /// + internal interface ILuaAllocator + { + /// + /// Allocate a new chunk of memory of at least size. + /// + /// Note that 0-sized allocations MUST succeed and MUST return a non-null reference. + /// A 0-sized allocation will NOT be dereferenced. + /// + /// If it cannot be satisfied, must be set to false. + /// + ref byte AllocateNew(int sizeBytes, out bool failed); + + /// + /// Free an allocation previously obtained from + /// or . + /// + /// will match sizeBytes or newSizeBytes, respectively, that + /// was passed when the allocation was obtained. + /// + void Free(ref byte start, int sizeBytes); + + /// + /// Resize an existing allocation. + /// + /// Data up to must be preserved. + /// + /// will be a previously obtained, non-null, allocation. + /// will be the sizeBytes or newSizeBytes that was passed when the allocaiton was obtained. + /// + /// This should resize in place if possible. + /// + /// If in place resizeing is not possible, the previously allocation will be freed if this method succeeds. + /// + /// If this allocation cannot be satisifed, must be set to false. + /// + ref byte ResizeAllocation(ref byte start, int oldSizeBytes, int newSizeBytes, out bool failed); + } +} \ No newline at end of file diff --git a/libs/server/Lua/LuaLimitedManagedAllocator.cs b/libs/server/Lua/LuaLimitedManagedAllocator.cs new file mode 100644 index 0000000000..7d63b55ee8 --- /dev/null +++ b/libs/server/Lua/LuaLimitedManagedAllocator.cs @@ -0,0 +1,1088 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Numerics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Garnet.server +{ + /// + /// Provides a mapping of Lua allocations onto the POH. + /// + /// Pre-allocates the full maximum allocation. + /// + /// + /// This is a really naive allocator, just has a free list tracked with pointers in block headers. + /// + /// The hope is that Lua's GC keeps the actual use of this down substantially. + /// + /// A small optimization is trying to keep a big free block at the head of the free list. + /// + internal sealed class LuaLimitedManagedAllocator : ILuaAllocator + { + /// + /// Minimum size we'll round all Lua allocs up to. + /// + /// Based on largest "normal" type Lua will allocate and the needs of . + /// + private const int LuaAllocMinSizeBytes = 16; + + /// + /// Represents a block of memory in this mapper. + /// + /// Blocks always have a size of at least . + /// + /// Blocks are either free or in use: + /// - if free, will be positive. + /// - if in use, will be negative. + /// + /// If free, the block is in a doublely linked list of free blocks. + /// - is undefined if block is in use, 0 if block is tail of list, and ref (as long) to next block otherwise + /// - is undefined if block is in use, 0 if block is head of list, and ref (as long) to previous block otherwise + /// + [StructLayout(LayoutKind.Explicit, Size = StructSizeBytes)] + private struct BlockHeader + { + public const int DataOffset = sizeof(int); + public const int StructSizeBytes = DataOffset + LuaAllocMinSizeBytes; + + /// + /// Size of the block - always valid. + /// + /// If negative, block is in use. + /// + /// If you don't already know the state of the block, use . + /// If you do, use this as it elides a conditional. + /// + [FieldOffset(0)] + public int SizeBytesRaw; + + /// + /// Ref of previous block if this block is in the free list. + /// + /// Only valid if the block is free. + /// + /// Most reads should go through or . + /// + [FieldOffset(sizeof(int))] + public long PrevFreeBlockRefRaw; + + /// + /// Ref of next block if this block is in the free list. + /// + /// Only valid if block is free. + /// + /// Most reads should go through or . + /// + [FieldOffset(sizeof(int) + sizeof(long))] + public long NextFreeBlockRefRaw; + + /// + /// True if block is free. + /// + public readonly bool IsFree + => SizeBytesRaw > 0; + + /// + /// True if block is in use. + /// + public readonly bool IsInUse + => SizeBytesRaw < 0; + + /// + /// Size of the block in bytes. + /// + /// Prefer to checking directly, as it accounts for state bits. + /// + public readonly int SizeBytes + { + get + { + Debug.Assert(IsFree || IsInUse, "Illegal state, neither free nor in use"); + + if (IsFree) + { + return SizeBytesRaw; + } + else + { + return -SizeBytesRaw; + } + } + } + + /// + /// Ref (as a long) of the next block in the free list, if any. + /// + public readonly long NextFreeBlockRef + { + get + { + Debug.Assert(IsFree, "Can't be in free list if allocated"); + + return NextFreeBlockRefRaw; + } + } + + /// + /// Ref of the previous block in the free list, if any. + /// + public readonly long PrevFreeBlockRef + { + get + { + Debug.Assert(IsFree, "Can't be in free list if allocated"); + + return PrevFreeBlockRefRaw; + } + } + + /// + /// Grab a reference to return to users. + /// + [UnscopedRef] + public ref byte DataReference + => ref Unsafe.AddByteOffset(ref Unsafe.As(ref this), DataOffset); + + /// + /// For debugging purposes, all the data covered by this block. + /// + public ReadOnlySpan Data + => MemoryMarshal.CreateReadOnlySpan(ref DataReference, SizeBytes); + + /// + /// Mark block free. + /// + /// After this, the block can be placed in the free list. + /// + public void MarkFree() + { + Debug.Assert(IsInUse, "Double free"); + + SizeBytesRaw = -SizeBytesRaw; + } + + /// + /// Mark block in use. + /// + /// After this, the and properties cannot be accessed. + /// + public void MarkInUse() + { + Debug.Assert(IsFree, "Already allocated"); + + SizeBytesRaw = -SizeBytesRaw; + } + + /// + /// Get a reference to the next block in the free list, or a reference BEFORE . + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public readonly unsafe ref BlockHeader GetNextFreeBlockRef() + => ref Unsafe.AsRef((void*)NextFreeBlockRef); + + /// + /// Get a reference to the prev block in the free list, or a reference BEFORE . + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public readonly unsafe ref BlockHeader GetPrevFreeBlockRef() + => ref Unsafe.AsRef((void*)PrevFreeBlockRef); + + /// + /// Get the block the comes after this one in data. + /// + /// If we're at the end of data, returns a a ref BEFORE . + /// + [UnscopedRef] + public ref BlockHeader GetNextAdjacentBlockRef(ref byte dataStartRef, int finalOffset) + { + ref var selfRef = ref Unsafe.As(ref this); + ref var nextRef = ref Unsafe.Add(ref selfRef, DataOffset + SizeBytes); + + var nextOffset = Unsafe.ByteOffset(ref dataStartRef, ref nextRef); + + if (nextOffset >= finalOffset) + { + // For symmetry with GetXXXFreeBlockRef return of (void*)0 + unsafe + { + return ref Unsafe.AsRef((void*)0); + } + } + + return ref Unsafe.As(ref nextRef); + } + + /// + /// Get a value that can be stashed in or + /// that refers to this block. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public unsafe long GetRefVal() + => (long)Unsafe.AsPointer(ref this); + } + + // Very basic allocation story, we just keep block metadata and + // a free list in the array. + private readonly Memory data; + + private unsafe void* dataStartPtr; + private unsafe void* freeListStartPtr; + + private int debugAllocatedBytes; + + /// + /// For testing purposes, how many bytes are allocated to Lua at the moment. + /// + /// This is how many bytes we've handed out, not the size of the backing array on the POH. + /// + /// Only available in DEBUG to avoid updating this in RELEASE builds. + /// + internal int AllocatedBytes + => debugAllocatedBytes; + + /// + /// For testing purposes, the number of blocks tracked in the free list. + /// + internal int FreeBlockCount + { + get + { + ref var dataStartRef = ref GetDataStartRef(); + + ref var cur = ref GetFreeList(); + + var ret = 0; + + while (IsValidBlockRef(ref cur)) + { + ret++; + + cur = ref cur.GetNextFreeBlockRef(); + } + + return ret; + } + } + + /// + /// For testing purposes, the size of the initial block. + /// + /// Does not care if the block is free or allocated. + /// + internal int FirstBlockSizeBytes + { + get + { + ref var dataStartRef = ref GetDataStartRef(); + + ref var firstBlock = ref Unsafe.As(ref dataStartRef); + + return firstBlock.SizeBytes; + } + } + + internal LuaLimitedManagedAllocator(int backingArraySize) + { + Debug.Assert(backingArraySize >= BlockHeader.StructSizeBytes, "Too small to ever allocate"); + + // Pinned because Lua is oblivious to .NET's GC + data = GC.AllocateUninitializedArray(backingArraySize, pinned: true); + unsafe + { + dataStartPtr = Unsafe.AsPointer(ref MemoryMarshal.GetReference(data.Span)); + } + + ref var dataStartRef = ref GetDataStartRef(); + + // Initialize the first block as a free block covering the whole allocation + unsafe + { + freeListStartPtr = Unsafe.AsPointer(ref dataStartRef); + } + ref var firstBlock = ref GetFreeList(); + firstBlock.SizeBytesRaw = backingArraySize - BlockHeader.DataOffset; + firstBlock.NextFreeBlockRefRaw = 0; + firstBlock.PrevFreeBlockRefRaw = 0; + } + + /// + /// Allocate a new chunk which can fit at least of data. + /// + /// Sets to true if the allocation failed. + /// + /// If the allocation failes, the returned ref will be null. + /// + public ref byte AllocateNew(int sizeBytes, out bool failed) + { + ref var dataStartRef = ref GetDataStartRef(); + + var actualSizeBytes = RoundToMinAlloc(sizeBytes); + + var firstAttempt = true; + + tryAgain: + ref var freeList = ref GetFreeList(); + ref var cur = ref freeList; + while (IsValidBlockRef(ref cur)) + { + Debug.Assert(cur.IsFree, "Free list corrupted"); + + // We know this block is free, so touch size directly + var freeBlockSize = cur.SizeBytesRaw; + + if (freeBlockSize < actualSizeBytes) + { + // Couldn't fit in the block, move on + cur = ref cur.GetNextFreeBlockRef(); + continue; + } + + if (ShouldSplit(ref cur, actualSizeBytes)) + { + ref var newBlock = ref SplitFreeBlock(ref dataStartRef, ref cur, actualSizeBytes); + Debug.Assert(cur.IsFree && cur.SizeBytes == actualSizeBytes, "Split produced unexpected block"); + + freeBlockSize = actualSizeBytes; + + // We know both these blocks are free, so use size directly + if (!Unsafe.AreSame(ref cur, ref freeList) && newBlock.SizeBytesRaw > freeList.SizeBytesRaw) + { + // Move the split block to the head of the list if it's large enough AND we had to traverse + // the free list at all + // + // Idea is to keep bigger things towards the front of the free list if we're spending any time + // chasing pointers + MoveToHeadOfFreeList(ref dataStartRef, ref freeList, ref newBlock); + } + } + + // Cur will work, so remove it from the free list, mark it, and return it + RemoveFromFreeList(ref dataStartRef, ref cur); + cur.MarkInUse(); + + UpdateDebugAllocatedBytes(freeBlockSize); + + failed = false; + return ref cur.DataReference; + } + + // Expensively compact if we're going to fail anyway + if (firstAttempt) + { + firstAttempt = false; + if (TryCoalesceAllFreeBlocks()) + { + goto tryAgain; + } + } + + // Even after compaction we failed to find a large enough block + failed = true; + return ref Unsafe.NullRef(); + } + + /// + /// Return a chunk of memory previously acquired by or + /// . + /// + /// Previously returned (non-null) value. + /// Size passed to last or call. + public void Free(ref byte start, int sizeBytes) + { + ref var dataStartRef = ref GetDataStartRef(); + + ref var blockRef = ref GetBlockRef(ref dataStartRef, ref start); + + Debug.Assert(blockRef.IsInUse, "Should be in use"); + + blockRef.MarkFree(); + AddToFreeList(ref dataStartRef, ref blockRef); + Debug.Assert(blockRef.IsFree, "Should be free by this point"); + + // We know the block is free, so touch size directly + UpdateDebugAllocatedBytes(-blockRef.SizeBytesRaw); + } + + /// + /// Akin to , except reuses the original allocation given in if possible. + /// + public ref byte ResizeAllocation(ref byte start, int oldSizeBytes, int newSizeBytes, out bool failed) + { + ref var dataStartRef = ref GetDataStartRef(); + + ref var curBlock = ref GetBlockRef(ref dataStartRef, ref start); + Debug.Assert(curBlock.IsInUse, "Shouldn't be resizing an allocation that isn't in use"); + + // For everything else, move things up to a reasonable multiple + var actualSizeBytes = RoundToMinAlloc(newSizeBytes); + + // We know the block is in use, so use raw size directly + if (-curBlock.SizeBytesRaw >= newSizeBytes) + { + // Existing allocation is large enough + + if (ShouldSplit(ref curBlock, actualSizeBytes)) + { + // Move the unused space into a new block + + SplitInUseBlock(ref dataStartRef, ref curBlock, actualSizeBytes); + + // And try and coalesce that new block into the largest possible one + ref var nextBlock = ref curBlock.GetNextAdjacentBlockRef(ref dataStartRef, data.Length); + if (IsValidBlockRef(ref nextBlock)) + { + while (TryCoalesceSingleBlock(ref dataStartRef, ref nextBlock, out _)) + { + } + } + } + + failed = false; + return ref start; + } + + // Attempt to grow the allocation in place + var keepInPlace = false; + + while (TryCoalesceSingleBlock(ref dataStartRef, ref curBlock, out var updatedCurBlockSizeBytes)) + { + if (updatedCurBlockSizeBytes >= actualSizeBytes) + { + keepInPlace = true; + break; + } + } + + if (keepInPlace) + { + // We built a big enough block, so use it + if (ShouldSplit(ref curBlock, actualSizeBytes)) + { + // We coalesced such that there's a lot of empty space at the end of this block, peel it off for later reuse + ref var newBlock = ref SplitInUseBlock(ref dataStartRef, ref curBlock, actualSizeBytes); + Debug.Assert(curBlock.IsInUse && curBlock.SizeBytes == actualSizeBytes, "Split produced unexpected block"); + + ref var freeList = ref GetFreeList(); + + // We know freeList is valid and both blocks are free, so check size directly + if (!Unsafe.AreSame(ref newBlock, ref freeList) && newBlock.SizeBytesRaw > freeList.SizeBytesRaw) + { + MoveToHeadOfFreeList(ref dataStartRef, ref freeList, ref newBlock); + } + } + + failed = false; + return ref start; + } + + // We couldn't resize in place, so we need to copy into a new alloc + ref var newAlloc = ref AllocateNew(newSizeBytes, out failed); + if (failed) + { + // Couldn't get a new allocation - per spec this leaves the old allocation alone + return ref Unsafe.NullRef(); + } + + // Copy the data over + var copyLen = newSizeBytes < oldSizeBytes ? newSizeBytes : oldSizeBytes; + + var copyFrom = MemoryMarshal.CreateReadOnlySpan(ref curBlock.DataReference, copyLen); + var copyInto = MemoryMarshal.CreateSpan(ref newAlloc, copyLen); + + copyFrom.CopyTo(copyInto); + + // Free the old alloc now that data is copied out + Free(ref start, oldSizeBytes); + + return ref newAlloc; + } + + /// + /// Returns true if this reference might have been handed out by this allocator. + /// + internal bool ContainsRef(ref byte startRef) + { + ref var dataStartRef = ref GetDataStartRef(); + + var delta = Unsafe.ByteOffset(ref dataStartRef, ref startRef); + + return delta >= 0 && delta < data.Length; + } + + /// + /// Validate the allocator. + /// + /// For testing purposes only. + /// + internal void CheckCorrectness() + { + ref var dataStartRef = ref GetDataStartRef(); + + // Check for cycles in free lists + { + // Basic tortoise and hare: + // - freeSlow moves forward 1 link at a time + // - freeFast moves forward 2 links at a time + // - if freeSlow == freeFast then we have a cycle + + ref var freeSlow = ref GetFreeList(); + ref var freeFast = ref freeSlow; + if (IsValidBlockRef(ref freeFast)) + { + freeFast = ref freeFast.GetNextFreeBlockRef(); + } + + while (IsValidBlockRef(ref freeSlow) && IsValidBlockRef(ref freeFast)) + { + freeSlow = ref freeSlow.GetNextFreeBlockRef(); + freeFast = ref freeFast.GetNextFreeBlockRef(); + if (IsValidBlockRef(ref freeFast)) + { + freeFast = ref freeFast.GetNextFreeBlockRef(); + } + + Check(!Unsafe.AreSame(ref freeSlow, ref freeFast), "Cycle exists in free list"); + } + } + + var walkFreeBlocks = 0; + var walkFreeBytes = 0L; + + // Walk the free list, counting and check that pointer make sense + { + ref var prevFree = ref Unsafe.NullRef(); + ref var curFree = ref GetFreeList(); + while (IsValidBlockRef(ref curFree)) + { + Check(curFree.IsFree, "Allocated block in free list"); + Check(ContainsRef(ref Unsafe.As(ref curFree)), "Free block not in managed bounds"); + + walkFreeBlocks++; + + Check(curFree.SizeBytes > 0, "Illegal size for a free block"); + + walkFreeBytes += curFree.SizeBytes; + + if (IsValidBlockRef(ref prevFree)) + { + ref var prevFromCur = ref curFree.GetPrevFreeBlockRef(); + Check(Unsafe.AreSame(ref prevFree, ref prevFromCur), "Prev link invalid"); + } + + prevFree = ref curFree; + curFree = ref curFree.GetNextFreeBlockRef(); + } + } + + var scanFreeBlocks = 0; + var scanFreeBytes = 0L; + var scanAllocatedBlocks = 0; + var scanAllocatedBytes = 0L; + + // Scan the whole array, counting free and allocated blocks + { + ref var cur = ref Unsafe.As(ref dataStartRef); + while (IsValidBlockRef(ref cur)) + { + Check(ContainsRef(ref Unsafe.As(ref cur)), "Free block not in managed bounds"); + + if (cur.IsFree) + { + scanFreeBlocks++; + scanFreeBytes += cur.SizeBytes; + } + else + { + Check(cur.IsInUse, "Illegal block state"); + + scanAllocatedBlocks++; + scanAllocatedBytes += cur.SizeBytes; + } + + cur = ref cur.GetNextAdjacentBlockRef(ref dataStartRef, data.Length); + } + } + + Check(scanFreeBlocks == walkFreeBlocks, "Free block mismatch"); + Check(scanFreeBytes == walkFreeBytes, "Free bytes mismatch"); + + DebugCheck(scanAllocatedBytes == AllocatedBytes, "Allocated bytes mismatch"); + + var totalBlocks = scanAllocatedBlocks + scanFreeBlocks; + var totalBytes = scanAllocatedBytes + scanFreeBytes; + + var expectedOverhead = totalBlocks * BlockHeader.DataOffset; + + var allBytes = totalBytes + expectedOverhead; + Check(allBytes == data.Length, "Bytes unaccounted for"); + + // Throws if shouldBe is false + static void Check(bool shouldBe, string errorMsg, [CallerArgumentExpression(nameof(shouldBe))] string shouldBeExpr = null) + { + if (shouldBe) + { + return; + } + + throw new InvalidOperationException($"Check failed: {errorMsg} ({shouldBeExpr})"); + } + + // Throws if shouldBe is false, but only in DEBUG builds + [Conditional("DEBUG")] + static void DebugCheck(bool shouldBe, string errorMsg, [CallerArgumentExpression(nameof(shouldBe))] string shouldBeExpr = null) + => Check(shouldBe, errorMsg, shouldBeExpr); + } + + /// + /// Do a very expensive pass attempting to coalesce free blocks as much as possible. + /// + internal bool TryCoalesceAllFreeBlocks() + { + ref var dataStartRef = ref GetDataStartRef(); + + var madeProgress = false; + ref var cur = ref GetFreeList(); + while (IsValidBlockRef(ref cur)) + { + // Coalesce this block repeatedly, so runs of free blocks are collapsed into one + while (TryCoalesceSingleBlock(ref dataStartRef, ref cur, out _)) + { + madeProgress = true; + } + + cur = ref cur.GetNextFreeBlockRef(); + } + + return madeProgress; + } + + /// + /// Add a the given free block to the free list. + /// + private void AddToFreeList(ref byte dataStartRef, ref BlockHeader block) + { + Debug.Assert(block.IsFree, "Can only add free blocks to list"); + + ref var freeListStartRef = ref GetFreeList(); + if (!IsValidBlockRef(ref freeListStartRef)) + { + // Free list is empty + + block.PrevFreeBlockRefRaw = 0; + block.NextFreeBlockRefRaw = 0; + + unsafe + { + freeListStartPtr = Unsafe.AsPointer(ref block); + } + } + else + { + Debug.Assert(freeListStartRef.IsFree, "Free list is corrupted"); + + // Free list isn't empty - block can go in either the head position + // or the one-past-the-head position depending on if block is bigger + // than the free list head + + var freeListStartRefVal = freeListStartRef.GetRefVal(); + var blockRefVal = block.GetRefVal(); + + // We know both blocks are free, so use size directly + if (block.SizeBytesRaw >= freeListStartRef.SizeBytesRaw) + { + // Block is larger, prefer allocating out of it + + // Move block to head of list, and old head to immediately after block + freeListStartRef.PrevFreeBlockRefRaw = blockRefVal; + block.NextFreeBlockRefRaw = freeListStartRefVal; + block.PrevFreeBlockRefRaw = 0; + + // Block is new head of free list + unsafe + { + freeListStartPtr = Unsafe.AsPointer(ref block); + } + } + else + { + // Block is smaller, it's a second choice for allocations + + ref var oldFreeListNextBlock = ref freeListStartRef.GetNextFreeBlockRef(); + + // Move block to right after free list head + block.PrevFreeBlockRefRaw = freeListStartRefVal; + block.NextFreeBlockRefRaw = freeListStartRef.NextFreeBlockRef; + freeListStartRef.NextFreeBlockRefRaw = blockRefVal; + + if (IsValidBlockRef(ref oldFreeListNextBlock)) + { + // Update back-pointer in next block to point to block + oldFreeListNextBlock.PrevFreeBlockRefRaw = blockRefVal; + } + + // Head of free list is unchanged + } + } + } + + /// + /// Removes the given block from the free list. + /// + private void RemoveFromFreeList(ref byte dataStartRef, ref BlockHeader block) + { + Debug.Assert(IsValidBlockRef(ref GetFreeList()), "Shouldn't be removing from free list if free list is empty"); + Debug.Assert(block.IsFree, "Only valid for free blocks"); + + var blockRefVal = block.GetRefVal(); + + ref var prevFreeBlock = ref block.GetPrevFreeBlockRef(); + ref var nextFreeBlock = ref block.GetNextFreeBlockRef(); + + // We've got a few different states we could be in here: + // 1. block is the only thing in the free list + // => freeList == block, prevFreeBlock == null, nextFreeBlock == null + // 2. block is the first thing in the free list, but not the only thing + // => freeList == block, prevFreeBlock == null, nextFreeBlock != null, nextFreeBLock.prev == block + // 3. block is the last thing in the free list, but not the first + // => freeList is valid, freeListStart != block, prevFreeBlock != null, prevFreeBlock.next == block, nextFreeBlock == null + // 4. block is in the middle of the list somewhere + // => freeList is valid, freeListStart != block, prevFreeBlock != null, prefFreeBlock.next == block, nextFreeBlock != null, nextFreeBlock.prev = next + + ref var freeList = ref GetFreeList(); + + if (Unsafe.AreSame(ref freeList, ref block)) + { + Debug.Assert(!IsValidBlockRef(ref prevFreeBlock), "Should be no prev pointer if block is head of free list"); + + if (!IsValidBlockRef(ref nextFreeBlock)) + { + // We're in state #1 - block is the only thing in the free list + + // Remove this last block from the free list, leaving it empty + unsafe + { + // For simplicity, treat empty free lists the same way we do end of the free list + freeListStartPtr = (void*)0; + } + return; + } + else + { + // We're in state #2 - block is first thing in the free list, but not the last + Debug.Assert(nextFreeBlock.PrevFreeBlockRef == blockRefVal, "Broken chain in free list"); + + // NextFreeBlock is new head, it needs to prev point now + nextFreeBlock.PrevFreeBlockRefRaw = 0; + unsafe + { + freeListStartPtr = (void*)block.NextFreeBlockRef; + } + return; + } + } + else + { + Debug.Assert(IsValidBlockRef(ref prevFreeBlock), "Should always have a prev pointer if not head of free list"); + + if (!IsValidBlockRef(ref nextFreeBlock)) + { + // We're in state #3 - block is last thing in the free list, but not the first + Debug.Assert(prevFreeBlock.NextFreeBlockRef == blockRefVal, "Broken chain in free list"); + + // Remove pointer to this block from it's preceeding block + prevFreeBlock.NextFreeBlockRefRaw = 0; + return; + } + else + { + // We're in state #4 - block is just in the middle of the free list somewhere, but not last or first + Debug.Assert(prevFreeBlock.NextFreeBlockRef == blockRefVal, "Broken chain in free list"); + Debug.Assert(nextFreeBlock.PrevFreeBlockRef == blockRefVal, "Broken chain in free list"); + + // Prev needs to skip this block when going forward + prevFreeBlock.NextFreeBlockRefRaw = block.NextFreeBlockRef; + + // Next needs to skip this block when going back + nextFreeBlock.PrevFreeBlockRefRaw = block.PrevFreeBlockRef; + return; + } + } + } + + /// + /// Move this given block to the head of the free list. + /// + /// This assumes that is not already at the head of the free list, and is not + /// empty. + /// + private void MoveToHeadOfFreeList(ref byte dataStartRef, ref BlockHeader freeList, ref BlockHeader block) + { + Debug.Assert(block.IsFree, "Should be free"); + Debug.Assert(freeList.IsFree, "Free list corrupted"); + Debug.Assert(!IsValidBlockRef(ref freeList.GetPrevFreeBlockRef()), "Free list corrupted"); + Debug.Assert(IsValidBlockRef(ref freeList.GetNextFreeBlockRef()), "Free list corrupted"); + + // Because block is not head of the free list, we have two cases here + // 1. Block is in the middle of the list somewhere + // 2. Block is the tail of the list + + ref var prevBlock = ref block.GetPrevFreeBlockRef(); + Debug.Assert(IsValidBlockRef(ref prevBlock), "Block shouldn't be head of free list"); + + ref var nextBlock = ref block.GetNextFreeBlockRef(); + if (IsValidBlockRef(ref nextBlock)) + { + // Case 1 - block is in the middle of the list + + // Update prev.next so it points to block.next and next.prev so it points to block.prev + prevBlock.NextFreeBlockRefRaw = block.NextFreeBlockRef; + nextBlock.PrevFreeBlockRefRaw = block.PrevFreeBlockRef; + } + else + { + // Case 2 - block is at the end of the list + + // Remove prev's pointer to block + prevBlock.NextFreeBlockRefRaw = 0; + } + + // Move block to the head of the list + freeList.PrevFreeBlockRefRaw = block.GetRefVal(); + block.NextFreeBlockRefRaw = freeList.GetRefVal(); + block.PrevFreeBlockRefRaw = 0; + + // Update freeListStartPtr to refer to block + unsafe + { + freeListStartPtr = Unsafe.AsPointer(ref block); + } + } + + /// + /// Attempt to coalesce a block with its adjacent block. + /// + /// can be free or allocated, but coalescing will only succeed + /// if the adjacent block is free. + /// + /// is set only if the return is true. + /// + private bool TryCoalesceSingleBlock(ref byte dataStartRef, ref BlockHeader block, out int newBlockSizeBytes) + { + ref var nextBlock = ref block.GetNextAdjacentBlockRef(ref dataStartRef, data.Length); + if (!IsValidBlockRef(ref nextBlock) || !nextBlock.IsFree) + { + Unsafe.SkipInit(out newBlockSizeBytes); + return false; + } + + // We know nextBlock is free, so touch size directly + var nextBlockSizeBytes = nextBlock.SizeBytesRaw; + + RemoveFromFreeList(ref dataStartRef, ref nextBlock); + + if (block.IsFree) + { + newBlockSizeBytes = nextBlockSizeBytes + block.SizeBytesRaw + BlockHeader.DataOffset; + + block.SizeBytesRaw = newBlockSizeBytes; + } + else + { + // Because we merged a free block into an allocated one we need to update the allocated byte total + + // We know the block is in use, so block.SizeBytesRaw is -SizeBytes + newBlockSizeBytes = nextBlockSizeBytes - block.SizeBytesRaw + BlockHeader.DataOffset; + UpdateDebugAllocatedBytes(block.SizeBytesRaw); + + block.SizeBytesRaw = -newBlockSizeBytes; + + UpdateDebugAllocatedBytes(newBlockSizeBytes); + } + + return true; + } + + /// + /// Split an in use block, such that the current block ends up with a size equal to . + /// + /// Returns a reference to the NEW free block. + /// + private ref BlockHeader SplitInUseBlock(ref byte dataStartRef, ref BlockHeader curBlock, int curBlockUpdateSizeBytes) + { + Debug.Assert(curBlock.IsInUse, "Only valid for in use blocks"); + + // We know the block is in use, so touch size bytes directly + var oldSizeBytes = -curBlock.SizeBytesRaw; + + ref var newBlock = ref SplitCommon(ref curBlock, curBlockUpdateSizeBytes); + + // New block needs to be placed in free list + AddToFreeList(ref dataStartRef, ref newBlock); + + // Because we split some bytes out of an allocated block, that means we need to remove those from allocation tracking + UpdateDebugAllocatedBytes(-oldSizeBytes); + + // We know the block is in use, so -SizeBytesRaw will be SizeBytes (unconditionally) + UpdateDebugAllocatedBytes(-curBlock.SizeBytesRaw); + + return ref newBlock; + } + + /// + /// Split a free block such that the current block ends up with a size equal to . + /// + /// Returns a reference to the NEW block. + /// + private ref BlockHeader SplitFreeBlock(ref byte dataStartRef, ref BlockHeader curBlock, int curBlockUpdateSizeBytes) + { + Debug.Assert(curBlock.IsFree, "Only valid for free blocks"); + + ref var oldNextBlock = ref curBlock.GetNextFreeBlockRef(); + + ref var newBlock = ref SplitCommon(ref curBlock, curBlockUpdateSizeBytes); + + var curBlockRefVal = curBlock.GetRefVal(); + + // Update newBlock + newBlock.PrevFreeBlockRefRaw = curBlockRefVal; + newBlock.NextFreeBlockRefRaw = curBlock.NextFreeBlockRef; + Debug.Assert(newBlock.IsFree, "New block shoud be free"); + Debug.Assert(ContainsRef(ref Unsafe.As(ref newBlock)), "New block out of managed memory"); + + var newBlockRefVal = newBlock.GetRefVal(); + + // Update curBlock + curBlock.NextFreeBlockRefRaw = newBlockRefVal; + Debug.Assert(curBlock.IsFree, "New block shoud be free"); + Debug.Assert(ContainsRef(ref Unsafe.As(ref curBlock)), "Split block out of managed memory"); + + // Update the old next block if it exists + if (IsValidBlockRef(ref oldNextBlock)) + { + Debug.Assert(oldNextBlock.IsFree, "Should have been in free list"); + oldNextBlock.PrevFreeBlockRefRaw = newBlockRefVal; + } + + return ref newBlock; + } + + /// + /// Grab the start of the managed memory we're allocating out of. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private unsafe ref byte GetDataStartRef() + => ref Unsafe.AsRef(dataStartPtr); + + /// + /// Get the start of the free list. + /// + /// If the free list is empty, returns a null ref. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private unsafe ref BlockHeader GetFreeList() + => ref Unsafe.AsRef(freeListStartPtr); + + /// + /// Turn a reference obtained from back into a reference. + /// + private ref BlockHeader GetBlockRef(ref byte dataStartRef, ref byte userDataRef) + { + Debug.Assert(!Unsafe.AreSame(ref dataStartRef, ref userDataRef), "User data is actually 0 size alloc, that doesn't make sense"); + + ref var blockStartByteRef = ref Unsafe.Add(ref userDataRef, -BlockHeader.DataOffset); + + Debug.Assert(!Unsafe.IsAddressLessThan(ref blockStartByteRef, ref dataStartRef), "User data is before managed memory"); + Debug.Assert(!Unsafe.IsAddressGreaterThan(ref blockStartByteRef, ref Unsafe.Add(ref dataStartRef, data.Length - 1)), "User data is after managed memory"); + + return ref Unsafe.As(ref blockStartByteRef); + } + + /// + /// In DEBUG builds, keep track of how many bytes we've allocated for testing purposes. + /// + [Conditional("DEBUG")] + private void UpdateDebugAllocatedBytes(int by) + => debugAllocatedBytes += by; + + /// + /// Common logic for splitting blocks. + /// + /// Block here can either be free or in use, so don't make any assumptionsin here. + /// + private static ref BlockHeader SplitCommon(ref BlockHeader curBlock, int curBlockUpdateSizeBytes) + { + Debug.Assert(curBlockUpdateSizeBytes >= LuaAllocMinSizeBytes, "Shouldn't split an existing block to be this small"); + + ref var curBlockData = ref curBlock.DataReference; + ref var newBlockStartByteRef = ref Unsafe.AddByteOffset(ref curBlockData, curBlockUpdateSizeBytes); + ref var newBlock = ref Unsafe.As(ref newBlockStartByteRef); + + int newBlockSizeBytes; + if (curBlock.IsFree) + { + // We know curBlock is free, so touch SizeBytesRaw directly + newBlockSizeBytes = curBlock.SizeBytesRaw - BlockHeader.DataOffset - curBlockUpdateSizeBytes; + + curBlock.SizeBytesRaw = curBlockUpdateSizeBytes; + } + else + { + Debug.Assert(curBlock.IsInUse, "Invalid block state"); + + // We know curBlock is is inuse, so touch SizeBytesRaw directly + newBlockSizeBytes = -curBlock.SizeBytesRaw - BlockHeader.DataOffset - curBlockUpdateSizeBytes; + + curBlock.SizeBytesRaw = -curBlockUpdateSizeBytes; + } + + Debug.Assert(newBlockSizeBytes >= LuaAllocMinSizeBytes, "Shouldn't create a new block this small"); + + // The new block is always free, so positive size + newBlock.SizeBytesRaw = newBlockSizeBytes; + + + return ref newBlock; + } + + /// + /// Check if a block should be split if it's used to serve a claim of the given size. + /// + private static bool ShouldSplit(ref BlockHeader block, int claimedBytes) + { + var unusedBytes = block.SizeBytes - claimedBytes; + + return unusedBytes >= BlockHeader.StructSizeBytes; + } + + /// + /// Turn requested bytes into the actual number of bytes we're going to reserve. + /// + private static int RoundToMinAlloc(int sizeBytes) + { + Debug.Assert(BitOperations.IsPow2(LuaAllocMinSizeBytes), "Assumes min allocation is a power of 2"); + + // To avoid special casing 0, we can reserve up to an extra LuaAllocMinSizeBytes + var ret = (sizeBytes + LuaAllocMinSizeBytes) & ~(LuaAllocMinSizeBytes - 1); + + Debug.Assert(ret > 0, "Rounding logic invalid - not positive"); + Debug.Assert(ret >= sizeBytes, "Rounding logic invalid - did not round up"); + Debug.Assert(ret % LuaAllocMinSizeBytes == 0, "Rounding logic invalid - not a whole multiple of step size"); + + return ret; + } + + /// + /// Check if is valid. + /// + /// Pulled out to indicate intent, it's basically just a null check + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe bool IsValidBlockRef(ref BlockHeader blockRef) + => Unsafe.AsPointer(ref blockRef) != (void*)0; + } +} \ No newline at end of file diff --git a/libs/server/Lua/LuaManagedAllocator.cs b/libs/server/Lua/LuaManagedAllocator.cs new file mode 100644 index 0000000000..72d777143e --- /dev/null +++ b/libs/server/Lua/LuaManagedAllocator.cs @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Numerics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Garnet.server +{ + /// + /// Provides a mapping of Lua allocations on to the POH. + /// + /// Unlike , allocations are not limited. + /// + /// + /// This is implemented in terms of because + /// it is expected to be ununusual to want allocations on the POH but unlimitted. + /// + internal sealed class LuaManagedAllocator : ILuaAllocator + { + private const int DefaultAllocationBytes = 2 * 1_024 * 1_024; + + private readonly List subAllocators = []; + public LuaManagedAllocator() { } + + /// + public ref byte AllocateNew(int sizeBytes, out bool failed) + { + for (var i = 0; i < subAllocators.Count; i++) + { + var alloc = subAllocators[i]; + + ref var ret = ref alloc.AllocateNew(sizeBytes, out failed); + if (!failed) + { + return ref ret; + } + } + + var newAllocSize = DefaultAllocationBytes; + + // Need to account for overhead in LuaLimitedManagedAllocator blocks + var minAllocSize = sizeBytes + sizeof(int); + if (newAllocSize < minAllocSize) + { + newAllocSize = (int)BitOperations.RoundUpToPowerOf2((nuint)minAllocSize); + if (newAllocSize < 0) + { + // Handle overflow + failed = true; + return ref Unsafe.NullRef(); + } + } + + var newAlloc = new LuaLimitedManagedAllocator(newAllocSize); + subAllocators.Add(newAlloc); + + ref var newRet = ref newAlloc.AllocateNew(sizeBytes, out failed); + Debug.Assert(!failed, "Failure shouldn't be possible with new alloc"); + + return ref newRet; + } + + /// + public void Free(ref byte start, int sizeBytes) + { + for (var i = 0; i < subAllocators.Count; i++) + { + var alloc = subAllocators[i]; + + if (alloc.ContainsRef(ref start)) + { + alloc.Free(ref start, sizeBytes); + + return; + } + } + + throw new InvalidOperationException("Allocation could not be found"); + } + + /// + public ref byte ResizeAllocation(ref byte start, int oldSizeBytes, int newSizeBytes, out bool failed) + { + for (var i = 0; i < subAllocators.Count; i++) + { + var alloc = subAllocators[i]; + + if (alloc.ContainsRef(ref start)) + { + ref var ret = ref alloc.ResizeAllocation(ref start, oldSizeBytes, newSizeBytes, out failed); + if (!failed) + { + return ref ret; + } + + // Have to make a copy into a new allocation + ref var newAlloc = ref AllocateNew(newSizeBytes, out failed); + if (failed) + { + return ref Unsafe.NullRef(); + } + + var copyLen = newSizeBytes < oldSizeBytes ? newSizeBytes : oldSizeBytes; + var from = MemoryMarshal.CreateReadOnlySpan(ref start, copyLen); + var to = MemoryMarshal.CreateSpan(ref newAlloc, copyLen); + from.CopyTo(to); + + // Release the old allocation + alloc.Free(ref start, oldSizeBytes); + + failed = false; + return ref newAlloc; + } + } + + throw new InvalidOperationException("Allocation could not be found"); + } + } +} \ No newline at end of file diff --git a/libs/server/Lua/LuaOptions.cs b/libs/server/Lua/LuaOptions.cs new file mode 100644 index 0000000000..395d20ff38 --- /dev/null +++ b/libs/server/Lua/LuaOptions.cs @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System.Runtime.InteropServices; +using Microsoft.Extensions.Logging; + +namespace Garnet.server +{ + /// + /// Options for Lua scripting. + /// + public sealed class LuaOptions + { + private readonly ILogger logger; + + public LuaMemoryManagementMode MemoryManagementMode = LuaMemoryManagementMode.Native; + public string MemoryLimit = ""; + + /// + /// Construct options with default options. + /// + public LuaOptions(ILogger logger = null) + { + this.logger = logger; + } + + /// + /// Construct options with specific settings. + /// + public LuaOptions(LuaMemoryManagementMode memoryMode, string memoryLimit, ILogger logger = null) : this(logger) + { + MemoryManagementMode = memoryMode; + MemoryLimit = memoryLimit; + } + + /// + /// Get the memory limit, if any, for each script invocation. + /// + internal int? GetMemoryLimitBytes() + { + if (string.IsNullOrEmpty(MemoryLimit)) + { + return null; + } + + if (MemoryManagementMode == LuaMemoryManagementMode.Native) + { + logger?.LogWarning("Lua script memory limit is ignored when mode = {MemoryManagementMode}", MemoryManagementMode); + return null; + } + + var ret = GarnetServerOptions.ParseSize(MemoryLimit); + if (ret is > int.MaxValue or < 1_024) + { + logger?.LogWarning("Lua script memory limit is out of range [1K, 2GB] = {MemoryLimit} and will be ignored", MemoryLimit); + return null; + } + + return (int)ret; + } + } + + /// + /// Different Lua supported memory modes. + /// + public enum LuaMemoryManagementMode + { + /// + /// Uses default Lua allocator - .NET host is unaware of allocations. + /// + Native = 0, + + /// + /// Uses and informs .NET host of the allocations. + /// + /// Limits are inexactly applied due to native memory allocation overhead. + /// + Tracked = 1, + + /// + /// Places allocations on the POH using a naive, free-list based, allocator. + /// + /// Limits are pre-allocated when scripts runs, which can increase allocation pressure. + /// + Managed = 2, + } +} \ No newline at end of file diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 38851209da..d9525df2da 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -74,7 +74,6 @@ private unsafe struct RunnerAdapter : IResponseAdapter private readonly ScratchBufferManager bufferManager; private byte* origin; private byte* curHead; - private byte* curEnd; internal RunnerAdapter(ScratchBufferManager bufferManager) { @@ -84,7 +83,7 @@ internal RunnerAdapter(ScratchBufferManager bufferManager) var scratchSpace = bufferManager.FullBuffer(); origin = curHead = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(scratchSpace)); - curEnd = curHead + scratchSpace.Length; + BufferEnd = curHead + scratchSpace.Length; } #pragma warning disable CS9084 // Struct member returns 'this' or other instance members by reference @@ -94,8 +93,7 @@ public unsafe ref byte* BufferCur #pragma warning restore CS9084 /// - public unsafe byte* BufferEnd - => curEnd; + public unsafe byte* BufferEnd { get; private set; } /// /// Gets a span that covers the responses as written so far. @@ -118,28 +116,21 @@ public void SendAndReset() var len = (int)(curHead - origin); // We don't actually send anywhere, we grow the backing array - bufferManager.GrowBuffer(); + // + // Since we're managing the start/end pointers outside of the buffer + // we need to signal that the buffer has data to copy + bufferManager.GrowBuffer(copyLengthOverride: len); var scratchSpace = bufferManager.FullBuffer(); origin = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(scratchSpace)); - curEnd = origin + scratchSpace.Length; + BufferEnd = origin + scratchSpace.Length; curHead = origin + len; } } const string LoaderBlock = @" import = function () end -redis = {} -function redis.call(...) - return garnet_call(...) -end -function redis.status_reply(text) - return text -end -function redis.error_reply(text) - return { err = 'ERR ' .. text } -end KEYS = {} ARGV = {} sandbox_env = { @@ -163,7 +154,6 @@ function redis.error_reply(text) rawequal = rawequal; rawget = rawget; rawset = rawset; - redis = redis; select = select; -- explicitly not allowing setfenv string = string; @@ -180,64 +170,62 @@ function redis.error_reply(text) } -- do resets in the Lua side to minimize pinvokes function reset_keys_and_argv(fromKey, fromArgv) - local keyCount = #KEYS + local keyRef = KEYS + local keyCount = #keyRef for i = fromKey, keyCount do - KEYS[i] = nil + table.remove(keyRef) end - local argvCount = #ARGV + local argvRef = ARGV + local argvCount = #argvRef for i = fromArgv, argvCount do - ARGV[i] = nil + table.remove(argvRef) end end -- responsible for sandboxing user provided code function load_sandboxed(source) - if (not source) then return nil end - local rawFunc, err = load(source, nil, nil, sandbox_env) + -- move into a local to avoid global lookup + local garnetCallRef = garnet_call - -- compilation error is returned directly - if err then - return rawFunc, err - end + sandbox_env.redis = { + status_reply = function(text) + return text + end, - -- otherwise we wrap the compiled function in a helper - return function() - local rawRet = rawFunc() + error_reply = function(text) + return { err = 'ERR ' .. text } + end, - -- handle ok response wrappers without crossing the pinvoke boundary - -- err response wrapper requires a bit more work, but is also rarer - if rawRet and type(rawRet) == ""table"" and rawRet.ok then - return rawRet.ok + call = function(...) + return garnetCallRef(...) end + } - return rawRet - end + local rawFunc, err = load(source, nil, nil, sandbox_env) + + return err, rawFunc end "; private static readonly ReadOnlyMemory LoaderBlockBytes = Encoding.UTF8.GetBytes(LoaderBlock); - // Rooted to keep function pointer alive - readonly LuaFunction garnetCall; - // References into Registry on the Lua side // // These are mix of objects we regularly update, // constants we want to avoid copying from .NET to Lua, // and the compiled function definition. - readonly int sandboxEnvRegistryIndex; readonly int keysTableRegistryIndex; readonly int argvTableRegistryIndex; readonly int loadSandboxedRegistryIndex; readonly int resetKeysAndArgvRegistryIndex; readonly int okConstStringRegistryIndex; + readonly int okLowerConstStringRegistryIndex; readonly int errConstStringRegistryIndex; readonly int noSessionAvailableConstStringRegistryIndex; readonly int pleaseSpecifyRedisCallConstStringRegistryIndex; readonly int errNoAuthConstStringRegistryIndex; readonly int errUnknownConstStringRegistryIndex; readonly int errBadArgConstStringRegistryIndex; - int functionRegistryIndex; readonly ReadOnlyMemory source; readonly ScratchBufferNetworkSender scratchBufferNetworkSender; @@ -248,41 +236,61 @@ function load_sandboxed(source) readonly TxnKeyEntries txnKeyEntries; readonly bool txnMode; + // The Lua registry index under which the user supplied function is stored post-compilation + int functionRegistryIndex; + // This cannot be readonly, as it is a mutable struct LuaStateWrapper state; + // We need to temporarily store these for P/Invoke reasons + // You shouldn't be touching them outside of the Compile and Run methods + + RunnerAdapter runnerAdapter; + RespResponseAdapter sessionAdapter; + RespServerSession preambleOuterSession; + int preambleKeyAndArgvCount; + int preambleNKeys; + string[] preambleKeys; + string[] preambleArgv; + int keyLength, argvLength; /// /// Creates a new runner with the source of the script /// - public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null) + public unsafe LuaRunner(LuaMemoryManagementMode memMode, int? memLimitBytes, ReadOnlyMemory source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null) { this.source = source; this.txnMode = txnMode; this.respServerSession = respServerSession; this.scratchBufferNetworkSender = scratchBufferNetworkSender; - this.scratchBufferManager = respServerSession?.scratchBufferManager ?? new(); this.logger = logger; - sandboxEnvRegistryIndex = -1; + scratchBufferManager = respServerSession?.scratchBufferManager ?? new(); + keysTableRegistryIndex = -1; argvTableRegistryIndex = -1; loadSandboxedRegistryIndex = -1; functionRegistryIndex = -1; - // TODO: custom allocator? - state = new LuaStateWrapper(new Lua()); + state = new LuaStateWrapper(memMode, memLimitBytes, this.logger); + delegate* unmanaged[Cdecl] garnetCall; if (txnMode) { txnKeyEntries = new TxnKeyEntries(16, respServerSession.storageSession.lockableContext, respServerSession.storageSession.objectStoreLockableContext); - garnetCall = garnet_call_txn; + garnetCall = &LuaRunnerTrampolines.GarnetCallWithTransaction; } else { - garnetCall = garnet_call; + garnetCall = &LuaRunnerTrampolines.GarnetCallNoTransaction; + } + + if (respServerSession == null) + { + // During benchmarking and testing this can happen, so just redirect once instead of on each redis.call + garnetCall = &LuaRunnerTrampolines.GarnetCallNoSession; } var loadRes = state.LoadBuffer(LoaderBlockBytes.Span); @@ -294,29 +302,46 @@ public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSe var sandboxRes = state.PCall(0, -1); if (sandboxRes != LuaStatus.OK) { - throw new GarnetException("Could not initialize Lua sandbox state"); + string errMsg; + try + { + if (state.StackTop >= 1) + { + // We control the definition of LoaderBlock, so we know this is a string + state.KnownStringToBuffer(1, out var errSpan); + errMsg = Encoding.UTF8.GetString(errSpan); + } + else + { + errMsg = "No error provided"; + } + } + catch + { + errMsg = "Error when fetching pcall error"; + } + + throw new GarnetException($"Could not initialize Lua sandbox state: {errMsg}"); } // Register garnet_call in global namespace - state.Register("garnet_call", garnetCall); - - state.GetGlobal(LuaType.Table, "sandbox_env"); - sandboxEnvRegistryIndex = state.Ref(); + state.Register("garnet_call\0"u8, garnetCall); - state.GetGlobal(LuaType.Table, "KEYS"); + state.GetGlobal(LuaType.Table, "KEYS\0"u8); keysTableRegistryIndex = state.Ref(); - state.GetGlobal(LuaType.Table, "ARGV"); + state.GetGlobal(LuaType.Table, "ARGV\0"u8); argvTableRegistryIndex = state.Ref(); - state.GetGlobal(LuaType.Function, "load_sandboxed"); + state.GetGlobal(LuaType.Function, "load_sandboxed\0"u8); loadSandboxedRegistryIndex = state.Ref(); - state.GetGlobal(LuaType.Function, "reset_keys_and_argv"); + state.GetGlobal(LuaType.Function, "reset_keys_and_argv\0"u8); resetKeysAndArgvRegistryIndex = state.Ref(); // Commonly used strings, register them once so we don't have to copy them over each time we need them okConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_OK); + okLowerConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_ok); errConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_err); noSessionAvailableConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_No_session_available); pleaseSpecifyRedisCallConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_ERR_Please_specify_at_least_one_argument_for_this_redis_lib_call); @@ -330,8 +355,8 @@ public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSe /// /// Creates a new runner with the source of the script /// - public LuaRunner(string source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null) - : this(Encoding.UTF8.GetBytes(source), txnMode, respServerSession, scratchBufferNetworkSender, logger) + public LuaRunner(LuaOptions options, string source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null) + : this(options.MemoryManagementMode, options.GetMemoryLimitBytes(), Encoding.UTF8.GetBytes(source), txnMode, respServerSession, scratchBufferNetworkSender, logger) { } @@ -340,7 +365,7 @@ public LuaRunner(string source, bool txnMode = false, RespServerSession respServ /// /// So instead we stash them in the Registry and load them by index /// - int ConstantStringToRegistry(ReadOnlySpan str) + private int ConstantStringToRegistry(ReadOnlySpan str) { state.PushBuffer(str); return state.Ref(); @@ -353,30 +378,86 @@ int ConstantStringToRegistry(ReadOnlySpan str) /// public unsafe void CompileForRunner() { - var adapter = new RunnerAdapter(scratchBufferManager); - CompileCommon(ref adapter); - - var resp = adapter.Response; - var respStart = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(resp)); - var respEnd = respStart + resp.Length; - if (RespReadUtils.TryReadErrorAsSpan(out var errSpan, ref respStart, respEnd)) + runnerAdapter = new RunnerAdapter(scratchBufferManager); + try + { + LuaRunnerTrampolines.SetCallbackContext(this); + state.PushCFunction(&LuaRunnerTrampolines.CompileForRunner); + var res = state.PCall(0, 0); + if (res == LuaStatus.OK) + { + var resp = runnerAdapter.Response; + var respStart = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(resp)); + var respEnd = respStart + resp.Length; + if (RespReadUtils.TryReadErrorAsSpan(out var errSpan, ref respStart, respEnd)) + { + var errStr = Encoding.UTF8.GetString(errSpan); + throw new GarnetException(errStr); + } + } + else + { + throw new GarnetException($"Internal Lua Error: {res}"); + } + } + finally { - var errStr = Encoding.UTF8.GetString(errSpan); - throw new GarnetException(errStr); + runnerAdapter = default; + LuaRunnerTrampolines.ClearCallbackContext(this); } } + /// + /// Actually compiles for runner. + /// + /// If you call this directly and Lua encounters an error, the process will crash. + /// + /// Call instead. + /// + internal int UnsafeCompileForRunner() + => CompileCommon(ref runnerAdapter); + /// /// Compile script for a . /// /// Any errors encountered are written out as Resp errors. /// - public void CompileForSession(RespServerSession session) + public unsafe bool CompileForSession(RespServerSession session) { - var adapter = new RespResponseAdapter(session); - CompileCommon(ref adapter); + sessionAdapter = new RespResponseAdapter(session); + + try + { + LuaRunnerTrampolines.SetCallbackContext(this); + state.PushCFunction(&LuaRunnerTrampolines.CompileForSession); + var res = state.PCall(0, 0); + if (res != LuaStatus.OK) + { + while (!RespWriteUtils.WriteError("Internal Lua Error"u8, ref session.dcurr, session.dend)) + session.SendAndReset(); + + return false; + } + + return true; + } + finally + { + sessionAdapter = default; + LuaRunnerTrampolines.ClearCallbackContext(this); + } } + /// + /// Actually compiles for runner. + /// + /// If you call this directly and Lua encounters an error, the process will crash. + /// + /// Call instead. + /// + internal int UnsafeCompileForSession() + => CompileCommon(ref sessionAdapter); + /// /// Drops compiled function, just for benchmarking purposes. /// @@ -392,7 +473,7 @@ public void ResetCompilation() /// /// Compile script, writing errors out to given response. /// - unsafe void CompileCommon(ref TResponse resp) + private unsafe int CompileCommon(ref TResponse resp) where TResponse : struct, IResponseAdapter { const int NeededStackSpace = 2; @@ -401,105 +482,61 @@ unsafe void CompileCommon(ref TResponse resp) state.ExpectLuaStackEmpty(); - try - { - state.ForceMinimumStackCapacity(NeededStackSpace); - - state.PushInteger(loadSandboxedRegistryIndex); - _ = state.RawGet(LuaType.Function, (int)LuaRegistry.Index); - - state.PushBuffer(source.Span); - state.Call(1, -1); // Multiple returns allowed - - var numRets = state.StackTop; - - if (numRets == 0) - { - while (!RespWriteUtils.WriteError("Shouldn't happen, no returns from load_sandboxed"u8, ref resp.BufferCur, resp.BufferEnd)) - resp.SendAndReset(); - - return; - } - else if (numRets == 1) - { - var returnType = state.Type(1); - if (returnType != LuaType.Function) - { - var errStr = $"Could not compile function, got back a {returnType}"; - while (!RespWriteUtils.WriteError(errStr, ref resp.BufferCur, resp.BufferEnd)) - resp.SendAndReset(); - - return; - } + state.ForceMinimumStackCapacity(NeededStackSpace); - functionRegistryIndex = state.Ref(); - } - else if (numRets == 2) - { - state.CheckBuffer(2, out var errorBuf); + _ = state.RawGetInteger(LuaType.Function, (int)LuaRegistry.Index, loadSandboxedRegistryIndex); + state.PushBuffer(source.Span); + state.Call(1, 2); - var errStr = $"Compilation error: {Encoding.UTF8.GetString(errorBuf)}"; - while (!RespWriteUtils.WriteError(errStr, ref resp.BufferCur, resp.BufferEnd)) - resp.SendAndReset(); + // Now the stack will have two things on it: + // 1. The error (nil if not error) + // 2. The function (nil if error) - state.Pop(2); + if (state.Type(1) == LuaType.Nil) + { + // No error, success! - return; - } - else - { - state.Pop(numRets); + Debug.Assert(state.Type(2) == LuaType.Function, "Unexpected type returned from load_sandboxed"); - throw new GarnetException($"Unexpected error compiling, got too many replies back: reply count = {numRets}"); - } - } - catch (Exception ex) - { - logger?.LogError(ex, "CreateFunction threw an exception"); - throw; + functionRegistryIndex = state.Ref(); } - finally + else { - state.ExpectLuaStackEmpty(); + // We control the definition of load_sandboxed, so we know this will be a string + state.KnownStringToBuffer(1, out var errorBuf); + + var errStr = $"Compilation error: {Encoding.UTF8.GetString(errorBuf)}"; + while (!RespWriteUtils.WriteError(errStr, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); } + + return 0; } /// /// Dispose the runner /// public void Dispose() - { - state.Dispose(); - } + => state.Dispose(); /// /// Entry point for redis.call method from a Lua script (non-transactional mode) /// - public int garnet_call(IntPtr luaStatePtr) + public int GarnetCall(nint luaStatePtr) { state.CallFromLuaEntered(luaStatePtr); - if (respServerSession == null) - { - return NoSessionResponse(); - } - - return ProcessCommandFromScripting(respServerSession.basicGarnetApi); + return ProcessCommandFromScripting(ref respServerSession.basicGarnetApi); } /// /// Entry point for redis.call method from a Lua script (transactional mode) /// - public int garnet_call_txn(IntPtr luaStatePtr) + public int GarnetCallWithTransaction(nint luaStatePtr) { state.CallFromLuaEntered(luaStatePtr); - if (respServerSession == null) - { - return NoSessionResponse(); - } - - return ProcessCommandFromScripting(respServerSession.lockableGarnetApi); + return ProcessCommandFromScripting(ref respServerSession.lockableGarnetApi); } /// @@ -507,10 +544,12 @@ public int garnet_call_txn(IntPtr luaStatePtr) /// /// This is used in benchmarking. /// - int NoSessionResponse() + internal int NoSessionResponse(nint luaStatePtr) { const int NeededStackSpace = 1; + state.CallFromLuaEntered(luaStatePtr); + state.ForceMinimumStackCapacity(NeededStackSpace); state.PushNil(); @@ -520,7 +559,7 @@ int NoSessionResponse() /// /// Entry point method for executing commands from a Lua Script /// - unsafe int ProcessCommandFromScripting(TGarnetApi api) + unsafe int ProcessCommandFromScripting(ref TGarnetApi api) where TGarnetApi : IGarnetApi { const int AdditionalStackSpace = 1; @@ -529,7 +568,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) { var argCount = state.StackTop; - if (argCount == 0) + if (argCount <= 0) { return LuaStaticError(pleaseSpecifyRedisCallConstStringRegistryIndex); } @@ -646,7 +685,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) /// /// Cause a Lua error to be raised with a message previously registered. /// - int LuaStaticError(int constStringRegistryIndex) + private int LuaStaticError(int constStringRegistryIndex) { const int NeededStackSize = 1; @@ -661,7 +700,7 @@ int LuaStaticError(int constStringRegistryIndex) /// /// Pushes result onto state stack and returns 1, or raises an error and never returns. /// - unsafe int ProcessResponse(byte* ptr, int length) + private unsafe int ProcessResponse(byte* ptr, int length) { const int NeededStackSize = 3; @@ -774,26 +813,76 @@ unsafe int ProcessResponse(byte* ptr, int length) /// /// Response is written directly into the . /// - public void RunForSession(int count, RespServerSession outerSession) + public unsafe void RunForSession(int count, RespServerSession outerSession) { - const int NeededStackSize = 3; + const int NeededStackSize = 2; state.ForceMinimumStackCapacity(NeededStackSize); + preambleOuterSession = outerSession; + preambleKeyAndArgvCount = count; + try + { + LuaRunnerTrampolines.SetCallbackContext(this); + + try + { + state.PushCFunction(&LuaRunnerTrampolines.RunPreambleForSession); + var callRes = state.PCall(0, 0); + if (callRes != LuaStatus.OK) + { + while (!RespWriteUtils.WriteError("Internal Lua Error"u8, ref outerSession.dcurr, outerSession.dend)) + outerSession.SendAndReset(); + + return; + } + } + finally + { + preambleOuterSession = null; + } + + var adapter = new RespResponseAdapter(outerSession); + + if (txnMode && preambleNKeys > 0) + { + RunInTransaction(ref adapter); + } + else + { + RunCommon(ref adapter); + } + } + finally + { + LuaRunnerTrampolines.ClearCallbackContext(this); + } + } + + /// + /// Setups a script to be run. + /// + /// If you call this directly and Lua encounters an error, the process will crash. + /// + /// Call instead. + /// + internal int UnsafeRunPreambleForSession() + { + state.ExpectLuaStackEmpty(); + scratchBufferManager.Reset(); - var parseState = outerSession.parseState; + ref var parseState = ref preambleOuterSession.parseState; var offset = 1; - var nKeys = parseState.GetInt(offset++); - count--; - ResetParameters(nKeys, count - nKeys); + var nKeys = preambleNKeys = parseState.GetInt(offset++); + preambleKeyAndArgvCount--; + ResetParameters(nKeys, preambleKeyAndArgvCount - nKeys); if (nKeys > 0) { // Get KEYS on the stack - state.PushInteger(keysTableRegistryIndex); - state.RawGet(LuaType.Table, (int)LuaRegistry.Index); + _ = state.RawGetInteger(LuaType.Table, (int)LuaRegistry.Index, keysTableRegistryIndex); for (var i = 0; i < nKeys; i++) { @@ -807,9 +896,8 @@ public void RunForSession(int count, RespServerSession outerSession) } // Equivalent to KEYS[i+1] = key - state.PushInteger(i + 1); state.PushBuffer(key.ReadOnlySpan); - state.RawSet(1); + state.RawSetInteger(1, i + 1); offset++; } @@ -817,23 +905,21 @@ public void RunForSession(int count, RespServerSession outerSession) // Remove KEYS from the stack state.Pop(1); - count -= nKeys; + preambleKeyAndArgvCount -= nKeys; } - if (count > 0) + if (preambleKeyAndArgvCount > 0) { // Get ARGV on the stack - state.PushInteger(argvTableRegistryIndex); - state.RawGet(LuaType.Table, (int)LuaRegistry.Index); + _ = state.RawGetInteger(LuaType.Table, (int)LuaRegistry.Index, argvTableRegistryIndex); - for (var i = 0; i < count; i++) + for (var i = 0; i < preambleKeyAndArgvCount; i++) { ref var argv = ref parseState.GetArgSliceByRef(offset); // Equivalent to ARGV[i+1] = argv - state.PushInteger(i + 1); state.PushBuffer(argv.ReadOnlySpan); - state.RawSet(1); + state.RawSetInteger(1, i + 1); offset++; } @@ -842,16 +928,7 @@ public void RunForSession(int count, RespServerSession outerSession) state.Pop(1); } - var adapter = new RespResponseAdapter(outerSession); - - if (txnMode && nKeys > 0) - { - RunInTransaction(ref adapter); - } - else - { - RunCommon(ref adapter); - } + return 0; } /// @@ -861,44 +938,69 @@ public void RunForSession(int count, RespServerSession outerSession) /// public unsafe object RunForRunner(string[] keys = null, string[] argv = null) { - scratchBufferManager?.Reset(); - LoadParametersForRunner(keys, argv); + const int NeededStackSize = 2; - var adapter = new RunnerAdapter(scratchBufferManager); + state.ForceMinimumStackCapacity(NeededStackSize); - if (txnMode && keys?.Length > 0) + try { - // Add keys to the transaction - foreach (var key in keys) + LuaRunnerTrampolines.SetCallbackContext(this); + + try { - var _key = scratchBufferManager.CreateArgSlice(key); - txnKeyEntries.AddKey(_key, false, Tsavorite.core.LockType.Exclusive); - if (!respServerSession.storageSession.objectStoreLockableContext.IsNull) - txnKeyEntries.AddKey(_key, true, Tsavorite.core.LockType.Exclusive); + preambleKeys = keys; + preambleArgv = argv; + + state.PushCFunction(&LuaRunnerTrampolines.RunPreambleForRunner); + state.Call(0, 0); + } + finally + { + preambleKeys = preambleArgv = null; } - RunInTransaction(ref adapter); - } - else - { - RunCommon(ref adapter); - } + RunnerAdapter adapter; + if (txnMode && keys?.Length > 0) + { + // Add keys to the transaction + foreach (var key in keys) + { + var _key = scratchBufferManager.CreateArgSlice(key); + txnKeyEntries.AddKey(_key, false, Tsavorite.core.LockType.Exclusive); + if (!respServerSession.storageSession.objectStoreLockableContext.IsNull) + txnKeyEntries.AddKey(_key, true, Tsavorite.core.LockType.Exclusive); + } + + adapter = new(scratchBufferManager); + RunInTransaction(ref adapter); + } + else + { + adapter = new(scratchBufferManager); + RunCommon(ref adapter); + } - var resp = adapter.Response; - var respCur = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(resp)); - var respEnd = respCur + resp.Length; + var resp = adapter.Response; + var respCur = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(resp)); + var respEnd = respCur + resp.Length; - if (RespReadUtils.TryReadErrorAsSpan(out var errSpan, ref respCur, respEnd)) - { - var errStr = Encoding.UTF8.GetString(errSpan); - throw new GarnetException(errStr); - } + if (RespReadUtils.TryReadErrorAsSpan(out var errSpan, ref respCur, respEnd)) + { + var errStr = Encoding.UTF8.GetString(errSpan); + throw new GarnetException(errStr); + } - var ret = MapRespToObject(ref respCur, respEnd); - Debug.Assert(respCur == respEnd, "Should have fully consumed response"); + var ret = MapRespToObject(ref respCur, respEnd); + Debug.Assert(respCur == respEnd, "Should have fully consumed response"); - return ret; + return ret; + } + finally + { + LuaRunnerTrampolines.ClearCallbackContext(this); + } + // Convert a RESP response into an object to return static object MapRespToObject(ref byte* cur, byte* end) { switch (*cur) @@ -953,10 +1055,26 @@ static object MapRespToObject(ref byte* cur, byte* end) } } + /// + /// Setups a script to be run. + /// + /// If you call this directly and Lua encounters an error, the process will crash. + /// + /// Call instead. + /// + internal int UnsafeRunPreambleForRunner() + { + state.ExpectLuaStackEmpty(); + + scratchBufferManager?.Reset(); + + return LoadParametersForRunner(preambleKeys, preambleArgv); + } + /// /// Calls after setting up appropriate state for a transaction. /// - void RunInTransaction(ref TResponse response) + private void RunInTransaction(ref TResponse response) where TResponse : struct, IResponseAdapter { try @@ -990,13 +1108,12 @@ internal void ResetParameters(int nKeys, int nArgs) if (keyLength > nKeys || argvLength > nArgs) { - state.RawGetInteger(LuaType.Function, (int)LuaRegistry.Index, resetKeysAndArgvRegistryIndex); + _ = state.RawGetInteger(LuaType.Function, (int)LuaRegistry.Index, resetKeysAndArgvRegistryIndex); state.PushInteger(nKeys + 1); state.PushInteger(nArgs + 1); - var resetRes = state.PCall(2, 0); - Debug.Assert(resetRes == LuaStatus.OK, "Resetting should never fail"); + state.Call(2, 0); } keyLength = nKeys; @@ -1006,7 +1123,7 @@ internal void ResetParameters(int nKeys, int nArgs) /// /// Takes .NET strings for keys and args and pushes them into KEYS and ARGV globals. /// - void LoadParametersForRunner(string[] keys, string[] argv) + private int LoadParametersForRunner(string[] keys, string[] argv) { const int NeededStackSize = 2; @@ -1017,8 +1134,7 @@ void LoadParametersForRunner(string[] keys, string[] argv) if (keys != null) { // get KEYS on the stack - state.PushInteger(keysTableRegistryIndex); - _ = state.RawGet(LuaType.Table, (int)LuaRegistry.Index); + _ = state.RawGetInteger(LuaType.Table, (int)LuaRegistry.Index, keysTableRegistryIndex); for (var i = 0; i < keys.Length; i++) { @@ -1035,8 +1151,7 @@ void LoadParametersForRunner(string[] keys, string[] argv) if (argv != null) { // get ARGV on the stack - state.PushInteger(argvTableRegistryIndex); - _ = state.RawGet(LuaType.Table, (int)LuaRegistry.Index); + _ = state.RawGetInteger(LuaType.Table, (int)LuaRegistry.Index, argvTableRegistryIndex); for (var i = 0; i < argv.Length; i++) { @@ -1050,6 +1165,9 @@ void LoadParametersForRunner(string[] keys, string[] argv) state.Pop(1); } + return 0; + + // Convert string into a span, using buffer for storage static void PrepareString(string raw, ScratchBufferManager buffer, out ReadOnlySpan strBytes) { var maxLen = Encoding.UTF8.GetMaxByteCount(raw.Length); @@ -1066,7 +1184,7 @@ static void PrepareString(string raw, ScratchBufferManager buffer, out ReadOnlyS /// /// Runs the precompiled Lua function. /// - unsafe void RunCommon(ref TResponse resp) + private unsafe void RunCommon(ref TResponse resp) where TResponse : struct, IResponseAdapter { const int NeededStackSize = 2; @@ -1078,8 +1196,7 @@ unsafe void RunCommon(ref TResponse resp) { state.ForceMinimumStackCapacity(NeededStackSize); - state.PushInteger(functionRegistryIndex); - _ = state.RawGet(LuaType.Function, (int)LuaRegistry.Index); + _ = state.RawGetInteger(LuaType.Function, (int)LuaRegistry.Index, functionRegistryIndex); var callRes = state.PCall(0, 1); if (callRes == LuaStatus.OK) @@ -1119,7 +1236,23 @@ unsafe void RunCommon(ref TResponse resp) { // Redis does not respect metatables, so RAW access is ok here - // If the key err is in there, we need to short circuit + // If the key "ok" is in there, we need to short circuit + state.PushConstantString(okLowerConstStringRegistryIndex); + var okType = state.RawGet(null, 1); + if (okType == LuaType.String) + { + WriteString(this, ref resp); + + // Remove table from stack + state.Pop(1); + + return; + } + + // Remove whatever we read from the table under the "ok" key + state.Pop(1); + + // If the key "err" is in there, we need to short circuit state.PushConstantString(errConstStringRegistryIndex); var errType = state.RawGet(null, 1); @@ -1153,11 +1286,29 @@ unsafe void RunCommon(ref TResponse resp) } else if (state.StackTop == 1) { - if (state.CheckBuffer(1, out var errBuf)) + // PCall will put error in a string + state.KnownStringToBuffer(1, out var errBuf); + + if (errBuf.Length >= 4 && MemoryMarshal.Read("ERR "u8) == Unsafe.As(ref MemoryMarshal.GetReference(errBuf))) { + // Response came back with a ERR, already - just pass it along while (!RespWriteUtils.WriteError(errBuf, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); } + else + { + // Otherwise, this is probably a Lua error - and those aren't very descriptive + // So slap some more information in + + while (!RespWriteUtils.WriteDirect("-ERR Lua encountered an error: "u8, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + while (!RespWriteUtils.WriteDirect(errBuf, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + while (!RespWriteUtils.WriteDirect("\r\n"u8, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + } state.Pop(1); @@ -1280,7 +1431,7 @@ static void WriteArray(LuaRunner runner, ref TResponse resp) } } - while (!RespWriteUtils.WriteArrayLength((int)trueLen, ref resp.BufferCur, resp.BufferEnd)) + while (!RespWriteUtils.WriteArrayLength(trueLen, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); for (var i = 1; i <= trueLen; i++) @@ -1316,4 +1467,107 @@ static void WriteArray(LuaRunner runner, ref TResponse resp) } } } + + /// + /// Holds static functions for Lua-to-.NET interop. + /// + /// We annotate these as "unmanaged callers only" as a micro-optimization. + /// See: https://devblogs.microsoft.com/dotnet/improvements-in-native-code-interop-in-net-5-0/#unmanagedcallersonly + /// + internal static class LuaRunnerTrampolines + { + [ThreadStatic] + private static LuaRunner callbackContext; + + /// + /// Set a that will be available in trampolines. + /// + /// This assumes the same thread is used to call into Lua. + /// + /// Call when finished to avoid extending + /// the lifetime of the . + /// + internal static void SetCallbackContext(LuaRunner context) + { + Debug.Assert(callbackContext == null, "Expected null context"); + callbackContext = context; + } + + /// + /// Clear a previously set + /// + internal static void ClearCallbackContext(LuaRunner context) + { + Debug.Assert(ReferenceEquals(callbackContext, context), "Expected context to match"); + callbackContext = null; + } + + /// + /// Entry point for Lua PCall'ing into . + /// + /// We need this indirection to allow Lua to detect and report internal and memory errors + /// without crashing the process. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int CompileForRunner(nint _) + => callbackContext.UnsafeCompileForRunner(); + + /// + /// Entry point for Lua PCall'ing into . + /// + /// We need this indirection to allow Lua to detect and report internal and memory errors + /// without crashing the process. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int CompileForSession(nint _) + => callbackContext.UnsafeCompileForSession(); + + /// + /// Entry point for Lua PCall'ing into . + /// + /// We need this indirection to allow Lua to detect and report internal and memory errors + /// without crashing the process. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int RunPreambleForRunner(nint _) + => callbackContext.UnsafeRunPreambleForRunner(); + + /// + /// Entry point for Lua PCall'ing into . + /// + /// We need this indirection to allow Lua to detect and report internal and memory errors + /// without crashing the process. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int RunPreambleForSession(nint _) + => callbackContext.UnsafeRunPreambleForSession(); + + /// + /// Entry point for Lua calling back into Garnet via redis.call(...). + /// + /// This entry point is for when there isn't an active . + /// This should only happen during testing. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int GarnetCallNoSession(nint luaState) + => callbackContext.NoSessionResponse(luaState); + + /// + /// Entry point for Lua calling back into Garnet via redis.call(...). + /// + /// This entry point is for when a transaction is in effect. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int GarnetCallWithTransaction(nint luaState) + => callbackContext.GarnetCallWithTransaction(luaState); + + /// + /// Entry point for Lua calling back into Garnet via redis.call(...). + /// + /// This entry point is for when a transaction is not necessary. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int GarnetCallNoTransaction(nint luaState) + => callbackContext.GarnetCall(luaState); + } } \ No newline at end of file diff --git a/libs/server/Lua/LuaStateWrapper.cs b/libs/server/Lua/LuaStateWrapper.cs index 51ac6b0e0a..136507d15c 100644 --- a/libs/server/Lua/LuaStateWrapper.cs +++ b/libs/server/Lua/LuaStateWrapper.cs @@ -4,9 +4,11 @@ using System; using System.Diagnostics; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Text; using Garnet.common; using KeraLua; +using Microsoft.Extensions.Logging; namespace Garnet.server { @@ -20,13 +22,49 @@ internal struct LuaStateWrapper : IDisposable { private const int LUA_MINSTACK = 20; - private readonly Lua state; + private GCHandle customAllocatorHandle; + private nint state; private int curStackSize; - internal LuaStateWrapper(Lua state) + /// + /// Current top item in the stack. + /// + /// 0 implies the stack is empty. + /// + internal int StackTop { get; private set; } + + internal unsafe LuaStateWrapper(LuaMemoryManagementMode memMode, int? memLimitBytes, ILogger logger) { - this.state = state; + // As an emergency thing, we need a logger (any logger) if Lua is going to panic + // TODO: Consider a better way to do this? + LuaStateWrapperTrampolines.PanicLogger ??= logger; + + ILuaAllocator customAllocator = + (memMode, memLimitBytes) switch + { + (LuaMemoryManagementMode.Native, null) => null, + (LuaMemoryManagementMode.Tracked, _) => new LuaTrackedAllocator(memLimitBytes), + (LuaMemoryManagementMode.Managed, null) => new LuaManagedAllocator(), + (LuaMemoryManagementMode.Managed, _) => new LuaLimitedManagedAllocator(memLimitBytes.Value), + _ => throw new InvalidOperationException($"Unexpected mode/limit combination: {memMode}/{memLimitBytes}") + }; + + if (customAllocator != null) + { + customAllocatorHandle = GCHandle.Alloc(customAllocator, GCHandleType.Normal); + var stateUserData = (nint)customAllocatorHandle; + + state = NativeMethods.NewState(&LuaStateWrapperTrampolines.LuaAllocateBytes, stateUserData); + } + else + { + state = NativeMethods.NewState(); + } + + NativeMethods.OpenLibs(state); + + _ = NativeMethods.AtPanic(state, &LuaStateWrapperTrampolines.LuaAtPanic); curStackSize = LUA_MINSTACK; StackTop = 0; @@ -35,17 +73,20 @@ internal LuaStateWrapper(Lua state) } /// - public readonly void Dispose() + public void Dispose() { - state.Dispose(); - } + if (state != 0) + { + NativeMethods.Close(state); + state = 0; + } - /// - /// Current top item in the stack. - /// - /// 0 implies the stack is empty. - /// - internal int StackTop { get; private set; } + if (customAllocatorHandle.IsAllocated) + { + customAllocatorHandle.Free(); + customAllocatorHandle = default; + } + } /// /// Call when ambient state indicates that the Lua stack is in fact empty. @@ -77,7 +118,7 @@ internal void ForceMinimumStackCapacity(int additionalCapacity) } var needed = additionalCapacity - availableSpace; - if (!state.CheckStack(needed)) + if (!NativeMethods.CheckStack(state, needed)) { throw new GarnetException("Could not reserve additional capacity on the Lua stack"); } @@ -93,9 +134,9 @@ internal void ForceMinimumStackCapacity(int additionalCapacity) [MethodImpl(MethodImplOptions.AggressiveInlining)] internal void CallFromLuaEntered(IntPtr luaStatePtr) { - Debug.Assert(luaStatePtr == state.Handle, "Unexpected Lua state presented"); + Debug.Assert(luaStatePtr == state, "Unexpected Lua state presented"); - StackTop = NativeMethods.GetTop(state.Handle); + StackTop = NativeMethods.GetTop(state); curStackSize = StackTop > LUA_MINSTACK ? StackTop : LUA_MINSTACK; } @@ -107,7 +148,7 @@ internal readonly bool CheckBuffer(int index, out ReadOnlySpan str) { AssertLuaStackIndexInBounds(index); - return NativeMethods.CheckBuffer(state.Handle, index, out str); + return NativeMethods.CheckBuffer(state, index, out str); } /// @@ -118,7 +159,7 @@ internal readonly LuaType Type(int stackIndex) { AssertLuaStackIndexInBounds(stackIndex); - return NativeMethods.Type(state.Handle, stackIndex); + return NativeMethods.Type(state, stackIndex); } /// @@ -131,7 +172,7 @@ internal void PushBuffer(ReadOnlySpan buffer) { AssertLuaStackNotFull(); - NativeMethods.PushBuffer(state.Handle, buffer); + _ = ref NativeMethods.PushBuffer(state, buffer); UpdateStackTop(1); } @@ -143,7 +184,7 @@ internal void PushNil() { AssertLuaStackNotFull(); - NativeMethods.PushNil(state.Handle); + NativeMethods.PushNil(state); UpdateStackTop(1); } @@ -155,7 +196,7 @@ internal void PushInteger(long number) { AssertLuaStackNotFull(); - NativeMethods.PushInteger(state.Handle, number); + NativeMethods.PushInteger(state, number); UpdateStackTop(1); } @@ -168,7 +209,7 @@ internal void PushBoolean(bool b) { AssertLuaStackNotFull(); - NativeMethods.PushBoolean(state.Handle, b); + NativeMethods.PushBoolean(state, b); UpdateStackTop(1); } @@ -180,26 +221,27 @@ internal void PushBoolean(bool b) [MethodImpl(MethodImplOptions.AggressiveInlining)] internal void Pop(int num) { - NativeMethods.Pop(state.Handle, num); + NativeMethods.Pop(state, num); UpdateStackTop(-num); } /// - /// This should be used for all Calls into Lua. + /// This should be used for all PCalls into Lua. /// /// Maintains and to minimize p/invoke calls. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void Call(int args, int rets) + internal LuaStatus PCall(int args, int rets) { - // We have to copy this off, as once we Call curStackTop could be modified + // We have to copy this off, as once we PCall curStackTop could be modified var oldStackTop = StackTop; - state.Call(args, rets); - if (rets < 0) + var res = NativeMethods.PCall(state, args, rets); + + if (res != LuaStatus.OK || rets < 0) { - StackTop = NativeMethods.GetTop(state.Handle); + StackTop = NativeMethods.GetTop(state); AssertLuaStackExpected(); } else @@ -208,23 +250,26 @@ internal void Call(int args, int rets) var update = newPosition - StackTop; UpdateStackTop(update); } + + return res; } /// - /// This should be used for all PCalls into Lua. + /// This should be used for all Calls into Lua. /// /// Maintains and to minimize p/invoke calls. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal LuaStatus PCall(int args, int rets) + internal void Call(int args, int rets) { - // We have to copy this off, as once we Call curStackTop could be modified + // We have to copy this off, as once we PCall curStackTop could be modified var oldStackTop = StackTop; - var res = state.PCall(args, rets, 0); - if (res != LuaStatus.OK || rets < 0) + NativeMethods.Call(state, args, rets); + + if (rets < 0) { - StackTop = NativeMethods.GetTop(state.Handle); + StackTop = NativeMethods.GetTop(state); AssertLuaStackExpected(); } else @@ -233,8 +278,6 @@ internal LuaStatus PCall(int args, int rets) var update = newPosition - StackTop; UpdateStackTop(update); } - - return res; } /// @@ -247,24 +290,10 @@ internal void RawSetInteger(int stackIndex, int tableIndex) { AssertLuaStackIndexInBounds(stackIndex); - state.RawSetInteger(stackIndex, tableIndex); + NativeMethods.RawSetInteger(state, stackIndex, tableIndex); UpdateStackTop(-1); } - /// - /// This should be used for all RawSets into Lua. - /// - /// Maintains and to minimize p/invoke calls. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void RawSet(int stackIndex) - { - AssertLuaStackIndexInBounds(stackIndex); - - state.RawSet(stackIndex); - UpdateStackTop(-2); - } - /// /// This should be used for all RawGetIntegers into Lua. /// @@ -276,7 +305,7 @@ internal LuaType RawGetInteger(LuaType? expectedType, int stackIndex, int tableI AssertLuaStackIndexInBounds(stackIndex); AssertLuaStackNotFull(); - var actual = state.RawGetInteger(stackIndex, tableIndex); + var actual = NativeMethods.RawGetInteger(state, stackIndex, tableIndex); Debug.Assert(expectedType == null || actual == expectedType, "Unexpected type received"); UpdateStackTop(1); @@ -294,7 +323,7 @@ internal readonly LuaType RawGet(LuaType? expectedType, int stackIndex) { AssertLuaStackIndexInBounds(stackIndex); - var actual = state.RawGet(stackIndex); + var actual = NativeMethods.RawGet(state, stackIndex); Debug.Assert(expectedType == null || actual == expectedType, "Unexpected type received"); AssertLuaStackExpected(); @@ -310,7 +339,7 @@ internal readonly LuaType RawGet(LuaType? expectedType, int stackIndex) [MethodImpl(MethodImplOptions.AggressiveInlining)] internal int Ref() { - var ret = state.Ref(LuaRegistry.Index); + var ret = NativeMethods.Ref(state, (int)LuaRegistry.Index); UpdateStackTop(-1); return ret; @@ -323,9 +352,7 @@ internal int Ref() /// [MethodImpl(MethodImplOptions.AggressiveInlining)] internal readonly void Unref(LuaRegistry registry, int reference) - { - state.Unref(registry, reference); - } + => NativeMethods.Unref(state, (int)registry, reference); /// /// This should be used for all CreateTables into Lua. @@ -337,7 +364,7 @@ internal void CreateTable(int numArr, int numRec) { AssertLuaStackNotFull(); - state.CreateTable(numArr, numRec); + NativeMethods.CreateTable(state, numArr, numRec); UpdateStackTop(1); } @@ -346,11 +373,11 @@ internal void CreateTable(int numArr, int numRec) /// /// Maintains and to minimize p/invoke calls. /// - internal void GetGlobal(LuaType expectedType, string globalName) + internal void GetGlobal(LuaType expectedType, ReadOnlySpan nullTerminatedGlobalName) { AssertLuaStackNotFull(); - var type = state.GetGlobal(globalName); + var type = NativeMethods.GetGlobal(state, nullTerminatedGlobalName); Debug.Assert(type == expectedType, "Unexpected type received"); UpdateStackTop(1); @@ -368,7 +395,7 @@ internal LuaStatus LoadBuffer(ReadOnlySpan buffer) { AssertLuaStackNotFull(); - var ret = NativeMethods.LoadBuffer(state.Handle, buffer); + var ret = NativeMethods.LoadBuffer(state, buffer); UpdateStackTop(1); @@ -389,9 +416,9 @@ internal readonly void KnownStringToBuffer(int stackIndex, out ReadOnlySpan @@ -402,7 +429,7 @@ internal readonly double CheckNumber(int stackIndex) { AssertLuaStackIndexInBounds(stackIndex); - return state.CheckNumber(stackIndex); + return NativeMethods.CheckNumber(state, stackIndex); } /// @@ -413,7 +440,7 @@ internal readonly bool ToBoolean(int stackIndex) { AssertLuaStackIndexInBounds(stackIndex); - return NativeMethods.ToBoolean(state.Handle, stackIndex); + return NativeMethods.ToBoolean(state, stackIndex); } /// @@ -424,14 +451,29 @@ internal readonly long RawLen(int stackIndex) { AssertLuaStackIndexInBounds(stackIndex); - return state.RawLen(stackIndex); + return NativeMethods.RawLen(state, stackIndex); + } + + /// + /// This should be used for all PushCFunctions into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal unsafe void PushCFunction(delegate* unmanaged[Cdecl] function) + { + NativeMethods.PushCFunction(state, (nint)function); + UpdateStackTop(1); } /// /// Call to register a function in the Lua global namespace. /// - internal readonly void Register(string name, LuaFunction func) - => state.Register(name, func); + internal unsafe void Register(ReadOnlySpan nullTerminatedName, delegate* unmanaged[Cdecl] function) + { + PushCFunction(function); + + NativeMethods.SetGlobal(state, nullTerminatedName); + UpdateStackTop(-1); + } /// /// This should be used to push all known constants strings into Lua. @@ -439,7 +481,7 @@ internal readonly void Register(string name, LuaFunction func) /// This avoids extra copying of data between .NET and Lua. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void PushConstantString(int constStringRegistryIndex, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) + internal void PushConstantString(int constStringRegistryIndex) => RawGetInteger(LuaType.String, (int)LuaRegistry.Index, constStringRegistryIndex); // Rarely used @@ -449,7 +491,7 @@ internal void PushConstantString(int constStringRegistryIndex, [CallerFilePath] /// internal void ClearStack() { - state.SetTop(0); + NativeMethods.SetTop(state, 0); StackTop = 0; AssertLuaStackExpected(); @@ -463,6 +505,7 @@ internal int RaiseError(string msg) ClearStack(); var b = Encoding.UTF8.GetBytes(msg); + PushBuffer(b); return RaiseErrorFromStack(); } @@ -473,7 +516,7 @@ internal readonly int RaiseErrorFromStack() { Debug.Assert(StackTop != 0, "Expected error message on the stack"); - return state.Error(); + return NativeMethods.Error(state); } /// @@ -492,7 +535,6 @@ private void UpdateStackTop(int by) /// Check that the given index refers to a valid part of the stack. /// [Conditional("DEBUG")] - [MethodImpl(MethodImplOptions.NoInlining)] private readonly void AssertLuaStackIndexInBounds(int stackIndex) { Debug.Assert(stackIndex == (int)LuaRegistry.Index || (stackIndex > 0 && stackIndex <= StackTop), "Lua stack index out of bounds"); @@ -502,20 +544,133 @@ private readonly void AssertLuaStackIndexInBounds(int stackIndex) /// Check that the Lua stack top is where expected in DEBUG builds. /// [Conditional("DEBUG")] - [MethodImpl(MethodImplOptions.NoInlining)] private readonly void AssertLuaStackExpected() { - Debug.Assert(NativeMethods.GetTop(state.Handle) == StackTop, "Lua stack not where expected"); + Debug.Assert(NativeMethods.GetTop(state) == StackTop, "Lua stack not where expected"); } /// /// Check that there's space to push some number of elements. /// [Conditional("DEBUG")] - [MethodImpl(MethodImplOptions.NoInlining)] private readonly void AssertLuaStackNotFull(int probe = 1) { Debug.Assert((StackTop + probe) <= curStackSize, "Lua stack should have been grown before pushing"); } } + + /// + /// Holds static functions for Lua-to-.NET interop. + /// + /// We annotate these as "unmanaged callers only" as a micro-optimization. + /// See: https://devblogs.microsoft.com/dotnet/improvements-in-native-code-interop-in-net-5-0/#unmanagedcallersonly + /// + internal static class LuaStateWrapperTrampolines + { + /// + /// Controls the singular logger that will be used for panic invocations. + /// + /// Because Lua is panic'ing, we're about to crash. This being process wide is hacky, + /// but we're in hacky situtation. + /// + internal static ILogger PanicLogger { get; set; } + + /// + /// Called when Lua encounters an unrecoverable error. + /// + /// When this returns, Lua is going to terminate the process, so plan accordingly. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int LuaAtPanic(nint luaState) + { + // See: https://www.lua.org/manual/5.4/manual.html#4.4 + + string errorMsg; + + var stackTop = NativeMethods.GetTop(luaState); + if (stackTop >= 1) + { + if (NativeMethods.CheckBuffer(luaState, stackTop, out var msgBuf)) + { + errorMsg = Encoding.UTF8.GetString(msgBuf); + } + else + { + var type = NativeMethods.Type(luaState, stackTop); + + errorMsg = $"Unexpected error type: {type}"; + } + } + else + { + errorMsg = "No error on stack"; + } + + PanicLogger?.LogCritical("Lua Panic '{errorMsg}', stack size {stackTop}", errorMsg, stackTop); + + return 0; + } + + /// + /// Provides data for Lua allocations. + /// + /// All returns data must be PINNED or unmanaged, Lua does not allow it to move. + /// + /// If allocation cannot be performed, null is returned. + /// + /// Pointer to user data provided during + /// Either null (if new alloc) or pointer to existing allocation being resized or freed. + /// If is not null, the value passed when allocation was obtained or resized. + /// The desired size of the allocation, in bytes. + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static unsafe nint LuaAllocateBytes(nint udPtr, nint ptr, nuint osize, nuint nsize) + { + // See: https://www.lua.org/manual/5.4/manual.html#lua_Alloc + + var handle = (GCHandle)udPtr; + Debug.Assert(handle.IsAllocated, "GCHandle should always be valid"); + + var customAllocator = (ILuaAllocator)handle.Target; + + if (ptr != IntPtr.Zero) + { + // Now osize is the size used to (re)allocate ptr last + + ref var dataRef = ref Unsafe.AsRef((void*)ptr); + + if (nsize == 0) + { + customAllocator.Free(ref dataRef, (int)osize); + + return 0; + } + else + { + ref var ret = ref customAllocator.ResizeAllocation(ref dataRef, (int)osize, (int)nsize, out var failed); + if (failed) + { + return 0; + } + + var retPtr = (nint)Unsafe.AsPointer(ref ret); + + return retPtr; + } + } + else + { + // Now osize is the size of the object being allocated, but nsize is the desired size + + ref var ret = ref customAllocator.AllocateNew((int)nsize, out var failed); + if (failed) + { + return 0; + } + + var retPtr = (nint)Unsafe.AsPointer(ref ret); + + return retPtr; + } + } + } } \ No newline at end of file diff --git a/libs/server/Lua/LuaTrackedAllocator.cs b/libs/server/Lua/LuaTrackedAllocator.cs new file mode 100644 index 0000000000..f66ff502c1 --- /dev/null +++ b/libs/server/Lua/LuaTrackedAllocator.cs @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Garnet.server +{ + /// + /// Lua allocator which still uses native memory, but tracks it and can impose limits. + /// + /// is also used to inform .NET host of native allocations. + /// + internal unsafe class LuaTrackedAllocator : ILuaAllocator + { + private readonly int? limitBytes; + + private int allocatedBytes; + + internal LuaTrackedAllocator(int? limitBytes) + { + this.limitBytes = limitBytes; + } + + /// + public ref byte AllocateNew(int sizeBytes, out bool failed) + { + if (limitBytes != null) + { + if (allocatedBytes + sizeBytes > limitBytes) + { + failed = true; + return ref Unsafe.NullRef(); + } + + allocatedBytes += sizeBytes; + } + + if (sizeBytes > 0) + { + GC.AddMemoryPressure(sizeBytes); + } + + failed = false; + + var ptr = NativeMemory.Alloc((nuint)sizeBytes); + return ref Unsafe.AsRef(ptr); + } + + /// + public void Free(ref byte start, int sizeBytes) + { + NativeMemory.Free(Unsafe.AsPointer(ref start)); + + if (sizeBytes > 0) + { + allocatedBytes -= sizeBytes; + + GC.RemoveMemoryPressure(sizeBytes); + } + } + + /// + public ref byte ResizeAllocation(ref byte start, int oldSizeBytes, int newSizeBytes, out bool failed) + { + var delta = newSizeBytes - oldSizeBytes; + + if (delta == 0) + { + failed = false; + return ref start; + } + + if (limitBytes != null) + { + if (allocatedBytes + delta > limitBytes) + { + failed = true; + return ref Unsafe.NullRef(); + } + + allocatedBytes += delta; + } + + if (delta < 0) + { + GC.RemoveMemoryPressure(-delta); + } + else + { + GC.AddMemoryPressure(delta); + } + + failed = false; + + var ptr = NativeMemory.Realloc(Unsafe.AsPointer(ref start), (nuint)newSizeBytes); + return ref Unsafe.AsRef(ptr); + } + } +} \ No newline at end of file diff --git a/libs/server/Lua/NativeMethods.cs b/libs/server/Lua/NativeMethods.cs index cc6bcf085b..76a9fda963 100644 --- a/libs/server/Lua/NativeMethods.cs +++ b/libs/server/Lua/NativeMethods.cs @@ -2,10 +2,13 @@ // Licensed under the MIT license. using System; +using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using KeraLua; using charptr_t = nint; +using intptr_t = nint; +using lua_CFunction = nint; using lua_State = nint; using size_t = nuint; @@ -42,6 +45,139 @@ internal static partial class NativeMethods [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] private static partial LuaStatus luaL_loadbufferx(lua_State luaState, charptr_t buff, size_t sz, charptr_t name, charptr_t mode); + /// + /// see: https://www.lua.org/manual/5.4/manual.html#luaL_newstate + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial nint luaL_newstate(); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_newstate + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial nint lua_newstate(lua_CFunction allocFunc, charptr_t ud); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#luaL_openlibs + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial void luaL_openlibs(lua_State luaState); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_close + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial void lua_close(lua_State luaState); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_checkstack + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial int lua_checkstack(lua_State luaState, int n); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#luaL_checknumber + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial double luaL_checknumber(lua_State luaState, int n); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_rawlen + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial int lua_rawlen(lua_State luaState, int n); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_pcallk + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial int lua_pcallk(lua_State luaState, int nargs, int nresults, int msgh, nint ctx, nint k); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_callk + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial int lua_callk(lua_State luaState, int nargs, int nresults, nint ctx, nint k); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_rawseti + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial void lua_rawseti(lua_State luaState, int index, long i); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_rawset + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial void lua_rawset(lua_State luaState, int index); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_rawgeti + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial int lua_rawgeti(lua_State luaState, int index, long n); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_rawget + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial int lua_rawget(lua_State luaState, int index); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#luaL_ref + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial int luaL_ref(lua_State luaState, int index); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#luaL_unref + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial int luaL_unref(lua_State luaState, int index, int refVal); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_createtable + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial void lua_createtable(lua_State luaState, int narr, int nrec); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_getglobal + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial int lua_getglobal(lua_State luaState, charptr_t name); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_setglobal + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial void lua_setglobal(lua_State luaState, charptr_t name); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_error + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial int lua_error(lua_State luaState); + // GC Transition suppressed - only do this after auditing the Lua method and confirming constant-ish, fast, runtime w/o allocations /// @@ -51,7 +187,7 @@ internal static partial class NativeMethods /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_gettop /// [LibraryImport(LuaLibraryName)] - [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])] private static partial int lua_gettop(lua_State luaState); /// @@ -63,7 +199,7 @@ internal static partial class NativeMethods /// see: https://www.lua.org/source/5.4/lapi.c.html#index2value /// [LibraryImport(LuaLibraryName)] - [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])] private static partial LuaType lua_type(lua_State L, int index); /// @@ -73,7 +209,7 @@ internal static partial class NativeMethods /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_pushnil /// [LibraryImport(LuaLibraryName)] - [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])] private static partial void lua_pushnil(lua_State L); /// @@ -83,7 +219,7 @@ internal static partial class NativeMethods /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_pushinteger /// [LibraryImport(LuaLibraryName)] - [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])] private static partial void lua_pushinteger(lua_State L, long num); /// @@ -93,7 +229,7 @@ internal static partial class NativeMethods /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_pushboolean /// [LibraryImport(LuaLibraryName)] - [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])] private static partial void lua_pushboolean(lua_State L, int b); /// @@ -103,9 +239,20 @@ internal static partial class NativeMethods /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_toboolean /// [LibraryImport(LuaLibraryName)] - [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])] private static partial int lua_toboolean(lua_State L, int ix); + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_tointegerx + /// + /// We should always have checked this is actually a number before calling, + /// so the expensive paths won't be taken. + /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_tointegerx + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])] + private static partial long lua_tointegerx(lua_State L, int idex, intptr_t pisnum); + /// /// see: https://www.lua.org/manual/5.4/manual.html#lua_settop /// @@ -113,9 +260,31 @@ internal static partial class NativeMethods /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_settop /// [LibraryImport(LuaLibraryName)] - [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])] private static partial void lua_settop(lua_State L, int num); + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_atpanic + /// + /// Just changing a global value, should be quick. + /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_atpanic + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])] + private static partial lua_CFunction lua_atpanic(lua_State luaState, lua_CFunction panicf); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_pushcclosure + /// + /// We never call this with n != 0, so does very little. + /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_pushcclosure + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])] + private static partial void lua_pushcclosure(lua_State luaState, lua_CFunction fn, int n); + + // Helper methods for using the pinvokes defined above + /// /// Returns true if the given index on the stack holds a string or a number. /// @@ -129,19 +298,15 @@ internal static partial class NativeMethods /// internal static bool CheckBuffer(lua_State luaState, int index, out ReadOnlySpan str) { - var type = lua_type(luaState, index); - - if (type is not LuaType.String and not LuaType.Number) - { - str = []; - return false; - } - + // See: https://www.lua.org/source/5.4/lapi.c.html#lua_tolstring + // + // If lua_tolstring fails, it will set len == 0 and start == NULL var start = lua_tolstring(luaState, index, out var len); + unsafe { str = new ReadOnlySpan((byte*)start, (int)len); - return true; + return start != (charptr_t)(void*)null; } } @@ -168,12 +333,15 @@ internal static void KnownStringToBuffer(lua_State luaState, int index, out Read /// /// Provided data is copied, and can be reused once this call returns. /// - internal static unsafe void PushBuffer(lua_State luaState, ReadOnlySpan str) + internal static unsafe ref byte PushBuffer(lua_State luaState, ReadOnlySpan str) { + nint inLuaPtr; fixed (byte* ptr = str) { - _ = lua_pushlstring(luaState, (charptr_t)ptr, (size_t)str.Length); + inLuaPtr = lua_pushlstring(luaState, (charptr_t)ptr, (size_t)str.Length); } + + return ref Unsafe.AsRef((void*)inLuaPtr); } /// @@ -239,6 +407,14 @@ internal static void PushBoolean(lua_State luaState, bool b) internal static bool ToBoolean(lua_State luaState, int index) => lua_toboolean(luaState, index) != 0; + /// + /// Read a long off the stack + /// + /// Differs from by suppressing GC transition. + /// + internal static long ToInteger(lua_State luaState, int index) + => lua_tointegerx(luaState, index, 0); + /// /// Remove some number of items from the stack. /// @@ -246,5 +422,172 @@ internal static bool ToBoolean(lua_State luaState, int index) /// internal static void Pop(lua_State luaState, int num) => lua_settop(luaState, -num - 1); + + /// + /// Update the panic function. + /// + /// Differs from by taking a function pointer + /// and suppressing GC transition. + /// + internal static unsafe nint AtPanic(lua_State luaState, delegate* unmanaged[Cdecl] panicFunc) + => lua_atpanic(luaState, (nint)panicFunc); + + /// + /// Create a new Lua state. + /// + internal static unsafe nint NewState() + => luaL_newstate(); + + /// + /// Create a new Lua state. + /// + /// Differs from by taking a function pointer. + /// + internal static unsafe nint NewState(delegate* unmanaged[Cdecl] allocFunc, nint ud) + => lua_newstate((nint)allocFunc, ud); + + /// + /// Open all standard Lua libraries. + /// + internal static void OpenLibs(lua_State luaState) + => luaL_openlibs(luaState); + + /// + /// Close the state, releasing all associated resources. + /// + internal static void Close(lua_State luaState) + => lua_close(luaState); + + /// + /// Reserve space on the stack, returning false if that was not possible. + /// + internal static bool CheckStack(lua_State luaState, int n) + => lua_checkstack(luaState, n) == 1; + + /// + /// Read a number, as a double, out of the stack. + /// + internal static double CheckNumber(lua_State luaState, int n) + => luaL_checknumber(luaState, n); + + /// + /// Gets the length of an object on the stack, ignoring metatable methods. + /// + internal static int RawLen(lua_State luaState, int n) + => lua_rawlen(luaState, n); + + /// + /// Push a function onto the stack. + /// + internal static void PushCFunction(lua_State luaState, nint ptr) + => lua_pushcclosure(luaState, ptr, 0); + + /// + /// Perform a protected call with the given number of arguments, expecting the given number of returns. + /// + internal static LuaStatus PCall(lua_State luaState, int nargs, int nrets) + => (LuaStatus)lua_pcallk(luaState, nargs, nrets, 0, 0, 0); + + /// + /// Perform a call with the given number of arguments, expecting the given number of returns. + /// + internal static void Call(lua_State luaState, int nargs, int nrets) + => lua_callk(luaState, nargs, nrets, 0, 0); + + /// + /// Equivalent of t[i] = v, where t is the table at the given index and v is the value on the top of the stack. + /// + /// Ignores metatable methods. + /// + internal static void RawSetInteger(lua_State luaState, int index, long i) + => lua_rawseti(luaState, index, i); + + /// + /// Equivalent to t[k] = v, where t is the value at the given index, v is the value on the top of the stack, and k is the value just below the top. + /// + /// Ignores metatable methods. + /// + internal static void RawSet(lua_State luaState, int index) + => lua_rawset(luaState, index); + + /// + /// Pushes onto the stack the value t[n], where t is the table at the given index. + /// + /// Ignores metatable methods. + /// + internal static LuaType RawGetInteger(lua_State luaState, int index, long n) + => (LuaType)lua_rawgeti(luaState, index, n); + + /// + /// Pushes onto the stack the value t[k], where t is the value at the given index and k is the value on the top of the stack. + /// + /// Ignores metatable methods. + /// + internal static LuaType RawGet(lua_State luaState, int index) + => (LuaType)lua_rawget(luaState, index); + + /// + /// Creates and reference a reference in that table at the given index. + /// + internal static int Ref(lua_State luaState, int index) + => luaL_ref(luaState, index); + + /// + /// Free a ref previously created with . + /// + internal static int Unref(lua_State luaState, int index, int refVal) + => luaL_unref(luaState, index, refVal); + + /// + /// Create a new table and push it on the stack. + /// + /// Reserves capacity for the given number of elements and hints at capacity for non-sequence records. + /// + internal static void CreateTable(lua_State luaState, int elements, int records) + => lua_createtable(luaState, elements, records); + + /// + /// Load a global under the given name onto the stack. + /// + internal static unsafe LuaType GetGlobal(lua_State luaState, ReadOnlySpan nullTerminatedName) + { + Debug.Assert(nullTerminatedName[^1] == 0, "Global name must be null terminated"); + + fixed (byte* ptr = nullTerminatedName) + { + return (LuaType)lua_getglobal(luaState, (nint)ptr); + } + } + + /// + /// Pops the top item on the stack, and stores it under the given name as a global. + /// + internal static unsafe void SetGlobal(lua_State luaState, ReadOnlySpan nullTerminatedName) + { + Debug.Assert(nullTerminatedName[^1] == 0, "Global name must be null terminated"); + + fixed (byte* ptr = nullTerminatedName) + { + lua_setglobal(luaState, (nint)ptr); + } + } + + /// + /// Sets the index of the top elements on the stack. + /// + /// 0 == empty + /// + /// Items above this point can no longer be safely accessed. + /// + internal static void SetTop(lua_State lua_State, int top) + => lua_settop(lua_State, top); + + /// + /// Raise an error, using the top of the stack as an error item. + /// + /// This method never returns, so be careful calling it. + /// + internal static int Error(lua_State luaState) + => lua_error(luaState); } } \ No newline at end of file diff --git a/libs/server/Lua/SessionScriptCache.cs b/libs/server/Lua/SessionScriptCache.cs index 4dfd1ac296..4d48154f83 100644 --- a/libs/server/Lua/SessionScriptCache.cs +++ b/libs/server/Lua/SessionScriptCache.cs @@ -5,11 +5,9 @@ using System.Collections.Generic; using System.Diagnostics; using System.Security.Cryptography; -using Garnet.common; using Garnet.server.ACL; using Garnet.server.Auth; using Microsoft.Extensions.Logging; -using Tsavorite.core; namespace Garnet.server { @@ -28,6 +26,9 @@ internal sealed class SessionScriptCache : IDisposable readonly Dictionary scriptCache = []; readonly byte[] hash = new byte[SHA1Len / 2]; + readonly LuaMemoryManagementMode memoryManagementMode; + readonly int? memoryLimitBytes; + public SessionScriptCache(StoreWrapper storeWrapper, IGarnetAuthenticator authenticator, ILogger logger = null) { this.storeWrapper = storeWrapper; @@ -35,6 +36,10 @@ public SessionScriptCache(StoreWrapper storeWrapper, IGarnetAuthenticator authen scratchBufferNetworkSender = new ScratchBufferNetworkSender(); processor = new RespServerSession(0, scratchBufferNetworkSender, storeWrapper, null, authenticator, false); + + // There's some parsing involved in these, so save them off per-session + memoryManagementMode = storeWrapper.serverOptions.LuaOptions.MemoryManagementMode; + memoryLimitBytes = storeWrapper.serverOptions.LuaOptions.GetMemoryLimitBytes(); } public void Dispose() @@ -74,21 +79,31 @@ internal bool TryLoad(RespServerSession session, ReadOnlySpan source, Scri { var sourceOnHeap = source.ToArray(); - runner = new LuaRunner(sourceOnHeap, storeWrapper.serverOptions.LuaTransactionMode, processor, scratchBufferNetworkSender, logger); - runner.CompileForSession(session); - - // Need to make sure the key is on the heap, so move it over - // - // There's an implicit assumption that all callers are using unmanaged memory. - // If that becomes untrue, there's an optimization opportunity to re-use the - // managed memory here. - var into = GC.AllocateUninitializedArray(SHA1Len, pinned: true); - digest.CopyTo(into); - - ScriptHashKey storeKeyDigest = new(into); - digestOnHeap = storeKeyDigest; - - _ = scriptCache.TryAdd(storeKeyDigest, runner); + runner = new LuaRunner(memoryManagementMode, memoryLimitBytes, sourceOnHeap, storeWrapper.serverOptions.LuaTransactionMode, processor, scratchBufferNetworkSender, logger); + + // If compilation fails, an error is written out + if (runner.CompileForSession(session)) + { + // Need to make sure the key is on the heap, so move it over + // + // There's an implicit assumption that all callers are using unmanaged memory. + // If that becomes untrue, there's an optimization opportunity to re-use the + // managed memory here. + var into = GC.AllocateUninitializedArray(SHA1Len, pinned: true); + digest.CopyTo(into); + + ScriptHashKey storeKeyDigest = new(into); + digestOnHeap = storeKeyDigest; + + _ = scriptCache.TryAdd(storeKeyDigest, runner); + } + else + { + runner.Dispose(); + + digestOnHeap = null; + return false; + } } catch (Exception ex) { diff --git a/libs/server/Resp/CmdStrings.cs b/libs/server/Resp/CmdStrings.cs index cb32cfc432..8f32e0bbea 100644 --- a/libs/server/Resp/CmdStrings.cs +++ b/libs/server/Resp/CmdStrings.cs @@ -353,6 +353,7 @@ static partial class CmdStrings // Lua scripting strings public static ReadOnlySpan LUA_OK => "OK"u8; + public static ReadOnlySpan LUA_ok => "ok"u8; public static ReadOnlySpan LUA_err => "err"u8; public static ReadOnlySpan LUA_No_session_available => "No session available"u8; public static ReadOnlySpan LUA_ERR_Please_specify_at_least_one_argument_for_this_redis_lib_call => "ERR Please specify at least one argument for this redis lib call"u8; diff --git a/libs/server/Resp/RespServerSession.cs b/libs/server/Resp/RespServerSession.cs index d0e357e2bf..aaf064f57a 100644 --- a/libs/server/Resp/RespServerSession.cs +++ b/libs/server/Resp/RespServerSession.cs @@ -394,6 +394,21 @@ public override int TryConsumeMessages(byte* reqBuffer, int bytesReceived) return readHead; } + /// + /// For testing purposes, call and update state accordingly. + /// + internal void EnterAndGetResponseObject() + => networkSender.EnterAndGetResponseObject(out dcurr, out dend); + + /// + /// For testing purposes, call and update state accordingly. + /// + internal void ExitAndReturnResponseObject() + { + networkSender.ExitAndReturnResponseObject(); + dcurr = dend = (byte*)0; + } + internal void SetTransactionMode(bool enable) => txnManager.state = enable ? TxnState.Running : TxnState.None; diff --git a/libs/server/Servers/GarnetServerOptions.cs b/libs/server/Servers/GarnetServerOptions.cs index 05dce6ca04..23d310842b 100644 --- a/libs/server/Servers/GarnetServerOptions.cs +++ b/libs/server/Servers/GarnetServerOptions.cs @@ -404,6 +404,8 @@ public class GarnetServerOptions : ServerOptions public bool EnableObjectStoreReadCache = false; + public LuaOptions LuaOptions; + /// /// Constructor /// diff --git a/libs/server/Servers/ServerOptions.cs b/libs/server/Servers/ServerOptions.cs index 342b563020..efa6208859 100644 --- a/libs/server/Servers/ServerOptions.cs +++ b/libs/server/Servers/ServerOptions.cs @@ -228,7 +228,7 @@ public void GetSettings() /// /// /// - protected static long ParseSize(string value) + public static long ParseSize(string value) { char[] suffix = ['k', 'm', 'g', 't', 'p']; long result = 0; diff --git a/test/BDNPerfTests/BDN_Benchmark_Config.json b/test/BDNPerfTests/BDN_Benchmark_Config.json index fcb08f23b7..89ee51f692 100644 --- a/test/BDNPerfTests/BDN_Benchmark_Config.json +++ b/test/BDNPerfTests/BDN_Benchmark_Config.json @@ -85,10 +85,90 @@ "expected_MSet_None": 0 }, "BDN.benchmark.Lua.LuaScripts.*": { - "expected_Script1_None": 0, - "expected_Script2_None": 24, - "expected_Script3_None": 32, - "expected_Script4_None": 0 + "expected_Script1_Managed,Limit": 0, + "expected_Script2_Managed,Limit": 24, + "expected_Script3_Managed,Limit": 32, + "expected_Script4_Managed,Limit": 0, + "expected_Script1_Managed,None": 0, + "expected_Script2_Managed,None": 24, + "expected_Script3_Managed,None": 32, + "expected_Script4_Managed,None": 0, + "expected_Script1_Native,None": 0, + "expected_Script2_Native,None": 24, + "expected_Script3_Native,None": 32, + "expected_Script4_Native,None": 0, + "expected_Script1_Tracked,Limit": 0, + "expected_Script2_Tracked,Limit": 24, + "expected_Script3_Tracked,Limit": 32, + "expected_Script4_Tracked,Limit": 0, + "expected_Script1_Tracked,None": 0, + "expected_Script2_Tracked,None": 24, + "expected_Script3_Tracked,None": 32, + "expected_Script4_Tracked,None": 0 + }, + "BDN.benchmark.Lua.LuaRunnerOperations.*": { + "expected_ResetParametersSmall_Managed,Limit": 0, + "expected_ResetParametersLarge_Managed,Limit": 0, + "expected_ConstructSmall_Managed,Limit": 2097602, + "expected_ConstructLarge_Managed,Limit": 2100666, + "expected_CompileForSessionSmall_Managed,Limit": 99, + "expected_CompileForSessionLarge_Managed,Limit": 0, + "expected_ResetParametersSmall_Managed,None": 0, + "expected_ResetParametersLarge_Managed,None": 0, + "expected_ConstructSmall_Managed,None": 2097674, + "expected_ConstructLarge_Managed,None": 2100738, + "expected_CompileForSessionSmall_Managed,None": 512, + "expected_CompileForSessionLarge_Managed,None": 0, + "expected_ResetParametersSmall_Native,None": 0, + "expected_ResetParametersLarge_Native,None": 0, + "expected_ConstructSmall_Native,None": 328, + "expected_ConstructLarge_Native,None": 3392, + "expected_CompileForSessionSmall_Native,None": 0, + "expected_CompileForSessionLarge_Native,None": 0, + "expected_ResetParametersSmall_Tracked,Limit": 0, + "expected_ResetParametersLarge_Tracked,Limit": 0, + "expected_ConstructSmall_Tracked,Limit": 402, + "expected_ConstructLarge_Tracked,Limit": 3465, + "expected_CompileForSessionSmall_Tracked,Limit": 0, + "expected_CompileForSessionLarge_Tracked,Limit": 0, + "expected_ResetParametersSmall_Tracked,None": 0, + "expected_ResetParametersLarge_Tracked,None": 0, + "expected_ConstructSmall_Tracked,None": 362, + "expected_ConstructLarge_Tracked,None": 3425, + "expected_CompileForSessionSmall_Tracked,None": 0, + "expected_CompileForSessionLarge_Tracked,None": 0 + }, + "BDN.benchmark.Lua.LuaScriptCacheOperations.*": { + "expected_LookupHit_Managed,Limit": 1312, + "expected_LookupMiss_Managed,Limit": 688, + "expected_LoadOuterHit_Managed,Limit": 688, + "expected_LoadInnerHit_Managed,Limit": 2098272, + "expected_LoadMiss_Managed,Limit": 1312, + "expected_Digest_Managed,Limit": 1312, + "expected_LookupHit_Managed,None": 688, + "expected_LookupMiss_Managed,None": 1312, + "expected_LoadOuterHit_Managed,None": 688, + "expected_LoadInnerHit_Managed,None": 2097760, + "expected_LoadMiss_Managed,None": 1312, + "expected_Digest_Managed,None": 688, + "expected_LookupHit_Native,None": 1312, + "expected_LookupMiss_Native,None": 1312, + "expected_LoadOuterHit_Native,None": 688, + "expected_LoadInnerHit_Native,None": 1040, + "expected_LoadMiss_Native,None": 1312, + "expected_Digest_Native,None": 688, + "expected_LookupHit_Tracked,Limit": 688, + "expected_LookupMiss_Tracked,Limit": 1312, + "expected_LoadOuterHit_Tracked,Limit": 688, + "expected_LoadInnerHit_Tracked,Limit": 1072, + "expected_LoadMiss_Tracked,Limit": 1264, + "expected_Digest_Tracked,Limit": 1312, + "expected_LookupHit_Tracked,None": 688, + "expected_LookupMiss_Tracked,None": 1312, + "expected_LoadOuterHit_Tracked,None": 1312, + "expected_LoadInnerHit_Tracked,None": 1696, + "expected_LoadMiss_Tracked,None": 1312, + "expected_Digest_Tracked,None": 1312 }, "BDN.benchmark.Operations.CustomOperations.*": { "expected_CustomRawStringCommand_ACL": 0, @@ -159,7 +239,7 @@ "expected_Eval_None": 0, "expected_EvalSha_None": 0, "expected_SmallScript_None": 0, - "expected_LargeScript_None": 12, + "expected_LargeScript_None": 23, "expected_ArrayReturn_None": 0 }, "BDN.benchmark.Network.RawStringOperations.*": { diff --git a/test/BDNPerfTests/run_bdnperftest.ps1 b/test/BDNPerfTests/run_bdnperftest.ps1 index 5c88e0be87..afece51daa 100644 --- a/test/BDNPerfTests/run_bdnperftest.ps1 +++ b/test/BDNPerfTests/run_bdnperftest.ps1 @@ -73,6 +73,11 @@ param ($ResultsLine, $columnNum) $columns = $ResultsLine.Trim('|').Split('|') $column = $columns | ForEach-Object { $_.Trim() } $foundValue = $column[$columnNum] + if ($foundValue -eq "NA") { + Write-Error -Message "The value for the column was NA which means that the BDN test failed and didn't generate performance metrics. Verify the BDN test ran successfully." + exit + } + if ($foundValue -eq "-") { $foundValue = "0" } diff --git a/test/Garnet.test.cluster/ClusterTestContext.cs b/test/Garnet.test.cluster/ClusterTestContext.cs index 8b1529879e..8bc92d8668 100644 --- a/test/Garnet.test.cluster/ClusterTestContext.cs +++ b/test/Garnet.test.cluster/ClusterTestContext.cs @@ -126,7 +126,9 @@ public void CreateInstances( bool disablePubSub = true, int metricsSamplingFrequency = 0, bool enableLua = false, - bool asyncReplay = false) + bool asyncReplay = false, + LuaMemoryManagementMode luaMemoryMode = LuaMemoryManagementMode.Native, + string luaMemoryLimit = "") { endpoints = TestUtils.GetEndPoints(shards, 7000); nodes = TestUtils.CreateGarnetCluster( @@ -159,7 +161,9 @@ public void CreateInstances( authenticationSettings: authenticationSettings, metricsSamplingFrequency: metricsSamplingFrequency, enableLua: enableLua, - asyncReplay: asyncReplay); + asyncReplay: asyncReplay, + luaMemoryMode: luaMemoryMode, + luaMemoryLimit: luaMemoryLimit); foreach (var node in nodes) node.Start(); diff --git a/test/Garnet.test/GarnetServerConfigTests.cs b/test/Garnet.test/GarnetServerConfigTests.cs index adbc02d881..b3aac12f40 100644 --- a/test/Garnet.test/GarnetServerConfigTests.cs +++ b/test/Garnet.test/GarnetServerConfigTests.cs @@ -10,6 +10,7 @@ using System.Text.Json.Serialization; using CommandLine; using Garnet.common; +using Garnet.server; using Microsoft.Extensions.Logging; using NUnit.Framework; using NUnit.Framework.Legacy; @@ -246,5 +247,80 @@ public void ImportExportConfigAzure() deviceFactory.Delete(new FileDescriptor { directoryName = "" }); } } + + [Test] + public void LuaMemoryOptions() + { + // Defaults to Native with no limit + { + var args = new[] { "--lua" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsTrue(options.EnableLua); + ClassicAssert.AreEqual(LuaMemoryManagementMode.Native, options.LuaMemoryManagementMode); + ClassicAssert.IsNull(options.LuaScriptMemoryLimit); + } + + // Native with limit rejected + { + var args = new[] { "--lua", "--lua-script-memory-limit", "10m" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsFalse(parseSuccessful); + } + + // Tracked with no limit works + { + var args = new[] { "--lua", "--lua-memory-management-mode", "Tracked" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsTrue(options.EnableLua); + ClassicAssert.AreEqual(LuaMemoryManagementMode.Tracked, options.LuaMemoryManagementMode); + ClassicAssert.IsNull(options.LuaScriptMemoryLimit); + } + + // Tracked with limit works + { + var args = new[] { "--lua", "--lua-memory-management-mode", "Tracked", "--lua-script-memory-limit", "10m" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsTrue(options.EnableLua); + ClassicAssert.AreEqual(LuaMemoryManagementMode.Tracked, options.LuaMemoryManagementMode); + ClassicAssert.AreEqual("10m", options.LuaScriptMemoryLimit); + } + + // Tracked with bad limit rejected + { + var args = new[] { "--lua", "--lua-memory-management-mode", "Tracked", "--lua-script-memory-limit", "10Q" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsFalse(parseSuccessful); + } + + // Managed with no limit works + { + var args = new[] { "--lua", "--lua-memory-management-mode", "Managed" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsTrue(options.EnableLua); + ClassicAssert.AreEqual(LuaMemoryManagementMode.Managed, options.LuaMemoryManagementMode); + ClassicAssert.IsNull(options.LuaScriptMemoryLimit); + } + + // Managed with limit works + { + var args = new[] { "--lua", "--lua-memory-management-mode", "Managed", "--lua-script-memory-limit", "10m" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsTrue(options.EnableLua); + ClassicAssert.AreEqual(LuaMemoryManagementMode.Managed, options.LuaMemoryManagementMode); + ClassicAssert.AreEqual("10m", options.LuaScriptMemoryLimit); + } + + // Managed with bad limit rejected + { + var args = new[] { "--lua", "--lua-memory-management-mode", "Managed", "--lua-script-memory-limit", "10Q" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsFalse(parseSuccessful); + } + } } } \ No newline at end of file diff --git a/test/Garnet.test/LuaScriptRunnerTests.cs b/test/Garnet.test/LuaScriptRunnerTests.cs index b7a7a7f0f7..7a9bef4ae8 100644 --- a/test/Garnet.test/LuaScriptRunnerTests.cs +++ b/test/Garnet.test/LuaScriptRunnerTests.cs @@ -1,6 +1,12 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using Garnet.common; using Garnet.server; using NUnit.Framework; @@ -15,51 +21,51 @@ internal class LuaScriptRunnerTests public void CannotRunUnsafeScript() { // Try to load an assembly - using (var runner = new LuaRunner("luanet.load_assembly('mscorlib')")) + using (var runner = new LuaRunner(new(), "luanet.load_assembly('mscorlib')")) { runner.CompileForRunner(); var ex = Assert.Throws(() => runner.RunForRunner()); - ClassicAssert.AreEqual("[string \"luanet.load_assembly('mscorlib')\"]:1: attempt to index a nil value (global 'luanet')", ex.Message); + ClassicAssert.AreEqual("ERR Lua encountered an error: [string \"luanet.load_assembly('mscorlib')\"]:1: attempt to index a nil value (global 'luanet')", ex.Message); } // Try to call a OS function - using (var runner = new LuaRunner("os = require('os'); return os.time();")) + using (var runner = new LuaRunner(new(), "os = require('os'); return os.time();")) { runner.CompileForRunner(); var ex = Assert.Throws(() => runner.RunForRunner()); - ClassicAssert.AreEqual("[string \"os = require('os'); return os.time();\"]:1: attempt to call a nil value (global 'require')", ex.Message); + ClassicAssert.AreEqual("ERR Lua encountered an error: [string \"os = require('os'); return os.time();\"]:1: attempt to call a nil value (global 'require')", ex.Message); } // Try to execute the input stream - using (var runner = new LuaRunner("dofile();")) + using (var runner = new LuaRunner(new(), "dofile();")) { runner.CompileForRunner(); var ex = Assert.Throws(() => runner.RunForRunner()); - ClassicAssert.AreEqual("[string \"dofile();\"]:1: attempt to call a nil value (global 'dofile')", ex.Message); + ClassicAssert.AreEqual("ERR Lua encountered an error: [string \"dofile();\"]:1: attempt to call a nil value (global 'dofile')", ex.Message); } // Try to call a windows executable - using (var runner = new LuaRunner("require \"notepad\"")) + using (var runner = new LuaRunner(new(), "require \"notepad\"")) { runner.CompileForRunner(); var ex = Assert.Throws(() => runner.RunForRunner()); - ClassicAssert.AreEqual("[string \"require \"notepad\"\"]:1: attempt to call a nil value (global 'require')", ex.Message); + ClassicAssert.AreEqual("ERR Lua encountered an error: [string \"require \"notepad\"\"]:1: attempt to call a nil value (global 'require')", ex.Message); } // Try to call an OS function - using (var runner = new LuaRunner("os.exit();")) + using (var runner = new LuaRunner(new(), "os.exit();")) { runner.CompileForRunner(); var ex = Assert.Throws(() => runner.RunForRunner()); - ClassicAssert.AreEqual("[string \"os.exit();\"]:1: attempt to index a nil value (global 'os')", ex.Message); + ClassicAssert.AreEqual("ERR Lua encountered an error: [string \"os.exit();\"]:1: attempt to index a nil value (global 'os')", ex.Message); } // Try to include a new .net library - using (var runner = new LuaRunner("import ('System.Diagnostics');")) + using (var runner = new LuaRunner(new(), "import ('System.Diagnostics');")) { runner.CompileForRunner(); var ex = Assert.Throws(() => runner.RunForRunner()); - ClassicAssert.AreEqual("[string \"import ('System.Diagnostics');\"]:1: attempt to call a nil value (global 'import')", ex.Message); + ClassicAssert.AreEqual("ERR Lua encountered an error: [string \"import ('System.Diagnostics');\"]:1: attempt to call a nil value (global 'import')", ex.Message); } } @@ -67,14 +73,14 @@ public void CannotRunUnsafeScript() public void CanLoadScript() { // Code with error - using (var runner = new LuaRunner("local;")) + using (var runner = new LuaRunner(new(), "local;")) { var ex = Assert.Throws(runner.CompileForRunner); ClassicAssert.AreEqual("Compilation error: [string \"local;\"]:1: expected near ';'", ex.Message); } // Code without error - using (var runner = new LuaRunner("local list; list = 1; return list;")) + using (var runner = new LuaRunner(new(), "local list; list = 1; return list;")) { runner.CompileForRunner(); } @@ -87,7 +93,7 @@ public void CanRunScript() string[] args = ["arg1", "arg2", "arg3"]; // Run code without errors - using (var runner = new LuaRunner("local list; list = ARGV[1] ; return list;")) + using (var runner = new LuaRunner(new(), "local list; list = ARGV[1] ; return list;")) { runner.CompileForRunner(); var res = runner.RunForRunner(keys, args); @@ -95,7 +101,7 @@ public void CanRunScript() } // Run code with errors - using (var runner = new LuaRunner("local list; list = ; return list;")) + using (var runner = new LuaRunner(new(), "local list; list = ; return list;")) { var ex = Assert.Throws(runner.CompileForRunner); ClassicAssert.AreEqual("Compilation error: [string \"local list; list = ; return list;\"]:1: unexpected symbol near ';'", ex.Message); @@ -105,7 +111,7 @@ public void CanRunScript() [Test] public void KeysAndArgsCleared() { - using (var runner = new LuaRunner("return { KEYS[1], ARGV[1], KEYS[2], ARGV[2] }")) + using (var runner = new LuaRunner(new(), "return { KEYS[1], ARGV[1], KEYS[2], ARGV[2] }")) { runner.CompileForRunner(); var res1 = runner.RunForRunner(["hello", "world"], ["fizz", "buzz"]); @@ -130,5 +136,374 @@ public void KeysAndArgsCleared() ClassicAssert.AreEqual("345", (string)obj3[2]); } } + + [Test] + public unsafe void LuaLimittedManaged() + { + const int Iters = 20; + const int TotalAllocSizeBytes = 1_024 * 1_024; + + var rand = new Random(2024_12_16); // Repeatable, but random + + // Special cases + { + var luaAlloc = new LuaLimitedManagedAllocator(TotalAllocSizeBytes); + luaAlloc.CheckCorrectness(); + + // 0 sized should work + ref var zero0 = ref luaAlloc.AllocateNew(0, out var failed0); + ClassicAssert.IsFalse(failed0); + ClassicAssert.IsFalse(Unsafe.IsNullRef(ref zero0)); + ref var zero1 = ref luaAlloc.AllocateNew(0, out var failed1); + ClassicAssert.IsFalse(failed1); + ClassicAssert.IsFalse(Unsafe.IsNullRef(ref zero1)); + luaAlloc.CheckCorrectness(); + + luaAlloc.Free(ref zero0, 0); + luaAlloc.Free(ref zero1, 0); + luaAlloc.CheckCorrectness(); + + // Impossibly large fails + ref var failedRef = ref luaAlloc.AllocateNew(TotalAllocSizeBytes * 2, out var failedLarge); + ClassicAssert.IsTrue(failedLarge); + ClassicAssert.IsTrue(Unsafe.IsNullRef(ref failedRef)); + + luaAlloc.CheckCorrectness(); + } + + // Fill whole allocation + { + var luaAlloc = new LuaLimitedManagedAllocator(TotalAllocSizeBytes); + var freeSpace = luaAlloc.FirstBlockSizeBytes; + + for (var i = 0; i < Iters; i++) + { + for (var size = 1; size <= 64; size *= 2) + { + var lastSize = 0L; + var toFree = new List(); + while (true) + { + ref var newData = ref luaAlloc.AllocateNew(size, out var failed); + if (failed) + { + break; + } + DebugClassicAssertIsTrue(luaAlloc.AllocatedBytes > lastSize); + lastSize = luaAlloc.AllocatedBytes; + + var into = new Span(Unsafe.AsPointer(ref newData), size); + into.Fill((byte)toFree.Count); + + toFree.Add((nint)Unsafe.AsPointer(ref newData)); + } + luaAlloc.CheckCorrectness(); + + for (var j = 0; j < toFree.Count; j++) + { + var ptr = toFree[j]; + + var into = new Span((void*)ptr, size); + ClassicAssert.IsFalse(into.ContainsAnyExcept((byte)j)); + + luaAlloc.Free(ref Unsafe.AsRef((void*)ptr), size); + } + luaAlloc.CheckCorrectness(); + + DebugClassicAssertAreEqual(0, luaAlloc.AllocatedBytes); + + _ = luaAlloc.TryCoalesceAllFreeBlocks(); + luaAlloc.CheckCorrectness(); + + ClassicAssert.AreEqual(freeSpace, luaAlloc.FirstBlockSizeBytes); + } + } + } + + // Repeated realloc preserves data and doesn't move. + { + var luaAlloc = new LuaLimitedManagedAllocator(TotalAllocSizeBytes); + var freeSpace = luaAlloc.FirstBlockSizeBytes; + + for (var i = 0; i < Iters; i++) + { + for (var initialSize = 1; initialSize <= 64; initialSize *= 2) + { + ref var initialData = ref luaAlloc.AllocateNew(initialSize, out var failed); + ClassicAssert.IsFalse(failed); + + var size = initialSize; + var val = 1; + + ref var curData = ref initialData; + + MemoryMarshal.CreateSpan(ref curData, initialSize).Fill((byte)val); + + while (true) + { + var newSize = size + rand.Next(4 * 1024) + 1; + ref var newData = ref luaAlloc.ResizeAllocation(ref curData, size, newSize, out failed); + if (failed) + { + ClassicAssert.IsTrue(Unsafe.IsNullRef(ref newData)); + break; + } + + // Byte totals are believable + DebugClassicAssertIsTrue(luaAlloc.AllocatedBytes >= newSize); + + // Shouldn't have moved + ClassicAssert.IsTrue(Unsafe.AreSame(ref initialData, ref newData)); + + // Data preserved + var oldData = MemoryMarshal.CreateReadOnlySpan(ref newData, size); + ClassicAssert.IsFalse(oldData.ContainsAnyExcept((byte)val)); + + // Mutate to check for faults + val++; + MemoryMarshal.CreateSpan(ref newData, newSize).Fill((byte)val); + + // Continue + size = newSize; + curData = ref newData; + } + luaAlloc.CheckCorrectness(); + + // Check final correctness + var finalData = MemoryMarshal.CreateReadOnlySpan(ref curData, size); + ClassicAssert.IsFalse(finalData.ContainsAnyExcept((byte)val)); + + // Hand the one block back, which should fully free everything + luaAlloc.Free(ref curData, size); + + DebugClassicAssertAreEqual(0, luaAlloc.AllocatedBytes); + ClassicAssert.AreEqual(freeSpace, luaAlloc.FirstBlockSizeBytes); + + luaAlloc.CheckCorrectness(); + } + } + } + + // Basic fixed sized allocs + { + const int AllocSize = 16; + + var luaAlloc = new LuaLimitedManagedAllocator(TotalAllocSizeBytes); + var freeSpace = luaAlloc.FirstBlockSizeBytes; + + for (var i = 0; i < Iters; i++) + { + DebugClassicAssertAreEqual(0, luaAlloc.AllocatedBytes); + ClassicAssert.AreEqual(1, luaAlloc.FreeBlockCount); + + var numOps = rand.Next(50) + 1; + + var toFree = new List(); + + // Do a bunch of allocs, all should succeed + for (var j = 0; j < numOps; j++) + { + ref var newData = ref luaAlloc.AllocateNew(AllocSize, out var failed); + ClassicAssert.False(failed); + ClassicAssert.False(Unsafe.IsNullRef(ref newData)); + + var into = new Span(Unsafe.AsPointer(ref newData), AllocSize); + into.Fill((byte)j); + + toFree.Add((nint)Unsafe.AsPointer(ref newData)); + } + luaAlloc.CheckCorrectness(); + + // Each block should be served out of a split, so free list should stay at 1 + ClassicAssert.AreEqual(1, luaAlloc.FreeBlockCount); + + // Check that data wasn't corrupted + for (var j = 0; j < toFree.Count; j++) + { + var ptr = toFree[j]; + var data = new Span((void*)ptr, AllocSize); + var expected = (byte)j; + + ClassicAssert.IsFalse(data.ContainsAnyExcept(expected)); + } + + // Free in a random order + toFree = toFree.Select(p => (Pointer: p, Order: rand.Next())).OrderBy(t => t.Order).Select(t => t.Pointer).ToList(); + for (var j = 0; j < toFree.Count; j++) + { + var ptr = toFree[j]; + ref var asData = ref Unsafe.AsRef((void*)ptr); + + luaAlloc.Free(ref asData, AllocSize); + } + luaAlloc.CheckCorrectness(); + + // Check that all free's didn't corrupt anything + DebugClassicAssertAreEqual(0, luaAlloc.AllocatedBytes); + + // Check that all memory is reclaimable + _ = luaAlloc.TryCoalesceAllFreeBlocks(); + ClassicAssert.AreEqual(freeSpace, luaAlloc.FirstBlockSizeBytes); + luaAlloc.CheckCorrectness(); + } + } + + // Random operations with variable sized allocs + { + var luaAlloc = new LuaLimitedManagedAllocator(TotalAllocSizeBytes); + var freeSpace = luaAlloc.FirstBlockSizeBytes; + + for (var i = 0; i < Iters; i++) + { + DebugClassicAssertAreEqual(0, luaAlloc.AllocatedBytes); + ClassicAssert.AreEqual(1, luaAlloc.FreeBlockCount); + + var toFree = new List<(nint Pointer, byte Expected, int AllocSize)>(); + for (var j = 0; j < 1_000; j++) + { + var op = rand.Next(4); + switch (op) + { + // Allocate + case 0: + { + var allocSize = rand.Next(4 * 1024) + 1; + + ref var newData = ref luaAlloc.AllocateNew(allocSize, out var failed); + ClassicAssert.IsFalse(failed); + ClassicAssert.IsFalse(Unsafe.IsNullRef(ref newData)); + + var into = new Span(Unsafe.AsPointer(ref newData), allocSize); + into.Fill((byte)j); + + toFree.Add(((nint)Unsafe.AsPointer(ref newData), (byte)j, allocSize)); + } + + break; + + // Reallocate + case 1: + { + if (toFree.Count == 0) + { + goto case 0; + } + + var reallocIx = rand.Next(toFree.Count); + var (ptr, expected, size) = toFree[reallocIx]; + + ref var reallocRef = ref Unsafe.AsRef((void*)ptr); + + int newSizeBytes; + if (rand.Next(2) == 0) + { + newSizeBytes = size + rand.Next(32) + 1; + } + else + { + newSizeBytes = size - rand.Next(size); + if (newSizeBytes == 0) + { + goto case 0; + } + } + + ref var updatedRef = ref luaAlloc.ResizeAllocation(ref reallocRef, size, newSizeBytes, out var failed); + ClassicAssert.IsFalse(failed); + + if (newSizeBytes <= size) + { + // Shrink should always leave in place + ClassicAssert.IsTrue(Unsafe.AreSame(ref updatedRef, ref reallocRef)); + } + + var toCheck = MemoryMarshal.CreateReadOnlySpan(ref updatedRef, Math.Min(size, newSizeBytes)); + ClassicAssert.IsFalse(toCheck.ContainsAnyExcept(expected)); + + toFree.RemoveAt(reallocIx); + + var toFill = MemoryMarshal.CreateSpan(ref updatedRef, newSizeBytes); + toFill.Fill((byte)j); + + toFree.Add(((nint)Unsafe.AsPointer(ref updatedRef), (byte)j, newSizeBytes)); + } + + break; + + // Free + case 2: + { + if (toFree.Count == 0) + { + goto case 0; + } + + var freeIx = rand.Next(toFree.Count); + var (ptr, expected, size) = toFree[freeIx]; + + toFree.RemoveAt(freeIx); + luaAlloc.Free(ref Unsafe.AsRef((void*)ptr), size); + } + + break; + + // Validate + case 3: + { + if (toFree.Count == 0) + { + goto case 0; + } + + foreach (var (ptr, expected, size) in toFree) + { + var data = new Span((void*)ptr, size); + ClassicAssert.IsFalse(data.ContainsAnyExcept(expected)); + } + } + + break; + + default: + ClassicAssert.Fail("Unexpected operation"); + break; + } + luaAlloc.CheckCorrectness(); + } + luaAlloc.CheckCorrectness(); + + // Validate and free everything that's left + foreach (var (ptr, expected, size) in toFree) + { + ref var asData = ref Unsafe.AsRef((void*)ptr); + + var data = new Span((void*)ptr, size); + ClassicAssert.IsFalse(data.ContainsAnyExcept(expected)); + + luaAlloc.Free(ref asData, size); + } + luaAlloc.CheckCorrectness(); + + // Check that all free's didn't corrupt anything + DebugClassicAssertAreEqual(0, luaAlloc.AllocatedBytes); + + // Full coalesce gets contiguous blocks back + _ = luaAlloc.TryCoalesceAllFreeBlocks(); + ClassicAssert.AreEqual(freeSpace, luaAlloc.FirstBlockSizeBytes); + + luaAlloc.CheckCorrectness(); + } + } + + // In DEBUG builds, assert condition is true + [Conditional("DEBUG")] + static void DebugClassicAssertIsTrue(bool condition) + => ClassicAssert.IsTrue(condition); + + // In DEBUG builds, assert objects are equal + [Conditional("DEBUG")] + static void DebugClassicAssertAreEqual(object expected, object actual) + => ClassicAssert.AreEqual(expected, actual); + } } } \ No newline at end of file diff --git a/test/Garnet.test/LuaScriptTests.cs b/test/Garnet.test/LuaScriptTests.cs index ace0593340..8b531401f6 100644 --- a/test/Garnet.test/LuaScriptTests.cs +++ b/test/Garnet.test/LuaScriptTests.cs @@ -5,22 +5,37 @@ using System.Linq; using System.Text; using System.Threading.Tasks; +using Garnet.server; using NUnit.Framework; using NUnit.Framework.Legacy; using StackExchange.Redis; namespace Garnet.test { - [TestFixture] + // Limits chosen here to allow completion - if you have to bump them up, consider that you might have introduced a regression + [TestFixture(LuaMemoryManagementMode.Native, "")] + [TestFixture(LuaMemoryManagementMode.Tracked, "")] + [TestFixture(LuaMemoryManagementMode.Tracked, "13m")] + [TestFixture(LuaMemoryManagementMode.Managed, "")] + [TestFixture(LuaMemoryManagementMode.Managed, "15m")] public class LuaScriptTests { + private readonly LuaMemoryManagementMode allocMode; + private readonly string limitBytes; + protected GarnetServer server; + public LuaScriptTests(LuaMemoryManagementMode allocMode, string limitBytes) + { + this.allocMode = allocMode; + this.limitBytes = limitBytes; + } + [SetUp] public void Setup() { TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); - server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, enableLua: true); + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, enableLua: true, luaMemoryMode: allocMode, luaMemoryLimit: limitBytes); server.Start(); } @@ -303,7 +318,11 @@ public void UseEvalShaLightClient() { var randPostFix = rnd.Next(1, 1000); valueKey = $"{valueKey}{randPostFix}"; + var r = lightClientRequest.SendCommand($"EVALSHA {sha1SetScript} 1 {nameKey}{randPostFix} {valueKey}", 1); + // Check for error reply + ClassicAssert.IsTrue(r[0] != '-'); + var g = Encoding.ASCII.GetString(lightClientRequest.SendCommand($"get {nameKey}{randPostFix}", 1)); var fstEndOfLine = g.IndexOf('\n', StringComparison.OrdinalIgnoreCase) + 1; var strKeyValue = g.Substring(fstEndOfLine, valueKey.Length); @@ -442,7 +461,7 @@ public void ComplexLuaTest3() using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); var db = redis.GetDatabase(0); - for (int i = 0; i < 10; i++) + for (var i = 0; i < 10; i++) { var response1 = (string[])db.ScriptEvaluate(script1, ["key1", "key2"]); ClassicAssert.AreEqual(2, response1.Length); @@ -615,8 +634,8 @@ public void NumberArgumentCoercion() using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); var db = redis.GetDatabase(); - db.StringSet("2", "hello"); - db.StringSet("2.1", "world"); + _ = db.StringSet("2", "hello"); + _ = db.StringSet("2.1", "world"); var res = (string)db.ScriptEvaluate("return redis.call('GET', 2.1)"); ClassicAssert.AreEqual("world", res); @@ -681,14 +700,15 @@ public void ComplexLuaReturns() { if (i != 1) { - tableDepth.Append(", "); + _ = tableDepth.Append(", "); } - tableDepth.Append("{ "); - tableDepth.Append(i); + _ = tableDepth.Append("{ "); + _ = tableDepth.Append(i); } + for (var i = 1; i <= Depth; i++) { - tableDepth.Append(" }"); + _ = tableDepth.Append(" }"); } var script = "return " + tableDepth.ToString(); @@ -785,5 +805,38 @@ public void MetatableReturn() ClassicAssert.AreEqual("jkl", (string)ret[3]); ClassicAssert.AreEqual("def", (string)ret[4]); } + + [Test] + public void IntentionalOOM() + { + if (string.IsNullOrEmpty(limitBytes)) + { + ClassicAssert.Ignore("No memory limit enabled"); + return; + } + + const string ScriptOOMText = @" +local foo = 'abcdefghijklmnopqrstuvwxyz' +if @Ctrl == 'OOM' then + for i = 1, 10000000 do + foo = foo .. foo + end +end + +return foo"; + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var scriptOOM = LuaScript.Prepare(ScriptOOMText); + var loadedScriptOOM = scriptOOM.Load(redis.GetServers()[0]); + + // OOM actually happens and is reported + var exc = ClassicAssert.Throws(() => db.ScriptEvaluate(loadedScriptOOM, new { Ctrl = "OOM" })); + ClassicAssert.AreEqual("ERR Lua encountered an error: not enough memory", exc.Message); + + // We can still run the script without issue (with non-crashing args) afterwards + var res = db.ScriptEvaluate(loadedScriptOOM, new { Ctrl = "Safe" }); + ClassicAssert.AreEqual("abcdefghijklmnopqrstuvwxyz", (string)res); + } } } \ No newline at end of file diff --git a/test/Garnet.test/TestUtils.cs b/test/Garnet.test/TestUtils.cs index 221843ce97..9bf375a706 100644 --- a/test/Garnet.test/TestUtils.cs +++ b/test/Garnet.test/TestUtils.cs @@ -217,7 +217,9 @@ public static GarnetServer CreateGarnetServer( ILogger logger = null, IEnumerable loadModulePaths = null, string pubSubPageSize = null, - bool asyncReplay = false) + bool asyncReplay = false, + LuaMemoryManagementMode luaMemoryMode = LuaMemoryManagementMode.Native, + string luaMemoryLimit = "") { if (UseAzureStorage) IgnoreIfNotRunningAzureTests(); @@ -296,7 +298,8 @@ public static GarnetServer CreateGarnetServer( LoadModuleCS = loadModulePaths, EnableReadCache = enableReadCache, EnableObjectStoreReadCache = enableObjectStoreReadCache, - ReplicationOffsetMaxLag = asyncReplay ? -1 : 0 + ReplicationOffsetMaxLag = asyncReplay ? -1 : 0, + LuaOptions = enableLua ? new LuaOptions(luaMemoryMode, luaMemoryLimit, logger) : null, }; if (!string.IsNullOrEmpty(pubSubPageSize)) @@ -396,7 +399,9 @@ public static GarnetServer[] CreateGarnetCluster( AadAuthenticationSettings authenticationSettings = null, int metricsSamplingFrequency = 0, bool enableLua = false, - bool asyncReplay = false) + bool asyncReplay = false, + LuaMemoryManagementMode luaMemoryMode = LuaMemoryManagementMode.Native, + string luaMemoryLimit = "") { if (UseAzureStorage) IgnoreIfNotRunningAzureTests(); @@ -438,7 +443,9 @@ public static GarnetServer[] CreateGarnetCluster( aadAuthenticationSettings: authenticationSettings, metricsSamplingFrequency: metricsSamplingFrequency, enableLua: enableLua, - asyncReplay: asyncReplay); + asyncReplay: asyncReplay, + luaMemoryMode: luaMemoryMode, + luaMemoryLimit: luaMemoryLimit); ClassicAssert.IsNotNull(opts); int iter = 0; @@ -486,7 +493,9 @@ public static GarnetServerOptions GetGarnetServerOptions( int metricsSamplingFrequency = 0, bool enableLua = false, bool asyncReplay = false, - ILogger logger = null) + ILogger logger = null, + LuaMemoryManagementMode luaMemoryMode = LuaMemoryManagementMode.Native, + string luaMemoryLimit = "") { if (UseAzureStorage) IgnoreIfNotRunningAzureTests(); @@ -571,7 +580,8 @@ public static GarnetServerOptions GetGarnetServerOptions( ClusterUsername = authUsername, ClusterPassword = authPassword, EnableLua = enableLua, - ReplicationOffsetMaxLag = asyncReplay ? -1 : 0 + ReplicationOffsetMaxLag = asyncReplay ? -1 : 0, + LuaOptions = enableLua ? new LuaOptions(luaMemoryMode, luaMemoryLimit) : null, }; if (lowMemory)