1:- module(egraph, [add_term//2, union//2, saturate//1, saturate//2,
2 extract/1, extract//0, lookup/2]).
27:- use_module(library(dcg/high_order)). 28:- use_module(library(ordsets)). 29:- use_module(library(rbtrees)). 30:- use_module(library(clpr)). 31
32:- use_module(egraph/compile). 33
34cost:attr_unify_hook(_, _) :-
35 true.
36const:attr_unify_hook(XConst, Y) :-
37 ( get_attr(Y, const, YConst)
38 -> ( XConst =:= YConst
39 -> true
40 ; domain_error(XConst, YConst)
41 )
42 ; var(Y)
43 -> put_attr(Y, const, XConst)
44 ; true
45 ).
55lookup(Item-V, [X1-V1, X2-V2, X3-V3, X4-V4|Xs]) :-
56 !,
57 compare(R4, Item, X4),
58 ( R4=(>)
59 -> lookup(Item-V, Xs)
60 ; R4=(<)
61 -> compare(R2, Item, X2),
62 ( R2=(>)
63 -> Item==X3, V = V3
64 ; R2=(<)
65 -> Item==X1, V = V1
66 ; V = V2
67 )
68 ; V = V4
69 ).
70lookup(Item-V, [X1-V1, X2-V2|Xs]) :-
71 !,
72 compare(R2, Item, X2),
73 ( R2=(>)
74 -> lookup(Item-V, Xs)
75 ; R2=(<)
76 -> Item==X1, V = V1
77 ; V = V2
78 ).
79lookup(Item-V, [X1-V1]) :-
80 Item==X1, V = V1.
92add_term(Term, Id), var(Term) ==>
93 add_node('$VAR'(Term), Id).
94add_term(Term, Id), is_dict(Term) ==>
95 {
96 dict_pairs(Term, Tag, Pairs),
97 pairs_keys_values(Pairs, Keys, Values)
98 },
99 foldl(add_term, Values, Ids),
100 {
101 pairs_keys_values(Data, Keys, Ids),
102 dict_create(Node, Tag, Data)
103 },
104 add_node(Node, Id).
105add_term(Term, Id), compound(Term) ==>
106 { Term =.. [F | Args] },
107 foldl(add_term, Args, Ids),
108 { Node =.. [F | Ids] },
109 add_node(Node, Id).
110add_term(Term, Id) ==>
111 add_node(Term, Id).
112
113add_node(Node-Id, In, Out) :-
114 add_node(Node, Id, In, Out).
115add_node(Node, Id, In, Out) :-
116 ( lookup(Node-node(Id, _Cost), In)
117 -> Out = In
118 ; ord_add_element(In, Node-node(Id, 1), Out),
119 ( number(Node)
120 -> put_attr(Id, const, Node)
121 ; true
122 )
123 ).
124
125rules([Rule | Rules], Index, Pat-node(Id, Cost), UnifsIn, UnifsOut) -->
126 call(Rule, Pat, Id, Index, UnifsIn, UnifsTmp),
127 rules(Rules, Index, Pat-node(Id, Cost), UnifsTmp, UnifsOut).
128rules([], _, _, Unifs, Unifs) --> [].
129
130make_index(In, Index) :-
131 index_pairs(In, UnsortedPairs),
132 keysort(UnsortedPairs, IdPairs),
133 group_pairs_by_key(IdPairs, Groups),
134 ord_list_to_rbtree(Groups, Index).
135
136index_pairs([], []).
137index_pairs([Node-node(Id, _Cost)|T0], [Id-Node|T1]) :-
138 index_pairs(T0, T1).
139
140match([], _, _, Unifs, Unifs) --> [].
141match([Node | Rest], Rules, Index, UnifsIn, UnifsOut) -->
142 rules(Rules, Index, Node, UnifsIn, UnifsTmp),
143 match(Rest, Rules, Index, UnifsTmp, UnifsOut).
152union(A, A) -->
153 merge_nodes.
154
155merge_nodes(In, Out) :-
156 sort(In, Sort),
157 group_pairs_by_key(Sort, Groups),
158 merge_groups(Groups, Tmp, false, Merged),
159 ( Merged == true
160 -> merge_nodes(Tmp, Out)
161 ; Out = Sort
162 ).
163
164merge_groups([Sig-[H | T] | Nodes], [Sig-Node | Worklist], In, Out) :-
165 merge_group(T, H, Node),
166 ( T == []
167 -> Tmp = In
168 ; Tmp = true
169 ),
170 merge_groups(Nodes, Worklist, Tmp, Out).
171merge_groups([], [], In, In).
172
173merge_group([], Node, Node).
174merge_group([node(Id, Cost) | T], node(Id, PrevCost), Out) :-
175 MinCost is min(Cost, PrevCost),
176 merge_group(T, node(Id, MinCost), Out).
177
178apply_unifs([]).
179apply_unifs([A=A | L]) :-
180 apply_unifs(L).
181
182rebuild(Matches, Unifs, Out) :-
183 apply_unifs(Unifs),
184 merge_nodes(Matches, Out).
193saturate(Rules) -->
194 saturate(Rules, inf).
204saturate(Rules, N, In, Out) :-
205 ( N > 0
206 -> make_index(In, Index),
207 match(In, Rules, Index, Unifs, [], Matches, In),
208 rebuild(Matches, Unifs, Tmp),
209 length(In, Len1),
210 length(Tmp, Len2),
211 ( Len1 \== Len2
212 -> ( N == inf
213 -> N1 = N
214 ; N1 is N - 1
215 ),
216 saturate(Rules, N1, Tmp, Out)
217 ; Out = Tmp
218 )
219 ; Out = In
220 ).
229extract(Nodes) :-
230 extract(Nodes, Nodes).
231extract(Nodes, Nodes) :-
232 transpose_pairs(Nodes, Pairs),
233 maplist([node(Id, Cost)-Node, Id-node(Cost, Node)]>>true, Pairs, IdPairs),
234 group_pairs_by_key(IdPairs, ClassNodes),
235 maplist([Id-_Node]>>({Cost >= 0}, put_attr(Id, cost, Cost)), ClassNodes),
236 maplist(compute_class_cost, ClassNodes, NewClassNodes),
237 maplist(extract_class, NewClassNodes).
238
(Id-Nodes) :-
240 241 sort(Nodes, SortedNodes),
242 member(node(_Cost, Node), SortedNodes),
243 ( Node = '$VAR'(Var)
244 -> Id = Var
245 ; Id = Node
246 ),
247 ( var(Id)
248 -> del_attr(Id, cost)
249 ; true
250 ).
251
252compute_class_cost(Id-Nodes, Id-NewNodes) :-
253 maplist(compute_node_cost, Nodes, NewNodes, NodeCosts),
254 NodeCosts = [FirstCost | RestCosts],
255 foldl([NodeCost, Cost, MinCost]>>
256 {MinCost = min(NodeCost, Cost)},
257 RestCosts, FirstCost, ClassCost),
258 get_attr(Id, cost, ClassCost).
259compute_node_cost(node(Offset, Node), node(Cost, Node), Cost) :-
260 ( Node = '$VAR'(_)
261 -> Cost = Offset
262 ; is_dict(Node)
263 -> dict_pairs(Node, _, Pairs),
264 pairs_keys_values(Pairs, _, Ids),
265 foldl([Id, In, Out]>>(
266 get_attr(Id, cost, IdCost),
267 {Out = In + IdCost}
268 ), Ids, 0, CCost),
269 { Cost = CCost + Offset }
270 ; compound(Node)
271 -> Node =.. [_ | Ids],
272 foldl([Id, In, Out]>>(
273 get_attr(Id, cost, IdCost),
274 {Out = In + IdCost}
275 ), Ids, 0, CCost),
276 { Cost = CCost + Offset }
277 ; Cost = Offset
278 )
E-graph implementation for term rewriting and saturation
This module implements an E-graph (Equivalence Graph) data structure, commonly used for efficient term rewriting, congruence closure, and e-matching. The E-graph state is typically threaded through operations using DCG notation.
Rewrite rules are automatically compiled into efficient DCG predicates via term expansion. See the
egraph_compilemodule for full details. The supported rule declarations are:rewrite(Name, Lhs, Rhs)rewrite(Name, Lhs, Rhs, RhsOptions)rewrite(Name, Lhs, LhsOptions, Rhs, RhsOptions)rewrite(Name, Lhs, LhsOptions, Rhs, RhsOptions):- BodyMain predicates:
term(s)from the E-graph based on term costs.*/