ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

MeLU模型复现

2022-09-13 00:01:52  阅读:217  来源: 互联网

标签:Linear weight 模型 fast MeLU 复现 grad self


MeLU算是推荐系统冷启动中非常经典的一个模型,在近两年很多冷启动相关的论文都拿它做baseline。以下总结一些个人觉得值得关注的地方。代码参考自MELU_pytorch

class Linear(nn.Linear):
    def __init__(self, in_features, out_features):
        super(Linear, self).__init__(in_features, out_features)
        self.weight.fast = None
        self.bias.fast = None
    def forward(self, x):
        if self.weight.fast is not None and self.bias.fast is not None:
            out = F.linear(x, self.weight.fast, self.bias.fast)
        else:
            out = super(Linear, self).forward(x)
        return out

首先是Linear的重写,因为MeLU中涉及到元学习中的MAML,会涉及到两个梯度,普通的Linear无法实现这种操作,因此在原有的Linear上又加入了fast,fast是inner loop更新后的参数。

fast_parameters = []
for k, weight in enumerate(model.final_part.parameters()):
    if weight.fast is None:
        weight.fast = weight - args.lr_inner * grad[k]
    else:
        weight.fast = weight.fast - args.lr_inner * grad[k]
    fast_parameters.append(weight.fast)

inner loop的更新,这里只更新除了用户与物品属性之外的参数的embedding。

logits_q = model(x_qry[i])
loss_q = F.mse_loss(logits_q, y_qry[i])
loss_after.append(loss_q.item())
task_grad_test = torch.autograd.grad(loss_q, model.parameters())

for g in range(len(task_grad_test)):
    meta_grad[g] += task_grad_test[g].detach()

meta_optimizer.zero_grad()

for c, param in enumerate(model.parameters()):
    param.grad = meta_grad[c] / float(args.tasks_per_metaupdate)
    param.grad.data.clamp_(-10, 10)

meta_optimizer.step()

outer loop的更新,这里包括梯度截断等操作。
MeLU总体的代码还是比较容易看懂的,代码中一半都是用来处理数据,实际的模型代码并不长,核心的部分就是上述内容。

MeLU有着比较明显的优缺点,优点是它使用了元学习中的MAML,可以为冷启动用户或物品生成一个较为通用的表示,使得仅使用少量数据就可以使冷启动用户和物品快速适应推荐系统。缺点是在推荐时,仅使用到了用户和物品的相关属性,没有使用到富有价值的用户历史交互序列这一信息,而且MeLU为每一类属性生成的embedding都是固定的,有可能存在相同属性的用户偏爱不同类型的物品,这种情况下MeLU效果就会很糟糕。

标签:Linear,weight,模型,fast,MeLU,复现,grad,self
来源: https://www.cnblogs.com/ambition-hhn/p/16687762.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有