ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

torch_geometric.data 自建数据集

2021-04-04 19:30:58  阅读:1177  来源: 互联网

标签:idx graph self torch transform geometric data


前言

博客大部分都是搬运文档,是文档的翻译版,没什么意思。精细的内容还要结合文档去看。
这个只是给你大致概念不至于看文档看的头昏眼花不是手把手教。
文档:
https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html

一针见血

数据集有两种,一个只存一个图的ImMemory类型,另一个是要存多个图DataSet的,需要额外实现len和get函数。
ImMemory要实现的基本上就是官网给的:

import torch
from torch_geometric.data import InMemoryDataset


class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # Download to `self.raw_dir`.

    def process(self):
        # Read data into huge `Data` list.
        data_list = [...]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

另一种无非再在继承类那地方改成torch_geometric.data.Dataset,继承这个类就是了,外加重写两个函数

	 def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
        return data

函数名称用途

  • download写怎么获得raw的dataset,显然我们要自定义数据集,往往是在本地就有的,这个可以直接pass return
  • raw_file_names这个函数给出多张graph所存的路径,假设有graph a,graph b,那么这里return的就应当是两幅图对应的文件名。
  • processed_paths写处理所有graph过后所存的路径,道理同raw_file_names
  • process处理数据,成规定格式。

规定的什么格式?

from torch_geometric.data import Data这个Data类型,就是你要处理成的格式。
一下内容可以在Data.py里面找到内容,我只是大体提一下。

人家必须要有的属性是:

  • y: label就是了,直接给one hot或者给数字类型的都行。
  • x: 节点属性
  • edge_index: 边关系,可以多种,一种是(id,id)的列表,一种是邻接表。都行。
    处理出来以上数据后,可以直接
# contiguous这个是(id,id)这种方式需要加的
graph=Data(x=features,edge_index=network.t().contiguous(),y=labels)

这样一个基本的graph的Data就完成了。
但其实还可以加其他的属性,就直接在他后面加就行:

# 加train_idx
train_idx = torch.tensor([id2inter_id[idx] for idx in herb_with_label_id], dtype=torch.long)
graph=Data(x=features,edge_index=network.t().contiguous(),y=labels)
graph.train_idx = train_idx

实现完自己的数据集运行后会出现什么?

会直接出现这些,processed就是存放运行process函数后的数据,raw是原始数据。
在这里插入图片描述

最后再给个我自己用的例子

import torch
import pickle
from torch_geometric.data import InMemoryDataset, Data

class TCMDataSet(InMemoryDataset):
    def __init__(self,root,name,feature_size,transform=None,pre_transform=None):
        self.feature_size=feature_size
        print(f'feature size: {feature_size}')

        super(TCMDataSet, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['tcm_dataset.pt',]

    @property
    def processed_file_names(self):
        return ['tcm_dataset.pt',]

    def download(self):
        pass

    def process(self):

        # do processing, get x, y, edge_index ready.   

        graph=Data(x=features,edge_index=network.t().contiguous(),y=labels)
        train_idx = torch.tensor([id2inter_id[idx] for idx in herb_with_label_id], dtype=torch.long)
        #加入新的属性
        graph.train_idx = train_idx

        if self.pre_filter is not None:
            graph = [data for data in graph if self.pre_filter(data)]

        if self.pre_transform is not None:
            graph = [self.pre_transform(data) for data in graph]

        data, slices = self.collate([graph])
        torch.save((data, slices), self.processed_paths[0])

标签:idx,graph,self,torch,transform,geometric,data
来源: https://blog.csdn.net/Yonggie/article/details/115431841

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有