src/HOL/Tools/SMT/smt_datatypes.ML
changeset 57226 c22ad39c3b4b
parent 57213 9daec42f6784
     1.1 --- a/src/HOL/Tools/SMT/smt_datatypes.ML	Thu Jun 12 01:00:49 2014 +0200
     1.2 +++ b/src/HOL/Tools/SMT/smt_datatypes.ML	Thu Jun 12 01:00:49 2014 +0200
     1.3 @@ -17,56 +17,23 @@
     1.4  
     1.5  val lhs_head_of = Term.head_of o fst o Logic.dest_equals o Thm.prop_of
     1.6  
     1.7 -fun mk_selectors T Ts ctxt =
     1.8 -  let
     1.9 -    val (sels, ctxt') =
    1.10 -      Variable.variant_fixes (replicate (length Ts) "select") ctxt
    1.11 -  in (map2 (fn n => fn U => Free (n, T --> U)) sels Ts, ctxt') end
    1.12 +fun mk_selectors T Ts =
    1.13 +  Variable.variant_fixes (replicate (length Ts) "select")
    1.14 +  #>> map2 (fn U => fn n => Free (n, T --> U)) Ts
    1.15  
    1.16  
    1.17 -(* datatype declarations *)
    1.18 -
    1.19 -fun get_datatype_decl ({descr, ...} : Datatype.info) n Ts ctxt =
    1.20 -  let
    1.21 -    fun get_vars (_, (m, vs, _)) = if m = n then SOME vs else NONE
    1.22 -    val vars = the (get_first get_vars descr) ~~ Ts
    1.23 -    val lookup_var = the o AList.lookup (op =) vars
    1.24 -
    1.25 -    fun typ_of (dt as Datatype.DtTFree _) = lookup_var dt
    1.26 -      | typ_of (Datatype.DtType (m, dts)) = Type (m, map typ_of dts)
    1.27 -      | typ_of (Datatype.DtRec i) =
    1.28 -          the (AList.lookup (op =) descr i)
    1.29 -          |> (fn (m, dts, _) => Type (m, map typ_of dts))
    1.30 -
    1.31 -    fun mk_constr T (m, dts) ctxt =
    1.32 -      let
    1.33 -        val Ts = map typ_of dts
    1.34 -        val constr = Const (m, Ts ---> T)
    1.35 -        val (selects, ctxt') = mk_selectors T Ts ctxt
    1.36 -      in ((constr, selects), ctxt') end
    1.37 +(* free constructor type declarations *)
    1.38  
    1.39 -    fun mk_decl (i, (_, _, constrs)) ctxt =
    1.40 -      let
    1.41 -        val T = typ_of (Datatype.DtRec i)
    1.42 -        val (css, ctxt') = fold_map (mk_constr T) constrs ctxt
    1.43 -      in ((T, css), ctxt') end
    1.44 -
    1.45 -  in fold_map mk_decl descr ctxt end
    1.46 -
    1.47 -
    1.48 -(* record declarations *)
    1.49 -
    1.50 -val record_name_of = Long_Name.implode o fst o split_last o Long_Name.explode
    1.51 -
    1.52 -fun get_record_decl ({ext_def, ...} : Record.info) T ctxt =
    1.53 +fun get_ctr_sugar_decl ({ctrs, ...} : Ctr_Sugar.ctr_sugar) T Ts ctxt =
    1.54    let
    1.55 -    val (con, _) = Term.dest_Const (lhs_head_of ext_def)
    1.56 -    val (fields, more) = Record.get_extT_fields (Proof_Context.theory_of ctxt) T
    1.57 -    val fieldTs = map snd fields @ [snd more]
    1.58 -
    1.59 -    val constr = Const (con, fieldTs ---> T)
    1.60 -    val (selects, ctxt') = mk_selectors T fieldTs ctxt
    1.61 -  in ((T, [(constr, selects)]), ctxt') end
    1.62 +    fun mk_constr ctr0 =
    1.63 +      let val ctr = Ctr_Sugar.mk_ctr Ts ctr0 in
    1.64 +        mk_selectors T (binder_types (fastype_of ctr)) #>> pair ctr
    1.65 +      end
    1.66 +  in
    1.67 +    fold_map mk_constr ctrs ctxt
    1.68 +    |>> (pair T #> single)
    1.69 +  end
    1.70  
    1.71  
    1.72  (* typedef declarations *)
    1.73 @@ -91,18 +58,12 @@
    1.74  fun declared' dss T = exists (exists (equal T o fst) o snd) dss
    1.75  
    1.76  fun get_decls T n Ts ctxt =
    1.77 -  let val thy = Proof_Context.theory_of ctxt
    1.78 -  in
    1.79 -    (case Datatype.get_info thy n of
    1.80 -      SOME info => get_datatype_decl info n Ts ctxt
    1.81 -    | NONE =>
    1.82 -        (case Record.get_info thy (record_name_of n) of
    1.83 -          SOME info => get_record_decl info T ctxt |>> single
    1.84 -        | NONE =>
    1.85 -            (case Typedef.get_info ctxt n of
    1.86 -              [] => ([], ctxt)
    1.87 -            | info :: _ => (get_typedef_decl info T Ts, ctxt))))
    1.88 -  end
    1.89 +  (case Ctr_Sugar.ctr_sugar_of ctxt n of
    1.90 +    SOME ctr_sugar => get_ctr_sugar_decl ctr_sugar T Ts ctxt
    1.91 +  | NONE =>
    1.92 +      (case Typedef.get_info ctxt n of
    1.93 +        [] => ([], ctxt)
    1.94 +      | info :: _ => (get_typedef_decl info T Ts, ctxt)))
    1.95  
    1.96  fun add_decls T (declss, ctxt) =
    1.97    let
    1.98 @@ -132,5 +93,4 @@
    1.99              in fold add Us (ins dss, ctxt2) end))
   1.100    in add T ([], ctxt) |>> append declss o map snd end
   1.101  
   1.102 -
   1.103  end