本文为稀土技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!
一、前言
在AIGC领域频繁出现着一个特殊名词“LoRA”,听上去有点像人名,可是这是一种模型练习的方法。LoRA全称Low-Rank Adaptation of Large Language Models,中文叫做大言语模型的低阶习惯。如今在stable diffusion中用地十分频繁。
因为大言语模型的参数量巨大,许多大公司都需求练习数月,由此提出了各种资源耗费较小的练习方法,LoRA便是其间一种。
本文将具体介绍LoRA的原理,并运用PyTorch实现小模型的LoRA练习。
二、模型练习
现在大多数模型练习都是采用梯度下降算法。梯度下降算法能够分为下面4个进程:
- 正向传达核算丢失值
- 反向传达核算梯度
- 运用梯度更新参数
- 重复1、2、3的进程,直到获取较小的丢失
以线性模型为例,模型参数为W,输入输出为x、y,丢失函数以均方差错为例。那么各个进程的核算如下,首先是正向传达,关于线性模型来说便是做一个矩阵乘法:
在求出丢失后,能够核算L对W的梯度,得到dW:
dW是一个矩阵,它会指向L上升最快的方向,可是咱们的意图是让L下降,因而让W减去dW。为了调整更新的脚步,还会乘上一个学习率,核算如下:
最终一直重复即刻。上述三个进程的伪代码如下:
# 4、重复1、2、3
for i in range(10000):
# 1、正向传达核算丢失
L = MSE(Wx, y)
# 2、反向传达核算梯度
dW = gradient(L, W)
# 3、运用梯度更新参数
W -= lr * dW
在更新完成后,得到新的参数W’。此刻咱们运用模型预测时,核算如下:
三、引进LoRA
咱们能够来思考一下W和W’之间的联系。W通常指基础模型的参数,而W’是在基础模型的基础上,经过几次矩阵加减得到的。假设在练习的进程中更新了10次,每次的dW别离为dW1、dW2、….、dW10,那么完整的更新进程能够写为一次运算:
其间dW是一个形状与W’共同的矩阵。咱们把-dW写成矩阵R,那么更新后的参数便是:
此刻练习的进程就被简化为原矩阵加上另一个矩阵R。可是求解矩阵R并没有更简略,而且也没有节省资源,此刻就引出LoRA了这一思想。
一个练习充沛的矩阵,通常是满秩或者根本满意秩的,即矩阵中没有一列是剩余的。在论文《Scaling Laws for Neural Language Model》中提出了数据集与参数巨细之间的联系,满意该联系且练习杰出,得到的模型是根本满秩的。在微调模型时,咱们会选取一个底模,该底模便是根本满秩的。而更新矩阵R秩的状况是怎么的呢?
咱们假定R矩阵是一个低秩矩阵,低秩矩阵有许多重复的列,因而能够分解为两个更小的矩阵。假设W的形状为mn,那么A的形状也是mn,咱们把矩阵R分解为AB(其间A形状为mr,B形状为rN),r通常会选取一个远小于m、n的值,如图所示:
将低秩矩阵分解为两个矩阵几点优点,首先是参数量明显减少。假设R矩阵的形状为100100,那么R的参数量为10000。当咱们选取秩为10时,此刻矩阵A的形状为10010,矩阵B的形状为10100,此刻参数量为2000,比R矩阵少了80%。
而且因为R是低秩矩阵,所以在练习充沛的状况下,A和B矩阵能够达到R的作用。这儿的矩阵AB便是咱们常说的LoRA模型。
在引进LoRA后,咱们的预测需求将x别离输入W和AB,此刻预测的核算为:
在预测时会比原始模型稍慢,可是在大模型中根本感觉不到差异。
四、实战
为了把握各个细节,这儿不运用大模型作为lora的实战,而是挑选运用vgg19这种小型网络来练习lora模型。导入需求用到的模块:
import os
import torch
from torch import optim, nn
from PIL import Image
from torch.utils import data
from torchvision import models
from torchvision.transforms import transforms
4.1 数据集预备
这儿运用vgg19在imagenet上的预练习权重作为底模,因而需求预备分类数据集。为了方便,这儿只预备了一个类别,且只预备了5张图片,图片在项目下的data/goldfish
下:
在imagenet中包含了goldfish类别,可是这儿选取的是插画版的goldfish,经过测验,预练习模型不能将上述图片正确分类。咱们的意图便是练习LoRA,让模型正确分类。
咱们创立一个LoraDataset:
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class LoraDataset(data.Dataset):
def __init__(self, data_path="datas"):
categories = models.VGG19_Weights.IMAGENET1K_V1.value.meta["categories"]
self.files = []
self.labels = []
for dir in os.listdir(data_path):
dirname = os.path.join(data_path, dir)
for file in os.listdir(dirname):
self.files.append(os.path.join(dirname, file))
self.labels.append(categories.index(dir))
def __getitem__(self, item):
image = Image.open(self.files[item]).convert("RGB")
label = torch.zeros(1000, dtype=torch.float64)
label[self.labels[item]] = 1.
return transform(image), label
def __len__(self):
return len(self.files)
4.2 创立LoRA模型
咱们把LoRA封装成一个层,LoRA中只要两个需求练习的矩阵,LoRA的代码如下:
class Lora(nn.Module):
def __init__(self, m, n, rank=10):
super().__init__()
self.m = m
self.A = nn.Parameter(torch.randn(m, rank))
self.B = nn.Parameter(torch.zeros(rank, n))
def forward(self, inputs):
inputs = inputs.view(-1, self.m)
return torch.mm(torch.mm(inputs, self.A), self.B)
其间m是输入的巨细,n是输出的巨细,rank是秩的巨细,咱们能够设置一个较小的值。
在权重初始化时,咱们把A用高斯噪声初始化,而B用0矩阵初始化,这样的意图是保证从底模开端练习。因为AB是0矩阵,所以初始状态下,LoRA不起作用。
4.3 设置超参数并练习
接下来便是练习了,这儿和PyTorch常规练习代码根本共同,先看代码:
# 加载底模和lora
vgg19 = models.vgg19(models.VGG19_Weights.IMAGENET1K_V1)
for params in vgg19.parameters():
params.requires_grad = False
vgg19.eval()
lora = Lora(224 * 224 * 3, 1000)
# 加载数据
lora_loader = data.DataLoader(LoraDataset(), batch_size=batch_size, shuffle=True)
# 加载优化器
optimizer = optim.Adam(lora.parameters(), lr=lr)
# 定义丢失
loss_fn = nn.CrossEntropyLoss()
# 练习
for epoch in range(epochs):
for image, label in lora_loader:
# 正向传达
pred = vgg19(image) + lora(image)
loss = loss_fn(pred, label)
# 反向传达
loss.backward()
# 更新参数
optimizer.step()
optimizer.zero_grad()
print(f"loss: {loss.item()}")
这儿有两点需求留意,第一点是咱们把vgg19的权重设置为不可练习,这和搬迁学习很像,但其实是不一样的。
第二点则是正向传达时,咱们运用了下面代码:
pred = vgg19(image) + lora(image)
4.4 测验
下面来简略测验一下:
# 测验
for image, _ in lora_loader:
pred = vgg19(image) + lora(image)
idx = torch.argmax(pred, dim=1).item()
category = models.VGG19_Weights.IMAGENET1K_V1.value.meta["categories"][idx]
print(category)
torch.save(lora.state_dict(), 'lora.pth')
输出成果如下:
goldfish
goldfish
goldfish
goldfish
goldfish
根本预测正确了,不过这个测验成果并不能说明什么。最终咱们保存了一个5M的LoRA模型,相比vgg19的几十M算是十分小了。
五、总结
LoRA是针对大模型的一种高效的练习方法,而本文则将LoRA运用在小型的分类网络中,旨在让读者更明晰认识LoRA的具体实现(一起也因为跑不动大模型)。限于数据量,对LoRA的精度效率等问题没有具体讨论,读者能够参考相关材料深化了解。