numpy中继承ndarray

编程语言
Published

July 17, 2021

关于如何更好的利用numpy中ndarray的特性来提升编程舒适度.

Code
import numpy as np

定义一个自定义的数组容器

注意数组的容器并不一定要继承ndarray,只需要添加__array__方法即可,所以首先定义一个类:

Code
class PEArray:
  def __init__(self, height, width, spad_size):
    self._h = height
    self._w = width
    self._spad = spad_size
    self._pe = np.random.rand(self._h, self._w)

  def __repr__(self):
    return f"{self.__class__.__name__}(h={self._h}, w={self._w})"

  def __array__(self, dtype=None):
    return self._pe

我们可以方便的初始化他,并且将其中的数据通过np.array\np.asarray方法进行获取.

Code
pe = PEArray(3, 4, 8)
pe
PEArray(h=3, w=4)
Code
np.asarray(pe)
array([[0.44450444, 0.85014876, 0.93706849, 0.9179388 ],
       [0.86145163, 0.11632653, 0.39719148, 0.23972649],
       [0.81139147, 0.46744501, 0.83977769, 0.76806018]])

__array__类似c++中的数据类型重载转换重载,所以我们可以传入一个PEArray对象到numpy的计算函数中进行计算,但是需要注意的是返回值类型肯定是ndarray了:

Code
np.multiply(pe, 2)
array([[0.88900889, 1.70029753, 1.87413698, 1.83587761],
       [1.72290325, 0.23265307, 0.79438295, 0.47945298],
       [1.62278294, 0.93489002, 1.67955538, 1.53612035]])

那么如果我们既想使用numpy提供的方法,又想保持我们的数据类型不变,仅对类中的数据进行操作,那么需要通过__array_ufunc____array_function__进行适配. 首先从__array_ufunc__方法开始:

__array_ufunc__

__array_ufunc__是一个unary操作函数的一个接口,即调用ufunc是对数组元素进行elemwise的操作,比如add\subtract\multiply\log\sin等等.

每个__array_ufunc__接收参数如下: - ufunc, ufunc函数对象,比如numpy.xxx

  • method, 方法名,因为每个ufunc函数对象都有四个方法,所以还得选方法

  • inputs, 输入对象

  • kwargs, ufunc的可选参数

对于每个ufunc都有相同的输入参数、属性,这个可以去文档中看,主要是每个函数还对应了4个method: |name | description| |-|-| |ufunc.reduce(array[, axis, dtype, out, …])|Reduces array’s dimension by one, by applying ufunc along one axis.| |ufunc.accumulate(array[, axis, dtype, out])|Accumulate the result of applying the operator to all elements.| |ufunc.reduceat(array, indices[, axis, …])|Performs a (local) reduce with specified slices over a single axis.| |ufunc.outer(A, B, /, **kwargs)|Apply the ufunc op to all pairs (a, b) with a in A and b in B.| |ufunc.at(a, indices[, b])|Performs unbuffered in place operation on operand ‘a’ for elements specified by ‘indices’. |

接下来我们适配一个__call__方法,也就是直接调用的方法:

Code
from numbers import Number


class PEArray:
  def __init__(self, height, width, spad_size, pe=None):
    self._h = height
    self._w = width
    self._spad = spad_size
    if pe is not None:
      self._pe = pe
    else:
      self._pe = np.random.rand(self._h, self._w)

  def __repr__(self):
    return f"{self.__class__.__name__}(h={self._h}, w={self._w})"

  def __array__(self, dtype=None):
    return self._pe

  def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
    if method == '__call__':
      scalars = []
      objects = []
      for input in inputs:
        if isinstance(input, Number):
          scalars.append(input)
        elif isinstance(input, self.__class__):
          if input._pe.shape != self._pe.shape:
            raise ValueError("inconsistent shape")
          objects.append(input._pe)
        else:
          return NotImplementedError("not support the other type")
      return self.__class__(self._h, self._w, self._spad, ufunc(*objects, *scalars, **kwargs))
    else:
      return NotImplementedError("now only support __call__!")

在编写以上代码时需要注意类内的array也会被传入到input里面的,所以不要手动再传入self._pe了. 还有就是要给自己类写一个合适的构造函数,以便于直接传入数组重新构造,接下来可以看到可以输出的正确的对象了.

Code
a = PEArray(3,4,5)
b = 3.
c = PEArray(3,4,6)
print(np.add(a,b))
print(np.multiply(a,c))
PEArray(h=3, w=4)
PEArray(h=3, w=4)

但是还有个问题,我们此时没有继承python内部的操作符号:

Code
a + b
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-8-bd58363a63fc> in <module>
----> 1 a + b

TypeError: unsupported operand type(s) for +: 'PEArray' and 'float'

如果一个个继承比较麻烦,我们可以继承numpy内置的脚手架类numpy.lib.mixins.NDArrayOperatorsMixin

Code
from numpy.lib.mixins import NDArrayOperatorsMixin


class PEArray(NDArrayOperatorsMixin):
  def __init__(self, height, width, spad_size, pe=None):
    self._h = height
    self._w = width
    self._spad = spad_size
    if pe is not None:
      self._pe = pe
    else:
      self._pe = np.random.rand(self._h, self._w)

  def __repr__(self):
    return f"{self.__class__.__name__}(h={self._h}, w={self._w})"

  def __array__(self, dtype=None):
    return self._pe

  def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
    if method == '__call__':
      scalars = []
      objects = []
      for input in inputs:
        if isinstance(input, Number):
          scalars.append(input)
        elif isinstance(input, self.__class__):
          if input._pe.shape != self._pe.shape:
            raise ValueError("inconsistent shape")
          objects.append(input._pe)
        else:
          return NotImplementedError("not support the other type")
      return self.__class__(self._h, self._w, self._spad, ufunc(*objects, *scalars, **kwargs))
    else:
      return NotImplementedError("now only support __call__!")
Code
a = PEArray(1,2,3)
b = 10
a + b
PEArray(h=1, w=2)

__array_function__

之前方式我们支持了ufunc,其实按那种方式也可以支持一些非ufunc,比如np.sum,其实他默认是调用的reduce方法,那么只需要在__array_ufunc__中添加对reduce也是可以的. 不过还有一种更加方便的方式,那就是直接在整个函数级别进行overwrite,比如我们要使用w w

Code
np.sum(a)
NotImplementedError('now only support __call__!')
Code
from typing import List
HANDLED_FUNCTIONS = {}


def register(np_function):
  def decorator(func):
    HANDLED_FUNCTIONS[np_function] = func
    return func
  return decorator


class PEArray(NDArrayOperatorsMixin):
  def __init__(self, height, width, spad_size, pe=None):
    self._h = height
    self._w = width
    self._spad = spad_size
    if pe is not None:
      self._pe = pe
    else:
      self._pe = np.random.rand(self._h, self._w)

  def __repr__(self):
    return f"{self.__class__.__name__}(h={self._h}, w={self._w})"

  def __array__(self, dtype=None):
    return self._pe

  def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
    if method == '__call__':
      scalars = []
      objects = []
      for input in inputs:
        if isinstance(input, Number):
          scalars.append(input)
        elif isinstance(input, self.__class__):
          if input._pe.shape != self._pe.shape:
            raise ValueError("inconsistent shape")
          objects.append(input._pe)
        else:
          return NotImplementedError("not support the other type")
      return self.__class__(self._h, self._w, self._spad, ufunc(*objects, *scalars, **kwargs))
    else:
      return NotImplementedError("now only support __call__!")

  def __array_function__(self, func, types, args, kwargs):
    if func not in HANDLED_FUNCTIONS:
      return NotImplemented
    # Note: this allows subclasses that don't override
    # __array_function__ to handle DiagonalArray objects.
    if not all(issubclass(t, self.__class__) for t in types):
      return NotImplemented
    return HANDLED_FUNCTIONS[func](*args, **kwargs)


@register(np.sum)
def pe_sum(arr: PEArray) -> np.ndarray:
  return arr._pe.sum()


@register(np.concatenate)
def pe_concat(arrs: List[PEArray], axis: int = 0):
  assert(len(arrs) > 1)
  assert(axis < 2)
  assert((arrs[0]._spad == np.array([arr._spad for arr in arrs[1:]])).all())
  new_pe = np.concatenate([arr._pe for arr in arrs], axis=axis)
  return PEArray(new_pe.shape[0], new_pe.shape[1], arrs[0]._spad, new_pe)
Code
a = PEArray(2, 4, 1)
b = PEArray(2, 3, 1)
c = PEArray(3, 4, 1)

np.sum(a)
5.520550933404442
Code
np.concatenate([a, b], axis=1)
PEArray(h=2, w=7)
Code
np.concatenate([a, c], axis=0)
PEArray(h=5, w=4)

总结

自定义数组容器的方法还是比较方便的,同时可以在两个层次上最大程度的复用numpy内置的接口,提高抽象的一致性.

从ndarray类继承

子类化ndarray还是相对来说比较复杂的,主要就是ndarray可以通过多种方式被构造(想想c++的移动构造拷贝构造): ## 1. 显式构造,比如PEarray(params)

2. view转换,类似于c++中的dynamic cast.

Code
import numpy as np
# create a completely useless ndarray subclass
class C(np.ndarray): pass
# create a standard ndarray
arr = np.zeros((3,))
# take a view of it, as our useless subclass
c_arr = arr.view(C)
type(c_arr)
__main__.C

3. from template, 比如copy,slice,ufunc都会生成

Code
v = c_arr[1:]
print(type(v)) # 切片后还是老类别,那是因为切片只是原始数组中的一个数组投影.
<class '__main__.C'>

view cast和from template的关系

view cast主要是当有了一个完整的ndarry的时候,创建子类类型的新对象.from template主要是从已有的对象中创建新对象, 这时候我们子类的属性通常就要复制过去.

继承的问题

继承的问题在于我们编写合适的处理方法对应以上三种情况的,否则你编写的子类很容易就变成了ndarray类型,导致后续调用出错.

1. __new__方法

首先我们不能从__init__方法开始,因为ndarray是从__new__方法就开始构造了的.__new__是可以返回任意的值的,同时__init__方法的self参数其实是从__new__返回的.

一个类构造的流程其实这样的,从__new__中创建特定类型的对象,然后返回值传入到__init__方法中对对象的属性等进行修改,最后这个对象返回给用户. 也就是我们从pe=PEArray()中获取的对象就是从new中返回的.

通过重载__new__方法,我们可以做到对一个类返回不同类型的对象,下面这个例子就是从初始化D返回一个C对象(因为他返回对象类型不是自身类型,所以不会触发__init__):

Code
class C:
    def __new__(cls, *args):
        print('Cls in __new__:', cls)
        print('Args in __new__:', args)
        # The `object` type __new__ method takes a single argument.
        return object.__new__(cls)

    def __init__(self, *args):
        print('type(self) in __init__:', type(self))
        print('Args in __init__:', args)

class D(C):
    def __new__(cls, *args):
        print('D cls is:', cls)
        print('D args in __new__:', args)
        return C.__new__(C, *args)

    def __init__(self, *args):
        # we never get here
        print('In D __init__')
D()
D cls is: <class '__main__.D'>
D args in __new__: ()
Cls in __new__: <class '__main__.C'>
Args in __new__: ()
<__main__.C at 0x10fd0cb50>

view cast的时候其实就是使用__new__方法,通过obj = ndarray.__new__(subtype, shape, ...返回了一个子类的对象,保证了子类在切片等时候返回对象的一致性.

2. __array_finalize__方法

array_finalize 是 numpy 提供的机制,允许子类处理创建新实例的各种方式。因为上面的__new__只有在显式构建的时候才会被调用,所以需要这个方法对别的创建方法进行处理

Code

class C(np.ndarray):
  def __new__(cls, *args, **kwargs):
    print('In __new__ with class %s' % cls)
    return super().__new__(cls, *args, **kwargs)

  def __init__(self, *args, **kwargs):
    # in practice you probably will not need or want an __init__
    # method for your subclass
    print('In __init__ with class %s' % self.__class__)

  def __array_finalize__(self, obj):
    print('In array_finalize:')
    print('   self type is %s' % type(self))
    print('   obj type is %s' % type(obj))
print("\nmethod 1 \n")
c = C((1,2,3))
print("\nmethod 2 \n")
np.arange(10).view(C)
print("\nmethod 3 \n")
cc = c[1:]

method 1 

In __new__ with class <class '__main__.C'>
In array_finalize:
   self type is <class '__main__.C'>
   obj type is <class 'NoneType'>
In __init__ with class <class '__main__.C'>

method 2 

In array_finalize:
   self type is <class '__main__.C'>
   obj type is <class 'numpy.ndarray'>

method 3 

In array_finalize:
   self type is <class '__main__.C'>
   obj type is <class '__main__.C'>

上述的例子中,可以看出array_finalize方法是可以在不同的构造方式中被调用的,在不同的构造方法中,他所接收的参数也是不同的:

  1. 显式构造的时候obj是None
  2. view cast时,obj是ndarray的任意子类类型
  3. from template时,obj是当前子类的一个对象,我们可以用这个对象来更新self这个对象.

所以在array_finalize中对self设置一系列属性是比较合适的.

例子1 向ndarray添加额外属性

Code
class PEArray(np.ndarray):
  def __new__(subtype, height, width, spad_size, max_height=12, max_width=14, dtype=float, buffer=None, offset=0,
              strides=None, order=None):
    obj = super().__new__(subtype, (height, width, spad_size), dtype=dtype, buffer=buffer,
                          offset=offset, strides=strides, order=order)
    obj.h = height
    obj.w = width
    obj.spad = spad_size
    obj.mh = max_height
    obj.mw = max_width
    return obj

  def __array_finalize__(self, obj):
    # 1. 显示构造函数 obj=none
    if obj is None:
      return
    # 2. view cast, type(obj) == np.ndarray
    if type(obj) == np.ndarray:
      self.h = self.shape[0]
      self.w = self.shape[1]
      self.spad = self.shape[2]
      self.mh = getattr(obj, 'mh', 0)
      self.mw = getattr(obj, 'mw', 0)
    # 3. from template, type(obj) == PEArray
    if type(obj) == PEArray:
      self.h = self.shape[0]
      self.w = self.shape[1]
      self.spad = self.shape[2]
      self.mh = getattr(obj, 'mh')
      self.mw = getattr(obj, 'mw')


print('\nmethod 1:\n')
pearr = PEArray(2, 3, 8)
print(type(pearr))
print(pearr.h, pearr.w, pearr.spad, pearr.mh, pearr.mw)

print('\nmethod 2:\n')
r = np.random.rand(3, 4, 6)
rr = r.view(PEArray)
print(type(rr))
print(rr.h, rr.w, rr.spad, rr.mh, rr.mw)

print('\nmethod 3:\n')
pearr_sub = pearr[2:]
print(type(pearr_sub))
print(pearr_sub.h, pearr_sub.w, pearr_sub.spad, pearr_sub.mh, pearr_sub.mw)

method 1:

<class '__main__.PEArray'>
2 3 8 12 14

method 2:

<class '__main__.PEArray'>
3 4 6 0 0

method 3:

<class '__main__.PEArray'>
0 3 8 12 14

其实对于view cast,我们可以不做支持.然后对于from template,其中self就是已经被切分的数组部分,但是他的一些属性还是在obj中,所以需要取出. 实际我感觉对于带大量额外参数的子类,是需要禁止view cast构造的,但是不知道会不会造成一些问题.