Skip to content

Commit

Permalink
Impl stats lin sol write json
Browse files Browse the repository at this point in the history
  • Loading branch information
cpmech committed Oct 21, 2023
1 parent 46f76a6 commit 8fe50b5
Showing 1 changed file with 62 additions and 10 deletions.
72 changes: 62 additions & 10 deletions russell_sparse/src/stats_lin_sol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use russell_lab::{format_nanoseconds, get_num_threads, using_intel_mkl};
use serde::{Deserialize, Serialize};
use serde_json;
use std::ffi::OsStr;
use std::fs::File;
use std::fs::{self, File};
use std::io::BufReader;
use std::path::Path;

Expand Down Expand Up @@ -179,15 +179,7 @@ impl StatsLinSol {

/// Gets a JSON representation of the stats structure
pub fn get_json(&mut self) -> String {
self.output.openmp_num_threads = get_num_threads();
self.time_nanoseconds.total_ifs =
self.time_nanoseconds.initialize + self.time_nanoseconds.factorize + self.time_nanoseconds.solve;
self.time_human.read_matrix = format_nanoseconds(self.time_nanoseconds.read_matrix);
self.time_human.initialize = format_nanoseconds(self.time_nanoseconds.initialize);
self.time_human.factorize = format_nanoseconds(self.time_nanoseconds.factorize);
self.time_human.solve = format_nanoseconds(self.time_nanoseconds.solve);
self.time_human.total_ifs = format_nanoseconds(self.time_nanoseconds.total_ifs);
self.time_human.verify = format_nanoseconds(self.time_nanoseconds.verify);
self.compute_derived_values();
serde_json::to_string_pretty(&self).unwrap()
}

Expand All @@ -206,6 +198,38 @@ impl StatsLinSol {
let stat = serde_json::from_reader(buffered).map_err(|_| "cannot parse JSON file")?;
Ok(stat)
}

/// Writes a JSON file with the results
///
/// # Input
///
/// * `full_path` -- may be a String, &str, or Path
pub fn write_json<P>(&mut self, full_path: &P) -> Result<(), StrError>
where
P: AsRef<OsStr> + ?Sized,
{
self.compute_derived_values();
let path = Path::new(full_path).to_path_buf();
if let Some(p) = path.parent() {
fs::create_dir_all(p).map_err(|_| "cannot create directory")?;
}
let mut file = File::create(&path).map_err(|_| "cannot create file")?;
serde_json::to_writer_pretty(&mut file, &self).map_err(|_| "cannot write file")?;
Ok(())
}

/// Computes derived values
fn compute_derived_values(&mut self) {
self.output.openmp_num_threads = get_num_threads();
self.time_nanoseconds.total_ifs =
self.time_nanoseconds.initialize + self.time_nanoseconds.factorize + self.time_nanoseconds.solve;
self.time_human.read_matrix = format_nanoseconds(self.time_nanoseconds.read_matrix);
self.time_human.initialize = format_nanoseconds(self.time_nanoseconds.initialize);
self.time_human.factorize = format_nanoseconds(self.time_nanoseconds.factorize);
self.time_human.solve = format_nanoseconds(self.time_nanoseconds.solve);
self.time_human.total_ifs = format_nanoseconds(self.time_nanoseconds.total_ifs);
self.time_human.verify = format_nanoseconds(self.time_nanoseconds.verify);
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -275,4 +299,32 @@ mod tests {
assert_eq!(stats.matrix.name, "pre2");
assert_eq!(stats.matrix.symmetry, "None");
}

#[test]
fn write_json_works() {
let mut stats = StatsLinSol::new();
const ONE_SECOND: u128 = 1000000000;
stats.time_nanoseconds.read_matrix = ONE_SECOND;
stats.time_nanoseconds.initialize = ONE_SECOND;
stats.time_nanoseconds.factorize = ONE_SECOND * 2;
stats.time_nanoseconds.solve = ONE_SECOND * 3;
stats.time_nanoseconds.verify = ONE_SECOND * 4;
let path = "/tmp/russell/write_json_works.json";
stats.write_json(path).unwrap();
let res = StatsLinSol::read_json(path).unwrap();
assert!(res.output.openmp_num_threads > 0);
assert_eq!(res.time_nanoseconds.read_matrix, ONE_SECOND);
assert_eq!(res.time_nanoseconds.initialize, ONE_SECOND);
assert_eq!(res.time_nanoseconds.factorize, ONE_SECOND * 2);
assert_eq!(res.time_nanoseconds.solve, ONE_SECOND * 3);
assert_eq!(res.time_nanoseconds.total_ifs, ONE_SECOND * 6);
assert_eq!(res.time_nanoseconds.verify, ONE_SECOND * 4);
assert_eq!(res.time_nanoseconds.total_ifs, ONE_SECOND * 6);
assert_eq!(res.time_human.read_matrix, "1s");
assert_eq!(res.time_human.initialize, "1s");
assert_eq!(res.time_human.factorize, "2s");
assert_eq!(res.time_human.solve, "3s");
assert_eq!(res.time_human.total_ifs, "6s");
assert_eq!(res.time_human.verify, "4s");
}
}

0 comments on commit 8fe50b5

Please sign in to comment.