diff --git a/.github/workflows/ci-bdnbenchmark.yml b/.github/workflows/ci-bdnbenchmark.yml
index 7ca620bc73..5f3ee17074 100644
--- a/.github/workflows/ci-bdnbenchmark.yml
+++ b/.github/workflows/ci-bdnbenchmark.yml
@@ -42,7 +42,7 @@ jobs:
os: [ ubuntu-latest, windows-latest ]
framework: [ 'net8.0' ]
configuration: [ 'Release' ]
- test: [ 'Operations.BasicOperations', 'Operations.ObjectOperations', 'Operations.HashObjectOperations', 'Cluster.ClusterMigrate', 'Cluster.ClusterOperations', 'Lua.LuaScripts', 'Operations.CustomOperations', 'Operations.RawStringOperations', 'Operations.ScriptOperations','Network.BasicOperations', 'Network.RawStringOperations' ]
+ test: [ 'Operations.BasicOperations', 'Operations.ObjectOperations', 'Operations.HashObjectOperations', 'Cluster.ClusterMigrate', 'Cluster.ClusterOperations', 'Lua.LuaScripts', 'Lua.LuaScriptCacheOperations','Lua.LuaRunnerOperations','Operations.CustomOperations', 'Operations.RawStringOperations', 'Operations.ScriptOperations','Network.BasicOperations', 'Network.RawStringOperations' ]
steps:
- name: Check out code
uses: actions/checkout@v4
diff --git a/benchmark/BDN.benchmark/Lua/LuaParams.cs b/benchmark/BDN.benchmark/Lua/LuaParams.cs
index 85c23fa9d1..3426a74786 100644
--- a/benchmark/BDN.benchmark/Lua/LuaParams.cs
+++ b/benchmark/BDN.benchmark/Lua/LuaParams.cs
@@ -1,26 +1,38 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
+using BenchmarkDotNet.Code;
+using Garnet.server;
+
namespace BDN.benchmark.Lua
{
///
- /// Cluster parameters
+ /// Lua parameters
///
- public struct LuaParams
+ public readonly struct LuaParams
{
+ public readonly LuaMemoryManagementMode Mode { get; }
+ public readonly bool MemoryLimit { get; }
+
///
/// Constructor
///
- public LuaParams()
+ public LuaParams(LuaMemoryManagementMode mode, bool memoryLimit)
{
+ Mode = mode;
+ MemoryLimit = memoryLimit;
}
+ ///
+ /// Get the equivalent .
+ ///
+ public LuaOptions CreateOptions()
+ => new(Mode, MemoryLimit ? "2m" : "");
+
///
/// String representation
///
public override string ToString()
- {
- return "None";
- }
+ => $"{Mode},{(MemoryLimit ? "Limit" : "None")}";
}
}
\ No newline at end of file
diff --git a/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs b/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs
index 07d28e641b..a100249efe 100644
--- a/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs
+++ b/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs
@@ -139,9 +139,13 @@ public unsafe class LuaRunnerOperations
/// Lua parameters provider
///
public IEnumerable LuaParamsProvider()
- {
- yield return new();
- }
+ => [
+ new(LuaMemoryManagementMode.Native, false),
+ new(LuaMemoryManagementMode.Tracked, false),
+ new(LuaMemoryManagementMode.Tracked, true),
+ new(LuaMemoryManagementMode.Managed, false),
+ new(LuaMemoryManagementMode.Managed, true),
+ ];
private EmbeddedRespServer server;
private RespServerSession session;
@@ -151,16 +155,21 @@ public IEnumerable LuaParamsProvider()
private LuaRunner smallCompileRunner;
private LuaRunner largeCompileRunner;
+ private LuaOptions opts;
+
[GlobalSetup]
public void GlobalSetup()
{
- server = new EmbeddedRespServer(new GarnetServerOptions() { EnableLua = true, QuietMode = true });
+ opts = Params.CreateOptions();
+
+ server = new EmbeddedRespServer(new GarnetServerOptions() { EnableLua = true, QuietMode = true, LuaOptions = opts });
session = server.GetRespSession();
- paramsRunner = new LuaRunner("return nil");
- smallCompileRunner = new LuaRunner(SmallScript);
- largeCompileRunner = new LuaRunner(LargeScript);
+ paramsRunner = new LuaRunner(opts, "return nil");
+
+ smallCompileRunner = new LuaRunner(opts, SmallScript);
+ largeCompileRunner = new LuaRunner(opts, LargeScript);
}
[GlobalCleanup]
@@ -171,6 +180,18 @@ public void GlobalCleanup()
paramsRunner.Dispose();
}
+ [IterationSetup]
+ public void IterationSetup()
+ {
+ session.EnterAndGetResponseObject();
+ }
+
+ [IterationCleanup]
+ public void IterationCleanup()
+ {
+ session.ExitAndReturnResponseObject();
+ }
+
[Benchmark]
public void ResetParametersSmall()
{
@@ -194,13 +215,13 @@ public void ResetParametersLarge()
[Benchmark]
public void ConstructSmall()
{
- using var runner = new LuaRunner(SmallScript);
+ using var runner = new LuaRunner(opts, SmallScript);
}
[Benchmark]
public void ConstructLarge()
{
- using var runner = new LuaRunner(LargeScript);
+ using var runner = new LuaRunner(opts, LargeScript);
}
[Benchmark]
diff --git a/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs b/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs
index a996e45469..430a92ca59 100644
--- a/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs
+++ b/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs
@@ -22,9 +22,13 @@ public class LuaScriptCacheOperations
/// Lua parameters provider
///
public IEnumerable LuaParamsProvider()
- {
- yield return new();
- }
+ => [
+ new(LuaMemoryManagementMode.Native, false),
+ new(LuaMemoryManagementMode.Tracked, false),
+ new(LuaMemoryManagementMode.Tracked, true),
+ new(LuaMemoryManagementMode.Managed, false),
+ new(LuaMemoryManagementMode.Managed, true),
+ ];
private EmbeddedRespServer server;
private StoreWrapper storeWrapper;
@@ -38,7 +42,9 @@ public IEnumerable LuaParamsProvider()
[GlobalSetup]
public void GlobalSetup()
{
- server = new EmbeddedRespServer(new GarnetServerOptions() { EnableLua = true, QuietMode = true });
+ var options = Params.CreateOptions();
+
+ server = new EmbeddedRespServer(new GarnetServerOptions() { EnableLua = true, QuietMode = true, LuaOptions = options });
storeWrapper = server.StoreWrapper;
sessionScriptCache = new SessionScriptCache(storeWrapper, new GarnetNoAuthAuthenticator());
session = server.GetRespSession();
diff --git a/benchmark/BDN.benchmark/Lua/LuaScripts.cs b/benchmark/BDN.benchmark/Lua/LuaScripts.cs
index 3cff4b30e3..fc2a2c7269 100644
--- a/benchmark/BDN.benchmark/Lua/LuaScripts.cs
+++ b/benchmark/BDN.benchmark/Lua/LuaScripts.cs
@@ -22,9 +22,13 @@ public unsafe class LuaScripts
/// Lua parameters provider
///
public IEnumerable LuaParamsProvider()
- {
- yield return new();
- }
+ => [
+ new(LuaMemoryManagementMode.Native, false),
+ new(LuaMemoryManagementMode.Tracked, false),
+ new(LuaMemoryManagementMode.Tracked, true),
+ new(LuaMemoryManagementMode.Managed, false),
+ new(LuaMemoryManagementMode.Managed, true),
+ ];
LuaRunner r1, r2, r3, r4;
readonly string[] keys = ["key1"];
@@ -32,13 +36,15 @@ public IEnumerable LuaParamsProvider()
[GlobalSetup]
public void GlobalSetup()
{
- r1 = new LuaRunner("return");
+ var options = Params.CreateOptions();
+
+ r1 = new LuaRunner(options, "return");
r1.CompileForRunner();
- r2 = new LuaRunner("return 1 + 1");
+ r2 = new LuaRunner(options, "return 1 + 1");
r2.CompileForRunner();
- r3 = new LuaRunner("return KEYS[1]");
+ r3 = new LuaRunner(options, "return KEYS[1]");
r3.CompileForRunner();
- r4 = new LuaRunner("return redis.call(KEYS[1])");
+ r4 = new LuaRunner(options, "return redis.call(KEYS[1])");
r4.CompileForRunner();
}
diff --git a/benchmark/BDN.benchmark/Operations/OperationsBase.cs b/benchmark/BDN.benchmark/Operations/OperationsBase.cs
index 8677331de3..c47a550a4f 100644
--- a/benchmark/BDN.benchmark/Operations/OperationsBase.cs
+++ b/benchmark/BDN.benchmark/Operations/OperationsBase.cs
@@ -55,6 +55,7 @@ public virtual void GlobalSetup()
QuietMode = true,
EnableLua = true,
DisablePubSub = true,
+ LuaOptions = new(LuaMemoryManagementMode.Native, ""),
};
if (Params.useAof)
{
diff --git a/benchmark/BDN.benchmark/Operations/ScriptOperations.cs b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs
index 8a6cd34f08..b8298912eb 100644
--- a/benchmark/BDN.benchmark/Operations/ScriptOperations.cs
+++ b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs
@@ -4,8 +4,10 @@
using System.Runtime.CompilerServices;
using System.Security.Cryptography;
using System.Text;
+using BDN.benchmark.Lua;
using BenchmarkDotNet.Attributes;
using Embedded.server;
+using Garnet.server;
namespace BDN.benchmark.Operations
{
@@ -13,7 +15,7 @@ namespace BDN.benchmark.Operations
/// Benchmark for SCRIPT LOAD, SCRIPT EXISTS, EVAL, and EVALSHA
///
[MemoryDiagnoser]
- public unsafe class ScriptOperations : OperationsBase
+ public unsafe class ScriptOperations
{
// Small script that does 1 operation and no logic
const string SmallScriptText = @"return redis.call('GET', KEYS[1]);";
@@ -156,9 +158,54 @@ public unsafe class ScriptOperations : OperationsBase
static ReadOnlySpan ARRAY_RETURN => "*3\r\n$4\r\nEVAL\r\n$22\r\nreturn {1, 2, 3, 4, 5}\r\n$1\r\n0\r\n"u8;
Request arrayReturn;
- public override void GlobalSetup()
+
+ ///
+ /// Lua parameters
+ ///
+ [ParamsSource(nameof(LuaParamsProvider))]
+ public LuaParams Params { get; set; }
+
+ ///
+ /// Lua parameters provider
+ ///
+ public static IEnumerable LuaParamsProvider()
+ => [
+ new(LuaMemoryManagementMode.Native, false),
+ new(LuaMemoryManagementMode.Tracked, false),
+ new(LuaMemoryManagementMode.Tracked, true),
+ new(LuaMemoryManagementMode.Managed, false),
+ new(LuaMemoryManagementMode.Managed, true),
+ ];
+
+ ///
+ /// Batch size per method invocation
+ /// With a batchSize of 100, we have a convenient conversion of latency to throughput:
+ /// 5 us = 20 Mops/sec
+ /// 10 us = 10 Mops/sec
+ /// 20 us = 5 Mops/sec
+ /// 25 us = 4 Mops/sec
+ /// 100 us = 1 Mops/sec
+ ///
+ internal const int batchSize = 100;
+ internal EmbeddedRespServer server;
+ internal RespServerSession session;
+
+ ///
+ /// Setup
+ ///
+ [GlobalSetup]
+ public void GlobalSetup()
{
- base.GlobalSetup();
+ var opts = new GarnetServerOptions
+ {
+ QuietMode = true,
+ EnableLua = true,
+ LuaOptions = Params.CreateOptions(),
+ };
+
+ server = new EmbeddedRespServer(opts);
+
+ session = server.GetRespSession();
SetupOperation(ref scriptLoad, SCRIPT_LOAD);
@@ -216,6 +263,16 @@ public override void GlobalSetup()
SetupOperation(ref evalShaLargeScript, largeScriptEvals);
}
+ ///
+ /// Cleanup
+ ///
+ [GlobalCleanup]
+ public virtual void GlobalCleanup()
+ {
+ session.Dispose();
+ server.Dispose();
+ }
+
[Benchmark]
public void ScriptLoad()
{
@@ -263,5 +320,36 @@ public void ArrayReturn()
{
Send(arrayReturn);
}
+
+ private void Send(Request request)
+ {
+ _ = session.TryConsumeMessages(request.bufferPtr, request.buffer.Length);
+ }
+
+ private unsafe void SetupOperation(ref Request request, ReadOnlySpan operation, int batchSize = batchSize)
+ {
+ request.buffer = GC.AllocateArray(operation.Length * batchSize, pinned: true);
+ request.bufferPtr = (byte*)Unsafe.AsPointer(ref request.buffer[0]);
+ for (int i = 0; i < batchSize; i++)
+ operation.CopyTo(new Span(request.buffer).Slice(i * operation.Length));
+ }
+
+ private unsafe void SetupOperation(ref Request request, string operation, int batchSize = batchSize)
+ {
+ request.buffer = GC.AllocateUninitializedArray(operation.Length * batchSize, pinned: true);
+ for (var i = 0; i < batchSize; i++)
+ {
+ var start = i * operation.Length;
+ Encoding.UTF8.GetBytes(operation, request.buffer.AsSpan().Slice(start, operation.Length));
+ }
+ request.bufferPtr = (byte*)Unsafe.AsPointer(ref request.buffer[0]);
+ }
+
+ private unsafe void SetupOperation(ref Request request, List operationBytes)
+ {
+ request.buffer = GC.AllocateUninitializedArray(operationBytes.Count, pinned: true);
+ operationBytes.CopyTo(request.buffer);
+ request.bufferPtr = (byte*)Unsafe.AsPointer(ref request.buffer[0]);
+ }
}
}
\ No newline at end of file
diff --git a/libs/host/Configuration/Options.cs b/libs/host/Configuration/Options.cs
index b32cf175a3..8922292bf9 100644
--- a/libs/host/Configuration/Options.cs
+++ b/libs/host/Configuration/Options.cs
@@ -8,6 +8,7 @@
using System.Linq;
using System.Net;
using System.Reflection;
+using System.Runtime.InteropServices;
using System.Security.Cryptography.X509Certificates;
using CommandLine;
using Garnet.server;
@@ -508,6 +509,14 @@ internal sealed class Options
[Option("fail-on-recovery-error", Default = false, Required = false, HelpText = "Server bootup should fail if errors happen during bootup of AOF and checkpointing")]
public bool? FailOnRecoveryError { get; set; }
+ [Option("lua-memory-management-mode", Default = LuaMemoryManagementMode.Native, Required = false, HelpText = "Memory management mode for Lua scripts, must be set to LimittedNative or Managed to impose script limits")]
+ public LuaMemoryManagementMode LuaMemoryManagementMode { get; set; }
+
+ [MemorySizeValidation(false)]
+ [ForbiddenWithOption(nameof(LuaMemoryManagementMode), nameof(LuaMemoryManagementMode.Native))]
+ [Option("lua-script-memory-limit", Default = null, HelpText = "Memory limit for a Lua instances while running a script, lua-memory-management-mode must be set to something other than Native to use this flag")]
+ public string LuaScriptMemoryLimit { get; set; }
+
///
/// This property contains all arguments that were not parsed by the command line argument parser
///
@@ -718,7 +727,8 @@ public GarnetServerOptions GetServerOptions(ILogger logger = null)
IndexResizeFrequencySecs = IndexResizeFrequencySecs,
IndexResizeThreshold = IndexResizeThreshold,
LoadModuleCS = LoadModuleCS,
- FailOnRecoveryError = FailOnRecoveryError.GetValueOrDefault()
+ FailOnRecoveryError = FailOnRecoveryError.GetValueOrDefault(),
+ LuaOptions = EnableLua.GetValueOrDefault() ? new LuaOptions(LuaMemoryManagementMode, LuaScriptMemoryLimit, logger) : null,
};
}
diff --git a/libs/host/Configuration/OptionsValidators.cs b/libs/host/Configuration/OptionsValidators.cs
index ea7f64682e..94d97cced8 100644
--- a/libs/host/Configuration/OptionsValidators.cs
+++ b/libs/host/Configuration/OptionsValidators.cs
@@ -563,4 +563,42 @@ protected override ValidationResult IsValid(object value, ValidationContext vali
return base.IsValid(value, validationContext);
}
}
+
+ ///
+ /// Forbids a config option from being set if the another option has particular values.
+ ///
+ [AttributeUsage(AttributeTargets.Property)]
+ internal sealed class ForbiddenWithOptionAttribute : ValidationAttribute
+ {
+ private readonly string otherOptionName;
+ private readonly string[] forbiddenValues;
+
+ internal ForbiddenWithOptionAttribute(string otherOptionName, string forbiddenValue, params string[] otherForbiddenValues)
+ {
+ this.otherOptionName = otherOptionName;
+ forbiddenValues = [forbiddenValue, .. otherForbiddenValues];
+ }
+
+ ///
+ protected override ValidationResult IsValid(object value, ValidationContext validationContext)
+ {
+ var optionIsSet = value != null && !(value is string valueStr && string.IsNullOrEmpty(valueStr));
+ if (optionIsSet)
+ {
+ var propAccessor = validationContext.ObjectInstance?.GetType()?.GetProperty(otherOptionName, BindingFlags.Instance | BindingFlags.Public);
+ if (propAccessor != null)
+ {
+ var otherOptionValue = propAccessor.GetValue(validationContext.ObjectInstance);
+ var otherOptionValueAsString = otherOptionValue is string strVal ? strVal : otherOptionValue?.ToString();
+
+ if (forbiddenValues.Contains(otherOptionValueAsString, StringComparer.OrdinalIgnoreCase))
+ {
+ return new ValidationResult($"{nameof(validationContext.DisplayName)} cannot be set with {otherOptionName} has value '{otherOptionValueAsString}'");
+ }
+ }
+ }
+
+ return ValidationResult.Success;
+ }
+ }
}
\ No newline at end of file
diff --git a/libs/host/defaults.conf b/libs/host/defaults.conf
index 5a19569653..4c08ebec06 100644
--- a/libs/host/defaults.conf
+++ b/libs/host/defaults.conf
@@ -342,5 +342,11 @@
"LoadModuleCS": null,
/* Fails if encounters error during AOF replay or checkpointing */
- "FailOnRecoveryError": false
+ "FailOnRecoveryError": false,
+
+ /* Lua uses the default, unmanaged and untracked, allocator */
+ "LuaMemoryManagementMode": "Native",
+
+ /* Lua limits are ignored for Native, but can be set for other modes */
+ "LuaScriptMemoryLimit": null
}
\ No newline at end of file
diff --git a/libs/server/ArgSlice/ScratchBufferManager.cs b/libs/server/ArgSlice/ScratchBufferManager.cs
index e2106bfaa7..42fa904690 100644
--- a/libs/server/ArgSlice/ScratchBufferManager.cs
+++ b/libs/server/ArgSlice/ScratchBufferManager.cs
@@ -5,6 +5,7 @@
using System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
using System.Text;
using Garnet.common;
@@ -297,15 +298,19 @@ void ExpandScratchBufferIfNeeded(int newLength)
ExpandScratchBuffer(scratchBufferOffset + newLength);
}
- void ExpandScratchBuffer(int newLength)
+ void ExpandScratchBuffer(int newLength, int? copyLengthOverride = null)
{
if (newLength < 64) newLength = 64;
else newLength = (int)BitOperations.RoundUpToPowerOf2((uint)newLength + 1);
var _scratchBuffer = GC.AllocateArray(newLength, true);
- var _scratchBufferHead = (byte*)Unsafe.AsPointer(ref _scratchBuffer[0]);
- if (scratchBufferOffset > 0)
- new ReadOnlySpan(scratchBufferHead, scratchBufferOffset).CopyTo(new Span(_scratchBufferHead, scratchBufferOffset));
+ var _scratchBufferHead = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetArrayDataReference(_scratchBuffer));
+
+ var copyLength = copyLengthOverride ?? scratchBufferOffset;
+ if (copyLength > 0)
+ {
+ new ReadOnlySpan(scratchBufferHead, copyLength).CopyTo(new Span(_scratchBufferHead, copyLength));
+ }
scratchBuffer = _scratchBuffer;
scratchBufferHead = _scratchBufferHead;
}
@@ -326,8 +331,12 @@ public ArgSlice GetSliceFromTail(int length)
///
/// Force backing buffer to grow.
+ ///
+ /// provides a way to force a chunk at the start of the
+ /// previous buffer be copied into the new buffer, even if this
+ /// doesn't consider that chunk in use.
///
- public void GrowBuffer()
+ public void GrowBuffer(int? copyLengthOverride = null)
{
if (scratchBuffer == null)
{
@@ -335,7 +344,7 @@ public void GrowBuffer()
}
else
{
- ExpandScratchBuffer(scratchBuffer.Length + 1);
+ ExpandScratchBuffer(scratchBuffer.Length + 1, copyLengthOverride);
}
}
}
diff --git a/libs/server/Lua/ILuaAllocator.cs b/libs/server/Lua/ILuaAllocator.cs
new file mode 100644
index 0000000000..78be36a9ea
--- /dev/null
+++ b/libs/server/Lua/ILuaAllocator.cs
@@ -0,0 +1,53 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+
+namespace Garnet.server
+{
+ ///
+ /// Common interface for Lua allocators.
+ ///
+ /// Lua itself has a somewhat esoteric alloc interface,
+ /// this maps to something akin to malloc/free/realloc though
+ /// with some C# niceties.
+ ///
+ /// Note that all returned references must be pinned, as Lua is not aware
+ /// of the .NET GC.
+ ///
+ internal interface ILuaAllocator
+ {
+ ///
+ /// Allocate a new chunk of memory of at least size.
+ ///
+ /// Note that 0-sized allocations MUST succeed and MUST return a non-null reference.
+ /// A 0-sized allocation will NOT be dereferenced.
+ ///
+ /// If it cannot be satisfied, must be set to false.
+ ///
+ ref byte AllocateNew(int sizeBytes, out bool failed);
+
+ ///
+ /// Free an allocation previously obtained from
+ /// or .
+ ///
+ /// will match sizeBytes or newSizeBytes, respectively, that
+ /// was passed when the allocation was obtained.
+ ///
+ void Free(ref byte start, int sizeBytes);
+
+ ///
+ /// Resize an existing allocation.
+ ///
+ /// Data up to must be preserved.
+ ///
+ /// will be a previously obtained, non-null, allocation.
+ /// will be the sizeBytes or newSizeBytes that was passed when the allocaiton was obtained.
+ ///
+ /// This should resize in place if possible.
+ ///
+ /// If in place resizeing is not possible, the previously allocation will be freed if this method succeeds.
+ ///
+ /// If this allocation cannot be satisifed, must be set to false.
+ ///
+ ref byte ResizeAllocation(ref byte start, int oldSizeBytes, int newSizeBytes, out bool failed);
+ }
+}
\ No newline at end of file
diff --git a/libs/server/Lua/LuaLimitedManagedAllocator.cs b/libs/server/Lua/LuaLimitedManagedAllocator.cs
new file mode 100644
index 0000000000..7d63b55ee8
--- /dev/null
+++ b/libs/server/Lua/LuaLimitedManagedAllocator.cs
@@ -0,0 +1,1088 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+
+using System;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Numerics;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+
+namespace Garnet.server
+{
+ ///
+ /// Provides a mapping of Lua allocations onto the POH.
+ ///
+ /// Pre-allocates the full maximum allocation.
+ ///
+ ///
+ /// This is a really naive allocator, just has a free list tracked with pointers in block headers.
+ ///
+ /// The hope is that Lua's GC keeps the actual use of this down substantially.
+ ///
+ /// A small optimization is trying to keep a big free block at the head of the free list.
+ ///
+ internal sealed class LuaLimitedManagedAllocator : ILuaAllocator
+ {
+ ///
+ /// Minimum size we'll round all Lua allocs up to.
+ ///
+ /// Based on largest "normal" type Lua will allocate and the needs of .
+ ///
+ private const int LuaAllocMinSizeBytes = 16;
+
+ ///
+ /// Represents a block of memory in this mapper.
+ ///
+ /// Blocks always have a size of at least .
+ ///
+ /// Blocks are either free or in use:
+ /// - if free, will be positive.
+ /// - if in use, will be negative.
+ ///
+ /// If free, the block is in a doublely linked list of free blocks.
+ /// - is undefined if block is in use, 0 if block is tail of list, and ref (as long) to next block otherwise
+ /// - is undefined if block is in use, 0 if block is head of list, and ref (as long) to previous block otherwise
+ ///
+ [StructLayout(LayoutKind.Explicit, Size = StructSizeBytes)]
+ private struct BlockHeader
+ {
+ public const int DataOffset = sizeof(int);
+ public const int StructSizeBytes = DataOffset + LuaAllocMinSizeBytes;
+
+ ///
+ /// Size of the block - always valid.
+ ///
+ /// If negative, block is in use.
+ ///
+ /// If you don't already know the state of the block, use .
+ /// If you do, use this as it elides a conditional.
+ ///
+ [FieldOffset(0)]
+ public int SizeBytesRaw;
+
+ ///
+ /// Ref of previous block if this block is in the free list.
+ ///
+ /// Only valid if the block is free.
+ ///
+ /// Most reads should go through or .
+ ///
+ [FieldOffset(sizeof(int))]
+ public long PrevFreeBlockRefRaw;
+
+ ///
+ /// Ref of next block if this block is in the free list.
+ ///
+ /// Only valid if block is free.
+ ///
+ /// Most reads should go through or .
+ ///
+ [FieldOffset(sizeof(int) + sizeof(long))]
+ public long NextFreeBlockRefRaw;
+
+ ///
+ /// True if block is free.
+ ///
+ public readonly bool IsFree
+ => SizeBytesRaw > 0;
+
+ ///
+ /// True if block is in use.
+ ///
+ public readonly bool IsInUse
+ => SizeBytesRaw < 0;
+
+ ///
+ /// Size of the block in bytes.
+ ///
+ /// Prefer to checking directly, as it accounts for state bits.
+ ///
+ public readonly int SizeBytes
+ {
+ get
+ {
+ Debug.Assert(IsFree || IsInUse, "Illegal state, neither free nor in use");
+
+ if (IsFree)
+ {
+ return SizeBytesRaw;
+ }
+ else
+ {
+ return -SizeBytesRaw;
+ }
+ }
+ }
+
+ ///
+ /// Ref (as a long) of the next block in the free list, if any.
+ ///
+ public readonly long NextFreeBlockRef
+ {
+ get
+ {
+ Debug.Assert(IsFree, "Can't be in free list if allocated");
+
+ return NextFreeBlockRefRaw;
+ }
+ }
+
+ ///
+ /// Ref of the previous block in the free list, if any.
+ ///
+ public readonly long PrevFreeBlockRef
+ {
+ get
+ {
+ Debug.Assert(IsFree, "Can't be in free list if allocated");
+
+ return PrevFreeBlockRefRaw;
+ }
+ }
+
+ ///
+ /// Grab a reference to return to users.
+ ///
+ [UnscopedRef]
+ public ref byte DataReference
+ => ref Unsafe.AddByteOffset(ref Unsafe.As(ref this), DataOffset);
+
+ ///
+ /// For debugging purposes, all the data covered by this block.
+ ///
+ public ReadOnlySpan Data
+ => MemoryMarshal.CreateReadOnlySpan(ref DataReference, SizeBytes);
+
+ ///
+ /// Mark block free.
+ ///
+ /// After this, the block can be placed in the free list.
+ ///
+ public void MarkFree()
+ {
+ Debug.Assert(IsInUse, "Double free");
+
+ SizeBytesRaw = -SizeBytesRaw;
+ }
+
+ ///
+ /// Mark block in use.
+ ///
+ /// After this, the and properties cannot be accessed.
+ ///
+ public void MarkInUse()
+ {
+ Debug.Assert(IsFree, "Already allocated");
+
+ SizeBytesRaw = -SizeBytesRaw;
+ }
+
+ ///
+ /// Get a reference to the next block in the free list, or a reference BEFORE .
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public readonly unsafe ref BlockHeader GetNextFreeBlockRef()
+ => ref Unsafe.AsRef((void*)NextFreeBlockRef);
+
+ ///
+ /// Get a reference to the prev block in the free list, or a reference BEFORE .
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public readonly unsafe ref BlockHeader GetPrevFreeBlockRef()
+ => ref Unsafe.AsRef((void*)PrevFreeBlockRef);
+
+ ///
+ /// Get the block the comes after this one in data.
+ ///
+ /// If we're at the end of data, returns a a ref BEFORE .
+ ///
+ [UnscopedRef]
+ public ref BlockHeader GetNextAdjacentBlockRef(ref byte dataStartRef, int finalOffset)
+ {
+ ref var selfRef = ref Unsafe.As(ref this);
+ ref var nextRef = ref Unsafe.Add(ref selfRef, DataOffset + SizeBytes);
+
+ var nextOffset = Unsafe.ByteOffset(ref dataStartRef, ref nextRef);
+
+ if (nextOffset >= finalOffset)
+ {
+ // For symmetry with GetXXXFreeBlockRef return of (void*)0
+ unsafe
+ {
+ return ref Unsafe.AsRef((void*)0);
+ }
+ }
+
+ return ref Unsafe.As(ref nextRef);
+ }
+
+ ///
+ /// Get a value that can be stashed in or
+ /// that refers to this block.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public unsafe long GetRefVal()
+ => (long)Unsafe.AsPointer(ref this);
+ }
+
+ // Very basic allocation story, we just keep block metadata and
+ // a free list in the array.
+ private readonly Memory data;
+
+ private unsafe void* dataStartPtr;
+ private unsafe void* freeListStartPtr;
+
+ private int debugAllocatedBytes;
+
+ ///
+ /// For testing purposes, how many bytes are allocated to Lua at the moment.
+ ///
+ /// This is how many bytes we've handed out, not the size of the backing array on the POH.
+ ///
+ /// Only available in DEBUG to avoid updating this in RELEASE builds.
+ ///
+ internal int AllocatedBytes
+ => debugAllocatedBytes;
+
+ ///
+ /// For testing purposes, the number of blocks tracked in the free list.
+ ///
+ internal int FreeBlockCount
+ {
+ get
+ {
+ ref var dataStartRef = ref GetDataStartRef();
+
+ ref var cur = ref GetFreeList();
+
+ var ret = 0;
+
+ while (IsValidBlockRef(ref cur))
+ {
+ ret++;
+
+ cur = ref cur.GetNextFreeBlockRef();
+ }
+
+ return ret;
+ }
+ }
+
+ ///
+ /// For testing purposes, the size of the initial block.
+ ///
+ /// Does not care if the block is free or allocated.
+ ///
+ internal int FirstBlockSizeBytes
+ {
+ get
+ {
+ ref var dataStartRef = ref GetDataStartRef();
+
+ ref var firstBlock = ref Unsafe.As(ref dataStartRef);
+
+ return firstBlock.SizeBytes;
+ }
+ }
+
+ internal LuaLimitedManagedAllocator(int backingArraySize)
+ {
+ Debug.Assert(backingArraySize >= BlockHeader.StructSizeBytes, "Too small to ever allocate");
+
+ // Pinned because Lua is oblivious to .NET's GC
+ data = GC.AllocateUninitializedArray(backingArraySize, pinned: true);
+ unsafe
+ {
+ dataStartPtr = Unsafe.AsPointer(ref MemoryMarshal.GetReference(data.Span));
+ }
+
+ ref var dataStartRef = ref GetDataStartRef();
+
+ // Initialize the first block as a free block covering the whole allocation
+ unsafe
+ {
+ freeListStartPtr = Unsafe.AsPointer(ref dataStartRef);
+ }
+ ref var firstBlock = ref GetFreeList();
+ firstBlock.SizeBytesRaw = backingArraySize - BlockHeader.DataOffset;
+ firstBlock.NextFreeBlockRefRaw = 0;
+ firstBlock.PrevFreeBlockRefRaw = 0;
+ }
+
+ ///
+ /// Allocate a new chunk which can fit at least of data.
+ ///
+ /// Sets to true if the allocation failed.
+ ///
+ /// If the allocation failes, the returned ref will be null.
+ ///
+ public ref byte AllocateNew(int sizeBytes, out bool failed)
+ {
+ ref var dataStartRef = ref GetDataStartRef();
+
+ var actualSizeBytes = RoundToMinAlloc(sizeBytes);
+
+ var firstAttempt = true;
+
+ tryAgain:
+ ref var freeList = ref GetFreeList();
+ ref var cur = ref freeList;
+ while (IsValidBlockRef(ref cur))
+ {
+ Debug.Assert(cur.IsFree, "Free list corrupted");
+
+ // We know this block is free, so touch size directly
+ var freeBlockSize = cur.SizeBytesRaw;
+
+ if (freeBlockSize < actualSizeBytes)
+ {
+ // Couldn't fit in the block, move on
+ cur = ref cur.GetNextFreeBlockRef();
+ continue;
+ }
+
+ if (ShouldSplit(ref cur, actualSizeBytes))
+ {
+ ref var newBlock = ref SplitFreeBlock(ref dataStartRef, ref cur, actualSizeBytes);
+ Debug.Assert(cur.IsFree && cur.SizeBytes == actualSizeBytes, "Split produced unexpected block");
+
+ freeBlockSize = actualSizeBytes;
+
+ // We know both these blocks are free, so use size directly
+ if (!Unsafe.AreSame(ref cur, ref freeList) && newBlock.SizeBytesRaw > freeList.SizeBytesRaw)
+ {
+ // Move the split block to the head of the list if it's large enough AND we had to traverse
+ // the free list at all
+ //
+ // Idea is to keep bigger things towards the front of the free list if we're spending any time
+ // chasing pointers
+ MoveToHeadOfFreeList(ref dataStartRef, ref freeList, ref newBlock);
+ }
+ }
+
+ // Cur will work, so remove it from the free list, mark it, and return it
+ RemoveFromFreeList(ref dataStartRef, ref cur);
+ cur.MarkInUse();
+
+ UpdateDebugAllocatedBytes(freeBlockSize);
+
+ failed = false;
+ return ref cur.DataReference;
+ }
+
+ // Expensively compact if we're going to fail anyway
+ if (firstAttempt)
+ {
+ firstAttempt = false;
+ if (TryCoalesceAllFreeBlocks())
+ {
+ goto tryAgain;
+ }
+ }
+
+ // Even after compaction we failed to find a large enough block
+ failed = true;
+ return ref Unsafe.NullRef();
+ }
+
+ ///
+ /// Return a chunk of memory previously acquired by or
+ /// .
+ ///
+ /// Previously returned (non-null) value.
+ /// Size passed to last or call.
+ public void Free(ref byte start, int sizeBytes)
+ {
+ ref var dataStartRef = ref GetDataStartRef();
+
+ ref var blockRef = ref GetBlockRef(ref dataStartRef, ref start);
+
+ Debug.Assert(blockRef.IsInUse, "Should be in use");
+
+ blockRef.MarkFree();
+ AddToFreeList(ref dataStartRef, ref blockRef);
+ Debug.Assert(blockRef.IsFree, "Should be free by this point");
+
+ // We know the block is free, so touch size directly
+ UpdateDebugAllocatedBytes(-blockRef.SizeBytesRaw);
+ }
+
+ ///
+ /// Akin to , except reuses the original allocation given in if possible.
+ ///
+ public ref byte ResizeAllocation(ref byte start, int oldSizeBytes, int newSizeBytes, out bool failed)
+ {
+ ref var dataStartRef = ref GetDataStartRef();
+
+ ref var curBlock = ref GetBlockRef(ref dataStartRef, ref start);
+ Debug.Assert(curBlock.IsInUse, "Shouldn't be resizing an allocation that isn't in use");
+
+ // For everything else, move things up to a reasonable multiple
+ var actualSizeBytes = RoundToMinAlloc(newSizeBytes);
+
+ // We know the block is in use, so use raw size directly
+ if (-curBlock.SizeBytesRaw >= newSizeBytes)
+ {
+ // Existing allocation is large enough
+
+ if (ShouldSplit(ref curBlock, actualSizeBytes))
+ {
+ // Move the unused space into a new block
+
+ SplitInUseBlock(ref dataStartRef, ref curBlock, actualSizeBytes);
+
+ // And try and coalesce that new block into the largest possible one
+ ref var nextBlock = ref curBlock.GetNextAdjacentBlockRef(ref dataStartRef, data.Length);
+ if (IsValidBlockRef(ref nextBlock))
+ {
+ while (TryCoalesceSingleBlock(ref dataStartRef, ref nextBlock, out _))
+ {
+ }
+ }
+ }
+
+ failed = false;
+ return ref start;
+ }
+
+ // Attempt to grow the allocation in place
+ var keepInPlace = false;
+
+ while (TryCoalesceSingleBlock(ref dataStartRef, ref curBlock, out var updatedCurBlockSizeBytes))
+ {
+ if (updatedCurBlockSizeBytes >= actualSizeBytes)
+ {
+ keepInPlace = true;
+ break;
+ }
+ }
+
+ if (keepInPlace)
+ {
+ // We built a big enough block, so use it
+ if (ShouldSplit(ref curBlock, actualSizeBytes))
+ {
+ // We coalesced such that there's a lot of empty space at the end of this block, peel it off for later reuse
+ ref var newBlock = ref SplitInUseBlock(ref dataStartRef, ref curBlock, actualSizeBytes);
+ Debug.Assert(curBlock.IsInUse && curBlock.SizeBytes == actualSizeBytes, "Split produced unexpected block");
+
+ ref var freeList = ref GetFreeList();
+
+ // We know freeList is valid and both blocks are free, so check size directly
+ if (!Unsafe.AreSame(ref newBlock, ref freeList) && newBlock.SizeBytesRaw > freeList.SizeBytesRaw)
+ {
+ MoveToHeadOfFreeList(ref dataStartRef, ref freeList, ref newBlock);
+ }
+ }
+
+ failed = false;
+ return ref start;
+ }
+
+ // We couldn't resize in place, so we need to copy into a new alloc
+ ref var newAlloc = ref AllocateNew(newSizeBytes, out failed);
+ if (failed)
+ {
+ // Couldn't get a new allocation - per spec this leaves the old allocation alone
+ return ref Unsafe.NullRef();
+ }
+
+ // Copy the data over
+ var copyLen = newSizeBytes < oldSizeBytes ? newSizeBytes : oldSizeBytes;
+
+ var copyFrom = MemoryMarshal.CreateReadOnlySpan(ref curBlock.DataReference, copyLen);
+ var copyInto = MemoryMarshal.CreateSpan(ref newAlloc, copyLen);
+
+ copyFrom.CopyTo(copyInto);
+
+ // Free the old alloc now that data is copied out
+ Free(ref start, oldSizeBytes);
+
+ return ref newAlloc;
+ }
+
+ ///
+ /// Returns true if this reference might have been handed out by this allocator.
+ ///
+ internal bool ContainsRef(ref byte startRef)
+ {
+ ref var dataStartRef = ref GetDataStartRef();
+
+ var delta = Unsafe.ByteOffset(ref dataStartRef, ref startRef);
+
+ return delta >= 0 && delta < data.Length;
+ }
+
+ ///
+ /// Validate the allocator.
+ ///
+ /// For testing purposes only.
+ ///
+ internal void CheckCorrectness()
+ {
+ ref var dataStartRef = ref GetDataStartRef();
+
+ // Check for cycles in free lists
+ {
+ // Basic tortoise and hare:
+ // - freeSlow moves forward 1 link at a time
+ // - freeFast moves forward 2 links at a time
+ // - if freeSlow == freeFast then we have a cycle
+
+ ref var freeSlow = ref GetFreeList();
+ ref var freeFast = ref freeSlow;
+ if (IsValidBlockRef(ref freeFast))
+ {
+ freeFast = ref freeFast.GetNextFreeBlockRef();
+ }
+
+ while (IsValidBlockRef(ref freeSlow) && IsValidBlockRef(ref freeFast))
+ {
+ freeSlow = ref freeSlow.GetNextFreeBlockRef();
+ freeFast = ref freeFast.GetNextFreeBlockRef();
+ if (IsValidBlockRef(ref freeFast))
+ {
+ freeFast = ref freeFast.GetNextFreeBlockRef();
+ }
+
+ Check(!Unsafe.AreSame(ref freeSlow, ref freeFast), "Cycle exists in free list");
+ }
+ }
+
+ var walkFreeBlocks = 0;
+ var walkFreeBytes = 0L;
+
+ // Walk the free list, counting and check that pointer make sense
+ {
+ ref var prevFree = ref Unsafe.NullRef();
+ ref var curFree = ref GetFreeList();
+ while (IsValidBlockRef(ref curFree))
+ {
+ Check(curFree.IsFree, "Allocated block in free list");
+ Check(ContainsRef(ref Unsafe.As(ref curFree)), "Free block not in managed bounds");
+
+ walkFreeBlocks++;
+
+ Check(curFree.SizeBytes > 0, "Illegal size for a free block");
+
+ walkFreeBytes += curFree.SizeBytes;
+
+ if (IsValidBlockRef(ref prevFree))
+ {
+ ref var prevFromCur = ref curFree.GetPrevFreeBlockRef();
+ Check(Unsafe.AreSame(ref prevFree, ref prevFromCur), "Prev link invalid");
+ }
+
+ prevFree = ref curFree;
+ curFree = ref curFree.GetNextFreeBlockRef();
+ }
+ }
+
+ var scanFreeBlocks = 0;
+ var scanFreeBytes = 0L;
+ var scanAllocatedBlocks = 0;
+ var scanAllocatedBytes = 0L;
+
+ // Scan the whole array, counting free and allocated blocks
+ {
+ ref var cur = ref Unsafe.As(ref dataStartRef);
+ while (IsValidBlockRef(ref cur))
+ {
+ Check(ContainsRef(ref Unsafe.As(ref cur)), "Free block not in managed bounds");
+
+ if (cur.IsFree)
+ {
+ scanFreeBlocks++;
+ scanFreeBytes += cur.SizeBytes;
+ }
+ else
+ {
+ Check(cur.IsInUse, "Illegal block state");
+
+ scanAllocatedBlocks++;
+ scanAllocatedBytes += cur.SizeBytes;
+ }
+
+ cur = ref cur.GetNextAdjacentBlockRef(ref dataStartRef, data.Length);
+ }
+ }
+
+ Check(scanFreeBlocks == walkFreeBlocks, "Free block mismatch");
+ Check(scanFreeBytes == walkFreeBytes, "Free bytes mismatch");
+
+ DebugCheck(scanAllocatedBytes == AllocatedBytes, "Allocated bytes mismatch");
+
+ var totalBlocks = scanAllocatedBlocks + scanFreeBlocks;
+ var totalBytes = scanAllocatedBytes + scanFreeBytes;
+
+ var expectedOverhead = totalBlocks * BlockHeader.DataOffset;
+
+ var allBytes = totalBytes + expectedOverhead;
+ Check(allBytes == data.Length, "Bytes unaccounted for");
+
+ // Throws if shouldBe is false
+ static void Check(bool shouldBe, string errorMsg, [CallerArgumentExpression(nameof(shouldBe))] string shouldBeExpr = null)
+ {
+ if (shouldBe)
+ {
+ return;
+ }
+
+ throw new InvalidOperationException($"Check failed: {errorMsg} ({shouldBeExpr})");
+ }
+
+ // Throws if shouldBe is false, but only in DEBUG builds
+ [Conditional("DEBUG")]
+ static void DebugCheck(bool shouldBe, string errorMsg, [CallerArgumentExpression(nameof(shouldBe))] string shouldBeExpr = null)
+ => Check(shouldBe, errorMsg, shouldBeExpr);
+ }
+
+ ///
+ /// Do a very expensive pass attempting to coalesce free blocks as much as possible.
+ ///
+ internal bool TryCoalesceAllFreeBlocks()
+ {
+ ref var dataStartRef = ref GetDataStartRef();
+
+ var madeProgress = false;
+ ref var cur = ref GetFreeList();
+ while (IsValidBlockRef(ref cur))
+ {
+ // Coalesce this block repeatedly, so runs of free blocks are collapsed into one
+ while (TryCoalesceSingleBlock(ref dataStartRef, ref cur, out _))
+ {
+ madeProgress = true;
+ }
+
+ cur = ref cur.GetNextFreeBlockRef();
+ }
+
+ return madeProgress;
+ }
+
+ ///
+ /// Add a the given free block to the free list.
+ ///
+ private void AddToFreeList(ref byte dataStartRef, ref BlockHeader block)
+ {
+ Debug.Assert(block.IsFree, "Can only add free blocks to list");
+
+ ref var freeListStartRef = ref GetFreeList();
+ if (!IsValidBlockRef(ref freeListStartRef))
+ {
+ // Free list is empty
+
+ block.PrevFreeBlockRefRaw = 0;
+ block.NextFreeBlockRefRaw = 0;
+
+ unsafe
+ {
+ freeListStartPtr = Unsafe.AsPointer(ref block);
+ }
+ }
+ else
+ {
+ Debug.Assert(freeListStartRef.IsFree, "Free list is corrupted");
+
+ // Free list isn't empty - block can go in either the head position
+ // or the one-past-the-head position depending on if block is bigger
+ // than the free list head
+
+ var freeListStartRefVal = freeListStartRef.GetRefVal();
+ var blockRefVal = block.GetRefVal();
+
+ // We know both blocks are free, so use size directly
+ if (block.SizeBytesRaw >= freeListStartRef.SizeBytesRaw)
+ {
+ // Block is larger, prefer allocating out of it
+
+ // Move block to head of list, and old head to immediately after block
+ freeListStartRef.PrevFreeBlockRefRaw = blockRefVal;
+ block.NextFreeBlockRefRaw = freeListStartRefVal;
+ block.PrevFreeBlockRefRaw = 0;
+
+ // Block is new head of free list
+ unsafe
+ {
+ freeListStartPtr = Unsafe.AsPointer(ref block);
+ }
+ }
+ else
+ {
+ // Block is smaller, it's a second choice for allocations
+
+ ref var oldFreeListNextBlock = ref freeListStartRef.GetNextFreeBlockRef();
+
+ // Move block to right after free list head
+ block.PrevFreeBlockRefRaw = freeListStartRefVal;
+ block.NextFreeBlockRefRaw = freeListStartRef.NextFreeBlockRef;
+ freeListStartRef.NextFreeBlockRefRaw = blockRefVal;
+
+ if (IsValidBlockRef(ref oldFreeListNextBlock))
+ {
+ // Update back-pointer in next block to point to block
+ oldFreeListNextBlock.PrevFreeBlockRefRaw = blockRefVal;
+ }
+
+ // Head of free list is unchanged
+ }
+ }
+ }
+
+ ///
+ /// Removes the given block from the free list.
+ ///
+ private void RemoveFromFreeList(ref byte dataStartRef, ref BlockHeader block)
+ {
+ Debug.Assert(IsValidBlockRef(ref GetFreeList()), "Shouldn't be removing from free list if free list is empty");
+ Debug.Assert(block.IsFree, "Only valid for free blocks");
+
+ var blockRefVal = block.GetRefVal();
+
+ ref var prevFreeBlock = ref block.GetPrevFreeBlockRef();
+ ref var nextFreeBlock = ref block.GetNextFreeBlockRef();
+
+ // We've got a few different states we could be in here:
+ // 1. block is the only thing in the free list
+ // => freeList == block, prevFreeBlock == null, nextFreeBlock == null
+ // 2. block is the first thing in the free list, but not the only thing
+ // => freeList == block, prevFreeBlock == null, nextFreeBlock != null, nextFreeBLock.prev == block
+ // 3. block is the last thing in the free list, but not the first
+ // => freeList is valid, freeListStart != block, prevFreeBlock != null, prevFreeBlock.next == block, nextFreeBlock == null
+ // 4. block is in the middle of the list somewhere
+ // => freeList is valid, freeListStart != block, prevFreeBlock != null, prefFreeBlock.next == block, nextFreeBlock != null, nextFreeBlock.prev = next
+
+ ref var freeList = ref GetFreeList();
+
+ if (Unsafe.AreSame(ref freeList, ref block))
+ {
+ Debug.Assert(!IsValidBlockRef(ref prevFreeBlock), "Should be no prev pointer if block is head of free list");
+
+ if (!IsValidBlockRef(ref nextFreeBlock))
+ {
+ // We're in state #1 - block is the only thing in the free list
+
+ // Remove this last block from the free list, leaving it empty
+ unsafe
+ {
+ // For simplicity, treat empty free lists the same way we do end of the free list
+ freeListStartPtr = (void*)0;
+ }
+ return;
+ }
+ else
+ {
+ // We're in state #2 - block is first thing in the free list, but not the last
+ Debug.Assert(nextFreeBlock.PrevFreeBlockRef == blockRefVal, "Broken chain in free list");
+
+ // NextFreeBlock is new head, it needs to prev point now
+ nextFreeBlock.PrevFreeBlockRefRaw = 0;
+ unsafe
+ {
+ freeListStartPtr = (void*)block.NextFreeBlockRef;
+ }
+ return;
+ }
+ }
+ else
+ {
+ Debug.Assert(IsValidBlockRef(ref prevFreeBlock), "Should always have a prev pointer if not head of free list");
+
+ if (!IsValidBlockRef(ref nextFreeBlock))
+ {
+ // We're in state #3 - block is last thing in the free list, but not the first
+ Debug.Assert(prevFreeBlock.NextFreeBlockRef == blockRefVal, "Broken chain in free list");
+
+ // Remove pointer to this block from it's preceeding block
+ prevFreeBlock.NextFreeBlockRefRaw = 0;
+ return;
+ }
+ else
+ {
+ // We're in state #4 - block is just in the middle of the free list somewhere, but not last or first
+ Debug.Assert(prevFreeBlock.NextFreeBlockRef == blockRefVal, "Broken chain in free list");
+ Debug.Assert(nextFreeBlock.PrevFreeBlockRef == blockRefVal, "Broken chain in free list");
+
+ // Prev needs to skip this block when going forward
+ prevFreeBlock.NextFreeBlockRefRaw = block.NextFreeBlockRef;
+
+ // Next needs to skip this block when going back
+ nextFreeBlock.PrevFreeBlockRefRaw = block.PrevFreeBlockRef;
+ return;
+ }
+ }
+ }
+
+ ///
+ /// Move this given block to the head of the free list.
+ ///
+ /// This assumes that is not already at the head of the free list, and is not
+ /// empty.
+ ///
+ private void MoveToHeadOfFreeList(ref byte dataStartRef, ref BlockHeader freeList, ref BlockHeader block)
+ {
+ Debug.Assert(block.IsFree, "Should be free");
+ Debug.Assert(freeList.IsFree, "Free list corrupted");
+ Debug.Assert(!IsValidBlockRef(ref freeList.GetPrevFreeBlockRef()), "Free list corrupted");
+ Debug.Assert(IsValidBlockRef(ref freeList.GetNextFreeBlockRef()), "Free list corrupted");
+
+ // Because block is not head of the free list, we have two cases here
+ // 1. Block is in the middle of the list somewhere
+ // 2. Block is the tail of the list
+
+ ref var prevBlock = ref block.GetPrevFreeBlockRef();
+ Debug.Assert(IsValidBlockRef(ref prevBlock), "Block shouldn't be head of free list");
+
+ ref var nextBlock = ref block.GetNextFreeBlockRef();
+ if (IsValidBlockRef(ref nextBlock))
+ {
+ // Case 1 - block is in the middle of the list
+
+ // Update prev.next so it points to block.next and next.prev so it points to block.prev
+ prevBlock.NextFreeBlockRefRaw = block.NextFreeBlockRef;
+ nextBlock.PrevFreeBlockRefRaw = block.PrevFreeBlockRef;
+ }
+ else
+ {
+ // Case 2 - block is at the end of the list
+
+ // Remove prev's pointer to block
+ prevBlock.NextFreeBlockRefRaw = 0;
+ }
+
+ // Move block to the head of the list
+ freeList.PrevFreeBlockRefRaw = block.GetRefVal();
+ block.NextFreeBlockRefRaw = freeList.GetRefVal();
+ block.PrevFreeBlockRefRaw = 0;
+
+ // Update freeListStartPtr to refer to block
+ unsafe
+ {
+ freeListStartPtr = Unsafe.AsPointer(ref block);
+ }
+ }
+
+ ///
+ /// Attempt to coalesce a block with its adjacent block.
+ ///
+ /// can be free or allocated, but coalescing will only succeed
+ /// if the adjacent block is free.
+ ///
+ /// is set only if the return is true.
+ ///
+ private bool TryCoalesceSingleBlock(ref byte dataStartRef, ref BlockHeader block, out int newBlockSizeBytes)
+ {
+ ref var nextBlock = ref block.GetNextAdjacentBlockRef(ref dataStartRef, data.Length);
+ if (!IsValidBlockRef(ref nextBlock) || !nextBlock.IsFree)
+ {
+ Unsafe.SkipInit(out newBlockSizeBytes);
+ return false;
+ }
+
+ // We know nextBlock is free, so touch size directly
+ var nextBlockSizeBytes = nextBlock.SizeBytesRaw;
+
+ RemoveFromFreeList(ref dataStartRef, ref nextBlock);
+
+ if (block.IsFree)
+ {
+ newBlockSizeBytes = nextBlockSizeBytes + block.SizeBytesRaw + BlockHeader.DataOffset;
+
+ block.SizeBytesRaw = newBlockSizeBytes;
+ }
+ else
+ {
+ // Because we merged a free block into an allocated one we need to update the allocated byte total
+
+ // We know the block is in use, so block.SizeBytesRaw is -SizeBytes
+ newBlockSizeBytes = nextBlockSizeBytes - block.SizeBytesRaw + BlockHeader.DataOffset;
+ UpdateDebugAllocatedBytes(block.SizeBytesRaw);
+
+ block.SizeBytesRaw = -newBlockSizeBytes;
+
+ UpdateDebugAllocatedBytes(newBlockSizeBytes);
+ }
+
+ return true;
+ }
+
+ ///
+ /// Split an in use block, such that the current block ends up with a size equal to .
+ ///
+ /// Returns a reference to the NEW free block.
+ ///
+ private ref BlockHeader SplitInUseBlock(ref byte dataStartRef, ref BlockHeader curBlock, int curBlockUpdateSizeBytes)
+ {
+ Debug.Assert(curBlock.IsInUse, "Only valid for in use blocks");
+
+ // We know the block is in use, so touch size bytes directly
+ var oldSizeBytes = -curBlock.SizeBytesRaw;
+
+ ref var newBlock = ref SplitCommon(ref curBlock, curBlockUpdateSizeBytes);
+
+ // New block needs to be placed in free list
+ AddToFreeList(ref dataStartRef, ref newBlock);
+
+ // Because we split some bytes out of an allocated block, that means we need to remove those from allocation tracking
+ UpdateDebugAllocatedBytes(-oldSizeBytes);
+
+ // We know the block is in use, so -SizeBytesRaw will be SizeBytes (unconditionally)
+ UpdateDebugAllocatedBytes(-curBlock.SizeBytesRaw);
+
+ return ref newBlock;
+ }
+
+ ///
+ /// Split a free block such that the current block ends up with a size equal to .
+ ///
+ /// Returns a reference to the NEW block.
+ ///
+ private ref BlockHeader SplitFreeBlock(ref byte dataStartRef, ref BlockHeader curBlock, int curBlockUpdateSizeBytes)
+ {
+ Debug.Assert(curBlock.IsFree, "Only valid for free blocks");
+
+ ref var oldNextBlock = ref curBlock.GetNextFreeBlockRef();
+
+ ref var newBlock = ref SplitCommon(ref curBlock, curBlockUpdateSizeBytes);
+
+ var curBlockRefVal = curBlock.GetRefVal();
+
+ // Update newBlock
+ newBlock.PrevFreeBlockRefRaw = curBlockRefVal;
+ newBlock.NextFreeBlockRefRaw = curBlock.NextFreeBlockRef;
+ Debug.Assert(newBlock.IsFree, "New block shoud be free");
+ Debug.Assert(ContainsRef(ref Unsafe.As(ref newBlock)), "New block out of managed memory");
+
+ var newBlockRefVal = newBlock.GetRefVal();
+
+ // Update curBlock
+ curBlock.NextFreeBlockRefRaw = newBlockRefVal;
+ Debug.Assert(curBlock.IsFree, "New block shoud be free");
+ Debug.Assert(ContainsRef(ref Unsafe.As(ref curBlock)), "Split block out of managed memory");
+
+ // Update the old next block if it exists
+ if (IsValidBlockRef(ref oldNextBlock))
+ {
+ Debug.Assert(oldNextBlock.IsFree, "Should have been in free list");
+ oldNextBlock.PrevFreeBlockRefRaw = newBlockRefVal;
+ }
+
+ return ref newBlock;
+ }
+
+ ///
+ /// Grab the start of the managed memory we're allocating out of.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private unsafe ref byte GetDataStartRef()
+ => ref Unsafe.AsRef(dataStartPtr);
+
+ ///
+ /// Get the start of the free list.
+ ///
+ /// If the free list is empty, returns a null ref.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private unsafe ref BlockHeader GetFreeList()
+ => ref Unsafe.AsRef(freeListStartPtr);
+
+ ///
+ /// Turn a reference obtained from back into a reference.
+ ///
+ private ref BlockHeader GetBlockRef(ref byte dataStartRef, ref byte userDataRef)
+ {
+ Debug.Assert(!Unsafe.AreSame(ref dataStartRef, ref userDataRef), "User data is actually 0 size alloc, that doesn't make sense");
+
+ ref var blockStartByteRef = ref Unsafe.Add(ref userDataRef, -BlockHeader.DataOffset);
+
+ Debug.Assert(!Unsafe.IsAddressLessThan(ref blockStartByteRef, ref dataStartRef), "User data is before managed memory");
+ Debug.Assert(!Unsafe.IsAddressGreaterThan(ref blockStartByteRef, ref Unsafe.Add(ref dataStartRef, data.Length - 1)), "User data is after managed memory");
+
+ return ref Unsafe.As(ref blockStartByteRef);
+ }
+
+ ///
+ /// In DEBUG builds, keep track of how many bytes we've allocated for testing purposes.
+ ///
+ [Conditional("DEBUG")]
+ private void UpdateDebugAllocatedBytes(int by)
+ => debugAllocatedBytes += by;
+
+ ///
+ /// Common logic for splitting blocks.
+ ///
+ /// Block here can either be free or in use, so don't make any assumptionsin here.
+ ///
+ private static ref BlockHeader SplitCommon(ref BlockHeader curBlock, int curBlockUpdateSizeBytes)
+ {
+ Debug.Assert(curBlockUpdateSizeBytes >= LuaAllocMinSizeBytes, "Shouldn't split an existing block to be this small");
+
+ ref var curBlockData = ref curBlock.DataReference;
+ ref var newBlockStartByteRef = ref Unsafe.AddByteOffset(ref curBlockData, curBlockUpdateSizeBytes);
+ ref var newBlock = ref Unsafe.As(ref newBlockStartByteRef);
+
+ int newBlockSizeBytes;
+ if (curBlock.IsFree)
+ {
+ // We know curBlock is free, so touch SizeBytesRaw directly
+ newBlockSizeBytes = curBlock.SizeBytesRaw - BlockHeader.DataOffset - curBlockUpdateSizeBytes;
+
+ curBlock.SizeBytesRaw = curBlockUpdateSizeBytes;
+ }
+ else
+ {
+ Debug.Assert(curBlock.IsInUse, "Invalid block state");
+
+ // We know curBlock is is inuse, so touch SizeBytesRaw directly
+ newBlockSizeBytes = -curBlock.SizeBytesRaw - BlockHeader.DataOffset - curBlockUpdateSizeBytes;
+
+ curBlock.SizeBytesRaw = -curBlockUpdateSizeBytes;
+ }
+
+ Debug.Assert(newBlockSizeBytes >= LuaAllocMinSizeBytes, "Shouldn't create a new block this small");
+
+ // The new block is always free, so positive size
+ newBlock.SizeBytesRaw = newBlockSizeBytes;
+
+
+ return ref newBlock;
+ }
+
+ ///
+ /// Check if a block should be split if it's used to serve a claim of the given size.
+ ///
+ private static bool ShouldSplit(ref BlockHeader block, int claimedBytes)
+ {
+ var unusedBytes = block.SizeBytes - claimedBytes;
+
+ return unusedBytes >= BlockHeader.StructSizeBytes;
+ }
+
+ ///
+ /// Turn requested bytes into the actual number of bytes we're going to reserve.
+ ///
+ private static int RoundToMinAlloc(int sizeBytes)
+ {
+ Debug.Assert(BitOperations.IsPow2(LuaAllocMinSizeBytes), "Assumes min allocation is a power of 2");
+
+ // To avoid special casing 0, we can reserve up to an extra LuaAllocMinSizeBytes
+ var ret = (sizeBytes + LuaAllocMinSizeBytes) & ~(LuaAllocMinSizeBytes - 1);
+
+ Debug.Assert(ret > 0, "Rounding logic invalid - not positive");
+ Debug.Assert(ret >= sizeBytes, "Rounding logic invalid - did not round up");
+ Debug.Assert(ret % LuaAllocMinSizeBytes == 0, "Rounding logic invalid - not a whole multiple of step size");
+
+ return ret;
+ }
+
+ ///
+ /// Check if is valid.
+ ///
+ /// Pulled out to indicate intent, it's basically just a null check
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static unsafe bool IsValidBlockRef(ref BlockHeader blockRef)
+ => Unsafe.AsPointer(ref blockRef) != (void*)0;
+ }
+}
\ No newline at end of file
diff --git a/libs/server/Lua/LuaManagedAllocator.cs b/libs/server/Lua/LuaManagedAllocator.cs
new file mode 100644
index 0000000000..72d777143e
--- /dev/null
+++ b/libs/server/Lua/LuaManagedAllocator.cs
@@ -0,0 +1,123 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Numerics;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+
+namespace Garnet.server
+{
+ ///
+ /// Provides a mapping of Lua allocations on to the POH.
+ ///
+ /// Unlike , allocations are not limited.
+ ///
+ ///
+ /// This is implemented in terms of because
+ /// it is expected to be ununusual to want allocations on the POH but unlimitted.
+ ///
+ internal sealed class LuaManagedAllocator : ILuaAllocator
+ {
+ private const int DefaultAllocationBytes = 2 * 1_024 * 1_024;
+
+ private readonly List subAllocators = [];
+ public LuaManagedAllocator() { }
+
+ ///
+ public ref byte AllocateNew(int sizeBytes, out bool failed)
+ {
+ for (var i = 0; i < subAllocators.Count; i++)
+ {
+ var alloc = subAllocators[i];
+
+ ref var ret = ref alloc.AllocateNew(sizeBytes, out failed);
+ if (!failed)
+ {
+ return ref ret;
+ }
+ }
+
+ var newAllocSize = DefaultAllocationBytes;
+
+ // Need to account for overhead in LuaLimitedManagedAllocator blocks
+ var minAllocSize = sizeBytes + sizeof(int);
+ if (newAllocSize < minAllocSize)
+ {
+ newAllocSize = (int)BitOperations.RoundUpToPowerOf2((nuint)minAllocSize);
+ if (newAllocSize < 0)
+ {
+ // Handle overflow
+ failed = true;
+ return ref Unsafe.NullRef();
+ }
+ }
+
+ var newAlloc = new LuaLimitedManagedAllocator(newAllocSize);
+ subAllocators.Add(newAlloc);
+
+ ref var newRet = ref newAlloc.AllocateNew(sizeBytes, out failed);
+ Debug.Assert(!failed, "Failure shouldn't be possible with new alloc");
+
+ return ref newRet;
+ }
+
+ ///
+ public void Free(ref byte start, int sizeBytes)
+ {
+ for (var i = 0; i < subAllocators.Count; i++)
+ {
+ var alloc = subAllocators[i];
+
+ if (alloc.ContainsRef(ref start))
+ {
+ alloc.Free(ref start, sizeBytes);
+
+ return;
+ }
+ }
+
+ throw new InvalidOperationException("Allocation could not be found");
+ }
+
+ ///
+ public ref byte ResizeAllocation(ref byte start, int oldSizeBytes, int newSizeBytes, out bool failed)
+ {
+ for (var i = 0; i < subAllocators.Count; i++)
+ {
+ var alloc = subAllocators[i];
+
+ if (alloc.ContainsRef(ref start))
+ {
+ ref var ret = ref alloc.ResizeAllocation(ref start, oldSizeBytes, newSizeBytes, out failed);
+ if (!failed)
+ {
+ return ref ret;
+ }
+
+ // Have to make a copy into a new allocation
+ ref var newAlloc = ref AllocateNew(newSizeBytes, out failed);
+ if (failed)
+ {
+ return ref Unsafe.NullRef();
+ }
+
+ var copyLen = newSizeBytes < oldSizeBytes ? newSizeBytes : oldSizeBytes;
+ var from = MemoryMarshal.CreateReadOnlySpan(ref start, copyLen);
+ var to = MemoryMarshal.CreateSpan(ref newAlloc, copyLen);
+ from.CopyTo(to);
+
+ // Release the old allocation
+ alloc.Free(ref start, oldSizeBytes);
+
+ failed = false;
+ return ref newAlloc;
+ }
+ }
+
+ throw new InvalidOperationException("Allocation could not be found");
+ }
+ }
+}
\ No newline at end of file
diff --git a/libs/server/Lua/LuaOptions.cs b/libs/server/Lua/LuaOptions.cs
new file mode 100644
index 0000000000..395d20ff38
--- /dev/null
+++ b/libs/server/Lua/LuaOptions.cs
@@ -0,0 +1,87 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+
+using System.Runtime.InteropServices;
+using Microsoft.Extensions.Logging;
+
+namespace Garnet.server
+{
+ ///
+ /// Options for Lua scripting.
+ ///
+ public sealed class LuaOptions
+ {
+ private readonly ILogger logger;
+
+ public LuaMemoryManagementMode MemoryManagementMode = LuaMemoryManagementMode.Native;
+ public string MemoryLimit = "";
+
+ ///
+ /// Construct options with default options.
+ ///
+ public LuaOptions(ILogger logger = null)
+ {
+ this.logger = logger;
+ }
+
+ ///
+ /// Construct options with specific settings.
+ ///
+ public LuaOptions(LuaMemoryManagementMode memoryMode, string memoryLimit, ILogger logger = null) : this(logger)
+ {
+ MemoryManagementMode = memoryMode;
+ MemoryLimit = memoryLimit;
+ }
+
+ ///
+ /// Get the memory limit, if any, for each script invocation.
+ ///
+ internal int? GetMemoryLimitBytes()
+ {
+ if (string.IsNullOrEmpty(MemoryLimit))
+ {
+ return null;
+ }
+
+ if (MemoryManagementMode == LuaMemoryManagementMode.Native)
+ {
+ logger?.LogWarning("Lua script memory limit is ignored when mode = {MemoryManagementMode}", MemoryManagementMode);
+ return null;
+ }
+
+ var ret = GarnetServerOptions.ParseSize(MemoryLimit);
+ if (ret is > int.MaxValue or < 1_024)
+ {
+ logger?.LogWarning("Lua script memory limit is out of range [1K, 2GB] = {MemoryLimit} and will be ignored", MemoryLimit);
+ return null;
+ }
+
+ return (int)ret;
+ }
+ }
+
+ ///
+ /// Different Lua supported memory modes.
+ ///
+ public enum LuaMemoryManagementMode
+ {
+ ///
+ /// Uses default Lua allocator - .NET host is unaware of allocations.
+ ///
+ Native = 0,
+
+ ///
+ /// Uses and informs .NET host of the allocations.
+ ///
+ /// Limits are inexactly applied due to native memory allocation overhead.
+ ///
+ Tracked = 1,
+
+ ///
+ /// Places allocations on the POH using a naive, free-list based, allocator.
+ ///
+ /// Limits are pre-allocated when scripts runs, which can increase allocation pressure.
+ ///
+ Managed = 2,
+ }
+}
\ No newline at end of file
diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs
index 38851209da..d9525df2da 100644
--- a/libs/server/Lua/LuaRunner.cs
+++ b/libs/server/Lua/LuaRunner.cs
@@ -74,7 +74,6 @@ private unsafe struct RunnerAdapter : IResponseAdapter
private readonly ScratchBufferManager bufferManager;
private byte* origin;
private byte* curHead;
- private byte* curEnd;
internal RunnerAdapter(ScratchBufferManager bufferManager)
{
@@ -84,7 +83,7 @@ internal RunnerAdapter(ScratchBufferManager bufferManager)
var scratchSpace = bufferManager.FullBuffer();
origin = curHead = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(scratchSpace));
- curEnd = curHead + scratchSpace.Length;
+ BufferEnd = curHead + scratchSpace.Length;
}
#pragma warning disable CS9084 // Struct member returns 'this' or other instance members by reference
@@ -94,8 +93,7 @@ public unsafe ref byte* BufferCur
#pragma warning restore CS9084
///
- public unsafe byte* BufferEnd
- => curEnd;
+ public unsafe byte* BufferEnd { get; private set; }
///
/// Gets a span that covers the responses as written so far.
@@ -118,28 +116,21 @@ public void SendAndReset()
var len = (int)(curHead - origin);
// We don't actually send anywhere, we grow the backing array
- bufferManager.GrowBuffer();
+ //
+ // Since we're managing the start/end pointers outside of the buffer
+ // we need to signal that the buffer has data to copy
+ bufferManager.GrowBuffer(copyLengthOverride: len);
var scratchSpace = bufferManager.FullBuffer();
origin = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(scratchSpace));
- curEnd = origin + scratchSpace.Length;
+ BufferEnd = origin + scratchSpace.Length;
curHead = origin + len;
}
}
const string LoaderBlock = @"
import = function () end
-redis = {}
-function redis.call(...)
- return garnet_call(...)
-end
-function redis.status_reply(text)
- return text
-end
-function redis.error_reply(text)
- return { err = 'ERR ' .. text }
-end
KEYS = {}
ARGV = {}
sandbox_env = {
@@ -163,7 +154,6 @@ function redis.error_reply(text)
rawequal = rawequal;
rawget = rawget;
rawset = rawset;
- redis = redis;
select = select;
-- explicitly not allowing setfenv
string = string;
@@ -180,64 +170,62 @@ function redis.error_reply(text)
}
-- do resets in the Lua side to minimize pinvokes
function reset_keys_and_argv(fromKey, fromArgv)
- local keyCount = #KEYS
+ local keyRef = KEYS
+ local keyCount = #keyRef
for i = fromKey, keyCount do
- KEYS[i] = nil
+ table.remove(keyRef)
end
- local argvCount = #ARGV
+ local argvRef = ARGV
+ local argvCount = #argvRef
for i = fromArgv, argvCount do
- ARGV[i] = nil
+ table.remove(argvRef)
end
end
-- responsible for sandboxing user provided code
function load_sandboxed(source)
- if (not source) then return nil end
- local rawFunc, err = load(source, nil, nil, sandbox_env)
+ -- move into a local to avoid global lookup
+ local garnetCallRef = garnet_call
- -- compilation error is returned directly
- if err then
- return rawFunc, err
- end
+ sandbox_env.redis = {
+ status_reply = function(text)
+ return text
+ end,
- -- otherwise we wrap the compiled function in a helper
- return function()
- local rawRet = rawFunc()
+ error_reply = function(text)
+ return { err = 'ERR ' .. text }
+ end,
- -- handle ok response wrappers without crossing the pinvoke boundary
- -- err response wrapper requires a bit more work, but is also rarer
- if rawRet and type(rawRet) == ""table"" and rawRet.ok then
- return rawRet.ok
+ call = function(...)
+ return garnetCallRef(...)
end
+ }
- return rawRet
- end
+ local rawFunc, err = load(source, nil, nil, sandbox_env)
+
+ return err, rawFunc
end
";
private static readonly ReadOnlyMemory LoaderBlockBytes = Encoding.UTF8.GetBytes(LoaderBlock);
- // Rooted to keep function pointer alive
- readonly LuaFunction garnetCall;
-
// References into Registry on the Lua side
//
// These are mix of objects we regularly update,
// constants we want to avoid copying from .NET to Lua,
// and the compiled function definition.
- readonly int sandboxEnvRegistryIndex;
readonly int keysTableRegistryIndex;
readonly int argvTableRegistryIndex;
readonly int loadSandboxedRegistryIndex;
readonly int resetKeysAndArgvRegistryIndex;
readonly int okConstStringRegistryIndex;
+ readonly int okLowerConstStringRegistryIndex;
readonly int errConstStringRegistryIndex;
readonly int noSessionAvailableConstStringRegistryIndex;
readonly int pleaseSpecifyRedisCallConstStringRegistryIndex;
readonly int errNoAuthConstStringRegistryIndex;
readonly int errUnknownConstStringRegistryIndex;
readonly int errBadArgConstStringRegistryIndex;
- int functionRegistryIndex;
readonly ReadOnlyMemory source;
readonly ScratchBufferNetworkSender scratchBufferNetworkSender;
@@ -248,41 +236,61 @@ function load_sandboxed(source)
readonly TxnKeyEntries txnKeyEntries;
readonly bool txnMode;
+ // The Lua registry index under which the user supplied function is stored post-compilation
+ int functionRegistryIndex;
+
// This cannot be readonly, as it is a mutable struct
LuaStateWrapper state;
+ // We need to temporarily store these for P/Invoke reasons
+ // You shouldn't be touching them outside of the Compile and Run methods
+
+ RunnerAdapter runnerAdapter;
+ RespResponseAdapter sessionAdapter;
+ RespServerSession preambleOuterSession;
+ int preambleKeyAndArgvCount;
+ int preambleNKeys;
+ string[] preambleKeys;
+ string[] preambleArgv;
+
int keyLength, argvLength;
///
/// Creates a new runner with the source of the script
///
- public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null)
+ public unsafe LuaRunner(LuaMemoryManagementMode memMode, int? memLimitBytes, ReadOnlyMemory source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null)
{
this.source = source;
this.txnMode = txnMode;
this.respServerSession = respServerSession;
this.scratchBufferNetworkSender = scratchBufferNetworkSender;
- this.scratchBufferManager = respServerSession?.scratchBufferManager ?? new();
this.logger = logger;
- sandboxEnvRegistryIndex = -1;
+ scratchBufferManager = respServerSession?.scratchBufferManager ?? new();
+
keysTableRegistryIndex = -1;
argvTableRegistryIndex = -1;
loadSandboxedRegistryIndex = -1;
functionRegistryIndex = -1;
- // TODO: custom allocator?
- state = new LuaStateWrapper(new Lua());
+ state = new LuaStateWrapper(memMode, memLimitBytes, this.logger);
+ delegate* unmanaged[Cdecl] garnetCall;
if (txnMode)
{
txnKeyEntries = new TxnKeyEntries(16, respServerSession.storageSession.lockableContext, respServerSession.storageSession.objectStoreLockableContext);
- garnetCall = garnet_call_txn;
+ garnetCall = &LuaRunnerTrampolines.GarnetCallWithTransaction;
}
else
{
- garnetCall = garnet_call;
+ garnetCall = &LuaRunnerTrampolines.GarnetCallNoTransaction;
+ }
+
+ if (respServerSession == null)
+ {
+ // During benchmarking and testing this can happen, so just redirect once instead of on each redis.call
+ garnetCall = &LuaRunnerTrampolines.GarnetCallNoSession;
}
var loadRes = state.LoadBuffer(LoaderBlockBytes.Span);
@@ -294,29 +302,46 @@ public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSe
var sandboxRes = state.PCall(0, -1);
if (sandboxRes != LuaStatus.OK)
{
- throw new GarnetException("Could not initialize Lua sandbox state");
+ string errMsg;
+ try
+ {
+ if (state.StackTop >= 1)
+ {
+ // We control the definition of LoaderBlock, so we know this is a string
+ state.KnownStringToBuffer(1, out var errSpan);
+ errMsg = Encoding.UTF8.GetString(errSpan);
+ }
+ else
+ {
+ errMsg = "No error provided";
+ }
+ }
+ catch
+ {
+ errMsg = "Error when fetching pcall error";
+ }
+
+ throw new GarnetException($"Could not initialize Lua sandbox state: {errMsg}");
}
// Register garnet_call in global namespace
- state.Register("garnet_call", garnetCall);
-
- state.GetGlobal(LuaType.Table, "sandbox_env");
- sandboxEnvRegistryIndex = state.Ref();
+ state.Register("garnet_call\0"u8, garnetCall);
- state.GetGlobal(LuaType.Table, "KEYS");
+ state.GetGlobal(LuaType.Table, "KEYS\0"u8);
keysTableRegistryIndex = state.Ref();
- state.GetGlobal(LuaType.Table, "ARGV");
+ state.GetGlobal(LuaType.Table, "ARGV\0"u8);
argvTableRegistryIndex = state.Ref();
- state.GetGlobal(LuaType.Function, "load_sandboxed");
+ state.GetGlobal(LuaType.Function, "load_sandboxed\0"u8);
loadSandboxedRegistryIndex = state.Ref();
- state.GetGlobal(LuaType.Function, "reset_keys_and_argv");
+ state.GetGlobal(LuaType.Function, "reset_keys_and_argv\0"u8);
resetKeysAndArgvRegistryIndex = state.Ref();
// Commonly used strings, register them once so we don't have to copy them over each time we need them
okConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_OK);
+ okLowerConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_ok);
errConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_err);
noSessionAvailableConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_No_session_available);
pleaseSpecifyRedisCallConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_ERR_Please_specify_at_least_one_argument_for_this_redis_lib_call);
@@ -330,8 +355,8 @@ public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSe
///
/// Creates a new runner with the source of the script
///
- public LuaRunner(string source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null)
- : this(Encoding.UTF8.GetBytes(source), txnMode, respServerSession, scratchBufferNetworkSender, logger)
+ public LuaRunner(LuaOptions options, string source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null)
+ : this(options.MemoryManagementMode, options.GetMemoryLimitBytes(), Encoding.UTF8.GetBytes(source), txnMode, respServerSession, scratchBufferNetworkSender, logger)
{
}
@@ -340,7 +365,7 @@ public LuaRunner(string source, bool txnMode = false, RespServerSession respServ
///
/// So instead we stash them in the Registry and load them by index
///
- int ConstantStringToRegistry(ReadOnlySpan str)
+ private int ConstantStringToRegistry(ReadOnlySpan str)
{
state.PushBuffer(str);
return state.Ref();
@@ -353,30 +378,86 @@ int ConstantStringToRegistry(ReadOnlySpan str)
///
public unsafe void CompileForRunner()
{
- var adapter = new RunnerAdapter(scratchBufferManager);
- CompileCommon(ref adapter);
-
- var resp = adapter.Response;
- var respStart = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(resp));
- var respEnd = respStart + resp.Length;
- if (RespReadUtils.TryReadErrorAsSpan(out var errSpan, ref respStart, respEnd))
+ runnerAdapter = new RunnerAdapter(scratchBufferManager);
+ try
+ {
+ LuaRunnerTrampolines.SetCallbackContext(this);
+ state.PushCFunction(&LuaRunnerTrampolines.CompileForRunner);
+ var res = state.PCall(0, 0);
+ if (res == LuaStatus.OK)
+ {
+ var resp = runnerAdapter.Response;
+ var respStart = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(resp));
+ var respEnd = respStart + resp.Length;
+ if (RespReadUtils.TryReadErrorAsSpan(out var errSpan, ref respStart, respEnd))
+ {
+ var errStr = Encoding.UTF8.GetString(errSpan);
+ throw new GarnetException(errStr);
+ }
+ }
+ else
+ {
+ throw new GarnetException($"Internal Lua Error: {res}");
+ }
+ }
+ finally
{
- var errStr = Encoding.UTF8.GetString(errSpan);
- throw new GarnetException(errStr);
+ runnerAdapter = default;
+ LuaRunnerTrampolines.ClearCallbackContext(this);
}
}
+ ///
+ /// Actually compiles for runner.
+ ///
+ /// If you call this directly and Lua encounters an error, the process will crash.
+ ///
+ /// Call instead.
+ ///
+ internal int UnsafeCompileForRunner()
+ => CompileCommon(ref runnerAdapter);
+
///
/// Compile script for a .
///
/// Any errors encountered are written out as Resp errors.
///
- public void CompileForSession(RespServerSession session)
+ public unsafe bool CompileForSession(RespServerSession session)
{
- var adapter = new RespResponseAdapter(session);
- CompileCommon(ref adapter);
+ sessionAdapter = new RespResponseAdapter(session);
+
+ try
+ {
+ LuaRunnerTrampolines.SetCallbackContext(this);
+ state.PushCFunction(&LuaRunnerTrampolines.CompileForSession);
+ var res = state.PCall(0, 0);
+ if (res != LuaStatus.OK)
+ {
+ while (!RespWriteUtils.WriteError("Internal Lua Error"u8, ref session.dcurr, session.dend))
+ session.SendAndReset();
+
+ return false;
+ }
+
+ return true;
+ }
+ finally
+ {
+ sessionAdapter = default;
+ LuaRunnerTrampolines.ClearCallbackContext(this);
+ }
}
+ ///
+ /// Actually compiles for runner.
+ ///
+ /// If you call this directly and Lua encounters an error, the process will crash.
+ ///
+ /// Call instead.
+ ///
+ internal int UnsafeCompileForSession()
+ => CompileCommon(ref sessionAdapter);
+
///
/// Drops compiled function, just for benchmarking purposes.
///
@@ -392,7 +473,7 @@ public void ResetCompilation()
///
/// Compile script, writing errors out to given response.
///
- unsafe void CompileCommon(ref TResponse resp)
+ private unsafe int CompileCommon(ref TResponse resp)
where TResponse : struct, IResponseAdapter
{
const int NeededStackSpace = 2;
@@ -401,105 +482,61 @@ unsafe void CompileCommon(ref TResponse resp)
state.ExpectLuaStackEmpty();
- try
- {
- state.ForceMinimumStackCapacity(NeededStackSpace);
-
- state.PushInteger(loadSandboxedRegistryIndex);
- _ = state.RawGet(LuaType.Function, (int)LuaRegistry.Index);
-
- state.PushBuffer(source.Span);
- state.Call(1, -1); // Multiple returns allowed
-
- var numRets = state.StackTop;
-
- if (numRets == 0)
- {
- while (!RespWriteUtils.WriteError("Shouldn't happen, no returns from load_sandboxed"u8, ref resp.BufferCur, resp.BufferEnd))
- resp.SendAndReset();
-
- return;
- }
- else if (numRets == 1)
- {
- var returnType = state.Type(1);
- if (returnType != LuaType.Function)
- {
- var errStr = $"Could not compile function, got back a {returnType}";
- while (!RespWriteUtils.WriteError(errStr, ref resp.BufferCur, resp.BufferEnd))
- resp.SendAndReset();
-
- return;
- }
+ state.ForceMinimumStackCapacity(NeededStackSpace);
- functionRegistryIndex = state.Ref();
- }
- else if (numRets == 2)
- {
- state.CheckBuffer(2, out var errorBuf);
+ _ = state.RawGetInteger(LuaType.Function, (int)LuaRegistry.Index, loadSandboxedRegistryIndex);
+ state.PushBuffer(source.Span);
+ state.Call(1, 2);
- var errStr = $"Compilation error: {Encoding.UTF8.GetString(errorBuf)}";
- while (!RespWriteUtils.WriteError(errStr, ref resp.BufferCur, resp.BufferEnd))
- resp.SendAndReset();
+ // Now the stack will have two things on it:
+ // 1. The error (nil if not error)
+ // 2. The function (nil if error)
- state.Pop(2);
+ if (state.Type(1) == LuaType.Nil)
+ {
+ // No error, success!
- return;
- }
- else
- {
- state.Pop(numRets);
+ Debug.Assert(state.Type(2) == LuaType.Function, "Unexpected type returned from load_sandboxed");
- throw new GarnetException($"Unexpected error compiling, got too many replies back: reply count = {numRets}");
- }
- }
- catch (Exception ex)
- {
- logger?.LogError(ex, "CreateFunction threw an exception");
- throw;
+ functionRegistryIndex = state.Ref();
}
- finally
+ else
{
- state.ExpectLuaStackEmpty();
+ // We control the definition of load_sandboxed, so we know this will be a string
+ state.KnownStringToBuffer(1, out var errorBuf);
+
+ var errStr = $"Compilation error: {Encoding.UTF8.GetString(errorBuf)}";
+ while (!RespWriteUtils.WriteError(errStr, ref resp.BufferCur, resp.BufferEnd))
+ resp.SendAndReset();
}
+
+ return 0;
}
///
/// Dispose the runner
///
public void Dispose()
- {
- state.Dispose();
- }
+ => state.Dispose();
///
/// Entry point for redis.call method from a Lua script (non-transactional mode)
///
- public int garnet_call(IntPtr luaStatePtr)
+ public int GarnetCall(nint luaStatePtr)
{
state.CallFromLuaEntered(luaStatePtr);
- if (respServerSession == null)
- {
- return NoSessionResponse();
- }
-
- return ProcessCommandFromScripting(respServerSession.basicGarnetApi);
+ return ProcessCommandFromScripting(ref respServerSession.basicGarnetApi);
}
///
/// Entry point for redis.call method from a Lua script (transactional mode)
///
- public int garnet_call_txn(IntPtr luaStatePtr)
+ public int GarnetCallWithTransaction(nint luaStatePtr)
{
state.CallFromLuaEntered(luaStatePtr);
- if (respServerSession == null)
- {
- return NoSessionResponse();
- }
-
- return ProcessCommandFromScripting(respServerSession.lockableGarnetApi);
+ return ProcessCommandFromScripting(ref respServerSession.lockableGarnetApi);
}
///
@@ -507,10 +544,12 @@ public int garnet_call_txn(IntPtr luaStatePtr)
///
/// This is used in benchmarking.
///
- int NoSessionResponse()
+ internal int NoSessionResponse(nint luaStatePtr)
{
const int NeededStackSpace = 1;
+ state.CallFromLuaEntered(luaStatePtr);
+
state.ForceMinimumStackCapacity(NeededStackSpace);
state.PushNil();
@@ -520,7 +559,7 @@ int NoSessionResponse()
///
/// Entry point method for executing commands from a Lua Script
///
- unsafe int ProcessCommandFromScripting(TGarnetApi api)
+ unsafe int ProcessCommandFromScripting(ref TGarnetApi api)
where TGarnetApi : IGarnetApi
{
const int AdditionalStackSpace = 1;
@@ -529,7 +568,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api)
{
var argCount = state.StackTop;
- if (argCount == 0)
+ if (argCount <= 0)
{
return LuaStaticError(pleaseSpecifyRedisCallConstStringRegistryIndex);
}
@@ -646,7 +685,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api)
///
/// Cause a Lua error to be raised with a message previously registered.
///
- int LuaStaticError(int constStringRegistryIndex)
+ private int LuaStaticError(int constStringRegistryIndex)
{
const int NeededStackSize = 1;
@@ -661,7 +700,7 @@ int LuaStaticError(int constStringRegistryIndex)
///
/// Pushes result onto state stack and returns 1, or raises an error and never returns.
///
- unsafe int ProcessResponse(byte* ptr, int length)
+ private unsafe int ProcessResponse(byte* ptr, int length)
{
const int NeededStackSize = 3;
@@ -774,26 +813,76 @@ unsafe int ProcessResponse(byte* ptr, int length)
///
/// Response is written directly into the .
///
- public void RunForSession(int count, RespServerSession outerSession)
+ public unsafe void RunForSession(int count, RespServerSession outerSession)
{
- const int NeededStackSize = 3;
+ const int NeededStackSize = 2;
state.ForceMinimumStackCapacity(NeededStackSize);
+ preambleOuterSession = outerSession;
+ preambleKeyAndArgvCount = count;
+ try
+ {
+ LuaRunnerTrampolines.SetCallbackContext(this);
+
+ try
+ {
+ state.PushCFunction(&LuaRunnerTrampolines.RunPreambleForSession);
+ var callRes = state.PCall(0, 0);
+ if (callRes != LuaStatus.OK)
+ {
+ while (!RespWriteUtils.WriteError("Internal Lua Error"u8, ref outerSession.dcurr, outerSession.dend))
+ outerSession.SendAndReset();
+
+ return;
+ }
+ }
+ finally
+ {
+ preambleOuterSession = null;
+ }
+
+ var adapter = new RespResponseAdapter(outerSession);
+
+ if (txnMode && preambleNKeys > 0)
+ {
+ RunInTransaction(ref adapter);
+ }
+ else
+ {
+ RunCommon(ref adapter);
+ }
+ }
+ finally
+ {
+ LuaRunnerTrampolines.ClearCallbackContext(this);
+ }
+ }
+
+ ///
+ /// Setups a script to be run.
+ ///
+ /// If you call this directly and Lua encounters an error, the process will crash.
+ ///
+ /// Call instead.
+ ///
+ internal int UnsafeRunPreambleForSession()
+ {
+ state.ExpectLuaStackEmpty();
+
scratchBufferManager.Reset();
- var parseState = outerSession.parseState;
+ ref var parseState = ref preambleOuterSession.parseState;
var offset = 1;
- var nKeys = parseState.GetInt(offset++);
- count--;
- ResetParameters(nKeys, count - nKeys);
+ var nKeys = preambleNKeys = parseState.GetInt(offset++);
+ preambleKeyAndArgvCount--;
+ ResetParameters(nKeys, preambleKeyAndArgvCount - nKeys);
if (nKeys > 0)
{
// Get KEYS on the stack
- state.PushInteger(keysTableRegistryIndex);
- state.RawGet(LuaType.Table, (int)LuaRegistry.Index);
+ _ = state.RawGetInteger(LuaType.Table, (int)LuaRegistry.Index, keysTableRegistryIndex);
for (var i = 0; i < nKeys; i++)
{
@@ -807,9 +896,8 @@ public void RunForSession(int count, RespServerSession outerSession)
}
// Equivalent to KEYS[i+1] = key
- state.PushInteger(i + 1);
state.PushBuffer(key.ReadOnlySpan);
- state.RawSet(1);
+ state.RawSetInteger(1, i + 1);
offset++;
}
@@ -817,23 +905,21 @@ public void RunForSession(int count, RespServerSession outerSession)
// Remove KEYS from the stack
state.Pop(1);
- count -= nKeys;
+ preambleKeyAndArgvCount -= nKeys;
}
- if (count > 0)
+ if (preambleKeyAndArgvCount > 0)
{
// Get ARGV on the stack
- state.PushInteger(argvTableRegistryIndex);
- state.RawGet(LuaType.Table, (int)LuaRegistry.Index);
+ _ = state.RawGetInteger(LuaType.Table, (int)LuaRegistry.Index, argvTableRegistryIndex);
- for (var i = 0; i < count; i++)
+ for (var i = 0; i < preambleKeyAndArgvCount; i++)
{
ref var argv = ref parseState.GetArgSliceByRef(offset);
// Equivalent to ARGV[i+1] = argv
- state.PushInteger(i + 1);
state.PushBuffer(argv.ReadOnlySpan);
- state.RawSet(1);
+ state.RawSetInteger(1, i + 1);
offset++;
}
@@ -842,16 +928,7 @@ public void RunForSession(int count, RespServerSession outerSession)
state.Pop(1);
}
- var adapter = new RespResponseAdapter(outerSession);
-
- if (txnMode && nKeys > 0)
- {
- RunInTransaction(ref adapter);
- }
- else
- {
- RunCommon(ref adapter);
- }
+ return 0;
}
///
@@ -861,44 +938,69 @@ public void RunForSession(int count, RespServerSession outerSession)
///
public unsafe object RunForRunner(string[] keys = null, string[] argv = null)
{
- scratchBufferManager?.Reset();
- LoadParametersForRunner(keys, argv);
+ const int NeededStackSize = 2;
- var adapter = new RunnerAdapter(scratchBufferManager);
+ state.ForceMinimumStackCapacity(NeededStackSize);
- if (txnMode && keys?.Length > 0)
+ try
{
- // Add keys to the transaction
- foreach (var key in keys)
+ LuaRunnerTrampolines.SetCallbackContext(this);
+
+ try
{
- var _key = scratchBufferManager.CreateArgSlice(key);
- txnKeyEntries.AddKey(_key, false, Tsavorite.core.LockType.Exclusive);
- if (!respServerSession.storageSession.objectStoreLockableContext.IsNull)
- txnKeyEntries.AddKey(_key, true, Tsavorite.core.LockType.Exclusive);
+ preambleKeys = keys;
+ preambleArgv = argv;
+
+ state.PushCFunction(&LuaRunnerTrampolines.RunPreambleForRunner);
+ state.Call(0, 0);
+ }
+ finally
+ {
+ preambleKeys = preambleArgv = null;
}
- RunInTransaction(ref adapter);
- }
- else
- {
- RunCommon(ref adapter);
- }
+ RunnerAdapter adapter;
+ if (txnMode && keys?.Length > 0)
+ {
+ // Add keys to the transaction
+ foreach (var key in keys)
+ {
+ var _key = scratchBufferManager.CreateArgSlice(key);
+ txnKeyEntries.AddKey(_key, false, Tsavorite.core.LockType.Exclusive);
+ if (!respServerSession.storageSession.objectStoreLockableContext.IsNull)
+ txnKeyEntries.AddKey(_key, true, Tsavorite.core.LockType.Exclusive);
+ }
+
+ adapter = new(scratchBufferManager);
+ RunInTransaction(ref adapter);
+ }
+ else
+ {
+ adapter = new(scratchBufferManager);
+ RunCommon(ref adapter);
+ }
- var resp = adapter.Response;
- var respCur = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(resp));
- var respEnd = respCur + resp.Length;
+ var resp = adapter.Response;
+ var respCur = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(resp));
+ var respEnd = respCur + resp.Length;
- if (RespReadUtils.TryReadErrorAsSpan(out var errSpan, ref respCur, respEnd))
- {
- var errStr = Encoding.UTF8.GetString(errSpan);
- throw new GarnetException(errStr);
- }
+ if (RespReadUtils.TryReadErrorAsSpan(out var errSpan, ref respCur, respEnd))
+ {
+ var errStr = Encoding.UTF8.GetString(errSpan);
+ throw new GarnetException(errStr);
+ }
- var ret = MapRespToObject(ref respCur, respEnd);
- Debug.Assert(respCur == respEnd, "Should have fully consumed response");
+ var ret = MapRespToObject(ref respCur, respEnd);
+ Debug.Assert(respCur == respEnd, "Should have fully consumed response");
- return ret;
+ return ret;
+ }
+ finally
+ {
+ LuaRunnerTrampolines.ClearCallbackContext(this);
+ }
+ // Convert a RESP response into an object to return
static object MapRespToObject(ref byte* cur, byte* end)
{
switch (*cur)
@@ -953,10 +1055,26 @@ static object MapRespToObject(ref byte* cur, byte* end)
}
}
+ ///
+ /// Setups a script to be run.
+ ///
+ /// If you call this directly and Lua encounters an error, the process will crash.
+ ///
+ /// Call instead.
+ ///
+ internal int UnsafeRunPreambleForRunner()
+ {
+ state.ExpectLuaStackEmpty();
+
+ scratchBufferManager?.Reset();
+
+ return LoadParametersForRunner(preambleKeys, preambleArgv);
+ }
+
///
/// Calls after setting up appropriate state for a transaction.
///
- void RunInTransaction(ref TResponse response)
+ private void RunInTransaction(ref TResponse response)
where TResponse : struct, IResponseAdapter
{
try
@@ -990,13 +1108,12 @@ internal void ResetParameters(int nKeys, int nArgs)
if (keyLength > nKeys || argvLength > nArgs)
{
- state.RawGetInteger(LuaType.Function, (int)LuaRegistry.Index, resetKeysAndArgvRegistryIndex);
+ _ = state.RawGetInteger(LuaType.Function, (int)LuaRegistry.Index, resetKeysAndArgvRegistryIndex);
state.PushInteger(nKeys + 1);
state.PushInteger(nArgs + 1);
- var resetRes = state.PCall(2, 0);
- Debug.Assert(resetRes == LuaStatus.OK, "Resetting should never fail");
+ state.Call(2, 0);
}
keyLength = nKeys;
@@ -1006,7 +1123,7 @@ internal void ResetParameters(int nKeys, int nArgs)
///
/// Takes .NET strings for keys and args and pushes them into KEYS and ARGV globals.
///
- void LoadParametersForRunner(string[] keys, string[] argv)
+ private int LoadParametersForRunner(string[] keys, string[] argv)
{
const int NeededStackSize = 2;
@@ -1017,8 +1134,7 @@ void LoadParametersForRunner(string[] keys, string[] argv)
if (keys != null)
{
// get KEYS on the stack
- state.PushInteger(keysTableRegistryIndex);
- _ = state.RawGet(LuaType.Table, (int)LuaRegistry.Index);
+ _ = state.RawGetInteger(LuaType.Table, (int)LuaRegistry.Index, keysTableRegistryIndex);
for (var i = 0; i < keys.Length; i++)
{
@@ -1035,8 +1151,7 @@ void LoadParametersForRunner(string[] keys, string[] argv)
if (argv != null)
{
// get ARGV on the stack
- state.PushInteger(argvTableRegistryIndex);
- _ = state.RawGet(LuaType.Table, (int)LuaRegistry.Index);
+ _ = state.RawGetInteger(LuaType.Table, (int)LuaRegistry.Index, argvTableRegistryIndex);
for (var i = 0; i < argv.Length; i++)
{
@@ -1050,6 +1165,9 @@ void LoadParametersForRunner(string[] keys, string[] argv)
state.Pop(1);
}
+ return 0;
+
+ // Convert string into a span, using buffer for storage
static void PrepareString(string raw, ScratchBufferManager buffer, out ReadOnlySpan strBytes)
{
var maxLen = Encoding.UTF8.GetMaxByteCount(raw.Length);
@@ -1066,7 +1184,7 @@ static void PrepareString(string raw, ScratchBufferManager buffer, out ReadOnlyS
///
/// Runs the precompiled Lua function.
///
- unsafe void RunCommon(ref TResponse resp)
+ private unsafe void RunCommon(ref TResponse resp)
where TResponse : struct, IResponseAdapter
{
const int NeededStackSize = 2;
@@ -1078,8 +1196,7 @@ unsafe void RunCommon(ref TResponse resp)
{
state.ForceMinimumStackCapacity(NeededStackSize);
- state.PushInteger(functionRegistryIndex);
- _ = state.RawGet(LuaType.Function, (int)LuaRegistry.Index);
+ _ = state.RawGetInteger(LuaType.Function, (int)LuaRegistry.Index, functionRegistryIndex);
var callRes = state.PCall(0, 1);
if (callRes == LuaStatus.OK)
@@ -1119,7 +1236,23 @@ unsafe void RunCommon(ref TResponse resp)
{
// Redis does not respect metatables, so RAW access is ok here
- // If the key err is in there, we need to short circuit
+ // If the key "ok" is in there, we need to short circuit
+ state.PushConstantString(okLowerConstStringRegistryIndex);
+ var okType = state.RawGet(null, 1);
+ if (okType == LuaType.String)
+ {
+ WriteString(this, ref resp);
+
+ // Remove table from stack
+ state.Pop(1);
+
+ return;
+ }
+
+ // Remove whatever we read from the table under the "ok" key
+ state.Pop(1);
+
+ // If the key "err" is in there, we need to short circuit
state.PushConstantString(errConstStringRegistryIndex);
var errType = state.RawGet(null, 1);
@@ -1153,11 +1286,29 @@ unsafe void RunCommon(ref TResponse resp)
}
else if (state.StackTop == 1)
{
- if (state.CheckBuffer(1, out var errBuf))
+ // PCall will put error in a string
+ state.KnownStringToBuffer(1, out var errBuf);
+
+ if (errBuf.Length >= 4 && MemoryMarshal.Read("ERR "u8) == Unsafe.As(ref MemoryMarshal.GetReference(errBuf)))
{
+ // Response came back with a ERR, already - just pass it along
while (!RespWriteUtils.WriteError(errBuf, ref resp.BufferCur, resp.BufferEnd))
resp.SendAndReset();
}
+ else
+ {
+ // Otherwise, this is probably a Lua error - and those aren't very descriptive
+ // So slap some more information in
+
+ while (!RespWriteUtils.WriteDirect("-ERR Lua encountered an error: "u8, ref resp.BufferCur, resp.BufferEnd))
+ resp.SendAndReset();
+
+ while (!RespWriteUtils.WriteDirect(errBuf, ref resp.BufferCur, resp.BufferEnd))
+ resp.SendAndReset();
+
+ while (!RespWriteUtils.WriteDirect("\r\n"u8, ref resp.BufferCur, resp.BufferEnd))
+ resp.SendAndReset();
+ }
state.Pop(1);
@@ -1280,7 +1431,7 @@ static void WriteArray(LuaRunner runner, ref TResponse resp)
}
}
- while (!RespWriteUtils.WriteArrayLength((int)trueLen, ref resp.BufferCur, resp.BufferEnd))
+ while (!RespWriteUtils.WriteArrayLength(trueLen, ref resp.BufferCur, resp.BufferEnd))
resp.SendAndReset();
for (var i = 1; i <= trueLen; i++)
@@ -1316,4 +1467,107 @@ static void WriteArray(LuaRunner runner, ref TResponse resp)
}
}
}
+
+ ///
+ /// Holds static functions for Lua-to-.NET interop.
+ ///
+ /// We annotate these as "unmanaged callers only" as a micro-optimization.
+ /// See: https://devblogs.microsoft.com/dotnet/improvements-in-native-code-interop-in-net-5-0/#unmanagedcallersonly
+ ///
+ internal static class LuaRunnerTrampolines
+ {
+ [ThreadStatic]
+ private static LuaRunner callbackContext;
+
+ ///
+ /// Set a that will be available in trampolines.
+ ///
+ /// This assumes the same thread is used to call into Lua.
+ ///
+ /// Call when finished to avoid extending
+ /// the lifetime of the .
+ ///
+ internal static void SetCallbackContext(LuaRunner context)
+ {
+ Debug.Assert(callbackContext == null, "Expected null context");
+ callbackContext = context;
+ }
+
+ ///
+ /// Clear a previously set
+ ///
+ internal static void ClearCallbackContext(LuaRunner context)
+ {
+ Debug.Assert(ReferenceEquals(callbackContext, context), "Expected context to match");
+ callbackContext = null;
+ }
+
+ ///
+ /// Entry point for Lua PCall'ing into .
+ ///
+ /// We need this indirection to allow Lua to detect and report internal and memory errors
+ /// without crashing the process.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int CompileForRunner(nint _)
+ => callbackContext.UnsafeCompileForRunner();
+
+ ///
+ /// Entry point for Lua PCall'ing into .
+ ///
+ /// We need this indirection to allow Lua to detect and report internal and memory errors
+ /// without crashing the process.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int CompileForSession(nint _)
+ => callbackContext.UnsafeCompileForSession();
+
+ ///
+ /// Entry point for Lua PCall'ing into .
+ ///
+ /// We need this indirection to allow Lua to detect and report internal and memory errors
+ /// without crashing the process.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int RunPreambleForRunner(nint _)
+ => callbackContext.UnsafeRunPreambleForRunner();
+
+ ///
+ /// Entry point for Lua PCall'ing into .
+ ///
+ /// We need this indirection to allow Lua to detect and report internal and memory errors
+ /// without crashing the process.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int RunPreambleForSession(nint _)
+ => callbackContext.UnsafeRunPreambleForSession();
+
+ ///
+ /// Entry point for Lua calling back into Garnet via redis.call(...).
+ ///
+ /// This entry point is for when there isn't an active .
+ /// This should only happen during testing.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int GarnetCallNoSession(nint luaState)
+ => callbackContext.NoSessionResponse(luaState);
+
+ ///
+ /// Entry point for Lua calling back into Garnet via redis.call(...).
+ ///
+ /// This entry point is for when a transaction is in effect.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int GarnetCallWithTransaction(nint luaState)
+ => callbackContext.GarnetCallWithTransaction(luaState);
+
+ ///
+ /// Entry point for Lua calling back into Garnet via redis.call(...).
+ ///
+ /// This entry point is for when a transaction is not necessary.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int GarnetCallNoTransaction(nint luaState)
+ => callbackContext.GarnetCall(luaState);
+ }
}
\ No newline at end of file
diff --git a/libs/server/Lua/LuaStateWrapper.cs b/libs/server/Lua/LuaStateWrapper.cs
index 51ac6b0e0a..136507d15c 100644
--- a/libs/server/Lua/LuaStateWrapper.cs
+++ b/libs/server/Lua/LuaStateWrapper.cs
@@ -4,9 +4,11 @@
using System;
using System.Diagnostics;
using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
using System.Text;
using Garnet.common;
using KeraLua;
+using Microsoft.Extensions.Logging;
namespace Garnet.server
{
@@ -20,13 +22,49 @@ internal struct LuaStateWrapper : IDisposable
{
private const int LUA_MINSTACK = 20;
- private readonly Lua state;
+ private GCHandle customAllocatorHandle;
+ private nint state;
private int curStackSize;
- internal LuaStateWrapper(Lua state)
+ ///
+ /// Current top item in the stack.
+ ///
+ /// 0 implies the stack is empty.
+ ///
+ internal int StackTop { get; private set; }
+
+ internal unsafe LuaStateWrapper(LuaMemoryManagementMode memMode, int? memLimitBytes, ILogger logger)
{
- this.state = state;
+ // As an emergency thing, we need a logger (any logger) if Lua is going to panic
+ // TODO: Consider a better way to do this?
+ LuaStateWrapperTrampolines.PanicLogger ??= logger;
+
+ ILuaAllocator customAllocator =
+ (memMode, memLimitBytes) switch
+ {
+ (LuaMemoryManagementMode.Native, null) => null,
+ (LuaMemoryManagementMode.Tracked, _) => new LuaTrackedAllocator(memLimitBytes),
+ (LuaMemoryManagementMode.Managed, null) => new LuaManagedAllocator(),
+ (LuaMemoryManagementMode.Managed, _) => new LuaLimitedManagedAllocator(memLimitBytes.Value),
+ _ => throw new InvalidOperationException($"Unexpected mode/limit combination: {memMode}/{memLimitBytes}")
+ };
+
+ if (customAllocator != null)
+ {
+ customAllocatorHandle = GCHandle.Alloc(customAllocator, GCHandleType.Normal);
+ var stateUserData = (nint)customAllocatorHandle;
+
+ state = NativeMethods.NewState(&LuaStateWrapperTrampolines.LuaAllocateBytes, stateUserData);
+ }
+ else
+ {
+ state = NativeMethods.NewState();
+ }
+
+ NativeMethods.OpenLibs(state);
+
+ _ = NativeMethods.AtPanic(state, &LuaStateWrapperTrampolines.LuaAtPanic);
curStackSize = LUA_MINSTACK;
StackTop = 0;
@@ -35,17 +73,20 @@ internal LuaStateWrapper(Lua state)
}
///
- public readonly void Dispose()
+ public void Dispose()
{
- state.Dispose();
- }
+ if (state != 0)
+ {
+ NativeMethods.Close(state);
+ state = 0;
+ }
- ///
- /// Current top item in the stack.
- ///
- /// 0 implies the stack is empty.
- ///
- internal int StackTop { get; private set; }
+ if (customAllocatorHandle.IsAllocated)
+ {
+ customAllocatorHandle.Free();
+ customAllocatorHandle = default;
+ }
+ }
///
/// Call when ambient state indicates that the Lua stack is in fact empty.
@@ -77,7 +118,7 @@ internal void ForceMinimumStackCapacity(int additionalCapacity)
}
var needed = additionalCapacity - availableSpace;
- if (!state.CheckStack(needed))
+ if (!NativeMethods.CheckStack(state, needed))
{
throw new GarnetException("Could not reserve additional capacity on the Lua stack");
}
@@ -93,9 +134,9 @@ internal void ForceMinimumStackCapacity(int additionalCapacity)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal void CallFromLuaEntered(IntPtr luaStatePtr)
{
- Debug.Assert(luaStatePtr == state.Handle, "Unexpected Lua state presented");
+ Debug.Assert(luaStatePtr == state, "Unexpected Lua state presented");
- StackTop = NativeMethods.GetTop(state.Handle);
+ StackTop = NativeMethods.GetTop(state);
curStackSize = StackTop > LUA_MINSTACK ? StackTop : LUA_MINSTACK;
}
@@ -107,7 +148,7 @@ internal readonly bool CheckBuffer(int index, out ReadOnlySpan str)
{
AssertLuaStackIndexInBounds(index);
- return NativeMethods.CheckBuffer(state.Handle, index, out str);
+ return NativeMethods.CheckBuffer(state, index, out str);
}
///
@@ -118,7 +159,7 @@ internal readonly LuaType Type(int stackIndex)
{
AssertLuaStackIndexInBounds(stackIndex);
- return NativeMethods.Type(state.Handle, stackIndex);
+ return NativeMethods.Type(state, stackIndex);
}
///
@@ -131,7 +172,7 @@ internal void PushBuffer(ReadOnlySpan buffer)
{
AssertLuaStackNotFull();
- NativeMethods.PushBuffer(state.Handle, buffer);
+ _ = ref NativeMethods.PushBuffer(state, buffer);
UpdateStackTop(1);
}
@@ -143,7 +184,7 @@ internal void PushNil()
{
AssertLuaStackNotFull();
- NativeMethods.PushNil(state.Handle);
+ NativeMethods.PushNil(state);
UpdateStackTop(1);
}
@@ -155,7 +196,7 @@ internal void PushInteger(long number)
{
AssertLuaStackNotFull();
- NativeMethods.PushInteger(state.Handle, number);
+ NativeMethods.PushInteger(state, number);
UpdateStackTop(1);
}
@@ -168,7 +209,7 @@ internal void PushBoolean(bool b)
{
AssertLuaStackNotFull();
- NativeMethods.PushBoolean(state.Handle, b);
+ NativeMethods.PushBoolean(state, b);
UpdateStackTop(1);
}
@@ -180,26 +221,27 @@ internal void PushBoolean(bool b)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal void Pop(int num)
{
- NativeMethods.Pop(state.Handle, num);
+ NativeMethods.Pop(state, num);
UpdateStackTop(-num);
}
///
- /// This should be used for all Calls into Lua.
+ /// This should be used for all PCalls into Lua.
///
/// Maintains and to minimize p/invoke calls.
///
[MethodImpl(MethodImplOptions.AggressiveInlining)]
- internal void Call(int args, int rets)
+ internal LuaStatus PCall(int args, int rets)
{
- // We have to copy this off, as once we Call curStackTop could be modified
+ // We have to copy this off, as once we PCall curStackTop could be modified
var oldStackTop = StackTop;
- state.Call(args, rets);
- if (rets < 0)
+ var res = NativeMethods.PCall(state, args, rets);
+
+ if (res != LuaStatus.OK || rets < 0)
{
- StackTop = NativeMethods.GetTop(state.Handle);
+ StackTop = NativeMethods.GetTop(state);
AssertLuaStackExpected();
}
else
@@ -208,23 +250,26 @@ internal void Call(int args, int rets)
var update = newPosition - StackTop;
UpdateStackTop(update);
}
+
+ return res;
}
///
- /// This should be used for all PCalls into Lua.
+ /// This should be used for all Calls into Lua.
///
/// Maintains and to minimize p/invoke calls.
///
[MethodImpl(MethodImplOptions.AggressiveInlining)]
- internal LuaStatus PCall(int args, int rets)
+ internal void Call(int args, int rets)
{
- // We have to copy this off, as once we Call curStackTop could be modified
+ // We have to copy this off, as once we PCall curStackTop could be modified
var oldStackTop = StackTop;
- var res = state.PCall(args, rets, 0);
- if (res != LuaStatus.OK || rets < 0)
+ NativeMethods.Call(state, args, rets);
+
+ if (rets < 0)
{
- StackTop = NativeMethods.GetTop(state.Handle);
+ StackTop = NativeMethods.GetTop(state);
AssertLuaStackExpected();
}
else
@@ -233,8 +278,6 @@ internal LuaStatus PCall(int args, int rets)
var update = newPosition - StackTop;
UpdateStackTop(update);
}
-
- return res;
}
///
@@ -247,24 +290,10 @@ internal void RawSetInteger(int stackIndex, int tableIndex)
{
AssertLuaStackIndexInBounds(stackIndex);
- state.RawSetInteger(stackIndex, tableIndex);
+ NativeMethods.RawSetInteger(state, stackIndex, tableIndex);
UpdateStackTop(-1);
}
- ///
- /// This should be used for all RawSets into Lua.
- ///
- /// Maintains and to minimize p/invoke calls.
- ///
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- internal void RawSet(int stackIndex)
- {
- AssertLuaStackIndexInBounds(stackIndex);
-
- state.RawSet(stackIndex);
- UpdateStackTop(-2);
- }
-
///
/// This should be used for all RawGetIntegers into Lua.
///
@@ -276,7 +305,7 @@ internal LuaType RawGetInteger(LuaType? expectedType, int stackIndex, int tableI
AssertLuaStackIndexInBounds(stackIndex);
AssertLuaStackNotFull();
- var actual = state.RawGetInteger(stackIndex, tableIndex);
+ var actual = NativeMethods.RawGetInteger(state, stackIndex, tableIndex);
Debug.Assert(expectedType == null || actual == expectedType, "Unexpected type received");
UpdateStackTop(1);
@@ -294,7 +323,7 @@ internal readonly LuaType RawGet(LuaType? expectedType, int stackIndex)
{
AssertLuaStackIndexInBounds(stackIndex);
- var actual = state.RawGet(stackIndex);
+ var actual = NativeMethods.RawGet(state, stackIndex);
Debug.Assert(expectedType == null || actual == expectedType, "Unexpected type received");
AssertLuaStackExpected();
@@ -310,7 +339,7 @@ internal readonly LuaType RawGet(LuaType? expectedType, int stackIndex)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal int Ref()
{
- var ret = state.Ref(LuaRegistry.Index);
+ var ret = NativeMethods.Ref(state, (int)LuaRegistry.Index);
UpdateStackTop(-1);
return ret;
@@ -323,9 +352,7 @@ internal int Ref()
///
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal readonly void Unref(LuaRegistry registry, int reference)
- {
- state.Unref(registry, reference);
- }
+ => NativeMethods.Unref(state, (int)registry, reference);
///
/// This should be used for all CreateTables into Lua.
@@ -337,7 +364,7 @@ internal void CreateTable(int numArr, int numRec)
{
AssertLuaStackNotFull();
- state.CreateTable(numArr, numRec);
+ NativeMethods.CreateTable(state, numArr, numRec);
UpdateStackTop(1);
}
@@ -346,11 +373,11 @@ internal void CreateTable(int numArr, int numRec)
///
/// Maintains and to minimize p/invoke calls.
///
- internal void GetGlobal(LuaType expectedType, string globalName)
+ internal void GetGlobal(LuaType expectedType, ReadOnlySpan nullTerminatedGlobalName)
{
AssertLuaStackNotFull();
- var type = state.GetGlobal(globalName);
+ var type = NativeMethods.GetGlobal(state, nullTerminatedGlobalName);
Debug.Assert(type == expectedType, "Unexpected type received");
UpdateStackTop(1);
@@ -368,7 +395,7 @@ internal LuaStatus LoadBuffer(ReadOnlySpan buffer)
{
AssertLuaStackNotFull();
- var ret = NativeMethods.LoadBuffer(state.Handle, buffer);
+ var ret = NativeMethods.LoadBuffer(state, buffer);
UpdateStackTop(1);
@@ -389,9 +416,9 @@ internal readonly void KnownStringToBuffer(int stackIndex, out ReadOnlySpan
@@ -402,7 +429,7 @@ internal readonly double CheckNumber(int stackIndex)
{
AssertLuaStackIndexInBounds(stackIndex);
- return state.CheckNumber(stackIndex);
+ return NativeMethods.CheckNumber(state, stackIndex);
}
///
@@ -413,7 +440,7 @@ internal readonly bool ToBoolean(int stackIndex)
{
AssertLuaStackIndexInBounds(stackIndex);
- return NativeMethods.ToBoolean(state.Handle, stackIndex);
+ return NativeMethods.ToBoolean(state, stackIndex);
}
///
@@ -424,14 +451,29 @@ internal readonly long RawLen(int stackIndex)
{
AssertLuaStackIndexInBounds(stackIndex);
- return state.RawLen(stackIndex);
+ return NativeMethods.RawLen(state, stackIndex);
+ }
+
+ ///
+ /// This should be used for all PushCFunctions into Lua.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ internal unsafe void PushCFunction(delegate* unmanaged[Cdecl] function)
+ {
+ NativeMethods.PushCFunction(state, (nint)function);
+ UpdateStackTop(1);
}
///
/// Call to register a function in the Lua global namespace.
///
- internal readonly void Register(string name, LuaFunction func)
- => state.Register(name, func);
+ internal unsafe void Register(ReadOnlySpan nullTerminatedName, delegate* unmanaged[Cdecl] function)
+ {
+ PushCFunction(function);
+
+ NativeMethods.SetGlobal(state, nullTerminatedName);
+ UpdateStackTop(-1);
+ }
///
/// This should be used to push all known constants strings into Lua.
@@ -439,7 +481,7 @@ internal readonly void Register(string name, LuaFunction func)
/// This avoids extra copying of data between .NET and Lua.
///
[MethodImpl(MethodImplOptions.AggressiveInlining)]
- internal void PushConstantString(int constStringRegistryIndex, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1)
+ internal void PushConstantString(int constStringRegistryIndex)
=> RawGetInteger(LuaType.String, (int)LuaRegistry.Index, constStringRegistryIndex);
// Rarely used
@@ -449,7 +491,7 @@ internal void PushConstantString(int constStringRegistryIndex, [CallerFilePath]
///
internal void ClearStack()
{
- state.SetTop(0);
+ NativeMethods.SetTop(state, 0);
StackTop = 0;
AssertLuaStackExpected();
@@ -463,6 +505,7 @@ internal int RaiseError(string msg)
ClearStack();
var b = Encoding.UTF8.GetBytes(msg);
+ PushBuffer(b);
return RaiseErrorFromStack();
}
@@ -473,7 +516,7 @@ internal readonly int RaiseErrorFromStack()
{
Debug.Assert(StackTop != 0, "Expected error message on the stack");
- return state.Error();
+ return NativeMethods.Error(state);
}
///
@@ -492,7 +535,6 @@ private void UpdateStackTop(int by)
/// Check that the given index refers to a valid part of the stack.
///
[Conditional("DEBUG")]
- [MethodImpl(MethodImplOptions.NoInlining)]
private readonly void AssertLuaStackIndexInBounds(int stackIndex)
{
Debug.Assert(stackIndex == (int)LuaRegistry.Index || (stackIndex > 0 && stackIndex <= StackTop), "Lua stack index out of bounds");
@@ -502,20 +544,133 @@ private readonly void AssertLuaStackIndexInBounds(int stackIndex)
/// Check that the Lua stack top is where expected in DEBUG builds.
///
[Conditional("DEBUG")]
- [MethodImpl(MethodImplOptions.NoInlining)]
private readonly void AssertLuaStackExpected()
{
- Debug.Assert(NativeMethods.GetTop(state.Handle) == StackTop, "Lua stack not where expected");
+ Debug.Assert(NativeMethods.GetTop(state) == StackTop, "Lua stack not where expected");
}
///
/// Check that there's space to push some number of elements.
///
[Conditional("DEBUG")]
- [MethodImpl(MethodImplOptions.NoInlining)]
private readonly void AssertLuaStackNotFull(int probe = 1)
{
Debug.Assert((StackTop + probe) <= curStackSize, "Lua stack should have been grown before pushing");
}
}
+
+ ///
+ /// Holds static functions for Lua-to-.NET interop.
+ ///
+ /// We annotate these as "unmanaged callers only" as a micro-optimization.
+ /// See: https://devblogs.microsoft.com/dotnet/improvements-in-native-code-interop-in-net-5-0/#unmanagedcallersonly
+ ///
+ internal static class LuaStateWrapperTrampolines
+ {
+ ///
+ /// Controls the singular logger that will be used for panic invocations.
+ ///
+ /// Because Lua is panic'ing, we're about to crash. This being process wide is hacky,
+ /// but we're in hacky situtation.
+ ///
+ internal static ILogger PanicLogger { get; set; }
+
+ ///
+ /// Called when Lua encounters an unrecoverable error.
+ ///
+ /// When this returns, Lua is going to terminate the process, so plan accordingly.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int LuaAtPanic(nint luaState)
+ {
+ // See: https://www.lua.org/manual/5.4/manual.html#4.4
+
+ string errorMsg;
+
+ var stackTop = NativeMethods.GetTop(luaState);
+ if (stackTop >= 1)
+ {
+ if (NativeMethods.CheckBuffer(luaState, stackTop, out var msgBuf))
+ {
+ errorMsg = Encoding.UTF8.GetString(msgBuf);
+ }
+ else
+ {
+ var type = NativeMethods.Type(luaState, stackTop);
+
+ errorMsg = $"Unexpected error type: {type}";
+ }
+ }
+ else
+ {
+ errorMsg = "No error on stack";
+ }
+
+ PanicLogger?.LogCritical("Lua Panic '{errorMsg}', stack size {stackTop}", errorMsg, stackTop);
+
+ return 0;
+ }
+
+ ///
+ /// Provides data for Lua allocations.
+ ///
+ /// All returns data must be PINNED or unmanaged, Lua does not allow it to move.
+ ///
+ /// If allocation cannot be performed, null is returned.
+ ///
+ /// Pointer to user data provided during
+ /// Either null (if new alloc) or pointer to existing allocation being resized or freed.
+ /// If is not null, the value passed when allocation was obtained or resized.
+ /// The desired size of the allocation, in bytes.
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static unsafe nint LuaAllocateBytes(nint udPtr, nint ptr, nuint osize, nuint nsize)
+ {
+ // See: https://www.lua.org/manual/5.4/manual.html#lua_Alloc
+
+ var handle = (GCHandle)udPtr;
+ Debug.Assert(handle.IsAllocated, "GCHandle should always be valid");
+
+ var customAllocator = (ILuaAllocator)handle.Target;
+
+ if (ptr != IntPtr.Zero)
+ {
+ // Now osize is the size used to (re)allocate ptr last
+
+ ref var dataRef = ref Unsafe.AsRef((void*)ptr);
+
+ if (nsize == 0)
+ {
+ customAllocator.Free(ref dataRef, (int)osize);
+
+ return 0;
+ }
+ else
+ {
+ ref var ret = ref customAllocator.ResizeAllocation(ref dataRef, (int)osize, (int)nsize, out var failed);
+ if (failed)
+ {
+ return 0;
+ }
+
+ var retPtr = (nint)Unsafe.AsPointer(ref ret);
+
+ return retPtr;
+ }
+ }
+ else
+ {
+ // Now osize is the size of the object being allocated, but nsize is the desired size
+
+ ref var ret = ref customAllocator.AllocateNew((int)nsize, out var failed);
+ if (failed)
+ {
+ return 0;
+ }
+
+ var retPtr = (nint)Unsafe.AsPointer(ref ret);
+
+ return retPtr;
+ }
+ }
+ }
}
\ No newline at end of file
diff --git a/libs/server/Lua/LuaTrackedAllocator.cs b/libs/server/Lua/LuaTrackedAllocator.cs
new file mode 100644
index 0000000000..f66ff502c1
--- /dev/null
+++ b/libs/server/Lua/LuaTrackedAllocator.cs
@@ -0,0 +1,101 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+
+using System;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+
+namespace Garnet.server
+{
+ ///
+ /// Lua allocator which still uses native memory, but tracks it and can impose limits.
+ ///
+ /// is also used to inform .NET host of native allocations.
+ ///
+ internal unsafe class LuaTrackedAllocator : ILuaAllocator
+ {
+ private readonly int? limitBytes;
+
+ private int allocatedBytes;
+
+ internal LuaTrackedAllocator(int? limitBytes)
+ {
+ this.limitBytes = limitBytes;
+ }
+
+ ///
+ public ref byte AllocateNew(int sizeBytes, out bool failed)
+ {
+ if (limitBytes != null)
+ {
+ if (allocatedBytes + sizeBytes > limitBytes)
+ {
+ failed = true;
+ return ref Unsafe.NullRef();
+ }
+
+ allocatedBytes += sizeBytes;
+ }
+
+ if (sizeBytes > 0)
+ {
+ GC.AddMemoryPressure(sizeBytes);
+ }
+
+ failed = false;
+
+ var ptr = NativeMemory.Alloc((nuint)sizeBytes);
+ return ref Unsafe.AsRef(ptr);
+ }
+
+ ///
+ public void Free(ref byte start, int sizeBytes)
+ {
+ NativeMemory.Free(Unsafe.AsPointer(ref start));
+
+ if (sizeBytes > 0)
+ {
+ allocatedBytes -= sizeBytes;
+
+ GC.RemoveMemoryPressure(sizeBytes);
+ }
+ }
+
+ ///
+ public ref byte ResizeAllocation(ref byte start, int oldSizeBytes, int newSizeBytes, out bool failed)
+ {
+ var delta = newSizeBytes - oldSizeBytes;
+
+ if (delta == 0)
+ {
+ failed = false;
+ return ref start;
+ }
+
+ if (limitBytes != null)
+ {
+ if (allocatedBytes + delta > limitBytes)
+ {
+ failed = true;
+ return ref Unsafe.NullRef();
+ }
+
+ allocatedBytes += delta;
+ }
+
+ if (delta < 0)
+ {
+ GC.RemoveMemoryPressure(-delta);
+ }
+ else
+ {
+ GC.AddMemoryPressure(delta);
+ }
+
+ failed = false;
+
+ var ptr = NativeMemory.Realloc(Unsafe.AsPointer(ref start), (nuint)newSizeBytes);
+ return ref Unsafe.AsRef(ptr);
+ }
+ }
+}
\ No newline at end of file
diff --git a/libs/server/Lua/NativeMethods.cs b/libs/server/Lua/NativeMethods.cs
index cc6bcf085b..76a9fda963 100644
--- a/libs/server/Lua/NativeMethods.cs
+++ b/libs/server/Lua/NativeMethods.cs
@@ -2,10 +2,13 @@
// Licensed under the MIT license.
using System;
+using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using KeraLua;
using charptr_t = nint;
+using intptr_t = nint;
+using lua_CFunction = nint;
using lua_State = nint;
using size_t = nuint;
@@ -42,6 +45,139 @@ internal static partial class NativeMethods
[UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
private static partial LuaStatus luaL_loadbufferx(lua_State luaState, charptr_t buff, size_t sz, charptr_t name, charptr_t mode);
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#luaL_newstate
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial nint luaL_newstate();
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_newstate
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial nint lua_newstate(lua_CFunction allocFunc, charptr_t ud);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#luaL_openlibs
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial void luaL_openlibs(lua_State luaState);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_close
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial void lua_close(lua_State luaState);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_checkstack
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial int lua_checkstack(lua_State luaState, int n);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#luaL_checknumber
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial double luaL_checknumber(lua_State luaState, int n);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_rawlen
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial int lua_rawlen(lua_State luaState, int n);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_pcallk
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial int lua_pcallk(lua_State luaState, int nargs, int nresults, int msgh, nint ctx, nint k);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_callk
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial int lua_callk(lua_State luaState, int nargs, int nresults, nint ctx, nint k);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_rawseti
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial void lua_rawseti(lua_State luaState, int index, long i);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_rawset
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial void lua_rawset(lua_State luaState, int index);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_rawgeti
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial int lua_rawgeti(lua_State luaState, int index, long n);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_rawget
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial int lua_rawget(lua_State luaState, int index);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#luaL_ref
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial int luaL_ref(lua_State luaState, int index);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#luaL_unref
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial int luaL_unref(lua_State luaState, int index, int refVal);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_createtable
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial void lua_createtable(lua_State luaState, int narr, int nrec);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_getglobal
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial int lua_getglobal(lua_State luaState, charptr_t name);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_setglobal
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial void lua_setglobal(lua_State luaState, charptr_t name);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_error
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial int lua_error(lua_State luaState);
+
// GC Transition suppressed - only do this after auditing the Lua method and confirming constant-ish, fast, runtime w/o allocations
///
@@ -51,7 +187,7 @@ internal static partial class NativeMethods
/// see: https://www.lua.org/source/5.4/lapi.c.html#lua_gettop
///
[LibraryImport(LuaLibraryName)]
- [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])]
private static partial int lua_gettop(lua_State luaState);
///
@@ -63,7 +199,7 @@ internal static partial class NativeMethods
/// see: https://www.lua.org/source/5.4/lapi.c.html#index2value
///
[LibraryImport(LuaLibraryName)]
- [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])]
private static partial LuaType lua_type(lua_State L, int index);
///
@@ -73,7 +209,7 @@ internal static partial class NativeMethods
/// see: https://www.lua.org/source/5.4/lapi.c.html#lua_pushnil
///
[LibraryImport(LuaLibraryName)]
- [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])]
private static partial void lua_pushnil(lua_State L);
///
@@ -83,7 +219,7 @@ internal static partial class NativeMethods
/// see: https://www.lua.org/source/5.4/lapi.c.html#lua_pushinteger
///
[LibraryImport(LuaLibraryName)]
- [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])]
private static partial void lua_pushinteger(lua_State L, long num);
///
@@ -93,7 +229,7 @@ internal static partial class NativeMethods
/// see: https://www.lua.org/source/5.4/lapi.c.html#lua_pushboolean
///
[LibraryImport(LuaLibraryName)]
- [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])]
private static partial void lua_pushboolean(lua_State L, int b);
///
@@ -103,9 +239,20 @@ internal static partial class NativeMethods
/// see: https://www.lua.org/source/5.4/lapi.c.html#lua_toboolean
///
[LibraryImport(LuaLibraryName)]
- [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])]
private static partial int lua_toboolean(lua_State L, int ix);
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_tointegerx
+ ///
+ /// We should always have checked this is actually a number before calling,
+ /// so the expensive paths won't be taken.
+ /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_tointegerx
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])]
+ private static partial long lua_tointegerx(lua_State L, int idex, intptr_t pisnum);
+
///
/// see: https://www.lua.org/manual/5.4/manual.html#lua_settop
///
@@ -113,9 +260,31 @@ internal static partial class NativeMethods
/// see: https://www.lua.org/source/5.4/lapi.c.html#lua_settop
///
[LibraryImport(LuaLibraryName)]
- [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])]
private static partial void lua_settop(lua_State L, int num);
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_atpanic
+ ///
+ /// Just changing a global value, should be quick.
+ /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_atpanic
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])]
+ private static partial lua_CFunction lua_atpanic(lua_State luaState, lua_CFunction panicf);
+
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_pushcclosure
+ ///
+ /// We never call this with n != 0, so does very little.
+ /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_pushcclosure
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl), typeof(CallConvSuppressGCTransition)])]
+ private static partial void lua_pushcclosure(lua_State luaState, lua_CFunction fn, int n);
+
+ // Helper methods for using the pinvokes defined above
+
///
/// Returns true if the given index on the stack holds a string or a number.
///
@@ -129,19 +298,15 @@ internal static partial class NativeMethods
///
internal static bool CheckBuffer(lua_State luaState, int index, out ReadOnlySpan str)
{
- var type = lua_type(luaState, index);
-
- if (type is not LuaType.String and not LuaType.Number)
- {
- str = [];
- return false;
- }
-
+ // See: https://www.lua.org/source/5.4/lapi.c.html#lua_tolstring
+ //
+ // If lua_tolstring fails, it will set len == 0 and start == NULL
var start = lua_tolstring(luaState, index, out var len);
+
unsafe
{
str = new ReadOnlySpan((byte*)start, (int)len);
- return true;
+ return start != (charptr_t)(void*)null;
}
}
@@ -168,12 +333,15 @@ internal static void KnownStringToBuffer(lua_State luaState, int index, out Read
///
/// Provided data is copied, and can be reused once this call returns.
///
- internal static unsafe void PushBuffer(lua_State luaState, ReadOnlySpan str)
+ internal static unsafe ref byte PushBuffer(lua_State luaState, ReadOnlySpan str)
{
+ nint inLuaPtr;
fixed (byte* ptr = str)
{
- _ = lua_pushlstring(luaState, (charptr_t)ptr, (size_t)str.Length);
+ inLuaPtr = lua_pushlstring(luaState, (charptr_t)ptr, (size_t)str.Length);
}
+
+ return ref Unsafe.AsRef((void*)inLuaPtr);
}
///
@@ -239,6 +407,14 @@ internal static void PushBoolean(lua_State luaState, bool b)
internal static bool ToBoolean(lua_State luaState, int index)
=> lua_toboolean(luaState, index) != 0;
+ ///
+ /// Read a long off the stack
+ ///
+ /// Differs from by suppressing GC transition.
+ ///
+ internal static long ToInteger(lua_State luaState, int index)
+ => lua_tointegerx(luaState, index, 0);
+
///
/// Remove some number of items from the stack.
///
@@ -246,5 +422,172 @@ internal static bool ToBoolean(lua_State luaState, int index)
///
internal static void Pop(lua_State luaState, int num)
=> lua_settop(luaState, -num - 1);
+
+ ///
+ /// Update the panic function.
+ ///
+ /// Differs from by taking a function pointer
+ /// and suppressing GC transition.
+ ///
+ internal static unsafe nint AtPanic(lua_State luaState, delegate* unmanaged[Cdecl] panicFunc)
+ => lua_atpanic(luaState, (nint)panicFunc);
+
+ ///
+ /// Create a new Lua state.
+ ///
+ internal static unsafe nint NewState()
+ => luaL_newstate();
+
+ ///
+ /// Create a new Lua state.
+ ///
+ /// Differs from by taking a function pointer.
+ ///
+ internal static unsafe nint NewState(delegate* unmanaged[Cdecl] allocFunc, nint ud)
+ => lua_newstate((nint)allocFunc, ud);
+
+ ///
+ /// Open all standard Lua libraries.
+ ///
+ internal static void OpenLibs(lua_State luaState)
+ => luaL_openlibs(luaState);
+
+ ///
+ /// Close the state, releasing all associated resources.
+ ///
+ internal static void Close(lua_State luaState)
+ => lua_close(luaState);
+
+ ///
+ /// Reserve space on the stack, returning false if that was not possible.
+ ///
+ internal static bool CheckStack(lua_State luaState, int n)
+ => lua_checkstack(luaState, n) == 1;
+
+ ///
+ /// Read a number, as a double, out of the stack.
+ ///
+ internal static double CheckNumber(lua_State luaState, int n)
+ => luaL_checknumber(luaState, n);
+
+ ///
+ /// Gets the length of an object on the stack, ignoring metatable methods.
+ ///
+ internal static int RawLen(lua_State luaState, int n)
+ => lua_rawlen(luaState, n);
+
+ ///
+ /// Push a function onto the stack.
+ ///
+ internal static void PushCFunction(lua_State luaState, nint ptr)
+ => lua_pushcclosure(luaState, ptr, 0);
+
+ ///
+ /// Perform a protected call with the given number of arguments, expecting the given number of returns.
+ ///
+ internal static LuaStatus PCall(lua_State luaState, int nargs, int nrets)
+ => (LuaStatus)lua_pcallk(luaState, nargs, nrets, 0, 0, 0);
+
+ ///
+ /// Perform a call with the given number of arguments, expecting the given number of returns.
+ ///
+ internal static void Call(lua_State luaState, int nargs, int nrets)
+ => lua_callk(luaState, nargs, nrets, 0, 0);
+
+ ///
+ /// Equivalent of t[i] = v, where t is the table at the given index and v is the value on the top of the stack.
+ ///
+ /// Ignores metatable methods.
+ ///
+ internal static void RawSetInteger(lua_State luaState, int index, long i)
+ => lua_rawseti(luaState, index, i);
+
+ ///
+ /// Equivalent to t[k] = v, where t is the value at the given index, v is the value on the top of the stack, and k is the value just below the top.
+ ///
+ /// Ignores metatable methods.
+ ///
+ internal static void RawSet(lua_State luaState, int index)
+ => lua_rawset(luaState, index);
+
+ ///
+ /// Pushes onto the stack the value t[n], where t is the table at the given index.
+ ///
+ /// Ignores metatable methods.
+ ///
+ internal static LuaType RawGetInteger(lua_State luaState, int index, long n)
+ => (LuaType)lua_rawgeti(luaState, index, n);
+
+ ///
+ /// Pushes onto the stack the value t[k], where t is the value at the given index and k is the value on the top of the stack.
+ ///
+ /// Ignores metatable methods.
+ ///
+ internal static LuaType RawGet(lua_State luaState, int index)
+ => (LuaType)lua_rawget(luaState, index);
+
+ ///
+ /// Creates and reference a reference in that table at the given index.
+ ///
+ internal static int Ref(lua_State luaState, int index)
+ => luaL_ref(luaState, index);
+
+ ///
+ /// Free a ref previously created with .
+ ///
+ internal static int Unref(lua_State luaState, int index, int refVal)
+ => luaL_unref(luaState, index, refVal);
+
+ ///
+ /// Create a new table and push it on the stack.
+ ///
+ /// Reserves capacity for the given number of elements and hints at capacity for non-sequence records.
+ ///
+ internal static void CreateTable(lua_State luaState, int elements, int records)
+ => lua_createtable(luaState, elements, records);
+
+ ///
+ /// Load a global under the given name onto the stack.
+ ///
+ internal static unsafe LuaType GetGlobal(lua_State luaState, ReadOnlySpan nullTerminatedName)
+ {
+ Debug.Assert(nullTerminatedName[^1] == 0, "Global name must be null terminated");
+
+ fixed (byte* ptr = nullTerminatedName)
+ {
+ return (LuaType)lua_getglobal(luaState, (nint)ptr);
+ }
+ }
+
+ ///
+ /// Pops the top item on the stack, and stores it under the given name as a global.
+ ///
+ internal static unsafe void SetGlobal(lua_State luaState, ReadOnlySpan nullTerminatedName)
+ {
+ Debug.Assert(nullTerminatedName[^1] == 0, "Global name must be null terminated");
+
+ fixed (byte* ptr = nullTerminatedName)
+ {
+ lua_setglobal(luaState, (nint)ptr);
+ }
+ }
+
+ ///
+ /// Sets the index of the top elements on the stack.
+ ///
+ /// 0 == empty
+ ///
+ /// Items above this point can no longer be safely accessed.
+ ///
+ internal static void SetTop(lua_State lua_State, int top)
+ => lua_settop(lua_State, top);
+
+ ///
+ /// Raise an error, using the top of the stack as an error item.
+ ///
+ /// This method never returns, so be careful calling it.
+ ///
+ internal static int Error(lua_State luaState)
+ => lua_error(luaState);
}
}
\ No newline at end of file
diff --git a/libs/server/Lua/SessionScriptCache.cs b/libs/server/Lua/SessionScriptCache.cs
index 4dfd1ac296..4d48154f83 100644
--- a/libs/server/Lua/SessionScriptCache.cs
+++ b/libs/server/Lua/SessionScriptCache.cs
@@ -5,11 +5,9 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Security.Cryptography;
-using Garnet.common;
using Garnet.server.ACL;
using Garnet.server.Auth;
using Microsoft.Extensions.Logging;
-using Tsavorite.core;
namespace Garnet.server
{
@@ -28,6 +26,9 @@ internal sealed class SessionScriptCache : IDisposable
readonly Dictionary scriptCache = [];
readonly byte[] hash = new byte[SHA1Len / 2];
+ readonly LuaMemoryManagementMode memoryManagementMode;
+ readonly int? memoryLimitBytes;
+
public SessionScriptCache(StoreWrapper storeWrapper, IGarnetAuthenticator authenticator, ILogger logger = null)
{
this.storeWrapper = storeWrapper;
@@ -35,6 +36,10 @@ public SessionScriptCache(StoreWrapper storeWrapper, IGarnetAuthenticator authen
scratchBufferNetworkSender = new ScratchBufferNetworkSender();
processor = new RespServerSession(0, scratchBufferNetworkSender, storeWrapper, null, authenticator, false);
+
+ // There's some parsing involved in these, so save them off per-session
+ memoryManagementMode = storeWrapper.serverOptions.LuaOptions.MemoryManagementMode;
+ memoryLimitBytes = storeWrapper.serverOptions.LuaOptions.GetMemoryLimitBytes();
}
public void Dispose()
@@ -74,21 +79,31 @@ internal bool TryLoad(RespServerSession session, ReadOnlySpan source, Scri
{
var sourceOnHeap = source.ToArray();
- runner = new LuaRunner(sourceOnHeap, storeWrapper.serverOptions.LuaTransactionMode, processor, scratchBufferNetworkSender, logger);
- runner.CompileForSession(session);
-
- // Need to make sure the key is on the heap, so move it over
- //
- // There's an implicit assumption that all callers are using unmanaged memory.
- // If that becomes untrue, there's an optimization opportunity to re-use the
- // managed memory here.
- var into = GC.AllocateUninitializedArray(SHA1Len, pinned: true);
- digest.CopyTo(into);
-
- ScriptHashKey storeKeyDigest = new(into);
- digestOnHeap = storeKeyDigest;
-
- _ = scriptCache.TryAdd(storeKeyDigest, runner);
+ runner = new LuaRunner(memoryManagementMode, memoryLimitBytes, sourceOnHeap, storeWrapper.serverOptions.LuaTransactionMode, processor, scratchBufferNetworkSender, logger);
+
+ // If compilation fails, an error is written out
+ if (runner.CompileForSession(session))
+ {
+ // Need to make sure the key is on the heap, so move it over
+ //
+ // There's an implicit assumption that all callers are using unmanaged memory.
+ // If that becomes untrue, there's an optimization opportunity to re-use the
+ // managed memory here.
+ var into = GC.AllocateUninitializedArray(SHA1Len, pinned: true);
+ digest.CopyTo(into);
+
+ ScriptHashKey storeKeyDigest = new(into);
+ digestOnHeap = storeKeyDigest;
+
+ _ = scriptCache.TryAdd(storeKeyDigest, runner);
+ }
+ else
+ {
+ runner.Dispose();
+
+ digestOnHeap = null;
+ return false;
+ }
}
catch (Exception ex)
{
diff --git a/libs/server/Resp/CmdStrings.cs b/libs/server/Resp/CmdStrings.cs
index cb32cfc432..8f32e0bbea 100644
--- a/libs/server/Resp/CmdStrings.cs
+++ b/libs/server/Resp/CmdStrings.cs
@@ -353,6 +353,7 @@ static partial class CmdStrings
// Lua scripting strings
public static ReadOnlySpan LUA_OK => "OK"u8;
+ public static ReadOnlySpan LUA_ok => "ok"u8;
public static ReadOnlySpan LUA_err => "err"u8;
public static ReadOnlySpan LUA_No_session_available => "No session available"u8;
public static ReadOnlySpan LUA_ERR_Please_specify_at_least_one_argument_for_this_redis_lib_call => "ERR Please specify at least one argument for this redis lib call"u8;
diff --git a/libs/server/Resp/RespServerSession.cs b/libs/server/Resp/RespServerSession.cs
index d0e357e2bf..aaf064f57a 100644
--- a/libs/server/Resp/RespServerSession.cs
+++ b/libs/server/Resp/RespServerSession.cs
@@ -394,6 +394,21 @@ public override int TryConsumeMessages(byte* reqBuffer, int bytesReceived)
return readHead;
}
+ ///
+ /// For testing purposes, call and update state accordingly.
+ ///
+ internal void EnterAndGetResponseObject()
+ => networkSender.EnterAndGetResponseObject(out dcurr, out dend);
+
+ ///
+ /// For testing purposes, call and update state accordingly.
+ ///
+ internal void ExitAndReturnResponseObject()
+ {
+ networkSender.ExitAndReturnResponseObject();
+ dcurr = dend = (byte*)0;
+ }
+
internal void SetTransactionMode(bool enable)
=> txnManager.state = enable ? TxnState.Running : TxnState.None;
diff --git a/libs/server/Servers/GarnetServerOptions.cs b/libs/server/Servers/GarnetServerOptions.cs
index 05dce6ca04..23d310842b 100644
--- a/libs/server/Servers/GarnetServerOptions.cs
+++ b/libs/server/Servers/GarnetServerOptions.cs
@@ -404,6 +404,8 @@ public class GarnetServerOptions : ServerOptions
public bool EnableObjectStoreReadCache = false;
+ public LuaOptions LuaOptions;
+
///
/// Constructor
///
diff --git a/libs/server/Servers/ServerOptions.cs b/libs/server/Servers/ServerOptions.cs
index 342b563020..efa6208859 100644
--- a/libs/server/Servers/ServerOptions.cs
+++ b/libs/server/Servers/ServerOptions.cs
@@ -228,7 +228,7 @@ public void GetSettings()
///
///
///
- protected static long ParseSize(string value)
+ public static long ParseSize(string value)
{
char[] suffix = ['k', 'm', 'g', 't', 'p'];
long result = 0;
diff --git a/test/BDNPerfTests/BDN_Benchmark_Config.json b/test/BDNPerfTests/BDN_Benchmark_Config.json
index fcb08f23b7..89ee51f692 100644
--- a/test/BDNPerfTests/BDN_Benchmark_Config.json
+++ b/test/BDNPerfTests/BDN_Benchmark_Config.json
@@ -85,10 +85,90 @@
"expected_MSet_None": 0
},
"BDN.benchmark.Lua.LuaScripts.*": {
- "expected_Script1_None": 0,
- "expected_Script2_None": 24,
- "expected_Script3_None": 32,
- "expected_Script4_None": 0
+ "expected_Script1_Managed,Limit": 0,
+ "expected_Script2_Managed,Limit": 24,
+ "expected_Script3_Managed,Limit": 32,
+ "expected_Script4_Managed,Limit": 0,
+ "expected_Script1_Managed,None": 0,
+ "expected_Script2_Managed,None": 24,
+ "expected_Script3_Managed,None": 32,
+ "expected_Script4_Managed,None": 0,
+ "expected_Script1_Native,None": 0,
+ "expected_Script2_Native,None": 24,
+ "expected_Script3_Native,None": 32,
+ "expected_Script4_Native,None": 0,
+ "expected_Script1_Tracked,Limit": 0,
+ "expected_Script2_Tracked,Limit": 24,
+ "expected_Script3_Tracked,Limit": 32,
+ "expected_Script4_Tracked,Limit": 0,
+ "expected_Script1_Tracked,None": 0,
+ "expected_Script2_Tracked,None": 24,
+ "expected_Script3_Tracked,None": 32,
+ "expected_Script4_Tracked,None": 0
+ },
+ "BDN.benchmark.Lua.LuaRunnerOperations.*": {
+ "expected_ResetParametersSmall_Managed,Limit": 0,
+ "expected_ResetParametersLarge_Managed,Limit": 0,
+ "expected_ConstructSmall_Managed,Limit": 2097602,
+ "expected_ConstructLarge_Managed,Limit": 2100666,
+ "expected_CompileForSessionSmall_Managed,Limit": 99,
+ "expected_CompileForSessionLarge_Managed,Limit": 0,
+ "expected_ResetParametersSmall_Managed,None": 0,
+ "expected_ResetParametersLarge_Managed,None": 0,
+ "expected_ConstructSmall_Managed,None": 2097674,
+ "expected_ConstructLarge_Managed,None": 2100738,
+ "expected_CompileForSessionSmall_Managed,None": 512,
+ "expected_CompileForSessionLarge_Managed,None": 0,
+ "expected_ResetParametersSmall_Native,None": 0,
+ "expected_ResetParametersLarge_Native,None": 0,
+ "expected_ConstructSmall_Native,None": 328,
+ "expected_ConstructLarge_Native,None": 3392,
+ "expected_CompileForSessionSmall_Native,None": 0,
+ "expected_CompileForSessionLarge_Native,None": 0,
+ "expected_ResetParametersSmall_Tracked,Limit": 0,
+ "expected_ResetParametersLarge_Tracked,Limit": 0,
+ "expected_ConstructSmall_Tracked,Limit": 402,
+ "expected_ConstructLarge_Tracked,Limit": 3465,
+ "expected_CompileForSessionSmall_Tracked,Limit": 0,
+ "expected_CompileForSessionLarge_Tracked,Limit": 0,
+ "expected_ResetParametersSmall_Tracked,None": 0,
+ "expected_ResetParametersLarge_Tracked,None": 0,
+ "expected_ConstructSmall_Tracked,None": 362,
+ "expected_ConstructLarge_Tracked,None": 3425,
+ "expected_CompileForSessionSmall_Tracked,None": 0,
+ "expected_CompileForSessionLarge_Tracked,None": 0
+ },
+ "BDN.benchmark.Lua.LuaScriptCacheOperations.*": {
+ "expected_LookupHit_Managed,Limit": 1312,
+ "expected_LookupMiss_Managed,Limit": 688,
+ "expected_LoadOuterHit_Managed,Limit": 688,
+ "expected_LoadInnerHit_Managed,Limit": 2098272,
+ "expected_LoadMiss_Managed,Limit": 1312,
+ "expected_Digest_Managed,Limit": 1312,
+ "expected_LookupHit_Managed,None": 688,
+ "expected_LookupMiss_Managed,None": 1312,
+ "expected_LoadOuterHit_Managed,None": 688,
+ "expected_LoadInnerHit_Managed,None": 2097760,
+ "expected_LoadMiss_Managed,None": 1312,
+ "expected_Digest_Managed,None": 688,
+ "expected_LookupHit_Native,None": 1312,
+ "expected_LookupMiss_Native,None": 1312,
+ "expected_LoadOuterHit_Native,None": 688,
+ "expected_LoadInnerHit_Native,None": 1040,
+ "expected_LoadMiss_Native,None": 1312,
+ "expected_Digest_Native,None": 688,
+ "expected_LookupHit_Tracked,Limit": 688,
+ "expected_LookupMiss_Tracked,Limit": 1312,
+ "expected_LoadOuterHit_Tracked,Limit": 688,
+ "expected_LoadInnerHit_Tracked,Limit": 1072,
+ "expected_LoadMiss_Tracked,Limit": 1264,
+ "expected_Digest_Tracked,Limit": 1312,
+ "expected_LookupHit_Tracked,None": 688,
+ "expected_LookupMiss_Tracked,None": 1312,
+ "expected_LoadOuterHit_Tracked,None": 1312,
+ "expected_LoadInnerHit_Tracked,None": 1696,
+ "expected_LoadMiss_Tracked,None": 1312,
+ "expected_Digest_Tracked,None": 1312
},
"BDN.benchmark.Operations.CustomOperations.*": {
"expected_CustomRawStringCommand_ACL": 0,
@@ -159,7 +239,7 @@
"expected_Eval_None": 0,
"expected_EvalSha_None": 0,
"expected_SmallScript_None": 0,
- "expected_LargeScript_None": 12,
+ "expected_LargeScript_None": 23,
"expected_ArrayReturn_None": 0
},
"BDN.benchmark.Network.RawStringOperations.*": {
diff --git a/test/BDNPerfTests/run_bdnperftest.ps1 b/test/BDNPerfTests/run_bdnperftest.ps1
index 5c88e0be87..afece51daa 100644
--- a/test/BDNPerfTests/run_bdnperftest.ps1
+++ b/test/BDNPerfTests/run_bdnperftest.ps1
@@ -73,6 +73,11 @@ param ($ResultsLine, $columnNum)
$columns = $ResultsLine.Trim('|').Split('|')
$column = $columns | ForEach-Object { $_.Trim() }
$foundValue = $column[$columnNum]
+ if ($foundValue -eq "NA") {
+ Write-Error -Message "The value for the column was NA which means that the BDN test failed and didn't generate performance metrics. Verify the BDN test ran successfully."
+ exit
+ }
+
if ($foundValue -eq "-") {
$foundValue = "0"
}
diff --git a/test/Garnet.test.cluster/ClusterTestContext.cs b/test/Garnet.test.cluster/ClusterTestContext.cs
index 8b1529879e..8bc92d8668 100644
--- a/test/Garnet.test.cluster/ClusterTestContext.cs
+++ b/test/Garnet.test.cluster/ClusterTestContext.cs
@@ -126,7 +126,9 @@ public void CreateInstances(
bool disablePubSub = true,
int metricsSamplingFrequency = 0,
bool enableLua = false,
- bool asyncReplay = false)
+ bool asyncReplay = false,
+ LuaMemoryManagementMode luaMemoryMode = LuaMemoryManagementMode.Native,
+ string luaMemoryLimit = "")
{
endpoints = TestUtils.GetEndPoints(shards, 7000);
nodes = TestUtils.CreateGarnetCluster(
@@ -159,7 +161,9 @@ public void CreateInstances(
authenticationSettings: authenticationSettings,
metricsSamplingFrequency: metricsSamplingFrequency,
enableLua: enableLua,
- asyncReplay: asyncReplay);
+ asyncReplay: asyncReplay,
+ luaMemoryMode: luaMemoryMode,
+ luaMemoryLimit: luaMemoryLimit);
foreach (var node in nodes)
node.Start();
diff --git a/test/Garnet.test/GarnetServerConfigTests.cs b/test/Garnet.test/GarnetServerConfigTests.cs
index adbc02d881..b3aac12f40 100644
--- a/test/Garnet.test/GarnetServerConfigTests.cs
+++ b/test/Garnet.test/GarnetServerConfigTests.cs
@@ -10,6 +10,7 @@
using System.Text.Json.Serialization;
using CommandLine;
using Garnet.common;
+using Garnet.server;
using Microsoft.Extensions.Logging;
using NUnit.Framework;
using NUnit.Framework.Legacy;
@@ -246,5 +247,80 @@ public void ImportExportConfigAzure()
deviceFactory.Delete(new FileDescriptor { directoryName = "" });
}
}
+
+ [Test]
+ public void LuaMemoryOptions()
+ {
+ // Defaults to Native with no limit
+ {
+ var args = new[] { "--lua" };
+ var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully);
+ ClassicAssert.IsTrue(parseSuccessful);
+ ClassicAssert.IsTrue(options.EnableLua);
+ ClassicAssert.AreEqual(LuaMemoryManagementMode.Native, options.LuaMemoryManagementMode);
+ ClassicAssert.IsNull(options.LuaScriptMemoryLimit);
+ }
+
+ // Native with limit rejected
+ {
+ var args = new[] { "--lua", "--lua-script-memory-limit", "10m" };
+ var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully);
+ ClassicAssert.IsFalse(parseSuccessful);
+ }
+
+ // Tracked with no limit works
+ {
+ var args = new[] { "--lua", "--lua-memory-management-mode", "Tracked" };
+ var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully);
+ ClassicAssert.IsTrue(parseSuccessful);
+ ClassicAssert.IsTrue(options.EnableLua);
+ ClassicAssert.AreEqual(LuaMemoryManagementMode.Tracked, options.LuaMemoryManagementMode);
+ ClassicAssert.IsNull(options.LuaScriptMemoryLimit);
+ }
+
+ // Tracked with limit works
+ {
+ var args = new[] { "--lua", "--lua-memory-management-mode", "Tracked", "--lua-script-memory-limit", "10m" };
+ var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully);
+ ClassicAssert.IsTrue(parseSuccessful);
+ ClassicAssert.IsTrue(options.EnableLua);
+ ClassicAssert.AreEqual(LuaMemoryManagementMode.Tracked, options.LuaMemoryManagementMode);
+ ClassicAssert.AreEqual("10m", options.LuaScriptMemoryLimit);
+ }
+
+ // Tracked with bad limit rejected
+ {
+ var args = new[] { "--lua", "--lua-memory-management-mode", "Tracked", "--lua-script-memory-limit", "10Q" };
+ var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully);
+ ClassicAssert.IsFalse(parseSuccessful);
+ }
+
+ // Managed with no limit works
+ {
+ var args = new[] { "--lua", "--lua-memory-management-mode", "Managed" };
+ var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully);
+ ClassicAssert.IsTrue(parseSuccessful);
+ ClassicAssert.IsTrue(options.EnableLua);
+ ClassicAssert.AreEqual(LuaMemoryManagementMode.Managed, options.LuaMemoryManagementMode);
+ ClassicAssert.IsNull(options.LuaScriptMemoryLimit);
+ }
+
+ // Managed with limit works
+ {
+ var args = new[] { "--lua", "--lua-memory-management-mode", "Managed", "--lua-script-memory-limit", "10m" };
+ var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully);
+ ClassicAssert.IsTrue(parseSuccessful);
+ ClassicAssert.IsTrue(options.EnableLua);
+ ClassicAssert.AreEqual(LuaMemoryManagementMode.Managed, options.LuaMemoryManagementMode);
+ ClassicAssert.AreEqual("10m", options.LuaScriptMemoryLimit);
+ }
+
+ // Managed with bad limit rejected
+ {
+ var args = new[] { "--lua", "--lua-memory-management-mode", "Managed", "--lua-script-memory-limit", "10Q" };
+ var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully);
+ ClassicAssert.IsFalse(parseSuccessful);
+ }
+ }
}
}
\ No newline at end of file
diff --git a/test/Garnet.test/LuaScriptRunnerTests.cs b/test/Garnet.test/LuaScriptRunnerTests.cs
index b7a7a7f0f7..7a9bef4ae8 100644
--- a/test/Garnet.test/LuaScriptRunnerTests.cs
+++ b/test/Garnet.test/LuaScriptRunnerTests.cs
@@ -1,6 +1,12 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
using Garnet.common;
using Garnet.server;
using NUnit.Framework;
@@ -15,51 +21,51 @@ internal class LuaScriptRunnerTests
public void CannotRunUnsafeScript()
{
// Try to load an assembly
- using (var runner = new LuaRunner("luanet.load_assembly('mscorlib')"))
+ using (var runner = new LuaRunner(new(), "luanet.load_assembly('mscorlib')"))
{
runner.CompileForRunner();
var ex = Assert.Throws(() => runner.RunForRunner());
- ClassicAssert.AreEqual("[string \"luanet.load_assembly('mscorlib')\"]:1: attempt to index a nil value (global 'luanet')", ex.Message);
+ ClassicAssert.AreEqual("ERR Lua encountered an error: [string \"luanet.load_assembly('mscorlib')\"]:1: attempt to index a nil value (global 'luanet')", ex.Message);
}
// Try to call a OS function
- using (var runner = new LuaRunner("os = require('os'); return os.time();"))
+ using (var runner = new LuaRunner(new(), "os = require('os'); return os.time();"))
{
runner.CompileForRunner();
var ex = Assert.Throws(() => runner.RunForRunner());
- ClassicAssert.AreEqual("[string \"os = require('os'); return os.time();\"]:1: attempt to call a nil value (global 'require')", ex.Message);
+ ClassicAssert.AreEqual("ERR Lua encountered an error: [string \"os = require('os'); return os.time();\"]:1: attempt to call a nil value (global 'require')", ex.Message);
}
// Try to execute the input stream
- using (var runner = new LuaRunner("dofile();"))
+ using (var runner = new LuaRunner(new(), "dofile();"))
{
runner.CompileForRunner();
var ex = Assert.Throws(() => runner.RunForRunner());
- ClassicAssert.AreEqual("[string \"dofile();\"]:1: attempt to call a nil value (global 'dofile')", ex.Message);
+ ClassicAssert.AreEqual("ERR Lua encountered an error: [string \"dofile();\"]:1: attempt to call a nil value (global 'dofile')", ex.Message);
}
// Try to call a windows executable
- using (var runner = new LuaRunner("require \"notepad\""))
+ using (var runner = new LuaRunner(new(), "require \"notepad\""))
{
runner.CompileForRunner();
var ex = Assert.Throws(() => runner.RunForRunner());
- ClassicAssert.AreEqual("[string \"require \"notepad\"\"]:1: attempt to call a nil value (global 'require')", ex.Message);
+ ClassicAssert.AreEqual("ERR Lua encountered an error: [string \"require \"notepad\"\"]:1: attempt to call a nil value (global 'require')", ex.Message);
}
// Try to call an OS function
- using (var runner = new LuaRunner("os.exit();"))
+ using (var runner = new LuaRunner(new(), "os.exit();"))
{
runner.CompileForRunner();
var ex = Assert.Throws(() => runner.RunForRunner());
- ClassicAssert.AreEqual("[string \"os.exit();\"]:1: attempt to index a nil value (global 'os')", ex.Message);
+ ClassicAssert.AreEqual("ERR Lua encountered an error: [string \"os.exit();\"]:1: attempt to index a nil value (global 'os')", ex.Message);
}
// Try to include a new .net library
- using (var runner = new LuaRunner("import ('System.Diagnostics');"))
+ using (var runner = new LuaRunner(new(), "import ('System.Diagnostics');"))
{
runner.CompileForRunner();
var ex = Assert.Throws(() => runner.RunForRunner());
- ClassicAssert.AreEqual("[string \"import ('System.Diagnostics');\"]:1: attempt to call a nil value (global 'import')", ex.Message);
+ ClassicAssert.AreEqual("ERR Lua encountered an error: [string \"import ('System.Diagnostics');\"]:1: attempt to call a nil value (global 'import')", ex.Message);
}
}
@@ -67,14 +73,14 @@ public void CannotRunUnsafeScript()
public void CanLoadScript()
{
// Code with error
- using (var runner = new LuaRunner("local;"))
+ using (var runner = new LuaRunner(new(), "local;"))
{
var ex = Assert.Throws(runner.CompileForRunner);
ClassicAssert.AreEqual("Compilation error: [string \"local;\"]:1: expected near ';'", ex.Message);
}
// Code without error
- using (var runner = new LuaRunner("local list; list = 1; return list;"))
+ using (var runner = new LuaRunner(new(), "local list; list = 1; return list;"))
{
runner.CompileForRunner();
}
@@ -87,7 +93,7 @@ public void CanRunScript()
string[] args = ["arg1", "arg2", "arg3"];
// Run code without errors
- using (var runner = new LuaRunner("local list; list = ARGV[1] ; return list;"))
+ using (var runner = new LuaRunner(new(), "local list; list = ARGV[1] ; return list;"))
{
runner.CompileForRunner();
var res = runner.RunForRunner(keys, args);
@@ -95,7 +101,7 @@ public void CanRunScript()
}
// Run code with errors
- using (var runner = new LuaRunner("local list; list = ; return list;"))
+ using (var runner = new LuaRunner(new(), "local list; list = ; return list;"))
{
var ex = Assert.Throws(runner.CompileForRunner);
ClassicAssert.AreEqual("Compilation error: [string \"local list; list = ; return list;\"]:1: unexpected symbol near ';'", ex.Message);
@@ -105,7 +111,7 @@ public void CanRunScript()
[Test]
public void KeysAndArgsCleared()
{
- using (var runner = new LuaRunner("return { KEYS[1], ARGV[1], KEYS[2], ARGV[2] }"))
+ using (var runner = new LuaRunner(new(), "return { KEYS[1], ARGV[1], KEYS[2], ARGV[2] }"))
{
runner.CompileForRunner();
var res1 = runner.RunForRunner(["hello", "world"], ["fizz", "buzz"]);
@@ -130,5 +136,374 @@ public void KeysAndArgsCleared()
ClassicAssert.AreEqual("345", (string)obj3[2]);
}
}
+
+ [Test]
+ public unsafe void LuaLimittedManaged()
+ {
+ const int Iters = 20;
+ const int TotalAllocSizeBytes = 1_024 * 1_024;
+
+ var rand = new Random(2024_12_16); // Repeatable, but random
+
+ // Special cases
+ {
+ var luaAlloc = new LuaLimitedManagedAllocator(TotalAllocSizeBytes);
+ luaAlloc.CheckCorrectness();
+
+ // 0 sized should work
+ ref var zero0 = ref luaAlloc.AllocateNew(0, out var failed0);
+ ClassicAssert.IsFalse(failed0);
+ ClassicAssert.IsFalse(Unsafe.IsNullRef(ref zero0));
+ ref var zero1 = ref luaAlloc.AllocateNew(0, out var failed1);
+ ClassicAssert.IsFalse(failed1);
+ ClassicAssert.IsFalse(Unsafe.IsNullRef(ref zero1));
+ luaAlloc.CheckCorrectness();
+
+ luaAlloc.Free(ref zero0, 0);
+ luaAlloc.Free(ref zero1, 0);
+ luaAlloc.CheckCorrectness();
+
+ // Impossibly large fails
+ ref var failedRef = ref luaAlloc.AllocateNew(TotalAllocSizeBytes * 2, out var failedLarge);
+ ClassicAssert.IsTrue(failedLarge);
+ ClassicAssert.IsTrue(Unsafe.IsNullRef(ref failedRef));
+
+ luaAlloc.CheckCorrectness();
+ }
+
+ // Fill whole allocation
+ {
+ var luaAlloc = new LuaLimitedManagedAllocator(TotalAllocSizeBytes);
+ var freeSpace = luaAlloc.FirstBlockSizeBytes;
+
+ for (var i = 0; i < Iters; i++)
+ {
+ for (var size = 1; size <= 64; size *= 2)
+ {
+ var lastSize = 0L;
+ var toFree = new List();
+ while (true)
+ {
+ ref var newData = ref luaAlloc.AllocateNew(size, out var failed);
+ if (failed)
+ {
+ break;
+ }
+ DebugClassicAssertIsTrue(luaAlloc.AllocatedBytes > lastSize);
+ lastSize = luaAlloc.AllocatedBytes;
+
+ var into = new Span(Unsafe.AsPointer(ref newData), size);
+ into.Fill((byte)toFree.Count);
+
+ toFree.Add((nint)Unsafe.AsPointer(ref newData));
+ }
+ luaAlloc.CheckCorrectness();
+
+ for (var j = 0; j < toFree.Count; j++)
+ {
+ var ptr = toFree[j];
+
+ var into = new Span((void*)ptr, size);
+ ClassicAssert.IsFalse(into.ContainsAnyExcept((byte)j));
+
+ luaAlloc.Free(ref Unsafe.AsRef((void*)ptr), size);
+ }
+ luaAlloc.CheckCorrectness();
+
+ DebugClassicAssertAreEqual(0, luaAlloc.AllocatedBytes);
+
+ _ = luaAlloc.TryCoalesceAllFreeBlocks();
+ luaAlloc.CheckCorrectness();
+
+ ClassicAssert.AreEqual(freeSpace, luaAlloc.FirstBlockSizeBytes);
+ }
+ }
+ }
+
+ // Repeated realloc preserves data and doesn't move.
+ {
+ var luaAlloc = new LuaLimitedManagedAllocator(TotalAllocSizeBytes);
+ var freeSpace = luaAlloc.FirstBlockSizeBytes;
+
+ for (var i = 0; i < Iters; i++)
+ {
+ for (var initialSize = 1; initialSize <= 64; initialSize *= 2)
+ {
+ ref var initialData = ref luaAlloc.AllocateNew(initialSize, out var failed);
+ ClassicAssert.IsFalse(failed);
+
+ var size = initialSize;
+ var val = 1;
+
+ ref var curData = ref initialData;
+
+ MemoryMarshal.CreateSpan(ref curData, initialSize).Fill((byte)val);
+
+ while (true)
+ {
+ var newSize = size + rand.Next(4 * 1024) + 1;
+ ref var newData = ref luaAlloc.ResizeAllocation(ref curData, size, newSize, out failed);
+ if (failed)
+ {
+ ClassicAssert.IsTrue(Unsafe.IsNullRef(ref newData));
+ break;
+ }
+
+ // Byte totals are believable
+ DebugClassicAssertIsTrue(luaAlloc.AllocatedBytes >= newSize);
+
+ // Shouldn't have moved
+ ClassicAssert.IsTrue(Unsafe.AreSame(ref initialData, ref newData));
+
+ // Data preserved
+ var oldData = MemoryMarshal.CreateReadOnlySpan(ref newData, size);
+ ClassicAssert.IsFalse(oldData.ContainsAnyExcept((byte)val));
+
+ // Mutate to check for faults
+ val++;
+ MemoryMarshal.CreateSpan(ref newData, newSize).Fill((byte)val);
+
+ // Continue
+ size = newSize;
+ curData = ref newData;
+ }
+ luaAlloc.CheckCorrectness();
+
+ // Check final correctness
+ var finalData = MemoryMarshal.CreateReadOnlySpan(ref curData, size);
+ ClassicAssert.IsFalse(finalData.ContainsAnyExcept((byte)val));
+
+ // Hand the one block back, which should fully free everything
+ luaAlloc.Free(ref curData, size);
+
+ DebugClassicAssertAreEqual(0, luaAlloc.AllocatedBytes);
+ ClassicAssert.AreEqual(freeSpace, luaAlloc.FirstBlockSizeBytes);
+
+ luaAlloc.CheckCorrectness();
+ }
+ }
+ }
+
+ // Basic fixed sized allocs
+ {
+ const int AllocSize = 16;
+
+ var luaAlloc = new LuaLimitedManagedAllocator(TotalAllocSizeBytes);
+ var freeSpace = luaAlloc.FirstBlockSizeBytes;
+
+ for (var i = 0; i < Iters; i++)
+ {
+ DebugClassicAssertAreEqual(0, luaAlloc.AllocatedBytes);
+ ClassicAssert.AreEqual(1, luaAlloc.FreeBlockCount);
+
+ var numOps = rand.Next(50) + 1;
+
+ var toFree = new List();
+
+ // Do a bunch of allocs, all should succeed
+ for (var j = 0; j < numOps; j++)
+ {
+ ref var newData = ref luaAlloc.AllocateNew(AllocSize, out var failed);
+ ClassicAssert.False(failed);
+ ClassicAssert.False(Unsafe.IsNullRef(ref newData));
+
+ var into = new Span(Unsafe.AsPointer(ref newData), AllocSize);
+ into.Fill((byte)j);
+
+ toFree.Add((nint)Unsafe.AsPointer(ref newData));
+ }
+ luaAlloc.CheckCorrectness();
+
+ // Each block should be served out of a split, so free list should stay at 1
+ ClassicAssert.AreEqual(1, luaAlloc.FreeBlockCount);
+
+ // Check that data wasn't corrupted
+ for (var j = 0; j < toFree.Count; j++)
+ {
+ var ptr = toFree[j];
+ var data = new Span((void*)ptr, AllocSize);
+ var expected = (byte)j;
+
+ ClassicAssert.IsFalse(data.ContainsAnyExcept(expected));
+ }
+
+ // Free in a random order
+ toFree = toFree.Select(p => (Pointer: p, Order: rand.Next())).OrderBy(t => t.Order).Select(t => t.Pointer).ToList();
+ for (var j = 0; j < toFree.Count; j++)
+ {
+ var ptr = toFree[j];
+ ref var asData = ref Unsafe.AsRef((void*)ptr);
+
+ luaAlloc.Free(ref asData, AllocSize);
+ }
+ luaAlloc.CheckCorrectness();
+
+ // Check that all free's didn't corrupt anything
+ DebugClassicAssertAreEqual(0, luaAlloc.AllocatedBytes);
+
+ // Check that all memory is reclaimable
+ _ = luaAlloc.TryCoalesceAllFreeBlocks();
+ ClassicAssert.AreEqual(freeSpace, luaAlloc.FirstBlockSizeBytes);
+ luaAlloc.CheckCorrectness();
+ }
+ }
+
+ // Random operations with variable sized allocs
+ {
+ var luaAlloc = new LuaLimitedManagedAllocator(TotalAllocSizeBytes);
+ var freeSpace = luaAlloc.FirstBlockSizeBytes;
+
+ for (var i = 0; i < Iters; i++)
+ {
+ DebugClassicAssertAreEqual(0, luaAlloc.AllocatedBytes);
+ ClassicAssert.AreEqual(1, luaAlloc.FreeBlockCount);
+
+ var toFree = new List<(nint Pointer, byte Expected, int AllocSize)>();
+ for (var j = 0; j < 1_000; j++)
+ {
+ var op = rand.Next(4);
+ switch (op)
+ {
+ // Allocate
+ case 0:
+ {
+ var allocSize = rand.Next(4 * 1024) + 1;
+
+ ref var newData = ref luaAlloc.AllocateNew(allocSize, out var failed);
+ ClassicAssert.IsFalse(failed);
+ ClassicAssert.IsFalse(Unsafe.IsNullRef(ref newData));
+
+ var into = new Span(Unsafe.AsPointer(ref newData), allocSize);
+ into.Fill((byte)j);
+
+ toFree.Add(((nint)Unsafe.AsPointer(ref newData), (byte)j, allocSize));
+ }
+
+ break;
+
+ // Reallocate
+ case 1:
+ {
+ if (toFree.Count == 0)
+ {
+ goto case 0;
+ }
+
+ var reallocIx = rand.Next(toFree.Count);
+ var (ptr, expected, size) = toFree[reallocIx];
+
+ ref var reallocRef = ref Unsafe.AsRef((void*)ptr);
+
+ int newSizeBytes;
+ if (rand.Next(2) == 0)
+ {
+ newSizeBytes = size + rand.Next(32) + 1;
+ }
+ else
+ {
+ newSizeBytes = size - rand.Next(size);
+ if (newSizeBytes == 0)
+ {
+ goto case 0;
+ }
+ }
+
+ ref var updatedRef = ref luaAlloc.ResizeAllocation(ref reallocRef, size, newSizeBytes, out var failed);
+ ClassicAssert.IsFalse(failed);
+
+ if (newSizeBytes <= size)
+ {
+ // Shrink should always leave in place
+ ClassicAssert.IsTrue(Unsafe.AreSame(ref updatedRef, ref reallocRef));
+ }
+
+ var toCheck = MemoryMarshal.CreateReadOnlySpan(ref updatedRef, Math.Min(size, newSizeBytes));
+ ClassicAssert.IsFalse(toCheck.ContainsAnyExcept(expected));
+
+ toFree.RemoveAt(reallocIx);
+
+ var toFill = MemoryMarshal.CreateSpan(ref updatedRef, newSizeBytes);
+ toFill.Fill((byte)j);
+
+ toFree.Add(((nint)Unsafe.AsPointer(ref updatedRef), (byte)j, newSizeBytes));
+ }
+
+ break;
+
+ // Free
+ case 2:
+ {
+ if (toFree.Count == 0)
+ {
+ goto case 0;
+ }
+
+ var freeIx = rand.Next(toFree.Count);
+ var (ptr, expected, size) = toFree[freeIx];
+
+ toFree.RemoveAt(freeIx);
+ luaAlloc.Free(ref Unsafe.AsRef((void*)ptr), size);
+ }
+
+ break;
+
+ // Validate
+ case 3:
+ {
+ if (toFree.Count == 0)
+ {
+ goto case 0;
+ }
+
+ foreach (var (ptr, expected, size) in toFree)
+ {
+ var data = new Span((void*)ptr, size);
+ ClassicAssert.IsFalse(data.ContainsAnyExcept(expected));
+ }
+ }
+
+ break;
+
+ default:
+ ClassicAssert.Fail("Unexpected operation");
+ break;
+ }
+ luaAlloc.CheckCorrectness();
+ }
+ luaAlloc.CheckCorrectness();
+
+ // Validate and free everything that's left
+ foreach (var (ptr, expected, size) in toFree)
+ {
+ ref var asData = ref Unsafe.AsRef((void*)ptr);
+
+ var data = new Span((void*)ptr, size);
+ ClassicAssert.IsFalse(data.ContainsAnyExcept(expected));
+
+ luaAlloc.Free(ref asData, size);
+ }
+ luaAlloc.CheckCorrectness();
+
+ // Check that all free's didn't corrupt anything
+ DebugClassicAssertAreEqual(0, luaAlloc.AllocatedBytes);
+
+ // Full coalesce gets contiguous blocks back
+ _ = luaAlloc.TryCoalesceAllFreeBlocks();
+ ClassicAssert.AreEqual(freeSpace, luaAlloc.FirstBlockSizeBytes);
+
+ luaAlloc.CheckCorrectness();
+ }
+ }
+
+ // In DEBUG builds, assert condition is true
+ [Conditional("DEBUG")]
+ static void DebugClassicAssertIsTrue(bool condition)
+ => ClassicAssert.IsTrue(condition);
+
+ // In DEBUG builds, assert objects are equal
+ [Conditional("DEBUG")]
+ static void DebugClassicAssertAreEqual(object expected, object actual)
+ => ClassicAssert.AreEqual(expected, actual);
+ }
}
}
\ No newline at end of file
diff --git a/test/Garnet.test/LuaScriptTests.cs b/test/Garnet.test/LuaScriptTests.cs
index ace0593340..8b531401f6 100644
--- a/test/Garnet.test/LuaScriptTests.cs
+++ b/test/Garnet.test/LuaScriptTests.cs
@@ -5,22 +5,37 @@
using System.Linq;
using System.Text;
using System.Threading.Tasks;
+using Garnet.server;
using NUnit.Framework;
using NUnit.Framework.Legacy;
using StackExchange.Redis;
namespace Garnet.test
{
- [TestFixture]
+ // Limits chosen here to allow completion - if you have to bump them up, consider that you might have introduced a regression
+ [TestFixture(LuaMemoryManagementMode.Native, "")]
+ [TestFixture(LuaMemoryManagementMode.Tracked, "")]
+ [TestFixture(LuaMemoryManagementMode.Tracked, "13m")]
+ [TestFixture(LuaMemoryManagementMode.Managed, "")]
+ [TestFixture(LuaMemoryManagementMode.Managed, "15m")]
public class LuaScriptTests
{
+ private readonly LuaMemoryManagementMode allocMode;
+ private readonly string limitBytes;
+
protected GarnetServer server;
+ public LuaScriptTests(LuaMemoryManagementMode allocMode, string limitBytes)
+ {
+ this.allocMode = allocMode;
+ this.limitBytes = limitBytes;
+ }
+
[SetUp]
public void Setup()
{
TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true);
- server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, enableLua: true);
+ server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, enableLua: true, luaMemoryMode: allocMode, luaMemoryLimit: limitBytes);
server.Start();
}
@@ -303,7 +318,11 @@ public void UseEvalShaLightClient()
{
var randPostFix = rnd.Next(1, 1000);
valueKey = $"{valueKey}{randPostFix}";
+
var r = lightClientRequest.SendCommand($"EVALSHA {sha1SetScript} 1 {nameKey}{randPostFix} {valueKey}", 1);
+ // Check for error reply
+ ClassicAssert.IsTrue(r[0] != '-');
+
var g = Encoding.ASCII.GetString(lightClientRequest.SendCommand($"get {nameKey}{randPostFix}", 1));
var fstEndOfLine = g.IndexOf('\n', StringComparison.OrdinalIgnoreCase) + 1;
var strKeyValue = g.Substring(fstEndOfLine, valueKey.Length);
@@ -442,7 +461,7 @@ public void ComplexLuaTest3()
using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig());
var db = redis.GetDatabase(0);
- for (int i = 0; i < 10; i++)
+ for (var i = 0; i < 10; i++)
{
var response1 = (string[])db.ScriptEvaluate(script1, ["key1", "key2"]);
ClassicAssert.AreEqual(2, response1.Length);
@@ -615,8 +634,8 @@ public void NumberArgumentCoercion()
using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig());
var db = redis.GetDatabase();
- db.StringSet("2", "hello");
- db.StringSet("2.1", "world");
+ _ = db.StringSet("2", "hello");
+ _ = db.StringSet("2.1", "world");
var res = (string)db.ScriptEvaluate("return redis.call('GET', 2.1)");
ClassicAssert.AreEqual("world", res);
@@ -681,14 +700,15 @@ public void ComplexLuaReturns()
{
if (i != 1)
{
- tableDepth.Append(", ");
+ _ = tableDepth.Append(", ");
}
- tableDepth.Append("{ ");
- tableDepth.Append(i);
+ _ = tableDepth.Append("{ ");
+ _ = tableDepth.Append(i);
}
+
for (var i = 1; i <= Depth; i++)
{
- tableDepth.Append(" }");
+ _ = tableDepth.Append(" }");
}
var script = "return " + tableDepth.ToString();
@@ -785,5 +805,38 @@ public void MetatableReturn()
ClassicAssert.AreEqual("jkl", (string)ret[3]);
ClassicAssert.AreEqual("def", (string)ret[4]);
}
+
+ [Test]
+ public void IntentionalOOM()
+ {
+ if (string.IsNullOrEmpty(limitBytes))
+ {
+ ClassicAssert.Ignore("No memory limit enabled");
+ return;
+ }
+
+ const string ScriptOOMText = @"
+local foo = 'abcdefghijklmnopqrstuvwxyz'
+if @Ctrl == 'OOM' then
+ for i = 1, 10000000 do
+ foo = foo .. foo
+ end
+end
+
+return foo";
+ using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig());
+ var db = redis.GetDatabase();
+
+ var scriptOOM = LuaScript.Prepare(ScriptOOMText);
+ var loadedScriptOOM = scriptOOM.Load(redis.GetServers()[0]);
+
+ // OOM actually happens and is reported
+ var exc = ClassicAssert.Throws(() => db.ScriptEvaluate(loadedScriptOOM, new { Ctrl = "OOM" }));
+ ClassicAssert.AreEqual("ERR Lua encountered an error: not enough memory", exc.Message);
+
+ // We can still run the script without issue (with non-crashing args) afterwards
+ var res = db.ScriptEvaluate(loadedScriptOOM, new { Ctrl = "Safe" });
+ ClassicAssert.AreEqual("abcdefghijklmnopqrstuvwxyz", (string)res);
+ }
}
}
\ No newline at end of file
diff --git a/test/Garnet.test/TestUtils.cs b/test/Garnet.test/TestUtils.cs
index 221843ce97..9bf375a706 100644
--- a/test/Garnet.test/TestUtils.cs
+++ b/test/Garnet.test/TestUtils.cs
@@ -217,7 +217,9 @@ public static GarnetServer CreateGarnetServer(
ILogger logger = null,
IEnumerable loadModulePaths = null,
string pubSubPageSize = null,
- bool asyncReplay = false)
+ bool asyncReplay = false,
+ LuaMemoryManagementMode luaMemoryMode = LuaMemoryManagementMode.Native,
+ string luaMemoryLimit = "")
{
if (UseAzureStorage)
IgnoreIfNotRunningAzureTests();
@@ -296,7 +298,8 @@ public static GarnetServer CreateGarnetServer(
LoadModuleCS = loadModulePaths,
EnableReadCache = enableReadCache,
EnableObjectStoreReadCache = enableObjectStoreReadCache,
- ReplicationOffsetMaxLag = asyncReplay ? -1 : 0
+ ReplicationOffsetMaxLag = asyncReplay ? -1 : 0,
+ LuaOptions = enableLua ? new LuaOptions(luaMemoryMode, luaMemoryLimit, logger) : null,
};
if (!string.IsNullOrEmpty(pubSubPageSize))
@@ -396,7 +399,9 @@ public static GarnetServer[] CreateGarnetCluster(
AadAuthenticationSettings authenticationSettings = null,
int metricsSamplingFrequency = 0,
bool enableLua = false,
- bool asyncReplay = false)
+ bool asyncReplay = false,
+ LuaMemoryManagementMode luaMemoryMode = LuaMemoryManagementMode.Native,
+ string luaMemoryLimit = "")
{
if (UseAzureStorage)
IgnoreIfNotRunningAzureTests();
@@ -438,7 +443,9 @@ public static GarnetServer[] CreateGarnetCluster(
aadAuthenticationSettings: authenticationSettings,
metricsSamplingFrequency: metricsSamplingFrequency,
enableLua: enableLua,
- asyncReplay: asyncReplay);
+ asyncReplay: asyncReplay,
+ luaMemoryMode: luaMemoryMode,
+ luaMemoryLimit: luaMemoryLimit);
ClassicAssert.IsNotNull(opts);
int iter = 0;
@@ -486,7 +493,9 @@ public static GarnetServerOptions GetGarnetServerOptions(
int metricsSamplingFrequency = 0,
bool enableLua = false,
bool asyncReplay = false,
- ILogger logger = null)
+ ILogger logger = null,
+ LuaMemoryManagementMode luaMemoryMode = LuaMemoryManagementMode.Native,
+ string luaMemoryLimit = "")
{
if (UseAzureStorage)
IgnoreIfNotRunningAzureTests();
@@ -571,7 +580,8 @@ public static GarnetServerOptions GetGarnetServerOptions(
ClusterUsername = authUsername,
ClusterPassword = authPassword,
EnableLua = enableLua,
- ReplicationOffsetMaxLag = asyncReplay ? -1 : 0
+ ReplicationOffsetMaxLag = asyncReplay ? -1 : 0,
+ LuaOptions = enableLua ? new LuaOptions(luaMemoryMode, luaMemoryLimit) : null,
};
if (lowMemory)