1:- module(ccp_mcmc, [mc_evidence/4, mh_machine/4, gibbs_machine/5]).
5:- use_module(library(insist)). 6:- use_module(library(callutils), [(*)/4]). 7:- use_module(library(listutils), [enumerate/2]). 8:- use_module(library(math), [neg/2, add/3, sub/3, exp/2, map_sum/4]). 9:- use_module(library(data/pair), [is_pair/1, pair/3, fst/2, fsnd/3, snd/2]). 10:- use_module(library(plrand), [log_partition_dirichlet/2]). 11:- use_module(library(machines), [unfold/2, unfolder/3, mapper/3, scan0/4, (:>)/3, mean/2, op(600,yfx,:>)]). 12
13:- use_module(effects, [dist/2, uniform/2]). 14:- use_module(learn, [converge/5, learn/4]). 15:- use_module(switches, [ map_sum_sw/3, map_sum_sw/4, map_swc/4
16 , sw_expectations/2, sw_log_prob/3, sw_posteriors/3, sw_samples/2
17 ]). 18:- use_module(graph, [ top_goal/1, top_value/2, tree_stats/2, sw_trees_stats/3, graph_fold/4
19 , graph_inside/3, prune_graph/4, igraph_sample_tree/3
20 ]). 21
22bernoulli(P1,X) :- P0 is 1-P1, dist([P0-0,P1-1],X).
23
24mc_evidence(Method, Graph, Prior, Stream) :-
25 converge(rel(1e-6), learn(vb(Prior), io(log), Graph), _, Prior, VBPost),
26 sw_expectations(VBPost, VBProbs),
27 call(top_value*graph_fold(r(log,lse,add,cons),VBProbs), Graph, LogPDataGivenVBProbs),
28 call(add(LogPDataGivenVBProbs)*sw_log_prob(Prior), VBProbs, LogPDataVBProbs),
29 method_machine_mapper(Method, Prior, Machine, Mapper),
30 unfold(call(Machine, Graph, Prior, VBProbs)
31 :> mapper(p_params_given_post(VBProbs)*Mapper) :> mean
32 :> mapper(add(LogPDataVBProbs)*neg*log), Stream).
33
34p_params_given_post(Probs,Post,P) :- sw_log_prob(Post,Probs,LP), P is exp(LP).
35
36method_machine_mapper(gibbs, _, ccp_mcmc:gibbs_machine(posterior), =).
37method_machine_mapper(mh, Prior, ccp_mcmc:mh_machine, ccp_mcmc:sw_posteriors(Prior)*mcs_counts).
38
39gibbs_machine(Rot, Graph, Prior, P1, M) :-
40 graph_inside(Graph, P0, IG),
41 rotation(Rot, sw_posteriors(Prior), gstep(P0,IG), sw_samples, Step),
42 unfolder(scan0(Step), P1, M).
43
44:- meta_predicate rotation(+,2,2,2,-). 45rotation(posterior,Post, Step, Sample, Post*Step*Sample).
46rotation(counts, Post, Step, Sample, Step*Sample*Post).
47rotation(params, Post, Step, Sample, Sample*Post*Step).
48
49gstep(P0,IG,P1,Counts) :-
50 copy_term(P0-IG,P1-IG1),
51 top_goal(Top),
52 igraph_sample_tree(IG1, Top, Trees),
53 tree_stats(Top-Trees, Counts).
54
55mh_machine(Graph, Prior, Probs0, M) :-
56 graph_as_conjunction(Graph, Graph1),
57 call(snd * top_value * graph_fold(best, Probs0), Graph1, VTrees),
58 maplist(fst,Prior,SWs),
59 mcs_init(SWs, VTrees, Keys, State),
60 ( Keys=[] -> unfolder(scan0(=), State, M)
61 ; make_tree_sampler(Graph1, SampleGoal),
62 (mcs_unit_counts(State) -> Stepper=gibbs; Stepper=mh),
63 unfolder(scan0(mc_step(Stepper, Keys, SampleGoal, SWs, Prior)), State, M)
64 ).
65
66graph_as_conjunction(Graph, Graph) :- top_value(Graph, [_]), !.
67graph_as_conjunction(Graph, [Top-[[Dummy]], Dummy-Expls | Graph0]) :-
68 top_goal(Top), Dummy = '^mcmc':dummy,
69 select(Top-Expls, Graph, Graph0).
70
71mc_sample(SampleGoal, SWs, Probs, T1, T2) :-
72 mct_goal(T1, Goal), call(SampleGoal, Probs, Goal, Tree),
73 mct_make(SWs, Goal, Tree, T2).
74
75make_tree_sampler(G, ccp_mcmc:sample_goal(IGs)) :-
76 top_value(G, [Factors]),
77 sort(Factors, UniqueFactors),
78 maplist(sub_igraph(G), UniqueFactors, IGs).
79
80sub_igraph(G, Goal, Goal-(IG-Ps)) :-
81 prune_graph(=, Goal, G, SubGraph),
82 graph_inside(SubGraph, Ps, IG).
83
84sample_goal(IGs, PP, Goal, Trees) :-
85 memberchk(Goal-(IG0-P0), IGs), 86 copy_term(P0-IG0, P1-IG1),
87 param_subset(P1, PP),
88 igraph_sample_tree(IG1, Goal, Trees).
89
90param_subset([], _).
91param_subset([H1-V1|T1], [H2-V2|T2]) :-
92 compare(Rel, H1, H2),
93 psub_aux(Rel, H1, V1, V2, T1, T2).
94
95psub_aux(>, H1, V1, _, T1, [H2-V2|T2]) :-
96 compare(Rel, H1, H2),
97 psub_aux(Rel, H1, V1, V2, T1, T2).
98psub_aux(=, _, V, V, T1, T2) :-
99 param_subset(T1, T2).
100
101mc_step(gibbs, Keys, SampleGoal, SWs, Prior, State1, State2) :-
102 mcs_random_select(Keys, TK_O, State1, StateExK),
103 mcs_dcounts(StateExK, CountsExK),
104 sw_posteriors(Prior, CountsExK, PostExK),
105 sw_expectations(PostExK, ProbsExK),
106 mc_sample(SampleGoal, SWs, ProbsExK, TK_O, TK_P),
107 mcs_rebuild(TK_P, StateExK, State2).
108
109mc_step(mh, Keys, SampleGoal, SWs, Prior, State1, State2) :-
110 mcs_random_select(Keys, TK_O, State1, StateExK),
111 mcs_dcounts(StateExK, CountsExK),
112 sw_posteriors(Prior, CountsExK, PostExK),
113 sw_expectations(PostExK, ProbsExK),
114 mc_sample(SampleGoal, SWs, ProbsExK, TK_O, TK_P),
115 maplist(tree_acceptance_weight(PostExK, ProbsExK), [TK_O, TK_P], [W_O, W_P]),
116 D is W_P-W_O, (D>= -1e-13 -> Accept=1; call(bernoulli*exp, D, Accept)),
117 (Accept=0 -> State2=State1; mcs_rebuild(TK_P, StateExK, State2)).
118
119tree_acceptance_weight(PostExTree, PProbs, Tree, W) :-
120 mct_counts(Tree, Counts),
121 sw_posteriors(PostExTree, Counts, Post),
122 map_sum_sw(log_partition_dirichlet, Post, LZ),
123 map_sum_sw(map_sum(log_mul), PProbs, Counts, LP),
124 W is LZ - LP.
125log_mul(Prob, N, X) :- X is N*log(Prob).
126
128mcs_init(SWs, VTrees, Ks, Totals-Map) :-
129 sw_trees_stats(SWs, VTrees, Totals),
130 call(list_to_rbtree * enumerate * map_stats(SWs) * include(is_pair), VTrees, Map),
131 rb_keys(Map, Ks).
132
133mcs_random_select(Ks, G-C, Totals-Map, dmhs(K,CountsExK,MapExK)) :-
134 uniform(Ks,K),
135 rb_delete(Map, K, G-C, MapExK),
136 map_swc(sub, C, Totals, CountsExK).
137
138mcs_rebuild(G-C, dmhs(K,CountsExK,MapExK), Totals-Map) :-
139 sw_posteriors(C, CountsExK, Totals),
140 rb_insert_new(MapExK, K, G-C, Map).
141
142mcs_dcounts(dmhs(_,CountsExK,_), CountsExK).
143mcs_counts(Counts-_, Counts).
144mcs_unit_counts(_-Map) :-
145 forall(rb_in(_,_-GCs,Map), forall(member(_-C, GCs), sumlist(C,1))).
146
147mct_goal(Goal-_, Goal).
148mct_make(SWs, Goal, T, Goal-C) :- sw_trees_stats(SWs,T,C).
149mct_counts(_-C,C).
150
151map_stats(SWs, Trees, Stats) :- maplist(fsnd(sw_trees_stats(SWs)), Trees, Stats)
Gibbs and Metropolis-Hastings explanation samplers */