Definition color : Set := bool.
Definition black : color := true.
Definition red : color := false.

Inductive tree : Set :=
| Leaf : tree
| Node : nat -> color -> tree -> tree -> tree.

(*
Check tree_ind.
Check (
forall P : tree -> Prop,
P Leaf ->
(forall (n : nat) (c : color) (t0 t1 : tree),
 P t0 -> P t1 -> P (Node n c t0 t1)) ->
forall t : tree, P t
).
*)

Fixpoint count_nodes (t : tree) {struct t} : nat :=
  match t with
  | Leaf => O
  | Node n c t0 t1 =>
      S (plus (count_nodes t0) (count_nodes t1))
  end.

Inductive not_red_root : tree -> Prop :=
| not_red_root_leaf : not_red_root Leaf
| not_red_root_node : forall (n : nat) (t0 t1 : tree),
    not_red_root (Node n black t0 t1).

Fixpoint not_red_root' (t : tree) {struct t} : Prop :=
  match t with
  | Leaf => True
  | Node n c t0 t1 => c = black
  end.

Lemma not_red_root_equiv :
  forall t, not_red_root t <-> not_red_root' t.
induction t; simpl; split; intros; subst; auto.
apply not_red_root_leaf.
inversion H. auto.
apply not_red_root_node.
Qed.

Definition not_red_root'' (t : tree) : Prop :=
  t = Leaf \/
  exists n : nat, exists t1 : tree, exists t2 : tree,
    t = Node n black t1 t2.

Lemma not_red_root'_equiv :
  forall t, not_red_root' t <-> not_red_root'' t.
intros. unfold not_red_root''. destruct t; simpl. tauto.
split; intros.
subst. right. exists n. exists t1. exists t2. auto.
destruct H. discriminate H. destruct H. destruct H. destruct H.
injection H. auto.
Qed.

Inductive okay : tree -> Prop :=
| okay_leaf : okay Leaf
| okay_black : forall (n : nat) (t0 t1 : tree),
    okay t0 -> okay t1 ->
    okay (Node n black t0 t1)
| okay_red : forall (n : nat) (t0 t1 : tree),
    okay t0 -> okay t1 ->
    not_red_root t0 -> not_red_root t1 ->
    okay (Node n red t0 t1).

Definition id : Set := nat.

Inductive aexp : Set :=
| ANum : nat -> aexp
| AId : id -> aexp
| AMinus : aexp -> aexp -> aexp.

Inductive com : Set :=
| CSkip : com
| CAss : id -> aexp -> com
| CSeq : com -> com -> com
| CIf : aexp -> com -> com -> com
| CWhile : aexp -> com -> com.

Definition state : Set := id -> nat.

Fixpoint aeval (a : aexp) (s : state) {struct a} : nat :=
  match a with
  | ANum n => n
  | AId x => s x
  | AMinus a1 a2 => minus (aeval a1 s) (aeval a2 s)
  end.

Require Import Peano_dec.

Definition update (s : state) (x : id) (n : nat) : state :=
  fun y : id => if eq_nat_dec x y then n else s y.

Inductive ceval : com -> state -> state -> Prop :=
| E_Skip : forall s, ceval CSkip s s
| E_Ass : forall x a s n, aeval a s = n ->
    ceval (CAss x a) s (update s x n)
| E_Seq : forall c1 c2 s s' s'',
    ceval c1 s s' -> ceval c2 s' s'' ->
    ceval (CSeq c1 c2) s s''
| E_IfTrue : forall a c1 c2 s s',
    ~(aeval a s = O) -> ceval c1 s s' ->
    ceval (CIf a c1 c2) s s'
| E_IfFalse : forall a c1 c2 s s',
    aeval a s = O -> ceval c2 s s' ->
    ceval (CIf a c1 c2) s s'
| E_WhileLoop : forall a c s s' s'',
    ~(aeval a s = O) ->
    ceval c s s' -> ceval (CWhile a c) s' s'' ->
    ceval (CWhile a c) s s''
| E_WhileEnd : forall a c s,
    aeval a s = O ->
    ceval (CWhile a c) s s.

Inductive astep : state -> aexp -> aexp -> Prop :=
| AS_Id : forall s x,
    astep s (AId x) (ANum (s x))
| AS_Minus1 : forall s a1 a1' a2,
    astep s a1 a1' ->
    astep s (AMinus a1 a2) (AMinus a1' a2)
| AS_Minus2 : forall s n a2 a2',
    astep s a2 a2' ->
    astep s (AMinus (ANum n) a2) (AMinus (ANum n) a2')
| AS_Minus : forall s n1 n2,
    astep s (AMinus (ANum n1) (ANum n2))
      (ANum (minus n1 n2)).

Inductive cstep : (com * state) -> (com * state) -> Prop :=
| CS_AssStep : forall x a a' s,
    astep s a a' ->
    cstep (CAss x a, s) (CAss x a', s)
| CS_Ass : forall x n s,
    cstep (CAss x (ANum n), s) (CSkip, update s x n)
| CS_SeqStep : forall c1 c1' c2 s s',
    cstep (c1, s) (c1', s') ->
    cstep (CSeq c1 c2, s) (CSeq c1' c2, s')
| CS_SeqFinish : forall c2 s,
    cstep (CSeq CSkip c2, s) (c2, s)
| CS_IfStep : forall a a' c1 c2 s,
    astep s a a' ->
    cstep (CIf a c1 c2, s) (CIf a' c1 c2, s)
| CS_IfTrue : forall n c1 c2 s,
    ~(n = O) ->
    cstep (CIf (ANum n) c1 c2, s) (c1, s)
| CS_IfFalse : forall c1 c2 s,
    cstep (CIf (ANum O) c1 c2, s) (c2, s)
| CS_While : forall a c s,
    cstep (CWhile a c, s)
      (CIf a (CSeq c (CWhile a c)) CSkip, s).

Inductive star {X : Set} (R : X -> X -> Prop) : (X -> X -> Prop) :=
| star_refl : forall x : X, star R x x
| star_step : forall x y z : X,
    R x y -> star R y z -> star R x z.

Lemma star_single : forall  (X : Set) (R : X -> X -> Prop) x y,
  R x y -> star R x y.
intros. eapply star_step. apply H. apply star_refl.
Qed.

Lemma star_trans : forall (X : Set) (R : X -> X -> Prop) x y z,
  star R x y -> star R y z -> star R x z.
intros. induction H. auto.
eapply star_step. apply H. auto.
Qed.

Lemma astep_aeval : forall s a a',
  astep s a a' -> aeval a s = aeval a' s.
induction a; intros; inversion H; simpl; auto.
Qed.

Lemma star_astep_aeval : forall s a a',
  star (astep s) a a' -> aeval a s = aeval a' s.
intros. induction H; auto.
eapply eq_trans. apply astep_aeval. apply H. auto.
Qed.

Lemma asem_forward : forall a s n,
  star (astep s) a (ANum n) -> aeval a s = n.
intros. erewrite star_astep_aeval. 2: apply H. auto.
Qed.

Lemma aminus_cong_left : forall a1 a1' a2 s,
  star (astep s) a1 a1' ->
  star (astep s) (AMinus a1 a2) (AMinus a1' a2).
intros. induction H. apply star_refl.
eapply star_step. apply AS_Minus1. apply H. auto.
Qed.

Lemma aminus_cong_right : forall n a2 a2' s,
  star (astep s) a2 a2' ->
  star (astep s) (AMinus (ANum n) a2) (AMinus (ANum n) a2').
intros. induction H. apply star_refl.
eapply star_step. apply AS_Minus2. apply H. auto.
Qed.

Lemma asem_backward : forall a s,
  star (astep s) a (ANum (aeval a s)).
induction a; simpl; intros.
apply star_refl.
eapply star_step. apply AS_Id. apply star_refl.
eapply star_trans. apply aminus_cong_left. apply IHa1.
eapply star_trans. apply aminus_cong_right. apply IHa2.
apply star_single. apply AS_Minus.
Qed.

Lemma asem_equiv : forall a s n,
  star (astep s) a (ANum n) <-> aeval a s = n.
intros. split; intros; subst. apply asem_forward; auto.
apply asem_backward.
Qed.

Require Import Program.

Lemma cstep_ceval : forall c c' s s' s'',
  cstep (c, s) (c', s') -> ceval c' s' s'' -> ceval c s s''.
intros. generalize s'' H0. clear s'' H0.
dependent induction H; intros.
inversion H0; subst. apply E_Ass; auto. apply astep_aeval; auto.
inversion H0; subst. apply E_Ass; auto.
inversion H0; subst. eapply E_Seq. apply IHcstep. apply H3. auto.
eapply E_Seq. apply E_Skip. auto.
inversion H0; subst.
apply E_IfTrue; auto. erewrite astep_aeval. apply H6. auto.
apply E_IfFalse; auto. erewrite astep_aeval. apply H6. auto.
apply E_IfTrue; auto.
apply E_IfFalse; auto.
inversion H0; subst. inversion H6; subst. eapply E_WhileLoop; eauto.
inversion H6; subst. apply E_WhileEnd. auto.
Qed.

Lemma star_cstep_ceval : forall  c c' s s' s'',
  star cstep (c, s) (c', s') -> ceval c' s' s'' -> ceval c s s''.
intros. dependent induction H; auto.
destruct y. eapply cstep_ceval. apply H. eapply IHstar; auto.
Qed.

Lemma csem_forward : forall c s s',
  star cstep (c, s) (CSkip, s') -> ceval c s s'.
intros. eapply star_cstep_ceval. apply H. auto. auto. apply E_Skip.
Qed.

Lemma ass_cong : forall s x a a',
  star (astep s) a a' ->
  star cstep (CAss x a, s) (CAss x a', s).
intros. induction H. apply star_refl.
eapply star_step. 2: apply IHstar. apply CS_AssStep. auto.
Qed.

Lemma seq_cong_left : forall c1 c1' c2 s s',
  star cstep (c1, s) (c1', s') ->
  star cstep (CSeq c1 c2, s) (CSeq c1' c2, s').
intros. dependent induction H. apply star_refl.
destruct y. eapply star_step. apply CS_SeqStep. apply H.
apply IHstar; auto.
Qed.

Lemma if_cong : forall s a a' c1 c2,
  star (astep s) a a' ->
  star cstep (CIf a c1 c2, s) (CIf a' c1 c2, s).
intros. induction H. apply star_refl.
eapply star_step. 2: apply IHstar. apply CS_IfStep. auto.
Qed.

Lemma csem_backward : forall c s s',
  ceval c s s' -> star cstep (c, s) (CSkip, s').
intros. dependent induction H.
apply star_refl.
eapply star_trans. eapply ass_cong. apply asem_backward.
apply star_single. rewrite H. apply CS_Ass.
eapply star_trans. apply seq_cong_left. apply IHceval1.
eapply star_step. apply CS_SeqFinish. apply IHceval2.
eapply star_trans. eapply if_cong. apply asem_backward.
eapply star_step. apply CS_IfTrue; auto. auto.
eapply star_trans. eapply if_cong. apply asem_backward.
eapply star_step. rewrite H. apply CS_IfFalse. auto.
eapply star_step. apply CS_While.
eapply star_trans. eapply if_cong. apply asem_backward.
eapply star_step. apply CS_IfTrue; auto.
eapply star_trans. apply seq_cong_left. apply IHceval1.
eapply star_step. apply CS_SeqFinish. auto.
eapply star_step. apply CS_While.
eapply star_trans. eapply if_cong. apply asem_backward.
eapply star_step. rewrite H. apply CS_IfFalse. apply star_refl.
Qed.

Lemma csem_equiv :
forall (c : com) (s s' : state),
ceval c s s' <-> star cstep (c, s) (CSkip, s').
intros. split. apply csem_backward. apply csem_forward.
Qed.

(*
Print Assumptions csem_equiv.
*)

