Theory ConditionalElimination

section ‹Conditional Elimination Phase›

text ‹
This theory presents the specification of the \texttt{ConditionalElimination} phase
within the GraalVM compiler.
The \texttt{ConditionalElimination} phase simplifies any condition of an \textsl{if}
statement that can be implied by the conditions that dominate it.
Such that if condition A implies that condition B \textsl{must} be true,
the condition B is simplified to \texttt{true}.

\begin{lstlisting}[language=java]
if (A) {
  if (B) {
    ...
  }
}
\end{lstlisting}

We begin by defining the individual implication rules used by the phase
in \ref{sec:rules}.
These rules are then lifted to the rewriting of a condition within an \textsl{if}
statement in \ref{sec:lift}.
The traversal algorithm used by the compiler is specified in \ref{sec:traversal}.
›

theory ConditionalElimination
  imports
    Semantics.IRTreeEvalThms
    Proofs.Rewrites
    Proofs.Bisimulation
    OptimizationDSL.Markup
begin

declare [[show_types=false]]

subsection ‹Implication Rules \label{sec:rules}›

text ‹
The set of rules used for determining whether a condition, @{term q1},
 implies another condition, @{term q2}, must be true or false.
›

subsubsection ‹Structural Implication›

text ‹
The first method for determining if a condition can be implied by another condition,
is structural implication.
That is, by looking at the structure of the conditions, we can determine the truth value.
For instance, @{term "x == y"} implies that @{term "x < y"} cannot be true.
›

inductive 
  impliesx :: "IRExpr  IRExpr  bool" ("_  _") and 
  impliesnot :: "IRExpr  IRExpr  bool" ("_ ⇛¬ _") where
  same:          "q  q" |
  eq_not_less:   "exp[x eq y] ⇛¬ exp[x < y]" |
  eq_not_less':  "exp[x eq y] ⇛¬ exp[y < x]" |
  less_not_less: "exp[x < y] ⇛¬ exp[y < x]" |
  less_not_eq:   "exp[x < y] ⇛¬ exp[x eq y]" |
  less_not_eq':  "exp[x < y] ⇛¬ exp[y eq x]" |
  negate_true:   "x ⇛¬ y  x  exp[!y]" |
  negate_false:  "x  y  x ⇛¬ exp[!y]"

inductive implies_complete :: "IRExpr  IRExpr  bool option  bool" where
  implies:
  "x  y  implies_complete x y (Some True)" |
  impliesnot:
  "x ⇛¬ y  implies_complete x y (Some False)" |
  fail:
  "¬((x  y)  (x ⇛¬ y))  implies_complete x y None"


text ‹
The relation @{term "q1  q2"} requires that the implication @{term "q1  q2"}
is known true (i.e. universally valid).
The relation @{term "q1 ⇛¬ q2"} requires that the implication @{term "q1  q2"}
is known false (i.e. @{term "q1 ¬ q2"} is universally valid).
If neither @{term "q1  q2"} nor @{term "q1 ⇛¬ q2"} then the status is unknown
and the condition cannot be simplified.
›

fun implies_valid :: "IRExpr  IRExpr  bool" (infix "" 50) where
  "implies_valid q1 q2 = 
    (m p v1 v2. ([m, p]  q1  v1)  ([m,p]  q2  v2)  
            (val_to_bool v1  val_to_bool v2))"

fun impliesnot_valid :: "IRExpr  IRExpr  bool" (infix "" 50) where
  "impliesnot_valid q1 q2 = 
    (m p v1 v2. ([m, p]  q1  v1)  ([m,p]  q2  v2)  
            (val_to_bool v1  ¬val_to_bool v2))"

text ‹
The relation @{term "q1  q2"} means @{term "q1  q2"} is universally valid, 
and the relation @{term "q1  q2"} means @{term "q1  ¬q2"} is universally valid.
›

lemma eq_not_less_val:
  "val_to_bool(val[v1 eq v2])  ¬val_to_bool(val[v1 < v2])"
  proof -
  have unfoldEqualDefined: "(intval_equals v1 v2  UndefVal) 
        (val_to_bool(intval_equals v1 v2)  (¬(val_to_bool(intval_less_than v1 v2))))"
    subgoal premises p
  proof -
    obtain v1b v1v where v1v: "v1 = IntVal v1b v1v"
      by (metis array_length.cases intval_equals.simps(2,3,4,5) p)
    obtain v2b v2v where v2v: "v2 = IntVal v2b v2v"
      by (metis Value.exhaust_sel intval_equals.simps(6,7,8,9) p)
    have sameWidth: "v1b=v2b"
      by (metis bool_to_val_bin.simps intval_equals.simps(1) p v1v v2v)
    have unfoldEqual: "intval_equals v1 v2 = (bool_to_val (v1v=v2v))"
      by (simp add: sameWidth v1v v2v)
    have unfoldLessThan: "intval_less_than v1 v2 = (bool_to_val (int_signed_value v1b v1v < int_signed_value v2b v2v))"
      by (simp add: sameWidth v1v v2v)
    have val: "((v1v=v2v))  (¬((int_signed_value v1b v1v < int_signed_value v2b v2v)))"
      using sameWidth by auto
    have doubleCast0: "val_to_bool (bool_to_val ((v1v = v2v))) = (v1v = v2v)"
      using bool_to_val.elims val_to_bool.simps(1) by fastforce
    have doubleCast1: "val_to_bool (bool_to_val ((int_signed_value v1b v1v < int_signed_value v2b v2v))) =
                                                 (int_signed_value v1b v1v < int_signed_value v2b v2v)"
      using bool_to_val.elims val_to_bool.simps(1) by fastforce
    then show ?thesis
      using p val unfolding unfoldEqual unfoldLessThan doubleCast0 doubleCast1 by blast
  qed done
  show ?thesis
    by (metis Value.distinct(1) val_to_bool.elims(2) unfoldEqualDefined)
qed

lemma eq_not_less'_val:
  "val_to_bool(val[v1 eq v2])  ¬val_to_bool(val[v2 < v1])"
proof -
  have a: "intval_equals v1 v2 = intval_equals v2 v1"
    apply (cases "intval_equals v1 v2 = UndefVal")
    apply (smt (z3) bool_to_val_bin.simps intval_equals.elims intval_equals.simps)
    subgoal premises p
    proof -
      obtain v1b v1v where v1v: "v1 = IntVal v1b v1v"
        by (metis Value.exhaust_sel intval_equals.simps(2,3,4,5) p)
      obtain v2b v2v where v2v: "v2 = IntVal v2b v2v"
        by (metis Value.exhaust_sel intval_equals.simps(6,7,8,9) p)
      then show ?thesis
        by (smt (verit) bool_to_val_bin.simps intval_equals.simps(1) v1v)
    qed done
  show ?thesis
    using a eq_not_less_val by presburger
qed

lemma less_not_less_val:
  "val_to_bool(val[v1 < v2])  ¬val_to_bool(val[v2 < v1])"
  apply (rule impI)
  subgoal premises p
  proof -
    obtain v1b v1v where v1v: "v1 = IntVal v1b v1v"
      by (metis Value.exhaust_sel intval_less_than.simps(2,3,4,5) p val_to_bool.simps(2))
    obtain v2b v2v where v2v: "v2 = IntVal v2b v2v"
      by (metis Value.exhaust_sel intval_less_than.simps(6,7,8,9) p val_to_bool.simps(2))
    then have unfoldLessThanRHS: "intval_less_than v2 v1 =
                                 (bool_to_val (int_signed_value v2b v2v < int_signed_value v1b v1v))"
      using p v1v by force
    then have unfoldLessThanLHS: "intval_less_than v1 v2 =
                                 (bool_to_val (int_signed_value v1b v1v < int_signed_value v2b v2v))"
      using bool_to_val_bin.simps intval_less_than.simps(1) p v1v v2v val_to_bool.simps(2) by auto
    then have symmetry: "(int_signed_value v2b v2v < int_signed_value v1b v1v) 
                       (¬(int_signed_value v1b v1v < int_signed_value v2b v2v))"
      by simp
    then show ?thesis
      using p unfoldLessThanLHS unfoldLessThanRHS by fastforce
  qed done

lemma less_not_eq_val:
  "val_to_bool(val[v1 < v2])  ¬val_to_bool(val[v1 eq v2])"
  using eq_not_less_val by blast 

lemma logic_negate_type:
  assumes "[m, p]  UnaryExpr UnaryLogicNegation x  v"
  shows "b v2. [m, p]  x  IntVal b v2"
  using assms
  by (metis UnaryExprE intval_logic_negation.elims unary_eval.simps(4))

lemma intval_logic_negation_inverse:
  assumes "b > 0"
  assumes "x = IntVal b v"
  shows "val_to_bool (intval_logic_negation x)  ¬(val_to_bool x)"
  using assms by (cases x; auto simp: logic_negate_def) 

lemma logic_negation_relation_tree:
  assumes "[m, p]  y  val"
  assumes "[m, p]  UnaryExpr UnaryLogicNegation y  invval"
  shows "val_to_bool val  ¬(val_to_bool invval)"
  using assms using intval_logic_negation_inverse
  by (metis UnaryExprE evalDet eval_bits_1_64 logic_negate_type unary_eval.simps(4))

text ‹The following theorem show that the known true/false rules are valid.›

theorem implies_impliesnot_valid:
  shows "((q1  q2)  (q1  q2)) 
         ((q1 ⇛¬ q2)  (q1  q2))"
          (is "(?imp  ?val)  (?notimp  ?notval)")
proof (induct q1 q2  rule: impliesx_impliesnot.induct)
  case (same q)
  then show ?case 
    using evalDet by fastforce
next
  case (eq_not_less x y)
  then show ?case apply auto[1] using eq_not_less_val evalDet by blast
next
  case (eq_not_less' x y)
  then show ?case apply auto[1] using eq_not_less'_val evalDet by blast
next
  case (less_not_less x y)
  then show ?case apply auto[1] using less_not_less_val evalDet by blast
next
  case (less_not_eq x y)
  then show ?case apply auto[1] using less_not_eq_val evalDet by blast
next
  case (less_not_eq' x y)
  then show ?case apply auto[1] using eq_not_less'_val evalDet by metis
next
  case (negate_true x y)
  then show ?case apply auto[1]
    by (metis logic_negation_relation_tree unary_eval.simps(4) unfold_unary)
next
  case (negate_false x y)
  then show ?case apply auto[1]
    by (metis UnaryExpr logic_negation_relation_tree unary_eval.simps(4)) 
qed


subsubsection ‹Type Implication›

text ‹
The second mechanism to determine whether a condition implies another is
to use the type information of the relevant nodes.
For instance, @{term "x < 4"} implies @{term "x < 10"}.
We can show this by strengthening the type, stamp,
of the node @{term x} such that the upper bound is @{term 4}.
Then we the second condition is reached,
we know that the condition must be true by the upperbound.
›

text ‹
The following relation corresponds to the \texttt{UnaryOpLogicNode.tryFold}
and \texttt{BinaryOpLogicNode.tryFold} methods and their associated
concrete implementations.

We track the refined stamps by mapping nodes to Stamps,
the second parameter to @{term tryFold}.
›

inductive tryFold :: "IRNode  (ID  Stamp)  bool  bool"
  where
  "alwaysDistinct (stamps x) (stamps y) 
     tryFold (IntegerEqualsNode x y) stamps False" |
  "neverDistinct (stamps x) (stamps y) 
     tryFold (IntegerEqualsNode x y) stamps True" |
  "is_IntegerStamp (stamps x);
    is_IntegerStamp (stamps y);
    stpi_upper (stamps x) < stpi_lower (stamps y) 
     tryFold (IntegerLessThanNode x y) stamps True" |
  "is_IntegerStamp (stamps x);
    is_IntegerStamp (stamps y);
    stpi_lower (stamps x)  stpi_upper (stamps y) 
     tryFold (IntegerLessThanNode x y) stamps False"

code_pred (modes: i  i  i  bool) tryFold .

text ‹
Prove that, when the stamp map is valid,
the @{term tryFold} relation correctly predicts the output value with respect to
our evaluation semantics.
›

inductive_cases StepE:
  "g, p  (nid,m,h)  (nid',m',h)"


lemma is_stamp_empty_valid:
  assumes "is_stamp_empty s"
  shows "¬( val. valid_value val s)"
  using assms is_stamp_empty.simps apply (cases s; auto)
  by (metis linorder_not_le not_less_iff_gr_or_eq order.strict_trans valid_value.elims(2) valid_value.simps(1) valid_value.simps(5))

lemma join_valid:
  assumes "is_IntegerStamp s1  is_IntegerStamp s2"
  assumes "valid_stamp s1  valid_stamp s2"
  shows "(valid_value v s1  valid_value v s2) = valid_value v (join s1 s2)" (is "?lhs = ?rhs")
proof
  assume ?lhs
  then show ?rhs 
   using assms(1) apply (cases s1; cases s2; auto)
   apply (metis Value.inject(1) valid_int)
  by (smt (z3) valid_int valid_stamp.simps(1) valid_value.simps(1))
  next
  assume ?rhs
  then show ?lhs
    using assms apply (cases s1; cases s2; simp)
  by (smt (verit, best) assms(2) valid_int valid_value.simps(1) valid_value.simps(22))
qed

lemma alwaysDistinct_evaluate:
  assumes "wf_stamp g stamps"
  assumes "alwaysDistinct (stamps x) (stamps y)"
  assumes "is_IntegerStamp (stamps x)  is_IntegerStamp (stamps y)  valid_stamp (stamps x)  valid_stamp (stamps y)"
  shows "¬( val . ([g, m, p]  x  val)  ([g, m, p]  y  val))"
proof -
  obtain stampx stampy where stampdef: "stampx = stamps x  stampy = stamps y"
    by simp
  then have xv: " xv . ([g, m, p]  x  xv)  valid_value xv stampx"
    by (meson assms(1) encodeeval.simps eval_in_ids wf_stamp.elims(2))
  from stampdef have yv: " yv . ([g, m, p]  y  yv)  valid_value yv stampy"
    by (meson assms(1) encodeeval.simps eval_in_ids wf_stamp.elims(2))
  have "v. valid_value v (join stampx stampy) = (valid_value v stampx  valid_value v stampy)"
    using assms(3)
    by (simp add: join_valid stampdef)
  then show ?thesis
    using assms unfolding alwaysDistinct.simps
    using is_stamp_empty_valid stampdef xv yv by blast
qed

lemma alwaysDistinct_valid:
  assumes "wf_stamp g stamps"
  assumes "kind g nid = (IntegerEqualsNode x y)"
  assumes "[g, m, p]  nid  v"
  assumes "alwaysDistinct (stamps x) (stamps y)"
  shows "¬(val_to_bool v)"
proof -
  have no_valid: " val. ¬(valid_value val (join (stamps x) (stamps y)))"
    by (smt (verit, best) is_stamp_empty.elims(2) valid_int valid_value.simps(1) assms(1,4)
        alwaysDistinct.simps)
  obtain xe ye where repr: "rep g nid (BinaryExpr BinIntegerEquals xe ye)"
    by (metis assms(2) assms(3) encodeeval.simps rep_integer_equals)
  moreover have evale: "[m, p]  (BinaryExpr BinIntegerEquals xe ye)  v"
    by (metis assms(3) calculation encodeeval.simps repDet)
  moreover have repsub: "rep g x xe  rep g y ye"
    by (metis IRNode.distinct(1955) IRNode.distinct(1997) IRNode.inject(17) IntegerEqualsNodeE assms(2) calculation)
  ultimately obtain xv yv where evalsub: "[g, m, p]  x  xv  [g, m, p]  y  yv"
    by (meson BinaryExprE encodeeval.simps)
  have xvalid: "valid_value xv (stamps x)"
    using assms(1) encode_in_ids encodeeval.simps evalsub wf_stamp.simps by blast
  then have xint: "is_IntegerStamp (stamps x)"
    using assms(4) valid_value.elims(2) by fastforce
  then have xstamp: "valid_stamp (stamps x)"
    using xvalid apply (cases xv; auto) 
    apply (smt (z3) valid_stamp.simps(6) valid_value.elims(1))
    using is_IntegerStamp_def by fastforce
  have yvalid: "valid_value yv (stamps y)"
    using assms(1) encode_in_ids encodeeval.simps evalsub wf_stamp.simps by blast
  then have yint: "is_IntegerStamp (stamps y)"
    using assms(4) valid_value.elims(2) by fastforce
  then have ystamp: "valid_stamp (stamps y)"
    using yvalid apply (cases yv; auto) 
    apply (smt (z3) valid_stamp.simps(6) valid_value.elims(1))
    using is_IntegerStamp_def by fastforce
  have disjoint: "¬( val . ([g, m, p]  x  val)  ([g, m, p]  y  val))"
    using alwaysDistinct_evaluate
    using assms(1) assms(4) xint yint xvalid yvalid xstamp ystamp by simp
  have "v = bin_eval BinIntegerEquals xv yv"
    by (metis BinaryExprE encodeeval.simps evale evalsub graphDet repsub)
  also have "v  UndefVal"
    using evale by auto
  ultimately have "b1 b2. v =  bool_to_val_bin b1 b2 (xv = yv)"
    unfolding bin_eval.simps
    by (smt (z3) Value.inject(1) bool_to_val_bin.simps intval_equals.elims)
  then show ?thesis
    by (metis (mono_tags, lifting) (v::Value)  UndefVal bool_to_val.elims bool_to_val_bin.simps disjoint evalsub val_to_bool.simps(1))
qed
thm_oracles alwaysDistinct_valid

lemma unwrap_valid:
  assumes "0 < b  b  64"
  assumes "take_bit (b::nat) (vv::64 word) = vv"
  shows "(vv::64 word) = take_bit b (word_of_int (int_signed_value (b::nat) (vv::64 word)))"
  using assms apply auto[1]
  by (simp add: take_bit_signed_take_bit)

lemma asConstant_valid:
  assumes "asConstant s = val"
  assumes "val  UndefVal"
  assumes "valid_value v s"
  shows "v = val"
proof -
  obtain b l h where s: "s = IntegerStamp b l h"
    using assms(1,2) by (cases s; auto)
  obtain vv where vdef: "v = IntVal b vv"
    using assms(3) s valid_int by blast
  have "l  int_signed_value b vv  int_signed_value b vv  h"
    by (metis (v::Value) = IntVal (b::nat) (vv::64 word) assms(3) s valid_value.simps(1))
  then have veq: "int_signed_value b vv = l"
    by (smt (verit) asConstant.simps(1) assms(1) assms(2) s)
  have valdef: "val = new_int b (word_of_int l)"
    by (metis asConstant.simps(1) assms(1) assms(2) s)
  have "take_bit b vv = vv"
    by (metis (v::Value) = IntVal (b::nat) (vv::64 word) assms(3) s valid_value.simps(1))
  then show ?thesis
    using veq vdef valdef
    using assms(3) s unwrap_valid by force
qed

lemma neverDistinct_valid:
  assumes "wf_stamp g stamps"
  assumes "kind g nid = (IntegerEqualsNode x y)"
  assumes "[g, m, p]  nid  v"
  assumes "neverDistinct (stamps x) (stamps y)"
  shows "val_to_bool v"
proof -
  obtain val where constx: "asConstant (stamps x) = val"
    by simp
  moreover have "val  UndefVal"
    using assms(4) calculation by auto
  then have constx: "val = asConstant (stamps y)"
    using calculation assms(4) by force
  obtain xe ye where repr: "rep g nid (BinaryExpr BinIntegerEquals xe ye)"
    by (metis assms(2) assms(3) encodeeval.simps rep_integer_equals)
  moreover have evale: "[m, p]  (BinaryExpr BinIntegerEquals xe ye)  v"
    by (metis assms(3) calculation encodeeval.simps repDet)
  moreover have repsub: "rep g x xe  rep g y ye"
    by (metis IRNode.distinct(1955) IRNode.distinct(1997) IRNode.inject(17) IntegerEqualsNodeE assms(2) calculation)
  ultimately obtain xv yv where evalsub: "[g, m, p]  x  xv  [g, m, p]  y  yv"
    by (meson BinaryExprE encodeeval.simps)
  have xvalid: "valid_value xv (stamps x)"
    using assms(1) encode_in_ids encodeeval.simps evalsub wf_stamp.simps by blast
  then have xint: "is_IntegerStamp (stamps x)"
    using assms(4) valid_value.elims(2) by fastforce
  have yvalid: "valid_value yv (stamps y)"
    using assms(1) encode_in_ids encodeeval.simps evalsub wf_stamp.simps by blast
  then have yint: "is_IntegerStamp (stamps y)"
    using assms(4) valid_value.elims(2) by fastforce
  have eq: "v1 v2. (([g, m, p]  x  v1)  ([g, m, p]  y  v2))  v1 = v2"
    by (metis asConstant_valid assms(4) encodeEvalDet evalsub neverDistinct.elims(1) xvalid yvalid)
  have "v = bin_eval BinIntegerEquals xv yv"
    by (metis BinaryExprE encodeeval.simps evale evalsub graphDet repsub)
  also have "v  UndefVal"
    using evale by auto
  ultimately have "b1 b2. v =  bool_to_val_bin b1 b2 (xv = yv)"
    unfolding bin_eval.simps
    by (smt (z3) Value.inject(1) bool_to_val_bin.simps intval_equals.elims)
  then show ?thesis
    using (v::Value)  UndefVal eq evalsub by fastforce
qed

lemma stampUnder_valid:
  assumes "wf_stamp g stamps"
  assumes "kind g nid = (IntegerLessThanNode x y)"
  assumes "[g, m, p]  nid  v"
  assumes "stpi_upper (stamps x) < stpi_lower (stamps y)"
  shows "val_to_bool v"
proof -
  obtain xe ye where repr: "rep g nid (BinaryExpr BinIntegerLessThan xe ye)"
    by (metis assms(2) assms(3) encodeeval.simps rep_integer_less_than)
  moreover have evale: "[m, p]  (BinaryExpr BinIntegerLessThan xe ye)  v"
    by (metis assms(3) calculation encodeeval.simps repDet)
  moreover have repsub: "rep g x xe  rep g y ye"
    by (metis IRNode.distinct(2047) IRNode.distinct(2089) IRNode.inject(18) IntegerLessThanNodeE assms(2) repr)
  ultimately obtain xv yv where evalsub: "[g, m, p]  x  xv  [g, m, p]  y  yv"
    by (meson BinaryExprE encodeeval.simps)
  have vval: "v = intval_less_than xv yv"
    by (metis BinaryExprE bin_eval.simps(14) encodeEvalDet encodeeval.simps evale evalsub repsub)
  then obtain b xvv where "xv = IntVal b xvv"
    by (metis bin_eval.simps(14) defined_eval_is_intval evale evaltree_not_undef is_IntVal_def)
  also have xvalid: "valid_value xv (stamps x)"
    by (meson assms(1) encodeeval.simps eval_in_ids evalsub wf_stamp.elims(2))
  then obtain xl xh where xstamp: "stamps x = IntegerStamp b xl xh"
    using calculation valid_value.simps apply (cases "stamps x"; auto)
    by presburger
  from vval obtain yvv where yint: "yv = IntVal b yvv"
    by (metis Value.collapse(1) bin_eval.simps(14) bool_to_val_bin.simps calculation defined_eval_is_intval evale evaltree_not_undef intval_less_than.simps(1))
  then have yvalid: "valid_value yv (stamps y)"
    using assms(1) encodeeval.simps evalsub no_encoding wf_stamp.simps by blast
  then obtain yl yh where ystamp: "stamps y = IntegerStamp b yl yh"
    using calculation yint valid_value.simps apply (cases "stamps y"; auto)
    by presburger
  have "int_signed_value b xvv  xh"
    using calculation valid_value.simps(1) xstamp xvalid by presburger
  moreover have "yl  int_signed_value b yvv"
    using valid_value.simps(1) yint ystamp yvalid by presburger
  moreover have "xh < yl"
    using assms(4) xstamp ystamp by auto
  ultimately have "int_signed_value b xvv < int_signed_value b yvv"
    by linarith
  then have "val_to_bool (intval_less_than xv yv)"
    by (simp add: (xv::Value) = IntVal (b::nat) (xvv::64 word) yint)
  then show ?thesis
    by (simp add: vval)
qed

lemma stampOver_valid:
  assumes "wf_stamp g stamps"
  assumes "kind g nid = (IntegerLessThanNode x y)"
  assumes "[g, m, p]  nid  v"
  assumes "stpi_lower (stamps x)  stpi_upper (stamps y)"
  shows "¬(val_to_bool v)"
proof -
  obtain xe ye where repr: "rep g nid (BinaryExpr BinIntegerLessThan xe ye)"
    by (metis assms(2) assms(3) encodeeval.simps rep_integer_less_than)
  moreover have evale: "[m, p]  (BinaryExpr BinIntegerLessThan xe ye)  v"
    by (metis assms(3) calculation encodeeval.simps repDet)
  moreover have repsub: "rep g x xe  rep g y ye"
    by (metis IRNode.distinct(2047) IRNode.distinct(2089) IRNode.inject(18) IntegerLessThanNodeE assms(2) repr)
  ultimately obtain xv yv where evalsub: "[g, m, p]  x  xv  [g, m, p]  y  yv"
    by (meson BinaryExprE encodeeval.simps)
  have vval: "v = intval_less_than xv yv"
    by (metis BinaryExprE bin_eval.simps(14) encodeEvalDet encodeeval.simps evale evalsub repsub)
  then obtain b xvv where "xv = IntVal b xvv"
    by (metis bin_eval.simps(14) defined_eval_is_intval evale evaltree_not_undef is_IntVal_def)
  also have xvalid: "valid_value xv (stamps x)"
    by (meson assms(1) encodeeval.simps eval_in_ids evalsub wf_stamp.elims(2))
  then obtain xl xh where xstamp: "stamps x = IntegerStamp b xl xh"
    using calculation valid_value.simps apply (cases "stamps x"; auto)
    by presburger
  from vval obtain yvv where yint: "yv = IntVal b yvv"
    by (metis Value.collapse(1) bin_eval.simps(14) bool_to_val_bin.simps calculation defined_eval_is_intval evale evaltree_not_undef intval_less_than.simps(1))
  then have yvalid: "valid_value yv (stamps y)"
    using assms(1) encodeeval.simps evalsub no_encoding wf_stamp.simps by blast
  then obtain yl yh where ystamp: "stamps y = IntegerStamp b yl yh"
    using calculation yint valid_value.simps apply (cases "stamps y"; auto)
    by presburger
  have "xl  int_signed_value b xvv"
    using calculation valid_value.simps(1) xstamp xvalid by presburger
  moreover have "int_signed_value b yvv  yh"
    using valid_value.simps(1) yint ystamp yvalid by presburger
  moreover have "xl  yh"
    using assms(4) xstamp ystamp by auto
  ultimately have "int_signed_value b xvv  int_signed_value b yvv"
    by linarith
  then have "¬(val_to_bool (intval_less_than xv yv))"
    by (simp add: (xv::Value) = IntVal (b::nat) (xvv::64 word) yint)
  then show ?thesis
    by (simp add: vval)
qed

theorem tryFoldTrue_valid:
  assumes "wf_stamp g stamps"
  assumes "tryFold (kind g nid) stamps True"
  assumes "[g, m, p]  nid  v"
  shows "val_to_bool v"
  using assms(2) proof (induction "kind g nid" stamps True rule: tryFold.induct)
case (1 stamps x y)
  then show ?case
    using alwaysDistinct_valid assms by force
next
  case (2 stamps x y)
  then show ?case
    by (smt (verit, best) one_neq_zero tryFold.cases neverDistinct_valid assms
        stampUnder_valid val_to_bool.simps(1))
next
  case (3 stamps x y)
  then show ?case
    by (smt (verit, best) one_neq_zero tryFold.cases neverDistinct_valid assms
        val_to_bool.simps(1) stampUnder_valid)
next
case (4 stamps x y)
  then show ?case
    by force
qed

theorem tryFoldFalse_valid:
  assumes "wf_stamp g stamps"
  assumes "tryFold (kind g nid) stamps False"
  assumes "[g, m, p]  nid  v"
  shows "¬(val_to_bool v)"
using assms(2) proof (induction "kind g nid" stamps False rule: tryFold.induct)
case (1 stamps x y)
  then show ?case
    by (smt (verit) stampOver_valid alwaysDistinct_valid tryFold.cases
        neverDistinct_valid val_to_bool.simps(1) assms)
next
case (2 stamps x y)
  then show ?case
    by blast
next
  case (3 stamps x y)
  then show ?case
    by blast
next
  case (4 stamps x y)
  then show ?case
    by (smt (verit, del_insts) tryFold.cases alwaysDistinct_valid val_to_bool.simps(1)
        stampOver_valid assms)
qed


subsection ‹Lift rules›

inductive condset_implies :: "IRExpr set  IRExpr  bool  bool" where
  impliesTrue:
  "(ce  conds . (ce  cond))  condset_implies conds cond True" |
  impliesFalse:
  "(ce  conds . (ce ⇛¬ cond))  condset_implies conds cond False"

code_pred (modes: i  i  i  bool) condset_implies .

text ‹
The @{term cond_implies} function lifts the structural and type implication
rules to the one relation.
›

fun conds_implies :: "IRExpr set  (ID  Stamp)  IRNode  IRExpr  bool option" where
  "conds_implies conds stamps condNode cond = 
    (if condset_implies conds cond True  tryFold condNode stamps True 
      then Some True
    else if condset_implies conds cond False  tryFold condNode stamps False
      then Some False
    else None)"

text ‹
Perform conditional elimination rewrites on the graph for a particular node
by lifting the individual implication rules to a relation that rewrites the
condition of \textsl{if} statements to constant values.

In order to determine conditional eliminations appropriately the rule needs two
data structures produced by static analysis.
The first parameter is the set of IRNodes that we know result in a true value
when evaluated.
The second parameter is a mapping from node identifiers to the flow-sensitive stamp.
›

inductive ConditionalEliminationStep :: 
  "IRExpr set  (ID  Stamp)  ID  IRGraph  IRGraph  bool"
  where
  impliesTrue:
  "kind g ifcond = (IfNode cid t f);
    g  cid  cond; 
    condNode = kind g cid;
    conds_implies conds stamps condNode cond = (Some True);
    g' = constantCondition True ifcond (kind g ifcond) g
      ConditionalEliminationStep conds stamps ifcond g g'" |

  impliesFalse:
  "kind g ifcond = (IfNode cid t f);
    g  cid  cond;
    condNode = kind g cid;
    conds_implies conds stamps condNode cond = (Some False);
    g' = constantCondition False ifcond (kind g ifcond) g
      ConditionalEliminationStep conds stamps ifcond g g'" |

  unknown:
  "kind g ifcond = (IfNode cid t f);
    g  cid  cond; 
    condNode = kind g cid;
    conds_implies conds stamps condNode cond = None
      ConditionalEliminationStep conds stamps ifcond g g" |

  notIfNode:
  "¬(is_IfNode (kind g ifcond)) 
    ConditionalEliminationStep conds stamps ifcond g g"


code_pred (modes: i  i  i  i  o  bool) ConditionalEliminationStep .

thm ConditionalEliminationStep.equation



subsection ‹Control-flow Graph Traversal›

type_synonym Seen = "ID set"
type_synonym Condition = "IRExpr"
type_synonym Conditions = "Condition list"
type_synonym StampFlow = "(ID  Stamp) list"
type_synonym ToVisit = "ID list"


text @{term "nextEdge"} helps determine which node to traverse next 
by returning the first successor edge that isn't in the set of already visited nodes.
If there is not an appropriate successor, None is returned instead.
›
fun nextEdge :: "Seen  ID  IRGraph  ID option" where
  "nextEdge seen nid g = 
    (let nids = (filter (λnid'. nid'  seen) (successors_of (kind g nid))) in 
     (if length nids > 0 then Some (hd nids) else None))"

text @{term "pred"} determines which node, if any, acts as the predecessor of another.

Merge nodes represent a special case wherein the predecessor exists as
an input edge of the merge node, to simplify the traversal we treat only
the first input end node as the predecessor, ignoring that multiple nodes
may act as a successor.

For all other nodes, the predecessor is the first element of the predecessors set.
Note that in a well-formed graph there should only be one element in the predecessor set.
›
fun preds :: "IRGraph  ID  ID list" where
  "preds g nid = (case kind g nid of
    (MergeNode ends _ _)  ends |
    _  
      sorted_list_of_set (IRGraph.predecessors g nid)
  )"

fun pred :: "IRGraph  ID  ID option" where
  "pred g nid = (case preds g nid of []  None | x # xs  Some x)"


text ‹
When the basic block of an if statement is entered, we know that the condition of the
preceding if statement must be true.
As in the GraalVM compiler, we introduce the \texttt{registerNewCondition} function
which roughly corresponds to \texttt{ConditionalEliminationPhase.registerNewCondition}.
This method updates the flow-sensitive stamp information based on the condition which
we know must be true. 
›
fun clip_upper :: "Stamp  int  Stamp" where
  "clip_upper (IntegerStamp b l h) c = 
          (if c < h then (IntegerStamp b l c) else (IntegerStamp b l h))" |
  "clip_upper s c = s"
fun clip_lower :: "Stamp  int  Stamp" where
  "clip_lower (IntegerStamp b l h) c = 
          (if l < c then (IntegerStamp b c h) else (IntegerStamp b l c))" |
  "clip_lower s c = s"

fun max_lower :: "Stamp  Stamp  Stamp" where
  "max_lower (IntegerStamp b1 xl xh) (IntegerStamp b2 yl yh) =
        (IntegerStamp b1 (max xl yl) xh)" |
  "max_lower xs ys = xs"
fun min_higher :: "Stamp  Stamp  Stamp" where
  "min_higher (IntegerStamp b1 xl xh) (IntegerStamp b2 yl yh) =
        (IntegerStamp b1 yl (min xh yh))" |
  "min_higher xs ys = ys"

fun registerNewCondition :: "IRGraph  IRNode  (ID  Stamp)  (ID  Stamp)" where
  ― ‹constrain equality by joining the stamps›
  "registerNewCondition g (IntegerEqualsNode x y) stamps =
    (stamps
      (x := join (stamps x) (stamps y)))
      (y := join (stamps x) (stamps y))" |
  ― ‹constrain less than by removing overlapping stamps›
  "registerNewCondition g (IntegerLessThanNode x y) stamps =
    (stamps
      (x := clip_upper (stamps x) ((stpi_lower (stamps y)) - 1)))
      (y := clip_lower (stamps y) ((stpi_upper (stamps x)) + 1))" |
  "registerNewCondition g (LogicNegationNode c) stamps =
    (case (kind g c) of
      (IntegerLessThanNode x y) 
        (stamps
          (x := max_lower (stamps x) (stamps y)))
          (y := min_higher (stamps x) (stamps y))
       | _  stamps)" |
  "registerNewCondition g _ stamps = stamps"

fun hdOr :: "'a list  'a  'a" where
  "hdOr (x # xs) de = x" |
  "hdOr [] de = de"

(*
fun isCFGNode :: "IRNode ⇒ bool" where
  "isCFGNode (BeginNode _) = True" |
  "isCFGNode (EndNode) = True" |
  "isCFGNode _ = False"

inductive CFGSuccessor ::
  "IRGraph ⇒ (ID × Seen × ToVisit) ⇒ (ID × Seen × ToVisit) ⇒ bool"
  for g where
  ― ‹
  Forward traversal transitively through successors until
  a CFG node is reached.›
  "⟦Some nid' = nextEdge seen nid g;
    ¬(isCFGNode (kind g nid'));
    CFGSuccessor g (nid', {nid} ∪ seen, nid # toVisit) (nid'', seen', toVisit')⟧ 
    ⟹ CFGSuccessor g (nid, seen, toVisit) (nid'', seen', toVisit')" |
  "⟦Some nid' = nextEdge seen nid g;
    isCFGNode (kind g nid')⟧
    ⟹ CFGSuccessor g (nid, seen, toVisit) (nid', {nid} ∪ seen, nid # toVisit)" |

  ― ‹
  Backwards traversal transitively through toVisit stack until
  a CFG node is reached.›
  "⟦toVisit = nid' # toVisit';
    CFGSuccessor g (nid', {nid} ∪ seen, nid # toVisit) (nid'', seen', toVisit')⟧ 
    ⟹ CFGSuccessor g (nid, seen, toVisit) (nid'', seen', toVisit')"

code_pred (modes: i ⇒ i ⇒ o ⇒ bool) CFGSuccessor .
*)

type_synonym DominatorCache = "(ID, ID set) map"

inductive 
  dominators_all :: "IRGraph  DominatorCache  ID  ID set set  ID list  DominatorCache  ID set set  ID list  bool" and
  dominators :: "IRGraph  DominatorCache  ID  (ID set × DominatorCache)  bool" where

  "pre = []
     dominators_all g c nid doms pre c doms pre" |

  "pre = pr # xs;
    (dominators g c pr (doms', c'));
    dominators_all g c' pr (doms  {doms'}) xs c'' doms'' pre'
     dominators_all g c nid doms pre c'' doms'' pre'" |

  "preds g nid = []
     dominators g c nid ({nid}, c)" |
  
  "c nid = None;
    preds g nid = x # xs;
    dominators_all g c nid {} (preds g nid) c' doms pre';
    c'' = c'(nid  ({nid}  (doms)))
     dominators g c nid (({nid}  (doms)), c'')" |

  "c nid = Some doms
     dominators g c nid (doms, c)"

― ‹
Trying to simplify by removing the 3rd case won't work.
A base case for root nodes is required as @{term "{} = coset []"}
which swallows anything unioned with it.
›
value "({}::nat set set)"
value "- ({}::nat set set)"
value "({{}, {0}}::nat set set)"
value "{0::nat}  ({})"

code_pred (modes: i  i  i  i  i  o  o  o  bool) dominators_all .
code_pred (modes: i  i  i  o  bool) dominators .

(* initial: ConditionalEliminationTest13_testSnippet2 *)
definition ConditionalEliminationTest13_testSnippet2_initial :: IRGraph where
  "ConditionalEliminationTest13_testSnippet2_initial = irgraph [
  (0, (StartNode (Some 2) 8), VoidStamp),
  (1, (ParameterNode 0), IntegerStamp 32 (-2147483648) (2147483647)),
  (2, (FrameState [] None None None), IllegalStamp),
  (3, (ConstantNode (new_int 32 (0))), IntegerStamp 32 (0) (0)),
  (4, (ConstantNode (new_int 32 (1))), IntegerStamp 32 (1) (1)),
  (5, (IntegerLessThanNode 1 4), VoidStamp),
  (6, (BeginNode 13), VoidStamp),
  (7, (BeginNode 23), VoidStamp),
  (8, (IfNode 5 7 6), VoidStamp),
  (9, (ConstantNode (new_int 32 (-1))), IntegerStamp 32 (-1) (-1)),
  (10, (IntegerEqualsNode 1 9), VoidStamp),
  (11, (BeginNode 17), VoidStamp),
  (12, (BeginNode 15), VoidStamp),
  (13, (IfNode 10 12 11), VoidStamp),
  (14, (ConstantNode (new_int 32 (-2))), IntegerStamp 32 (-2) (-2)),
  (15, (StoreFieldNode 15 ''org.graalvm.compiler.core.test.ConditionalEliminationTestBase::sink2'' 14 (Some 16) None 19), VoidStamp),
  (16, (FrameState [] None None None), IllegalStamp),
  (17, (EndNode), VoidStamp),
  (18, (MergeNode [17, 19] (Some 20) 21), VoidStamp),
  (19, (EndNode), VoidStamp),
  (20, (FrameState [] None None None), IllegalStamp),
  (21, (StoreFieldNode 21 ''org.graalvm.compiler.core.test.ConditionalEliminationTestBase::sink1'' 3 (Some 22) None 25), VoidStamp),
  (22, (FrameState [] None None None), IllegalStamp),
  (23, (EndNode), VoidStamp),
  (24, (MergeNode [23, 25] (Some 26) 27), VoidStamp),
  (25, (EndNode), VoidStamp),
  (26, (FrameState [] None None None), IllegalStamp),
  (27, (StoreFieldNode 27 ''org.graalvm.compiler.core.test.ConditionalEliminationTestBase::sink0'' 9 (Some 28) None 29), VoidStamp),
  (28, (FrameState [] None None None), IllegalStamp),
  (29, (ReturnNode None None), VoidStamp)
  ]"

(* :(
fun dominators :: "IRGraph ⇒ ID ⇒ ID set" where
  "dominators g nid = {nid} ∪ (⋂ y ∈ preds g nid. dominators g y)"
*)

values "{(snd x) 13| x. dominators ConditionalEliminationTest13_testSnippet2_initial Map.empty 25 x}"

(*fun condition_of :: "IRGraph ⇒ ID ⇒ ID option" where
  "condition_of g nid = (case (kind g nid) of
    (IfNode c t f) ⇒ Some c |
    _ ⇒ None)"*)

inductive
  condition_of :: "IRGraph  ID  (IRExpr × IRNode) option  bool" where
  "Some ifcond = pred g nid;
    kind g ifcond = IfNode cond t f;

    i = find_index nid (successors_of (kind g ifcond));
    c = (if i = 0 then kind g cond else LogicNegationNode cond);
    rep g cond ce;
    ce' = (if i = 0 then ce else UnaryExpr UnaryLogicNegation ce)
   condition_of g nid (Some (ce', c))" |

  "pred g nid = None  condition_of g nid None" |
  "pred g nid = Some nid';
    ¬(is_IfNode (kind g nid'))  condition_of g nid None"

code_pred (modes: i  i  o  bool) condition_of .

(*inductive
  conditions_of_dominators :: "IRGraph ⇒ ID list ⇒ Conditions ⇒ Conditions ⇒ bool" where
  "⟦nids = []⟧
    ⟹ conditions_of_dominators g nids conditions conditions" |

  "⟦nids = nid # nids';
    condition_of g nid (Some (expr, _));
    conditions_of_dominators g nids' (expr # conditions) conditions'⟧
    ⟹ conditions_of_dominators g nids conditions conditions'" |

  "⟦nids = nid # nids';
    condition_of g nid None;
    conditions_of_dominators g nids' conditions conditions'⟧
    ⟹ conditions_of_dominators g nids conditions conditions'"*)

fun conditions_of_dominators :: "IRGraph  ID list  Conditions  Conditions" where
  "conditions_of_dominators g [] cds = cds" |
  "conditions_of_dominators g (nid # nids) cds = 
    (case (Predicate.the (condition_of_i_i_o g nid)) of 
      None  conditions_of_dominators g nids cds |
      Some (expr, _)  conditions_of_dominators g nids (expr # cds))"

(*code_pred (modes: i ⇒ i ⇒ i ⇒ o ⇒ bool) conditions_of_dominators .*)

(*
inductive
  stamps_of_dominators :: "IRGraph ⇒ ID list ⇒ StampFlow ⇒ StampFlow ⇒ bool" where
  "⟦nids = []⟧
    ⟹ stamps_of_dominators g nids stamps stamps" |

  "⟦nids = nid # nids';
    condition_of g nid (Some (_, node));
    he = registerNewCondition g node (hd stamps);
    stamps_of_dominators g nids' (he # stamps) stamps'⟧
    ⟹ stamps_of_dominators g nids stamps stamps'" |

  "⟦nids = nid # nids';
    condition_of g nid None;
    stamps_of_dominators g nids' stamps stamps'⟧
    ⟹ stamps_of_dominators g nids stamps stamps'"
*)

fun stamps_of_dominators :: "IRGraph  ID list  StampFlow  StampFlow" where
  "stamps_of_dominators g [] stamps = stamps" |
  "stamps_of_dominators g (nid # nids) stamps = 
    (case (Predicate.the (condition_of_i_i_o g nid)) of 
      None  stamps_of_dominators g nids stamps |
      Some (_, node)  stamps_of_dominators g nids 
        ((registerNewCondition g node (hd stamps)) # stamps))"

(*code_pred (modes: i ⇒ i ⇒ i ⇒ o ⇒ bool) stamps_of_dominators .*)

inductive
  analyse :: "IRGraph  DominatorCache  ID  (Conditions × StampFlow × DominatorCache)  bool" where
  "dominators g c nid (doms, c');
    conditions_of_dominators g (sorted_list_of_set doms) [] = conds;
    stamps_of_dominators g (sorted_list_of_set doms) [stamp g] = stamps
     analyse g c nid (conds, stamps, c')"

code_pred (modes: i  i  i  o  bool) analyse .

values "{x. dominators ConditionalEliminationTest13_testSnippet2_initial Map.empty 13 x}"
values "{(conds, stamps, c). 
analyse ConditionalEliminationTest13_testSnippet2_initial Map.empty 13 (conds, stamps, c)}"
values "{(hd stamps) 1| conds stamps c .
analyse ConditionalEliminationTest13_testSnippet2_initial Map.empty 13 (conds, stamps, c)}"
values "{(hd stamps) 1| conds stamps c .
analyse ConditionalEliminationTest13_testSnippet2_initial Map.empty 27 (conds, stamps, c)}"

fun next_nid :: "IRGraph  ID set  ID  ID option" where
  "next_nid g seen nid = (case (kind g nid) of
    (EndNode)  Some (any_usage g nid) |
    _  nextEdge seen nid g)"

inductive Step
  :: "IRGraph  (ID × Seen)  (ID × Seen) option  bool"
  for g where
  ― ‹We can find a successor edge that is not in seen, go there›
  "seen' = {nid}  seen;

    Some nid' = next_nid g seen' nid;
    nid'  seen'
    Step g (nid, seen) (Some (nid', seen'))" |

  ― ‹We can cannot find a successor edge that is not in seen, give back None›
  "seen' = {nid}  seen;

    None = next_nid g seen' nid
     Step g (nid, seen) None" |

  ― ‹We've already seen this node, give back None›
  "seen' = {nid}  seen;

    Some nid' = next_nid g seen' nid;
    nid'  seen'  Step g (nid, seen) None"

code_pred (modes: i  i  o  bool) Step .

fun nextNode :: "IRGraph  Seen  (ID × Seen) option" where
  "nextNode g seen = 
    (let toSee = sorted_list_of_set {n  ids g. n  seen} in
      case toSee of []  None | (x # xs)  Some (x, seen  {x}))"

values "{x. Step ConditionalEliminationTest13_testSnippet2_initial (17, {17,11,25,21,18,19,15,12,13,6,29,27,24,23,7,8,0}) x}"


text ‹
The @{text "ConditionalEliminationPhase"} relation is responsible for combining
the individual traversal steps from the @{text "Step"} relation and the optimizations
from the @{text "ConditionalEliminationStep"} relation to perform a transformation of the
whole graph.
›

inductive ConditionalEliminationPhase 
  :: "(Seen × DominatorCache)  IRGraph  IRGraph  bool"
  where

  ― ‹Can do a step and optimise for the current node›
  "nextNode g seen = Some (nid, seen');
    
    analyse g c nid (conds, flow, c');
    ConditionalEliminationStep (set conds) (hd flow) nid g g';

    ConditionalEliminationPhase (seen', c') g' g''
     ConditionalEliminationPhase (seen, c) g g''" |

  "nextNode g seen = None
     ConditionalEliminationPhase (seen, c) g g"

code_pred (modes: i  i  o  bool) ConditionalEliminationPhase . 

definition runConditionalElimination :: "IRGraph  IRGraph" where
  "runConditionalElimination g = 
    (Predicate.the (ConditionalEliminationPhase_i_i_o ({}, Map.empty) g))"


values "{(doms, c')| doms c'.
dominators ConditionalEliminationTest13_testSnippet2_initial Map.empty 6 (doms, c')}"

values "{(conds, stamps, c)| conds stamps c .
analyse ConditionalEliminationTest13_testSnippet2_initial Map.empty 6 (conds, stamps, c)}"
value "
  (nextNode
      ConditionalEliminationTest13_testSnippet2_initial {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29})
"
(*
values "{g|g. (ConditionalEliminationPhase ({}, Map.empty) ConditionalEliminationTest13_testSnippet2_initial g)}"
*)
(*
inductive ConditionalEliminationPhaseWithTrace✐‹tag invisible›
  :: "IRGraph ⇒ (ID × Seen × Conditions × StampFlow) ⇒ ID list ⇒ IRGraph ⇒ ID list ⇒ Conditions ⇒ bool" where✐‹tag invisible›

  (* Can do a step and optimise for the current nid *)
  "⟦Step g (nid, seen, conds, flow) (Some (nid', seen', conds', flow'));
    ConditionalEliminationStep (set conds) (hdOr flow (stamp g)) nid g g';
    
    ConditionalEliminationPhaseWithTrace g' (nid', seen', conds', flow') (nid # t) g'' t' conds''⟧
    ⟹ ConditionalEliminationPhaseWithTrace g (nid, seen, conds, flow) t g'' t' conds''" |

  (* Can do a step, matches whether optimised or not causing non-determinism
     Need to find a way to negate ConditionalEliminationStep *)
  "⟦Step g (nid, seen, conds, flow) (Some (nid', seen', conds', flow'));
    
    ConditionalEliminationPhaseWithTrace g (nid', seen', conds', flow') (nid # t) g' t' conds''⟧
    ⟹ ConditionalEliminationPhaseWithTrace g (nid, seen, conds, flow) t g' t' conds''" |

  (* Can't do a step but there is a predecessor we can backtrace to *)
  "⟦Step g (nid, seen, conds, flow) None;
    Some nid' = pred g nid;
    seen' = {nid} ∪ seen;
    ConditionalEliminationPhaseWithTrace g (nid', seen', conds, flow) (nid # t) g' t' conds'⟧
    ⟹ ConditionalEliminationPhaseWithTrace g (nid, seen, conds, flow) t g' t' conds'" |

  (* Can't do a step and have no predecessors do terminate *)
  "⟦Step g (nid, seen, conds, flow) None;
    None = pred g nid⟧
    ⟹ ConditionalEliminationPhaseWithTrace g (nid, seen, conds, flow) t g (nid # t) conds"

code_pred (modes: i ⇒ i ⇒ i ⇒ o ⇒ o ⇒ o ⇒ bool) ConditionalEliminationPhaseWithTrace .
*)

lemma IfNodeStepE: "g, p  (nid, m, h)  (nid', m', h) 
  (cond tb fb val.
        kind g nid = IfNode cond tb fb 
        nid' = (if val_to_bool val then tb else fb)  
        [g, m, p]  cond  val  m' = m)"
  using StepE
  by (smt (verit, best) IfNode Pair_inject stepDet)

lemma ifNodeHasCondEvalStutter:
  assumes "(g m p h  nid  nid')"
  assumes "kind g nid = IfNode cond t f"
  shows " v. ([g, m, p]  cond  v)"
  using IfNodeStepE assms(1) assms(2)  stutter.cases unfolding encodeeval.simps
  by (smt (verit, ccfv_SIG) IfNodeCond)

lemma ifNodeHasCondEval:
  assumes "(g, p  (nid, m, h)  (nid', m', h'))"
  assumes "kind g nid = IfNode cond t f"
  shows " v. ([g, m, p]  cond  v)"
  using IfNodeStepE assms(1) assms(2) apply auto[1]
  by (smt (verit) IRNode.disc(1966) IRNode.distinct(1733) IRNode.distinct(1735) IRNode.distinct(1755) IRNode.distinct(1757) IRNode.distinct(1777) IRNode.distinct(1783) IRNode.distinct(1787) IRNode.distinct(1789) IRNode.distinct(401) IRNode.distinct(755) StutterStep fst_conv ifNodeHasCondEvalStutter is_AbstractEndNode.simps is_EndNode.simps(16) snd_conv step.cases)

lemma replace_if_t:
  assumes "kind g nid = IfNode cond tb fb"
  assumes "[g, m, p]  cond  bool"
  assumes "val_to_bool bool"
  assumes g': "g' = replace_usages nid tb g"
  shows "nid' .(g m p h  nid  nid')  (g' m p h  nid  nid')"
proof -
  have g1step: "g, p  (nid, m, h)  (tb, m, h)"
    by (meson IfNode assms(1) assms(2) assms(3) encodeeval.simps)
  have g2step: "g', p  (nid, m, h)  (tb, m, h)"
    using g' unfolding replace_usages.simps
    by (simp add: stepRefNode)
  from g1step g2step show ?thesis
    using StutterStep by blast
qed

lemma replace_if_t_imp:
  assumes "kind g nid = IfNode cond tb fb"
  assumes "[g, m, p]  cond  bool"
  assumes "val_to_bool bool"
  assumes g': "g' = replace_usages nid tb g"
  shows "nid' .(g m p h  nid  nid')  (g' m p h  nid  nid')"
  using replace_if_t assms by blast

lemma replace_if_f:
  assumes "kind g nid = IfNode cond tb fb"
  assumes "[g, m, p]  cond  bool"
  assumes "¬(val_to_bool bool)"
  assumes g': "g' = replace_usages nid fb g"
  shows "nid' .(g m p h  nid  nid')  (g' m p h  nid  nid')"
proof -
  have g1step: "g, p  (nid, m, h)  (fb, m, h)"
    by (meson IfNode assms(1) assms(2) assms(3) encodeeval.simps)
  have g2step: "g', p  (nid, m, h)  (fb, m, h)"
    using g' unfolding replace_usages.simps
    by (simp add: stepRefNode)
  from g1step g2step show ?thesis
    using StutterStep by blast
qed

text ‹
Prove that the individual conditional elimination rules are correct
with respect to preservation of stuttering steps.
›
lemma ConditionalEliminationStepProof:
  assumes wg: "wf_graph g"
  assumes ws: "wf_stamps g"
  assumes wv: "wf_values g"
  assumes nid: "nid  ids g"
  assumes conds_valid: " c  conds .  v. ([m, p]  c  v)  val_to_bool v"
  assumes ce: "ConditionalEliminationStep conds stamps nid g g'"

  shows "nid' .(g m p h  nid  nid')  (g' m p h  nid  nid')"
  using ce using assms
proof (induct nid g g' rule: ConditionalEliminationStep.induct)
  case (impliesTrue g ifcond cid t f cond conds g')
  show ?case proof (cases "nid'. (g m p h  ifcond  nid')")
    case True
    show ?thesis
      by (metis StutterStep constantConditionNoIf constantConditionTrue impliesTrue.hyps(5))
  next
    case False
    then show ?thesis by auto
  qed
next
  case (impliesFalse g ifcond cid t f cond conds g')
  then show ?case 
  proof (cases "nid'. (g m p h  ifcond  nid')")
    case True
    then show ?thesis
      by (metis StutterStep constantConditionFalse constantConditionNoIf impliesFalse.hyps(5))
  next
    case False
    then show ?thesis
      by auto
  qed
next
  case (unknown g ifcond cid t f cond condNode conds stamps)
  then show ?case
    by blast
next
  case (notIfNode g ifcond conds stamps)
  then show ?case
    by blast
qed


text ‹
Prove that the individual conditional elimination rules are correct
with respect to finding a bisimulation between the unoptimized and optimized graphs.
›
lemma ConditionalEliminationStepProofBisimulation:
  assumes wf: "wf_graph g  wf_stamp g stamps  wf_values g"
  assumes nid: "nid  ids g"
  assumes conds_valid: " c  conds .  v. ([m, p]  c  v)  val_to_bool v"
  assumes ce: "ConditionalEliminationStep conds stamps nid g g'"
  assumes gstep: " h nid'. (g, p  (nid, m, h)  (nid', m, h))" (* we don't yet consider optimizations which produce a step that didn't already exist *)

  shows "nid | g  g'"
  using ce gstep using assms
proof (induct nid g g' rule: ConditionalEliminationStep.induct)
  case (impliesTrue g ifcond cid t f cond condNode conds stamps g')
  from impliesTrue(5) obtain h where gstep: "g, p  (ifcond, m, h)  (t, m, h)"
    using IfNode encodeeval.simps ifNodeHasCondEval impliesTrue.hyps(1) impliesTrue.hyps(2) impliesTrue.hyps(3) impliesTrue.prems(4) implies_impliesnot_valid implies_valid.simps repDet
    by (smt (verit) conds_implies.elims condset_implies.simps impliesTrue.hyps(4) impliesTrue.prems(1) impliesTrue.prems(2) option.distinct(1) option.inject tryFoldTrue_valid)
  have "g', p  (ifcond, m, h)  (t, m, h)"
    using constantConditionTrue impliesTrue.hyps(1) impliesTrue.hyps(5) by blast
  then show ?case using gstep
    by (metis stepDet strong_noop_bisimilar.intros)
next
  case (impliesFalse g ifcond cid t f cond condNode conds stamps g')
  from impliesFalse(5) obtain h where gstep: "g, p  (ifcond, m, h)  (f, m, h)"
    using IfNode encodeeval.simps ifNodeHasCondEval impliesFalse.hyps(1) impliesFalse.hyps(2) impliesFalse.hyps(3) impliesFalse.prems(4) implies_impliesnot_valid impliesnot_valid.simps repDet
    by (smt (verit) conds_implies.elims condset_implies.simps impliesFalse.hyps(4) impliesFalse.prems(1) impliesFalse.prems(2) option.distinct(1) option.inject tryFoldFalse_valid)
  have "g', p  (ifcond, m, h)  (f, m, h)"
    using constantConditionFalse impliesFalse.hyps(1) impliesFalse.hyps(5) by blast
  then show ?case using gstep
    by (metis stepDet strong_noop_bisimilar.intros)
next
  case (unknown g ifcond cid t f cond condNode conds stamps)
  then show ?case
    using strong_noop_bisimilar.simps by presburger
next
  case (notIfNode g ifcond conds stamps)
  then show ?case
    using strong_noop_bisimilar.simps by presburger
qed


experiment begin
(*lemma if_step:
  assumes "nid ∈ ids g"
  assumes "(kind g nid) ∈ control_nodes"
  shows "(g m p h ⊢ nid ↝ nid')"
  using assms apply (cases "kind g nid") sorry
*)
(*
definition blockNodes :: "IRGraph ⇒ Block ⇒ ID set" where
  "blockNodes g b = {n ∈ ids g. blockOf g n = b}"

lemma phiInCFG:
  "∀n ∈ blockNodes g nid. (g, p ⊢ (n, m, h) → (n', m', h'))"
*)

lemma inverse_succ:
  "n'  (succ g n). n  ids g  n  (predecessors g n')"
  by simp

lemma sequential_successors:
  assumes "is_sequential_node n"
  shows "successors_of n  []"
  using assms by (cases n; auto)

lemma nid'_succ:
  assumes "nid  ids g"
  assumes "¬(is_AbstractEndNode (kind g nid0))"
  assumes "g, p  (nid0, m0, h0)  (nid, m, h)"
  shows "nid  succ g nid0"
  using assms(3) proof (induction "(nid0, m0, h0)" "(nid, m, h)" rule: step.induct)
  case SequentialNode
  then show ?case
    by (metis length_greater_0_conv nth_mem sequential_successors succ.simps)
next
  case (FixedGuardNode cond before val)
  then have "{nid} = succ g nid0"
    using IRNodes.successors_of_FixedGuardNode unfolding succ.simps
    by (metis empty_set list.simps(15))
  then show ?case
    using FixedGuardNode.hyps(5) by blast
next
  case (BytecodeExceptionNode args st exceptionType ref)
  then have "{nid} = succ g nid0"
    using IRNodes.successors_of_BytecodeExceptionNode unfolding succ.simps
    by (metis empty_set list.simps(15))
  then show ?case
    by blast
next
  case (IfNode cond tb fb val)
  then have "{tb, fb} = succ g nid0"
    using IRNodes.successors_of_IfNode unfolding succ.simps
    by (metis empty_set list.simps(15))
  then show ?case
    by (metis IfNode.hyps(3) insert_iff)
next
  case (EndNodes i phis inps vs)
  then show ?case using assms(2) by blast
next
  case (NewArrayNode len st length' arrayType h' ref refNo)
  then have "{nid} = succ g nid0"
    using IRNodes.successors_of_NewArrayNode unfolding succ.simps
    by (metis empty_set list.simps(15))
  then show ?case
    by blast
next
  case (ArrayLengthNode x ref arrayVal length')
  then have "{nid} = succ g nid0"
    using IRNodes.successors_of_ArrayLengthNode unfolding succ.simps
    by (metis empty_set list.simps(15))
  then show ?case
    by blast
next
  case (LoadIndexedNode index guard array indexVal ref arrayVal loaded)
  then have "{nid} = succ g nid0"
    using IRNodes.successors_of_LoadIndexedNode unfolding succ.simps
    by (metis empty_set list.simps(15))
  then show ?case
    by blast
next
  case (StoreIndexedNode check val st index guard array indexVal ref "value" arrayVal updated)
  then have "{nid} = succ g nid0"
    using IRNodes.successors_of_StoreIndexedNode unfolding succ.simps
    by (metis empty_set list.simps(15))
  then show ?case
    by blast
next
  case (NewInstanceNode cname obj ref)
  then have "{nid} = succ g nid0"
    using IRNodes.successors_of_NewInstanceNode unfolding succ.simps
    by (metis empty_set list.simps(15))
  then show ?case
    by blast
next
  case (LoadFieldNode f obj ref)
  then have "{nid} = succ g nid0"
    using IRNodes.successors_of_LoadFieldNode unfolding succ.simps
    by (metis empty_set list.simps(15))
  then show ?case
    by blast
next
  case (SignedDivNode x y zero sb v1 v2)
  then have "{nid} = succ g nid0"
    using IRNodes.successors_of_SignedDivNode unfolding succ.simps
    by (metis empty_set list.simps(15))
  then show ?case
    by blast
next
  case (SignedRemNode x y zero sb v1 v2)
  then have "{nid} = succ g nid0"
    using IRNodes.successors_of_SignedRemNode unfolding succ.simps
    by (metis empty_set list.simps(15))
  then show ?case
    by blast
next
  case (StaticLoadFieldNode f)
  then have "{nid} = succ g nid0"
    using IRNodes.successors_of_LoadFieldNode unfolding succ.simps
    by (metis empty_set list.simps(15))
  then show ?case
    by blast
next 
  case (StoreFieldNode _ _ _ _ _ _) 
  then have "{nid} = succ g nid0"
    using IRNodes.successors_of_StoreFieldNode unfolding succ.simps
    by (metis empty_set list.simps(15))
  then show ?case
    by blast
next
  case (StaticStoreFieldNode _ _ _ _)
  then have "{nid} = succ g nid0"
    using IRNodes.successors_of_StoreFieldNode unfolding succ.simps
    by (metis empty_set list.simps(15))
  then show ?case
    by blast
qed

lemma nid'_pred:
  assumes "nid  ids g"
  assumes "¬(is_AbstractEndNode (kind g nid0))"
  assumes "g, p  (nid0, m0, h0)  (nid, m, h)"
  shows "nid0  predecessors g nid"
  using assms
  by (meson inverse_succ nid'_succ step_in_ids)

definition wf_pred:
  "wf_pred g = (n  ids g. card (predecessors g n) = 1)"

lemma
  assumes "¬(is_AbstractMergeNode (kind g n'))"
  assumes "wf_pred g"
  shows "v. predecessors g n = {v}  pred g n' = Some v"
  using assms unfolding pred.simps sorry

lemma inverse_succ1:
  assumes "¬(is_AbstractEndNode (kind g n'))"
  assumes "wf_pred g"
  shows "n'  (succ g n). n  ids g  Some n = (pred g n')"
  using assms sorry

lemma BeginNodeFlow:
  assumes "g, p  (nid0, m0, h0)  (nid, m, h)"
  assumes "Some ifcond = pred g nid"
  assumes "kind g ifcond = IfNode cond t f"
  assumes "i = find_index nid (successors_of (kind g ifcond))"
  shows "i = 0  ([g, m, p]  cond  v)  val_to_bool v"
proof -
  obtain tb fb where "[tb, fb] = successors_of (kind g ifcond)"
    by (simp add: assms(3))
  have "nid0 = ifcond"
    using assms step.IfNode sorry
  show ?thesis sorry
qed

(*
lemma StepConditionsValid:
  assumes "∀ cond ∈ set conds. ([m, p] ⊢ cond ↦ v) ⟶ val_to_bool v"
  assumes "g, p ⊢ (nid0, m0, h0) → (nid, m, h)"
  assumes "Step g (nid, seen, conds, flow) (Some (nid', seen', conds', flow'))"
  shows "∀ cond ∈ set conds'. ([m, p] ⊢ cond ↦ v) ⟶ val_to_bool v"
  using assms(3)
proof (induction "(nid, seen, conds, flow)" "Some (nid', seen', conds', flow')" rule: Step.induct)
  case (1 ifcond cond t f i c ce ce' flow')
  assume "∃cv. [m, p] ⊢ ce ↦ cv"
  then obtain cv where "[m, p] ⊢ ce ↦ cv"
    by blast
  have cvt: "val_to_bool cv"
    using assms(2) sorry
  have "set conds' = {c} ∪ set conds"
    using "1.hyps"(8) by auto
  then show ?case using cv cvt assms(1) sorry
next
  case (2)
  from 2(5) have "set conds' ⊆ set conds"
    by (metis list.sel(2) list.set_sel(2) subsetI)
  then show ?case using assms(1)
    by blast
next
case (3)
  then show ?case
    using assms(1) by force
qed

lemma ConditionalEliminationPhaseProof:
  assumes "wf_graph g"
  assumes "wf_stamps g"
  assumes "ConditionalEliminationPhase g (0, {}, [], []) g'"
  
  shows "∃nid' .(g m p h ⊢ 0 ↝ nid') ⟶ (g' m p h ⊢ 0 ↝ nid')"
proof -
  have "0 ∈ ids g"
    using assms(1) wf_folds by blast
  show ?thesis
using assms(3) assms proof (induct rule: ConditionalEliminationPhase.induct)
case (1 g nid g' succs nid' g'')
  then show ?case sorry
next
  case (2 succs g nid nid' g'')
  then show ?case sorry
next
  case (3 succs g nid)
  then show ?case 
    by simp
next
  case (4)
  then show ?case sorry
qed
qed
*)

end

end