Skip to content

Commit

Permalink
add noescape annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
mhr3 committed May 28, 2023
1 parent b90fa65 commit 4edc66a
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 5 deletions.
24 changes: 19 additions & 5 deletions gozstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,22 @@ func compress(cctx, cctxDict *cctxWrapper, dst, src []byte, cd *CDict, compressi
return dst
}

// noescape hides a pointer from escape analysis. It is the identity function
// but escape analysis doesn't think the output depends on the input.
// noescape is inlined and currently compiles down to zero instructions.
// This is copied from go's strings.Builder. Allows us to use stack-allocated
// slices.
//go:nosplit
//go:nocheckptr
func noescape(p unsafe.Pointer) unsafe.Pointer {
x := uintptr(p)
return unsafe.Pointer(x ^ 0)
}

func compressInternal(cctx, cctxDict *cctxWrapper, dst, src []byte, cd *CDict, compressionLevel int, mustSucceed bool) C.size_t {
dstHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dst))
srcHdr := (*reflect.SliceHeader)(unsafe.Pointer(&src))
// using noescape will allow this to work with stack-allocated slices
dstHdr := (*reflect.SliceHeader)(noescape(unsafe.Pointer(&dst)))
srcHdr := (*reflect.SliceHeader)(noescape(unsafe.Pointer(&src)))

if cd != nil {
result := C.ZSTD_compress_usingCDict_wrapper(
Expand Down Expand Up @@ -180,6 +193,7 @@ func compressInternal(cctx, cctxDict *cctxWrapper, dst, src []byte, cd *CDict, c
if mustSucceed {
ensureNoError("ZSTD_compressCCtx", result)
}

return result
}

Expand Down Expand Up @@ -258,7 +272,7 @@ func decompress(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) ([]byte
}

// Slow path - resize dst to fit decompressed data.
srcHdr := (*reflect.SliceHeader)(unsafe.Pointer(&src))
srcHdr := (*reflect.SliceHeader)(noescape(unsafe.Pointer(&src)))
contentSize := C.ZSTD_getFrameContentSize_wrapper(unsafe.Pointer(srcHdr.Data), C.size_t(len(src)))
switch {
case contentSize == C.ZSTD_CONTENTSIZE_UNKNOWN || contentSize > maxFrameContentSize:
Expand Down Expand Up @@ -290,8 +304,8 @@ func decompress(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) ([]byte

func decompressInternal(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) C.size_t {
var (
dstHdr = (*reflect.SliceHeader)(unsafe.Pointer(&dst))
srcHdr = (*reflect.SliceHeader)(unsafe.Pointer(&src))
dstHdr = (*reflect.SliceHeader)(noescape(unsafe.Pointer(&dst)))
srcHdr = (*reflect.SliceHeader)(noescape(unsafe.Pointer(&src)))
n C.size_t
)
if dd != nil {
Expand Down
51 changes: 51 additions & 0 deletions gozstd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/hex"
"fmt"
"io"
"math/rand"
"runtime"
"strings"
Expand Down Expand Up @@ -54,6 +55,14 @@ func TestDecompressSmallBlockWithoutSingleSegmentFlag(t *testing.T) {
})
}

func TestCompressEmpty(t *testing.T) {
var dst [64]byte
res := Compress(dst[:0], nil)
if len(res) > 0 {
t.Fatalf("unexpected non-empty compressed frame: %X", res)
}
}

func TestDecompressTooLarge(t *testing.T) {
src := []byte{40, 181, 47, 253, 228, 122, 118, 105, 67, 140, 234, 85, 20, 159, 67}
_, err := Decompress(nil, src)
Expand All @@ -70,6 +79,48 @@ func mustUnhex(dataHex string) []byte {
return data
}

func TestCompressWithStackMove(t *testing.T) {
var srcBuf [96]byte

n, err := io.ReadFull(rand.New(rand.NewSource(time.Now().Unix())), srcBuf[:])
if err != nil {
t.Fatalf("cannot fill srcBuf with random data: %s", err)
}

// We're running this twice, because the first run will allocate
// objects in sync.Pool, calls to which extend the stack, and the second
// run can skip those allocations and extend the stack right before
// the CGO call.
// Note that this test might require some go:nosplit annotations
// to force the stack move to happen exactly before the CGO call.
for i := 0; i < 2; i++ {
ch := make(chan struct{})
go func() {
defer close(ch)

var dstBuf [1416]byte

res := Compress(dstBuf[:0], srcBuf[:n])

// make a copy of the result, so the original can remain on the stack
compressedCpy := make([]byte, len(res))
copy(compressedCpy, res)

orig, err := Decompress(nil, compressedCpy)
if err != nil {
panic(fmt.Errorf("cannot decompress: %s", err))
}
if !bytes.Equal(orig, srcBuf[:n]) {
panic(fmt.Errorf("unexpected decompressed data; got %q; want %q", orig, srcBuf[:n]))
}
}()
// wait for the goroutine to finish
<-ch
}

runtime.GC()
}

func TestCompressDecompressDistinctConcurrentDicts(t *testing.T) {
// Build multiple distinct dicts.
var cdicts []*CDict
Expand Down

0 comments on commit 4edc66a

Please sign in to comment.