Theory IRGraph

subsection ‹IR Graph Type›

theory IRGraph
  imports 
    IRNodeHierarchy
    Stamp
    "HOL-Library.FSet"
    "HOL.Relation"
begin

text ‹This theory defines the main Graal data structure - an entire IR Graph.›

text ‹
IRGraph is defined as a partial map with a finite domain.
The finite domain is required to be able to generate code and produce an interpreter.
›
typedef IRGraph = "{g :: ID  (IRNode × Stamp) . finite (dom g)}"
proof -
  have "finite(dom(Map.empty))  ran Map.empty = {}" by auto
  then show ?thesis
    by fastforce
qed

setup_lifting type_definition_IRGraph

lift_definition ids :: "IRGraph  ID set"
  is "λg. {nid  dom g . s. g nid = (Some (NoNode, s))}" .

fun with_default :: "'c  ('b  'c)  (('a  'b)  'a  'c)" where
  "with_default def conv = (λm k.
    (case m k of None  def | Some v  conv v))"

lift_definition kind :: "IRGraph  (ID  IRNode)"
  is "with_default NoNode fst" .

lift_definition stamp :: "IRGraph  ID  Stamp"
  is "with_default IllegalStamp snd" .

lift_definition add_node :: "ID  (IRNode × Stamp)  IRGraph  IRGraph"
  is "λnid k g. if fst k = NoNode then g else g(nid  k)" by simp

lift_definition remove_node :: "ID  IRGraph  IRGraph"
  is "λnid g. g(nid := None)" by simp

lift_definition replace_node :: "ID  (IRNode × Stamp)  IRGraph  IRGraph"
  is "λnid k g. if fst k = NoNode then g else g(nid  k)" by simp

lift_definition as_list :: "IRGraph  (ID × IRNode × Stamp) list"
  is "λg. map (λk. (k, the (g k))) (sorted_list_of_set (dom g))" .

fun no_node :: "(ID × (IRNode × Stamp)) list  (ID × (IRNode × Stamp)) list" where
  "no_node g = filter (λn. fst (snd n)  NoNode) g"

lift_definition irgraph :: "(ID × (IRNode × Stamp)) list  IRGraph"
  is "map_of  no_node"
  by (simp add: finite_dom_map_of)

definition as_set :: "IRGraph  (ID × (IRNode × Stamp)) set" where
  "as_set g = {(n, kind g n, stamp g n) | n . n  ids g}"

definition true_ids :: "IRGraph  ID set" where
  "true_ids g = ids g - {n  ids g. n' . kind g n = RefNode n'}"

definition domain_subtraction :: "'a set  ('a × 'b) set  ('a × 'b) set"
  (infix "" 30) where
  "domain_subtraction s r = {(x, y) . (x, y)  r  x  s}"

notation (latex)
  domain_subtraction ("_ latex‹$\\ndres$› _")

(* Theories are required for code generation to work *)
code_datatype irgraph

fun filter_none where
  "filter_none g = {nid  dom g . s. g nid = (Some (NoNode, s))}"

lemma no_node_clears:
  "res = no_node xs  (x  set res. fst (snd x)  NoNode)"
  by simp

lemma dom_eq:
  assumes "x  set xs. fst (snd x)  NoNode"
  shows "filter_none (map_of xs) = dom (map_of xs)"
  using assms map_of_SomeD by fastforce

lemma fil_eq:
  "filter_none (map_of (no_node xs)) = set (map fst (no_node xs))"
  by (metis no_node_clears dom_eq dom_map_of_conv_image_fst list.set_map)

lemma irgraph[code]: "ids (irgraph m) = set (map fst (no_node m))"
  by (metis fil_eq Rep_IRGraph eq_onp_same_args filter_none.simps ids.abs_eq irgraph.abs_eq
      irgraph.rep_eq mem_Collect_eq)

lemma [code]: "Rep_IRGraph (irgraph m) = map_of (no_node m)"
  by (simp add: irgraph.rep_eq)

― ‹Get the inputs set of a given node ID›
fun inputs :: "IRGraph  ID  ID set" where
  "inputs g nid = set (inputs_of (kind g nid))"
― ‹Get the successor set of a given node ID›
fun succ :: "IRGraph  ID  ID set" where
  "succ g nid = set (successors_of (kind g nid))"
― ‹Gives a relation between node IDs - between a node and its input nodes›
fun input_edges :: "IRGraph  ID rel" where
  "input_edges g = ( i  ids g. {(i,j)|j. j  (inputs g i)})"
― ‹Find all the nodes in the graph that have nid as an input - the usages of nid›
fun usages :: "IRGraph  ID  ID set" where
  "usages g nid = {i. i  ids g  nid  inputs g i}"
fun successor_edges :: "IRGraph  ID rel" where
  "successor_edges g = ( i  ids g. {(i,j)|j . j  (succ  g i)})"
fun predecessors :: "IRGraph  ID  ID set" where
  "predecessors g nid = {i. i  ids g  nid  succ g i}"
fun nodes_of :: "IRGraph  (IRNode  bool)  ID set" where
  "nodes_of g sel = {nid  ids g . sel (kind g nid)}"
fun edge :: "(IRNode  'a)  ID  IRGraph  'a" where
  "edge sel nid g = sel (kind g nid)"

fun filtered_inputs :: "IRGraph  ID  (IRNode  bool)  ID list" where
  "filtered_inputs g nid f = filter (f  (kind g)) (inputs_of (kind g nid))"
fun filtered_successors :: "IRGraph  ID  (IRNode  bool)  ID list" where
  "filtered_successors g nid f = filter (f  (kind g)) (successors_of (kind g nid))"
fun filtered_usages :: "IRGraph  ID  (IRNode  bool)  ID set" where
  "filtered_usages g nid f = {n  (usages g nid). f (kind g n)}"

fun is_empty :: "IRGraph  bool" where
  "is_empty g = (ids g = {})"

fun any_usage :: "IRGraph  ID  ID" where
  "any_usage g nid = hd (sorted_list_of_set (usages g nid))"

lemma ids_some[simp]: "x  ids g  kind g x  NoNode" 
proof -
  have that: "x  ids g  kind g x  NoNode"
    by (auto simp add: kind.rep_eq ids.rep_eq)
  have "kind g x  NoNode  x  ids g"
    by (cases "Rep_IRGraph g x = None"; auto simp add: ids_def kind_def)
  from this that show ?thesis 
    by auto
qed

lemma not_in_g: 
  assumes "nid  ids g"
  shows "kind g nid = NoNode"
  using assms by simp

lemma valid_creation[simp]:
  "finite (dom g)  Rep_IRGraph (Abs_IRGraph g) = g"
  by (metis Abs_IRGraph_inverse Rep_IRGraph mem_Collect_eq)

lemma [simp]: "finite (ids g)" 
  using Rep_IRGraph by (simp add: ids.rep_eq)

lemma [simp]: "finite (ids (irgraph g))" 
  by (simp add: finite_dom_map_of)

lemma [simp]: "finite (dom g)  ids (Abs_IRGraph g) = {nid  dom g . s. g nid = Some (NoNode, s)}"
  by (simp add: ids.rep_eq)

lemma [simp]: "finite (dom g)  kind (Abs_IRGraph g) = (λx . (case g x of None  NoNode | Some n  fst n))"
  by (simp add: kind.rep_eq)

lemma [simp]: "finite (dom g)  stamp (Abs_IRGraph g) = (λx . (case g x of None  IllegalStamp | Some n  snd n))"
  by (simp add: stamp.rep_eq)

lemma [simp]: "ids (irgraph g) = set (map fst (no_node g))" 
  by (simp add: irgraph)

lemma [simp]: "kind (irgraph g) = (λnid. (case (map_of (no_node g)) nid of None  NoNode | Some n  fst n))" 
  by (simp add: kind.rep_eq irgraph.rep_eq)

lemma [simp]: "stamp (irgraph g) = (λnid. (case (map_of (no_node g)) nid of None  IllegalStamp | Some n  snd n))" 
  by (simp add: stamp.rep_eq irgraph.rep_eq)

lemma map_of_upd: "(map_of g)(k  v) = (map_of ((k, v) # g))"
  by simp

(* this proof should be simplier *)
lemma [code]: "replace_node nid k (irgraph g) = (irgraph ( ((nid, k) # g)))"
proof (cases "fst k = NoNode")
  case True
  then show ?thesis
    by (metis (mono_tags, lifting) Rep_IRGraph_inject filter.simps(2) irgraph.abs_eq no_node.simps 
        replace_node.rep_eq snd_conv)
next
  case False
  then show ?thesis
    by (smt (verit, ccfv_SIG) irgraph_def Rep_IRGraph comp_apply eq_onp_same_args filter.simps(2)
        id_def irgraph.rep_eq map_fun_apply map_of_upd mem_Collect_eq no_node.elims replace_node_def
        replace_node.abs_eq snd_eqD)
qed

lemma [code]: "add_node nid k (irgraph g) = (irgraph (((nid, k) # g)))"
  by (smt (verit) Rep_IRGraph_inject add_node.rep_eq filter.simps(2) irgraph.rep_eq map_of_upd
      snd_conv no_node.simps)

lemma add_node_lookup:
  "gup = add_node nid (k, s) g  
    (if k  NoNode then kind gup nid = k  stamp gup nid = s else kind gup nid = kind g nid)"
proof (cases "k = NoNode")
  case True
  then show ?thesis
    by (simp add: add_node.rep_eq kind.rep_eq)
next
  case False
  then show ?thesis
    by (simp add: kind.rep_eq add_node.rep_eq stamp.rep_eq)
qed

lemma remove_node_lookup:
  "gup = remove_node nid g  kind gup nid = NoNode  stamp gup nid = IllegalStamp"
  by (simp add: kind.rep_eq remove_node.rep_eq stamp.rep_eq)

lemma replace_node_lookup[simp]:
  "gup = replace_node nid (k, s) g  k  NoNode  kind gup nid = k  stamp gup nid = s"
  by (simp add: replace_node.rep_eq kind.rep_eq stamp.rep_eq)

lemma replace_node_unchanged:
  "gup = replace_node nid (k, s) g  ( n  (ids g - {nid}) . n  ids g  n  ids gup  kind g n = kind gup n)" 
  by (simp add: kind.rep_eq replace_node.rep_eq)

subsubsection "Example Graphs"
text "Example 1: empty graph (just a start and end node)"
definition start_end_graph:: IRGraph where
  "start_end_graph = irgraph [(0, StartNode None 1, VoidStamp), (1, ReturnNode None None, VoidStamp)]"

text ‹Example 2:
  public static int sq(int x) { return x * x; }

             [1 P(0)]
               \ /
  [0 Start]   [4 *]
       |      /
       V     /
      [5 Return]
›
definition eg2_sq :: "IRGraph" where
  "eg2_sq = irgraph [
    (0, StartNode None 5, VoidStamp),
    (1, ParameterNode 0, default_stamp),
    (4, MulNode 1 1, default_stamp),
    (5, ReturnNode (Some 4) None, default_stamp)
   ]"

(* TODO: to include the float type (used by stamps) we need
         a code equation for float_of but it is not clear how
         to implement this correctly
lemma[code]: "float_of n = 0"
*)

(* Test the code generation. *)
value "input_edges eg2_sq"
value "usages eg2_sq 1"

end