|
- import torch
- import torch.nn as nn
- from torch.utils.data import Dataset,ConcatDataset,DataLoader
- import numpy as np
- from data import *
- import os
- import pandas as pd
- import torchvision
- from functools import partial
- import math
- import torch
- import torch.nn as nn
-
- from timm.models.vision_transformer import PatchEmbed, Block
-
- # from util.pos_embed import get_2d_sincos_pos_embed
-
- class Transformer(nn.Module):
- """ hubert with VisionTransformer backbone
- """
-
- def __init__(self, x_size=150000, in_chans=1,
- embed_dim=768, depth=12, num_heads=8, mlp_ratio=4,
- norm_layer=nn.LayerNorm, norm_pix_loss=False):
- super().__init__()
- self.num_patches = math.floor(x_size / embed_dim)
- self.embed_dim = embed_dim
- self.Con1 = nn.Conv1d(1, embed_dim, kernel_size=embed_dim, stride=embed_dim)
- # self.Con2 = nn.ModuleList([
- # nn.Conv1d(512, 512, kernel_size=3, stride=2)
- # for i in range(4)])
- # self.Con3 = nn.ModuleList([
- # nn.Conv1d(512, 512, kernel_size=2, stride=2)
- # for i in range(2)])
- self.layer_norm1 = norm_layer(embed_dim)
- # self.proj1 = nn.Linear(512, embed_dim)
-
- self.blocks = nn.ModuleList([
- Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
- for i in range(depth)])
- # self.proj2 = nn.Linear(embed_dim, 256)
- # self.dropout2 = nn.Dropout(0.1)
- # self.proj3 = nn.Linear(3840, 1024)
- # self.act1 = nn.ReLU(inplace=True)
- # self.dropout3 = nn.Dropout(0.1)
- # self.proj4 = nn.Linear(1024, 1024)
- # self.act2 = nn.ReLU(inplace=True)
- # self.proj5 = nn.Linear(1024, 83)
- self.proj6 = nn.Linear(self.embed_dim, 170)
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
- self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
- self.initialize_weights()
- self.get_1d_sincos_pos_embed_from_grid(embed_dim, self.num_patches, cls_token=True)
-
-
- def initialize_weights(self):
- # initialize (and freeze) pos_embed by sin-cos embedding
- pos_embed = self.get_1d_sincos_pos_embed_from_grid(self.embed_dim, self.num_patches,cls_token=True)
- self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
- torch.nn.init.normal_(self.cls_token, std=.02) #使用从正态分布中提取的值填充输入Tensor
- # torch.nn.init.normal_(self.mask_token, std=.02)
-
- # initialize nn.Linear and nn.LayerNorm and nn.Conv1d
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- # we use xavier_uniform following official JAX ViT:
- torch.nn.init.xavier_uniform_(m.weight)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
- elif isinstance(m, nn.Conv1d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
-
- def get_1d_sincos_pos_embed_from_grid(self,embed_dim, pos, cls_token=True):
- """
- embed_dim: output dimension for each position
- pos: a list of positions to be encoded: size (M,)
- out: (M, D)
- """
- assert embed_dim % 2 == 0
- omega = np.arange(embed_dim // 2, dtype=np.float)
- omega /= embed_dim / 2.
- omega = 1. / 10000**omega # (D/2,)
- grid_d = np.arange(pos, dtype=np.float32)
- grid_d = grid_d.reshape(-1) # (M,)
- out = np.einsum('m,d->md', grid_d, omega) # (M, D/2), outer product
-
- emb_sin = np.sin(out) # (M, D/2)
- emb_cos = np.cos(out) # (M, D/2)
-
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
- if cls_token:
- pos_embed = np.concatenate([np.zeros([1, embed_dim]), emb], axis=0)
- return pos_embed
-
- def forward(self,x):
- # print('x.size0:',x.size())
- x = self.Con1(x)
- # for Con2 in self.Con2:
- # x = Con2(x)
- # for Con3 in self.Con3:
- # x = Con3(x)
- x = x.transpose(1,2)
- x = self.layer_norm1(x)
- # x = self.proj1(x)
- # print("x.size1:",x.size())
- # print("self.pos_embed:",self.pos_embed.size())
- x = x + self.pos_embed[:, 1:, :]
- # append cls token, [1, 1, D] + [1, 1, D]
- cls_token = self.cls_token + self.pos_embed[:, :1, :]
- # [B, 1, D]
- cls_tokens = cls_token.expand(x.shape[0], -1, -1)
- # print("cls_tokens:",cls_tokens.size())
- # [B, 1+N, D]
- x = torch.cat((cls_tokens, x), dim=1)
- # print("x.size2:",x.size())
- for blk in self.blocks:
- x = blk(x)
- x = x[:, 0, :]
- x = self.proj6(x)
- return x
-
-
- def Transformer_base_emd768d12b():
- model = Transformer(x_size=150000, in_chans=1,
- embed_dim=768, depth=12, num_heads=8, mlp_ratio=4,
- norm_layer=partial(nn.LayerNorm, eps=1e-6), norm_pix_loss=False)
- return model
-
- def Transformer_large_emd1024d24b():
- model = Transformer(x_size=5000, in_chans=1,
- embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
- norm_layer=partial(nn.LayerNorm, eps=1e-6), norm_pix_loss=False)
- return model
-
- def Transformer_huge_emd1280d48b():
- model = Transformer(x_size=5000, in_chans=1,
- embed_dim=1280, depth=48, num_heads=16, mlp_ratio=4,
- norm_layer=partial(nn.LayerNorm, eps=1e-6), norm_pix_loss=False)
- return model
-
- # # set recommended archs
- Transformer_base = Transformer_base_emd768d12b # decoder: 512 dim, 8 blocks
- # Transformer_large = Transformer_large_emd1024d24b # decoder: 512 dim, 8 blocks
- # Transformer_huge = Transformer_huge_emd1280d48b # decoder: 512 dim, 8 blocks
|