1:- module(crp_tagged,
    2		[	empty_classes/1
    3		,	dec_class//3
    4		,	inc_class//1
    5		,	remove_class//1
    6		,	add_class//2
    7
    8		,	crp_prob/5
    9		,	crp_sample/5
   10		,	crp_sample_obs/7
   11		,	crp_sample_obs/8
   12		,	crp_sample_rm/5
   13
   14		,	dp_sampler_teh/3
   15		,	py_sampler_teh/4
   16		]).

Chinese Restaurant Process utilities (alternative version)

This module provides some building blocks for implementing a family of random processes related to Dirichlet processes, including Pitman Yor processe, the Chinese Restaurant process, and the stick breaking model (GEM). The Dirichlet processes takes a single concentration parameter, representated as dp(Conc), while the Pitman Yor process takes a concentration parameter and a discount parameter, representated as py(Conc,Disc).

gem_param   ---> dp(nonneg) ; py(nonneg,0--1).
gamma_prior ---> gamma(nonneg, nonneg).
beta_prior  ---> beta(nonneg, nonneg).
classes(A)  ---> classes(natural, list(nonneg), list(A)).
action(A)   ---> new ; old(A, class_idx).
action      ---> new ; old(class_idx).

rndstate  == plrand:state
class_idx == natural
prob      == 0--1

param_sampler == pred(+gem_param, -gem_param, +rndstate, -rndstate).

This may seems like a very low-level library for building CRPs, leaving a lot for the implemeenter to do, but this is intentional, to allow the implementer freedom to decide how to manage the states (terms of type classes(_)) of one or more CRPs, as well as the state of the random generator, in whatever way is most appropriate. See the the example implementation of test_crp.pl for one way to do this. */

   49% :- use_module(library(dcg_core)).
   50% :- use_module(library(dcg_macros)).
   51:- use_module(library(apply_macros)).   52:- use_module(library(plrand), [crp_prob/5]).   53:- use_module(library(math),   [sub/3, equal/3, stoch/3, mul/3]).   54:- use_module(library(prob/tagged), [discrete//2]).
 crp_prob(+GEM:gem_param, +Classes:classes(A), +X:A, +PBase:prob, -Prob:prob) is det
Compute the probability Prob of observing X given a CRP with already observed values in Classes if the probability of drawing X from the base distribution is PBase.
   62% crp_prob( Alpha, classes(_,Counts,Vals), X, PProb, P) :-
   63% 	counts_dist( Alpha, Counts, Counts1),
   64% 	stoch( Counts1, Probs, _),
   65% 	maplist( equal(X), Vals, Mask),
   66% 	maplist( mul, [PProb | Mask], Probs, PostProbs),
   67% 	sumlist( PostProbs, P).
 crp_sample(+GEM:gem_param, +Classes:classes(A), -A:action(A))// is det
crp_sample(+GEM:gem_param, +Classes:classes(A), -A:action(A))// is det
Sample a new value from CRP, Action A is either new, which means that the user should sample a new value from the base distribtion, or old(X,ID), where X is an old value and C is the class index. Operates in random state DCG. of the action choosen.
   78crp_sample( Alpha, classes(_,Counts,Vals), Action, RS1, RS2) :-
   79	counts_dist(Alpha, Counts, Counts1),
   80   stoch(Counts1, Probs, _),
   81	discrete(Probs,Z,RS1,RS2),
   82	( Z>1 -> succ(C,Z), nth1(C,Vals,X), Action=old(X,C)
   83	; Action=new).
 crp_sample_obs(+GEM:gem_param, +Classes:classes(A), +X:A, +PBase:prob, -A:action, -P:prob)// is det
 crp_sample_obs(+GEM:gem_param, +Classes:classes(A), +X:A, +PBase:prob, -A:action)// is det
Sample action appropriate for observation of value X. PBase is the probability of X from the base distribution. Action A is new or old(N) where N is the class index. crp_sample_obs//6 additionally returns the probability of the observation, equivalent to calling crp_prob with X BEFORE calling crp_sample_obs//5. Operates in random state DCG.
   97crp_sample_obs(GEM, classes(_,Counts,Vals), X, PBase, Action, RS1, RS2) :- 
   98	counts_dist( GEM, Counts, [CNew|Counts1]),	
   99	PNew is CNew*PBase,
  100	maplist( posterior_count(X),Vals,Counts1,Counts2),
  101   stoch([PNew|Counts2], Probs, _),
  102	discrete(Probs, Z, RS1, RS2),
  103	(Z=1 -> Action=new; succ(C,Z), Action=old(C)).
  104
  105crp_sample_obs( GEM, classes(_,Counts,Vals), X, PBase, A, ProbX, RS1, RS2) :-
  106	counts_dist( GEM, Counts, [CNew|Counts1]),	
  107	PNew is CNew*PBase,
  108	maplist( posterior_count(X),Vals,Counts1,Counts2),
  109   sumlist([CNew|Counts1], Total),
  110   stoch([PNew|Counts2], Probs, TotalX),
  111   ProbX is TotalX / Total,
  112	discrete(Probs, Z, RS1, RS2),
  113	(Z=1 -> A=new; succ(C,Z), A=old(C)).
 crp_sample_rm(+Classes:classes(A), +X:A, -N:class_idx)// is det
Sample appropriate class index N from which to remove value X. Operates in random state DCG.
  121crp_sample_rm( classes(_,Counts,Vals), X, Class, RS1, RS2) :-
  122	maplist(posterior_count(X),Vals,Counts,Counts1),
  123   stoch(Counts1, Probs, _),
  124	discrete(Probs, Class, RS1, RS2).
  125
  126
  127
  128% --------------------------------------------------------------------------------
  129% classes data structure (basic CRP stuff)
 empty_classes(-Classes:classes(_)) is det
Unify Classes with an empty classes structure.
  135empty_classes(classes(0,[],[])).
 dec_class(+N:class_idx, -C:natural, -X:A, +C1:classes(A), -C2:classes(A)) is det
Decrement count associated with class id N. C is the count after decrementing and X is the value associated with the Nth class.
  142dec_class(N,CI,X,classes(K,C1,Vs),classes(K,C2,Vs)) :- dec_nth(N,CI,C1,C2), nth1(N,Vs,X).
  143dec_nth(1,Y,[X|T],[Y|T]) :- succ(Y,X).
  144dec_nth(N,Y,[X|T1],[X|T2]) :- succ(M,N), dec_nth(M,Y,T1,T2).
 inc_class(+N:class_idx, +C1:classes(A), -C2:classes(A)) is det
Increment count associated with class N.
  149inc_class(N,classes(K,C1,V),classes(K,C2,V)) :- inc_nth(N,C1,C2).
  150inc_nth(1,[X|T],[Y|T]) :- succ(X,Y).
  151inc_nth(N,[X|T1],[X|T2]) :- succ(M,N), inc_nth(M,T1,T2).
 remove_class(+N:class_idx, +C1:classes(A), -C2:classes(A)) is det
Removes Nth class.
  157remove_class(N,classes(K1,C1,V1),classes(K2,C2,V2)) :-
  158	remove_nth(N,C1,C2),
  159	remove_nth(N,V1,V2),
  160	succ(K2,K1).
 add_class(+X:A, -ID:class_idx, +C1:classes(A), -C2:classes(A)) is det
Add a class associated with value X. N is the index of the new class.
  165add_class(X,N2,classes(N1,C1,V1),classes(N2,C2,V2)) :-
  166	succ(N1,N2),
  167	append(C1,[1],C2),
  168	append(V1,[X],V2). 
  169
  170
  171remove_nth(1,[_|T],T).
  172remove_nth(N,[Y|T1],[Y|T2]) :- 
  173	(	var(N) 
  174	->	remove_nth(M,T1,T2), succ(M,N)
  175	;	succ(M,N), remove_nth(M,T1,T2)
  176	).
  177
  178posterior_count(X,Val,Count,PC) :- X=Val -> PC=Count; PC=0.
  179
  180% -----------------------------------------------------------
  181% Dirichlet process and Pitman-Yor process
  182% pseudo-counts models.
  183
  184counts_dist(dp(Alpha),Counts,[Alpha|Counts]) :- !.
  185counts_dist(py(_,_),[],[1]) :- !.
  186counts_dist(py(Alpha,Discount),Counts,[CNew|Counts1]) :- !,
  187	length(Counts,K),
  188	CNew is Alpha+Discount*K,
  189	maplist(sub(Discount),Counts,Counts1).
  190
  191% ---------------------------------------------------------------
  192% PARAMETER SAMPLING
  193% Initialisers in Prolog, samplers written in C.
 dp_sampler_teh(+Prior:gamma_prior, +Counts:list(natural), -S:param_sampler) is det
Prepares a predicate for sampling the concentration parameter of a Dirichlet process. The sampler's gem_prior arguments must be of the form dp(_). Prior specifies the Gamma distribution prior for the concentration parameter, as gamma(a,b), where a is the shape parameter and b is the rate parameter (ie the inverse of the scale parameter).
  202dp_sampler_teh( gamma(A,B), CX, plrand:sample_dp_teh(ApSumKX,B,NX)) :-
  203	maplist(sumlist,CX,NX),
  204	maplist(length,CX,KX), 
  205	sumlist(KX,SumKX), 
  206	ApSumKX is A+SumKX.
 py_sampler_teh(+ConcPrior:gamma_prior, +DiscPr:beta_prior, +Counts:list(natural), -S:param_sampler) is det
Prepares a predicate for sampling the concentration and discount parameters of a Pitman-Yor process. The sampler's gem_prior arguments must be of the form py(_,_). See dp_sampler_teh/3 for tha description of the gamma_prior type. DiscPr is a Beta distribution prior for the concentration parameter.
  215py_sampler_teh( ThPrior, DiscPrior, CountsX, Sampler) :-
  216	Sampler = plrand:sample_py_teh( ThPrior, DiscPrior, CountsX)