本文咱们主要来看看ParameterServerStrategy怎么分发核算,也便是ClusterCoordinator怎么运作。这是TF分布式的最终一篇。

本系列其他文章如下:

[翻译] TensorFlow 分布式之论文篇 Large-Scale Machine Learning on Heterogeneous Distribute

[翻译] TensorFlow 分布式之论文篇 “Implementation of Control Flow in TensorFlow”

[源码解析] TensorFlow 分布式环境(1) — 整体架构

[源码解析] TensorFlow 分布式环境(2)—Master 静态逻辑

[源码解析] TensorFlow 分布式环境(3)— Worker 静态逻辑

[源码解析] TensorFlow 分布式环境(4) — WorkerCache

[源码解析] TensorFlow 分布式环境(5) — Session

[源码解析] TensorFlow 分布式环境(6) — Master 动态逻辑

[源码解析] TensorFlow 分布式环境(7) — Worker 动态逻辑

[源码解析] TensorFlow 分布式环境(8) — 通讯机制

[翻译] 运用 TensorFlow 进行分布式练习

[源码解析] TensorFlow 分布式 DistributedStrategy 之基础篇

[源码解析] TensorFlow 之 分布式变量

[源码解析] TensorFlow 分布式之 MirroredStrategy

[源码解析] TensorFlow 分布式之 ParameterServerStrategy V1

[源码解析] TensorFlow 分布式之 ParameterServerStrategy V2

1. 思路

TensorFlow 2 推荐运用一种基于中心和谐的架构来进行参数服务器练习。每个作业者和参数服务器都运转一个 tf.distribution.Server,在此基础上,一个和谐者使命担任在作业者和参数服务器上创立资源,调度功用,并和谐练习。和谐器运用 tf.distribution.experimental.coordinator.ClusterCoordinator 来和谐集群,运用 tf.distribution.experimental.ParameterServerStrategy 来界说参数服务器上的变量和作业者的核算。

ClusterCoordinator 是一个用于组织和和谐长途函数履行的目标。该类用于创立容错(fault-tolerant)资源和调度函数到长途 TensorFlow 服务器。现在该类不支撑独立运用,它应该与旨在与之协作的 tf.distribution 战略一同运用。ClusterCoordinator 类现在只适用于和 tf.distribution.experimental.ParameterServerStrategy 一同作业。

[源码解析] TensorFlow 分布式之 ClusterCoordinator

1.1 运用

在运用 ParameterServerStrategy 界说一切的核算后,用户能够运用 tf.distribution.experimental.coordinator.ClusterCoordinator 类来创立资源并将练习步骤分配给长途作业者。

首要,咱们来创立一个 ClusterCoordinator 目标并传入战略目标。

strategy = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)

其次,创一个属于每个作业者(per-worker)的数据集和一个迭代器。在下面代码的 per_worker_dataset_fn 中,主张将 dataset_fn 包裹到 strategy.distribution_datasets_from_function 中,以允许无缝高效的把数据预取(prefetching )到 GPU。

@tf.function
def per_worker_dataset_fn():
  return strategy.distribute_datasets_from_function(dataset_fn)
per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)

最终一步是运用 ClusterCoordinator.schedule 将核算分配给长途作业者。

  • schedule 办法把一个 tf.function 刺进行列,并当即回来一个 future-like 的 RemoteValue 。行列之中的函数将被派发给后台线程中的长途作业者,RemoteValue 将被异步填充成果。
  • 用户能够运用 join 办法( ClusterCoordinator.join )来等候一切被规划(scheduled)的函数履行。
@tf.function
def step_fn(iterator):
	return next(iterator)
num_epoches = 4
steps_per_epoch = 5
for i in range(num_epoches):
  accuracy.reset_states()
  for _ in range(steps_per_epoch):
    coordinator.schedule(step_fn, args=(per_worker_iterator,))
  # Wait at epoch boundaries.
  coordinator.join()
  print ("Finished epoch %d, accuracy is %f." % (i, accuracy.result().numpy()))

下面是怎么得到 RemoteValue 的成果。

loss = coordinator.schedule(step_fn, args=(per_worker_iterator,))
print ("Final loss is %f" % loss.fetch())

用户也能够发动一切的步骤(steps),并在等候完结时做一些作业。

for _ in range(total_steps):
  coordinator.schedule(step_fn, args=(per_worker_iterator,))
while not coordinator.done():
  time.sleep(10)
  # Do something like logging metrics or writing checkpoints.

1.2 问题点

根据前面的代码,咱们总结出来问题点如下:

  • Worker 怎么知道运用哪些设备?
  • 怎么详细履行用户函数?
  • 怎么获取数据?

接下来咱们就测验经过分析代码来回答这些问题。

2. 界说

ClusterCoordinator 的主要思路如下。

  • 和谐者不是练习作业者之一,相反,它担任创立资源,如变量和数据集,调度 “tf.function”,保存查看点等等。
  • 为了使练习作业顺利进行,和谐者差遣 “tf.function” 在长途作业者上履行。
  • 在收到和谐者的恳求后,作业者经过从参数服务器读取变量、履行操作和更新参数服务器上的变量来履行 “tf.function”。
  • 每个作业者只处理来自和谐者的恳求,并与参数服务器进行通讯。而不与集群中的其他作业者直接互动。

[源码解析] TensorFlow 分布式之 ClusterCoordinator

ClusterCoordinator 界说详细如下,咱们能够看到,其主要是装备了 _strategy 成员变量,生成了 _cluster 成员变量。

@tf_export("distribute.experimental.coordinator.ClusterCoordinator", v1=[])
class ClusterCoordinator(object):
  def __new__(cls, strategy):
    #  ClusterCoordinator  is kept as a single instance to a given  Strategy .
    if strategy._cluster_coordinator is None:
      strategy._cluster_coordinator = super(
          ClusterCoordinator, cls).__new__(cls)
    return strategy._cluster_coordinator
  def __init__(self, strategy):
    """Initialization of a  ClusterCoordinator  instance.
    Args:
      strategy: a supported  tf.distribute.Strategy  object. Currently, only
         tf.distribute.experimental.ParameterServerStrategy  is supported.
    Raises:
      ValueError: if the strategy being used is not supported.
    """
    if not getattr(self, "_has_initialized", False):
      if not isinstance(strategy,
                        parameter_server_strategy_v2.ParameterServerStrategyV2):
        raise ValueError(
            "Only  tf.distribute.experimental.ParameterServerStrategy  "
            "is supported to work with "
            " tf.distribute.experimental.coordinator.ClusterCoordinator  "
            "currently.")
      self._strategy = strategy
      self.strategy.extended._used_with_coordinator = True
      self._cluster = Cluster(strategy)
      self._has_initialized = True
  def __del__(self):
    self._cluster.stop()
  @property
  def strategy(self):
    """Returns the  Strategy  associated with the  ClusterCoordinator ."""
    return self._strategy

2.1 Schedule

由 ClusterCoordinator 目标供给的最重要的 API 是 schedule,其会分派 tf.function 到一个作业者,以便异步履行,详细如下:

  • 该办法对错堵塞的,因为它把 fn 刺进行列,并当即回来 tf.distribution.experimental.coordinator.RemoteValue 目标。fn 排队等候稍后履行。
  • 在行列之中排队的函数将被派发给后台线程中的长途作业者来异步履行,他们的 RemoteValue 将被异步赋值。
  • 因为 schedule 不需求分配一个作业者,传递进来的 tf.function 能够在任何可用的作业者上履行。
  • 能够调用 fetch 来等候函数履行完结,并从长途作业者那里获取其输出。另一方面,也能够调用 tf.distribution.experimental.coordinator.ClusterCoordinator.join 来等候一切预订的函数完结。

失利和容错的战略如下:

  • 因为作业者在履行函数的任何时分都或许失利,所以函数有或许被部分履行,可是 tf.distribution.experimental.coordinator.ClusterCoordinator 确保在这些事情中,函数最终将在任何可用的作业者上履行。
  • schedule 确保 fn 至少在作业者上履行一次;如果其对应的作业者在履行过程中失利,因为函数的履行不是原子性的,所以一个函数或许被履行多次。
  • 如果被履行的作业者在完毕之前变得不行用,该函数将在另一个可用的作业者上重试。
  • 如果任何从前组织的函数呈现过错,schedule 将抛出其间任何一个过错,并清除到现在为止搜集的过错。用户能够在回来的 tf.distribution.experimental.coordinator.RemoteValue 上调用 fetch 来查看它们是否现已履行、失利或撤销,如果需求,能够从头组织相应的函数。当 schedule 引发反常时,它确保没有任何函数仍在履行。

Schedule 的详细界说如下,数据迭代器作为参数之一会和 fn 一同被传入。

  def schedule(self, fn, args=None, kwargs=None):
    """Schedules  fn  to be dispatched to a worker for asynchronous execution.
    This method is non-blocking in that it queues the  fn  which will be
    executed later and returns a 
     tf.distribute.experimental.coordinator.RemoteValue  object immediately.
     fetch  can be called on it to wait for the function execution to finish
    and retrieve its output from a remote worker. On the other hand, call
     tf.distribute.experimental.coordinator.ClusterCoordinator.join  to wait for
    all scheduled functions to finish.
     schedule  guarantees that  fn  will be executed on a worker at least once;
    it could be more than once if its corresponding worker fails in the middle
    of its execution. Note that since worker can fail at any point when
    executing the function, it is possible that the function is partially
    executed, but  tf.distribute.experimental.coordinator.ClusterCoordinator 
    guarantees that in those events, the function will eventually be executed on
    any worker that is available.
    If any previously scheduled function raises an error,  schedule  will raise
    any one of those errors, and clear the errors collected so far. What happens
    here, some of the previously scheduled functions may have not been executed.
    User can call  fetch  on the returned
     tf.distribute.experimental.coordinator.RemoteValue  to inspect if they have
    executed, failed, or cancelled, and reschedule the corresponding function if
    needed.
    When  schedule  raises, it guarantees that there is no function that is
    still being executed.
    At this time, there is no support of worker assignment for function
    execution, or priority of the workers.
     args  and  kwargs  are the arguments passed into  fn , when  fn  is
    executed on a worker. They can be
     tf.distribute.experimental.coordinator.PerWorkerValues  and in this case,
    the argument will be substituted with the corresponding component on the
    target worker. Arguments that are not
     tf.distribute.experimental.coordinator.PerWorkerValues  will be passed into
     fn  as-is. Currently,  tf.distribute.experimental.coordinator.RemoteValue 
    is not supported to be input  args  or  kwargs .
    Args:
      fn: A  tf.function ; the function to be dispatched to a worker for
        execution asynchronously. Regular python funtion is not supported to be
        scheduled.
      args: Positional arguments for  fn .
      kwargs: Keyword arguments for  fn .
    Returns:
      A  tf.distribute.experimental.coordinator.RemoteValue  object that
      represents the output of the function scheduled.
    Raises:
      Exception: one of the exceptions caught by the coordinator from any
        previously scheduled function, since the last time an error was thrown
        or since the beginning of the program.
    """
    if not isinstance(fn,
                      (def_function.Function, tf_function.ConcreteFunction)):
      raise TypeError(
          " tf.distribute.experimental.coordinator.ClusterCoordinator.schedule "
          " only accepts a  tf.function  or a concrete function.")
    # Slot variables are usually created during function tracing time; thus
    #  schedule  needs to be called within the  strategy.scope() .
    with self.strategy.scope():
      self.strategy.extended._being_scheduled = True  
      remote_value = self._cluster.schedule(fn, args=args, kwargs=kwargs)
      self.strategy.extended._being_scheduled = False  
      return remote_value

2.2 Join

Join 办法的作用是堵塞直到一切预订的函数都履行完毕,其详细特色如下:

  • 如果任何从前组织的函数产生过错,join 将因为抛出一个过错而失利,并清除到现在为止搜集的过错。如果产生这种状况,一些从前组织的函数或许没有被履行。
  • 用户能够对回来的 tf.distribution.experimental.coordinator.RemoteValue 调用 fetch 来查看它们是否现已履行、失利或撤销了。
  • 如果一些现已撤销的函数需求从头组织,用户应该再次调用 schedule 。
  • 当 join 回来或抛出反常时,它确保没有任何函数仍在履行。

  def join(self):
    """Blocks until all the scheduled functions have finished execution.
    If any previously scheduled function raises an error,  join  will fail by
    raising any one of those errors, and clear the errors collected so far. If
    this happens, some of the previously scheduled functions may have not been
    executed. Users can call  fetch  on the returned
     tf.distribute.experimental.coordinator.RemoteValue  to inspect if they have
    executed, failed, or cancelled. If some that have been cancelled need to be
    rescheduled, users should call  schedule  with the function again.
    When  join  returns or raises, it guarantees that there is no function that
    is still being executed.
    Raises:
      Exception: one of the exceptions caught by the coordinator by any
        previously scheduled function since the last time an error was thrown or
        since the beginning of the program.
    """
    self._cluster.join()

2.3 Done

Done 办法回来一切分发的函数是否现已履行完毕。如果任何从前分发的函数引发过错,done’将会失利。

  def done(self):
    """Returns whether all the scheduled functions have finished execution.
    If any previously scheduled function raises an error,  done  will fail by
    raising any one of those errors.
    When  done  returns True or raises, it guarantees that there is no function
    that is still being executed.
    Returns:
      Whether all the scheduled functions have finished execution.
    Raises:
      Exception: one of the exceptions caught by the coordinator by any
        previously scheduled function since the last time an error was thrown or
        since the beginning of the program.
    """
    return self._cluster.done()

2.4 Fetch

Fetch 会获取 remote values 的成果。

  def fetch(self, val):
    """Blocking call to fetch results from the remote values.
    This is a wrapper around
     tf.distribute.experimental.coordinator.RemoteValue.fetch  for a
     RemoteValue  structure; it returns the execution results of
     RemoteValue s. If not ready, wait for them while blocking the caller.
    Example:
    ```python
    strategy = ...
    coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
        strategy)
    def dataset_fn():
      return tf.data.Dataset.from_tensor_slices([1, 1, 1])
    with strategy.scope():
      v = tf.Variable(initial_value=0)
    @tf.function
    def worker_fn(iterator):
      def replica_fn(x):
        v.assign_add(x)
        return v.read_value()
      return strategy.run(replica_fn, args=(next(iterator),))
    distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
    distributed_iterator = iter(distributed_dataset)
    result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
    assert coordinator.fetch(result) == 1
    ```
    Args:
      val: The value to fetch the results from. If this is structure of
         tf.distribute.experimental.coordinator.RemoteValue ,  fetch()  will be
        called on the individual
         tf.distribute.experimental.coordinator.RemoteValue  to get the result.
    Returns:
      If  val  is a  tf.distribute.experimental.coordinator.RemoteValue  or a
      structure of  tf.distribute.experimental.coordinator.RemoteValue s,
      return the fetched  tf.distribute.experimental.coordinator.RemoteValue 
      values immediately if they are available, or block the call until they are
      available, and return the fetched
       tf.distribute.experimental.coordinator.RemoteValue  values with the same
      structure. If  val  is other types, return it as-is.
    """
    def _maybe_fetch(val):
      if isinstance(val, RemoteValue):
        return val.fetch()
      else:
        return val
    return nest.map_structure(_maybe_fetch, val)

3. 数据

除了调度长途函数,ClusterCoordinator 还帮助在一切作业者上创立数据集,并当一个作业者从失利中康复时重建这些数据集。用户能够经过调用 dataset_fn 来在worker设备上创立数据集。运用比如如下:

strategy = tf.distribute.experimental.ParameterServerStrategy(
    cluster_resolver=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
    strategy=strategy)
@tf.function
def worker_fn(iterator):
  return next(iterator)
def per_worker_dataset_fn():
  return strategy.distribute_datasets_from_function(
      lambda x: tf.data.Dataset.from_tensor_slices([3] * 3))
per_worker_dataset = coordinator.create_per_worker_dataset(
    per_worker_dataset_fn)
per_worker_iter = iter(per_worker_dataset)
remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,))
assert remote_value.fetch() == 3

3.1 树立数据集

上面代码运用了 create_per_worker_dataset 在worker上创立数据集,这些数据集由 dataset_fn 生成,并回来一个代表这些数据集的调集。在这样的数据集调集上调用 iter 会回来一个 tf.distribution.experimental.coordinator.PerWorkerValues,它是一个迭代器的调集,其间的迭代器现已被放置在各个作业者上。

需求留意,不支撑在迭代器的 “PerWorkerValues”上直接调用 “next”。该迭代器应该是作为一个参数传递给 tf.distribution.experimental.coordinator.ClusterCoordinator.schedule 。当计划的函数即将被作业者履行时,该函数将收到与该作业者相对应的单个迭代器。该函数能够对该迭代器调用 next 办法。

现在,schedule 办法假定作业者都是相同的,因而假定不同作业者上的数据集是一样的,除非它们包括 dataset.shuffle 操作,而且没有设置随机种子,在这种状况下,它们的洗牌方式会不同。正因为如此,主张将数据集无限地重复,并组织有限的步骤,而不是依赖于数据集的 OutOfRangeError 来完毕。

  def create_per_worker_dataset(self, dataset_fn):
    """Create dataset on workers by calling  dataset_fn  on worker devices.
    This creates the given dataset generated by dataset_fn on workers
    and returns an object that represents the collection of those individual
    datasets. Calling  iter  on such collection of datasets returns a
     tf.distribute.experimental.coordinator.PerWorkerValues , which is a
    collection of iterators, where the iterators have been placed on respective
    workers.
    Calling  next  on a  PerWorkerValues  of iterator is unsupported. The
    iterator is meant to be passed as an argument into
     tf.distribute.experimental.coordinator.ClusterCoordinator.schedule . When
    the scheduled function is about to be executed by a worker, the
    function will receive the individual iterator that corresponds to the
    worker. The  next  method can be called on an iterator inside a
    scheduled function when the iterator is an input of the function.
    Currently the  schedule  method assumes workers are all the same and thus
    assumes the datasets on different workers are the same, except they may be
    shuffled differently if they contain a  dataset.shuffle  operation and a
    random seed is not set. Because of this, we also recommend the datasets to
    be repeated indefinitely and schedule a finite number of steps instead of
    relying on the  OutOfRangeError  from a dataset.
    Args:
      dataset_fn: The dataset function that returns a dataset. This is to be
        executed on the workers.
    Returns:
      An object that represents the collection of those individual
      datasets.  iter  is expected to be called on this object that returns
      a  tf.distribute.experimental.coordinator.PerWorkerValues  of the
      iterators (that are on the workers).
    """
    return values_lib.get_per_worker_dataset(dataset_fn, self)

get_per_worker_dataset 则回来 PerWorkerDatasetFromDataset 或许 PerWorkerDatasetFromDatasetFunction。

def get_per_worker_dataset(dataset_or_dataset_fn, coordinator):
  if callable(dataset_or_dataset_fn):
    return PerWorkerDatasetFromDatasetFunction(dataset_or_dataset_fn,
                                               coordinator)
  else:
    return PerWorkerDatasetFromDataset(dataset_or_dataset_fn, coordinator)

3.2 PerWorkerDistributedDataset

PerWorkerDistributedDataset 代表了从一个数据集树立的作业者运用的分布式数据集。

class PerWorkerDatasetFromDataset(PerWorkerDatasetFromDatasetFunction):
  """Represents worker-distributed datasets created from a dataset."""
  def __init__(self, dataset, coordinator):
    """Makes an iterable from datasets created by the given dataset.
    It creates a dataset_fn which deserializes a dataset from a graph under the
    hood.
    Args:
      dataset: A tf.data.Dataset, a DistributedDataset or a
        DistributedDatasetsFromFunction
      coordinator: a  ClusterCoordinator  object, used to create dataset
        resources.
    """
    if isinstance(dataset, input_lib.DistributedDataset):
      original_dataset = dataset._original_dataset
      serialized = serialize_dataset_to_graph(original_dataset)
      def dataset_fn():
        deserialized = deserialize_dataset_from_graph(
            serialized, original_dataset.element_spec)
        dataset.build(dataset_to_replace=deserialized)
        return dataset
    elif isinstance(dataset, input_lib.DistributedDatasetsFromFunction):
      def dataset_fn():
        dataset.build()
        return dataset
    elif isinstance(dataset, dataset_ops.Dataset):
      serialized = serialize_dataset_to_graph(dataset)
      def dataset_fn():
        return deserialize_dataset_from_graph(serialized, dataset.element_spec)
    else:
      raise ValueError("Unexpected dataset type!")
    super(PerWorkerDatasetFromDataset, self).__init__(dataset_fn, coordinator)

3.3 PerWorkerDatasetFromDatasetFunction

PerWorkerDistributedDataset 代表了从一个数据集办法树立的作业者运用的分布式数据集。

iter 之中有:

  • 调用 _create_per_worker_iterator 得到一个 iter(dataset)。

  • 调用 self._coordinator._create_per_worker_resources 为每作业者生成一个 iterator。

  • 最终回来一个 PerWorkerDistributedIterator。

class PerWorkerDatasetFromDatasetFunction(object):
  """Represents worker-distributed datasets created from dataset function."""
  def __init__(self, dataset_fn, coordinator):
    """Makes an iterable from datasets created by the given function.
    Args:
      dataset_fn: A function that returns a  Dataset .
      coordinator: a  ClusterCoordinator  object, used to create dataset
        resources.
    """
    def disallow_variable_creation(next_creator, **kwargs):
      raise ValueError("Creating variables in  dataset_fn  is not allowed.")
    if isinstance(dataset_fn, def_function.Function):
      with variable_scope.variable_creator_scope(disallow_variable_creation):
        dataset_fn = dataset_fn.get_concrete_function()
    elif not isinstance(dataset_fn, tf_function.ConcreteFunction):
      with variable_scope.variable_creator_scope(disallow_variable_creation):
        dataset_fn = def_function.function(dataset_fn).get_concrete_function()
    self._dataset_fn = dataset_fn
    self._coordinator = coordinator
    self._element_spec = None
  def __iter__(self):
    # We would like users to create iterators outside  tf.function s so that we
    # can track them.
    if (not context.executing_eagerly() or
        ops.get_default_graph().building_function):
      raise RuntimeError(
          "__iter__() is not supported inside of tf.function or in graph mode.")
    def _create_per_worker_iterator():
      dataset = self._dataset_fn()
      return iter(dataset)
    # If PerWorkerDatasetFromDatasetFunction.__iter__ is called multiple
    # times, for the same object it should only create and register resource
    # once. Using object id to distinguish different iterator resources.
    per_worker_iterator = self._coordinator._create_per_worker_resources(
        _create_per_worker_iterator)
    # Setting type_spec of each RemoteValue so that functions taking these
    # RemoteValues as inputs can be traced.
    for iterator_remote_value in per_worker_iterator._values:
      iterator_remote_value._type_spec = (
          input_lib.get_iterator_spec_from_dataset(
              self._coordinator.strategy, self._dataset_fn.structured_outputs))
    return PerWorkerDistributedIterator(per_worker_iterator._values)
  @property
  def element_spec(self):
    """The type specification of an element of this dataset.
    This property is subject to change without notice.
    """
    return self._dataset_fn.structured_outputs.element_spec

3.4 _create_per_worker_resources

_create_per_worker_resources 会调用各个作业者的办法来让每个作业者得到数据。

def _create_per_worker_resources(self, fn, args=None, kwargs=None):
  """Synchronously create resources on the workers.
  The resources are represented by
   tf.distribute.experimental.coordinator.RemoteValue s.
  Args:
    fn: The function to be dispatched to all workers for execution
      asynchronously.
    args: Positional arguments for  fn .
    kwargs: Keyword arguments for  fn .
  Returns:
    A  tf.distribute.experimental.coordinator.PerWorkerValues  object, which
    wraps a tuple of  tf.distribute.experimental.coordinator.RemoteValue 
    objects.
  """
  results = []
  for w in self._cluster.workers:
    results.append(w.create_resource(fn, args=args, kwargs=kwargs))  
  return PerWorkerValues(tuple(results))

3.5 PerWorkerValues

PerWorkerValues 是一个容纳 value 列表的容器,每个作业者对应一个 value。Tf.distribution.experimental.coordinator.PerWorkerValues 包括一个值的调集,其间每个值都坐落其相应的作业者上,当被用作 tf.distribution.experimental.coordinator.ClusterCoordinator.schedule() 的 args 或 kwargs 时,某一个作业者的特定值将被传递到该作业者上履行的函数中。

创立 tf.distribution.experimental.coordinator.PerWorkerValues 目标的仅有途径是经过在 ClusterCoordinator.create_per_worker_dataset 回来的分布式数据集实例上调用 iter 。现在还不支撑创立自界说 tf.distribution.experimental.coordinator.PerWorkerValues 的机制。

@tf_export("distribute.experimental.coordinator.PerWorkerValues", v1=[])
class PerWorkerValues(composite_tensor.CompositeTensor):
  """A container that holds a list of values, one value per worker.
   tf.distribute.experimental.coordinator.PerWorkerValues  contains a collection
  of values, where each of the values is located on its corresponding worker,
  and upon being used as one of the  args  or  kwargs  of
   tf.distribute.experimental.coordinator.ClusterCoordinator.schedule() , the
  value specific to a worker will be passed into the function being executed at
  that corresponding worker.
  Currently, the only supported path to create an object of
   tf.distribute.experimental.coordinator.PerWorkerValues  is through calling
   iter  on a  ClusterCoordinator.create_per_worker_dataset -returned
  distributed dataset instance. The mechanism to create a custom
   tf.distribute.experimental.coordinator.PerWorkerValues  is not yet supported.
  """
  def __init__(self, values):
    for v in values:
      if not isinstance(v, RemoteValue):
        raise AssertionError(
            " PerWorkerValues  should only take  RemoteValue s.")
    self._values = tuple(values)
  @property
  def _type_spec(self):
    return PerWorkerValuesTypeSpec(
        self._values[0]._type_spec,  
        type(self))

获取数据的逻辑如下:

[源码解析] TensorFlow 分布式之 ClusterCoordinator

4. Cluster

Cluster 才是事务履行者。

4.1 界说

Cluster 是一个作业者集群。在初始化办法之中,会做如下处理:

  • 设置怎么疏忽参数服务器暂时过错。
  • 设定作业者的设备姓名。
  • 生成一系列作业者。

这里要留意的是怎么疏忽因为作业者瞬时衔接过错而陈述的毛病。

  • 作业者和参数服务器之间的瞬时衔接问题会由作业者转达给和谐者,这将导致和谐者以为存在参数服务器毛病。
  • 瞬时与永久的参数服务器毛病之间的区别是作业者陈述的数量。当这个环境变量设置为正整数 K 时,和谐器疏忽最多 K 个失利陈述,也便是说,只有超越 K 个履行过错,而且这些过错是因为同一个参数服务器实例导致的,咱们才以为参数服务器实例遇到了失利。
class Cluster(object):
  """A cluster with workers.
  We assume all function errors are fatal and based on this assumption our
  error reporting logic is:
  1) Both  schedule  and  join  can raise a non-retryable error which is the
  first error seen by the coordinator from any previously scheduled functions.
  2) When an error is raised, there is no guarantee on how many previously
  scheduled functions have been executed; functions that have not been executed
  will be thrown away and marked as cancelled.
  3) After an error is raised, the internal state of error will be cleared.
  I.e. functions can continue to be scheduled and subsequent calls of  schedule 
  or  join  will not raise the same error again.
  Attributes:
    failure_handler: The failure handler used to handler worker preemption
      failure.
    workers: a list of  Worker  objects in the cluster.
  """
  def __init__(self, strategy):
    """Initializes the cluster instance."""
    self._num_workers = strategy._num_workers
    self._num_ps = strategy._num_ps
    # 怎么疏忽参数服务器暂时过错
    self._transient_ps_failures_threshold = int(
        os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3))
    self._potential_ps_failures_lock = threading.Lock()
    self._potential_ps_failures_count = [0] * self._num_ps
    self._closure_queue = _CoordinatedClosureQueue()
    self.failure_handler = WorkerPreemptionHandler(context.get_server_def(),
                                                   self)
    # 设定 worker 的设备姓名
    worker_device_strings = [
        "/job:worker/replica:0/task:%d" % i for i in range(self._num_workers)
    ]
    # 生成 Workers
    self.workers = [
        Worker(i, w, self) for i, w in enumerate(worker_device_strings)
    ]

4.2 Schedule

这个类供给的最重要的API是 “schedule”/”join” 这对函数。”schedule” API对错堵塞的,它把一个 “tf.function “刺进行列,并当即回来一个 “RemoteValue”。

  def schedule(self, function, args, kwargs):
    """Schedules  function  to be dispatched to a worker for execution.
    Args:
      function: The function to be dispatched to a worker for execution
        asynchronously.
      args: Positional arguments for  fn .
      kwargs: Keyword arguments for  fn .
    Returns:
      A  RemoteValue  object.
    """
    closure = Closure(
        function,
        self._closure_queue._cancellation_mgr, 
        args=args,
        kwargs=kwargs)
    self._closure_queue.put(closure)
    return closure.output_remote_value
  def join(self):
    """Blocks until all scheduled functions are executed."""
    self._closure_queue.wait()

详细逻辑如下,虚线表示数据集被传入,这里的 Queue 是 from six.moves import queue 引进的 queue.Queue,咱们接下来在_CoordinatedClosureQueue之中会见到。

[源码解析] TensorFlow 分布式之 ClusterCoordinator

或许咱们从官方文档图来看,现在完结的是左面圆圈部分。

[源码解析] TensorFlow 分布式之 ClusterCoordinator

4.3 中止

中止代码如下,详细是调用行列的处理办法。

  def stop(self):
    """Stop worker, worker preemption threads, and the closure queue."""
    self.failure_handler.stop()
    for worker in self.workers:
      worker.stop()
    self._closure_queue.stop()
  def done(self):
    """Returns true if all scheduled functions are executed."""
    return self._closure_queue.done()

5. 使命 Closure

[源码解析] TensorFlow 分布式之 ClusterCoordinator

Closure 的作用是把使命封装起来,而且供给了其他功用。

class Closure(object):
  """Hold a function to be scheduled and its arguments."""
  def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
    if not callable(function):
      raise ValueError("Function passed to  ClusterCoordinator.schedule  must "
                       "be a callable object.")
    self._args = args or ()
    self._kwargs = kwargs or {}
    _disallow_remote_value_as_input(self._args)
    _disallow_remote_value_as_input(self._kwargs)
    if isinstance(function, def_function.Function):
      replica_args = _select_worker_slice(0, self._args)
      replica_kwargs = _select_worker_slice(0, self._kwargs)
      # Note: no need to handle function registration failure since this kind of
      # failure will not raise exceptions as designed in the runtime. The
      # coordinator has to rely on subsequent operations that raise to catch
      # function registration failure.
      # Record the function tracing overhead. Note that we pass in the tracing
      # count of the def_function.Function as a state tracker, so that metrics
      # will only record the time for actual function tracing (i.e., excluding
      # function cache lookups).
      with metric_utils.monitored_timer(
          "function_tracing", state_tracker=function._get_tracing_count):  
        self._concrete_function = function.get_concrete_function(
            *nest.map_structure(_maybe_as_type_spec, replica_args),
            **nest.map_structure(_maybe_as_type_spec, replica_kwargs))
    elif isinstance(function, tf_function.ConcreteFunction):
      self._concrete_function = function
    if hasattr(self, "_concrete_function"):
      # If we have a concrete function, we get to retrieve the output type spec
      # via the structured_output.
      output_type_spec = func_graph.convert_structure_to_signature(
          self._concrete_function.structured_outputs)
      self._function = cancellation_mgr.get_cancelable_function(
          self._concrete_function)
    else:
      # Otherwise (i.e. what is passed in is a regular python function), we have
      # no such information.
      output_type_spec = None
      self._function = function
    self.output_remote_value = RemoteValueImpl(self, output_type_spec)

5.1 履行

Closure 的 execute_on 担任运转,详细是在指定的设备上履行 self._function,便是用户自界说的 function。需求留意的是,with context.executor_scope(worker.executor) 运用了 context。

  def execute_on(self, worker):
    """Executes the closure on the given worker.
    Args:
      worker: a Worker object.
    """
    replica_args = _select_worker_slice(worker.worker_index, self._args)
    replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs)
    e = (
        _maybe_rebuild_remote_values(worker, replica_args) or
        _maybe_rebuild_remote_values(worker, replica_kwargs))
    if e:
      if not isinstance(e, InputError):
        e = InputError(e)
      self.output_remote_value._set_error(e) 
      return
    with ops.device(worker.device_name): # 在指定设备上
      with context.executor_scope(worker.executor): # 经过上下文
        with metric_utils.monitored_timer("closure_execution"):
          output_values = self._function( # 运转用户的参数
              *nest.map_structure(_maybe_get_remote_value, replica_args),
              **nest.map_structure(_maybe_get_remote_value, replica_kwargs))
    self.output_remote_value._set_values(output_values) 

Self._function 是用户自界说的 function,咱们再给出一个办法示例,能够看出来能够运用 strategy.run 把练习办法分发到远端作业者进行练习。

@tf.function
def worker_fn(iterator):
	def replica_fn(inputs):
      batch_data, labels = inputs
      # calculate gradient, applying gradient, metrics update etc.
	strategy.run(replica_fn, args=(next(iterator),))

5.2 撤销

用户能够设置撤销 Closure,便是在回来值之中做下设置。

  def mark_cancelled(self):
    self.output_remote_value._set_error(  
        errors.CancelledError(
            None, None, "The corresponding function is "
            "cancelled. Please reschedule the function."))

5.3 ResourceClosure

ResourceClosure 是派生类,把 Closure 用 RemoteValue 包装起来。实际上运用的都是 ResourceClosure。

class ResourceClosure(Closure):
  def build_output_remote_value(self):
    if self._output_remote_value_ref is None:
      # We need to remember the Closure object in the  RemoteValue  here.
      ret = RemoteValueImpl(self, self._output_type_spec)
      self._output_remote_value_ref = weakref.ref(ret)
      return ret
    else:
      return self._output_remote_value_ref()

6. 行列

_CoordinatedClosureQueue 是使命所在的行列。

6.1 界说

from six.moves import queue
class _CoordinatedClosureQueue(object):
  """Manage a queue of closures, inflight count and errors from execution.
  This class is thread-safe.
  """
  def __init__(self):
    #  self._inflight_closure_count  only tracks the number of inflight closures
    # that are "in generation". Once an error occurs, error generation is
    # incremented and all subsequent arriving closures (from inflight) are
    # considered "out of generation".
    self._inflight_closure_count = 0
    self._queue_lock = threading.Lock()
    # Condition indicating that all pending closures (either queued or inflight)
    # have been processed, failed, or cancelled.
    self._stop_waiting_condition = threading.Condition(self._queue_lock)
    # Condition indicating that an item becomes available in queue (not empty).
    self._closures_queued_condition = threading.Condition(self._queue_lock)
    self._should_process_closures = True
    # Condition indicating that a queue slot becomes available (not full).
    # Note that even with "infinite" queue size, there is still a "practical"
    # size limit for the queue depending on host memory capacity, and thus the
    # queue will eventually become full with a lot of enqueued closures.
    self._queue_free_slot_condition = threading.Condition(self._queue_lock)
    # Condition indicating there is no inflight closures.
    self._no_inflight_closure_condition = threading.Condition(self._queue_lock)
    # Use to cancel in-flight closures.
    self._cancellation_mgr = cancellation.CancellationManager()
    self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
    self._error = None
    # The following is a lock to make sure when  wait  is called and before it
    # returns no  put  can be executed during this period. It is because  wait 
    # won't know what to do with newly put closures. This lock adds an cutoff
    # for  wait  so that closures put into the queue while waiting would not be
    # taken responsible by this  wait .
    #
    # We cannot reuse the  self._queue_lock  since when  wait  waits for a
    # condition, the  self._queue_lock  will be released.
    #
    # We don't use a reader/writer's lock on purpose to reduce the complexity
    # of the code.
    self._put_wait_lock = threading.Lock()

6.2 刺进取出

Put 和 get 办法分别担任刺进和取出。

  def put(self, closure):
    """Put a closure into the queue for later execution.
    If  mark_failed  was called before  put , the error from the first
    invocation of  mark_failed  will be raised.
    Args:
      closure: The  Closure  to put into the queue.
    """
    with self._put_wait_lock, self._queue_lock:
      self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
      self._queue.put(closure, block=False)
      self._raise_if_error()
      self._closures_queued_condition.notify()
  def get(self, timeout=None):
    """Return a closure from the queue to be executed."""
    with self._queue_lock:
      while self._queue.empty() and self._should_process_closures:
        if not self._closures_queued_condition.wait(timeout=timeout):
          return None
      if not self._should_process_closures:
        return None
      closure = self._queue.get(block=False)
      self._queue_free_slot_condition.notify()
      self._inflight_closure_count += 1
      return closure

Put_back 则担任把 closure 从头放回queue。

  def put_back(self, closure):
    """Put the closure back into the queue as it was not properly executed."""
    with self._queue_lock:
      if self._inflight_closure_count < 1:
        raise AssertionError("There is no inflight closures to put_back.")
      if self._error:
        closure.mark_cancelled()
      else:
        self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
        self._queue.put(closure, block=False)
        self._closures_queued_condition.notify()
      self._inflight_closure_count -= 1
      if self._inflight_closure_count == 0:
        self._no_inflight_closure_condition.notifyAll()

6.3 等候

办法 wait 会等候一切 closures 完毕。

  def wait(self, timeout=None):
    """Wait for all closures to be finished before returning.
    If  mark_failed  was called before or during  wait , the error from the
    first invocation of  mark_failed  will be raised.
    Args:
      timeout: A float specifying a timeout for the wait in seconds.
    Returns:
      True unless the given timeout expired, in which case it returns False.
    """
    with self._put_wait_lock, self._queue_lock:
      while (not self._error and
             (not self._queue.empty() or self._inflight_closure_count > 0)):
        if not self._stop_waiting_condition.wait(timeout=timeout):
          return False
      self._raise_if_error()
      return True

6.4 反常&完毕

Mark_failed 和 done 则是处理完毕和反常的一套组合。

  def mark_failed(self, e):
    """Sets error and unblocks any wait() call."""
    with self._queue_lock:
      # TODO(yuefengz): maybe record all failure and give users more
      # information?
      if self._inflight_closure_count < 1:
        raise AssertionError("There is no inflight closures to mark_failed.")
      if self._error is None:
        self._error = e
      self._inflight_closure_count -= 1
      if self._inflight_closure_count == 0:
        self._no_inflight_closure_condition.notifyAll()
      self._stop_waiting_condition.notifyAll()
  def done(self):
    """Returns true if the queue is empty and there is no inflight closure.
    If  mark_failed  was called before  done , the error from the first
    invocation of  mark_failed  will be raised.
    """
    with self._queue_lock:
      self._raise_if_error()
      return self._queue.empty() and self._inflight_closure_count == 0

6.5 中止

Stop 和 _cancel_all_closures 担任暂停 closures。


  def stop(self):
    with self._queue_lock:
      self._should_process_closures = False
      self._closures_queued_condition.notifyAll()
  def _cancel_all_closures(self):
    """Clears the queue and sets remaining closures cancelled error.
    This method expects self._queue_lock to be held prior to entry.
    """
    self._cancellation_mgr.start_cancel()
    while self._inflight_closure_count > 0:
      self._no_inflight_closure_condition.wait()
    while True:
      try:
        closure = self._queue.get(block=False)
        self._queue_free_slot_condition.notify()
        closure.mark_cancelled()
      except queue.Empty:
        break
    # The cancellation manager cannot be reused once cancelled. After all
    # closures (queued or inflight) are cleaned up, recreate the cancellation
    # manager with clean state.
    # Note on thread-safety: this is triggered when one of theses
    # ClusterCoordinator APIs are called:  schedule ,  wait , and  done . At the
    # same time, no new closures can be constructed (which reads the
    # _cancellation_mgr to get cancellable functions).
    self._cancellation_mgr = cancellation.CancellationManager()
  def _raise_if_error(self):
    """Raises the error if one exists.
    If an error exists, cancel the closures in queue, raises it, and clear
    the error.
    This method expects self._queue_lock to be held prior to entry.
    """
    if self._error:
      logging.error("Start cancelling closures due to error %r: %s",
                    self._error, self._error)
      self._cancel_all_closures()
      try:
        raise self._error  
      finally:
        self._error = None

7.4 Worker

Worker 是函数的履行者。

7.1 界说

Worker 的界说如下,其发动了一个线程来运转 _process_queue。

class Worker(object):
  """A worker in a cluster.
  Attributes:
    worker_index: The index of the worker in the cluster.
    device_name: The device string of the worker, e.g. "/job:worker/task:1".
    executor: The worker's executor for remote function execution.
    failure_handler: The failure handler used to handler worker preemption
      failure.
  """
  def __init__(self, worker_index, device_name, cluster):
    self.worker_index = worker_index
    self.device_name = device_name
    # 这里会有一个executor
    self.executor = executor.new_executor(enable_async=False)
    self.failure_handler = cluster.failure_handler
    self._cluster = cluster
    self._resource_remote_value_refs = []
    self._should_worker_thread_run = True
    # Worker threads need to start after  Worker 's initialization.
    threading.Thread(target=self._process_queue,
                     name="WorkerClosureProcessingLoop-%d" % self.worker_index,
                     daemon=True).start()

New_executor 会调用 TFE_NewExecutor。

def new_executor(enable_async):
  handle = pywrap_tfe.TFE_NewExecutor(enable_async)
  return Executor(handle)

TFE_NewExecutor 界说在 tensorflow/c/eager/c_api_experimental.cc,其生成了 TFE_Executor。

TFE_Executor* TFE_NewExecutor(bool is_async) {
  return new TFE_Executor(is_async);
}

TFE_Executor 界说如下,Executor类是会话履行器的抽象,在 TF2 之中,也有 EagerExecutor。

struct TFE_Executor {
  explicit TFE_Executor(bool async)
      : owned_executor(new tensorflow::EagerExecutor(async)) {}
  explicit TFE_Executor(tensorflow::EagerExecutor* executor)
      : owned_executor(nullptr), unowned_executor(executor) {}
  tensorflow::EagerExecutor* executor() {
    return owned_executor == nullptr ? unowned_executor : owned_executor.get();
  }
  std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
  tensorflow::EagerExecutor* unowned_executor;
};

7.2 处理

_process_queue 办法会从 queue 之中取出 Closure,然后运转使命。

  def _process_queue(self):
    """Function running in a worker thread to process closure queues."""
    self._maybe_delay()
    while self._should_worker_thread_run:
      closure = self._cluster._closure_queue.get()  
      if not self._should_worker_thread_run or closure is None:
        return
      self._process_closure(closure)
      # To properly stop the worker and preemption threads, it is important that
      #  ClusterCoordinator  object is not held onto so its  __del__  can be
      # called. By removing the reference to the  closure  that has already been
      # processed, we ensure that the  closure  object is released, while
      # getting the next  closure  at above  self._cluster._closure_queue.get() 
      # call.
      del closure

7.2.1 等候

_process_queue 之中首要会调用 _maybe_delay 等候环境变量装备。

  def _maybe_delay(self):
    """Delay if corresponding env vars are set."""
    # If the following two env vars variables are set. Scheduling for workers
    # will start in a staggered manner. Worker i will wait for
    #  TF_COORDINATOR_SCHEDULE_START_DELAY  * i seconds, not exceeding
    #  TF_COORDINATOR_SCHEDULE_START_DELAY_MAX .
    delay_secs = int(os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY", "0"))
    delay_cap = int(
        os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY_MAX", "0"))
    if delay_cap:
      delay_secs = min(delay_secs * self.worker_index, delay_cap)
    if delay_secs > 0:
      logging.info("Worker %d sleeping for %d seconds before running function",
                   self.worker_index, delay_secs)
    time.sleep(delay_secs)

7.2.2 处理使命

_process_queue 之中接着会调用 _process_closure 来运转 closure。

  def _process_closure(self, closure):
    """Runs a closure with preemption handling."""
    try:
      with self._cluster.failure_handler.wait_on_failure(
          on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure),  
          on_recovery_fn=self._set_resources_aborted,
          worker_device_name=self.device_name):
        closure.execute_on(self)
        with metric_utils.monitored_timer("remote_value_fetch"):
          # Copy the remote tensor to local (the coordinator) in case worker
          # becomes unavailable at a later time.
          closure.output_remote_value.get()
        self._cluster._closure_queue.mark_finished()  
    except Exception as e:  
      # Avoid logging the derived cancellation error
      if not isinstance(e, errors.CancelledError):
        logging.error(
            "/job:worker/task:%d encountered the following error when "
            "processing closure: %r:%s", self.worker_index, e, e)
      closure.output_remote_value._set_error(e)  
      self._cluster._closure_queue.mark_failed(e)  

7.3 数据

咱们接下来看看怎么把数据读取放到作业者上运转。前面提到了,在 _create_per_worker_resources 会调用 create_resource,为每一个作业者树立其自己的资源。

  def create_resource(self, function, args=None, kwargs=None):
    """Synchronously creates a per-worker resource represented by a  RemoteValue .
    Args:
      function: the resource function to be run remotely. It should be a
         tf.function , a concrete function or a Python function.
      args: positional arguments to be passed to the function.
      kwargs: keyword arguments to be passed to the function.
    Returns:
      one or several RemoteValue objects depending on the function return
      values.
    """
    # Some notes about the concurrency: currently all the activities related to
    # the same worker such as creating resources, setting resources' aborted
    # status, and executing closures happen on the same thread. This allows us
    # to have simpler logic of concurrency.
    closure = ResourceClosure(
        function,
        self._cluster.closure_queue._cancellation_mgr,  
        args=args,
        kwargs=kwargs)
    resource_remote_value = closure.build_output_remote_value()
    self._register_resource(resource_remote_value)
    # The following is a short-term solution to lazily create resources in
    # parallel.
    resource_remote_value._set_aborted() 
    return resource_remote_value

_register_resource 则会把每个 Worker 的资源注册到 Worker 之上。

def _register_resource(self, resource_remote_value):
  if not isinstance(resource_remote_value, RemoteValue):
    raise ValueError("Resource being registered is not of type "
                     " tf.distribute.experimental.coordinator.RemoteValue .")
  self._resource_remote_value_refs.append(weakref.ref(resource_remote_value))

逻辑如下,虚线表述数据流。用户经过 put 办法向行列之中放入 Closure,Worker 经过 put 办法从行列获取 Closure 履行。

[源码解析] TensorFlow 分布式之 ClusterCoordinator

7.4 中止

Stop 等一系列办法担任中止。

  def stop(self):
    """Ensure the worker thread is closed."""
    self._should_worker_thread_run = False
  def _set_resources_aborted(self):
    for weakref_resource in self._resource_remote_value_refs:
      resource = weakref_resource()
      if resource:
        resource._set_aborted()  # pylint: disable=protected-access
  def _set_dead(self):
    raise NotImplementedError("_set_dead is not implemented.")

7.5 与 Strategy 联络

至此,咱们其实还没有正式和 Strategy 联络起来,咱们再用一个比如来看看,这里会发现,传递给 coordinator 的办法之中,会调用 strategy.run(replica_fn, args=(next(iterator),)),这样就和 strategy 联络起来了。

    strategy = ...
    coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
        strategy)
    def dataset_fn():
      return tf.data.Dataset.from_tensor_slices([1, 1, 1])
    with strategy.scope():
      v = tf.Variable(initial_value=0)
    @tf.function
    def worker_fn(iterator):
      def replica_fn(x):
        v.assign_add(x)
        return v.read_value()
      return strategy.run(replica_fn, args=(next(iterator),)) # 这里正式联络起来
    distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
    distributed_iterator = iter(distributed_dataset)
    result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
    assert coordinator.fetch(result) == 1

8. Failover

8.1 战略

应对失利的整体战略大致如下:

  • 当发现一个作业者失利了,Coordinator 把 function 再次放入行列,然后发给另一个作业者履行,一起发动一个后台线程等候康复,如果康复了,则用资源来重建这个作业者,持续分配作业。

  • 因而,一些作业者的失利并不阻碍集群持续作业,这使得集群之中的实例能够偶然不行用(例如,可抢占或spot 实例)。可是和谐者和参数服务器有必要始终可用,这样集群才干获得发展。

[源码解析] TensorFlow 分布式之 ClusterCoordinator

8.2 作业者失利

当产生作业者失利(failure)时分,详细逻辑如下:

  • ClusterCoordinator 类与 tf.distribution.experimental.ParameterServerStrategy 一同运用时,具有内置的作业者毛病容错功用。也便是说,当一些作业者因为任何原因,和谐器无法联络上它们,这些作业者的练习进度将持续由其余作业者完结。
  • 在作业者康复时,之前供给的数据集函数(关于自界说练习循环,能够是 ClusterCoordinator.create_per_worker_dataset,或许是 tf.keras.utils.experimental.DatasetCreator 用于 Model.fit )将被调用到作业者身上,以从头创立数据集。
  • 当一个失利的作业者康复之后,在运用经过 create_per_worker_dataset 创立的数据被从头树立后,它将被添加到函数履行中。

8.3 参数服务器或许和谐器毛病

当参数服务器失利时,schedule,join 或 done 会引发 tf.errors.UnavailableError。在这种状况下,除了重置失利的参数服务器外,用户还应该从头发动和谐器,使其从头衔接到作业者和参数服务器,从头创立变量,并加载查看点。如果和谐器产生毛病,在用户把它重置回来之后,程序会主动衔接到作业者和参数服务器,并从查看点持续前进。因为和谐器自身也或许变得不行用。因而主张运用某些东西以便不丢掉练习进度:

  • 因而,在用户的程序中,有必要定时保存查看点文件,并在程序开端时康复。如果 “tf.keras.optimizers.Optimizer” 被应用 checkpoint,在从查看点康复后,其 “iterations” 特点会大致显现现已进行的步骤数。这能够用来决定在练习完结前还需求多少个 epochs 和步骤(steps)。
  • 关于 Model.fit,你应该运用 BackupAndRestore 回调,它能够主动处理进度的保存和康复。
  • 关于一个自界说的练习循环,你应该定时查看模型变量,并在练习开端前从查看点(如果有的话)加载模型变量。如果优化器有查看点,练习进度能够从 optimizer.iterations 中大致推断出来。
checkpoint_manager = tf.train.CheckpointManager(
    tf.train.Checkpoint(model=model, optimizer=optimizer),
    checkpoint_dir,
    max_to_keep=3)
if checkpoint_manager.latest_checkpoint:
  checkpoint = checkpoint_manager.checkpoint
  checkpoint.restore(
      checkpoint_manager.latest_checkpoint).assert_existing_objects_matched()
global_steps = int(optimizer.iterations.numpy())
starting_epoch = global_steps // steps_per_epoch
for _ in range(starting_epoch, num_epoches):
  for _ in range(steps_per_epoch):
    coordinator.schedule(step_fn, args=(per_worker_iterator,))
  coordinator.join()
  checkpoint_manager.save()

8.4 回来 RemoteValue

如果一个函数被成功履行,就能够成功获取到 RemoteValue。这是因为现在在履行完一个函数后,回来值会当即被复制到和谐器。如果在复制过程中呈现任何作业者毛病,该函数将在另一个可用的作业者上重试。因而,如果你想优化功能,你能够组织(schedule)一个没有回来值的函数。

8.5 过错陈述

一旦和谐器发现一个过错,如来自参数服务器的 UnavailableError 或其他应用过错,如来自 tf.debugging.check_numerics 的 InvalidArgument,它将在引发过错之前撤销一切 pending 和排队(queued)的函数。获取它们相应的 RemoteValue 将引发一个 CancelledError 。

在引发过错后,和谐器将不会引发相同的过错或任何引发一个来自已撤销函数的过错。

ClusterCoordinator 假定一切的函数过错都是致命的,基于这个假定,其的过错陈述逻辑是:

  • Schedule 和 join 都能够引发一个不行重试的过错,这是和谐者从任何从前组织的函数中看到的第一个过错。
  • 当一个过错被抛出时,不确保有多少从前组织的功用被履行;没有被履行的功用将被丢弃并被标记为撤销。
  • 在一个过错被抛出后,过错的内部状态将被清除。

8.6 WorkerPreemptionHandler

WorkerPreemptionHandler 是处理失利的主要模块,其界说如下:

class WorkerPreemptionHandler(object):
  """Handles worker preemptions."""
  def __init__(self, server_def, cluster):
    self._server_def = server_def
    self._cluster = cluster
    self._cluster_update_lock = threading.Lock()
    self._cluster_due_for_update_or_finish = threading.Event()
    self._worker_up_cond = threading.Condition(self._cluster_update_lock)
    self._error_from_recovery = None
    self._should_preemption_thread_run = True
    self._preemption_handler_thread = threading.Thread(
        target=self._preemption_handler,
        name="WorkerPreemptionHandler",
        daemon=True)
    self._preemption_handler_thread.start()

8.6.1 装备

在 Cluster 生成时,会把 WorkerPreemptionHandler 装备进来。

self.failure_handler = WorkerPreemptionHandler(context.get_server_def(), self)

8.6.2 等候

在处理 closure 时,会用 wait_on_failure 包裹一层用来处理过错。

  def _process_closure(self, closure):
    """Runs a closure with preemption handling."""
    assert closure is not None
    try:
      with self._cluster.failure_handler.wait_on_failure(
          on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure),  
          on_recovery_fn=self._set_resources_aborted,
          worker_device_name=self.device_name):
        closure.execute_on(self)

WorkerPreemptionHandler 的 wait_on_failure 办法如下:

  @contextlib.contextmanager
  def wait_on_failure(self,
                      on_failure_fn=None,
                      on_transient_failure_fn=None,
                      on_recovery_fn=None,
                      worker_device_name="(unknown)"):
    """Catches worker preemption error and wait until failed workers are back.
    Args:
      on_failure_fn: an optional function to run if preemption happens.
      on_transient_failure_fn: an optional function to run if transient failure
        happens.
      on_recovery_fn: an optional function to run when a worker is recovered
        from preemption.
      worker_device_name: the device name of the worker instance that is passing
        through the failure.
    Yields:
      None.
    """
    try:
      yield
    except (errors.OpError, InputError) as e:
      # If the error is due to temporary connectivity issues between worker and
      # ps, put back closure, ignore error and do not mark worker as failure.
      if self._cluster._record_and_ignore_transient_ps_failure(e):  
        if on_transient_failure_fn:
          on_transient_failure_fn()
        return
      # Ignoring derived CancelledErrors to tolerate transient failures in
      # PS-worker communication, which initially exposed as an UnavailableError
      # and then lead to sub-function cancellation, subsequently getting
      # reported from worker to chief as CancelledError.
      # We do not mark either worker or PS as failed due to only CancelledError.
      # If there are real (non-transient) failures, they must also be reported
      # as other errors (UnavailableError most likely) in closure executions.
      if isinstance(e, errors.CancelledError) and "/job:" in str(e):
        if on_transient_failure_fn:
          on_transient_failure_fn()
        return
      # This reraises the error, if it's not considered recoverable; otherwise,
      # the following failure recovery logic run. At this time, only worker
      # unavailability is recoverable. PS unavailability as well as other
      # errors in the user function is not recoverable.
      self._validate_preemption_failure(e)
      if on_failure_fn:
        on_failure_fn()
      with self._cluster_update_lock:
        self._cluster_due_for_update_or_finish.set()
        self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC)
        if self._error_from_recovery:
          try:
            raise self._error_from_recovery
          finally:
            self._error_from_recovery = None
      if on_recovery_fn:
        with self.wait_on_failure(
            on_recovery_fn=on_recovery_fn,
            on_transient_failure_fn=on_transient_failure_fn,
            worker_device_name=worker_device_name):
          on_recovery_fn()

_validate_preemption_failure 界说如下:

  def _validate_preemption_failure(self, e):
    """Validates that the given exception represents worker preemption."""
    # Only categorize the failure as a worker preemption if the cancellation
    # manager did not attempt to cancel the blocking operations.
    if _is_worker_failure(e) and (
        not self._cluster._closure_queue._cancellation_mgr.is_cancelled):  
      return
    raise e

8.6.3 handler

WorkerPreemptionHandler 有一个后台线程 _preemption_handler_thread。

    self._preemption_handler_thread = threading.Thread(
        target=self._preemption_handler,
        name="WorkerPreemptionHandler",
        daemon=True)
    self._preemption_handler_thread.start()

_preemption_handler 会进行必要的过错处理。


  def _preemption_handler(self):
    """A loop that handles preemption.
    This loop waits for signal of worker preemption and upon worker preemption,
    it waits until all workers are back and updates the cluster about the
    restarted workers.
    """
    assert self._should_preemption_thread_run
    while True:
      self._cluster_due_for_update_or_finish.wait()
      if not self._should_preemption_thread_run:
        break
      with self._cluster_update_lock:
        try:
          context.context().update_server_def(self._server_def)
          # Cluster updated successfully, clear the update signal, and notify
          # all workers that they are recovered from failure.
          self._worker_up_cond.notify_all()
          # The check for _should_preemption_thread_run is necessary since the
          #  stop  may have already set _cluster_due_for_update_or_finish.
          if self._should_preemption_thread_run:
            self._cluster_due_for_update_or_finish.clear()
        except Exception as e:  
          try:
            self._validate_preemption_failure(e)
          except Exception as ps_e: 
            # In this case, a parameter server fails. So we raise this error to
            # the caller of  wait_on_failure .
            self._error_from_recovery = ps_e
            self._worker_up_cond.notify_all()
            if self._should_preemption_thread_run:
              self._cluster_due_for_update_or_finish.clear()
          # NOTE: Since the first RPC (GetStatus) of update_server_def is
          # currently blocking by default, error should only happen if:
          # (1) More workers failed while waiting for the previous workers to
          #     come back;
          # (2) Worker failed when exchanging subsequent RPCs after the first
          #     RPC returns.
          # Consider adding backoff retry logic if we see the error logged
          # too frequently.

9. 总结

根据前面的代码,咱们总结出来问题点如下:

  • Worker 怎么知道运用哪些设备?答案是:在集群树立作业者时分,会给每一个作业者设定一个设备。

  • 怎么详细履行用户函数?答案是:在作业者运转 Closure 时分,会在指定运转在本作业者设备上,然后运转指定的办法(Self._function)。Self._function 是用户自界说的 function,其间能够运用 strategy.run 把练习办法分发到远端作业者进行练习。

  • 怎么获取数据?答案是:为每个作业者树立一个 PerWorkerValues,PerWorkerValues 是一个容纳 value 列表的容器,每个作业者从对应 PerWorkerValues 之中获取数据。

0xEE 个人信息

★★★★★★关于生活和技能的考虑★★★★★★

微信公众账号:罗西的考虑

0xFF 参阅

tensorflow源码解析之distributed_runtime

TensorFlow分布式练习

  • TensorFlow内核分析
  • 源代码

Tensorflow分布式原理了解

TensorFlow架构与规划:概述

Tensorflow 跨设备通讯

TensorFlow 篇 | TensorFlow 2.x 分布式练习概览