Theory ConditionalPhase

subsection ‹ConditionalNode Phase›

theory ConditionalPhase
  imports
    Common
    Proofs.StampEvalThms
begin

phase ConditionalNode
  terminating size
begin

lemma negates: "v b. e = IntVal b v  b > 0  val_to_bool (val[e])  ¬(val_to_bool (val[!e]))"
  by (metis (mono_tags, lifting) intval_logic_negation.simps(1) logic_negate_def new_int.simps 
      of_bool_eq(2) one_neq_zero take_bit_of_0 take_bit_of_1 val_to_bool.simps(1))

lemma negation_condition_intval: 
  assumes "e = IntVal b ie"
  assumes "0 < b"
  shows "val[(!e) ? x : y] = val[e ? y : x]"
  by (metis assms intval_conditional.simps negates)

lemma negation_preserve_eval:
  assumes "[m, p]  exp[!e]  v"
  shows "v'. ([m, p]  exp[e]  v')  v = val[!v']"
  using assms by auto

lemma negation_preserve_eval_intval:
  assumes "[m, p]  exp[!e]  v"
  shows "v' b vv. ([m, p]  exp[e]  v')  v' = IntVal b vv  b > 0"
  by (metis assms eval_bits_1_64 intval_logic_negation.elims negation_preserve_eval unfold_unary)

optimization NegateConditionFlipBranches: "((!e) ? x : y)  (e ? y : x)"
  apply simp apply (rule allI; rule allI; rule allI; rule impI)
  subgoal premises p for m p v
  proof -
    obtain ev where ev: "[m,p]  e  ev"
      using p by blast
    obtain notEv where notEv: "notEv = intval_logic_negation ev"
      by simp
    obtain lhs where lhs: "[m,p]  ConditionalExpr (UnaryExpr UnaryLogicNegation e) x y  lhs"
      using p by auto
    obtain xv where xv: "[m,p]  x  xv"
      using lhs by blast
    obtain yv where yv: "[m,p]  y  yv"
      using lhs by blast
    then show ?thesis
      by (smt (z3) le_expr_def ConditionalExpr ConditionalExprE Value.distinct(1) evalDet negates p
          negation_preserve_eval negation_preserve_eval_intval)
  qed
  done

optimization DefaultTrueBranch: "(true ? x : y)  x" .

optimization DefaultFalseBranch: "(false ? x : y)  y" .

optimization ConditionalEqualBranches: "(e ? x : x)  x" .

optimization condition_bounds_x: "((u < v) ? x : y)  x 
    when (stamp_under (stamp_expr u) (stamp_expr v)  wf_stamp u  wf_stamp v)"
  using stamp_under_defn by fastforce

optimization condition_bounds_y: "((u < v) ? x : y)  y 
    when (stamp_under (stamp_expr v) (stamp_expr u)  wf_stamp u  wf_stamp v)"
  using stamp_under_defn_inverse by fastforce

(** Start of new proofs **)

(* Value-level proofs *)
lemma val_optimise_integer_test: 
  assumes "v. x = IntVal 32 v"
  shows "val[((x & (IntVal 32 1)) eq (IntVal 32 0)) ? (IntVal 32 0) : (IntVal 32 1)] = 
         val[x & IntVal 32 1]"
  using assms apply auto
  apply (metis (full_types) bool_to_val.simps(2) val_to_bool.simps(1))
  by (metis (mono_tags, lifting) bool_to_val.simps(1) val_to_bool.simps(1) even_iff_mod_2_eq_zero
      odd_iff_mod_2_eq_one and_one_eq)

optimization ConditionalEliminateKnownLess: "((x < y) ? x : y)  x 
                                 when (stamp_under (stamp_expr x) (stamp_expr y)
                                       wf_stamp x  wf_stamp y)"
  using stamp_under_defn by fastforce

lemma ExpIntBecomesIntVal:
  assumes "stamp_expr x = IntegerStamp b xl xh"
  assumes "wf_stamp x"
  assumes "valid_value v (IntegerStamp b xl xh)"
  assumes "[m,p]  x  v"
  shows "xv. v = IntVal b xv"
  using assms by (simp add: IRTreeEvalThms.valid_value_elims(3))

(* Optimisations *)

lemma intval_self_is_true:
  assumes "yv  UndefVal"
  assumes "yv = IntVal b yvv"
  shows "intval_equals yv yv = IntVal 32 1"
  using assms by (cases yv; auto)

lemma intval_commute:
  assumes "intval_equals yv xv  UndefVal"
  assumes "intval_equals xv yv  UndefVal"
  shows "intval_equals yv xv = intval_equals xv yv"
  using assms apply (cases yv; cases xv; auto) by (smt (verit, best))

definition isBoolean :: "IRExpr  bool" where
  "isBoolean e = (m p cond. (([m,p]  e  cond)  (cond  {IntVal 32 0, IntVal 32 1})))"

lemma preserveBoolean:
  assumes "isBoolean c"
  shows "isBoolean exp[!c]"
  using assms isBoolean_def apply auto
  by (metis (no_types, lifting) IntVal0 IntVal1 intval_logic_negation.simps(1) logic_negate_def)

optimization ConditionalIntegerEquals_1: "exp[BinaryExpr BinIntegerEquals (c ? x : y) (x)]  c
                                          when stamp_expr x = IntegerStamp b xl xh  wf_stamp x 
                                               stamp_expr y = IntegerStamp b yl yh  wf_stamp y 
                                               (alwaysDistinct (stamp_expr x) (stamp_expr y)) 
                                               isBoolean c"
  apply (metis Canonicalization.cond_size add_lessD1 size_binary_lhs) apply auto
  subgoal premises p for m p cExpr xv cond
  proof -
    obtain cond where cond: "[m,p]  c  cond"
      using p by blast
    have cRange: "cond = IntVal 32 0  cond = IntVal 32 1"
      using p cond isBoolean_def by blast
    then obtain yv where yVal: "[m,p]  y  yv"
      using p(15) by auto
    obtain xvv where xvv: "xv = IntVal b xvv"
      by (metis p(1,2,7) valid_int wf_stamp_def)
    obtain yvv where yvv: "yv = IntVal b yvv"
      by (metis ExpIntBecomesIntVal p(3,4) wf_stamp_def yVal)
    have yxDiff: "xvv  yvv"
      by (smt (verit, del_insts) yVal xvv wf_stamp_def valid_int_signed_range p yvv)
    have eqEvalFalse: "intval_equals yv xv = (IntVal 32 0)"
      unfolding xvv yvv apply auto by (metis (mono_tags) bool_to_val.simps(2) yxDiff)
    then have valEvalSame: "cond = intval_equals val[cond ? xv : yv] xv"
      apply (cases "cond = IntVal 32 0"; simp) using cRange xvv by auto
    then have condTrue: "val_to_bool cond  cExpr = xv"
      by (metis (mono_tags, lifting) cond evalDet p(11) p(7) p(9))
    then have condFalse: "¬(val_to_bool cond)  cExpr = yv"
      by (metis (full_types) cond evalDet p(11) p(9) yVal)
    then have "[m,p]  c  intval_equals cExpr xv"
      using cond condTrue valEvalSame by fastforce
    then show ?thesis
      by blast
  qed
  done

(* Helpers *)
lemma negation_preserve_eval0:
  assumes "[m, p]  exp[e]  v"
  assumes "isBoolean e"
  shows "v'. ([m, p]  exp[!e]  v')"
  using assms
proof -
  obtain b vv where vIntVal: "v = IntVal b vv"
    using isBoolean_def assms by blast
  then have negationDefined: "intval_logic_negation v  UndefVal"
    by simp
  show ?thesis
    using assms(1) negationDefined by fastforce
qed

lemma negation_preserve_eval2:
  assumes "([m, p]  exp[e]  v)"
  assumes "(isBoolean e)"
  shows "v'. ([m, p]  exp[!e]  v')  v = val[!v']"
  using assms
proof -
  obtain notEval where notEval: "([m, p]  exp[!e]  notEval)"
    by (metis assms negation_preserve_eval0)
  then have logicNegateEquiv: "notEval = intval_logic_negation v"
    using evalDet  assms(1) unary_eval.simps(4) by blast
  then have vRange: "v = IntVal 32 0  v = IntVal 32 1"
    using assms by (auto simp add: isBoolean_def)
  have evaluateNot: "v = intval_logic_negation notEval"
    by (metis IntVal0 IntVal1 intval_logic_negation.simps(1) logicNegateEquiv logic_negate_def
        vRange)
  then show ?thesis
    using notEval by auto
qed

optimization ConditionalIntegerEquals_2: "exp[BinaryExpr BinIntegerEquals (c ? x : y) (y)]  (!c)
                                          when stamp_expr x = IntegerStamp b xl xh  wf_stamp x 
                                               stamp_expr y = IntegerStamp b yl yh  wf_stamp y 
                                               (alwaysDistinct (stamp_expr x) (stamp_expr y)) 
                                               isBoolean c"
  apply (smt (verit) not_add_less1 max_less_iff_conj max.absorb3 linorder_less_linear add_2_eq_Suc'
         add_less_cancel_right size_binary_lhs add_lessD1 Canonicalization.cond_size)
  apply auto
  subgoal premises p for m p cExpr yv cond trE faE
  proof -
    obtain cond where cond: "[m,p]  c  cond"
      using p by blast
    then have condNotUndef: "cond  UndefVal"
      by (simp add: evaltree_not_undef)
    then obtain notCond where notCond: "[m,p]  exp[!c]  notCond"
      by (meson p(6) negation_preserve_eval2 cond)
    have cRange: "cond = IntVal 32 0  cond = IntVal 32 1"
      using p cond by (simp add: isBoolean_def)
    then have cNotRange:  "notCond = IntVal 32 0  notCond = IntVal 32 1"
      by (metis (no_types, lifting) IntVal0 IntVal1 cond evalDet intval_logic_negation.simps(1)
          logic_negate_def negation_preserve_eval notCond)
    then obtain xv where xv: "[m,p]  x  xv"
      using p by auto
    then have trueCond: "(notCond = IntVal 32 1)  [m,p]  (ConditionalExpr c x y)  yv"
      by (smt (verit, best) cRange evalDet negates negation_preserve_eval notCond p(7) cond
          zero_less_numeral val_to_bool.simps(1) evaltree_not_undef ConditionalExpr
          ConditionalExprE)
    obtain xvv where xvv: "xv = IntVal b xvv"
      by (metis p(1,2) valid_int wf_stamp_def xv)
    then have opposites: "notCond = intval_logic_negation cond"
      by (metis cond evalDet negation_preserve_eval notCond)
    then have negate: "(intval_logic_negation cond = IntVal 32 0)  (cond = IntVal 32 1)"
      using cRange intval_logic_negation.simps negates by fastforce
    have falseCond: "(notCond = IntVal 32 0)  [m,p]  (ConditionalExpr c x y)  xv"
      unfolding opposites using negate cond evalDet p(13,14,15,16) xv by auto
    obtain yvv where yvv: "yv = IntVal b yvv"
      by (metis p(3,4,7) wf_stamp_def ExpIntBecomesIntVal)
    have yxDiff: "xv  yv"
      by (metis linorder_not_less max.absorb1 max.absorb4 max_less_iff_conj min_def xv yvv
          wf_stamp_def valid_int_signed_range p(1,2,3,4,5,7))
    then have trueEvalCond: "(cond = IntVal 32 0) 
                         [m,p]  exp[BinaryExpr BinIntegerEquals (c ? x : y) (y)]
                                intval_equals yv yv"
      by (smt (verit) cNotRange trueCond ConditionalExprE cond bin_eval.simps(13) evalDet p
          falseCond unfold_binary val_to_bool.simps(1))
    then have falseEval: "(notCond = IntVal 32 0) 
                         [m,p]  exp[BinaryExpr BinIntegerEquals (c ? x : y) (y)]
                                intval_equals xv yv"
      using p by (metis ConditionalExprE bin_eval.simps(13) evalDet falseCond unfold_binary)
    have eqEvalFalse: "intval_equals yv xv = (IntVal 32 0)"
      unfolding xvv yvv apply auto by (metis (mono_tags) bool_to_val.simps(2) yxDiff yvv xvv)
    have trueEvalEquiv: "[m,p]  exp[BinaryExpr BinIntegerEquals (c ? x : y) (y)]  notCond"
      apply (cases notCond) prefer 2
      apply (metis IntVal0 Value.distinct(1) eqEvalFalse evalDet evaltree_not_undef falseEval p(6)
             intval_commute intval_logic_negation.simps(1) intval_self_is_true logic_negate_def
             negation_preserve_eval2 notCond trueEvalCond yvv cNotRange cond)
      using notCond cNotRange by auto
    show ?thesis
      using ConditionalExprE
      by (metis cNotRange falseEval notCond trueEvalEquiv trueCond falseCond intval_self_is_true
          yvv p(9,11) evalDet)
  qed
  done

optimization ConditionalExtractCondition: "exp[(c ? true : false)]  c
                                          when isBoolean c"
  using isBoolean_def by fastforce

optimization ConditionalExtractCondition2: "exp[(c ? false : true)]  !c
                                          when isBoolean c"
  apply auto
  subgoal premises p for m p cExpr cond
  proof-
    obtain cond where cond: "[m,p]  c  cond"
      using p(2) by auto
    obtain notCond where notCond: "[m,p]  exp[!c]  notCond"
      by (metis cond negation_preserve_eval2 p(1))
    then have cRange: "cond = IntVal 32 0  cond = IntVal 32 1"
      using isBoolean_def cond p(1) by auto
    then have cExprRange: "cExpr = IntVal 32 0  cExpr = IntVal 32 1"
      by (metis (full_types) ConstantExprE p(4))
    then have condTrue: "cond = IntVal 32 1  cExpr = IntVal 32 0"
      using cond evalDet p(2) p(4) by fastforce
    then have condFalse: "cond = IntVal 32 0  cExpr = IntVal 32 1"
      using p cond evalDet by fastforce
    then have opposite: "cond = intval_logic_negation cExpr"
      by (metis (full_types) IntVal0 IntVal1 cRange condTrue intval_logic_negation.simps(1)
          logic_negate_def)
    then have eq: "notCond = cExpr"
      by (metis (no_types, lifting) IntVal0 IntVal1 cExprRange cond evalDet negation_preserve_eval
          intval_logic_negation.simps(1) logic_negate_def notCond)
    then show ?thesis
      using notCond by auto
  qed
  done

optimization ConditionalEqualIsRHS: "((x eq y) ? x : y)  y"
  apply auto
  subgoal premises p for m p v true false xa ya
  proof-
    obtain xv where xv: "[m,p]  x  xv"
      using p(8) by auto
    obtain yv where yv: "[m,p]  y  yv"
      using p(9) by auto
    have notUndef: "xv  UndefVal  yv  UndefVal"
      using evaltree_not_undef xv yv by blast
    have evalNotUndef: "intval_equals xv yv  UndefVal"
      by (metis evalDet p(1,8,9) xv yv)
    obtain xb xvv where xvv: "xv = IntVal xb xvv"
      by (metis Value.exhaust evalNotUndef intval_equals.simps(3,4,5) notUndef)
    obtain yb yvv where yvv: "yv = IntVal yb yvv"
      by (metis evalNotUndef intval_equals.simps(7,8,9) intval_logic_negation.cases notUndef)
    obtain vv where evalLHS: "[m,p]  if val_to_bool (intval_equals xv yv) then x else y  vv"
      by (metis (full_types) p(4) yv)
    obtain equ where equ: "equ = intval_equals xv yv"
      by fastforce
    have trueEval: "equ = IntVal 32 1  vv = xv"
      using evalLHS by (simp add: evalDet xv equ)
    have falseEval: "equ = IntVal 32 0  vv = yv"
      using evalLHS by (simp add: evalDet yv equ)
    then have "vv = v"
      by (metis evalDet evalLHS p(2,8,9) xv yv)
    then show ?thesis
      by (metis (full_types) bool_to_val.simps(1,2) bool_to_val_bin.simps equ evalNotUndef falseEval
          intval_equals.simps(1) trueEval xvv yv yvv)
  qed
  done

(* todo not sure if this is done properly *)
optimization normalizeX: "((x eq const (IntVal 32 0)) ? 
                                (const (IntVal 32 0)) : (const (IntVal 32 1)))  x
                                when stamp_expr x = IntegerStamp 32 0 1  wf_stamp x 
                                     isBoolean x"
  apply auto
  subgoal premises p for m p v
    proof -
      obtain xa where xa: "[m,p]  x  xa"
        using p by blast
       have eval: "[m,p]  if val_to_bool (intval_equals xa (IntVal 32 0))
                        then ConstantExpr (IntVal 32 0)
                        else ConstantExpr (IntVal 32 1)  v"
         using evalDet p(3,4,5,6,7) xa by blast
       then have xaRange: "xa = IntVal 32 0  xa = IntVal 32 1"
         using isBoolean_def p(3) xa by blast
      then have 6: "v = xa"
        using eval xaRange by auto
      then show ?thesis
        by (auto simp: xa)
    qed
  done

(* todo not sure if this is done properly *)
optimization normalizeX2: "((x eq (const (IntVal 32 1))) ? 
                                  (const (IntVal 32 1)) : (const (IntVal 32 0)))  x
                                   when (x = ConstantExpr (IntVal 32 0) | 
                                        (x = ConstantExpr (IntVal 32 1)))" .

(* todo not sure if this is done properly *)
optimization flipX: "((x eq (const (IntVal 32 0))) ? 
                            (const (IntVal 32 1)) : (const (IntVal 32 0)))  x  (const (IntVal 32 1))
                             when (x = ConstantExpr (IntVal 32 0) | 
                                  (x = ConstantExpr (IntVal 32 1)))" .

(* todo not sure if this is done properly *)
optimization flipX2: "((x eq (const (IntVal 32 1))) ? 
                             (const (IntVal 32 0)) : (const (IntVal 32 1)))  x  (const (IntVal 32 1))
                              when (x = ConstantExpr (IntVal 32 0) | 
                                   (x = ConstantExpr (IntVal 32 1)))" .

lemma stamp_of_default:
  assumes "stamp_expr x = default_stamp"
  assumes "wf_stamp x"
  shows "([m, p]  x  v)  (vv. v = IntVal 32 vv)"
  by (metis assms default_stamp valid_value_elims(3) wf_stamp_def)

optimization OptimiseIntegerTest: 
     "(((x & (const (IntVal 32 1))) eq (const (IntVal 32 0))) ? 
      (const (IntVal 32 0)) : (const (IntVal 32 1)))  
       x & (const (IntVal 32 1))
       when (stamp_expr x = default_stamp  wf_stamp x)"
  apply (simp; rule impI; (rule allI)+; rule impI)
  subgoal premises eval for m p v
proof -
  obtain xv where xv: "[m, p]  x  xv"
    using eval by fast
  then have x32: "v. xv = IntVal 32 v"
    using stamp_of_default eval by auto
  obtain lhs where lhs: "[m, p]  exp[(((x & (const (IntVal 32 1))) eq (const (IntVal 32 0))) ? 
                                 (const (IntVal 32 0)) : (const (IntVal 32 1)))]  lhs"
    using eval(2) by auto
  then have lhsV: "lhs = val[((xv & (IntVal 32 1)) eq (IntVal 32 0)) ? 
                        (IntVal 32 0) : (IntVal 32 1)]"
    using ConditionalExprE ConstantExprE bin_eval.simps(4,11) evalDet xv unfold_binary
          intval_conditional.simps
    by fastforce
  obtain rhs where rhs: "[m, p]  exp[x & (const (IntVal 32 1))]  rhs"
    using eval(2) by blast
  then have rhsV: "rhs = val[xv & IntVal 32 1]"
    by (metis BinaryExprE ConstantExprE bin_eval.simps(6) evalDet xv)
  have "lhs = rhs" 
    using val_optimise_integer_test x32 lhsV rhsV by presburger
  then show ?thesis
    by (metis eval(2) evalDet lhs rhs)
qed
  done

(* todo not sure if this is done properly *)
optimization opt_optimise_integer_test_2: 
     "(((x & (const (IntVal 32 1))) eq (const (IntVal 32 0))) ? 
             (const (IntVal 32 0)) : (const (IntVal 32 1)))  x
              when (x = ConstantExpr (IntVal 32 0) | (x = ConstantExpr (IntVal 32 1)))" .

(*
optimization opt_conditional_eliminate_known_less: "((x < y) ? x : y) ⟼ x 
                                 when (((stamp_under (stamp_expr x) (stamp_expr y)) |
                                      ((stpi_upper (stamp_expr x)) = (stpi_lower (stamp_expr y))))
                                      ∧ wf_stamp x ∧ wf_stamp y)"
   apply auto using stamp_under_defn
  apply simp sorry
*)

(*
optimization opt_normalize_x_original: "((BinaryExpr BinIntegerEquals x (ConstantExpr (IntVal32 0))) ? 
                                (ConstantExpr (IntVal32 0)) : (ConstantExpr (IntVal32 1))) ⟼ x
                                when (stamp_expr x = IntegerStamp 32 0 1 ∧ 
                                      wf_stamp x)"
   apply unfold_optimization apply simp_all
  using wf_stamp_def apply (cases x; simp) 
  
  sorry
*)

(** End of new proofs **)

end

end