文章目录
- 环境
- Libtorch下载
- Pytorch将.pth转为.pt文件
-
- python环境下的预测
-
- 输出结果:rose
- 新建pt模型生成文件
-
- 输出结果:rose
- C++调用pytorch模型
-
- 新建空项目pt_alex
- 项目属性配置
-
- 修改配置管理器
- 属性>VC++目录>包含目录
- 属性>VC++目录>库目录
- 属性>链接器>输入>附加依赖项
- 注意CUDA下的情况
- 属性>C/C++
- 项目下新建test.cpp
-
- 输出结果:rose
- C# Demo
-
- 新建C++空项目,封装DLL
-
- 源码
- 项目属性
- 点击生成解决方案,生成DLL
- 新建C#窗体应用
-
- DLLFun.cs
- 窗体Form2.cs核心代码
- 结果
参考:C++调用PyTorch模型:LibTorch
环境
Windows10
VS2017
CPUOpenCV3.0.0Pytorch1.10.2torchvision0.11.3
Libtorch1.10.2
Libtorch下载 Pytorch官网
文章图片
解压后:注意红框文件夹路径,之后需要添加到项目属性配置中。
文章图片
Pytorch将.pth转为.pt文件 所使用的模型为基于AlexNet的分类模型:AlexNet:论文阅读及pytorch网络搭建
python环境下的预测 输出结果:rose
文章图片
文章图片
新建pt模型生成文件
# tmp.pyimport os
import torch
from PIL import Image
from torchvision import transforms
from model import AlexNetdef main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))# create model
model = AlexNet(num_classes=5).to(device)image = Image.open("rose2.jpg").convert('RGB')
data_transform = transforms.Compose(
[transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
img = data_transform(image)
img = img.unsqueeze(dim=0)
print(img.shape)# load model weights
weights_path = "AlexNet.pth"
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)testsize = 224if torch.cuda.is_available():
modelState = torch.load(weights_path, map_location='cuda')
model.load_state_dict(modelState, strict=False)
model = model.cuda()
model = model.eval()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, testsize, testsize)
example = example.cuda()
traced_script_module = torch.jit.trace(model, example)output = traced_script_module(img.cuda())
print(output.shape)
pred = torch.argmax(output, dim=1)
print(pred)traced_script_module.save('model_cuda.pt')
else:
modelState = torch.load(weights_path, map_location='cpu')
model.load_state_dict(modelState, strict=False)
example = torch.rand(1, 3, testsize, testsize)
example = example.cpu()
traced_script_module = torch.jit.trace(model, example)output = traced_script_module(img.cpu())
print(output.shape)
pred = torch.argmax(output, dim=1)
print(pred)traced_script_module.save('model.pt')if __name__ == '__main__':
main()
输出结果:rose
【C/C++/C#|C++(Windows平台下利用LibTorch调用PyTorch模型)】
文章图片
文章图片
C++调用pytorch模型 新建空项目pt_alex
文章图片
项目属性配置 修改配置管理器
Release/x64
文章图片
属性>VC++目录>包含目录
添加:(libtorch解压位置)
文章图片
注意还应有opencv目录:(继承值修改可参考)
文章图片
属性>VC++目录>库目录
添加:
文章图片
属性>链接器>输入>附加依赖项
添加:
文章图片
注意:
如果后续出现error:找不到c10.dll,
直接把该目录下的相应dll复制到项目pt_alex/x64/Release文件夹下。
注意还应有opencv目录:(Debug下为lib*d.lib)
文章图片
注意CUDA下的情况
链接器>命令行,添加:
/INCLUDE:?warp_size@cuda@at@@YAHXZ
属性>C/C++
常规>SDL检查:选择否
语言>符合模式:选择否
项目下新建test.cpp c++调用后分类结果不准确的参考:
Ptorch 与libTorch 使用过程中问题记录
注意python和c++中的图像预处理过程需要完全一致。
// test.cpp#include // One-stop header.
#include "torch/torch.h"
#include
#include "opencv2/core.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/imgcodecs.hpp"
#include
#include
#include
#include
#include
#include // class_list
/*
"0": "daisy",
"1": "dandelion",
"2": "roses",
"3": "sunflowers",
"4": "tulips"
*/std::string classList[5] = { "daisy", "dandelion", "rose", "sunflower", "tulip" };
std::string image_path = "rose2.jpg";
int main(int argc, const char* argv[]) { // Deserialize the ScriptModule from a file using torch::jit::load().
//std::shared_ptr module = torch::jit::load("../../model_resnet_jit.pt");
using torch::jit::script::Module;
Module module = torch::jit::load("model.pt");
std::cout << "测试图片:" << image_path << std::endl;
std::cout << "cuda support:" << (torch::cuda::is_available() ? "ture" : "false") << std::endl;
std::cout << "CUDNN:" << torch::cuda::cudnn_is_available() << std::endl;
std::cout << "GPU(s): " << torch::cuda::device_count() << std::endl;
// module.to(at::kCUDA);
//cpu下会在(auto image = cv::imread(image_path, cv::IMREAD_COLOR))行引起c10:error,未经处理的异常
module.eval();
module.to(at::kCPU);
//assert(module != nullptr);
//std::cout << "ok\n";
//输入图像
auto image = cv::imread(image_path, cv::IMREAD_COLOR);
cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
cv::Mat image_transfomed = cv::Mat(cv::Size(224, 224), image.type());
cv::resize(image, image_transfomed, cv::Size(224, 224));
//cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);
// 转换为Tensor
torch::Tensor tensor_image = torch::from_blob(image_transfomed.data,
{ image_transfomed.rows, image_transfomed.cols,3 }, torch::kByte);
tensor_image = tensor_image.permute({ 2,0,1 });
tensor_image = tensor_image.toType(torch::kFloat);
auto tensor_image_Tmp = torch::autograd::make_variable(tensor_image, false);
tensor_image = tensor_image.div(255);
tensor_image = tensor_image.unsqueeze(0);
// tensor_image = tensor_image.to(at::kCUDA);
tensor_image = tensor_image.to(at::kCPU);
// 网络前向计算
at::Tensor output = module.forward({ tensor_image }).toTensor();
std::cout << "output:" << output << std::endl;
auto prediction = output.argmax(1);
std::cout << "prediction:" << prediction << std::endl;
int maxk = 5;
auto top3 = std::get<1>(output.topk(maxk, 1, true, true));
std::cout << "top3: " << top3 << '\n';
std::vector res;
for (auto i = 0;
i < maxk;
i++) {
res.push_back(top3[0][i].item().toInt());
}
// for (auto i : res) {
//std::cout << i << " ";
// }
// std::cout << "\n";
int pre = torch::Tensor(prediction).item();
std::string result = classList[pre];
std::cout << "This is:" << result << std::endl;
cvWaitKey();
return 0;
// system("pause");
}
出现以下报错不影响项目生成:
文章图片
输出结果:rose
文章图片
C# Demo 新建C++空项目,封装DLL
- 传入图像路径
- 传出类别序号
- // class_list
/*
“0”: “daisy”,
“1”: “dandelion”,
“2”: “roses”,
“3”: “sunflowers”,
“4”: “tulips”
*/
// test.cpp#include // One-stop header.
#include "torch/torch.h"
#include
#include "opencv2/core.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/imgcodecs.hpp"
#include
#include
#include
#include
#include
#include #include "test.h"std::string classList[5] = { "daisy", "dandelion", "rose", "sunflower", "tulip" };
int TestAlex(char* img)
{
// Deserialize the ScriptModule from a file using torch::jit::load().
//std::shared_ptr module = torch::jit::load("../../model_resnet_jit.pt");
using torch::jit::script::Module;
Module module = torch::jit::load("D:/model.pt");
// ...... 略 int pre = torch::Tensor(prediction).item();
std::string result = classList[pre];
//std::cout << "This is:" << result << std::endl;
return pre;
// system("pause");
}
//test.h#pragma once#include
#include extern "C" __declspec(dllexport) int TestAlex(char* img);
项目属性
- 输出目录改为C#项目路径
文章图片
新建C#窗体应用 DLLFun.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Runtime.InteropServices;
namespace AlexDemo
{
class DllFun
{
public string img;
[DllImport("AlexDLL.dll", CallingConvention = CallingConvention.Cdecl)]
public extern static int TestAlex(string img);
// 注意 C++ char* 对应 C# string
}
}
窗体Form2.cs核心代码
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Data;
using System.Drawing;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Windows.Forms;
using System.IO;
using AlexDemo;
namespace Demo
{
public partial class Form2 : Form
{
public Form2()
{
InitializeComponent();
this.Load += new EventHandler(Form2_Load);
//窗体启动后自动执行事件
}private void Form2_Load(object sender, EventArgs e)
{
string[] classList = { "daisy", "dandelion", "rose", "sunflower", "tulip" };
string fname = "path.txt";
//StreamReader sr = new StreamReader(fname, Encoding.Default);
StreamReader sr = new StreamReader(fname, Encoding.GetEncoding("gb2312"));
string line = sr.ReadLine();
//读取txt文件
if (line != null)
{
this.pictureBox1.Image = Image.FromFile(line);
if (line.Contains("\\"))
{
line = line.Replace("\\", "/");
}
}
int result;
result = DllFun.TestAlex(line);
//string r = result.ToString();
label2.Text = classList[result];
//StringBuilder img;
//img = new StringBuilder(1024);
//img.Append(line);
//int r = DllFun.TestAlex(img);
//label2.Text = "123";
}private void label1_Click(object sender, EventArgs e)
{}}
}
结果
文章图片
推荐阅读
- C|【C语言|RUNOOB教程】100道经典例题详解(1~5题)
- C/C++课程设计代码|C语言实现的2048小游戏
- C/C++|孪生素数——C语言实现
- 常用的七大排序算法
- COMP2401 simulator
- C/C++程序员进阶课堂|纯新网络编程教学
- 学习态度|为什么C++比C语言麻烦这么多,程序员笑了(这些点你知道吗())
- MIPS simulator 项目
- STM32实验讲解|STM32控制舵机讲解,从入门到放弃。