ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

我如何将numpy数组中的分类数据加载到指标或嵌入列中?

2019-11-08 19:57:45  阅读:247  来源: 互联网

标签:tensorflow python-2-7 tensorflow-estimator python


使用Tensorflow 1.8.0时,每当尝试构建分类列时都会遇到问题.这是演示问题的完整示例.它按原样运行(仅使用数字列).取消注释指示符列定义和数据的注释会生成堆栈跟踪,结尾为tensorflow.python.framework.errors_impl.InternalError:无法将元素作为字节获取.

import tensorflow as tf
import numpy as np

def feature_numeric(key):
  return tf.feature_column.numeric_column(key=key, default_value=0)

def feature_indicator(key, vocabulary):
  return tf.feature_column.indicator_column(
    tf.feature_column.categorical_column_with_vocabulary_list(
      key=key, vocabulary_list=vocabulary ))


labels = ['Label1','Label2','Label3']

model = tf.estimator.DNNClassifier(
  feature_columns=[
    feature_numeric("number"),
    # feature_indicator("indicator", ["A","B","C"]),
  ],
  hidden_units=[64, 16, 8],
  model_dir='./models',
  n_classes=len(labels),
  label_vocabulary=labels)

def train(inputs, training):
  model.train(
    input_fn=tf.estimator.inputs.numpy_input_fn(
        x=inputs,
        y=training,
        shuffle=True
      ), steps=1)

inputs = {
  "number": np.array([1,2,3,4,5]),
  # "indicator": np.array([
  #     ["A"],
  #     ["B"],
  #     ["C"],
  #     ["A", "A"],
  #     ["A", "B", "C"],
  #   ]),
}

training = np.array(['Label1','Label2','Label3','Label2','Label1'])

train(inputs, training)

尝试使用嵌入票价不会更好.仅使用数字输入,我们就可以成功扩展到数千个输入节点,实际上,我们已经临时扩展了预处理器中的分类功能以模拟指标.

categorical_column _ *()和indicator_column()的文档中充斥着对我们确定不会使用的功能的引用(原型输入,无论bytes_list是什么),但也许我们错了吗?

解决方法:

这里的问题与“指示器”输入数组的参差不齐的形状有关(某些元素的长度为1,一个为长度2,一个为长度3).如果您用一些非词汇字符串填充输入列表(例如,由于您的词汇是“ A”,“ B”,“ C”,我就使用了“ Z”),您将获得预期的结果:

inputs = {
  "number": np.array([1,2,3,4,5]),
  "indicator": np.array([
    ["A", "Z", "Z"],
    ["B", "Z", "Z"],
    ["C", "Z", "Z"],
    ["A", "A", "Z"],
    ["A", "B", "C"]
  ])
}

您可以通过打印结果张量来验证此方法是否有效:

dense = tf.feature_column.input_layer(
  inputs,
  [
    feature_numeric("number"),
    feature_indicator("indicator", ["A","B","C"]),
  ])

with tf.train.MonitoredTrainingSession() as sess:
  print(dense)
  print(sess.run(dense))

标签:tensorflow,python-2-7,tensorflow-estimator,python
来源: https://codeday.me/bug/20191108/2010246.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有