Skip to content

Commit

Permalink
feat: support running gen_trace in WASM
Browse files Browse the repository at this point in the history
  • Loading branch information
mellowcroc committed Dec 15, 2024
1 parent 7677735 commit 661d7a8
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 74 deletions.
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ wgpu = "23.0.0"
flume = "0.11.0"
pollster = "0.3"
once_cell = "1.20.2"
wasm-bindgen = "0.2.84"
wasm-bindgen-futures = "0.4.45"

[profile.bench]
codegen-units = 1
Expand Down
10 changes: 10 additions & 0 deletions crates/prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ wgpu.workspace = true
flume.workspace = true
pollster.workspace = true
once_cell.workspace = true
wasm-bindgen.workspace = true
wasm-bindgen-futures.workspace = true
wasm-bindgen-test = "0.3.43"
js-sys = "0.3.66"
web-sys = { version = "0.3", features = ["console", "Performance", "Window"] }

[dev-dependencies]
aligned = "0.4.2"
Expand All @@ -49,6 +54,11 @@ default-features = false
features = ["html_reports"]
version = "0.5.1"

[target.wasm32-unknown-unknown]
rustflags = [
"-C", "link-args=-z stack-size=1500000",
]

[lib]
bench = false
crate-type = ["cdylib", "lib"]
Expand Down
207 changes: 146 additions & 61 deletions crates/prover/src/core/backend/gpu/gen_trace_interpolate_columns.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#[cfg(not(target_family = "wasm"))]
use std::time::Instant;

use bytemuck::{Pod, Zeroable};
use wgpu::util::DeviceExt;

const N_ROWS: u32 = 1 << 5;
const N_ROWS: u32 = 32;
const N_STATE: u32 = 16;
const N_INSTANCES_PER_ROW: u32 = 1 << N_LOG_INSTANCES_PER_ROW;
const N_LOG_INSTANCES_PER_ROW: u32 = 3;
Expand Down Expand Up @@ -47,15 +48,6 @@ struct GpuLookupData {
final_state: [[GpuBaseColumn; N_STATE as usize]; N_INSTANCES_PER_ROW as usize],
}

impl From<GpuLookupData> for LookupData {
fn from(value: GpuLookupData) -> Self {
LookupData {
initial_state: value.initial_state.map(|c| c.map(|c| c.into())),
final_state: value.final_state.map(|c| c.map(|c| c.into())),
}
}
}

#[derive(Debug, Clone, Copy)]
#[repr(C)]
struct GpuM31 {
Expand All @@ -75,15 +67,6 @@ struct GpuBaseColumn {
length: u32,
}

impl GpuBaseColumn {
fn zeros(length: u32) -> Self {
GpuBaseColumn {
data: [[GpuM31 { data: 0 }; N_LANES as usize]; N_ROWS as usize],
length,
}
}
}

impl From<GpuBaseColumn> for BaseColumn {
fn from(value: GpuBaseColumn) -> Self {
BaseColumn {
Expand All @@ -103,29 +86,90 @@ impl From<GpuBaseColumn> for BaseColumn {
}
}

#[allow(dead_code)]
const GEN_TRACE_OUTPUT_SIZE: usize =
N_COLUMNS as usize * N_LANES as usize * N_STATE as usize * N_LOG_INSTANCES_PER_ROW as usize;

#[derive(Clone, Debug)]
#[derive(Clone, Debug, Copy)]
#[repr(C)]
struct GenTraceOutput {
trace: [GpuBaseColumn; N_COLUMNS as usize],
lookup_data: GpuLookupData,
}

impl Default for GenTraceOutput {
fn default() -> Self {
GenTraceOutput {
trace: std::array::from_fn(|_| GpuBaseColumn::zeros(1 << N_LANES)),
lookup_data: GpuLookupData {
initial_state: std::array::from_fn(|_| {
std::array::from_fn(|_| GpuBaseColumn::zeros(1 << N_LANES))
}),
final_state: std::array::from_fn(|_| {
std::array::from_fn(|_| GpuBaseColumn::zeros(1 << N_LANES))
}),
},
#[derive(Clone, Debug)]
#[repr(C)]
struct GenTraceOutputVec {
trace: Vec<BaseColumn>,
lookup_data: LookupData,
}

impl GenTraceOutputVec {
fn from_bytes(bytes: &[u8]) -> Self {
let base_column_size = std::mem::size_of::<GpuBaseColumn>();
let lookup_data_size = std::mem::size_of::<GpuLookupData>();
assert!(bytes.len() >= base_column_size * N_COLUMNS as usize + lookup_data_size);
let base_column_slice = bytes
.chunks(base_column_size)
.take(N_COLUMNS as usize)
.map(|chunk| BaseColumn::from_bytes(chunk))
.collect::<Vec<_>>();
let lookup_data_start = base_column_size * N_COLUMNS as usize;
let lookup_data =
LookupData::from_bytes(&bytes[lookup_data_start..lookup_data_start + lookup_data_size]);
Self {
trace: base_column_slice,
lookup_data,
}
}
}

impl BaseColumn {
fn from_bytes(bytes: &[u8]) -> Self {
assert!(bytes.len() >= std::mem::size_of::<Self>());
let slice = unsafe { &*(bytes.as_ptr() as *const GpuBaseColumn) };
(*slice).into()
}
}

impl LookupData {
fn from_bytes(bytes: &[u8]) -> Self {
let base_column_size = std::mem::size_of::<GpuBaseColumn>();
let base_column_vec_size = base_column_size * N_STATE as usize;
let state_size = base_column_vec_size * N_INSTANCES_PER_ROW as usize;
let lookup_data_size = state_size * 2;
assert!(bytes.len() >= lookup_data_size);
let initial_state_slice: [[BaseColumn; N_STATE as usize]; N_INSTANCES_PER_ROW as usize] =
bytes
.chunks(base_column_vec_size)
.take(N_INSTANCES_PER_ROW as usize)
.map(|chunk| {
chunk
.chunks(base_column_size)
.take(N_STATE as usize)
.map(|chunk| BaseColumn::from_bytes(chunk))
.collect::<Vec<_>>()
.try_into()
.unwrap()
})
.collect::<Vec<_>>()
.try_into()
.unwrap();
let final_state_slice: [[BaseColumn; N_STATE as usize]; N_INSTANCES_PER_ROW as usize] =
bytes[state_size..]
.chunks(base_column_vec_size)
.take(N_INSTANCES_PER_ROW as usize)
.map(|chunk| {
chunk
.chunks(base_column_size)
.take(N_STATE as usize)
.map(|chunk| BaseColumn::from_bytes(chunk))
.collect::<Vec<_>>()
.try_into()
.unwrap()
})
.collect::<Vec<_>>()
.try_into()
.unwrap();
Self {
initial_state: initial_state_slice,
final_state: final_state_slice,
}
}
}
Expand Down Expand Up @@ -177,7 +221,20 @@ impl ByteSerialize for BaseColumn {}
impl ByteSerialize for GpuStateData {}
impl ByteSerialize for GenTraceOutput {}
impl ByteSerialize for ShaderResult {}
pub async fn gen_trace_interpolate_columns(log_n_rows: u32) -> (Vec<BaseColumn>, LookupData) {

#[allow(dead_code)]
struct WgpuInstance {
instance: wgpu::Instance,
adapter: wgpu::Adapter,
device: wgpu::Device,
queue: wgpu::Queue,
staging_buffer: wgpu::Buffer,
state_staging_buffer: wgpu::Buffer,
shader_result_staging_buffer: wgpu::Buffer,
encoder: wgpu::CommandEncoder,
}

async fn init(log_n_rows: u32) -> WgpuInstance {
let instance = wgpu::Instance::default();
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
Expand Down Expand Up @@ -352,9 +409,6 @@ pub async fn gen_trace_interpolate_columns(log_n_rows: u32) -> (Vec<BaseColumn>,
label: Some("Gen Trace Command Encoder"),
});

// === GPU FFT Timing Start ===
let gpu_start = Instant::now();

// Dispatch the compute shader
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
Expand Down Expand Up @@ -410,8 +464,28 @@ pub async fn gen_trace_interpolate_columns(log_n_rows: u32) -> (Vec<BaseColumn>,
shader_result_staging_buffer.size(),
);

WgpuInstance {
instance,
adapter,
device,
queue,
staging_buffer,
state_staging_buffer,
shader_result_staging_buffer,
encoder,
}
}

pub async fn gen_trace_interpolate_columns(log_n_rows: u32) -> (Vec<BaseColumn>, LookupData) {
let instance = init(log_n_rows).await;

#[cfg(not(target_family = "wasm"))]
let gpu_start = Instant::now();
#[cfg(target_family = "wasm")]
let gpu_start = web_sys::window().unwrap().performance().unwrap().now();

// Submit the commands
queue.submit(Some(encoder.finish()));
instance.queue.submit(Some(instance.encoder.finish()));

// let buffer_slice = state_staging_buffer.slice(..);
// let (sender, receiver) = flume::bounded(1);
Expand All @@ -427,44 +501,55 @@ pub async fn gen_trace_interpolate_columns(log_n_rows: u32) -> (Vec<BaseColumn>,
// println!("State data: {:?}", _result);
// }

let buffer_slice = shader_result_staging_buffer.slice(..);
let buffer_slice = instance.shader_result_staging_buffer.slice(..);
let (sender, receiver) = flume::bounded(1);
buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
device.poll(wgpu::Maintain::wait()).panic_on_timeout();
instance
.device
.poll(wgpu::Maintain::wait())
.panic_on_timeout();

if let Ok(Ok(())) = receiver.recv_async().await {
let data = buffer_slice.get_mapped_range();
let _result = *ShaderResult::from_bytes(&data);
drop(data);
shader_result_staging_buffer.unmap();
instance.shader_result_staging_buffer.unmap();

// _result.values.iter().enumerate().for_each(|(i, v)| {
// println!("Shader result[{}]: {:?}", i, v);
// });
}

// Wait for the GPU to finish and map the staging buffer
let buffer_slice = staging_buffer.slice(..);
let buffer_slice = instance.staging_buffer.slice(..);
let (sender, receiver) = flume::bounded(1);
buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
device.poll(wgpu::Maintain::wait()).panic_on_timeout();
let mut result = GenTraceOutput::default();
if let Ok(Ok(())) = receiver.recv_async().await {
instance
.device
.poll(wgpu::Maintain::wait())
.panic_on_timeout();
let result = async {
receiver.recv_async().await.unwrap().unwrap();
let data = buffer_slice.get_mapped_range();
result = GenTraceOutput::from_bytes(&data).clone();

let output = GenTraceOutputVec::from_bytes(&data);
drop(data);
staging_buffer.unmap();
instance.staging_buffer.unmap();

// let trace: Vec<BaseColumn> = result.trace.clone().map(|c| c.into()).to_vec();
// for (i, c) in trace.iter().enumerate() {
// println!("GPU Trace[{}].data[0]: {:?}", i, c.data[0]);
// }
// let lookup_data = result.lookup_data.into();
}
println!("Poseidon generate trace time: {:?}", gpu_start.elapsed());
let output_trace: Vec<BaseColumn> =
output.trace.clone().into_iter().map(|c| c.into()).collect();
(output_trace, output.lookup_data.into())
};

let (trace, lookup_data) = result.await;

#[cfg(not(target_family = "wasm"))]
println!("GPU time: {:?}", gpu_start.elapsed());

#[cfg(target_family = "wasm")]
let gpu_end = web_sys::window().unwrap().performance().unwrap().now();
#[cfg(target_family = "wasm")]
web_sys::console::log_1(&format!("GPU time: {:?}", gpu_end - gpu_start).into());

(
result.trace.map(|c| c.into()).to_vec(),
result.lookup_data.into(),
)
(trace, lookup_data)
}
Loading

0 comments on commit 661d7a8

Please sign in to comment.