From a75d00a21c67ad9b77798fd86af0ea801a88d4d7 Mon Sep 17 00:00:00 2001 From: jason <94618524+mellowcroc@users.noreply.github.com> Date: Tue, 10 Dec 2024 15:09:13 +0900 Subject: [PATCH] feat: parallelize generating trace using multiple threads limited by 2^13 poseidon instances due to wgpu buffer size limit --- .../core/backend/gpu/gen_trace_parallel.rs | 499 ++++++++++++++++++ .../core/backend/gpu/gen_trace_parallel.wgsl | 295 +++++++++++ crates/prover/src/core/backend/gpu/mod.rs | 1 + crates/prover/src/core/backend/simd/column.rs | 2 +- crates/prover/src/core/backend/simd/m31.rs | 2 +- crates/prover/src/examples/poseidon/mod.rs | 21 +- 6 files changed, 813 insertions(+), 7 deletions(-) create mode 100644 crates/prover/src/core/backend/gpu/gen_trace_parallel.rs create mode 100644 crates/prover/src/core/backend/gpu/gen_trace_parallel.wgsl diff --git a/crates/prover/src/core/backend/gpu/gen_trace_parallel.rs b/crates/prover/src/core/backend/gpu/gen_trace_parallel.rs new file mode 100644 index 0000000000..76010bdab8 --- /dev/null +++ b/crates/prover/src/core/backend/gpu/gen_trace_parallel.rs @@ -0,0 +1,499 @@ +use std::time::Instant; + +use bytemuck::{Pod, Zeroable}; +use wgpu::util::DeviceExt; + +const N_ROWS: u32 = 1 << 5; +const N_STATE: u32 = 16; +const N_INSTANCES_PER_ROW: u32 = 1 << N_LOG_INSTANCES_PER_ROW; +const N_LOG_INSTANCES_PER_ROW: u32 = 3; +const N_COLUMNS: u32 = N_INSTANCES_PER_ROW * N_COLUMNS_PER_REP; +const N_HALF_FULL_ROUNDS: u32 = 4; +const FULL_ROUNDS: u32 = 2 * N_HALF_FULL_ROUNDS; +const N_PARTIAL_ROUNDS: u32 = 14; +const N_LANES: u32 = 16; +const N_COLUMNS_PER_REP: u32 = N_STATE * (1 + FULL_ROUNDS) + N_PARTIAL_ROUNDS; +const LOG_N_LANES: u32 = 4; +const WORKGROUP_SIZE: u32 = 1; +const THREADS_PER_WORKGROUP: u32 = 1 << 5; + +use crate::core::backend::simd::column::BaseColumn; +#[allow(unused_imports)] +use crate::core::backend::simd::m31::PackedM31; +#[allow(unused_imports)] +use crate::core::backend::Column; +#[allow(unused_imports)] +use crate::core::fields::m31::M31; +#[allow(unused_imports)] +use crate::examples::poseidon::LookupData; + +#[derive(Debug, Clone, Copy, Pod, Zeroable)] +#[repr(C)] +struct Complex { + real: f32, + imag: f32, +} + +#[derive(Debug, Clone, Copy, Pod, Zeroable)] +#[repr(C)] +struct GenTraceInput { + log_n_rows: u32, +} + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +struct GpuLookupData { + initial_state: [[GpuBaseColumn; N_STATE as usize]; N_INSTANCES_PER_ROW as usize], + final_state: [[GpuBaseColumn; N_STATE as usize]; N_INSTANCES_PER_ROW as usize], +} + +impl From for LookupData { + fn from(value: GpuLookupData) -> Self { + LookupData { + initial_state: value.initial_state.map(|c| c.map(|c| c.into())), + final_state: value.final_state.map(|c| c.map(|c| c.into())), + } + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +struct GpuPackedM31 { + data: [u32; N_LANES as usize], +} + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct StateData { + data: [GpuPackedM31; N_STATE as usize], +} + +#[derive(Clone, Debug, Copy)] +pub struct GpuBaseColumn { + data: [GpuPackedM31; N_ROWS as usize], + length: u32, +} + +impl GpuBaseColumn { + fn zeros(length: u32) -> Self { + GpuBaseColumn { + data: [GpuPackedM31 { + data: [0; N_LANES as usize], + }; N_ROWS as usize], + length, + } + } +} + +impl From for BaseColumn { + fn from(value: GpuBaseColumn) -> Self { + BaseColumn { + data: value + .data + .iter() + .map(|f| PackedM31::from_array(f.data.map(|v| M31(v)))) + .collect(), + length: value.length as usize, + } + } +} + +#[allow(dead_code)] +const GEN_TRACE_OUTPUT_SIZE: usize = + N_COLUMNS as usize * N_LANES as usize * N_STATE as usize * N_LOG_INSTANCES_PER_ROW as usize; + +#[derive(Clone, Debug)] +#[repr(C)] +struct GenTraceOutput { + trace: [GpuBaseColumn; N_COLUMNS as usize], + lookup_data: GpuLookupData, +} + +impl Default for GenTraceOutput { + fn default() -> Self { + GenTraceOutput { + trace: std::array::from_fn(|_| GpuBaseColumn::zeros(1 << N_LANES)), + lookup_data: GpuLookupData { + initial_state: std::array::from_fn(|_| { + std::array::from_fn(|_| GpuBaseColumn::zeros(1 << N_LANES)) + }), + final_state: std::array::from_fn(|_| { + std::array::from_fn(|_| GpuBaseColumn::zeros(1 << N_LANES)) + }), + }, + } + } +} + +// impl GenTraceOutput { +// fn into_trace(data: &[u8], log_n_rows: u32) -> Vec { +// result.trace.map(|c| c.into_base_column()).to_vec() +// } +// } + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +struct ShaderResult { + values: [u32; THREADS_PER_WORKGROUP as usize * WORKGROUP_SIZE as usize], +} + +pub trait ByteSerialize: Sized { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + (self as *const Self) as *const u8, + std::mem::size_of::(), + ) + } + } + + fn from_bytes(bytes: &[u8]) -> &Self { + println!("from_bytes"); + println!("bytes.len(): {}", bytes.len()); + println!( + "std::mem::size_of::(): {}", + std::mem::size_of::() + ); + assert!(bytes.len() >= std::mem::size_of::()); + unsafe { &*(bytes.as_ptr() as *const Self) } + } +} + +impl ByteSerialize for GenTraceInput {} +impl ByteSerialize for BaseColumn {} +impl ByteSerialize for StateData {} +impl ByteSerialize for GenTraceOutput {} +impl ByteSerialize for ShaderResult {} +pub async fn gen_trace_parallel(log_n_rows: u32) -> (Vec, LookupData) { + let instance = wgpu::Instance::default(); + let adapter = instance + .request_adapter(&wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::HighPerformance, + compatible_surface: None, + force_fallback_adapter: false, + }) + .await + .unwrap(); + + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: Some("Device"), + required_features: wgpu::Features::SHADER_INT64, + required_limits: wgpu::Limits::default(), + memory_hints: wgpu::MemoryHints::Performance, + }, + None, + ) + .await + .unwrap(); + + let input_data: GenTraceInput = GenTraceInput { log_n_rows }; + + // Create buffers + let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("Input Buffer"), + contents: bytemuck::cast_slice(&[input_data]), + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, + }); + + let size = std::mem::size_of::(); + println!("std::mem::size_of::: {}", size); + + println!( + "std::mem::size_of::: {}", + std::mem::size_of::() + ); + println!("std::mem::size_of::: {}", std::mem::size_of::()); + println!( + "((std::mem::size_of::() * N_STATE as usize + std::mem::size_of::()) + * N_COLUMNS as usize + * (1 << (log_n_rows - LOG_N_LANES)) as usize): {}", + ((std::mem::size_of::() * N_ROWS as usize + std::mem::size_of::()) + * N_COLUMNS as usize + * (1 << (log_n_rows - LOG_N_LANES)) as usize) + ); + println!( + "std::mem::size_of::: {}", + std::mem::size_of::() + ); + // let buffer_size = (std::mem::size_of::() * N_STATE as usize + // + std::mem::size_of::()) + // * N_COLUMNS as usize + // * (1 << (log_n_rows - LOG_N_LANES)) as usize; + let buffer_size = std::mem::size_of::(); + let output_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Output Buffer"), + size: buffer_size as wgpu::BufferAddress, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + let state_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("State Buffer"), + size: (std::mem::size_of::()) as wgpu::BufferAddress, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + let shader_result = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Shader Result Buffer"), + size: (std::mem::size_of::()) as wgpu::BufferAddress, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + // Load shader + let shader_source = include_str!("gen_trace_parallel.wgsl"); + let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Gen Trace Shader"), + source: wgpu::ShaderSource::Wgsl(shader_source.into()), + }); + + // Get the maximum buffer size supported by the device + let max_buffer_size = device.limits().max_buffer_size; + println!("Maximum buffer size supported: {} bytes", max_buffer_size); + + // Check if our buffer size exceeds the limit + if buffer_size > max_buffer_size as usize { + panic!( + "Required buffer size {} exceeds device maximum of {}", + buffer_size, max_buffer_size + ); + } + + // Bind group layout + let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + entries: &[ + // Binding 0: Input buffer + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + // Binding 1: Output buffer + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + // Binding 2: Debug buffer + wgpu::BindGroupLayoutEntry { + binding: 2, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + // Binding 3: Workgroup result buffer + wgpu::BindGroupLayoutEntry { + binding: 3, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], + label: Some("Gen Trace Bind Group Layout"), + }); + + // Create bind group + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: input_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: output_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 2, + resource: state_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 3, + resource: shader_result.as_entire_binding(), + }, + ], + label: Some("Gen Trace Bind Group"), + }); + + // Pipeline layout + let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + label: Some("Gen Trace Pipeline Layout"), + }); + + // Compute pipeline + let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("Gen Trace Compute Pipeline"), + layout: Some(&pipeline_layout), + module: &shader_module, + entry_point: Some("gen_trace"), + cache: None, + compilation_options: Default::default(), + }); + + // Create encoder + let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Gen Trace Command Encoder"), + }); + + // === GPU FFT Timing Start === + let gpu_start = Instant::now(); + + // Dispatch the compute shader + { + let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("Gen Trace Compute Pass"), + timestamp_writes: None, + }); + compute_pass.set_pipeline(&compute_pipeline); + compute_pass.set_bind_group(0, &bind_group, &[]); + + // Workgroup size defined in shader + compute_pass.dispatch_workgroups(WORKGROUP_SIZE, 1, 1); + } + + // Copy output to staging buffer for read access + let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Staging Buffer"), + size: buffer_size as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, staging_buffer.size()); + + // create storage buffer for debug data + let state_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("State Staging Buffer"), + size: (std::mem::size_of::()) as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + // create storage buffer for workgroup result + let shader_result_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Shader Result Staging Buffer"), + size: (std::mem::size_of::()) as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + encoder.copy_buffer_to_buffer( + &state_buffer, + 0, + &state_staging_buffer, + 0, + state_staging_buffer.size(), + ); + + encoder.copy_buffer_to_buffer( + &shader_result, + 0, + &shader_result_staging_buffer, + 0, + shader_result_staging_buffer.size(), + ); + + // Submit the commands + queue.submit(Some(encoder.finish())); + + // let buffer_slice = state_staging_buffer.slice(..); + // let (sender, receiver) = flume::bounded(1); + // buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap()); + // device.poll(wgpu::Maintain::wait()).panic_on_timeout(); + + // if let Ok(Ok(())) = receiver.recv_async().await { + // let data = buffer_slice.get_mapped_range(); + // let _result = *StateData::from_bytes(&data); + // drop(data); + // state_staging_buffer.unmap(); + + // println!("State data: {:?}", _result); + // } + + let buffer_slice = shader_result_staging_buffer.slice(..); + let (sender, receiver) = flume::bounded(1); + buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap()); + device.poll(wgpu::Maintain::wait()).panic_on_timeout(); + + if let Ok(Ok(())) = receiver.recv_async().await { + let data = buffer_slice.get_mapped_range(); + let _result = *ShaderResult::from_bytes(&data); + drop(data); + shader_result_staging_buffer.unmap(); + + println!("Shader result: {:?}", _result); + } + + // Wait for the GPU to finish and map the staging buffer + let buffer_slice = staging_buffer.slice(..); + let (sender, receiver) = flume::bounded(1); + buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap()); + device.poll(wgpu::Maintain::wait()).panic_on_timeout(); + println!("after device.poll"); + // let mut result = Vec::::new(); + let mut result = GenTraceOutput::default(); + if let Ok(Ok(())) = receiver.recv_async().await { + println!("inside buffer_slice"); + let data = buffer_slice.get_mapped_range(); + println!("data.len(): {}", data.len()); + result = GenTraceOutput::from_bytes(&data).clone(); + // result = GenTraceOutput::into_trace(&data, log_n_rows); + drop(data); + staging_buffer.unmap(); + // println!("result: {:?}", result); + + // let trace = result.trace.clone(); + // for i in 0..trace.len() { + // println!("trace[{}].data.len(): {:?}", i, trace[i].data.len()); + // } + // for i in 0..trace.len() { + // println!("trace[{}].data[15]: {:?}", i, trace[i].data[15]); + // } + // for i in 0..trace[0].data.len() { + // println!("trace[0].data[{}]: {:?}", i, trace[0].data[i]); + // } + // for i in 0..trace[1].data.len() { + // println!("trace[1].data[{}]: {:?}", i, trace[1].data[i]); + // } + // for i in 0..trace[2].data.len() { + // println!("trace[2].data[{}]: {:?}", i, trace[2].data[i]); + // } + // for i in 158..158 * 2 { + // println!("trace[{}].data[{}]: {:?}", i, 0, trace[i].data[0]); + // } + // let trace = result.trace.clone(); + // for i in 0..trace.len() { + // println!("trace[{}].data.len(): {:?}", i, trace[i].data.len()); + // } + } + println!("Poseidon generate trace time: {:?}", gpu_start.elapsed()); + + ( + result.trace.map(|c| c.into()).to_vec(), + result.lookup_data.into(), + ) +} diff --git a/crates/prover/src/core/backend/gpu/gen_trace_parallel.wgsl b/crates/prover/src/core/backend/gpu/gen_trace_parallel.wgsl new file mode 100644 index 0000000000..448d11a1bc --- /dev/null +++ b/crates/prover/src/core/backend/gpu/gen_trace_parallel.wgsl @@ -0,0 +1,295 @@ +const MODULUS_BITS: u32 = 31u; +const P: u32 = 2147483647u; + +// Define constants +const N_ROWS: u32 = 64; +const N_STATE: u32 = 16; +const N_INSTANCES_PER_ROW: u32 = 8; +const N_COLUMNS: u32 = N_INSTANCES_PER_ROW * N_COLUMNS_PER_REP; +const N_HALF_FULL_ROUNDS: u32 = 4; +const FULL_ROUNDS: u32 = 2u * N_HALF_FULL_ROUNDS; +const N_PARTIAL_ROUNDS: u32 = 14; +const N_LANES: u32 = 16; +const N_COLUMNS_PER_REP: u32 = N_STATE * (1 + FULL_ROUNDS) + N_PARTIAL_ROUNDS; +const LOG_N_LANES: u32 = 4; +const THREADS_PER_WORKGROUP: u32 = 64; +const WORKGROUP_SIZE: u32 = 1; +const TOTAL_THREAD_SIZE: u32 = THREADS_PER_WORKGROUP * WORKGROUP_SIZE; + +// Initialize EXTERNAL_ROUND_CONSTS with explicit values +var EXTERNAL_ROUND_CONSTS: array, FULL_ROUNDS> = array, FULL_ROUNDS>( + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), +); + +// Initialize INTERNAL_ROUND_CONSTS with explicit values +var INTERNAL_ROUND_CONSTS: array = array( + 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234 +); + +struct BaseColumn { + data: array, + length: u32, +} + +struct PackedM31 { + data: array, +} + +struct GenTraceInput { + log_n_rows: u32, +} + +struct StateData { + data: array, +} + +struct LookupData { + initial_state: array, N_INSTANCES_PER_ROW>, + final_state: array, N_INSTANCES_PER_ROW>, +} + +struct GenTraceOutput { + trace: array, + lookup_data: LookupData, +} + +struct ShaderResult { + values: array, +} + +@group(0) @binding(0) +var input: GenTraceInput; + +// Output buffer +@group(0) @binding(1) +var output: GenTraceOutput; + +@group(0) @binding(2) +var state_data: StateData; + +@group(0) @binding(3) +var shader_result: ShaderResult; + +fn from_u32(value: u32) -> PackedM31 { + var packedM31 = PackedM31(); + for (var i = 0u; i < N_LANES; i++) { + packedM31.data[i] = value; + } + return packedM31; +} + +fn add(a: PackedM31, b: PackedM31) -> PackedM31 { + var packedM31 = PackedM31(); + for (var i = 0u; i < N_LANES; i++) { + packedM31.data[i] = partial_reduce(a.data[i] + b.data[i]); + } + return packedM31; +} + +fn mul(a: PackedM31, b: PackedM31) -> PackedM31 { + var packedM31 = PackedM31(); + for (var i = 0u; i < N_LANES; i++) { + var temp: u64 = u64(a.data[i]); + temp = temp * u64(b.data[i]); + packedM31.data[i] = full_reduce(temp); + } + return packedM31; +} + +// Partial reduce for values in [0, 2P) +fn partial_reduce(val: u32) -> u32 { + let reduced = val - P; + return select(val, reduced, reduced < val); +} + +fn full_reduce(val: u64) -> u32 { + let first_shift = val >> MODULUS_BITS; + let first_sum = first_shift + val + 1; + let second_shift = first_sum >> MODULUS_BITS; + let final_sum = second_shift + val; + return u32(final_sum & u64(P)); +} + +// Function to apply pow5 operation +fn pow5(x: PackedM31) -> PackedM31 { + return mul(mul(mul(x, x), mul(x, x)), x); +} + +/// Applies the external round matrix. +/// See 5.1 and Appendix B. +fn apply_external_round_matrix(state: array) -> array { + // Applies circ(2M4, M4, M4, M4). + var modified_state = state; + for (var i = 0u; i < 4u; i++) { + let partial_state = array( + state[4 * i], + state[4 * i + 1], + state[4 * i + 2], + state[4 * i + 3], + ); + let modified_partial_state = apply_m4(partial_state); + modified_state[4 * i] = modified_partial_state[0]; + modified_state[4 * i + 1] = modified_partial_state[1]; + modified_state[4 * i + 2] = modified_partial_state[2]; + modified_state[4 * i + 3] = modified_partial_state[3]; + } + for (var j = 0u; j < 4u; j++) { + let s = add(add(modified_state[j], modified_state[j + 4]), add(modified_state[j + 8], modified_state[j + 12])); + for (var i = 0u; i < 4u; i++) { + modified_state[4 * i + j] = add(modified_state[4 * i + j], s); + } + } + return modified_state; +} + +// Applies the internal round matrix. +// mu_i = 2^{i+1} + 1. +// See 5.2. +fn apply_internal_round_matrix(state: array) -> array { + var sum = state[0]; + for (var i = 1u; i < N_STATE; i++) { + sum = add(sum, state[i]); + } + + var result = array(); + for (var i = 0u; i < N_STATE; i++) { + let factor = partial_reduce(1u << (i + 1)); + result[i] = add(mul(from_u32(factor), state[i]), sum); + } + + return result; +} + +/// Applies the M4 MDS matrix described in 5.1. +fn apply_m4(x: array) -> array { + let t0 = add(x[0], x[1]); + let t02 = add(t0, t0); + let t1 = add(x[2], x[3]); + let t12 = add(t1, t1); + let t2 = add(add(x[1], x[1]), t1); + let t3 = add(add(x[3], x[3]), t0); + let t4 = add(add(t12, t12), t3); + let t5 = add(add(t02, t02), t2); + let t6 = add(t3, t5); + let t7 = add(t2, t4); + return array(t6, t5, t7, t4); +} + +@compute @workgroup_size(64) +fn gen_trace( + @builtin(workgroup_id) workgroup_id: vec3, + @builtin(local_invocation_id) local_invocation_id: vec3, + @builtin(global_invocation_id) global_invocation_id: vec3, + @builtin(local_invocation_index) local_invocation_index: u32, + @builtin(num_workgroups) num_workgroups: vec3, +) { + let workgroup_index = + workgroup_id.x + + workgroup_id.y * num_workgroups.x + + workgroup_id.z * num_workgroups.x * num_workgroups.y; + + let global_invocation_index = workgroup_index * WORKGROUP_SIZE + local_invocation_index; + + shader_result.values[global_invocation_index] = global_invocation_index; + + for (var i = 0u; i < N_COLUMNS; i++) { + output.trace[i].length = N_ROWS * N_LANES; + } + + for (var i = 0u; i < N_INSTANCES_PER_ROW; i++) { + for (var j = 0u; j < N_STATE; j++) { + output.lookup_data.initial_state[i][j].length = N_ROWS * N_LANES; + output.lookup_data.final_state[i][j].length = N_ROWS * N_LANES; + } + } + + let log_size = input.log_n_rows; + + var vec_index = global_invocation_index; + // for (var vec_index = 0u; vec_index < (1u << (log_size - LOG_N_LANES)); vec_index++) { + var col_index = 0u; + + for (var rep_i = 0u; rep_i < N_INSTANCES_PER_ROW; rep_i++) { + var state: array = initialize_state(vec_index, rep_i); + + for (var i = 0u; i < N_STATE; i++) { + output.trace[col_index].data[vec_index] = state[i]; + col_index += 1u; + } + + for (var i = 0u; i < N_STATE; i++) { + output.lookup_data.initial_state[rep_i][i].data[vec_index] = state[i]; + output.lookup_data.initial_state[rep_i][i].length = N_ROWS * N_LANES; + } + + // 4 full rounds + for (var i = 0u; i < N_HALF_FULL_ROUNDS; i++) { + for (var j = 0u; j < N_STATE; j++) { + state[j] = add(state[j], from_u32(EXTERNAL_ROUND_CONSTS[i][j])); + } + state = apply_external_round_matrix(state); + for (var j = 0u; j < N_STATE; j++) { + state[j] = pow5(state[j]); + } + for (var j = 0u; j < N_STATE; j++) { + output.trace[col_index].data[vec_index] = state[j]; + col_index += 1u; + } + } + // Partial rounds + for (var i = 0u; i < N_PARTIAL_ROUNDS; i++) { + state[0] = add(state[0], from_u32(INTERNAL_ROUND_CONSTS[i])); + state = apply_internal_round_matrix(state); + state[0] = pow5(state[0]); + output.trace[col_index].data[vec_index] = state[0]; + col_index += 1u; + } + // 4 full rounds + for (var i = 0u; i < N_HALF_FULL_ROUNDS; i++) { + for (var j = 0u; j < N_STATE; j++) { + state[j] = add(state[j], from_u32(EXTERNAL_ROUND_CONSTS[i + N_HALF_FULL_ROUNDS][j])); + } + state = apply_external_round_matrix(state); + for (var j = 0u; j < N_STATE; j++) { + state[j] = pow5(state[j]); + } + for (var j = 0u; j < N_STATE; j++) { + output.trace[col_index].data[vec_index] = state[j]; + col_index += 1u; + } + } + + for (var j = 0u; j < N_STATE; j++) { + output.lookup_data.final_state[rep_i][j].data[vec_index] = state[j]; + } + } + // } +} + +// Function to initialize the state array +fn initialize_state(vec_index: u32, rep_i: u32) -> array { + var state: array; + + for (var state_i = 0u; state_i < N_STATE; state_i++) { + // Initialize each element of the state array + var packed_value = PackedM31(); + + for (var i = 0u; i < N_LANES; i++) { + // Calculate the value based on vec_index, state_i, and rep_i + let value: u32 = vec_index * 16u + i + state_i + rep_i; + // Here, you would typically pack this value into a PackedBaseField equivalent + // For simplicity, we'll just assign it directly + packed_value.data[i] = value; // Replace with actual packing logic if needed + } + state[state_i] = packed_value; + } + + return state; +} diff --git a/crates/prover/src/core/backend/gpu/mod.rs b/crates/prover/src/core/backend/gpu/mod.rs index 7a0c69372d..0edf52c13e 100644 --- a/crates/prover/src/core/backend/gpu/mod.rs +++ b/crates/prover/src/core/backend/gpu/mod.rs @@ -1 +1,2 @@ pub mod gen_trace; +pub mod gen_trace_parallel; diff --git a/crates/prover/src/core/backend/simd/column.rs b/crates/prover/src/core/backend/simd/column.rs index dd5578c0ea..92c2dcc11a 100644 --- a/crates/prover/src/core/backend/simd/column.rs +++ b/crates/prover/src/core/backend/simd/column.rs @@ -33,7 +33,7 @@ impl FieldOps for SimdBackend { } /// An efficient structure for storing and operating on a arbitrary number of [`BaseField`] values. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct BaseColumn { pub data: Vec, /// The number of [`BaseField`]s in the vector. diff --git a/crates/prover/src/core/backend/simd/m31.rs b/crates/prover/src/core/backend/simd/m31.rs index dbeec152f7..90fe56c231 100644 --- a/crates/prover/src/core/backend/simd/m31.rs +++ b/crates/prover/src/core/backend/simd/m31.rs @@ -27,7 +27,7 @@ pub type PackedBaseField = PackedM31; /// /// Implemented with [`std::simd`] to support multiple targets (avx512, neon, wasm etc.). // TODO: Remove `pub` visibility -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, PartialEq)] #[repr(transparent)] pub struct PackedM31(Simd); diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 76184621c3..a886f44b8e 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -1,5 +1,6 @@ //! AIR for Poseidon2 hash function from . +use core::fmt::Debug; use std::ops::{Add, AddAssign, Mul, Sub}; use itertools::Itertools; @@ -195,9 +196,10 @@ pub fn eval_poseidon_constraints(eval: &mut E, lookup_elements: &P eval.finalize_logup(); } +#[derive(Debug, Clone, PartialEq)] pub struct LookupData { - initial_state: [[BaseColumn; N_STATE]; N_INSTANCES_PER_ROW], - final_state: [[BaseColumn; N_STATE]; N_INSTANCES_PER_ROW], + pub initial_state: [[BaseColumn; N_STATE]; N_INSTANCES_PER_ROW], + pub final_state: [[BaseColumn; N_STATE]; N_INSTANCES_PER_ROW], } pub fn gen_trace( log_size: u32, @@ -485,9 +487,18 @@ mod tests { #[test] fn test_gpu_poseidon_constraints() { - use crate::core::backend::gpu::gen_trace::gen_trace as gen_trace_gpu; - - pollster::block_on(gen_trace_gpu()); + // use crate::core::backend::gpu::gen_trace::gen_trace as gen_trace_gpu; + use crate::core::backend::gpu::gen_trace_parallel::gen_trace_parallel as gen_trace_gpu; + + let log_n_instances = 13; + let log_n_instances_per_row = 3; + let log_n_rows = log_n_instances - log_n_instances_per_row; + let (_gpu_trace, _gpu_lookup_data) = pollster::block_on(gen_trace_gpu(log_n_rows)); + + let (_trace, _lookup_data) = gen_trace(log_n_rows); + let _cpu_trace = _trace.into_iter().map(|c| c.values.clone()).collect_vec(); + assert_eq!(_cpu_trace, _gpu_trace); + assert_eq!(_lookup_data, _gpu_lookup_data); } #[test_log::test]