- 首页 > it技术 > >
import numpy as np
import pandas as pd## 方法一:
def generate_C1(data_set):
c1 = set()
for items in data_set:
for item in items:
item_set = frozenset([item])
c1.add(item_set)
return c1#计算Ck在数据集D中的支持度,并返回支持度大于minSupport的数据集
def get_supports(data_set, ck, min_support,supports):
freq_set = set()
item_count = {}
# supports = {}
for data in data_set:
for item in ck:
if item.issubset(data):
if item not in item_count:
item_count[item] = 1
else:
item_count[item] += 1n= float(len(data_set))
for item in item_count:
if (item_count[item] / n) >= min_support:
freq_set.add(item)
supports[item] = item_count[item] / nreturn freq_set#剪枝
def get_new_set(d, k):
new_set = set()
n = len(d)
freqset_list = list(d)for i in range(n):
for j in range(i + 1, n):
L1= list(freqset_list[i])
L2 = list(freqset_list[j])
L1.sort()
L2.sort()
# 若两个集合的前k-2个项相同时,则将两个集合合并
if L1[0:k - 2] == L2[0:k - 2]:
freq_item = freqset_list[i] | freqset_list[j]
new_set.add(freq_item)return new_setdef apriori(data_set, min_support, k):
freq_sets = []
supports = {}
c1=generate_C1(data_set)
L1= get_supports(data_set, c1, min_support,supports)
Lksub1 = L1.copy()
freq_sets.append(Lksub1)
for i in range(2,k+1):
ci=get_new_set(Lksub1,i)
Li=get_supports(data_set, ci, min_support,supports)
Lksub1 = Li.copy()
freq_sets.append(Lksub1)return freq_sets,supportsdef find_rules(freq_sets, supports, min_conf):
rules = []
n = len(freq_sets)
for i in range(n - 1):
for freq_set in freq_sets[i]:#前项
for sub_set in freq_sets[i + 1]:#后项
if freq_set.issubset(sub_set):#是否频繁
support = supports[sub_set]
conf = supports[sub_set] / supports[freq_set]#定义置信度
rule = (freq_set, sub_set - freq_set, support, conf)
if conf >= min_conf:
print([i for i in freq_set], "-->", [i for i in sub_set - freq_set], 'support:', support, 'conf:', conf)
rules.append(rule)
return rulesif __name__ == '__main__':
inputfile = 'G:/untitled/venv/tmp/apriori_rules2.csv'
data = https://www.it610.com/article/pd.read_csv(inputfile)
data = pd.DataFrame(data)
data = data.drop_duplicates()
data_set = []
groups = data.groupby(by='日期')
for group in groups:
if len(group[1]) >= 2:
data_set.append(group[1]['股票代码'].tolist())L, support = apriori(data_set, k=3, min_support=0.4)
find_rules = find_rules(L, support, min_conf=0.6)
print('关联规则:\n', find_rules)
outputfile = 'G:/untitled/venv/tmp/apriori_rules3.csv'
df=data = https://www.it610.com/article/pd.DataFrame( find_rules)
df.to_csv(outputfile, index=False,header=None)
推荐阅读