创建一个简单的RNN网络

Posted by yaohong on Monday, November 23, 2020

TOC

创建一个简单的RNN网络

Environment:

python version: 3.7.4
pip version: 19.0.3
numpy version:1.19.4
matplotlib version:3.3.3
tensorflow version:1.14.0
keras version:2.1.5

代码如下:

import keras
from keras import backend as K
from keras.layers import RNN
class MinimalRNNCell(keras.layers.Layer):

	def __init__(self, units,use_bias = True, **kwargs):
		self.units = units
		self.state_size = units
		self.use_bias = use_bias
		super(MinimalRNNCell, self).__init__(**kwargs)

	def build(self, input_shape):
		self.kernel = self.add_weight(shape=(input_shape[-1], self.units), # 添加kernel
										initializer='uniform',
										name='kernel')
		self.recurrent_kernel = self.add_weight(# 添加循环层kernel
			shape=(self.units, self.units),
			initializer='uniform',
			name='recurrent_kernel')

		if self.use_bias:
			self.bias = self.add_weight( # 添加bias
				shape=(self.units,),
				name='bias',
				initializer='uniform',)
		else:
			self.bias = None

		self.built = True

	def call(self, inputs, states):
		prev_output = states[0]
		h = K.dot(inputs, self.kernel)
		output = h + K.dot(prev_output, self.recurrent_kernel)
		return output, [output]


# Let's use this cell in a RNN layer:

cell = MinimalRNNCell(32)
x = keras.Input((None, 5))
layer = RNN(cell)
y = layer(x)

model = keras.Model(x,y)
model.summary()

Output:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, None, 5)           0         
_________________________________________________________________
rnn_1 (RNN)                  (None, 32)                1216      
=================================================================
Total params: 1,216
Trainable params: 1,216
Non-trainable params: 0
_________________________________________________________________

REFERENCES:

rnn

recurrent

「点个赞」

Yaohong

点个赞

使用微信扫描二维码完成支付