src/HOL/Tools/function_package/context_tree.ML
changeset 21237 b803f9870e97
parent 21188 2aa15b663cd4
child 23819 2040846d1bbe
equal deleted inserted replaced
21236:890fafbcf8b0 21237:b803f9870e97
    53 
    53 
    54 
    54 
    55 (*** Dependency analysis for congruence rules ***)
    55 (*** Dependency analysis for congruence rules ***)
    56 
    56 
    57 fun branch_vars t = 
    57 fun branch_vars t = 
    58     let	
    58     let 
    59       val t' = snd (dest_all_all t)
    59       val t' = snd (dest_all_all t)
    60       val assumes = Logic.strip_imp_prems t'
    60       val assumes = Logic.strip_imp_prems t'
    61       val concl = Logic.strip_imp_concl t'
    61       val concl = Logic.strip_imp_concl t'
    62     in (fold (curry add_term_vars) assumes [], term_vars concl)
    62     in (fold (curry add_term_vars) assumes [], term_vars concl)
    63     end
    63     end
    64 
    64 
    65 fun cong_deps crule =
    65 fun cong_deps crule =
    66     let
    66     let
    67 	val branches = map branch_vars (prems_of crule)
    67   val branches = map branch_vars (prems_of crule)
    68 	val num_branches = (1 upto (length branches)) ~~ branches
    68   val num_branches = (1 upto (length branches)) ~~ branches
    69     in
    69     in
    70 	IntGraph.empty
    70   IntGraph.empty
    71 	    |> fold (fn (i,_)=> IntGraph.new_node (i,i)) num_branches
    71       |> fold (fn (i,_)=> IntGraph.new_node (i,i)) num_branches
    72 	    |> fold (fn ((i,(c1,_)),(j,(_, t2))) => if i = j orelse null (c1 inter t2) then I else IntGraph.add_edge_acyclic (i,j))
    72       |> fold (fn ((i,(c1,_)),(j,(_, t2))) => if i = j orelse null (c1 inter t2) then I else IntGraph.add_edge_acyclic (i,j))
    73 	    (product num_branches num_branches)
    73       (product num_branches num_branches)
    74     end
    74     end
    75     
    75     
    76 val add_congs = map (fn c => c RS eq_reflection) [cong, ext] 
    76 val add_congs = map (fn c => c RS eq_reflection) [cong, ext] 
    77 
    77 
    78 
    78 
    79 
    79 
    80 (* Called on the INSTANTIATED branches of the congruence rule *)
    80 (* Called on the INSTANTIATED branches of the congruence rule *)
    81 fun mk_branch ctx t = 
    81 fun mk_branch ctx t = 
    82     let
    82     let
    83 	val (ctx', fixes, impl) = dest_all_all_ctx ctx t
    83   val (ctx', fixes, impl) = dest_all_all_ctx ctx t
    84     in
    84     in
    85       (ctx', fixes, Logic.strip_imp_prems impl, rhs_of (Logic.strip_imp_concl impl))
    85       (ctx', fixes, Logic.strip_imp_prems impl, rhs_of (Logic.strip_imp_concl impl))
    86     end
    86     end
    87 
    87 
    88 fun find_cong_rule ctx fvar h ((r,dep)::rs) t =
    88 fun find_cong_rule ctx fvar h ((r,dep)::rs) t =
    94 
    94 
    95        val subst = Pattern.match (ProofContext.theory_of ctx) (c, tt') (Vartab.empty, Vartab.empty)
    95        val subst = Pattern.match (ProofContext.theory_of ctx) (c, tt') (Vartab.empty, Vartab.empty)
    96        val branches = map (mk_branch ctx o Envir.beta_norm o Envir.subst_vars subst) subs
    96        val branches = map (mk_branch ctx o Envir.beta_norm o Envir.subst_vars subst) subs
    97        val inst = map (fn v => (cterm_of thy (Var v), cterm_of thy (Envir.subst_vars subst (Var v)))) (Term.add_vars c [])
    97        val inst = map (fn v => (cterm_of thy (Var v), cterm_of thy (Envir.subst_vars subst (Var v)))) (Term.add_vars c [])
    98      in
    98      in
    99 	 (cterm_instantiate inst r, dep, branches)
    99    (cterm_instantiate inst r, dep, branches)
   100      end
   100      end
   101     handle Pattern.MATCH => find_cong_rule ctx fvar h rs t)
   101     handle Pattern.MATCH => find_cong_rule ctx fvar h rs t)
   102   | find_cong_rule _ _ _ [] _ = sys_error "function_package/context_tree.ML: No cong rule found!"
   102   | find_cong_rule _ _ _ [] _ = sys_error "function_package/context_tree.ML: No cong rule found!"
   103 
   103 
   104 
   104 
   109     case matchcall fvar t of
   109     case matchcall fvar t of
   110       SOME arg => RCall (t, mk_tree congs fvar h ctx arg)
   110       SOME arg => RCall (t, mk_tree congs fvar h ctx arg)
   111     | NONE => 
   111     | NONE => 
   112       if not (exists_subterm (fn Free v => v = fvar | _ => false) t) then Leaf t
   112       if not (exists_subterm (fn Free v => v = fvar | _ => false) t) then Leaf t
   113       else 
   113       else 
   114 	let val (r, dep, branches) = find_cong_rule ctx fvar h congs t in
   114   let val (r, dep, branches) = find_cong_rule ctx fvar h congs t in
   115 	  Cong (t, r, dep, 
   115     Cong (t, r, dep, 
   116                 map (fn (ctx', fixes, assumes, st) => 
   116                 map (fn (ctx', fixes, assumes, st) => 
   117 			(fixes, map (assume o cterm_of (ProofContext.theory_of ctx)) assumes, 
   117       (fixes, map (assume o cterm_of (ProofContext.theory_of ctx)) assumes, 
   118                          mk_tree congs fvar h ctx' st)) branches)
   118                          mk_tree congs fvar h ctx' st)) branches)
   119 	end
   119   end
   120 		
   120     
   121 
   121 
   122 fun inst_tree thy fvar f tr =
   122 fun inst_tree thy fvar f tr =
   123     let
   123     let
   124       val cfvar = cterm_of thy fvar
   124       val cfvar = cterm_of thy fvar
   125       val cf = cterm_of thy f
   125       val cf = cterm_of thy f
   140       inst_tree_aux tr
   140       inst_tree_aux tr
   141     end
   141     end
   142 
   142 
   143 
   143 
   144 
   144 
   145 (* FIXME: remove *)		
   145 (* FIXME: remove *)   
   146 fun add_context_varnames (Leaf _) = I
   146 fun add_context_varnames (Leaf _) = I
   147   | add_context_varnames (Cong (_, _, _, sub)) = fold (fn (fs, _, st) => fold (insert (op =) o fst) fs o add_context_varnames st) sub
   147   | add_context_varnames (Cong (_, _, _, sub)) = fold (fn (fs, _, st) => fold (insert (op =) o fst) fs o add_context_varnames st) sub
   148   | add_context_varnames (RCall (_,st)) = add_context_varnames st
   148   | add_context_varnames (RCall (_,st)) = add_context_varnames st
   149     
   149     
   150 
   150 
   162     fold (forall_elim o cterm_of thy o Free) fixes
   162     fold (forall_elim o cterm_of thy o Free) fixes
   163  #> fold implies_elim_swp athms
   163  #> fold implies_elim_swp athms
   164 
   164 
   165 fun assume_in_ctxt thy (fixes, athms) prop =
   165 fun assume_in_ctxt thy (fixes, athms) prop =
   166     let
   166     let
   167 	val global_assum = export_term (fixes, map prop_of athms) prop
   167   val global_assum = export_term (fixes, map prop_of athms) prop
   168     in
   168     in
   169 	(global_assum,
   169   (global_assum,
   170 	 assume (cterm_of thy global_assum) |> import_thm thy (fixes, athms))
   170    assume (cterm_of thy global_assum) |> import_thm thy (fixes, athms))
   171     end
   171     end
   172 
   172 
   173 
   173 
   174 (* folds in the order of the dependencies of a graph. *)
   174 (* folds in the order of the dependencies of a graph. *)
   175 fun fold_deps G f x =
   175 fun fold_deps G f x =
   176     let
   176     let
   177 	fun fill_table i (T, x) =
   177   fun fill_table i (T, x) =
   178 	    case Inttab.lookup T i of
   178       case Inttab.lookup T i of
   179 		SOME _ => (T, x)
   179     SOME _ => (T, x)
   180 	      | NONE => 
   180         | NONE => 
   181 		let
   181     let
   182 		    val (T', x') = fold fill_table (IntGraph.imm_succs G i) (T, x)
   182         val (T', x') = fold fill_table (IntGraph.imm_succs G i) (T, x)
   183 		    val (v, x'') = f (the o Inttab.lookup T') i x
   183         val (v, x'') = f (the o Inttab.lookup T') i x
   184 		in
   184     in
   185 		    (Inttab.update (i, v) T', x'')
   185         (Inttab.update (i, v) T', x'')
   186 		end
   186     end
   187 
   187 
   188 	val (T, x) = fold fill_table (IntGraph.keys G) (Inttab.empty, x)
   188   val (T, x) = fold fill_table (IntGraph.keys G) (Inttab.empty, x)
   189     in
   189     in
   190 	(Inttab.fold (cons o snd) T [], x)
   190   (Inttab.fold (cons o snd) T [], x)
   191     end
   191     end
   192 
   192 
   193 
   193 
   194 fun flatten xss = fold_rev append xss []
   194 fun flatten xss = fold_rev append xss []
   195 
   195 
   196 fun traverse_tree rcOp tr x =
   196 fun traverse_tree rcOp tr x =
   197     let 
   197     let 
   198 	fun traverse_help ctx (Leaf _) u x = ([], x)
   198   fun traverse_help ctx (Leaf _) u x = ([], x)
   199 	  | traverse_help ctx (RCall (t, st)) u x =
   199     | traverse_help ctx (RCall (t, st)) u x =
   200 	    rcOp ctx t u (traverse_help ctx st u x)
   200       rcOp ctx t u (traverse_help ctx st u x)
   201 	  | traverse_help ctx (Cong (t, crule, deps, branches)) u x =
   201     | traverse_help ctx (Cong (t, crule, deps, branches)) u x =
   202 	    let
   202       let
   203 		fun sub_step lu i x =
   203     fun sub_step lu i x =
   204 		    let
   204         let
   205 			val (fixes, assumes, subtree) = nth branches (i - 1)
   205       val (fixes, assumes, subtree) = nth branches (i - 1)
   206 			val used = fold_rev (append o lu) (IntGraph.imm_succs deps i) u
   206       val used = fold_rev (append o lu) (IntGraph.imm_succs deps i) u
   207 			val (subs, x') = traverse_help (compose ctx (fixes, assumes)) subtree used x
   207       val (subs, x') = traverse_help (compose ctx (fixes, assumes)) subtree used x
   208 			val exported_subs = map (apfst (compose (fixes, assumes))) subs
   208       val exported_subs = map (apfst (compose (fixes, assumes))) subs
   209 		    in
   209         in
   210 			(exported_subs, x')
   210       (exported_subs, x')
   211 		    end
   211         end
   212 	    in
   212       in
   213 		fold_deps deps sub_step x
   213     fold_deps deps sub_step x
   214 			  |> apfst flatten
   214         |> apfst flatten
   215 	    end
   215       end
   216     in
   216     in
   217 	snd (traverse_help ([], []) tr [] x)
   217   snd (traverse_help ([], []) tr [] x)
   218     end
   218     end
   219 
   219 
   220 
   220 
   221 fun is_refl thm = let val (l,r) = Logic.dest_equals (prop_of thm) in l = r end
   221 fun is_refl thm = let val (l,r) = Logic.dest_equals (prop_of thm) in l = r end
   222 
   222 
   223 fun rewrite_by_tree thy h ih x tr =
   223 fun rewrite_by_tree thy h ih x tr =
   224     let
   224     let
   225 	fun rewrite_help fix f_as h_as x (Leaf t) = (reflexive (cterm_of thy t), x)
   225       fun rewrite_help fix f_as h_as x (Leaf t) = (reflexive (cterm_of thy t), x)
   226 	  | rewrite_help fix f_as h_as x (RCall (_ $ arg, st)) =
   226         | rewrite_help fix f_as h_as x (RCall (_ $ arg, st)) =
   227 	    let
   227           let
   228 		val (inner, (lRi,ha)::x') = rewrite_help fix f_as h_as x st
   228             val (inner, (lRi,ha)::x') = rewrite_help fix f_as h_as x st
   229 					   
   229                                                      
   230 					   (* Need not use the simplifier here. Can use primitive steps! *)
   230              (* Need not use the simplifier here. Can use primitive steps! *)
   231 		val rew_ha = if is_refl inner then I else simplify (HOL_basic_ss addsimps [inner])
   231             val rew_ha = if is_refl inner then I else simplify (HOL_basic_ss addsimps [inner])
   232 			     
   232            
   233 		val h_a_eq_h_a' = combination (reflexive (cterm_of thy h)) inner
   233             val h_a_eq_h_a' = combination (reflexive (cterm_of thy h)) inner
   234 		val iha = import_thm thy (fix, h_as) ha (* (a', h a') : G *)
   234             val iha = import_thm thy (fix, h_as) ha (* (a', h a') : G *)
   235 				     |> rew_ha
   235                                  |> rew_ha
   236 
   236                       
   237 		val inst_ih = instantiate' [] [SOME (cterm_of thy arg)] ih
   237             val inst_ih = instantiate' [] [SOME (cterm_of thy arg)] ih
   238 		val eq = implies_elim (implies_elim inst_ih lRi) iha
   238             val eq = implies_elim (implies_elim inst_ih lRi) iha
   239 			 
   239                      
   240 		val h_a'_eq_f_a' = eq RS eq_reflection
   240             val h_a'_eq_f_a' = eq RS eq_reflection
   241 		val result = transitive h_a_eq_h_a' h_a'_eq_f_a'
   241             val result = transitive h_a_eq_h_a' h_a'_eq_f_a'
   242 	    in
   242           in
   243 		(result, x')
   243             (result, x')
   244 	    end
   244           end
   245 	  | rewrite_help fix f_as h_as x (Cong (t, crule, deps, branches)) =
   245         | rewrite_help fix f_as h_as x (Cong (t, crule, deps, branches)) =
   246 	    let
   246           let
   247 		fun sub_step lu i x =
   247             fun sub_step lu i x =
   248 		    let
   248                 let
   249 			val (fixes, assumes, st) = nth branches (i - 1)
   249                   val (fixes, assumes, st) = nth branches (i - 1)
   250 			val used = fold_rev (cons o lu) (IntGraph.imm_succs deps i) []
   250                   val used = fold_rev (cons o lu) (IntGraph.imm_succs deps i) []
   251 			val used_rev = map (fn u_eq => (u_eq RS sym) RS eq_reflection) used
   251                   val used_rev = map (fn u_eq => (u_eq RS sym) RS eq_reflection) used
   252 			val assumes' = map (simplify (HOL_basic_ss addsimps (filter_out is_refl used_rev))) assumes
   252                   val assumes' = map (simplify (HOL_basic_ss addsimps (filter_out is_refl used_rev))) assumes
   253 
   253                                  
   254 			val (subeq, x') = rewrite_help (fix @ fixes) (f_as @ assumes) (h_as @ assumes') x st
   254                   val (subeq, x') = rewrite_help (fix @ fixes) (f_as @ assumes) (h_as @ assumes') x st
   255 			val subeq_exp = export_thm thy (fixes, map prop_of assumes) (subeq RS meta_eq_to_obj_eq)
   255                   val subeq_exp = export_thm thy (fixes, map prop_of assumes) (subeq RS meta_eq_to_obj_eq)
   256 		    in
   256                 in
   257 			(subeq_exp, x')
   257                   (subeq_exp, x')
   258 		    end
   258                 end
   259 		    
   259                 
   260 		val (subthms, x') = fold_deps deps sub_step x
   260             val (subthms, x') = fold_deps deps sub_step x
   261 	    in
   261           in
   262 		(fold_rev (curry op COMP) subthms crule, x')
   262             (fold_rev (curry op COMP) subthms crule, x')
   263 	    end
   263           end
   264 	    
   264     in
   265     in
   265       rewrite_help [] [] [] x tr
   266 	rewrite_help [] [] [] x tr
   266     end
   267     end
   267     
   268 
       
   269 end
   268 end