Cythonの応用技法#

import numpy as np
import helper.magics

%load_ext helper.cython

拡張型#

Pythonではclassを使って定義されたクラスは、インスタンスの属性を辞書で保存します。しかし、属性へのアクセス速度を向上させるため、Pythonの組み込み型は属性をオブジェクトの構造体のフィールドに直接保存します。Cythonではcdef classを使って拡張型を定義できます。拡張型はPythonの組み込み型と同様に、C言語の構造体を使ってオブジェクトの属性を保存するため、Cythonプログラム内でこれらの属性に高速にアクセスできます。拡張型はC言語の関数ライブラリをラップし、オブジェクト指向のPythonインターフェースを提供するのに適しています。

拡張型の基本構造#

以下のプログラムでは、cdef classを使って拡張型Point2Dを定義し、cdefを使って属性xyを定義しています。Pythonのクラスとは異なり、拡張型の属性はクラス内で定義され、__init__()メソッド内で生成されるわけではありません。

%%cython

cdef class Point2D:
    cdef public double x, y

Cythonは自動的に以下の構造体を定義してPoint2Dオブジェクトを表現します。ob_refcntob_typeの2つのフィールドはすべてのPythonオブジェクトに必要なため、PythonのC言語コードでは通常PyObject_HEADマクロを使って定義されます。

struct __pyx_obj_Point2D {
  PyObject_HEAD
  double x;
  double y;
};

Cythonプログラム内では、オブジェクトpの型がPoint2Dであることが明確にわかっている場合、p.xPoint2D構造体のxフィールドに直接アクセスするように変換されるため、拡張型の変数の属性へのアクセスは非常に高速です。PythonでPoint2Dxy属性にアクセスするためには、属性を宣言する際にpublicキーワードを使用する必要があります。Cythonは整数型、浮動小数点型、文字列型、およびPythonオブジェクト型の4種類のcdef属性に対して、属性アクセス用のディスクリプタを自動的に作成します。これらのディスクリプタには__get__()__set__()メソッドが含まれており、属性の取得と設定を行います。読み取り専用属性の場合は、publicキーワードをreadonlyに置き換えることができます。以下のコードはxに対応する属性アクセスディスクリプタを表示します:

print(type(Point2D.x))
print(Point2D.x.__get__)
print(Point2D.x.__set__)
<class 'getset_descriptor'>
<method-wrapper '__get__' of getset_descriptor object at 0x0000028B617D1B40>
<method-wrapper '__set__' of getset_descriptor object at 0x0000028B617D1B40>

関数の定義と同様に、拡張型ではdefcdefcpdefを使ってオブジェクトのメソッドを定義できます。すべてのメソッドはCython内で呼び出すことができますが、defcpdefで定義されたメソッドのみがPython内で呼び出すことができます。Cython内でcdefcpdefメソッドを呼び出す場合、対応するC言語関数が直接呼び出されるため、defメソッドよりも効率が大幅に向上します。

拡張型は他の拡張型からの継承をサポートしています。例えば、以下のPoint3DPoint2Dから継承し、フィールドzを追加しています:

%%cython -a

cdef class Point2D:
    cdef public double x, y

cdef class Point3D(Point2D):
    cdef public double z
    
cdef Point3D p = Point3D()
p.x = 1.0
p.y = 2.0
p.z = 3.0

Generated by Cython 3.0.12

Yellow lines hint at Python interaction.
Click on a line that starts with a "+" to see the C code that Cython generated for it.

 01: 
+02: cdef class Point2D:
  __pyx_t_2 = __Pyx_PyDict_NewPresized(0); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 2, __pyx_L1_error)
  __Pyx_GOTREF(__pyx_t_2);
  if (PyDict_SetItem(__pyx_d, __pyx_n_s_test, __pyx_t_2) < 0) __PYX_ERR(0, 2, __pyx_L1_error)
  __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
/* … */
struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D {
  PyObject_HEAD
  double x;
  double y;
};

+03:     cdef public double x, y
/* Python wrapper */
static PyObject *__pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1x_1__get__(PyObject *__pyx_v_self); /*proto*/
static PyObject *__pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1x_1__get__(PyObject *__pyx_v_self) {
  CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
  PyObject *__pyx_r = 0;
  __Pyx_RefNannyDeclarations
  __Pyx_RefNannySetupContext("__get__ (wrapper)", 0);
  __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs);
  __pyx_r = __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1x___get__(((struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D *)__pyx_v_self));

  /* function exit code */
  __Pyx_RefNannyFinishContext();
  return __pyx_r;
}

static PyObject *__pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1x___get__(struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D *__pyx_v_self) {
  PyObject *__pyx_r = NULL;
  __Pyx_XDECREF(__pyx_r);
  __pyx_t_1 = PyFloat_FromDouble(__pyx_v_self->x); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 3, __pyx_L1_error)
  __Pyx_GOTREF(__pyx_t_1);
  __pyx_r = __pyx_t_1;
  __pyx_t_1 = 0;
  goto __pyx_L0;

  /* function exit code */
  __pyx_L1_error:;
  __Pyx_XDECREF(__pyx_t_1);
  __Pyx_AddTraceback("_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e.Point2D.x.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename);
  __pyx_r = NULL;
  __pyx_L0:;
  __Pyx_XGIVEREF(__pyx_r);
  __Pyx_RefNannyFinishContext();
  return __pyx_r;
}

/* Python wrapper */
static int __pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1x_3__set__(PyObject *__pyx_v_self, PyObject *__pyx_v_value); /*proto*/
static int __pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1x_3__set__(PyObject *__pyx_v_self, PyObject *__pyx_v_value) {
  CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
  int __pyx_r;
  __Pyx_RefNannyDeclarations
  __Pyx_RefNannySetupContext("__set__ (wrapper)", 0);
  __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs);
  __pyx_r = __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1x_2__set__(((struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D *)__pyx_v_self), ((PyObject *)__pyx_v_value));

  /* function exit code */
  __Pyx_RefNannyFinishContext();
  return __pyx_r;
}

static int __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1x_2__set__(struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D *__pyx_v_self, PyObject *__pyx_v_value) {
  int __pyx_r;
  __pyx_t_1 = __pyx_PyFloat_AsDouble(__pyx_v_value); if (unlikely((__pyx_t_1 == (double)-1) && PyErr_Occurred())) __PYX_ERR(0, 3, __pyx_L1_error)
  __pyx_v_self->x = __pyx_t_1;

  /* function exit code */
  __pyx_r = 0;
  goto __pyx_L0;
  __pyx_L1_error:;
  __Pyx_AddTraceback("_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e.Point2D.x.__set__", __pyx_clineno, __pyx_lineno, __pyx_filename);
  __pyx_r = -1;
  __pyx_L0:;
  return __pyx_r;
}

/* Python wrapper */
static PyObject *__pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1y_1__get__(PyObject *__pyx_v_self); /*proto*/
static PyObject *__pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1y_1__get__(PyObject *__pyx_v_self) {
  CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
  PyObject *__pyx_r = 0;
  __Pyx_RefNannyDeclarations
  __Pyx_RefNannySetupContext("__get__ (wrapper)", 0);
  __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs);
  __pyx_r = __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1y___get__(((struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D *)__pyx_v_self));

  /* function exit code */
  __Pyx_RefNannyFinishContext();
  return __pyx_r;
}

static PyObject *__pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1y___get__(struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D *__pyx_v_self) {
  PyObject *__pyx_r = NULL;
  __Pyx_XDECREF(__pyx_r);
  __pyx_t_1 = PyFloat_FromDouble(__pyx_v_self->y); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 3, __pyx_L1_error)
  __Pyx_GOTREF(__pyx_t_1);
  __pyx_r = __pyx_t_1;
  __pyx_t_1 = 0;
  goto __pyx_L0;

  /* function exit code */
  __pyx_L1_error:;
  __Pyx_XDECREF(__pyx_t_1);
  __Pyx_AddTraceback("_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e.Point2D.y.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename);
  __pyx_r = NULL;
  __pyx_L0:;
  __Pyx_XGIVEREF(__pyx_r);
  __Pyx_RefNannyFinishContext();
  return __pyx_r;
}

/* Python wrapper */
static int __pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1y_3__set__(PyObject *__pyx_v_self, PyObject *__pyx_v_value); /*proto*/
static int __pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1y_3__set__(PyObject *__pyx_v_self, PyObject *__pyx_v_value) {
  CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
  int __pyx_r;
  __Pyx_RefNannyDeclarations
  __Pyx_RefNannySetupContext("__set__ (wrapper)", 0);
  __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs);
  __pyx_r = __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1y_2__set__(((struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D *)__pyx_v_self), ((PyObject *)__pyx_v_value));

  /* function exit code */
  __Pyx_RefNannyFinishContext();
  return __pyx_r;
}

static int __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point2D_1y_2__set__(struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point2D *__pyx_v_self, PyObject *__pyx_v_value) {
  int __pyx_r;
  __pyx_t_1 = __pyx_PyFloat_AsDouble(__pyx_v_value); if (unlikely((__pyx_t_1 == (double)-1) && PyErr_Occurred())) __PYX_ERR(0, 3, __pyx_L1_error)
  __pyx_v_self->y = __pyx_t_1;

  /* function exit code */
  __pyx_r = 0;
  goto __pyx_L0;
  __pyx_L1_error:;
  __Pyx_AddTraceback("_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e.Point2D.y.__set__", __pyx_clineno, __pyx_lineno, __pyx_filename);
  __pyx_r = -1;
  __pyx_L0:;
  return __pyx_r;
}
 04: 
 05: cdef class Point3D(Point2D):
+06:     cdef public double z
/* Python wrapper */
static PyObject *__pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point3D_1z_1__get__(PyObject *__pyx_v_self); /*proto*/
static PyObject *__pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point3D_1z_1__get__(PyObject *__pyx_v_self) {
  CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
  PyObject *__pyx_r = 0;
  __Pyx_RefNannyDeclarations
  __Pyx_RefNannySetupContext("__get__ (wrapper)", 0);
  __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs);
  __pyx_r = __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point3D_1z___get__(((struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point3D *)__pyx_v_self));

  /* function exit code */
  __Pyx_RefNannyFinishContext();
  return __pyx_r;
}

static PyObject *__pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point3D_1z___get__(struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point3D *__pyx_v_self) {
  PyObject *__pyx_r = NULL;
  __Pyx_XDECREF(__pyx_r);
  __pyx_t_1 = PyFloat_FromDouble(__pyx_v_self->z); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 6, __pyx_L1_error)
  __Pyx_GOTREF(__pyx_t_1);
  __pyx_r = __pyx_t_1;
  __pyx_t_1 = 0;
  goto __pyx_L0;

  /* function exit code */
  __pyx_L1_error:;
  __Pyx_XDECREF(__pyx_t_1);
  __Pyx_AddTraceback("_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e.Point3D.z.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename);
  __pyx_r = NULL;
  __pyx_L0:;
  __Pyx_XGIVEREF(__pyx_r);
  __Pyx_RefNannyFinishContext();
  return __pyx_r;
}

/* Python wrapper */
static int __pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point3D_1z_3__set__(PyObject *__pyx_v_self, PyObject *__pyx_v_value); /*proto*/
static int __pyx_pw_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point3D_1z_3__set__(PyObject *__pyx_v_self, PyObject *__pyx_v_value) {
  CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
  int __pyx_r;
  __Pyx_RefNannyDeclarations
  __Pyx_RefNannySetupContext("__set__ (wrapper)", 0);
  __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs);
  __pyx_r = __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point3D_1z_2__set__(((struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point3D *)__pyx_v_self), ((PyObject *)__pyx_v_value));

  /* function exit code */
  __Pyx_RefNannyFinishContext();
  return __pyx_r;
}

static int __pyx_pf_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_7Point3D_1z_2__set__(struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point3D *__pyx_v_self, PyObject *__pyx_v_value) {
  int __pyx_r;
  __pyx_t_1 = __pyx_PyFloat_AsDouble(__pyx_v_value); if (unlikely((__pyx_t_1 == (double)-1) && PyErr_Occurred())) __PYX_ERR(0, 6, __pyx_L1_error)
  __pyx_v_self->z = __pyx_t_1;

  /* function exit code */
  __pyx_r = 0;
  goto __pyx_L0;
  __pyx_L1_error:;
  __Pyx_AddTraceback("_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e.Point3D.z.__set__", __pyx_clineno, __pyx_lineno, __pyx_filename);
  __pyx_r = -1;
  __pyx_L0:;
  return __pyx_r;
}
 07: 
+08: cdef Point3D p = Point3D()
  __pyx_t_2 = __Pyx_PyObject_CallNoArg(((PyObject *)__pyx_ptype_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point3D)); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 8, __pyx_L1_error)
  __Pyx_GOTREF(__pyx_t_2);
  __Pyx_XGOTREF((PyObject *)__pyx_v_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_p);
  __Pyx_DECREF_SET(__pyx_v_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_p, ((struct __pyx_obj_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_Point3D *)__pyx_t_2));
  __Pyx_GIVEREF(__pyx_t_2);
  __pyx_t_2 = 0;
+09: p.x = 1.0
  __pyx_v_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_p->__pyx_base.x = 1.0;
+10: p.y = 2.0
  __pyx_v_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_p->__pyx_base.y = 2.0;
+11: p.z = 3.0
  __pyx_v_54_cython_magic_390498a5e4057d93d29df0bbd7ea9444e0e5747e_p->z = 3.0;

Point3Dオブジェクトに対応するC言語構造体は以下のようになります:

struct __pyx_obj_Point3D {
  struct __pyx_obj_Point2D __pyx_base;
  double z;
};

最初のフィールド__pyx_baseは、その基底クラスPoint2Dに対応する構造体です。Cythonでは__pyx_base.xを使って基底クラスで定義された属性xにアクセスし、PythonではPoint3D.__base__内のxに対応するディスクリプタを使ってこの属性にアクセスします。

1次元浮動小数点ベクトル型#

次に、1次元浮動小数点ベクトルVector型を例に、拡張型の使用方法を紹介します。Vectorオブジェクトは2つのプライベート属性を持っています:countは配列の長さを表し、dataは配列データの先頭アドレスを保持します。

Note

長いCythonコードを部分ごとに取り出して説明するため、本書では提供する%includeマジックコマンドを利用します。

%include cython_examples/vector.pyx 1
cdef class Vector:
    cdef int count
    cdef double * data

新しいVectorオブジェクトを作成する際、ヒープメモリから構造体が割り当てられます。構造体のメモリ割り当てが完了した後、すぐにその中の属性を初期化する必要があります。この初期化作業は__cinit__()によって行われ、C言語レベルの__init__()と考えることができます。Vectorオブジェクトは2つの初期化方法をサポートしています:指定されたサイズのメモリを割り当てるか、シーケンスオブジェクトを使って配列の内容を初期化します。ここではPython/C APIのPyMem_Malloc()を使って、Pythonが管理するヒープから配列データを保存するメモリを割り当てます。このAPI関数を使用するためには、from cpython cimport memを使ってメモリ管理API関数のヘッダファイルを読み込む必要があります。

%include cython_examples/vector.pyx 2
    def __cinit__(self, data):
        cdef int i
        if isinstance(data, int):
            self.count = data
        else:
            self.count = len(data)
        self.data = <double *>mem.PyMem_Malloc(sizeof(double)*self.count)
        if self.data is NULL:
            raise MemoryError
        
        if not isinstance(data, int):
            for i in range(self.count):
                self.data[i] = data[i]

オブジェクトの参照カウントが0になると、ガベージコレクションによって回収されます。オブジェクト構造体が回収される前に、Cythonは__dealloc__()を呼び出し、ここでdata属性が指すメモリを解放する必要があります:

%include cython_examples/vector.pyx 3
    def __dealloc__(self):
        if self.data is not NULL:
            mem.PyMem_Free(self.data)

Vectorオブジェクトが整数の添字アクセスとループによる反復をサポートするためには、__len__()__getitem__()__setitem__()などのメソッドを定義する必要があります。これらの2つのアンダースコアで始まり終わるメソッドはマジックメソッドと呼ばれ、これらのメソッドを定義することで、特定の構文におけるオブジェクトの動作を変更できます。例えば、オブジェクトの長さを調べる組み込み関数len(obj)は実際にはobj.__len__()の値を返し、obj[index]obj.__getitem__(index)を呼び出し、obj[index] = valueobj.__setitem__(index, value)を呼び出します。ある型が__len__()__getitem__()を定義している場合、そのオブジェクトは自動的にforループによる要素の反復をサポートします。

以下のプログラムでは、__getitem__()__setitem__()はその添字引数が整数型であることを要求するため、Vectorオブジェクトはスライス添字をサポートしません。負の添字と添字の範囲外チェックのサポートを_check_index()メソッドに分離し、このメソッドの引数は整数型のポインタで、渡された添字変数を直接変更できます。添字が範囲外の場合、IndexError例外を発生させます。Cythonではp[0]を使ってポインタ変数pが指すアドレスにアクセスします。

%include cython_examples/vector.pyx 4
    def __len__(self):
        return self.count
    
    cdef _check_index(self, int *index):
        if index[0] < 0:
            index[0] = self.count + index[0]
        if index[0] < 0  or index[0] > self.count - 1:
            raise IndexError("Vector index out of range")
    
    def __getitem__(self, int index):
        self._check_index(&index)
        return self.data[index]
    
    def __setitem__(self, int index, double value):
        self._check_index(&index)
        self.data[index] = value

Vectorオブジェクトが加算演算子をサポートするためには、__add__()メソッドを定義する必要があります:

%include cython_examples/vector.pyx 5
    def _add(self, other):        
        cdef Vector new, _other

        if isinstance(other, Vector): #❷       
            _other = <Vector>other #❸
            if self.count != _other.count:
                raise ValueError("Vector size not equal")
            new = Vector(self.count) #❹
            add_array(self.data, _other.data, new.data, self.count)
            return new
        new = Vector(self.count) #❹
        add_number(self.data, <double>other, new.data, self.count)
        return new

    def __add__(self, other): #❶
        return self._add(other)

    def __radd__(self, other): #❶
        return self._add(other)

足し算や掛け算のような二項演算では、演算子が左右のオペランドに対してどのように動作するかを制御するために、特定のメソッドが使用されます。例えば、足し算に関するマジックメソッド __add____radd__ について、a + b のような式で、まず a.__add__(b) が呼び出され、もしそれが失敗した場合は b.__radd__(a) が実行されます。

vVectorオブジェクトの場合、1 + v を実行すると、整数 1__add__() 関数呼び出しが失敗し、Vector.__radd__() が呼び出され、引数 other1 となります。

プログラムでは、❶ __add__()__radd__()はヘルプメソッド_add()を呼び出します。❷__add__()は数値とVectorオブジェクトの両方を処理できるため、otherオブジェクトの型に応じて異なる処理を行う必要があります。❸selfother変数には型宣言がないため、C言語構造体に保存されている属性を取得できません。<Vector>を使ってPythonオブジェクトを型付きの変数_self_otherに変換し、これらの変数を通じてcountdata属性にアクセスします。❹演算結果を保存するVectorオブジェクトを作成し、add_array()またはadd_number()を呼び出して計算を行います。これらの関数のコードは後で紹介します。

以下は+=演算子に対応するマジックメソッド__iadd__()です:

%include cython_examples/vector.pyx 6
    def __iadd__(self, other):
        cdef Vector _other
        if isinstance(other, Vector):
            _other = <Vector>other
            if self.count != _other.count:
                raise ValueError("Vector size not equal")
            add_array(self.data, _other.data, self.data, self.count)
        else:
            add_number(self.data, <double>other, self.data, self.count)
        return self

二項演算子とは異なり、__iadd__()の最初の引数は現在のオブジェクトであるため、Cythonはその型を知っており、型変換を行う必要はありません。

以下では、cpdefを使ってnorm()メソッドを定義し、ベクトルの長さを計算し、__str__()でそれを呼び出します。オブジェクトを文字列に変換する際、__str__()メソッドが呼び出されます。cpdefで定義されたメソッドは、CythonとPythonの両方の呼び出しインターフェースを生成します。__str__()ではCythonの呼び出しインターフェースを使ってnorm()を実行し、Pythonではより遅いインターフェースを使ってnorm()を呼び出します。

%include cython_examples/vector.pyx 7
    def __str__(self):
        values = ", ".join(str(self.data[i]) for i in range(self.count))
        norm = self.norm()
        return "Vector[{}]({})".format(norm, values)
    
    cpdef norm(self):
        cdef double *p
        cdef double s
        cdef int i
        s = 0
        p = self.data
        for i in range(self.count):
            s += p[i] * p[i]
        return s**0.5

さらに、拡張型はPythonで継承でき、基底クラスで定義されたdefcpdefメソッドをオーバーライドできます。Pythonで定義されたオーバーライドメソッドは、Cython内で正しく呼び出すことができます。Cython内でcpdefメソッドを呼び出す場合、このメソッドがオーバーライドされているかどうかをチェックする必要があるため、cdefメソッドよりも若干呼び出し速度が遅くなります。

最後に、コードの先頭に計算を行うadd_array()add_number()の2つの関数を追加します。これらはcdefで定義され、Cythonコード内でのみ呼び出すことができます。

%include cython_examples/vector.pyx 8
cdef add_array(double *op1, double *op2, double *res, int count):
    cdef int i
    for i in range(count):
        res[i] = op1[i] + op2[i]

cdef add_number(double *op1, double op2, double *res, int count):
    cdef int i
    for i in range(count):
        res[i] = op1[i] + op2
!cythonize -i cython_examples/vector.pyx

他の二項計算関数や要素アクセス関数を完成させてみてください。以下はVectorオブジェクトの使用例です:

from cython_examples.vector import Vector

v1 = Vector(range(5))
v2 = Vector(range(100, 105))
print(len(v1))
print(v1 + v2)
print(v1 + 2)
print(20 + v2)
print(v1.norm(), v2.norm())
print([x**2 for x in v1])
5
Vector[232.63705637752557](100.0, 102.0, 104.0, 106.0, 108.0)
Vector[9.486832980505138](2.0, 3.0, 4.0, 5.0, 6.0)
Vector[272.81862106535175](120.0, 121.0, 122.0, 123.0, 124.0)
5.477225575051661 228.10085488660502
[0.0, 1.0, 4.0, 9.0, 16.0]

以下はVectorオブジェクトとNumPy配列のベクトル加算の速度比較です:

v1 = Vector(range(10000))
v2 = Vector(range(10000))
%timeit v1 + v2

a1 = np.arange(10000, dtype=float)
a2 = np.arange(10000, dtype=float)
%timeit a1 + a2
7.05 μs ± 61.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
7.23 μs ± 133 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

以下はVectorオブジェクトとNumPy配列の要素アクセス速度の比較です:

%timeit v1[100]
%timeit v1[100] = 2.0
%timeit a1[100]
%timeit a1[100] = 2.0
59.7 ns ± 0.726 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
62.7 ns ± 1.45 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
137 ns ± 2.18 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
122 ns ± 2.14 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

multifastライブラリのラッピング#

拡張型は、C言語の関数ライブラリをラップし、オブジェクト指向のPythonインターフェースを提供するためによく使用されます。このセクションでは、「多パターンマッチングアルゴリズム」のC言語ライブラリmultifastをラップする例を通じて、C言語関数ライブラリを拡張型でラップする方法を紹介します。

See also

http://multifast.sourceforge.net/v1_4_2/ 多パターンマッチングアルゴリズムを実装したmultifastプロジェクト

以下はこの拡張型の使用方法のデモンストレーションです。まず、一連のキーワードを使ってMultiSearchオブジェクトを作成し、そのisin()メソッドを呼び出して、ターゲットバイト列内でキーワードを検索します。キーワードのいずれかがバイト列内に存在する場合、Trueを返し、そうでない場合はFalseを返します。

from cython_examples.multisearch import MultiSearch

ms = MultiSearch([b"abc", b"xyz"])
print(ms.isin(b"123abcdef"))
print(ms.isin(b"123uvwxyz"))
print(ms.isin(b"123456789"))
True
True
False

search()メソッドは、ターゲットのバイト列内でキーワードの位置を検索するために使用できます。その第2引数はコールバック関数で、マッチする位置が見つかるたびにその位置とマッチしたキーワードをコールバック関数に渡します。コールバック関数が0を返すと検索を続行し、1を返すと検索を終了します。

def process(pos, pattern):
    print("found {0} at {1}".format(pattern, pos))
    return 0


ms.search(b"123abc456xyz789abc", process)
found b'abc' at 3
found b'xyz' at 9
found b'abc' at 15

また、iter_search()メソッドを使用してイテレータを返すこともできます:

for pos, pattern in ms.iter_search(b"123abc456xyz789abc"):
    print("found {0} at {1}".format(pattern, pos))
found b'abc' at 3
found b'xyz' at 9
found b'abc' at 15

拡張タイプの作成を開始する前に、C言語でこのライブラリをどのように使用するかを見てみましょう:

#include <stdio.h>
#include "ahocorasick.h"

/* 検索キーワードリスト */
AC_ALPHABET_t * allstr[] = {
    "recent", "from", "college"
};

#define PATTERN_NUMBER (sizeof(allstr)/sizeof(AC_ALPHABET_t *))

/* 検索テキスト */
AC_ALPHABET_t * input_text = {"She recently graduated from college"};

//*** マッチ時のコールバック関数
int match_handler(AC_MATCH_t * m, void * param)
{
    unsigned int j;

    printf ("@ %ld : %s
", m->position, m->patterns->astring);
    /* 0を返すと検索を続行し、1を返すと検索を停止 */
    return 0;
}

int main (int argc, char ** argv)
{
    unsigned int i;

    AC_AUTOMATA_t * acap;
    AC_PATTERN_t tmp_patt;
    AC_TEXT_t tmp_text;

    //*** AC_AUTOMATA_t構造体を作成し、コールバック関数を渡す
    acap = ac_automata_init();

    //*** キーワードを追加
    for (i=0; i<PATTERN_NUMBER; i++)
    {
        tmp_patt.astring = allstr[i];
        tmp_patt.rep.number = i+1; // optional
        tmp_patt.length = strlen(tmp_patt.astring);
        ac_automata_add (acap, &tmp_patt);
    }

    //*** キーワードの追加を終了
    ac_automata_finalize (acap);

    //*** 検索対象のバイト列を設定
    tmp_text.astring = input_text;
    tmp_text.length = strlen(tmp_text.astring);

    //*** 検索を実行
    ac_automata_search (acap, &tmp_text, 0, match_handler, NULL);

    //*** メモリを解放
    ac_automata_release (acap);
    return 0;
}

上記のプログラムからわかるように、このライブラリ全体はAC_AUTOMATA_t構造体を中心に処理されています。これはC言語でデータをカプセル化する一般的な方法です。Cythonで拡張タイプを使用してこのようなライブラリをラップする場合、通常はこの構造体へのポインタ属性を作成し、__cinit__()__dealloc__()でこの構造体を割り当てたり解放したりします。その後、C言語ライブラリが提供する各API関数を呼び出すいくつかのdefメソッドを定義して、ラッピングを実現します。

以下では、C言語のライブラリを拡張タイプでラップする方法を段階的に説明します。

%include cython_examples/multisearch.pyx 1
cdef extern from "ahocorasick.h": #❶
    ctypedef int (*AC_MATCH_CALBACK_f)(AC_MATCH_t *, void *) #❷
    ctypedef enum AC_STATUS_t: #❸
        ACERR_SUCCESS = 0
        ACERR_DUPLICATE_PATTERN
        ACERR_LONG_PATTERN
        ACERR_ZERO_PATTERN
        ACERR_AUTOMATA_CLOSED

    ctypedef struct AC_MATCH_t: #❹
        AC_PATTERN_t * patterns
        long position
        unsigned int match_num

    ctypedef struct AC_AUTOMATA_t:
        AC_MATCH_t match

    ctypedef struct AC_PATTERN_t:
        char * astring
        unsigned int length

    ctypedef struct AC_TEXT_t:
        char * astring
        unsigned int length

    #❺
    AC_AUTOMATA_t * ac_automata_init() 
    AC_STATUS_t ac_automata_add(AC_AUTOMATA_t * thiz, AC_PATTERN_t * pattern)
    void ac_automata_finalize(AC_AUTOMATA_t * thiz)
    int ac_automata_search(AC_AUTOMATA_t * thiz, AC_TEXT_t * text, int keep, 
        AC_MATCH_CALBACK_f callback, void * param)
    void ac_automata_settext (AC_AUTOMATA_t * thiz, AC_TEXT_t * text, int keep)
    AC_MATCH_t * ac_automata_findnext (AC_AUTOMATA_t * thiz)        
    void ac_automata_release(AC_AUTOMATA_t * thiz)

❶まず、cdef extern from ...を使用してCythonに、コンパイル後のC言語プログラムにahocorasick.hヘッダーファイルを含める必要があることを伝えます。CythonはC言語のヘッダーファイルを自動的に解析しないため、使用する型、定数、関数プロトタイプをCythonの構文で宣言する必要があります。

❷関数ポインタ型MATCH_CALBACK_fを定義します。これはコールバック関数へのポインタ型です。その第1引数は、マッチデータを保存する構造体へのポインタで、第2引数は任意の追加データを指すポインタです。C言語では通常、このようなvoid *型のポインタを使用してユーザー定義データを渡します。

❸❹列挙型と構造体型を定義します。Cythonプログラムで使用する列挙メンバーと構造体のフィールドのみを定義すれば十分です。Cythonで構造体のフィールドにアクセスしない場合は、フィールドの定義の代わりにpassキーワードを使用できます。

❺Cythonプログラムで呼び出す関数のプロトタイプを定義します。

上記のプログラムをC言語プログラムにコンパイルすると、#include "ahocorasick.h"という1行だけが残り、残りの型宣言はCythonにこれらの型を操作するステートメントをどのようにコンパイルするかを指示します。例えば、構造体AC_PATTERN_tlengthフィールドはunsigned int型として宣言されているため、必要に応じてCythonはPython/C APIを呼び出して、Pythonの整数オブジェクトとunsigned int型の間で変換を行います。

次に、MultiSearch拡張タイプの定義を示します:

%include cython_examples/multisearch.pyx 2
cdef class MultiSearch:
    
    cdef AC_AUTOMATA_t * _auto #❶
    cdef bint found
    cdef object callback
    cdef object exc_info
    cdef list _keywords

    def __cinit__(self, keywords):
        self._auto = ac_automata_init()
        if self._auto is NULL:
            raise MemoryError
        self._keywords = []
        self.add(keywords) #❷

    def __dealloc__(self):
        if self._auto is not NULL:
            ac_automata_release(self._auto)
            
    cdef add(self, keywords):
        cdef AC_PATTERN_t pattern
        cdef bytes keyword
        cdef AC_STATUS_t err
        
        for keyword in keywords: #❸
            self._keywords.append(keyword)
            pattern.astring = <char *>keyword
            pattern.length = len(keyword)
            err = ac_automata_add(self._auto, &pattern)
            if err != ACERR_SUCCESS:
                raise ValueError("Error Code:%d" % err)

        ac_automata_finalize(self._auto)

_auto属性はAC_AUTOMATA_t構造体へのポインタです。__cinit__()ac_automata_init()を呼び出してメモリを割り当て、__dealloc__()ac_automata_release()を呼び出してメモリを解放します。❷AC_AUTOMATA_t構造体の割り当てが成功した後、cdefメソッドadd()を呼び出してすべてのキーワードをこの構造体に追加します。

add()内部でkeywordsパラメータをイテレートし、各要素をbytes型として処理し、<char *>を使用してC言語の文字ポインタ型に変換し、その長さとともにAC_PATTERN_t構造体にパッケージ化してac_automata_add()に渡します。この関数が返された後、ahocorasick内部の関数は文字ポインタが指す内容を使用しないため、この方法は安全です。後続の関数呼び出しで文字ポインタが指す内容を使用する必要がある場合は、keywords内の各バイト列オブジェクトを参照して、それらが早期にガベージコレクションされないようにする必要があります。

次に、isin()メソッドの定義を示します:

%include cython_examples/multisearch.pyx 3
    def isin(self, bytes text, bint keep=False):
        cdef AC_TEXT_t temp_text   #❶
        temp_text.astring = <char *>text
        temp_text.length = len(text)
        self.found = False         #❷
        ac_automata_search(self._auto, &temp_text, keep, isin_callback, <void *>self) #❸
        return self.found

❶ターゲットのバイト列をAC_TEXT_t構造体にパッケージ化し、❸ac_automata_search()に渡して検索を実行します。検索の前に、❷found属性をFalseに設定します。isin_callback()関数のアドレスをac_automata_search()に渡して検索のコールバック関数として使用します。その最後のパラメータはコールバック関数に渡されるユーザーデータで、ここではMultiSearchオブジェクトのアドレスを渡します。これにより、コールバック関数内でMultiSearchオブジェクトの属性にアクセスできます。

以下はisin_callback()コールバック関数の定義です。この関数はC言語ライブラリ内部で呼び出されるため、cdefでのみ定義できることに注意してください:

%include cython_examples/multisearch.pyx 4
cdef int isin_callback(AC_MATCH_t * match, void * param) noexcept:
    cdef MultiSearch ms = <MultiSearch> param #❶
    ms.found = True  #❷
    return 1  #❸

isin_callback()の第1引数はマッチ情報を記述する構造体へのポインタで、第2引数はMultiSearchオブジェクトへのポインタです。❶まず、void *型のポインタをMultiSearchオブジェクトに変換し、msを通じてMultiSearch拡張クラスで定義されたさまざまな属性やメソッドにアクセスできます。❷MultiSearchオブジェクトのfound属性をTrueに設定し、マッチ位置が見つかったことを示します。❸isin()は1つのマッチ位置を見つけるだけで十分なため、関数は1を返して検索を終了します。

次に、search()を定義します。その第1引数は検索対象のバイト列で、第2引数はPythonの呼び出し可能オブジェクトです。マッチ位置が見つかるたびにこのオブジェクトを呼び出して処理します:

Tip

C言語のコールバック関数とPythonのコールバック関数の使用方法を紹介するために、ここではac_automata_search()を使用していますが、実際には後述のac_automata_findnext()を使用すると、isin()search()関数をより簡単に記述できます。

%include cython_examples/multisearch.pyx 5
    def search(self, bytes text, callback, bint keep=False):
        cdef AC_TEXT_t temp_text
        temp_text.astring = <char *>text
        temp_text.length = len(text)
        self.found = False
        self.callback = callback  #❶
        self.exc_info = None
        ac_automata_search(self._auto, &temp_text, keep, search_callback, <void *>self) #❷
        if self.exc_info is not None:
            raise self.exc_info[1], None, self.exc_info[2]  #❸

ac_automata_search()のコールバック関数はsearch_callback()です。Pythonの呼び出し可能オブジェクトを呼び出すために、❶callbackパラメータをself.callbackに渡します。❸C言語の関数はPython関数がスローする例外情報を上位に伝えることができないため、exc_info属性を使用してPythonコールバック関数でスローされる可能性のある例外を伝える必要があります。

以下は search_callback() コールバック関数の定義です。noexcept を指定することで、この関数内では例外が発生しないことを宣言します。

%include cython_examples/multisearch.pyx 6
cdef int search_callback(AC_MATCH_t * match, void * param) noexcept:
    cdef MultiSearch ms = <MultiSearch> param
    cdef bytes pattern = match.patterns.astring
    cdef int res = 1
    try:
        res = ms.callback(match.position - len(pattern), pattern)  #❶
    except Exception as ex:
        import sys
        ms.exc_info = sys.exc_info()  #❷
    return res

第2引数をMultiSearchオブジェクトに変換した後、❶callbackが指すPythonコールバック関数を呼び出し、スローされる可能性のある例外を捕捉します。❷例外とそのトレースバック情報をexc_info属性に保存します。search()の最後でexc_info属性をチェックし、設定されている場合はその中の例外オブジェクトをスローします。ここでは、キャプチャされた例外を直接保存するのではなく、sys.exc_info()を使用して例外情報を取得していることに注意してください。これにより、エラーの位置を正しく示すトレースバック情報が得られます。

multifastのC言語プログラムでは、ac_automata_settext()ac_automata_findnext()も提供されています。これら2つの関数を使用して、以下のようなジェネレータ関数iter_search()を記述できます。

%include cython_examples/multisearch.pyx 7
    def iter_search(self, bytes text, bint keep=False):
        cdef AC_TEXT_t temp_text
        cdef AC_MATCH_t * match
        cdef bytes matched_pattern
        temp_text.astring = <char *>text
        temp_text.length = len(text)
        ac_automata_settext(self._auto, &temp_text, keep)
        while True:
            match = ac_automata_findnext(self._auto)
            if match == NULL:
                break
            matched_pattern = <bytes>match.patterns.astring
            yield match.position - len(matched_pattern), matched_pattern

Cythonのテクニック集#

前の章の紹介を通じて、読者はCythonを使用してPythonプログラムの計算速度を向上させる基本的な方法をすでに習得していると思います。本章の最後の節として、Cythonのいくつかの高度な使用テクニックを見てみましょう。

ufunc 関数の作成#

NumPy の ufunc(ユニバーサル関数)は、配列の各要素に対して効率的に演算を行うための関数です。NumPy の C-API を利用すると、C 言語を用いて独自の ufunc 関数を作成できます。興味のある読者は、以下の URL から関連するチュートリアルを参照してください。

また、@cython.ufunc を使用すると、通常の関数をufunc関数に変換できます。以下のプログラムは、ロジスティック関数 \(\frac{1}{1+e^{-x}}\) を計算する ufunc logistic() を作成する例です。

%%cython --compile-args=-w
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
cimport cython
from libc.math cimport exp

@cython.ufunc
cdef double logistic(double x):
    return 1.0 / (1.0 + exp(-x))

以下は logistic() 関数のテストです。double 型の計算関数のみを定義していますが、さまざまな型の配列やリストを処理できます。

logistic([-1, 0, 1])
array([0.26894142, 0.5       , 0.73105858])

float32 型の配列を入力しても、計算結果は float64 型になります。これは、logistic() 関数が double 型の入力と出力のみをサポートしているためです。

logistic(np.arange(-1, 2, dtype=np.float32))
array([0.26894142, 0.5       , 0.73105858])

float32float64 の両方をサポートしたい場合は、次のように融合型 float_num を宣言し、関数の入出力のデータ型を float_num で定義します。

%%cython --compile-args=-w
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
cimport cython
from libc.math cimport exp

ctypedef fused float_num:
    double
    float

@cython.ufunc
cdef float_num logistic2(float_num x):
    return 1.0 / (1.0 + exp(-x))

整数を入力した場合の結果は float64 になりますが、float32 を入力した場合の出力は float32 になります。

print(logistic2([-1, 0, 1]).dtype)
print(logistic2(np.arange(-1, 2, dtype=np.float32)).dtype)
float64
float32

複数の引数を持つ関数もサポートできます。例えば、次の peaks() 関数は 2 次元関数を計算します。

%%cython --compile-args=-w
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
cimport cython
from libc.math cimport exp


@cython.ufunc
cdef double peaks(double x, double y):
    return x * exp(-x*x - y*y)
Y, X = np.ogrid[-2:2:5j, -2:2:5j]
peaks(X, Y)
array([[-0.00067093, -0.00673795,  0.        ,  0.00673795,  0.00067093],
       [-0.01347589, -0.13533528,  0.        ,  0.13533528,  0.01347589],
       [-0.03663128, -0.36787944,  0.        ,  0.36787944,  0.03663128],
       [-0.01347589, -0.13533528,  0.        ,  0.13533528,  0.01347589],
       [-0.00067093, -0.00673795,  0.        ,  0.00673795,  0.00067093]])

複数の出力がある場合、タプルでリターン型を宣言します。次の例では、二次方程式の二つの解を求める ufunc 関数を定義します。引数のデータ型は double ですが、リターン値は二つの complex 型の値です。

%%cython --compile-args=-w
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
cimport cython
from libc.complex cimport csqrt

@cython.ufunc
cdef (complex, complex) quadratic_roots(double a, double b, double c):
    cdef complex t0, t1
    t0 = csqrt(-4*a*c + b**2)
    t1 = 0.5 / a
    return t1*(-b - t0), t1*(-b + t0)
a, b, c = np.ogrid[-1:1:4j, -1:1:3j, -1:1:2j]
x0, x1 = quadratic_roots(a, b, c)
print(x0.shape)
print(x1.shape)
print(np.allclose(a * x0 ** 2 + b * x0 + c, 0))
print(np.allclose(a * x1 ** 2 + b * x1 + c, 0))
(4, 3, 2)
(4, 3, 2)
True
True

DLL内の関数の高速呼び出し#

Pythonの標準ライブラリctypesまたは拡張ライブラリcffiを使用すると、ダイナミックリンクライブラリ内の関数を簡単に呼び出すことができます。ただし、実際の関数を呼び出す前に多くの前処理が必要なため、関数呼び出しの効率は高くありません。ループ内で大量に呼び出す必要がある場合、この前処理によるオーバーヘッドがプログラムの実行速度に大きく影響します。このセクションでは、cffiを使用してダイナミックリンクライブラリ内の関数のアドレスを取得し、そのアドレスをCythonの関数に渡してループ呼び出しを行うことで、関数呼び出しの効率を向上させる方法を紹介します。

まず、C言語で関数peaks()を作成します。この関数のアドレスを正しく取得できることを確認するために、get_addr()を使用してpeaks()のアドレスを返します:

%%writefile peaks.c
#include <math.h>
double peaks(double x, double y)
{
    return x * exp(-x*x - y*y);
}

unsigned long long get_addr()
{
    return (unsigned long long)(void *)peaks;
}
Overwriting peaks.c

次に、gccを呼び出してpeaks.cpeaks.dllにコンパイルします:

!gcc -Ofast -shared -o peaks.dll peaks.c

次に、cffiを使用してpeaks.dllを読み込み、関数シグネチャを設定し、最後にその関数を呼び出します:

import cffi

ffi = cffi.FFI()
ffi.cdef(
    """
unsigned long long get_addr();
double peaks(double x, double y);
"""
)
lib = ffi.dlopen("peaks.dll")
lib.peaks(1.0, 2.0)
0.006737946999085467

lib.peaksは、C言語の関数アドレスをラップしたcffiのオブジェクトです。ffi.addressof()を使用してこの関数へのポインタを取得し、ffi.cast()を使用してポインタを整数に変換できます。以下のプログラムは、C言語関数のアドレスを取得し、get_addr()の戻り値と比較します:

peak_addr = ffi.cast("size_t", ffi.addressof(lib, "peaks"))
assert peak_addr == lib.get_addr()

以下は、Cythonでvectorize_2d()関数を作成する例です。func_addrパラメータは関数アドレスを表す整数で、xyは2次元の連続した配列です。

❶まず、ctypedefキーワードを使用して、二項倍精度浮動小数点数関数ポインタ型Functionを宣言します。❷次に、パラメータfunc_addrFunction型の関数ポインタに変換し、その後、二重ループ内でこの関数ポインタを使用して、それが指すC言語関数を高速に呼び出すことができます。

%%cython
import cython
import numpy as np

ctypedef double(*Function)(double x, double y)      #❶

@cython.wraparound(False)
@cython.boundscheck(False)
def vectorize_2d(size_t func_addr, double[:, ::1] x, double[:, ::1] y):
    cdef double[:, ::1] res = np.zeros_like(x.base)
    cdef Function func_ptr = <Function><void *>func_addr #❷
    cdef int i, j
    
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            res[i, j] = func_ptr(x[i, j], y[i, j])
            
    return res.base

以下は、vectorize_2d()を使用してpeaks()を呼び出し、vectorize()を使用して作成したufunc関数の結果と比較する例です:

Y, X = np.mgrid[-2:2:200j, -2:2:200j]
vectorize_peaks = np.vectorize(lib.peaks, otypes=["f8"])
np.allclose(vectorize_peaks(X, Y), vectorize_2d(peak_addr, X, Y))
True

本書で提供されているFuncAddrクラスを使用すると、ダイナミックリンクライブラリ内の関数のアドレスをより簡単に取得できます:

from helper.cffi import FuncAddr

msvcrt = FuncAddr("msvcrt.dll")
np.allclose(np.arctan2(X, Y), vectorize_2d(msvcrt.atan2, X, Y))
True

BLAS関数の呼び出し#

BLASは、基本的な線形代数ルーチンのAPI標準であり、SciPyの多くの高速線形代数演算関数は、Fortranで書かれたBLAS関数を内部で呼び出しています。これらのFortran関数の計算効率は高いですが、Pythonの呼び出しインターフェースによるオーバーヘッドは無視できず、大量のループで呼び出す場合にはこのオーバーヘッドが問題となります。Cythonを使用してこれらの関数をループ内で呼び出すことで、Pythonの呼び出しインターフェースの制約から完全に解放されます。

saxpy()関数のラッピング#

BLASのAPI関数はscipy.linalg.blasモジュールを通じてアクセスできます。以下は、その中のsaxpy()を呼び出すデモンストレーションです:

from scipy.linalg import blas
import numpy as np

x = np.array([1, 2, 3], np.float32)
y = np.array([1, 3, 5], np.float32)
print(blas.saxpy)
blas.saxpy(x, y, a=0.5)
<fortran function saxpy>
array([1.5, 4. , 6.5], dtype=float32)

saxpy<fortran object>であり、呼び出すたびにPythonのオブジェクトをFortran関数のパラメータに変換するため、ループ内で大量に呼び出すには適していません。

scipy.linalgには、Cythonから呼び出すためのcython_blasモジュールも提供されています。その中のsaxpy()の関数シグネチャは以下の通りです:

ctypedef float s
cdef void saxpy(int *n, s *sa, s *sx, int *incx, s *sy, int *incy) nogil

Fortran言語は参照渡しを使用するため、関数呼び出し時に渡されるパラメータと関数内で受け取るパラメータは同じメモリアドレスです。したがって、関数シグネチャではすべてのパラメータがポインタとして定義されています。

以下のプログラムでは、blas_saxpy()cython_blasモジュール内のsaxpy()を簡単にラップしています。計算速度を比較するために、cython_saxpy()はループを使用して計算を行います:

%%cython
import cython
cimport scipy.linalg.cython_blas as blas

def blas_saxpy(float[:] y, float a, float[:] x):
    cdef int n = y.shape[0]
    cdef int inc_x = x.strides[0] // sizeof(float)
    cdef int inc_y = y.strides[0] // sizeof(float)
    blas.saxpy(&n, &a, &x[0], &inc_x, &y[0], &inc_y)
    
@cython.wraparound(False)
@cython.boundscheck(False)
def cython_saxpy(float[:] y, float a, float[:] x):
    cdef int i
    for i in range(y.shape[0]):
        y[i] += a * x[i]

以下で両者の実行速度を比較します。Cythonのループで実装した関数の方が速いです。

a = np.arange(100000, dtype=np.float32)
b = np.zeros_like(a)
%timeit blas_saxpy(b, 0.2, a)
%timeit cython_saxpy(b, 0.2, a)
164 μs ± 7.5 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
96 μs ± 1.39 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

dgemm()高速行列積#

BLASのDGEMM()は以下の行列積演算を実装しています。パラメータalphaが1、betaが0の場合、結果Cは行列ABの積となります:

C = alpha*op(A)*op(B) + beta*C

ここで、op()は行列を転置することができます。Fortran形式の配列とC言語形式の配列の軸の順序が逆であるため、2つのC言語の配列で表される行列の積を計算するには、op()を転置に設定する必要があります。転置するかどうかに関わらず、演算結果CはFortran形式の配列となります。

Fortran関数dgemm()のパラメータは以下の通りです:

subroutine dgemm	(
    character 	TRANSA,
    character 	TRANSB,
    integer 	M,
    integer 	N,
    integer 	K,
    double precision 	ALPHA,
    double precision, dimension(lda,*) 	A,
    integer 	LDA,
    double precision, dimension(ldb,*) 	B,
    integer 	LDB,
    double precision 	BETA,
    double precision, dimension(ldc,*) 	C,
    integer 	LDC 
)

以下は、cython_blasモジュール内の関数シグネチャです:

ctypedef double d
cdef void dgemm(
    char *transa, 
    char *transb, 
    int *m, 
    int *n, 
    int *k, 
    d *alpha, 
    d *a, 
    int *lda, 
    d *b, 
    int *ldb, 
    d *beta, 
    d *c, 
    int *ldc) nogil

以下のCython関数dgemm(A, B, index)では、ABはC言語形式の3次元配列で、形状はそれぞれ(La, M, K)(Lb, K, N)です。indexは形状が(Lc, 2)の整数配列です。この関数は、index内の各整数ペアj, kに対してC[i] = A[j] * B[k]を計算します。ここで、iはその整数ペアのインデックスです。したがって、関数の戻り値Cは形状が(Lc, N, M)の3次元配列です。CLc個のFortran形式の2次元配列と見なすことができます。

メモリビューの要素にアクセスする際にはPythonに関連する操作は行われず、各行列の積演算は互いに独立しているため、この部分を並列化することができます。CythonはOpenMPを使用して並列化を実装しているため、コンパイル時に-fopenmpオプションを設定する必要があります。

❶並列化されたprange()関数を読み込みます。この関数は並列化されたループにコンパイルされます。❷nogilパラメータをTrueに設定し、並列化中にPythonのグローバルロックを解放することを示します。

%%cython -c-Ofast -c-fopenmp --link-args=-fopenmp

from cython.parallel import prange #❶
import cython
import numpy as np
cimport scipy.linalg.cython_blas as blas

@cython.wraparound(False)
@cython.boundscheck(False)
def dgemm(double[:, :, :] A, double[:, :, :] B, int[:, ::1] index):
    cdef int m, n, k, i, length, idx_a, idx_b
    cdef double[:, :, :] C
    cdef char ta, tb
    cdef double alpha = 1.0
    cdef double beta = 0.0
        
    length = index.shape[0]
    m, k, n  = A.shape[1], A.shape[2], B.shape[2]        
    
    C = np.zeros((length, n, m))
    
    ta = b"T"
    tb = ta
    
    for i in prange(length, nogil=True): #❷
        idx_a = index[i, 0]
        idx_b = index[i, 1]
        blas.dgemm(&ta, &tb, &m, &n, &k, &alpha, 
               &A[idx_a, 0, 0], &k, 
               &B[idx_b, 0, 0], &n, 
               &beta,
               &C[i, 0, 0], &m)
    
    return C.base

NumPyに新しく追加されたgufunc関数は、単一の行列の演算をブロードキャストして配列全体に適用することができます。NumPyに新しく追加された行列積演算子@は、行列積のブロードキャスト演算を実現します。同様の機能は、上記のdgemm()を使用して実現することもできます。以下のmatrix_multiply(a, b)は、2つの任意の次元数の配列の最後の2つの軸に対して行列積演算を行い、他の軸に対してブロードキャスト演算を行います。例えば、aの形状が(12,  1, 10, 100, 30)で、bの形状が( 1, 15,  1,  30, 50)の場合、最後の2つの軸に対応する行列積の結果の形状は(100, 50)であり、他の軸のブロードキャスト後の形状は(12, 15, 10)であるため、結果の配列の形状は(12, 15, 10, 100, 50)となります。合計で\(12 \times 15 \times 10\)回の行列積演算が行われます。

このプログラムの実装の考え方は以下の通りです。詳細は読者自身で研究してください:

  • a内の各行列に番号を付け、その番号の形状をaのブロードキャスト部分の形状に変更してidx_aを得ます。

  • bに対しても同様の操作を行い、idx_bを得ます。

  • broadcast_arrays()を使用して、idx_aidx_bのブロードキャスト後の配列を計算します。

  • 上記の2つの配列を平坦化して2列に並べ、dgemm()関数のindexパラメータを得ます。

  • abの形状を3次元配列に変更してdgemm()関数に渡し、行列積を計算します。

def matrix_multiply(a, b):
    if a.ndim <= 2 and b.ndim <= 2:
        return np.dot(a, b)

    a = np.ascontiguousarray(a).astype(np.float64, copy=False)
    b = np.ascontiguousarray(b).astype(np.float64, copy=False)
    if a.ndim == 2:
        a = a[None, :, :]
    if b.ndim == 2:
        b = b[None, :, :]

    shape_a = a.shape[:-2]
    shape_b = b.shape[:-2]
    len_a = np.prod(shape_a)
    len_b = np.prod(shape_b)

    idx_a = np.arange(len_a, dtype=np.int32).reshape(shape_a)
    idx_b = np.arange(len_b, dtype=np.int32).reshape(shape_b)
    idx_a, idx_b = np.broadcast_arrays(idx_a, idx_b)

    index = np.column_stack((idx_a.ravel(), idx_b.ravel()))
    bshape = idx_a.shape

    if a.ndim > 3:
        a = a.reshape(-1, a.shape[-2], a.shape[-1])
    if b.ndim > 3:
        b = b.reshape(-1, b.shape[-2], b.shape[-1])

    if a.shape[-1] != b.shape[-2]:
        raise ValueError("can't do matrix multiply because k isn't the same")

    c = dgemm(a, b, index)
    c = np.swapaxes(c, -2, -1)
    c.shape = bshape + c.shape[-2:]
    return c

以下でmatrix_multiply()@演算子の計算結果を比較します:

a = np.random.rand(12, 1, 10, 100, 30)
b = np.random.rand(1, 15, 1, 30, 50)
np.allclose(matrix_multiply(a, b), a @ b)
True

以下は両者の実行速度の比較です。matrix_multiply()はNumPyの組み込み行列積演算子より高速です:

%timeit matrix_multiply(a, b)
%timeit a @ b
177 ms ± 4.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
414 ms ± 116 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)