Source code for tensorforce.core.optimizers.synchronization

# Copyright 2018 Tensorforce Team. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import tensorflow as tf

from tensorforce import util
from tensorforce.core import Module, parameter_modules
from tensorforce.core.optimizers import Optimizer

[docs]class Synchronization(Optimizer): """ Synchronization optimizer, which updates variables periodically to the value of a corresponding set of source variables (specification key: `synchronization`). Args: name (string): Module name (<span style="color:#0000C0"><b>internal use</b></span>). optimizer (specification): Optimizer configuration (<span style="color:#C00000"><b>required</b></span>). sync_frequency (parameter, int > 0): Timestep interval between updates which also perform a synchronization step (<span style="color:#00C000"><b>default</b></span>: every time). update_weight (parameter, 0.0 < float <= 1.0): Update weight (<span style="color:#00C000"><b>default</b></span>: 1.0). summary_labels ('all' | iter[string]): Labels of summaries to record (<span style="color:#00C000"><b>default</b></span>: inherit value of parent module). """ def __init__(self, name, sync_frequency=1, update_weight=1.0, summary_labels=None): super().__init__(name=name, summary_labels=summary_labels) self.sync_frequency = self.add_module( name='sync-frequency', module=sync_frequency, modules=parameter_modules, dtype='long' ) self.update_weight = self.add_module( name='update-weight', module=update_weight, modules=parameter_modules, dtype='float' ) def tf_initialize(self): super().tf_initialize() self.last_sync = self.add_variable( name='last-sync', dtype='long', shape=(), is_trainable=False, initializer=-1 ) def tf_step(self, variables, source_variables, **kwargs): assert all( util.shape(source) == util.shape(target) for source, target in zip(source_variables, variables) ) timestep = Module.retrieve_tensor(name='timestep') def apply_sync(): update_weight = self.update_weight.value() deltas = list() for source_variable, target_variable in zip(source_variables, variables): delta = update_weight * (source_variable - target_variable) deltas.append(delta) applied = self.apply_step(variables=variables, deltas=deltas) last_sync_updated = self.last_sync.assign(value=timestep) with tf.control_dependencies(control_inputs=(applied, last_sync_updated)): # Trivial operation to enforce control dependency return util.fmap(function=util.identity_operation, xs=deltas) def no_sync(): deltas = list() for variable in variables: delta = tf.zeros(shape=util.shape(variable), dtype=util.tf_dtype(dtype='float')) deltas.append(delta) return deltas sync_frequency = self.sync_frequency.value() zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long')) skip_sync = tf.math.less(x=(timestep - self.last_sync), y=sync_frequency) skip_sync = tf.math.logical_and( x=skip_sync, y=tf.math.greater_equal(x=self.last_sync, y=zero) ) return self.cond(pred=skip_sync, true_fn=no_sync, false_fn=apply_sync)