+-
tfa.seq2seq.TrainingSampler 理解总结
首页 专栏 tensorflow 文章详情
0

tfa.seq2seq.TrainingSampler 理解总结

楚知行 发布于 2 月 1 日

概述

tfa.seq2seq.TrainingSampler,简单读取输入的训练采样器。
调用trainingSampler.initialize(input_tensors)时,取各batch中time_step=0的数据,拼接成一个数据集,返回。
下一次调用sampler.next_inputs函数时,会取各batch中time_step++的数据,拼接成一个数据集,返回。

举例说明

官网例子修改版:

import tensorflow_addons as tfa
import tensorflow as tf


def tfa_seq2seq_TrainingSampler_test():
    batch_size = 2
    max_time = 3
    word_vector_len = 4
    hidden_size = 5

    sampler = tfa.seq2seq.TrainingSampler()
    cell = tf.keras.layers.LSTMCell(hidden_size)

    input_tensors = tf.random.uniform([batch_size, max_time, word_vector_len])
    initial_finished, initial_inputs = sampler.initialize(input_tensors)

    cell_input = initial_inputs
    cell_state = cell.get_initial_state(initial_inputs)

    for time_step in tf.range(max_time):
        cell_output, cell_state = cell(cell_input, cell_state)
        sample_ids = sampler.sample(time_step, cell_output, cell_state)
        finished, cell_input, cell_state = sampler.next_inputs(
            time_step, cell_output, cell_state, sample_ids)
        if tf.reduce_all(finished):
            break
        print(time_step)

if __name__ == '__main__':
    pass;
    tfa_seq2seq_TrainingSampler_test()

以上面的代码为例,

# 假设输入数值上如下所示, 输入各维度含义, [batch_size, time_step, feature_length(或者word_vector_length)]

input_tensors = tf.Tensor(
[[[0.9346709  0.13170087 0.6356932  0.13167298]
  [0.4919318  0.44602418 0.49046385 0.28244007]
  [0.9263021  0.9984634  0.10324025 0.653986  ]]

 [[0.8260417  0.269673   0.37965262 0.86320114]
  [0.88838446 0.28112316 0.5868691  0.4174199 ]
  [0.61980057 0.2420206  0.17553246 0.9765543 ]]], shape=(2, 3, 4), dtype=float32)

当运行完sampler.initialize(input_tensors)时,得到如下的采样结果,即两个batch中,每个batch中time_step=0的数据,拼接而成。

initial_inputs = tf.Tensor(
[[0.9346709  0.13170087 0.6356932  0.13167298]
 [0.8260417  0.269673   0.37965262 0.86320114]], shape=(2, 4), dtype=float32)

第一次运行完sampler.next_inputs时,得到如下的采样结果,即两个batch中,每个batch中time_step=1的数据,拼接而成。

initial_inputs = tf.Tensor(
[[0.4919318  0.44602418 0.49046385 0.28244007]
 [0.88838446 0.28112316 0.5868691  0.4174199 ]], shape=(2, 4), dtype=float32)

第二次运行完sampler.next_inputs时,得到如下的采样结果,即两个batch中,每个batch中time_step=2的数据,拼接而成。

initial_inputs = tf.Tensor(
[[0.9263021  0.9984634  0.10324025 0.653986  ]
 [0.61980057 0.2420206  0.17553246 0.9765543 ]], shape=(2, 4), dtype=float32)

sample_ids的含义,RNN输出,每一批中,数值最大的逻辑位对应的下标。

# 当LSTMCell的输出如下所示时,
cell_output = tf.Tensor(
[[-0.07552935  0.07034459  0.12033001 -0.1792231   0.05634112]
 [-0.10488522  0.06370427  0.17486209 -0.10092633  0.09584342]], shape=(2, 5), dtype=float32)
 
 # 显然,第一批与第二批中都是下标=2的逻辑位数值最大
sample_ids = tf.Tensor([2 2], shape=(2,), dtype=int32)

参考文献

https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/Sampler?hl=zh-cn (tfa.seq2seq.Sampler  |  TensorFlow Addons)
https://tensorflow.google.cn/addons/api_docs/python/tfa/seq2seq/TrainingSampler (tfa.seq2seq.TrainingSampler  |  TensorFlow Addons)

tensorflow
阅读 45 发布于 2 月 1 日
收藏
分享
本作品系原创, 采用《署名-非商业性使用-禁止演绎 4.0 国际》许可协议
Java与大数据技术
我的Java与大数据开发经验总结与分享~
关注专栏
avatar
楚知行
18 声望
3 粉丝
关注作者
0 条评论
得票 时间
提交评论
avatar
楚知行
18 声望
3 粉丝
关注作者
宣传栏
目录

概述

tfa.seq2seq.TrainingSampler,简单读取输入的训练采样器。
调用trainingSampler.initialize(input_tensors)时,取各batch中time_step=0的数据,拼接成一个数据集,返回。
下一次调用sampler.next_inputs函数时,会取各batch中time_step++的数据,拼接成一个数据集,返回。

举例说明

官网例子修改版:

import tensorflow_addons as tfa
import tensorflow as tf


def tfa_seq2seq_TrainingSampler_test():
    batch_size = 2
    max_time = 3
    word_vector_len = 4
    hidden_size = 5

    sampler = tfa.seq2seq.TrainingSampler()
    cell = tf.keras.layers.LSTMCell(hidden_size)

    input_tensors = tf.random.uniform([batch_size, max_time, word_vector_len])
    initial_finished, initial_inputs = sampler.initialize(input_tensors)

    cell_input = initial_inputs
    cell_state = cell.get_initial_state(initial_inputs)

    for time_step in tf.range(max_time):
        cell_output, cell_state = cell(cell_input, cell_state)
        sample_ids = sampler.sample(time_step, cell_output, cell_state)
        finished, cell_input, cell_state = sampler.next_inputs(
            time_step, cell_output, cell_state, sample_ids)
        if tf.reduce_all(finished):
            break
        print(time_step)

if __name__ == '__main__':
    pass;
    tfa_seq2seq_TrainingSampler_test()

以上面的代码为例,

# 假设输入数值上如下所示, 输入各维度含义, [batch_size, time_step, feature_length(或者word_vector_length)]

input_tensors = tf.Tensor(
[[[0.9346709  0.13170087 0.6356932  0.13167298]
  [0.4919318  0.44602418 0.49046385 0.28244007]
  [0.9263021  0.9984634  0.10324025 0.653986  ]]

 [[0.8260417  0.269673   0.37965262 0.86320114]
  [0.88838446 0.28112316 0.5868691  0.4174199 ]
  [0.61980057 0.2420206  0.17553246 0.9765543 ]]], shape=(2, 3, 4), dtype=float32)

当运行完sampler.initialize(input_tensors)时,得到如下的采样结果,即两个batch中,每个batch中time_step=0的数据,拼接而成。

initial_inputs = tf.Tensor(
[[0.9346709  0.13170087 0.6356932  0.13167298]
 [0.8260417  0.269673   0.37965262 0.86320114]], shape=(2, 4), dtype=float32)

第一次运行完sampler.next_inputs时,得到如下的采样结果,即两个batch中,每个batch中time_step=1的数据,拼接而成。

initial_inputs = tf.Tensor(
[[0.4919318  0.44602418 0.49046385 0.28244007]
 [0.88838446 0.28112316 0.5868691  0.4174199 ]], shape=(2, 4), dtype=float32)

第二次运行完sampler.next_inputs时,得到如下的采样结果,即两个batch中,每个batch中time_step=2的数据,拼接而成。

initial_inputs = tf.Tensor(
[[0.9263021  0.9984634  0.10324025 0.653986  ]
 [0.61980057 0.2420206  0.17553246 0.9765543 ]], shape=(2, 4), dtype=float32)

sample_ids的含义,RNN输出,每一批中,数值最大的逻辑位对应的下标。

# 当LSTMCell的输出如下所示时,
cell_output = tf.Tensor(
[[-0.07552935  0.07034459  0.12033001 -0.1792231   0.05634112]
 [-0.10488522  0.06370427  0.17486209 -0.10092633  0.09584342]], shape=(2, 5), dtype=float32)
 
 # 显然,第一批与第二批中都是下标=2的逻辑位数值最大
sample_ids = tf.Tensor([2 2], shape=(2,), dtype=int32)

参考文献

https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/Sampler?hl=zh-cn (tfa.seq2seq.Sampler  |  TensorFlow Addons)
https://tensorflow.google.cn/addons/api_docs/python/tfa/seq2seq/TrainingSampler (tfa.seq2seq.TrainingSampler  |  TensorFlow Addons)