|
1.前言注册机制是一种在编程中常见的设计模式,它允许程序在运行时动态地将函数、类或其他对象注册到某个中心管理器中,以便随后可以使用这些注册的对象。在Python中,注册机制通常用于实现插件系统、扩展性架构以及回调函数的管理。通俗的说,当我们的项目中需要成批量的函数和类,且这些函数和类功能上相似或并行时,为了方便管理,我们可以把这些指定的函数和类整合到一个字典。我们可以用函数名或类名作为字典的key,也可用使用自定义的名字作为key,对应的函数或类作为value。构建这样一个字典的过程就是注册(Registry),Python引入注册器机制保证了这个字典可以自动维护,增加或删除新的函数或类时,不需要手动去修改字典。Python注册器机制本质上是用装饰器(decorator)来实现的。下面我们将从基本的Python函数出发,逐步介绍装饰器,最后来学习注册器。1.理解Python函数1.1函数的不同调用首先定义一个函数,然后用不同的方式调用它。deffoo():return"IamLe0v1n"print(f"foo():{foo()}")fn=foo#这里foo后面没有小括号,不是函数调用,而是将foo函数赋值给变量fnprint(f"fn():{fn()}")12345678foo():IamLe0v1nfn():IamLe0v1n121.2函数中的函数在函数体中还可以定义函数(函数的函数😂),只是这个函数体内的函数不能在函数体外被直接调用:deffoo():print("foo函数正在运行...")#定义函数中的函数defbar():return"foo.bar函数正在运行..."defbam():return"foo.bam函数正在运行..."#调用函数中的函数print(bar())print(bam())print("foo函数即将结束!")if__name__=="__main__":foo()123456789101112131415161718foo函数正在运行...foo.bar函数正在运行...foo.bam函数正在运行...foo函数即将结束!1234上面的结果没有什么意思,但如果我们直接调用函数的函数,会发生什么?#如果我们调用函数中的函数try:bar()exceptExceptionase:print(f"报错啦:{e}")try:bam()exceptExceptionase:print(f"报错啦:{e}")12345678910报错啦:name'bar'isnotdefined报错啦:name'bam'isnotdefined121.3函数中函数的外部调用函数体内的函数虽然不能在函数体外被直接调用,但是可以将它们返回出来。deffoo(choice='bar'):print("foo函数正在运行...")#定义函数中的函数defbar():return"foo.bar函数正在运行..."defbam():return"foo.bam函数正在运行..."print("foo函数即将结束!")ifchoice=='bar':returnbarelifchoice=='bam':returnbamelse:raiseNotImplementedError("choice必须是bar或bam!")if__name__=="__main__":fn1=foo(choice='bar')fn2=foo(choice='bam')print(fn1)print(fn2)print(fn1())print(fn2())123456789101112131415161718192021222324252627foo函数正在运行...foo函数即将结束!foo函数正在运行...foo函数即将结束!.barat0x000001D72F1ECE50>.bamat0x000001D72F2558B0>foo.bar函数正在运行...foo.bam函数正在运行...12345678注意到返回的bar和bam后面没有小括号,那它就可以被传递,并且可以赋值给别的变量而被执行,如果有小括号,那它就会被执行。1.4函数作为函数参数我们还可以将函数作为参数传递给另一个函数:deffoo():return"Iamfoo"defbar(fn):print("Iambar")print(fn())if__name__=="__main__":bar(foo)print()try:bar(foo())exceptExceptionase:print(f"报错啦:{e}")12345678910111213141516IambarIamfooIambar报错啦:'str'objectisnotcallable123451.5函数的包装有了这样的印象之后,我们再写一个更加复杂一点的例子:defdecorator(fn):defwrapper():print("----------函数调用前----------")fn()#调用函数print("----------函数调用后----------")returnwrapperdeffoo():print("Iamfoo!")if__name__=="__main__":#直接调用函数print("直接调用函数:",end="")foo()print()#调用装饰器包装后函数fn=decorator(foo)#将foo函数用装饰器包装->fnprint("调用包装后的foo函数:")fn()#调用包装后的foo函数print()1234567891011121314151617181920212223直接调用函数:Iamfoo!调用包装后的foo函数:----------函数调用前----------Iamfoo!----------函数调用后----------1234562.理解Python装饰器2.1定义Python装饰器是一种高阶函数,用于修改其他函数的行为或添加额外功能。装饰器本质上是一个函数,它接受一个函数作为参数,然后返回一个新的函数,通常扩展了或修改了原始函数的行为。2.2装饰器的初步使用上一节的最后一个例子我们封装了一个函数foo,并且用另一个函数decorate去修改这个函数的行为,这个功能其实就是Python装饰器(Decorate)所做的事情,只是我们以函数的形式显式的写了出来。Python中的装饰器提供了更简洁的方式来实现同样的功能,装饰器的写法是在被装饰的函数前使用@装饰器名。现在我们用装饰器的写法来实现同样的功能:defdecorator(fn):defwrapper():print("----------函数调用前----------")fn()print("----------函数调用后----------")returnwrapper@decorator#@装饰器名称deffoo():print("Iamfoo")if__name__=="__main__":#直接调用被装饰的函数print("直接调用函数:")foo()print()print(f"函数的名称:{foo.__name__}")123456789101112131415161718192021直接调用函数:----------函数调用前----------Iamfoo----------函数调用后----------函数的名称:wrapper123456可以看到,当我们使用@装饰器名称对foo函数进行装饰后,直接调用foo函数就可以达到之前的效果,这无疑更加方便。decorate:英[ˈdekəreɪt]美[ˈdekəreɪt]v.装饰;装潢;点缀;装点;油漆;粉刷;糊墙纸;授给(某人)勋章(或奖章);wrapper:英[ˈræpə(r)]美[ˈræpər]n.封套;(食品等的)包装材料;包装纸;封皮;Q:为什么@decorate不能加()?Adecorator不能加()是因为@decorator是用来应用装饰器的语法糖,而不是直接调用装饰器函数。装饰器通常是函数,它接受一个函数作为参数,然后返回一个新的函数,用于包装原始函数。因此,在使用装饰器时,应该省略()。当我们使用@decorator这种语法时,Python实际上会将被装饰的函数(foo)作为参数传递给decorator函数。然后,decorator函数返回一个包装了foo的新函数wrapper,而不是直接调用decorator。这允许我们在foo函数的前后添加额外的操作,而不需要显式地调用decorator。如果我们在@decorator后面加上(),就变成了直接调用decorator函数,而不是应用装饰器。这通常不是我们的意图,因为我们的目标是装饰foo函数,而不是调用decorator。正确的做法是使用@decorator而不带括号,如我们在提供的示例中所示。2.3装饰器的问题与此同时,我们也发现了一个问题:当我们输出被装饰函数的名字时,它被wrapper函数替代了。如果我们需要获取调用函数的名称,此时输出wrapper是不合适的。Python为了解决这个问题,提供了一个简单的函数functools.wraps。fromfunctoolsimportwrapsdefdecorator(fn)wraps(fn)defwrapper():print("----------函数调用前----------")fn()print("----------函数调用后----------")returnwrapper@decoratordeffoo():print("Iamfoo")if__name__=="__main__":print("直接调用函数:")foo()print()print(f"函数的名称:{foo.__name__}")1234567891011121314151617181920212223直接调用函数:----------函数调用前----------Iamfoo----------函数调用后----------函数的名称:foo123456Q:为什么@wraps要加(fn)?A:@wraps装饰器用于保留原始函数的元信息,例如函数的名称(__name__)、文档字符串(__doc__)等,以确保包装函数在行为上与原始函数一致,并且在使用工具或调试时提供准确的信息。@wraps需要传递原始函数(fn)作为参数,以便它知道应该保留哪个函数的元信息。在我们的示例中,@wraps(fn)装饰了wrapper函数,其中fn是被装饰的原始函数,也就是foo。这告诉@wraps装饰器将wrapper的元信息设置为与foo相关的元信息,以确保wrapper在元信息上与foo一致。如果不使用@wraps装饰器,wrapper函数将继承decorator函数的元信息,而不是foo函数的元信息,这可能导致在使用foo时出现不一致或不正确的元信息。所以,@wraps(fn)用于解决这个问题,确保包装函数的正确元信息。2.4类的装饰器不仅仅只有函数可以构建装饰器,类也可以用于构建装饰器,在构建装饰器类时,需要将原本装饰器函数的部分实现于__call__函数中即可:fromfunctoolsimportwrapsclassDecorate:def__init__(self,fn)->None:self.fn=fndef__call__(self)wraps(self.fn)defwrapper(*args,**kwargs):print("----------函数调用前----------")self.fn(*args,**kwargs)print("----------函数调用后----------")returnwrapper@Decorate#用类来装饰函数,那么函数也变为了类deffoo(param1,param2):print(f"Iamfoo.\n"f"Myparametersare:\n"f"param1:{param1}|param2:{param2}")if__name__=="__main__":#实例化类对象obj=foo()#调用对象的方法obj("参数1","参数2")1234567891011121314151617181920212223242526272829----------函数调用前----------Iamfoo.Myparametersare:param1:参数1|param2:参数2----------函数调用后----------123453.Python注册器——Registry3.1实现一个手动注册器有了装饰器的基础之后,我们现在要走入注册器的世界了。Python的注册器本质上就是用装饰器的原理实现的。Registry提供了字符串到函数或类的映射,这个映射会被整合到一个字典中,开发者只要输入输入相应的字符串(为函数或类起的名字)和参数,就能获得一个函数或初始化好的类。为了说明Registry的好处,我们首先看一下用一个字典存放字符串到函数的映射:deffoo():...deffn(x):returnx**2classExampleClass:...if__name__=="__main__":#创建注册字典register=dict()#开始为函数和类进行注册register[foo.__name__]=fooregister[fn.__name__]=fnregister[ExampleClass.__name__]=ExampleClassprint(register)123456789101112131415161718192021{'foo':,'':at0x000001D730752D30>,'ExmpleClass':}123虽然这样也可以创建一个注册器,但这样做的缺点是我们需要手动维护register这个字典,当增加或删除新的函数或类时,我们需要手动修改register这个字典,因此我们需要一个可以自动维护的字典,在我们定义一个函数或类的时候就自动把它整合到字典中。为了达到这一目的,这里就使用到了装饰器,在装饰器中将我们新定义的函数或类存放的字典中,这个过程我们称之为注册。3.2实现一个半自动注册器3.2.1代码这里我们需要定义一个装饰器类Register,其中核心部分就是成员函数register,它作为一个装饰器函数:classRegister(dict):def__init__(self,*args,**kwargs):super(Register,self).__init__(*args,**kwargs)self._dict=dict()#创建一个字典用于保存注册的可调用对象defregister(self,target):defadd_item(key,value):ifkeyinself._dict:#如果key已经存在print(f"\033[31m"f"WARNING:{value.__name__}已经存在!"f"\033[0m")#进行注册,将key和value添加到字典中self[key]=valuereturnvalue#传入的target可调用-->没有给注册名-->传入的函数名或类名作为注册名ifcallable(target):#key为函数/类的名称;value为函数/类本体returnadd_item(key=target.__name__,value=target)else:#传入的target不可调用-->抛出异常raiseTypeError("\033[31mOnlysupportcallableobject,e.g.functionorclass\033[0m")def__call__(self,target):returnself.register(target)def__setitem__(self,key,value):#将键值对添加到_dict字典中self._dict[key]=valuedef__getitem__(self,key):#从_dict字典中获取注册的可调用对象returnself._dict[key]def__contains__(self,key):#检查给定的注册名是否存在于_dict字典中returnkeyinself._dictdef__str__(self):#返回_dict字典的字符串表示returnstr(self._dict)defkeys(self):#返回_dict字典中的所有键returnself._dict.keys()defvalues(self):#返回_dict字典中的所有值returnself._dict.values()defitems(self):#返回_dict字典中的所有键值对returnself._dict.items()if__name__=="__main__":register_obj=Register()@register_obj#不用再register_obj.register了deffn1_add(a,b):returna+b@register_obj#不用再register_obj.register了deffn2_subject(a,b):returna-b@register_obj#不用再register_obj.register了deffn3_multiply(a,b):returna*b@register_obj#不用再register_obj.register了deffn4_divide(a,b):returna/b#我们再重复定义一个函数@register_obj#不用再register_obj.register了deffn2_subject(a,b):returnb-a#尝试使用register方法注册不可调用的对象try:register_obj.register("传入字符串,它是不可调用的")exceptExceptionase:print(f"报错啦:{e}")print("所有函数均已注册!\n")#我们查看一个注册器中有哪些元素print(f"\033[34mkey\t\tvalue\033[0m")fork,vinregister_obj.items():#fork,vinregister_obj._dict.items()print(f"{k}:\t{v}")12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182833.3实现一个全自动注册器3.3.1代码实现如果不想手动调用register()函数,可以在Register类中添加一个__call__()函数:classRegister(dict):def__init__(self,*args,**kwargs):super(Register,self).__init__(*args,**kwargs)self._dict=dict()#创建一个字典用于保存注册的可调用对象defregister(self,target):defadd_item(key,value):ifkeyinself._dict:#如果key已经存在print(f"\033[34m"f"WARNING:{value.__name__}已经存在!"f"\033[0m")#进行注册,将key和value添加到字典中self[key]=valuereturnvalue#传入的target可调用-->没有给注册名-->传入的函数名或类名作为注册名ifcallable(target):#key为函数/类的名称;value为函数/类本体returnadd_item(key=target.__name__,value=target)else:#传入的target不可调用-->抛出异常raiseTypeError("\033[31mOnlysupportcallableobject,e.g.functionorclass\033[0m")def__call__(self,target):returnself.register(target)def__setitem__(self,key,value):#将键值对添加到_dict字典中self._dict[key]=valuedef__getitem__(self,key):#从_dict字典中获取注册的可调用对象returnself._dict[key]def__contains__(self,key):#检查给定的注册名是否存在于_dict字典中returnkeyinself._dictdef__str__(self):#返回_dict字典的字符串表示returnstr(self._dict)defkeys(self):#返回_dict字典中的所有键returnself._dict.keys()defvalues(self):#返回_dict字典中的所有值returnself._dict.values()defitems(self):#返回_dict字典中的所有键值对returnself._dict.items()if__name__=="__main__":register_obj=Register()@register_obj#不用再register_obj.register了deffn1_add(a,b):returna+b@register_obj#不用再register_obj.register了deffn2_subject(a,b):returna-b@register_obj#不用再register_obj.register了deffn3_multiply(a,b):returna*b@register_obj#不用再register_obj.register了deffn4_divide(a,b):returna/b#我们再重复定义一个函数@register_obj#不用再register_obj.register了deffn2_subject(a,b):returnb-a#尝试使用register方法注册不可调用的对象try:register_obj("传入字符串,它是不可调用的")#register_obj.register("传入字符串,它是不可调用的")#因为我们实现了__call__()函数exceptExceptionase:print(f"报错啦:{e}")print("\n所有函数均已注册!\n")#我们查看一个注册器中有哪些元素print(f"\033[34mkey\t\tvalue\033[0m")fork,vinregister_obj.items():#fork,vinregister_obj._dict.items()print(f"{k}:\t{v}")1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283843.3.2代码分析高清图片链接:图片+源文件Register类继承了内置的dict类,并在其构造函数中初始化一个名为_dict的字典,用于保存注册的可调用对象。register方法用于注册可调用对象。它接受一个参数target,这可以是可调用对象或者是一个注册名。如果target是可调用对象,它会将函数或类名作为注册名。如果target不可调用,它会将传入的注册名与传入的可调用对象关联。add_item内部函数检查可调用对象是否可被调用,如果不可调用会引发异常。它还检查注册名是否已存在,如果存在则发出警告。⭐️__call__方法允许对象实例(register_obj)像函数一样被调用,实际上是将调用委托给register方法。其余的魔法方法(__setitem__,__getitem__,__contains__,__str__,keys(),values(),items())覆盖了字典的行为,以便访问和管理内部的_dict字典。在主程序中,Register类的一个实例register_obj被创建。使用装饰器@register_obj,将多个函数注册到register_obj实例中,每个函数都有一个注册名。如果函数的注册名已经存在,会打印警告信息。最后,程序输出注册器中的所有注册名和可调用对象。4.Python注册器在深度学习中的应用4.1应用场景在深度学习和机器学习中,注册器模式可以有一些有趣的应用,尤其是在构建自定义层、损失函数、优化器或其他模型组件时。以下是在深度学习中使用注册器的一些潜在应用示例:自定义层和模型自定义损失函数自定义优化器数据预处理步骤回调函数这些示例说明了如何使用注册器模式来管理和选择深度学习中的各种组件,从而使模型的构建和训练更加灵活和可配置。通过注册器,我们可以轻松地扩展和定制深度学习模型的各个部分。4.2自定义层和模型我们可以使用注册器来注册自定义神经网络层或模型结构。这在构建自定义神经网络架构时非常有用。例如,我们可以构建一个注册器,用于注册各种自定义层,如卷积层、循环层等。然后,我们可以在模型构建过程中按名称选择并使用这些自定义层。importtorch.nnasnn#实现一个注册器classLayerRegistry:def__init__(self):self.layers=dict()defregister(self,layer_name):#让装饰器接受layer参数defdecorator(layer):#开始注册self.layers[layer_name]=layerreturnlayer#返回注册的层returndecoratordefget_layer(self,layer_name):iflayer_nameinself.layers:returnself.layers[layer_name]else:raiseKeyError(f"未注册的层'{layer_name}'.")#实例化自定义层注册器layer_register=LayerRegistry()#自定义层类@layer_register.register("ConvBNReLU")classConvBNReLU(nn.Module):def__init__(self,in_channels,out_channels,kernel_size,stride,padding,*args,**kwargs):super(ConvBNReLU,self).__init__()self.layers=nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,*args,**kwargs),nn.BatchNorm2d(out_channels),nn.ReLU())defforward(self,x):returnself.layers(x)if__name__=="__main__":#在创建层的使用可以使用注册器中的层example_layer=layer_register.get_layer("ConvBNReLU")#创建具体的层实例specific_layer=example_layer(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1)#打印具体层的信息print(specific_layer)1234567891011121314151617181920212223242526272829303132333435363738394041424344454647ConvBNReLU((layers):Sequential((0):Conv2d(3,64,kernel_size=(3,3),stride=(1,1),padding=(1,1))(1):BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)(2):ReLU()))1234567但是这样似乎没有什么意思,是的,这样的确没有什么意义,一般注册器和配置文件一起用的,下面我们看一下例子:importtorch.nnasnnclassLayerRegistry:#实现一个注册器def__init__(self):self.layers=dict()defregister(self,layer_name):#让装饰器接受layer参数defdecorator(layer):#开始注册self.layers[layer_name]=layerreturnlayer#返回注册的层returndecoratordefget_layer(self,layer_name):iflayer_nameinself.layers:returnself.layers[layer_name]else:raiseKeyError(f"未注册的层'{layer_name}'.")#实例化自定义层注册器layer_register=LayerRegistry()@layer_register.register("ConvBNReLU")classConvBNReLU(nn.Module):#自定义层类def__init__(self,in_channels,out_channels,kernel_size,stride,padding):super(ConvBNReLU,self).__init__()self.layers=nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding),nn.BatchNorm2d(out_channels),nn.ReLU())defforward(self,x):returnself.layers(x)#继续注册其他模块@layer_register.register("BatchNorm2d")classBatchNorm2d(nn.Module):def__init__(self,num_features,*args,**kwargs):super(BatchNorm2d,self).__init__()self.bn=nn.BatchNorm2d(num_features,*args,**kwargs)defforward(self,x):returnself.bn(x)@layer_register.register("ReLU")classReLU(nn.Module):def__init__(self,*args,**kwargs):super(ReLU,self).__init__()self.relu=nn.ReLU(*args,**kwargs)defforward(self,x):returnself.relu(x)@layer_register.register("MaxPooling")classMaxPooling(nn.Module):def__init__(self,kernel_size,stride=1,padding=0):super(MaxPooling,self).__init__()self.maxpool=nn.MaxPool2d(kernel_size,stride=stride,padding=padding)defforward(self,x):returnself.maxpool(x)@layer_register.register("AvgPooling")classAvgPooling(nn.Module):def__init__(self,kernel_size,stride=1,padding=0):super(AvgPooling,self).__init__()self.avgpool=nn.AvgPool2d(kernel_size,stride=stride,padding=padding)defforward(self,x):returnself.avgpool(x)#定义网络配置(cfg)来构建完整的网络cfg=[('ConvBNReLU',3,64,3,1),#传递4个参数('MaxPooling',2,2,0),('ConvBNReLU',64,128,3,1),('MaxPooling',2,2,0),('ConvBNReLU',128,256,3,1),('AvgPooling',4,1,0),]#构建网络classCustomNet(nn.Module):def__init__(self,cfg):super(CustomNet,self).__init__()self.layers=nn.ModuleList()in_channels=3#输入通道数forlayer_cfgincfg:layer_name,*layer_params=layer_cfglayer=layer_register.get_layer(layer_name)iflayer_namein['ConvBNReLU','BatchNorm2d']:self.layers.append(layer(in_channels,*layer_params))in_channels=layer_params[1]else:self.layers.append(layer(*layer_params))#创建完整的网络实例custom_net=CustomNet(cfg)#打印网络结构print(custom_net)123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116CustomNet((layers):ModuleList((0):ConvBNReLU((layers):Sequential((0):Conv2d(3,3,kernel_size=(64,64),stride=(3,3),padding=(1,1))(1):BatchNorm2d(3,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)(2):ReLU()))(1):MaxPooling((maxpool):MaxPool2d(kernel_size=2,stride=2,padding=0,dilation=1,ceil_mode=False))(2):ConvBNReLU((layers):Sequential((0):Conv2d(64,64,kernel_size=(128,128),stride=(3,3),padding=(1,1))(1):BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)(2):ReLU()))(3):MaxPooling((maxpool):MaxPool2d(kernel_size=2,stride=2,padding=0,dilation=1,ceil_mode=False))(4):ConvBNReLU((layers):Sequential((0):Conv2d(128,128,kernel_size=(256,256),stride=(3,3),padding=(1,1))...(avgpool):AvgPool2d(kernel_size=4,stride=1,padding=0))))123456789101112131415161718192021222324252627282930我们可以看到,模型根据cfg变量搭建了模型,那么我们就可以通过读取.cfg/.config/.yaml/.yml/.json等格式的配置文件,从而非常方便的搭建模型或者修改模型。4.3自定义损失函数深度学习任务通常需要特定的损失函数。通过使用注册器,我们可以注册和管理各种自定义损失函数,并在模型编译时选择要使用的损失函数。我们可以仿照上面的例子,使用配置文件的方式注册损失函数。首先,我们需要创建一个类似的注册器类,然后在配置文件中定义不同的损失函数以及它们的参数。下面是一个示例:首先,创建一个注册器类LossRegistry:classLossRegistry:def__init__(self):self.losses=dict()defregister(self,loss_name):defdecorator(loss_fn):self.losses[loss_name]=loss_fnreturnloss_fnreturndecoratordefget_loss(self,loss_name):ifloss_nameinself.losses:returnself.losses[loss_name]else:raiseKeyError(f"未注册的损失函数'{loss_name}'.")#实例化自定义损失函数注册器loss_register=LossRegistry()123456789101112131415161718然后,我们可以创建不同的损失函数并注册它们:@loss_register.register("MSE")classMeanSquaredErrorLoss(nn.Module):defforward(self,input,target):returnnn.functional.mse_loss(input,target)@loss_register.register("CE")classCrossEntropyLoss(nn.Module):defforward(self,input,target):returnnn.functional.cross_entropy(input,target)123456789接下来,我们可以使用配置文件定义不同的损失函数:loss_config=[('MSE',None),#使用默认参数('CE',None),#使用默认参数]#通过配置文件构建损失函数列表loss_functions=[loss_register.get_loss(loss_name)forloss_name,_inloss_config]loss_fn_1=loss_functions[0]()loss_fn_2=loss_functions[1]()print(loss_fn_1)print(loss_fn_2)123456789101112MeanSquaredErrorLoss()CrossEntropyLoss()12现在,我们可以使用loss_functions列表中的损失函数来定义我们的损失函数组合。这样,我们可以根据配置文件轻松切换不同的损失函数,而无需更改网络代码。4.4自定义优化器与损失函数一样,我们可以注册自定义优化器并在模型编译时选择要使用的优化器。这可以让我们尝试不同的优化算法,并根据任务选择最合适的优化器。我们可以使用与损失函数注册类似的方法来注册不同的优化器。首先,创建一个注册器类OptimizerRegistry,然后在配置文件中定义不同的优化器及其参数。以下是一个示例:首先,创建一个注册器类OptimizerRegistry:importtorch.optimasoptimclassOptimizerRegistry:def__init__(self):self.optimizers=dict()defregister(self,optimizer_name):defdecorator(optimizer_fn):self.optimizers[optimizer_name]=optimizer_fnreturnoptimizer_fnreturndecoratordefget_optimizer(self,optimizer_name,model_parameters,*args,**kwargs):ifoptimizer_nameinself.optimizers:returnself.optimizers[optimizer_name](model_parameters,*args,**kwargs)else:raiseKeyError(f"未注册的优化器'{optimizer_name}'.")#实例化自定义优化器注册器optimizer_register=OptimizerRegistry()1234567891011121314151617181920然后,我们可以创建不同的优化器并注册它们:@optimizer_register.register("SGD")classSGDOptimizer:def__init__(self,model_parameters,lr,momentum):self.optimizer=optim.SGD(model_parameters,lr=lr,momentum=momentum)defstep(self):self.optimizer.step()defzero_grad(self):self.optimizer.zero_grad()@optimizer_register.register("Adam")classAdamOptimizer:def__init__(self,model_parameters,lr,betas):self.optimizer=optim.Adam(model_parameters,lr=lr,betas=betas)defstep(self):self.optimizer.step()defzero_grad(self):self.optimizer.zero_grad()123456789101112131415161718192021接下来,我们可以使用配置文件定义不同的优化器:optimizer_config=[('SGD',{'lr':0.01,'momentum':0.9}),('Adam',{'lr':0.001,'betas'0.9,0.999)})]#通过配置文件构建优化器列表optimizers=[optimizer_register.get_optimizer(optimizer_name,model_parameters,**params)foroptimizer_name,paramsinoptimizer_config]foroptimizerinoptimizersptimizer.zero_grad()#清空梯度optimizer.step()#下一步1234567891011现在,我们可以使用optimizers列表中的不同优化器来为模型定义不同的优化器。这使我们可以根据配置文件轻松切换不同的优化器,而无需更改网络代码。4.5数据预处理步骤在深度学习中,数据预处理对于模型性能非常重要。我们可以注册各种数据预处理步骤,例如图像增强、标准化方法等,然后根据需要应用它们。数据预处理是深度学习中的重要步骤之一。我们可以使用与前面示例类似的方法来注册不同的数据预处理步骤。首先,创建一个注册器类PreprocessingRegistry,然后在配置文件中定义不同的数据预处理步骤及其参数。以下是一个示例:首先,创建一个注册器类PreprocessingRegistry:classPreprocessingRegistry:def__init__(self):self.preprocessing_steps=dict()defregister(self,step_name):defdecorator(preprocessing_fn):self.preprocessing_steps[step_name]=preprocessing_fnreturnpreprocessing_fnreturndecoratordefget_preprocessing_step(self,step_name,*args,**kwargs):ifstep_nameinself.preprocessing_steps:returnself.preprocessing_steps[step_name](*args,**kwargs)else:raiseKeyError(f"未注册的数据预处理步骤'{step_name}'.")#实例化自定义数据预处理步骤注册器preprocessing_register=PreprocessingRegistry()123456789101112131415161718然后,我们可以创建不同的数据预处理步骤并注册它们:importnumpyasnp@preprocessing_register.register("Normalize")classNormalize:def__init__(self,mean,std):self.mean=meanself.std=stddef__call__(self,data):return(data-self.mean)/self.std@preprocessing_register.register("RandomCrop")classRandomCrop:def__init__(self,crop_size):self.crop_size=crop_sizedef__call__(self,data):h,w,c=data.shapex=np.random.randint(0,h-self.crop_size)y=np.random.randint(0,w-self.crop_size)returndata[x:x+self.crop_size,y:y+self.crop_size,:]@preprocessing_register.register("Resize")classResize:def__init__(self,target_size):self.target_size=target_sizedef__call__(self,data):returncv2.resize(data,(self.target_size,self.target_size))1234567891011121314151617181920212223242526272829接下来,我们可以使用配置文件定义不同的数据预处理步骤:preprocessing_config=[('Normalize',{'mean':[0.485,0.456,0.406],'std':[0.229,0.224,0.225]}),('RandomCrop',{'crop_size':224}),('Resize',{'target_size':256})]#通过配置文件构建数据预处理步骤列表preprocessing_steps=[preprocessing_register.get_preprocessing_step(step_name,**params)forstep_name,paramsinpreprocessing_config]12345678现在,我们可以使用preprocessing_steps列表中的不同数据预处理步骤来预处理我们的数据。这使我们可以根据配置文件轻松切换不同的数据预处理步骤,而无需更改数据处理代码。#假设我们有一张原始图像original_image=cv2.imread('./lena.png')#读取原始图像#应用数据预处理步骤preprocessed_data=original_image.copy()#创建副本以保存经过预处理的数据forpreprocessing_stepinpreprocessing_steps:preprocessed_data=preprocessing_step(preprocessed_data)#preprocessed_data现在包含了经过预处理的数据print(preprocessed_data.shape)#(256,256,3)#现在可以将preprocessed_data用于深度学习模型的训练或推理123456789101112134.6回调函数在深度学习训练中,回调函数用于执行各种操作,如保存模型检查点、记录训练指标、可视化等。我们可以使用注册器来注册各种回调函数,并在模型训练时选择适当的回调函数。首先,创建一个回调函数注册器类CallbackRegistry:classCallbackRegistry:def__init__(self):self.callbacks=dict()defregister(self,callback_name):defdecorator(callback_fn):self.callbacks[callback_name]=callback_fnreturncallback_fnreturndecoratordefget_callback(self,callback_name):ifcallback_nameinself.callbacks:returnself.callbacks[callback_name]else:raiseKeyError(f"未注册的回调函数'{callback_name}'.")#实例化自定义回调函数注册器callback_register=CallbackRegistry()123456789101112131415161718接下来,我们可以创建不同的回调函数并注册它们,例如,一个用于保存模型检查点的回调函数和一个用于记录训练指标的回调函数:importos@callback_register.register("ModelCheckpoint")classModelCheckpointCallback:def__init__(self,save_dir,save_freq):self.save_dir=save_dirself.save_freq=save_freqself.best_accuracy=0.0self.model=Nonedefon_epoch_end(self,model,epoch,val_accuracy):ifval_accuracy>self.best_accuracy:self.best_accuracy=val_accuracyself.model=modelmodel_save_path=os.path.join(self.save_dir,f"best_model_epoch_{epoch}.pth")torch.save(model.state_dict(),model_save_path)@callback_register.register("RecordMetrics")classRecordMetricsCallback:def__init__(self,log_file):self.log_file=log_filedefon_epoch_end(self,epoch,train_loss,val_loss,train_accuracy,val_accuracy):withopen(self.log_file,'a')asfile:file.write(f"Epoch{epoch}-TrainLoss:{train_loss},ValLoss:{val_loss},TrainAccuracy:{train_accuracy},ValAccuracy:{val_accuracy}\n")12345678910111213141516171819202122232425现在,我们可以在模型训练时选择适当的回调函数,并在适当的时机调用它们:#在训练循环中选择适当的回调函数forepochinrange(num_epochs):train_loss,train_accuracy=train_one_epoch()val_loss,val_accuracy=validate()#在每个epoch结束时调用回调函数forcallback_name,callback_instanceincallback_instances.items():callback_instance.on_epoch_end(epoch,train_loss,val_loss,train_accuracy,val_accuracy)12345678上述示例中,我们注册了两种回调函数,一个用于保存模型检查点,另一个用于记录训练指标。然后,在训练循环中,在每个epoch结束时调用这些回调函数,以执行相应的操作。我们可以根据需要定义更多的回调函数,并根据模型训练的具体需求来选择和调用它们。知识来源【Python】Python的Registry机制
|
|