tensorflow2.0——CIFAR100卷积+全连接实战
内容导读
互联网集市收集整理的这篇技术教程文章主要介绍了tensorflow2.0——CIFAR100卷积+全连接实战,小编现在分享给大家,供广大互联网技能从业者学习和参考。文章包含3832字,纯文字阅读大概需要6分钟。
内容图文
![tensorflow2.0——CIFAR100卷积+全连接实战](/upload/InfoBanner/zyjiaocheng/1065/040348eaf71f42188b6421d36b666950.jpg)
import tensorflow as tf # 设置相关底层配置 physical_devices = tf.config.experimental.list_physical_devices(‘GPU‘) assert len(physical_devices) > 0, "Not enough GPU hardware devices available" tf.config.experimental.set_memory_growth(physical_devices[0], True) def preprocess(x,y): x = tf.cast(x,dtype=tf.float32) / 255 y = tf.cast(y,dtype=tf.int32) return x,y # ###############数据加载以及处理############# (x,y),(x_test,y_test) = tf.keras.datasets.cifar100.load_data() # 将y的1维度去掉 y = tf.squeeze(y,axis=1) y_test = tf.squeeze(y_test,axis=1) print(‘x.shape,y.shape,x_test.shape,y_test.shape:‘) print(x.shape,y.shape,x_test.shape,y_test.shape) train_db = tf.data.Dataset.from_tensor_slices((x,y)) train_db = train_db.shuffle(1000).batch(64) test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test)) test_db = test_db.shuffle(1000).batch(200) # 打印看下数据的形状 sample = next(iter(train_db)) print(‘sample:‘,sample[0].shape,sample[1].shape ,tf.reduce_min(sample[0]),tf.reduce_max(sample[0])) if__name__ == ‘__main__‘: # 卷积网络结构 conv_layers = [ # 第一部分(两卷积一池化) tf.keras.layers.Conv2D(64, kernel_size=[3, 3], padding=‘same‘, activation=tf.nn.relu), tf.keras.layers.Conv2D(64, kernel_size=[3, 3], padding=‘same‘, activation=tf.nn.relu), tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2, padding=‘same‘), # 第二部分(两卷积一池化) tf.keras.layers.Conv2D(128, kernel_size=[3, 3], padding=‘same‘, activation=tf.nn.relu), tf.keras.layers.Conv2D(128, kernel_size=[3, 3], padding=‘same‘, activation=tf.nn.relu), tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2, padding=‘same‘), # 第三部分(两卷积一池化) tf.keras.layers.Conv2D(256, kernel_size=[3, 3], padding=‘same‘, activation=tf.nn.relu), tf.keras.layers.Conv2D(256, kernel_size=[3, 3], padding=‘same‘, activation=tf.nn.relu), tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2, padding=‘same‘), # 第四部分(两卷积一池化) tf.keras.layers.Conv2D(512, kernel_size=[3, 3], padding=‘same‘, activation=tf.nn.relu), tf.keras.layers.Conv2D(512, kernel_size=[3, 3], padding=‘same‘, activation=tf.nn.relu), tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2, padding=‘same‘), # 第五部分(两卷积一池化) tf.keras.layers.Conv2D(512, kernel_size=[3, 3], padding=‘same‘, activation=tf.nn.relu), tf.keras.layers.Conv2D(512, kernel_size=[3, 3], padding=‘same‘, activation=tf.nn.relu), tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2, padding=‘same‘), ] # [b,32,32,3] => [b,1,1,512] 卷积层操作 conv_net = tf.keras.Sequential(conv_layers) conv_net.build(input_shape=[None,32,32,3]) x = tf.random.normal([4,32,32,3]) out = conv_net(x) print(out.shape) # 全连接层操作 fc_net = tf.keras.Sequential([ tf.keras.layers.Dense(256,activation=tf.nn.relu), tf.keras.layers.Dense(128, activation=tf.nn.relu), tf.keras.layers.Dense(100, activation=None) ]) fc_net.build(input_shape=[None,512]) # 把卷积和全连接层的参数合并 ‘+’可以把两个列表直接合并 variables = conv_net.trainable_variables + fc_net.trainable_variables # 定义优化器 optimizer = tf.optimizers.Adam(lr=1e-4) # 训练for epoch in range(50): for step,(x,y) in enumerate(train_db): with tf.GradientTape() as tape: # [b,32,32,3] => [b,1,1,512] out = conv_net(x) # flatten out = tf.reshape(out,[-1,512]) # [b,512] =>[b,100] logits = fc_net(out) # y_onehot = tf.one_hot(y,depth=100) loss = tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True) loss = tf.reduce_mean(loss) grads = tape.gradient(loss,variables) optimizer.apply_gradients(zip(grads,variables)) if step % 100 == 0: print(epoch,step,‘loss:‘,float(loss)) for x,y in test_db: out = conv_net(x) out = tf.reshape(out,[-1,512]) logits = fc_net(out) prob = tf.nn.softmax(logits,axis=1) pred = tf.argmax(prob,axis=1) pred = tf.cast(pred,tf.int32) correct = tf.cast(tf.equal(pred,y),dtype=tf.int32) correct = tf.reduce_mean(tf.cast(correct,dtype=tf.float32)) print(‘acc:‘,float(correct))
原文:https://www.cnblogs.com/cxhzy/p/13732972.html
内容总结
以上是互联网集市为您收集整理的tensorflow2.0——CIFAR100卷积+全连接实战全部内容,希望文章能够帮你解决tensorflow2.0——CIFAR100卷积+全连接实战所遇到的程序开发问题。 如果觉得互联网集市技术教程内容还不错,欢迎将互联网集市网站推荐给程序员好友。
内容备注
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 gblab@vip.qq.com 举报,一经查实,本站将立刻删除。
内容手机端
扫描二维码推送至手机访问。