File ‹rewrites.ML›
signature RewriteSystem =
sig
val preservation: term;
val termination: term;
val intval: term;
end
signature DSL_REWRITES =
sig
val of_term: binding -> term -> term -> rewrite;
val preservation_of: Proof.context -> rewrite -> term;
val termination_of: Proof.context -> rewrite -> term;
val code_of: Proof.context -> rewrite -> term;
val rewrite_cmd: (binding * Token.src list) * string -> Proof.context -> Proof.state;
end
functor DSL_Rewrites(System: RewriteSystem): DSL_REWRITES =
struct
fun of_term name term source =
case term of
(((Const ("Markup.Rewrite.Transform", _)) $ lhs) $ rhs) =>
{name=name, rewrite=Transform (lhs, rhs), proofs=[], code=[], source=source}
| ((((Const ("Markup.Rewrite.Conditional", _)) $ lhs) $ rhs) $ cond) =>
{name=name, rewrite=Conditional (lhs, rhs, cond), proofs=[], code=[], source=source}
| _ => raise TERM ("optimization is not a rewrite", [term])
fun to_term rewrite =
case (#rewrite rewrite) of
Transform (lhs, rhs) =>
(Const ("Markup.Rewrite.Transform", @{typ "IRExpr => IRExpr => IRExpr Rewrite"})) $ lhs $ rhs
| Conditional (lhs, rhs, cond) =>
(Const ("Markup.Rewrite.Conditional", @{typ "IRExpr => IRExpr ⇒ bool ⇒ IRExpr Rewrite"})) $ lhs $ rhs $ cond
| _ => raise TERM ("rewrite cannot be translated yet", [])
fun lhs (term:term) =
case term of
(((Const ("Markup.Rewrite.Transform", _)) $ lhs) $ _) =>
lhs
| ((((Const ("Markup.Rewrite.Conditional", _)) $ lhs) $ _) $ _) =>
lhs
| _ => raise TERM ("optimization is not a rewrite", [term])
fun rhs (term:term) =
case term of
(((Const ("Markup.Rewrite.Transform", _)) $ _) $ rhs) =>
rhs
| ((((Const ("Markup.Rewrite.Conditional", _)) $ _) $ rhs) $ _) =>
rhs
| _ => raise TERM ("optimization is not a rewrite", [term])
fun preservation_of ctxt rewrite =
Syntax.check_prop ctxt
(@{const Trueprop} $ (System.preservation $ (to_term rewrite)))
fun termination_of' ctxt trm rewrite =
Syntax.check_prop ctxt
(@{const Trueprop} $ (System.termination $ (to_term rewrite) $ trm))
fun intval_of ctxt term =
Syntax.check_prop ctxt
(@{const Trueprop} $ (System.intval $ term))
fun value_def ctxt name rewrite =
((Const ("Pure.eq", @{typ "bool ⇒ bool ⇒ prop"})
$ (Free ("val_" ^ (Binding.name_of name), @{typ "bool"}))
$ (
(((Const ("HOL.eq", @{typ "Value ⇒ Value ⇒ bool"}))
$ (IntValMarkup.markup_expr [] ctxt [lhs rewrite])
$ (IntValMarkup.markup_expr [] ctxt [rhs rewrite])))
)))
fun val_def_const ctxt name rewrite =
@{const Trueprop}
$ ((Const ("HOL.eq", @{typ "bool ⇒ bool ⇒ bool"})
$ (Free ("val_" ^ (Binding.name_of name), @{typ "bool"}))
$ (
((Const ("HOL.eq", @{typ "int ⇒ int ⇒ bool"}))
$ @{term "(10::int) + (-(12::int))"}
$ @{term "(10::int) + (-(12::int))"}))
))
fun word_def ctxt name rewrite =
((Const ("HOL.eq", @{typ "bool ⇒ bool ⇒ bool"})
$ (Free ("word_" ^ name, @{typ "bool"}))
$ (
(@{const Trueprop}
$ ((Const ("HOL.eq", @{typ "('a::len) word ⇒ 'a word ⇒ bool"}))
$ (WordMarkup.markup_expr [] ctxt [lhs rewrite])
$ (WordMarkup.markup_expr [] ctxt [rhs rewrite])))
)))
fun termination_of ctxt rewrite =
let
val state = RewritePhase.current (Proof_Context.theory_of ctxt);
val trm = (case state of
NONE => raise TERM ("Optimization phase missing", []) |
SOME phase => (#trm phase)
);
in
termination_of' ctxt trm rewrite
end
fun code_of _ rewrite =
@{const Trueprop}
$ ((Const ("HOL.eq", @{typ "(IRExpr ⇒ IRExpr option) ⇒ (IRExpr ⇒ IRExpr option) ⇒ bool"})
$ (Free ((Binding.name_of (#name rewrite)) ^ "_code", @{typ "IRExpr ⇒ IRExpr option"}))
$ @{term ‹λ (x::IRExpr) ⇒ (None::IRExpr option)›}))
fun add_proofs (rewrite: rewrite) proofs =
{
name= #name rewrite,
rewrite= #rewrite rewrite,
proofs= proofs,
code= #code rewrite,
source= #source rewrite
}
fun rewrite_cmd ((bind, options), opt) ctxt =
let
val intval = case options of [[token]] => Token.content_of token = "intval" | _ => false;
val subgoals = case options of [[token]] => Token.content_of token = "subgoals" | _ => false;
val raw_term = (Syntax.parse_term ctxt opt);
val term = IRExprMarkup.markup_expr [] ctxt [raw_term];
val rewrite = of_term bind term raw_term;
val ctxt = (if subgoals then
(Specification.abbreviation (Syntax.mode_default)
NONE []
(val_def_const ctxt (#name rewrite) raw_term)
false ctxt)
else ctxt);
val intval_preservation = (if intval
then intval_of ctxt (IntValMarkup.markup_expr [term] ctxt [(Syntax.parse_term ctxt opt)])
else @{term True})
val extra = (if intval then [(intval_preservation, [intval_preservation])] else [])
val preservation = preservation_of ctxt rewrite;
val preservation = if intval
then @{term "Pure.imp"} $ intval_preservation $ preservation
else preservation;
val termination = termination_of ctxt rewrite;
val code = code_of ctxt rewrite;
val register = fn proofs => RewritePhase.register (bind, (add_proofs rewrite proofs));
fun after_qed thms lthy =
let
val lthy' = Local_Theory.background_theory (register (hd thms)) lthy;
val (_, lthy'') = Specification.definition
NONE [] []
((Binding.suffix_name "_code" bind, []), code)
lthy'
in
snd (Local_Theory.note ((bind, []), hd thms) lthy'')
end
val target = (Proof.theorem NONE after_qed
[extra @ [(preservation, [preservation]), (termination, [termination])]] ctxt);
val apply = (Proof.apply
((Method.Source (Token.make_src ("unfold_optimization", Position.start) [])),
Position.no_range) target);
val result = (case Seq.hd apply of
Seq.Result r => r
| Seq.Error _ => target);
val apply = (Proof.apply
((Method.Source (Token.make_src ("unfold_size", Position.start) [])),
Position.no_range) (Proof.defer 1 result));
val result = (case Seq.hd apply of
Seq.Result r => r
| Seq.Error _ => result);
in
result
end
end