当前位置: 代码网 > it编程>硬件开发>stm32 > CUBEAI详细使用教程(STM32运行神经网络)---以手写识别为例

CUBEAI详细使用教程(STM32运行神经网络)---以手写识别为例

2024年07月31日 stm32 我要评论
实验效果,通过上位机上传图像到单片机识别后返回识别结果CUBEAI(Cube Artificial Intelligence)是一种人工智能(AI)中间件,旨在为嵌入式系统提供高效、灵活的神经网络推理能力。该中间件的设计目标是在资源有限的嵌入式设备上实现深度学习推理,从而为物联网(IoT)设备、嵌入式系统和边缘计算场景提供强大的人工智能支持。

系列文章目录



前言

实验效果,通过上位机上传图像到单片机识别后返回识别结果
在这里插入图片描述


cubeai(cube artificial intelligence)是一种人工智能(ai)中间件,旨在为嵌入式系统提供高效、灵活的神经网络推理能力。该中间件的设计目标是在资源有限的嵌入式设备上实现深度学习推理,从而为物联网(iot)设备、嵌入式系统和边缘计算场景提供强大的人工智能支持。

一、cubemx配置步骤

下载x-cube-ai工具包
示例:pandas 是基于numpy 的一种工具,该工具是为了解决数据分析任务而创建的。

在这里插入图片描述

创建项目选择能运行神经网络的mcu,stm32f4以上系列的都支持部署神经网络。

在这里插入图片描述

在这里插入图片描述
设置串口,注意一定要把串口中断打开要不然接收不到数据。
在这里插入图片描述
选择ai工具包,选择最近的cubeai工具包。最新的是8.1版本的,不同版本的生成的代码有些会不一样,有的模型老版本可以部署新版本就不行了,挺让人头疼的。
在这里插入图片描述
在这里插入图片描述

利用cubeai工具分析模型,这里并没有选择去优化压缩,因为模型本来就很小。cubeai提供的是一种无数据压缩,压缩效率并不是很高,也会损失精度。
在这里插入图片描述
进行分析得到,模型分析后的结果。会有模型的flash和ram
在这里插入图片描述

二、模型结构及模型存储方式

网络结构利用netron这个网址可以将模型进行可视化。模型结构如下图所示。
在这里插入图片描述
在这里插入图片描述
模型如何在单片机上进行存储

  1. 权重参数:存放在flash中
  2. 激活值:存在mcu自带的sram中,也可以使用外部的sd卡或者外接sram进行存储。需要用户自己定义
  3. 输入输出数据:存放在sram中需要用户自己定义。如下图所示

在这里插入图片描述
以下是官方文档中翻译过来的

三、常用api函数

cubeai模型推理库不开源只开放了一些api接口,因此我们必须了解一些常用的api使用方法。
在这里插入图片描述
在这里插入图片描述

1.ai_(name)_create()

2.ai_(name)_init

3.ai_(name)_create_and_init()

应用案例:进行创建和初始化网络模型

#include "network.h"
#include "network_data.h"

ai_aligned(32)
static ai_u8 activations[ai_network_data_activations_size];
...
const ai_handle acts[] = { activations };

ai_network_create_and_init(&network, acts, null);

3.ai_(name)_run()

官方提供的示例代码

根据示例代码可知,自己需要修改的地方主要有三点。
1、定义网络句柄、输入、输出和激活缓冲区数据buf 和管理输入和输出数据的指针。
2、获取和处理数据并进行推理。如果是图像分类任务,图像数据可能是来自摄像头模块获取。其他任务例如运动检测任务,数据可能会来自六轴加速度传感器。
3、对输出数据进行后处理。对于分类任务,模型输出结果是每个类别的概率,有时我们要输出准确率最高的类别就要写一个求最大值的函数

#include <stdio.h>

#include "network.h"
#include "network_data.h"

/* global handle to reference the instantiated c-model *//*引用实例化的c-model的全局句柄*/
static ai_handle network = ai_handle_null;

/* global c-array to handle the activations buffer *//*用于处理激活缓冲区的全局c数组*/
ai_aligned(32)
static ai_u8 activations[ai_network_data_activations_size];

/* array to store the data of the input tensor *//*用于存储输入张量数据的数组*/
ai_aligned(32)
static ai_float in_data[ai_network_in_1_size];
/* or static ai_u8 in_data[ai_network_in_1_size_bytes]; *//*用于存储输出张量数据的c数组*/

/* c-array to store the data of the output tensor */
ai_aligned(32)
static ai_float out_data[ai_network_out_1_size];
/* static ai_u8 out_data[ai_network_out_1_size_bytes]; */

/* array of pointer to manage the model's input/output tensors */
/*用于管理模型输入/输出张量的指针数组*/
static ai_buffer *ai_input;
static ai_buffer *ai_output;

/* 
 * bootstrap
 */
int aiinit(void) {
  ai_error err;
  
  /* create and initialize the c-model *//*创建并初始化c-model */
  const ai_handle acts[] = { activations };
  err = ai_network_create_and_init(&network, acts, null);
  if (err.type != ai_error_none) { ... };

  /* reteive pointers to the model's input/output tensors *//*获取指向模型输入/输出张量的指针*/
  ai_input = ai_network_inputs_get(network, null);
  ai_output = ai_network_outputs_get(network, null);

  return 0;
}

/* 
 * run inference
 */
int airun(const void *in_data, void *out_data) {
  ai_i32 n_batch;
  ai_error err;
  
  /* 1 - update io handlers with the data payload *//* 1 -用数据有效载荷更新io处理程序*/
  ai_input[0].data = ai_handle_ptr(in_data);
  ai_output[0].data = ai_handle_ptr(out_data);

  /* 2 - perform the inference *//* 2 -执行推理*/
  n_batch = ai_network_run(network, &ai_input[0], &ai_output[0]);
  if (n_batch != 1) {
      err = ai_network_get_error(network);
      ...
  };
  
  return 0;
}

/* 
 * example of main loop function
 */
void main_loop() {
  /* the stm32 crc ip clock should be enabled to use the network runtime library */
  /*应启用stm32 crc ip时钟,以使用网络运行时库*/
  __hal_rcc_crc_clk_enable();

  aiinit();

  while (1) {
    /* 1 - acquire, pre-process and fill the input buffers */
    /* 1 -获取、预处理和填充输入缓冲区*/
    acquire_and_process_data(in_data);

    /* 2 - call inference engine */
    /* 2 -调用推理引擎*/
    airun(in_data, out_data);

    /* 3 - post-process the predictions */
    /* 3——对预测进行后处理*/
    post_process(out_data);
  }
}

四、如何获取官方开发文档

在实际应用中发现cubeai在每个版本会有些差别,我训练好的模型在cubeai7.3可以使用但是在cubeai8.1就无法使用,每个版本也会新增加一些算子。可以在下载安装包中找到相对应的文档说明,例如在我的电脑中的这个目录获取文档
file:///c:/users/wg/stm32cube/repository/packs/stmicroelectronics/x-cube-ai/8.1.0/documentation/embedded_client_api.html
在这里插入图片描述

在这里插入图片描述

五、手写识别案例

代码来自https://github.com/colin2135/stm32g070_ai_test.git 大家可以去star一下。感谢作者的开源。作者使用的是stm32g0系列的单片机,我是使用的是stm32f7和正点原子的h5mini,f4系列的都可以。目前还没有尝试f1系列的单片机。

一、代码思路:1. 利用上位机将手写数据通过串口发送给单片机。2.单片机进行获取数据利用神经网络进行判断3.将输出结果发送给上位机。上位机链接https://github.com/colin2135/handwriteapp,下载源码后,找到这个文件路径。在这里插入图片描述

二、代码实现
1、定义网络句柄、输入、输出和激活缓冲区数据buf 和管理输入和输出数据的指针。

ai_handle network;
float aiindata[ai_network_in_1_size];
float aioutdata[ai_network_out_1_size];
ai_u8 activations[ai_network_data_activations_size];

ai_buffer * ai_input;
ai_buffer * ai_output;

2、获取和处理数据
通过串口获取数据

/* user code begin 4 */利用回调函数接收上位机发来的手写数字的数据
void hal_uart_rxcpltcallback(uart_handletypedef *uarthandle)
{
	if(gorunning ==0)
	{
		if (uart_rx_length < uart_buff_len)
		{
			uart_rx_buffer[uart_rx_length] = uart_rx_byte;
			uart_rx_length++;

			if (uart_rx_byte == '\n')
			{
				gorunning = 1;
			}
		}
		else
		{
			//rt_kprintf("rx len over");
			uart_rx_length = 0;
		}
	}
	hal_uart_receive_it(&huart1, (uint8_t *)&uart_rx_byte, 1);
}

处理数据,通过上位机发送的数据是8位数据,由于模型参数是32位浮点数因此输入数据要转换成32位浮点数

void picturechararraytofloat(uint8_t *srcbuf,float *dstbuf,int len)
{
	for(int i=0;i<len;i++)
	{
		dstbuf[i] = srcbuf[i];//==1?0:1;
	}
}

神经网络推理

static void ai_run(float *pin, float *pout)
{
	char logstr[100];
	int count = 0;
	float max = 0;
  ai_i32 batch;
  ai_error err;

  /* update io handlers with the data payload */
  ai_input[0].data = ai_handle_ptr(pin);
  ai_output[0].data = ai_handle_ptr(pout);

  batch = ai_network_run(network, ai_input, ai_output);
  if (batch != 1) {
    err = ai_network_get_error(network);
    printf("ai ai_network_run error - type=%d code=%d\r\n", err.type, err.code);
    error_handler();
  }
  for (uint32_t i = 0; i < ai_network_out_1_size; i++) {

	  sprintf(logstr,"%d  %8.6f\r\n",i,aioutdata[i]);
	  uart_send(logstr);
	  if(max<aioutdata[i])
	  {
		  count = i;
		  max= aioutdata[i];
	  }
  }
  sprintf(logstr,"current number is %d\r\n",count);
  uart_send(logstr);
}

3、对输出数据进行后处理
将输出结果和最大值通过串口进行发送

 for (uint32_t i = 0; i < ai_network_out_1_size; i++) {

	  sprintf(logstr,"%d  %8.6f\r\n",i,aioutdata[i]);
	  uart_send(logstr);
	  if(max<aioutdata[i])
	  {
		  count = i;
		  max= aioutdata[i];
	  }
  }

3.whlie(1)代码

  while (1)
  {
    /* user code end while */

    /* user code begin 3 */
	  uart_send(message);
	  char str[10];
	  if(gorunning>0)
	  {
		  if(uart_rx_length == one_frame_len)
		  {
			  picturechararraytofloat(uart_rx_buffer+1,aiindata,28*28);
			  ai_run(aiindata, aioutdata);

		  }
		  memset(uart_rx_buffer,0,784);
		  gorunning = 0;
		  uart_rx_length = 0;
	  }
  }
  /* user code end 3 */

完整代码

/* user code begin header */
/**
  ******************************************************************************
  * @file           : main.c
  * @brief          : main program body
  ******************************************************************************
  * @attention
  *
  * copyright (c) 2024 stmicroelectronics.
  * all rights reserved.
  *
  * this software is licensed under terms that can be found in the license file
  * in the root directory of this software component.
  * if no license file comes with this software, it is provided as-is.
  *
  ******************************************************************************
  */
/* user code end header */
/* includes ------------------------------------------------------------------*/
#include "main.h"

/* private includes ----------------------------------------------------------*/
/* user code begin includes */
#include "stdio.h"
#include "ai_platform.h"
#include "network.h"
#include "network_data.h"

/* user code end includes */

/* private typedef -----------------------------------------------------------*/
/* user code begin ptd */

/* user code end ptd */

/* private define ------------------------------------------------------------*/
/* user code begin pd */

/* user code end pd */

/* private macro -------------------------------------------------------------*/
/* user code begin pm */

/* user code end pm */

/* private variables ---------------------------------------------------------*/

crc_handletypedef hcrc;

uart_handletypedef huart1;

/* user code begin pv */

/* user code end pv */

/* private function prototypes -----------------------------------------------*/
void systemclock_config(void);
static void mx_gpio_init(void);
static void mx_crc_init(void);
static void mx_usart1_uart_init(void);
/* user code begin pfp */

/* user code end pfp */

/* private user code ---------------------------------------------------------*/
/* user code begin 0 */
ai_handle network;
float aiindata[ai_network_in_1_size];
float aioutdata[ai_network_out_1_size];
ai_u8 activations[ai_network_data_activations_size];

ai_buffer * ai_input;
ai_buffer * ai_output;

static void ai_init(void);
static void ai_run(float *pin, float *pout);
void picturechararraytofloat(uint8_t *srcbuf,float *dstbuf,int len);


void uart_send(char * str);
#define uart_buff_len 1024
#define one_frame_len 1+784+2
uint16_t uart_rx_length = 0;
uint8_t uart_rx_byte = 0;
uint8_t uart_rx_buffer[uart_buff_len];
volatile uint8_t gorunning = 0;

char message[]="hello";

/* user code end 0 */

/**
  * @brief  the application entry point.
  * @retval int
  */
int main(void)
{
  /* user code begin 1 */

  /* user code end 1 */

  /* enable i-cache---------------------------------------------------------*/
//  scb_enableicache();
//
//  /* enable d-cache---------------------------------------------------------*/
//  scb_enabledcache();

  /* mcu configuration--------------------------------------------------------*/

  /* reset of all peripherals, initializes the flash interface and the systick. */
  hal_init();

  /* user code begin init */

  /* user code end init */

  /* configure the system clock */
  systemclock_config();

  /* user code begin sysinit */

  /* user code end sysinit */

  /* initialize all configured peripherals */
  mx_gpio_init();
  mx_crc_init();
  mx_usart1_uart_init();
  /* user code begin 2 */
  ai_init();
  memset(uart_rx_buffer,0,784);
  hal_uart_receive_it(&huart1, (uint8_t *)&uart_rx_byte, 1);


  /* user code end 2 */

  /* infinite loop */
  /* user code begin while */
  while (1)
  {
    /* user code end while */

    /* user code begin 3 */
	  uart_send(message);
	  char str[10];
	  if(gorunning>0)
	  {
		  if(uart_rx_length == one_frame_len)
		  {
			  picturechararraytofloat(uart_rx_buffer+1,aiindata,28*28);
			  ai_run(aiindata, aioutdata);

		  }
		  memset(uart_rx_buffer,0,784);
		  gorunning = 0;
		  uart_rx_length = 0;
	  }
  }
  /* user code end 3 */
}

/**
  * @brief system clock configuration
  * @retval none
  */
void systemclock_config(void)
{
  rcc_oscinittypedef rcc_oscinitstruct = {0};
  rcc_clkinittypedef rcc_clkinitstruct = {0};

  /** configure the main internal regulator output voltage
  */
  __hal_rcc_pwr_clk_enable();
  __hal_pwr_voltagescaling_config(pwr_regulator_voltage_scale1);

  /** initializes the rcc oscillators according to the specified parameters
  * in the rcc_oscinittypedef structure.
  */
  rcc_oscinitstruct.oscillatortype = rcc_oscillatortype_hsi;
  rcc_oscinitstruct.hsistate = rcc_hsi_on;
  rcc_oscinitstruct.hsicalibrationvalue = rcc_hsicalibration_default;
  rcc_oscinitstruct.pll.pllstate = rcc_pll_on;
  rcc_oscinitstruct.pll.pllsource = rcc_pllsource_hsi;
  rcc_oscinitstruct.pll.pllm = 8;
  rcc_oscinitstruct.pll.plln = 216;
  rcc_oscinitstruct.pll.pllp = rcc_pllp_div2;
  rcc_oscinitstruct.pll.pllq = 2;
  if (hal_rcc_oscconfig(&rcc_oscinitstruct) != hal_ok)
  {
    error_handler();
  }

  /** activate the over-drive mode
  */
  if (hal_pwrex_enableoverdrive() != hal_ok)
  {
    error_handler();
  }

  /** initializes the cpu, ahb and apb buses clocks
  */
  rcc_clkinitstruct.clocktype = rcc_clocktype_hclk|rcc_clocktype_sysclk
                              |rcc_clocktype_pclk1|rcc_clocktype_pclk2;
  rcc_clkinitstruct.sysclksource = rcc_sysclksource_pllclk;
  rcc_clkinitstruct.ahbclkdivider = rcc_sysclk_div1;
  rcc_clkinitstruct.apb1clkdivider = rcc_hclk_div4;
  rcc_clkinitstruct.apb2clkdivider = rcc_hclk_div2;

  if (hal_rcc_clockconfig(&rcc_clkinitstruct, flash_latency_7) != hal_ok)
  {
    error_handler();
  }
}

/**
  * @brief crc initialization function
  * @param none
  * @retval none
  */
static void mx_crc_init(void)
{

  /* user code begin crc_init 0 */

  /* user code end crc_init 0 */

  /* user code begin crc_init 1 */

  /* user code end crc_init 1 */
  hcrc.instance = crc;
  hcrc.init.defaultpolynomialuse = default_polynomial_enable;
  hcrc.init.defaultinitvalueuse = default_init_value_enable;
  hcrc.init.inputdatainversionmode = crc_inputdata_inversion_none;
  hcrc.init.outputdatainversionmode = crc_outputdata_inversion_disable;
  hcrc.inputdataformat = crc_inputdata_format_bytes;
  if (hal_crc_init(&hcrc) != hal_ok)
  {
    error_handler();
  }
  /* user code begin crc_init 2 */

  /* user code end crc_init 2 */

}

/**
  * @brief usart1 initialization function
  * @param none
  * @retval none
  */
static void mx_usart1_uart_init(void)
{

  /* user code begin usart1_init 0 */

  /* user code end usart1_init 0 */

  /* user code begin usart1_init 1 */

  /* user code end usart1_init 1 */
  huart1.instance = usart1;
  huart1.init.baudrate = 115200;
  huart1.init.wordlength = uart_wordlength_8b;
  huart1.init.stopbits = uart_stopbits_1;
  huart1.init.parity = uart_parity_none;
  huart1.init.mode = uart_mode_tx_rx;
  huart1.init.hwflowctl = uart_hwcontrol_none;
  huart1.init.oversampling = uart_oversampling_16;
  huart1.init.onebitsampling = uart_one_bit_sample_disable;
  huart1.advancedinit.advfeatureinit = uart_advfeature_no_init;
  if (hal_uart_init(&huart1) != hal_ok)
  {
    error_handler();
  }
  /* user code begin usart1_init 2 */

  /* user code end usart1_init 2 */

}

/**
  * @brief gpio initialization function
  * @param none
  * @retval none
  */
static void mx_gpio_init(void)
{
/* user code begin mx_gpio_init_1 */
/* user code end mx_gpio_init_1 */

  /* gpio ports clock enable */
  __hal_rcc_gpioa_clk_enable();
  __hal_rcc_gpiob_clk_enable();
  __hal_rcc_gpioh_clk_enable();

/* user code begin mx_gpio_init_2 */
/* user code end mx_gpio_init_2 */
}

/* user code begin 4 */
void hal_uart_rxcpltcallback(uart_handletypedef *uarthandle)
{
	if(gorunning ==0)
	{
		if (uart_rx_length < uart_buff_len)
		{
			uart_rx_buffer[uart_rx_length] = uart_rx_byte;
			uart_rx_length++;

			if (uart_rx_byte == '\n')
			{
				gorunning = 1;
			}
		}
		else
		{
			//rt_kprintf("rx len over");
			uart_rx_length = 0;
		}
	}
	hal_uart_receive_it(&huart1, (uint8_t *)&uart_rx_byte, 1);
}

void uart_send(char * str)
{
	hal_uart_transmit(&huart1, (uint8_t *)str, strlen(str),0xffff);
}

static void ai_init(void)
{
  ai_error err;

  /* create a local array with the addresses of the activations buffers */
  const ai_handle act_addr[] = { activations };
  /* create an instance of the model */
  err = ai_network_create_and_init(&network, act_addr, null);
  if (err.type != ai_error_none) {
    printf("ai_network_create error - type=%d code=%d\r\n", err.type, err.code);
    error_handler();
  }
  ai_input = ai_network_inputs_get(network, null);
  ai_output = ai_network_outputs_get(network, null);
}

static void ai_run(float *pin, float *pout)
{
	char logstr[100];
	int count = 0;
	float max = 0;
  ai_i32 batch;
  ai_error err;

  /* update io handlers with the data payload */
  ai_input[0].data = ai_handle_ptr(pin);
  ai_output[0].data = ai_handle_ptr(pout);

  batch = ai_network_run(network, ai_input, ai_output);
  if (batch != 1) {
    err = ai_network_get_error(network);
    printf("ai ai_network_run error - type=%d code=%d\r\n", err.type, err.code);
    error_handler();
  }
  for (uint32_t i = 0; i < ai_network_out_1_size; i++) {

	  sprintf(logstr,"%d  %8.6f\r\n",i,aioutdata[i]);
	  uart_send(logstr);
	  if(max<aioutdata[i])
	  {
		  count = i;
		  max= aioutdata[i];
	  }
  }
  sprintf(logstr,"current number is %d\r\n",count);
  uart_send(logstr);
}

void picturechararraytofloat(uint8_t *srcbuf,float *dstbuf,int len)
{
	for(int i=0;i<len;i++)
	{
		dstbuf[i] = srcbuf[i];//==1?0:1;
	}
}


/* user code end 4 */

/**
  * @brief  this function is executed in case of error occurrence.
  * @retval none
  */
void error_handler(void)
{
  /* user code begin error_handler_debug */
  /* user can add his own implementation to report the hal error return state */
  __disable_irq();
  while (1)
  {
  }
  /* user code end error_handler_debug */
}

#ifdef  use_full_assert
/**
  * @brief  reports the name of the source file and the source line number
  *         where the assert_param error has occurred.
  * @param  file: pointer to the source file name
  * @param  line: assert_param error line source number
  * @retval none
  */
void assert_failed(uint8_t *file, uint32_t line)
{
  /* user code begin 6 */
  /* user can add his own implementation to report the file name and line number,
     ex: printf("wrong parameters value: file %s on line %d\r\n", file, line) */
  /* user code end 6 */
}
#endif /* use_full_assert */


(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com