|
本文作者为 360 奇舞团前端开发工程师随着AI的火热发展,涌现了一些AI模特换装的前端工具(比如weshop网站),他们是怎么实现的呢?使用了什么技术呢?下文我们就来探索一下其实现原理。总体的实现流程如下:我们将下图中的这个模特的图片,使用Segment Anything Model在后端分割图层,然后将分割后的图层mask信息返回给前端处理。在前端中选择需要保留的图层信息(如下图中的模特的衣服图层),然后将选中的图层信息交给后端中的Stable Diffusion处理。后端使用原始图片结合选中的图层蒙版图片结合图生图的功能,可以实现weshop等网站的模特换衣等功能。本文先简单介绍一下使用SAM智能图层分割,然后主要介绍一下在前端中怎么对分割后的图层进行选择的处理流程。使用SAM识别图层首先我们需要对图层进行分割,在SAM出来之前,我们需要使用PS将模特的衣服选取出来,然后倒出衣服的模板,然后再使用其他工具进行替换。但是现在有了SAM后,我们可以对图片中的事物进去只能区分,获取各种物品的图层。Segment Anything Model(SAM)是一种尖端的图像分割模型,可以进行快速分割,为图像分析任务提供无与伦比的多功能性。SAM 的先进设计使其能够在无需先验知识的情况下适应新的图像分布和任务,这一功能称为零样本传输。SAM 使任何人都可以在不依赖标记数据的情况下为其数据创建分段掩码。要深入了解 Segment Anything 模型和 SA-1B 数据集,请访问Segment Anything 网站(https://segment-anything.com/)并查看研究论文Segment Anything(https://arxiv.org/abs/2304.02643)。我们使用SAM进行图像分割,将一个图片中的物体分割成不同的部分。defmask2rle(img):'''img:numpyarray,1-mask,0-backgroundReturnsrunlengthasstringformated'''pixels=img.T.flatten()pixels=np.concatenate([[0],pixels,[0]])runs=np.where(pixels[1:]!=pixels[:-1])[0]+1runs[1::2]-=runs[::2]return''.join(str(x)forxinruns)deftrans_anns(anns):iflen(anns)==0:returnsorted_anns=sorted(anns,key=(lambdax:x['area']),reverse=False)list=[]index=0#对每个注释进行处理foranninsorted_anns:bool_array=ann['segmentation']#将boolean类型的数组转换为int类型int_array=bool_array.astype(int)#转化为RLE格式rle=mask2rle(int_array)list.append({"index":index,"mask":rle})index+=1returnlistimage=cv2.imread('')importsyssys.path.append('')fromsegment_anythingimportsam_model_registry,SamAutomaticMaskGenerator,SamPredictor#sam模型路径sam_checkpoint=''#根据下载的模型,设置对应的类型model_type="vit_h"#device="cuda"sam=sam_model_registry[model_type](checkpoint=sam_checkpoint)#sam.to(device=device)mask_generator=SamAutomaticMaskGenerator(sam)masks=mask_generator.generate(image)#处理sam返回的图层信息mask_list=trans_anns(masks)mask_obj={"height":image.shape[0],"width":image.shape[1],"mask_list":mask_list}importjsonprint(json.dumps(mask_obj))运行以上python代码之前,需要配置sam的python环境,具体的配置描述请查看sam的官方描述。我们通过以上代码,将我们提供的图片,通过SAM处理后,返回图层分割数据。在trans_anns方法中,将图层按照area从小到大的顺序排序。遍历各个图层,将boolean类型的数组转换为 0 1 int类型,然后对二维numpy array类型的0 1二进制mask图像转换为RLE格式。RLE是一种简单的无损数据压缩算法,通常用于表示连续的相同值的序列。RLE编码的字符串通常用于在图像分割等任务中存储和传输二进制掩码信息,以便更有效地表示图像中的目标区域。并且方便数据压缩和传输。我们参照的这种编解码方式。也可以使用coco RLE的编解码方式。将编码后的各图层信息存储到list中,就可以通过接口传输给前端处理了。前端选择图层下面这些是本文的重点,在前端将刚才解析后的mask_list信息展示,并可以通过交互选取需要保留的模版,并生成最终合并选取的mask生成一个需要保留的服装模版。body中的基本组件为保存id为layer-box的div组件作为各个mask的父组件,用于查找和管理各个mask的隐藏和展示。其子组件中的第一个标签是展示原始的模特图片的。id为save的组件在点击时可以处理保存选中的各个mask为一个新的mask图片,用于处理图片合成。id为mergedCanvas的canvas是进行图片合成和展示合成后的图片的。解析SAM处理后的mask_list信息/***rle格式图片信息转换为mask信息*/functionrle2mask(mask_rle,shape=[500,500]){/*mask_rle:run-lengthasstringformatted(startlength)shape:[width,height]ofarraytoreturnReturnsanarray,1-mask,0-background*/consts=mask_rle.split("");letstarts=s.filter((_,index)=>index%2===0).map(Number);constlengths=s.filter((_,index)=>index%2!==0).map(Number);starts=starts.map(start=>start-1);constends=starts.map((start,index)=>start+lengths[index]);constimg=newArray(shape[0]*shape[1]).fill(0);for(leti=0;inewArray(shape[0]).fill(0));for(leti=0;itransformMaskImage(item,res.width,res.height));res是sam处理后返回的图层信息(由于篇幅限制,已省略,详情请看demo(https://github.com/yuhao1128/AI-model-mask-select-demo/blob/main/index.html)中的数据)。遍历mask_list,使用canvas保存各个mask的信息。由于前面sam处理后的mask_list是经过压缩编码的,所以在rle2mask方法中对rle编码后的数据解码为 0/1二维数组的格式。rle2mask中的解码方式请参考这种解码(https://www.kaggle.com/code/pestipeti/decoding-rle-masks)方式。然后遍历二维数组,将值为1的点填充颜色,此处是填充的rgba为"#4169eb"的颜色,可以根据需要自己修改为其他的颜色。此处填充的颜色会在下文中鼠标移动到mask上面时,在mask展示的时候呈现此颜色。最后在layers中存储各个mask的base64格式的图片信息和二维数组信息。将各个mask添加到图层constbox=document.querySelector("#layer-box");constbaseStyle="width:100%;height:100%;position:absolute;";//将各个mask添加为layer-box的子组件,并隐藏mask的展示layers.forEach((ele)=>{constimage=document.createElement("img");image.src=ele.imageData;image.style=`${baseStyle}opacity:0`;image.className="layer";box.append(image);});将各个mask添加的图片添加为layer-box组件的子组件,并且设置opacity为0,先隐藏这些mask的展示,在下文会监听鼠标的位置,通过设置mask的opacity属性来展示mask。监听鼠标的位置和点击//鼠标移入mask组件的区域时,展示maskbox.addEventListener("mousemove",(e)=>{const{clientX,clientY}=e;constX=box.getBoundingClientRect().left+document.body.scrollLeft;constY=box.getBoundingClientRect().top+document.body.scrollTop;constx=parseInt(res.width*(clientX-X)/box.getBoundingClientRect().width)consty=parseInt(res.height*(clientY-Y)/box.getBoundingClientRect().height)constallLayers=box.querySelectorAll(".layer");constindex=layers.findIndex((item)=>item.matrix.[y].[x]);allLayers.forEach((ele,i)=>{if(i===index){ele.style=`${baseStyle}opacity:0.7`;}else{//已经选中的不需要隐藏if(selectedIndexList.indexOf(i)===-1){ele.style=`${baseStyle}opacity:0`;}}});});//鼠标移出mask组件的区域时,隐藏maskbox.addEventListener("mouseout",(e)=>{console.log('mouseoutselectedIndexList',selectedIndexList);constallLayers=box.querySelectorAll(".layer");allLayers.forEach((ele,i)=>{//只有选中的才会展示if(selectedIndexList.indexOf(i)>-1){ele.style=`${baseStyle}opacity:0.7`;}else{ele.style=`${baseStyle}opacity:0`;}});});//用户点击时,保存用户选中的mask的indexbox.addEventListener("mousedown",(e)=>{const{clientX,clientY}=e;constX=box.getBoundingClientRect().left+document.body.scrollLeft;constY=box.getBoundingClientRect().top+document.body.scrollTop;constx=parseInt(res.width*(clientX-X)/box.getBoundingClientRect().width)consty=parseInt(res.height*(clientY-Y)/box.getBoundingClientRect().height)constindex=layers.findIndex((item)=>item.matrix.[y].[x]);if(selectedIndexList.indexOf(index)===-1){//保存点击选中的元素indexselectedIndexList.push(index)}});box就是上文的layer-box,是各个mask的父组件。layer-box监听鼠标的move事件和click事件,当move到对应的mask上时,将mask展示,移除mask时,隐藏mask。mask在list中是从小到大的顺序,所以遍历匹配mask时,会优先匹配面积小的组件,方便灵活选择。当点击mask的位置时,保存mask在list中的index到selectedIndexList中,方便后续导出保存选择,并高亮展示选中的mask。选中的mask合成图片//存储各个图层图片信息letlayers=[]//选择layer的indexconstselectedIndexList=[]//点击保存document.getElementById('save').onclick=function(){constimages=[];selectedIndexList.forEach(index=>{images.push(layers[index].imageData)})drawing(images)}/***图片合成*/functiondrawing(images){constcanvas=document.getElementById("mergedCanvas");canvas.width=500;//设置canvas宽canvas.height=500;//设置canvas高constctx=canvas.getContext("2d");letloadedImages=0;images.forEach(function(src){constimg=newImage();img.src=src;img.onload=function(){loadedImages++;//绘制每张图片到canvas上ctx.drawImage(img,0,0);//如果所有图片都加载完成,保存合并后的图片if(loadedImages===images.length){//获取图片的像素数据constimageData=ctx.getImageData(0,0,img.width,img.height);constdata=imageData.data;//转换为黑白效果for(leti=0;i
|
|