File ‹markup.ML›

(*  Title:      Optimizations/DSL/markup.ML
    Author:     Brae Webb

Lift shorthand expressions to the full type.
*)

structure DSL_Tokens =
struct

datatype Tokens =
    Add
  | Sub
  | Mul
  | Div
  | Rem
  | And
  | Or
  | Xor
  | ShortCircuitOr
  | Less
  | Abs
  | Equals
  | Not
  | Negate
  | LogicNegate
  | LeftShift
  | RightShift
  | UnsignedRightShift
  | Conditional
  | Constant
  | TrueConstant
  | FalseConstant
end

signature DSL_TRANSLATION =
sig
val markup: DSL_Tokens.Tokens -> term
end

signature DSL_MARKUP =
sig
  val const_expr: Proof.context -> term list -> term
  val equals_expr: Proof.context -> term list -> term

  val markup_expr: term list -> Proof.context -> term list -> term

  val rewrite: term list -> Proof.context -> term list -> term
  val conditional_rewrite: term list -> Proof.context -> term list -> term
end;

functor DSL_Markup(Translator: DSL_TRANSLATION): DSL_MARKUP =
struct

fun const_expr _ [c as Const _] =
    @{const ConstantExpr} $ (@{const IntVal} $ c)
  | const_expr _ [x $ y] =
    @{const ConstantExpr} $ (@{const IntVal} $ (x $ y))
  | const_expr _ ts = raise TERM ("const_expr", ts)

fun equals_expr _ [lhs, rhs] =
    Translator.markup DSL_Tokens.Equals $ lhs $ rhs
  | equals_expr _ ts = raise TERM ("equals_expr", ts)


(* Shorthand expressions *)
fun markup_constant (str, typ) =
  case (str, typ) of
    ("constMarkup.ExtraNotation.ConditionalNotation", _) => Translator.markup DSL_Tokens.Conditional
    | ("Markup.ExtraNotation.ConditionalNotation", _) => Translator.markup DSL_Tokens.Conditional
    | ("constMarkup.ExtraNotation.EqualsNotation", _) => Translator.markup DSL_Tokens.Equals
    | ("Markup.ExtraNotation.EqualsNotation", _) => Translator.markup DSL_Tokens.Equals
    | ("constMarkup.ExtraNotation.ConstantNotation", _) => Translator.markup DSL_Tokens.Constant
    | ("Markup.ExtraNotation.ConstantNotation", _) => Translator.markup DSL_Tokens.Constant
    | ("constMarkup.ExtraNotation.TrueNotation", _) => Translator.markup DSL_Tokens.TrueConstant
    | ("Markup.ExtraNotation.TrueNotation", _) => Translator.markup DSL_Tokens.TrueConstant
    | ("constMarkup.ExtraNotation.FalseNotation", _) => Translator.markup DSL_Tokens.FalseConstant
    | ("Markup.ExtraNotation.FalseNotation", _) => Translator.markup DSL_Tokens.FalseConstant
    | ("constMarkup.ExtraNotation.LogicNegationNotation", _) => Translator.markup DSL_Tokens.LogicNegate
    | ("Markup.ExtraNotation.LogicNegationNotation", _) => Translator.markup DSL_Tokens.LogicNegate
    | ("constGroups.plus_class.plus", _) => Translator.markup DSL_Tokens.Add
    | ("Groups.plus_class.plus", _) => Translator.markup DSL_Tokens.Add
    | ("constGroups.minus_class.minus", _) => Translator.markup DSL_Tokens.Sub
    | ("Groups.minus_class.minus", _) => Translator.markup DSL_Tokens.Sub
    | ("constGroups.times_class.times", _) => Translator.markup DSL_Tokens.Mul
    | ("Groups.times_class.times", _) => Translator.markup DSL_Tokens.Mul
    | ("constRings.divide_class.divide", _) => Translator.markup DSL_Tokens.Div
    | ("Rings.divide_class.divide", _) => Translator.markup DSL_Tokens.Div
    | ("constFields.inverse_class.inverse_divide", _) => Translator.markup DSL_Tokens.Div
    | ("Fields.inverse_class.inverse_divide", _) => Translator.markup DSL_Tokens.Div
    | ("constMarkup.ExtraNotation.Remainder", _) => Translator.markup DSL_Tokens.Div
    | ("Markup.ExtraNotation.Remainder", _) => Translator.markup DSL_Tokens.Div
    | ("constHOL.conj", _) => Translator.markup DSL_Tokens.And
    | ("HOL.conj", _) => Translator.markup DSL_Tokens.And
    | ("constHOL.disj", _) => Translator.markup DSL_Tokens.Or
    | ("HOL.disj", _) => Translator.markup DSL_Tokens.Or
    | ("constMarkup.ExtraNotation.ShortCircuitOr", _) => Translator.markup DSL_Tokens.ShortCircuitOr
    | ("Markup.ExtraNotation.ShortCircuitOr", _) => Translator.markup DSL_Tokens.ShortCircuitOr
    | ("constMarkup.ExtraNotation.ExclusiveOr", _) => Translator.markup DSL_Tokens.Xor
    | ("Markup.ExtraNotation.ExclusiveOr", _) => Translator.markup DSL_Tokens.Xor
    | ("constGroups.uminus_class.uminus", _) => Translator.markup DSL_Tokens.Negate
    | ("Groups.uminus_class.uminus", _) => Translator.markup DSL_Tokens.Negate
    | ("constHOL.Not", _) => Translator.markup DSL_Tokens.Not
    | ("HOL.Not", _) => Translator.markup DSL_Tokens.Not
    | ("constOrderings.ord_class.less", _) => Translator.markup DSL_Tokens.Less
    | ("Orderings.ord_class.less", _) => Translator.markup DSL_Tokens.Less
    | ("constGroups.abs_class.abs", _) => Translator.markup DSL_Tokens.Abs
    | ("Groups.abs_class.abs", _) => Translator.markup DSL_Tokens.Abs
    | ("constJavaWords.shiftl", _) => Translator.markup DSL_Tokens.LeftShift
    | ("JavaWords.shiftl", _) => Translator.markup DSL_Tokens.LeftShift
    | ("constJavaWords.sshiftr", _) => Translator.markup DSL_Tokens.RightShift
    | ("JavaWords.sshiftr", _) => Translator.markup DSL_Tokens.RightShift
    | ("constJavaWords.shiftr", _) => Translator.markup DSL_Tokens.UnsignedRightShift
    | ("JavaWords.shiftr", _) => Translator.markup DSL_Tokens.UnsignedRightShift
    | _ => Const (str, typ)

fun find_free (taken, goal) =
  (singleton
    (Name.variant_list 
      (fold Term.add_tfree_names taken [])) goal)

fun markup_free taken (str, typ) =
  case (str, typ) of
    ("abs", _) => Translator.markup DSL_Tokens.Abs
    | (var, typ) => Free (find_free (taken, var), typ)

fun markup_expr taken ctxt [trm] =
  let
    val res = (case trm of
      Const (str, typ) => markup_constant (str, typ)
    | e as ((Const ("constIRTreeEval.IRExpr.ConstantExpr",_)) $ _) => e (* ignore within a constant *)
    | (e as (Const ("constMarkup.Rewrite.Conditional",_))) $ lhs $ rhs $ cond => 
      e $ markup_expr taken ctxt [lhs] $ markup_expr taken ctxt [rhs] $ cond
    | (e as (Const ("Markup.Rewrite.Conditional",_))) $ lhs $ rhs $ cond => 
      e $ markup_expr taken ctxt [lhs] $ markup_expr taken ctxt [rhs] $ cond
    | (x $ y) => (markup_expr taken ctxt [x] $ markup_expr taken ctxt [y])
    | Free (str, typ) => markup_free taken (str, typ)
    | Abs (str, typ, trm) => Abs (str, typ, markup_expr taken ctxt [trm])
    | _ => trm);
  in
    res
  end
  | markup_expr _ _ ts = raise TERM ("markup_expr", ts)


(* Rewrite syntax *)
fun rewrite taken ctxt [pre, post] =
  Const ("Transform", @{typ "'a => 'a  'a Rewrite"})
    $ markup_expr taken ctxt [pre]
    $ markup_expr taken ctxt [post]

  | rewrite _ _ ts = raise TERM ("rewrite", ts)

fun conditional_rewrite taken ctxt [pre, post, cond] =
  Const ("Conditional", @{typ "'a  'a  bool  'a Rewrite"})
    $ markup_expr taken ctxt [pre]
    $ markup_expr taken ctxt [post]
    $ cond

  | conditional_rewrite _ _ ts = raise TERM ("conditional_rewrite", ts)

end