tf.estimator是一个高级的训练框架,可以不用手写大量和训练有关的代码
import tensorflow as tf
#np是一个数据预处理的库
import numpy as np
#定义特征列表表这里只有一个特征x
feature_columns = [tf.feature_column.numeric_column("x", shape=[1])]
#创建一个线性回归拟合
estimator = tf.estimator.LinearRegressor(feature_columns=feature_columns)
#设置训练数据
x_train = np.array([1., 2., 3., 4.])
y_train = np.array([0., -1., -2., -3.])
#设置评估数据
x_eval = np.array([2., 5., 8., 1.])
y_eval = np.array([-1.01, -4.1, -7, 0.])
#设置训练函数
input_fn = tf.estimator.inputs.numpy_input_fn(
{"x": x_train}, y_train, batch_size=4, num_epochs=None, shuffle=True)
#设置评估函数
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
{"x": x_eval}, y_eval, batch_size=4, shuffle=False)
#训练
estimator.train(input_fn=input_fn, steps=1000)
#评估
eval_metrics = estimator.evaluate(input_fn=eval_input_fn)
print("eval metrics: %r"% eval_metrics)