# 读代码：geo_prior（2）

2021-09-24 21:32:52  阅读：186  来源： 互联网

## utils.py

### encode_loc_time

``````def encode_loc_time(loc_ip, date_ip, concat_dim=1, params=None):
# assumes inputs location and date features are in range -1 to 1
# location is lon, lat

if params['loc_encode'] == 'encode_cos_sin':
feats = torch.cat((torch.sin(math.pi*loc_ip), torch.cos(math.pi*loc_ip)), concat_dim)

elif params['loc_encode'] == 'encode_3D':
# X, Y, Z in 3D space
if concat_dim == 1:
cos_lon = torch.cos(math.pi*loc_ip[:, 0]).unsqueeze(-1)
sin_lon = torch.sin(math.pi*loc_ip[:, 0]).unsqueeze(-1)
cos_lat = torch.cos(math.pi*loc_ip[:, 1]).unsqueeze(-1)
sin_lat = torch.sin(math.pi*loc_ip[:, 1]).unsqueeze(-1)
if concat_dim == 2:
cos_lon = torch.cos(math.pi*loc_ip[:, :, 0]).unsqueeze(-1)
sin_lon = torch.sin(math.pi*loc_ip[:, :, 0]).unsqueeze(-1)
cos_lat = torch.cos(math.pi*loc_ip[:, :, 1]).unsqueeze(-1)
sin_lat = torch.sin(math.pi*loc_ip[:, :, 1]).unsqueeze(-1)
feats = torch.cat((cos_lon*cos_lat, sin_lon*cos_lat, sin_lat), concat_dim)

elif params['loc_encode'] == 'encode_none':
feats = loc_ip

else:
print('error - no loc feat type defined')

if params['use_date_feats']:
if params['date_encode'] == 'encode_cos_sin':
feats_date = torch.cat((torch.sin(math.pi*date_ip.unsqueeze(-1)),
torch.cos(math.pi*date_ip.unsqueeze(-1))), concat_dim)
elif params['date_encode'] == 'encode_none':
feats_date = date_ip.unsqueeze(-1)
else:
print('error - no date feat type defined')
feats = torch.cat((feats, feats_date), concat_dim)

return feats``````

### BalanceSampler类

``````class BalancedSampler(Sampler):
# sample "evenly" from each from class
def __init__(self, classes, num_per_class, use_replace=False, multi_label=False):
self.class_dict = {}
self.num_per_class = num_per_class
self.use_replace = use_replace
self.multi_label = multi_label

if self.multi_label:
self.class_dict = classes
else:
# standard classification
un_classes = np.unique(classes)
for cc in un_classes:
self.class_dict[cc] = []

for ii in range(len(classes)):
self.class_dict[classes[ii]].append(ii)

if self.use_replace:
self.num_exs = self.num_per_class*len(un_classes)
else:
self.num_exs = 0
for cc in self.class_dict.keys():
self.num_exs += np.minimum(len(self.class_dict[cc]), self.num_per_class)

def __iter__(self):
indices = []
for cc in self.class_dict:
if self.use_replace:
indices.extend(np.random.choice(self.class_dict[cc], self.num_per_class).tolist())
else:
indices.extend(np.random.choice(self.class_dict[cc], np.minimum(len(self.class_dict[cc]),
self.num_per_class), replace=False).tolist())
# in the multi label setting there will be duplictes at training time
np.random.shuffle(indices)  # will remain a list
return iter(indices)

def __len__(self):
return self.num_exs``````

### convert_loc_to_tensor

``````def convert_loc_to_tensor(x, device=None):
# intput is in lon {-180, 180}, lat {90, -90}
xt = x.astype(np.float32)
xt[:,0] /= 180.0
xt[:,1] /= 90.0
xt = torch.from_numpy(xt)
if device is not None:
xt = xt.to(device)
return xt``````

### 一些距离函数

``````def distance_pw_euclidean(xx, yy):
# equivalent to scipy.spatial.distance.cdist
dist = np.sqrt((xx**2).sum(1)[:, np.newaxis] - 2*xx.dot(yy.transpose()) + ((yy**2).sum(1)[np.newaxis, :]))
return dist

# input should be in radians
# output is in km's if radius = 6372.8

d_lon = xx[:, 0][..., np.newaxis] - yy[:, 0][np.newaxis, ...]
d_lat = xx[:, 1][..., np.newaxis] - yy[:, 1][np.newaxis, ...]

cos_term = np.cos(xx[:,1])[..., np.newaxis]*np.cos(yy[:, 1])[np.newaxis, ...]
dist = np.sin(d_lat/2.0)**2 + cos_term*np.sin(d_lon/2.0)**2
dist = 2 * radius * np.arcsin(np.sqrt(dist))
return dist

def euclidean_distance(xx, yy):
return np.sqrt(((xx - yy)**2).sum(1))

# assumes shape N x 2, where col 0 is lat, and col 1 is lon
# input should be in radians
# output is in km's if radius = 6371.4
# note that SKLearns haversine distance is [latitude, longitude] not [longitude, latitude]

d_lon = xx[:, 0] - yy[0]
d_lat = xx[:, 1] - yy[1]

cos_term = np.cos(xx[:,1])*np.cos(yy[1])
dist = np.sin(d_lat/2.0)**2 + cos_term*np.sin(d_lon/2.0)**2
dist = 2 * radius * np.arcsin(np.sqrt(dist + 1e-16))

return dist``````

2. 关于本站的所有留言、评论、转载及引用，纯属内容发起人的个人观点，与本站观点和立场无关；
3. 关于本站的所有言论和文字，纯属内容发起人的个人观点，与本站观点和立场无关；
4. 本站文章均是网友提供，不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属；如您发现该文章侵犯了您的权益，可联系我们第一时间进行删除；
5. 本站为非盈利性的个人网站，所有内容不会用来进行牟利，也不会利用任何形式的广告来间接获益，纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。