一覧に戻る

WASMのパフォーマンス最適化の勘所と使い所考察

#JavaScript#Rust#WebAssembly

TL;DR

  • JavaScriptとWebAssemblyの間の値渡しには気を使おう(なるべく値のコピーを避けよう)。
  • JSはとても高速で、単純なループではWASM並の速度が出るので、WASMの使い所はよくよく考えるべき。

はじめに

1000^2級の画像の全ピクセルをループして、(簡単に言うと)RGBA値の最も大きい値を抽出する、という処理をブラウザ上で突然したくなりました。ピクセル数が1,000,000だと、RGBAなので配列長は4,000,000となります。ブラウザで扱いたくないサイズ感です。

ここで、①WASMで高速化、ダメなら②サーバーサイドで計算させる…という算段をして、とりあえずWASMを試してみました。色々チューニングした結果、ブラウザ上で現実的な速度が出ることがわかりました(数十msの世界)。

RGBA値の計算について

本記事では深く説明しませんが、今回やりたい処理は「下記式により求まる実数値の最大値を探す」というものです。

value = -10000 + ((R * 256 * 256 + G * 256 + B) * 0.1)

TerrainRGBという標高値のエンコーディングです、詳しくは下記 https://qiita.com/Kanahiro/items/e22594b738655a189c1d#rgb%E5%80%A4%E3%81%AE%E6%A8%99%E9%AB%98%E6%8F%9B%E7%AE%97

WebAssemblyの実装

  • パフォーマンスチューニングの勘所は、引数や計算結果の渡し方です。単純な計算では明らかにWASMが速いですが、WASMの初期化や関数呼び出しなどのオーバーヘッドがあります。初期化のロスは避けられませんが、後者は値のやり取りに気を使うことで高速化が図れます。
  • なおサンプルコードでは、WASMはRust+wasm-bindgenで書くものとします。
  • 画像サイズは1400x1815とします(ピクセル数2541000)。

JSからWASMへの配列の渡し方

  • WASMへ要素数400万級の配列を関数の引数として素直に渡すと、値のコピーですごく遅くなります。
    • このサイズのUint8Arrayをfn somefunc(arr: &[u8])のように単純な引数として渡すと関数実行までで500msくらいかかります。
  • 下記のように渡すと高速です。
  1. JS側で、WASMに渡したい配列を用意する(ここでは画像のピクセル値の配列)
  2. WASM側で、JSから受け取る配列のサイズに合わせてを配列を初期化する(メモリをアロケートする)
  3. WASMからJSへ、WASM側で初期化した配列のポインタを渡す
  4. JS側で、WASMから受け取った配列のポインタをもとにWASM上のメモリのビューとして配列を初期化する
  5. 4の配列に、1で用意した配列をコピーする(set()を用いて高速にコピー)

WASMからJSへの計算結果の渡し方

  • JSからWASMへの配列の渡し方の③以降と全く同じ考え方で、ポインタをJSに渡してビューとして配列を初期化します。
  • そもそもWASMからJSには良い感じに配列を返す方法がこれ以外になさそうでした。
  1. WASM側で、JSに渡したい配列を初期化し(メモリをアロケートする)、値を書き込む
  2. WASMからJSへ、WASM側で初期化した配列のポインタを渡す
  3. JS側で、WASMから受け取った配列のポインタをもとにWASM上のメモリのビューとして配列を初期化する

サンプルコード

JS<->WASM間の配列のやりとりが実装されています。

mod utils;

use wasm_bindgen::prelude::*;

// When the `wee_alloc` feature is enabled, use `wee_alloc` as the global
// allocator.
#[cfg(feature = "wee_alloc")]
#[global_allocator]
static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT;

#[wasm_bindgen]
pub struct TerrainRgb {
    rgba: Vec<u8>, // RGBA配列、JSから値を受け取る
    elevations: Vec<f64>, // 標高配列、WASMでの計算結果が入る、JSへ渡す
    pub pointer_to_rgba: *const u8, // JSからRGBA配列を参照するためのポインタ
    pub pointer_to_elevations: *const f64, // JSから計算結果の標高配列を参照するためのポインタ
}

#[wasm_bindgen]
impl TerrainRgb {
    #[wasm_bindgen(constructor)]
    pub fn new(pixel_length: usize) -> Self {
        let mut rgba: Vec<u8> = Vec::with_capacity(pixel_length); // 配列の初期化
        unsafe { rgba.set_len(pixel_length) } // unsafeで配列長を確定してしまうことで添字アクセスを可能に
        let pointer_to_rgba = rgba.as_mut_ptr(); // RGBA配列へのポインタ=メモリアドレスを取得

        let mut elevations: Vec<f64> = Vec::with_capacity(pixel_length / 4);
        unsafe { elevations.set_len(pixel_length / 4) }
        let pointer_to_elevations = elevations.as_mut_ptr();

        Self {
            rgba,
            elevations,
            pointer_to_rgba,
            pointer_to_elevations,
        }
    }

    pub fn decode_elevations(&mut self) {
        // RGBA値から標高値を計算し配列の値を更新
        for i in 0..(self.rgba.len() / 4) {
            self.elevations[i] = -10000.
                + 6553.6 * self.rgba[4 * i] as f64
                + 25.6 * self.rgba[4 * i + 1] as f64
                + 0.1 * self.rgba[4 * i + 2] as f64;
        }
    }
}

import wasm, { TerrainRgb } from './pkg/trial.js';
const image = new Image();
image.crossOrigin = '';
image.onload = () => {
    const canvas = document.createElement('canvas');
    canvas.width = image.width;
    canvas.height = image.height;

    const context = canvas.getContext('2d');
    context.drawImage(image, 0, 0);
    const imageData = context.getImageData(
        0,
        0,
        canvas.width,
        canvas.height,
    );
    wasm().then((instance) => {
        let start = performance.now();
        const terrainrgb = new TerrainRgb(imageData.data.length);
        const rgba = new Uint8Array(
            instance.memory.buffer,
            terrainrgb.pointer_to_rgba,
            imageData.data.length,
        ); // ポインタをもとに、WASM側で初期化した配列の「ビュー」としてJS配列を初期化
        rgba.set(imageData.data); // 画像の配列を丸々コピーする
        terrainrgb.decode_elevations(); // RGBA値から標高値を計算する

        const elevations = new Float64Array(
            instance.memory.buffer,
            terrainrgb.pointer_to_elevation,
            imageData.data.length / 4,
        ); // ポインタをもとに、WASM側で初期化した配列の「ビュー」としてJS配列を初期化し計算結果を参照する
        
        console.log('elevations', elevations);
        console.log('wasm finish', performance.now() - start);
    });
};
image.src = './terrain.png'; // 1400x1815の画像
結果
elevations Float64Array(2541000) […]
wasm finish 76.90000009536743

まとめ

  • どちらの方向でも共通しているのは、WASM上にメモリを確保しJS側から参照することで、引数による値のコピーを避けているということです。
  • 以上により、数十msで1,000,000ピクセルの画像の全ピクセルのループ処理が実現しました、速いですね!

オチ:異常言語JavaScript

ついにWASMをプロダクションで使えるかもなー、これでナウいフロントエンドえんじにゃーだぜ とか思いつつ、じゃあJavaScriptだとどれだけかかるのか、ベンチマークしてみました。

const image = new Image();
image.crossOrigin = '';
image.onload = () => {
    const canvas = document.createElement('canvas');
    canvas.width = image.width;
    canvas.height = image.height;

    const context = canvas.getContext('2d');
    context.drawImage(image, 0, 0);
    const imageData = context.getImageData(
        0,
        0,
        canvas.width,
        canvas.height,
    );

    let start = performance.now();

    const jsDecodeElevation = (arr) => {
        const elevs = new Float64Array(arr.length / 4);
        for (let i = 0; i < arr.length / 4; i++) {
            elevs[i] =
                -10000 +
                6553.6 * arr[4 * i] +
                25.6 * arr[4 * i + 1] +
                0.1 * arr[4 * i + 2];
        }
        return elevs;
    };

    const elevations = jsDecodeElevation(imageData.data)
    console.log('elevations', elevations)
    console.log('finish', performance.now() - start);

};
image.src = './terrain.png';
elevations Float64Array(2541000) […]
js finish 30.200000047683716

は??

素のJavaScriptの方が2倍速かったというオチ 軽量言語でネイティブ並の速度が出るJavaScriptは異常(キレそう) JSはやっぱり最高だな!!

終わりに

  • WASMの関数呼び出しでは、なるべく値のコピーが発生しないように気を使おう。
  • その処理のパフォーマンス向上のためにほんとうにWASMが必要か、よく考えよう。案外JavaScriptだけで最適化出来たりします。
    • この文は全然まとまっていなくてメモレベルですが、8bit-uintのループ処理はJSとWASMで大差はなく、64bit-floatだとWASMが明らかに速くなっていた、はず。もう気力がないので本記事には書きませんが、大量の実数値計算ではWASMを使うとパフォーマンス改善に大きく寄与するかもしれません。