Tensorflow无法从dataset.map(mapFn)中的方法获取image.shape


10

我正在尝试做的tensorflow等效操作torch.transforms.Resize(TRAIN_IMAGE_SIZE),它将最小的图像尺寸调整为TRAIN_IMAGE_SIZE。像这样

def transforms(filename):
  parts = tf.strings.split(filename, '/')
  label = parts[-2]

  image = tf.io.read_file(filename)
  image = tf.image.decode_jpeg(image)
  image = tf.image.convert_image_dtype(image, tf.float32)

  # this doesn't work with Dataset.map() because image.shape=(None,None,3) from Dataset.map()
  image = largest_sq_crop(image) 

  image = tf.image.resize(image, (256,256))
  return image, label

list_ds = tf.data.Dataset.list_files('{}/*/*'.format(DATASET_PATH))
images_ds = list_ds.map(transforms).batch(4)

简单的答案在这里:Tensorflow:裁剪图像的最大中央正方形区域

但是当我使用该方法时tf.data.Dataset.map(transforms),我shape=(None,None,3)从内部得到了largest_sq_crop(image)。当我正常调用该方法时,它可以正常工作。


1
我认为问题与以下事实有关:EagerTensors内部Dataset.map()形状不可用。有解决方法吗?
迈克尔

可以包括的定义largest_sq_crop吗?
jakub

Answers:


1

我找到了答案。这与以下事实有关:我的resize方法可以很好地执行,例如,tf.executing_eagerly()==True但是在内使用时失败dataset.map()。显然,在该执行环境中,tf.executing_eagerly()==False

我的错误在于我要解压缩图像的形状以获取缩放比例的方式。Tensorflow图执行似乎不支持访问tensor.shape元组。

  # wrong
  b,h,w,c = img.shape
  print("ERR> ", h,w,c)
  # ERR>  None None 3

  # also wrong
  b = img.shape[0]
  h = img.shape[1]
  w = img.shape[2]
  c = img.shape[3]
  print("ERR> ", h,w,c)
  # ERR>  None None 3

  # but this works!!!
  shape = tf.shape(img)
  b = shape[0]
  h = shape[1]
  w = shape[2]
  c = shape[3]
  img = tf.reshape( img, (-1,h,w,c))
  print("OK> ", h,w,c)
  # OK>  Tensor("strided_slice_2:0", shape=(), dtype=int32) Tensor("strided_slice_3:0", shape=(), dtype=int32) Tensor("strided_slice_4:0", shape=(), dtype=int32)

我在dataset.map()函数的下游使用形状尺寸,它引发了以下异常,因为它获取的None不是值。

TypeError: Failed to convert object of type <class 'tuple'> to Tensor. Contents: (-1, None, None, 3). Consider casting elements to a supported type.

当我切换为从中手动打开形状时tf.shape(),一切正常。

By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.