diff --git a/src/main/resources/logback.xml b/src/main/resources/logback.xml
index ca540748f..29eb63363 100644
--- a/src/main/resources/logback.xml
+++ b/src/main/resources/logback.xml
@@ -22,7 +22,7 @@
-
+
diff --git a/src/main/scala/uclid/UclidMain.scala b/src/main/scala/uclid/UclidMain.scala
index 1874ceeae..6adf31d67 100644
--- a/src/main/scala/uclid/UclidMain.scala
+++ b/src/main/scala/uclid/UclidMain.scala
@@ -468,6 +468,9 @@ object UclidMain {
passManager.addPass(new ModuleTypeChecker())
// optimisation, has previously been called
passManager.addPass(new SemanticAnalyzer())
+ // reorder statements if necessary.
+ // Pass MUST be run after variable renamers
+ passManager.addPass(new BlockSorter())
// known bugs in the following passes
if (config.enumToNumeric) passManager.addPass(new EnumTypeAnalysis())
if (config.enumToNumeric) passManager.addPass(new EnumTypeRenamer("BV"))
diff --git a/src/main/scala/uclid/lang/BlockFlattener.scala b/src/main/scala/uclid/lang/BlockFlattener.scala
index 85915e37d..81993b478 100644
--- a/src/main/scala/uclid/lang/BlockFlattener.scala
+++ b/src/main/scala/uclid/lang/BlockFlattener.scala
@@ -138,9 +138,11 @@ class BlockFlattenerPass extends RewritePass {
val stmtsP = rewriter.rewriteStatements(blk.stmts, context + blk.vars)
(stmtsP, varDecls)
}
+
override def rewriteBlock(blkStmt : BlockStmt, context : Scope) : Option[Statement] = {
logger.debug("==> [%s] Input:\n%s".format(analysis.passName, blkStmt.toString()))
val init = (List.empty[Statement], Map.empty[Identifier, Type])
+
val (stmtsP, mapOut) = blkStmt.stmts.foldLeft(init) {
(acc, st) => {
val (stP, mapOut) = st match {
@@ -150,6 +152,7 @@ class BlockFlattenerPass extends RewritePass {
(acc._1 ++ stP, mapOut)
}
}
+
val vars = mapOut.map(p => BlockVarsDecl(List(p._1), p._2))
val result = BlockStmt(blkStmt.vars ++ vars, stmtsP)
logger.debug("<== Result:\n" + result.toString())
@@ -170,6 +173,61 @@ class BlockFlattener() extends ASTRewriter(BlockFlattener.getName(), new BlockFl
override val repeatUntilNoChange = true
}
+// This pass changes the order of any statements in a block so that
+// assignments to state variables are moved to the end of the block.
+// This avoids issues where one submodule reads from a variable after another
+// submodule has written to it, without introducing additional variables
+// It must be run after the procedures have been inlined and converted into SSA.
+class BlockSorterPass extends RewritePass {
+ lazy val logger = Logger(classOf[BlockSorterPass])
+
+ override def rewriteBlock(blkStmt : BlockStmt, context : Scope) : Option[Statement] = {
+ logger.debug("==> [%s] Input:\n%s".format(analysis.passName, blkStmt.toString()))
+ val init = (List.empty[Statement], Map.empty[Identifier, Type])
+
+ def isStateVarAssign(st : Statement) : Boolean = {
+ st match {
+ case AssignStmt(lhss, rhs) => {
+ lhss.exists {
+ case LhsId(id) => context.map.contains(id) && context.map(id).isInstanceOf[Scope.StateVar] && !(id.toString contains "_ucld_")
+ case _ => false
+ }
+ }
+ case _ => false
+ }
+ }
+
+ val endStmts = blkStmt.stmts.filter(st => isStateVarAssign(st))
+ logger.debug("Moving the following statements to the end of the block: " + endStmts.toString())
+ val filteredStmts = blkStmt.stmts.filter(endStmts.contains(_)==false)
+ logger.debug("Moving these statements to the start of the block " + filteredStmts.toString())
+ val result = BlockStmt(blkStmt.vars, filteredStmts++endStmts)
+ logger.debug("<== Result:\n" + result.toString())
+ Some(result)
+ }
+}
+
+object BlockSorter {
+ var index = 0
+ def getName() : String = {
+ index += 1
+ "BlockSorter:" + index.toString()
+ }
+}
+
+class BlockSorter() extends ASTRewriter(BlockSorter.getName(), new BlockSorterPass())
+{
+ override val repeatUntilNoChange = true
+ // Don't reorder procedural code
+ override def visitInit(init : InitDecl, context : Scope) : Option[InitDecl] = Some(init)
+ // Don't reorder procedural code
+ override def visitProcedure(proc : ProcedureDecl, contextIn : Scope) : Option[ProcedureDecl] = Some(proc);
+
+}
+
+
+
+
object Optimizer {
var index = 0
def getName() : String = {
diff --git a/src/test/scala/VerifierSpec.scala b/src/test/scala/VerifierSpec.scala
index f8c0104bc..bf698bac1 100644
--- a/src/test/scala/VerifierSpec.scala
+++ b/src/test/scala/VerifierSpec.scala
@@ -484,6 +484,9 @@ class ModuleVerifSpec extends AnyFlatSpec {
"test-module-import-0.ucl" should "verify all assertions." in {
VerifierSpec.expectedFails("./test/test-module-import-0.ucl", 0)
}
+ "test-module-ordering.ucl" should "verify all assertions." in {
+ VerifierSpec.expectedFails("./test/test-module-ordering.ucl", 0)
+ }
"test-type-import.ucl" should "verify all assertions." in {
VerifierSpec.expectedFails("./test/test-type-import.ucl", 0)
}
diff --git a/test/test-module-ordering.ucl b/test/test-module-ordering.ucl
new file mode 100644
index 000000000..6683ba88d
--- /dev/null
+++ b/test/test-module-ordering.ucl
@@ -0,0 +1,45 @@
+module test {
+ input a : integer;
+ output b : integer;
+
+ init {
+ b = 0;
+ }
+
+ next {
+ b' = a+1;
+ }
+}
+
+module main {
+ var x: integer;
+ var y: integer;
+
+ // test1 reads in x and updates y'=x+1
+ instance test1 : test(a : (x), b: (y));
+ // test2 reads in y and updates x'=y+1
+ instance test2 : test(a : (y), b: (x));
+
+ init {
+ x = 0;
+ y = 0;
+
+ }
+
+ next {
+ // both assertions should pass regardless of the ordering of these statements
+ next(test1);
+ next(test2);
+ }
+
+ invariant test1_lt2: test1.b < 2;
+ invariant test2lt2: test2.b < 2;
+
+ control {
+ print_module;
+ v = bmc(1);
+ check;
+ print_results;
+ v.print_cex;
+ }
+}