370 end; |
370 end; |
371 |
371 |
372 |
372 |
373 (* case cerificates *) |
373 (* case cerificates *) |
374 |
374 |
375 fun case_cert thm = |
375 fun case_certificate thm = |
376 let |
376 let |
377 (*FIXME rework this code*) |
377 val thy = Thm.theory_of_thm thm; |
378 val thy = Thm.theory_of_thm thm; |
378 val ((head, raw_case_expr), cases) = (apfst Logic.dest_equals |
379 val (cas, t_pats) = (Logic.dest_implies o Thm.prop_of) thm; |
379 o apsnd Logic.dest_conjunctions o Logic.dest_implies o Thm.prop_of) thm; |
380 val pats = Logic.dest_conjunctions t_pats; |
|
381 val (head, proto_case_expr) = Logic.dest_equals cas; |
|
382 val _ = case head of Free _ => true |
380 val _ = case head of Free _ => true |
383 | Var _ => true |
381 | Var _ => true |
384 | _ => raise TERM ("case_cert", []); |
382 | _ => raise TERM ("case_cert", []); |
385 val ([(case_expr_v, _)], case_expr) = Term.strip_abs_eta 1 proto_case_expr; |
383 val ([(case_var, _)], case_expr) = Term.strip_abs_eta 1 raw_case_expr; |
386 val (Const (c_case, _), raw_params) = strip_comb case_expr; |
384 val (Const (case_const, _), raw_params) = strip_comb case_expr; |
387 val i = find_index |
385 val n = find_index (fn Free (v, _) => v = case_var | _ => false) raw_params; |
388 (fn Free (v, _) => v = case_expr_v | _ => false) raw_params; |
386 val _ = if n = ~1 then raise TERM ("case_cert", []) else (); |
389 val _ = if i = ~1 then raise TERM ("case_cert", []) else (); |
387 val params = map (fst o dest_Var) (nth_drop n raw_params); |
390 val t_params = nth_drop i raw_params; |
388 fun dest_case t = |
391 val params = map (fst o dest_Var) t_params; |
|
392 fun dest_pat t = |
|
393 let |
389 let |
394 val (head' $ t_co, rhs) = Logic.dest_equals t; |
390 val (head' $ t_co, rhs) = Logic.dest_equals t; |
395 val _ = if head' = head then () else raise TERM ("case_cert", []); |
391 val _ = if head' = head then () else raise TERM ("case_cert", []); |
396 val (Const (co, _), args) = strip_comb t_co; |
392 val (Const (co, _), args) = strip_comb t_co; |
397 val (Var (param, _), args') = strip_comb rhs; |
393 val (Var (param, _), args') = strip_comb rhs; |
398 val _ = if args' = args then () else raise TERM ("case_cert", []); |
394 val _ = if args' = args then () else raise TERM ("case_cert", []); |
399 in (param, co) end; |
395 in (param, co) end; |
400 fun analyze_pats pats = |
396 fun analyze_cases cases = |
401 let |
397 let |
402 val co_list = fold (AList.update (op =) o dest_pat) pats []; |
398 val co_list = fold (AList.update (op =) o dest_case) cases []; |
403 in map (the o AList.lookup (op =) co_list) params end; |
399 in map (the o AList.lookup (op =) co_list) params end; |
404 fun analyze_let t = |
400 fun analyze_let t = |
405 let |
401 let |
406 val (head' $ arg, Var (param', _) $ arg') = Logic.dest_equals t; |
402 val (head' $ arg, Var (param', _) $ arg') = Logic.dest_equals t; |
407 val _ = if head' = head then () else raise TERM ("case_cert", []); |
403 val _ = if head' = head then () else raise TERM ("case_cert", []); |
408 val _ = if arg' = arg then () else raise TERM ("case_cert", []); |
404 val _ = if arg' = arg then () else raise TERM ("case_cert", []); |
409 val _ = if [param'] = params then () else raise TERM ("case_cert", []); |
405 val _ = if [param'] = params then () else raise TERM ("case_cert", []); |
410 in [] end; |
406 in [] end; |
411 in (c_case, (i, case pats |
407 fun analyze (cases as [let_case]) = |
412 of [pat] => (analyze_pats pats handle Bind => analyze_let pat) |
408 (analyze_cases cases handle Bind => analyze_let let_case) |
413 | _ :: _ => analyze_pats pats)) |
409 | analyze cases = analyze_cases cases; |
414 end handle Bind => error "bad case certificate" |
410 in (case_const, (n, analyze cases)) end; |
415 | TERM _ => error "bad case certificate"; |
411 |
|
412 fun case_cert thm = case_certificate thm |
|
413 handle Bind => error "bad case certificate" |
|
414 | TERM _ => error "bad case certificate"; |
416 |
415 |
417 end; |
416 end; |