Skip to content

Commit 7bb97a8

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent ac19e67 commit 7bb97a8

1 file changed

Lines changed: 6 additions & 13 deletions

File tree

machine_learning/apriori_algorithm.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def load_data() -> list[list[str]]:
2727

2828
# ---------- Helpers ----------
2929

30+
3031
def get_support(itemset, transactions):
3132
"""Compute support count of an itemset efficiently."""
3233
return sum(1 for t in transactions if itemset.issubset(t))
@@ -60,6 +61,7 @@ def has_infrequent_subset(candidate, prev_frequent):
6061

6162
# ---------- Main Apriori ----------
6263

64+
6365
def apriori(data: list[list[str]], min_support: int):
6466
transactions = [set(t) for t in data]
6567

@@ -70,14 +72,11 @@ def apriori(data: list[list[str]], min_support: int):
7072
item_counts[frozenset([item])] += 1
7173

7274
frequent = {
73-
itemset for itemset, count in item_counts.items()
74-
if count >= min_support
75+
itemset for itemset, count in item_counts.items() if count >= min_support
7576
}
7677

7778
all_frequents = [
78-
(next(iter(i)), c)
79-
for i, c in item_counts.items()
80-
if c >= min_support
79+
(next(iter(i)), c) for i, c in item_counts.items() if c >= min_support
8180
]
8281

8382
k = 2
@@ -87,10 +86,7 @@ def apriori(data: list[list[str]], min_support: int):
8786
candidates = generate_candidates(frequent, k)
8887

8988
# 3. prune
90-
candidates = {
91-
c for c in candidates
92-
if not has_infrequent_subset(c, frequent)
93-
}
89+
candidates = {c for c in candidates if not has_infrequent_subset(c, frequent)}
9490

9591
# 4. count support
9692
candidate_counts = defaultdict(int)
@@ -100,10 +96,7 @@ def apriori(data: list[list[str]], min_support: int):
10096
candidate_counts[c] += 1
10197

10298
# 5. filter frequent
103-
frequent = {
104-
c for c, count in candidate_counts.items()
105-
if count >= min_support
106-
}
99+
frequent = {c for c, count in candidate_counts.items() if count >= min_support}
107100

108101
all_frequents.extend(
109102
(sorted(c), count)

0 commit comments

Comments
 (0)