|
- import torch
- import torch.distributed as dist
- from torch import nn
- from torch.utils.dlpack import to_dlpack
- from torch.utils.dlpack import from_dlpack
- import time
- import numpy as np
-
- class BirderState:
- def __init__(self, beta, comp_flag=False, record_time=False, packbits_by_cupy=False, hierarchy=False, all_gather_by_chunks=False):
- self.beta = beta
- self.comp_flag = comp_flag
- self.comp_delta = {}
- self.comp_delta_server = {}
- self.m = {}
- self.v = {}
- self.index_flag = set()
- self.not_first_iter_flag = False
- self.record_time = record_time
- self.time_counter = {}
- self.packbits_by_cupy = packbits_by_cupy
- if self.packbits_by_cupy:
- import cupy
- self.hierarchy = hierarchy
- self.all_gather_by_chunks = all_gather_by_chunks
- self.new_dist_groups()
-
-
- def new_dist_groups(self):
- if self.hierarchy:
- stage_1_cur_subgroup, stage_1_subgroups = dist.new_subgroups()
- self.stage_1_cur_subgroup = stage_1_cur_subgroup
- self.stage_1_subgroups = stage_1_subgroups
- intra_size = torch.cuda.device_count()
- word_size = dist.get_world_size()
- inter_size = word_size // intra_size
- stage_2_cur_subgroup, stage_2_subgroups = dist.new_subgroups_by_enumeration(
- ranks_per_subgroup_list=[[i+j*intra_size for j in range(inter_size)]for i in range(intra_size)]
- )
- self.stage_2_cur_subgroup = stage_2_cur_subgroup
- self.stage_2_subgroups = stage_2_subgroups
-
- def state_dict(self):
- return {'comp_delta': self.comp_delta,
- 'comp_delta_server': self.comp_delta_server,
- 'm': self.m,
- 'v': self.v
- }
-
- def load_state_dict(self, state_dict):
- self.comp_delta = state_dict['comp_delta']
- self.comp_delta_server = state_dict['comp_delta_server']
- self.m = state_dict['m']
- self.v = state_dict['v']
-
-
-
- def _packbits(torch_tensor, by_cupy=False):
- if by_cupy:
- cupy_device = torch_tensor.device.index
- with cupy.cuda.Device(cupy_device):
- cupy_tensor = cupy.fromDlpack(to_dlpack(torch_tensor))
- cupy_packed_tensor = cupy.packbits(cupy_tensor)
- torch_packed_tensor = from_dlpack(cupy_packed_tensor.toDlpack())
- return torch_packed_tensor
-
- else:
- empty_flag = torch_tensor.shape[0] % 8
- if empty_flag:
- empty_tensor = torch.zeros(8 - empty_flag, device=torch_tensor.device, dtype=torch_tensor.dtype)
- torch_tensor = torch.cat([torch_tensor, empty_tensor])
- mask = torch.tensor([2**(7-i) for i in range(8)], device=torch_tensor.device, dtype=torch_tensor.dtype)
- torch_packed_tensor = torch.sum(torch_tensor.view(-1, 8) * mask, axis=1, dtype=torch_tensor.dtype)
- return torch_packed_tensor
-
- def _unpackbits(torch_packed_tensor, by_cupy=False):
- if by_cupy:
- cupy_device = torch_packed_tensor.device.index
- with cupy.cuda.Device(cupy_device):
- cupy_packed_tensor = cupy.fromDlpack(to_dlpack(torch_packed_tensor))
- cupy_tensor = cupy.unpackbits(cupy_packed_tensor)
- torch_tensor = from_dlpack(cupy_tensor.toDlpack())
- return torch_tensor
- else:
- mask = torch.tensor([2**(7-i) for i in range(8)], device=torch_packed_tensor.device, dtype=torch_packed_tensor.dtype)
- torch_tensor = torch.bitwise_and(torch_packed_tensor.view(-1, 1).expand(-1, 8), mask).to(torch.bool).to(torch_packed_tensor.dtype)
- return torch_tensor.view(-1)
-
-
- def _quantize_onebit_tensor_cuda(state, tensor, index, by_cupy=False):
- if index not in state.m:
- state.m[index] = tensor * (1 - state.beta)
- state.v[index] = tensor.abs() * (1 - state.beta)
- else:
- state.m[index].mul_(state.beta).add_(tensor, alpha=(1 - state.beta))
- state.v[index].mul_(state.beta).add_(tensor.abs(), alpha=(1 - state.beta))
- norm_factor = (state.m[index] / (state.v[index] + 1.0e-8) + 1) / 2
- if state.comp_flag:
- if index not in state.comp_delta:
- distributed_tensor = torch.bernoulli(norm_factor)
- state.comp_delta[index] = (norm_factor - distributed_tensor)
- else:
- distributed_tensor = torch.bernoulli((norm_factor + state.comp_delta[index]).clamp(0, 1))
- state.comp_delta[index].add_(norm_factor - distributed_tensor)
- else:
- distributed_tensor = torch.bernoulli(norm_factor)
- compressed_tensor = distributed_tensor.to(torch.uint8)
- compressed_tensor = _packbits(compressed_tensor, by_cupy)
-
- return compressed_tensor
-
- def _quantize_onebit_tensor_cuda_server(state, tensor, index, by_cupy=False):
- norm_factor = (tensor + 1) / 2
- if state.comp_flag:
- if index not in state.comp_delta_server:
- distributed_tensor = torch.bernoulli(norm_factor)
- state.comp_delta_server[index] = (norm_factor - distributed_tensor)
- else:
- distributed_tensor = torch.bernoulli((norm_factor + state.comp_delta_server[index]).clamp(0, 1))
- state.comp_delta_server[index].add_(norm_factor - distributed_tensor)
- else:
- distributed_tensor = torch.bernoulli(norm_factor)
- compressed_tensor = distributed_tensor.to(torch.uint8)
- compressed_tensor = _packbits(compressed_tensor, by_cupy)
-
- return compressed_tensor
-
-
- def _dequantize_onebit_tensor_cuda(torch_packed_tensor, shape, by_cupy=False):
-
- torch_tensor = _unpackbits(torch_packed_tensor, by_cupy)
- return torch_tensor[:shape]
-
- def _get_allgather_out_list(all_gather_in_list, world_size):
- out_list = [
- torch.zeros_like(
- all_gather_in_list,
- device=all_gather_in_list.device,
- dtype=all_gather_in_list.dtype,
- )
- for _ in range(world_size)
- ]
- return out_list
-
-
- def quantization_birder_hook(state: BirderState, bucket):
- if state.record_time:
- time_list = []
- time_list.append(time.time())
-
- group_to_use = dist.group.WORLD
- world_size = group_to_use.size()
- bucket_tensor = bucket.buffer()
- bucket_index = bucket.index()
- if state.hierarchy:
- bucket_tensor_comp_shape = bucket_tensor.shape[0]
- empty_flag = bucket_tensor.shape[0] % (dist.get_world_size() * 8)
- if empty_flag:
- bucket_tensor_comp_shape += (dist.get_world_size() * 8) - empty_flag
- stage_1_comp_shape = bucket_tensor_comp_shape // dist.get_world_size(group=state.stage_1_cur_subgroup)
- stage_2_comp_shape = stage_1_comp_shape // dist.get_world_size(group=state.stage_2_cur_subgroup)
-
- if not state.not_first_iter_flag:
- if bucket_index in state.index_flag:
- state.not_first_iter_flag = True
- else:
- state.index_flag.add(bucket_index)
-
- if state.record_time:
- time_list.append(time.time())
-
- def first_iter_avg(fut):
- decompressed_tensor = bucket.buffer()
- tensor = fut.value()[0]
- decompressed_tensor.copy_(tensor)
- decompressed_tensor.div_(dist.group.WORLD.size())
- return decompressed_tensor
-
- def dequantize_and_aggregate(fut):
- decompressed_tensor = bucket.buffer()
- if dist.get_backend() == 'nccl':
- all_ranks_quantized_tensor = fut.wait()[0]
- if state.record_time or state.packbits_by_cupy:
- torch.cuda.synchronize()
- else:
- assert dist.get_backend() == 'gloo'
- all_ranks_quantized_tensor = fut.value()
- if state.record_time:
- time_list.append(time.time())
- aggregated_dequantized_tensor = torch.zeros_like(
- decompressed_tensor, device=decompressed_tensor.device, dtype=torch.int32
- )
- for quantized_tensor in all_ranks_quantized_tensor:
- aggregated_dequantized_tensor += _dequantize_onebit_tensor_cuda(quantized_tensor, decompressed_tensor.shape[0], by_cupy=state.packbits_by_cupy)
- if state.record_time:
- time_list.append(time.time())
- decompressed_tensor.copy_(aggregated_dequantized_tensor)
- decompressed_tensor.mul_(2.0 / dist.group.WORLD.size()).sub_(1.0)
- if state.record_time:
- time_list.append(time.time())
- if bucket_index not in state.time_counter:
- state.time_counter[bucket_index] = np.diff(np.array(time_list))
- else:
- state.time_counter[bucket_index] += np.diff(np.array(time_list))
-
- return decompressed_tensor
-
- def stage_2_process(fut):
-
- assert dist.get_backend() == 'nccl'
- stage_2_input = fut.wait()[0]
- if state.record_time or state.packbits_by_cupy:
- torch.cuda.synchronize()
- stage_2_input = _quantize_onebit_tensor_cuda(state, stage_2_input, bucket_index, by_cupy=state.packbits_by_cupy)
-
- cur_subgroup = state.stage_2_cur_subgroup
- cur_size = dist.get_world_size(group=cur_subgroup)
- stage_2_out_list = _get_allgather_out_list(stage_2_input, cur_size)
-
- fut = dist.all_gather(
- stage_2_out_list,
- stage_2_input,
- group=cur_subgroup,
- async_op=True,
- ).get_future()
-
- return fut.wait()
-
- def stage_3_process(fut):
-
- aggregated_dequantized_tensor = torch.zeros(
- stage_1_comp_shape, device=bucket_tensor.device, dtype=bucket_tensor.dtype
- )
-
- assert dist.get_backend() == 'nccl'
- quantized_tensor_list = fut.wait()[0]
- if state.record_time or state.packbits_by_cupy:
- torch.cuda.synchronize()
-
- for quantized_tensor in quantized_tensor_list:
- aggregated_dequantized_tensor += _dequantize_onebit_tensor_cuda(quantized_tensor, stage_1_comp_shape, by_cupy=state.packbits_by_cupy)
- if state.record_time:
- time_list.append(time.time())
- aggregated_dequantized_tensor.mul_(2.0 / dist.get_world_size(group=state.stage_2_cur_subgroup)).sub_(1.0)
- if state.record_time:
- time_list.append(time.time())
- if bucket_index not in state.time_counter:
- state.time_counter[bucket_index] = np.diff(np.array(time_list))
- else:
- state.time_counter[bucket_index] += np.diff(np.array(time_list))
-
- cur_subgroup = state.stage_1_cur_subgroup
- cur_size = dist.get_world_size(group=cur_subgroup)
- stage_3_input = aggregated_dequantized_tensor
- stage_3_output_list = _get_allgather_out_list(stage_3_input, cur_size)
-
- fut = dist.all_gather(
- stage_3_output_list,
- stage_3_input,
- group=cur_subgroup,
- async_op=True,
- ).get_future()
- return fut.wait()
-
- def stage_4_process(fut):
- decompressed_tensor = bucket.buffer()
-
- final_comp_tensor = fut.wait()[0]
- final_comp_tensor = final_comp_tensor.view(-1)[:decompressed_tensor.shape[0]]
- decompressed_tensor.copy_(final_comp_tensor)
-
- return decompressed_tensor
-
-
- def stage_2_all_gather_by_chunks_process(fut):
- assert dist.get_backend() == 'nccl'
- stage_2_input = fut.wait()[0]
-
- if state.record_time or state.packbits_by_cupy:
- torch.cuda.synchronize()
- stage_2_input = _quantize_onebit_tensor_cuda(state, stage_2_input, bucket_index, by_cupy=state.packbits_by_cupy)
- stage_2_output = torch.zeros_like(
- stage_2_input, device=stage_2_input.device, dtype=stage_2_input.dtype
- )
- cur_subgroup = state.stage_2_cur_subgroup
-
- fut = dist.all_to_all_single(
- stage_2_output,
- stage_2_input,
- group=cur_subgroup,
- async_op=True,
- ).get_future()
- return fut.wait()
-
- def stage_3_all_gahter_by_chunks_process(fut):
-
- cur_subgroup = state.stage_2_cur_subgroup
- cur_size = dist.get_world_size(group=cur_subgroup)
- assert dist.get_backend() == 'nccl'
- quantized_tensor_list = fut.wait()[0]
-
- quantized_tensor_list = quantized_tensor_list.chunk(cur_size)
- if state.record_time or state.packbits_by_cupy:
- torch.cuda.synchronize()
-
- aggregated_dequantized_tensor = torch.zeros(
- stage_2_comp_shape, device=bucket_tensor.device, dtype=bucket_tensor.dtype
- )
- for quantized_tensor in quantized_tensor_list:
- aggregated_dequantized_tensor += _dequantize_onebit_tensor_cuda(quantized_tensor, stage_2_comp_shape, by_cupy=state.packbits_by_cupy)
-
- if state.record_time:
- time_list.append(time.time())
- aggregated_dequantized_tensor.mul_(2.0 / dist.get_world_size(group=state.stage_2_cur_subgroup)).sub_(1.0)
-
- if state.record_time:
- time_list.append(time.time())
- if bucket_index not in state.time_counter:
- state.time_counter[bucket_index] = np.diff(np.array(time_list))
- else:
- state.time_counter[bucket_index] += np.diff(np.array(time_list))
-
- stage_3_input = _quantize_onebit_tensor_cuda_server(state, aggregated_dequantized_tensor, bucket_index, by_cupy=state.packbits_by_cupy)
- stage_3_output_list = _get_allgather_out_list(stage_3_input, cur_size)
-
- fut = dist.all_gather(
- stage_3_output_list,
- stage_3_input,
- group=cur_subgroup,
- async_op=True,
- ).get_future()
- return fut.wait()
-
- def stage_4_all_gather_by_chunks_process(fut):
- cur_subgroup = state.stage_1_cur_subgroup
- cur_size = dist.get_world_size(group=cur_subgroup)
-
- assert dist.get_backend() == 'nccl'
- quantized_tensor_list = fut.wait()[0]
- if state.record_time or state.packbits_by_cupy:
- torch.cuda.synchronize()
-
- all_aggregated_dequantized_tensor = torch.zeros(
- stage_1_comp_shape, device=bucket_tensor.device, dtype=bucket_tensor.dtype
- )
- aggregated_dequantized_tensor_list = all_aggregated_dequantized_tensor.chunk(
- dist.get_world_size(group=state.stage_2_cur_subgroup)
- )
-
- for quantized_tensor, aggregated_dequantized_tensor in zip(quantized_tensor_list, aggregated_dequantized_tensor_list):
- aggregated_dequantized_tensor.copy_(
- _dequantize_onebit_tensor_cuda(quantized_tensor, stage_2_comp_shape, by_cupy=state.packbits_by_cupy)
- )
-
- all_aggregated_dequantized_tensor.mul_(2.0).sub_(1.0)
- stage_4_input = all_aggregated_dequantized_tensor
-
- stage_4_output_list = _get_allgather_out_list(stage_4_input, cur_size)
-
- fut = dist.all_gather(
- stage_4_output_list,
- stage_4_input,
- group=cur_subgroup,
- async_op=True,
- ).get_future()
- return fut.wait()
-
-
-
-
- if state.not_first_iter_flag:
- if state.hierarchy:
-
- cur_subgroup = state.stage_1_cur_subgroup
- cur_size = dist.get_world_size(group=cur_subgroup)
- cur_rank = dist.get_rank(group=cur_subgroup)
- stage_1_input = bucket.buffer()
- stage_1_input.div_(cur_size)
- empty_flag = stage_1_input.shape[0] % (dist.get_world_size() * 8)
- if empty_flag:
- empty_tensor = torch.zeros((dist.get_world_size() * 8) - empty_flag, device=stage_1_input.device, dtype=stage_1_input.dtype)
- stage_1_input = torch.cat([stage_1_input, empty_tensor])
-
- stage_1_input_list = list(stage_1_input.chunk(cur_size))
- stage_1_output = torch.zeros(
- stage_1_comp_shape, device=stage_1_input.device, dtype=stage_1_input.dtype
- )
-
- fut = dist.reduce_scatter(stage_1_output, stage_1_input_list, group=cur_subgroup, async_op=True).get_future()
-
- if state.all_gather_by_chunks:
- return fut.then(stage_2_all_gather_by_chunks_process).then(stage_3_all_gahter_by_chunks_process).then(stage_4_all_gather_by_chunks_process).then(stage_4_process)
- else:
- return fut.then(stage_2_process).then(stage_3_process).then(stage_4_process)
-
- else:
- quantized_tensor = _quantize_onebit_tensor_cuda(state, bucket_tensor, bucket_index, by_cupy=state.packbits_by_cupy)
- if state.record_time:
- time_list.append(time.time())
-
- out_list = _get_allgather_out_list(quantized_tensor, world_size)
- fut = dist.all_gather(
- out_list,
- quantized_tensor,
- group=group_to_use,
- async_op=True,
- ).get_future()
- return fut.then(dequantize_and_aggregate)
- else:
- fut = dist.all_reduce(bucket_tensor, group=group_to_use, async_op=True).get_future()
- return fut.then(first_iter_avg)
-
-
-
- class SGDState:
- def __init__(self, record_time=False):
- self.record_time = record_time
- self.time_counter = {}
-
-
- def my_allreduce_hook(state, bucket):
- if state.record_time:
- time_list = []
- time_list.append(time.time())
- tensor = bucket.buffer()
-
- bucket_index = bucket.index()
- group_to_use = dist.group.WORLD
-
- if state.record_time:
- time_list.append(time.time())
-
- tensor.div_(group_to_use.size())
- def count_time(fut):
- ar_tensor = bucket.buffer()
- fut_tensor = fut.value()[0]
- if state.record_time:
- if dist.get_backend() == 'nccl':
- torch.cuda.synchronize()
- time_list.append(time.time())
- ar_tensor.copy_(fut_tensor)
- if state.record_time:
- time_list.append(time.time())
- if bucket_index not in state.time_counter:
- state.time_counter[bucket_index] = np.diff(np.array(time_list))
- else:
- state.time_counter[bucket_index] += np.diff(np.array(time_list))
- return ar_tensor
- return (
- dist.all_reduce(tensor, group=group_to_use, async_op=True)
- .get_future()
- .then(count_time)
- )
-
|