K-Fold Target Encoding
はじめに
ここではK-Fold Target Encodingについてまとめておく。下記のサイトでは、K-Fold Target Encodingの説明とPythonでの実装が乗っているので、それを参考にRで雑に書き直してみた。いつの日か関数化しよう…。
K-Fold Target Encoding
One-hot EncodingやDummy Encodingは、カテゴリ変数のカーディナリティが多くなると、データセットの次元が大きくなるため、役に立たない場合があるらしい。そこで、Target Encodingなどが役立つが、なにも考えずにTarget Encodingすると、リークしてしまい、モデルの汎化性能が良くなくなる。それを回避しつつ、Target Encodingするために、K-Fold Target Encodingを行う。
K-Fold Target Encodingとは、クロスバリデーションを行いながら、ターゲットの変数の値を、カテゴリカル特徴量のグループごとの平均などでエンコードすることで、カテゴリカルな特徴量を新たな特徴量に変換するエンコード方法。クロスバリデーションしないと行けない理由としては、単純に全データを使ってTarget Encodingしてしまうと、リークする可能性がでてくるため。なので、Target Encodingする際には、エンコードされた値を付与したいレコード群(フォールド群)には、自身のレコードのターゲットの値を含めず、アウトオブフォールド群で計算された平均値を付与する。それをクロスバリデーションの要領で行うことで、トレーニング用のモデルの特徴量として、モデルを構築する。テストデータに対するTarget Encodingの値は、テストデータに付与したK-Fold Target Encodingの値をカテゴリ毎に平均して付与する。
サンプルデータ
K-Fold Target Encodingの値と一致しているかを確認するため、サンプルデータ自体もK-Fold Target Encodingのデータを再現させていただいている。
library(cvTools) df <- tibble(id = 1:25, feature = c("A","B","B","B","B","A","B","A","A","B", "A","A","B","A","A","B","B","B","A","A","B","B","B","A","A"), target = c(1,0,0,1,1,1,0,0,0,0,1,0,1,0,1,0,0,0,1,1,NA,NA,NA,NA,NA), flg = rep(c("train", "test"), c(20, 5))) train_df <- df %>% dplyr::filter(flg == "train") test_df <- df %>% dplyr::filter(flg == "test")
K-Fold Target Encodingの雑な実装
フォールドのインデックスとクロスバリデーション回数は、参考サイトの図解にあわせて、k=5で、それにあわせてフォールドは連番で作成している。これはあくまでも解説のためなので、本番のデータの状態にあわせて変更する。たぶん、{purrr}
で書き直せばもっとすっきりしそう。
# type = c("random", "consecutive", "interleaved") k <- 5 folds <- cvFolds(nrow(train_df), K = k, type = "consecutive") train_df <- train_df %>% dplyr::bind_cols(tibble(folds = folds$which)) res <- tibble() for (i in 1:k) { fold_data <- train_df %>% dplyr::filter(folds == i) out_fold_data <- train_df %>% dplyr::filter(folds != i) tmp <- out_fold_data %>% dplyr::mutate(folds = i) %>% dplyr::group_by(feature, folds) %>% dplyr::summarise(cv_target_encoding = mean(target, na.rm = TRUE)) res <- res %>% dplyr::bind_rows(tmp) } res # A tibble: 10 x 3 feature folds cv_target_encoding <chr> <int> <dbl> 1 A 1 0.556 2 B 1 0.286 3 A 2 0.625 4 B 2 0.25 5 A 3 0.714 6 B 3 0.333 7 A 4 0.625 8 B 4 0.25 9 A 5 0.5 10 B 5 0.375
あとはこれをtrain_df
にもどし、そこからTarget Encodingの値を訓練データから平均して求め、テストデータに返す。left_join()
のところでサブクエリっぽくなっている部分で訓練データから平均を計算している。雑ですんません。
train_df <- train_df %>% dplyr::left_join(x = ., y = res, by = c("feature", "folds")) test_df <- test_df %>% dplyr::left_join(x = ., y = train_df %>% dplyr::group_by(feature) %>% dplyr::summarise(cv_target_encoding = mean(cv_target_encoding)), by = c("feature")) df2 <- train_df %>% dplyr::bind_rows(test_df) df2 %>% # 表示のため as.data.frame() id feature target flg folds cv_target_encoding 1 1 A 1 train 1 0.5555556 2 2 B 0 train 1 0.2857143 3 3 B 0 train 1 0.2857143 4 4 B 1 train 1 0.2857143 5 5 B 1 train 2 0.2500000 6 6 A 1 train 2 0.6250000 7 7 B 0 train 2 0.2500000 8 8 A 0 train 2 0.6250000 9 9 A 0 train 3 0.7142857 10 10 B 0 train 3 0.3333333 11 11 A 1 train 3 0.7142857 12 12 A 0 train 3 0.7142857 13 13 B 1 train 4 0.2500000 14 14 A 0 train 4 0.6250000 15 15 A 1 train 4 0.6250000 16 16 B 0 train 4 0.2500000 17 17 B 0 train 5 0.3750000 18 18 B 0 train 5 0.3750000 19 19 A 1 train 5 0.5000000 20 20 A 1 train 5 0.5000000 21 21 B NA test NA 0.2940476 22 22 B NA test NA 0.2940476 23 23 B NA test NA 0.2940476 24 24 A NA test NA 0.6198413 25 25 A NA test NA 0.6198413
S3クラスのまとめ
はじめに
RのS3クラスシステムについて、他の言語をやっていると、少しごっちゃごちゃになってきたので、簡単にまとめておく。
S3クラス
Rの基本となるクラスシステムはS3クラス。他にもS4とかR5とかあるけどもここでは、S3クラスのことをまとめる。このシステムによって、Rでは、異なるクラスをどのように扱うのかをコントロールしている。このクラスシステムは、class
の他に、names
、levels
などを持てる属性、ジェネリック関数、メソッドから構成される。
そのオブジェクトがどのようなクラスを持っているかはclass()
で確認できる。
df <- data.frame(x = 1:10) class(df) [1] "data.frame" class(lm(x ~ x)) [1] "lm" class(Sys.Date()) [1] "Date" class(Sys.time()) [1] "POSIXct" "POSIXt"
そのクラスに応じてジェネリック関数(総称関数)は、そのオブジェクトをどのように扱うかを決めている。例えば、data.frame
クラスをもつオブジェクトをprint()
してみると、下記のように表示される。
print(df) x 1 1 2 2 3 3 4 4 5 5 6 6 7 7 8 8 9 9 10 10
これは、print()
のprint.data.frame
メソッドが適用され、data.frame
クラスのオブジェクトに合わして表示がコントロールされる。このクラスにあわせてメソッドを適用する仕組みをメソッドディスパッチという。その関数がジェネリック関数かどうかは、その関数名のみを実行して、UseMethod
うんたらと表示されればジェネリック関数である。そのジェネリック関数が、どのようなメソッドを持っているかはmethods()
にジェネリック関数の名前のみをいれれば確認できる。
print function (x, ...) UseMethod("print") <bytecode: 0x7fd43cae1cf0> <environment: namespace:base> methods(print) [1] print.acf* print.AES* [3] print.all_vars* print.anova* [5] print.anova.lme* print.ansi_string* [7] print.ansi_style* print.any_vars* [9] print.aov* print.aovlist* ----【略】
実際に、data.frame
クラスをもつオブジェクトをprint()
すると、下記のメソッドが適用される。
methods(print)[93] [1] "print.data.frame" print.data.frame function (x, ..., digits = NULL, quote = FALSE, right = TRUE, row.names = TRUE, max = NULL) { n <- length(row.names(x)) if (length(x) == 0L) { cat(sprintf(ngettext(n, "data frame with 0 columns and %d row", "data frame with 0 columns and %d rows"), n), "\n", sep = "") } else if (n == 0L) { print.default(names(x), quote = FALSE) cat(gettext("<0 rows> (or 0-length row.names)\n")) } else { if (is.null(max)) max <- getOption("max.print", 99999L) if (!is.finite(max)) stop("invalid 'max' / getOption(\"max.print\"): ", max) omit <- (n0 <- max%/%length(x)) < n m <- as.matrix(format.data.frame(if (omit) x[seq_len(n0), , drop = FALSE] else x, digits = digits, na.encode = FALSE)) if (!isTRUE(row.names)) dimnames(m)[[1L]] <- if (isFALSE(row.names)) rep.int("", if (omit) n0 else n) else row.names print(m, ..., quote = quote, right = right, max = max) if (omit) cat(" [ reached 'max' / getOption(\"max.print\") -- omitted", n - n0, "rows ]\n") } invisible(x) } <bytecode: 0x7fd43dfeb010> <environment: namespace:base>
なので、このS3クラスシステムを利用し、独自の関数を定義した際に、クラスを与えてメソッドを作ることができる。例えば、ここではmoney
クラスを作ってみる。このクラスはオブジェクトの先頭に$
マークを付ける、というもの。オブジェクトにクラスを設定したいときは、class(オブジェクト) <- "クラス名"
とするだけで良い。
x <- 1:10 class(x) <- "money" class(x) [1] "money"
あとは、print()
にprint.money
というメソッドを追加する。メソッドを追加するときは、print.money
という関数を作ることで追加できる。名前のルールはジェネリック関数.クラス名
。
print.money <- function(x) { paste0("$", x) }
あとはこれで、money
クラスのオブジェクトにprint()
を使えば、money
クラス用のメソッドが適用される。
print(x) [1] "$1" "$2" "$3" "$4" "$5" "$6" "$7" "$8" "$9" "$10"
こんな変なこともクラス属性が後付で付与できるので、やろうと思えば、できてしまう。
y <- TRUE print(y) [1] TRUE class(y) <- "money" print(y) [1] "$TRUE"
クラス設定~ジェネリック関数作成
クラスの設定からジェネリック関数の作成までやってみる。さきほどのまでの流れだと、オブジェクトを作ってクラスを付与していたのだけれど、これは幾分、めんどくさい。
s <- list(name = "Taro", age = 21L, score = 30L) class(s) <- "student"
なので、関数を作って、オブジェクトを作ったと同時にクラスを付与しておく。
student <- function(name, age, score) { if(score > 100 || score < 0){ stop("score must be between 0 and 100") } value <- list(name = name, age = age, score = score) class(value) <- "student" value }
これでstudent()
を使えば、クラスがstudent
クラスになる。
s1 <- student(name = "Tanaka", age = 26L, score = 80L) class(s1) [1] "student"
このstudent
クラスに対して、print()
のメソッドを作ってみる。さきほど同様、メソッドを追加するときは、print.student
という関数を作ることで追加できる。名前のルールはジェネリック関数.クラス名
。
print.student <- function(x) { cat("Name : ", x$name, "\n") cat("Age : " , x$age, "\n") cat("Score : " , x$score, "\n") } print(s1) Name : Tanaka Age : 26 Score : 80 s2 <- student(name = "Suzuki", age = 30L, score = 50L) print(s2) Name : Suzuki Age : 30 Score : 50
ジェネリック関数というのは、既存の関数だけでなく、自分で作ることができる。score
というジェネリック関数を作ってみる。UseMethod()
を使って、下記のように書けば、ジェネリック関数が作成できる。
score <- function(x) { UseMethod("score") }
あとは、score
というジェネリック関数が各クラスに対して、どのように振る舞うのか、メソッドを作っておく。
score.default <- function(x) { cat("This is a generic function") } score.student <- function(x) { cat("Your Score is", x$score, "\n") }
score.default
というのは、ジェネリック関数がクラスを受けたとき、該当するクラスが見つからない場合、score.default
というクラスで処理するためのもの。print()
にもprint.default
とにも用意されている。
print.default function (x, digits = NULL, quote = TRUE, na.print = NULL, print.gap = NULL, right = FALSE, max = NULL, useSource = TRUE, ...) { args <- pairlist(digits = digits, quote = quote, na.print = na.print, print.gap = print.gap, right = right, max = max, useSource = useSource, ...) missings <- c(missing(digits), missing(quote), missing(na.print), missing(print.gap), missing(right), missing(max), missing(useSource)) .Internal(print.default(x, args, missings)) } <bytecode: 0x7f90e8000628> <environment: namespace:base>
これで準備は整ったので、student
クラスのオブジェクトをscore()
にわたすと、score.student
メソッドが適用されるようになる。
score(s1) Your Score is 80 score(s2) Your Score is 50
score.default
メソッドが動くかどうか確認するために、student
クラス以外のオブジェクトをscore()
にわたしてみる。
score(rnorm(10)) This is a generic function
継承
S3クラスの継承について。下記のようにstudent
クラスを作ったとする。ジェネリック関数のprint()
について、次のようにstudent
クラスのメソッドを定義する。
student <- function(name, age, score) { value <- list(name = name, age = age, score = score) attr(value, "class") <- "student" value } print.student <- function(obj) { cat(obj$name, "\n") cat(obj$age, "years old\n") cat("score:", obj$score, "\n") }
ここでリストで下記のようなデータを作ったとする。もちろんクラスはlist
。
s <- list(name = "Tom", age = 26, score = 90, country = "japan") class(s) [1] "list"
オブジェクトs
に2つのクラスを付与する。
class(s) <- c("InternationalStudent", "student") class(s) [1] "InternationalStudent" "student"
この状態だと、print(s)
とすると、print.student
が呼び出される。
print(s) Tom 26 years old score: 90
ここで、InternationalStudent
クラスのメソッドを定義すると、print(s)
はprint.InternationalStudent
を呼び出すようになる。つまり、以下のようにクラスstudent
に定義されたメソッドが上書きされる。
print.InternationalStudent <- function(obj) { cat(obj$name, "is from", obj$country, "\n") } print(s) Tom is from japan
inherits()
かis()
を使うと継承されているかどうかわかる。
inherits(s,"student") [1] TRUE is(s,"student") [1] TRUE
クラスのベクトル順序によって上書きされるかどうかは決まる。
s <- list(name = "Tom", age = 26, score = 90, country = "japan") class(s) [1] "list" class(s) <- c("student", "InternationalStudent") class(s) [1] "student" "InternationalStudent" print(s) Tom 26 years old score: 90 # NOT OVERWRITE print.InternationalStudent <- function(obj) { cat(obj$name, "is from", obj$country, "\n") } print(s) Tom 26 years old score: 90 inherits(s, "student") [1] TRUE
Rプログラミング本格入門―達人データサイエンティストへの道― で紹介されている継承の部分をメモしておく。Vehicle()
はクラスとして、引数のclass
とvehicle
を持つ。そして、この関数を使って、vehicle
を継承するcar
クラスを生成するCar()
を作る。
Vehicle <- function(class, name, speed){ obj <- new.env(parent = emptyenv()) obj$name <- name obj$speed <- speed obj$position <- c(0,0,0) class(obj) <- c(class, "vehicle") obj } Car <- function(...){ Vehicle(class = "car", ...) } car <- Car(name = "Model-S", speed = 100) class(car) [1] "car" "vehicle"
そして、vehicle
クラスに対するメソッドprint.vehicle
を定義する。
print.vehicle <- function(x, ...) { cat(sprintf("<vehicle: %s>\n", class(x)[[1]])) cat("name:", x$name, "\n") cat("speed:", x$speed, "km/h\n") cat("position:", paste(x$position, collapse = ", ")) }
car
クラスはvehicle
クラスを継承しているので、この関数を使うことで、carクラス
でもprint.vehicle
が呼び出される。
print(car) <vehicle: car> name: Model-S speed: 100 km/h position: 0, 0, 0 sloop::s3_dispatch(print(car)) print.car => print.vehicle * print.default