R関数の引数チェック
はじめに
R関数の引数チェックについて、社内勉強会用にまとめておく。本当は{rlang}
のabort()
、warn()
、inform()
あたりをまとめたかったのだが、時間がなかったので、基本的な部分だけまとめている。今後追記していく予定。
引数チェックとは
基本的に関数を書くときは、引数をチェックするためのスクリプトを記述する。しかし、意図と違うものが入力されると、ストップなりワーニングを返すようにしているが、全てをカバーしきれない。実力不足と言われれば…そうである…。なので、作った関数を誰かに渡すと、予期していないような使い方をするので、エラーが出る。エラーが出ればいいが、計算が進んで、変な結論を生むことは避けたい。
例えば、三角形の面積を求める関数を書いたとする。この段階では何ら問題はない。
trianle_area1 <- function(base, height){ ans <- (base * height) * 0.5 return(ans) } trianle_area1(base = 2, height = 10) [1] 10
しかし、誰かに渡すとこのように使われたりする。つまり、負の値が入力される…面積なので、負の値はありえない。
trianle_area1(base = 2, height = -10) [1] -10
なので、こうならないようにstop()
をおいておく。
trianle_area2 <- function(base, height){ if (0 > height) stop("`height` must be >= 0") if (0 > base) stop("`base` must be >= 0") ans <- (base * height) * 0.5 return(ans) } trianle_area2(base = 2, height = -10) trianle_area2(base = 2, height = -10) でエラー: `height` must be >= 0
すると、次はこうなったりする。負の値を含むベクトル(Rでは、長さ1の場合でもベクトルと呼ぶ…確か、スカラとは呼ばない買ったような…まぁどっちゃでもいいか。ここでは本質ではないので。)である。この場合、Rはワーニングは出すが、負の値ではない1番目の値だけで非負かどうかの判定がなされ、ベクトルの長さにあうように値がリサイクルされる。このようなあからさまにわかりやすい計算であれば、結果がおかしいと気づくが、難しい計算であれば、気づかないし、人によってはワーニングを無視する。
trianle_area2(base = c(1, -2), height = 5) [1] 2.5 -5.0 警告メッセージ: if (0 > base) stop("`base` must be >= 0") で: 条件が長さが 2 以上なので、最初の 1 つだけが使われます
なので、これでは困るために、インプットのベクトルの長さを1に制限する。
trianle_area3 <- function(base, height){ if (!length(base) == 1) stop("`base` length must be = 1") if (0 > base) stop("`base` must be >= 0") if (!length(height) == 1) stop("`height` length must be = 1") if (0 > height) stop("`height` must be >= 0") ans <- (base * height) * 0.5 return(ans) } trianle_area3(base = c(1, -2), height = 5) trianle_area3(base = c(1, -2), height = 5) でエラー: `base` length must be = 1
かといって、これではせっかくベクトル化されているのに、それを使わないようにするのはもったいない。なので、if(any(hoge))
と対処することもできる。負の値が含まれている場合は計算されない。
trianle_area4 <- function(base, height){ if (any(0 > base)) stop("`base` must be >= 0") if (any(0 > height)) stop("`height` must be >= 0") ans <- (base * height) * 0.5 return(ans) } trianle_area(base = c(1, 2), height = c(1,-2)) trianle_area(base = c(1, 2), height = c(1, -2)) でエラー: `height` must be >= 0 trianle_area(base = c(1, 2), height = c(1,2)) [1] 0.5 2.0
ベクトル化とは
Rの関数のほとんどは、引数としてベクトルをとって、ベクトルを返すようになっている。こうした関数の内側では、内部的にC言語のサブルーチンで処理することで、高速に計算されるように実装されています。非常にありがたい。なので、ベクトルを入力するとベクトルをfor()
などを使わずに返すようにしたかったりするが、if()
などはベクトルに対応していない。数年前に、なんとかプロジェクトでfor()
などが大幅に改善されたとかあったような気がする。うる覚えで申し訳ないが…。
そして、for()
を避ければよいという話ではない。モノによってはこっちのほうが速い。
例えば、負、ゼロ、正で値をわけてフラグを立てる関数my_sign()
。
my_sign <- function(x){ if(!is.numeric(x)){ stop("argument is not numeric : returning NA") } output <- NA_character_ if(is.na(x)){ output <- NA } else if(0 > x) { output <- "Negative" } else if(x == 0){ output <- "Zero" } else { output <- "Positive" } return(output) } my_sign(1) [1] "Positive" my_sign(0) [1] "Zero" my_sign(-1) [1] "Negative" my_sign("abc") my_sign("abc") でエラー: argument is not numeric : returning NA my_sign(NA) my_sign(NA) でエラー: argument is not numeric : returning NA
ベクトルを入力すると、フラグを立てるところにif()
があるので、ベクトルが返ってこない。
vals <- c(-5,0,5,NA_integer_) vals [1] -5 0 5 NA my_sign(vals) [1] "Negative" 警告メッセージ: 1: if (is.na(x)) { で: 条件が長さが 2 以上なので、最初の 1 つだけが使われます 2: if (0 > x) { で: 条件が長さが 2 以上なので、最初の 1 つだけが使われます
なので、これを対応できるようにfor()
を使うとする。気づいた人はいるかも知れないが、ベクトル化してくれる関数を使えばよいのでは?と思うかもしれない…。こんなやつである。
my_sign_vec <- base::Vectorize(FUN = my_sign) my_sign_vec(vals) [1] "Negative" "Zero" "Positive" NA sapply(X = vals, FUN = my_sign, simplify = TRUE) [1] "Negative" "Zero" "Positive" NA vals %>% map_chr(.x = ., .f = function(x){ my_sign(x) }) [1] "Negative" "Zero" "Positive" NA
さっきのことを忘れて、for()
で対応する。
my_sign2 <- function(x){ if(!is.numeric(x)){ stop("argument is not numeric : returning NA") } n_len <- length(x) output <- vector(mode = "character", length = n_len) for (i in seq_along(x)) { if(is.na(x[[i]])){ output[[i]] <- NA_character_ } else if(0 > x[[i]]) { output[[i]] <- "Negative" } else if(x[[i]] == 0){ output[[i]] <- "Zero" } else { output[[i]] <- "Positive" } } return(output) } my_sign2(vals) [1] "Negative" "Zero" "Positive" NA
なんだけれど、このような場合はif()
の代わりにcase_when()
で代替できるので、こっちを使ったほうがよい。5000万件くらい処理すると、10秒くらい変わってくる。
my_sign3 <- function(x){ if(!is.numeric(x)){ stop("argument is not numeric : returning NA") } output <- case_when(is.na(x) ~ NA_character_, 0 > x ~ "Negative", x == 0 ~ "Zero", TRUE ~ "Postive") return(output) } my_sign3(vals) system.time( my_sign2(rnorm(5e7)) ) ユーザ システム 経過 24.712 0.391 25.337 system.time( my_sign3(rnorm(5e7)) ) ユーザ システム 経過 12.103 3.958 16.845
1億件くらい処理すると、15秒くらい変わってくる。
system.time( my_sign2(rnorm(1e8)) ) ユーザ システム 経過 53.781 1.076 60.426 system.time( my_sign3(rnorm(1e8)) ) ユーザ システム 経過 24.987 12.033 44.488
再帰関数
再帰関数、再帰呼び出しについて。これは関数内で自分自身を呼び出す関数のこと。
rsum <- function(n){ if (n == 1){ return(1) } return(n + rsum(n - 1)) } rsum(5) [1] 15
やっていることは下記と同じではある。
sum(1:5) [1] 15 res = 0 for (i in 1:5) { res = res + i } res [1] 15
上記のnまでの数を合計する関数の挙動について考える。下記の画像のように自分自身であるrsum()
を呼び出している。
ここでは、終了条件として、 if (n == 1){return(1)}
があり、ここまで潜ると、return
が実行され、随時、結果が返され、画像のように動く。つまり、return
は下記のようになる。