Spaces:
Runtime error
Runtime error
| # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # This work is licensed under the Creative Commons Attribution-NonCommercial | |
| # 4.0 International License. To view a copy of this license, visit | |
| # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | |
| # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | |
| import os | |
| import sys | |
| import inspect | |
| import importlib | |
| import imp | |
| import numpy as np | |
| from collections import OrderedDict | |
| import tensorflow as tf | |
| #---------------------------------------------------------------------------- | |
| # Convenience. | |
| def run(*args, **kwargs): # Run the specified ops in the default session. | |
| return tf.get_default_session().run(*args, **kwargs) | |
| def is_tf_expression(x): | |
| return isinstance(x, tf.Tensor) or isinstance(x, tf.Variable) or isinstance(x, tf.Operation) | |
| def shape_to_list(shape): | |
| return [dim.value for dim in shape] | |
| def flatten(x): | |
| with tf.name_scope('Flatten'): | |
| return tf.reshape(x, [-1]) | |
| def log2(x): | |
| with tf.name_scope('Log2'): | |
| return tf.log(x) * np.float32(1.0 / np.log(2.0)) | |
| def exp2(x): | |
| with tf.name_scope('Exp2'): | |
| return tf.exp(x * np.float32(np.log(2.0))) | |
| def lerp(a, b, t): | |
| with tf.name_scope('Lerp'): | |
| return a + (b - a) * t | |
| def lerp_clip(a, b, t): | |
| with tf.name_scope('LerpClip'): | |
| return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) | |
| def absolute_name_scope(scope): # Forcefully enter the specified name scope, ignoring any surrounding scopes. | |
| return tf.name_scope(scope + '/') | |
| #---------------------------------------------------------------------------- | |
| # Initialize TensorFlow graph and session using good default settings. | |
| def init_tf(config_dict=dict()): | |
| if tf.get_default_session() is None: | |
| tf.set_random_seed(np.random.randint(1 << 31)) | |
| create_session(config_dict, force_as_default=True) | |
| #---------------------------------------------------------------------------- | |
| # Create tf.Session based on config dict of the form | |
| # {'gpu_options.allow_growth': True} | |
| def create_session(config_dict=dict(), force_as_default=False): | |
| config = tf.ConfigProto() | |
| for key, value in config_dict.items(): | |
| fields = key.split('.') | |
| obj = config | |
| for field in fields[:-1]: | |
| obj = getattr(obj, field) | |
| setattr(obj, fields[-1], value) | |
| session = tf.Session(config=config) | |
| if force_as_default: | |
| session._default_session = session.as_default() | |
| session._default_session.enforce_nesting = False | |
| session._default_session.__enter__() | |
| return session | |
| #---------------------------------------------------------------------------- | |
| # Initialize all tf.Variables that have not already been initialized. | |
| # Equivalent to the following, but more efficient and does not bloat the tf graph: | |
| # tf.variables_initializer(tf.report_unitialized_variables()).run() | |
| def init_uninited_vars(vars=None): | |
| if vars is None: vars = tf.global_variables() | |
| test_vars = []; test_ops = [] | |
| with tf.control_dependencies(None): # ignore surrounding control_dependencies | |
| for var in vars: | |
| assert is_tf_expression(var) | |
| try: | |
| tf.get_default_graph().get_tensor_by_name(var.name.replace(':0', '/IsVariableInitialized:0')) | |
| except KeyError: | |
| # Op does not exist => variable may be uninitialized. | |
| test_vars.append(var) | |
| with absolute_name_scope(var.name.split(':')[0]): | |
| test_ops.append(tf.is_variable_initialized(var)) | |
| init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] | |
| run([var.initializer for var in init_vars]) | |
| #---------------------------------------------------------------------------- | |
| # Set the values of given tf.Variables. | |
| # Equivalent to the following, but more efficient and does not bloat the tf graph: | |
| # tfutil.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] | |
| def set_vars(var_to_value_dict): | |
| ops = [] | |
| feed_dict = {} | |
| for var, value in var_to_value_dict.items(): | |
| assert is_tf_expression(var) | |
| try: | |
| setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(':0', '/setter:0')) # look for existing op | |
| except KeyError: | |
| with absolute_name_scope(var.name.split(':')[0]): | |
| with tf.control_dependencies(None): # ignore surrounding control_dependencies | |
| setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, 'new_value'), name='setter') # create new setter | |
| ops.append(setter) | |
| feed_dict[setter.op.inputs[1]] = value | |
| run(ops, feed_dict) | |
| #---------------------------------------------------------------------------- | |
| # Autosummary creates an identity op that internally keeps track of the input | |
| # values and automatically shows up in TensorBoard. The reported value | |
| # represents an average over input components. The average is accumulated | |
| # constantly over time and flushed when save_summaries() is called. | |
| # | |
| # Notes: | |
| # - The output tensor must be used as an input for something else in the | |
| # graph. Otherwise, the autosummary op will not get executed, and the average | |
| # value will not get accumulated. | |
| # - It is perfectly fine to include autosummaries with the same name in | |
| # several places throughout the graph, even if they are executed concurrently. | |
| # - It is ok to also pass in a python scalar or numpy array. In this case, it | |
| # is added to the average immediately. | |
| _autosummary_vars = OrderedDict() # name => [var, ...] | |
| _autosummary_immediate = OrderedDict() # name => update_op, update_value | |
| _autosummary_finalized = False | |
| def autosummary(name, value): | |
| id = name.replace('/', '_') | |
| if is_tf_expression(value): | |
| with tf.name_scope('summary_' + id), tf.device(value.device): | |
| update_op = _create_autosummary_var(name, value) | |
| with tf.control_dependencies([update_op]): | |
| return tf.identity(value) | |
| else: # python scalar or numpy array | |
| if name not in _autosummary_immediate: | |
| with absolute_name_scope('Autosummary/' + id), tf.device(None), tf.control_dependencies(None): | |
| update_value = tf.placeholder(tf.float32) | |
| update_op = _create_autosummary_var(name, update_value) | |
| _autosummary_immediate[name] = update_op, update_value | |
| update_op, update_value = _autosummary_immediate[name] | |
| run(update_op, {update_value: np.float32(value)}) | |
| return value | |
| # Create the necessary ops to include autosummaries in TensorBoard report. | |
| # Note: This should be done only once per graph. | |
| def finalize_autosummaries(): | |
| global _autosummary_finalized | |
| if _autosummary_finalized: | |
| return | |
| _autosummary_finalized = True | |
| init_uninited_vars([var for vars in _autosummary_vars.values() for var in vars]) | |
| with tf.device(None), tf.control_dependencies(None): | |
| for name, vars in _autosummary_vars.items(): | |
| id = name.replace('/', '_') | |
| with absolute_name_scope('Autosummary/' + id): | |
| sum = tf.add_n(vars) | |
| avg = sum[0] / sum[1] | |
| with tf.control_dependencies([avg]): # read before resetting | |
| reset_ops = [tf.assign(var, tf.zeros(2)) for var in vars] | |
| with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting | |
| tf.summary.scalar(name, avg) | |
| # Internal helper for creating autosummary accumulators. | |
| def _create_autosummary_var(name, value_expr): | |
| assert not _autosummary_finalized | |
| v = tf.cast(value_expr, tf.float32) | |
| if v.shape.ndims is 0: | |
| v = [v, np.float32(1.0)] | |
| elif v.shape.ndims is 1: | |
| v = [tf.reduce_sum(v), tf.cast(tf.shape(v)[0], tf.float32)] | |
| else: | |
| v = [tf.reduce_sum(v), tf.reduce_prod(tf.cast(tf.shape(v), tf.float32))] | |
| v = tf.cond(tf.is_finite(v[0]), lambda: tf.stack(v), lambda: tf.zeros(2)) | |
| with tf.control_dependencies(None): | |
| var = tf.Variable(tf.zeros(2)) # [numerator, denominator] | |
| update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v)) | |
| if name in _autosummary_vars: | |
| _autosummary_vars[name].append(var) | |
| else: | |
| _autosummary_vars[name] = [var] | |
| return update_op | |
| #---------------------------------------------------------------------------- | |
| # Call filewriter.add_summary() with all summaries in the default graph, | |
| # automatically finalizing and merging them on the first call. | |
| _summary_merge_op = None | |
| def save_summaries(filewriter, global_step=None): | |
| global _summary_merge_op | |
| if _summary_merge_op is None: | |
| finalize_autosummaries() | |
| with tf.device(None), tf.control_dependencies(None): | |
| _summary_merge_op = tf.summary.merge_all() | |
| filewriter.add_summary(_summary_merge_op.eval(), global_step) | |
| #---------------------------------------------------------------------------- | |
| # Utilities for importing modules and objects by name. | |
| def import_module(module_or_obj_name): | |
| parts = module_or_obj_name.split('.') | |
| parts[0] = {'np': 'numpy', 'tf': 'tensorflow'}.get(parts[0], parts[0]) | |
| for i in range(len(parts), 0, -1): | |
| try: | |
| module = importlib.import_module('.'.join(parts[:i])) | |
| relative_obj_name = '.'.join(parts[i:]) | |
| return module, relative_obj_name | |
| except ImportError: | |
| pass | |
| raise ImportError(module_or_obj_name) | |
| def find_obj_in_module(module, relative_obj_name): | |
| obj = module | |
| for part in relative_obj_name.split('.'): | |
| obj = getattr(obj, part) | |
| return obj | |
| def import_obj(obj_name): | |
| module, relative_obj_name = import_module(obj_name) | |
| return find_obj_in_module(module, relative_obj_name) | |
| def call_func_by_name(*args, func=None, **kwargs): | |
| assert func is not None | |
| return import_obj(func)(*args, **kwargs) | |
| #---------------------------------------------------------------------------- | |
| # Wrapper for tf.train.Optimizer that automatically takes care of: | |
| # - Gradient averaging for multi-GPU training. | |
| # - Dynamic loss scaling and typecasts for FP16 training. | |
| # - Ignoring corrupted gradients that contain NaNs/Infs. | |
| # - Reporting statistics. | |
| # - Well-chosen default settings. | |
| class Optimizer: | |
| def __init__( | |
| self, | |
| name = 'Train', | |
| tf_optimizer = 'tf.train.AdamOptimizer', | |
| learning_rate = 0.001, | |
| use_loss_scaling = False, | |
| loss_scaling_init = 64.0, | |
| loss_scaling_inc = 0.0005, | |
| loss_scaling_dec = 1.0, | |
| **kwargs): | |
| # Init fields. | |
| self.name = name | |
| self.learning_rate = tf.convert_to_tensor(learning_rate) | |
| self.id = self.name.replace('/', '.') | |
| self.scope = tf.get_default_graph().unique_name(self.id) | |
| self.optimizer_class = import_obj(tf_optimizer) | |
| self.optimizer_kwargs = dict(kwargs) | |
| self.use_loss_scaling = use_loss_scaling | |
| self.loss_scaling_init = loss_scaling_init | |
| self.loss_scaling_inc = loss_scaling_inc | |
| self.loss_scaling_dec = loss_scaling_dec | |
| self._grad_shapes = None # [shape, ...] | |
| self._dev_opt = OrderedDict() # device => optimizer | |
| self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...] | |
| self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor) | |
| self._updates_applied = False | |
| # Register the gradients of the given loss function with respect to the given variables. | |
| # Intended to be called once per GPU. | |
| def register_gradients(self, loss, vars): | |
| assert not self._updates_applied | |
| # Validate arguments. | |
| if isinstance(vars, dict): | |
| vars = list(vars.values()) # allow passing in Network.trainables as vars | |
| assert isinstance(vars, list) and len(vars) >= 1 | |
| assert all(is_tf_expression(expr) for expr in vars + [loss]) | |
| if self._grad_shapes is None: | |
| self._grad_shapes = [shape_to_list(var.shape) for var in vars] | |
| assert len(vars) == len(self._grad_shapes) | |
| assert all(shape_to_list(var.shape) == var_shape for var, var_shape in zip(vars, self._grad_shapes)) | |
| dev = loss.device | |
| assert all(var.device == dev for var in vars) | |
| # Register device and compute gradients. | |
| with tf.name_scope(self.id + '_grad'), tf.device(dev): | |
| if dev not in self._dev_opt: | |
| opt_name = self.scope.replace('/', '_') + '_opt%d' % len(self._dev_opt) | |
| self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs) | |
| self._dev_grads[dev] = [] | |
| loss = self.apply_loss_scaling(tf.cast(loss, tf.float32)) | |
| grads = self._dev_opt[dev].compute_gradients(loss, vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage | |
| grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros | |
| self._dev_grads[dev].append(grads) | |
| # Construct training op to update the registered variables based on their gradients. | |
| def apply_updates(self): | |
| assert not self._updates_applied | |
| self._updates_applied = True | |
| devices = list(self._dev_grads.keys()) | |
| total_grads = sum(len(grads) for grads in self._dev_grads.values()) | |
| assert len(devices) >= 1 and total_grads >= 1 | |
| ops = [] | |
| with absolute_name_scope(self.scope): | |
| # Cast gradients to FP32 and calculate partial sum within each device. | |
| dev_grads = OrderedDict() # device => [(grad, var), ...] | |
| for dev_idx, dev in enumerate(devices): | |
| with tf.name_scope('ProcessGrads%d' % dev_idx), tf.device(dev): | |
| sums = [] | |
| for gv in zip(*self._dev_grads[dev]): | |
| assert all(v is gv[0][1] for g, v in gv) | |
| g = [tf.cast(g, tf.float32) for g, v in gv] | |
| g = g[0] if len(g) == 1 else tf.add_n(g) | |
| sums.append((g, gv[0][1])) | |
| dev_grads[dev] = sums | |
| # Sum gradients across devices. | |
| if len(devices) > 1: | |
| with tf.name_scope('SumAcrossGPUs'), tf.device(None): | |
| for var_idx, grad_shape in enumerate(self._grad_shapes): | |
| g = [dev_grads[dev][var_idx][0] for dev in devices] | |
| if np.prod(grad_shape): # nccl does not support zero-sized tensors | |
| g = tf.contrib.nccl.all_sum(g) | |
| for dev, gg in zip(devices, g): | |
| dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1]) | |
| # Apply updates separately on each device. | |
| for dev_idx, (dev, grads) in enumerate(dev_grads.items()): | |
| with tf.name_scope('ApplyGrads%d' % dev_idx), tf.device(dev): | |
| # Scale gradients as needed. | |
| if self.use_loss_scaling or total_grads > 1: | |
| with tf.name_scope('Scale'): | |
| coef = tf.constant(np.float32(1.0 / total_grads), name='coef') | |
| coef = self.undo_loss_scaling(coef) | |
| grads = [(g * coef, v) for g, v in grads] | |
| # Check for overflows. | |
| with tf.name_scope('CheckOverflow'): | |
| grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads])) | |
| # Update weights and adjust loss scaling. | |
| with tf.name_scope('UpdateWeights'): | |
| opt = self._dev_opt[dev] | |
| ls_var = self.get_loss_scaling_var(dev) | |
| if not self.use_loss_scaling: | |
| ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op)) | |
| else: | |
| ops.append(tf.cond(grad_ok, | |
| lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)), | |
| lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec)))) | |
| # Report statistics on the last device. | |
| if dev == devices[-1]: | |
| with tf.name_scope('Statistics'): | |
| ops.append(autosummary(self.id + '/learning_rate', self.learning_rate)) | |
| ops.append(autosummary(self.id + '/overflow_frequency', tf.where(grad_ok, 0, 1))) | |
| if self.use_loss_scaling: | |
| ops.append(autosummary(self.id + '/loss_scaling_log2', ls_var)) | |
| # Initialize variables and group everything into a single op. | |
| self.reset_optimizer_state() | |
| init_uninited_vars(list(self._dev_ls_var.values())) | |
| return tf.group(*ops, name='TrainingOp') | |
| # Reset internal state of the underlying optimizer. | |
| def reset_optimizer_state(self): | |
| run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()]) | |
| # Get or create variable representing log2 of the current dynamic loss scaling factor. | |
| def get_loss_scaling_var(self, device): | |
| if not self.use_loss_scaling: | |
| return None | |
| if device not in self._dev_ls_var: | |
| with absolute_name_scope(self.scope + '/LossScalingVars'), tf.control_dependencies(None): | |
| self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name='loss_scaling_var') | |
| return self._dev_ls_var[device] | |
| # Apply dynamic loss scaling for the given expression. | |
| def apply_loss_scaling(self, value): | |
| assert is_tf_expression(value) | |
| if not self.use_loss_scaling: | |
| return value | |
| return value * exp2(self.get_loss_scaling_var(value.device)) | |
| # Undo the effect of dynamic loss scaling for the given expression. | |
| def undo_loss_scaling(self, value): | |
| assert is_tf_expression(value) | |
| if not self.use_loss_scaling: | |
| return value | |
| return value * exp2(-self.get_loss_scaling_var(value.device)) | |
| #---------------------------------------------------------------------------- | |
| # Generic network abstraction. | |
| # | |
| # Acts as a convenience wrapper for a parameterized network construction | |
| # function, providing several utility methods and convenient access to | |
| # the inputs/outputs/weights. | |
| # | |
| # Network objects can be safely pickled and unpickled for long-term | |
| # archival purposes. The pickling works reliably as long as the underlying | |
| # network construction function is defined in a standalone Python module | |
| # that has no side effects or application-specific imports. | |
| network_import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import. | |
| _network_import_modules = [] # Temporary modules create during pickle import. | |
| class Network: | |
| def __init__(self, | |
| name=None, # Network name. Used to select TensorFlow name and variable scopes. | |
| func=None, # Fully qualified name of the underlying network construction function. | |
| **static_kwargs): # Keyword arguments to be passed in to the network construction function. | |
| self._init_fields() | |
| self.name = name | |
| self.static_kwargs = dict(static_kwargs) | |
| # Init build func. | |
| module, self._build_func_name = import_module(func) | |
| self._build_module_src = inspect.getsource(module) | |
| self._build_func = find_obj_in_module(module, self._build_func_name) | |
| # Init graph. | |
| self._init_graph() | |
| self.reset_vars() | |
| def _init_fields(self): | |
| self.name = None # User-specified name, defaults to build func name if None. | |
| self.scope = None # Unique TF graph scope, derived from the user-specified name. | |
| self.static_kwargs = dict() # Arguments passed to the user-supplied build func. | |
| self.num_inputs = 0 # Number of input tensors. | |
| self.num_outputs = 0 # Number of output tensors. | |
| self.input_shapes = [[]] # Input tensor shapes (NC or NCHW), including minibatch dimension. | |
| self.output_shapes = [[]] # Output tensor shapes (NC or NCHW), including minibatch dimension. | |
| self.input_shape = [] # Short-hand for input_shapes[0]. | |
| self.output_shape = [] # Short-hand for output_shapes[0]. | |
| self.input_templates = [] # Input placeholders in the template graph. | |
| self.output_templates = [] # Output tensors in the template graph. | |
| self.input_names = [] # Name string for each input. | |
| self.output_names = [] # Name string for each output. | |
| self.vars = OrderedDict() # All variables (localname => var). | |
| self.trainables = OrderedDict() # Trainable variables (localname => var). | |
| self._build_func = None # User-supplied build function that constructs the network. | |
| self._build_func_name = None # Name of the build function. | |
| self._build_module_src = None # Full source code of the module containing the build function. | |
| self._run_cache = dict() # Cached graph data for Network.run(). | |
| def _init_graph(self): | |
| # Collect inputs. | |
| self.input_names = [] | |
| for param in inspect.signature(self._build_func).parameters.values(): | |
| if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty: | |
| self.input_names.append(param.name) | |
| self.num_inputs = len(self.input_names) | |
| assert self.num_inputs >= 1 | |
| # Choose name and scope. | |
| if self.name is None: | |
| self.name = self._build_func_name | |
| self.scope = tf.get_default_graph().unique_name(self.name.replace('/', '_'), mark_as_used=False) | |
| # Build template graph. | |
| with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE): | |
| assert tf.get_variable_scope().name == self.scope | |
| with absolute_name_scope(self.scope): # ignore surrounding name_scope | |
| with tf.control_dependencies(None): # ignore surrounding control_dependencies | |
| self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names] | |
| out_expr = self._build_func(*self.input_templates, is_template_graph=True, **self.static_kwargs) | |
| # Collect outputs. | |
| assert is_tf_expression(out_expr) or isinstance(out_expr, tuple) | |
| self.output_templates = [out_expr] if is_tf_expression(out_expr) else list(out_expr) | |
| self.output_names = [t.name.split('/')[-1].split(':')[0] for t in self.output_templates] | |
| self.num_outputs = len(self.output_templates) | |
| assert self.num_outputs >= 1 | |
| # Populate remaining fields. | |
| self.input_shapes = [shape_to_list(t.shape) for t in self.input_templates] | |
| self.output_shapes = [shape_to_list(t.shape) for t in self.output_templates] | |
| self.input_shape = self.input_shapes[0] | |
| self.output_shape = self.output_shapes[0] | |
| self.vars = OrderedDict([(self.get_var_localname(var), var) for var in tf.global_variables(self.scope + '/')]) | |
| self.trainables = OrderedDict([(self.get_var_localname(var), var) for var in tf.trainable_variables(self.scope + '/')]) | |
| # Run initializers for all variables defined by this network. | |
| def reset_vars(self): | |
| run([var.initializer for var in self.vars.values()]) | |
| # Run initializers for all trainable variables defined by this network. | |
| def reset_trainables(self): | |
| run([var.initializer for var in self.trainables.values()]) | |
| # Get TensorFlow expression(s) for the output(s) of this network, given the inputs. | |
| def get_output_for(self, *in_expr, return_as_list=False, **dynamic_kwargs): | |
| assert len(in_expr) == self.num_inputs | |
| all_kwargs = dict(self.static_kwargs) | |
| all_kwargs.update(dynamic_kwargs) | |
| with tf.variable_scope(self.scope, reuse=True): | |
| assert tf.get_variable_scope().name == self.scope | |
| named_inputs = [tf.identity(expr, name=name) for expr, name in zip(in_expr, self.input_names)] | |
| out_expr = self._build_func(*named_inputs, **all_kwargs) | |
| assert is_tf_expression(out_expr) or isinstance(out_expr, tuple) | |
| if return_as_list: | |
| out_expr = [out_expr] if is_tf_expression(out_expr) else list(out_expr) | |
| return out_expr | |
| # Get the local name of a given variable, excluding any surrounding name scopes. | |
| def get_var_localname(self, var_or_globalname): | |
| assert is_tf_expression(var_or_globalname) or isinstance(var_or_globalname, str) | |
| globalname = var_or_globalname if isinstance(var_or_globalname, str) else var_or_globalname.name | |
| assert globalname.startswith(self.scope + '/') | |
| localname = globalname[len(self.scope) + 1:] | |
| localname = localname.split(':')[0] | |
| return localname | |
| # Find variable by local or global name. | |
| def find_var(self, var_or_localname): | |
| assert is_tf_expression(var_or_localname) or isinstance(var_or_localname, str) | |
| return self.vars[var_or_localname] if isinstance(var_or_localname, str) else var_or_localname | |
| # Get the value of a given variable as NumPy array. | |
| # Note: This method is very inefficient -- prefer to use tfutil.run(list_of_vars) whenever possible. | |
| def get_var(self, var_or_localname): | |
| return self.find_var(var_or_localname).eval() | |
| # Set the value of a given variable based on the given NumPy array. | |
| # Note: This method is very inefficient -- prefer to use tfutil.set_vars() whenever possible. | |
| def set_var(self, var_or_localname, new_value): | |
| return set_vars({self.find_var(var_or_localname): new_value}) | |
| # Pickle export. | |
| def __getstate__(self): | |
| return { | |
| 'version': 2, | |
| 'name': self.name, | |
| 'static_kwargs': self.static_kwargs, | |
| 'build_module_src': self._build_module_src, | |
| 'build_func_name': self._build_func_name, | |
| 'variables': list(zip(self.vars.keys(), run(list(self.vars.values()))))} | |
| # Pickle import. | |
| def __setstate__(self, state): | |
| self._init_fields() | |
| # Execute custom import handlers. | |
| for handler in network_import_handlers: | |
| state = handler(state) | |
| # Set basic fields. | |
| assert state['version'] == 2 | |
| self.name = state['name'] | |
| self.static_kwargs = state['static_kwargs'] | |
| self._build_module_src = state['build_module_src'] | |
| self._build_func_name = state['build_func_name'] | |
| # Parse imported module. | |
| module = imp.new_module('_tfutil_network_import_module_%d' % len(_network_import_modules)) | |
| exec(self._build_module_src, module.__dict__) | |
| self._build_func = find_obj_in_module(module, self._build_func_name) | |
| _network_import_modules.append(module) # avoid gc | |
| # Init graph. | |
| self._init_graph() | |
| self.reset_vars() | |
| set_vars({self.find_var(name): value for name, value in state['variables']}) | |
| # Create a clone of this network with its own copy of the variables. | |
| def clone(self, name=None): | |
| net = object.__new__(Network) | |
| net._init_fields() | |
| net.name = name if name is not None else self.name | |
| net.static_kwargs = dict(self.static_kwargs) | |
| net._build_module_src = self._build_module_src | |
| net._build_func_name = self._build_func_name | |
| net._build_func = self._build_func | |
| net._init_graph() | |
| net.copy_vars_from(self) | |
| return net | |
| # Copy the values of all variables from the given network. | |
| def copy_vars_from(self, src_net): | |
| assert isinstance(src_net, Network) | |
| name_to_value = run({name: src_net.find_var(name) for name in self.vars.keys()}) | |
| set_vars({self.find_var(name): value for name, value in name_to_value.items()}) | |
| # Copy the values of all trainable variables from the given network. | |
| def copy_trainables_from(self, src_net): | |
| assert isinstance(src_net, Network) | |
| name_to_value = run({name: src_net.find_var(name) for name in self.trainables.keys()}) | |
| set_vars({self.find_var(name): value for name, value in name_to_value.items()}) | |
| # Create new network with the given parameters, and copy all variables from this network. | |
| def convert(self, name=None, func=None, **static_kwargs): | |
| net = Network(name, func, **static_kwargs) | |
| net.copy_vars_from(self) | |
| return net | |
| # Construct a TensorFlow op that updates the variables of this network | |
| # to be slightly closer to those of the given network. | |
| def setup_as_moving_average_of(self, src_net, beta=0.99, beta_nontrainable=0.0): | |
| assert isinstance(src_net, Network) | |
| with absolute_name_scope(self.scope): | |
| with tf.name_scope('MovingAvg'): | |
| ops = [] | |
| for name, var in self.vars.items(): | |
| if name in src_net.vars: | |
| cur_beta = beta if name in self.trainables else beta_nontrainable | |
| new_value = lerp(src_net.vars[name], var, cur_beta) | |
| ops.append(var.assign(new_value)) | |
| return tf.group(*ops) | |
| # Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s). | |
| def run(self, *in_arrays, | |
| return_as_list = False, # True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs. | |
| print_progress = False, # Print progress to the console? Useful for very large input arrays. | |
| minibatch_size = None, # Maximum minibatch size to use, None = disable batching. | |
| num_gpus = 1, # Number of GPUs to use. | |
| out_mul = 1.0, # Multiplicative constant to apply to the output(s). | |
| out_add = 0.0, # Additive constant to apply to the output(s). | |
| out_shrink = 1, # Shrink the spatial dimensions of the output(s) by the given factor. | |
| out_dtype = None, # Convert the output to the specified data type. | |
| **dynamic_kwargs): # Additional keyword arguments to pass into the network construction function. | |
| assert len(in_arrays) == self.num_inputs | |
| num_items = in_arrays[0].shape[0] | |
| if minibatch_size is None: | |
| minibatch_size = num_items | |
| key = str([list(sorted(dynamic_kwargs.items())), num_gpus, out_mul, out_add, out_shrink, out_dtype]) | |
| # Build graph. | |
| if key not in self._run_cache: | |
| with absolute_name_scope(self.scope + '/Run'), tf.control_dependencies(None): | |
| in_split = list(zip(*[tf.split(x, num_gpus) for x in self.input_templates])) | |
| out_split = [] | |
| for gpu in range(num_gpus): | |
| with tf.device('/gpu:%d' % gpu): | |
| out_expr = self.get_output_for(*in_split[gpu], return_as_list=True, **dynamic_kwargs) | |
| if out_mul != 1.0: | |
| out_expr = [x * out_mul for x in out_expr] | |
| if out_add != 0.0: | |
| out_expr = [x + out_add for x in out_expr] | |
| if out_shrink > 1: | |
| ksize = [1, 1, out_shrink, out_shrink] | |
| out_expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding='VALID', data_format='NCHW') for x in out_expr] | |
| if out_dtype is not None: | |
| if tf.as_dtype(out_dtype).is_integer: | |
| out_expr = [tf.round(x) for x in out_expr] | |
| out_expr = [tf.saturate_cast(x, out_dtype) for x in out_expr] | |
| out_split.append(out_expr) | |
| self._run_cache[key] = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)] | |
| # Run minibatches. | |
| out_expr = self._run_cache[key] | |
| out_arrays = [np.empty([num_items] + shape_to_list(expr.shape)[1:], expr.dtype.name) for expr in out_expr] | |
| for mb_begin in range(0, num_items, minibatch_size): | |
| if print_progress: | |
| print('\r%d / %d' % (mb_begin, num_items), end='') | |
| mb_end = min(mb_begin + minibatch_size, num_items) | |
| mb_in = [src[mb_begin : mb_end] for src in in_arrays] | |
| mb_out = tf.get_default_session().run(out_expr, dict(zip(self.input_templates, mb_in))) | |
| for dst, src in zip(out_arrays, mb_out): | |
| dst[mb_begin : mb_end] = src | |
| # Done. | |
| if print_progress: | |
| print('\r%d / %d' % (num_items, num_items)) | |
| if not return_as_list: | |
| out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays) | |
| return out_arrays | |
| # Returns a list of (name, output_expr, trainable_vars) tuples corresponding to | |
| # individual layers of the network. Mainly intended to be used for reporting. | |
| def list_layers(self): | |
| patterns_to_ignore = ['/Setter', '/new_value', '/Shape', '/strided_slice', '/Cast', '/concat'] | |
| all_ops = tf.get_default_graph().get_operations() | |
| all_ops = [op for op in all_ops if not any(p in op.name for p in patterns_to_ignore)] | |
| layers = [] | |
| def recurse(scope, parent_ops, level): | |
| prefix = scope + '/' | |
| ops = [op for op in parent_ops if op.name == scope or op.name.startswith(prefix)] | |
| # Does not contain leaf nodes => expand immediate children. | |
| if level == 0 or all('/' in op.name[len(prefix):] for op in ops): | |
| visited = set() | |
| for op in ops: | |
| suffix = op.name[len(prefix):] | |
| if '/' in suffix: | |
| suffix = suffix[:suffix.index('/')] | |
| if suffix not in visited: | |
| recurse(prefix + suffix, ops, level + 1) | |
| visited.add(suffix) | |
| # Otherwise => interpret as a layer. | |
| else: | |
| layer_name = scope[len(self.scope)+1:] | |
| layer_output = ops[-1].outputs[0] | |
| layer_trainables = [op.outputs[0] for op in ops if op.type.startswith('Variable') and self.get_var_localname(op.name) in self.trainables] | |
| layers.append((layer_name, layer_output, layer_trainables)) | |
| recurse(self.scope, all_ops, 0) | |
| return layers | |
| # Print a summary table of the network structure. | |
| def print_layers(self, title=None, hide_layers_with_no_params=False): | |
| if title is None: title = self.name | |
| print() | |
| print('%-28s%-12s%-24s%-24s' % (title, 'Params', 'OutputShape', 'WeightShape')) | |
| print('%-28s%-12s%-24s%-24s' % (('---',) * 4)) | |
| total_params = 0 | |
| for layer_name, layer_output, layer_trainables in self.list_layers(): | |
| weights = [var for var in layer_trainables if var.name.endswith('/weight:0')] | |
| num_params = sum(np.prod(shape_to_list(var.shape)) for var in layer_trainables) | |
| total_params += num_params | |
| if hide_layers_with_no_params and num_params == 0: | |
| continue | |
| print('%-28s%-12s%-24s%-24s' % ( | |
| layer_name, | |
| num_params if num_params else '-', | |
| layer_output.shape, | |
| weights[0].shape if len(weights) == 1 else '-')) | |
| print('%-28s%-12s%-24s%-24s' % (('---',) * 4)) | |
| print('%-28s%-12s%-24s%-24s' % ('Total', total_params, '', '')) | |
| print() | |
| # Construct summary ops to include histograms of all trainable parameters in TensorBoard. | |
| def setup_weight_histograms(self, title=None): | |
| if title is None: title = self.name | |
| with tf.name_scope(None), tf.device(None), tf.control_dependencies(None): | |
| for localname, var in self.trainables.items(): | |
| if '/' in localname: | |
| p = localname.split('/') | |
| name = title + '_' + p[-1] + '/' + '_'.join(p[:-1]) | |
| else: | |
| name = title + '_toplevel/' + localname | |
| tf.summary.histogram(name, var) | |
| #---------------------------------------------------------------------------- | |