ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

基于tensor2tensor的注意力可视化

2020-12-27 04:01:06  阅读:421  来源: 互联网

标签:inputs name self attention list tensor2tensor 可视化 model 注意力


根据训练好的Transformer模型,得到注意力矩阵,并对注意力进行可视化

首先安装:tensorflow 1.13.1 + tensor2tensor 1.13.1

 

# coding=utf-8
# Copyright 2020 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Shared code for visualizing transformer attentions."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np

# To register the hparams set
from tensor2tensor import models  # pylint: disable=unused-import
from tensor2tensor import problems
from tensor2tensor.utils import registry
from tensor2tensor.utils import trainer_lib

import tensorflow.compat.v1 as tf
from tensor2tensor.utils import usr_dir
EOS_ID = 1

class AttentionVisualizer2(object):
  """Helper object for creating Attention visualizations."""

  def __init__(
      self, hparams_set,hparams,t2t_usr_dir, model_name, data_dir, problem_name, beam_size=1):
    inputs, targets, samples, att_mats = build_model(
        hparams_set,hparams, t2t_usr_dir, model_name, data_dir, problem_name, beam_size=beam_size)

    # Fetch the problem
    ende_problem = problems.problem(problem_name)
    encoders = ende_problem.feature_encoders(data_dir)

    self.inputs = inputs
    self.targets = targets
    self.att_mats = att_mats
    self.samples = samples
    self.encoders = encoders

  def encode(self, input_str):
    """Input str to features dict, ready for inference."""
    inputs = self.encoders["inputs"].encode(input_str) + [EOS_ID]
    batch_inputs = np.reshape(inputs, [1, -1, 1, 1])  # Make it 3D.
    return batch_inputs

  def decode(self, integers):
    """List of ints to str."""
    integers = list(np.squeeze(integers))
    return self.encoders["targets"].decode(integers)

  def encode_list(self, integers):
    """List of ints to list of str."""
    integers = list(np.squeeze(integers))
    return self.encoders["inputs"].decode_list(integers)

  def decode_list(self, integers):
    """List of ints to list of str."""
    integers = list(np.squeeze(integers))
    return self.encoders["targets"].decode_list(integers)

  def get_vis_data_from_string(self, sess, input_string):
    """Constructs the data needed for visualizing attentions.
    Args:
      sess: A tf.Session object.
      input_string: The input sentence to be translated and visualized.
    Returns:
      Tuple of (
          output_string: The translated sentence.
          input_list: Tokenized input sentence.
          output_list: Tokenized translation.
          att_mats: Tuple of attention matrices; (
              enc_atts: Encoder self attention weights.
                A list of `num_layers` numpy arrays of size
                (batch_size, num_heads, inp_len, inp_len)
              dec_atts: Decoder self attention weights.
                A list of `num_layers` numpy arrays of size
                (batch_size, num_heads, out_len, out_len)
              encdec_atts: Encoder-Decoder attention weights.
                A list of `num_layers` numpy arrays of size
                (batch_size, num_heads, out_len, inp_len)
          )
    """
    encoded_inputs = self.encode(input_string)

    # Run inference graph to get the translation.
    out = sess.run(self.samples, {
        self.inputs: encoded_inputs,
    })



    # Run the decoded translation through the training graph to get the
    # attention tensors.


    att_mats = sess.run(self.att_mats, {
        self.inputs: encoded_inputs,
        self.targets: np.reshape(out, [1, -1, 1, 1]),
    })

    output_string = self.decode(out)
    input_list = self.encode_list(encoded_inputs)
    output_list = self.decode_list(out)

    return output_string, input_list, output_list, att_mats


def build_model(hparams_set, hparams,t2t_usr_dir, model_name, data_dir, problem_name, beam_size=1):
  """Build the graph required to fetch the attention weights.
  Args:
    hparams_set: HParams set to build the model with.
    model_name: Name of model.
    data_dir: Path to directory containing training data.
    problem_name: Name of problem.
    beam_size: (Optional) Number of beams to use when decoding a translation.
        If set to 1 (default) then greedy decoding is used.
  Returns:
    Tuple of (
        inputs: Input placeholder to feed in ids to be translated.
        targets: Targets placeholder to feed to translation when fetching
            attention weights.
        samples: Tensor representing the ids of the translation.
        att_mats: Tensors representing the attention weights.
    )
  """
  print(model_name)
  usr_dir.import_usr_dir(t2t_usr_dir)
  hparams = trainer_lib.create_hparams(
      hparams_set,hparams, data_dir=data_dir, problem_name=problem_name)

  # print(hparams)

  translate_model = registry.model(model_name)(
      hparams, tf.estimator.ModeKeys.EVAL)

  inputs = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="inputs")
  targets = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="targets")
  translate_model({
      "inputs": inputs,
      "targets": targets,
  })

  # Must be called after building the training graph, so that the dict will
  # have been filled with the attention tensors. BUT before creating the
  # inference graph otherwise the dict will be filled with tensors from
  # inside a tf.while_loop from decoding and are marked unfetchable.
  atts = get_att_mats(translate_model,model_name)

  with tf.variable_scope(tf.get_variable_scope(), reuse=True):
    samples = translate_model.infer({
        "inputs": inputs,
    }, beam_size=beam_size)["outputs"]

  return inputs, targets, samples, atts


def get_att_mats(translate_model,model_name):
  """Get's the tensors representing the attentions from a build model.
  The attentions are stored in a dict on the Transformer object while building
  the graph.
  Args:
    translate_model: Transformer object to fetch the attention weights from.
  Returns:
  Tuple of attention matrices; (
      enc_atts: Encoder self attention weights.
        A list of `num_layers` numpy arrays of size
        (batch_size, num_heads, inp_len, inp_len)
      dec_atts: Decoder self attetnion weights.
        A list of `num_layers` numpy arrays of size
        (batch_size, num_heads, out_len, out_len)
      encdec_atts: Encoder-Decoder attention weights.
        A list of `num_layers` numpy arrays of size
        (batch_size, num_heads, out_len, inp_len)
  )
  """
  enc_atts = []
  dec_atts = []
  encdec_atts = []

  prefix = "%s/body/"%(model_name)
  postfix_self_attention = "/multihead_attention/dot_product_attention"
  if translate_model.hparams.self_attention_type == "dot_product_relative":
    postfix_self_attention = ("/multihead_attention/"
                              "dot_product_attention_relative")
  postfix_encdec = "/multihead_attention/dot_product_attention"

  for i in range(translate_model.hparams.num_hidden_layers):
    enc_att = translate_model.attention_weights[
        "%sencoder/layer_%i/self_attention%s"
        % (prefix, i, postfix_self_attention)]
    dec_att = translate_model.attention_weights[
        "%sdecoder/layer_%i/self_attention%s"
        % (prefix, i, postfix_self_attention)]
    encdec_att = translate_model.attention_weights[
        "%sdecoder/layer_%i/encdec_attention%s" % (prefix, i, postfix_encdec)]
    enc_atts.append(enc_att)
    dec_atts.append(dec_att)
    encdec_atts.append(encdec_att)

  return enc_atts, dec_atts, encdec_atts


import os
from tensor2tensor import problems
from tensor2tensor.bin import t2t_decoder  # To register the hparams set
# from tensor2tensor.utils import registry
from tensor2tensor.utils import trainer_lib
from tensor2tensor.visualization import attention
# from src.visualization import visualization
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

# CHECKPOINT = '/home/usrname/collaboration/t2t_train/translate_envi_iwslt32k/collaboration-collaboration_tiny-v6_0.5_0_normd'
# HParams
problem_name = 'translate_envi_iwslt32k'
data_dir = os.path.expanduser('/home/usrname/collaboration/t2t_data/%s'%(problem_name))
model_name = "collaboration"
hparams_set = "collaboration_tiny"
hparams = 'max_length=128,num_hidden_layers=6,usedegray=0.5,reuse_n=0'
t2t_usr_dir = './src/'   #个人自定义模型代码路径

visualizer = AttentionVisualizer2(hparams_set,hparams, t2t_usr_dir,model_name, data_dir, problem_name, beam_size=1)

#/home/usrname/collaboration/t2t_data/translate_envi_iwslt32k/vocab.translate_envi_iwslt32k.32768.subwords

tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step')

saver = tf.train.Saver()
with tf.Session() as sess:
  # ckpts = tf.train.get_checkpoint_state(CHECKPOINT)
  # ckpt = ckpts.model_checkpoint_path
  ckpt = 'averaged.ckpt-0'  #模型checkpoint
  print(ckpt)
  saver.restore(sess, ckpt)


  input_sentence = "My family was not poor , and myself , I had never experienced hunger ."
  output_string, inp_text, out_text, att_mats = visualizer.get_vis_data_from_string(sess, input_sentence)
  print(output_string)
  print(att_mats)

  attention.show(inp_text, out_text, *att_mats)

  

标签:inputs,name,self,attention,list,tensor2tensor,可视化,model,注意力
来源: https://www.cnblogs.com/huadongw/p/14195355.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有