tensorflow学习笔记-SavedModel文件解释及TFServing的模型加载、使用
tensorflow基本概念:https://www.cnblogs.com/wanyu416/p/8954098.html 这里是一系列文章
Tensorflow SavedModel 模型的保存和加载 https://www.jianshu.com/p/83cfd9571158
Tensorflow如何加载离线模型 https://www.zhihu.com/question/300914772
TensorFlow模型的跨平台部署 https://zhuanlan.zhihu.com/p/40481765
TensorFlow程序结构 http://c.biancheng.net/view/1883.html
SavedModel的格式:https://www.tensorflow.org/guide/saved_model
SavedModel 是一个包含序列化签名和运行这些签名所需的状态的目录,其中包括变量值和词汇表。
目录如下:
saved_model.pb 文件用于存储实际 TensorFlow 程序或模型,以及一组已命名的签名(signatures)——每个签名标识一个接受tensor输入和产生tensor输出的函数。
variables 目录包含一个标准训练检查点(checkpoint)
名词解释:
signatures: 使用SavedModel保存的签名。只适用于“tf”格式,详情查看 tf.saved_model.save
checkpoint: 检查点,保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练。这种在训练中保存模型,习惯上称之为保存检查点。
TensorFlow——Checkpoint为模型添加检查点 https://www.cnblogs.com/baby-lily/p/10930591.html
TFS模型加载:
tensorflow/tensorflow/cc/saved_model/loader.cc
/// Checks whether the provided directory could contain a SavedModel. Note that
/// the method does not load any data by itself. If the method returns `false`,
/// the export directory definitely does not contain a SavedModel. If the method
/// returns `true`, the export directory may contain a SavedModel but provides
/// no guarantee that it can be loaded.
bool MaybeSavedModelDirectory(const string& export_dir);
检查提供的目录是否可以包含SavedModel。 请注意,该方法本身不会加载任何数据。 如果该方法返回false
,则导出目录肯定不包含SavedModel。 如果该方法返回true
,则导出目录可能包含SavedModel,但不保证可以加载它。
/// Loads a SavedModel from the specified export directory. The meta graph def
/// to be loaded is identified by the supplied tags, corresponding exactly to
/// the set of tags used at SavedModel build time. Returns a SavedModel bundle
/// with a session and the requested meta graph def, if found.
Status LoadSavedModel(const SessionOptions& session_options,
const RunOptions& run_options, const string& export_dir,
const std::unordered_set<string>& tags,
SavedModelBundle* const bundle);
从指定的导出目录加载SavedModel。 所要加载的meta graph def由所提供的tag标识,该标签恰好与SavedModel构建时使用的标签集相对应。 如果找到,返回带有会话和请求的meta graph def的SavedModel bundle。
Eg:例子: https://gist.github.com/OneRaynyDay/c79346890dda095aecc6e9249a9ff3e1
tensorflow::MaybeSavedModelDirectory
tensorflow::LoadSavedModel
bundle.session->Run
点击查看例子
#include <tensorflow/cc/saved_model/loader.h>
#include <tensorflow/cc/saved_model/tag_constants.h>
#include <tensorflow/core/public/session_options.h>
#include <tensorflow/core/framework/tensor.h>
#include <xtensor/xarray.hpp>
#include <xtensor/xnpy.hpp>
#include <string>
#include <iostream>
#include <vector>
#include <cfloat>
static const int IMG_SIZE = 784;
static const int NUM_SAMPLES = 10000;
tensorflow::Tensor load_npy_img(const std::string& filename) {
auto data = xt::load_npy<float>(filename);
tensorflow::Tensor t(tensorflow::DT_FLOAT, tensorflow::TensorShape({NUM_SAMPLES, IMG_SIZE}));
for (int i = 0; i < NUM_SAMPLES; i++)
for (int j = 0; j < IMG_SIZE; j++)
t.tensor<float, 2>()(i,j) = data(i, j);
return t;
}
std::vector<int> get_tensor_shape(const tensorflow::Tensor& tensor)
{
std::vector<int> shape;
auto num_dimensions = tensor.shape().dims();
for(int i=0; i < num_dimensions; i++) {
shape.push_back(tensor.shape().dim_size(i));
}
return shape;
}
template <typename M>
void print_keys(const M& sig_map) {
for (auto const& p : sig_map) {
std::cout << "key : " << p.first << std::endl;
}
}
template <typename K, typename M>
bool assert_in(const K& k, const M& m) {
return !(m.find(k) == m.end());
}
std::string _input_name = "digits";
std::string _output_name = "predictions";
int main() {
// This is passed into LoadSavedModel to be populated.
tensorflow::SavedModelBundle bundle;
// From docs: "If 'target' is empty or unspecified, the local TensorFlow runtime
// implementation will be used. Otherwise, the TensorFlow engine
// defined by 'target' will be used to perform all computations."
tensorflow::SessionOptions session_options;
// Run option flags here: https://www.tensorflow.org/api_docs/python/tf/compat/v1/RunOptions
// We don't need any of these yet.
tensorflow::RunOptions run_options;
// Fills in this from a session run call
std::vector<tensorflow::Tensor> out;
std::string dir = "pyfiles/foo";
std::string npy_file = "pyfiles/data.npy";
std::string prediction_npy_file = "pyfiles/predictions.npy";
std::cout << "Found model: " << tensorflow::MaybeSavedModelDirectory(dir) << std::endl;
// TF_CHECK_OK takes the status and checks whether it works.
TF_CHECK_OK(tensorflow::LoadSavedModel(session_options,
run_options,
dir,
// Refer to tag_constants. We just want to serve the model.
{tensorflow::kSavedModelTagServe},
&bundle));
auto sig_map = bundle.meta_graph_def.signature_def();
// not sure why it's called this but upon running this for loop to check for keys we see it.
print_keys(sig_map);
std::string sig_def = "serving_default";
auto model_def = sig_map.at(sig_def);
auto inputs = model_def.inputs().at(_input_name);
auto input_name = inputs.name();
auto outputs = model_def.outputs().at(_output_name);
auto output_name = outputs.name();
auto input = load_npy_img(npy_file);
TF_CHECK_OK(bundle.session->Run({{input_name, input}},
{output_name},
{},
&out));
std::cout << out[0].DebugString() << std::endl;
auto res = out[0];
auto shape = get_tensor_shape(res);
// we only care about the first dimension of shape
xt::xarray<float> predictions = xt::zeros<float>({shape[0]});
for(int row = 0; row < shape[0]; row++) {
float max = FLT_MIN;
int max_idx = -1;
for(int col = 0; col < shape[1]; col++) {
auto val = res.tensor<float, 2>()(row, col);
if(max < val) {
max_idx = col;
max = val;
}
}
predictions(row) = max_idx;
}
xt::dump_npy(prediction_npy_file, predictions);
}
最新评论