数据预处理

In [2]:
import pandas as pd
import numpy as np
In [4]:
data = pd.read_csv('./input/dress-recommendation.csv')
data.head()
Out[4]:
Dress_ID Style Price Rating Size Season NeckLine SleeveLength waiseline Material FabricType Decoration Pattern Type Recommendation
0 1006032852 Sexy Low 4.6 M Summer o-neck sleevless empire null chiffon ruffles animal 1
1 1212192089 Casual Low 0.0 L Summer o-neck Petal natural microfiber null ruffles animal 0
2 1190380701 vintage High 0.0 L Automn o-neck full natural polyster null null print 0
3 966005983 Brief Average 4.6 L Spring o-neck full natural silk chiffon embroidary print 1
4 876339541 cute Low 4.5 M Summer o-neck butterfly natural chiffonfabric chiffon bow dot 0
In [3]:
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 500 entries, 0 to 499
Data columns (total 14 columns):
Dress_ID          500 non-null int64
Style             500 non-null object
Price             498 non-null object
Rating            500 non-null float64
Size              500 non-null object
Season            498 non-null object
NeckLine          497 non-null object
SleeveLength      498 non-null object
waiseline         499 non-null object
Material          499 non-null object
FabricType        499 non-null object
Decoration        499 non-null object
Pattern Type      499 non-null object
Recommendation    500 non-null int64
dtypes: float64(1), int64(2), object(11)
memory usage: 54.8+ KB
In [6]:
data['Style'].value_counts()
Out[6]:
Casual      232
Sexy         69
party        51
cute         45
vintage      25
bohemian     24
Brief        18
work         17
Novelty       8
sexy          7
Flare         2
OL            1
fashion       1
Name: Style, dtype: int64

特征编码

In [23]:
dummpy_list = ['Style', 'Price', 'Size', 'Season', 'NeckLine', 'SleeveLength', 'waiseline','Material', 'FabricType','Decoration','Pattern Type']
transfer_list = [pd.get_dummies(data[[feature]], prefix = feature+'_') for feature in dummpy_list]
In [27]:
transfer_list.append(data[['Rating','Recommendation']])
In [30]:
new_data = pd.concat(transfer_list, axis = 1)

数据标签的均衡性

In [31]:
new_data['Recommendation'].value_counts()
Out[31]:
0    290
1    210
Name: Recommendation, dtype: int64

自定义评价函数

In [38]:
def evaluate(pred,test_y):
    
    import seaborn as sns
    import matplotlib.pyplot as plt
    import sklearn.metrics as metrics
    # 输出分类的准确率
    print("Accuracy: %.4f"  % (metrics.accuracy_score(test_y,pred)))
    
    # 输出衡量分类效果的各项指标
    print(metrics.classification_report(test_y, pred)) 
    
    # 更直观的,我们通过seaborn画出混淆矩阵
    %matplotlib inline
    plt.figure(figsize=(6,4))
    colorMetrics = metrics.confusion_matrix(test_y,pred)
    
    # 坐标y代表test_y,即真实的类别,坐标x代表估计出的类别pred
    sns.heatmap(colorMetrics, annot=True, fmt='d', xticklabels=[0,1],yticklabels=[0,1])
    sns.plt.show()

高斯混合模型

In [32]:
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(new_data.iloc[:,:-1], new_data['Recommendation'], test_size = .3, random_state = 0)

协方差矩阵的取值为: {‘full’, ‘tied’, ‘diag’, ‘spherical’},默认取值为full

  • 'full'

高斯混合模型的每个组成部分(component)都有自己独立的协方差矩阵

  • 'tied'

高斯混合模型的所有组成部分(component)都分享相同的协方差矩阵

  • 'diag'

高斯混合模型的每个组成部分(component)都有自己的对角化的协方差矩阵

  • 'spherical'

斯混合模型的每个组成部分(component)拥有一个方差

In [65]:
cov_type = ['full', 'tied', 'diag', 'spherical']

model_list = [GaussianMixture(n_components = 2, covariance_type = cov_name, init_params='random', max_iter=20) for cov_name in cov_type]
In [68]:
# 标签
reco_label = np.unique(y_train)

mean_list = np.array([X_train[y_train == value].mean(axis = 0) for value in reco_label])
mean_list
Out[68]:
array([[ 0.04477612,  0.47761194,  0.00497512,  0.01492537,  0.00497512,
         0.14925373,  0.04477612,  0.07462687,  0.        ,  0.05472637,
         0.00995025,  0.07462687,  0.04477612,  0.55721393,  0.01492537,
         0.25373134,  0.039801  ,  0.039801  ,  0.08457711,  0.00995025,
         0.21890547,  0.33333333,  0.05472637,  0.039801  ,  0.35323383,
         0.        ,  0.        ,  0.14427861,  0.0199005 ,  0.14925373,
         0.37313433,  0.22885572,  0.00497512,  0.        ,  0.07960199,
         0.        ,  0.01492537,  0.00497512,  0.02487562,  0.02985075,
         0.        ,  0.00497512,  0.58208955,  0.00497512,  0.01492537,
         0.        ,  0.04975124,  0.0199005 ,  0.00497512,  0.03482587,
         0.20895522,  0.        ,  0.        ,  0.00497512,  0.00497512,
         0.15920398,  0.00497512,  0.07960199,  0.24875622,  0.00497512,
         0.01492537,  0.41293532,  0.        ,  0.03482587,  0.        ,
         0.02487562,  0.        ,  0.00497512,  0.00995025,  0.17412935,
         0.64676617,  0.16915423,  0.        ,  0.00995025,  0.00497512,
         0.06965174,  0.30348259,  0.00497512,  0.00497512,  0.01492537,
         0.00995025,  0.00497512,  0.01492537,  0.02985075,  0.00497512,
         0.        ,  0.19402985,  0.02487562,  0.        ,  0.21393035,
         0.00995025,  0.00497512,  0.04477612,  0.00497512,  0.01492537,
         0.00497512,  0.00497512,  0.00497512,  0.        ,  0.05472637,
         0.26865672,  0.00497512,  0.00497512,  0.00497512,  0.0199005 ,
         0.        ,  0.00497512,  0.        ,  0.56716418,  0.        ,
         0.        ,  0.        ,  0.        ,  0.01492537,  0.01492537,
         0.        ,  0.        ,  0.00497512,  0.        ,  0.02985075,
         0.00497512,  0.02985075,  0.02487562,  0.02487562,  0.01492537,
         0.00497512,  0.00497512,  0.00497512,  0.00995025,  0.        ,
         0.00497512,  0.05472637,  0.11940299,  0.00497512,  0.47761194,
         0.00497512,  0.00497512,  0.        ,  0.01492537,  0.01492537,
         0.00995025,  0.04477612,  0.10447761,  0.01492537,  0.        ,
         0.039801  ,  0.00497512,  0.00995025,  0.00995025,  0.00497512,
         0.        ,  0.00497512,  0.        ,  0.16915423,  0.09452736,
         0.00995025,  0.15422886,  0.44278607,  0.00497512,  0.04975124,
         3.49353234],
       [ 0.02684564,  0.42281879,  0.        ,  0.        ,  0.        ,
         0.1409396 ,  0.06711409,  0.11409396,  0.        ,  0.16778523,
         0.01342282,  0.03355705,  0.01342282,  0.44295302,  0.00671141,
         0.26845638,  0.08724832,  0.02684564,  0.08053691,  0.0738255 ,
         0.16778523,  0.3557047 ,  0.08053691,  0.02013423,  0.37583893,
         0.        ,  0.        ,  0.10067114,  0.00671141,  0.34899329,
         0.25503356,  0.18791946,  0.00671141,  0.        ,  0.08724832,
         0.        ,  0.05369128,  0.        ,  0.06040268,  0.01342282,
         0.        ,  0.        ,  0.47651007,  0.00671141,  0.        ,
         0.00671141,  0.04697987,  0.        ,  0.        ,  0.01342282,
         0.30872483,  0.        ,  0.        ,  0.00671141,  0.00671141,
         0.24832215,  0.        ,  0.04697987,  0.12751678,  0.00671141,
         0.00671141,  0.48322148,  0.        ,  0.02684564,  0.00671141,
         0.02013423,  0.00671141,  0.        ,  0.01342282,  0.22818792,
         0.52348993,  0.22818792,  0.        ,  0.        ,  0.01342282,
         0.04697987,  0.2885906 ,  0.        ,  0.        ,  0.        ,
         0.00671141,  0.00671141,  0.00671141,  0.02684564,  0.        ,
         0.00671141,  0.30872483,  0.02684564,  0.        ,  0.14765101,
         0.03355705,  0.00671141,  0.05369128,  0.        ,  0.00671141,
         0.00671141,  0.        ,  0.        ,  0.00671141,  0.04697987,
         0.27516779,  0.00671141,  0.        ,  0.        ,  0.03355705,
         0.00671141,  0.        ,  0.        ,  0.4966443 ,  0.00671141,
         0.        ,  0.00671141,  0.        ,  0.00671141,  0.02684564,
         0.00671141,  0.00671141,  0.        ,  0.00671141,  0.05369128,
         0.        ,  0.04697987,  0.06040268,  0.03355705,  0.01342282,
         0.        ,  0.        ,  0.        ,  0.02013423,  0.01342282,
         0.01342282,  0.01342282,  0.16107383,  0.00671141,  0.4295302 ,
         0.        ,  0.00671141,  0.00671141,  0.01342282,  0.        ,
         0.00671141,  0.03355705,  0.06711409,  0.04026846,  0.00671141,
         0.04697987,  0.        ,  0.04026846,  0.        ,  0.        ,
         0.00671141,  0.00671141,  0.00671141,  0.27516779,  0.10067114,
         0.        ,  0.12080537,  0.36912752,  0.        ,  0.02013423,
         3.75436242]])

评价高斯混合模型

In [69]:
for gmm_model in model_list:
    
    gmm_model.means_ = mean_list
    gmm_model.fit(X_train)
    
    y_pred = gmm_model.predict(X_test)
    evaluate(y_test, y_pred)
Accuracy: 0.5400
             precision    recall  f1-score   support

          0       0.71      0.59      0.65       106
          1       0.30      0.41      0.34        44

avg / total       0.59      0.54      0.56       150

Accuracy: 0.5200
             precision    recall  f1-score   support

          0       0.55      0.60      0.58        81
          1       0.48      0.42      0.45        69

avg / total       0.52      0.52      0.52       150

Accuracy: 0.4267
             precision    recall  f1-score   support

          0       0.21      0.54      0.31        35
          1       0.74      0.39      0.51       115

avg / total       0.62      0.43      0.46       150

Accuracy: 0.4467
             precision    recall  f1-score   support

          0       0.27      0.57      0.37        42
          1       0.70      0.40      0.51       108

avg / total       0.58      0.45      0.47       150