Spaces:
Runtime error
Runtime error
| import tensorflow as tf | |
| class AugmentationFactory: | |
| def __init__(self, image_size) -> None: | |
| self.image_size = image_size | |
| def random_crop(self, input_image, enhanced_image): | |
| input_image_shape = tf.shape(input_image)[:2] | |
| low_w = tf.random.uniform( | |
| shape=(), maxval=input_image_shape[1] - self.image_size + 1, dtype=tf.int32 | |
| ) | |
| low_h = tf.random.uniform( | |
| shape=(), maxval=input_image_shape[0] - self.image_size + 1, dtype=tf.int32 | |
| ) | |
| enhanced_w = low_w | |
| enhanced_h = low_h | |
| input_image_cropped = input_image[ | |
| low_h : low_h + self.image_size, low_w : low_w + self.image_size | |
| ] | |
| enhanced_image_cropped = enhanced_image[ | |
| enhanced_h : enhanced_h + self.image_size, | |
| enhanced_w : enhanced_w + self.image_size, | |
| ] | |
| return input_image_cropped, enhanced_image_cropped | |
| def random_horizontal_flip(sefl, input_image, enhanced_image): | |
| return tf.cond( | |
| tf.random.uniform(shape=(), maxval=1) < 0.5, | |
| lambda: (input_image, enhanced_image), | |
| lambda: ( | |
| tf.image.flip_left_right(input_image), | |
| tf.image.flip_left_right(enhanced_image), | |
| ), | |
| ) | |
| def random_vertical_flip(self, input_image, enhanced_image): | |
| return tf.cond( | |
| tf.random.uniform(shape=(), maxval=1) < 0.5, | |
| lambda: (input_image, enhanced_image), | |
| lambda: ( | |
| tf.image.flip_up_down(input_image), | |
| tf.image.flip_up_down(enhanced_image), | |
| ), | |
| ) | |
| def random_rotate(self, input_image, enhanced_image): | |
| condition = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32) | |
| return tf.image.rot90(input_image, condition), tf.image.rot90( | |
| enhanced_image, condition | |
| ) | |