Source code for colorex.cex_lut

# Modifications Copyright 2021 Alex Harvill
# SPDX-License-Identifier: Apache-2.0

# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''
Color lookup table op forked and modified from:
https://github.com/tensorflow/probability/tree/master/tensorflow_probability/python/math/interpolation.py
'''

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import itertools

# Dependency imports
import numpy as np

import tensorflow as tf

from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static
from tensorflow_probability.python.internal import tensorshape_util

__all__ = ['color_lookup_table']


[docs]def color_lookup_table( x, x_ref_min, x_ref_max, y_ref, axis=-4, fill_value='constant_extension', name=None, ): ''' apply a batch of 3d lookup tables to a batch of color tensors. the following shapes are expected to work: N: batch size D: lookup table cube size ( lut total element count is D*D*D*3 ) HW: number of colors in a batch ( probably flattened image height*width ) x: shape [N,HW,3] x_ref_min: shape [N,3] x_ref_max: shape [N,3] y_ref: shape [N,D,D,D,3] ''' with tf.name_scope(name or 'color_lookup_table'): dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref], dtype_hint=tf.float32) # Arg checking. if isinstance(fill_value, str): if fill_value != 'constant_extension': raise ValueError( 'A fill value ({}) was not an allowed string ({})'.format( fill_value, 'constant_extension')) else: fill_value = tf.convert_to_tensor(fill_value, name='fill_value', dtype=dtype) _assert_ndims_statically(fill_value, expect_ndims=0) # x.shape = [..., nd]. x = tf.convert_to_tensor(x, name='x', dtype=dtype) _assert_ndims_statically(x, expect_ndims_at_least=2) # y_ref.shape = [..., C1,...,Cnd, B1,...,BM] y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype) # x_ref_min.shape = [nd] x_ref_min = tf.convert_to_tensor(x_ref_min, name='x_ref_min', dtype=dtype) x_ref_max = tf.convert_to_tensor(x_ref_max, name='x_ref_max', dtype=dtype) _assert_ndims_statically(x_ref_min, expect_ndims_at_least=1, expect_static=True) _assert_ndims_statically(x_ref_max, expect_ndims_at_least=1, expect_static=True) # nd is the number of dimensions indexing the interpolation table, it's the # 'nd' in the function name. nd = tf.compat.dimension_value(x_ref_min.shape[-1]) if nd is None: raise ValueError('`x_ref_min.shape[-1]` must be known statically.') tensorshape_util.assert_is_compatible_with(x_ref_max.shape[-1:], x_ref_min.shape[-1:]) # Convert axis and check it statically. axis = tf.convert_to_tensor(axis, dtype=tf.int32, name='axis') axis = prefer_static.non_negative_axis(axis, tf.rank(y_ref)) batch_dims = tf.get_static_value(tf.rank(x)) - 2 assert batch_dims == 1, 'only 1 batch dimension supported' assert nd == 3, 'only 3 color components tested' assert x.shape[0] == x_ref_min.shape[0], 'unequal batch dimensions' assert x.shape[0] == x_ref_max.shape[0], 'unequal batch dimensions' assert x.shape[0] == y_ref.shape[0], 'unequal batch dimensions' return _batch_interp_with_gather_nd( x=x, x_ref_min=x_ref_min, x_ref_max=x_ref_max, y_ref=y_ref, nd=nd, fill_value=fill_value, batch_dims=batch_dims, )
def _batch_interp_with_gather_nd( x, x_ref_min, x_ref_max, y_ref, nd, fill_value, batch_dims, ): ''' N-D interpolation that works with leading batch dims. reformatted duplicate of tfp.math._batch_interp_regular_nd_grid ''' dtype = x.dtype # In this function, # x.shape = [A1, ..., An, D, nd], where n = batch_dims # and # y_ref.shape = [A1, ..., An, C1, C2,..., Cnd, B1,...,BM] # y_ref[A1, ..., An, i1,...,ind] is a shape [B1,...,BM] Tensor with the value # at index [i1,...,ind] in the interpolation table. # and x_ref_max have shapes [A1, ..., An, nd]. # ny[k] is number of y reference points in interp dim k. ny = tf.cast(tf.shape(y_ref)[batch_dims:batch_dims + nd], dtype) # Map [x_ref_min, x_ref_max] to [0, ny - 1]. # This is the (fractional) index of x. # x_idx_unclipped[A1, ..., An, d, k] is the fractional index into dim k of # interpolation table for the dth x value. x_ref_min_expanded = tf.expand_dims(x_ref_min, axis=-2) x_ref_max_expanded = tf.expand_dims(x_ref_max, axis=-2) x_idx_unclipped = (ny - 1) * (x - x_ref_min_expanded) / (x_ref_max_expanded - x_ref_min_expanded) # Wherever x is NaN, x_idx_unclipped will be NaN as well. # Keep track of the nan indices here (so we can impute NaN later). # Also eliminate any NaN indices, since there is not NaN in 32bit. nan_idx = tf.math.is_nan(x_idx_unclipped) x_idx_unclipped = tf.where(nan_idx, tf.cast(0., dtype=dtype), x_idx_unclipped) # x_idx.shape = [A1, ..., An, D, nd] x_idx = tf.clip_by_value(x_idx_unclipped, tf.zeros((), dtype=dtype), ny - 1) # Get the index above and below x_idx. # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx), # however, this results in idx_below == idx_above whenever x is on a grid. # This in turn results in y_ref_below == y_ref_above, and then the gradient # at this point is zero. So here we 'jitter' one of idx_below, idx_above, # so that they are at different values. This jittering does not affect the # interpolated value, but does make the gradient nonzero (unless of course # the y_ref values are the same). idx_below = tf.floor(x_idx) idx_above = tf.minimum(idx_below + 1, ny - 1) idx_below = tf.maximum(idx_above - 1, 0) # These are the values of y_ref corresponding to above/below indices. # idx_below_int32.shape = x.shape[:-1] + [nd] idx_below_int32 = tf.cast(idx_below, dtype=tf.int32) idx_above_int32 = tf.cast(idx_above, dtype=tf.int32) # idx_below_list is a length nd list of shape x.shape[:-1] int32 tensors. idx_below_list = tf.unstack(idx_below_int32, axis=-1) idx_above_list = tf.unstack(idx_above_int32, axis=-1) # Use t to get a convex combination of the below/above values. # t.shape = [A1, ..., An, D, nd] t = x_idx - idx_below # x, and tensors shaped like x, need to be added to, and selected with # (using tf.where) the output y. This requires appending singletons. def _expand_x_fn(tensor): # Reshape tensor to tensor.shape + [1] * M. extended_shape = tf.concat( [tf.shape(tensor), tf.ones_like(tf.shape(y_ref)[batch_dims + nd:])], axis=0) return tf.reshape(tensor, extended_shape) # Now, t.shape = [A1, ..., An, D, nd] + [1] * (rank(y_ref) - nd - batch_dims) t = _expand_x_fn(t) s = 1 - t # Re-insert NaN wherever x was NaN. nan_idx = _expand_x_fn(nan_idx) t = tf.where(nan_idx, tf.constant(np.nan, dtype), t) terms = [] # Our work above has located x's fractional index inside a cube of above/below # indices. The distance to the below indices is t, and to the above indices # is s. # Drawing lines from x to the cube walls, we get 2**nd smaller cubes. Each # term in the result is a product of a reference point, gathered from y_ref, # multiplied by a volume. The volume is that of the cube opposite to the # reference point. E.g. if the reference point is below x in every axis, the # volume is that of the cube with corner above x in every axis, s[0]*...*s[nd] # We could probably do this with one massive gather, but that would be very # unreadable and un-debuggable. It also would create a large Tensor. for zero_ones_list in _binary_count(nd): gather_from_y_ref_idx = [] opposite_volume_t_idx = [] opposite_volume_s_idx = [] for k, zero_or_one in enumerate(zero_ones_list): if zero_or_one == 0: # If the kth iterate has zero_or_one = 0, # Will gather from the 'below' reference point along axis k. gather_from_y_ref_idx.append(idx_below_list[k]) # Now append the index to gather for computing opposite_volume. # This could be done by initializing opposite_volume to 1, then here: # opposite_volume *= tf.gather(s, indices=k, axis=tf.rank(x) - 1) # but that puts a gather in the 'inner loop.' Better to append the # index and do one larger gather down below. opposite_volume_s_idx.append(k) else: gather_from_y_ref_idx.append(idx_above_list[k]) # Append an index to gather, having the same effect as # opposite_volume *= tf.gather(t, indices=k, axis=tf.rank(x) - 1) opposite_volume_t_idx.append(k) # Compute opposite_volume (volume of cube opposite the ref point): # Recall t.shape = s.shape = [D, nd] + [1, ..., 1] # Gather from t and s along the 'nd' axis, which is rank(x) - 1. ov_axis = tf.rank(x) - 1 opposite_volume = (tf.reduce_prod( tf.gather( t, indices=tf.cast(opposite_volume_t_idx, dtype=tf.int32), axis=ov_axis, ), axis=ov_axis, ) * tf.reduce_prod( tf.gather( s, indices=tf.cast(opposite_volume_s_idx, dtype=tf.int32), axis=ov_axis, ), axis=ov_axis, )) y_ref_pt = tf.gather_nd( y_ref, tf.stack(gather_from_y_ref_idx, axis=-1), batch_dims=batch_dims, ) terms.append(y_ref_pt * opposite_volume) y = tf.math.add_n(terms) if tf.debugging.is_numeric_tensor(fill_value): # Recall x_idx_unclipped.shape = [D, nd], # so here we check if it was out of bounds in any of the nd dims. # Thus, oob_idx.shape = [D]. oob_idx = tf.reduce_any( (x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1), axis=-1, ) # Now, y.shape = [D, B1,...,BM], so we'll have to broadcast oob_idx. oob_idx = _expand_x_fn(oob_idx) # Shape [D, 1,...,1] oob_idx |= tf.fill(tf.shape(y), False) y = tf.where(oob_idx, fill_value, y) return y def _assert_ndims_statically( x, expect_ndims=None, expect_ndims_at_least=None, expect_static=False, ): '''Assert that Tensor x has expected number of dimensions.''' ndims = tensorshape_util.rank(x.shape) if ndims is None: if expect_static: raise ValueError('Expected static ndims. Found: {}'.format(x)) return if expect_ndims is not None and ndims != expect_ndims: raise ValueError('ndims must be {}. Found: {}'.format(expect_ndims, ndims)) if expect_ndims_at_least is not None and ndims < expect_ndims_at_least: raise ValueError('ndims must be at least {}. Found {}'.format( expect_ndims_at_least, ndims)) def _make_expand_x_fn_for_non_batch_interpolation(y_ref, axis): '''Make func to expand left/right (of axis) dims of tensors shaped like x.''' # This expansion is to help x broadcast with `y`, the output. # In the non-batch case, the output shape is going to be # y_ref.shape[:axis] + x.shape + y_ref.shape[axis+1:] # Recall we made axis non-negative y_ref_shape = tf.shape(y_ref) y_ref_shape_left = y_ref_shape[:axis] y_ref_shape_right = y_ref_shape[axis + 1:] def expand_ends(x, broadcast=False): '''Expand x so it can bcast w/ tensors of output shape.''' # Assume out_shape = A + x.shape + B, and rank(A) = axis. # Expand with singletons with same rank as A, B. expanded_shape = tf.pad( tensor=tf.shape(x), paddings=[[axis, tf.size(y_ref_shape_right)]], constant_values=1, ) x_expanded = tf.reshape(x, expanded_shape) if broadcast: out_shape = tf.concat( ( y_ref_shape_left, tf.shape(x), y_ref_shape_right, ), axis=0, ) if dtype_util.is_bool(x.dtype): x_expanded = x_expanded | tf.cast(tf.zeros(out_shape), tf.bool) else: x_expanded += tf.zeros(out_shape, dtype=x.dtype) return x_expanded return expand_ends def _binary_count(n): '''Count `n` binary digits from [0...0] to [1...1].''' return list(itertools.product([0, 1], repeat=n))