算法应用–携程客户流失预测
此次弹性网络和Lasso应用的数据依然为携程客户流失数据
载入数据及训练好的模型
library(glmnet)
library(tidyverse)
library(pROC)
library(broom)
load("~/Documents/GitHub/customer_loss/data/df_train.RDs")
load("~/Documents/GitHub/customer_loss/data/df_test.RDs")
load("~/Documents/GitHub/customer_loss/data/enet_md_cv.model")
load("~/Documents/GitHub/customer_loss/data/enet_md.model")
load("~/Documents/GitHub/customer_loss/data/grid_lg.RDs")
load("~/Documents/GitHub/customer_loss/data/lasso_md_cv.model")
load("~/Documents/GitHub/customer_loss/data/lasso_md.model")
load("~/Documents/GitHub/customer_loss/data/lasso_md_scale.model")
训练模型过程
# # 用于在alpha固定的情况下,对lambda进行交叉验证,并输出相应的结果
# get_enet <- function(input, response, alpha,method = "linear") {
# if(method == "logistic"){
# enet <- cv.glmnet(x = input, y = response,family = "binomial", alpha = alpha,type.measure = "auc")
# }else{
# enet <- cv.glmnet(x = input, y = response, alpha = alpha)
# }
# list(
# alpha = alpha,
# # 平均交叉验证误差最小的lambda
# lambda_min = enet$lambda.min,
# # lambda = lambda.min时非0系数的个数
# coef_min = with(enet, nzero[lambda == lambda.min]),
# # lambda = lambda.min时均方误差
# mse_min = with(enet, cvm[lambda == lambda.min]),
# # 交叉验证误差的标准差最小的lambda
# lambda_1se = enet$lambda.1se,
# coef = with(enet, nzero[lambda == lambda.1se]),
# mse_1se = with(enet, cvm[lambda == lambda.1se])
# )
# }
#
# alphas = seq(0, 1, .1)
#
# (grid_lg <- map_df(
# alphas,
# ~ get_enet(
# df_train[, -1] %>% as.matrix(),
# df_train$label,
# alpha = .x,
# method = "logistic"
# )
# ) %>%
# arrange(mse_1se))
# # 输出最优alpha和lambda
# best_alpha_lg <- grid_lg$alpha[1]
# best_lambda_lg <- grid_lg$lambda_min[1]
#
# # 训练模型
# enet_md <- glmnet(
# x = df_train[, -1] %>% as.matrix(),
# y = df_train$label,
# family = "binomial",
# alpha = best_alpha_lg,
# lambda = best_lambda_lg
# )
# lasso_md_cv <- cv.glmnet(
# x = df_train[, -1] %>% as.matrix(),
# y = df_train$label,
# alpha = 1,
# family = "binomial"
# )
#
# lasso_md <- glmnet(
# x = df_train[, -1] %>% as.matrix(),
# y = df_train$label,
# alpha = 1,
# family = "binomial",
# lambda = best_lambda_lasso
# )
#
# lasso_md_scale <- glmnet(
# x = df_train_scale_x %>% as.matrix(),
# y = df_train_scale_y,
# alpha = 1,
# family = "binomial",
# lambda = best_lambda_lasso
# )
elastic net 模型
# 最优模型变量个数
with(enet_md_cv, nzero[lambda == lambda.min])
## s98
## 34
# 最优模型预测值
pred_train_enet <- predict(
enet_md, df_train[, -1] %>% as.matrix(),
type = "response"
) %>%
as.vector()
pred_test_enet <- predict(
enet_md, df_test[, -1] %>% as.matrix(),
type = "response"
) %>%
as.vector()
# enet ROC & AUC(训练集)
auc(roc(df_train$label, pred_train_enet))
## Area under the curve: 0.7005
plot(roc(df_train$label, pred_train_enet), col="blue", ylab = "enet_train_sensitivity")
# enet ROC & AUC(测试集)
auc(roc(df_test$label, pred_test_enet))
## Area under the curve: 0.7014
plot(roc(df_test$label, pred_test_enet), col="blue", ylab = "enet_test_sensitivity")
Lasso 模型
# 最优模型变量个数
with(lasso_md_cv, nzero[lambda == lambda.min])
## s54
## 29
# 最优模型预测值
pred_train_lasso <- predict(
lasso_md, df_train[, -1] %>% as.matrix(),
type = "response"
) %>%
as.vector()
pred_test_lasso <- predict(
lasso_md, df_test[, -1] %>% as.matrix(),
type = "response"
) %>%
as.vector()
# lasso ROC & AUC(训练集)
auc(roc(df_train$label, pred_train_lasso))
## Area under the curve: 0.7007
plot(roc(df_train$label, pred_train_lasso), col="blue", ylab = "lasso_train_sensitivity")
# lasso ROC & AUC(测试集)
auc(roc(df_test$label, pred_test_lasso))
## Area under the curve: 0.7016
plot(roc(df_test$label, pred_test_lasso), col="blue", ylab = "lasso_train_sensitivity")
数据标准化后lasso
# 变量重要性
coef(lasso_md_scale) %>%
tidy() %>%
filter(row != "(Intercept)") %>%
top_n(20, wt = abs(value)) %>%
mutate(wt = abs(value)) %>%
arrange(-wt)
## Warning: 'tidy.dgCMatrix' is deprecated.
## See help("Deprecated")
## Warning: 'tidy.dgTMatrix' is deprecated.
## See help("Deprecated")
## row column value wt
## 1 intervals s0 -0.30010403 0.30010403
## 2 ordernum_oneyear s0 0.24528873 0.24528873
## 3 iforderpv_24h s0 0.24143794 0.24143794
## 4 cr s0 0.23502623 0.23502623
## 5 visitnum_oneyear s0 -0.23030131 0.23030131
## 6 h s0 -0.18201748 0.18201748
## 7 cityorders s0 0.08861190 0.08861190
## 8 hotelcr s0 0.08389900 0.08389900
## 9 lowestprice s0 -0.08179874 0.08179874
## 10 cancelrate s0 0.07669428 0.07669428
## 11 delta_price2 s0 0.07192665 0.07192665
## 12 hoteluv s0 -0.06455523 0.06455523
## 13 cr_pre s0 0.06367519 0.06367519
## 14 businessrate_pre2 s0 0.06320619 0.06320619
## 15 sid s0 -0.05927907 0.05927907
## 16 lowestprice_pre2 s0 -0.05612325 0.05612325
## 17 lastpvgap s0 0.05231896 0.05231896
## 18 ctrip_profits s0 0.04590612 0.04590612
## 19 avgprice s0 -0.04420086 0.04420086
## 20 novoters_pre s0 0.03913176 0.03913176