TFL/tfl.sml
author wenzelm
Fri Mar 07 15:30:23 1997 +0100 (1997-03-07)
changeset 2768 bc6d915b8019
parent 2112 3902e9af752f
child 3191 14bd6e5985f1
permissions -rw-r--r--
renamed SYSTEM to RAW_ML_SYSTEM;
     1 functor TFL(structure Rules : Rules_sig
     2             structure Thry  : Thry_sig
     3             structure Thms  : Thms_sig
     4             sharing type Rules.binding = Thry.binding = 
     5                          Thry.USyntax.binding = Mask.binding
     6             sharing type Rules.Type = Thry.Type = Thry.USyntax.Type
     7             sharing type Rules.Preterm = Thry.Preterm = Thry.USyntax.Preterm
     8             sharing type Rules.Term = Thry.Term = Thry.USyntax.Term
     9             sharing type Thms.Thm = Rules.Thm = Thry.Thm) : TFL_sig  =
    10 struct
    11 
    12 (* Declarations *)
    13 structure Thms = Thms;
    14 structure Rules = Rules;
    15 structure Thry = Thry;
    16 structure USyntax = Thry.USyntax;
    17 
    18 type Preterm = Thry.USyntax.Preterm;
    19 type Term = Thry.USyntax.Term;
    20 type Thm = Thms.Thm;
    21 type Thry = Thry.Thry;
    22 type Tactic = Rules.Tactic;
    23    
    24 
    25 (* Abbreviations *)
    26 structure R = Rules;
    27 structure S = USyntax;
    28 structure U = S.Utils;
    29 
    30 (* Declares 'a binding datatype *)
    31 open Mask;
    32 
    33 nonfix mem --> |-> ##;
    34 val --> = S.-->;
    35 val ##    = U.##;
    36 
    37 infixr 3 -->;
    38 infixr 7 |->;
    39 infix  4 ##; 
    40 
    41 val concl = #2 o R.dest_thm;
    42 val hyp = #1 o R.dest_thm;
    43 
    44 val list_mk_type = U.end_itlist (U.curry(op -->));
    45 
    46 fun flatten [] = []
    47   | flatten (h::t) = h@flatten t;
    48 
    49 
    50 fun gtake f =
    51   let fun grab(0,rst) = ([],rst)
    52         | grab(n, x::rst) = 
    53              let val (taken,left) = grab(n-1,rst)
    54              in (f x::taken, left) end
    55   in grab
    56   end;
    57 
    58 fun enumerate L = 
    59  rev(#1(U.rev_itlist (fn x => fn (alist,i) => ((x,i)::alist, i+1)) L ([],0)));
    60 
    61 fun stringize [] = ""
    62   | stringize [i] = U.int_to_string i
    63   | stringize (h::t) = (U.int_to_string h^", "^stringize t);
    64 
    65 
    66 fun TFL_ERR{func,mesg} = U.ERR{module = "Tfl", func = func, mesg = mesg};
    67 
    68 
    69 (*---------------------------------------------------------------------------
    70  * The next function is common to pattern-match translation and 
    71  * proof of completeness of cases for the induction theorem.
    72  *
    73  * "gvvariant" make variables that are guaranteed not to be in vlist and
    74  * furthermore, are guaranteed not to be equal to each other. The names of
    75  * the variables will start with "v" and end in a number.
    76  *---------------------------------------------------------------------------*)
    77 local val counter = ref 0
    78 in
    79 fun gvvariant vlist =
    80   let val slist = ref (map (#Name o S.dest_var) vlist)
    81       val mem = U.mem (U.curry (op=))
    82       val _ = counter := 0
    83       fun pass str = 
    84          if (mem str (!slist)) 
    85          then ( counter := !counter + 1;
    86                 pass (U.concat"v" (U.int_to_string(!counter))))
    87          else (slist := str :: !slist; str)
    88   in 
    89   fn ty => S.mk_var{Name=pass "v",  Ty=ty}
    90   end
    91 end;
    92 
    93 
    94 (*---------------------------------------------------------------------------
    95  * Used in induction theorem production. This is the simple case of
    96  * partitioning up pattern rows by the leading constructor.
    97  *---------------------------------------------------------------------------*)
    98 fun ipartition gv (constructors,rows) =
    99   let fun pfail s = raise TFL_ERR{func = "partition.part", mesg = s}
   100       fun part {constrs = [],   rows = [],   A} = rev A
   101         | part {constrs = [],   rows = _::_, A} = pfail"extra cases in defn"
   102         | part {constrs = _::_, rows = [],   A} = pfail"cases missing in defn"
   103         | part {constrs = c::crst, rows,     A} =
   104           let val {Name,Ty} = S.dest_const c
   105               val (L,_) = S.strip_type Ty
   106               val (in_group, not_in_group) =
   107                U.itlist (fn (row as (p::rst, rhs)) =>
   108                          fn (in_group,not_in_group) =>
   109                   let val (pc,args) = S.strip_comb p
   110                   in if (#Name(S.dest_const pc) = Name)
   111                      then ((args@rst, rhs)::in_group, not_in_group)
   112                      else (in_group, row::not_in_group)
   113                   end)      rows ([],[])
   114               val col_types = U.take S.type_of (length L, #1(hd in_group))
   115           in 
   116           part{constrs = crst, rows = not_in_group, 
   117                A = {constructor = c, 
   118                     new_formals = map gv col_types,
   119                     group = in_group}::A}
   120           end
   121   in part{constrs = constructors, rows = rows, A = []}
   122   end;
   123 
   124 
   125 
   126 (*---------------------------------------------------------------------------
   127  * This datatype carries some information about the origin of a
   128  * clause in a function definition.
   129  *---------------------------------------------------------------------------*)
   130 datatype pattern = GIVEN   of S.Preterm * int
   131                  | OMITTED of S.Preterm * int
   132 
   133 fun psubst theta (GIVEN (tm,i)) = GIVEN(S.subst theta tm, i)
   134   | psubst theta (OMITTED (tm,i)) = OMITTED(S.subst theta tm, i);
   135 
   136 fun dest_pattern (GIVEN (tm,i)) = ((GIVEN,i),tm)
   137   | dest_pattern (OMITTED (tm,i)) = ((OMITTED,i),tm);
   138 
   139 val pat_of = #2 o dest_pattern;
   140 val row_of_pat = #2 o #1 o dest_pattern;
   141 
   142 (*---------------------------------------------------------------------------
   143  * Produce an instance of a constructor, plus genvars for its arguments.
   144  *---------------------------------------------------------------------------*)
   145 fun fresh_constr ty_match colty gv c =
   146   let val {Ty,...} = S.dest_const c
   147       val (L,ty) = S.strip_type Ty
   148       val ty_theta = ty_match ty colty
   149       val c' = S.inst ty_theta c
   150       val gvars = map (S.inst ty_theta o gv) L
   151   in (c', gvars)
   152   end;
   153 
   154 
   155 (*---------------------------------------------------------------------------
   156  * Goes through a list of rows and picks out the ones beginning with a
   157  * pattern with constructor = Name.
   158  *---------------------------------------------------------------------------*)
   159 fun mk_group Name rows =
   160   U.itlist (fn (row as ((prefix, p::rst), rhs)) =>
   161             fn (in_group,not_in_group) =>
   162                let val (pc,args) = S.strip_comb p
   163                in if ((#Name(S.dest_const pc) = Name) handle _ => false)
   164                   then (((prefix,args@rst), rhs)::in_group, not_in_group)
   165                   else (in_group, row::not_in_group) end)
   166       rows ([],[]);
   167 
   168 (*---------------------------------------------------------------------------
   169  * Partition the rows. Not efficient: we should use hashing.
   170  *---------------------------------------------------------------------------*)
   171 fun partition _ _ (_,_,_,[]) = raise TFL_ERR{func="partition", mesg="no rows"}
   172   | partition gv ty_match
   173               (constructors, colty, res_ty, rows as (((prefix,_),_)::_)) =
   174 let val fresh = fresh_constr ty_match colty gv
   175      fun part {constrs = [],      rows, A} = rev A
   176        | part {constrs = c::crst, rows, A} =
   177          let val (c',gvars) = fresh c
   178              val {Name,Ty} = S.dest_const c'
   179              val (in_group, not_in_group) = mk_group Name rows
   180              val in_group' =
   181                  if (null in_group)  (* Constructor not given *)
   182                  then [((prefix, #2(fresh c)), OMITTED (S.ARB res_ty, ~1))]
   183                  else in_group
   184          in 
   185          part{constrs = crst, 
   186               rows = not_in_group, 
   187               A = {constructor = c', 
   188                    new_formals = gvars,
   189                    group = in_group'}::A}
   190          end
   191 in part{constrs=constructors, rows=rows, A=[]}
   192 end;
   193 
   194 (*---------------------------------------------------------------------------
   195  * Misc. routines used in mk_case
   196  *---------------------------------------------------------------------------*)
   197 
   198 fun mk_pat c =
   199   let val L = length(#1(S.strip_type(S.type_of c)))
   200       fun build (prefix,tag,plist) =
   201           let val (args,plist') = gtake U.I (L, plist)
   202            in (prefix,tag,S.list_mk_comb(c,args)::plist') end
   203   in map build 
   204   end;
   205 
   206 fun v_to_prefix (prefix, v::pats) = (v::prefix,pats)
   207   | v_to_prefix _ = raise TFL_ERR{func="mk_case", mesg="v_to_prefix"};
   208 
   209 fun v_to_pats (v::prefix,tag, pats) = (prefix, tag, v::pats)
   210   | v_to_pats _ = raise TFL_ERR{func="mk_case", mesg="v_to_pats"};
   211  
   212 
   213 (*----------------------------------------------------------------------------
   214  * Translation of pattern terms into nested case expressions.
   215  *
   216  * This performs the translation and also builds the full set of patterns. 
   217  * Thus it supports the construction of induction theorems even when an 
   218  * incomplete set of patterns is given.
   219  *---------------------------------------------------------------------------*)
   220 
   221 fun mk_case ty_info ty_match FV range_ty =
   222  let 
   223  fun mk_case_fail s = raise TFL_ERR{func = "mk_case", mesg = s}
   224  val fresh_var = gvvariant FV 
   225  val divide = partition fresh_var ty_match
   226  fun expand constructors ty ((_,[]), _) = mk_case_fail"expand_var_row"
   227    | expand constructors ty (row as ((prefix, p::rst), rhs)) = 
   228        if (S.is_var p) 
   229        then let val fresh = fresh_constr ty_match ty fresh_var
   230                 fun expnd (c,gvs) = 
   231                   let val capp = S.list_mk_comb(c,gvs)
   232                   in ((prefix, capp::rst), psubst[p |-> capp] rhs)
   233                   end
   234             in map expnd (map fresh constructors)  end
   235        else [row]
   236  fun mk{rows=[],...} = mk_case_fail"no rows"
   237    | mk{path=[], rows = ((prefix, []), rhs)::_} =  (* Done *)
   238         let val (tag,tm) = dest_pattern rhs
   239         in ([(prefix,tag,[])], tm)
   240         end
   241    | mk{path=[], rows = _::_} = mk_case_fail"blunder"
   242    | mk{path as u::rstp, rows as ((prefix, []), rhs)::rst} = 
   243         mk{path = path, 
   244            rows = ((prefix, [fresh_var(S.type_of u)]), rhs)::rst}
   245    | mk{path = u::rstp, rows as ((_, p::_), _)::_} =
   246      let val (pat_rectangle,rights) = U.unzip rows
   247          val col0 = map(hd o #2) pat_rectangle
   248      in 
   249      if (U.all S.is_var col0) 
   250      then let val rights' = map(fn(v,e) => psubst[v|->u] e) (U.zip col0 rights)
   251               val pat_rectangle' = map v_to_prefix pat_rectangle
   252               val (pref_patl,tm) = mk{path = rstp,
   253                                       rows = U.zip pat_rectangle' rights'}
   254           in (map v_to_pats pref_patl, tm)
   255           end
   256      else
   257      let val pty = S.type_of p
   258          val ty_name = (#Tyop o S.dest_type) pty
   259      in
   260      case (ty_info ty_name)
   261      of U.NONE => mk_case_fail("Not a known datatype: "^ty_name)
   262       | U.SOME{case_const,constructors} =>
   263         let val case_const_name = #Name(S.dest_const case_const)
   264             val nrows = flatten (map (expand constructors pty) rows)
   265             val subproblems = divide(constructors, pty, range_ty, nrows)
   266             val groups      = map #group subproblems
   267             and new_formals = map #new_formals subproblems
   268             and constructors' = map #constructor subproblems
   269             val news = map (fn (nf,rows) => {path = nf@rstp, rows=rows})
   270                            (U.zip new_formals groups)
   271             val rec_calls = map mk news
   272             val (pat_rect,dtrees) = U.unzip rec_calls
   273             val case_functions = map S.list_mk_abs(U.zip new_formals dtrees)
   274             val types = map S.type_of (case_functions@[u]) @ [range_ty]
   275             val case_const' = S.mk_const{Name = case_const_name,
   276                                          Ty   = list_mk_type types}
   277             val tree = S.list_mk_comb(case_const', case_functions@[u])
   278             val pat_rect1 = flatten(U.map2 mk_pat constructors' pat_rect)
   279         in (pat_rect1,tree)
   280         end 
   281      end end
   282  in mk
   283  end;
   284 
   285 
   286 (* Repeated variable occurrences in a pattern are not allowed. *)
   287 fun FV_multiset tm = 
   288    case (S.dest_term tm)
   289      of S.VAR v => [S.mk_var v]
   290       | S.CONST _ => []
   291       | S.COMB{Rator, Rand} => FV_multiset Rator @ FV_multiset Rand
   292       | S.LAMB _ => raise TFL_ERR{func = "FV_multiset", mesg = "lambda"};
   293 
   294 fun no_repeat_vars thy pat =
   295  let fun check [] = true
   296        | check (v::rst) =
   297          if (U.mem S.aconv v rst) 
   298          then raise TFL_ERR{func = "no_repeat_vars",
   299              mesg = U.concat(U.quote(#Name(S.dest_var v)))
   300                      (U.concat" occurs repeatedly in the pattern "
   301                          (U.quote(S.Term_to_string (Thry.typecheck thy pat))))}
   302          else check rst
   303  in check (FV_multiset pat)
   304  end;
   305 
   306 local fun paired1{lhs,rhs} = (lhs,rhs) 
   307       and paired2{Rator,Rand} = (Rator,Rand)
   308       fun mk_functional_err s = raise TFL_ERR{func = "mk_functional", mesg=s}
   309 in
   310 fun mk_functional thy eqs =
   311  let val clauses = S.strip_conj eqs
   312      val (L,R) = U.unzip (map (paired1 o S.dest_eq o U.snd o S.strip_forall)
   313                               clauses)
   314      val (funcs,pats) = U.unzip(map (paired2 o S.dest_comb) L)
   315      val [f] = U.mk_set (S.aconv) funcs 
   316                handle _ => mk_functional_err "function name not unique"
   317      val _ = map (no_repeat_vars thy) pats
   318      val rows = U.zip (map (fn x => ([],[x])) pats) (map GIVEN (enumerate R))
   319      val fvs = S.free_varsl R
   320      val a = S.variant fvs (S.mk_var{Name="a", Ty = S.type_of(hd pats)})
   321      val FV = a::fvs
   322      val ty_info = Thry.match_info thy
   323      val ty_match = Thry.match_type thy
   324      val range_ty = S.type_of (hd R)
   325      val (patts, case_tm) = mk_case ty_info ty_match FV range_ty 
   326                                     {path=[a], rows=rows}
   327      val patts1 = map (fn (_,(tag,i),[pat]) => tag (pat,i)) patts handle _ 
   328                   => mk_functional_err "error in pattern-match translation"
   329      val patts2 = U.sort(fn p1=>fn p2=> row_of_pat p1 < row_of_pat p2) patts1
   330      val finals = map row_of_pat patts2
   331      val originals = map (row_of_pat o #2) rows
   332      fun int_eq i1 (i2:int) =  (i1=i2)
   333      val _ = case (U.set_diff int_eq originals finals)
   334              of [] => ()
   335           | L => mk_functional_err("The following rows (counting from zero)\
   336                                    \ are inaccessible: "^stringize L)
   337  in {functional = S.list_mk_abs ([f,a], case_tm),
   338      pats = patts2}
   339 end end;
   340 
   341 
   342 (*----------------------------------------------------------------------------
   343  *
   344  *                    PRINCIPLES OF DEFINITION
   345  *
   346  *---------------------------------------------------------------------------*)
   347 
   348 
   349 (*----------------------------------------------------------------------------
   350  * This basic principle of definition takes a functional M and a relation R
   351  * and specializes the following theorem
   352  *
   353  *    |- !M R f. (f = WFREC R M) ==> WF R ==> !x. f x = M (f%R,x) x
   354  *
   355  * to them (getting "th1", say). Then we make the definition "f = WFREC R M" 
   356  * and instantiate "th1" to the constant "f" (getting th2). Then we use the
   357  * definition to delete the first antecedent to th2. Hence the result in
   358  * the "corollary" field is 
   359  *
   360  *    |-  WF R ==> !x. f x = M (f%R,x) x
   361  *
   362  *---------------------------------------------------------------------------*)
   363 
   364 fun prim_wfrec_definition thy {R, functional} =
   365  let val tych = Thry.typecheck thy
   366      val {Bvar,...} = S.dest_abs functional
   367      val {Name,...} = S.dest_var Bvar  (* Intended name of definition *)
   368      val cor1 = R.ISPEC (tych functional) Thms.WFREC_COROLLARY
   369      val cor2 = R.ISPEC (tych R) cor1
   370      val f_eq_WFREC_R_M = (#ant o S.dest_imp o #Body 
   371                            o S.dest_forall o concl) cor2
   372      val {lhs,rhs} = S.dest_eq f_eq_WFREC_R_M
   373      val {Ty, ...} = S.dest_var lhs
   374      val def_term = S.mk_eq{lhs = S.mk_var{Name=Name,Ty=Ty}, rhs=rhs}
   375      val (def_thm,thy1) = Thry.make_definition thy 
   376                                   (U.concat Name "_def") def_term
   377      val (_,[f,_]) = (S.strip_comb o concl) def_thm
   378      val cor3 = R.ISPEC (Thry.typecheck thy1 f) cor2
   379  in 
   380  {theory = thy1, def=def_thm, corollary=R.MP cor3 def_thm}
   381  end;
   382 
   383 
   384 (*---------------------------------------------------------------------------
   385  * This structure keeps track of congruence rules that aren't derived
   386  * from a datatype definition.
   387  *---------------------------------------------------------------------------*)
   388 structure Context =
   389 struct
   390   val non_datatype_context = ref []:Rules.Thm list ref
   391   fun read() = !non_datatype_context
   392   fun write L = (non_datatype_context := L)
   393 end;
   394 
   395 fun extraction_thms thy = 
   396  let val {case_rewrites,case_congs} = Thry.extract_info thy
   397  in (case_rewrites, case_congs@Context.read())
   398  end;
   399 
   400 
   401 (*---------------------------------------------------------------------------
   402  * Pair patterns with termination conditions. The full list of patterns for
   403  * a definition is merged with the TCs arising from the user-given clauses.
   404  * There can be fewer clauses than the full list, if the user omitted some 
   405  * cases. This routine is used to prepare input for mk_induction.
   406  *---------------------------------------------------------------------------*)
   407 fun merge full_pats TCs =
   408 let fun insert (p,TCs) =
   409       let fun insrt ((x as (h,[]))::rst) = 
   410                  if (S.aconv p h) then (p,TCs)::rst else x::insrt rst
   411             | insrt (x::rst) = x::insrt rst
   412             | insrt[] = raise TFL_ERR{func="merge.insert",mesg="pat not found"}
   413       in insrt end
   414     fun pass ([],ptcl_final) = ptcl_final
   415       | pass (ptcs::tcl, ptcl) = pass(tcl, insert ptcs ptcl)
   416 in 
   417   pass (TCs, map (fn p => (p,[])) full_pats)
   418 end;
   419 
   420 fun not_omitted (GIVEN(tm,_)) = tm
   421   | not_omitted (OMITTED _) = raise TFL_ERR{func="not_omitted",mesg=""}
   422 val givens = U.mapfilter not_omitted;
   423 
   424 
   425 (*--------------------------------------------------------------------------
   426  * This is a wrapper for "prim_wfrec_definition": it builds a functional,
   427  * calls "prim_wfrec_definition", then specializes the result. This gives a
   428  * list of rewrite rules where the right hand sides are quite ugly, so we 
   429  * simplify to get rid of the case statements. In essence, this function
   430  * performs pre- and post-processing for patterns. As well, after
   431  * simplification, termination conditions are extracted.
   432  *-------------------------------------------------------------------------*)
   433 
   434 fun gen_wfrec_definition thy {R, eqs} =
   435  let val {functional,pats}  = mk_functional thy eqs
   436      val given_pats = givens pats
   437      val {def,corollary,theory} = prim_wfrec_definition thy
   438                                         {R=R, functional=functional}
   439      val tych = Thry.typecheck theory 
   440      val {lhs=f,...} = S.dest_eq(concl def)
   441      val WFR = #ant(S.dest_imp(concl corollary))
   442      val corollary' = R.UNDISCH corollary  (* put WF R on assums *)
   443      val corollaries = map (U.C R.SPEC corollary' o tych) given_pats
   444      val (case_rewrites,context_congs) = extraction_thms thy
   445      val corollaries' = map(R.simplify case_rewrites) corollaries
   446      fun xtract th = R.CONTEXT_REWRITE_RULE(f,R)
   447                          {thms = [(R.ISPECL o map tych)[f,R] Thms.CUT_LEMMA],
   448                          congs = context_congs,
   449                             th = th}
   450      val (rules, TCs) = U.unzip (map xtract corollaries')
   451      val rules0 = map (R.simplify [Thms.CUT_DEF]) rules
   452      val mk_cond_rule = R.FILTER_DISCH_ALL(not o S.aconv WFR)
   453      val rules1 = R.LIST_CONJ(map mk_cond_rule rules0)
   454  in
   455  {theory = theory,   (* holds def, if it's needed *)
   456   rules = rules1,
   457   full_pats_TCs = merge (map pat_of pats) (U.zip given_pats TCs), 
   458   TCs = TCs, 
   459   patterns = pats}
   460  end;
   461 
   462 
   463 (*---------------------------------------------------------------------------
   464  * Perform the extraction without making the definition. Definition and
   465  * extraction commute for the non-nested case. For hol90 users, this 
   466  * function can be invoked without being in draft mode.
   467  *---------------------------------------------------------------------------*)
   468 fun wfrec_eqns thy eqns =
   469  let val {functional,pats} = mk_functional thy eqns
   470      val given_pats = givens pats
   471      val {Bvar = f, Body} = S.dest_abs functional
   472      val {Bvar = x, ...} = S.dest_abs Body
   473      val {Name,Ty = fty} = S.dest_var f
   474      val {Tyop="fun", Args = [f_dty, f_rty]} = S.dest_type fty
   475      val (case_rewrites,context_congs) = extraction_thms thy
   476      val tych = Thry.typecheck thy
   477      val WFREC_THM0 = R.ISPEC (tych functional) Thms.WFREC_COROLLARY
   478      val R = S.variant(S.free_vars eqns) 
   479                       (#Bvar(S.dest_forall(concl WFREC_THM0)))
   480      val WFREC_THM = R.ISPECL [tych R, tych f] WFREC_THM0
   481      val ([proto_def, WFR],_) = S.strip_imp(concl WFREC_THM)
   482      val R1 = S.rand WFR
   483      val corollary' = R.UNDISCH(R.UNDISCH WFREC_THM)
   484      val corollaries = map (U.C R.SPEC corollary' o tych) given_pats
   485      val corollaries' = map (R.simplify case_rewrites) corollaries
   486      fun extract th = R.CONTEXT_REWRITE_RULE(f,R1)
   487                         {thms = [(R.ISPECL o map tych)[f,R1] Thms.CUT_LEMMA], 
   488                         congs = context_congs,
   489                           th  = th}
   490  in {proto_def=proto_def, 
   491      WFR=WFR, 
   492      pats=pats,
   493      extracta = map extract corollaries'}
   494  end;
   495 
   496 
   497 (*---------------------------------------------------------------------------
   498  * Define the constant after extracting the termination conditions. The 
   499  * wellfounded relation used in the definition is computed by using the
   500  * choice operator on the extracted conditions (plus the condition that
   501  * such a relation must be wellfounded).
   502  *---------------------------------------------------------------------------*)
   503 fun lazyR_def thy eqns =
   504  let val {proto_def,WFR,pats,extracta} = wfrec_eqns thy eqns
   505      val R1 = S.rand WFR
   506      val f = S.lhs proto_def
   507      val {Name,...} = S.dest_var f
   508      val (extractants,TCl) = U.unzip extracta
   509      val TCs = U.Union S.aconv TCl
   510      val full_rqt = WFR::TCs
   511      val R' = S.mk_select{Bvar=R1, Body=S.list_mk_conj full_rqt}
   512      val R'abs = S.rand R'
   513      val (def,theory) = Thry.make_definition thy (U.concat Name "_def") 
   514                                                  (S.subst[R1 |-> R'] proto_def)
   515      val fconst = #lhs(S.dest_eq(concl def)) 
   516      val tych = Thry.typecheck theory
   517      val baz = R.DISCH (tych proto_def)
   518                  (U.itlist (R.DISCH o tych) full_rqt (R.LIST_CONJ extractants))
   519      val def' = R.MP (R.SPEC (tych fconst) 
   520                              (R.SPEC (tych R') (R.GENL[tych R1, tych f] baz)))
   521                      def
   522      val body_th = R.LIST_CONJ (map (R.ASSUME o tych) full_rqt)
   523      val bar = R.MP (R.BETA_RULE(R.ISPECL[tych R'abs, tych R1] Thms.SELECT_AX))
   524                      body_th
   525  in {theory = theory, R=R1,
   526      rules = U.rev_itlist (U.C R.MP) (R.CONJUNCTS bar) def',
   527      full_pats_TCs = merge (map pat_of pats) (U.zip (givens pats) TCl),
   528      patterns = pats}
   529  end;
   530 
   531 
   532 
   533 (*----------------------------------------------------------------------------
   534  *
   535  *                           INDUCTION THEOREM
   536  *
   537  *---------------------------------------------------------------------------*)
   538 
   539 
   540 (*------------------------  Miscellaneous function  --------------------------
   541  *
   542  *           [x_1,...,x_n]     ?v_1...v_n. M[v_1,...,v_n]
   543  *     -----------------------------------------------------------
   544  *     ( M[x_1,...,x_n], [(x_i,?v_1...v_n. M[v_1,...,v_n]),
   545  *                        ... 
   546  *                        (x_j,?v_n. M[x_1,...,x_(n-1),v_n])] )
   547  *
   548  * This function is totally ad hoc. Used in the production of the induction 
   549  * theorem. The nchotomy theorem can have clauses that look like
   550  *
   551  *     ?v1..vn. z = C vn..v1
   552  *
   553  * in which the order of quantification is not the order of occurrence of the
   554  * quantified variables as arguments to C. Since we have no control over this
   555  * aspect of the nchotomy theorem, we make the correspondence explicit by
   556  * pairing the incoming new variable with the term it gets beta-reduced into.
   557  *---------------------------------------------------------------------------*)
   558 
   559 fun alpha_ex_unroll xlist tm =
   560   let val (qvars,body) = S.strip_exists tm
   561       val vlist = #2(S.strip_comb (S.rhs body))
   562       val plist = U.zip vlist xlist
   563       val args = map (U.C (U.assoc1 (U.uncurry S.aconv)) plist) qvars
   564       val args' = map (fn U.SOME(_,v) => v 
   565                         | U.NONE => raise TFL_ERR{func = "alpha_ex_unroll",
   566                                        mesg = "no correspondence"}) args
   567       fun build ex [] = []
   568         | build ex (v::rst) =
   569            let val ex1 = S.beta_conv(S.mk_comb{Rator=S.rand ex, Rand=v})
   570            in ex1::build ex1 rst
   571            end
   572      val (nex::exl) = rev (tm::build tm args')
   573   in 
   574   (nex, U.zip args' (rev exl))
   575   end;
   576 
   577 
   578 
   579 (*----------------------------------------------------------------------------
   580  *
   581  *             PROVING COMPLETENESS OF PATTERNS
   582  *
   583  *---------------------------------------------------------------------------*)
   584 
   585 fun mk_case ty_info FV thy =
   586  let 
   587  val divide = ipartition (gvvariant FV)
   588  val tych = Thry.typecheck thy
   589  fun tych_binding(x|->y) = (tych x |-> tych y)
   590  fun fail s = raise TFL_ERR{func = "mk_case", mesg = s}
   591  fun mk{rows=[],...} = fail"no rows"
   592    | mk{path=[], rows = [([], (thm, bindings))]} = 
   593                          R.IT_EXISTS (map tych_binding bindings) thm
   594    | mk{path = u::rstp, rows as (p::_, _)::_} =
   595      let val (pat_rectangle,rights) = U.unzip rows
   596          val col0 = map hd pat_rectangle
   597          val pat_rectangle' = map tl pat_rectangle
   598      in 
   599      if (U.all S.is_var col0) (* column 0 is all variables *)
   600      then let val rights' = map (fn ((thm,theta),v) => (thm,theta@[u|->v]))
   601                                 (U.zip rights col0)
   602           in mk{path = rstp, rows = U.zip pat_rectangle' rights'}
   603           end
   604      else                     (* column 0 is all constructors *)
   605      let val ty_name = (#Tyop o S.dest_type o S.type_of) p
   606      in
   607      case (ty_info ty_name)
   608      of U.NONE => fail("Not a known datatype: "^ty_name)
   609       | U.SOME{constructors,nchotomy} =>
   610         let val thm' = R.ISPEC (tych u) nchotomy
   611             val disjuncts = S.strip_disj (concl thm')
   612             val subproblems = divide(constructors, rows)
   613             val groups      = map #group subproblems
   614             and new_formals = map #new_formals subproblems
   615             val existentials = U.map2 alpha_ex_unroll new_formals disjuncts
   616             val constraints = map #1 existentials
   617             val vexl = map #2 existentials
   618             fun expnd tm (pats,(th,b)) = (pats,(R.SUBS[R.ASSUME(tych tm)]th,b))
   619             val news = map (fn (nf,rows,c) => {path = nf@rstp, 
   620                                                rows = map (expnd c) rows})
   621                            (U.zip3 new_formals groups constraints)
   622             val recursive_thms = map mk news
   623             val build_exists = U.itlist(R.CHOOSE o (tych##(R.ASSUME o tych)))
   624             val thms' = U.map2 build_exists vexl recursive_thms
   625             val same_concls = R.EVEN_ORS thms'
   626         in R.DISJ_CASESL thm' same_concls
   627         end 
   628      end end
   629  in mk
   630  end;
   631 
   632 
   633 fun complete_cases thy =
   634  let val tych = Thry.typecheck thy
   635      fun pmk_var n ty = S.mk_var{Name = n,Ty = ty}
   636      val ty_info = Thry.induct_info thy
   637  in fn pats =>
   638  let val FV0 = S.free_varsl pats
   639      val a = S.variant FV0 (pmk_var "a" (S.type_of(hd pats)))
   640      val v = S.variant (a::FV0) (pmk_var "v" (S.type_of a))
   641      val FV = a::v::FV0
   642      val a_eq_v = S.mk_eq{lhs = a, rhs = v}
   643      val ex_th0 = R.EXISTS ((tych##tych) (S.mk_exists{Bvar=v,Body=a_eq_v},a))
   644                            (R.REFL (tych a))
   645      val th0 = R.ASSUME (tych a_eq_v)
   646      val rows = map (fn x => ([x], (th0,[]))) pats
   647  in
   648  R.GEN (tych a) 
   649        (R.RIGHT_ASSOC
   650           (R.CHOOSE(tych v, ex_th0)
   651                 (mk_case ty_info FV thy {path=[v], rows=rows})))
   652  end end;
   653 
   654 
   655 (*---------------------------------------------------------------------------
   656  * Constructing induction hypotheses: one for each recursive call.
   657  *
   658  * Note. R will never occur as a variable in the ind_clause, because
   659  * to do so, it would have to be from a nested definition, and we don't
   660  * allow nested defns to have R variable.
   661  *
   662  * Note. When the context is empty, there can be no local variables.
   663  *---------------------------------------------------------------------------*)
   664 
   665 local nonfix ^ ;   infix 9 ^  ;     infix 5 ==>
   666       fun (tm1 ^ tm2)   = S.mk_comb{Rator = tm1, Rand = tm2}
   667       fun (tm1 ==> tm2) = S.mk_imp{ant = tm1, conseq = tm2}
   668 in
   669 fun build_ih f P (pat,TCs) = 
   670  let val globals = S.free_vars_lr pat
   671      fun nested tm = U.can(S.find_term (S.aconv f)) tm handle _ => false
   672      fun dest_TC tm = 
   673          let val (cntxt,R_y_pat) = S.strip_imp(#2(S.strip_forall tm))
   674              val (R,y,_) = S.dest_relation R_y_pat
   675              val P_y = if (nested tm) then R_y_pat ==> P^y else P^y
   676          in case cntxt 
   677               of [] => (P_y, (tm,[]))
   678                | _  => let 
   679                     val imp = S.list_mk_conj cntxt ==> P_y
   680                     val lvs = U.set_diff S.aconv (S.free_vars_lr imp) globals
   681                     val locals = #2(U.pluck (S.aconv P) lvs) handle _ => lvs
   682                     in (S.list_mk_forall(locals,imp), (tm,locals)) end
   683          end
   684  in case TCs
   685     of [] => (S.list_mk_forall(globals, P^pat), [])
   686      |  _ => let val (ihs, TCs_locals) = U.unzip(map dest_TC TCs)
   687                  val ind_clause = S.list_mk_conj ihs ==> P^pat
   688              in (S.list_mk_forall(globals,ind_clause), TCs_locals)
   689              end
   690  end
   691 end;
   692 
   693 
   694 
   695 (*---------------------------------------------------------------------------
   696  * This function makes good on the promise made in "build_ih: we prove
   697  * <something>.
   698  *
   699  * Input  is tm = "(!y. R y pat ==> P y) ==> P pat",  
   700  *           TCs = TC_1[pat] ... TC_n[pat]
   701  *           thm = ih1 /\ ... /\ ih_n |- ih[pat]
   702  *---------------------------------------------------------------------------*)
   703 fun prove_case f thy (tm,TCs_locals,thm) =
   704  let val tych = Thry.typecheck thy
   705      val antc = tych(#ant(S.dest_imp tm))
   706      val thm' = R.SPEC_ALL thm
   707      fun nested tm = U.can(S.find_term (S.aconv f)) tm handle _ => false
   708      fun get_cntxt TC = tych(#ant(S.dest_imp(#2(S.strip_forall(concl TC)))))
   709      fun mk_ih ((TC,locals),th2,nested) =
   710          R.GENL (map tych locals)
   711             (if nested 
   712               then R.DISCH (get_cntxt TC) th2 handle _ => th2
   713                else if S.is_imp(concl TC) 
   714                      then R.IMP_TRANS TC th2 
   715                       else R.MP th2 TC)
   716  in 
   717  R.DISCH antc
   718  (if S.is_imp(concl thm') (* recursive calls in this clause *)
   719   then let val th1 = R.ASSUME antc
   720            val TCs = map #1 TCs_locals
   721            val ylist = map (#2 o S.dest_relation o #2 o S.strip_imp o 
   722                             #2 o S.strip_forall) TCs
   723            val TClist = map (fn(TC,lvs) => (R.SPEC_ALL(R.ASSUME(tych TC)),lvs))
   724                             TCs_locals
   725            val th2list = map (U.C R.SPEC th1 o tych) ylist
   726            val nlist = map nested TCs
   727            val triples = U.zip3 TClist th2list nlist
   728            val Pylist = map mk_ih triples
   729        in R.MP thm' (R.LIST_CONJ Pylist) end
   730   else thm')
   731  end;
   732 
   733 
   734 (*---------------------------------------------------------------------------
   735  *
   736  *         x = (v1,...,vn)  |- M[x]
   737  *    ---------------------------------------------
   738  *      ?v1 ... vn. x = (v1,...,vn) |- M[x]
   739  *
   740  *---------------------------------------------------------------------------*)
   741 fun LEFT_ABS_VSTRUCT tych thm = 
   742   let fun CHOOSER v (tm,thm) = 
   743         let val ex_tm = S.mk_exists{Bvar=v,Body=tm}
   744         in (ex_tm, R.CHOOSE(tych v, R.ASSUME (tych ex_tm)) thm)
   745         end
   746       val [veq] = U.filter (U.can S.dest_eq) (#1 (R.dest_thm thm))
   747       val {lhs,rhs} = S.dest_eq veq
   748       val L = S.free_vars_lr rhs
   749   in U.snd(U.itlist CHOOSER L (veq,thm))
   750   end;
   751 
   752 
   753 fun combize M N = S.mk_comb{Rator=M,Rand=N};
   754 fun eq v tm = S.mk_eq{lhs=v,rhs=tm};
   755 
   756 
   757 (*----------------------------------------------------------------------------
   758  * Input : f, R,  and  [(pat1,TCs1),..., (patn,TCsn)]
   759  *
   760  * Instantiates WF_INDUCTION_THM, getting Sinduct and then tries to prove
   761  * recursion induction (Rinduct) by proving the antecedent of Sinduct from 
   762  * the antecedent of Rinduct.
   763  *---------------------------------------------------------------------------*)
   764 fun mk_induction thy f R pat_TCs_list =
   765 let val tych = Thry.typecheck thy
   766     val Sinduction = R.UNDISCH (R.ISPEC (tych R) Thms.WF_INDUCTION_THM)
   767     val (pats,TCsl) = U.unzip pat_TCs_list
   768     val case_thm = complete_cases thy pats
   769     val domain = (S.type_of o hd) pats
   770     val P = S.variant (S.all_varsl (pats@flatten TCsl))
   771                       (S.mk_var{Name="P", Ty=domain --> S.bool})
   772     val Sinduct = R.SPEC (tych P) Sinduction
   773     val Sinduct_assumf = S.rand ((#ant o S.dest_imp o concl) Sinduct)
   774     val Rassums_TCl' = map (build_ih f P) pat_TCs_list
   775     val (Rassums,TCl') = U.unzip Rassums_TCl'
   776     val Rinduct_assum = R.ASSUME (tych (S.list_mk_conj Rassums))
   777     val cases = map (S.beta_conv o combize Sinduct_assumf) pats
   778     val tasks = U.zip3 cases TCl' (R.CONJUNCTS Rinduct_assum)
   779     val proved_cases = map (prove_case f thy) tasks
   780     val v = S.variant (S.free_varsl (map concl proved_cases))
   781                       (S.mk_var{Name="v", Ty=domain})
   782     val vtyped = tych v
   783     val substs = map (R.SYM o R.ASSUME o tych o eq v) pats
   784     val proved_cases1 = U.map2 (fn th => R.SUBS[th]) substs proved_cases
   785     val abs_cases = map (LEFT_ABS_VSTRUCT tych) proved_cases1
   786     val dant = R.GEN vtyped (R.DISJ_CASESL (R.ISPEC vtyped case_thm) abs_cases)
   787     val dc = R.MP Sinduct dant
   788     val Parg_ty = S.type_of(#Bvar(S.dest_forall(concl dc)))
   789     val vars = map (gvvariant[P]) (S.strip_prod_type Parg_ty)
   790     val dc' = U.itlist (R.GEN o tych) vars
   791                        (R.SPEC (tych(S.mk_vstruct Parg_ty vars)) dc)
   792 in 
   793    R.GEN (tych P) (R.DISCH (tych(concl Rinduct_assum)) dc')
   794 end 
   795 handle _ => raise TFL_ERR{func = "mk_induction", mesg = "failed derivation"};
   796 
   797 
   798 
   799 (*---------------------------------------------------------------------------
   800  * 
   801  *                        POST PROCESSING
   802  *
   803  *---------------------------------------------------------------------------*)
   804 
   805 
   806 fun simplify_induction thy hth ind = 
   807   let val tych = Thry.typecheck thy
   808       val (asl,_) = R.dest_thm ind
   809       val (_,tc_eq_tc') = R.dest_thm hth
   810       val tc = S.lhs tc_eq_tc'
   811       fun loop [] = ind
   812         | loop (asm::rst) = 
   813           if (U.can (Thry.match_term thy asm) tc)
   814           then R.UNDISCH
   815                  (R.MATCH_MP
   816                      (R.MATCH_MP Thms.simp_thm (R.DISCH (tych asm) ind)) 
   817                      hth)
   818          else loop rst
   819   in loop asl
   820 end;
   821 
   822 
   823 (*---------------------------------------------------------------------------
   824  * The termination condition is an antecedent to the rule, and an 
   825  * assumption to the theorem.
   826  *---------------------------------------------------------------------------*)
   827 fun elim_tc tcthm (rule,induction) = 
   828    (R.MP rule tcthm, R.PROVE_HYP tcthm induction)
   829 
   830 
   831 fun postprocess{WFtac, terminator, simplifier} theory {rules,induction,TCs} =
   832 let val tych = Thry.typecheck theory
   833 
   834    (*---------------------------------------------------------------------
   835     * Attempt to eliminate WF condition. It's the only assumption of rules
   836     *---------------------------------------------------------------------*)
   837    val (rules1,induction1)  = 
   838        let val thm = R.prove(tych(hd(#1(R.dest_thm rules))),WFtac)
   839        in (R.PROVE_HYP thm rules,  R.PROVE_HYP thm induction)
   840        end handle _ => (rules,induction)
   841 
   842    (*----------------------------------------------------------------------
   843     * The termination condition (tc) is simplified to |- tc = tc' (there
   844     * might not be a change!) and then 3 attempts are made:
   845     *
   846     *   1. if |- tc = T, then eliminate it with eqT; otherwise,
   847     *   2. apply the terminator to tc'. If |- tc' = T then eliminate; else
   848     *   3. replace tc by tc' in both the rules and the induction theorem.
   849     *---------------------------------------------------------------------*)
   850    fun simplify_tc tc (r,ind) =
   851        let val tc_eq = simplifier (tych tc)
   852        in 
   853        elim_tc (R.MATCH_MP Thms.eqT tc_eq) (r,ind)
   854        handle _ => 
   855         (elim_tc (R.MATCH_MP(R.MATCH_MP Thms.rev_eq_mp tc_eq)
   856                             (R.prove(tych(S.rhs(concl tc_eq)),terminator)))
   857                  (r,ind)
   858          handle _ => 
   859           (R.UNDISCH(R.MATCH_MP (R.MATCH_MP Thms.simp_thm r) tc_eq), 
   860            simplify_induction theory tc_eq ind))
   861        end
   862 
   863    (*----------------------------------------------------------------------
   864     * Nested termination conditions are harder to get at, since they are
   865     * left embedded in the body of the function (and in induction 
   866     * theorem hypotheses). Our "solution" is to simplify them, and try to 
   867     * prove termination, but leave the application of the resulting theorem 
   868     * to a higher level. So things go much as in "simplify_tc": the 
   869     * termination condition (tc) is simplified to |- tc = tc' (there might 
   870     * not be a change) and then 2 attempts are made:
   871     *
   872     *   1. if |- tc = T, then return |- tc; otherwise,
   873     *   2. apply the terminator to tc'. If |- tc' = T then return |- tc; else
   874     *   3. return |- tc = tc'
   875     *---------------------------------------------------------------------*)
   876    fun simplify_nested_tc tc =
   877       let val tc_eq = simplifier (tych (#2 (S.strip_forall tc)))
   878       in
   879       R.GEN_ALL
   880        (R.MATCH_MP Thms.eqT tc_eq
   881         handle _
   882         => (R.MATCH_MP(R.MATCH_MP Thms.rev_eq_mp tc_eq)
   883                       (R.prove(tych(S.rhs(concl tc_eq)),terminator))
   884             handle _ => tc_eq))
   885       end
   886 
   887    (*-------------------------------------------------------------------
   888     * Attempt to simplify the termination conditions in each rule and 
   889     * in the induction theorem.
   890     *-------------------------------------------------------------------*)
   891    fun strip_imp tm = if S.is_neg tm then ([],tm) else S.strip_imp tm
   892    fun loop ([],extras,R,ind) = (rev R, ind, extras)
   893      | loop ((r,ftcs)::rst, nthms, R, ind) =
   894         let val tcs = #1(strip_imp (concl r))
   895             val extra_tcs = U.set_diff S.aconv ftcs tcs
   896             val extra_tc_thms = map simplify_nested_tc extra_tcs
   897             val (r1,ind1) = U.rev_itlist simplify_tc tcs (r,ind)
   898             val r2 = R.FILTER_DISCH_ALL(not o S.is_WFR) r1
   899         in loop(rst, nthms@extra_tc_thms, r2::R, ind1)
   900         end
   901    val rules_tcs = U.zip (R.CONJUNCTS rules1) TCs
   902    val (rules2,ind2,extras) = loop(rules_tcs,[],[],induction1)
   903 in
   904   {induction = ind2, rules = R.LIST_CONJ rules2, nested_tcs = extras}
   905 end;
   906 
   907 
   908 (*---------------------------------------------------------------------------
   909  * Extract termination goals so that they can be put it into a goalstack, or 
   910  * have a tactic directly applied to them.
   911  *--------------------------------------------------------------------------*)
   912 local exception IS_NEG 
   913       fun strip_imp tm = if S.is_neg tm then raise IS_NEG else S.strip_imp tm
   914 in
   915 fun termination_goals rules = 
   916     U.itlist (fn th => fn A =>
   917         let val tcl = (#1 o S.strip_imp o #2 o S.strip_forall o concl) th
   918         in tcl@A
   919         end handle _ => A) (R.CONJUNCTS rules) (hyp rules)
   920 end;
   921 
   922 end; (* TFL *)