python – Tensorflow和cifar 10,测试单个图像
内容导读
互联网集市收集整理的这篇技术教程文章主要介绍了python – Tensorflow和cifar 10,测试单个图像,小编现在分享给大家,供广大互联网技能从业者学习和参考。文章包含3195字,纯文字阅读大概需要5分钟。
内容图文
我试图用tensorflow的cifar-10预测单个图像的类.
我找到了这个代码,但它失败了这个错误:
分配要求两个张量的形状匹配. lhs shape = [18,384] rhs shape = [2304,384]
我理解这是因为批次的大小只有1.(使用expand_dims我创建一个假批次.)
但我不知道如何解决这个问题?
我到处搜索但没有解决方案..
提前致谢!
from PIL import Image
import tensorflow as tf
from tensorflow.models.image.cifar10 import cifar10
width = 24
height = 24
categories = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck" ]
filename = "path/to/jpg" # absolute path to input image
im = Image.open(filename)
im.save(filename, format='JPEG', subsampling=0, quality=100)
input_img = tf.image.decode_jpeg(tf.read_file(filename), channels=3)
tf_cast = tf.cast(input_img, tf.float32)
float_image = tf.image.resize_image_with_crop_or_pad(tf_cast, height, width)
images = tf.expand_dims(float_image, 0)
logits = cifar10.inference(images)
_, top_k_pred = tf.nn.top_k(logits, k=5)
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state('/tmp/cifar10_train')
if ckpt and ckpt.model_checkpoint_path:
print("ckpt.model_checkpoint_path ", ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
else:
print('No checkpoint file found')
exit(0)
sess.run(init_op)
_, top_indices = sess.run([_, top_k_pred])
for key, value in enumerate(top_indices[0]):
print (categories[value] + ", " + str(_[0][key]))
编辑
我尝试放置一个占位符,在第一个形状中使用None,但是我收到了这个错误:
必须完全定义新变量(local3 / weights)的形状,而不是(?,384).
现在我真的迷路了……
这是新代码:
from PIL import Image
import tensorflow as tf
from tensorflow.models.image.cifar10 import cifar10
import itertools
width = 24
height = 24
categories = [ "airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck" ]
filename = "toto.jpg" # absolute path to input image
im = Image.open(filename)
im.save(filename, format='JPEG', subsampling=0, quality=100)
x = tf.placeholder(tf.float32, [None, 24, 24, 3])
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
# Restore variables from training checkpoint.
input_img = tf.image.decode_jpeg(tf.read_file(filename), channels=3)
tf_cast = tf.cast(input_img, tf.float32)
float_image = tf.image.resize_image_with_crop_or_pad(tf_cast, height, width)
images = tf.expand_dims(float_image, 0)
i = images.eval()
print (i)
sess.run(init_op, feed_dict={x: i})
logits = cifar10.inference(x)
_, top_k_pred = tf.nn.top_k(logits, k=5)
variable_averages = tf.train.ExponentialMovingAverage(
cifar10.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
ckpt = tf.train.get_checkpoint_state('/tmp/cifar10_train')
if ckpt and ckpt.model_checkpoint_path:
print("ckpt.model_checkpoint_path ", ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
else:
print('No checkpoint file found')
exit(0)
_, top_indices = sess.run([_, top_k_pred])
for key, value in enumerate(top_indices[0]):
print (categories[value] + ", " + str(_[0][key]))
解决方法:
我认为这是因为tf.Variable或tf.get_variable获取的变量必须具有完整定义的形状.您可以检查代码并提供完整定义的形状.
内容总结
以上是互联网集市为您收集整理的python – Tensorflow和cifar 10,测试单个图像全部内容,希望文章能够帮你解决python – Tensorflow和cifar 10,测试单个图像所遇到的程序开发问题。 如果觉得互联网集市技术教程内容还不错,欢迎将互联网集市网站推荐给程序员好友。
内容备注
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 gblab@vip.qq.com 举报,一经查实,本站将立刻删除。
内容手机端
扫描二维码推送至手机访问。