python – 将Estimator转换为TPUEstimator
内容导读
互联网集市收集整理的这篇技术教程文章主要介绍了python – 将Estimator转换为TPUEstimator,小编现在分享给大家,供广大互联网技能从业者学习和参考。文章包含3826字,纯文字阅读大概需要6分钟。
内容图文
![python – 将Estimator转换为TPUEstimator](/upload/InfoBanner/zyjiaocheng/728/92d8f76d289a4283acb845596090e2f4.jpg)
是否有可能在TensorFlow中将Estimator转换为TPUEstimator而无需重写其功能?我有一个Estimator形式的模型,可以在CPU上很好地工作,但是我不知道将它转换为TPUEstimator的简便方法,而不必重写model_fn和input_fn.
这需要手动做大量工作的原因是我使用Keras创建我的模型,然后使用以下辅助函数来创建Estimator:
my_keras_model.compile(
optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9),
loss='categorical_crossentropy',
metric='accuracy')
estimator = tf.keras.estimator.model_to_estimator(keras_model=my_keras_model)
如果我可以做像estimator.to_TPU_estimator()之类的东西,那会很棒 – 也许有人知道解决方案?
解决方法:
不能有这样的功能,因为model_fn规范在两个估计器中是不同的.有些差异很深,比如这个(从TPU tutorial开始):
When training on a cloud TPU you must wrap the optimizer in a
tf.contrib.tpu.CrossShardOptimizer
, which uses anallreduce
to
aggregate gradients and broadcast the result to each shard (each TPU
core).
这意味着修补keras优化器的内部并更新操作.
推荐的方法是为GPU和TPU模型提供不同的model_fn包装器,这似乎是最快的方式.在您的情况下,它意味着重写TPU估算器的keras model_to_estimator功能.
第一个也是最简单的近似是:
def model_to_estimator(keras_model=None,
keras_model_path=None,
custom_objects=None,
model_dir=None,
config=None):
keras_weights = keras_model.get_weights()
keras_model_fn = _create_keras_tpu_model_fn(keras_model, custom_objects)
est = tf.contrib.tpu.TPUEstimator(keras_model_fn, model_dir=model_dir, config=config)
_save_first_checkpoint(keras_model, est, custom_objects, keras_weights)
return est
这里,_save_first_checkpoint调用实际上是可选的,但如果您想保留它,请从tensorflow.python.keras._impl.keras.estimator导入此函数.
真正的工作发生在_create_keras_tpu_model_fn函数中,它取代了_create_keras_model_fn.更改是:
>如前所述,内部张量流优化器必须用CrossShardOptimizer包装,并且
>内部函数必须返回TPUEstimatorSpec.
可能还需要修补更多的行,但它看起来还不错.完整版本如下:
from tensorflow.python.keras._impl.keras.estimator import _save_first_checkpoint, _clone_and_build_model
def model_to_estimator(keras_model=None,
keras_model_path=None,
custom_objects=None,
model_dir=None,
config=None):
keras_weights = keras_model.get_weights()
keras_model_fn = _create_keras_tpu_model_fn(keras_model, custom_objects)
est = tf.contrib.tpu.TPUEstimator(keras_model_fn, model_dir=model_dir, config=config)
_save_first_checkpoint(keras_model, est, custom_objects, keras_weights)
return est
def _create_keras_tpu_model_fn(keras_model, custom_objects=None):
def model_fn(features, labels, mode):
"""model_fn for keras Estimator."""
model = _clone_and_build_model(mode, keras_model, custom_objects, features,
labels)
predictions = dict(zip(model.output_names, model.outputs))
loss = None
train_op = None
eval_metric_ops = None
# Set loss and metric only during train and evaluate.
if mode is not tf.estimator.ModeKeys.PREDICT:
model.optimizer.optimizer = tf.contrib.tpu.CrossShardOptimizer(model.optimizer.optimizer)
model._make_train_function() # pylint: disable=protected-access
loss = model.total_loss
if model.metrics:
eval_metric_ops = {}
# When each metric maps to an output
if isinstance(model.metrics, dict):
for i, output_name in enumerate(model.metrics.keys()):
metric_name = model.metrics[output_name]
if callable(metric_name):
metric_name = metric_name.__name__
# When some outputs use the same metric
if list(model.metrics.values()).count(metric_name) > 1:
metric_name += '_' + output_name
eval_metric_ops[metric_name] = tf.metrics.mean(
model.metrics_tensors[i - len(model.metrics)])
else:
for i, metric_name in enumerate(model.metrics):
if callable(metric_name):
metric_name = metric_name.__name__
eval_metric_ops[metric_name] = tf.metrics.mean(
model.metrics_tensors[i])
if mode is tf.estimator.ModeKeys.TRAIN:
train_op = model.train_function.updates_op
return tf.contrib.tpu.TPUEstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops)
return model_fn
内容总结
以上是互联网集市为您收集整理的python – 将Estimator转换为TPUEstimator全部内容,希望文章能够帮你解决python – 将Estimator转换为TPUEstimator所遇到的程序开发问题。 如果觉得互联网集市技术教程内容还不错,欢迎将互联网集市网站推荐给程序员好友。
内容备注
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 gblab@vip.qq.com 举报,一经查实,本站将立刻删除。
内容手机端
扫描二维码推送至手机访问。