python – 将input_fn用于tf.contrib.learn.Estimator时设置batch_size
内容导读
互联网集市收集整理的这篇技术教程文章主要介绍了python – 将input_fn用于tf.contrib.learn.Estimator时设置batch_size,小编现在分享给大家,供广大互联网技能从业者学习和参考。文章包含2355字,纯文字阅读大概需要4分钟。
内容图文
![python – 将input_fn用于tf.contrib.learn.Estimator时设置batch_size](/upload/InfoBanner/zyjiaocheng/797/c472c479e8e84bf1bee124625013fbfa.jpg)
我在TF上使用高级Estimator:
estim = tf.contrib.learn.Estimator(...)
estim.fit ( some_input )
如果some_input有x,y和batch_size,则代码会运行,但会显示警告;所以我尝试使用input_fn,并设法通过此input_fn发送x,y,但不发送batch_size.没有找到任何例子.
任何人都可以共享一个使用input_fn作为estim.fit / estim.evaluate输入的简单示例,并使用batch_size吗?
我必须使用tf.train.batch吗?如果是这样,它如何合并到更高级别的实现(tf.layers) – 我不知道图形的tf.Graph()或会话?
以下是我收到的警告:
WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/monitors.py:657: calling evaluate
(from tensorflow.contrib.learn.python.learn.estimators.estimator) with y is deprecated and will be removed after 2016-12-01.
Instructions for updating:
Estimator is decoupled from Scikit Learn interface by moving into
separate class SKCompat. Arguments x, y and batch_size are only
available in the SKCompat class, Estimator will only accept input_fn.Example conversion:
est = Estimator(…) -> est = SKCompat(Estimator(…))
解决方法:
link provided in Roi’s own comment确实很有帮助.由于我一直在努力解决同样的问题,我想总结上面链接提供的答案作为参考:
def batched_input_fn(dataset_x, dataset_y, batch_size):
def _input_fn():
all_x = tf.constant(dataset_x, shape=dataset_x.shape, dtype=tf.float32)
all_y = tf.constant(dataset_y, shape=dataset_y.shape, dtype=tf.float32)
sliced_input = tf.train.slice_input_producer([all_x, all_y])
return tf.train.batch(sliced_input, batch_size=batch_size)
return _input_fn
然后可以像这个例子一样使用它(使用TensorFlow v1.1):
model = CustomModel(FLAGS.learning_rate)
estimator= tf.estimator.Estimator(model_fn=model.build(), params=model.params())
estimator.train(input_fn=batched_input_fn(
train.features,
train.labels,
FLAGS.batch_size),
steps=FLAGS.train_steps)
不幸的是,与使用TensorFlows低级API相比,使用整个数据集和使用train.shape [0] == batch_size而不使用train.sliced_input_producer()和train.batch(与使用TensorFlow低级API相比),这种方法慢了约10倍. ).至少在我的机器上(仅限CPU).我真的很想知道为什么这种方法太慢了.有任何想法吗?
编辑:
我可以通过使用num_threads>来加快速度. 1作为train.batch()的参数.在具有2个CPU的VM上,与默认的num_threads = 1相比,我可以使用此批处理机制将性能提高一倍.但是,它比手动喂食慢5倍.
但是在本机系统或使用输入管道的所有CPU核心和用于模型计算的GPU的系统上,结果可能会有所不同.如果有人可以在评论中发表他的结果,那将会很棒.
内容总结
以上是互联网集市为您收集整理的python – 将input_fn用于tf.contrib.learn.Estimator时设置batch_size全部内容,希望文章能够帮你解决python – 将input_fn用于tf.contrib.learn.Estimator时设置batch_size所遇到的程序开发问题。 如果觉得互联网集市技术教程内容还不错,欢迎将互联网集市网站推荐给程序员好友。
内容备注
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 gblab@vip.qq.com 举报,一经查实,本站将立刻删除。
内容手机端
扫描二维码推送至手机访问。