Skip to content

Commit

Permalink
Linux implementation of hat::memory_protector
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeroMemes committed Oct 7, 2024
1 parent 82f4466 commit a149d40
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ set(LIBHAT_SRC
src/Scanner.cpp
src/System.cpp

src/os/linux/MemoryProtector.cpp

src/os/unix/System.cpp

src/os/win32/MemoryProtector.cpp
Expand Down
15 changes: 15 additions & 0 deletions src/Utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include <bit>
#include <cstdint>

namespace hat::detail {

constexpr uintptr_t fast_align_down(uintptr_t address, size_t alignment) {
return address & ~static_cast<uintptr_t>(alignment - 1);
}

constexpr uintptr_t fast_align_up(uintptr_t address, size_t alignment) {
return (address + alignment - 1) & ~static_cast<uintptr_t>(alignment - 1);
}
}
84 changes: 84 additions & 0 deletions src/os/linux/MemoryProtector.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#include <libhat/Defines.hpp>
#ifdef LIBHAT_LINUX

#include <charconv>
#include <fstream>
#include <optional>
#include <string>

#include <libhat/MemoryProtector.hpp>
#include <libhat/System.hpp>
#include "../../Utils.hpp"

#include <sys/mman.h>

namespace hat {

static int to_system_prot(const protection flags) {
int prot = 0;
if (static_cast<bool>(flags & protection::Read)) prot |= PROT_READ;
if (static_cast<bool>(flags & protection::Write)) prot |= PROT_WRITE;
if (static_cast<bool>(flags & protection::Execute)) prot |= PROT_EXEC;
return prot;
}

static std::optional<int> get_page_prot(const uintptr_t address) {
std::ifstream f("/proc/self/maps");
std::string s;
while (std::getline(f, s)) {
const char* it = s.data();
const char* end = s.data() + s.size();
std::from_chars_result res{};

uintptr_t pageBegin;
res = std::from_chars(it, end, pageBegin, 16);
if (res.ec != std::errc{} || res.ptr == end) {
continue;
}
it = res.ptr + 1; // +1 to skip the hyphen

uintptr_t pageEnd;
res = std::from_chars(it, end, pageEnd, 16);
if (res.ec != std::errc{} || res.ptr == end) {
continue;
}
it = res.ptr + 1; // +1 to skip the space

std::string_view remaining{it, end};
if (address >= pageBegin && address < pageEnd && remaining.size() >= 3) {
int prot = 0;
if (remaining[0] == 'r') prot |= PROT_READ;
if (remaining[1] == 'w') prot |= PROT_WRITE;
if (remaining[2] == 'x') prot |= PROT_EXECUTE;
return prot;
}
}
return std::nullopt;
}

memory_protector::memory_protector(const uintptr_t address, const size_t size, const protection flags) : address(address), size(size) {
const auto pageSize = hat::get_system().page_size;

const auto oldProt = get_page_prot(address);
if (!oldProt) {
return; // Failure indicated via is_set()
}

this->oldProtection = static_cast<uint32_t>(*oldProt);
this->set = 0 == mprotect(
reinterpret_cast<void*>(detail::fast_align_down(address, pageSize)),
static_cast<size_t>(detail::fast_align_up(size, pageSize)),
to_system_prot(flags)
);
}

void memory_protector::restore() {
const auto pageSize = hat::get_system().page_size;
mprotect(
reinterpret_cast<void*>(detail::fast_align_down(address, pageSize)),
static_cast<size_t>(detail::fast_align_up(size, pageSize)),
this->oldProtection
);
}
}
#endif
10 changes: 8 additions & 2 deletions src/os/win32/MemoryProtector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
#include <Windows.h>

namespace hat {
static DWORD ToWinProt(const protection flags) {

static DWORD to_system_prot(const protection flags) {
const bool r = static_cast<bool>(flags & protection::Read);
const bool w = static_cast<bool>(flags & protection::Write);
const bool x = static_cast<bool>(flags & protection::Execute);
Expand All @@ -20,7 +21,12 @@ namespace hat {
}

memory_protector::memory_protector(const uintptr_t address, const size_t size, const protection flags) : address(address), size(size) {
this->set = 0 != VirtualProtect(reinterpret_cast<LPVOID>(this->address), this->size, ToWinProt(flags), reinterpret_cast<PDWORD>(&this->oldProtection));
this->set = 0 != VirtualProtect(
reinterpret_cast<LPVOID>(this->address),
this->size,
to_system_prot(flags),
reinterpret_cast<PDWORD>(&this->oldProtection)
);
}

void memory_protector::restore() {
Expand Down

0 comments on commit a149d40

Please sign in to comment.