标签:__ Data self DataLoader label dataset PyTorch Dataset data
import torch import torch.utils.data.dataset as Dataset import numpy as np import torch.utils.data.dataloader as DataLoader Data = np.asarray([[1, 2], [3, 4], [5, 6], [7, 8]]) Label = np.asarray([[0], [1], [0], [2]]) class SubDataSet(Dataset.Dataset): # 定义数据类型和标签 def __init__(self, Data, Label): self.Data = Data self.Label = Label # 返回数据集的大小 def __len__(self): return len(self.Data) # 得到数据内容和标签,一个一个返回的 def __getitem__(self, index): data = torch.Tensor(self.Data[index]) label = torch.Tensor(self.Label[index]) return data, label dataset = SubDataSet(Data, Label) print(dataset) print(f"dataset size: {dataset.__len__()}") print(dataset.__getitem__(0)) # data, label print(dataset[0]) # __getitem__(0) == dataset[0] # batch_size表示一次性从dataset取多少个作为一个批次大小、 # data和label是一一对应 # shuffle表示每个epoch是否乱序
# num_workers表示并行的线程数 dataloader = DataLoader.DataLoader(dataset, batch_size = 2,shuffle = False, num_workers = 2) print(enumerate(dataloader)) for i, item in enumerate(dataloader): data, label = item print(f"data: {data} \n, label: {label} \n")
标签:__,Data,self,DataLoader,label,dataset,PyTorch,Dataset,data 来源: https://www.cnblogs.com/xjtu-yzk/p/16369302.html
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。