TOC
创建一个简单的RNN网络
Environment:
python version: 3.7.4
pip version: 19.0.3
numpy version:1.19.4
matplotlib version:3.3.3
tensorflow version:1.14.0
keras version:2.1.5
代码如下:
import keras
from keras import backend as K
from keras.layers import RNN
class MinimalRNNCell(keras.layers.Layer):
def __init__(self, units,use_bias = True, **kwargs):
self.units = units
self.state_size = units
self.use_bias = use_bias
super(MinimalRNNCell, self).__init__(**kwargs)
def build(self, input_shape):
self.kernel = self.add_weight(shape=(input_shape[-1], self.units), # 添加kernel
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(# 添加循环层kernel
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
if self.use_bias:
self.bias = self.add_weight( # 添加bias
shape=(self.units,),
name='bias',
initializer='uniform',)
else:
self.bias = None
self.built = True
def call(self, inputs, states):
prev_output = states[0]
h = K.dot(inputs, self.kernel)
output = h + K.dot(prev_output, self.recurrent_kernel)
return output, [output]
# Let's use this cell in a RNN layer:
cell = MinimalRNNCell(32)
x = keras.Input((None, 5))
layer = RNN(cell)
y = layer(x)
model = keras.Model(x,y)
model.summary()
Output:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, None, 5) 0
_________________________________________________________________
rnn_1 (RNN) (None, 32) 1216
=================================================================
Total params: 1,216
Trainable params: 1,216
Non-trainable params: 0
_________________________________________________________________
REFERENCES:
「点个赞」
点个赞
使用微信扫描二维码完成支付
