src/HOL/Tools/function_package/lexicographic_order.ML
author bulwahn
Tue Feb 13 10:09:21 2007 +0100 (2007-02-13)
changeset 22309 87ec1ca65312
parent 22258 0967b03844b5
child 22325 be61bd159a99
permissions -rw-r--r--
improved lexicographic order termination tactic
bulwahn@21131
     1
(*  Title:       HOL/Tools/function_package/lexicographic_order.ML
krauss@21201
     2
    ID:          $Id$
bulwahn@21131
     3
    Author:      Lukas Bulwahn, TU Muenchen
bulwahn@21131
     4
bulwahn@21131
     5
Method for termination proofs with lexicographic orderings.
bulwahn@21131
     6
*)
bulwahn@21131
     7
bulwahn@21131
     8
signature LEXICOGRAPHIC_ORDER =
bulwahn@21131
     9
sig
krauss@21319
    10
  val lexicographic_order : Proof.context -> Method.method
krauss@21510
    11
krauss@21510
    12
  (* exported for use by size-change termination prototype.
krauss@21510
    13
     FIXME: provide a common interface later *)
krauss@21510
    14
  val mk_base_funs : typ -> term list
bulwahn@22309
    15
  (* exported for debugging *)
krauss@21237
    16
  val setup: theory -> theory
bulwahn@21131
    17
end
bulwahn@21131
    18
bulwahn@21131
    19
structure LexicographicOrder : LEXICOGRAPHIC_ORDER =
bulwahn@21131
    20
struct
bulwahn@21131
    21
bulwahn@21131
    22
(* Theory dependencies *)
bulwahn@21131
    23
val measures = "List.measures"
bulwahn@21131
    24
val wf_measures = thm "wf_measures"
bulwahn@21131
    25
val measures_less = thm "measures_less"
bulwahn@21131
    26
val measures_lesseq = thm "measures_lesseq"
krauss@21237
    27
                      
bulwahn@21131
    28
fun del_index (n, []) = []
bulwahn@21131
    29
  | del_index (n, x :: xs) =
krauss@21237
    30
    if n>0 then x :: del_index (n - 1, xs) else xs 
bulwahn@21131
    31
bulwahn@21131
    32
fun transpose ([]::_) = []
bulwahn@21131
    33
  | transpose xss = map hd xss :: transpose (map tl xss)
bulwahn@21131
    34
bulwahn@21131
    35
fun mk_sum_case (f1, f2) =
bulwahn@21131
    36
    case (fastype_of f1, fastype_of f2) of
krauss@21237
    37
      (Type("fun", [A, B]), Type("fun", [C, D])) =>
krauss@21237
    38
      if (B = D) then
krauss@21237
    39
        Const("Datatype.sum.sum_case", (A --> B) --> (C --> D) --> Type("+", [A,C]) --> B) $ f1 $ f2
krauss@21237
    40
      else raise TERM ("mk_sum_case: range type mismatch", [f1, f2]) 
krauss@21237
    41
    | _ => raise TERM ("mk_sum_case", [f1, f2])
krauss@21237
    42
                 
bulwahn@21131
    43
fun dest_wf (Const ("Wellfounded_Recursion.wf", _) $ t) = t
bulwahn@21131
    44
  | dest_wf t = raise TERM ("dest_wf", [t])
krauss@21237
    45
                      
bulwahn@21131
    46
datatype cell = Less of thm | LessEq of thm | None of thm | False of thm;
krauss@21237
    47
         
bulwahn@21131
    48
fun is_Less cell = case cell of (Less _) => true | _ => false  
krauss@21237
    49
                                                        
bulwahn@21131
    50
fun is_LessEq cell = case cell of (LessEq _) => true | _ => false
krauss@21237
    51
                                                            
bulwahn@21131
    52
fun thm_of_cell cell =
bulwahn@21131
    53
    case cell of 
krauss@21237
    54
      Less thm => thm
krauss@21237
    55
    | LessEq thm => thm
krauss@21237
    56
    | False thm => thm
krauss@21237
    57
    | None thm => thm
krauss@21237
    58
                  
bulwahn@21131
    59
fun mk_base_fun_bodys (t : term) (tt : typ) =
bulwahn@21131
    60
    case tt of
krauss@21237
    61
      Type("*", [ft, st]) => (mk_base_fun_bodys (Const("fst", tt --> ft) $ t) ft) @ (mk_base_fun_bodys (Const("snd", tt --> st) $ t) st)      
krauss@21237
    62
    | _ => [(t, tt)]
krauss@21237
    63
           
bulwahn@21131
    64
fun mk_base_fun_header fulltyp (t, typ) =
krauss@21816
    65
    Abs ("x", fulltyp, Const("Nat.size", typ --> HOLogic.natT) $ t)
krauss@21237
    66
         
bulwahn@21131
    67
fun mk_base_funs (tt: typ) = 
bulwahn@21131
    68
    mk_base_fun_bodys (Bound 0) tt |>
krauss@21237
    69
                      map (mk_base_fun_header tt)
bulwahn@22258
    70
bulwahn@22309
    71
fun mk_funorder_funs (tt : typ) (one : bool) : Term.term list =
bulwahn@22258
    72
    case tt of
bulwahn@22309
    73
      Type("+", [ft, st]) => let
bulwahn@22309
    74
                               val ft_funs = mk_funorder_funs ft
bulwahn@22309
    75
                               val st_funs = mk_funorder_funs st 
bulwahn@22309
    76
                             in
bulwahn@22309
    77
                               (if one then 
bulwahn@22309
    78
                                 (product (ft_funs true) (st_funs false)) @ (product (ft_funs false) (st_funs true))
bulwahn@22309
    79
                               else
bulwahn@22309
    80
                                 product (ft_funs false) (st_funs false)) 
bulwahn@22258
    81
                                    |> map mk_sum_case
bulwahn@22309
    82
                             end
bulwahn@22309
    83
    | _ => if one then [Abs ("x", tt, HOLogic.Suc_zero)] else [Abs ("x", tt, HOLogic.zero)]
bulwahn@22258
    84
bulwahn@21131
    85
fun mk_ext_base_funs (tt : typ) =
bulwahn@21131
    86
    case tt of
krauss@21237
    87
      Type("+", [ft, st]) =>
krauss@21237
    88
      product (mk_ext_base_funs ft) (mk_ext_base_funs st)
krauss@21237
    89
              |> map mk_sum_case
krauss@21237
    90
    | _ => mk_base_funs tt
bulwahn@22258
    91
bulwahn@22258
    92
fun mk_all_measure_funs (tt : typ) =
bulwahn@22258
    93
    case tt of
bulwahn@22309
    94
      Type("+", _) => (mk_ext_base_funs tt) @ (mk_funorder_funs tt true)
bulwahn@22258
    95
    | _ => mk_base_funs tt
krauss@21237
    96
           
bulwahn@21131
    97
fun dest_term (t : term) =
bulwahn@21131
    98
    let
krauss@21237
    99
      val (vars, prop) = (FundefLib.dest_all_all t)
krauss@21237
   100
      val prems = Logic.strip_imp_prems prop
krauss@21237
   101
      val (tuple, rel) = Logic.strip_imp_concl prop
krauss@21237
   102
                         |> HOLogic.dest_Trueprop 
krauss@21237
   103
                         |> HOLogic.dest_mem
krauss@21237
   104
      val (lhs, rhs) = HOLogic.dest_prod tuple
bulwahn@21131
   105
    in
krauss@21237
   106
      (vars, prems, lhs, rhs, rel)
bulwahn@21131
   107
    end
krauss@21237
   108
    
bulwahn@21131
   109
fun mk_goal (vars, prems, lhs, rhs) rel =
bulwahn@21131
   110
    let 
krauss@21237
   111
      val concl = HOLogic.mk_binrel rel (lhs, rhs) |> HOLogic.mk_Trueprop
krauss@21237
   112
    in  
krauss@21237
   113
      Logic.list_implies (prems, concl) |>
krauss@21237
   114
                         fold_rev FundefLib.mk_forall vars
bulwahn@21131
   115
    end
krauss@21237
   116
    
bulwahn@21131
   117
fun prove (thy: theory) (t: term) =
bulwahn@21131
   118
    cterm_of thy t |> Goal.init 
bulwahn@21131
   119
    |> SINGLE (CLASIMPSET auto_tac) |> the
krauss@21237
   120
    
bulwahn@21131
   121
fun mk_cell (thy : theory) (vars, prems) (lhs, rhs) = 
krauss@21237
   122
    let 
krauss@21237
   123
      val goals = mk_goal (vars, prems, lhs, rhs) 
krauss@21237
   124
      val less_thm = goals "Orderings.less" |> prove thy
bulwahn@21131
   125
    in
krauss@21237
   126
      if Thm.no_prems less_thm then
krauss@21237
   127
        Less (Goal.finish less_thm)
krauss@21237
   128
      else
krauss@21237
   129
        let
krauss@21237
   130
          val lesseq_thm = goals "Orderings.less_eq" |> prove thy
krauss@21237
   131
        in
krauss@21237
   132
          if Thm.no_prems lesseq_thm then
krauss@21237
   133
            LessEq (Goal.finish lesseq_thm)
krauss@21237
   134
          else 
krauss@21237
   135
            if prems_of lesseq_thm = [HOLogic.Trueprop $ HOLogic.false_const] then False lesseq_thm
krauss@21237
   136
            else None lesseq_thm
krauss@21237
   137
        end
bulwahn@21131
   138
    end
krauss@21237
   139
    
bulwahn@22258
   140
fun mk_row (thy: theory) measure_funs (t : term) =
bulwahn@21131
   141
    let
krauss@21237
   142
      val (vars, prems, lhs, rhs, _) = dest_term t
bulwahn@22258
   143
      val lhs_list = map (fn x => x $ lhs) measure_funs
bulwahn@22258
   144
      val rhs_list = map (fn x => x $ rhs) measure_funs
bulwahn@21131
   145
    in
krauss@21237
   146
      map (mk_cell thy (vars, prems)) (lhs_list ~~ rhs_list)
bulwahn@21131
   147
    end
bulwahn@22309
   148
    
bulwahn@22309
   149
fun pr_cell cell = case cell of Less _ => " <  " 
bulwahn@22309
   150
                              | LessEq _ => " <= " 
bulwahn@22309
   151
                              | None _ => " N  "
bulwahn@22309
   152
                              | False _ => " F  "
bulwahn@22309
   153
bulwahn@22309
   154
fun pr_table table = writeln (cat_lines (map (fn r => concat (map pr_cell r)) table))
bulwahn@22309
   155
bulwahn@22309
   156
fun check_col ls = (forall (fn c => is_Less c orelse is_LessEq c) ls) andalso not (forall (fn c => is_LessEq c) ls)
bulwahn@22309
   157
bulwahn@22309
   158
fun transform_table table col = table |> filter_out (fn x => is_Less (nth x col)) |> map (curry del_index col)
bulwahn@22309
   159
bulwahn@22309
   160
fun transform_order col order = map (fn x => if x>=col then x+1 else x) order
krauss@21237
   161
      
bulwahn@21131
   162
(* simple depth-first search algorithm for the table *)
bulwahn@21131
   163
fun search_table table =
bulwahn@21131
   164
    case table of
krauss@21237
   165
      [] => SOME []
krauss@21237
   166
    | _ =>
krauss@21237
   167
      let
krauss@21237
   168
        val col = find_index (check_col) (transpose table)
krauss@21237
   169
      in case col of
krauss@21237
   170
           ~1 => NONE 
krauss@21237
   171
         | _ =>
krauss@21237
   172
           let
bulwahn@22309
   173
             val order_opt = (table, col) |-> transform_table |> search_table
krauss@21237
   174
           in case order_opt of
krauss@21237
   175
                NONE => NONE
krauss@21237
   176
              | SOME order =>SOME (col::(transform_order col order))
krauss@21237
   177
           end
krauss@21237
   178
      end
bulwahn@22258
   179
bulwahn@22258
   180
(* find all positions of elements in a list *) 
bulwahn@22258
   181
fun find_index_list pred =
bulwahn@22258
   182
  let fun find _ [] = []
bulwahn@22258
   183
        | find n (x :: xs) = if pred x then n::(find (n + 1) xs) else find (n + 1) xs;
bulwahn@22258
   184
  in find 0 end;
bulwahn@22258
   185
bulwahn@22258
   186
(* simple breadth-first search algorithm for the table *) 
bulwahn@22309
   187
fun bfs_search_table nodes =
bulwahn@22309
   188
    case nodes of
bulwahn@22258
   189
      [] => sys_error "INTERNAL ERROR IN lexicographic order termination tactic - fun search_table (breadth search finished)" 
bulwahn@22309
   190
    | (node::rnodes) => let
bulwahn@22309
   191
	val (order, table) = node
bulwahn@22309
   192
      in
bulwahn@22309
   193
        case table of
bulwahn@22309
   194
          [] => SOME (foldr (fn (c, order) => c::transform_order c order) [] (rev order))
bulwahn@22309
   195
        | _ => let
bulwahn@22309
   196
	    val cols = find_index_list (check_col) (transpose table)
bulwahn@22309
   197
          in
bulwahn@22309
   198
            case cols of
bulwahn@22309
   199
	      [] => NONE
bulwahn@22309
   200
            | _ => let 
bulwahn@22309
   201
              val newtables = map (transform_table table) cols 
bulwahn@22309
   202
              val neworders = map (fn c => c::order) cols
bulwahn@22309
   203
              val newnodes = neworders ~~ newtables
bulwahn@22309
   204
            in
bulwahn@22309
   205
              bfs_search_table (rnodes @ newnodes)
bulwahn@22309
   206
            end 
bulwahn@22309
   207
          end
bulwahn@22258
   208
      end
bulwahn@22258
   209
bulwahn@22309
   210
fun nsearch_table table = bfs_search_table [([], table)] 	       
bulwahn@22258
   211
bulwahn@21131
   212
fun prove_row row (st : thm) =
bulwahn@21131
   213
    case row of
krauss@21237
   214
      [] => sys_error "INTERNAL ERROR IN lexicographic order termination tactic - fun prove_row (row is empty)" 
krauss@21237
   215
    | cell::tail =>
krauss@21237
   216
      case cell of
krauss@21237
   217
        Less less_thm =>
krauss@21237
   218
        let
krauss@21237
   219
          val next_thm = st |> SINGLE (rtac measures_less 1) |> the
krauss@21237
   220
        in
krauss@21237
   221
          implies_elim next_thm less_thm
krauss@21237
   222
        end
krauss@21237
   223
      | LessEq lesseq_thm =>
krauss@21237
   224
        let
krauss@21237
   225
          val next_thm = st |> SINGLE (rtac measures_lesseq 1) |> the
krauss@21237
   226
        in
krauss@21237
   227
          implies_elim next_thm lesseq_thm 
krauss@21237
   228
          |> prove_row tail
krauss@21237
   229
        end
krauss@21237
   230
      | _ => sys_error "INTERNAL ERROR IN lexicographic order termination tactic - fun prove_row (Only expecting Less or LessEq)"
krauss@21237
   231
             
bulwahn@21131
   232
fun pr_unprovable_subgoals table =
bulwahn@21131
   233
    filter (fn x => not (is_Less x) andalso not (is_LessEq x)) (flat table)
krauss@21237
   234
    |> map ((fn th => Pretty.string_of (Pretty.chunks (Display.pretty_goals (Thm.nprems_of th) th))) o thm_of_cell)
krauss@21237
   235
    
bulwahn@21131
   236
fun pr_goal thy t i = 
bulwahn@21131
   237
    let
krauss@21237
   238
      val (_, prems, lhs, rhs, _) = dest_term t 
krauss@21237
   239
      val prterm = string_of_cterm o (cterm_of thy)
bulwahn@21131
   240
    in
krauss@21237
   241
      (* also show prems? *)
bulwahn@21131
   242
        i ^ ") " ^ (prterm lhs) ^ " '<' " ^ (prterm rhs) 
bulwahn@21131
   243
    end
krauss@21237
   244
    
bulwahn@21131
   245
fun pr_fun thy t i =
bulwahn@21131
   246
    (string_of_int i) ^ ") " ^ (string_of_cterm (cterm_of thy t))
krauss@21237
   247
                                             
bulwahn@21131
   248
(* fun pr_err: prints the table if tactic failed *)
bulwahn@22258
   249
fun pr_err table thy tl measure_funs =  
bulwahn@21131
   250
    let 
krauss@21237
   251
      val gc = map (fn i => chr (i + 96)) (1 upto (length table))
bulwahn@22258
   252
      val mc = 1 upto (length measure_funs)
krauss@21237
   253
      val tstr = ("   " ^ (concat (map (fn i => " " ^ (string_of_int i) ^ "  ") mc))) ::
krauss@21237
   254
                 (map2 (fn r => fn i => i ^ ": " ^ (concat (map pr_cell r))) table gc)
krauss@21237
   255
      val gstr = ("Goals:"::(map2 (pr_goal thy) tl gc))
bulwahn@22258
   256
      val mstr = ("Measures:"::(map2 (pr_fun thy) measure_funs mc))   
krauss@21237
   257
      val ustr = ("Unfinished subgoals:"::(pr_unprovable_subgoals table))
bulwahn@21131
   258
    in
krauss@21237
   259
      tstr @ gstr @ mstr @ ustr
bulwahn@21131
   260
    end
krauss@21237
   261
      
bulwahn@21131
   262
(* the main function: create table, search table, create relation,
wenzelm@21590
   263
   and prove the subgoals *)  (* FIXME proper goal addressing -- do not hardwire 1 *)
krauss@21319
   264
fun lexicographic_order_tac ctxt (st: thm) = 
bulwahn@21131
   265
    let
krauss@21237
   266
      val thy = theory_of_thm st
krauss@21319
   267
      val termination_thm = the (FundefCommon.get_termination_rule ctxt)
krauss@21237
   268
      val next_st = SINGLE (rtac termination_thm 1) st |> the
krauss@21237
   269
      val premlist = prems_of next_st
bulwahn@21131
   270
    in
krauss@21237
   271
      case premlist of 
bulwahn@21131
   272
            [] => error "invalid number of subgoals for this tactic - expecting at least 1 subgoal" 
bulwahn@21131
   273
          | (wf::tl) => let
krauss@21237
   274
    val (var, prop) = FundefLib.dest_all wf
krauss@21237
   275
    val rel = HOLogic.dest_Trueprop prop |> dest_wf |> head_of
krauss@21237
   276
    val crel = cterm_of thy rel
bulwahn@22258
   277
    val measure_funs = mk_all_measure_funs (fastype_of var)
krauss@21237
   278
    val _ = writeln "Creating table"
bulwahn@22258
   279
    val table = map (mk_row thy measure_funs) tl
krauss@21237
   280
    val _ = writeln "Searching for lexicographic order"
bulwahn@22309
   281
    (* val _ = pr_table table *)
krauss@21237
   282
    val possible_order = search_table table
krauss@21237
   283
      in
krauss@21237
   284
    case possible_order of 
bulwahn@22258
   285
        NONE => error (cat_lines ("Could not find lexicographic termination order:"::(pr_err table thy tl measure_funs)))
krauss@21237
   286
      | SOME order  => let
krauss@21237
   287
      val clean_table = map (fn x => map (nth x) order) table
bulwahn@22258
   288
      val funs = map (nth measure_funs) order
wenzelm@21757
   289
      val list = HOLogic.mk_list (fastype_of var --> HOLogic.natT) funs
krauss@21237
   290
      val relterm = Abs ("x", fastype_of var, Const(measures, (fastype_of list) --> (range_type (fastype_of rel))) $ list)
krauss@21237
   291
      val crelterm = cterm_of thy relterm
krauss@21237
   292
      val _ = writeln ("Instantiating R with " ^ (string_of_cterm crelterm))
krauss@21237
   293
      val _ = writeln "Proving subgoals"
krauss@21237
   294
        in
krauss@21237
   295
      next_st |> cterm_instantiate [(crel, crelterm)]
krauss@21237
   296
        |> SINGLE (rtac wf_measures 1) |> the
krauss@21237
   297
        |> fold prove_row clean_table
krauss@21237
   298
        |> Seq.single
bulwahn@21131
   299
                    end
bulwahn@21131
   300
            end
bulwahn@21131
   301
    end
bulwahn@21131
   302
krauss@21319
   303
val lexicographic_order = Method.SIMPLE_METHOD o lexicographic_order_tac
krauss@21201
   304
krauss@21319
   305
val setup = Method.add_methods [("lexicographic_order", Method.ctxt_args lexicographic_order, "termination prover for lexicographic orderings")]
bulwahn@21131
   306
wenzelm@21590
   307
end