Skip to content

Commit 63b316e

Browse files
committed
refactor: streamline candidate generation and pruning in Apriori algorithm
1 parent 791deb4 commit 63b316e

1 file changed

Lines changed: 76 additions & 63 deletions

File tree

machine_learning/apriori_algorithm.py

Lines changed: 76 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
Examples: https://www.kaggle.com/code/earthian/apriori-association-rules-mining
1212
"""
1313

14-
from collections import Counter
1514
from itertools import combinations
15+
from collections import defaultdict
1616

1717

1818
def load_data() -> list[list[str]]:
@@ -25,78 +25,91 @@ def load_data() -> list[list[str]]:
2525
return [["milk"], ["milk", "butter"], ["milk", "bread"], ["milk", "bread", "chips"]]
2626

2727

28-
def prune(itemset: list, candidates: list, length: int) -> list:
29-
"""
30-
Prune candidate itemsets that are not frequent.
31-
The goal of pruning is to filter out candidate itemsets that are not frequent. This
32-
is done by checking if all the (k-1) subsets of a candidate itemset are present in
33-
the frequent itemsets of the previous iteration (valid subsequences of the frequent
34-
itemsets from the previous iteration).
35-
36-
Prunes candidate itemsets that are not frequent.
37-
38-
>>> itemset = ['X', 'Y', 'Z']
39-
>>> candidates = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
40-
>>> prune(itemset, candidates, 2)
41-
[['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
42-
43-
>>> itemset = ['1', '2', '3', '4']
44-
>>> candidates = ['1', '2', '4']
45-
>>> prune(itemset, candidates, 3)
46-
[]
28+
# ---------- Helpers ----------
29+
30+
def get_support(itemset, transactions):
31+
"""Compute support count of an itemset efficiently."""
32+
return sum(1 for t in transactions if itemset.issubset(t))
33+
34+
35+
def generate_candidates(prev_frequent, k):
4736
"""
48-
itemset_counter = Counter(tuple(item) for item in itemset)
49-
pruned = []
50-
for candidate in candidates:
51-
is_subsequence = True
52-
for item in candidate:
53-
item_tuple = tuple(item)
54-
if (
55-
item_tuple not in itemset_counter
56-
or itemset_counter[item_tuple] < length - 1
57-
):
58-
is_subsequence = False
59-
break
60-
if is_subsequence:
61-
pruned.append(candidate)
62-
return pruned
63-
64-
65-
def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]:
37+
Generate candidate itemsets of size k from frequent itemsets of size k-1.
6638
"""
67-
Returns a list of frequent itemsets and their support counts.
39+
prev_list = list(prev_frequent)
40+
candidates = set()
6841

69-
>>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']]
70-
>>> apriori(data, 2)
71-
[(['A', 'B'], 1), (['A', 'C'], 2), (['B', 'C'], 2)]
42+
for i in range(len(prev_list)):
43+
for j in range(i + 1, len(prev_list)):
44+
union = prev_list[i] | prev_list[j]
45+
if len(union) == k:
46+
candidates.add(union)
7247

73-
>>> data = [['1', '2', '3'], ['1', '2'], ['1', '3'], ['1', '4'], ['2', '3']]
74-
>>> apriori(data, 3)
75-
[]
48+
return candidates
49+
50+
51+
def has_infrequent_subset(candidate, prev_frequent):
7652
"""
77-
itemset = [list(transaction) for transaction in data]
78-
frequent_itemsets = []
79-
length = 1
53+
Apriori pruning: all (k-1)-subsets must be frequent.
54+
"""
55+
for subset in combinations(candidate, len(candidate) - 1):
56+
if frozenset(subset) not in prev_frequent:
57+
return True
58+
return False
59+
60+
61+
# ---------- Main Apriori ----------
62+
63+
def apriori(data: list[list[str]], min_support: int):
64+
transactions = [set(t) for t in data]
65+
66+
# 1. initial 1-itemsets
67+
item_counts = defaultdict(int)
68+
for t in transactions:
69+
for item in t:
70+
item_counts[frozenset([item])] += 1
71+
72+
frequent = {
73+
itemset for itemset, count in item_counts.items()
74+
if count >= min_support
75+
}
76+
77+
all_frequents = [(list(i)[0], c) for i, c in item_counts.items() if c >= min_support]
78+
79+
k = 2
80+
81+
while frequent:
82+
# 2. generate candidates
83+
candidates = generate_candidates(frequent, k)
84+
85+
# 3. prune
86+
candidates = {
87+
c for c in candidates
88+
if not has_infrequent_subset(c, frequent)
89+
}
8090

81-
while itemset:
82-
# Count itemset support
83-
counts = [0] * len(itemset)
84-
for transaction in data:
85-
for j, candidate in enumerate(itemset):
86-
if all(item in transaction for item in candidate):
87-
counts[j] += 1
91+
# 4. count support
92+
candidate_counts = defaultdict(int)
93+
for t in transactions:
94+
for c in candidates:
95+
if c.issubset(t):
96+
candidate_counts[c] += 1
8897

89-
# Prune infrequent itemsets
90-
itemset = [item for i, item in enumerate(itemset) if counts[i] >= min_support]
98+
# 5. filter frequent
99+
frequent = {
100+
c for c, count in candidate_counts.items()
101+
if count >= min_support
102+
}
91103

92-
# Append frequent itemsets (as a list to maintain order)
93-
for i, item in enumerate(itemset):
94-
frequent_itemsets.append((sorted(item), counts[i]))
104+
all_frequents.extend(
105+
(sorted(list(c)), count)
106+
for c, count in candidate_counts.items()
107+
if count >= min_support
108+
)
95109

96-
length += 1
97-
itemset = prune(itemset, list(combinations(itemset, length)), length)
110+
k += 1
98111

99-
return frequent_itemsets
112+
return all_frequents
100113

101114

102115
if __name__ == "__main__":

0 commit comments

Comments
 (0)