Theory TacticSolving

theory TacticSolving
  imports Common
begin

fun size :: "IRExpr  nat" where
  "size (UnaryExpr op e) = (size e) * 2" |
  "size (BinaryExpr BinAdd x y) = (size x) + ((size y) * 2)" |
  "size (BinaryExpr op x y) = (size x) + (size y)" |
  "size (ConditionalExpr cond t f) = (size cond) + (size t) + (size f) + 2" |
  "size (ConstantExpr c) = 1" |
  "size (ParameterExpr ind s) = 2" |
  "size (LeafExpr nid s) = 2" |
  "size (ConstantVar c) = 2" |
  "size (VariableExpr x s) = 2"

lemma size_pos[simp]: "0 < size y"
  apply (induction y; auto?)
  subgoal premises prems for op a b
    using prems by (induction op; auto)
  done

phase TacticSolving
  terminating size
begin

subsection ‹AddNode›
(*lemma val_add_left_negate_to_sub:
  "val[-x + y] ≈ val[y - x]"
  apply simp by (cases x; cases y; auto)

lemma exp_add_left_negate_to_sub:
  "exp[-x + y] ≥ exp[y - x]"
  using val_add_left_negate_to_sub by auto*)

lemma value_approx_implies_refinement:
  assumes "lhs  rhs"
  assumes "m p v. ([m, p]  elhs  v)  v = lhs"
  assumes "m p v. ([m, p]  erhs  v)  v = rhs"
  assumes "m p v1 v2. ([m, p]  elhs  v1)  ([m, p]  erhs  v2)"
  shows "elhs  erhs"
  by (metis assms(4) le_expr_def evaltree_not_undef)

method explore_cases for x y :: Value =
  (cases x; cases y; auto)

method explore_cases_bin for x :: IRExpr =
  (cases x; auto)

method obtain_approx_eq for lhs rhs x y :: Value =
  (rule meta_mp[where P="lhs  rhs"], defer_tac, explore_cases x y)

method obtain_eval for exp::IRExpr and val::Value =
  (rule meta_mp[where P="m p v. ([m, p]  exp  v)  v = val"], defer_tac)

method solve for lhs rhs x y :: Value =
  (match conclusion in "size _ < size _"  simp)?,
  (match conclusion in "(elhs::IRExpr)  (erhs::IRExpr)" for elhs erhs  (obtain_approx_eq lhs rhs x y)?)

print_methods
(*
    (simp del: well_formed_equal_def le_expr_def)?;
    ((rule allI)+)?›)*)
thm BinaryExprE
optimization opt_add_left_negate_to_sub:
  "-x + y  y - x"
  (*defer apply simp apply (rule allI)+ apply (rule impI)
   apply (subgoal_tac "∀x1. [m, p] ⊢ exp[-x + y] ↦ x1") defer
  *)
   apply (solve "val[-x1 + y1]" "val[y1 - x1]" x1 y1)
  apply simp apply auto using evaltree_not_undef sorry
(*
  apply (obtain_eval "exp[-x + y]" "val[-x1 + y1]")
  

  apply (rule BinaryExprE)
  apply (rule allI)+ sorry
  apply (auto simp: unfold_evaltree) sorry*)
  (*
   defer apply (test "val[-x1 + y1]" "val[y1 - x1]" x1 y1)
   apply (rule meta_mp[where P="val[-x1 + y1] ≈ val[y1 - x1]"])
    prefer 2 apply (cases x1; cases y1; auto)
   apply (subgoal_tac "val[-x1 + y1] ≈ val[y1 - x1]")
    apply (cases x1; cases y1; auto)
  using exp_add_left_negate_to_sub apply simp
  unfolding size.simps by simp*)

subsection ‹NegateNode›

lemma val_distribute_sub: 
 "val[-(x-y)]  val[y-x]" 
  by (cases x; cases y; auto) 

optimization distribute_sub: "-(x-y)  (y-x)" 
  using val_distribute_sub unfold_binary unfold_unary by auto

lemma val_xor_self_is_false:
  assumes "x = IntVal 32 v"
  shows "val[x  x]  val[false]"
  by (cases x; auto simp: assms)

definition wf_stamp :: "IRExpr  bool" where
  "wf_stamp e = (m p v. ([m, p]  e  v)  valid_value v (stamp_expr e))"

lemma exp_xor_self_is_false: 
  assumes "stamp_expr x = IntegerStamp 32 l h"
  assumes "wf_stamp x"
  shows "exp[x  x] >= exp[false]"
  by (smt (z3) wf_value_def bin_eval.simps(8) bin_eval_new_int constantAsStamp.simps(1) evalDet 
      int_signed_value_bounds new_int.simps new_int_take_bits unfold_binary unfold_const valid_int 
      valid_stamp.simps(1) valid_value.simps(1) well_formed_equal_defn val_xor_self_is_false 
      le_expr_def assms wf_stamp_def)

lemma val_or_commute[simp]:
   "val[x | y] = val[y | x]"
  by (cases x; cases y; auto simp: or.commute)

lemma val_xor_commute[simp]:
   "val[x  y] = val[y  x]"
  by (cases x; cases y; auto simp: word_bw_comms(3))

lemma val_and_commute[simp]:
   "val[x & y] = val[y & x]"
  by (cases x; cases y; auto simp: word_bw_comms(1))

lemma exp_or_commutative:
  "exp[x | y]  exp[y | x]"
  by auto 

lemma exp_xor_commutative:
  "exp[x  y]  exp[y  x]"
  by auto 

lemma exp_and_commutative:
  "exp[x & y]  exp[y & x]"
  by auto 

text ‹--- --- New Optimisations - submitted and added into Graal ---›
lemma OrInverseVal:
  assumes "n = IntVal 32 v"
  shows "val[n | ~n]  new_int 32 (-1)"
  apply (auto simp: assms)
  by (metis bit.disj_cancel_right mask_eq_take_bit_minus_one take_bit_or)

optimization OrInverse: "exp[n | ~n]  (const (new_int 32 (not 0)))
                        when (stamp_expr n = IntegerStamp 32 l h  wf_stamp n)"
   apply (auto simp: Suc_lessI)
  subgoal premises p for m p xa xaa
  proof -
    obtain nv where nv: "[m,p]  n  nv"
      using p(3) by auto
    obtain nbits nvv where nvv: "nv = IntVal nbits nvv"
      by (metis evalDet evaltree_not_undef intval_logic_negation.cases intval_not.simps(3,4,5) nv
          p(5,6))
    then have width: "nbits = 32"
      by (metis Value.inject(1) nv p(1,2) valid_int wf_stamp_def)
    then have stamp: "constantAsStamp (IntVal 32 (mask 32)) =
                  (IntegerStamp 32 (int_signed_value 32 (mask 32)) (int_signed_value 32 (mask 32)))"
      by auto
    have wf: "wf_value (IntVal 32 (mask 32))"
      unfolding wf_value_def stamp apply auto by eval+
    then have unfoldOr: "val[nv | ~nv] = (new_int 32 (or (not nvv) nvv))"
      using intval_or.simps OrInverseVal nvv width by auto
    then have eq: "val[nv | ~nv] = new_int 32 (not 0)"
      by (simp add: unfoldOr)
    then show ?thesis
      by (metis bit.compl_zero evalDet local.wf new_int.elims nv p(3,5) take_bit_minus_one_eq_mask
          unfold_const)
  qed
  done

optimization OrInverse2: "exp[~n | n]  (const (new_int 32 (not 0)))
                        when (stamp_expr n = IntegerStamp 32 l h  wf_stamp n)"
   using OrInverse exp_or_commutative by auto

lemma XorInverseVal:
  assumes "n = IntVal 32 v"
  shows "val[n  ~n]  new_int 32 (-1)"
  apply (auto simp: assms)
  by (metis (no_types, opaque_lifting) bit.compl_zero bit.xor_compl_right bit.xor_self take_bit_xor
      mask_eq_take_bit_minus_one)

optimization XorInverse: "exp[n  ~n]  (const (new_int 32 (not 0)))
                        when (stamp_expr n = IntegerStamp 32 l h  wf_stamp n)"
  apply (auto simp: Suc_lessI)
  subgoal premises p for m p xa xaa
  proof-
    obtain xv where xv: "[m,p]  n  xv"
      using p(3) by auto
    obtain xb xvv where xvv: "xv = IntVal xb xvv"
      by (metis evalDet evaltree_not_undef intval_logic_negation.cases intval_not.simps(3,4,5) xv
          p(5,6))
    have rhsDefined: "[m,p]  (ConstantExpr (IntVal 32 (mask 32)))  (IntVal 32 (mask 32))"
      by (metis ConstantExpr add.right_neutral add_less_cancel_left neg_one_value numeral_Bit0
          new_int_unused_bits_zero not_numeral_less_zero validDefIntConst zero_less_numeral
          verit_comp_simplify1(3) wf_value_def)
    have w32: "xb=32"
      by (metis Value.inject(1) p(1,2) valid_int xv xvv wf_stamp_def)
    then have unfoldNot: "val[(¬xv)] = new_int xb (not xvv)"
      by (simp add: xvv)
    have unfoldXor: "val[xv  (¬xv)] =
                    (if xb=xb then (new_int xb (xor xvv (not xvv))) else UndefVal)"
      using intval_xor.simps(1) XorInverseVal w32 xvv by auto
    then have rhs: "val[xv  (¬xv)] = new_int 32 (mask 32)"
      using unfoldXor w32 by auto
    then show ?thesis
      by (metis evalDet neg_one.elims neg_one_value p(3,5) rhsDefined xv)
  qed
  done

optimization XorInverse2: "exp[(~n)  n]  (const (new_int 32 (not 0)))
                        when (stamp_expr n = IntegerStamp 32 l h  wf_stamp n)"
   using XorInverse exp_xor_commutative by auto

lemma AndSelfVal:
  assumes "n = IntVal 32 v"
  shows "val[~n & n] = new_int 32 0"
  apply (auto simp: assms) 
  by (metis take_bit_and take_bit_of_0 word_and_not)

optimization AndSelf: "exp[(~n) & n]  (const (new_int 32 (0)))
                        when (stamp_expr n = IntegerStamp 32 l h  wf_stamp n)"
  apply (auto simp: Suc_lessI) unfolding size.simps
  by (metis (no_types) val_and_commute ConstantExpr IntVal0 Value.inject(1) evalDet wf_stamp_def
      eval_bits_1_64 new_int.simps validDefIntConst valid_int wf_value_def AndSelfVal)

optimization AndSelf2: "exp[n & (~n)]  (const (new_int 32 (0)))
                        when (stamp_expr n = IntegerStamp 32 l h  wf_stamp n)"
  using AndSelf exp_and_commutative by auto

lemma NotXorToXorVal:
  assumes "x = IntVal 32 xv"
  assumes "y = IntVal 32 yv"
  shows "val[(~x)  (~y)] = val[x  y]" 
  apply (auto simp: assms) 
  by (metis (no_types, opaque_lifting) bit.xor_compl_left bit.xor_compl_right take_bit_xor 
      word_not_not) 

lemma NotXorToXorExp:
  assumes "stamp_expr x = IntegerStamp 32 lx hx"
  assumes "wf_stamp x"
  assumes "stamp_expr y = IntegerStamp 32 ly hy"
  assumes "wf_stamp y"
  shows "exp[(~x)  (~y)]  exp[x  y]" 
  apply auto 
  subgoal premises p for m p xa xb
    proof -
      obtain xa where xa: "[m,p]  x  xa"
        using p by blast
      obtain xb where xb: "[m,p]  y  xb"
        using p by blast
      then have a: "val[(~xa)  (~xb)] = val[xa  xb]" 
        by (metis assms valid_int wf_stamp_def xa xb NotXorToXorVal)
      then show ?thesis
        by (metis BinaryExpr bin_eval.simps(8) evalDet p(1,2,4) xa xb)
    qed 
  done

optimization NotXorToXor: "exp[(~x)  (~y)]  (x  y)
                        when (stamp_expr x = IntegerStamp 32 lx hx  wf_stamp x) 
                             (stamp_expr y = IntegerStamp 32 ly hy  wf_stamp y)"
  using NotXorToXorExp by simp

end

text ‹--- New optimisations - submitted, not added into Graal yet ---›

context stamp_mask
begin

(* Extension to old Or optimisation 
   x | y ↦ -1 when (downMask x | downMask y == -1)
*)

lemma ExpIntBecomesIntValArbitrary:
  assumes "stamp_expr x = IntegerStamp b xl xh"
  assumes "wf_stamp x"
  assumes "valid_value v (IntegerStamp b xl xh)"
  assumes "[m,p]  x  v"
  shows "xv. v = IntVal b xv"
  using assms by (simp add: IRTreeEvalThms.valid_value_elims(3)) 

lemma OrGeneralization:
  assumes "stamp_expr x = IntegerStamp b xl xh"
  assumes "stamp_expr y = IntegerStamp b yl yh"
  assumes "stamp_expr exp[x | y] = IntegerStamp b el eh"
  assumes "wf_stamp x"
  assumes "wf_stamp y"
  assumes "wf_stamp exp[x | y]"
  assumes "(or (x) (y)) = not 0" 
  shows "exp[x | y]  exp[(const (new_int b (not 0)))]"
  using assms apply auto
  subgoal premises p for m p xvv yvv
  proof -
    obtain xv where xv: "[m, p]  x  IntVal b xv"
      by (metis p(1,3,9) valid_int wf_stamp_def)
    obtain yv where yv: "[m, p]  y  IntVal b yv"
      by (metis p(2,4,10) valid_int wf_stamp_def)
    obtain evv where ev: "[m, p]  exp[x | y]  IntVal b evv"
      by (metis BinaryExpr bin_eval.simps(7) unfold_binary p(5,9,10,11) valid_int wf_stamp_def
          assms(3))
    then have rhsWf: "wf_value (new_int b (not 0))"
      by (metis eval_bits_1_64 new_int.simps new_int_take_bits validDefIntConst wf_value_def)
    then have rhs: "(new_int b (not 0)) = val[IntVal b xv | IntVal b yv]" 
      using assms word_ao_absorbs(1)
      by (metis (no_types, opaque_lifting) bit.de_Morgan_conj word_bw_comms(2) xv down_spec
          word_not_not yv bit.disj_conj_distrib intval_or.simps(1) new_int_bin.simps ucast_id
          or.right_neutral)
    then have notMaskEq: "(new_int b (not 0)) = (new_int b (mask b))"
      by auto
    then show ?thesis 
      by (metis neg_one.elims neg_one_value p(9,10) rhsWf unfold_const evalDet xv yv rhs)
    qed
    done
end

phase TacticSolving
  terminating size
begin

(* Add
   x + ~x ↦ -1 
*)

lemma constEvalIsConst: 
  assumes "wf_value n"
  shows "[m,p]  exp[(const (n))]  n"  
  by (simp add: assms IRTreeEval.evaltree.ConstantExpr)

lemma ExpAddCommute:
  "exp[x + y]  exp[y + x]"
  by (auto simp add: Values.intval_add_sym)

lemma AddNotVal:
  assumes "n = IntVal bv v"
  shows "val[n + (~n)] = new_int bv (not 0)"
  by (auto simp: assms)

lemma AddNotExp:
  assumes "stamp_expr n = IntegerStamp b l h"
  assumes "wf_stamp n"
  shows "exp[n + (~n)]  exp[(const (new_int b (not 0)))]"
  apply auto
  subgoal premises p for m p x xa
  proof -
    have xaDef: "[m,p]  n  xa"
      by (simp add: p)
    then have xaDef2: "[m,p]  n  x"
      by (simp add: p)
    then have "xa = x" 
      using p by (simp add: evalDet)
    then obtain xv where xv: "xa = IntVal b xv"
      by (metis valid_int wf_stamp_def xaDef2 assms)
    have toVal: "[m,p]  exp[n + (~n)]  val[xa + (~xa)]"
      by (metis UnaryExpr bin_eval.simps(1) evalDet p unary_eval.simps(3) unfold_binary xaDef)
    have wfInt: "wf_value (new_int b (not 0))"
      using validDefIntConst xaDef by (simp add: eval_bits_1_64 xv wf_value_def) 
    have toValRHS: "[m,p]  exp[(const (new_int b (not 0)))]  new_int b (not 0)"
      using wfInt by (simp add: constEvalIsConst)
    have isNeg1: "val[xa + (~xa)] = new_int b (not 0)"
      by (simp add: xv)
    then show ?thesis
      using toValRHS by (simp add: (xa::Value) = (x::Value))
    qed 
   done

optimization AddNot: "exp[n + (~n)]  (const (new_int b (not 0)))
                        when (stamp_expr n = IntegerStamp b l h  wf_stamp n)"
   apply (simp add: Suc_lessI) using AddNotExp by force

optimization AddNot2: "exp[(~n) + n]  (const (new_int b (not 0)))
                        when (stamp_expr n = IntegerStamp b l h  wf_stamp n)"
   apply (simp add: Suc_lessI) using AddNot ExpAddCommute by simp

(* 
  ~e == e ↦ false
 *)

lemma TakeBitNotSelf:
  "(take_bit 32 (not e) = e) = False"
  by (metis even_not_iff even_take_bit_eq zero_neq_numeral)

lemma ValNeverEqNotSelf:
  assumes "e = IntVal 32 ev"
  shows "val[intval_equals (¬e) e] = val[bool_to_val False]"
  by (simp add: TakeBitNotSelf assms)

lemma ExpIntBecomesIntVal:
  assumes "stamp_expr x = IntegerStamp 32 xl xh"
  assumes "wf_stamp x"
  assumes "valid_value v (IntegerStamp 32 xl xh)"
  assumes "[m,p]  x  v"
  shows "xv. v = IntVal 32 xv"
  using assms by (simp add: IRTreeEvalThms.valid_value_elims(3)) 

lemma ExpNeverNotSelf:
  assumes "stamp_expr x = IntegerStamp 32 xl xh"
  assumes "wf_stamp x"
  shows "exp[BinaryExpr BinIntegerEquals (¬x) x] 
         exp[(const (bool_to_val False))]" 
  using assms apply auto
  subgoal premises p for m p xa xaa
  proof -
    obtain xa where xa: "[m,p]  x  xa"
      using p(5) by auto
    then obtain xv where xv: "xa = IntVal 32 xv"
      by (metis p(1,2) valid_int wf_stamp_def)
    then have lhsVal: "[m,p]  exp[BinaryExpr BinIntegerEquals (¬x) x]  
                               val[intval_equals (¬xa) xa]" 
      by (metis p(3,4,5,6) unary_eval.simps(3) evaltree.BinaryExpr bin_eval.simps(13) xa UnaryExpr 
          evalDet)
    have wfVal: "wf_value (IntVal 32 0)" 
      using wf_value_def apply rule 
      by (metis IntVal0 intval_word.simps nat_le_linear new_int.simps numeral_le_iff wf_value_def
          semiring_norm(71,76) validDefIntConst verit_comp_simplify1(3) zero_less_numeral)
    then have rhsVal: "[m,p]  exp[(const (bool_to_val False))]  val[bool_to_val False]"
      by auto
    then have valEq: "val[intval_equals (¬xa) xa] = val[bool_to_val False]" 
      using ValNeverEqNotSelf by (simp add: xv)
    then show ?thesis
      by (metis bool_to_val.simps(2) evalDet p(3,5) rhsVal xa)
   qed
  done

optimization NeverEqNotSelf: "exp[BinaryExpr BinIntegerEquals (¬x) x]  
                              exp[(const (bool_to_val False))]
                        when (stamp_expr x = IntegerStamp 32 xl xh  wf_stamp x)"
  apply (simp add: Suc_lessI) using ExpNeverNotSelf by force

text ‹--- New optimisations - not submitted / added into Graal yet ---›
(* 
  (x ^ y) == x ↦ y == 0
  x == (x ^ y) ↦ y == 0 
  (x ^ y) == y ↦ x == 0 
  y == (x ^ y) ↦ x == 0
 *)
lemma BinXorFallThrough:
  shows "bin[(x  y) = x]  bin[y = 0]"
  by (metis xor.assoc xor.left_neutral xor_self_eq)

lemma valXorEqual:
  assumes "x = new_int 32 xv"
  assumes "val[x  x]  UndefVal"
  shows "val[x  x] = val[new_int 32 0]"
  using assms by (cases x; auto)

lemma valXorAssoc:
  assumes "x = new_int b xv"
  assumes "y = new_int b yv"
  assumes "z = new_int b zv"
  assumes "val[(x  y)  z]  UndefVal"
  assumes "val[x  (y  z)]  UndefVal"
  shows "val[(x  y)  z] = val[x  (y  z)]"
  by (simp add: xor.commute xor.left_commute assms)

lemma valNeutral:
  assumes "x = new_int b xv"
  assumes "val[x  (new_int b 0)]  UndefVal"
  shows "val[x  (new_int b 0)] = val[x]"
  using assms by (auto; meson)

lemma ValXorFallThrough:
  assumes "x = new_int b xv"
  assumes "y = new_int b yv"
  shows "val[intval_equals (x  y) x] = val[intval_equals y (new_int b 0)]"
  by (simp add: assms BinXorFallThrough)

lemma ValEqAssoc:
  "val[intval_equals x y] = val[intval_equals y x]"
  apply (cases x; cases y; auto) by (metis (full_types) bool_to_val.simps)

lemma ExpEqAssoc:
  "exp[BinaryExpr BinIntegerEquals x y]  exp[BinaryExpr BinIntegerEquals y x]"
  by (auto simp add: ValEqAssoc)

lemma ExpXorBinEqCommute:
  "exp[BinaryExpr BinIntegerEquals (x  y) y]  exp[BinaryExpr BinIntegerEquals (y  x) y]"
  using exp_xor_commutative mono_binary by blast

lemma ExpXorFallThrough:
  assumes "stamp_expr x = IntegerStamp b xl xh"
  assumes "stamp_expr y = IntegerStamp b yl yh"
  assumes "wf_stamp x"
  assumes "wf_stamp y"
  shows "exp[BinaryExpr BinIntegerEquals (x  y) x] 
         exp[BinaryExpr BinIntegerEquals y (const (new_int b 0))]"
  using assms apply auto 
  subgoal premises p for m p xa xaa ya
  proof -
    obtain b xv where xa: "[m,p]  x  new_int b xv"
      using intval_equals.elims 
      by (metis new_int.simps eval_unused_bits_zero p(1,3,5) wf_stamp_def valid_int)
    obtain yv where ya: "[m,p]  y  new_int b yv"
      by (metis Value.inject(1) wf_stamp_def p(1,2,3,4,8) eval_unused_bits_zero xa new_int.simps
          valid_int)
    then have wfVal: "wf_value (new_int b 0)"
      by (metis eval_bits_1_64 new_int.simps new_int_take_bits validDefIntConst wf_value_def xa)
    then have eval: "[m,p]  exp[BinaryExpr BinIntegerEquals y (const (new_int b 0))]  
                             val[intval_equals (xa  ya) xa]" 
      by (metis (no_types, lifting) ValXorFallThrough constEvalIsConst bin_eval.simps(13) evalDet xa
          p(5,6,7,8) unfold_binary ya)
    then show ?thesis
      by (metis evalDet new_int.elims p(1,3,5,7) take_bit_of_0 valid_value.simps(1) wf_stamp_def xa)
   qed 
  done

lemma ExpXorFallThrough2:
  assumes "stamp_expr x = IntegerStamp b xl xh"
  assumes "stamp_expr y = IntegerStamp b yl yh"
  assumes "wf_stamp x"
  assumes "wf_stamp y"
  shows "exp[BinaryExpr BinIntegerEquals (x  y) y] 
         exp[BinaryExpr BinIntegerEquals x (const (new_int b 0))]"
  by (meson assms dual_order.trans ExpXorBinEqCommute ExpXorFallThrough)

optimization XorFallThrough1: "exp[BinaryExpr BinIntegerEquals (x  y) x]  
                               exp[BinaryExpr BinIntegerEquals y (const (new_int b 0))]
                        when (stamp_expr x = IntegerStamp b xl xh  wf_stamp x)  
                             (stamp_expr y = IntegerStamp b yl yh  wf_stamp y)"
  using ExpXorFallThrough by force

optimization XorFallThrough2: "exp[BinaryExpr BinIntegerEquals x (x  y)]  
                               exp[BinaryExpr BinIntegerEquals y (const (new_int b 0))]
                        when (stamp_expr x = IntegerStamp b xl xh  wf_stamp x)  
                             (stamp_expr y = IntegerStamp b yl yh  wf_stamp y)"
  using ExpXorFallThrough ExpEqAssoc by force

optimization XorFallThrough3: "exp[BinaryExpr BinIntegerEquals (x  y) y]  
                               exp[BinaryExpr BinIntegerEquals x (const (new_int b 0))]
                        when (stamp_expr x = IntegerStamp b xl xh  wf_stamp x)  
                             (stamp_expr y = IntegerStamp b yl yh  wf_stamp y)"
  using ExpXorFallThrough2 by force

optimization XorFallThrough4: "exp[BinaryExpr BinIntegerEquals y (x  y)]  
                               exp[BinaryExpr BinIntegerEquals x (const (new_int b 0))]
                        when (stamp_expr x = IntegerStamp b xl xh  wf_stamp x)  
                             (stamp_expr y = IntegerStamp b yl yh  wf_stamp y)"
  using ExpXorFallThrough2 ExpEqAssoc by force

end

context stamp_mask
begin

(* Ian's optimisation, and it's Or equivalent
    x & y ↦ x when x.up ∈ y.Down
    x | y ↦ y when x.up ∈ y.Down

    x.up ∈ y.Down means (x.up & y.Down = x.up), 
               equiv to (x.up | y.Down = y.Down)
*)

lemma inEquivalence:
  assumes "[m, p]  y  IntVal b yv"
  assumes "[m, p]  x  IntVal b xv"
  shows "(and (x) yv) = (x)  (or (x) yv) = yv"
  by (metis word_ao_absorbs(3) word_ao_absorbs(4))

lemma inEquivalence2:
  assumes "[m, p]  y  IntVal b yv"
  assumes "[m, p]  x  IntVal b xv"
  shows "(and (x) (y)) = (x)  (or (x) (y)) = (y)"
  by (metis word_ao_absorbs(3) word_ao_absorbs(4))

(* x | y ↦ y when x.up ∈ y.Down *)
lemma RemoveLHSOrMask:
  assumes "(and (x) (y)) = (x)"
  assumes "(or (x) (y)) = (y)"
  shows "exp[x | y]  exp[y]"
  using assms apply auto
  subgoal premises p for m p v
  proof -
    obtain b ev where exp: "[m, p]  exp[x | y]  IntVal b ev" 
      by (metis BinaryExpr bin_eval.simps(7) p(3,4,5) bin_eval_new_int new_int.simps)
    from exp obtain yv where yv: "[m, p]  y  IntVal b yv"
      apply (subst (asm) unfold_binary_width) by force+
    from exp obtain xv where xv: "[m, p]  x  IntVal b xv"
      apply (subst (asm) unfold_binary_width) by force+
    then have "yv = (or xv yv)"
      using assms yv xv apply auto
      by (metis (no_types, opaque_lifting) down_spec ucast_id up_spec word_ao_absorbs(1) word_or_not
          word_ao_equiv word_log_esimps(3) word_oa_dist word_oa_dist2)
    then have "(IntVal b yv) = val[(IntVal b xv) | (IntVal b yv)]"
      apply auto using eval_unused_bits_zero yv by presburger
    then show ?thesis    
      by (metis p(3,4) evalDet xv yv)
  qed
  done

(* x & y ↦ x when x.up ∈ y.Down *)
lemma RemoveRHSAndMask:
  assumes "(and (x) (y)) = (x)"
  assumes "(or (x) (y)) = (y)"
  shows "exp[x & y]  exp[x]"
  using assms apply auto
  subgoal premises p for m p v
  proof -
    obtain b ev where exp: "[m, p]  exp[x & y]  IntVal b ev"
      by (metis BinaryExpr bin_eval.simps(6) p(3,4,5) new_int.simps bin_eval_new_int)
    from exp obtain yv where yv: "[m, p]  y  IntVal b yv"
      apply (subst (asm) unfold_binary_width) by force+
    from exp obtain xv where xv: "[m, p]  x  IntVal b xv"
      apply (subst (asm) unfold_binary_width) by force+
    then have "IntVal b xv = val[(IntVal b xv) & (IntVal b yv)]"
      apply auto 
      by (smt (verit, ccfv_threshold) or.right_neutral not_down_up_mask_and_zero_implies_zero p(1)
          bit.conj_cancel_right word_bw_comms(1) eval_unused_bits_zero yv word_bw_assocs(1)
          word_ao_absorbs(4) or_eq_not_not_and)
    then show ?thesis     
      by (metis p(3,4) yv xv evalDet)
   qed
  done

(* Ian's new And optimisation
    x & y ↦ 0 when x.up & y.up = 0
*)
lemma ReturnZeroAndMask:
  assumes "stamp_expr x = IntegerStamp b xl xh"
  assumes "stamp_expr y = IntegerStamp b yl yh"
  assumes "stamp_expr exp[x & y] = IntegerStamp b el eh"
  assumes "wf_stamp x"
  assumes "wf_stamp y"
  assumes "wf_stamp exp[x & y]"
  assumes "(and (x) (y)) = 0"
  shows "exp[x & y]  exp[const (new_int b 0)]"
  using assms apply auto
  subgoal premises p for m p v
  proof -
    obtain yv where yv: "[m, p]  y  IntVal b yv"
      by (metis valid_int wf_stamp_def assms(2,5) p(2,4,10) wf_stamp_def)
    obtain xv where xv: "[m, p]  x  IntVal b xv"
      by (metis valid_int wf_stamp_def assms(1,4) p(3,9) wf_stamp_def)
    obtain ev where exp: "[m, p]  exp[x & y]  IntVal b ev"
      by (metis BinaryExpr bin_eval.simps(6) p(5,9,10,11) assms(3) valid_int wf_stamp_def)
    then have wfVal: "wf_value (new_int b 0)"
      by (metis eval_bits_1_64 new_int.simps new_int_take_bits validDefIntConst wf_value_def)
    then have lhsEq: "IntVal b ev = val[(IntVal b xv) & (IntVal b yv)]"
      by (metis bin_eval.simps(6) yv xv evalDet exp unfold_binary)
    then have newIntEquiv: "new_int b 0 = IntVal b ev" 
      apply auto by (smt (z3) p(6) eval_unused_bits_zero xv yv up_mask_and_zero_implies_zero)
    then have isZero: "ev = 0"
      by auto
    then show ?thesis
      by (metis evalDet lhsEq newIntEquiv p(9,10) unfold_const wfVal xv yv)
   qed
   done

end

phase TacticSolving
  terminating size
begin


(* 
 (x ^ y) == (x ^ z) ↦ y == z
 (x ^ y) == (z ^ x) ↦ y == z
 (y ^ x) == (x ^ z) ↦ y == z
 (y ^ x) == (z ^ x) ↦ y == z
 *)

lemma binXorIsEqual:
  "bin[((x  y) = (x  z))]  bin[(y = z)]"
  by (metis (no_types, opaque_lifting) BinXorFallThrough xor.left_commute xor_self_eq)

lemma binXorIsDeterministic:
  assumes "y  z"
  shows "bin[x  y]  bin[x  z]"
  by (auto simp add: binXorIsEqual assms)

lemma ValXorSelfIsZero:
  assumes "x = IntVal b xv"
  shows "val[x  x] = IntVal b 0" 
  by (simp add: assms)

lemma ValXorSelfIsZero2:
  assumes "x = new_int b xv"
  shows "val[x  x] = IntVal b 0" 
  by (simp add: assms)

lemma ValXorIsAssociative:
  assumes "x = IntVal b xv"
  assumes "y = IntVal b yv"
  assumes "val[(x  y)]  UndefVal"
  shows "val[(x  y)  y] = val[x  (y  y)]"
  by (auto simp add: word_bw_lcs(3) assms) 

lemma ValXorIsAssociative2:
  assumes "x = new_int b xv"
  assumes "y = new_int b yv"
  assumes "val[(x  y)]  UndefVal"
  shows "val[(x  y)  y] = val[x  (y  y)]"
  using ValXorIsAssociative by (simp add: assms)

lemma XorZeroIsSelf64:
  assumes "x = IntVal 64 xv"
  assumes "val[x  (IntVal 64 0)]  UndefVal"
  shows  "val[x  (IntVal 64 0)] = x" 
  using assms apply (cases x; auto)
  subgoal
  proof -
    have "take_bit (LENGTH(64)) xv = xv"
      unfolding Word.take_bit_length_eq by simp
    then show ?thesis
      by auto
   qed
  done

lemma ValXorElimSelf64:
  assumes "x = IntVal 64 xv"
  assumes "y = IntVal 64 yv"
  assumes "val[x  y]  UndefVal"
  assumes "val[y  y]  UndefVal"
  shows "val[x  (y  y)] = x"
  proof -
    have removeRhs: "val[x  (y  y)] = val[x  (IntVal 64 0)]"
      by (simp add: assms(2))
    then have XorZeroIsSelf: "val[x  (IntVal 64 0)] = x"
      using XorZeroIsSelf64 by (simp add: assms(1))
    then show ?thesis
      by (simp add: removeRhs)
  qed

lemma ValXorIsReverse64:
  assumes "x = IntVal 64 xv"
  assumes "y = IntVal 64 yv"
  assumes "z = IntVal 64 zv"
  assumes "z = val[x  y]"
  assumes "val[x  y]  UndefVal"
  assumes "val[z  y]  UndefVal"
  shows "val[z  y] = x"
  using ValXorIsAssociative ValXorElimSelf64 assms(1,2,4,5) by force

lemma valXorIsEqual_64:
  assumes "x = IntVal 64 xv"
  assumes "val[x  y]  UndefVal"
  assumes "val[x  z]  UndefVal"
  shows "val[intval_equals (x  y) (x  z)] = val[intval_equals y z]"
  using assms apply (cases x; cases y; cases z; auto)
  subgoal premises p for yv zv apply (cases "(yv = zv)"; simp)
  subgoal premises p
  proof -
    have isFalse: "bool_to_val (yv = zv) = bool_to_val False"
      by (simp add: p)
    then have unfoldTakebityv: "take_bit LENGTH(64) yv = yv"
      using take_bit_length_eq by blast
    then have unfoldTakebitzv: "take_bit LENGTH(64) zv = zv"
      using take_bit_length_eq by blast
    then have unfoldTakebitxv: "take_bit LENGTH(64) xv = xv"
      using take_bit_length_eq by blast
    then have lhs: "(xor (take_bit LENGTH(64) yv) (take_bit LENGTH(64) xv) =
                     xor (take_bit LENGTH(64) zv) (take_bit LENGTH(64) xv)) = (False)"
      unfolding unfoldTakebityv unfoldTakebitzv unfoldTakebitxv
      by (simp add: binXorIsEqual word_bw_comms(3) p)
    then show ?thesis
      by (simp add: isFalse)
    qed
   done
  done

lemma ValXorIsDeterministic_64:
  assumes "x = IntVal 64 xv"
  assumes "y = IntVal 64 yv"
  assumes "z = IntVal 64 zv"
  assumes "val[x  y]  UndefVal"
  assumes "val[x  z]  UndefVal"
  assumes "yv  zv"
  shows "val[x  y]  val[x  z]" 
  by (smt (verit, best) ValXorElimSelf64 ValXorIsAssociative ValXorSelfIsZero Value.distinct(1)
      assms Value.inject(1) val_xor_commute valXorIsEqual_64)

lemma ExpIntBecomesIntVal_64:
  assumes "stamp_expr x = IntegerStamp 64 xl xh"
  assumes "wf_stamp x"
  assumes "valid_value v (IntegerStamp 64 xl xh)"
  assumes "[m,p]  x  v"
  shows "xv. v = IntVal 64 xv"
  using assms by (simp add: IRTreeEvalThms.valid_value_elims(3)) 

lemma expXorIsEqual_64:
  assumes "stamp_expr x = IntegerStamp 64 xl xh"
  assumes "stamp_expr y = IntegerStamp 64 yl yh"
  assumes "stamp_expr z = IntegerStamp 64 zl zh"
  assumes "wf_stamp x"
  assumes "wf_stamp y"
  assumes "wf_stamp z"
    shows "exp[BinaryExpr BinIntegerEquals (x  y) (x  z)] 
           exp[BinaryExpr BinIntegerEquals y z]"
  using assms apply auto
  subgoal premises p for m p x1 y1 x2 z1
  proof -
    obtain xVal where xVal: "[m,p]  x  xVal"
      using p(8) by simp
    obtain yVal where yVal: "[m,p]  y  yVal"
      using p(9) by simp
    obtain zVal where zVal: "[m,p]  z  zVal"
      using p(12) by simp
    obtain xv where xv: "xVal = IntVal 64 xv"
      by (metis p(1) p(4) wf_stamp_def xVal ExpIntBecomesIntVal_64)
    then have rhs: "[m,p]  exp[BinaryExpr BinIntegerEquals y z]  val[intval_equals yVal zVal]"
      by (metis BinaryExpr bin_eval.simps(13) evalDet p(7,8,9,10,11,12,13) valXorIsEqual_64 xVal 
          yVal zVal)
    then show ?thesis
      by (metis xv evalDet p(8,9,10,11,12,13) valXorIsEqual_64 xVal yVal zVal)
  qed
  done

(* 64 bit versions *)
optimization XorIsEqual_64_1: "exp[BinaryExpr BinIntegerEquals (x  y) (x  z)]  
                             exp[BinaryExpr BinIntegerEquals y z]
                        when (stamp_expr x = IntegerStamp 64 xl xh  wf_stamp x)  
                             (stamp_expr y = IntegerStamp 64 yl yh  wf_stamp y)  
                             (stamp_expr z = IntegerStamp 64 zl zh  wf_stamp z)"
  using expXorIsEqual_64 by force

optimization XorIsEqual_64_2: "exp[BinaryExpr BinIntegerEquals (x  y) (z  x)]  
                             exp[BinaryExpr BinIntegerEquals y z]
                        when (stamp_expr x = IntegerStamp 64 xl xh  wf_stamp x)  
                             (stamp_expr y = IntegerStamp 64 yl yh  wf_stamp y)  
                             (stamp_expr z = IntegerStamp 64 zl zh  wf_stamp z)" 
  by (meson dual_order.trans mono_binary exp_xor_commutative expXorIsEqual_64) 

optimization XorIsEqual_64_3: "exp[BinaryExpr BinIntegerEquals (y  x) (x  z)]  
                             exp[BinaryExpr BinIntegerEquals y z]
                        when (stamp_expr x = IntegerStamp 64 xl xh  wf_stamp x)  
                             (stamp_expr y = IntegerStamp 64 yl yh  wf_stamp y)  
                             (stamp_expr z = IntegerStamp 64 zl zh  wf_stamp z)"
  by (meson dual_order.trans mono_binary exp_xor_commutative expXorIsEqual_64) 

optimization XorIsEqual_64_4: "exp[BinaryExpr BinIntegerEquals (y  x) (z  x)]  
                             exp[BinaryExpr BinIntegerEquals y z]
                        when (stamp_expr x = IntegerStamp 64 xl xh  wf_stamp x)  
                             (stamp_expr y = IntegerStamp 64 yl yh  wf_stamp y)  
                             (stamp_expr z = IntegerStamp 64 zl zh  wf_stamp z)"
  by (meson dual_order.trans mono_binary exp_xor_commutative expXorIsEqual_64) 

(*
XorEqZero
  (x ^ y) == 0 ↦ (x == y)
 *)

lemma unwrap_bool_to_val:
  shows "(bool_to_val a = bool_to_val b) = (a = b)"
  apply auto using bool_to_val.elims by fastforce+

lemma take_bit_size_eq:
  shows "take_bit 64 a = take_bit LENGTH(64) (a::64 word)"
  by auto

lemma xorZeroIsEq:
  "bin[(xor xv yv) = 0] = bin[xv = yv]"
  by (metis binXorIsEqual xor_self_eq)

lemma valXorEqZero_64:
  assumes "val[(x  y)]  UndefVal"
  assumes "x = IntVal 64 xv"
  assumes "y = IntVal 64 yv"
  shows "val[intval_equals (x  y) ((IntVal 64 0))] = val[intval_equals (x) (y)]"
  using assms apply (cases x; cases y; auto)
  unfolding unwrap_bool_to_val take_bit_size_eq Word.take_bit_length_eq by (simp add: xorZeroIsEq)

lemma expXorEqZero_64:
  assumes "stamp_expr x = IntegerStamp 64 xl xh"
  assumes "stamp_expr y = IntegerStamp 64 yl yh"
  assumes "wf_stamp x"
  assumes "wf_stamp y"
    shows "exp[BinaryExpr BinIntegerEquals (x  y) (const (IntVal 64 0))] 
           exp[BinaryExpr BinIntegerEquals (x) (y)]"
  using assms apply auto
  subgoal premises p for m p x1 y1
  proof -
    obtain xv where xv: "[m,p]  x  xv"
      using p by blast
    obtain yv where yv: "[m,p]  y  yv"
      using p by fast
    obtain xvv where xvv: "xv = IntVal 64 xvv"
      by (metis p(1,3) wf_stamp_def xv ExpIntBecomesIntVal_64)
    obtain yvv where yvv: "yv = IntVal 64 yvv"
      by (metis p(2,4) wf_stamp_def yv ExpIntBecomesIntVal_64)
    have rhs: "[m,p]  exp[BinaryExpr BinIntegerEquals (x) (y)]  val[intval_equals xv yv]"
      by (smt (z3) BinaryExpr ValEqAssoc ValXorSelfIsZero Value.distinct(1) bin_eval.simps(13) xvv
          evalDet p(5,6,7,8) valXorIsEqual_64 xv yv)
    then show ?thesis
      by (metis evalDet p(6,7,8) valXorEqZero_64 xv xvv yv yvv)
  qed
  done

optimization XorEqZero_64: "exp[BinaryExpr BinIntegerEquals (x  y) (const (IntVal 64 0))] 
                            exp[BinaryExpr BinIntegerEquals (x) (y)]
                      when (stamp_expr x = IntegerStamp 64 xl xh  wf_stamp x) 
                           (stamp_expr y = IntegerStamp 64 yl yh  wf_stamp y)"
  using expXorEqZero_64 by fast

(*
XorEqNeg1
  (x ^ y) == -1 ↦ (x == ¬y)
 *)

lemma xorNeg1IsEq:
  "bin[(xor xv yv) = (not 0)] = bin[xv = not yv]"
  using xorZeroIsEq by fastforce

lemma valXorEqNeg1_64:
  assumes "val[(x  y)]  UndefVal"
  assumes "x = IntVal 64 xv"
  assumes "y = IntVal 64 yv"
  shows "val[intval_equals (x  y) (IntVal 64 (not 0))] = val[intval_equals (x) (¬y)]"
  using assms apply (cases x; cases y; auto)
  unfolding unwrap_bool_to_val take_bit_size_eq Word.take_bit_length_eq using xorNeg1IsEq by auto

lemma expXorEqNeg1_64:
  assumes "stamp_expr x = IntegerStamp 64 xl xh"
  assumes "stamp_expr y = IntegerStamp 64 yl yh"
  assumes "wf_stamp x"
  assumes "wf_stamp y"
    shows "exp[BinaryExpr BinIntegerEquals (x  y) (const (IntVal 64 (not 0)))] 
           exp[BinaryExpr BinIntegerEquals (x) (¬y)]"
  using assms apply auto
  subgoal premises p for m p x1 y1
  proof -
    obtain xv where xv: "[m,p]  x  xv"
      using p by blast
    obtain yv where yv: "[m,p]  y  yv"
      using p by fast
    obtain xvv where xvv: "xv = IntVal 64 xvv"
      by (metis p(1,3) wf_stamp_def xv ExpIntBecomesIntVal_64)
    obtain yvv where yvv: "yv = IntVal 64 yvv"
      by (metis p(2,4) wf_stamp_def yv ExpIntBecomesIntVal_64)
    obtain nyv where nyv: "[m,p]  exp[(¬y)]  nyv"
      by (metis ValXorSelfIsZero2 Value.distinct(1) intval_not.simps(1) yv yvv intval_xor.simps(2)
          UnaryExpr unary_eval.simps(3))
    then have nyvEq: "val[¬yv] = nyv"
      using evalDet yv by fastforce
    obtain nyvv where nyvv: "nyv = IntVal 64 nyvv"
      using nyvEq intval_not.simps yvv by force
    have notUndef: "val[intval_equals xv (¬yv)]  UndefVal"
      using bool_to_val.elims nyvEq nyvv xvv by auto
    have rhs: "[m,p]  exp[BinaryExpr BinIntegerEquals (x) (¬y)]  val[intval_equals xv (¬yv)]"
      by (metis BinaryExpr bin_eval.simps(13) notUndef nyv nyvEq xv)
    then show ?thesis
      by (metis bit.compl_zero evalDet p(6,7,8) rhs valXorEqNeg1_64 xvv yvv xv yv)
  qed
  done

optimization XorEqNeg1_64: "exp[BinaryExpr BinIntegerEquals (x  y) (const (IntVal 64 (not 0)))] 
                            exp[BinaryExpr BinIntegerEquals (x) (¬y)]
                      when (stamp_expr x = IntegerStamp 64 xl xh  wf_stamp x) 
                           (stamp_expr y = IntegerStamp 64 yl yh  wf_stamp y)"
  using expXorEqNeg1_64 apply auto (* termination proof *) sorry

end

end