Theory TreeToGraph

section ‹Tree to Graph›

theory TreeToGraph
  imports 
    Semantics.IRTreeEval
    Graph.IRGraph
    Snippets.Snipping
begin

subsection ‹Subgraph to Data-flow Tree›

fun find_node_and_stamp :: "IRGraph  (IRNode × Stamp)  ID option" where
  "find_node_and_stamp g (n,s) =
     find (λi. kind g i = n  stamp g i = s) (sorted_list_of_set(ids g))"

export_code find_node_and_stamp

(* These kinds of nodes are evaluated during the control flow, so are already in MapState. *)
fun is_preevaluated :: "IRNode  bool" where
  "is_preevaluated (InvokeNode n _ _ _ _ _) = True" |
  "is_preevaluated (InvokeWithExceptionNode n _ _ _ _ _ _) = True" |
  "is_preevaluated (NewInstanceNode n _ _ _) = True" |
  "is_preevaluated (LoadFieldNode n _ _ _) = True" |
  "is_preevaluated (SignedDivNode n _ _ _ _ _) = True" |
  "is_preevaluated (SignedRemNode n _ _ _ _ _) = True" |
  "is_preevaluated (ValuePhiNode n _ _) = True" |
  "is_preevaluated (BytecodeExceptionNode n _ _) = True" |
  "is_preevaluated (NewArrayNode n _ _) = True" |
  "is_preevaluated (ArrayLengthNode n _) = True" |
  "is_preevaluated (LoadIndexedNode n _ _ _) = True" |
  "is_preevaluated (StoreIndexedNode n _ _ _ _ _ _) = True" |
  "is_preevaluated _ = False"

inductive
  rep :: "IRGraph  ID  IRExpr  bool" ("_  _  _" 55)
  for g where

  ConstantNode:
  "kind g n = ConstantNode c
     g  n  (ConstantExpr c)" |

  ParameterNode:
  "kind g n = ParameterNode i;
    stamp g n = s
     g   n  (ParameterExpr i s)" |

  ConditionalNode:
  "kind g n = ConditionalNode c t f;
    g  c  ce;
    g  t  te;
    g  f  fe
     g  n  (ConditionalExpr ce te fe)" |

(* Unary nodes *)
  AbsNode:
  "kind g n = AbsNode x;
    g  x  xe
     g  n  (UnaryExpr UnaryAbs xe)" |

  ReverseBytesNode:
  "kind g n = ReverseBytesNode x;
    g  x  xe
     g  n  (UnaryExpr UnaryReverseBytes xe)" |

  BitCountNode:
  "kind g n = BitCountNode x;
    g  x  xe
     g  n  (UnaryExpr UnaryBitCount xe)" |

  NotNode:
  "kind g n = NotNode x;
    g  x  xe
     g  n  (UnaryExpr UnaryNot xe)" |

  NegateNode:
  "kind g n = NegateNode x;
    g  x  xe
     g  n  (UnaryExpr UnaryNeg xe)" |

  LogicNegationNode:
  "kind g n = LogicNegationNode x;
    g  x  xe
     g  n  (UnaryExpr UnaryLogicNegation xe)" |

(* Binary nodes *)
  AddNode:
  "kind g n = AddNode x y;
    g  x  xe;
    g  y  ye
     g  n  (BinaryExpr BinAdd xe ye)" |

  MulNode:
  "kind g n = MulNode x y;
    g  x  xe;
    g  y  ye
     g  n  (BinaryExpr BinMul xe ye)" |

  DivNode:
  "kind g n = SignedFloatingIntegerDivNode x y;
    g  x  xe;
    g  y  ye
     g  n  (BinaryExpr BinDiv xe ye)" |

  ModNode:
  "kind g n = SignedFloatingIntegerRemNode x y;
    g  x  xe;
    g  y  ye
     g  n  (BinaryExpr BinMod xe ye)" |

  SubNode:
  "kind g n = SubNode x y;
    g  x  xe;
    g  y  ye
     g  n  (BinaryExpr BinSub xe ye)" |

  AndNode:
  "kind g n = AndNode x y;
    g  x  xe;
    g  y  ye
     g  n  (BinaryExpr BinAnd xe ye)" |

  OrNode:
  "kind g n = OrNode x y;
    g  x  xe;
    g  y  ye
     g  n  (BinaryExpr BinOr xe ye)" |

  XorNode:
  "kind g n = XorNode x y;
    g  x  xe;
    g  y  ye
     g  n  (BinaryExpr BinXor xe ye)" |

  ShortCircuitOrNode:
  "kind g n = ShortCircuitOrNode x y;
    g  x  xe;
    g  y  ye
     g  n  (BinaryExpr BinShortCircuitOr xe ye)" |

  LeftShiftNode:
  "kind g n = LeftShiftNode x y;
    g  x  xe;
    g  y  ye
    g  n  (BinaryExpr BinLeftShift xe ye)" |

  RightShiftNode:
  "kind g n = RightShiftNode x y;
    g  x  xe;
    g  y  ye
    g  n  (BinaryExpr BinRightShift xe ye)" |

  UnsignedRightShiftNode:
  "kind g n = UnsignedRightShiftNode x y;
    g  x  xe;
    g  y  ye
    g  n  (BinaryExpr BinURightShift xe ye)" |

  IntegerBelowNode:
  "kind g n = IntegerBelowNode x y;
    g  x  xe;
    g  y  ye
     g  n  (BinaryExpr BinIntegerBelow xe ye)" |

  IntegerEqualsNode:
  "kind g n = IntegerEqualsNode x y;
    g  x  xe;
    g  y  ye
     g  n  (BinaryExpr BinIntegerEquals xe ye)" |

  IntegerLessThanNode:
  "kind g n = IntegerLessThanNode x y;
    g  x  xe;
    g  y  ye
     g  n  (BinaryExpr BinIntegerLessThan xe ye)" |

  IntegerTestNode:
  "kind g n = IntegerTestNode x y;
    g  x  xe;
    g  y  ye
     g  n  (BinaryExpr BinIntegerTest xe ye)" |

  IntegerNormalizeCompareNode:
  "kind g n = IntegerNormalizeCompareNode x y;
    g  x  xe;
    g  y  ye
     g  n  (BinaryExpr BinIntegerNormalizeCompare xe ye)" |

  IntegerMulHighNode:
  "kind g n = IntegerMulHighNode x y;
    g  x  xe;
    g  y  ye
     g  n  (BinaryExpr BinIntegerMulHigh xe ye)" |

(* Convert Nodes *)
  NarrowNode:
  "kind g n = NarrowNode inputBits resultBits x;
    g  x  xe
     g  n  (UnaryExpr (UnaryNarrow inputBits resultBits) xe)" |

  SignExtendNode:
  "kind g n = SignExtendNode inputBits resultBits x;
    g  x  xe
     g  n  (UnaryExpr (UnarySignExtend inputBits resultBits) xe)" |

  ZeroExtendNode:
  "kind g n = ZeroExtendNode inputBits resultBits x;
    g  x  xe
     g  n  (UnaryExpr (UnaryZeroExtend inputBits resultBits) xe)" |

(* Leaf Node
    TODO: For now, BytecodeExceptionNode is treated as a LeafNode.
*)
  LeafNode:
  "is_preevaluated (kind g n);
    stamp g n = s
     g  n  (LeafExpr n s)" |

(* TODO: For now, ignore narrowing. *)
  PiNode:
  "kind g n = PiNode n' guard;
    g  n'  e
     g  n  e" |

(* Ref Node *)
  RefNode:
  "kind g n = RefNode n';
    g  n'  e
     g  n  e" |

(* IsNull Node *)
  IsNullNode:
  "kind g n = IsNullNode v;
    g  v  lfn
     g  n  (UnaryExpr UnaryIsNull lfn)"

code_pred (modes: i ⇒ i ⇒ o ⇒ bool as exprE) rep .

inductive
  replist :: "IRGraph  ID list  IRExpr list  bool" ("_  _ [≃] _" 55)
  for g where

  RepNil:
  "g  [] [≃] []" |

  RepCons:
  "g  x  xe;
    g  xs [≃] xse
     g  x#xs [≃] xe#xse"

code_pred (modes: i ⇒ i ⇒ o ⇒ bool as exprListE) replist .

definition wf_term_graph :: "MapState  Params  IRGraph  ID  bool" where
  "wf_term_graph m p g n = ( e. (g  n  e)  ( v. ([m, p]  e  v)))"

values "{t. eg2_sq  4  t}"

inductive_cases ConstantNodeE[elim!]:tag invisible›
  "g  n  (ConstantExpr c)"
inductive_cases ParameterNodeE[elim!]:tag invisible›
  "g  n  (ParameterExpr i s)"
inductive_cases ConditionalNodeE[elim!]:tag invisible›
  "g  n  (ConditionalExpr ce te fe)"
inductive_cases AbsNodeE[elim!]:tag invisible›
  "g  n  (UnaryExpr UnaryAbs xe)"
inductive_cases ReverseBytesNodeE[elim!]:tag invisible›
  "g  n  (UnaryExpr UnaryReverseBytes xe)"
inductive_cases BitCountNodeE[elim!]:tag invisible›
  "g  n  (UnaryExpr UnaryBitCount xe)"
inductive_cases NotNodeE[elim!]:tag invisible›
  "g  n  (UnaryExpr UnaryNot xe)"
inductive_cases NegateNodeE[elim!]:tag invisible›
  "g  n  (UnaryExpr UnaryNeg xe)"
inductive_cases LogicNegationNodeE[elim!]:tag invisible›
  "g  n  (UnaryExpr UnaryLogicNegation xe)"
inductive_cases AddNodeE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinAdd xe ye)"
inductive_cases MulNodeE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinMul xe ye)"
inductive_cases DivNodeE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinDiv xe ye)"
inductive_cases ModNodeE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinMod xe ye)"
inductive_cases SubNodeE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinSub xe ye)"
inductive_cases AndNodeE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinAnd xe ye)"
inductive_cases OrNodeE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinOr xe ye)"
inductive_cases XorNodeE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinXor xe ye)"
inductive_cases ShortCircuitOrE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinShortCircuitOr xe ye)"
inductive_cases LeftShiftNodeE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinLeftShift xe ye)"
inductive_cases RightShiftNodeE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinRightShift xe ye)"
inductive_cases UnsignedRightShiftNodeE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinURightShift xe ye)"
inductive_cases IntegerBelowNodeE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinIntegerBelow xe ye)"
inductive_cases IntegerEqualsNodeE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinIntegerEquals xe ye)"
inductive_cases IntegerLessThanNodeE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinIntegerLessThan xe ye)"
inductive_cases IntegerMulHighNodeE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinIntegerMulHigh xe ye)"
inductive_cases IntegerTestNodeE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinIntegerTest xe ye)"
inductive_cases IntegerNormalizeCompareNodeE[elim!]:tag invisible›
  "g  n  (BinaryExpr BinIntegerNormalizeCompare xe ye)"
inductive_cases NarrowNodeE[elim!]:tag invisible›
  "g  n  (UnaryExpr (UnaryNarrow ib rb) xe)"
inductive_cases SignExtendNodeE[elim!]:tag invisible›
  "g  n  (UnaryExpr (UnarySignExtend ib rb) xe)"
inductive_cases ZeroExtendNodeE[elim!]:tag invisible›
  "g  n  (UnaryExpr (UnaryZeroExtend ib rb) xe)"
inductive_cases LeafNodeE[elim!]:tag invisible›
  "g  n  (LeafExpr n s)"
inductive_cases IsNullNodeE[elim!]:tag invisible›
  "g  n  (UnaryExpr UnaryIsNull lfn)"

(* group those forward rules into a named set *)
lemmas RepEtag invisible› = 
  ConstantNodeE
  ParameterNodeE
  ConditionalNodeE
  AbsNodeE
  ReverseBytesNodeE
  BitCountNodeE
  NotNodeE
  NegateNodeE
  LogicNegationNodeE
  AddNodeE
  MulNodeE
  DivNodeE
  ModNodeE
  SubNodeE
  AndNodeE
  OrNodeE
  XorNodeE
  ShortCircuitOrE
  LeftShiftNodeE
  RightShiftNodeE
  UnsignedRightShiftNodeE
  IntegerBelowNodeE
  IntegerEqualsNodeE
  IntegerLessThanNodeE
  IntegerMulHighNodeE
  IntegerTestNodeE
  IntegerNormalizeCompareNodeE
  NarrowNodeE
  SignExtendNodeE
  ZeroExtendNodeE
  LeafNodeE
  IsNullNodeE 

subsection ‹Data-flow Tree to Subgraph›

fun unary_node :: "IRUnaryOp  ID  IRNode" where
  "unary_node UnaryAbs v = AbsNode v" |
  "unary_node UnaryNot v = NotNode v" |
  "unary_node UnaryNeg v = NegateNode v" |
  "unary_node UnaryLogicNegation v = LogicNegationNode v" |
  "unary_node (UnaryNarrow ib rb) v = NarrowNode ib rb v" |
  "unary_node (UnarySignExtend ib rb) v = SignExtendNode ib rb v" |
  "unary_node (UnaryZeroExtend ib rb) v = ZeroExtendNode ib rb v" |
  "unary_node UnaryIsNull v = IsNullNode v" |
  "unary_node UnaryReverseBytes v = ReverseBytesNode v" |
  "unary_node UnaryBitCount v = BitCountNode v"

(* Creates the appropriate IRNode for a given binary operator. *)
fun bin_node :: "IRBinaryOp  ID  ID  IRNode" where
  "bin_node BinAdd x y = AddNode x y" |
  "bin_node BinMul x y = MulNode x y" |
  "bin_node BinDiv x y = SignedFloatingIntegerDivNode x y" |
  "bin_node BinMod x y = SignedFloatingIntegerRemNode x y" |
  "bin_node BinSub x y = SubNode x y" |
  "bin_node BinAnd x y = AndNode x y" |
  "bin_node BinOr  x y = OrNode x y" |
  "bin_node BinXor x y = XorNode x y" |
  "bin_node BinShortCircuitOr x y = ShortCircuitOrNode x y" |
  "bin_node BinLeftShift x y = LeftShiftNode x y" |
  "bin_node BinRightShift x y = RightShiftNode x y" |
  "bin_node BinURightShift x y = UnsignedRightShiftNode x y" |
  "bin_node BinIntegerEquals x y = IntegerEqualsNode x y" |
  "bin_node BinIntegerLessThan x y = IntegerLessThanNode x y" |
  "bin_node BinIntegerBelow x y = IntegerBelowNode x y" |
  "bin_node BinIntegerTest x y = IntegerTestNode x y" |
  "bin_node BinIntegerNormalizeCompare x y = IntegerNormalizeCompareNode x y" |
  "bin_node BinIntegerMulHigh x y = IntegerMulHighNode x y"

inductive fresh_id :: "IRGraph  ID  bool" where
  "n  ids g  fresh_id g n"

code_pred fresh_id .

(* This generates a specific fresh ID (max+1), in a code-friendly way. *)
fun get_fresh_id :: "IRGraph  ID" where
(* Previous attempts - cannot generate code due to nat not Enum. 
  "get_fresh_id g = 100"
  "get_fresh_id g = (ffold max (0::nat) (f_ids g))"
  "get_fresh_id g = fst(last(as_list g))"
  "get_fresh_id g = last(sorted_list_of_set (dom (Rep_IRGraph g)))"
*)
  "get_fresh_id g = last(sorted_list_of_set(ids g)) + 1"

export_code get_fresh_id
(* these examples return 6 and 7 respectively *)
value "get_fresh_id eg2_sq"
value "get_fresh_id (add_node 6 (ParameterNode 2, default_stamp) eg2_sq)"

inductive unique :: "IRGraph  (IRNode × Stamp)  (IRGraph × ID)  bool" where
  Exists:
  "find_node_and_stamp g node = Some n
    unique g node (g, n)" |
  New:
  "find_node_and_stamp g node = None;
    n = get_fresh_id g;
    g' = add_node n node g
    unique g node (g', n)"

code_pred (modes: i ⇒ i ⇒ o ⇒ bool as uniqueE) unique .

(* Second version of tree insertion into graph:

      g ◃ expr ↝ (g',n) re-inserts expr into g and returns the new root n.

   This means that we can try to re-use existing nodes by finding node IDs bottom-up.
*)
inductive
  unrep :: "IRGraph  IRExpr  (IRGraph × ID)  bool" ("_  _  _" 55)
  where

  UnrepConstantNode:
  "unique g (ConstantNode c, constantAsStamp c) (g1, n)
     g  (ConstantExpr c)  (g1, n)" |

  UnrepParameterNode:
  "unique g (ParameterNode i, s) (g1, n)
     g  (ParameterExpr i s)  (g1, n)" |

  UnrepConditionalNode:
  "g  ce  (g1, c);
    g1  te  (g2, t);
    g2  fe  (g3, f);
    s' = meet (stamp g3 t) (stamp g3 f);
    unique g3 (ConditionalNode c t f, s') (g4, n)
     g  (ConditionalExpr ce te fe)  (g4, n)" |

  UnrepUnaryNode:
  "g  xe  (g1, x);
    s' = stamp_unary op (stamp g1 x);
    unique g1 (unary_node op x, s') (g2, n)
     g  (UnaryExpr op xe)  (g2, n)" |

  UnrepBinaryNode:
  "g  xe  (g1, x);
    g1  ye  (g2, y);
    s' = stamp_binary op (stamp g2 x) (stamp g2 y);
    unique g2 (bin_node op x y, s') (g3, n)
     g  (BinaryExpr op xe ye)  (g3, n)" |

  AllLeafNodes:
  "stamp g n = s;
    is_preevaluated (kind g n)
     g  (LeafExpr n s)  (g, n)"

(*  UnrepNil:
  "g ◃L [] ↝ (g, [])" |

(* TODO: this will fail for [xe,ye] where they are not equal.
         Figure out how to generate code for this?
*)
  UnrepCons:
  "⟦g ◃ xe ↝ (g2, x);
    g2 ◃L xes ↝ (g3, xs)⟧
    ⟹ g ◃L (xe#xes) ↝ (g3, x#xs)"*)

code_pred (modes: i ⇒ i ⇒ o ⇒ bool as unrepE)
(*
  [show_steps,show_mode_inference,show_intermediate_results] 
*)  unrep .

snipbegin ‹uniqueRules›
text ‹
\begin{center}
@{thm[mode=Rule] unique.Exists [no_vars]}\\[8px]
@{thm[mode=Rule] unique.New [no_vars]}\\[8px]
\end{center}
›
snipend -

snipbegin ‹unrepRules›
text ‹
\begin{center}
@{thm[mode=Rule] unrep.UnrepConstantNode [no_vars]}\\[8px]
@{thm[mode=Rule] unrep.UnrepParameterNode [no_vars]}\\[8px]
@{thm[mode=Rule] unrep.UnrepConditionalNode [no_vars]}\\[8px]
@{thm[mode=Rule] unrep.UnrepBinaryNode [no_vars]}\\[8px]
@{thm[mode=Rule] unrep.UnrepUnaryNode [no_vars]}\\[8px]
@{thm[mode=Rule] unrep.AllLeafNodes [no_vars]}\\[8px]
\end{center}
›
snipend -

(*
instantiation IRGraph :: equal begin

definition "(g1 = g2) ⟷ 
              (f_ids g1 = f_ids g2 ∧
               (∀n. (n ∈ ids g1 ⟹ (Rep_IRGraph g1 n = Rep_IRGraph g2 n))))"
instance proof 
  fix x y :: IRGraph
  show "x = y ⟷ x ≤ y ∧ ¬ (y ≤ x)" by (simp add: equiv_exprs_def; auto)
  show "x ≤ x" by simp
  show "x ≤ y ⟹ y ≤ z ⟹ x ≤ z" by simp 
qed
end
*)

(*values "{(n, g) . (eg2_sq ⊕ sq_param0 ↝ (g, n))}"*)

subsection ‹Lift Data-flow Tree Semantics›

inductive encodeeval :: "IRGraph  MapState  Params  ID  Value  bool" 
  ("[_,_,_]  _  _" 50)
  where
  "(g  n  e)  ([m,p]  e  v)  [g, m, p]  n  v"

code_pred (modes: i ⇒ i ⇒ i ⇒ i ⇒ o ⇒ bool) encodeeval .


inductive encodeEvalAll :: "IRGraph  MapState  Params  ID list  Value list  bool"
  ("[_,_,_]  _ [↦] _" 60) where
  "(g  nids [≃] es)  ([m, p]  es [↦] vs)  ([g, m, p]  nids [↦] vs)"

code_pred (modes: i ⇒ i ⇒ i ⇒ i ⇒ o ⇒ bool) encodeEvalAll .


subsection ‹Graph Refinement›

definition graph_represents_expression :: "IRGraph  ID  IRExpr  bool" 
  ("_  _  _" 50)
  where
  "(g  n  e) = (e' . (g  n  e')  (e'  e))"

definition graph_refinement :: "IRGraph  IRGraph  bool" where
  "graph_refinement g1 g2 = 
        ((ids g1  ids g2) 
        ( n . n  ids g1  (e. (g1  n  e)  (g2  n  e))))"

lemma graph_refinement:
  "graph_refinement g1 g2  
   (n m p v. n  ids g1  ([g1, m, p]  n  v)  ([g2, m, p]  n  v))"
  by (meson encodeeval.simps graph_refinement_def graph_represents_expression_def le_expr_def)

subsection ‹Maximal Sharing›

definition maximal_sharing:
  "maximal_sharing g = ( n1 n2 . n1  true_ids g  n2  true_ids g  
      ( e. (g  n1  e)  (g  n2  e)  (stamp g n1 = stamp g n2)  n1 = n2))"

end