Cythonの応用技法#
import numpy as np
import helper.magics
%load_ext helper.cython
拡張型#
Pythonではclass
を使って定義されたクラスは、インスタンスの属性を辞書で保存します。しかし、属性へのアクセス速度を向上させるため、Pythonの組み込み型は属性をオブジェクトの構造体のフィールドに直接保存します。Cythonではcdef class
を使って拡張型を定義できます。拡張型はPythonの組み込み型と同様に、C言語の構造体を使ってオブジェクトの属性を保存するため、Cythonプログラム内でこれらの属性に高速にアクセスできます。拡張型はC言語の関数ライブラリをラップし、オブジェクト指向のPythonインターフェースを提供するのに適しています。
拡張型の基本構造#
以下のプログラムでは、cdef class
を使って拡張型Point2D
を定義し、cdef
を使って属性x
とy
を定義しています。Pythonのクラスとは異なり、拡張型の属性はクラス内で定義され、__init__()
メソッド内で生成されるわけではありません。
%%cython
cdef class Point2D:
cdef public double x, y
Cythonは自動的に以下の構造体を定義してPoint2D
オブジェクトを表現します。ob_refcnt
とob_type
の2つのフィールドはすべてのPythonオブジェクトに必要なため、PythonのC言語コードでは通常PyObject_HEAD
マクロを使って定義されます。
struct __pyx_obj_Point2D {
PyObject_HEAD
double x;
double y;
};
Cythonプログラム内では、オブジェクトp
の型がPoint2D
であることが明確にわかっている場合、p.x
はPoint2D
構造体のx
フィールドに直接アクセスするように変換されるため、拡張型の変数の属性へのアクセスは非常に高速です。PythonでPoint2D
のx
とy
属性にアクセスするためには、属性を宣言する際にpublic
キーワードを使用する必要があります。Cythonは整数型、浮動小数点型、文字列型、およびPythonオブジェクト型の4種類のcdef
属性に対して、属性アクセス用のディスクリプタを自動的に作成します。これらのディスクリプタには__get__()
と__set__()
メソッドが含まれており、属性の取得と設定を行います。読み取り専用属性の場合は、public
キーワードをreadonly
に置き換えることができます。以下のコードはx
に対応する属性アクセスディスクリプタを表示します:
print(type(Point2D.x))
print(Point2D.x.__get__)
print(Point2D.x.__set__)
<class 'getset_descriptor'>
<method-wrapper '__get__' of getset_descriptor object at 0x0000028B617D1B40>
<method-wrapper '__set__' of getset_descriptor object at 0x0000028B617D1B40>
関数の定義と同様に、拡張型ではdef
、cdef
、cpdef
を使ってオブジェクトのメソッドを定義できます。すべてのメソッドはCython内で呼び出すことができますが、def
とcpdef
で定義されたメソッドのみがPython内で呼び出すことができます。Cython内でcdef
とcpdef
メソッドを呼び出す場合、対応するC言語関数が直接呼び出されるため、def
メソッドよりも効率が大幅に向上します。
拡張型は他の拡張型からの継承をサポートしています。例えば、以下のPoint3D
はPoint2D
から継承し、フィールドz
を追加しています:
%%cython -a
cdef class Point2D:
cdef public double x, y
cdef class Point3D(Point2D):
cdef public double z
cdef Point3D p = Point3D()
p.x = 1.0
p.y = 2.0
p.z = 3.0
Generated by Cython 3.0.12
Yellow lines hint at Python interaction.
Click on a line that starts with a "+
" to see the C code that Cython generated for it.
01:
+02: cdef class Point2D:
__pyx_t_2 = __Pyx_PyDict_NewPresized(0); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 2, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_2); if (PyDict_SetItem(__pyx_d, __pyx_n_s_test, __pyx_t_2) < 0) __PYX_ERR(0, 2, __pyx_L1_error) __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; /* … */ struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D { PyObject_HEAD double x; double y; };
+03: cdef public double x, y
/* Python wrapper */ static PyObject *__pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1x_1__get__(PyObject *__pyx_v_self); /*proto*/ static PyObject *__pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1x_1__get__(PyObject *__pyx_v_self) { CYTHON_UNUSED PyObject *const *__pyx_kwvalues; PyObject *__pyx_r = 0; __Pyx_RefNannyDeclarations __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); __pyx_r = __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1x___get__(((struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D *)__pyx_v_self)); /* function exit code */ __Pyx_RefNannyFinishContext(); return __pyx_r; } static PyObject *__pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1x___get__(struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D *__pyx_v_self) { PyObject *__pyx_r = NULL; __Pyx_XDECREF(__pyx_r); __pyx_t_1 = PyFloat_FromDouble(__pyx_v_self->x); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 3, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_r = __pyx_t_1; __pyx_t_1 = 0; goto __pyx_L0; /* function exit code */ __pyx_L1_error:; __Pyx_XDECREF(__pyx_t_1); __Pyx_AddTraceback("_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e.Point2D.x.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); __pyx_r = NULL; __pyx_L0:; __Pyx_XGIVEREF(__pyx_r); __Pyx_RefNannyFinishContext(); return __pyx_r; } /* Python wrapper */ static int __pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1x_3__set__(PyObject *__pyx_v_self, PyObject *__pyx_v_value); /*proto*/ static int __pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1x_3__set__(PyObject *__pyx_v_self, PyObject *__pyx_v_value) { CYTHON_UNUSED PyObject *const *__pyx_kwvalues; int __pyx_r; __Pyx_RefNannyDeclarations __Pyx_RefNannySetupContext("__set__ (wrapper)", 0); __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); __pyx_r = __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1x_2__set__(((struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D *)__pyx_v_self), ((PyObject *)__pyx_v_value)); /* function exit code */ __Pyx_RefNannyFinishContext(); return __pyx_r; } static int __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1x_2__set__(struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D *__pyx_v_self, PyObject *__pyx_v_value) { int __pyx_r; __pyx_t_1 = __pyx_PyFloat_AsDouble(__pyx_v_value); if (unlikely((__pyx_t_1 == (double)-1) && PyErr_Occurred())) __PYX_ERR(0, 3, __pyx_L1_error) __pyx_v_self->x = __pyx_t_1; /* function exit code */ __pyx_r = 0; goto __pyx_L0; __pyx_L1_error:; __Pyx_AddTraceback("_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e.Point2D.x.__set__", __pyx_clineno, __pyx_lineno, __pyx_filename); __pyx_r = -1; __pyx_L0:; return __pyx_r; } /* Python wrapper */ static PyObject *__pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1y_1__get__(PyObject *__pyx_v_self); /*proto*/ static PyObject *__pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1y_1__get__(PyObject *__pyx_v_self) { CYTHON_UNUSED PyObject *const *__pyx_kwvalues; PyObject *__pyx_r = 0; __Pyx_RefNannyDeclarations __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); __pyx_r = __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1y___get__(((struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D *)__pyx_v_self)); /* function exit code */ __Pyx_RefNannyFinishContext(); return __pyx_r; } static PyObject *__pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1y___get__(struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D *__pyx_v_self) { PyObject *__pyx_r = NULL; __Pyx_XDECREF(__pyx_r); __pyx_t_1 = PyFloat_FromDouble(__pyx_v_self->y); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 3, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_r = __pyx_t_1; __pyx_t_1 = 0; goto __pyx_L0; /* function exit code */ __pyx_L1_error:; __Pyx_XDECREF(__pyx_t_1); __Pyx_AddTraceback("_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e.Point2D.y.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); __pyx_r = NULL; __pyx_L0:; __Pyx_XGIVEREF(__pyx_r); __Pyx_RefNannyFinishContext(); return __pyx_r; } /* Python wrapper */ static int __pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1y_3__set__(PyObject *__pyx_v_self, PyObject *__pyx_v_value); /*proto*/ static int __pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1y_3__set__(PyObject *__pyx_v_self, PyObject *__pyx_v_value) { CYTHON_UNUSED PyObject *const *__pyx_kwvalues; int __pyx_r; __Pyx_RefNannyDeclarations __Pyx_RefNannySetupContext("__set__ (wrapper)", 0); __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); __pyx_r = __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1y_2__set__(((struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D *)__pyx_v_self), ((PyObject *)__pyx_v_value)); /* function exit code */ __Pyx_RefNannyFinishContext(); return __pyx_r; } static int __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1y_2__set__(struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D *__pyx_v_self, PyObject *__pyx_v_value) { int __pyx_r; __pyx_t_1 = __pyx_PyFloat_AsDouble(__pyx_v_value); if (unlikely((__pyx_t_1 == (double)-1) && PyErr_Occurred())) __PYX_ERR(0, 3, __pyx_L1_error) __pyx_v_self->y = __pyx_t_1; /* function exit code */ __pyx_r = 0; goto __pyx_L0; __pyx_L1_error:; __Pyx_AddTraceback("_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e.Point2D.y.__set__", __pyx_clineno, __pyx_lineno, __pyx_filename); __pyx_r = -1; __pyx_L0:; return __pyx_r; }
04:
05: cdef class Point3D(Point2D):
+06: cdef public double z
/* Python wrapper */ static PyObject *__pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point3D_1z_1__get__(PyObject *__pyx_v_self); /*proto*/ static PyObject *__pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point3D_1z_1__get__(PyObject *__pyx_v_self) { CYTHON_UNUSED PyObject *const *__pyx_kwvalues; PyObject *__pyx_r = 0; __Pyx_RefNannyDeclarations __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); __pyx_r = __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point3D_1z___get__(((struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point3D *)__pyx_v_self)); /* function exit code */ __Pyx_RefNannyFinishContext(); return __pyx_r; } static PyObject *__pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point3D_1z___get__(struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point3D *__pyx_v_self) { PyObject *__pyx_r = NULL; __Pyx_XDECREF(__pyx_r); __pyx_t_1 = PyFloat_FromDouble(__pyx_v_self->z); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 6, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_r = __pyx_t_1; __pyx_t_1 = 0; goto __pyx_L0; /* function exit code */ __pyx_L1_error:; __Pyx_XDECREF(__pyx_t_1); __Pyx_AddTraceback("_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e.Point3D.z.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); __pyx_r = NULL; __pyx_L0:; __Pyx_XGIVEREF(__pyx_r); __Pyx_RefNannyFinishContext(); return __pyx_r; } /* Python wrapper */ static int __pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point3D_1z_3__set__(PyObject *__pyx_v_self, PyObject *__pyx_v_value); /*proto*/ static int __pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point3D_1z_3__set__(PyObject *__pyx_v_self, PyObject *__pyx_v_value) { CYTHON_UNUSED PyObject *const *__pyx_kwvalues; int __pyx_r; __Pyx_RefNannyDeclarations __Pyx_RefNannySetupContext("__set__ (wrapper)", 0); __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); __pyx_r = __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point3D_1z_2__set__(((struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point3D *)__pyx_v_self), ((PyObject *)__pyx_v_value)); /* function exit code */ __Pyx_RefNannyFinishContext(); return __pyx_r; } static int __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point3D_1z_2__set__(struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point3D *__pyx_v_self, PyObject *__pyx_v_value) { int __pyx_r; __pyx_t_1 = __pyx_PyFloat_AsDouble(__pyx_v_value); if (unlikely((__pyx_t_1 == (double)-1) && PyErr_Occurred())) __PYX_ERR(0, 6, __pyx_L1_error) __pyx_v_self->z = __pyx_t_1; /* function exit code */ __pyx_r = 0; goto __pyx_L0; __pyx_L1_error:; __Pyx_AddTraceback("_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e.Point3D.z.__set__", __pyx_clineno, __pyx_lineno, __pyx_filename); __pyx_r = -1; __pyx_L0:; return __pyx_r; }
07:
+08: cdef Point3D p = Point3D()
__pyx_t_2 = __Pyx_PyObject_CallNoArg(((PyObject *)__pyx_ptype_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point3D)); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 8, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_2); __Pyx_XGOTREF((PyObject *)__pyx_v_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_p); __Pyx_DECREF_SET(__pyx_v_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_p, ((struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point3D *)__pyx_t_2)); __Pyx_GIVEREF(__pyx_t_2); __pyx_t_2 = 0;
+09: p.x = 1.0
__pyx_v_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_p->__pyx_base.x = 1.0;
+10: p.y = 2.0
__pyx_v_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_p->__pyx_base.y = 2.0;
+11: p.z = 3.0
__pyx_v_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_p->z = 3.0;
Point3D
オブジェクトに対応するC言語構造体は以下のようになります:
struct __pyx_obj_Point3D {
struct __pyx_obj_Point2D __pyx_base;
double z;
};
最初のフィールド__pyx_base
は、その基底クラスPoint2D
に対応する構造体です。Cythonでは__pyx_base.x
を使って基底クラスで定義された属性x
にアクセスし、PythonではPoint3D.__base__
内のx
に対応するディスクリプタを使ってこの属性にアクセスします。
1次元浮動小数点ベクトル型#
次に、1次元浮動小数点ベクトルVector
型を例に、拡張型の使用方法を紹介します。Vector
オブジェクトは2つのプライベート属性を持っています:count
は配列の長さを表し、data
は配列データの先頭アドレスを保持します。
Note
長いCythonコードを部分ごとに取り出して説明するため、本書では提供する%include
マジックコマンドを利用します。
%include cython_examples/vector.pyx 1
cdef class Vector:
cdef int count
cdef double * data
新しいVector
オブジェクトを作成する際、ヒープメモリから構造体が割り当てられます。構造体のメモリ割り当てが完了した後、すぐにその中の属性を初期化する必要があります。この初期化作業は__cinit__()
によって行われ、C言語レベルの__init__()
と考えることができます。Vector
オブジェクトは2つの初期化方法をサポートしています:指定されたサイズのメモリを割り当てるか、シーケンスオブジェクトを使って配列の内容を初期化します。ここではPython/C APIのPyMem_Malloc()
を使って、Pythonが管理するヒープから配列データを保存するメモリを割り当てます。このAPI関数を使用するためには、from cpython cimport mem
を使ってメモリ管理API関数のヘッダファイルを読み込む必要があります。
%include cython_examples/vector.pyx 2
def __cinit__(self, data):
cdef int i
if isinstance(data, int):
self.count = data
else:
self.count = len(data)
self.data = <double *>mem.PyMem_Malloc(sizeof(double)*self.count)
if self.data is NULL:
raise MemoryError
if not isinstance(data, int):
for i in range(self.count):
self.data[i] = data[i]
オブジェクトの参照カウントが0になると、ガベージコレクションによって回収されます。オブジェクト構造体が回収される前に、Cythonは__dealloc__()
を呼び出し、ここでdata
属性が指すメモリを解放する必要があります:
%include cython_examples/vector.pyx 3
def __dealloc__(self):
if self.data is not NULL:
mem.PyMem_Free(self.data)
Vector
オブジェクトが整数の添字アクセスとループによる反復をサポートするためには、__len__()
、__getitem__()
、__setitem__()
などのメソッドを定義する必要があります。これらの2つのアンダースコアで始まり終わるメソッドはマジックメソッドと呼ばれ、これらのメソッドを定義することで、特定の構文におけるオブジェクトの動作を変更できます。例えば、オブジェクトの長さを調べる組み込み関数len(obj)
は実際にはobj.__len__()
の値を返し、obj[index]
はobj.__getitem__(index)
を呼び出し、obj[index] = value
はobj.__setitem__(index, value)
を呼び出します。ある型が__len__()
と__getitem__()
を定義している場合、そのオブジェクトは自動的にfor
ループによる要素の反復をサポートします。
以下のプログラムでは、__getitem__()
と__setitem__()
はその添字引数が整数型であることを要求するため、Vector
オブジェクトはスライス添字をサポートしません。負の添字と添字の範囲外チェックのサポートを_check_index()
メソッドに分離し、このメソッドの引数は整数型のポインタで、渡された添字変数を直接変更できます。添字が範囲外の場合、IndexError
例外を発生させます。Cythonではp[0]
を使ってポインタ変数p
が指すアドレスにアクセスします。
%include cython_examples/vector.pyx 4
def __len__(self):
return self.count
cdef _check_index(self, int *index):
if index[0] < 0:
index[0] = self.count + index[0]
if index[0] < 0 or index[0] > self.count - 1:
raise IndexError("Vector index out of range")
def __getitem__(self, int index):
self._check_index(&index)
return self.data[index]
def __setitem__(self, int index, double value):
self._check_index(&index)
self.data[index] = value
Vector
オブジェクトが加算演算子をサポートするためには、__add__()
メソッドを定義する必要があります:
%include cython_examples/vector.pyx 5
def _add(self, other):
cdef Vector new, _other
if isinstance(other, Vector): #❷
_other = <Vector>other #❸
if self.count != _other.count:
raise ValueError("Vector size not equal")
new = Vector(self.count) #❹
add_array(self.data, _other.data, new.data, self.count)
return new
new = Vector(self.count) #❹
add_number(self.data, <double>other, new.data, self.count)
return new
def __add__(self, other): #❶
return self._add(other)
def __radd__(self, other): #❶
return self._add(other)
足し算や掛け算のような二項演算では、演算子が左右のオペランドに対してどのように動作するかを制御するために、特定のメソッドが使用されます。例えば、足し算に関するマジックメソッド __add__
と __radd__
について、a + b
のような式で、まず a.__add__(b)
が呼び出され、もしそれが失敗した場合は b.__radd__(a)
が実行されます。
v
がVector
オブジェクトの場合、1 + v
を実行すると、整数 1
の __add__()
関数呼び出しが失敗し、Vector.__radd__()
が呼び出され、引数 other
は 1
となります。
プログラムでは、❶ __add__()
と__radd__()
はヘルプメソッド_add()
を呼び出します。❷__add__()
は数値とVector
オブジェクトの両方を処理できるため、other
オブジェクトの型に応じて異なる処理を行う必要があります。❸self
とother
変数には型宣言がないため、C言語構造体に保存されている属性を取得できません。<Vector>
を使ってPythonオブジェクトを型付きの変数_self
と_other
に変換し、これらの変数を通じてcount
とdata
属性にアクセスします。❹演算結果を保存するVector
オブジェクトを作成し、add_array()
またはadd_number()
を呼び出して計算を行います。これらの関数のコードは後で紹介します。
以下は+=
演算子に対応するマジックメソッド__iadd__()
です:
%include cython_examples/vector.pyx 6
def __iadd__(self, other):
cdef Vector _other
if isinstance(other, Vector):
_other = <Vector>other
if self.count != _other.count:
raise ValueError("Vector size not equal")
add_array(self.data, _other.data, self.data, self.count)
else:
add_number(self.data, <double>other, self.data, self.count)
return self
二項演算子とは異なり、__iadd__()
の最初の引数は現在のオブジェクトであるため、Cythonはその型を知っており、型変換を行う必要はありません。
以下では、cpdef
を使ってnorm()
メソッドを定義し、ベクトルの長さを計算し、__str__()
でそれを呼び出します。オブジェクトを文字列に変換する際、__str__()
メソッドが呼び出されます。cpdef
で定義されたメソッドは、CythonとPythonの両方の呼び出しインターフェースを生成します。__str__()
ではCythonの呼び出しインターフェースを使ってnorm()
を実行し、Pythonではより遅いインターフェースを使ってnorm()
を呼び出します。
%include cython_examples/vector.pyx 7
def __str__(self):
values = ", ".join(str(self.data[i]) for i in range(self.count))
norm = self.norm()
return "Vector[{}]({})".format(norm, values)
cpdef norm(self):
cdef double *p
cdef double s
cdef int i
s = 0
p = self.data
for i in range(self.count):
s += p[i] * p[i]
return s**0.5
さらに、拡張型はPythonで継承でき、基底クラスで定義されたdef
とcpdef
メソッドをオーバーライドできます。Pythonで定義されたオーバーライドメソッドは、Cython内で正しく呼び出すことができます。Cython内でcpdef
メソッドを呼び出す場合、このメソッドがオーバーライドされているかどうかをチェックする必要があるため、cdef
メソッドよりも若干呼び出し速度が遅くなります。
最後に、コードの先頭に計算を行うadd_array()
とadd_number()
の2つの関数を追加します。これらはcdef
で定義され、Cythonコード内でのみ呼び出すことができます。
%include cython_examples/vector.pyx 8
cdef add_array(double *op1, double *op2, double *res, int count):
cdef int i
for i in range(count):
res[i] = op1[i] + op2[i]
cdef add_number(double *op1, double op2, double *res, int count):
cdef int i
for i in range(count):
res[i] = op1[i] + op2
!cythonize -i cython_examples/vector.pyx
他の二項計算関数や要素アクセス関数を完成させてみてください。以下はVector
オブジェクトの使用例です:
from cython_examples.vector import Vector
v1 = Vector(range(5))
v2 = Vector(range(100, 105))
print(len(v1))
print(v1 + v2)
print(v1 + 2)
print(20 + v2)
print(v1.norm(), v2.norm())
print([x**2 for x in v1])
5
Vector[232.63705637752557](100.0, 102.0, 104.0, 106.0, 108.0)
Vector[9.486832980505138](2.0, 3.0, 4.0, 5.0, 6.0)
Vector[272.81862106535175](120.0, 121.0, 122.0, 123.0, 124.0)
5.477225575051661 228.10085488660502
[0.0, 1.0, 4.0, 9.0, 16.0]
以下はVector
オブジェクトとNumPy配列のベクトル加算の速度比較です:
v1 = Vector(range(10000))
v2 = Vector(range(10000))
%timeit v1 + v2
a1 = np.arange(10000, dtype=float)
a2 = np.arange(10000, dtype=float)
%timeit a1 + a2
7.05 μs ± 61.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
7.23 μs ± 133 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
以下はVectorオブジェクトとNumPy配列の要素アクセス速度の比較です:
%timeit v1[100]
%timeit v1[100] = 2.0
%timeit a1[100]
%timeit a1[100] = 2.0
59.7 ns ± 0.726 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
62.7 ns ± 1.45 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
137 ns ± 2.18 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
122 ns ± 2.14 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
multifastライブラリのラッピング#
拡張型は、C言語の関数ライブラリをラップし、オブジェクト指向のPythonインターフェースを提供するためによく使用されます。このセクションでは、「多パターンマッチングアルゴリズム」のC言語ライブラリmultifast
をラップする例を通じて、C言語関数ライブラリを拡張型でラップする方法を紹介します。
See also
http://multifast.sourceforge.net/v1_4_2/ 多パターンマッチングアルゴリズムを実装したmultifastプロジェクト
以下はこの拡張型の使用方法のデモンストレーションです。まず、一連のキーワードを使ってMultiSearch
オブジェクトを作成し、そのisin()
メソッドを呼び出して、ターゲットバイト列内でキーワードを検索します。キーワードのいずれかがバイト列内に存在する場合、True
を返し、そうでない場合はFalse
を返します。
from cython_examples.multisearch import MultiSearch
ms = MultiSearch([b"abc", b"xyz"])
print(ms.isin(b"123abcdef"))
print(ms.isin(b"123uvwxyz"))
print(ms.isin(b"123456789"))
True
True
False
search()
メソッドは、ターゲットのバイト列内でキーワードの位置を検索するために使用できます。その第2引数はコールバック関数で、マッチする位置が見つかるたびにその位置とマッチしたキーワードをコールバック関数に渡します。コールバック関数が0を返すと検索を続行し、1を返すと検索を終了します。
def process(pos, pattern):
print("found {0} at {1}".format(pattern, pos))
return 0
ms.search(b"123abc456xyz789abc", process)
found b'abc' at 3
found b'xyz' at 9
found b'abc' at 15
また、iter_search()
メソッドを使用してイテレータを返すこともできます:
for pos, pattern in ms.iter_search(b"123abc456xyz789abc"):
print("found {0} at {1}".format(pattern, pos))
found b'abc' at 3
found b'xyz' at 9
found b'abc' at 15
拡張タイプの作成を開始する前に、C言語でこのライブラリをどのように使用するかを見てみましょう:
#include <stdio.h>
#include "ahocorasick.h"
/* 検索キーワードリスト */
AC_ALPHABET_t * allstr[] = {
"recent", "from", "college"
};
#define PATTERN_NUMBER (sizeof(allstr)/sizeof(AC_ALPHABET_t *))
/* 検索テキスト */
AC_ALPHABET_t * input_text = {"She recently graduated from college"};
//*** マッチ時のコールバック関数
int match_handler(AC_MATCH_t * m, void * param)
{
unsigned int j;
printf ("@ %ld : %s
", m->position, m->patterns->astring);
/* 0を返すと検索を続行し、1を返すと検索を停止 */
return 0;
}
int main (int argc, char ** argv)
{
unsigned int i;
AC_AUTOMATA_t * acap;
AC_PATTERN_t tmp_patt;
AC_TEXT_t tmp_text;
//*** AC_AUTOMATA_t構造体を作成し、コールバック関数を渡す
acap = ac_automata_init();
//*** キーワードを追加
for (i=0; i<PATTERN_NUMBER; i++)
{
tmp_patt.astring = allstr[i];
tmp_patt.rep.number = i+1; // optional
tmp_patt.length = strlen(tmp_patt.astring);
ac_automata_add (acap, &tmp_patt);
}
//*** キーワードの追加を終了
ac_automata_finalize (acap);
//*** 検索対象のバイト列を設定
tmp_text.astring = input_text;
tmp_text.length = strlen(tmp_text.astring);
//*** 検索を実行
ac_automata_search (acap, &tmp_text, 0, match_handler, NULL);
//*** メモリを解放
ac_automata_release (acap);
return 0;
}
上記のプログラムからわかるように、このライブラリ全体はAC_AUTOMATA_t
構造体を中心に処理されています。これはC言語でデータをカプセル化する一般的な方法です。Cythonで拡張タイプを使用してこのようなライブラリをラップする場合、通常はこの構造体へのポインタ属性を作成し、__cinit__()
と__dealloc__()
でこの構造体を割り当てたり解放したりします。その後、C言語ライブラリが提供する各API関数を呼び出すいくつかのdef
メソッドを定義して、ラッピングを実現します。
以下では、C言語のライブラリを拡張タイプでラップする方法を段階的に説明します。
%include cython_examples/multisearch.pyx 1
cdef extern from "ahocorasick.h": #❶
ctypedef int (*AC_MATCH_CALBACK_f)(AC_MATCH_t *, void *) #❷
ctypedef enum AC_STATUS_t: #❸
ACERR_SUCCESS = 0
ACERR_DUPLICATE_PATTERN
ACERR_LONG_PATTERN
ACERR_ZERO_PATTERN
ACERR_AUTOMATA_CLOSED
ctypedef struct AC_MATCH_t: #❹
AC_PATTERN_t * patterns
long position
unsigned int match_num
ctypedef struct AC_AUTOMATA_t:
AC_MATCH_t match
ctypedef struct AC_PATTERN_t:
char * astring
unsigned int length
ctypedef struct AC_TEXT_t:
char * astring
unsigned int length
#❺
AC_AUTOMATA_t * ac_automata_init()
AC_STATUS_t ac_automata_add(AC_AUTOMATA_t * thiz, AC_PATTERN_t * pattern)
void ac_automata_finalize(AC_AUTOMATA_t * thiz)
int ac_automata_search(AC_AUTOMATA_t * thiz, AC_TEXT_t * text, int keep,
AC_MATCH_CALBACK_f callback, void * param)
void ac_automata_settext (AC_AUTOMATA_t * thiz, AC_TEXT_t * text, int keep)
AC_MATCH_t * ac_automata_findnext (AC_AUTOMATA_t * thiz)
void ac_automata_release(AC_AUTOMATA_t * thiz)
❶まず、cdef extern from ...
を使用してCythonに、コンパイル後のC言語プログラムにahocorasick.h
ヘッダーファイルを含める必要があることを伝えます。CythonはC言語のヘッダーファイルを自動的に解析しないため、使用する型、定数、関数プロトタイプをCythonの構文で宣言する必要があります。
❷関数ポインタ型MATCH_CALBACK_f
を定義します。これはコールバック関数へのポインタ型です。その第1引数は、マッチデータを保存する構造体へのポインタで、第2引数は任意の追加データを指すポインタです。C言語では通常、このようなvoid *
型のポインタを使用してユーザー定義データを渡します。
❸❹列挙型と構造体型を定義します。Cythonプログラムで使用する列挙メンバーと構造体のフィールドのみを定義すれば十分です。Cythonで構造体のフィールドにアクセスしない場合は、フィールドの定義の代わりにpass
キーワードを使用できます。
❺Cythonプログラムで呼び出す関数のプロトタイプを定義します。
上記のプログラムをC言語プログラムにコンパイルすると、#include "ahocorasick.h"
という1行だけが残り、残りの型宣言はCythonにこれらの型を操作するステートメントをどのようにコンパイルするかを指示します。例えば、構造体AC_PATTERN_t
のlength
フィールドはunsigned int
型として宣言されているため、必要に応じてCythonはPython/C APIを呼び出して、Pythonの整数オブジェクトとunsigned int
型の間で変換を行います。
次に、MultiSearch
拡張タイプの定義を示します:
%include cython_examples/multisearch.pyx 2
cdef class MultiSearch:
cdef AC_AUTOMATA_t * _auto #❶
cdef bint found
cdef object callback
cdef object exc_info
cdef list _keywords
def __cinit__(self, keywords):
self._auto = ac_automata_init()
if self._auto is NULL:
raise MemoryError
self._keywords = []
self.add(keywords) #❷
def __dealloc__(self):
if self._auto is not NULL:
ac_automata_release(self._auto)
cdef add(self, keywords):
cdef AC_PATTERN_t pattern
cdef bytes keyword
cdef AC_STATUS_t err
for keyword in keywords: #❸
self._keywords.append(keyword)
pattern.astring = <char *>keyword
pattern.length = len(keyword)
err = ac_automata_add(self._auto, &pattern)
if err != ACERR_SUCCESS:
raise ValueError("Error Code:%d" % err)
ac_automata_finalize(self._auto)
❶_auto
属性はAC_AUTOMATA_t
構造体へのポインタです。__cinit__()
でac_automata_init()
を呼び出してメモリを割り当て、__dealloc__()
でac_automata_release()
を呼び出してメモリを解放します。❷AC_AUTOMATA_t
構造体の割り当てが成功した後、cdef
メソッドadd()
を呼び出してすべてのキーワードをこの構造体に追加します。
❸add()
内部でkeywords
パラメータをイテレートし、各要素をbytes
型として処理し、<char *>
を使用してC言語の文字ポインタ型に変換し、その長さとともにAC_PATTERN_t
構造体にパッケージ化してac_automata_add()
に渡します。この関数が返された後、ahocorasick
内部の関数は文字ポインタが指す内容を使用しないため、この方法は安全です。後続の関数呼び出しで文字ポインタが指す内容を使用する必要がある場合は、keywords
内の各バイト列オブジェクトを参照して、それらが早期にガベージコレクションされないようにする必要があります。
次に、isin()
メソッドの定義を示します:
%include cython_examples/multisearch.pyx 3
def isin(self, bytes text, bint keep=False):
cdef AC_TEXT_t temp_text #❶
temp_text.astring = <char *>text
temp_text.length = len(text)
self.found = False #❷
ac_automata_search(self._auto, &temp_text, keep, isin_callback, <void *>self) #❸
return self.found
❶ターゲットのバイト列をAC_TEXT_t
構造体にパッケージ化し、❸ac_automata_search()
に渡して検索を実行します。検索の前に、❷found
属性をFalse
に設定します。isin_callback()
関数のアドレスをac_automata_search()
に渡して検索のコールバック関数として使用します。その最後のパラメータはコールバック関数に渡されるユーザーデータで、ここではMultiSearch
オブジェクトのアドレスを渡します。これにより、コールバック関数内でMultiSearch
オブジェクトの属性にアクセスできます。
以下はisin_callback()
コールバック関数の定義です。この関数はC言語ライブラリ内部で呼び出されるため、cdef
でのみ定義できることに注意してください:
%include cython_examples/multisearch.pyx 4
cdef int isin_callback(AC_MATCH_t * match, void * param) noexcept:
cdef MultiSearch ms = <MultiSearch> param #❶
ms.found = True #❷
return 1 #❸
isin_callback()
の第1引数はマッチ情報を記述する構造体へのポインタで、第2引数はMultiSearch
オブジェクトへのポインタです。❶まず、void *
型のポインタをMultiSearch
オブジェクトに変換し、ms
を通じてMultiSearch
拡張クラスで定義されたさまざまな属性やメソッドにアクセスできます。❷MultiSearch
オブジェクトのfound
属性をTrue
に設定し、マッチ位置が見つかったことを示します。❸isin()
は1つのマッチ位置を見つけるだけで十分なため、関数は1を返して検索を終了します。
次に、search()
を定義します。その第1引数は検索対象のバイト列で、第2引数はPythonの呼び出し可能オブジェクトです。マッチ位置が見つかるたびにこのオブジェクトを呼び出して処理します:
Tip
C言語のコールバック関数とPythonのコールバック関数の使用方法を紹介するために、ここではac_automata_search()
を使用していますが、実際には後述のac_automata_findnext()
を使用すると、isin()
とsearch()
関数をより簡単に記述できます。
%include cython_examples/multisearch.pyx 5
def search(self, bytes text, callback, bint keep=False):
cdef AC_TEXT_t temp_text
temp_text.astring = <char *>text
temp_text.length = len(text)
self.found = False
self.callback = callback #❶
self.exc_info = None
ac_automata_search(self._auto, &temp_text, keep, search_callback, <void *>self) #❷
if self.exc_info is not None:
raise self.exc_info[1], None, self.exc_info[2] #❸
❷ac_automata_search()
のコールバック関数はsearch_callback()
です。Pythonの呼び出し可能オブジェクトを呼び出すために、❶callback
パラメータをself.callback
に渡します。❸C言語の関数はPython関数がスローする例外情報を上位に伝えることができないため、exc_info
属性を使用してPythonコールバック関数でスローされる可能性のある例外を伝える必要があります。
以下は search_callback()
コールバック関数の定義です。noexcept
を指定することで、この関数内では例外が発生しないことを宣言します。
%include cython_examples/multisearch.pyx 6
cdef int search_callback(AC_MATCH_t * match, void * param) noexcept:
cdef MultiSearch ms = <MultiSearch> param
cdef bytes pattern = match.patterns.astring
cdef int res = 1
try:
res = ms.callback(match.position - len(pattern), pattern) #❶
except Exception as ex:
import sys
ms.exc_info = sys.exc_info() #❷
return res
第2引数をMultiSearch
オブジェクトに変換した後、❶callback
が指すPythonコールバック関数を呼び出し、スローされる可能性のある例外を捕捉します。❷例外とそのトレースバック情報をexc_info
属性に保存します。search()
の最後でexc_info
属性をチェックし、設定されている場合はその中の例外オブジェクトをスローします。ここでは、キャプチャされた例外を直接保存するのではなく、sys.exc_info()
を使用して例外情報を取得していることに注意してください。これにより、エラーの位置を正しく示すトレースバック情報が得られます。
multifastのC言語プログラムでは、ac_automata_settext()
とac_automata_findnext()
も提供されています。これら2つの関数を使用して、以下のようなジェネレータ関数iter_search()
を記述できます。
%include cython_examples/multisearch.pyx 7
def iter_search(self, bytes text, bint keep=False):
cdef AC_TEXT_t temp_text
cdef AC_MATCH_t * match
cdef bytes matched_pattern
temp_text.astring = <char *>text
temp_text.length = len(text)
ac_automata_settext(self._auto, &temp_text, keep)
while True:
match = ac_automata_findnext(self._auto)
if match == NULL:
break
matched_pattern = <bytes>match.patterns.astring
yield match.position - len(matched_pattern), matched_pattern
Cythonのテクニック集#
前の章の紹介を通じて、読者はCythonを使用してPythonプログラムの計算速度を向上させる基本的な方法をすでに習得していると思います。本章の最後の節として、Cythonのいくつかの高度な使用テクニックを見てみましょう。
ufunc 関数の作成#
NumPy の ufunc(ユニバーサル関数)は、配列の各要素に対して効率的に演算を行うための関数です。NumPy の C-API を利用すると、C 言語を用いて独自の ufunc 関数を作成できます。興味のある読者は、以下の URL から関連するチュートリアルを参照してください。
また、@cython.ufunc
を使用すると、通常の関数をufunc関数に変換できます。以下のプログラムは、ロジスティック関数 \(\frac{1}{1+e^{-x}}\) を計算する ufunc logistic()
を作成する例です。
%%cython --compile-args=-w
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
cimport cython
from libc.math cimport exp
@cython.ufunc
cdef double logistic(double x):
return 1.0 / (1.0 + exp(-x))
以下は logistic()
関数のテストです。double
型の計算関数のみを定義していますが、さまざまな型の配列やリストを処理できます。
logistic([-1, 0, 1])
array([0.26894142, 0.5 , 0.73105858])
float32
型の配列を入力しても、計算結果は float64
型になります。これは、logistic()
関数が double
型の入力と出力のみをサポートしているためです。
logistic(np.arange(-1, 2, dtype=np.float32))
array([0.26894142, 0.5 , 0.73105858])
float32
と float64
の両方をサポートしたい場合は、次のように融合型 float_num
を宣言し、関数の入出力のデータ型を float_num
で定義します。
%%cython --compile-args=-w
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
cimport cython
from libc.math cimport exp
ctypedef fused float_num:
double
float
@cython.ufunc
cdef float_num logistic2(float_num x):
return 1.0 / (1.0 + exp(-x))
整数を入力した場合の結果は float64
になりますが、float32
を入力した場合の出力は float32
になります。
print(logistic2([-1, 0, 1]).dtype)
print(logistic2(np.arange(-1, 2, dtype=np.float32)).dtype)
float64
float32
複数の引数を持つ関数もサポートできます。例えば、次の peaks()
関数は 2 次元関数を計算します。
%%cython --compile-args=-w
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
cimport cython
from libc.math cimport exp
@cython.ufunc
cdef double peaks(double x, double y):
return x * exp(-x*x - y*y)
Y, X = np.ogrid[-2:2:5j, -2:2:5j]
peaks(X, Y)
array([[-0.00067093, -0.00673795, 0. , 0.00673795, 0.00067093],
[-0.01347589, -0.13533528, 0. , 0.13533528, 0.01347589],
[-0.03663128, -0.36787944, 0. , 0.36787944, 0.03663128],
[-0.01347589, -0.13533528, 0. , 0.13533528, 0.01347589],
[-0.00067093, -0.00673795, 0. , 0.00673795, 0.00067093]])
複数の出力がある場合、タプルでリターン型を宣言します。次の例では、二次方程式の二つの解を求める ufunc 関数を定義します。引数のデータ型は double
ですが、リターン値は二つの complex
型の値です。
%%cython --compile-args=-w
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
cimport cython
from libc.complex cimport csqrt
@cython.ufunc
cdef (complex, complex) quadratic_roots(double a, double b, double c):
cdef complex t0, t1
t0 = csqrt(-4*a*c + b**2)
t1 = 0.5 / a
return t1*(-b - t0), t1*(-b + t0)
a, b, c = np.ogrid[-1:1:4j, -1:1:3j, -1:1:2j]
x0, x1 = quadratic_roots(a, b, c)
print(x0.shape)
print(x1.shape)
print(np.allclose(a * x0 ** 2 + b * x0 + c, 0))
print(np.allclose(a * x1 ** 2 + b * x1 + c, 0))
(4, 3, 2)
(4, 3, 2)
True
True
DLL内の関数の高速呼び出し#
Pythonの標準ライブラリctypes
または拡張ライブラリcffi
を使用すると、ダイナミックリンクライブラリ内の関数を簡単に呼び出すことができます。ただし、実際の関数を呼び出す前に多くの前処理が必要なため、関数呼び出しの効率は高くありません。ループ内で大量に呼び出す必要がある場合、この前処理によるオーバーヘッドがプログラムの実行速度に大きく影響します。このセクションでは、cffi
を使用してダイナミックリンクライブラリ内の関数のアドレスを取得し、そのアドレスをCythonの関数に渡してループ呼び出しを行うことで、関数呼び出しの効率を向上させる方法を紹介します。
まず、C言語で関数peaks()
を作成します。この関数のアドレスを正しく取得できることを確認するために、get_addr()
を使用してpeaks()
のアドレスを返します:
%%writefile peaks.c
#include <math.h>
double peaks(double x, double y)
{
return x * exp(-x*x - y*y);
}
unsigned long long get_addr()
{
return (unsigned long long)(void *)peaks;
}
Overwriting peaks.c
次に、gcc
を呼び出してpeaks.c
をpeaks.dll
にコンパイルします:
!gcc -Ofast -shared -o peaks.dll peaks.c
次に、cffi
を使用してpeaks.dll
を読み込み、関数シグネチャを設定し、最後にその関数を呼び出します:
import cffi
ffi = cffi.FFI()
ffi.cdef(
"""
unsigned long long get_addr();
double peaks(double x, double y);
"""
)
lib = ffi.dlopen("peaks.dll")
lib.peaks(1.0, 2.0)
0.006737946999085467
lib.peaks
は、C言語の関数アドレスをラップしたcffi
のオブジェクトです。ffi.addressof()
を使用してこの関数へのポインタを取得し、ffi.cast()
を使用してポインタを整数に変換できます。以下のプログラムは、C言語関数のアドレスを取得し、get_addr()
の戻り値と比較します:
peak_addr = ffi.cast("size_t", ffi.addressof(lib, "peaks"))
assert peak_addr == lib.get_addr()
以下は、Cythonでvectorize_2d()
関数を作成する例です。func_addr
パラメータは関数アドレスを表す整数で、x
とy
は2次元の連続した配列です。
❶まず、ctypedef
キーワードを使用して、二項倍精度浮動小数点数関数ポインタ型Function
を宣言します。❷次に、パラメータfunc_addr
をFunction
型の関数ポインタに変換し、その後、二重ループ内でこの関数ポインタを使用して、それが指すC言語関数を高速に呼び出すことができます。
%%cython
import cython
import numpy as np
ctypedef double(*Function)(double x, double y) #❶
@cython.wraparound(False)
@cython.boundscheck(False)
def vectorize_2d(size_t func_addr, double[:, ::1] x, double[:, ::1] y):
cdef double[:, ::1] res = np.zeros_like(x.base)
cdef Function func_ptr = <Function><void *>func_addr #❷
cdef int i, j
for i in range(x.shape[0]):
for j in range(x.shape[1]):
res[i, j] = func_ptr(x[i, j], y[i, j])
return res.base
以下は、vectorize_2d()
を使用してpeaks()
を呼び出し、vectorize()
を使用して作成したufunc関数の結果と比較する例です:
Y, X = np.mgrid[-2:2:200j, -2:2:200j]
vectorize_peaks = np.vectorize(lib.peaks, otypes=["f8"])
np.allclose(vectorize_peaks(X, Y), vectorize_2d(peak_addr, X, Y))
True
本書で提供されているFuncAddr
クラスを使用すると、ダイナミックリンクライブラリ内の関数のアドレスをより簡単に取得できます:
from helper.cffi import FuncAddr
msvcrt = FuncAddr("msvcrt.dll")
np.allclose(np.arctan2(X, Y), vectorize_2d(msvcrt.atan2, X, Y))
True
BLAS関数の呼び出し#
BLASは、基本的な線形代数ルーチンのAPI標準であり、SciPyの多くの高速線形代数演算関数は、Fortranで書かれたBLAS関数を内部で呼び出しています。これらのFortran関数の計算効率は高いですが、Pythonの呼び出しインターフェースによるオーバーヘッドは無視できず、大量のループで呼び出す場合にはこのオーバーヘッドが問題となります。Cythonを使用してこれらの関数をループ内で呼び出すことで、Pythonの呼び出しインターフェースの制約から完全に解放されます。
saxpy()
関数のラッピング#
BLASのAPI関数はscipy.linalg.blas
モジュールを通じてアクセスできます。以下は、その中のsaxpy()
を呼び出すデモンストレーションです:
from scipy.linalg import blas
import numpy as np
x = np.array([1, 2, 3], np.float32)
y = np.array([1, 3, 5], np.float32)
print(blas.saxpy)
blas.saxpy(x, y, a=0.5)
<fortran function saxpy>
array([1.5, 4. , 6.5], dtype=float32)
saxpy
は<fortran object>
であり、呼び出すたびにPythonのオブジェクトをFortran関数のパラメータに変換するため、ループ内で大量に呼び出すには適していません。
scipy.linalg
には、Cythonから呼び出すためのcython_blas
モジュールも提供されています。その中のsaxpy()
の関数シグネチャは以下の通りです:
ctypedef float s
cdef void saxpy(int *n, s *sa, s *sx, int *incx, s *sy, int *incy) nogil
Fortran言語は参照渡しを使用するため、関数呼び出し時に渡されるパラメータと関数内で受け取るパラメータは同じメモリアドレスです。したがって、関数シグネチャではすべてのパラメータがポインタとして定義されています。
以下のプログラムでは、blas_saxpy()
がcython_blas
モジュール内のsaxpy()
を簡単にラップしています。計算速度を比較するために、cython_saxpy()
はループを使用して計算を行います:
%%cython
import cython
cimport scipy.linalg.cython_blas as blas
def blas_saxpy(float[:] y, float a, float[:] x):
cdef int n = y.shape[0]
cdef int inc_x = x.strides[0] // sizeof(float)
cdef int inc_y = y.strides[0] // sizeof(float)
blas.saxpy(&n, &a, &x[0], &inc_x, &y[0], &inc_y)
@cython.wraparound(False)
@cython.boundscheck(False)
def cython_saxpy(float[:] y, float a, float[:] x):
cdef int i
for i in range(y.shape[0]):
y[i] += a * x[i]
以下で両者の実行速度を比較します。Cythonのループで実装した関数の方が速いです。
a = np.arange(100000, dtype=np.float32)
b = np.zeros_like(a)
%timeit blas_saxpy(b, 0.2, a)
%timeit cython_saxpy(b, 0.2, a)
164 μs ± 7.5 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
96 μs ± 1.39 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
dgemm()
高速行列積#
BLASのDGEMM()
は以下の行列積演算を実装しています。パラメータalpha
が1、beta
が0の場合、結果C
は行列A
とB
の積となります:
C = alpha*op(A)*op(B) + beta*C
ここで、op()
は行列を転置することができます。Fortran形式の配列とC言語形式の配列の軸の順序が逆であるため、2つのC言語の配列で表される行列の積を計算するには、op()
を転置に設定する必要があります。転置するかどうかに関わらず、演算結果C
はFortran形式の配列となります。
Fortran関数dgemm()
のパラメータは以下の通りです:
subroutine dgemm (
character TRANSA,
character TRANSB,
integer M,
integer N,
integer K,
double precision ALPHA,
double precision, dimension(lda,*) A,
integer LDA,
double precision, dimension(ldb,*) B,
integer LDB,
double precision BETA,
double precision, dimension(ldc,*) C,
integer LDC
)
以下は、cython_blas
モジュール内の関数シグネチャです:
ctypedef double d
cdef void dgemm(
char *transa,
char *transb,
int *m,
int *n,
int *k,
d *alpha,
d *a,
int *lda,
d *b,
int *ldb,
d *beta,
d *c,
int *ldc) nogil
以下のCython関数dgemm(A, B, index)
では、A
とB
はC言語形式の3次元配列で、形状はそれぞれ(La, M, K)
と(Lb, K, N)
です。index
は形状が(Lc, 2)
の整数配列です。この関数は、index
内の各整数ペアj, k
に対してC[i] = A[j] * B[k]
を計算します。ここで、i
はその整数ペアのインデックスです。したがって、関数の戻り値C
は形状が(Lc, N, M)
の3次元配列です。C
はLc
個のFortran形式の2次元配列と見なすことができます。
メモリビューの要素にアクセスする際にはPythonに関連する操作は行われず、各行列の積演算は互いに独立しているため、この部分を並列化することができます。CythonはOpenMPを使用して並列化を実装しているため、コンパイル時に-fopenmp
オプションを設定する必要があります。
❶並列化されたprange()
関数を読み込みます。この関数は並列化されたループにコンパイルされます。❷nogil
パラメータをTrue
に設定し、並列化中にPythonのグローバルロックを解放することを示します。
%%cython -c-Ofast -c-fopenmp --link-args=-fopenmp
from cython.parallel import prange #❶
import cython
import numpy as np
cimport scipy.linalg.cython_blas as blas
@cython.wraparound(False)
@cython.boundscheck(False)
def dgemm(double[:, :, :] A, double[:, :, :] B, int[:, ::1] index):
cdef int m, n, k, i, length, idx_a, idx_b
cdef double[:, :, :] C
cdef char ta, tb
cdef double alpha = 1.0
cdef double beta = 0.0
length = index.shape[0]
m, k, n = A.shape[1], A.shape[2], B.shape[2]
C = np.zeros((length, n, m))
ta = b"T"
tb = ta
for i in prange(length, nogil=True): #❷
idx_a = index[i, 0]
idx_b = index[i, 1]
blas.dgemm(&ta, &tb, &m, &n, &k, &alpha,
&A[idx_a, 0, 0], &k,
&B[idx_b, 0, 0], &n,
&beta,
&C[i, 0, 0], &m)
return C.base
NumPyに新しく追加されたgufunc関数は、単一の行列の演算をブロードキャストして配列全体に適用することができます。NumPyに新しく追加された行列積演算子@
は、行列積のブロードキャスト演算を実現します。同様の機能は、上記のdgemm()
を使用して実現することもできます。以下のmatrix_multiply(a, b)
は、2つの任意の次元数の配列の最後の2つの軸に対して行列積演算を行い、他の軸に対してブロードキャスト演算を行います。例えば、a
の形状が(12, 1, 10, 100, 30)
で、b
の形状が( 1, 15, 1, 30, 50)
の場合、最後の2つの軸に対応する行列積の結果の形状は(100, 50)
であり、他の軸のブロードキャスト後の形状は(12, 15, 10)
であるため、結果の配列の形状は(12, 15, 10, 100, 50)
となります。合計で\(12 \times 15 \times 10\)回の行列積演算が行われます。
このプログラムの実装の考え方は以下の通りです。詳細は読者自身で研究してください:
a
内の各行列に番号を付け、その番号の形状をa
のブロードキャスト部分の形状に変更してidx_a
を得ます。b
に対しても同様の操作を行い、idx_b
を得ます。broadcast_arrays()
を使用して、idx_a
とidx_b
のブロードキャスト後の配列を計算します。上記の2つの配列を平坦化して2列に並べ、
dgemm()
関数のindex
パラメータを得ます。a
とb
の形状を3次元配列に変更してdgemm()
関数に渡し、行列積を計算します。
def matrix_multiply(a, b):
if a.ndim <= 2 and b.ndim <= 2:
return np.dot(a, b)
a = np.ascontiguousarray(a).astype(np.float64, copy=False)
b = np.ascontiguousarray(b).astype(np.float64, copy=False)
if a.ndim == 2:
a = a[None, :, :]
if b.ndim == 2:
b = b[None, :, :]
shape_a = a.shape[:-2]
shape_b = b.shape[:-2]
len_a = np.prod(shape_a)
len_b = np.prod(shape_b)
idx_a = np.arange(len_a, dtype=np.int32).reshape(shape_a)
idx_b = np.arange(len_b, dtype=np.int32).reshape(shape_b)
idx_a, idx_b = np.broadcast_arrays(idx_a, idx_b)
index = np.column_stack((idx_a.ravel(), idx_b.ravel()))
bshape = idx_a.shape
if a.ndim > 3:
a = a.reshape(-1, a.shape[-2], a.shape[-1])
if b.ndim > 3:
b = b.reshape(-1, b.shape[-2], b.shape[-1])
if a.shape[-1] != b.shape[-2]:
raise ValueError("can't do matrix multiply because k isn't the same")
c = dgemm(a, b, index)
c = np.swapaxes(c, -2, -1)
c.shape = bshape + c.shape[-2:]
return c
以下でmatrix_multiply()
と@
演算子の計算結果を比較します:
a = np.random.rand(12, 1, 10, 100, 30)
b = np.random.rand(1, 15, 1, 30, 50)
np.allclose(matrix_multiply(a, b), a @ b)
True
以下は両者の実行速度の比較です。matrix_multiply()
はNumPyの組み込み行列積演算子より高速です:
%timeit matrix_multiply(a, b)
%timeit a @ b
177 ms ± 4.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
414 ms ± 116 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)