ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

gym库中from gym.wrappers import FlattenObservation的理解

2022-03-21 16:03:09  阅读:242  来源: 互联网

标签:wrappers FlattenObservation observation space gym spaces flatten env np


 

看代码的过程中看到有这样的调用:

 

from gym.wrappers import FlattenObservation

if sinstance(env.observation_space, gym.spaces.Dict):
     env = FlattenObservation(env)

 

 

不是很理解这个代码的意思。

 

 

 

 

===============================================

 

 

 

查看gym源码中类:

FlattenObservation(ObservationWrapper)

 

import numpy as np
import gym.spaces as spaces
from gym import ObservationWrapper


class FlattenObservation(ObservationWrapper):
    r"""Observation wrapper that flattens the observation."""
    def __init__(self, env):
        super(FlattenObservation, self).__init__(env)

        flatdim = spaces.flatdim(env.observation_space)
        self.observation_space = spaces.Box(low=-float('inf'), high=float('inf'), shape=(flatdim,), dtype=np.float32)

    def observation(self, observation):
        return spaces.flatten(self.env.observation_space, observation)

 

从gym的状态空间的转换可以看出这个类是要将observation的状态空间进行flatten操作。

 

具体的flatten操作调用:

spaces.flatten(self.env.observation_space, observation)

 

 

 

 

查看spaces.flatten源代码:

def flatten(space, x):
    if isinstance(space, Box):
        return np.asarray(x, dtype=np.float32).flatten()
    elif isinstance(space, Discrete):
        onehot = np.zeros(space.n, dtype=np.float32)
        onehot[x] = 1.0
        return onehot
    elif isinstance(space, Tuple):
        return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
    elif isinstance(space, Dict):
        return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
    elif isinstance(space, MultiBinary):
        return np.asarray(x).flatten()
    elif isinstance(space, MultiDiscrete):
        return np.asarray(x).flatten()
    else:
        raise NotImplementedError

 

可以知道如果 env.observation_space属于Box类型,则直接调用np.array的flatten操作。

 

如果 env.observation_space属于Discrete类型,则直接进行onehot编码的方法进行flatten操作。

 

env.observation_space如果属于多个Box类型或Discrete类型组合而成的,也就是属于Tuple, Dict, 那么需要将其中的每个类型的状态空间都进行flatten操作后在进行拼接操作。

即:(取出组合空间中的各个子状态空间迭代调用flatten操作从而实现对组合中的各个子observation_space进行flatten)

    elif isinstance(space, Tuple):
        return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
    elif isinstance(space, Dict):
        return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])

 

 

 

 

 

MultiBinary, MultiDiscrete类型直接转为np.array类型的数据再进行flatten操作。

 

 

 

 

 

 

===================================================

 

标签:wrappers,FlattenObservation,observation,space,gym,spaces,flatten,env,np
来源: https://www.cnblogs.com/devilmaycry812839668/p/16035111.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有