此问题即对一个人是否与女生约会做预测,基本属性是:是否有钱,是否年轻,是否美丽。训练出来模型可以对输入的数据进行分类,最终输出一个人的偏好图。输入时是用的表格导入。
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
import copy
import os
import dtreeviz
# 指定结果保存的目录,确保结果文件能够存放在指定路径下
result_dir = '/Users/winter/Documents/VSCode/Python/DateOrNot/result'
if not os.path.exists(result_dir):
os.makedirs(result_dir) # 创建目录,如果目录已存在则不进行任何操作
# 从 Excel 文件中读取问卷和标准数据
date_download = pd.read_excel(r'/Users/winter/Documents/VSCode/Python/DateOrNot/机器学习原理-22级选修-周一-问卷统计详情.xlsx'
, sheet_name='问卷')
standard_download = pd.read_excel(r'/Users/winter/Documents/VSCode/Python/DateOrNot/YesOrNo.xlsx'
, sheet_name=0)
# 名字列表,从问卷数据中提取参与者的名字
namelist = date_download.iloc[5:, 0].dropna().values.tolist()
# 定义需要处理的列和对应的映射字典
cols = ['相貌', '见不见'] # 需要处理的列名
col_dicts = {
'相貌': {'佳': 3, '普通': 2, '差': 1}, # 将相貌进行数值化映射
'见不见': {'A': 1, 'B': 0} # 将见不见进行数值化映射
}
def processing(name):
"""
处理指定参与者的问卷数据,训练决策树模型,并保存结果。
参数:
name (str): 参与者的名字
"""
global date_download
standard_date = copy.deepcopy(standard_download) # 复制标准数据,以便后续填充
a = int()
pdi = standard_date # 用于存放处理后的数据
# 找到对应名字的行索引
for i in range(date_download.shape[0]):
if date_download.iloc[i, 0] == name:
a = i
# 填充数据,使用问卷数据中的相关信息
for i in range(len(standard_date.index)):
pdi.iloc[i, 4] = date_download.iloc[a, 8 + i] # 将问卷中相应的数据填充到标准数据中
# 映射相貌和见不见的列
for col in cols:
pdi[col] = pdi[col].map(col_dicts[col]) # 将字符串数据映射为数值数据
# 保存处理后的数据到 Excel 文件
pdi.to_excel(os.path.join(result_dir, f'{name}.xlsx'), index=False)
# 准备特征和标签
X = pdi.loc[:, ['工资', '年龄', '相貌']] # 特征数据
Y = pdi.loc[:, '见不见'] # 标签数据
# 更改特征列名以便于理解
X.columns = ['salary', 'age', 'face']
print(X) # 输出特征数据以供检查
if Y.isnull().sum() == 0: # 检查标签数据是否为空
print(Y) # 输出标签数据以供检查
# 创建决策树分类模型
credit_model = DecisionTreeClassifier(criterion='entropy', max_depth=3)
entropy_tree = credit_model.fit(X.values, Y) # 训练模型
# 可视化决策树模型
viz = dtreeviz.model(entropy_tree,
X_train=X,
y_train=Y,
target_name='Yes or No',
feature_names=['salary', 'age', 'face'],
class_names=['No', 'Yes'])
v = viz.view() # 显示可视化结果
v.show() # 展示可视化图形
v.save(os.path.join(result_dir, f'result_{name}.svg')) # 保存可视化结果为 SVG 文件
if __name__ == '__main__':
while len(namelist): # 循环直到处理完所有名单
print('名单序列:', namelist) # 输出当前名单序列
name = input('输入所要计算的人员名字(结束输入q,输出所有a):') # 接受用户输入
if name.lower() == 'q': # 如果用户输入 q,退出循环
break
elif name == 'a': # 如果用户输入 a,处理所有名单
for i in namelist:
print(i) # 输出当前处理的名字
processing(i) # 处理当前名字
break
elif name in namelist: # 如果输入的名字在名单中
namelist.remove(name) # 从名单中移除该名字
processing(name) # 处理该名字
else: # 如果输入的名字不在名单中
print('输入在名单中的名字!') # 提示用户输入有效的名字