tf.estimator的使用

作者: admin 日期: 2017-10-10 15:57:47 人气: - 评论: 0

tf.estimator是一个高级的训练框架,可以不用手写大量和训练有关的代码


import tensorflow as tf
#np是一个数据预处理的库
import numpy as np

#定义特征列表表这里只有一个特征x
feature_columns = [tf.feature_column.numeric_column("x", shape=[1])]

#创建一个线性回归拟合
estimator = tf.estimator.LinearRegressor(feature_columns=feature_columns)

#设置训练数据
x_train = np.array([1., 2., 3., 4.])
y_train = np.array([0., -1., -2., -3.])

#设置评估数据
x_eval = np.array([2., 5., 8., 1.])
y_eval = np.array([-1.01, -4.1, -7, 0.])

#设置训练函数
input_fn = tf.estimator.inputs.numpy_input_fn(
{"x": x_train}, y_train, batch_size=4, num_epochs=None, shuffle=True)

#设置评估函数
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
{"x": x_eval}, y_eval, batch_size=4, shuffle=False)

#训练
estimator.train(input_fn=input_fn, steps=1000)

#评估
eval_metrics = estimator.evaluate(input_fn=eval_input_fn)
print("eval metrics: %r"% eval_metrics)


相关内容

发表评论
更多 网友评论0 条评论)
暂无评论

Copyright © 2012-2014 我的代码板 Inc. 保留所有权利。

页面耗时0.0411秒, 内存占用1.82 MB, 访问数据库13次

闽ICP备15009223号-1