From 37b9d6e9b7bfb60afbce261fb4712497226661a3 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Tue, 31 Dec 2024 09:51:27 +0100 Subject: [PATCH] Fix accumulated values for backdated queries --- src/accumulator.rs | 4 +- src/accumulator/accumulated_map.rs | 39 +++++++++++--- src/function.rs | 4 +- src/function/maybe_changed_after.rs | 52 +++++++++++++++---- src/function/memo.rs | 3 ++ src/function/specify.rs | 2 + src/ingredient.rs | 24 ++++++++- src/input.rs | 6 +-- src/input/input_field.rs | 7 +-- src/interned.rs | 6 +-- src/key.rs | 5 +- src/tracked_struct.rs | 6 +-- src/tracked_struct/tracked_field.rs | 10 ++-- tests/accumulated_backdate.rs | 79 +++++++++++++++++++++++++++++ 14 files changed, 206 insertions(+), 41 deletions(-) create mode 100644 tests/accumulated_backdate.rs diff --git a/src/accumulator.rs b/src/accumulator.rs index 8cb38e2e1..929da03e6 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -12,7 +12,7 @@ use accumulated_map::AccumulatedMap; use crate::{ cycle::CycleRecoveryStrategy, - ingredient::{fmt_index, Ingredient, Jar}, + ingredient::{fmt_index, Ingredient, Jar, MaybeChangedAfter}, plumbing::JarAux, zalsa::IngredientIndex, zalsa_local::QueryOrigin, @@ -106,7 +106,7 @@ impl Ingredient for IngredientImpl { _db: &dyn Database, _input: Option, _revision: Revision, - ) -> bool { + ) -> MaybeChangedAfter { panic!("nothing should ever depend on an accumulator directly") } diff --git a/src/accumulator/accumulated_map.rs b/src/accumulator/accumulated_map.rs index 5c9e92543..648e161cb 100644 --- a/src/accumulator/accumulated_map.rs +++ b/src/accumulator/accumulated_map.rs @@ -1,3 +1,6 @@ +use std::ops::{BitOr, BitOrAssign}; + +use crossbeam::atomic::AtomicCell; use rustc_hash::FxHashMap; use crate::IngredientIndex; @@ -10,7 +13,7 @@ pub struct AccumulatedMap { /// [`InputAccumulatedValues::Empty`] if any input read during the query's execution /// has any direct or indirect accumulated values. - inputs: InputAccumulatedValues, + inputs: AtomicCell, } impl AccumulatedMap { @@ -22,18 +25,22 @@ impl AccumulatedMap { } /// Adds the accumulated state of an input to this accumulated map. - pub(crate) fn add_input(&mut self, input: InputAccumulatedValues) { + pub(crate) fn add_input(&self, input: InputAccumulatedValues) { if input.is_any() { - self.inputs = InputAccumulatedValues::Any; + self.inputs.store(InputAccumulatedValues::Any); } } + pub(crate) fn set_inputs(&self, input: InputAccumulatedValues) { + self.inputs.store(input); + } + /// Returns whether an input of the associated query has any accumulated values. /// /// Note: Use [`InputAccumulatedValues::from_map`] to check if the associated query itself /// or any of its inputs has accumulated values. pub(crate) fn inputs(&self) -> InputAccumulatedValues { - self.inputs + self.inputs.load() } pub fn extend_with_accumulated( @@ -60,7 +67,7 @@ impl Clone for AccumulatedMap { .iter() .map(|(&key, value)| (key, value.cloned())) .collect(), - inputs: self.inputs, + inputs: AtomicCell::new(self.inputs.load()), } } } @@ -70,7 +77,7 @@ impl Clone for AccumulatedMap { /// Knowning whether any input has accumulated values makes aggregating the accumulated values /// cheaper because we can skip over entire subtrees. #[derive(Copy, Clone, Debug, Default)] -pub(crate) enum InputAccumulatedValues { +pub enum InputAccumulatedValues { /// The query nor any of its inputs have any accumulated values. #[default] Empty, @@ -82,7 +89,7 @@ pub(crate) enum InputAccumulatedValues { impl InputAccumulatedValues { pub(crate) fn from_map(accumulated: &AccumulatedMap) -> Self { if accumulated.map.is_empty() { - accumulated.inputs + accumulated.inputs.load() } else { Self::Any } @@ -96,3 +103,21 @@ impl InputAccumulatedValues { matches!(self, Self::Empty) } } + +impl BitOr for InputAccumulatedValues { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { + if rhs.is_any() { + InputAccumulatedValues::Any + } else { + self + } + } +} + +impl BitOrAssign for InputAccumulatedValues { + fn bitor_assign(&mut self, rhs: Self) { + *self = *self | rhs; + } +} \ No newline at end of file diff --git a/src/function.rs b/src/function.rs index b06be486c..7b69e52dc 100644 --- a/src/function.rs +++ b/src/function.rs @@ -3,7 +3,7 @@ use std::{any::Any, fmt, sync::Arc}; use crate::{ accumulator::accumulated_map::AccumulatedMap, cycle::CycleRecoveryStrategy, - ingredient::fmt_index, + ingredient::{fmt_index, MaybeChangedAfter}, key::DatabaseKeyIndex, plumbing::JarAux, salsa_struct::SalsaStructInDb, @@ -194,7 +194,7 @@ where db: &dyn Database, input: Option, revision: Revision, - ) -> bool { + ) -> MaybeChangedAfter { let key = input.unwrap(); let db = db.as_view::(); self.maybe_changed_after(db, key, revision) diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index b1d671a36..f93502a56 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -1,4 +1,6 @@ use crate::{ + accumulator::accumulated_map::InputAccumulatedValues, + ingredient::MaybeChangedAfter, key::DatabaseKeyIndex, zalsa::{Zalsa, ZalsaDatabase}, zalsa_local::{ActiveQueryGuard, EdgeKind, QueryOrigin}, @@ -16,7 +18,7 @@ where db: &'db C::DbView, id: Id, revision: Revision, - ) -> bool { + ) -> MaybeChangedAfter { let (zalsa, zalsa_local) = db.zalsas(); zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database()); @@ -29,7 +31,11 @@ where let memo_guard = self.get_memo_from_table_for(zalsa, id); if let Some(memo) = &memo_guard { if self.shallow_verify_memo(db, zalsa, database_key_index, memo) { - return memo.revisions.changed_at > revision; + return if memo.revisions.changed_at > revision { + MaybeChangedAfter::Yes + } else { + MaybeChangedAfter::No(memo.revisions.accumulated.inputs()) + }; } drop(memo_guard); // release the arc-swap guard before cold path if let Some(mcs) = self.maybe_changed_after_cold(db, id, revision) { @@ -39,7 +45,7 @@ where } } else { // No memo? Assume has changed. - return true; + return MaybeChangedAfter::Yes; } } } @@ -49,7 +55,7 @@ where db: &'db C::DbView, key_index: Id, revision: Revision, - ) -> Option { + ) -> Option { let (zalsa, zalsa_local) = db.zalsas(); let database_key_index = self.database_key_index(key_index); @@ -63,7 +69,7 @@ where // Load the current memo, if any. let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index) else { - return Some(true); + return Some(MaybeChangedAfter::Yes); }; tracing::debug!( @@ -74,7 +80,11 @@ where // Check if the inputs are still valid and we can just compare `changed_at`. if self.deep_verify_memo(db, &old_memo, &active_query) { - return Some(old_memo.revisions.changed_at > revision); + return Some(if old_memo.revisions.changed_at > revision { + MaybeChangedAfter::Yes + } else { + MaybeChangedAfter::No(old_memo.revisions.accumulated.inputs()) + }); } // If inputs have changed, but we have an old value, we can re-execute. @@ -84,11 +94,18 @@ where if old_memo.value.is_some() { let memo = self.execute(db, active_query, Some(old_memo)); let changed_at = memo.revisions.changed_at; - return Some(changed_at > revision); + + return Some(if changed_at > revision { + MaybeChangedAfter::Yes + } else { + MaybeChangedAfter::No( + memo.revisions.accumulated.inputs() | InputAccumulatedValues::from_map(&memo.revisions.accumulated) + ) + }); } // Otherwise, nothing for it: have to consider the value to have changed. - Some(true) + Some(MaybeChangedAfter::Yes) } /// True if the memo's value and `changed_at` time is still valid in this revision. @@ -117,7 +134,12 @@ where if memo.check_durability(zalsa) { // No input of the suitable durability has changed since last verified. let db = db.as_dyn_database(); - memo.mark_as_verified(db, revision_now, database_key_index); + memo.mark_as_verified( + db, + revision_now, + database_key_index, + memo.revisions.accumulated.inputs(), + ); memo.mark_outputs_as_verified(db, database_key_index); return true; } @@ -151,6 +173,8 @@ where return true; } + let mut inputs = InputAccumulatedValues::default(); + match &old_memo.revisions.origin { QueryOrigin::Assigned(_) => { // If the value was assigneed by another query, @@ -185,10 +209,15 @@ where for &(edge_kind, dependency_index) in edges.input_outputs.iter() { match edge_kind { EdgeKind::Input => { - if dependency_index + match dependency_index .maybe_changed_after(db.as_dyn_database(), last_verified_at) { - return false; + MaybeChangedAfter::Yes => { + return false; + } + MaybeChangedAfter::No(input_accumulated) => { + inputs |= input_accumulated; + } } } EdgeKind::Output => { @@ -220,6 +249,7 @@ where db.as_dyn_database(), zalsa.current_revision(), database_key_index, + inputs, ); true } diff --git a/src/function/memo.rs b/src/function/memo.rs index 304c6c314..38d01133a 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use crossbeam::atomic::AtomicCell; +use crate::accumulator::accumulated_map::InputAccumulatedValues; use crate::zalsa_local::QueryOrigin; use crate::{ key::DatabaseKeyIndex, zalsa::Zalsa, zalsa_local::QueryRevisions, Event, EventKind, Id, @@ -143,6 +144,7 @@ impl Memo { db: &dyn crate::Database, revision_now: Revision, database_key_index: DatabaseKeyIndex, + accumulated: InputAccumulatedValues, ) { db.salsa_event(&|| Event { thread_id: std::thread::current().id(), @@ -152,6 +154,7 @@ impl Memo { }); self.verified_at.store(revision_now); + self.revisions.accumulated.set_inputs(accumulated); } pub(super) fn mark_outputs_as_verified( diff --git a/src/function/specify.rs b/src/function/specify.rs index 9eccad65b..5d4105169 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -1,6 +1,7 @@ use crossbeam::atomic::AtomicCell; use crate::{ + accumulator::accumulated_map::InputAccumulatedValues, tracked_struct::TrackedStructInDb, zalsa::ZalsaDatabase, zalsa_local::{QueryOrigin, QueryRevisions}, @@ -127,6 +128,7 @@ where db.as_dyn_database(), zalsa.current_revision(), database_key_index, + InputAccumulatedValues::Empty, ); } } diff --git a/src/ingredient.rs b/src/ingredient.rs index 8a46205d2..858d5009e 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -4,7 +4,7 @@ use std::{ }; use crate::{ - accumulator::accumulated_map::AccumulatedMap, + accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}, cycle::CycleRecoveryStrategy, zalsa::{IngredientIndex, MemoIngredientIndex}, zalsa_local::QueryOrigin, @@ -61,7 +61,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { db: &'db dyn Database, input: Option, revision: Revision, - ) -> bool; + ) -> MaybeChangedAfter; /// What were the inputs (if any) that were used to create the value at `key_index`. fn origin(&self, db: &dyn Database, key_index: Id) -> Option; @@ -172,3 +172,23 @@ pub(crate) fn fmt_index( write!(fmt, "{debug_name}()") } } + +#[derive(Copy, Clone, Debug)] +pub enum MaybeChangedAfter { + /// The query result hasn't changed. + /// + /// The inner value tracks whether the memo or any of its dependencies have an accumulated value. + No(InputAccumulatedValues), + + /// The query's result has changed since the last revision or the query isn't cached yet. + Yes, +} + +impl From for MaybeChangedAfter { + fn from(value: bool) -> Self { + match value { + true => MaybeChangedAfter::Yes, + false => MaybeChangedAfter::No(InputAccumulatedValues::Empty), + } + } +} diff --git a/src/input.rs b/src/input.rs index eac540abf..e014648ac 100644 --- a/src/input.rs +++ b/src/input.rs @@ -15,7 +15,7 @@ use crate::{ accumulator::accumulated_map::InputAccumulatedValues, cycle::CycleRecoveryStrategy, id::{AsId, FromId}, - ingredient::{fmt_index, Ingredient}, + ingredient::{fmt_index, Ingredient, MaybeChangedAfter}, key::{DatabaseKeyIndex, DependencyIndex}, plumbing::{Jar, JarAux, Stamp}, table::{memo::MemoTable, sync::SyncTable, Slot, Table}, @@ -222,10 +222,10 @@ impl Ingredient for IngredientImpl { _db: &dyn Database, _input: Option, _revision: Revision, - ) -> bool { + ) -> MaybeChangedAfter { // Input ingredients are just a counter, they store no data, they are immortal. // Their *fields* are stored in function ingredients elsewhere. - false + MaybeChangedAfter::No(InputAccumulatedValues::Empty) } fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { diff --git a/src/input/input_field.rs b/src/input/input_field.rs index fd3082256..d52ff673c 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -1,5 +1,5 @@ use crate::cycle::CycleRecoveryStrategy; -use crate::ingredient::{fmt_index, Ingredient}; +use crate::ingredient::{fmt_index, Ingredient, MaybeChangedAfter}; use crate::input::Configuration; use crate::zalsa::IngredientIndex; use crate::zalsa_local::QueryOrigin; @@ -54,11 +54,12 @@ where db: &dyn Database, input: Option, revision: Revision, - ) -> bool { + ) -> MaybeChangedAfter { let zalsa = db.zalsa(); let input = input.unwrap(); let value = >::data(zalsa, input); - value.stamps[self.field_index].changed_at > revision + + MaybeChangedAfter::from(value.stamps[self.field_index].changed_at > revision) } fn origin(&self, _db: &dyn Database, _key_index: Id) -> Option { diff --git a/src/interned.rs b/src/interned.rs index 62ef4fe62..6eebaed9f 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -1,7 +1,7 @@ use crate::accumulator::accumulated_map::InputAccumulatedValues; use crate::durability::Durability; use crate::id::AsId; -use crate::ingredient::fmt_index; +use crate::ingredient::{fmt_index, MaybeChangedAfter}; use crate::key::DependencyIndex; use crate::plumbing::{Jar, JarAux}; use crate::table::memo::MemoTable; @@ -225,8 +225,8 @@ where _db: &dyn Database, _input: Option, revision: Revision, - ) -> bool { - revision < self.reset_at + ) -> MaybeChangedAfter { + MaybeChangedAfter::from(revision < self.reset_at) } fn cycle_recovery_strategy(&self) -> crate::cycle::CycleRecoveryStrategy { diff --git a/src/key.rs b/src/key.rs index 67ce3e3e8..2eb7b24e5 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,4 +1,5 @@ -use crate::{cycle::CycleRecoveryStrategy, zalsa::IngredientIndex, Database, Id}; +use crate::{ cycle::CycleRecoveryStrategy, + ingredient::MaybeChangedAfter, zalsa::IngredientIndex, Database, Id}; /// An integer that uniquely identifies a particular query instance within the /// database. Used to track dependencies between queries. Fully ordered and @@ -51,7 +52,7 @@ impl DependencyIndex { &self, db: &dyn Database, last_verified_at: crate::Revision, - ) -> bool { + ) -> MaybeChangedAfter { db.zalsa() .lookup_ingredient(self.ingredient_index) .maybe_changed_after(db, self.key_index, last_verified_at) diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 97d3f9680..d43b1737d 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -6,7 +6,7 @@ use tracked_field::FieldIngredientImpl; use crate::{ accumulator::accumulated_map::InputAccumulatedValues, cycle::CycleRecoveryStrategy, - ingredient::{fmt_index, Ingredient, Jar, JarAux}, + ingredient::{fmt_index, Ingredient, Jar, JarAux, MaybeChangedAfter}, key::{DatabaseKeyIndex, DependencyIndex}, plumbing::ZalsaLocal, runtime::StampedValue, @@ -587,8 +587,8 @@ where _db: &dyn Database, _input: Option, _revision: Revision, - ) -> bool { - false + ) -> MaybeChangedAfter { + MaybeChangedAfter::No(InputAccumulatedValues::Empty) } fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index ff1909397..112b9eed3 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -1,6 +1,10 @@ use std::marker::PhantomData; -use crate::{ingredient::Ingredient, zalsa::IngredientIndex, Database, Id}; +use crate::{ + ingredient::{Ingredient, MaybeChangedAfter}, + zalsa::IngredientIndex, + Database, Id, +}; use super::{Configuration, Value}; @@ -53,12 +57,12 @@ where db: &'db dyn Database, input: Option, revision: crate::Revision, - ) -> bool { + ) -> MaybeChangedAfter { let zalsa = db.zalsa(); let id = input.unwrap(); let data = >::data(zalsa.table(), id); let field_changed_at = data.revisions[self.field_index]; - field_changed_at > revision + MaybeChangedAfter::from(field_changed_at > revision) } fn origin( diff --git a/tests/accumulated_backdate.rs b/tests/accumulated_backdate.rs new file mode 100644 index 000000000..9361d65c2 --- /dev/null +++ b/tests/accumulated_backdate.rs @@ -0,0 +1,79 @@ +//! Tests that accumulated values are correctly accounted for +//! when backdating a value. + +mod common; +use common::LogDatabase; + +use expect_test::expect; +use salsa::{Accumulator, Setter}; +use test_log::test; + +#[salsa::input] +struct File { + content: String, +} + +#[salsa::accumulator] +struct Log(#[allow(dead_code)] String); + +#[salsa::tracked] +fn compile(db: &dyn LogDatabase, input: File) -> u32 { + dbg!("Compile"); + + parse(db, input) +} + +#[salsa::tracked] +fn parse(db: &dyn LogDatabase, input: File) -> u32 { + let value: Result = input.content(db).parse(); + + match dbg!(value) { + Ok(value) => value, + Err(error) => { + Log(error.to_string()).accumulate(db); + 0 + } + } +} + +#[test] +fn backdate() { + let mut db = common::LoggerDatabase::default(); + + let input = File::new(&db, "0".to_string()); + + let logs = compile::accumulated::(&db, input); + expect![[r#"[]"#]].assert_eq(&format!("{:#?}", logs)); + + input.set_content(&mut db).to("a".to_string()); + let logs = compile::accumulated::(&db, input); + + expect![[r#" + [ + Log( + "invalid digit found in string", + ), + ]"#]] + .assert_eq(&format!("{:#?}", logs)); +} + +#[test] +fn backdate_no_diagnostics() { + let mut db = common::LoggerDatabase::default(); + + let input = File::new(&db, "a".to_string()); + + let logs = compile::accumulated::(&db, input); + expect![[r#" + [ + Log( + "invalid digit found in string", + ), + ]"#]] + .assert_eq(&format!("{:#?}", logs)); + + input.set_content(&mut db).to("0".to_string()); + let logs = compile::accumulated::(&db, input); + + expect![[r#"[]"#]].assert_eq(&format!("{:#?}", logs)); +}