python – Tensorflow – 保存模型
内容导读
互联网集市收集整理的这篇技术教程文章主要介绍了python – Tensorflow – 保存模型,小编现在分享给大家,供广大互联网技能从业者学习和参考。文章包含3127字,纯文字阅读大概需要5分钟。
内容图文
我有以下代码,并在尝试保存模型时出错.我可能做错了什么,我该如何解决这个问题?
import tensorflow as tf
data, labels = cifar_tools.read_data('C:\\Users\\abc\\Desktop\\Testing')
x = tf.placeholder(tf.float32, [None, 150 * 150])
y = tf.placeholder(tf.float32, [None, 2])
w1 = tf.Variable(tf.random_normal([5, 5, 1, 64]))
b1 = tf.Variable(tf.random_normal([64]))
w2 = tf.Variable(tf.random_normal([5, 5, 64, 64]))
b2 = tf.Variable(tf.random_normal([64]))
w3 = tf.Variable(tf.random_normal([38*38*64, 1024]))
b3 = tf.Variable(tf.random_normal([1024]))
w_out = tf.Variable(tf.random_normal([1024, 2]))
b_out = tf.Variable(tf.random_normal([2]))
def conv_layer(x,w,b):
conv = tf.nn.conv2d(x,w,strides=[1,1,1,1], padding = 'SAME')
conv_with_b = tf.nn.bias_add(conv,b)
conv_out = tf.nn.relu(conv_with_b)
return conv_out
def maxpool_layer(conv,k=2):
return tf.nn.max_pool(conv, ksize=[1,k,k,1], strides=[1,k,k,1], padding='SAME')
def model():
x_reshaped = tf.reshape(x, shape=[-1, 150, 150, 1])
conv_out1 = conv_layer(x_reshaped, w1, b1)
maxpool_out1 = maxpool_layer(conv_out1)
norm1 = tf.nn.lrn(maxpool_out1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
conv_out2 = conv_layer(norm1, w2, b2)
norm2 = tf.nn.lrn(conv_out2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
maxpool_out2 = maxpool_layer(norm2)
maxpool_reshaped = tf.reshape(maxpool_out2, [-1, w3.get_shape().as_list()[0]])
local = tf.add(tf.matmul(maxpool_reshaped, w3), b3)
local_out = tf.nn.relu(local)
out = tf.add(tf.matmul(local_out, w_out), b_out)
return out
model_op = model()
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(model_op, y))
train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)
correct_pred = tf.equal(tf.argmax(model_op, 1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
onehot_labels = tf.one_hot(labels, 2, on_value=1.,off_value=0.,axis=-1)
onehot_vals = sess.run(onehot_labels)
batch_size = 1
saver = tf.train.Saver()
saved_path = saver.save(sess, 'mymodel')
print("The model is in this file: ", saved_path)
for j in range(0, 5):
print('EPOCH', j)
for i in range(0, len(data), batch_size):
batch_data = data[i:i+batch_size, :]
batch_onehot_vals = onehot_vals[i:i+batch_size, :]
_, accuracy_val = sess.run([train_op, accuracy], feed_dict={x: batch_data, y: batch_onehot_vals})
print(i, accuracy_val)
print('DONE WITH EPOCH')
编辑-1
忘了说出我遇到的错误:-)
Traceback (most recent call last):
File "cnn.py", line 67, in <module>
save_path = saver.save(sess, 'mymodel')
File "C:\Python35\lib\site-packages\tensorflow\python\training\saver.py", line 1314, in save
"Parent directory of {} doesn't exist, can't save.".format(save_path))
ValueError: Parent directory of mymodel doesn't exist, can't save.
谢谢.
解决方法:
看起来您要存储模型的文件夹不存在(可能会检查您当前的工作目录是什么).为了避免这些问题,我会使用绝对路径并在保存之前执行以下操作:
save_path = ...
if not os.path.exists(save_path):
os.makedirs(save_path)
...
saver = tf.train.Saver()
with tf.Session() as sess:
...
saved_path = saver.save(sess, os.path.join(save_path, 'my_model')
print("The model is in this file: ", saved_path)
内容总结
以上是互联网集市为您收集整理的python – Tensorflow – 保存模型全部内容,希望文章能够帮你解决python – Tensorflow – 保存模型所遇到的程序开发问题。 如果觉得互联网集市技术教程内容还不错,欢迎将互联网集市网站推荐给程序员好友。
内容备注
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 gblab@vip.qq.com 举报,一经查实,本站将立刻删除。
内容手机端
扫描二维码推送至手机访问。