Lasso 回归分析

1.原理

​ LASSO,全称Least absolute shrinkage and selection operator,是一种数据挖掘方法,即在常用的多元线性回归中,添加惩罚函数,不断压缩系数,从而达到精简模型的目的,以避免共线性和过拟合。当系数为0时,同时达到筛选变量的效果。

​ LASSO回归高效解决了筛选变量的难题:区别于传统的逐步回归stepwise前进、后退变量筛选方法,LASSO回归可以利用较少样本量,高效筛选较多变量。比如在基因组学、影像学、以及其他小样本分析中,LASSO回归都可以派上大用场。

2.实战 R代码

添加Times New Roman的字体格式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
library(showtext)
library(sysfonts)
font_families()
showtext_auto()
path <- function(x) {
font_path <- 'C:/Windows/Fonts/'
paste0(font_path, x)
}
font_add('Times New Roman',
regular = path("times.ttf"),
bold = path("timesbd.ttf"),
italic = path('timesi.ttf'),
bolditalic=path('timesbi.ttf'))
my_theme <-ggplot2::theme(
text = ggplot2::element_text(family = "Times New Roman", size = 16),
plot.title = ggplot2::element_text(family = "Times New Roman", size = 16),
axis.title = ggplot2::element_text(family = "Times New Roman", size = 16),
axis.text = ggplot2::element_text(family = "Times New Roman", size = 16),
legend.title = ggplot2::element_text(family = "Times New Roman", size = 16),
legend.text = ggplot2::element_text(family = "Times New Roman", size = 16),
strip.text = ggplot2::element_text(family = "Times New Roman", size = 16),
axis.text.x = ggplot2::element_text(family = "Times New Roman", size = 16),
axis.text.y = ggplot2::element_text(family = "Times New Roman", size = 16)
)

lasso回归建立

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
if(F){
library(glmnet)
#mldata为我们的数据 data最后一列为结局变量 二分类

#没有探索的lambda时候的lasso
lasso<- glmnet(lasso_x, lasso_y,family = "binomial", alpha = 1)
coef(lasso)
plot(lasso,xvar = "lambda",label = T,lwd=2)

#这里首先进行lasso cv 寻找最佳lambda的
lasso_x=as.matrix(mldata[,1:ncol(mldata)-1])
lasso_y=mldata$outcome
lasso_cv<-cv.glmnet(lasso_x, lasso_y,family='binomial',
nfolds = 5, alpha=1,standardize = TRUE)
plot(lasso_cv)

# 这里我们以lambda.min为最优 λ
best_lambda <- lasso_cv$lambda.min
best_lambda
lasso_best<- glmnet(lasso_x, lasso_y,family = "binomial", alpha = 1,lambda = best_lambda)
coef(lasso_best)
}

lasso回归优化图形

p2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

if(F){
library(ggsci)
library(tidyverse)
library(reshape2)
cvfit<-lasso_cv
aaa <- data.frame(lamda = cvfit[["lambda"]]%>%log()%>%round(.,3),
cvm = cvfit[["cvm"]],
cvup = cvfit[["cvup"]],
cvlo = cvfit[["cvlo"]],
nzero = as.numeric(cvfit[["nzero"]]))
aaa$nzero2 <- aaa$nzero
aaa$nzero2 <- factor(aaa$nzero2)
#画上面那个坐标轴
xbreaks <- aaa$lamda[seq(1,100,10)]
xlabels <- as.character(aaa$nzero[seq(1,100,10)])
#实际大小相同,但是改变label

#先画交叉验证的那个,p1
figure3C1 <- ggplot(data = aaa)+
geom_errorbar(aes(x=lamda,ymin = cvlo,ymax = cvup,color = nzero2),
linewidth = .5,width = 0.1,alpha = 0.5)+
geom_point(aes(x=lamda,y = cvm,color = nzero2))+
scale_color_d3(c('category20'))+
scale_x_continuous(sec.axis = sec_axis(~.,breaks = xbreaks,labels = xlabels))+#上面那个X轴
labs(x = 'Log Lambda', y = 'Binomial Deviance',color = 'vars')+ #color修改的事图例的标题
geom_vline(xintercept = log(c(cvfit$lambda.min,cvfit$lambda.1se)),lty = 2,col = 'grey50')+
theme_classic()+
theme(axis.title = element_text(size = 15),
axis.text = element_text(size = 12),
axis.line = element_line(linewidth = .5),
legend.position = c(0.7,0.6),
legend.background = element_blank())+
guides(colour = guide_legend(ncol = 2, byrow = F))#图例拆分为2列
figure3C1

fit<-lasso
#然后画p2
#dev存放在了这里
dev <- fit[["dev.ratio"]]
#变量数存放在这里,但其实和上面那个是一样的
df <- fit[["df"]]
# lamda
lamda <- fit[["lambda"]]%>%log()
# beta 值
aaa2 <- as.matrix(fit[["beta"]])
#过滤掉那些一直被筛选出去了的,也就是系数一直是0的,眼不见为净。尤其是变量很多的时候
aaa2 <- aaa2[apply(aaa2,1,sum) != 0,]
#值是coefficients
#L1范数的计算方法:
norm <- apply(abs(aaa2),2,sum)
#现在,绘图需要的数据已经提取完毕

#将数据整理成绘制线图需要用的格式
bbb <- melt(aaa2)
bbb$lamda <- rep(lamda,each = nrow(aaa2))
bbb$dev <- rep(dev,each = nrow(aaa2))
bbb$df <- rep(df,each = nrow(aaa2))
bbb$norm <- rep(norm,each = nrow(aaa2))

#画lamda的
xbreaks <- lamda[seq(1,100,20)]
xlabels <- df[seq(1,100,20)]#用于上面那个x轴
library(RColorBrewer)
mypalette <- c(brewer.pal(11,"Set3"),brewer.pal(11,"Spectral"),brewer.pal(8,"Accent"))

figure3C2 <- ggplot(data = bbb,aes(x = lamda,y = value,color = Var1))+
geom_smooth(linewidth = 0.7,se = F)+
labs(x = 'Log Lambda', y = 'Coefficients',color = ' ')+
scale_color_ucscgb()+
scale_x_continuous(sec.axis = sec_axis(~.,breaks = xbreaks,labels = xlabels))+
scale_y_continuous(limits = c(-12,15))+
theme_classic()+
theme(legend.background = element_blank()
)+
guides(colour = guide_legend(ncol = 2, byrow = F))+my_theme+
scale_color_manual(name=" ",values = mypalette)+
theme(legend.text = ggplot2::element_text(family = "Times New Roman", size = 8))

figure3C2
}