src/HOL/Tools/numeral.ML
changeset 47108 2a1953f0d20d
parent 46497 89ccf66aa73d
child 48072 ace701efe203
equal deleted inserted replaced
47107:35807a5d8dc2 47108:2a1953f0d20d
    14 structure Numeral: NUMERAL =
    14 structure Numeral: NUMERAL =
    15 struct
    15 struct
    16 
    16 
    17 (* numeral *)
    17 (* numeral *)
    18 
    18 
    19 fun mk_cbit 0 = @{cterm "Int.Bit0"}
    19 fun mk_cbit 0 = @{cterm "Num.Bit0"}
    20   | mk_cbit 1 = @{cterm "Int.Bit1"}
    20   | mk_cbit 1 = @{cterm "Num.Bit1"}
    21   | mk_cbit _ = raise CTERM ("mk_cbit", []);
    21   | mk_cbit _ = raise CTERM ("mk_cbit", []);
    22 
    22 
    23 fun mk_cnumeral 0 = @{cterm "Int.Pls"}
    23 fun mk_cnumeral i =
    24   | mk_cnumeral ~1 = @{cterm "Int.Min"}
    24   let
    25   | mk_cnumeral i =
    25     fun mk 1 = @{cterm "Num.One"}
       
    26       | mk i =
    26       let val (q, r) = Integer.div_mod i 2 in
    27       let val (q, r) = Integer.div_mod i 2 in
    27         Thm.apply (mk_cbit r) (mk_cnumeral q)
    28         Thm.apply (mk_cbit r) (mk q)
    28       end;
    29       end
       
    30   in
       
    31     if i > 0 then mk i else raise CTERM ("mk_cnumeral: negative input", [])
       
    32   end
    29 
    33 
    30 
    34 
    31 (* number *)
    35 (* number *)
    32 
    36 
    33 local
    37 local
    36 val zeroT = Thm.ctyp_of_term zero;
    40 val zeroT = Thm.ctyp_of_term zero;
    37 
    41 
    38 val one = @{cpat "1"};
    42 val one = @{cpat "1"};
    39 val oneT = Thm.ctyp_of_term one;
    43 val oneT = Thm.ctyp_of_term one;
    40 
    44 
    41 val number_of = @{cpat "number_of"};
    45 val numeral = @{cpat "numeral"};
    42 val numberT = Thm.ctyp_of @{theory} (Term.range_type (Thm.typ_of (Thm.ctyp_of_term number_of)));
    46 val numeralT = Thm.ctyp_of @{theory} (Term.range_type (Thm.typ_of (Thm.ctyp_of_term numeral)));
       
    47 
       
    48 val neg_numeral = @{cpat "neg_numeral"};
       
    49 val neg_numeralT = Thm.ctyp_of @{theory} (Term.range_type (Thm.typ_of (Thm.ctyp_of_term neg_numeral)));
    43 
    50 
    44 fun instT T V = Thm.instantiate_cterm ([(V, T)], []);
    51 fun instT T V = Thm.instantiate_cterm ([(V, T)], []);
    45 
    52 
    46 in
    53 in
    47 
    54 
    48 fun mk_cnumber T 0 = instT T zeroT zero
    55 fun mk_cnumber T 0 = instT T zeroT zero
    49   | mk_cnumber T 1 = instT T oneT one
    56   | mk_cnumber T 1 = instT T oneT one
    50   | mk_cnumber T i = Thm.apply (instT T numberT number_of) (mk_cnumeral i);
    57   | mk_cnumber T i =
       
    58     if i > 0 then Thm.apply (instT T numeralT numeral) (mk_cnumeral i)
       
    59     else Thm.apply (instT T neg_numeralT neg_numeral) (mk_cnumeral (~i));
    51 
    60 
    52 end;
    61 end;
    53 
    62 
    54 
    63 
    55 (* code generator *)
    64 (* code generator *)
    56 
    65 
    57 local open Basic_Code_Thingol in
    66 local open Basic_Code_Thingol in
    58 
    67 
    59 fun add_code number_of negative print target thy =
    68 fun add_code number_of negative print target thy =
    60   let
    69   let
    61     fun dest_numeral pls' min' bit0' bit1' thm =
    70     fun dest_numeral one' bit0' bit1' thm t =
    62       let
    71       let
    63         fun dest_bit (IConst (c, _)) = if c = bit0' then 0
    72         fun dest_bit (IConst (c, _)) = if c = bit0' then 0
    64               else if c = bit1' then 1
    73               else if c = bit1' then 1
    65               else Code_Printer.eqn_error thm "Illegal numeral expression: illegal bit"
    74               else Code_Printer.eqn_error thm "Illegal numeral expression: illegal bit"
    66           | dest_bit _ = Code_Printer.eqn_error thm "Illegal numeral expression: illegal bit";
    75           | dest_bit _ = Code_Printer.eqn_error thm "Illegal numeral expression: illegal bit";
    67         fun dest_num (IConst (c, _)) = if c = pls' then SOME 0
    76         fun dest_num (IConst (c, _)) = if c = one' then 1
    68               else if c = min' then
       
    69                 if negative then SOME ~1 else NONE
       
    70               else Code_Printer.eqn_error thm "Illegal numeral expression: illegal leading digit"
    77               else Code_Printer.eqn_error thm "Illegal numeral expression: illegal leading digit"
    71           | dest_num (t1 `$ t2) =
    78           | dest_num (t1 `$ t2) = 2 * dest_num t2 + dest_bit t1
    72               let val (n, b) = (dest_num t2, dest_bit t1)
       
    73               in case n of SOME n => SOME (2 * n + b) | NONE => NONE end
       
    74           | dest_num _ = Code_Printer.eqn_error thm "Illegal numeral expression: illegal term";
    79           | dest_num _ = Code_Printer.eqn_error thm "Illegal numeral expression: illegal term";
    75       in dest_num end;
    80       in if negative then ~ (dest_num t) else dest_num t end;
    76     fun pretty literals [pls', min', bit0', bit1'] _ thm _ _ [(t, _)] =
    81     fun pretty literals [one', bit0', bit1'] _ thm _ _ [(t, _)] =
    77       (Code_Printer.str o print literals o the_default 0 o dest_numeral pls' min' bit0' bit1' thm) t;
    82       (Code_Printer.str o print literals o dest_numeral one' bit0' bit1' thm) t;
    78   in
    83   in
    79     thy |> Code_Target.add_const_syntax target number_of
    84     thy |> Code_Target.add_const_syntax target number_of
    80       (SOME (Code_Printer.complex_const_syntax (1, ([@{const_name Int.Pls}, @{const_name Int.Min},
    85       (SOME (Code_Printer.complex_const_syntax (1, ([@{const_name Num.One},
    81         @{const_name Int.Bit0}, @{const_name Int.Bit1}], pretty))))
    86         @{const_name Num.Bit0}, @{const_name Num.Bit1}], pretty))))
    82   end;
    87   end;
    83 
    88 
    84 end; (*local*)
    89 end; (*local*)
    85 
    90 
    86 end;
    91 end;