src/HOL/Tools/Function/mutual.ML
changeset 53608 53bd62921c54
parent 53607 825b6a41411b
child 54566 5f3e9baa8f13
equal deleted inserted replaced
53607:825b6a41411b 53608:53bd62921c54
   246       end
   246       end
   247   in
   247   in
   248     fst (fold_map (project induct_inst) parts 0)
   248     fst (fold_map (project induct_inst) parts 0)
   249   end
   249   end
   250 
   250 
   251 fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, ...}) proof =
   251 fun mutual_cases_rule ctxt cases_rule n ST (MutualPart {i, cargTs, ...}) =
       
   252   let
       
   253     val arg_vars = 
       
   254       cargTs
       
   255       |> map_index (fn (i, T) => Free ("x" ^ string_of_int i, T)) (* FIXME: proper context *)
       
   256 
       
   257     val argsT = fastype_of (HOLogic.mk_tuple arg_vars)
       
   258     val args = Free ("x", argsT) (* FIXME: proper context *)
       
   259 
       
   260     val cert = cterm_of (Proof_Context.theory_of ctxt)
       
   261 
       
   262     val sumtree_inj = SumTree.mk_inj ST n i args
       
   263 
       
   264     val sum_elims =
       
   265       @{thms HOL.notE[OF Sum_Type.sum.distinct(1)] HOL.notE[OF Sum_Type.sum.distinct(2)]}
       
   266 
       
   267     fun prep_subgoal i =
       
   268       REPEAT (eresolve_tac @{thms Pair_inject Inl_inject[elim_format] Inr_inject[elim_format]} i)
       
   269       THEN REPEAT (Tactic.eresolve_tac sum_elims i)
       
   270   in
       
   271     cases_rule
       
   272     |> Thm.forall_elim @{cterm "P::bool"}
       
   273     |> Thm.forall_elim (cert sumtree_inj)
       
   274     |> Tactic.rule_by_tactic ctxt (ALLGOALS prep_subgoal)
       
   275     |> Thm.forall_intr (cert args)
       
   276     |> Thm.forall_intr @{cterm "P::bool"}
       
   277   end
       
   278 
       
   279 
       
   280 fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, n, ST, ...}) proof =
   252   let
   281   let
   253     val result = inner_cont proof
   282     val result = inner_cont proof
   254     val FunctionResult {G, R, cases, psimps, simple_pinducts=[simple_pinduct],
   283     val FunctionResult {G, R, cases=[cases_rule], psimps, simple_pinducts=[simple_pinduct],
   255       termination, domintros, dom, pelims, ...} = result
   284       termination, domintros, dom, pelims, ...} = result
   256 
   285 
   257     val (all_f_defs, fs) =
   286     val (all_f_defs, fs) =
   258       map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} =>
   287       map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} =>
   259         (mk_applied_form lthy cargTs (Thm.symmetric f_def), f))
   288         (mk_applied_form lthy cargTs (Thm.symmetric f_def), f))
   267       in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp
   296       in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp
   268 
   297 
   269     val rew_simpset = put_simpset HOL_basic_ss lthy addsimps all_f_defs
   298     val rew_simpset = put_simpset HOL_basic_ss lthy addsimps all_f_defs
   270     val mpsimps = map2 mk_mpsimp fqgars psimps
   299     val mpsimps = map2 mk_mpsimp fqgars psimps
   271     val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m
   300     val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m
       
   301     val mcases = map (mutual_cases_rule lthy cases_rule n ST) parts
   272     val mtermination = full_simplify rew_simpset termination
   302     val mtermination = full_simplify rew_simpset termination
   273     val mdomintros = Option.map (map (full_simplify rew_simpset)) domintros
   303     val mdomintros = Option.map (map (full_simplify rew_simpset)) domintros
   274 
   304 
   275   in
   305   in
   276     FunctionResult { fs=fs, G=G, R=R, dom=dom,
   306     FunctionResult { fs=fs, G=G, R=R, dom=dom,
   277       psimps=mpsimps, simple_pinducts=minducts,
   307       psimps=mpsimps, simple_pinducts=minducts,
   278       cases=cases, pelims=pelims, termination=mtermination,
   308       cases=mcases, pelims=pelims, termination=mtermination,
   279       domintros=mdomintros}
   309       domintros=mdomintros}
   280   end
   310   end
   281 
       
   282 
       
   283 fun postprocess_cases_rules ctxt cont proof =
       
   284   let val result = cont proof;
       
   285       val FunctionResult {fs, G, R, dom, psimps, simple_pinducts, cases, pelims,
       
   286                         termination, domintros, ...} = result;
       
   287       val n_fs = length fs;
       
   288       val domT = R |> dest_Free |> snd |> hd o snd o dest_Type
       
   289 
       
   290       fun postprocess_cases_rule (idx,f) =
       
   291         let val lhs_of =
       
   292               prop_of #> Logic.strip_assums_concl #> HOLogic.dest_Trueprop #> HOLogic.dest_eq #> fst
       
   293 
       
   294             val f_simps = filter (fn r => Term.head_of (lhs_of r) aconv f) psimps
       
   295             val arity = length (snd (strip_comb (lhs_of (hd f_simps))))
       
   296 
       
   297             val arg_vars = 
       
   298                 take arity (binder_types (fastype_of f))
       
   299                 |> map_index (fn (i, T) => Free ("x" ^ string_of_int i, T)) (* FIXME: proper context *)
       
   300 
       
   301             val argsT = fastype_of (HOLogic.mk_tuple arg_vars);
       
   302             val args = Free ("x", argsT); (* FIXME: proper context *)
       
   303 
       
   304             val cert = cterm_of (Proof_Context.theory_of ctxt);
       
   305 
       
   306             val sumtree_inj = SumTree.mk_inj domT n_fs (idx+1) args;
       
   307 
       
   308             val sum_elims = @{thms HOL.notE[OF Sum_Type.sum.distinct(1)]
       
   309                                    HOL.notE[OF Sum_Type.sum.distinct(2)]};
       
   310             fun prep_subgoal i =
       
   311               REPEAT (eresolve_tac @{thms Pair_inject Inl_inject[elim_format]
       
   312                                           Inr_inject[elim_format]} i)
       
   313               THEN REPEAT (Tactic.eresolve_tac sum_elims i);
       
   314 
       
   315         in
       
   316             hd cases
       
   317               |> Thm.forall_elim @{cterm "P::bool"}
       
   318               |> Thm.forall_elim (cert sumtree_inj)
       
   319               |> Tactic.rule_by_tactic ctxt (ALLGOALS prep_subgoal)
       
   320               |> Thm.forall_intr (cert args)
       
   321               |> Thm.forall_intr @{cterm "P::bool"}
       
   322 
       
   323         end;
       
   324 
       
   325   val cases' = map_index postprocess_cases_rule fs;
       
   326 
       
   327 in
       
   328   FunctionResult {fs=fs, G=G, R=R, dom=dom, psimps=psimps,
       
   329                   simple_pinducts=simple_pinducts,
       
   330                   cases=cases', pelims=pelims, termination=termination,
       
   331                   domintros=domintros}
       
   332 end;
       
   333 
   311 
   334 
   312 
   335 fun prepare_function_mutual config defname fixes eqss lthy =
   313 fun prepare_function_mutual config defname fixes eqss lthy =
   336   let
   314   let
   337     val mutual as Mutual {fsum_var=(n, T), qglrs, ...} =
   315     val mutual as Mutual {fsum_var=(n, T), qglrs, ...} =
   341       Function_Core.prepare_function config defname [((n, T), NoSyn)] qglrs lthy
   319       Function_Core.prepare_function config defname [((n, T), NoSyn)] qglrs lthy
   342 
   320 
   343     val (mutual', lthy'') = define_projections fixes mutual fsum lthy'
   321     val (mutual', lthy'') = define_projections fixes mutual fsum lthy'
   344 
   322 
   345     val cont' = mk_partial_rules_mutual lthy'' cont mutual'
   323     val cont' = mk_partial_rules_mutual lthy'' cont mutual'
   346     val cont'' = postprocess_cases_rules lthy'' cont'
   324   in
   347   in
   325     ((goalstate, cont'), lthy'')
   348     ((goalstate, cont''), lthy'')
       
   349   end
   326   end
   350 
   327 
   351 end
   328 end