3 min read

Elastic Net & Lasso

算法应用–携程客户流失预测

此次弹性网络和Lasso应用的数据依然为携程客户流失数据

载入数据及训练好的模型

library(glmnet)
library(tidyverse)
library(pROC)
library(broom)

load("~/Documents/GitHub/customer_loss/data/df_train.RDs")
load("~/Documents/GitHub/customer_loss/data/df_test.RDs")
load("~/Documents/GitHub/customer_loss/data/enet_md_cv.model")
load("~/Documents/GitHub/customer_loss/data/enet_md.model")
load("~/Documents/GitHub/customer_loss/data/grid_lg.RDs")
load("~/Documents/GitHub/customer_loss/data/lasso_md_cv.model")
load("~/Documents/GitHub/customer_loss/data/lasso_md.model")
load("~/Documents/GitHub/customer_loss/data/lasso_md_scale.model")

训练模型过程

# # 用于在alpha固定的情况下,对lambda进行交叉验证,并输出相应的结果
# get_enet <- function(input, response, alpha,method = "linear") {
#     if(method == "logistic"){
#         enet <- cv.glmnet(x = input, y = response,family = "binomial", alpha = alpha,type.measure = "auc")
#     }else{
#         enet <- cv.glmnet(x = input, y = response, alpha = alpha)
#     }
#     list(
#         alpha = alpha,
#         # 平均交叉验证误差最小的lambda
#         lambda_min = enet$lambda.min,
#         # lambda = lambda.min时非0系数的个数
#         coef_min = with(enet, nzero[lambda == lambda.min]),
#         # lambda = lambda.min时均方误差
#         mse_min = with(enet, cvm[lambda == lambda.min]),
#         # 交叉验证误差的标准差最小的lambda
#         lambda_1se = enet$lambda.1se,
#         coef = with(enet, nzero[lambda == lambda.1se]),
#         mse_1se = with(enet, cvm[lambda == lambda.1se])
#     )
# }
# 
# alphas = seq(0, 1, .1)
# 
# (grid_lg <- map_df(
#     alphas,
#     ~ get_enet(
#         df_train[, -1] %>% as.matrix(),
#         df_train$label,
#         alpha = .x,
#         method = "logistic"
#         )
#     ) %>% 
#     arrange(mse_1se))

# # 输出最优alpha和lambda
# best_alpha_lg <- grid_lg$alpha[1]
# best_lambda_lg <- grid_lg$lambda_min[1]
# 
# # 训练模型
# enet_md <- glmnet(
#     x = df_train[, -1] %>% as.matrix(),
#     y = df_train$label,
#     family = "binomial",
#     alpha = best_alpha_lg,
#     lambda = best_lambda_lg
# )

# lasso_md_cv <- cv.glmnet(
#     x = df_train[, -1] %>% as.matrix(),
#     y = df_train$label,
#     alpha = 1,
#     family = "binomial"
#     )
# 
# lasso_md <- glmnet(
#     x = df_train[, -1] %>% as.matrix(),
#     y = df_train$label,
#     alpha = 1,
#     family = "binomial",
#     lambda = best_lambda_lasso
#     )
# 
# lasso_md_scale <- glmnet(
#     x = df_train_scale_x %>% as.matrix(),
#     y = df_train_scale_y,
#     alpha = 1,
#     family = "binomial",
#     lambda = best_lambda_lasso
# )

elastic net 模型

# 最优模型变量个数
with(enet_md_cv, nzero[lambda == lambda.min])
## s98 
##  34
# 最优模型预测值
pred_train_enet <- predict(
    enet_md, df_train[, -1] %>% as.matrix(), 
    type = "response"
    ) %>% 
    as.vector()

pred_test_enet <- predict(
    enet_md, df_test[, -1] %>% as.matrix(), 
    type = "response"
    ) %>% 
    as.vector()

# enet ROC & AUC(训练集)
auc(roc(df_train$label, pred_train_enet))
## Area under the curve: 0.7005
plot(roc(df_train$label, pred_train_enet), col="blue", ylab = "enet_train_sensitivity")  

# enet ROC & AUC(测试集)
auc(roc(df_test$label, pred_test_enet))
## Area under the curve: 0.7014
plot(roc(df_test$label, pred_test_enet), col="blue", ylab = "enet_test_sensitivity")  

Lasso 模型

# 最优模型变量个数
with(lasso_md_cv, nzero[lambda == lambda.min])
## s54 
##  29
# 最优模型预测值
pred_train_lasso <- predict(
    lasso_md, df_train[, -1] %>% as.matrix(), 
    type = "response"
) %>% 
    as.vector()

pred_test_lasso <- predict(
    lasso_md, df_test[, -1] %>% as.matrix(), 
    type = "response"
) %>% 
    as.vector()

# lasso ROC & AUC(训练集) 
auc(roc(df_train$label, pred_train_lasso))
## Area under the curve: 0.7007
plot(roc(df_train$label, pred_train_lasso), col="blue", ylab = "lasso_train_sensitivity")  

# lasso ROC & AUC(测试集) 
auc(roc(df_test$label, pred_test_lasso))
## Area under the curve: 0.7016
plot(roc(df_test$label, pred_test_lasso), col="blue", ylab = "lasso_train_sensitivity")  

数据标准化后lasso

# 变量重要性
coef(lasso_md_scale) %>% 
    tidy() %>% 
    filter(row != "(Intercept)") %>% 
    top_n(20, wt = abs(value)) %>% 
    mutate(wt = abs(value)) %>% 
    arrange(-wt)
## Warning: 'tidy.dgCMatrix' is deprecated.
## See help("Deprecated")
## Warning: 'tidy.dgTMatrix' is deprecated.
## See help("Deprecated")
##                  row column       value         wt
## 1          intervals     s0 -0.30010403 0.30010403
## 2   ordernum_oneyear     s0  0.24528873 0.24528873
## 3      iforderpv_24h     s0  0.24143794 0.24143794
## 4                 cr     s0  0.23502623 0.23502623
## 5   visitnum_oneyear     s0 -0.23030131 0.23030131
## 6                  h     s0 -0.18201748 0.18201748
## 7         cityorders     s0  0.08861190 0.08861190
## 8            hotelcr     s0  0.08389900 0.08389900
## 9        lowestprice     s0 -0.08179874 0.08179874
## 10        cancelrate     s0  0.07669428 0.07669428
## 11      delta_price2     s0  0.07192665 0.07192665
## 12           hoteluv     s0 -0.06455523 0.06455523
## 13            cr_pre     s0  0.06367519 0.06367519
## 14 businessrate_pre2     s0  0.06320619 0.06320619
## 15               sid     s0 -0.05927907 0.05927907
## 16  lowestprice_pre2     s0 -0.05612325 0.05612325
## 17         lastpvgap     s0  0.05231896 0.05231896
## 18     ctrip_profits     s0  0.04590612 0.04590612
## 19          avgprice     s0 -0.04420086 0.04420086
## 20      novoters_pre     s0  0.03913176 0.03913176