位置: IT常识 - 正文

Pytorch - 弹性训练原理(pytorch如何训练模型)

编辑:rootadmin
Pytorch - 弹性训练原理

推荐整理分享Pytorch - 弹性训练原理(pytorch如何训练模型),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch metric learning,pytorch tanh,pytorch 训练模型,pytorch functional,pytorch metric learning,pytorch functional,pytorch 训练模型,pytorch 训练模型,内容如对您有帮助,希望把文章链接给更多的朋友!

Pytorch在1.9.0引入了torchrun,用其替代1.9.0以前版本的torch.distributed.launch。torchrun在torch.distributed.launch 功能的基础上主要新增了两个功能:

Failover: 当worker训练失败时,会自动重新启动所有worker继续进行训练;

Elastic: 可以动态增加或或删除node节点;

弹性训练代码同DDP代码编写的思路基本一致,只要在DDP代码上增加以下两点即可:

checkpoint处理:由于再每次增加或删除node时,会将所有worker kill掉,然后再重新启动所有worker进行训练。因此,在训练代码中要对训练的状态进行保存,以保证重启后能接着上次的状态继续训练。

超参调解:由于node节点数的变化,会导致global batch size的变化,因此我们的learning rate一般也要做相应的调整,保证训练出的模质量不受影响。

代码见第二节 最下面

当编写完弹性训练代码后,我们可以使用torchrun来启动弹性训练任务:

--nnodes=1:3 :表示当前训练任务接受最少1个node,最多3个node参与分布式训练;

--nproc_per_node=4:表示每个node上节点有4个process

--max_restarts=3: worker group最大的重启次数;这里需要注意的是,node fail、node scale down和node scale up都会导致restart;

--rdzv_id=1:一个unique的job id,所有node均使用同一个job id;

--rdzv_backend: rendezvous的backend实现,默认支持c10d和etcd两种;rendezvous用于多个node之间的通信和协调;

--rdzv_endpoint:rendezvous的地址,应该为一个node的host ip和port;

torchrun \ --nnodes=1:3\ --nproc_per_node=4\ --max_restarts=3\ --rdzv_id=1\ --rdzv_backend=c10d\ --rdzv_endpoint="192.0.0.1:1234"\ train_elastic.py3 整体架构

弹性调度的架构如上图所示,其中最关键角色为elastic agent。在每个Node上面都有一个elastic agent进程,其负责管理当前Node上面的所有workers。

当我们调用torchrun 命令启动弹性训练任务后:

首先,elastic agent会触发rendezvous 流程; rendezvous的功能是在所有elastic agent间做协调和同步,该接口会一直阻塞直到至少min个elastic agent加入进来后返回;

然后,elastic agent会启动当前Node的所有workers

最后,elastic agent会监控当前Node上所有workers的运行状态,并根据workers的状态进行相应的处理(例如restart worker)

4 Elastic Agent

本小结,我们详细分析下Elastic Agent的实现。Elastic Agent在Pytorch代码中由以下对象构成:

Elastic Agent是抽象基类

SimpleElasticAgent提供了更完整的Agent接口,并且实现了部分接口

LocalElasticAgent则是实现剩余的接口

Elastic Agent在代码中的调用逻辑如下:

torch.distributed.launcher.api:launch_agent() 弹性训练逻辑的入口;

首先、会构建一个RendezvousParameters来描述Rendezvous调用时所需要的参数,例如min_nodes/max_nodes/endpoint等;

然后、构建WorkerSpec描述当前Node上启动Wokers的信息, 例如max_restart/entrypoint等;

再然后,构建LocalElasticAgent对象;

最后,调用LocalElasticAgent的run接口启动当前node的workers进行弹性训练;

Elastic run接口主要由两个部分逻辑组成:

若process group的状态为succeeded:调用_exit_barrier接口等待所有node上agent相应并退出

若process group的状态为unhealthy或failed: 如果重试次数小于_remaining_restart则restart所有worker进程,否则stop所有worker,并退出;

若process group的状态为healthy: 则判断当前是否有node等待加入,如果有则restart_worker;(注:restart worker的实现逻辑是先stop 所有worker,然后在调用_initialize_workers)

SimpleElasticAgent._initialize_workers:先调用_rendezvous等待至少min 个node加入,然后调用_start_workers接口在当前node上启动worker process

while loop monitor worker:while循环,监控上一步启动process的状态

5 Rendezvous5.1 基本概念

Pytorch中Rendezvous的实现涉及到很多概念,我们这里先把这些概念一一介绍下,然后再介绍Rendezvous的实现这样会清晰很多。

首先是_RendezvousState,每个ElasticAgent上都会存储一份_RendezvousState,并会在必要时进行彼此间的同步,_RendezvousState存储的内容如下:

round: The current round of the rendezvous.

complete: A boolean value indicating whether the current round of the rendezvous is complete.

deadline: The time at which the current round of the rendezvous will be considered complete if it is still waiting for nodes to join.

closed: A boolean value indicating whether the rendezvous is closed.

participants: A dictionary of the participants and their corresponding ranks.

wait_list:A set of nodes that are waiting to participate in the next round of the rendezvous.

last_heartbeats: A dictionary containing each node's last heartbeat time.

那_RendezvousState是如何在所有ElasticAgent间进行同步的呢,Pytorch中又提出了Store的概念,在Pytorch中有TCPStore、FileStore和HashStore三种类型,在弹性训练场景,默认使用TCPStore。

TCPStore的典型用法如下:

其是一个典型的server-client架构,我们在process1上启动server,在proess2上启动client,通过TCPStore的set和get接口可以进行数据的设置和获取

在Rendezvous实现中即是通过TCPStore来对_RendezvousState进行设置和获取的。

import torch.distributed as distfrom datetime import timedelta# Run on process 1 (server)server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30))# Run on process 2 (client)client_store = dist.TCPStore("127.0.0.1", 1234, 2, False)# Use any of the store methods from either the client or server after initializationserver_store.set("first_key", "first_value")client_store.get("first_key")

Pytorch的Rendezvous实现中,通过C10dRendezvousBackend对TCPStore进行了封装,并提供了set_state和get_state接口,方便state的操作。(注:Pytorch中还提供了EtcdRendezvousBackend,该类型的RendezvousBackend通过Etcd来进行_RendezvousState的同步)。

C10dRendezvousBackend的主要实现如下,可以很清晰的看到get_state和set_state的实现,均是对store接口的调用.

class C10dRendezvousBackend(RendezvousBackend):    def get_state(self) -> Optional[Tuple[bytes, Token]]:        """See base class."""        base64_state: bytes = self._call_store("get", self._key)        return self._decode_state(base64_state)    def set_state(        self, state: bytes, token: Optional[Token] = None    ) -> Optional[Tuple[bytes, Token, bool]]:        """See base class."""        base64_state_str: str = b64encode(state).decode()        if token:            # Shortcut if we know for sure that the token is not valid.            if not isinstance(token, bytes):                result = self.get_state()                if result is not None:                    tmp = *result, False                    # Python 3.6 does not support tuple unpacking in return                    # statements.                    return tmp                return None            token = token.decode()        else:            token = self._NULL_SENTINEL        base64_state: bytes = self._call_store("compare_set", self._key, token, base64_state_str)        state_token_pair = self._decode_state(base64_state)        if state_token_pair is None:            return None        new_state, new_token = state_token_pair        # C10d Store's compare_set method does not offer an easy way to find out        # whether our write attempt was successful. As a brute-force solution we        # perform a bitwise comparison of our local state and the remote state.        return new_state, new_token, new_state == state        def _call_store(self, store_op: str, *args, **kwargs) -> Any:        try:            return getattr(self._store, store_op)(*args, **kwargs)        except (ValueError, RuntimeError, TimeoutError) as exc:            raise RendezvousConnectionError(                "The connection to the C10d store has failed. See inner exception for details."            ) from exc    

在RendezvousBackend的基础上,Pytorch提出了一个更偏向业务层面的概念**_RendezvousStateHolder**,其提供了_RendezvousState进行获取、同步、标记更新的接口,这些接口的实现均是调用RendezvousBackend的set_state和get_state完成的。

_RendezvousStateHolder的定义如下:

class _RendezvousStateHolder(ABC):    """Holds the shared rendezvous state synced with other nodes."""    def state(self) -> _RendezvousState:        """Gets the local state."""    def sync(self) -> Optional[bool]:        """Reads or writes the latest state.        Returns:            A boolean value indicating whether the local state, in case marked            as dirty, was successfully synced with other nodes.        """    def mark_dirty(self) -> None:        """Marks the local state as dirty."""

Rendezvous的基础设置都准备好了,状态在 _RendezvousState中保存,状态的同步通过 _RendezvousStateHolder来完成,此时还差一项,就是Rendezvous state的是如何变更的。这个变更通过 _RendezvousXXXOp和 _RendezvousOpExecutor共同来完成。

Pytorch首先提供了_RendezvousExitOp/_RendezvousJoinOp/_RendezvousCloseOp/_RendezvousKeepAliveOp来对应ElasticAgent的退出、加入、Rendezvous关闭和心跳保保持四个操作。这些OP的实现逻辑是根据OP的类型和当前_RendezvousState的内容来决定来返回一个action,_RendezvousOpExecutor则执行对应的action。

例如_RendezvousExitOp 对应ElasticAgent的退出操作

如果当前节点仍旧在participants列表中,则返回一个REMOVE_FROM_PARTICIPANTS,_RendezvousOpExecutor在接收到这个action后会执行_remove_from_participants逻辑;

如果当前节点没有在participants列表中,返回FINISH,这个状态_RendezvousOpExecutor不会做任何操作;

class _RendezvousExitOp:    """Represents a rendezvous exit operation."""    def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:        if ctx.node in ctx.state.participants:            if time.monotonic() > deadline:                return _Action.ERROR_TIMEOUT            return _Action.REMOVE_FROM_PARTICIPANTS        return _Action.FINISH     

_DistributedRendezvousOpExecutor的核心接口如下:

run提供了执行Rendezvous op的总入口

其他接口则对应了Rendezvous op返回的action的实现。这些action的实现本质上都是对_RendezvousState内容的修改,例如_mark_rendezvous_closed是将_RendezvousState的close字段设置为了True。

class _DistributedRendezvousOpExecutor:  def run(self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float,) -> None:  def _keep_alive(self) -> None:  def _add_to_participants(self)   def _add_to_wait_list(self)  def _remove_from_participants(self)  def _remove_from_wait_list(self)  def _mark_rendezvous_complete(self)  def _mark_rendezvous_closed(self):        self._state.closed = True

最后一个要介绍的概念是RendezvousHandler,其是Rendezvous系统最上层的对外接口,ElasticAgent通过该接口来在所有节点间进行协调。在Pytorch中提供了DynamicRendezvousHandler、EtcdRendezvousHandler和StaticTCPRendezvous三种实现,这里我们仅关注DynamicRendezvousHandler。

RendezvousHandler中最核心的接口是next_rendezvous,ElasticAgent会调用该接口来等待至少min个node的加入。他们实现我们后面再进行讲解。

上面介绍的这些概念,可以通过如下的关系图来进行描述。

5.2 实现逻辑

在熟系完Rendezvous的基本概念后,我们现在可以来看其实现逻辑了。

首先,我们看DynamicRendezvousHandler.next_rendezvous的实现逻辑(注:ElasticAgent通过调用该接口实现的node间的协调)。DynamicRendezvousHandler.next_rendezvous 一共由5个步骤组成:

DynamicRendezvousHandler._stop_heartbeats():停止先TCPStore的心跳操作,通过调用定时器_PeriodicTimer的cancel接口实现;

Execute Exit OP:执行退出逻辑,如果当前node已经在participants中了,则先把当前节点从_RendezvousState的participants列表中删除;

Execute Join OP: 下图仅描述了一个常规的场景,源码中还有一些特殊情况需要处理;

将自己加入到_RendezvousState的participants列表中;

向TCPStore发起心跳,等待至少min个node加入;

当_RendezvousState的participants的个数大于min时,mark rendezvous;

此时,Join OP执行完成,返回给_RendezvousOpExecutor 个Finish action;

DynamicRendezvousHandler._start_heartbeats(): 开启心跳,这个逻辑通过_PeriodicTimer定期执行_RendezvousKeepAliveOp实现;_RendezvousKeepAliveOp的操作则是对_RendezvousState的last_heartbeats进行更新来实现;

DynamicRendezvousHandler._get_world():从_RendezvousState中获取当前rank和work_size信息;

下面我们再看下Rendezvous的OP是如何执行的。上文提到OP是通过_DistributedRendezvousOpExecutor.run()接口统一来完成的。

主流程包裹在while循环中,直到OP的action为finish方可退出循环;

首先,会调用_BackendRendezvousStateHolder.sync()接口在所有node间进行_RendezvousState的同步;

Pytorch - 弹性训练原理(pytorch如何训练模型)

若当前node有内容需要更新,则调用C10dRendezvousBackend.set_state()来更新;若没有,则调用C10dRendezvousBackend.get_state()来获取最新的state;

若获取了最新的state,则对当前node上存储的state进行更新;

然后,调用当前需要执行的OP,OP接口会返回一个ACTION,_DistributedRendezvousOpExecutor则根据ACTION的内容执行keep_alive/add_to_participants/add_to_wait_list等操作;

6 Failover

Failover分为两种情况:

ElasticAgent Process正常,但是worker process 出错

ElasticAgent Process 异常退出

6.1 Worker Fail

对于worker fail的场景,worker process的异常状态会被ElasticAgent捕获,实现逻辑在SimpleElasticAgent的_invoke_run接口中。

该接口实现中会循环monitor 当前node上所有worker process的状态,如果process 异常,则会进行入UNHEALTHY/FAILED状态的处理流程。

如果当前重试的次数小于_remain_restart,则会发起restart worker的流程

restart worker的实现逻辑也很清晰: whaosoft aiot http://143ai.com

先stop 点前node上所有worker

然后重新走_initialize_workers逻辑来进行Rendezvous和start worker

    def _restart_workers(self, worker_group: WorkerGroup) -> None:        """        Restarts (stops, rendezvous, starts) all local workers in the group.        """        role = worker_group.spec.role        log.info(f"[{role}] Stopping worker group")        self._stop_workers(worker_group)        worker_group.state = WorkerState.STOPPED        self._initialize_workers(worker_group)6.2 ElasticAgent Fail

首先,我们看下当一个node Fail掉后,弹性训练是如何运行的。这有两个node:node0和node1,开始node0和node1同时进行分布式训练,当训练到一定时间后,我们将node1 kill掉。

这是node1上的日志:

[763] epoch 14 (rank = 4, local_rank = 0) loss = 1.2388396263122559[765] epoch 14 (rank = 6, local_rank = 2) loss = 1.4543075561523438[766] epoch 14 (rank = 7, local_rank = 3) loss = 1.0290627479553223[764] epoch 14 (rank = 5, local_rank = 1) loss = 1.1143463850021362^CTraceback (most recent call last):Traceback (most recent call last): File "/opt/conda/bin/torchrun", line 33, in <module> sys.exit(load_entry_point('torch==1.11.0', 'console_scripts', 'torchrun')()) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper return f(*args, **kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/run.py", line 724, in main run(args) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/run.py", line 715, in run elastic_launch( File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 131, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 236, in launch_agent result = agent.run() File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/metrics/api.py", line 125, in wrapper result = f(*args, **kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 709, in run result = self._invoke_run(role) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 850, in _invoke_run time.sleep(monitor_interval) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 60, in _terminate_process_handler raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)torch.distributed.elastic.multiprocessing.api.SignalException: Process 759 got signal: 2

这是node0上的日志,我们可以得出以下结论:

当Elastic Agent退出时,会导致其他存活的Elastic Agent中的process 运行失败;这是因为剩余process无法在正常进行collective communication了;

存活的Elastic Agent会按照UNHEALTHY/FAILED的处理逻辑来重启本机的worker;若失败的Elastic Agent没有重启,则剩余的Elastic Agent重新构建worker group继续进行训练,若失败的Elastic Agent重新启动(例如kubernetes中job提供重启的机制),则会重新加入到整个训练任务中;

# 1) 此时node0和node1共同进行分布式训练...[11762] epoch 14 (rank = 2, local_rank = 2) loss = 1.1763713359832764 [702/1958][11760] epoch 14 (rank = 0, local_rank = 0) loss = 1.324049949645996# 2) 此时node1被kill掉,因此当执行collective communication时,会报出异常[E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.terminate called after throwing an instance of 'std::runtime_error' what(): NCCL error: unhandled system error, NCCL version 21.0.3ncclSystemError: System call (socket, malloc, munmap, etc) failed.# 3)stop 其他三个processWARNING:torch.distributed.elastic.multiprocessing.api:Sending process 11761 closing signal SIGTERMWARNING:torch.distributed.elastic.multiprocessing.api:Sending process 11762 closing signal SIGTERMWARNING:torch.distributed.elastic.multiprocessing.api:Sending process 11763 closing signal SIGTERMERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 0 (pid: 11760) of binary: /opt/conda/bin/python# 4)重新走_initialize_workers逻辑[11828] Initializing process group with: {'MASTER_ADDR': 'iZ2ze9q3ftqtxtqlkrk6tuZ', 'MASTER_PORT': '40539', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}[11825] Initializing process group with: {'MASTER_ADDR': 'iZ2ze9q3ftqtxtqlkrk6tuZ', 'MASTER_PORT': '40539', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}[11826] Initializing process group with: {'MASTER_ADDR': 'iZ2ze9q3ftqtxtqlkrk6tuZ', 'MASTER_PORT': '40539', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}[11827] Initializing process group with: {'MASTER_ADDR': 'iZ2ze9q3ftqtxtqlkrk6tuZ', 'MASTER_PORT': '40539', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}[11827] (rank = 2, local_rank = 2) train worker starting...[11828] (rank = 3, local_rank = 3) train worker starting...[11825] (rank = 0, local_rank = 0) train worker starting...[11826] (rank = 1, local_rank = 1) train worker starting...# 5)node0 独自进行分布式训练load checkpoint from checkpoint.ptload checkpoint from checkpoint.ptload checkpoint from checkpoint.ptload checkpoint from checkpoint.pt[11826] epoch 14 (rank = 1, local_rank = 1) loss = 0.839302122592926[11828] epoch 14 (rank = 3, local_rank = 3) loss = 0.8971960544586182[11825] epoch 14 (rank = 0, local_rank = 0) loss = 1.33822691440582287 Scale Up/Down

Scale Down的可以理解为上文中Elastic Agent退出,但是没有重启的场景,因此这里不再赘述。

Scale UP这里要再介绍一下,Scale UP的流程仍旧可以用上图进行描述:

当有新的节点加入时,由于当前Elastic已经建立一个的Rendezvous,其无法加入,所以当前Node会被加入到_RendezvousState的wait_list中

当ElasticAgent和对应的worker process都正常运行时,monitor会返回Healthy的状态;此时,ElasticAgent会检查_RendezvousState的waiting list的node个数,发现waiting list大于0,则出发restart worker来发起新一轮的Rendezvous以将新的加入,这样新的Node加入到了worker group中;

二 \ 代码----

著名物理学家,诺贝尔奖得主Richard Feynman办公室的黑板上写了:"What I cannot create, I do not understand."。在程序员界也经常有"show me the code"的口号。因此,我打算写一系列的分布式训练的文章,将以往抽象的分布式训练的概念以代码的形式展现出来,并保证每个代码可执行、可验证、可复现,并贡献出来源码让大家相互交流。

经过调研发现pytorch对于分布式训练做好很好的抽象且接口完善,因此本系列文章将以pytorch为主要框架进行,文章中的例子很多都来自pytorch的文档,并在此基础上进行了调试和扩充。

最后,由于分布式训练的理论介绍网络上已经很多了,理论部分的介绍不会是本系列文章的重点,我会将重点放在代码层面的介绍上面。

Pytorch - 分布式训练极简体验:https://zhuanlan.zhihu.com/p/477073906

Pytorch - 分布式通信原语(附源码):https://zhuanlan.zhihu.com/p/478953028

Pytorch - 手写allreduce分布式训练(附源码):https://zhuanlan.zhihu.com/p/482557067

Pytorch - 算子间并行极简实现(附源码):https://zhuanlan.zhihu.com/p/483640235

Pytorch - 多机多卡极简实现(附源码):https://zhuanlan.zhihu.com/p/486130584

1. 介绍

Pytorch在1.9.0引入了torchrun,用其替代1.9.0以前版本的torch.distributed.launch。torchrun在torch.distributed.launch 功能的基础上主要新增了两个功能:

Failover: 当worker训练失败时,会自动重新启动所有worker继续进行训练;

Elastic: 可以动态增加或或删除node节点,本文将通过一个例子说明Elastic Training应该如何使用;

本例中会先在Node0上启动4 GPU的worker group ,等其训练一段时间后,会在Node1上再启动4 GPU的workers,并与Node1上的workers构成一个新的worker group,最终构成一个2机8卡的分布式训练。

2. 模型构建

一个简单的全连接模型神经网络模型

class ToyModel(nn.Module):    def __init__(self):        super(ToyModel, self).__init__()        self.net1 = nn.Linear(10, 10)        self.relu = nn.ReLU()        self.net2 = nn.Linear(10, 5)    def forward(self, x):        return self.net2(self.relu(self.net1(x)))3. checkpoint 处理

由于再每次增加或删除node时,会将所有worker kill掉,然后再重新启动所有worker进行训练。因此,在训练代码中要对训练的状态进行保存,以保证重启后能接着上次的状态继续训练。

需要保存的信息一般有如下内容:

model :模型的参数信息

optimizer :优化器的参数信心

epoch:当前执行到第几个epoch

save和load的代码如下所示

torch.save:利用python的pickle将python的object 进行序列化,并保存到本地文件;

torch.load : 将torch.save后的本地文件进行反序列化,并加载到内存中;

model.state_dict(): 存储了model 每个layer和其对应的param信息

optimizer.state_dict():存储了优化器的参数信信息

def save_checkpoint(epoch, model, optimizer, path):    torch.save({    "epoch": epoch,    "model_state_dict": model.state_dict(),    "optimize_state_dict": optimizer.state_dict(),}, path)def load_checkpoint(path):    checkpoint = torch.load(path)    return checkpoint4. 训练代码

初始化逻辑如下:

1~3行: 输出当前worker的关键环境变量,用于后面的结果展示

5~8行:创建模型、优化器和损失函数

10~12行:初始化参数信息

14~19行:如果存在checkpoint,则加载checkpoint,并赋值给model、optimizer和firt_epoch

    local_rank = int(os.environ["LOCAL_RANK"])    rank = int(os.environ["RANK"])    print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) train worker starting...")        model = ToyModel().cuda(local_rank)    ddp_model = DDP(model, [local_rank])    loss_fn = nn.MSELoss()    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)    optimizer.zero_grad()    max_epoch = 100    first_epoch = 0    ckp_path = "checkpoint.pt"        if os.path.exists(ckp_path):        print(f"load checkpoint from {ckp_path}")        checkpoint = load_checkpoint(ckp_path)        model.load_state_dict(checkpoint["model_state_dict"])        optimizer.load_state_dict(checkpoint["optimize_state_dict"])        first_epoch = checkpoint["epoch"]

训练逻辑:

1行:epoch执行的次数为first_epoch到max_epoch,以便能够在worker被重启后继续原有的epoch继续训练;

2行:为了展示动态添加node效果,这里添加sleep函数来降低训练的速度;

3~8行:模型训练流程;

9行:为了简单,文本每个epoch进行一次checkpoint保存;将当前的epoch,model和optimizer保存到checkpoint中;

    for i in range(first_epoch, max_epoch):        time.sleep(1) # 为了展示动态添加node效果,这里添加sleep函数来降低训练的速度        outputs = ddp_model(torch.randn(20, 10).to(local_rank))        labels = torch.randn(20, 5).to(local_rank)        loss = loss_fn(outputs, labels)        loss.backward()        print(f"[{os.getpid()}] epoch {i} (rank = {rank}, local_rank = {local_rank}) loss = {loss.item()}\n")        optimizer.step()        save_checkpoint(i, model, optimizer, ckp_path)5. 启动方式

由于我们使用torchrun来启动多机多卡任务,无需使用spawn接口来启动多个进程(torchrun会负责将我们的python script启动为一个process),因此直接调用上文编写的train函数,并在前后分别添加DistributedDataParallel的初始化和效果函数即可。

下面代码描述了上文train接口的调用。

def run():    env_dict = {        key: os.environ[key]        for key in ("MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "LOCAL_WORLD_SIZE")    }    print(f"[{os.getpid()}] Initializing process group with: {env_dict}")    dist.init_process_group(backend="nccl")    train()    dist.destroy_process_group()if __name__ == "__main__":    run()

本例中使用torchrun来执行多机多卡的分布式训练任务(注:torch.distributed.launch已经被pytorch淘汰了,尽量不要再使用)。启动脚本描述如下(注:node0和node1均通过该脚本进行启动)

--nnodes=1:3 :表示当前训练任务接受最少1个node,最多3个node参与分布式训练;

--nproc_per_node=4:表示每个node上节点有4个process

--max_restarts=3: worker group最大的重启次数;这里需要注意的是,node fail、node scale down和node scale up都会导致restart;

--rdzv_id=1:一个unique的job id,所有node均使用同一个job id;

--rdzv_backend: rendezvous的backend实现,默认支持c10d和etcd两种;rendezvous用于多个node之间的通信和协调;

--rdzv_endpoint:rendezvous的地址,应该为一个node的host ip和port;

torchrun \ --nnodes=1:3\ --nproc_per_node=4\ --max_restarts=3\ --rdzv_id=1\ --rdzv_backend=c10d\ --rdzv_endpoint="192.0.0.1:1234"\ train_elastic.py6. 结果分析

代码:BetterDL - train_elastic.py:https://github.com/tingshua-yts/BetterDL/blob/master/test/pytorch/DDP/train_elastic.py

运行环境: 2台4卡 v100机器

image: pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtimegpu: v100

先在node0上执行执行启动脚本

torchrun \ --nnodes=1:3\ --nproc_per_node=4\ --max_restarts=3\ --rdzv_id=1\ --rdzv_backend=c10d\ --rdzv_endpoint="192.0.0.1:1234"\ train_elastic.py

得到如下结果

2~5行:当前启动的是单机4卡的训练任务,因此WORLD_SIZE为4, LOCAL_WORKD_SIZE也为4

6~9行:共有4个rank参与了分布式训练,rank0~rank3

10~18行: rank0~rank3 均从epoch=0开始训练

r/workspace/DDP# sh run_elastic.sh[4031] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}[4029] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}[4030] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}[4032] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}[4029] (rank = 0, local_rank = 0) train worker starting...[4030] (rank = 1, local_rank = 1) train worker starting...[4032] (rank = 3, local_rank = 3) train worker starting...[4031] (rank = 2, local_rank = 2) train worker starting...[4101] epoch 0 (rank = 1, local_rank = 1) loss = 0.9288564920425415[4103] epoch 0 (rank = 3, local_rank = 3) loss = 0.9711472988128662[4102] epoch 0 (rank = 2, local_rank = 2) loss = 1.0727070569992065[4100] epoch 0 (rank = 0, local_rank = 0) loss = 0.9402943253517151[4100] epoch 1 (rank = 0, local_rank = 0) loss = 1.0327017307281494[4101] epoch 1 (rank = 1, local_rank = 1) loss = 1.4485043287277222[4103] epoch 1 (rank = 3, local_rank = 3) loss = 1.0959293842315674[4102] epoch 1 (rank = 2, local_rank = 2) loss = 1.0669530630111694...

在node1上执行与上面相同的脚本

torchrun \ --nnodes=1:3\ --nproc_per_node=4\ --max_restarts=3\ --rdzv_id=1\ --rdzv_backend=c10d\ --rdzv_endpoint="192.0.0.1:1234"\ train_elastic.py

node1上结果如下:

2~5行:由于添加node1,当前执行的是2机8卡的分布式训练任务,因此WORLD_SIZE=8, LOCAL_WORLD_SIZE=4

6~9行:当前node1上workers的rank为rank4 ~rank7

13~20行: 由于node1是在node0上work训练到epoch35的时候加入的,因此其接着epoch 35开始训练

/workspace/DDP# sh run_elastic.sh[696] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}[697] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}[695] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}[694] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}[697] (rank = 7, local_rank = 3) train worker starting...[695] (rank = 5, local_rank = 1) train worker starting...[694] (rank = 4, local_rank = 0) train worker starting...[696] (rank = 6, local_rank = 2) train worker starting...load checkpoint from checkpoint.ptload checkpoint from checkpoint.ptload checkpoint from checkpoint.ptload checkpoint from checkpoint.pt[697] epoch 35 (rank = 7, local_rank = 3) loss = 1.1888569593429565[694] epoch 35 (rank = 4, local_rank = 0) loss = 0.8916441202163696[695] epoch 35 (rank = 5, local_rank = 1) loss = 1.5685604810714722[696] epoch 35 (rank = 6, local_rank = 2) loss = 1.11683189868927[696] epoch 36 (rank = 6, local_rank = 2) loss = 1.3724170923233032[694] epoch 36 (rank = 4, local_rank = 0) loss = 1.061527967453003[695] epoch 36 (rank = 5, local_rank = 1) loss = 0.96876460313797[697] epoch 36 (rank = 7, local_rank = 3) loss = 0.8060566782951355...

node0上结果如下:

6~9行: node0上的works在执行到epoch 35时,node1上执行了训练脚本,请求加入到训练任务中

10~13行:所有workers重新启动,由于添加了node1,当前执行的是2机8卡的分布式训练任务,因此WORLD_SIZE=8, LOCAL_WORLD_SIZE=4

14~17行:当前node1上works的rank为rank0~rank3

18~21行:加载checkpoint

22~30行:接着checkpoint中的model、optimizer和epoch继续训练

...[4100] epoch 35 (rank = 0, local_rank = 0) loss = 1.0746158361434937[4101] epoch 35 (rank = 1, local_rank = 1) loss = 1.1712706089019775[4103] epoch 35 (rank = 3, local_rank = 3) loss = 1.1774182319641113[4102] epoch 35 (rank = 2, local_rank = 2) loss = 1.0898035764694214WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4100 closing signal SIGTERMWARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4101 closing signal SIGTERMWARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4102 closing signal SIGTERMWARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4103 closing signal SIGTERM[4164] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}[4165] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}[4162] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}[4163] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}[4162] (rank = 0, local_rank = 0) train worker starting...[4163] (rank = 1, local_rank = 1) train worker starting...[4164] (rank = 2, local_rank = 2) train worker starting...[4165] (rank = 3, local_rank = 3) train worker starting...load checkpoint from checkpoint.ptload checkpoint from checkpoint.ptload checkpoint from checkpoint.ptload checkpoint from checkpoint.pt[4165] epoch 35 (rank = 3, local_rank = 3) loss = 1.3437936305999756[4162] epoch 35 (rank = 0, local_rank = 0) loss = 1.5693414211273193[4163] epoch 35 (rank = 1, local_rank = 1) loss = 1.199862003326416[4164] epoch 35 (rank = 2, local_rank = 2) loss = 1.0465545654296875[4163] epoch 36 (rank = 1, local_rank = 1) loss = 0.9741991758346558[4162] epoch 36 (rank = 0, local_rank = 0) loss = 1.3609280586242676[4164] epoch 36 (rank = 2, local_rank = 2) loss = 0.9585908055305481[4165] epoch 36 (rank = 3, local_rank = 3) loss = 0.9169824123382568...
本文链接地址:https://www.jiuchutong.com/zhishi/294462.html 转载请保留说明!

上一篇:莫尔国家公园中的天蚕蛾,加纳拉拉班加 (© Robert Thompson/Minden Pictures)(莫尔道嘎湿地公园)

下一篇:三万字硬核详解:yolov1、yolov2、yolov3、yolov4、yolov5、yolov7(三万个字多久写完)

  • 抖音几分钟前在线可以进行设置吗(抖音几分钟在线和今天在线什么区别)

    抖音几分钟前在线可以进行设置吗(抖音几分钟在线和今天在线什么区别)

  • 笔记本摔了一下会坏吗(笔记本摔了一下会内伤吗)

    笔记本摔了一下会坏吗(笔记本摔了一下会内伤吗)

  • 手机处理器太低能换吗(手机处理器太低玩游戏卡怎么办)

    手机处理器太低能换吗(手机处理器太低玩游戏卡怎么办)

  • v1732a是什么型号(v1732a是什么型号怎么换屏)

    v1732a是什么型号(v1732a是什么型号怎么换屏)

  • 拼多多能不能一次买几样东西(拼多多能不能一个店铺卖多样东西)

    拼多多能不能一次买几样东西(拼多多能不能一个店铺卖多样东西)

  • 华为nova7手机耳机孔在哪里(华为nova7手机耳机插哪)

    华为nova7手机耳机孔在哪里(华为nova7手机耳机插哪)

  • 强制启用4xmsaa是什么意思(强制启用4xmsaa什么用)

    强制启用4xmsaa是什么意思(强制启用4xmsaa什么用)

  • 苹果停用连接itunes怎么办恢复好要密码吗(苹果停用连接itunes会不会删除手机内存)

    苹果停用连接itunes怎么办恢复好要密码吗(苹果停用连接itunes会不会删除手机内存)

  • 华为双系统占内存吗(华为双系统占用内存吗?)

    华为双系统占内存吗(华为双系统占用内存吗?)

  • iphone11来电不显示通讯录名字(苹果11来电不显示号码怎么办)

    iphone11来电不显示通讯录名字(苹果11来电不显示号码怎么办)

  • 如何制作抖音短视频(如何制作抖音短视频照片)

    如何制作抖音短视频(如何制作抖音短视频照片)

  • 新办的手机号是别人以前用过的(新办的手机号是别人以前用过的,怎么办)

    新办的手机号是别人以前用过的(新办的手机号是别人以前用过的,怎么办)

  • 剪映两个视频怎么合成一个(剪映两个视频怎么一上一下)

    剪映两个视频怎么合成一个(剪映两个视频怎么一上一下)

  • 手机卡lte网络是什么意思(手机卡出现lte)

    手机卡lte网络是什么意思(手机卡出现lte)

  • bn37电池是小米几的电池(小米电池型号bn31是什么手机)

    bn37电池是小米几的电池(小米电池型号bn31是什么手机)

  • imel是什么意思(imal是什么意思)

    imel是什么意思(imal是什么意思)

  • 电容mf是什么意思(电容的mf是什么意思)

    电容mf是什么意思(电容的mf是什么意思)

  • word文档中的英文字体(word文档中的英文字体咋样变换字体)

    word文档中的英文字体(word文档中的英文字体咋样变换字体)

  • vivoz5x手机怎么截屏(vivoz5x手机怎么样值得入手吗)

    vivoz5x手机怎么截屏(vivoz5x手机怎么样值得入手吗)

  • 抖音号注销了会怎么样(抖音号注销了会显示什么状态)

    抖音号注销了会怎么样(抖音号注销了会显示什么状态)

  • 菜鸟裹裹关联是双方能看到吗(菜鸟裹裹关联是相互的吗)

    菜鸟裹裹关联是双方能看到吗(菜鸟裹裹关联是相互的吗)

  • iphone8plus支持nfc吗(iphone8plus支持nfc给公交卡充值吗)

    iphone8plus支持nfc吗(iphone8plus支持nfc给公交卡充值吗)

  • 滴滴费用怎么算(滴滴车费计算公式)

    滴滴费用怎么算(滴滴车费计算公式)

  • 3300毫安电池能用多久(3300毫安电池能玩多久)

    3300毫安电池能用多久(3300毫安电池能玩多久)

  • 交管12123验证方式错误怎么办(交管12123验证失败是怎么回事)

    交管12123验证方式错误怎么办(交管12123验证失败是怎么回事)

  • 应交税费转结哪里去
  • 金蝶能够反年结账吗
  • 代扣代缴增值税是什么意思
  • 财务管理考试时间多长
  • 委托加工物资的会计科目
  • 周转材料登三栏式明细账吗
  • 幼儿园申报税种及税率
  • 存货跌价准备计提原则
  • 自己提供原材料让别人加工
  • 逾期支付工程款利息计算
  • 分公司注销存货处理
  • 收入跨期审计调整分录如何滚调
  • 公司进行债务重构的原因可能包括
  • 公司法人信息变更是先去税务局还是先去银行
  • 工程交税必须在工程地点交吗
  • 培训费用开具什么发票
  • 建安业一般纳税人企业所得税率是多少
  • 工会没有税号怎么开普票
  • 印花税滞纳金计算方法
  • 车辆保险车船税每年交多少
  • 出口专用发票可抵扣吗
  • 资金清算款项
  • 限售股转让所得
  • 税前金额是不含税金额
  • 如果被客户骗了货款怎么办
  • 固定资产出租需要交什么税
  • window10环境变量
  • 无形资产如何评估作价
  • 车船使用税进哪个会计科目
  • 计提劳务派遣人员社保收到发票后没有付款的会计分录
  • 股票的交易费用是怎么算的
  • 跨期发票怎么作废
  • 电脑卡慢咋办
  • 被白雪覆盖的彩虹歌词
  • 暂估价是单价还是总价
  • 不能抵扣的福利发票要勾选吗为什么
  • 公司从银行提取现金4000元备作零星开支
  • 主营业务成本和生产成本的关系
  • 什么时候工程物资什么时候在建工程
  • 一般纳税人混凝土税率
  • 外购存货的账务处理
  • 发出商品 会计科目
  • 公司投资者如何避免风险
  • 生产企业电费怎么做账
  • 管理费用怎样分摊归集到产品
  • 企业出售产品
  • 网上购物退款后未退回物品怎么投诉
  • 飞机票保险发票是什么样子的
  • 一般纳税人提供劳务税率是多少
  • 一次性收入怎么计税
  • 如何申请高新技术企业认定
  • 工业企业的材料销售收入应计入什么
  • mysql数据库内存占用高
  • 修改mysql配置的两种方法
  • sql转换
  • 远程管理是什么意思
  • brasil.exe是什么进程
  • xp注册表损坏怎么修复
  • 文件夹隐藏属性怎么弄
  • macbookair如何删除
  • .exe是什么意思
  • 2021年win10累积更新
  • linux下tar.gz、tar、bz2、zip等解压缩、压缩命令小结
  • win10周年版
  • fragment懒加载原理
  • opengl绘制地面
  • win10自带终端
  • ms-dos7.10如何安装
  • python 入门指南
  • unity 技术
  • nodejs.
  • flask框架菜鸟教程
  • u3d脚本语言
  • android开发之apritag
  • javascript基础入门视频教程
  • js实现拖拽元素改编顺序
  • javascript面向对象编程指南
  • 稽查查补税款享受增值税免税优惠吗
  • 电子税务局官网湖南省
  • 东莞国税咨询电话
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设