ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

读代码:geo_prior(2)

2021-09-24 21:32:52  阅读:138  来源: 互联网

标签:loc cos 代码 class prior np sin geo self


utils.py

encode_loc_time 

def encode_loc_time(loc_ip, date_ip, concat_dim=1, params=None):
    # assumes inputs location and date features are in range -1 to 1
    # location is lon, lat

    if params['loc_encode'] == 'encode_cos_sin':
        feats = torch.cat((torch.sin(math.pi*loc_ip), torch.cos(math.pi*loc_ip)), concat_dim)

    elif params['loc_encode'] == 'encode_3D':
        # X, Y, Z in 3D space
        if concat_dim == 1:
            cos_lon = torch.cos(math.pi*loc_ip[:, 0]).unsqueeze(-1)
            sin_lon = torch.sin(math.pi*loc_ip[:, 0]).unsqueeze(-1)
            cos_lat = torch.cos(math.pi*loc_ip[:, 1]).unsqueeze(-1)
            sin_lat = torch.sin(math.pi*loc_ip[:, 1]).unsqueeze(-1)
        if concat_dim == 2:
            cos_lon = torch.cos(math.pi*loc_ip[:, :, 0]).unsqueeze(-1)
            sin_lon = torch.sin(math.pi*loc_ip[:, :, 0]).unsqueeze(-1)
            cos_lat = torch.cos(math.pi*loc_ip[:, :, 1]).unsqueeze(-1)
            sin_lat = torch.sin(math.pi*loc_ip[:, :, 1]).unsqueeze(-1)
        feats = torch.cat((cos_lon*cos_lat, sin_lon*cos_lat, sin_lat), concat_dim)

    elif params['loc_encode'] == 'encode_none':
        feats = loc_ip

    else:
        print('error - no loc feat type defined')


    if params['use_date_feats']:
        if params['date_encode'] == 'encode_cos_sin':
            feats_date = torch.cat((torch.sin(math.pi*date_ip.unsqueeze(-1)),
                                    torch.cos(math.pi*date_ip.unsqueeze(-1))), concat_dim)
        elif params['date_encode'] == 'encode_none':
            feats_date = date_ip.unsqueeze(-1)
        else:
            print('error - no date feat type defined')
        feats = torch.cat((feats, feats_date), concat_dim)

    return feats

 BalanceSampler类

class BalancedSampler(Sampler):
    # sample "evenly" from each from class
    def __init__(self, classes, num_per_class, use_replace=False, multi_label=False):
        self.class_dict = {}
        self.num_per_class = num_per_class
        self.use_replace = use_replace
        self.multi_label = multi_label

        if self.multi_label:
            self.class_dict = classes
        else:
            # standard classification
            un_classes = np.unique(classes)
            for cc in un_classes:
                self.class_dict[cc] = []

            for ii in range(len(classes)):
                self.class_dict[classes[ii]].append(ii)

        if self.use_replace:
            self.num_exs = self.num_per_class*len(un_classes)
        else:
            self.num_exs = 0
            for cc in self.class_dict.keys():
                self.num_exs += np.minimum(len(self.class_dict[cc]), self.num_per_class)


    def __iter__(self):
        indices = []
        for cc in self.class_dict:
            if self.use_replace:
                indices.extend(np.random.choice(self.class_dict[cc], self.num_per_class).tolist())
            else:
                indices.extend(np.random.choice(self.class_dict[cc], np.minimum(len(self.class_dict[cc]),
                                                self.num_per_class), replace=False).tolist())
        # in the multi label setting there will be duplictes at training time
        np.random.shuffle(indices)  # will remain a list
        return iter(indices)

    def __len__(self):
        return self.num_exs

convert_loc_to_tensor 

def convert_loc_to_tensor(x, device=None):
    # intput is in lon {-180, 180}, lat {90, -90}
    xt = x.astype(np.float32)
    xt[:,0] /= 180.0
    xt[:,1] /= 90.0
    xt = torch.from_numpy(xt)
    if device is not None:
        xt = xt.to(device)
    return xt

 一些距离函数

def distance_pw_euclidean(xx, yy):
    # equivalent to scipy.spatial.distance.cdist
    dist = np.sqrt((xx**2).sum(1)[:, np.newaxis] - 2*xx.dot(yy.transpose()) + ((yy**2).sum(1)[np.newaxis, :]))
    return dist


def distance_pw_haversine(xx, yy, radius=6372.8):
    # input should be in radians
    # output is in km's if radius = 6372.8

    d_lon = xx[:, 0][..., np.newaxis] - yy[:, 0][np.newaxis, ...]
    d_lat = xx[:, 1][..., np.newaxis] - yy[:, 1][np.newaxis, ...]

    cos_term = np.cos(xx[:,1])[..., np.newaxis]*np.cos(yy[:, 1])[np.newaxis, ...]
    dist = np.sin(d_lat/2.0)**2 + cos_term*np.sin(d_lon/2.0)**2
    dist = 2 * radius * np.arcsin(np.sqrt(dist))
    return dist


def euclidean_distance(xx, yy):
    return np.sqrt(((xx - yy)**2).sum(1))


def haversine_distance(xx, yy, radius=6371.4):
    # assumes shape N x 2, where col 0 is lat, and col 1 is lon
    # input should be in radians
    # output is in km's if radius = 6371.4
    # note that SKLearns haversine distance is [latitude, longitude] not [longitude, latitude]

    d_lon = xx[:, 0] - yy[0]
    d_lat = xx[:, 1] - yy[1]

    cos_term = np.cos(xx[:,1])*np.cos(yy[1])
    dist = np.sin(d_lat/2.0)**2 + cos_term*np.sin(d_lon/2.0)**2
    dist = 2 * radius * np.arcsin(np.sqrt(dist + 1e-16))

    return dist

 

 

标签:loc,cos,代码,class,prior,np,sin,geo,self
来源: https://blog.csdn.net/weixin_39627422/article/details/120463418

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有