tf.dataset无法推断shape导致错误
深度学习
使用tensorflow.keras的时候,tf.dataset在执行model.fit的时候报错:
ValueError: Cannot take the length of shape with unknown rank.
这里大概率是因为tf.dataset中使用了tf.py_function导致无法自动推导出张 良的形状,所以需要自己手动设置形状。
解决方案
这里一定要使用tensorflow 1.x版本,2.0中我也没找到解决方案😓,使用tf.contrib.data.assert_element_shape 函数直接指定形状即可。
import tensorflow as tf
from tensorflow.python import keras
yolo_model = keras_yolo_mobilev2((240, 320, 3), 3, 20, 1., True)
shapes = (yolo_model.input.shape, tuple(out.shape for out in yolo_model.output))
h.train_dataset = h.train_dataset.apply(tf.contrib.data.assert_element_shape(shapes))
yolo_model.fit(h.train_dataset, epochs=max_nrof_epochs,
steps_per_epoch=h.train_epoch_step,callbacks=[tbcall])