glmnetパッケージのcv.glmnet関数のstandardize引数の超雑なメモ
glmnetパッケージのcv.glmnet関数のstandardize引数の超雑なメモ
?glmnet
には下記の記載がある。
standardize
Logical flag for x variable standardization, prior to fitting the model sequence. The coefficients are always returned on the original scale. Default is standardize=TRUE. If variables are in the same units already, you might not wish to standardize. See details below for y standardization with family="gaussian".
要するに、モデルをフィッティングする前に、変数の標準化を行うための論理フラグで、係数は常に「オリジナルのスケール」で返される。デフォルトは standardize=TRUE。もし変数の単位が既に同じであれば、標準化を行わない方が良い。
標準化していない状態でstandardize=TRUEすれば、計算時に標準化され、罰則を公平にした上での標準化回帰係数が返され、それが通常の回帰係数となって返ってくるということだろう...
「オリジナルのスケール」というのが、おそらくモデルにデータを投入した時点でのスケールを指すのだろうけど、変数の状態と、引数の設定の組み合わせがどのような場合に、どのオリジナルのスケールを返すのかわからなかった…勉強不足。なので、とりあえず、実装を追う時間がないので、力づくでパターンを計算させた。
df <- birthwt %>% dplyr::select(lwt, age, race, smoke) dy <- df %>% dplyr::pull(lwt) dx <- df %>% dplyr::select(-lwt) %>% data.matrix() head(df) lwt age race smoke 85 182 19 2 0 86 155 33 3 0 87 105 20 1 1 88 108 21 1 1 89 107 18 1 1 91 124 21 3 0 # 標準化されたデータ dx_std <- df %>% dplyr::select(-lwt) %>% scale() %>% data.matrix() head(dx_std) age race smoke 85 -0.7998401 0.1670828 -0.800046 86 1.8423284 1.2560014 -0.800046 87 -0.6111138 -0.9218359 1.243315 88 -0.4223875 -0.9218359 1.243315 89 -0.9885665 -0.9218359 1.243315 91 -0.4223875 1.2560014 -0.800046
モデルを4つ準備。計算する前のデータが標準化されているかどうかと、standardize
がTRUEなのかFALSEなのか。
- 非標準化変数 × standardize=FALSE
- 非標準化変数 × standardize=TRUE
- 標準化変数 × standardize=TRUE
- 標準化変数 × standardize=FALSE
set.seed(1) cvfit1 <- cv.glmnet(x = dx, y = dy, standardize = FALSE, alpha = 1) coef(cvfit1, s = "lambda.min") 4 x 1 sparse Matrix of class "dgCMatrix" s1 (Intercept) 114.3915155 age 0.9013971 race -2.9912102 smoke . set.seed(1) cvfit2 <- cv.glmnet(x = dx, y = dy, standardize = TRUE, alpha = 1) coef(cvfit2, s = "lambda.min") 4 x 1 sparse Matrix of class "dgCMatrix" s1 (Intercept) 123.1038264 age 0.8293544 race -5.5912696 smoke -5.7134436 set.seed(1) cvfit3 <- cv.glmnet(x = dx_std, y = dy, standardize = TRUE, alpha = 1) coef(cvfit3, s = "lambda.min") 4 x 1 sparse Matrix of class "dgCMatrix" s1 (Intercept) 129.814815 age 4.394482 race -5.134699 smoke -2.796101 set.seed(1) cvfit4 <- cv.glmnet(x = dx_std, y = dy, standardize = FALSE, alpha = 1) coef(cvfit4, s = "lambda.min") 4 x 1 sparse Matrix of class "dgCMatrix" s1 (Intercept) 129.814815 age 4.394482 race -5.134699 smoke -2.796101
下記の組み合わせの回帰係数は同じ。当たり前だが、標準化したものを、モデルの計算過程で標準化し直しても、値は変わらない。
- 標準化変数 × standardize=TRUE
- 標準化変数 × standardize=FALSE
ということで、オリジナルスケールというのは、モデル計算前の変数のその時点での状態のことを「オリジナル」ということになりそう。また、標準化しない変数からモデルを構築するのであれば、standardize=TRUEを設定し、標準化することが推奨されているので、それで計算させて非標準化変数のスケールで返却されるモデルを利用すれば良さそう。予測時はもちろん非標準化変数を利用して予測する。
- 非標準化変数 × standardize=TRUE
# 非標準化変数 × standardize=FALSE p1 <- df %>% dplyr::mutate(p1 = 114.3915155 + age * 0.9013971 + race * -2.9912102 + smoke * 0) %>% dplyr::select(lwt, p1) # 非標準化変数 × standardize=TRUE p2 <- df %>% dplyr::mutate(p2 = 123.1038264 + age * 0.8293544 + race * -5.5912696 + smoke * -5.7134436) %>% dplyr::select(p2) # 標準化変数 × standardize=TRUE p3 <- dx_std %>% as.data.frame() %>% dplyr::mutate(p3 = 129.814815 + age * 4.394482 + race * -5.134699 + smoke * -2.796101) %>% dplyr::select(p3) # 非標準化変数 × standardize=FALSE p4 <- dx_std %>% as.data.frame() %>% dplyr::mutate(p4 = 129.814815 + age * 4.394482 + race * -5.134699 + smoke * -2.796101) %>% dplyr::select(p4) pred_df <- cbind(p1, p2, p3, p4)
下記の組み合わせ(p2, p3, p4)の場合、予測される予測値は同じになる。
- 非標準化変数 × standardize=TRUE
- 標準化変数 × standardize=TRUE
- 標準化変数 × standardize=FALSE
pred_df lwt p1 p2 p3 p4 85 182 125.5356 127.6790 127.6790 127.6790 86 155 135.1640 133.6987 133.6987 133.6987 87 105 129.4282 128.3862 128.3862 128.3862 88 108 130.3296 129.2156 129.2156 129.2156 89 107 127.6255 126.7275 126.7275 126.7275 91 124 124.3472 123.7465 123.7465 123.7465 92 118 131.2310 135.7584 135.7584 135.7584 93 103 120.7416 120.4290 120.4290 120.4290 94 123 137.5408 135.8504 135.8504 135.8504 95 113 134.8366 133.3623 133.3623 133.3623 96 95 122.5444 122.0878 122.0878 122.0878 97 150 122.5444 122.0878 122.0878 122.0878 98 95 125.2486 124.5758 124.5758 124.5758 99 107 132.4598 131.2106 131.2107 131.2107 100 100 127.6255 126.7275 126.7275 126.7275 101 100 127.6255 126.7275 126.7275 126.7275 102 98 121.9301 124.3616 124.3616 124.3616 103 118 133.9352 132.5330 132.5330 132.5330 104 120 123.4458 122.9171 122.9171 122.9171 105 120 136.6394 135.0210 135.0210 135.0210 106 121 134.2626 132.8694 132.8694 132.8694 107 100 139.3436 143.2225 143.2225 143.2225 108 202 143.8506 147.3693 147.3693 147.3693 109 120 130.6570 129.5519 129.5519 129.5519 111 120 127.9528 127.0639 127.0639 127.0639 112 167 136.6394 140.7345 140.7345 140.7345 113 122 126.7241 125.8981 125.8981 125.8981 114 150 137.5408 141.5638 141.5638 141.5638 115 168 131.8454 127.7711 127.7711 127.7711 116 113 123.7328 126.0203 126.0203 126.0203 117 113 123.7328 126.0203 126.0203 126.0203 118 90 133.0338 131.7036 131.7036 131.7036 119 121 139.9580 135.2352 135.2352 135.2352 120 155 133.9352 138.2464 138.2464 138.2464 121 125 130.9440 132.6551 132.6551 132.6551 123 140 137.5408 135.8504 135.8504 135.8504 124 138 128.5269 127.5568 127.5568 127.5568 125 124 135.7380 134.1917 134.1917 134.1917 126 215 139.3436 137.5091 137.5091 137.5091 127 109 141.1464 139.1678 139.1678 139.1678 128 185 127.3384 123.6243 123.6243 123.6243 129 189 128.5269 133.2703 133.2703 133.2703 130 130 129.1412 130.9964 130.9964 130.9964 131 160 130.3296 134.9290 134.9290 134.9290 132 90 127.6255 126.7275 126.7275 126.7275 133 90 127.6255 126.7275 126.7275 126.7275 134 132 140.2450 144.0519 144.0519 144.0519 135 132 122.5444 122.0878 122.0878 122.0878 136 115 133.0338 137.4171 137.4171 137.4171 137 85 125.2486 118.8624 118.8624 118.8624 138 120 131.2310 135.7584 135.7584 135.7584 139 128 126.1500 125.4052 125.4052 125.4052 140 130 131.2310 130.0449 130.0449 130.0449 141 95 138.4422 136.6797 136.6797 136.6797 142 115 122.5444 122.0878 122.0878 122.0878 143 110 119.8402 119.5997 119.5997 119.5997 144 110 124.3472 118.0330 118.0330 118.0330 145 153 132.4598 131.2106 131.2107 131.2107 146 103 123.4458 122.9171 122.9171 122.9171 147 119 120.7416 120.4290 120.4290 120.4290 148 119 120.7416 120.4290 120.4290 120.4290 149 119 126.1500 125.4052 125.4052 125.4052 150 110 127.0514 126.2345 126.2345 126.2345 151 140 136.6394 140.7345 140.7345 140.7345 154 133 128.8542 122.1798 122.1798 122.1798 155 169 123.4458 122.9171 122.9171 122.9171 156 115 127.0514 126.2345 126.2345 126.2345 159 250 130.6570 123.8385 123.8385 123.8385 160 141 129.4282 134.0996 134.0996 134.0996 161 158 128.2398 130.1671 130.1671 130.1671 162 112 131.2310 130.0449 130.0449 130.0449 163 150 133.3612 126.3266 126.3266 126.3266 164 115 126.1500 119.6917 119.6917 119.6917 166 112 122.8314 125.1910 125.1910 125.1910 167 135 125.8227 125.0688 125.0688 125.0688 168 229 124.6342 126.8497 126.8497 126.8497 169 140 133.9352 138.2464 138.2464 138.2464 170 134 140.2450 138.3385 138.3385 138.3385 172 121 126.4370 122.7949 122.7949 122.7949 173 190 132.1324 136.5877 136.5877 136.5877 174 131 131.2310 135.7584 135.7584 135.7584 175 170 140.2450 144.0519 144.0519 144.0519 176 110 132.4598 131.2106 131.2107 131.2107 177 127 123.4458 122.9171 122.9171 122.9171 179 123 126.1500 125.4052 125.4052 125.4052 180 120 120.7416 114.7156 114.7156 114.7156 181 105 122.5444 122.0878 122.0878 122.0878 182 130 132.1324 136.5877 136.5877 136.5877 183 175 143.8506 147.3693 147.3693 147.3693 184 125 131.2310 135.7584 135.7584 135.7584 185 133 133.0338 137.4171 137.4171 137.4171 186 134 124.3472 123.7465 123.7465 123.7465 187 235 128.5269 127.5568 127.5568 127.5568 188 95 133.9352 132.5330 132.5330 132.5330 189 135 125.8227 125.0688 125.0688 125.0688 190 135 137.5408 141.5638 141.5638 141.5638 191 154 137.5408 141.5638 141.5638 141.5638 192 147 128.5269 127.5568 127.5568 127.5568 193 147 128.5269 127.5568 127.5568 127.5568 195 137 138.4422 142.3932 142.3932 142.3932 196 110 133.0338 137.4171 137.4171 137.4171 197 184 128.5269 127.5568 127.5568 127.5568 199 110 127.0514 126.2345 126.2345 126.2345 200 110 132.1324 136.5877 136.5877 136.5877 201 120 123.4458 122.9171 122.9171 122.9171 202 241 130.9440 132.6551 132.6551 132.6551 203 112 138.4422 142.3932 142.3932 142.3932 204 169 131.2310 135.7584 135.7584 135.7584 205 120 127.6255 126.7275 126.7275 126.7275 206 170 122.8314 125.1910 125.1910 125.1910 207 186 140.2450 144.0519 144.0519 144.0519 208 120 121.6430 121.2584 121.2584 121.2584 209 130 137.5408 135.8504 135.8504 135.8504 210 117 141.1464 144.8813 144.8813 144.8813 211 170 129.4282 128.3862 128.3862 128.3862 212 134 130.6570 129.5519 129.5519 129.5519 213 135 124.0199 129.1235 129.1235 129.1235 214 130 130.6570 129.5519 129.5519 129.5519 215 120 133.9352 138.2464 138.2464 138.2464 216 95 119.8402 119.5997 119.5997 119.5997 217 158 129.4282 134.0996 134.0996 134.0996 218 160 128.8542 127.8932 127.8932 127.8932 219 115 130.3296 134.9290 134.9290 134.9290 220 129 131.2310 135.7584 135.7584 135.7584 221 130 133.9352 138.2464 138.2464 138.2464 222 120 139.3436 143.2225 143.2225 143.2225 223 170 142.9492 146.5400 146.5400 146.5400 224 120 128.5269 127.5568 127.5568 127.5568 225 116 133.0338 137.4171 137.4171 137.4171 226 123 151.9632 154.8335 154.8335 154.8335 4 120 130.6570 123.8385 123.8385 123.8385 10 130 137.5408 141.5638 141.5638 141.5638 11 187 139.0566 134.4059 134.4059 134.4059 13 105 127.9528 127.0639 127.0639 127.0639 15 85 127.9528 127.0639 127.0639 127.0639 16 150 129.7556 128.7226 128.7226 128.7226 17 97 126.1500 125.4052 125.4052 125.4052 18 128 130.0426 131.8258 131.8258 131.8258 19 132 127.0514 126.2345 126.2345 126.2345 20 165 130.3296 129.2156 129.2156 129.2156 22 105 140.2450 138.3385 138.3385 138.3385 23 91 128.5269 127.5568 127.5568 127.5568 24 115 127.9528 127.0639 127.0639 127.0639 25 130 119.8402 119.5997 119.5997 119.5997 26 92 133.9352 132.5330 132.5330 132.5330 27 150 129.4282 128.3862 128.3862 128.3862 28 200 127.3384 129.3377 129.3377 129.3377 29 155 133.0338 131.7036 131.7036 131.7036 30 103 124.3472 123.7465 123.7465 123.7465 31 125 123.4458 122.9171 122.9171 122.9171 32 89 127.9528 127.0639 127.0639 127.0639 33 102 128.5269 133.2703 133.2703 133.2703 34 112 128.5269 127.5568 127.5568 127.5568 35 117 134.8366 133.3623 133.3623 133.3623 36 138 133.0338 137.4171 137.4171 137.4171 37 130 120.7416 114.7156 114.7156 114.7156 40 120 126.4370 122.7949 122.7949 122.7949 42 130 131.2310 130.0449 130.0449 130.0449 43 130 132.7468 134.3139 134.3139 134.3139 44 80 123.4458 117.2037 117.2037 117.2037 45 110 126.7241 125.8981 125.8981 125.8981 46 105 127.9528 127.0639 127.0639 127.0639 47 109 123.4458 122.9171 122.9171 122.9171 49 148 121.6430 121.2584 121.2584 121.2584 50 110 124.6342 121.1362 121.1362 121.1362 51 121 129.4282 128.3862 128.3862 128.3862 52 100 124.3472 123.7465 123.7465 123.7465 54 96 128.8542 127.8932 127.8932 127.8932 56 102 139.3436 137.5091 137.5091 137.5091 57 110 124.9213 129.9529 129.9529 129.9529 59 187 129.1412 125.2830 125.2830 125.2830 60 122 126.4370 122.7949 122.7949 122.7949 61 105 130.0426 126.1123 126.1124 126.1124 62 115 118.9388 118.7703 118.7703 118.7703 63 120 126.1500 125.4052 125.4052 125.4052 65 142 138.4422 136.6797 136.6797 136.6797 67 130 131.2310 130.0449 130.0449 130.0449 68 120 126.7241 125.8981 125.8981 125.8981 69 110 132.1324 130.8743 130.8743 130.8743 71 120 123.7328 126.0203 126.0203 126.0203 75 154 128.8542 127.8932 127.8932 127.8932 76 105 123.4458 122.9171 122.9171 122.9171 77 190 134.8366 133.3623 133.3623 133.3623 78 101 118.0374 112.2275 112.2275 112.2275 79 95 136.6394 135.0210 135.0210 135.0210 81 100 118.0374 117.9410 117.9410 117.9410 82 94 126.1500 119.6917 119.6917 119.6917 83 142 123.7328 126.0203 126.0203 126.0203 84 130 130.3296 129.2156 129.2156 129.2156
おまけ
p3, p4では、当たり前だが、標準化したものを標準化しても値が変わらないので、standardize
がどちらでも同じ回帰係数が返る。
- 標準化変数 × standardize=TRUE
- 標準化変数 × standardize=FALSE
set.seed(1) x_std1 <- scale(rnorm(10, 100, 10)) x_std2 <- scale(x_std1) x_std1 [,1] [1,] -0.97190653 [2,] 0.06589991 [3,] -1.23987805 [4,] 1.87433300 [5,] 0.25276523 [6,] -1.22045645 [7,] 0.45507643 [8,] 0.77649606 [9,] 0.56826358 [10,] -0.56059319 attr(,"scaled:center") [1] 101.322 attr(,"scaled:scale") [1] 7.80586 x_std2 [,1] [1,] -0.97190653 [2,] 0.06589991 [3,] -1.23987805 [4,] 1.87433300 [5,] 0.25276523 [6,] -1.22045645 [7,] 0.45507643 [8,] 0.77649606 [9,] 0.56826358 [10,] -0.56059319 attr(,"scaled:center") [1] 5.620504e-16 attr(,"scaled:scale") [1] 1