new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
authorAndreas Lochbihler
Tue, 01 Jan 2019 17:04:53 +0100
changeset 69568 de09a7261120
parent 69567 6b4c41037649
child 69569 2d88bf80c84f
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
CONTRIBUTORS
NEWS
src/HOL/Library/Simps_Case_Conv.thy
src/HOL/Library/case_converter.ML
src/HOL/Library/code_lazy.ML
src/HOL/Library/simps_case_conv.ML
src/HOL/ex/Simps_Case_Conv_Examples.thy
--- a/CONTRIBUTORS	Sun Dec 30 10:30:41 2018 +0100
+++ b/CONTRIBUTORS	Tue Jan 01 17:04:53 2019 +0100
@@ -6,6 +6,10 @@
 Contributions to this Isabelle version
 --------------------------------------
 
+* January 2019: Andreas Lochbihler
+  New implementation for case_of_simps based on Code_Lazy's
+  pattern matching elimination algorithm.
+
 * October 2018: Mathias Fleury
   Proof reconstruction for the SMT solver veriT in the smt method
 
--- a/NEWS	Sun Dec 30 10:30:41 2018 +0100
+++ b/NEWS	Tue Jan 01 17:04:53 2019 +0100
@@ -91,6 +91,13 @@
 
 * SMT: reconstruction is now possible using the SMT solver veriT.
 
+* HOL-Library.Simps_Case_Conv: case_of_simps now supports overlapping 
+and non-exhaustive patterns and handles arbitrarily nested patterns.
+It uses on the same algorithm as HOL-Library.Code_Lazy, which assumes
+sequential left-to-right pattern matching. The generated
+equation no longer tuples the arguments on the right-hand side.
+INCOMPATIBILITY.
+
 * Session HOL-SPARK: .prv files are no longer written to the
 file-system, but exported to the session database. Results may be
 retrieved with the "isabelle export" command-line tool like this:
--- a/src/HOL/Library/Simps_Case_Conv.thy	Sun Dec 30 10:30:41 2018 +0100
+++ b/src/HOL/Library/Simps_Case_Conv.thy	Tue Jan 01 17:04:53 2019 +0100
@@ -3,7 +3,7 @@
 *)
 
 theory Simps_Case_Conv
-imports Main
+imports Case_Converter
   keywords "simps_of_case" "case_of_simps" :: thy_decl
   abbrevs "simps_of_case" "case_of_simps" = ""
 begin
--- a/src/HOL/Library/case_converter.ML	Sun Dec 30 10:30:41 2018 +0100
+++ b/src/HOL/Library/case_converter.ML	Tue Jan 01 17:04:53 2019 +0100
@@ -3,8 +3,11 @@
 
 signature CASE_CONVERTER =
 sig
-  val to_case: Proof.context -> (string * string -> bool) -> (string * typ -> int) ->
+  type elimination_strategy
+  val to_case: Proof.context -> elimination_strategy -> (string * typ -> int) ->
     thm list -> thm list option
+  val replace_by_type: (Proof.context -> string * string -> bool) -> elimination_strategy
+  val keep_constructor_context: elimination_strategy
 end;
 
 structure Case_Converter : CASE_CONVERTER =
@@ -60,30 +63,72 @@
     Coordinate (merge_consts xs ys)
   end;
 
-fun term_to_coordinates P term = 
-  let
-    val (ctr, args) = strip_comb term
-  in
-    case ctr of Const (s, T) =>
-      if P (body_type T |> dest_Type |> fst, s)
-      then SOME (End (body_type T))
-      else
-        let
-          fun f (i, t) = term_to_coordinates P t |> Option.map (pair i)
-          val tcos = map_filter I (map_index f args)
-        in
-          if null tcos then NONE
-          else SOME (Coordinate (map (pair s) tcos))
-        end
-    | _ => NONE
-  end;
-
 fun coordinates_to_list (End x) = [(x, [])]
   | coordinates_to_list (Coordinate xs) = 
   let
     fun f (s, (n, xss)) = map (fn (T, xs) => (T, (s, n) :: xs)) (coordinates_to_list xss)
   in flat (map f xs) end;
 
+type elimination_strategy = Proof.context -> term list -> term_coordinate list
+
+fun replace_by_type replace_ctr ctxt pats =
+  let
+    fun term_to_coordinates P term = 
+      let
+        val (ctr, args) = strip_comb term
+      in
+        case ctr of Const (s, T) =>
+          if P (body_type T |> dest_Type |> fst, s)
+          then SOME (End (body_type T))
+          else
+            let
+              fun f (i, t) = term_to_coordinates P t |> Option.map (pair i)
+              val tcos = map_filter I (map_index f args)
+            in
+              if null tcos then NONE
+              else SOME (Coordinate (map (pair s) tcos))
+            end
+        | _ => NONE
+      end
+    in
+      map_filter (term_to_coordinates (replace_ctr ctxt)) pats
+    end
+
+fun keep_constructor_context ctxt pats =
+  let
+    fun to_coordinates [] = NONE
+      | to_coordinates pats =
+        let
+          val (fs, argss) = map strip_comb pats |> split_list
+          val f = hd fs
+          fun is_single_ctr (Const (name, T)) = 
+              let
+                val tyco = body_type T |> dest_Type |> fst
+                val _ = Ctr_Sugar.ctr_sugar_of ctxt tyco |> the |> #ctrs
+              in
+                case Ctr_Sugar.ctr_sugar_of ctxt tyco of
+                  NONE => error ("Not a free constructor " ^ name ^ " in pattern")
+                | SOME info =>
+                  case #ctrs info of [Const (name', _)] => name = name'
+                    | _ => false
+              end
+            | is_single_ctr _ = false
+        in 
+          if not (is_single_ctr f) andalso forall (fn x => f = x) fs then
+            let
+              val patss = Ctr_Sugar_Util.transpose argss
+              fun recurse (i, pats) = to_coordinates pats |> Option.map (pair i)
+              val coords = map_filter I (map_index recurse patss)
+            in
+              if null coords then NONE
+              else SOME (Coordinate (map (pair (dest_Const f |> fst)) coords))
+            end
+          else SOME (End (body_type (fastype_of f)))
+          end
+    in
+      the_list (to_coordinates pats)
+    end
+
 
 (* AL: TODO: change from term to const_name *)
 fun find_ctr ctr1 xs =
@@ -453,7 +498,8 @@
       ctxt1)
   end;
 
-fun build_case_t replace_ctr ctr_count head lhss rhss ctxt =
+
+fun build_case_t elimination_strategy ctr_count head lhss rhss ctxt =
   let
     val num_eqs = length lhss
     val _ = if length rhss = num_eqs andalso num_eqs > 0 then ()
@@ -464,16 +510,17 @@
     val _ = if forall (fn m => length m = n) lhss then ()
       else raise Fail "expected equal number of arguments"
 
-    fun to_coordinates (n, ts) = case map_filter (term_to_coordinates replace_ctr) ts of
-        [] => NONE
-      | (tco :: tcos) => SOME (n, fold term_coordinate_merge tcos tco |> coordinates_to_list)
+    fun to_coordinates (n, ts) = 
+      case elimination_strategy ctxt ts of
+          [] => NONE
+        | (tco :: tcos) => SOME (n, fold term_coordinate_merge tcos tco |> coordinates_to_list)
     fun add_T (n, xss) = map (fn (T, xs) => (T, (n, xs))) xss
     val (typ_list, poss) = lhss
       |> Ctr_Sugar_Util.transpose
       |> map_index to_coordinates
       |> map_filter (Option.map add_T)
       |> flat
-      |> split_list 
+      |> split_list
   in
     if null poss then ([], [], ctxt)
     else terms_to_case ctxt (dest_Const #> ctr_count) head lhss rhss typ_list poss
--- a/src/HOL/Library/code_lazy.ML	Sun Dec 30 10:30:41 2018 +0100
+++ b/src/HOL/Library/code_lazy.ML	Tue Jan 01 17:04:53 2019 +0100
@@ -530,18 +530,23 @@
 fun transform_code_eqs _ [] = NONE
   | transform_code_eqs ctxt eqs =
     let
+      fun replace_ctr ctxt =
+        let 
+          val thy = Proof_Context.theory_of ctxt
+          val table = Laziness_Data.get thy
+        in fn (s1, s2) => case Symtab.lookup table s1 of
+            NONE => false
+          | SOME x => #active x andalso s2 <> (#ctr x |> dest_Const |> fst)
+        end
       val thy = Proof_Context.theory_of ctxt
       val table = Laziness_Data.get thy
-      fun eliminate (s1, s2) = case Symtab.lookup table s1 of
-          NONE => false
-        | SOME x => #active x andalso s2 <> (#ctr x |> dest_Const |> fst)
       fun num_consts_fun (_, T) =
         let
           val s = body_type T |> dest_Type |> fst
         in
           if Symtab.defined table s
-            then Ctr_Sugar.ctr_sugar_of ctxt s |> the |> #ctrs |> length
-            else Code.get_type thy s |> fst |> snd |> length
+          then Ctr_Sugar.ctr_sugar_of ctxt s |> the |> #ctrs |> length
+          else Code.get_type thy s |> fst |> snd |> length
         end
       val eqs = map (apfst (Thm.transfer thy)) eqs;
 
@@ -554,10 +559,10 @@
         handle THM _ => (([], eqs), false)
       val to_original_eq = if pure then map (apfst (fn x => x RS @{thm eq_reflection})) else I
     in
-      case Case_Converter.to_case ctxt eliminate num_consts_fun (map fst code_eqs) of
+      case Case_Converter.to_case ctxt (Case_Converter.replace_by_type replace_ctr) num_consts_fun (map fst code_eqs) of
           NONE => NONE
         | SOME thms => SOME (nbe_eqs @ map (rpair true) thms |> to_original_eq)
-    end handle THM ex => (Output.writeln (@{make_string} eqs); raise THM ex);
+    end
 
 val activate_lazy_type = set_active_lazy_type true;
 val deactivate_lazy_type = set_active_lazy_type false;
--- a/src/HOL/Library/simps_case_conv.ML	Sun Dec 30 10:30:41 2018 +0100
+++ b/src/HOL/Library/simps_case_conv.ML	Tue Jan 01 17:04:53 2019 +0100
@@ -31,88 +31,6 @@
 
 val strip_eq = Thm.prop_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq
 
-
-local
-
-  fun transpose [] = []
-    | transpose ([] :: xss) = transpose xss
-    | transpose xss = map hd xss :: transpose (map tl xss);
-
-  fun same_fun single_ctrs (ts as _ $ _ :: _) =
-      let
-        val (fs, argss) = map strip_comb ts |> split_list
-        val f = hd fs
-        fun is_single_ctr (Const (name, _)) = member (op =) single_ctrs name
-          | is_single_ctr _ = false
-      in if not (is_single_ctr f) andalso forall (fn x => f = x) fs then SOME (f, argss) else NONE end
-    | same_fun _ _ = NONE
-
-  (* pats must be non-empty *)
-  fun split_pat single_ctrs pats ctxt =
-      case same_fun single_ctrs pats of
-        NONE =>
-          let
-            val (name, ctxt') = yield_singleton Variable.variant_fixes "x" ctxt
-            val var = Free (name, fastype_of (hd pats))
-          in (((var, [var]), map single pats), ctxt') end
-      | SOME (f, argss) =>
-          let
-            val (((def_pats, def_frees), case_patss), ctxt') =
-              split_pats single_ctrs argss ctxt
-            val def_pat = list_comb (f, def_pats)
-          in (((def_pat, flat def_frees), case_patss), ctxt') end
-  and
-      split_pats single_ctrs patss ctxt =
-        let
-          val (splitted, ctxt') = fold_map (split_pat single_ctrs) (transpose patss) ctxt
-          val r = splitted |> split_list |> apfst split_list |> apsnd (transpose #> map flat)
-        in (r, ctxt') end
-
-(*
-  Takes a list lhss of left hand sides (which are lists of patterns)
-  and a list rhss of right hand sides. Returns
-    - a single equation with a (nested) case-expression on the rhs
-    - a list of all split-thms needed to split the rhs
-  Patterns which have the same outer context in all lhss remain
-  on the lhs of the computed equation.
-*)
-fun build_case_t fun_t lhss rhss ctxt =
-  let
-    val single_ctrs =
-      get_type_infos ctxt (map fastype_of (flat lhss))
-      |> map_filter (fn ti => case #ctrs ti of [Const (name, _)] => SOME name | _ => NONE)
-    val (((def_pats, def_frees), case_patss), ctxt') =
-      split_pats single_ctrs lhss ctxt
-    val pattern = map HOLogic.mk_tuple case_patss
-    val case_arg = HOLogic.mk_tuple (flat def_frees)
-    val cases = Case_Translation.make_case ctxt' Case_Translation.Warning Name.context
-      case_arg (pattern ~~ rhss)
-    val split_thms = get_split_ths ctxt' [fastype_of case_arg]
-    val t = (list_comb (fun_t, def_pats), cases)
-      |> HOLogic.mk_eq
-      |> HOLogic.mk_Trueprop
-  in ((t, split_thms), ctxt') end
-
-fun tac ctxt {splits, intros, defs} =
-  let val ctxt' = Classical.addSIs (ctxt, intros) in
-    REPEAT_DETERM1 (FIRSTGOAL (split_tac ctxt splits))
-    THEN Local_Defs.unfold_tac ctxt defs
-    THEN safe_tac ctxt'
-  end
-
-fun import [] ctxt = ([], ctxt)
-  | import (thm :: thms) ctxt =
-    let
-      val fun_ct = strip_eq #> fst #> strip_comb #> fst #> Logic.mk_term
-        #> Thm.cterm_of ctxt
-      val ct = fun_ct thm
-      val cts = map fun_ct thms
-      val pairs = map (fn s => (s,ct)) cts
-      val thms' = map (fn (th,p) => Thm.instantiate (Thm.match p) th) (thms ~~ pairs)
-    in Variable.import true (thm :: thms') ctxt |> apfst snd end
-
-in
-
 (*
   For a list
     f p_11 ... p_1n = t1
@@ -122,39 +40,24 @@
   of theorems, prove a single theorem
     f x1 ... xn = t
   where t is a (nested) case expression. f must not be a function
-  application. Moreover, the terms p_11, ..., p_mn must be non-overlapping
-  datatype patterns. The patterns must be exhausting up to common constructor
-  contexts.
+  application.
 *)
 fun to_case ctxt ths =
   let
-    val (iths, ctxt') = import ths ctxt
-    val fun_t = hd iths |> strip_eq |> fst |> head_of
-    val eqs = map (strip_eq #> apfst (snd o strip_comb)) iths
-
-    fun hide_rhs ((pat, rhs), name) lthy =
+    fun ctr_count (ctr, T) = 
       let
-        val frees = fold Term.add_frees pat []
-        val abs_rhs = fold absfree frees rhs
-        val ([(f, (_, def))], lthy') = lthy
-          |> Local_Defs.define [((Binding.name name, NoSyn), (Binding.empty_atts, abs_rhs))]
-      in ((list_comb (f, map Free (rev frees)), def), lthy') end
-
-    val ((def_ts, def_thms), ctxt2) =
-      let val names = Name.invent (Variable.names_of ctxt') "rhs" (length eqs)
-      in fold_map hide_rhs (eqs ~~ names) ctxt' |> apfst split_list end
-
-    val ((t, split_thms), ctxt3) = build_case_t fun_t (map fst eqs) def_ts ctxt2
-
-    val th = Goal.prove ctxt3 [] [] t (fn {context=ctxt, ...} =>
-          tac ctxt {splits=split_thms, intros=ths, defs=def_thms})
-  in th
-    |> singleton (Proof_Context.export ctxt3 ctxt)
-    |> Goal.norm_result ctxt
+        val tyco = body_type T |> dest_Type |> fst
+        val info = Ctr_Sugar.ctr_sugar_of ctxt tyco
+        val _ = if is_none info then error ("Pattern match on non-constructor constant " ^ ctr) else ()
+      in
+        info |> the |> #ctrs |> length
+      end
+    val thms = Case_Converter.to_case ctxt Case_Converter.keep_constructor_context ctr_count ths
+  in
+    case thms of SOME thms => hd thms
+      | _ => error ("Conversion to case expression failed.")
   end
 
-end
-
 local
 
 fun was_split t =
--- a/src/HOL/ex/Simps_Case_Conv_Examples.thy	Sun Dec 30 10:30:41 2018 +0100
+++ b/src/HOL/ex/Simps_Case_Conv_Examples.thy	Tue Jan 01 17:04:53 2019 +0100
@@ -30,27 +30,23 @@
 case_of_simps foo_cases1: foo.simps
 lemma
   fixes xs :: "'a list" and ys :: "'b list"
-  shows "foo xs ys = (case (xs, ys) of
-    ( [], []) \<Rightarrow> 3
-    | ([], y # ys) \<Rightarrow> 1
-    | (x # xs, []) \<Rightarrow> 0
-    | (x # xs, y # ys) \<Rightarrow> foo ([] :: 'a list) ([] :: 'b list))"
+  shows "foo xs ys = 
+   (case xs of [] \<Rightarrow> (case ys of [] \<Rightarrow> 3 | _ # _ \<Rightarrow> 1)
+    | _ # _ \<Rightarrow> (case ys of [] \<Rightarrow> 0 | _ # _ \<Rightarrow> foo ([] :: 'a list) ([] :: 'b list)))"
   by (fact foo_cases1)
 
 text \<open>Redundant equations are ignored\<close>
 case_of_simps foo_cases2: foo.simps foo.simps
 lemma
   fixes xs :: "'a list" and ys :: "'b list"
-  shows "foo xs ys = (case (xs, ys) of
-    ( [], []) \<Rightarrow> 3
-    | ([], y # ys) \<Rightarrow> 1
-    | (x # xs, []) \<Rightarrow> 0
-    | (x # xs, y # ys) \<Rightarrow> foo ([] :: 'a list) ([] :: 'b list))"
+  shows "foo xs ys = 
+   (case xs of [] \<Rightarrow> (case ys of [] \<Rightarrow> 3 | _ # _ \<Rightarrow> 1)
+    | _ # _ \<Rightarrow> (case ys of [] \<Rightarrow> 0 | _ # _ \<Rightarrow> foo ([] :: 'a list) ([] :: 'b list)))"
   by (fact foo_cases2)
 
 text \<open>Variable patterns\<close>
 case_of_simps bar_cases: bar.simps
-print_theorems
+lemma "bar x n y = (case n of 0 \<Rightarrow> 0 + x | Suc n' \<Rightarrow> n' + x)" by(fact bar_cases)
 
 text \<open>Case expression not at top level\<close>
 simps_of_case split_rule_test_simps: split_rule_test_def
@@ -96,24 +92,23 @@
 text \<open>Reversal\<close>
 case_of_simps test_def1: test_simps1
 lemma
-  "test x y = (case (x,y) of
-    (None, []) \<Rightarrow> 1
-  | (None, _#_) \<Rightarrow> 2
-  | (Some x, _) \<Rightarrow> x)"
+  "test x y = 
+   (case x of None \<Rightarrow> (case y of [] \<Rightarrow> 1 | _ # _ \<Rightarrow> 2)
+    | Some x' \<Rightarrow> x')"
   by (fact test_def1)
 
 text \<open>Case expressions on RHS\<close>
 case_of_simps test_def2: test_simps2
-lemma "test xs y = (case (xs, y) of (None, []) \<Rightarrow> 1 | (None, x # xa) \<Rightarrow> 2 | (Some x, y) \<Rightarrow> x)"
+lemma "test x y = 
+  (case x of None \<Rightarrow> (case y of [] \<Rightarrow> 1 | _ # _ \<Rightarrow> 2)
+   | Some x' \<Rightarrow> x')"
   by (fact test_def2)
 
 text \<open>Partial split of simps\<close>
 case_of_simps foo_cons_def: foo.simps(1,2)
 lemma
   fixes xs :: "'a list" and ys :: "'b list"
-  shows "foo (x # xs) ys = (case (x,xs,ys) of
-      (_,_,[]) \<Rightarrow> 0
-    | (_,_,_ # _) \<Rightarrow> foo ([] :: 'a list) ([] :: 'b list))"
+  shows "foo (x # xs) ys = (case ys of [] \<Rightarrow> 0 | _ # _ \<Rightarrow> foo ([] :: 'a list) ([] :: 'b list))"
   by (fact foo_cons_def)
 
 end