import argparse
import osimport json5
import numpy as np
import torch
from torch.utils.data import DataLoader
from util.utils import initialize_configdef main(config, resume):# 意思是定义了一个函数main,这个函数里面有两个需要玩家提供的参数,config和resumetorch.manual_seed(config["seed"]) # for both CPU and GPU #是在调用json文件中预定义的参数np.random.seed(config["seed"])#2. torch.manual_seed()是PyTorch库中的一个函数,设置PyTorch的随机数生成器种子#numpy随机种子# 设置种子前 - 每次运行结果不同
# torch.rand(3) # 第一次:tensor([0.4387, 0.0385, 0.9119])
# torch.rand(3) # 第二次:tensor([0.1345, 0.7892, 0.6543])
# 设置种子后 - 每次运行结果相同
# torch.manual_seed(42)
# torch.rand(3) # 第一次:tensor([0.8823, 0.9150, 0.3829])
# torch.manual_seed(42)
# torch.rand(3) # 第二次:tensor([0.8823, 0.9150, 0.3829]) # 相同!train_dataloader = DataLoader(dataset=initialize_config(config["train_dataset"]),batch_size=config["train_dataloader"]["batch_size"],num_workers=config["train_dataloader"]["num_workers"],shuffle=config["train_dataloader"]["shuffle"],pin_memory=config["train_dataloader"]["pin_memory"])valid_dataloader = DataLoader(dataset=initialize_config(config["validation_dataset"]),num_workers=1,batch_size=1)model = initialize_config(config["model"])optimizer = torch.optim.Adam(params=model.parameters(),lr=config["optimizer"]["lr"],betas=(config["optimizer"]["beta1"], config["optimizer"]["beta2"]))loss_function = initialize_config(config["loss_function"])trainer_class = initialize_config(config["trainer"], pass_args=False)trainer = trainer_class(config=config,resume=resume,model=model,loss_function=loss_function,optimizer=optimizer,train_dataloader=train_dataloader,validation_dataloader=valid_dataloader)trainer.train()if __name__ == '__main__':parser = argparse.ArgumentParser(description="Wave-U-Net for Speech Enhancement")parser.add_argument("-C", "--configuration", required=True, type=str, help="Configuration (*.json).")parser.add_argument("-R", "--resume", action="store_true", help="Resume experiment from latest checkpoint.")args = parser.parse_args()configuration = json5.load(open(args.configuration))configuration["experiment_name"], _ = os.path.splitext(os.path.basename(args.configuration))configuration["config_path"] = args.configurationmain(configuration, resume=args.resume)
下面是Unet:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass DownSamplingLayer(nn.Module):#定义一个类叫下采样模块,里面需要人为提供nn. Module,继承自 nn.Moduledef __init__(self, channel_in, channel_out, dilation=1, kernel_size=15, stride=1, padding=7):
#这个初始定义,是对类定义的,也就是当需要调用DownSamplingLayer()时,除了要输入channel_in, channel_out, 会自动带入dilation=1, kernel_size=15, stride=1, padding=7):super(DownSamplingLayer, self).__init__()#解释:调用父类 nn.Module 的初始化方法。这是Python继承的标准写法。self.main = nn.Sequential(nn.Conv1d(channel_in, channel_out, kernel_size=kernel_size,stride=stride, padding=padding, dilation=dilation),nn.BatchNorm1d(channel_out),nn.LeakyReLU(negative_slope=0.1))def forward(self, ipt):#ipt:输入数据return self.main(ipt)class UpSamplingLayer(nn.Module):def __init__(self, channel_in, channel_out, kernel_size=5, stride=1, padding=2):#padding=5-1//2super(UpSamplingLayer, self).__init__()self.main = nn.Sequential(nn.Conv1d(channel_in, channel_out, kernel_size=kernel_size,stride=stride, padding=padding),nn.BatchNorm1d(channel_out),nn.LeakyReLU(negative_slope=0.1, inplace=True),)def forward(self, ipt):return self.main(ipt)class Model(nn.Module): def __init__(self, n_layers=12, channels_interval=24):#这玩意就是后面需要用到什么固定参数就写什么,然后后面直接调用super(Model, self).__init__()self.n_layers = n_layersself.channels_interval = channels_intervalencoder_in_channels_list = [1] + [i * self.channels_interval for i in range(1, self.n_layers)]#输入list,i*通道数,i为1—12,也就是1层24通道,2层48通道...encoder_out_channels_list = [i * self.channels_interval for i in range(1, self.n_layers + 1)]self.encoder = nn.ModuleList()for i in range(self.n_layers):#1-12层,每层都在encoder这个list后增添一个下采样层,其中每层输入,等于encoder_in_channels_list中层数对应的索引self.encoder.append(DownSamplingLayer(channel_in=encoder_in_channels_list[i],channel_out=encoder_out_channels_list[i]))self.middle = nn.Sequential(nn.Conv1d(self.n_layers * self.channels_interval, self.n_layers * self.channels_interval, 15, stride=1,padding=7),nn.BatchNorm1d(self.n_layers * self.channels_interval),nn.LeakyReLU(negative_slope=0.1, inplace=True))decoder_in_channels_list = [(2 * i + 1) * self.channels_interval for i in range(1, self.n_layers)] + [2 * self.n_layers * self.channels_interval]decoder_in_channels_list = decoder_in_channels_list[::-1]decoder_out_channels_list = encoder_out_channels_list[::-1]self.decoder = nn.ModuleList()for i in range(self.n_layers):self.decoder.append(UpSamplingLayer(channel_in=decoder_in_channels_list[i],channel_out=decoder_out_channels_list[i]))self.out = nn.Sequential(nn.Conv1d(1 + self.channels_interval, 1, kernel_size=1, stride=1),nn.Tanh())
##上面,先定义了上下采样模块,然后设计了整个网络,encoder-decoder,以及非线性层
##下面是前向传播的全过程def forward(self, input):tmp = []#先初始化一个空的列表o = input #这我猜是用o字母表示input# Up Samplingfor i in range(self.n_layers):#这里是循环,共12层,第一层:输入进去,经过encoder后填入tmp中;然后上一层的输入再经过encoder得到输出填到tmp ;... o = self.encoder[i](o)tmp.append(o)# [batch_size, T // 2, channels]o = o[:, :, ::2]# # 时间维度下采样(取一半)o = self.middle(o) #经过encoder的输入再经过中间层# Down Samplingfor i in range(self.n_layers):# [batch_size, T * 2, channels]o = F.interpolate(o, scale_factor=2, mode="linear", align_corners=True)
#F.interpolate是线性差值,是上采样的方法,也就是通过临近信息填补到中间空位的方式来增加信息量,从而提高分辨率# Skip Connection
#跳跃连接,cat 是 concatenate(连接)的缩写。就是把几个数组像粘胶水一样粘在一起。o = torch.cat([o, tmp[self.n_layers - i - 1]], dim=1)#这里意思是横着拼,把当前解码输入与对应编码输出拼到一起成为新的列表o = self.decoder[i](o)#新的列表作为输入到下一层解码,直到12层都结束,生成最终输出o = torch.cat([o, input], dim=1) 最终输出作为输入与最开始输入拼接得到最最终输入o = self.out(o) 作为网络输出return o
下面是对跳跃链接的一些解释:
import torch
# 有两块积木
积木A = torch.tensor([[1, 2], [3, 4]]) # 2×2
积木B = torch.tensor([[5, 6], [7, 8]]) # 2×2
# 横着拼(dim=0,按行拼)
横拼 = torch.cat([积木A, 积木B], dim=0)"""
[[1, 2],
[3, 4],
[5, 6], ← 积木B接在下面
[7, 8]]
形状: 4×2
"""
# 竖着拼(dim=1,按列拼)
竖拼 = torch.cat([积木A, 积木B], dim=1)"""
[[1, 2, 5, 6], ← 积木B接在右边
[3, 4, 7, 8]]
形状: 2×4
"""