src/Pure/Isar/code_unit.ML
changeset 24917 8b97a94ab187
parent 24848 5dbbd33c3236
child 25336 027a63deb61c
equal deleted inserted replaced
24916:dc56dd1b3cda 24917:8b97a94ab187
   370   end;
   370   end;
   371 
   371 
   372 
   372 
   373 (* case cerificates *)
   373 (* case cerificates *)
   374 
   374 
   375 fun case_cert thm =
   375 fun case_certificate thm =
   376   let
   376   let
   377     (*FIXME rework this code*)
   377     val thy = Thm.theory_of_thm thm;
   378     val thy = Thm.theory_of_thm thm;
   378     val ((head, raw_case_expr), cases) = (apfst Logic.dest_equals
   379     val (cas, t_pats) = (Logic.dest_implies o Thm.prop_of) thm;
   379       o apsnd Logic.dest_conjunctions o Logic.dest_implies o Thm.prop_of) thm;
   380     val pats = Logic.dest_conjunctions t_pats;
       
   381     val (head, proto_case_expr) = Logic.dest_equals cas;
       
   382     val _ = case head of Free _ => true
   380     val _ = case head of Free _ => true
   383       | Var _ => true
   381       | Var _ => true
   384       | _ => raise TERM ("case_cert", []);
   382       | _ => raise TERM ("case_cert", []);
   385     val ([(case_expr_v, _)], case_expr) = Term.strip_abs_eta 1 proto_case_expr;
   383     val ([(case_var, _)], case_expr) = Term.strip_abs_eta 1 raw_case_expr;
   386     val (Const (c_case, _), raw_params) = strip_comb case_expr;
   384     val (Const (case_const, _), raw_params) = strip_comb case_expr;
   387     val i = find_index
   385     val n = find_index (fn Free (v, _) => v = case_var | _ => false) raw_params;
   388       (fn Free (v, _) => v = case_expr_v | _ => false) raw_params;
   386     val _ = if n = ~1 then raise TERM ("case_cert", []) else ();
   389     val _ = if i = ~1 then raise TERM ("case_cert", []) else ();
   387     val params = map (fst o dest_Var) (nth_drop n raw_params);
   390     val t_params = nth_drop i raw_params;
   388     fun dest_case t =
   391     val params = map (fst o dest_Var) t_params;
       
   392     fun dest_pat t =
       
   393       let
   389       let
   394         val (head' $ t_co, rhs) = Logic.dest_equals t;
   390         val (head' $ t_co, rhs) = Logic.dest_equals t;
   395         val _ = if head' = head then () else raise TERM ("case_cert", []);
   391         val _ = if head' = head then () else raise TERM ("case_cert", []);
   396         val (Const (co, _), args) = strip_comb t_co;
   392         val (Const (co, _), args) = strip_comb t_co;
   397         val (Var (param, _), args') = strip_comb rhs;
   393         val (Var (param, _), args') = strip_comb rhs;
   398         val _ = if args' = args then () else raise TERM ("case_cert", []);
   394         val _ = if args' = args then () else raise TERM ("case_cert", []);
   399       in (param, co) end;
   395       in (param, co) end;
   400     fun analyze_pats pats =
   396     fun analyze_cases cases =
   401       let
   397       let
   402         val co_list = fold (AList.update (op =) o dest_pat) pats [];
   398         val co_list = fold (AList.update (op =) o dest_case) cases [];
   403       in map (the o AList.lookup (op =) co_list) params end;
   399       in map (the o AList.lookup (op =) co_list) params end;
   404     fun analyze_let t =
   400     fun analyze_let t =
   405       let
   401       let
   406         val (head' $ arg, Var (param', _) $ arg') = Logic.dest_equals t;
   402         val (head' $ arg, Var (param', _) $ arg') = Logic.dest_equals t;
   407         val _ = if head' = head then () else raise TERM ("case_cert", []);
   403         val _ = if head' = head then () else raise TERM ("case_cert", []);
   408         val _ = if arg' = arg then () else raise TERM ("case_cert", []);
   404         val _ = if arg' = arg then () else raise TERM ("case_cert", []);
   409         val _ = if [param'] = params then () else raise TERM ("case_cert", []);
   405         val _ = if [param'] = params then () else raise TERM ("case_cert", []);
   410       in [] end;
   406       in [] end;
   411   in (c_case, (i, case pats
   407     fun analyze (cases as [let_case]) =
   412    of [pat] => (analyze_pats pats handle Bind => analyze_let pat)
   408           (analyze_cases cases handle Bind => analyze_let let_case)
   413     | _ :: _ => analyze_pats pats))
   409       | analyze cases = analyze_cases cases;
   414   end handle Bind => error "bad case certificate"
   410   in (case_const, (n, analyze cases)) end;
   415     | TERM _ => error "bad case certificate";
   411 
       
   412 fun case_cert thm = case_certificate thm
       
   413   handle Bind => error "bad case certificate"
       
   414       | TERM _ => error "bad case certificate";
   416 
   415 
   417 end;
   416 end;