Theory IRTreeEval

section ‹Data-flow Semantics›

theory IRTreeEval
  imports
    Graph.Stamp
begin

text ‹
We define a tree representation of data-flow nodes, as an abstraction of the graph view.

Data-flow trees are evaluated in the context of a method state
(currently called MapState in the theories for historical reasons).

The method state consists of the values for each method parameter, references to
method parameters use an index of the parameter within the parameter list, as such
we store a list of parameter values which are looked up at parameter references.

The method state also stores a mapping of node ids to values. The contents of this
mapping is calculates during the traversal of the control flow graph.

As a concrete example, as the @{term SignedDivNode} can have side-effects
(during division by zero), it is treated as part of the control-flow, since
the data-flow phase is specified to be side-effect free.
As a result, the control-flow semantics for @{term SignedDivNode} calculates the
value of a node and maps the node identifier to the value within the method state.
The data-flow semantics then just reads the value stored in the method state for the node.
›

type_synonym ID = nat
type_synonym MapState = "ID  Value"
type_synonym Params = "Value list"

definition new_map_state :: "MapState" where
  "new_map_state = (λx. UndefVal)"

(* ======================== START OF NEW TREE STUFF ==========================*)
subsection ‹Data-flow Tree Representation›

datatype IRUnaryOp =
    UnaryAbs
  | UnaryNeg
  | UnaryNot
  | UnaryLogicNegation
  | UnaryNarrow (ir_inputBits: nat) (ir_resultBits: nat)
  | UnarySignExtend (ir_inputBits: nat) (ir_resultBits: nat)
  | UnaryZeroExtend (ir_inputBits: nat) (ir_resultBits: nat)
  | UnaryIsNull
  | UnaryReverseBytes
  | UnaryBitCount

datatype IRBinaryOp =
    BinAdd
  | BinSub
  | BinMul
  | BinDiv
  | BinMod
  | BinAnd
  | BinOr
  | BinXor
  | BinShortCircuitOr
  | BinLeftShift
  | BinRightShift
  | BinURightShift
  | BinIntegerEquals
  | BinIntegerLessThan
  | BinIntegerBelow
  | BinIntegerTest
  | BinIntegerNormalizeCompare
  | BinIntegerMulHigh

datatype (discs_sels) IRExpr =
    UnaryExpr (ir_uop: IRUnaryOp) (ir_value: IRExpr)
  | BinaryExpr (ir_op: IRBinaryOp) (ir_x: IRExpr) (ir_y: IRExpr)
  | ConditionalExpr (ir_condition: IRExpr) (ir_trueValue: IRExpr) (ir_falseValue: IRExpr)
(* TODO
  | IsNullNode (ir_value: IRExpr) 
  | RefNode ?
*)
  | ParameterExpr (ir_index: nat) (ir_stamp: Stamp)
(* Not needed?
  | PiNode (ir_object: IRExpr) (ir_guard_opt: "IRExpr option")
  | ShortCircuitOrNode (ir_x: IRExpr) (ir_y: IRExpr)
*)
(* Not needed?
  | UnwindNode (ir_exception: IRExpr) 
  | ValueProxyNode (ir_value: IRExpr) (ir_loopExit: IRExpr) 
*)
  | LeafExpr (ir_nid: ID) (ir_stamp: Stamp)
  (* LeafExpr is for pre-evaluated nodes, like LoadFieldNode, SignedDivNode. *) 
  | ConstantExpr (ir_const: Value) (* Ground constant *)
  | ConstantVar (ir_name: String.literal)  (* Pattern variable for constant *)
  | VariableExpr (ir_name: String.literal) (ir_stamp: Stamp) (* Pattern variable for expression *)

fun is_ground :: "IRExpr  bool" where
  "is_ground (UnaryExpr op e) = is_ground e" |
  "is_ground (BinaryExpr op e1 e2) = (is_ground e1  is_ground e2)" |
  "is_ground (ConditionalExpr b e1 e2) = (is_ground b  is_ground e1  is_ground e2)" |
  "is_ground (ParameterExpr i s) = True" |
  "is_ground (LeafExpr n s) = True" |
  "is_ground (ConstantExpr v) = True" |
  "is_ground (ConstantVar name) = False" |
  "is_ground (VariableExpr name s) = False"

typedef GroundExpr = "{ e :: IRExpr . is_ground e }"
  using is_ground.simps(6) by blast

subsection ‹Functions for re-calculating stamps› 

text ‹Note: in Java all integer calculations are done as 32 or 64 bit calculations.
  However, here we generalise the operators to allow any size calculations.
  Many operators have the same output bits as their inputs.
  However, the unary integer operators that are not $normal\_unary$ are narrowing 
  or widening operators, so the result bits is specified by the operator.
  The binary integer operators are divided into three groups:
  (1) $binary\_fixed\_32$ operators always output 32 bits,
  (2) $binary\_shift\_ops$ operators output size is determined by their left argument,
  and (3) other operators output the same number of bits as both their inputs.
›

abbreviation binary_normal :: "IRBinaryOp set" where
  "binary_normal  {BinAdd, BinMul, BinDiv, BinMod, BinSub, BinAnd, BinOr, BinXor}"

abbreviation binary_fixed_32_ops :: "IRBinaryOp set" where
  "binary_fixed_32_ops  {BinShortCircuitOr, BinIntegerEquals, BinIntegerLessThan, BinIntegerBelow, BinIntegerTest, BinIntegerNormalizeCompare}"

abbreviation binary_shift_ops :: "IRBinaryOp set" where
  "binary_shift_ops  {BinLeftShift, BinRightShift, BinURightShift}"

(* TODO add a note into the text above about this?

  Describes operators whose output width matches the input (for 32 or 64), otherwise the output is
  UndefVal.
*)
abbreviation binary_fixed_ops :: "IRBinaryOp set" where
  "binary_fixed_ops  {BinIntegerMulHigh}"

abbreviation normal_unary :: "IRUnaryOp set" where
  "normal_unary  {UnaryAbs, UnaryNeg, UnaryNot, UnaryLogicNegation, UnaryReverseBytes}"

(* TODO add a note into the text above about this ? *)
abbreviation unary_fixed_32_ops :: "IRUnaryOp set" where
  "unary_fixed_32_ops  {UnaryBitCount}"

(* TODO add a note into the text above about this ? *)
abbreviation boolean_unary :: "IRUnaryOp set" where
  "boolean_unary  {UnaryIsNull}"

(* Helpful set lemmas *)

lemma binary_ops_all:
  shows "op  binary_normal  op  binary_fixed_32_ops  op  binary_fixed_ops  op  binary_shift_ops"
  by (cases op; auto)

lemma binary_ops_distinct_normal:
  shows "op  binary_normal  op  binary_fixed_32_ops  op  binary_fixed_ops  op  binary_shift_ops"
  by auto

lemma binary_ops_distinct_fixed_32:
  shows "op  binary_fixed_32_ops  op  binary_normal  op  binary_fixed_ops  op  binary_shift_ops"
  by auto

lemma binary_ops_distinct_fixed:
  shows "op  binary_fixed_ops  op  binary_fixed_32_ops  op  binary_normal  op  binary_shift_ops"
  by auto

lemma binary_ops_distinct_shift:
  shows  "op  binary_shift_ops  op  binary_fixed_32_ops  op  binary_fixed_ops  op  binary_normal"
  by auto

lemma unary_ops_distinct:
  shows "op  normal_unary  op  boolean_unary  op  unary_fixed_32_ops"
  and   "op  boolean_unary  op  normal_unary  op  unary_fixed_32_ops"
  and   "op  unary_fixed_32_ops  op  boolean_unary  op  normal_unary"
  by auto

fun stamp_unary :: "IRUnaryOp  Stamp  Stamp" where
(* WAS:
  "stamp_unary op (IntegerStamp b lo hi) =
    (let bits = (if op ∈ normal_unary 
                 then (if b=64 then 64 else 32)
                 else (ir_resultBits op)) in
    unrestricted_stamp (IntegerStamp bits lo hi))" |
*)
(* TODO update to generalise all boolean_unary operators to return IntegerStamp 32 0 1 *)
  "stamp_unary UnaryIsNull _ = (IntegerStamp 32 0 1)" |
  "stamp_unary op (IntegerStamp b lo hi) =
     unrestricted_stamp (IntegerStamp 
                        (if op  normal_unary       then b  else
                         if op  boolean_unary      then 32 else
                         if op  unary_fixed_32_ops then 32 else
                          (ir_resultBits op)) lo hi)" |
  (* for now... *)
  "stamp_unary op _ = IllegalStamp"

fun stamp_binary :: "IRBinaryOp  Stamp  Stamp  Stamp" where
  "stamp_binary op (IntegerStamp b1 lo1 hi1) (IntegerStamp b2 lo2 hi2) =
    (if op  binary_shift_ops then unrestricted_stamp (IntegerStamp b1 lo1 hi1)
     else if b1  b2 then IllegalStamp else
      (if op  binary_fixed_32_ops
       then unrestricted_stamp (IntegerStamp 32 lo1 hi1)
       else unrestricted_stamp (IntegerStamp b1 lo1 hi1)))" |
  (* for now... *)
  "stamp_binary op _ _ = IllegalStamp"

fun stamp_expr :: "IRExpr  Stamp" where
  "stamp_expr (UnaryExpr op x) = stamp_unary op (stamp_expr x)" |
  "stamp_expr (BinaryExpr bop x y) = stamp_binary bop (stamp_expr x) (stamp_expr y)" |
  "stamp_expr (ConstantExpr val) = constantAsStamp val" |
  "stamp_expr (LeafExpr i s) = s" |
  "stamp_expr (ParameterExpr i s) = s" |
  "stamp_expr (ConditionalExpr c t f) = meet (stamp_expr t) (stamp_expr f)"

export_code stamp_unary stamp_binary stamp_expr

subsection ‹Data-flow Tree Evaluation›

fun unary_eval :: "IRUnaryOp  Value  Value" where
  "unary_eval UnaryAbs v = intval_abs v" |
  "unary_eval UnaryNeg v = intval_negate v" |
  "unary_eval UnaryNot v = intval_not v" |
  "unary_eval UnaryLogicNegation v = intval_logic_negation v" |
  "unary_eval (UnaryNarrow inBits outBits) v = intval_narrow inBits outBits v" |
  "unary_eval (UnarySignExtend inBits outBits) v = intval_sign_extend inBits outBits v" |
  "unary_eval (UnaryZeroExtend inBits outBits) v = intval_zero_extend inBits outBits v" |
  "unary_eval UnaryIsNull v = intval_is_null v" |
  "unary_eval UnaryReverseBytes v = intval_reverse_bytes v" |
  "unary_eval UnaryBitCount v = intval_bit_count v"
(*  "unary_eval op v1 = UndefVal" *)

fun bin_eval :: "IRBinaryOp  Value  Value  Value" where
  "bin_eval BinAdd v1 v2 = intval_add v1 v2" |
  "bin_eval BinSub v1 v2 = intval_sub v1 v2" |
  "bin_eval BinMul v1 v2 = intval_mul v1 v2" |
  "bin_eval BinDiv v1 v2 = intval_div v1 v2" |
  "bin_eval BinMod v1 v2 = intval_mod v1 v2" |
  "bin_eval BinAnd v1 v2 = intval_and v1 v2" |
  "bin_eval BinOr  v1 v2 = intval_or v1 v2" |
  "bin_eval BinXor v1 v2 = intval_xor v1 v2" |
  "bin_eval BinShortCircuitOr v1 v2 = intval_short_circuit_or v1 v2" |
  "bin_eval BinLeftShift v1 v2 = intval_left_shift v1 v2" |
  "bin_eval BinRightShift v1 v2 = intval_right_shift v1 v2" |
  "bin_eval BinURightShift v1 v2 = intval_uright_shift v1 v2" |
  "bin_eval BinIntegerEquals v1 v2 = intval_equals v1 v2" |
  "bin_eval BinIntegerLessThan v1 v2 = intval_less_than v1 v2" |
  "bin_eval BinIntegerBelow v1 v2 = intval_below v1 v2" |
  "bin_eval BinIntegerTest v1 v2 = intval_test v1 v2" |
  "bin_eval BinIntegerNormalizeCompare v1 v2 = intval_normalize_compare v1 v2" |
  "bin_eval BinIntegerMulHigh v1 v2 = intval_mul_high v1 v2"
(*  "bin_eval op v1 v2 = UndefVal" *)

lemma defined_eval_is_intval:
  shows "bin_eval op x y  UndefVal  (is_IntVal x  is_IntVal y)"
  by (cases op; cases x; cases y; auto)

lemmas eval_thms =
  intval_abs.simps intval_negate.simps intval_not.simps
  intval_logic_negation.simps intval_narrow.simps
  intval_sign_extend.simps intval_zero_extend.simps
  intval_add.simps intval_mul.simps intval_sub.simps
  intval_and.simps intval_or.simps intval_xor.simps
  intval_left_shift.simps intval_right_shift.simps
  intval_uright_shift.simps intval_equals.simps
  intval_less_than.simps intval_below.simps

inductive not_undef_or_fail :: "Value  Value  bool" where
  "value  UndefVal  not_undef_or_fail value value"

notation (latex output) (* we can pretend intval_* are partial functions *)
  not_undef_or_fail ("_ = _")

inductive
  evaltree :: "MapState  Params  IRExpr  Value  bool" ("[_,_]  _  _" 55)
  for m p where

  ConstantExpr:
  "wf_value c
     [m,p]  (ConstantExpr c)  c" |

  ParameterExpr:
  "i < length p; valid_value (p!i) s
     [m,p]  (ParameterExpr i s)  p!i" |

  (* We need to add this to prove certain optimizations
     but it also requires more work to show monotonicity of refinement.
  compatible (stamp_expr te) (stamp_expr fe);*)
  ConditionalExpr:
  "[m,p]  ce  cond;
    cond  UndefVal;
    branch = (if val_to_bool cond then te else fe);
    [m,p]  branch  result;
    result  UndefVal;

    [m,p]  te  true;  true   UndefVal;
    [m,p]  fe  false; false  UndefVal
     [m,p]  (ConditionalExpr ce te fe)  result" |

  UnaryExpr:
  "[m,p]  xe  x;
    result = (unary_eval op x);
    result  UndefVal
     [m,p]  (UnaryExpr op xe)  result" |

  BinaryExpr:
  "[m,p]  xe  x;
    [m,p]  ye  y;
    result = (bin_eval op x y);
    result  UndefVal
     [m,p]  (BinaryExpr op xe ye)  result" |

  LeafExpr:
  "val = m n;
    valid_value val s
     [m,p]  LeafExpr n s  val"

code_pred (modes: i ⇒ i ⇒ i ⇒ o ⇒ bool as evalT)
  [show_steps,show_mode_inference,show_intermediate_results] 
  evaltree .

inductive
  evaltrees :: "MapState  Params  IRExpr list  Value list  bool" ("[_,_]  _ [↦] _" 55)
  for m p where

  EvalNil:
  "[m,p]  [] [↦] []" |

  EvalCons:
  "[m,p]  x  xval;
    [m,p]  yy [↦] yyval
     [m,p]  (x#yy) [↦] (xval#yyval)"

code_pred (modes: i ⇒ i ⇒ i ⇒ o ⇒ bool as evalTs)
  evaltrees .

definition sq_param0 :: IRExpr where
  "sq_param0 = BinaryExpr BinMul 
    (ParameterExpr 0 (IntegerStamp 32 (- 2147483648) 2147483647))
    (ParameterExpr 0 (IntegerStamp 32 (- 2147483648) 2147483647))"

values "{v. evaltree new_map_state [IntVal 32 5] sq_param0 v}"

(* We add all the inductive rules as unsafe intro rules. *)
declare evaltree.intros [intro]
declare evaltrees.intros [intro]

(* We derive a safe elimination (forward) reasoning rule for each case.
  Note that each pattern is as general as possible. *)
inductive_cases ConstantExprE[elim!]:tag invisible›
  "[m,p]  (ConstantExpr c)  val"
inductive_cases ParameterExprE[elim!]:tag invisible›
  "[m,p]  (ParameterExpr i s)  val"
inductive_cases ConditionalExprE[elim!]:tag invisible›
  "[m,p]  (ConditionalExpr c t f)  val"
inductive_cases UnaryExprE[elim!]:tag invisible›
  "[m,p]  (UnaryExpr op xe)  val"
inductive_cases BinaryExprE[elim!]:tag invisible›
  "[m,p]  (BinaryExpr op xe ye)  val"
inductive_cases LeafExprE[elim!]:tag invisible›
  "[m,p]  (LeafExpr n s)  val"
inductive_cases ConstantVarE[elim!]:tag invisible›
  "[m,p]  (ConstantVar x)  val"
inductive_cases VariableExprE[elim!]:tag invisible›
  "[m,p]  (VariableExpr x s)  val"
inductive_cases EvalNilE[elim!]:tag invisible›
  "[m,p]  [] [↦] vals"
inductive_cases EvalConsE[elim!]:tag invisible›
  "[m,p]  (x#yy) [↦] vals"

(* group these forward rules into a named set *)
lemmas EvalTreeEtag invisible› = 
  ConstantExprE
  ParameterExprE
  ConditionalExprE
  UnaryExprE
  BinaryExprE
  LeafExprE
  ConstantVarE
  VariableExprE
  EvalNilE
  EvalConsE

subsection ‹Data-flow Tree Refinement›

text ‹We define the induced semantic equivalence relation between expressions.
  Note that syntactic equality implies semantic equivalence, but not vice versa.
›
definition equiv_exprs :: "IRExpr  IRExpr  bool" ("_  _" 55) where
  "(e1  e2) = ( m p v. (([m,p]  e1  v)  ([m,p]  e2  v)))"

text ‹We also prove that this is a total equivalence relation (@{term "equivp equiv_exprs"})
  (HOL.Equiv\_Relations), so that we can reuse standard results about equivalence relations.
›
lemma "equivp equiv_exprs"
  apply (auto simp add: equivp_def equiv_exprs_def) by (metis equiv_exprs_def)+

text ‹We define a refinement ordering over IRExpr and show that it is a preorder.
  Note that it is asymmetric because e2 may refer to fewer variables than e1.
›
instantiation IRExpr :: preorder begin

notation less_eq (infix "" 65)

definition
  le_expr_def [simp]:
    "(e2  e1)  ( m p v. (([m,p]  e1  v)  ([m,p]  e2  v)))"

definition
  lt_expr_def [simp]:
    "(e1 < e2)  (e1  e2  ¬ (e1  e2))"

instance proof 
  fix x y z :: IRExpr
  show "x < y  x  y  ¬ (y  x)" by (simp add: equiv_exprs_def; auto)
  show "x  x" by simp
  show "x  y  y  z  x  z" by simp 
qed

end

abbreviation (output) Refines :: "IRExpr  IRExpr  bool" (infix "" 64)
  where "e1  e2  (e2  e1)"

subsection ‹Stamp Masks›

text ‹
A stamp can contain additional range information in the form of masks.
A stamp has an up mask and a down mask,
corresponding to a the bits that may be set and the bits that must be set.

Examples:
  A stamp where no range information is known will have;
    an up mask of -1 as all bits may be set, and
    a down mask of 0 as no bits must be set.

  A stamp known to be one should have;
    an up mask of 1 as only the first bit may be set, no others, and
    a down mask of 1 as the first bit must be set and no others.

We currently don't carry mask information in stamps,
and instead assume correct masks to prove optimizations.
›

locale stamp_mask =
  fixes up :: "IRExpr  int64" ("")
  fixes down :: "IRExpr  int64" ("")
  assumes up_spec: "[m, p]  e  IntVal b v  (and v (not ((ucast (e))))) = 0"
      and down_spec: "[m, p]  e  IntVal b v  (and (not v) (ucast (e))) = 0"
begin

lemma may_implies_either:
  "[m, p]  e  IntVal b v  bit (e) n  bit v n = False  bit v n = True"
  by simp

lemma not_may_implies_false:
  "[m, p]  e  IntVal b v  ¬(bit (e) n)  bit v n = False"
  by (metis (no_types, lifting) bit.double_compl up_spec bit_and_iff bit_not_iff bit_unsigned_iff 
      down_spec)

lemma must_implies_true:
  "[m, p]  e  IntVal b v  bit (e) n  bit v n = True"
  by (metis bit.compl_one bit_and_iff bit_minus_1_iff bit_not_iff impossible_bit ucast_id down_spec)

lemma not_must_implies_either:
  "[m, p]  e  IntVal b v  ¬(bit (e) n)  bit v n = False  bit v n = True"
  by simp

lemma must_implies_may:
  "[m, p]  e  IntVal b v  n < 32  bit (e) n  bit (e) n"
  by (meson must_implies_true not_may_implies_false)

lemma up_mask_and_zero_implies_zero:
  assumes "and (x) (y) = 0"
  assumes "[m, p]  x  IntVal b xv"
  assumes "[m, p]  y  IntVal b yv"
  shows "and xv yv = 0"
  by (smt (z3) assms and.commute and.right_neutral bit.compl_zero bit.conj_cancel_right ucast_id
      bit.conj_disj_distribs(1) up_spec word_bw_assocs(1) word_not_dist(2) word_ao_absorbs(8)
      and_eq_not_not_or)

lemma not_down_up_mask_and_zero_implies_zero:
  assumes "and (not (x)) (y) = 0"
  assumes "[m, p]  x  IntVal b xv"
  assumes "[m, p]  y  IntVal b yv"
  shows "and xv yv = yv"
  by (metis (no_types, opaque_lifting) assms bit.conj_cancel_left bit.conj_disj_distribs(1,2)
      bit.de_Morgan_disj ucast_id down_spec or_eq_not_not_and up_spec word_ao_absorbs(2,8)
      word_bw_lcs(1) word_not_dist(2))

end

definition IRExpr_up :: "IRExpr  int64" where
  "IRExpr_up e = not 0"

definition IRExpr_down :: "IRExpr  int64" where
  "IRExpr_down e = 0"

lemma ucast_zero: "(ucast (0::int64)::int32) = 0"
  by simp

lemma ucast_minus_one: "(ucast (-1::int64)::int32) = -1"
  apply transfer by auto

interpretation simple_mask: stamp_mask
  "IRExpr_up :: IRExpr  int64"
  "IRExpr_down :: IRExpr  int64"
  apply unfold_locales
  by (simp add: ucast_minus_one IRExpr_up_def IRExpr_down_def)+

end