diff --git a/src/lazy.rs b/src/lazy.rs index d8807a9..c163f8e 100644 --- a/src/lazy.rs +++ b/src/lazy.rs @@ -5,6 +5,12 @@ use dfdx::{ use std::any::{Any, TypeId}; +#[derive(Debug)] +pub struct MemoryMappedTensor { + tensor: Tensor, + mmap: memmap2::Mmap, +} + #[derive(Debug)] #[non_exhaustive] pub enum LazyTensor { @@ -13,24 +19,35 @@ pub enum LazyTensor { shape: S, move_to_ram: bool, }, - Cpu(Tensor), + MemoryMapped(Option>), #[cfg(feature = "cuda")] Cuda(Tensor), } impl LazyTensor { pub fn defer_load(&mut self) { - if let Self::Disk { - path: _, - shape: _, - move_to_ram, - } = self - { + if let Self::Disk { move_to_ram, .. } = self { *move_to_ram = true; } } } +impl Drop for LazyTensor { + fn drop(&mut self) { + match self { + Self::Disk { .. } => {} + Self::MemoryMapped(tensor) => { + if let Some(MemoryMappedTensor { tensor, .. }) = std::mem::take(tensor) { + // since this tensor doesn't own the vec, we need to forget it so it doesn't get dropped + std::mem::forget(tensor); + } + } + #[cfg(feature = "cuda")] + Self::Cuda(tensor) => {} + } + } +} + impl LazyTensor { pub fn num_bytes(&self) -> usize { self.shape().num_elements() * std::mem::size_of::() @@ -42,24 +59,15 @@ impl LazyTensor { fn shape(&self) -> S { match self { - Self::Disk { - path: _, - shape, - move_to_ram: _, - } => *shape, - Self::Cpu(tensor) => *tensor.shape(), + Self::Disk { shape, .. } => *shape, + Self::MemoryMapped(tensor) => *tensor.as_ref().unwrap().tensor.shape(), #[cfg(feature = "cuda")] Self::Cuda(tensor) => *tensor.shape(), } } pub fn move_to_ram + TensorFromVec + CopySlice>(&mut self, device: &D) { - if let Self::Disk { - path: _, - shape: _, - move_to_ram, - } = self - { + if let Self::Disk { move_to_ram, .. } = self { if *move_to_ram { self.get_on(device); } @@ -73,78 +81,96 @@ impl LazyTensor { let shape = self.shape(); let numel = shape.num_elements(); - match &self { + match self { Self::Disk { path, shape, move_to_ram, } => { - let mut loaded = device.zeros_like(shape); - let file = std::fs::File::open(path).unwrap(); - let mmap = unsafe { memmap2::Mmap::map(&file).unwrap() }; - let bytes: &[u8] = &mmap; - let ptr = bytes.as_ptr() as *const E; - assert!(bytes.len() < (isize::MAX as usize)); - assert_eq!(bytes.len(), numel * std::mem::size_of::()); - assert_eq!(ptr.align_offset(std::mem::align_of::()), 0); - // # Safety - // - assertion checks for byte length - // - non-null because we created from bytes slice - // - aligned due to assertion - let slice = unsafe { std::slice::from_raw_parts(ptr, numel) }; - loaded.copy_from(slice); - - if *move_to_ram { + if !*move_to_ram { + let mut loaded = device.zeros_like(shape); + let file = std::fs::File::open(path).unwrap(); + let mmap = unsafe { memmap2::Mmap::map(&file).unwrap() }; + let bytes: &[u8] = &mmap; + let ptr = bytes.as_ptr() as *const E; + assert!(bytes.len() < (isize::MAX as usize)); + assert_eq!(bytes.len(), numel * std::mem::size_of::()); + assert_eq!(ptr.align_offset(std::mem::align_of::()), 0); + // # Safety + // - assertion checks for byte length + // - non-null because we created from bytes slice + // - aligned due to assertion + let slice = unsafe { std::slice::from_raw_parts(ptr, numel) }; + loaded.copy_from(slice); + loaded + } else { if TypeId::of::() == TypeId::of::() { + let file = std::fs::File::open(path).unwrap(); + let mmap = unsafe { memmap2::Mmap::map(&file).unwrap() }; + let bytes: &[u8] = &mmap; + let ptr = bytes.as_ptr() as *mut E; + assert!(bytes.len() < (isize::MAX as usize)); + assert_eq!(bytes.len(), numel * std::mem::size_of::()); + assert_eq!(ptr.align_offset(std::mem::align_of::()), 0); + // # Safety + // TODO + let vec = unsafe { Vec::from_raw_parts(ptr, numel, numel) }; + let loaded = device.tensor_from_vec(vec, *shape); let tensor: Box = Box::new(loaded.clone()); - *self = Self::Cpu(*tensor.downcast().unwrap()); + *self = Self::MemoryMapped(Some(MemoryMappedTensor { + tensor: *tensor.downcast().unwrap(), + mmap, + })); + loaded } else { #[cfg(feature = "cuda")] if TypeId::of::() == TypeId::of::() { - let tensor: Box = Box::new(loaded.clone()); - *self = Self::Cuda(*tensor.downcast().unwrap()); + let mut loaded = device.zeros_like(shape); + let file = std::fs::File::open(path).unwrap(); + let mmap = unsafe { memmap2::Mmap::map(&file).unwrap() }; + let bytes: &[u8] = &mmap; + let ptr = bytes.as_ptr() as *const E; + assert!(bytes.len() < (isize::MAX as usize)); + assert_eq!(bytes.len(), numel * std::mem::size_of::()); + assert_eq!(ptr.align_offset(std::mem::align_of::()), 0); + // # Safety + // - assertion checks for byte length + // - non-null because we created from bytes slice + // - aligned due to assertion + let slice = unsafe { std::slice::from_raw_parts(ptr, numel) }; + loaded.copy_from(slice); + let t: Box = Box::new(tensor.as_ref().unwrap().tensor.clone()); + self = Self::Cuda(*tensor.downcast().unwrap()); + loaded } else { panic!("Unsupported device found (not Cpu/Cuda"); } #[cfg(not(feature = "cuda"))] - panic!("Unsupported device found (not Cpu/Cuda"); + panic!("Unsupported device found (not Cpu/Cuda)"); } } - loaded } - Self::Cpu(tensor) => { - if TypeId::of::() == TypeId::of::() { - // Here since we know `D` is of type `Cpu`, we can just clone the tensor. - // However we can't easily return `tensor.clone()` because of the generic - // type. - // - // One idea might be to use std::mem::transmute, however that gives us - // an error about depedendly sized types for some reason. - // - // Instead we can go through Box and downcast it, which basically - // goes through pointers to do this. - let t: Box = Box::new(tensor.clone()); - *t.downcast().unwrap() - } else { - let mut loaded = device.zeros_like(tensor.shape()); - let buf = tensor.as_vec(); - loaded.copy_from(&buf); - loaded - } + Self::MemoryMapped(tensor) => { + // Here since we know `D` is of type `Cpu`, we can just clone the tensor. + // However we can't easily return `tensor.clone()` because of the generic + // type. + // + // One idea might be to use std::mem::transmute, however that gives us + // an error about depedendly sized types for some reason. + // + // Instead we can go through Box and downcast it, which basically + // goes through pointers to do this. + assert_eq!(TypeId::of::(), TypeId::of::()); + let t: Box = Box::new(tensor.as_ref().unwrap().tensor.clone()); + *t.downcast().unwrap() } #[cfg(feature = "cuda")] Self::Cuda(tensor) => { - if TypeId::of::() == TypeId::of::() { - // See comment in corresponding Self::CPU branch. - let t: Box = Box::new(tensor.clone()); - *t.downcast().unwrap() - } else { - let mut loaded = device.zeros_like(tensor.shape()); - let buf = tensor.as_vec(); - loaded.copy_from(&buf); - loaded - } + // See comment in corresponding Self::CPU branch. + assert_eq!(TypeId::of::(), TypeId::of::()); + let t: Box = Box::new(tensor.clone()); + *t.downcast().unwrap() } } }