OpenPose人体姿态估计

深度学习
Published

May 16, 2020

我参考tf-pose-estimation重新构建一个人体姿态估计模型,下面主要记录一下坑点.

标签制作

  1. tensorflow重写heatmap失败.

这个姿态估计其实和CenterNet是差不多的实现思路,构建标签的时候都是通过对图像的对应位置放置高斯分布的方式制作出heatmap,我一开始看懂了之后就准备直接用tensorflow的方式重写,结果如下:

  def get_heatmap(self, im_h, im_w, joint_list, th=4.6052, sigma=8.):

    heatmap: tf.Variable = tf.Variable(tf.zeros((self.hparams.parts, im_h, im_w)), trainable=False)

    for joints in joint_list:
      for i, center in enumerate(joints):
        if center[0] < 0 or center[1] < 0:
          continue

        delta = tf.sqrt(th * 2)
        # p0 -> x,y    p1 -> x,y
        im_wh = tf.cast((im_w, im_h), tf.float32)
        p0 = tf.cast(tf.maximum(0., center - delta * sigma), tf.int32)
        p1 = tf.cast(tf.minimum(im_wh, center + delta * sigma), tf.int32)

        x = tf.range(p0[0], p1[0])[None, :, None]
        y = tf.range(p0[1], p1[1])[:, None, None]

        p = tf.concat([x + tf.zeros_like(y), tf.zeros_like(x) + y], axis=-1)
        exp = tf.reduce_sum(tf.square(tf.cast(p, tf.float32) - center), -1) / (2. * sigma * sigma)
        # use indices update point area
        indices = tf.concat([tf.ones(p.shape[:-1] + [1], tf.int32) * i,
                             p[..., ::-1]], -1)
        # NOTE p is [x,y] , but `gather_nd` and `scatter_nd` require [y,x]
        old_center_area = tf.gather_nd(heatmap, indices)
        center_area = tf.minimum(tf.maximum(old_center_area, tf.exp(-exp)), 1.0)
        center_area = tf.where(exp > th, old_center_area, center_area)

        heatmap.scatter_nd_update(indices, center_area)
    # use indices update heatmap background NOTE scatter_nd can't use -1
    heatmap.scatter_update(tf.IndexedSlices(
        tf.clip_by_value(1. - tf.reduce_max(heatmap, axis=0), 0., 1.),
        self.hparams.parts - 1))

    heatmap_tensor = tf.transpose(heatmap, (1, 2, 0))

    if self.target_hw:
      heatmap_tensor = tf.image.resize(heatmap_tensor, self.target_hw)

    return heatmap_tensor

到这里其实是可以正确运行的,但是问题在于joint_list(关键点注释)他是个数是不确定的,所以还得把这个函数改成动态迭代的(用tf.while_loop),调这个函数我都调到吐了..所以我暂时先不搞全tf化.

  1. 向量化制作heatmap

之前tf化输入管道太麻烦我就直接把他全部的输入管道复制过来,但是性能问题让我吐血…而且有个很奇特问题,我本来是用tensor_slices制作tf.dataset的,虽然速度慢但还是十几秒还是有一个batch.然后我改成tfrecord的方式,发现可能是前端读的数据太快,后面标签制作太慢反而导致一分钟才能读取一个batch

然后我又得重写numpy版的制作heatmap.

def get_heatmap_v(self, target_hw, joint_list_mask):
  heatmap = np.zeros((self.coco_parts, self.height, self.width), dtype=np.float32)
  height, width = self.height, self.width
  for (joints, masks) in zip(self.joint_list, joint_list_mask):
    for idx, (point, mask) in enumerate(zip(joints, masks)):
      if mask:
        th = 4.6052
        delta = np.sqrt(th * 2)
        p0 = np.maximum(0., point - delta * self.sigma).astype('int32')
        p1 = np.minimum([width, height], point + delta * self.sigma).astype('int32')

        x = np.arange(p0[0], p1[0])
        y = np.arange(p0[1], p1[1])

        xv, yv = np.meshgrid(x, y, sparse=False, indexing='xy')

        exp = (((xv - point[0]) ** 2 + (yv - point[1]) ** 2) /
                (2.0 * self.sigma * self.sigma))

        yidx = yv[exp < th]
        xidx = xv[exp < th]
        exp_valid = exp[exp < th]

        heatmap[idx, yidx, xidx] = np.minimum(
            np.maximum(heatmap[idx, yidx, xidx],
                        np.exp(-exp_valid)), 1.0)

  heatmap = heatmap.transpose((1, 2, 0))

  # background
  heatmap[:, :, -1] = np.clip(1 - np.amax(heatmap, axis=2), 0.0, 1.0)

  if target_hw:
    heatmap = cv2.resize(heatmap, target_hw[::-1], interpolation=cv2.INTER_LINEAR)

  return heatmap

def get_heatmap(self, target_hw):
  heatmap = np.zeros((self.coco_parts, self.height, self.width), dtype=np.float32)

  for joints in self.joint_list:
    for idx, point in enumerate(joints):
      if point[0] < 0 or point[1] < 0:
        continue
      ImageMeta.put_heatmap(heatmap, idx, point, self.sigma)

  heatmap = heatmap.transpose((1, 2, 0))

  # background
  heatmap[:, :, -1] = np.clip(1 - np.amax(heatmap, axis=2), 0.0, 1.0)

  if target_hw:
    heatmap = cv2.resize(heatmap, target_hw[::-1], interpolation=cv2.INTER_LINEAR)

  return heatmap

@staticmethod
def put_heatmap(heatmap, plane_idx, center, sigma):
  center_x, center_y = center  # point
  _, height, width = heatmap.shape[:3]  # 热图大小

  th = 4.6052
  delta = math.sqrt(th * 2)

  x0 = int(max(0, center_x - delta * sigma))
  y0 = int(max(0, center_y - delta * sigma))

  x1 = int(min(width, center_x + delta * sigma))
  y1 = int(min(height, center_y + delta * sigma))

  cnt = 0
  uncnt = 0
  for y in range(y0, y1):
    for x in range(x0, x1):
      d = (x - center_x) ** 2 + (y - center_y) ** 2
      exp = d / 2.0 / sigma / sigma
      # NOTE 如果这个点不是靠近图像边沿,exp就是th的两倍
      if exp > th:
        continue
      heatmap[plane_idx][y][x] = max(heatmap[plane_idx][y][x], math.exp(-exp))
      heatmap[plane_idx][y][x] = min(heatmap[plane_idx][y][x], 1.0)

然后看一下耗时对比,直接提高32倍~

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   666                                           @func_line_time([])
   667                                           def dev_vector_get_heatmap():
   668                                             """ 构造一个向量化的heatmap函数,并测试速度 """
   669         1     653201.0 653201.0      1.3    h, ds = get_raw_dataset()
   670                                             h: OpenPoseHelper
   671         1       8555.0   8555.0      0.0    iters = iter(ds)
   672       101        142.0      1.4      0.0    for i in range(100):
   673       100     319597.0   3196.0      0.6      img, joint_list = next(iters)
   674       100      12173.0    121.7      0.0      img, joint_list = img.numpy(), joint_list.numpy()
   675                                           
   676       100       3520.0     35.2      0.0      meta = ImageMeta(img, joint_list, h.in_hw, h.hparams.parts, h.hparams.vecs, h.hparams.sigma)
   677                                           
   678       100   48977247.0 489772.5     95.1      heatmap = meta.get_heatmap(h.target_hw)
   679       100       4721.0     47.2      0.0      joint_list_mask = np.logical_not(np.all(joint_list == -1000, -1, keepdims=True))
   680       100    1489782.0  14897.8      2.9      heatmap_v = meta.get_heatmap_v(h.target_hw, joint_list_mask)
   681                                           
   682       100      25532.0    255.3      0.0      print(np.allclose(heatmap, heatmap_v))

我发现我有时候就是太喜欢过早优化代码..这真的是万恶之源,很容易吃力不讨好.一开始直接用numpy重写就能顺利提高速度还不用费那么多时期.