for ループを使わずに条件式を回したい
ウェブで Python の for 文について調べてみると、for ループは時間がかかるとか、for 文は読みにくいとか、極端な例では for 文を使ったら負け、みたいなことが書いてあります。
それはわからなくもないのですが、では具体的にどうすれば良いのかを説明しているウェブサイトが少ないと思ったので(少なくとも私はかなり苦労しました)備忘の意味も込めて具体的な方法を書くことにしました。
Chat GPT のような優秀な AI が簡単に使える時代にどれくらいの人がウェブ検索でプログラミングについて調べているのかわかりませんが、プログラミング初心者の方のスキルアップにつながればと思います。
この記事の内容
結論:numpy.where 関数を使う
for 文と if 文を組み合わせて、多数の要素を持つ配列の中から特定の条件を満たす要素だけを取り出すことがあると思います。
その場合は NumPy の where 関数(以下、np.where)を使います。
np.where の使い方は次のとおりです。
- np.where(条件式, 条件式を満たす場合の処理, 満たさない場合の処理)
公式ドキュメントはこちらです(英語)。
注意事項として、numpy 配列にしか使えません。
Python 標準のリストなどの配列変数は、np.where 関数を使う前に numpy配列に変換しましょう。
カッコ内のうしろ2つ、条件式を満たす場合と満たさない場合の処理は無くても構いません。
その場合、ある配列変数に対して np.where 関数を使うと、条件式を満たす要素の配列番号を抽出し、それを並べた配列を返します。
条件式を満たした場合、満たさない場合の処理は、例えば数値や文字列などを入力すると、それらに置換した配列を返します。
以下で具体的な使い方を見ていきましょう。
np.where 関数の基本
1 条件を満たす配列番号を返す
ある配列に対して条件式のみを入力した場合、その条件を満たす配列番号を並べた配列を返します。
以下は、0 から 100 までの 11個の要素を持つ配列変数(名前は “data“)のうち、50 を超える要素の配列番号を np.where 関数で取得します。
import numpy as np
# 0 から 100 までの 11個の要素を持つ numpy配列を作成する。
data = np.array([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
# numpy.where 関数で指定した条件の要素を持つ配列番号を表示する。
# この例では 50を超える要素を持つ配列番号を表示する。
num_array = np.where(data > 50)
print(num_array)
実行結果:以下のように 50 を超える要素を持つ配列番号を返します。
この配列番号を使うことで、必要なデータだけを抽出できます(後述)。
2 条件を満たす/満たさない要素にそれぞれ指定した数値を代入する
条件式を満たす場合を 1、満たさない場合を 0 に置換した配列を取得したい場合は、以下のようにします。
import numpy as np
# 0 から 100 までの 11個の要素を持つ numpy配列を作成する。
data = np.array([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
# numpy.where 関数で指定した条件を満たす場合は 1、
# 満たさない場合は 0 に置換した配列を作る。
# この例では 50を超える要素を 1に、それ以外を 0に置換した配列を返す。
count_array = np.where(data > 50, 1, 0)
print(count_array)
実行結果:以下のように 50 を超える要素を 1に、それ以外を 0に置換した配列を返します。
これを使うことで、条件を満たす要素の個数を数えることができます(後述)。
3 条件を満たす/満たさない要素にそれぞれ指定した文字列を代入する
数値ではなく、文字列に置換することもできます。以下は条件を満たす場合は「OK」、満たさない場合は「NG」の文字列に置換した配列を返します。
import numpy as np
# 0 から 100 までの 11個の要素を持つ numpy配列を作成する。
data = np.array([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
# numpy.where 関数で指定した条件を満たす場合は文字列 "OK"、
# 満たさない場合は文字列 "NG" に置換した配列を作る。
str_array = np.where(data > 50, "OK", "NG")
print(str_array)
実行結果:50 を超える要素を「OK」に、それ以外を「NG」に置換した配列を返します。
4 複数条件(論理積・論理和)を使う
論理積(&:いわゆる AND)や論理和(|:いわゆる OR)を使って複数条件を与えることもできます。
以下の例でははじめに &、次に | を使っています。
import numpy as np
# 0 から 100 までの 11個の要素を持つ numpy配列を作成する。
data = np.array([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
# numpy.where 関数で指定した条件を満たす場合は 1、
# 満たさない場合は 0 に置換した配列を作る。
# この例では 10を超えて、かつ 60未満の要素を 1に、それ以外を 0に置換した配列を返す。
count_array1 = np.where((data > 10) & (data < 60), 1, 0)
print(count_array1)
# numpy.where 関数で指定した条件を満たす場合は 1、
# 満たさない場合は 0 に置換した配列を作る。
# この例では 30未満、または 70 を超える要素を 1に、それ以外を 0に置換した配列を返す。
count_array2 = np.where((data < 30) | (data > 70), 1, 0)
print(count_array2)
実行結果:どちらも指定した条件通りになりました。
5 2次元配列に適用する
当然ですが 2次元の配列にも使えます。スライスを使って条件を与えます。
以下の例では、3行 6列の 2次元配列で指定した行または列について条件を与えています。
import numpy as np
# 2次元のnumpy配列を作成する。
data_2d = np.array([[0, 1, 2, 3, 4, 5],
[10, 11, 12, 13, 14, 15],
[20, 21, 22, 23, 24, 25]])
# 2次元配列のうち 2行目(行番号は 1)のデータから 12を超える要素を持つ
# 列の番号を表示する。
num_array1 = np.where(data_2d[1, :] > 12)
print(num_array1)
# 2次元配列のうち 3列目(列番号は 2)のデータから 10 を超え、
# かつ 20未満の要素を持つ行の番号を表示する。
num_array2 = np.where((data_2d[:, 2] > 10) & (data_2d[:, 2] < 20))
print(num_array2)
実行結果:それぞれ指定した行または列について、条件に当てはまる要素番号を返しました。
実践的な使い方
1 条件を満たす要素の数を数える
for 文と if 文を組み合わせて、条件を満たす要素の数を数えることがあると思います。
np.where を使えば、for 文を使うよりもシンプルな記述でかつ高速に処理ですます。
以下の例では、10000個のランダムな数値をもつNumPy配列から条件を満たす要素を 1、満たさないものを 0に置換し、その後配列内の要素の合計を計算することで個数を数えます。
import numpy as np
# 0 から 1 の間の乱数を 1万個持つ配列を作る。
random_data = np.random.random(10000)
print(random_data)
# numpy.where 関数で 0.9以上の要素を 1に、それ以外を 0 にして
# 条件を満たす要素の数を数える。
count_array = np.where(random_data >= 0.9, 1, 0)
count = sum(count_array)
print(count)
実行結果:0 から 1 までのランダムな数値の中から、条件(0.9 以上)を満たす要素の数を無事に数えることができました。
1行目に配列の中身(最初と最後の 3つだけ)を表示し、2行目に条件を満たす要素の数を示しています。
2 2次元配列で別の行の要素を取り出す
年間の売上データのようなものを考えます。
1900 | 1901 | 1902 | 1903 | 1904 | ・・・ |
250,341 | 319,247 | 280,728 | 303,864 | 266,193 | ・・・ |
上のデータから、例えば 30万を超える年を取り出したい、という場合も np.where を使えば簡単に処理できます。
初めに np.where で条件を満たす列の番号を取り出し、その番号を使って別の行の要素を取り出します。下の実行例では、1行目に「年」、2行目に 0から 1のランダムな数値を持つ 2次元配列(2行)を作り、その中から指定した数値をもつ「年」だけを取り出します。
import numpy as np
# 1900 から 2020 までの要素を持つ行と、それに対応した個数の
# 乱数を持つ行を組み合わせた 2次元配列を作る。
line0 = np.arange(1900, 2021)
line1 = np.random.random(len(line0))
data_2d = np.array([line0, line1])
# 2行目(行番号は 1)の配列から 0.9以上の要素を持つ列番号を抽出する。
conditional_line1 = np.where(data_2d[1, :] >= 0.9)
# 1行目(行番号は 0)の配列から、上の条件を満たした列番号に対応する
# 要素を抽出する。
satisfied_line0 = data_2d[0, conditional_line1]
print(satisfied_line0)
実行結果:条件を満たす「年」だけの配列を得られました。この方法は、割と色々なデータ解析で使えると思います。
しかも、実際にデータを処理しているコードはたったの 2行なので、np.where の使い方さえ覚えれば for + if 文の記述よりも圧倒的に見やすくなります。
処理速度の比較
NumPy の where 関数を使うことで、実際に forループ + if 文と比べてどれくらい処理速度が変わるのかをテストしてみました。
1 「1900年」から「2020年」までのうち条件を満たす「年」を抽出する
上で紹介した 2行の配列からのデータ抽出について、
(1) forループ + if文
(2) np.where
のそれぞれを使った場合について、処理速度を比較してみます。
比較するデータは全く同じです。最初に用意し (1) と (2) それぞれの処理開始前後で時間を測定してみます。
時間は ms(ミリ秒)で表示します。
以下のコードを使いました。
import numpy as np
import time
# 1900 から 2020 までの要素を持つ行と、それに対応した個数の
# 乱数を持つ行を組み合わせた 2次元配列を作る。
line0 = np.arange(1900, 2021)
line1 = np.random.random(len(line0))
data_2d = np.array([line0, line1])
# for文と if文で 2行目の配列が 0.9以上の要素を持つ 1行目の要素を抽出する。
# 計算開始時間の取得
time1 = time.time()
# 結果を格納するための空配列を用意する。
result_for = []
# for文を配列の列数分だけ回す。
for i in range(0, data_2d.shape[1]):
# if文で 2行目の配列が 0.9以上かどうかを判定する。
if data_2d[1, i] >= 0.9:
# 1行目の同じ列にある要素を初めに用意した空配列に格納する。
result_for.append(data_2d[0, i])
print(result_for)
# 計算終了時間の取得
time2 = time.time()
# 計算時間を ms で表示する。
calc_time1 = 1000*(time2 - time1)
print("for + if 文による計算時間: ", calc_time1, " ms \n")
# where文で 2行目の配列が 0.9以上の要素を持つ 1行目の要素を抽出する。
# 計算開始時間の取得
time3 = time.time()
# where 関数で2行目の配列から 0.9以上の要素を持つ列番号を抽出する。
conditional = np.where(data_2d[1, :] >= 0.9)
# 1行目の配列から、上の条件を満たした列番号に対応する要素を抽出する。
result_where = data_2d[0, conditional]
print(result_where)
# 計算終了時間の取得
time4 = time.time()
# 計算時間を ms で表示する。
calc_time2 = 1000*(time4 - time3)
print("where 関数による計算時間: ", calc_time2, " ms")
実行結果:なんと np.where で処理した結果は、処理時間が短すぎて 0.0 ms となってしまいました。もちろん環境(パソコン)によって違うとは思いますが。
これでは比較にならないので、次はより大きなデータを用意します。
2 「1900年」から「2020年」までのうち条件を満たす「年」を抽出する
次は 2行 × 100万列のデータを用意し、同じ処理を
(1) forループ + if文
(2) np.where
のそれぞれで実行して比較しました。
コードは以下のとおりです。
import numpy as np
import time
# 0から 999,999までの数値をもつ行と、それに対応した 0から 1までの
# 乱数を持つ行を組み合わせた 2次元配列を作る。
line0 = np.arange(0, 1000000)
line1 = np.random.random(len(line0))
data_2d = np.array([line0, line1])
# for文と if文で 2行目の配列が 0.9以上の要素を持つ 1行目の要素を抽出する。
# 計算開始時間の取得
time1 = time.time()
# 結果を格納するための空配列を用意する。
result_for = []
# for文を配列の列数分だけ回す。
for i in range(0, data_2d.shape[1]):
# if文で 2行目の配列が 0.9以上かどうかを判定する。
if data_2d[1, i] >= 0.9:
# 1行目の同じ列にある要素を初めに用意した空配列に格納する。
result_for.append(data_2d[0, i])
print(result_for)
# 計算終了時間の取得
time2 = time.time()
# 計算時間を ms で表示する。
calc_time1 = 1000*(time2 - time1)
print("for + if 文による計算時間: ", calc_time1, " ms \n")
# where文で 2行目の配列が 0.9以上の要素を持つ 1行目の要素を抽出する。
# 計算開始時間の取得
time3 = time.time()
# where 関数で2行目の配列から 0.9以上の要素を持つ列番号を抽出する。
conditional = np.where(data_2d[1, :] >= 0.9)
# 1行目の配列から、上の条件を満たした列番号に対応する要素を抽出する。
result_where = data_2d[0, conditional]
print(result_where)
# 計算終了時間の取得
time4 = time.time()
# 計算時間を ms で表示する。
calc_time2 = 1000*(time4 - time3)
print("where 関数による計算時間: ", calc_time2, " ms")import numpy as np
import time
# 0から 999,999までの数値をもつ行と、それに対応した 0から 1までの
# 乱数を持つ行を組み合わせた 2次元配列を作る。
line0 = np.arange(0, 1000000)
line1 = np.random.random(len(line0))
data_2d = np.array([line0, line1])
# for文と if文で 2行目の配列が 0.9以上の要素を持つ 1行目の要素を抽出する。
# 計算開始時間の取得
time1 = time.time()
# 結果を格納するための空配列を用意する。
result_for = []
# for文を配列の列数分だけ回す。
for i in range(0, data_2d.shape[1]):
# if文で 2行目の配列が 0.9以上かどうかを判定する。
if data_2d[1, i] >= 0.9:
# 1行目の同じ列にある要素を初めに用意した空配列に格納する。
result_for.append(data_2d[0, i])
print(result_for)
# 計算終了時間の取得
time2 = time.time()
# 計算時間を ms で表示する。
calc_time1 = 1000*(time2 - time1)
print("for + if 文による計算時間: ", calc_time1, " ms \n")
# where文で 2行目の配列が 0.9以上の要素を持つ 1行目の要素を抽出する。
# 計算開始時間の取得
time3 = time.time()
# where 関数で2行目の配列から 0.9以上の要素を持つ列番号を抽出する。
conditional = np.where(data_2d[1, :] >= 0.9)
# 1行目の配列から、上の条件を満たした列番号に対応する要素を抽出する。
result_where = data_2d[0, conditional]
print(result_where)
# 計算終了時間の取得
time4 = time.time()
# 計算時間を ms で表示する。
calc_time2 = 1000*(time4 - time3)
print("where 関数による計算時間: ", calc_time2, " ms")
実行結果:
for + if 文 184.9 ミリ秒
np.where 7.3 ミリ秒
なんと、np.where の方が for + if 文よりも 25倍も処理速度が速いという結果になりました。
処理速度は当然環境によって変わりますが、それでも forループを回避した方が圧倒的に処理が速いことがわかりました。
for ループ自体は C や Java など他の言語でも使用するので、知っていて損はないと思います。
しかし Python で for ループを回すのはやはり時間がかかるようです。
今回の例のように、データ配列の中から条件を満たすものを抽出する処理の場合は、NumPy の where関数を積極的に使ってみてください。
今後、forループ回避の他の例も紹介していければと思います。