Skip to content

Commit

Permalink
Update block flattener to correctly flatten concurrent blocks
Browse files Browse the repository at this point in the history
BlockStmts now have a flag that indicates whether blocks are procedural or concurrent. When two blocks are flattened inside a concurrent block, we create extra variables for any read variables that are written to, to avoid variables being read by one block after they are written to by another.
  • Loading branch information
polgreen committed May 9, 2024
1 parent 4b4b2a4 commit b0c07a9
Show file tree
Hide file tree
Showing 26 changed files with 260 additions and 95 deletions.
4 changes: 2 additions & 2 deletions src/main/scala/uclid/SymbolicSimulator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1868,7 +1868,7 @@ class SymbolicSimulator (module : Module) {
case AssignStmt(lhss,rhss) =>
val es = rhss.map(i => evaluate(i, symbolTable, frameTable, frameNumber, scope));
return simulateAssign(lhss, es, symbolTable, label)
case BlockStmt(vars, stmts) =>
case BlockStmt(vars, stmts, _ ) =>
val declaredVars = vars.flatMap(vs => vs.ids.map(v => (v, vs.typ)))
val initSymbolTable = symbolTable
val localSymbolTable = declaredVars.foldLeft(initSymbolTable) {
Expand Down Expand Up @@ -1929,7 +1929,7 @@ class SymbolicSimulator (module : Module) {
}
case AssignStmt(lhss,rhss) =>
return lhss.map(lhs => lhs.ident).toSet
case BlockStmt(vars, stmts) =>
case BlockStmt(vars, stmts, _) =>
val declaredVars : Set[Identifier] = vars.flatMap(vs => vs.ids.map(id => id)).toSet
return writeSets(stmts) -- declaredVars
case IfElseStmt(e,then_branch,else_branch) =>
Expand Down
5 changes: 4 additions & 1 deletion src/main/scala/uclid/UclidMain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ object UclidMain {
val newInitDecl = initAccDecls match {
case Some(initAcc) => initModuleDecls match {
case Some(initMod) => List(InitDecl(BlockStmt(List[BlockVarsDecl](),
List(initAcc.asInstanceOf[InitDecl].body, initMod.asInstanceOf[InitDecl].body))))
List(initAcc.asInstanceOf[InitDecl].body, initMod.asInstanceOf[InitDecl].body), true)))
case None => List(initAcc)
}
case None => initModuleDecls match {
Expand Down Expand Up @@ -468,6 +468,9 @@ object UclidMain {
passManager.addPass(new ModuleTypeChecker())
// optimisation, has previously been called
passManager.addPass(new SemanticAnalyzer())
// reorder statements if necessary.
// Pass MUST be run after variable renamers
//passManager.addPass(new BlockSorter())
// known bugs in the following passes
if (config.enumToNumeric) passManager.addPass(new EnumTypeAnalysis())
if (config.enumToNumeric) passManager.addPass(new EnumTypeRenamer("BV"))
Expand Down
42 changes: 42 additions & 0 deletions src/main/scala/uclid/lang/ASTVisitorUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,48 @@ class ExprRewriter(name: String, rewrites : Map[Expr, Expr])
}
}

// used to rewrite only the read expressions
class ReadSetExprRewriter(name: String, rewrites : Map[Expr, Expr])
extends ASTRewriter(name, new ExprRewriterPass(rewrites))
{
def rewriteExpr(e : Expr, context : Scope) : Expr = {
e match {
case OperatorApplication(OldOperator(), _) => e
case OperatorApplication(HistoryOperator(), _) => e
case _ => visitExpr(e, context).get
}
}

def rewriteStatements(stmts : List[Statement], context : Scope) : List[Statement] = {
return stmts.flatMap(visitStatement(_, context))
}

def rewriteStatement(stmt : Statement, context : Scope) : Option[Statement] = {
visitStatement(stmt, context)
}

// do nothing for the LHS
override def visitLhs(lhs: Lhs, context: Scope): Option[Lhs] = Some(lhs)

override def visitOperatorApp(opapp : OperatorApplication, context : Scope) : Option[Expr] = {

opapp match {
case OperatorApplication(HistoryOperator(), _) => {
Some(opapp)
}
case OperatorApplication(OldOperator(), _) => Some(opapp)
case _ => {
val opAppP = visitOperator(opapp.op, context).flatMap((op) => {
pass.rewriteOperatorApp(OperatorApplication(op, opapp.operands.map(visitExpr(_, context + opapp)).flatten), context)
})
return ASTNode.introducePos(true, true, opAppP, opapp.position) }
}
}

}



// This class has been modified to handle the abstract class: ModifiableEntity.
class OldExprRewriterPass(rewrites : Map[ModifiableEntity, Identifier]) extends RewritePass
{
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/uclid/lang/ASTVistors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1991,7 +1991,7 @@ class ASTRewriter (_passName : String, _pass: RewritePass, setFilename : Boolean
log.debug("visitBlockStatement\n{}", Utils.join(blkStmt.toLines, "\n"))
val contextP = context + blkStmt.vars
val varsP = blkStmt.vars.map(v => visitBlockVars(v, contextP)).flatten
val blkStmtP1 = BlockStmt(varsP, blkStmt.stmts.flatMap(st => visitStatement(st, contextP)))
val blkStmtP1 = BlockStmt(varsP, blkStmt.stmts.flatMap(st => visitStatement(st, contextP)), blkStmt.isProcedural)
val blkStmtP = pass.rewriteBlock(blkStmtP1, context)
return ASTNode.introducePos(setPosition, setFilename, blkStmtP, blkStmt.position)
}
Expand Down
66 changes: 63 additions & 3 deletions src/main/scala/uclid/lang/BlockFlattener.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ package uclid
package lang

import com.typesafe.scalalogging.Logger
import java.util.jar.Attributes.Name


class BlockVariableRenamerPass extends RewritePass {
def renameVarList (vars : List[(Identifier, Type)], context : Scope) : List[(Identifier, Identifier, Type)] = {
Expand All @@ -63,7 +65,7 @@ class BlockVariableRenamerPass extends RewritePass {
val rewriter = new ExprRewriter("BlockVariableRenamerPass:Block", rewriteMap)
val statementsP = rewriter.rewriteStatements(blkStmt.stmts, context + blkStmt.vars)
val varsP = varTuples.map(p => BlockVarsDecl(List(p._2), p._3))
Some(BlockStmt(varsP, statementsP))
Some(BlockStmt(varsP, statementsP, blkStmt.isProcedural))
}
override def rewriteProcedure(proc : ProcedureDecl, context : Scope) : Option[ProcedureDecl] = {
val argTuples = renameVarList(proc.sig.inParams, context)
Expand Down Expand Up @@ -138,10 +140,65 @@ class BlockFlattenerPass extends RewritePass {
val stmtsP = rewriter.rewriteStatements(blk.stmts, context + blk.vars)
(stmtsP, varDecls)
}

def addConcurrentVars (blkStmt : BlockStmt, context: Scope) : BlockStmt = {
val filteredStmts = blkStmt.stmts.filter(_.isInstanceOf[BlockStmt])

if(filteredStmts.size != blkStmt.stmts.size)
logger.debug("BlockFlattener: block contains blk statements and other statements")

val nonSequentialBlockCount = filteredStmts.count(_.asInstanceOf[BlockStmt].isProcedural == false)
logger.debug("Number of blocks: " + filteredStmts.size.toString())

if(!blkStmt.isProcedural && filteredStmts.size >1)
{
val reads = filteredStmts.foldLeft(Set.empty[Identifier]) {
(acc, blk) => {
val readSet = StatementScheduler.readSets(blk.asInstanceOf[BlockStmt].stmts, context)
acc ++ readSet
}
}.filter(id => context.map.contains(id) && context.map(id).isInstanceOf[Scope.StateVar] && !id.name.startsWith("__ucld"))

val writes = filteredStmts.foldLeft(Set.empty[Identifier]) {
(acc, blk) => {
val writeSet = StatementScheduler.writeSets(blk.asInstanceOf[BlockStmt].stmts, context)
acc ++ writeSet
}
}.filter(id => context.map.contains(id) && context.map(id).isInstanceOf[Scope.StateVar])

// create new vars. We only need new variables for the reads that are also written to
// because there should only be
// one write to a variable in a concurrent block. Blocks with more than one write will have been
// caught earlier
val varPairs: Map[Expr, Expr] =
reads.intersect(writes).map(
id => (id.asInstanceOf[Expr] -> NameProvider.get("block_" + id.toString()).asInstanceOf[Expr])).toMap
logger.debug("New vars: " + varPairs.toString())

val rewriter = new ReadSetExprRewriter("BlockFlattener:Rewrite", varPairs)
val stmtsP = rewriter.rewriteStatements(blkStmt.stmts, context + blkStmt.vars)

// create variable declarations for the new read variables.
val vars = varPairs.map(p => BlockVarsDecl(List(p._2.asInstanceOf[Identifier]), context.map(p._1.asInstanceOf[Identifier]).asInstanceOf[Scope.StateVar].typ))
// create assign statements for the new variables
val readVarAssigns = varPairs.map(p => AssignStmt(List(LhsId(p._2.asInstanceOf[Identifier])), List(p._1.asInstanceOf[Expr]))).toList

// new block statement
val blkStmtP = BlockStmt(blkStmt.vars ++ vars, readVarAssigns ++ stmtsP, blkStmt.isProcedural)
logger.debug("New block statement:\n" + blkStmtP.toString())
blkStmtP
}
else{
blkStmt
}
}

override def rewriteBlock(blkStmt : BlockStmt, context : Scope) : Option[Statement] = {
logger.debug("==> [%s] Input:\n%s".format(analysis.passName, blkStmt.toString()))
val init = (List.empty[Statement], Map.empty[Identifier, Type])
val (stmtsP, mapOut) = blkStmt.stmts.foldLeft(init) {

val blkStmtP = addConcurrentVars(blkStmt, context)
val (stmtsP, mapOut) = blkStmtP.stmts.foldLeft(init) {
(acc, st) => {
val (stP, mapOut) = st match {
case blk : BlockStmt => renameBlock(blk, context, acc._2)
Expand All @@ -150,8 +207,9 @@ class BlockFlattenerPass extends RewritePass {
(acc._1 ++ stP, mapOut)
}
}

val vars = mapOut.map(p => BlockVarsDecl(List(p._1), p._2))
val result = BlockStmt(blkStmt.vars ++ vars, stmtsP)
val result = BlockStmt(blkStmtP.vars ++ vars, stmtsP, blkStmt.isProcedural)
logger.debug("<== Result:\n" + result.toString())
Some(result)
}
Expand All @@ -170,6 +228,8 @@ class BlockFlattener() extends ASTRewriter(BlockFlattener.getName(), new BlockFl
override val repeatUntilNoChange = true
}



object Optimizer {
var index = 0
def getName() : String = {
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/uclid/lang/LoopUnroller.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class ForLoopRewriterPass(forStmtsToRewrite: Set[ForStmt]) extends RewritePass {
rewriter.rewriteStatement(st.body, ctx)
}
val stmts = (low to high).foldLeft(List.empty[Statement])((acc, i) => acc ++ rewriteForValue(i).toList)
Some(BlockStmt(List.empty, stmts))
Some(BlockStmt(List.empty, stmts, true))
} else {
Some(st)
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/uclid/lang/MacroRewriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class MacroReplacerPass(macroId : Identifier, newMacroBody : BlockStmt) extends
case _ =>
}
}
BlockStmt(st.vars, leftStmts)
BlockStmt(st.vars, leftStmts, false)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/uclid/lang/ModSetAnalysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ModSetRewriterPass() extends RewritePass {
* @param modSetMap The modifies set map inferred by the ModSetAnalysis pass. Should contain a map from procedures to thier inferred modifies sets.
*/
def getStmtModSet(stmt: Statement, modSetMap: Map[Identifier, Set[ModifiableEntity]], varIdSet: Set[Identifier], locVarIdSet: Set[Identifier]): Set[ModifiableEntity] = stmt match {
case BlockStmt(vars, stmts) => {
case BlockStmt(vars, stmts,_) => {
val locVarIdSetP = vars.foldLeft(locVarIdSet)((acc, bvd) => acc ++ bvd.ids.toSet)
stmts.foldLeft(Set.empty[ModifiableEntity])((acc, stmt) => acc ++ getStmtModSet(stmt, modSetMap, varIdSet, locVarIdSetP))
}
Expand Down
Loading

0 comments on commit b0c07a9

Please sign in to comment.