`    1:- module(ccp_mcmc, [mc_evidence/4, mh_machine/4, gibbs_machine/5]).`

# Gibbs and Metropolis-Hastings explanation samplers */

```    5:- use_module(library(insist)).    6:- use_module(library(callutils),   [(*)/4]).    7:- use_module(library(listutils),   [enumerate/2]).    8:- use_module(library(math),        [neg/2, add/3, sub/3, exp/2, map_sum/4]).    9:- use_module(library(data/pair),   [is_pair/1, pair/3, fst/2, fsnd/3, snd/2]).   10:- use_module(library(plrand),      [log_partition_dirichlet/2]).   11:- use_module(library(machines),    [unfold/2, unfolder/3, mapper/3, scan0/4, (:>)/3, mean/2, op(600,yfx,:>)]).   12
13:- use_module(effects,    [dist/2, uniform/2]).   14:- use_module(learn,      [converge/5, learn/4]).   15:- use_module(switches,   [ map_sum_sw/3, map_sum_sw/4, map_swc/4
16                          , sw_expectations/2, sw_log_prob/3, sw_posteriors/3, sw_samples/2
17                          ]).   18:- use_module(graph,      [ top_goal/1, top_value/2, tree_stats/2, sw_trees_stats/3, graph_fold/4
19                          , graph_inside/3, prune_graph/4, igraph_sample_tree/3
20                          ]).   21
22bernoulli(P1,X) :- P0 is 1-P1, dist([P0-0,P1-1],X).
23
24mc_evidence(Method, Graph, Prior, Stream) :-
25   converge(rel(1e-6), learn(vb(Prior), io(log), Graph), _, Prior, VBPost),
26   sw_expectations(VBPost, VBProbs),
29   method_machine_mapper(Method, Prior, Machine, Mapper),
30   unfold(call(Machine, Graph, Prior, VBProbs)
31          :> mapper(p_params_given_post(VBProbs)*Mapper) :> mean
33
34p_params_given_post(Probs,Post,P) :- sw_log_prob(Post,Probs,LP), P is exp(LP).
35
36method_machine_mapper(gibbs,  _,     ccp_mcmc:gibbs_machine(posterior), =).
37method_machine_mapper(mh,     Prior, ccp_mcmc:mh_machine, ccp_mcmc:sw_posteriors(Prior)*mcs_counts).
38
39gibbs_machine(Rot, Graph, Prior, P1, M) :-
40   graph_inside(Graph, P0, IG),
41   rotation(Rot, sw_posteriors(Prior), gstep(P0,IG), sw_samples, Step),
42   unfolder(scan0(Step), P1, M).
43
44:- meta_predicate rotation(+,2,2,2,-).   45rotation(posterior,Post, Step, Sample, Post*Step*Sample).
46rotation(counts,   Post, Step, Sample, Step*Sample*Post).
47rotation(params,   Post, Step, Sample, Sample*Post*Step).
48
49gstep(P0,IG,P1,Counts) :-
50   copy_term(P0-IG,P1-IG1),
51   top_goal(Top),
52   igraph_sample_tree(IG1, Top, Trees),
53   tree_stats(Top-Trees, Counts).
54
55mh_machine(Graph, Prior, Probs0, M) :-
56   graph_as_conjunction(Graph, Graph1),
57   call(snd * top_value * graph_fold(best, Probs0), Graph1, VTrees),
58   maplist(fst,Prior,SWs),
59   mcs_init(SWs, VTrees, Keys, State),
60   (  Keys=[] -> unfolder(scan0(=), State, M)
61   ;  make_tree_sampler(Graph1, SampleGoal),
62      (mcs_unit_counts(State) -> Stepper=gibbs; Stepper=mh),
63      unfolder(scan0(mc_step(Stepper, Keys, SampleGoal, SWs, Prior)), State, M)
64   ).
65
66graph_as_conjunction(Graph, Graph) :- top_value(Graph, [_]), !.
67graph_as_conjunction(Graph, [Top-[[Dummy]], Dummy-Expls | Graph0]) :-
68   top_goal(Top), Dummy = '^mcmc':dummy,
69   select(Top-Expls, Graph, Graph0).
70
71mc_sample(SampleGoal, SWs, Probs, T1, T2) :-
72   mct_goal(T1, Goal), call(SampleGoal, Probs, Goal, Tree),
73   mct_make(SWs, Goal, Tree, T2).
74
75make_tree_sampler(G, ccp_mcmc:sample_goal(IGs)) :-
76   top_value(G, [Factors]),
77   sort(Factors, UniqueFactors),
78   maplist(sub_igraph(G), UniqueFactors, IGs).
79
80sub_igraph(G, Goal, Goal-(IG-Ps)) :-
81   prune_graph(=, Goal, G, SubGraph),
82   graph_inside(SubGraph, Ps, IG).
83
84sample_goal(IGs, PP, Goal, Trees) :-
85   memberchk(Goal-(IG0-P0), IGs), % use rbtree for faster lookup
86   copy_term(P0-IG0, P1-IG1),
87   param_subset(P1, PP),
88   igraph_sample_tree(IG1, Goal, Trees).
89
90param_subset([], _).
91param_subset([H1-V1|T1], [H2-V2|T2]) :-
92    compare(Rel, H1, H2),
93    psub_aux(Rel, H1, V1, V2, T1, T2).
94
95psub_aux(>, H1, V1, _, T1, [H2-V2|T2]) :-
96    compare(Rel, H1, H2),
97    psub_aux(Rel, H1, V1, V2, T1, T2).
98psub_aux(=, _, V, V, T1, T2) :-
99    param_subset(T1, T2).
100
101mc_step(gibbs, Keys, SampleGoal, SWs, Prior, State1, State2) :-
102   mcs_random_select(Keys, TK_O, State1, StateExK),
103   mcs_dcounts(StateExK, CountsExK),
104   sw_posteriors(Prior, CountsExK, PostExK),
105   sw_expectations(PostExK, ProbsExK),
106   mc_sample(SampleGoal, SWs, ProbsExK, TK_O, TK_P),
107   mcs_rebuild(TK_P, StateExK, State2).
108
109mc_step(mh, Keys, SampleGoal, SWs, Prior, State1, State2) :-
110   mcs_random_select(Keys, TK_O, State1, StateExK),
111   mcs_dcounts(StateExK, CountsExK),
112   sw_posteriors(Prior, CountsExK, PostExK),
113   sw_expectations(PostExK, ProbsExK),
114   mc_sample(SampleGoal, SWs, ProbsExK, TK_O, TK_P),
115   maplist(tree_acceptance_weight(PostExK, ProbsExK), [TK_O, TK_P], [W_O, W_P]),
116   D is W_P-W_O, (D>= -1e-13 -> Accept=1; call(bernoulli*exp, D, Accept)),
117   (Accept=0 -> State2=State1; mcs_rebuild(TK_P, StateExK, State2)).
118
119tree_acceptance_weight(PostExTree, PProbs, Tree, W) :-
120   mct_counts(Tree, Counts),
121   sw_posteriors(PostExTree, Counts, Post),
122   map_sum_sw(log_partition_dirichlet, Post, LZ),
123   map_sum_sw(map_sum(log_mul), PProbs, Counts, LP),
124   W is LZ - LP.
125log_mul(Prob, N, X) :- X is N*log(Prob).
126
127% MCS: Monte Carlo state: rbtree to map K to tree, stash counts
128mcs_init(SWs, VTrees, Ks, Totals-Map) :-
129   sw_trees_stats(SWs, VTrees, Totals),
130   call(list_to_rbtree * enumerate * map_stats(SWs) * include(is_pair), VTrees, Map),
131   rb_keys(Map, Ks).
132
133mcs_random_select(Ks, G-C, Totals-Map, dmhs(K,CountsExK,MapExK)) :-
134   uniform(Ks,K),
135   rb_delete(Map, K, G-C, MapExK),
136   map_swc(sub, C, Totals, CountsExK).
137
138mcs_rebuild(G-C, dmhs(K,CountsExK,MapExK), Totals-Map) :-
139   sw_posteriors(C, CountsExK, Totals),
140   rb_insert_new(MapExK, K, G-C, Map).
141
142mcs_dcounts(dmhs(_,CountsExK,_), CountsExK).
143mcs_counts(Counts-_, Counts).
144mcs_unit_counts(_-Map) :-
145   forall(rb_in(_,_-GCs,Map), forall(member(_-C, GCs), sumlist(C,1))).
146
147mct_goal(Goal-_, Goal).
148mct_make(SWs, Goal, T, Goal-C) :- sw_trees_stats(SWs,T,C).
149mct_counts(_-C,C).
150
151map_stats(SWs, Trees, Stats) :- maplist(fsnd(sw_trees_stats(SWs)), Trees, Stats)```