1111Examples: https://www.kaggle.com/code/earthian/apriori-association-rules-mining
1212"""
1313
14- from collections import Counter
1514from itertools import combinations
15+ from collections import defaultdict
1616
1717
1818def load_data () -> list [list [str ]]:
@@ -25,78 +25,91 @@ def load_data() -> list[list[str]]:
2525 return [["milk" ], ["milk" , "butter" ], ["milk" , "bread" ], ["milk" , "bread" , "chips" ]]
2626
2727
28- def prune (itemset : list , candidates : list , length : int ) -> list :
29- """
30- Prune candidate itemsets that are not frequent.
31- The goal of pruning is to filter out candidate itemsets that are not frequent. This
32- is done by checking if all the (k-1) subsets of a candidate itemset are present in
33- the frequent itemsets of the previous iteration (valid subsequences of the frequent
34- itemsets from the previous iteration).
35-
36- Prunes candidate itemsets that are not frequent.
37-
38- >>> itemset = ['X', 'Y', 'Z']
39- >>> candidates = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
40- >>> prune(itemset, candidates, 2)
41- [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
42-
43- >>> itemset = ['1', '2', '3', '4']
44- >>> candidates = ['1', '2', '4']
45- >>> prune(itemset, candidates, 3)
46- []
28+ # ---------- Helpers ----------
29+
30+ def get_support (itemset , transactions ):
31+ """Compute support count of an itemset efficiently."""
32+ return sum (1 for t in transactions if itemset .issubset (t ))
33+
34+
35+ def generate_candidates (prev_frequent , k ):
4736 """
48- itemset_counter = Counter (tuple (item ) for item in itemset )
49- pruned = []
50- for candidate in candidates :
51- is_subsequence = True
52- for item in candidate :
53- item_tuple = tuple (item )
54- if (
55- item_tuple not in itemset_counter
56- or itemset_counter [item_tuple ] < length - 1
57- ):
58- is_subsequence = False
59- break
60- if is_subsequence :
61- pruned .append (candidate )
62- return pruned
63-
64-
65- def apriori (data : list [list [str ]], min_support : int ) -> list [tuple [list [str ], int ]]:
37+ Generate candidate itemsets of size k from frequent itemsets of size k-1.
6638 """
67- Returns a list of frequent itemsets and their support counts.
39+ prev_list = list (prev_frequent )
40+ candidates = set ()
6841
69- >>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']]
70- >>> apriori(data, 2)
71- [(['A', 'B'], 1), (['A', 'C'], 2), (['B', 'C'], 2)]
42+ for i in range (len (prev_list )):
43+ for j in range (i + 1 , len (prev_list )):
44+ union = prev_list [i ] | prev_list [j ]
45+ if len (union ) == k :
46+ candidates .add (union )
7247
73- >>> data = [['1', '2', '3'], ['1', '2'], ['1', '3'], ['1', '4'], ['2', '3']]
74- >>> apriori(data, 3)
75- []
48+ return candidates
49+
50+
51+ def has_infrequent_subset (candidate , prev_frequent ):
7652 """
77- itemset = [list (transaction ) for transaction in data ]
78- frequent_itemsets = []
79- length = 1
53+ Apriori pruning: all (k-1)-subsets must be frequent.
54+ """
55+ for subset in combinations (candidate , len (candidate ) - 1 ):
56+ if frozenset (subset ) not in prev_frequent :
57+ return True
58+ return False
59+
60+
61+ # ---------- Main Apriori ----------
62+
63+ def apriori (data : list [list [str ]], min_support : int ):
64+ transactions = [set (t ) for t in data ]
65+
66+ # 1. initial 1-itemsets
67+ item_counts = defaultdict (int )
68+ for t in transactions :
69+ for item in t :
70+ item_counts [frozenset ([item ])] += 1
71+
72+ frequent = {
73+ itemset for itemset , count in item_counts .items ()
74+ if count >= min_support
75+ }
76+
77+ all_frequents = [(list (i )[0 ], c ) for i , c in item_counts .items () if c >= min_support ]
78+
79+ k = 2
80+
81+ while frequent :
82+ # 2. generate candidates
83+ candidates = generate_candidates (frequent , k )
84+
85+ # 3. prune
86+ candidates = {
87+ c for c in candidates
88+ if not has_infrequent_subset (c , frequent )
89+ }
8090
81- while itemset :
82- # Count itemset support
83- counts = [0 ] * len (itemset )
84- for transaction in data :
85- for j , candidate in enumerate (itemset ):
86- if all (item in transaction for item in candidate ):
87- counts [j ] += 1
91+ # 4. count support
92+ candidate_counts = defaultdict (int )
93+ for t in transactions :
94+ for c in candidates :
95+ if c .issubset (t ):
96+ candidate_counts [c ] += 1
8897
89- # Prune infrequent itemsets
90- itemset = [item for i , item in enumerate (itemset ) if counts [i ] >= min_support ]
98+ # 5. filter frequent
99+ frequent = {
100+ c for c , count in candidate_counts .items ()
101+ if count >= min_support
102+ }
91103
92- # Append frequent itemsets (as a list to maintain order)
93- for i , item in enumerate (itemset ):
94- frequent_itemsets .append ((sorted (item ), counts [i ]))
104+ all_frequents .extend (
105+ (sorted (list (c )), count )
106+ for c , count in candidate_counts .items ()
107+ if count >= min_support
108+ )
95109
96- length += 1
97- itemset = prune (itemset , list (combinations (itemset , length )), length )
110+ k += 1
98111
99- return frequent_itemsets
112+ return all_frequents
100113
101114
102115if __name__ == "__main__" :
0 commit comments