Deep Learningを使った空気圧縮機の間欠動作検出

1.概要

Arduino Nano 33 BLE Senseに搭載されたデジタルマイクロフォンMP34DT05の出力データを使い、間欠動作する機器の動作異常音を検出するシステムの検討を行います。そのためには、間欠動作する機器が動作・停止のいずれの状態であるかを知る必要があります。
MP34DT05の出力である音データをFFT(高速フーリエ変換)し、その結果を機械学習(Deep Learning)させることで、動作中か停止中かを自動で認識させます。
MP34DT05の出力データをBLE(Bluetooth Low Energy)でPDH(Raspberry Pi)に転送し、PDHでFFTと機械学習結果の判定を行います。ArduinoのBLEで、1回で32bitのデータを送るのに約50ms必要です(これ以上速いと取りこぼしが発生します)。
そこで、以下の様な考え方でシステムを検討しました。
 1)MP34DT05のデータは16bit整数ですので、2回分を32bitにまとめて送ります。
 2)FFTの帯域はMP34DT05の最大性能を確保(8kHz=16kHz÷2)
 3)BLEでのデータ送信に時間がかかるため、データを連続で必要数計測したのちに送信しFFTを行う。
 4)FFTの間隔が転送時間(表1参照)に律速されることは許容する。
 5)転送時間の律速の影響をなりべく小さくするために、停止時には256サンプルの荒い周波数分解能でFFTを行い、起動をいち早く知る。そして、起動中は1024サンプルに変えて細かい周波数分解能で異常の検出を行えるようにする。

               表1.機器の状態とFFT条件

機器の状態

サンプル数

周波数分解能

FFT周期

停止時

256

62.5Hz(=16kHz÷256)

約6.4秒(=256×50ms÷2) *1

動作時

1024

15.625Hz(=16kHz÷1024)

約25.6秒(=1024×50ms÷2) *1

MP34DT05の出力データをPDHに転送しFFTを実施する内容に関しては、「第6巻 4-5.デジタルマイクMP34DT05データのPDHによるFFT解析」を参照ください。
また、FFT結果の機械学習(Deep Learning)による状態判定に関しては、「第9巻 ・音響データのFFT結果を使った機械学習(Deep Learning)」を参照ください。
ここでは、判定結果によるサンプル数の変更の仕組みに関して述べます。

2.システムブロック図

システムブロック図を次に示します。
本レポートでは、主にPDHの下側のライン(FFTの結果の判定からArduinoへのサンプル数の指示)を説明します。
1)Arduino Nano 33 BLE SenseとPDHの間は、BLEで双方向で通信します。
そのために、notify用とwrite用のサービスを準備しています。
2)PDHは、Pythonのプログラムが2本とNode-REDで構成しています。PDH内部は、MQTTで通信しています。トピックスは3種類あります。
3)python pro1は、Arduino Nanoとの通信を行うフロントエンド部です。
4)Node-RED部はpython pro1とpython pro2の起動停止制御を行います。そして、BLEで送られてきたデータの集約を行い、
python pro2に渡します。
5)python pro2は、音データのFFTを行います。そして、FFT結果を機械学習で決めtパラメータに基づいて判定部で状態を判定します。
判定結果が機器が起動状態であれば、1024サンプルのデータを要求します。停止状態であれば256サンプルを要求します。
6)これとは別に、機械学習(Deep Learning)用のpythonのプログラムをPDHで実行します。

3. 動作実験結果

上記のシステムで、空気圧縮機の動作を確認した結果を次に示します。
現在の状態を機械学習(Deep Learning)で習得したパラメータを使ってFFT結果をもとに分類し、状態の変化が起こった場合には、サンプル数を自動で切り替えてFFTの精度を上げる(1024サンプル)、もしくはFFTの周期を速くする(256サンプル)ことが可能になりました。


① 電源投入直後は、待機状態です。256サンプルのFFTでスタートします。
② 待機状態で空気が使用されると圧力が下がり、設定値を下回ると空気圧縮機が起動し圧縮状態になります。
↓ :256サンプルでのFFT結果をDeep Learningで分類し、圧縮状態を認識するとサンプル数を1024に切り替えます。
③ 1024サンプルの圧縮状態です。この状態の時に、圧縮用モータの異常検出を行います。
↓ :圧縮がしばらく続き、圧力が設定値に戻ると圧縮が終了します。
④ 1024サンプルの待機状態になりました。
↓:1024サンプルでのFFT結果をDeep Learningで分類し、待機状態を認識するとサンプル数を256に切り替えます。
① 256サンプルの待機状態に移ります。

4.Arduino Nano 33 BLE Senseのサンプル数制御

Arduino Nano 33 BLE SenseとPDH(Raspberry Pi)は、BLEで双方向の通信を行います。Nanoはペリフェラル機器、PDHはセントラル機器です。従って、PDHからNanoのデータをリードする際には、Notifyサービスを使用し、PDHからNanoへサンプル数を知らせる場合には、Writeサービスで行います。

ここでのデータのやり取りでポイントとなるのは、PDHから指定したサンプル数のデータが送られてくるのはいつからかということです。例えば256サンプルから1024サンプルに変更せよとPDHからNanoにBLEのWriteサービスで連絡したからといって、次に送られてくるのは、256サンプルの可能性もあれば1024サンプルの可能性もあります。なぜなら、Nanoが1024サンプルにせよとの指示を受けた時にすでに、256サンプルの採取が始まっているかもしれないからです。

対応として、デジタルマイクMP34DT05データのPDHによるFFT解析3.1.データ送信時の頭出しの工夫で説明したNanoがデータを送る際に、先頭であることを示すために“7FFFFFFFH”をデータの先頭に付加する方法を流用します。そして、以下の様に変更します。

サンプル数

先頭32bit

256

7FFFFFFFH(=32767)

1024

7FF0FFFFH(=32752)

これにより、PDHはNotifyで送られてくるデータに“7FFFFFFFH”か、”7FF0FFFFH”が無いかを探し、あれば、データの先頭だということが分かるとともに、このデータが256サンプルなのか1024サンプルなのかも区別できます。

5.Arduino Nano 33 BLE Senseの送受信フロー

Nanoのプログラムで、NotifyとWriteは以下の様なフローで動かしています。

① sample数読出でWriteサービスによるsample数を読みに行き、
  samples_numに格納します。
② sample数変更で、max_samplesにsamples_numを書込みます。
sample数変更とマイクデータ読出しは、cnt==0の時のみ実行されます。
これは読み取ったマイクデータを送信し終わったタイミングです。
(32bit単位で送るので、cnt≧max_samples÷2+1になった時1回分の送信終了)

Arduinoのプログラムは、Appendix.1に掲載しますので参考にしてください。
ポイントは、上記で説明した2点とWriteサービスのデータの読出しコマンドです。

読出しに関して補足します。
読出しは、readSamplesNum()という関数で行っています。

BLELongCharacteristic MP34DT05_PDM2(MP34DT05_PDM_Characteristic2_UUID, BLERead | BLEWrite);

LongCharacteristicとしてMP34DT05_PDM2を宣言。

Sensor_MP34DT05_Service.addCharacteristic(MP34DT05_PDM2);

サービスに、MP34DT05_PDM2を追加。

void readSamplesNum(){
  if (MP34DT05_PDM2.written()) {
    if (MP34DT05_PDM2.value()) {   // any value other than 0
      samples_num = 1024;
    } else {                              // a 0 value
      samples_num = 256;
    }
  }
}

MP34DT05_PDM2.written()で書き込まれたかを確認し、書き込まれていたら、MP34DT05_PDM2.value()で読み出します。
“1”なら1024サンプル、”0”なら256サンプルです。

6.PDHのBLE送信動作

次に、PDHのBLEのWriteサービスを使って、サンプル数を送信する動作を確認します。次のブロック図のPDH内のpython pro1で実現します。
python pro1は、BLEのNotifyサービスで送られてくるデータをJSON形式にしてtopic名”pdm_pic”でパブリッシュします。ブローカーは、PDH内のNode-REDに立ち上げていますので、アドレスはlocalhostです。それと並行して、python pr2からのトピック名“fft_condition”をサブスクライブします。そしてデータをサブスクライブしたら、BLEのWriteサービスでNanoに書込みます。

ポイントは、1) BLEのNotifyサービスとWriteサービスの双方向さーぶすへの対応と 2) MQTTのパブリッシングとサブスクライブの双方向動作の実現です。

6.1. BLEのNotifyサービスとWriteサービスの双方向サービスへの対応

まず、一つ目のBLEのNotifyサービスとWriteサービスの双方向サービスへの対応から見ていきます。
BLEは、PythonのBleakライブラリを使用します。そのために、asyncioライブラリも使用します。
asyncioを使って、BLEのNotify動作(client.start_notify())とサンプル数の値が変化したらWriteする動作(client.write_gatt_char())の2つの無限ループを動かしているイメージです。
フロー図を次に示します。

左側が、notifyサービスのループで、右側がWriteサービスのループです。

1)Notify サービス
Notifyは、以下の関数呼び出しでchar_uuidにnotifyがあると、norification_handler()が呼び出されます。

 await client.start_notify(char_uuid, notification_handler)

notification_handlerの中に、実行したいコマンドを記述します。
ここでは、jsonフォーマットに受信データをはめ込み、パブリッシュします。

2) Writeサービス

Writeサービスは、以下の関数の呼び出しでchar2_uuidにsampleNoが書き込まれます。sampleNoはbytearrayです。

client.write_gatt_char(char2_uuid, sampleNo)

動作としてはsetSampleNo()関数で、subscribe結果をsampleNoに渡します。その値が前回と異なるとWriteサービスを呼び出します。同じ場合は、何もしません。

6.2. MQTTのパブリッシングとサブスクライブの双方向動作の実現

もう一つのポイントのMQTTのパブリッシングとサブスクライブの双方向動作の実現について説明します。MQTTには、pythonライブラリpaho.mqttを使用します。

1) mqttブローカー接続

publish、subscribeのいずれかもしくは両方を実装するためには、まずブローカーに接続する必要があります。

import paho.mqtt.client as mqtt
mqtt_client = mqtt.Client()
mqtt_client.connect(broker_address, port)

1行目:paho.mqtt.clientをmqttとして呼び込みます。
2行目:mqtt_clientとしてmqtt.Client()のインスタンスを作成します。
この一つのインスタンスでpublish、scribeの両方の動作を行うことができます。
3行目:mqtt_clientのconnectメソッドを使って、ブローカーに接続します。ブローカーのアドレスと、ポート番号を指定します。
ここでは、それぞれlocalhostと1883を使っています。

2) パブリッシュ動作

パブリッシュ動作は、mqtt_clientのpublishメソッドを使って送信します。

mqtt_client.publish(topic, json_data)

トピック名とjsonデータを指定します。jsonデータは、pythonのリストをjson形式に変換してくれるjsonライブラリのdumpsというメソッドがあります。

json_data = json.dumps(sdata)

3) subscribe動作

サブスクライブ動作は、受け身の動作になるので、以下のメソッドを使用します。

mqtt_client.on_connect = on_connect
mqtt_client.on_message = on_message
mqtt_client.loop_start()

1行目:サブスクライブのコネクト時に実行する関数を指定します。今回のソースでは、resultコードの表示とトピック名の設定をしています。

def on_connect(client, userdata, flags, rc):
    print("Connected with result code "+str(rc))
    client.subscribe(topic2)

2行目:サブスクライブのメッセージが来た際に実行する関数を指定します。今回のソースでは、recv_sampleにmsg.payloadをpythonオブジェクト化したデータを代入します。

def on_message(client, userdata, msg):
    global recv_sample
    recv_sample = json.loads(msg.payload)

3行目:ループ動作を開始します。

7.Node-REDでのFFT前処理

R-CPS-HPの第6巻 4-5. デジタルマイクMP34DT05データのPDHによるFFT解析4.FFT前の準備(Node-RED)で説明していますが、Node-REDにFFTの前処理を行わせています。前述のシステムブロック図から関連するところ抜き出すと以下の様になります。
 4.FFT前の準備(Node-RED)で述べた内容から変わった部分は、python pro1から送られてくるデータのサンプル数が1024個の場合と256個の場合があるという点のみです。送られてくるデータを監視し、”7FFFFFFF”か“7FF0FFFF”でサンプル数が256個か1024個かを判断し、その数での集約を行い、FFTを行う後段のpython pro2にMQTTで送信します。

Node-REDの該当箇所のフローを示します。ここでは、変更になっている赤枠で囲んだfunctionノードのみ説明します。詳細は、こちら( 4.FFT前の準備(Node-RED))を参照ください。また、Appendix3.にフローのjsonデータを掲示しておりますのでそちらを参照ください。

以下に、functionノードの中のプログラムのフローを示します。実際には、変数のflow変数としていますが、ここでは煩雑になるのを避けるために、flowの記述を外しています。変更になっている箇所は、初期化部分のcount_max値が256になっているところと、赤枠で囲ったcount_maxの設定の部分です。

functionノードでは、javascriptのプログラムで以下の処理を行っています。
1) 先頭データ(0x7FFFFFFF or 0x7FF0FFFF)の検出。
2) 検出した先頭データが0x7FFFFFFF or 0x7FF0FFFFなら、1024個の配列array[1024]を0に初期化する
3) さらに先頭データが0x7FFFFFFFならcount_max=256個、0x7FF0FFFFならcount_max=1024個とする。
4) 先頭データの次のデータからcount_max個のデータを集める。
5) msg.payload.dataとして次段にデータを送る。

8. FFT計算処理と状態判定プログラム(python)

システムブロック図で一番右側にあるpythonのプログラム(python pro2)は、送られてきたデータのFFT処理とその処理結果を機械学習結果のパラメータを用いて分類(判定)します。分類(判定)した結果が、圧縮状態であり、かつ現状256サンプルであれば、送ってくるサンプル数を1024サンプルに変えるようにMQTTにパブリッシングします。また、分類(判定)した結果が、待機状態であり、かつ現状1024サンプルであれば、送ってくるサンプル数を256サンプルに変えるようにMQTTにパブリッシングします。

FFTのプログラムに関しては、第6巻 4-5. デジタルマイクMP34DT05データのPDHによるFFT解析で1024サンプルの場合の説明を行っています。また、機械学習(Deep Learning)に関しては、第9巻音響データのFFT結果を使った機械学習(Deep Learning)でこちらも1024サンプルの場合に関して説明しています。基本的にいずれも1024サンプルと256サンプルで大きな違いはありません(詳細はAppendixのプログラムを参照ください)。ここでは、FFTに関しては、256サンプルと1024サンプルの切換に関してを中心に説明します。そして、分類(判定)に関しては、機械学習(Deep Learning)結果のパラメータを使って判定する方法に関して説明します。

8.1.プログラムの概略フロー

プログラムの概略フローを以下に示します。
左側はmain関数の中身です。初めに、機械学習(Deep Learning)で取得したパラメータを、サンプル数N=256の場合とN=1024の場合を共に読み込みます。その後、MQTTブローカーに接続します。subscribeデータ入手時に実行する関数の設定を行い、loop_foreverメソッドでサブスクライブ待ちに入ります。
右側はサブスクライブデータ入手時に実行する関数の中身です。
msg.payloadをpythonのオブジェクトに変換後、count値を読み出しサンプル数Nを入手します。その後、送られてきたdataに対してFFTを実行し、結果のPlotを作成します。そして、そのFFT結果を機械学習結果のパラメータを使って分類(判定)します。分類結果と現在のサンプル数からサンプル数の変更が必要ならパブリッシュします。

8.2.FFT処理プログラム

FFTのプログラムに関しては、サンプル数Nの値が256と1024の2つもちうるうるだけで第6巻 4-5. デジタルマイクMP34DT05データのPDHによるFFT解析5. FFT計算処理プログラム(python)から変わっていません。

MQTTのトピック“fft_data”をサブスクライブしています。データが送られてくると、関数on_massage()が実行されます。

dict_data = json.loads(msg.payload)

  サブスクライブされたmsg.payloadをjsonのオブジェクトに変更します。

N = dict_data["count"] + 2

  送られてきたデータ数に2を足してサンプリング数Nにします。256時にはcount=254が、1024時にはcount=1022となっています。ここでサンプリング数が決まります。

window = np.hanning(N)

  窓関数にハニング窓を使います。

data = dict_data["data"]
y = np.array(data[:N])

  dataをnumpy配列に変換します。

y = y * window

  窓関数を掛け合わせます。

以降FFTと結果のプロットに関しては、第6巻 4-5. デジタルマイクMP34DT05データのPDHによるFFT解析と同じです。

8.3.機械学習(Deep Learning)結果を使った分類(判定)プログラム

第9巻 音響データのFFT結果を使った機械学習(Deep Learning)で、ミニバッチ学習を行い、その結果のパラメータをライブラリpickleを使ってファイルに保存しました。
ここでは、この保存したパラメータを読み込み、分類(判定)を行います。
そのために、第9巻 音響データのFFT結果を使った機械学習(Deep Learning)では、参考文献(「ゼロから作るDeep Learning」 斎藤 康毅 著 オライリー・ジャパン)のプログラムの“two_layer_net.py”をそのまま使いましたが、ここでは、データを読み込むためにメソッドを追加します。

追加するメソッドは以下です。

def setlayer(self):
        self.layers['Affine1'] = Affine(self.params['W1'], self.params['b1'])
        self.layers['Relu1'] = Relu()
        self.layers['Affine2'] = Affine(self.params['W2'], self.params['b2’])

パラメータをレイヤに設定するためのメソッドです。

話を、pytnon pro2に戻します。
main関数の冒頭で、N=1024とN=256の学習結果のパラメータファイルを読み込み、setlayer()を実行します。以下、

# Deep Learning Calc
network256  = TwoLayerNet(input_size=127, hidden_size=100, output_size=5)
network1024 = TwoLayerNet(input_size=511, hidden_size=100, output_size=5)

  TwoLayerNetを256用と1024用に作成します。inputサイズ= サンプル数N÷2 – 1となっています。DCデータが除去されているためです。

# Deep Learning用paramter読込み
with open('params_1024.pkl', 'rb') as f1:
    network1024.params = pickle.load(f1)
network1024.setlayer()

with open('params_256.pkl', 'rb') as f2:
    network256.params = pickle.load(f2)
network256.setlayer()

  ファイル名‘params_1024(256).pkl‘のライブラリpickleで保存したファイルを開けて、パラメータに読み込みます。  そして、setlayer()メソッドを実行します。

そして、判定以下のon_message()関数の中のFFT直後で行っています。

# Deep Learningによる判定
    if N == 256:
        y = network256.predict(np.tile(Amp[1:int(N/2)], (2,1)))
    else:
        y = network1024.predict(np.tile(Amp[1:int(N/2)], (2,1)))
    y = np.argmax(y, axis=1)
    tp = int(y[0])

3行目:N=256の場合、5行目:N=1024の場合のpredictメソッドを呼び出して、各状態である確率を計算させます。
6行目でその中で一番大きい値を選びます。
7行目でその状態の番号をtpに代入します。値は0~4です。

tpの値をベースに現在のサンプル数Nの値から次のサンプル数nextを決めます。next != Nであれば、nextの値をtopic”fft_condition”にパブリッシュします。

 # Sampling point数の変更
    if (tp == 2 or tp == 4):
        next = 1024
    else:
        next = 256

    if next != N:
        publish(str(next))

2行から5行目でnextを設定し、7行目から8行目でpublishを行っています。

Appendix.1 MP34DT05データBLE転送プログラム

Arduino Nano 33 BLE Senseに搭載されているMP34DT05の音声データをBLEで転送するプログラムを下記に載せます。
参考にしてください。2つのタブから構成されています。
メインタブ:peri-ble-pdm03.ino
サブタブ :update_MP34DT05.ino

1) peri-ble-pdm03.ino

/*
 * prei-ble-pdm03: 間欠動作対応 サンプル数256とサンプル数1024の切換
 * Arduino LSM9DS1 : send sensordata on BLE
 * Arduino MP34DT05 : Digital Microphone
 */
 
#include 
#include 

#define localNAME "peri_pdm"
#define DeviceNAME "PDM"

// UUID for MP34DT05(Digital MicroPhone)
#define MP34DT05_SERVICE_UUID "9997c274-9b20-4181-808f-ef2de31032a8"
#define MP34DT05_PDM_Characteristic_UUID "9997c274-9b21-4181-808f-ef2de31032a8"
#define MP34DT05_PDM_Characteristic2_UUID "9997c274-9b22-4181-808f-ef2de31032a8"

// send data size
#define MAX_SAMPLES 1024
int max_samples = 256;
int samples_num = 256; 
short sendData[MAX_SAMPLES];

//Variables for MP34DT05
short PDM_MP34DT05 = 0;
short sampleBuffer[256];
short sampleData;
short sampleDataBuffer[256];
unsigned long theTime;
volatile int samplesRead;
int cnt = 0;

// BLE Service
BLEService Sensor_MP34DT05_Service(MP34DT05_SERVICE_UUID);

// BLE Characteristic
BLELongCharacteristic MP34DT05_PDM(MP34DT05_PDM_Characteristic_UUID, BLERead | BLENotify);
BLELongCharacteristic MP34DT05_PDM2(MP34DT05_PDM_Characteristic2_UUID, BLERead | BLEWrite);

void setup() {
  Serial.begin(115200);
  //while (!Serial);
  Serial.println("Started");

  // initialize the built-in LED pin to indicate when a central is connected
  pinMode(LED_BUILTIN, OUTPUT); 
  
  // MP34DT05 begin initialization
  PDM.onReceive(onPDMdata);

  // optionally set the gain, defaults to 20
  // PDM.setGain(30);

  // initialize PDM with: one channel (mono mode) & 16 kHz sample rate
  if (!PDM.begin(1, 16000)) {
    Serial.println("Failed to start PDM!");
    while (1);
  }

  // BLE begin initialization
  if (!BLE.begin()) {
    Serial.println("starting BLE failed!");
    while (1);
  }
  BLE.setLocalName(localNAME);
  BLE.setDeviceName(DeviceNAME);

  // Initialize MP34DT05 service
  set_MP34DT05();

  // start advertising
  BLE.advertise();
  Serial.println("Bluetooth device active, waiting for connections...");
}

void loop() {
  // wait for a BLE central
  BLEDevice central = BLE.central();
 
  // if a central is connected to the peripheral:
  if (central) {
    Serial.print("Connected to central: ");
    // print the central's BT address:
    Serial.println(central.address());
    // turn on the LED to indicate the connection:
    digitalWrite(LED_BUILTIN, HIGH);
 
    // while the central is connected:
    while (central.connected()) {
      readSamplesNum();
      if (cnt == 0){
        max_samples = samples_num;
        Read_PDM_Data();   // reads 1024 data from MP34DT05
        //IncDataGen();      // Incremental data generation for sending test
      }
      RGB_LED();
      notify_PDM();       // SEND_BLE
      counter();
      delay(50);
    }

    // when the central disconnects
    digitalWrite(LED_BUILTIN, LOW);
    Serial.print("Disconnected from central: ");
    Serial.println(central.address());
  } else {
      readMP34DT05_nonBLE();    
  }
}

2) update_MP34DT05.ino

#include 
#include 

void onPDMdata() {
  // query the number of bytes available
  int bytesAvailable = PDM.available();

  // read into the sample buffer
  PDM.read(sampleBuffer, bytesAvailable);

  // 16-bit, 2 bytes per sample
  samplesRead = bytesAvailable / 2;
}

void set_MP34DT05(){
  // add the service UUID
  BLE.setAdvertisedService(Sensor_MP34DT05_Service);

  // add characteristic
  Sensor_MP34DT05_Service.addCharacteristic(MP34DT05_PDM);
  Sensor_MP34DT05_Service.addCharacteristic(MP34DT05_PDM2);

  // Add service
  BLE.addService(Sensor_MP34DT05_Service);
}

// turn on RGB LED according max_samples
void RGB_LED(){
  if (max_samples == 1024){
    digitalWrite(LEDR,LOW);
    digitalWrite(LEDG,HIGH);
    digitalWrite(LEDB,HIGH);
  } else if (max_samples == 256){
    digitalWrite(LEDR,HIGH);
    digitalWrite(LEDG,HIGH);
    digitalWrite(LEDB,LOW);
  }      
}

// read pdm data and set to the dimensions
void Read_PDM_Data(){
  // Wait for samples to be read
  for (int j = 0; j<(max_samples/256); j++){
    while (samplesRead!=256);
    for (int i = 0; i < samplesRead; i++) {
      sendData[j*256+i] = sampleBuffer[i];
    }
    // Clear the read count
    samplesRead = 0;
  }
}

// Increment data generator for sending test
void IncDataGen(){
  for (int i=0; i<max_samples; i++){
    sendData[i] = i;
  }
}

// counter for send data control
void counter(){
  cnt++;
  if (cnt >= (max_samples/2+1)) cnt = 0;
}

// read 
void notify_PDM(){
  int i = 2*(cnt - 1);
  long sending;
  short sendingL, sendingH;
  if (cnt == 0){
    if (max_samples == 256){
      sending=0x7FFFFFFF;  // 32767      
    } else {
      sending=0x7FF0FFFF;  // 32752
    }
  } else {
    sending=((long)sendData[i+1] << 16) | (sendData[i] & 0xFFFF);
  }
  MP34DT05_PDM.writeValue(sending);
  sendingH = (short)((sending >> 16) & 0xFFFF);
  sendingL = (short)(sending & 0xFFFF);
  Serial.print(cnt);
  Serial.print(": "); 
  Serial.print(sendingH);
  Serial.print(", "); 
  Serial.println(sendingL);
}

// read samples number indicated by PDM
void readSamplesNum(){
  if (MP34DT05_PDM2.written()) {
    if (MP34DT05_PDM2.value()) {   // any value other than 0
      samples_num = 1024;
    } else {                              // a 0 value
      samples_num = 256;
    }
  }
}

void readMP34DT05_nonBLE() {
  // print samples to the serial monitor or plotter
  for (int i = 0; i < samplesRead; i++) {
    sampleData = sampleBuffer[i];
    LED_Control();
  }
}

void LED_Control(){
  if (abs(sampleData)>=0 && abs(sampleData) < 50){
    digitalWrite(LEDR,HIGH);
    digitalWrite(LEDG,HIGH);
    digitalWrite(LEDB,HIGH);
  } else if (abs(sampleData) < 100){
    digitalWrite(LEDR,HIGH);
    digitalWrite(LEDG,HIGH);
    digitalWrite(LEDB,LOW);        
  } else if (abs(sampleData) < 150){
    digitalWrite(LEDR,HIGH);
    digitalWrite(LEDG,LOW);
    digitalWrite(LEDB,HIGH);        
  } else if (abs(sampleData) < 200){
    digitalWrite(LEDR,HIGH);
    digitalWrite(LEDG,LOW);
    digitalWrite(LEDB,LOW);        
  } else if (abs(sampleData) < 250){
    digitalWrite(LEDR,LOW);
    digitalWrite(LEDG,HIGH);
    digitalWrite(LEDB,HIGH);        
  } else if (abs(sampleData) < 300){
    digitalWrite(LEDR,LOW);
    digitalWrite(LEDG,HIGH);
    digitalWrite(LEDB,LOW);        
  } else if (abs(sampleData) < 350){
    digitalWrite(LEDR,LOW);
    digitalWrite(LEDG,LOW);
    digitalWrite(LEDB,HIGH);
  } else {
    digitalWrite(LEDR,LOW);
    digitalWrite(LEDG,LOW);
    digitalWrite(LEDB,LOW);
  }
}

Appendix.2 BLEでの通信用Pythonプログラム(Python Pro1)

ble_pdh06.pyという名称で保存して下さい。Node-REDからの起動で使用します。

# -*- coding: utf-8 -*-
#####
# Successively receive micro-phone data of int type
# Rev.0.03: 2023/11/16 無限loopにする。
# Rev.0.04: 2023/11/20 short X 2 = Longの転送に対応
# Rev.0.05: 2023/11/25 send sampling number
# Rev.0.06: 2023/12/20 FFT結果の受信とSampling数の変換指示追加
####
import sys
import signal
import asyncio
from bleak import BleakClient

import paho.mqtt.client as mqtt
import time
import json
import struct

# setting for BLE
ADDRESS = (
#    "93:B3:5F:16:CD:A7" # Arduino Nano 33 BLE sense Flat
    "33:07:90:D9:BB:5D" # Arduino Nano 33 BLE sense MinoMushi
)

# UUID for MP34DT05
CHARACTERISTIC_UUID = "9997c274-9b21-4181-808f-ef2de31032a8" # read & notify
CHARACTERISTIC2_UUID = "9997c274-9b22-4181-808f-ef2de31032a8" # read & write

# MQTT ブローカーの設定
broker_address = "localhost"
port = 1883
topic = "pdm_pic"
topic2 = "fft_condition"

# MQTT クライアントを作成
mqtt_client = mqtt.Client()

data = [0.0,0.0,0.0]
sdata = {    # 辞書型データ
    'sensor': 'MP34DT05',
    'func': 'none',
    'time': 'time',
    'data': data
}

def publish(sdata):
    # Publish on MQTT
    json_data = json.dumps(sdata)
    mqtt_client.publish(topic, json_data)

def bytes_to_long(data):
    Data = [0]*2
    for i in range(0, len(data), 4):
        # 4バイトずつ読み込み、それをlong型に変換
        long_data = struct.unpack('I', data[i:i+4])[0]
    # 上位と下位の16bitに分ける(符号付き)
    Data[1] = struct.unpack('h', struct.pack('H', (long_data >> 16) & 0xFFFF))[0]
    Data[0] = struct.unpack('h', struct.pack('H', long_data & 0xFFFF))[0]
    return Data

def bytes_to_int(data):
    return int.from_bytes(data, byteorder='little')

def bytes_to_double(data):
    return struct.unpack('d', data)[0]

def bytes_to_signed_int(data):
    return int.from_bytes(data, byteorder='little', signed=True)

def notification_handler(sender, data: bytearray):
    """Simple notification handler which sends MP34DT05 PDM data"""
    data1 = bytes_to_long(data)
    sdata['sensor'] = 'MP34DT05'
    sdata['func'] = 'PDM'
    sdata['time'] = time.time()
    sdata['data'] = data1
    #print(data1)
    publish(sdata)

sampleNo = bytearray([0])
sampleNoOld = bytearray([0])
recv_sample = 256

async def setSampleNo():
    global sampleNo
    if (recv_sample == 256):
        sampleNo = bytearray([0])
    elif (recv_sample == 1024):
        sampleNo = bytearray([1])
    else:
        sampleNo = bytearray([0])
    #print(f'{cycleNo}, {sampleNo}, {recv_sample}')

def on_connect(client, userdata, flags, rc):
    print("Connected with result code "+str(rc))
    client.subscribe(topic2)

def on_message(client, userdata, msg):
    global recv_sample
    recv_sample = json.loads(msg.payload)
    print(f'{msg.payload},{recv_sample}')
    
async def main(address):
    global sampleNoOld
    char_uuid = CHARACTERISTIC_UUID
    char2_uuid = CHARACTERISTIC2_UUID

    # MQTT ブローカーに接続
    mqtt_client.connect(broker_address, port)
    print(f"MQTT connected: {broker_address}")
    mqtt_client.on_connect = on_connect
    mqtt_client.on_message = on_message
    mqtt_client.loop_start()
    
    print(address, char_uuid)
    async with BleakClient(address) as client:
        print(f"Connected: {client.is_connected}")
        await client.start_notify(char_uuid, notification_handler)
        try:
            while True:
                await setSampleNo()
                if sampleNo != sampleNoOld:
                    await client.write_gatt_char(char2_uuid, sampleNo)
                    print(f'sampleNo={sampleNo}')
                    sampleNoOld = sampleNo
                await asyncio.sleep(1.0)
        except KeyboardInterrupt:
            print("Stop!! Key board Interrupt!!");
            #await client.stop_notify(char_uuid)
        finally:
            await client.stop_notify(char_uuid)

    print("Disconnected")
    # MQTT クライアントを切断
    mqtt_client.disconnect()
    print(f"MQTT disconnected: {broker_address}")
        
if __name__ == "__main__":
    asyncio.run(
        main(
            sys.argv[1] if len(sys.argv) > 1 else ADDRESS,
        )
    )

Appendix.3 Node-REDのフローファイル

本文で紹介しているNode-REDのフローファイルです。
Node.js v18.16.0, Node-RED v3.0.2で動作しています。
このフロー中には、MQTTのブローカーノードが入っていないので、同じPDH(Raspberry Pi)のフローの何処かにブローカーを配置して下さい。別の端末のブローカーを使う場合には、mqtt関係のIPアドレスを変更する必要があります。

ディレクトリ関係は環境に合わせて修正する必要があります。

FFT用のプログラムに関しては、Node-REDから実行するとNGになる場合があります。その場合には、端末画面から起動してください。

[{"id":"287721ebd0af0b63","type":"tab","label":"PDM_PIC2","disabled":false,"info":"","env":[]},{"id":"00823d1a50950b0a","type":"debug","z":"287721ebd0af0b63","name":"debug 81","active":true,"tosidebar":true,"console":false,"tostatus":false,"complete":"false","statusVal":"","statusType":"auto","x":420,"y":80,"wires":[]},{"id":"abe28cfc0dda4281","type":"exec","z":"287721ebd0af0b63","command":"python3 /home/pi/source/python/ble/ble_pdh06.py","addpay":"payload","append":"","useSpawn":"false","timer":"","winHide":false,"oldrc":false,"name":"ble_pdh06","x":270,"y":80,"wires":[[],["00823d1a50950b0a"],["00823d1a50950b0a"]]},{"id":"ce137ed9b7a8867b","type":"inject","z":"287721ebd0af0b63","name":"受信開始","props":[{"p":"filename","v":"/home/pi/Documents/pdm_imu/receive_","vt":"str"},{"p":"payload"}],"repeat":"","crontab":"","once":false,"onceDelay":0.1,"topic":"","payload":"33:07:90:D9:BB:5D","payloadType":"str","x":100,"y":100,"wires":[["885ca16d7a763aed","abe28cfc0dda4281","382c3786e2e8efe7","52491d06734e366b"]]},{"id":"0cd751654fc88190","type":"inject","z":"287721ebd0af0b63","name":"受信終了","props":[{"p":"kill","v":"","vt":"str"},{"p":"topic","vt":"str"}],"repeat":"","crontab":"","once":false,"onceDelay":0.1,"topic":"","x":100,"y":140,"wires":[["abe28cfc0dda4281"]]},{"id":"885ca16d7a763aed","type":"moment","z":"287721ebd0af0b63","name":"日時","topic":"","input":"","inputType":"date","inTz":"Asia/Tokyo","adjAmount":0,"adjType":"days","adjDir":"add","format":"YYYYMMDD_HHmmss","locale":"ja-JP","output":"datetime","outputType":"msg","outTz":"Asia/Tokyo","x":250,"y":140,"wires":[["41ec1a3b54d21e73"]]},{"id":"382c3786e2e8efe7","type":"function","z":"287721ebd0af0b63","name":"初期化","func":"flow.set('count', 0);\nflow.set('cycle', 0);\nflow.set('count_max', 256);","outputs":1,"noerr":0,"initialize":"","finalize":"","libs":[],"x":250,"y":180,"wires":[[]]},{"id":"52491d06734e366b","type":"function","z":"287721ebd0af0b63","name":"配列作成","func":"// 配列の要素数\nflow.set(\"length\", 1024);\n// 配列作成\n//if (!context.flow.array) {\ncontext.flow.array = new Array(flow.get(\"length\"));\n//}\n// 配列の初期化:データを0で埋める\ncontext.flow.array.fill(0);","outputs":1,"noerr":0,"initialize":"","finalize":"","libs":[],"x":260,"y":220,"wires":[[]]},{"id":"41ec1a3b54d21e73","type":"change","z":"287721ebd0af0b63","name":"filename","rules":[{"t":"set","p":"filename","pt":"flow","to":"filename&datetime&ext","tot":"jsonata"}],"action":"","property":"","from":"","to":"","reg":false,"x":380,"y":140,"wires":[[]]},{"id":"d802477288e1c9d3","type":"inject","z":"287721ebd0af0b63","name":"初期化","props":[{"p":"payload"},{"p":"topic","vt":"str"}],"repeat":"","crontab":"","once":true,"onceDelay":0.1,"topic":"","payload":"","payloadType":"date","x":100,"y":200,"wires":[["382c3786e2e8efe7","52491d06734e366b"]]},{"id":"4550065dbe0175df","type":"inject","z":"287721ebd0af0b63","name":"FFT終了","props":[{"p":"kill","v":"","vt":"str"},{"p":"topic","vt":"str"}],"repeat":"","crontab":"","once":false,"onceDelay":0.1,"topic":"","x":660,"y":120,"wires":[["cc57b01a40dff466"]]},{"id":"cc57b01a40dff466","type":"exec","z":"287721ebd0af0b63","command":"python3 ","addpay":"payload","append":"","useSpawn":"false","timer":"","winHide":false,"oldrc":false,"name":"fft_calc04","x":800,"y":80,"wires":[[],["eea006e5a7f7e122"],["eea006e5a7f7e122"]]},{"id":"1edbfdb4622b3e10","type":"inject","z":"287721ebd0af0b63","name":"FFT開始","props":[{"p":"payload"},{"p":"topic","vt":"str"}],"repeat":"","crontab":"","once":false,"onceDelay":0.1,"topic":"","payload":"/home/pi/source/python/deep_learning/fft/fft_calc04.py","payloadType":"str","x":660,"y":80,"wires":[["cc57b01a40dff466"]]},{"id":"eea006e5a7f7e122","type":"debug","z":"287721ebd0af0b63","name":"debug 82","active":true,"tosidebar":true,"console":false,"tostatus":false,"complete":"payload","targetType":"msg","statusVal":"","statusType":"auto","x":940,"y":80,"wires":[]},{"id":"77cb07192336961c","type":"moment","z":"287721ebd0af0b63","name":"日時","topic":"","input":"","inputType":"date","inTz":"Asia/Tokyo","adjAmount":0,"adjType":"days","adjDir":"add","format":"YYYY-MM-DD_HH_mm_ss","locale":"ja-JP","output":"payload.datetime","outputType":"msg","outTz":"Asia/Tokyo","x":350,"y":360,"wires":[["00b398a331e6550e","27a33a153566adeb"]]},{"id":"4195344a881735ae","type":"change","z":"287721ebd0af0b63","name":"削除","rules":[{"t":"delete","p":"payload.sensor","pt":"msg"},{"t":"delete","p":"payload.func","pt":"msg"},{"t":"delete","p":"payload.time","pt":"msg"}],"action":"","property":"","from":"","to":"","reg":false,"x":230,"y":360,"wires":[["77cb07192336961c"]]},{"id":"00b398a331e6550e","type":"function","z":"287721ebd0af0b63","name":"集約(ble_pdh04)","func":"var cycle = flow.get('cycle');\nvar count = flow.get('count');\nvar count_max = flow.get('count_max');\nif (msg.payload.data[1] >= (2 ** 15 - 16) || count > 1024){\n count = 0;\n context.flow.array.fill(0);\n if (msg.payload.data[1] >= 2**15-1){\n count_max = 256;\n } else if (msg.payload.data[1] >= 2**15-16){\n count_max = 1024;\n }\n node.warn(\"first data detexted(\"+String(msg.payload.data[1]+\")!!\"));\n flow.set('count_max',count_max);\n} else {\n context.flow.array[count] = msg.payload.data[0];\n context.flow.array[count+1] = msg.payload.data[1];\n msg.payload.count = count;\n msg.payload.cycle = cycle;\n msg.payload.full = false;\n count = count+2;\n}\nflow.set('count', count);\nif (count == count_max){\n msg.payload.data = context.flow.array;\n msg.payload.full = true;\n cycle++;\n flow.set('cycle', cycle);\n return msg;\n}\n","outputs":1,"noerr":0,"initialize":"","finalize":"","libs":[],"x":510,"y":360,"wires":[["0fc3768d2ca70b8c","6c7733f802402d1e"]]},{"id":"27a33a153566adeb","type":"debug","z":"287721ebd0af0b63","name":"debug 83","active":false,"tosidebar":true,"console":false,"tostatus":false,"complete":"false","statusVal":"","statusType":"auto","x":480,"y":400,"wires":[]},{"id":"947bc2a4a89f7186","type":"mqtt in","z":"287721ebd0af0b63","name":"","topic":"pdm_pic","qos":"2","datatype":"auto-detect","broker":"496bc63a53e5165b","nl":false,"rap":true,"rh":0,"inputs":0,"x":100,"y":360,"wires":[["4195344a881735ae"]]},{"id":"0fc3768d2ca70b8c","type":"change","z":"287721ebd0af0b63","name":"filename","rules":[{"t":"set","p":"filename","pt":"msg","to":"filename","tot":"flow"},{"t":"set","p":"filename","pt":"msg","to":"filename&\".csv\"","tot":"jsonata"}],"action":"","property":"","from":"","to":"","reg":false,"x":680,"y":360,"wires":[["46f796c3c04083b5"]]},{"id":"6c7733f802402d1e","type":"mqtt out","z":"287721ebd0af0b63","name":"","topic":"fft_data","qos":"","retain":"","respTopic":"","contentType":"","userProps":"","correl":"","expiry":"","broker":"496bc63a53e5165b","x":680,"y":400,"wires":[]},{"id":"46f796c3c04083b5","type":"csv","z":"287721ebd0af0b63","name":"","sep":",","hdrin":"","hdrout":"none","multi":"one","ret":"\\n","temp":"cycle,count,datetime,data","skip":"0","strings":true,"include_empty_strings":"","include_null_values":"","x":810,"y":360,"wires":[["50e9a5dbf4aab002"]]},{"id":"50e9a5dbf4aab002","type":"file","z":"287721ebd0af0b63","name":"","filename":"filename","filenameType":"msg","appendNewline":false,"createDir":true,"overwriteFile":"false","encoding":"none","x":940,"y":360,"wires":[[]]},{"id":"cb78e1cd8f1677a3","type":"comment","z":"287721ebd0af0b63","name":"送信データ集約","info":"","x":120,"y":320,"wires":[]},{"id":"3905c9ed879c6422","type":"mqtt in","z":"287721ebd0af0b63","name":"","topic":"fft_data","qos":"2","datatype":"auto-detect","broker":"496bc63a53e5165b","nl":false,"rap":true,"rh":0,"inputs":0,"x":90,"y":440,"wires":[["9cf1ca24ac162b8d"]]},{"id":"9cf1ca24ac162b8d","type":"debug","z":"287721ebd0af0b63","name":"debug 84","active":true,"tosidebar":true,"console":false,"tostatus":false,"complete":"false","statusVal":"","statusType":"auto","x":220,"y":440,"wires":[]},{"id":"3cc3a9ad104246b2","type":"watch","z":"287721ebd0af0b63","name":"png監視","files":"/home/pi/Documents/pdm_imu/pictures","recursive":"","x":100,"y":540,"wires":[["43975452a45c4645"]]},{"id":"43975452a45c4645","type":"function","z":"287721ebd0af0b63","name":"png file","func":"if (context.get(\"fileName\") === undefined) {\n context.set(\"fileName\", \"newfile.png\")\n}\nif (context.fileName == msg.payload){\n context.fileName = msg.payload;\n return msg;\n} else {\n context.fileName = msg.payload;\n}\n","outputs":1,"noerr":0,"initialize":"","finalize":"","libs":[],"x":240,"y":540,"wires":[["3e896ea2754d9266","85fce9f8600c3316"]]},{"id":"3e896ea2754d9266","type":"image viewer","z":"287721ebd0af0b63","name":"","width":160,"data":"payload","dataType":"msg","active":true,"x":370,"y":540,"wires":[[]]},{"id":"85fce9f8600c3316","type":"debug","z":"287721ebd0af0b63","name":"debug 85","active":false,"tosidebar":true,"console":false,"tostatus":false,"complete":"false","statusVal":"","statusType":"auto","x":380,"y":500,"wires":[]},{"id":"496bc63a53e5165b","type":"mqtt-broker","name":"","broker":"localhost","port":"1883","clientid":"","autoConnect":true,"usetls":false,"protocolVersion":"4","keepalive":"60","cleansession":true,"birthTopic":"","birthQos":"0","birthPayload":"","birthMsg":{},"closeTopic":"","closeQos":"0","closePayload":"","closeMsg":{},"willTopic":"","willQos":"0","willPayload":"","willMsg":{},"userProps":"","sessionExpiry":""}]

Appendix.4 FFT処理と機械学習結果を用いた分類プログラム(Python Pro2)

機械学習のプログラムは、本文でも記載していますように、参考文献(「ゼロから作るDeep Learning」 斎藤 康毅 著 オライリー・ジャパン)のプログラムをベースに使用しています。参考文献に記載のHPからダウンロードしてご使用ください。

ここでは、以下の2つのプログラムを掲載します。two_layer_net.pyは、参考文献のプログラムtw0_layer_net.pyにパラメータセット用のメソッドを追加しています。

1) fft_calc04.py

'''
fft_calc04.py : execute fft
  rev0.3: change sampling number according msg.payload.count
  rev0.4: fftの結果をDeep Learningで判断できるように修正
'''
import paho.mqtt.client as mqtt
import numpy as np
import matplotlib.pyplot as plt
import json
import time

# Deep Learning用
import numpy as np
import pickle
from two_layer_net import TwoLayerNet

# MQTT ブローカーの設定
broker_address = "localhost"
port = 1883
topic = "fft_data"
topic2 = "fft_condition"

# FFT Configulation
fs = 16000       # sampling frequency
N  = 1024       # number of samples
dt = 1/fs       # sampling period

# Window function
func_window = 'hanning'

# MQTT クライアントを作成
mqtt_client = mqtt.Client()

# Deep Learning Calc
network256  = TwoLayerNet(input_size=127, hidden_size=100, output_size=5)
network1024 = TwoLayerNet(input_size=511, hidden_size=100, output_size=5)

sdata = {    # 辞書型データ
    'current_SMP': 1024, 
    'next_SMP' : 1024
}

def publish(sdata):
    # Publish on MQTT
    #json_data = json.dumps(sdata)
    mqtt_client.publish(topic2, sdata)

def on_connect(client, userdata, flags, rc):
    print("Connected with result code "+str(rc))
    client.subscribe(topic)

def on_message(client, userdata, msg):
    dict_data = json.loads(msg.payload)
    N = dict_data["count"] + 2
    window = np.hanning(N)
    data = dict_data["data"]
    y = np.array(data[:N])
    #y = np.array(dict_data["data"])
    y = y * window
    if dict_data["full"] == 1:
        print(y)
    else:
        print(dict_data["full"])

    # FFTを実行
    y_fft = np.fft.fft(y)               # 離散フーリエ変換
    freq = np.fft.fftfreq(N, d=dt)      # 周波数を割り当てる
    Amp = abs(y_fft/(N/2))              # 音の大きさ(振幅の大きさ)

    # 窓補正
    acf=1/(sum(window)/N)
    Amp = acf*Amp
    
    output_FN = "/home/pi/Documents/pdm_imu/pictures/"+dict_data["datetime"]+".png"

    ### 音波のスペクトル ###
    plt.plot(freq[1:int(N/2)], Amp[1:int(N/2)]) # A-f グラフのプロット
    if N == 256:
        plt.xlabel("frequency [Hz](sp=256)")
    else:
        plt.xlabel("frequency [Hz](sp=1024)")
    plt.ylabel("amplitude [V/rtHz]")
    plt.xscale("log")                           # 横軸を対数軸にセット
    plt.yscale("log")                           # 縦軸を対数軸にセット
    plt.savefig(output_FN, dpi=300)
    plt.clf()

    # Deep Learningによる判定
    if N == 256:
        y = network256.predict(np.tile(Amp[1:int(N/2)], (2,1)))
    else:
        y = network1024.predict(np.tile(Amp[1:int(N/2)], (2,1)))
    y = np.argmax(y, axis=1)
    tp = int(y[0])

    # Sampling point数の変更
    if (tp == 2 or tp == 4):
        next = 1024
    else:
        next = 256

    if next != N:
        sdata['current_SMP'] = N
        sdata['next_SMP'] = next
        publish(str(next))

    print(f"fft type: {tp}, current_SMP: {N}, next_SMP: {next}")

    # NumPy配列をPythonリストに変換
    freq_list = freq[1:int(N/2)].tolist()
    amp_list = Amp[1:int(N/2)].tolist()

    # PythonリストをJSON形式に変換
    out_data = {"Num":N, "type":tp, "datatime":dict_data["datetime"], "freq": freq_list, "amp": amp_list}
    data_json = json.dumps(out_data)
    #print(data_json)

    outFile = "/home/pi/Documents/pdm_imu/fftResult/"+dict_data["datetime"]+".json"
    # JSONデータをファイルに出力
    with open(outFile, "w") as f:
        f.write(data_json)
    
def main():
    # Deep Learning用paramter読込み
    with open('params_1024.pkl', 'rb') as f1:
        network1024.params = pickle.load(f1)
    network1024.setlayer()

    with open('params_256.pkl', 'rb') as f2:
        network256.params = pickle.load(f2)
    network256.setlayer()

    # MQTT ブローカーに接続
    mqtt_client.connect(broker_address, port)
    print(f"MQTT connected: {broker_address}")
    mqtt_client.on_connect = on_connect
    mqtt_client.on_message = on_message

    try:
        mqtt_client.loop_forever()
    except KeyboardInterrupt:
        # MQTT クライアントを切断
        mqtt_client.disconnect()
        print(f"MQTT disconnected: {broker_address}")

if __name__ == "__main__":
    main()

2) two_layer_net.py

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 親ディレクトリのファイルをインポートするための設定
import numpy as np
from common.layers import *
from common.gradient import numerical_gradient
from collections import OrderedDict


class TwoLayerNet:

    def __init__(self, input_size, hidden_size, output_size, weight_init_std = 0.01):
        # 重みの初期化
        self.params = {}
        self.params['W1'] = weight_init_std * np.random.randn(input_size, hidden_size)
        self.params['b1'] = np.zeros(hidden_size)
        self.params['W2'] = weight_init_std * np.random.randn(hidden_size, output_size) 
        self.params['b2'] = np.zeros(output_size)

        # レイヤの生成
        self.layers = OrderedDict()
        self.layers['Affine1'] = Affine(self.params['W1'], self.params['b1'])
        self.layers['Relu1'] = Relu()
        self.layers['Affine2'] = Affine(self.params['W2'], self.params['b2'])

        self.lastLayer = SoftmaxWithLoss()

    def setlayer(self):
        self.layers['Affine1'] = Affine(self.params['W1'], self.params['b1'])
        self.layers['Relu1'] = Relu()
        self.layers['Affine2'] = Affine(self.params['W2'], self.params['b2'])
        
    def predict(self, x):
        for layer in self.layers.values():
            x = layer.forward(x)
        return x
        
    # x:入力データ, t:教師データ
    def loss(self, x, t):
        y = self.predict(x)
        return self.lastLayer.forward(y, t)
    
    def accuracy(self, x, t):
        y = self.predict(x)
        y = np.argmax(y, axis=1)
        if t.ndim != 1 : t = np.argmax(t, axis=1)
        
        accuracy = np.sum(y == t) / float(x.shape[0])
        return accuracy
        
    # x:入力データ, t:教師データ
    def numerical_gradient(self, x, t):
        loss_W = lambda W: self.loss(x, t)
        
        grads = {}
        grads['W1'] = numerical_gradient(loss_W, self.params['W1'])
        grads['b1'] = numerical_gradient(loss_W, self.params['b1'])
        grads['W2'] = numerical_gradient(loss_W, self.params['W2'])
        grads['b2'] = numerical_gradient(loss_W, self.params['b2'])
        
        return grads
        
    def gradient(self, x, t):
        # forward
        self.loss(x, t)

        # backward
        dout = 1
        dout = self.lastLayer.backward(dout)
        
        layers = list(self.layers.values())
        layers.reverse()
        for layer in layers:
            dout = layer.backward(dout)

        # 設定
        grads = {}
        grads['W1'], grads['b1'] = self.layers['Affine1'].dW, self.layers['Affine1'].db
        grads['W2'], grads['b2'] = self.layers['Affine2'].dW, self.layers['Affine2'].db

        return grads