آموزش هرس در کتابخانه های Pytorch و Tensorflow
تکنیک های پیشرفته یادگیری عمیق بر مدل های بیش از حد پارامتریزه شده تکیه دارند که به سختی به کار می روند. در مقابل، شبکههای عصبی بیولوژیکی از اتصال پراکنده کارآمد استفاده میکنند. شناسایی تکنیکهای بهینه برای فشردهسازی مدلها با کاهش تعداد پارامترهای موجود در آنها به منظور کاهش مصرف حافظه، باتری و سختافزار بدون کاهش دقت، استقرار مدلهای سبک وزن بر روی دستگاه و تضمین حریم خصوصی با محاسبات شخصی روی دستگاه مهم است. در شاخه تحقیقاتی، هرس برای بررسی تفاوتهای پویایی یادگیری بین شبکههای بیش از حد پارامتر و شبکههای کم پارامتر، برای مطالعه نقش زیرشبکههای پراکنده و مقداردهی اولیه (به صورت اتفاقی) به عنوان یک تکنیک جستجوی معماری عصبی مخرب استفاده میشود.
در این آموزش، شما یاد خواهید گرفت که چگونه از شبکه های عصبیtorch.nn.utils.prune
خود استفاده کنید و چگونه آن را برای پیاده سازی تکنیک هرس سفارشی خود گسترش دهید.
الزامات برنامه :
import torch from torch import nn import torch.nn.utils.prune as prune import torch.nn.functional as F
یک مدل ایجاد کنید :
در این آموزش، ما از معماری LeNet از LeCun و همکاران، 1998 استفاده می کنیم.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1 input image channel, 6 output channels, 3x3 square conv kernel
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device=device)
یک ماژول را بررسی کنید :
بیاییدلایه conv1
(هرس نشده) را در مدل LeNet خود بررسی کنیم. این شامل دو پارامتر weight
و bias
، و بدون بافر، در حال حاضر خواهد بود.
module = model.conv1 print(list(module.named_parameters()))
خروجی:
[('weight', Parameter containing:
tensor([[[[ 0.0547, -0.1490, 0.2548],
[ 0.1011, -0.0678, -0.1258],
[-0.2431, 0.0490, 0.2079]]],
[[[ 0.2246, -0.0719, 0.2456],
[-0.0275, 0.0436, 0.0349],
[ 0.0376, -0.2445, 0.2093]]],
[[[-0.2927, -0.1899, 0.0973],
[ 0.3133, -0.2256, -0.0648],
[ 0.0768, -0.0155, -0.2793]]],
[[[-0.1462, 0.1492, 0.1009],
[ 0.0474, -0.0459, -0.0308],
[ 0.1235, 0.1081, 0.2985]]],
[[[-0.2447, 0.1707, 0.3060],
[-0.1629, 0.2025, 0.1867],
[ 0.2173, 0.1944, -0.0762]]],
[[[ 0.1163, -0.1372, 0.1721],
[ 0.0272, 0.0657, -0.1102],
[ 0.0223, 0.1550, 0.0909]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.2666, 0.2486, -0.2918, 0.3310, 0.3075, -0.1668], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
خروجی:
[]
هرس یک ماژول :
برای هرس کردن یک ماژول (در این مثال، لایهconv1
معماری LeNet ما)، ابتدا یک تکنیک هرس را از بین روشهای موجود در آن انتخاب کنید torch.nn.utils.prune
(یا تکنیک خود را با زیر کلاسبندی پیادهسازیBasePruningMethod
کنید ). سپس، ماژول و نام پارامتر مورد نظر را در آن ماژول مشخص کنید. در نهایت، با استفاده از آرگومان های کافی کلمه کلیدی مورد نیاز تکنیک هرس انتخاب شده، پارامترهای هرس را مشخص کنید.
در این مثال، 30 درصد از اتصالات پارامترweight
نامگذاری شده در لایهconv1
را به صورت تصادفی هرس می کنیم. ماژول به عنوان اولین آرگومان به تابع ارسال می شود. پارامترname
را در آن ماژول با استفاده از شناسه رشته آن شناسایی می کند. و درصدamount
اتصالات برای هرس (اگر یک شناور بین 0. و 1 باشد)، یا تعداد مطلق اتصالات برای هرس (اگر یک عدد صحیح غیر منفی باشد) را نشان می دهد.
prune.random_unstructured(module, name="weight", amount=0.3)
هرس با حذف weight
از پارامترها و جایگزینی آن با یک پارامتر جدید به نام weight_orig
(یعنی الحاق "_orig"
به پارامتر اولیه name
) عمل می کند. نسخه هرس نشده تانسورweight_orig
را ذخیره می کند.bias
هرس نشده است، بنابراین bias
دست نخورده باقی خواهد ماند.
print(list(module.named_parameters()))
خروجی:
[('bias', Parameter containing:
tensor([-0.2666, 0.2486, -0.2918, 0.3310, 0.3075, -0.1668], device='cuda:0',
requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.0547, -0.1490, 0.2548],
[ 0.1011, -0.0678, -0.1258],
[-0.2431, 0.0490, 0.2079]]],
[[[ 0.2246, -0.0719, 0.2456],
[-0.0275, 0.0436, 0.0349],
[ 0.0376, -0.2445, 0.2093]]],
[[[-0.2927, -0.1899, 0.0973],
[ 0.3133, -0.2256, -0.0648],
[ 0.0768, -0.0155, -0.2793]]],
[[[-0.1462, 0.1492, 0.1009],
[ 0.0474, -0.0459, -0.0308],
[ 0.1235, 0.1081, 0.2985]]],
[[[-0.2447, 0.1707, 0.3060],
[-0.1629, 0.2025, 0.1867],
[ 0.2173, 0.1944, -0.0762]]],
[[[ 0.1163, -0.1372, 0.1721],
[ 0.0272, 0.0657, -0.1102],
[ 0.0223, 0.1550, 0.0909]]]], device='cuda:0', requires_grad=True))]
ماسک هرس تولید شده توسط تکنیک هرس انجام شده در بالا به عنوان یک بافر ماژول با نام weight_mask
(یعنی الحاق "_mask"
به پارامتر اولیه name
) ذخیره می شود.
print(list(module.named_buffers()))
خروجی:
[('weight_mask', tensor([[[[1., 0., 0.],
[1., 1., 1.],
[1., 0., 1.]]],
[[[1., 0., 1.],
[1., 1., 1.],
[0., 1., 0.]]],
[[[1., 0., 1.],
[1., 1., 1.],
[0., 1., 0.]]],
[[[1., 1., 1.],
[0., 1., 0.],
[1., 1., 0.]]],
[[[0., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[0., 0., 1.],
[1., 1., 1.],
[1., 0., 1.]]]], device='cuda:0'))]
برای اینکه حرکت رو به جلو بدون تغییر کار کند، ویژگی weight
باید وجود داشته باشد. تکنیک های هرس اجرا شده در torch.nn.utils.prune
محاسبه نسخه هرس شده وزن (با ترکیب ماسک با پارامتر اصلی) و ذخیره آنها در ویژگی weight
. توجه داشته باشید، این دیگر یک پارامتر از module نیست
, بلکه صرفاً یک ویژگی است.
print(module.weight)
خروجی:
tensor([[[[ 0.0547, -0.0000, 0.0000],
[ 0.1011, -0.0678, -0.1258],
[-0.2431, 0.0000, 0.2079]]],
[[[ 0.2246, -0.0000, 0.2456],
[-0.0275, 0.0436, 0.0349],
[ 0.0000, -0.2445, 0.0000]]],
[[[-0.2927, -0.0000, 0.0973],
[ 0.3133, -0.2256, -0.0648],
[ 0.0000, -0.0155, -0.0000]]],
[[[-0.1462, 0.1492, 0.1009],
[ 0.0000, -0.0459, -0.0000],
[ 0.1235, 0.1081, 0.0000]]],
[[[-0.0000, 0.1707, 0.3060],
[-0.1629, 0.2025, 0.1867],
[ 0.2173, 0.1944, -0.0762]]],
[[[ 0.0000, -0.0000, 0.1721],
[ 0.0272, 0.0657, -0.1102],
[ 0.0223, 0.0000, 0.0909]]]], device='cuda:0',
grad_fn=<MulBackward0>)
در نهایت، هرس قبل از هر پاس رو به جلو با استفاده از PyTorch اعمال می شود forward_pre_hooks
. به طور خاص، هنگامی که module
هرس می شود، همانطور که در اینجا انجام دادیم، forward_pre_hook
برای هر پارامتر مرتبط با آن که هرس می شود، یک عدد بدست می آورد. در این حالت، از آنجایی که ما تاکنون فقط پارامتر اصلی با نام weight
را هرس کردهایم ، تنها یک قالب وجود خواهد داشت.
print(module._forward_pre_hooks)
خروجی:
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fb1c6262590>)])
برای کاملتر شدن، اکنون میتوانیم bias را
نیز هرس کنیم تا ببینیم پارامترها، بافرها، هوکها و ویژگیهای module
تغییراتش چگونه هستند. فقط به خاطر امتحان کردن تکنیک هرس دیگری، در اینجا ما 3 ورودی کوچکتر در بایاس را با مقدار مناسب L1 هرس می کنیم، همانطور که در l1_unstructured
تابع هرس پیاده سازی شده است.
prune.l1_unstructured(module, name="bias", amount=3)
اکنون انتظار داریم که پارامترهای نامگذاری شده هر دوbias_orig
وweight_orig
(از قبل) . بافرها شامل weight_mask
و bias_mask
. نسخههای هرسشده دو تانسور بهعنوان ویژگیهای ماژول وجود خواهند داشت، و ماژول forward_pre_hooks
اکنون دو تانسور خواهد داشت .
print(list(module.named_parameters()))
خروجی:
[('weight_orig', Parameter containing:
tensor([[[[ 0.0547, -0.1490, 0.2548],
[ 0.1011, -0.0678, -0.1258],
[-0.2431, 0.0490, 0.2079]]],
[[[ 0.2246, -0.0719, 0.2456],
[-0.0275, 0.0436, 0.0349],
[ 0.0376, -0.2445, 0.2093]]],
[[[-0.2927, -0.1899, 0.0973],
[ 0.3133, -0.2256, -0.0648],
[ 0.0768, -0.0155, -0.2793]]],
[[[-0.1462, 0.1492, 0.1009],
[ 0.0474, -0.0459, -0.0308],
[ 0.1235, 0.1081, 0.2985]]],
[[[-0.2447, 0.1707, 0.3060],
[-0.1629, 0.2025, 0.1867],
[ 0.2173, 0.1944, -0.0762]]],
[[[ 0.1163, -0.1372, 0.1721],
[ 0.0272, 0.0657, -0.1102],
[ 0.0223, 0.1550, 0.0909]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.2666, 0.2486, -0.2918, 0.3310, 0.3075, -0.1668], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
خروجی:
[('weight_mask', tensor([[[[1., 0., 0.],
[1., 1., 1.],
[1., 0., 1.]]],
[[[1., 0., 1.],
[1., 1., 1.],
[0., 1., 0.]]],
[[[1., 0., 1.],
[1., 1., 1.],
[0., 1., 0.]]],
[[[1., 1., 1.],
[0., 1., 0.],
[1., 1., 0.]]],
[[[0., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[0., 0., 1.],
[1., 1., 1.],
[1., 0., 1.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 1., 0.], device='cuda:0'))]
print(module.bias)
خروجی:
tensor([-0.0000, 0.0000, -0.2918, 0.3310, 0.3075, -0.0000], device='cuda:0',
grad_fn=<MulBackward0>)
print(module._forward_pre_hooks)
خروجی:
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fb1c6262590>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7fb1c5266310>)])
هرس تکراری :
یک پارامتر یکسان در یک ماژول را می توان چندین بار هرس کرد، با تأثیر فراخوان های مختلف هرس برابر با ترکیب ماسک های مختلف اعمال شده به صورت سری است. ترکیب یک ماسک جدید با ماسک قدیمی به روشcompute_mask
PruningContainer
‘s انجام می شود.
مثلاً بگویید که اکنون میخواهیمmodule.weight را
بیشتر هرس کنیم، این بار با استفاده از هرس ساختاری در امتداد محور 0 تانسور (محور 0 مربوط به کانالهای خروجی لایه کانولوشن است و دارای ابعاد 6 برای conv1
) بر اساس کانالها. مقدار مناسب L2. این را می توان با استفاده از تابعln_structured
، با n=2
وdim=0
به دست آورد .
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0) # As we can verify, this will zero out all the connections corresponding to # 50% (3 out of 6) of the channels, while preserving the action of the # previous mask. print(module.weight)
tensor([[[[ 0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000]]],
[[[ 0.2246, -0.0000, 0.2456],
[-0.0275, 0.0436, 0.0349],
[ 0.0000, -0.2445, 0.0000]]],
[[[-0.2927, -0.0000, 0.0973],
[ 0.3133, -0.2256, -0.0648],
[ 0.0000, -0.0155, -0.0000]]],
[[[-0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000]]],
[[[-0.0000, 0.1707, 0.3060],
[-0.1629, 0.2025, 0.1867],
[ 0.2173, 0.1944, -0.0762]]],
[[[ 0.0000, -0.0000, 0.0000],
[ 0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000]]]], device='cuda:0',
grad_fn=<MulBackward0>)
قالب مربوطه اکنون از نوع torch.nn.utils.prune.PruningContainer
خواهد بود و سابقه هرس اعمال شده روی پارامترweight
را ذخیره می کند.
for hook in module._forward_pre_hooks.values(): if hook._tensor_name == "weight": # select out the correct hook break print(list(hook)) # pruning history in the container
خروجی:
[<torch.nn.utils.prune.RandomUnstructured object at 0x7fb1c6262590>, <torch.nn.utils.prune.LnStructured object at 0x7fb1b0f90110>]
سریال سازی یک مدل هرس شده :
تمام تانسورهای مربوطه، از جمله بافرهای ماسک و پارامترهای اصلی که برای محاسبه تانسورهای هرس شده استفاده میشوند، در مدل state_dict
ذخیره میشوند و بنابراین میتوان به راحتی سریالسازی و در صورت نیاز ذخیره کرد.
print(model.state_dict().keys())
خروجی:
odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
پارامترسازی مجدد هرس را حذف کنید :
برای دائمی کردن هرس، پارامترسازی مجدد را بر حسب weight_orig
و weight_mask
حذف کرده و می توانیم از forward_pre_hook
remove
عملکردtorch.nn.utils.prune
استفاده کنیم . توجه داشته باشید که این کار هرس را خنثی نمی کند، گویی هرگز اتفاق نیفتاده است. در عوض، با تخصیص مجدد پارامتر weight
به پارامترهای مدل، در نسخه هرس شده آن، به سادگی آن را دائمی می کند.
قبل از حذف پارامترسازی مجدد:
print(list(module.named_parameters()))
خروجی:
[('weight_orig', Parameter containing:
tensor([[[[ 0.0547, -0.1490, 0.2548],
[ 0.1011, -0.0678, -0.1258],
[-0.2431, 0.0490, 0.2079]]],
[[[ 0.2246, -0.0719, 0.2456],
[-0.0275, 0.0436, 0.0349],
[ 0.0376, -0.2445, 0.2093]]],
[[[-0.2927, -0.1899, 0.0973],
[ 0.3133, -0.2256, -0.0648],
[ 0.0768, -0.0155, -0.2793]]],
[[[-0.1462, 0.1492, 0.1009],
[ 0.0474, -0.0459, -0.0308],
[ 0.1235, 0.1081, 0.2985]]],
[[[-0.2447, 0.1707, 0.3060],
[-0.1629, 0.2025, 0.1867],
[ 0.2173, 0.1944, -0.0762]]],
[[[ 0.1163, -0.1372, 0.1721],
[ 0.0272, 0.0657, -0.1102],
[ 0.0223, 0.1550, 0.0909]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.2666, 0.2486, -0.2918, 0.3310, 0.3075, -0.1668], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
خروجی:
[('weight_mask', tensor([[[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]],
[[[1., 0., 1.],
[1., 1., 1.],
[0., 1., 0.]]],
[[[1., 0., 1.],
[1., 1., 1.],
[0., 1., 0.]]],
[[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]],
[[[0., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 1., 0.], device='cuda:0'))]
print(module.weight)
خروجی:
tensor([[[[ 0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000]]],
[[[ 0.2246, -0.0000, 0.2456],
[-0.0275, 0.0436, 0.0349],
[ 0.0000, -0.2445, 0.0000]]],
[[[-0.2927, -0.0000, 0.0973],
[ 0.3133, -0.2256, -0.0648],
[ 0.0000, -0.0155, -0.0000]]],
[[[-0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000]]],
[[[-0.0000, 0.1707, 0.3060],
[-0.1629, 0.2025, 0.1867],
[ 0.2173, 0.1944, -0.0762]]],
[[[ 0.0000, -0.0000, 0.0000],
[ 0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000]]]], device='cuda:0',
grad_fn=<MulBackward0>)
پس از حذف پارامترسازی مجدد:
prune.remove(module, 'weight') print(list(module.named_parameters()))
خروجی:
[('bias_orig', Parameter containing:
tensor([-0.2666, 0.2486, -0.2918, 0.3310, 0.3075, -0.1668], device='cuda:0',
requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000]]],
[[[ 0.2246, -0.0000, 0.2456],
[-0.0275, 0.0436, 0.0349],
[ 0.0000, -0.2445, 0.0000]]],
[[[-0.2927, -0.0000, 0.0973],
[ 0.3133, -0.2256, -0.0648],
[ 0.0000, -0.0155, -0.0000]]],
[[[-0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000]]],
[[[-0.0000, 0.1707, 0.3060],
[-0.1629, 0.2025, 0.1867],
[ 0.2173, 0.1944, -0.0762]]],
[[[ 0.0000, -0.0000, 0.0000],
[ 0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000]]]], device='cuda:0', requires_grad=True))]
print(list(module.named_buffers()))
خروجی:
[('bias_mask', tensor([0., 0., 1., 1., 1., 0.], device='cuda:0'))]
هرس چند پارامتر در یک مدل :
با تعیین تکنیک و پارامترهای هرس مورد نظر، می توانیم به راحتی چندین تانسور را در یک شبکه، شاید با توجه به نوع آنها، هرس کنیم، همانطور که در این مثال خواهیم دید.
new_model = LeNet() for name, module in new_model.named_modules(): # prune 20% of connections in all 2D-conv layers if isinstance(module, torch.nn.Conv2d): prune.l1_unstructured(module, name='weight', amount=0.2) # prune 40% of connections in all linear layers elif isinstance(module, torch.nn.Linear): prune.l1_unstructured(module, name='weight', amount=0.4) print(dict(new_model.named_buffers()).keys()) # to verify that all masks exist
خروجی:
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])
هرس فراگیر :
تا کنون، ما فقط به آنچه معمولاً به عنوان هرس «محلی» گفته میشود، یعنی تمرین هرس کردن تانسورها در یک مدل یک به یک، با مقایسه آمار (قدر وزن، فعالسازی، گرادیان و غیره) هر ورودی منحصراً نگاه کردیم. به ورودی های دیگر در آن تانسور. با این حال، یک تکنیک رایج و شاید قدرتمندتر این است که مدل را به یکباره هرس کنید، با حذف (به عنوان مثال) کمترین 20٪ از اتصالات در کل مدل، به جای حذف کمترین 20٪ از اتصالات در هر لایه. این احتمالاً منجر به درصدهای مختلف هرس در هر لایه می شود. بیایید ببینیم چگونه با استفاده global_unstructured
از torch.nn.utils.prune
انجام میشود.
model = LeNet() parameters_to_prune = ( (model.conv1, 'weight'), (model.conv2, 'weight'), (model.fc1, 'weight'), (model.fc2, 'weight'), (model.fc3, 'weight'), ) prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2, )
اکنون میتوانیم پراکندگی ایجاد شده در هر پارامتر هرس شده را بررسی کنیم که در هر لایه برابر با 20% نخواهد بود. با این حال، پراکندگی فراگیر (تقریبا) 20٪ خواهد بود.
print( "Sparsity in conv1.weight: {:.2f}%".format( 100. * float(torch.sum(model.conv1.weight == 0)) / float(model.conv1.weight.nelement()) ) ) print( "Sparsity in conv2.weight: {:.2f}%".format( 100. * float(torch.sum(model.conv2.weight == 0)) / float(model.conv2.weight.nelement()) ) ) print( "Sparsity in fc1.weight: {:.2f}%".format( 100. * float(torch.sum(model.fc1.weight == 0)) / float(model.fc1.weight.nelement()) ) ) print( "Sparsity in fc2.weight: {:.2f}%".format( 100. * float(torch.sum(model.fc2.weight == 0)) / float(model.fc2.weight.nelement()) ) ) print( "Sparsity in fc3.weight: {:.2f}%".format( 100. * float(torch.sum(model.fc3.weight == 0)) / float(model.fc3.weight.nelement()) ) ) print( "Global sparsity: {:.2f}%".format( 100. * float( torch.sum(model.conv1.weight == 0) + torch.sum(model.conv2.weight == 0) + torch.sum(model.fc1.weight == 0) + torch.sum(model.fc2.weight == 0) + torch.sum(model.fc3.weight == 0) ) / float( model.conv1.weight.nelement() + model.conv2.weight.nelement() + model.fc1.weight.nelement() + model.fc2.weight.nelement() + model.fc3.weight.nelement() ) ) )
خروجی:
Sparsity in conv1.weight: 0.00%
Sparsity in conv2.weight: 8.91%
Sparsity in fc1.weight: 22.05%
Sparsity in fc2.weight: 12.11%
Sparsity in fc3.weight: 10.48%
Global sparsity: 20.00%
گسترش با توابع هرس سفارشیtorch.nn.utils.prune
برای پیاده سازی تابع هرس خود، می توانید ماژولnn.utils.prune
را با زیر کلاس بندی کلاس پایهBasePruningMethod
گسترش دهید، به همان روشی که سایر روش های هرس انجام می دهند. کلاس پایه متدهای زیر را برای شما پیاده سازی می کند: __call__
, apply_mask
, apply
, prune
و remove
. فراتر از برخی موارد خاص، لازم نیست این روش ها را برای تکنیک هرس جدید خود دوباره اجرا کنید. با این حال، شما باید __init__
(سازنده) و compute_mask
(دستورالعملهایی در مورد نحوه محاسبه ماسک برای تانسور دادهشده با توجه به منطق تکنیک هرس خود) را پیادهسازی کنید. علاوه بر این، شما باید مشخص کنید که این تکنیک کدام نوع هرس را اجرا می کند (گزینه های پشتیبانی شده عبارتند از global
، structured
وunstructured
). این برای تعیین نحوه ترکیب ماسک ها در مواردی که هرس به طور مکرر اعمال می شود، مورد نیاز است. به عبارت دیگر، هنگام هرس کردن یک پارامتر از پیش هرس شده، انتظار می رود روش هرس فعلی روی قسمت هرس نشده پارامتر عمل کند. تعیین PRUNING_TYPEPruningContainer
کردن (که کاربرد تکراری ماسکهای هرس را انجام میدهد) را قادر میسازد تا تکهای از پارامتر مورد نظر را به درستی شناسایی کند.
برای مثال، فرض کنید که میخواهید یک تکنیک هرس را پیادهسازی کنید که هر ورودی دیگر را در یک تانسور (یا – اگر تانسور قبلاً هرس شده است – در قسمت هرس نشده باقیمانده تانسور) هرس کند. این PRUNING_TYPE='unstructured'
به این دلیل است که بر روی اتصالات جداگانه در یک لایه عمل می کند و نه بر روی کل واحدها / کانال ها ( 'structured'
) یا در پارامترهای مختلف ( 'global'
).
class FooBarPruningMethod(prune.BasePruningMethod): """Prune every other entry in a tensor """ PRUNING_TYPE = 'unstructured' def compute_mask(self, t, default_mask): mask = default_mask.clone() mask.view(-1)[::2] = 0 return mask
حال، برای اعمال این پارامتر در یک nn.Module
، باید یک تابع ساده نیز ارائه کنید که متد را نمونهسازی کرده و آن را اعمال میکند.
def foobar_unstructured(module, name): """Prunes tensor corresponding to parameter called `name` in `module` by removing every other entry in the tensors. Modifies module in place (and also return the modified module) by: 1) adding a named buffer called `name+'_mask'` corresponding to the binary mask applied to the parameter `name` by the pruning method. The parameter `name` is replaced by its pruned version, while the original (unpruned) parameter is stored in a new parameter named `name+'_orig'`. Args: module (nn.Module): module containing the tensor to prune name (string): parameter name within `module` on which pruning will act. Returns: module (nn.Module): modified (i.e. pruned) version of the input module Examples: >>> m = nn.Linear(3, 4) >>> foobar_unstructured(m, name='bias') """ FooBarPruningMethod.apply(module, name) return module
بیایید آن را امتحان کنیم!
model = LeNet() foobar_unstructured(model.fc3, name='bias') print(model.fc3.bias_mask)
خروجی:
tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
پایان
دیدگاهتان را بنویسید