:- module(lda, [lda/5, lda2/5, unif/2, unif//2, mkunif/2, dirichlet/3]). /* Samplers for Latent Dirichlet Allocation (no inference) */ :- use_module(library(ccprism/macros)). :- use_module(library(ccprism/effects)). :- use_module(library(math), [mul/3, add/3, stoch/3]). :- use_module(library(listutils), [zip/3, enumerate/2]). :- use_module(library(data/pair), [pair/3, fst/2]). :- use_module(library(callutils)). :- use_module(library(prob/tagless)). :- use_module(library(lex/biglex)). % load time execution of body to get head(s) term_expansion(Head := G, Heads) :- findall(Head, call(G), Heads). places(Xs) := findall(X,place(_,X),Xs). nouns(Xs) := findall(X,noun(_,_,X,_),Xs). names(Xs) := findall(X,pname(_,X),Xs). unif(G,X,S,S) :- call(G,Xs), uniform(Xs,X). unif(G,X) :- call(G,Xs), uniform(Xs,X). mkunif(G,D) :- call(G,Xs), length(Xs,N), P is 1/N, maplist(pair(P),Xs,D). dirichlet(Alpha,Base,Dist) :- zip(Probs,Vals,Base), zip(Probs1,Vals,Dist), maplist(mul(Alpha), Probs, Alphas), sample(dirichlet(Alphas),Probs1). lda(Eta, K, Alpha, L, Docs) :- mkunif(nouns,Noun), length(Topics, K), maplist(dirichlet(Eta,Noun), Topics), mkunif(=(Topics), Topic), maplist(doc(Alpha,Topic,L), Docs). doc(Alpha, Topic, L, Doc) :- length(Doc, L), dirichlet(Alpha,Topic, TopicDist), maplist(topic_dist_word(TopicDist), Doc). topic_dist_word(TopicDist, Word) :- dist(TopicDist, Topic), dist(Topic, Word). % this version avoids a lot of (un)zipping and is faster mkunif2(G,Probs,Vals) :- call(G,Vals), length(Vals,N), P is 1/N, maplist(const(P),Vals,Probs). dirichlet2(Alpha,Probs,Probs1) :- maplist(mul(Alpha), Probs, Alphas), sample(dirichlet(Alphas),Probs1). lda2(Eta, K, Alpha, L, Docs) :- mkunif2(nouns,NounProbs,Nouns), length(Topics, K), maplist(dirichlet2(Eta,NounProbs), Topics), mkunif2(=(Topics), TopicProbs, _), maplist(doc2(Nouns, Alpha,Topics, TopicProbs,L), Docs). doc2(Nouns, Alpha, Topics, TopicProbs, L, Doc) :- length(Doc, L), dirichlet2(Alpha, TopicProbs, TopicDist), maplist(topic_dist_word2(Nouns, Topics, TopicDist), Doc). topic_dist_word2(Nouns, Topics, TopicDist, Word) :- dist(TopicDist, Topics, Topic), dist(Topic, Nouns, Word).