TOC
nn_Module
1.Where are module parameters configured?
The parameters are stored in the network node which is one of points of a network layer.
Neural network layer is defined in init
method of module and need to be defined as class variable;
import torch.nn as nn
import numpy as np
class TorchDNN(nn.Module):
def __init__(self, input, hidden, output):
super(TorchDNN, self).__init__();
layer_hidden = nn.Linear(input, hidden, bias = True);
def forward(self, input_data):
pass
x = np.array([1, 2, 3])
torch_model = TorchDNN(len(x), 5, 3)
print(torch_model.state_dict())
# OUTPUT:
# OrderedDict()
Network layer should be defined as a variable of Module
class;
import torch.nn as nn
import numpy as np
class TorchDNN(nn.Module):
def __init__(self, input, hidden, output):
super(TorchDNN, self).__init__();
self.layer_hidden = nn.Linear(input, hidden, bias = True); # using `self.`
def forward(self, input_data):
pass
x = np.array([1, 2, 3])
torch_model = TorchDNN(len(x), 5, 3)
print(torch_model.state_dict())
# OUTPUT:
# OrderedDict([('layer_hidden.weight', tensor([[-5.5532e-02, 3.7091e-01, -6.2572e-02],
# [ 6.1063e-02, -3.9203e-01, 3.1489e-04],
# [-1.3280e-01, 1.6668e-01, 4.1367e-01],
# [ 4.5186e-02, 8.8564e-02, -5.0505e-01],
# [ 1.5705e-01, 1.1069e-02, 7.6944e-02]])), ('layer_hidden.bias', tensor([ 0.2683, -0.1999, 0.2895, 0.1940, 0.0580]))])
import torch.nn as nn
import numpy as np
class TorchDNN(nn.Module):
def __init__(self, input, hidden, output):
super(TorchDNN, self).__init__();
self.layer_hidden = nn.Linear(input, hidden, bias = True);
def forward(self, input_data):
self.layer_hidden2 = nn.Linear(3, 3, bias = True); # NOTE THAT: neural network layer should be defined in `__init__` method.
x = np.array([1, 2, 3])
torch_model = TorchDNN(len(x), 5, 3)
print(torch_model.state_dict())
# OUTPUT (There is only one layer!):
# OrderedDict([('layer_hidden.weight', tensor([[ 0.5433, 0.4628, 0.2992],
# [ 0.3414, -0.2000, 0.5657],
# [-0.3542, 0.2974, 0.0271],
# [-0.4149, -0.0669, -0.3499],
# [ 0.1169, -0.1990, -0.2368]])), ('layer_hidden.bias', tensor([ 0.3161, -0.1012, 0.2077, -0.1572, 0.2423]))])
「点个赞」
点个赞
使用微信扫描二维码完成支付
