ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

FSL-GNN代码解读

2021-10-06 15:04:19  阅读:293  来源: 互联网

标签:enc nn models metric args FSL 解读 test GNN


FSL-GNN代码解读

main.py(主函数)

1、加载数据集:

train_loader = generator.Generator(args.dataset_root, args, partition='train', dataset=args.dataset)

2、初始化或加载模型:

enc_nn = models.load_model('enc_nn', args, io)
metric_nn = models.load_model('metric_nn', args, io)

if enc_nn is None or metric_nn is None:
	enc_nn, metric_nn = models.create_models(args=args)
    softmax_module = models.SoftmaxModule()

models.create_models(args=args) : in models.py

def create_models(args):
    print (args.dataset)
    if 'omniglot' == args.dataset:
        enc_nn = EmbeddingOmniglot(args, 64)
    elif 'mini_imagenet' == args.dataset:
        enc_nn = EmbeddingImagenet(args, 128)
    else:
        raise NameError('Dataset ' + args.dataset + ' not knows')
    return enc_nn, MetricNN(args, emb_size=enc_nn.emb_size)

class EmbeddingOmniglot():				# 特征提取
class EmbeddingImagenet():				# 略

class MetricNN(nn.Module):
	if self.metric_network == 'gnn_iclr_nl':……		# 正常的网络
	self.gnn_obj = gnn_iclr.GNN_nl()			# in gnn_iclr.py
	
	elif self.metric_network == 'gnn_iclr_active':……	# 主动学习
	self.gnn_obj = gnn_iclr.GNN_active()# in gnn_iclr.py
	
class SoftmaxModule():		# 线性分类

class GNN_nl(nn.Module) & class GNN_active(nn.Module) : in gnn_iclr.py

class GNN_nl(nn.Module):		# 图网络主要部分
	class Wcompute(nn.Module)	# W邻接矩阵计算
    class Gconv(nn.Module)		# 组图
		def gmul(input)		# 更新图节点特征,W直接返回

3、训练

# 权重衰减
weight_decay = 1e-6

# 优化器
opt_enc_nn = optim.Adam(enc_nn.parameters(), lr=args.lr, weight_decay=weight_decay)
opt_metric_nn = optim.Adam(metric_nn.parameters(), lr=args.lr, weight_decay=weight_decay)

# 梯度置零,也就是把loss关于weight的导数变成0
opt_enc_nn.zero_grad()
opt_metric_nn.zero_grad()

# 训练
loss_d_metric = train_batch(
	model=[enc_nn, metric_nn, 
	softmax_module],
	data=[batch_x, label_x, batches_xi, labels_yi, oracles_yi, hidden_labels])

# 更新参数
opt_enc_nn.step()
opt_metric_nn.step()

# 自适应参数
adjust_learning_rate(optimizers=[opt_enc_nn, opt_metric_nn], lr=args.lr, iter=batch_idx)

# 显示训练中loss的更新
if batch_idx % args.log_interval == 0:
	display_str = 'Train Iter: {}'.format(batch_idx)
	display_str += '\tLoss_d_metric: {:.6f}'.format(total_loss/counter)
	io.cprint(display_str)

# 测试
def test_one_shot(args, model, test_samples=5000, partition='test') 定义于 test.py 中
val_acc_aux = test.test_one_shot	# 验证集上测试
test_acc_aux = test.test_one_shot	# 测试集上测试
test.test_one_shot(					# 训练集上测试
	args, 
	model=[enc_nn, metric_nn, softmax_module],
	test_samples=test_samples, 
	partition='train')				

# 测试完毕,将模型设置回训练状态
enc_nn.train()
metric_nn.train()

# 若在验证集上的效果继续变好,则更新
if val_acc_aux is not None and val_acc_aux >= val_acc:

# 保存模型
torch.save(enc_nn, 'checkpoints/%s/models/enc_nn.t7' % args.exp_name)
torch.save(metric_nn, 'checkpoints/%s/models/metric_nn.t7' % args.exp_name)

# 全部训练完毕后进行测试
test.test_one_shot

标签:enc,nn,models,metric,args,FSL,解读,test,GNN
来源: https://www.cnblogs.com/SethDeng/p/15371286.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有