Administrator
发布于 2026-05-17 / 2 阅读
0

tensorflow自定义网络结构

自定义网络层 ====== 自定义层需要继承tf.keras.layers.Layer类,重写init,build,call * __init__,执行与输入无关的初始化 * build,了解输入张量的形状,定义需要什么输入 * call,进行正向计算 ``` class MyDense(tf.keras.layers.Layer): def __init__(self,units): # units 神经元个数 super().__init__() # 必须写 self.units = units def build(self,input_shape): self.w = self.add_variable( name="w", shape=[input_shape[-1],self.units], initializer = tf.initializers.RandomNormal() ) self.b = self.add_variable(name="b",shape=[self.units],initializer = tf.initializers.Zeros()) # b一般是全0 def call(self,input): # wx+b return input @ self.w + self.b return tf.nn.relu(input @ self.w + self.b) ``` 自定义模型类 ====== ``` class MyModel(tf.keras.Model): def __init__(self): super().__init__() self.fc1 = MyDense(512) self.fc2 = MyDense(256) self.fc3 = MyDense(128) self.fc4 = MyDense(10) def call(self,input): self.fc1.out = self.fc1(input) self.fc2.out = self.fc2(self.fc1.out) self.fc3.out = self.fc3(self.fc2.out) self.fc4.out = self.fc4(self.fc3.out) return self.fc4.out myModel = MyModel() myModel.build(input_shape=(None,784)) myModel.summary() ``` ###### 注: ``` # 模型保存 # 1,保存模型 # model.save("xxx.h5") # tensorflow.keras.models.load_model("xxxx.h5") ​ # 2,保存权重参数 # model.save_weights("xxxx.ckpt") # model.load_weights("xxxx.ckpt") ​ # 3,save_model 此时保存的模型具有平台无关性,移植性好 1.15及之后版本 # tensorflow.keras.models.save_model(model,"foldername") 生成文件夹,里面有pb文件 # tensorflow.keras.models.load_model("foldername") # 此时只导入的只有model结构与weight参数 model.compile还需要自己写 ```