8月17日、「PyTorchは死んだ、JAX万歳」と題した記事が公開され、海外で人気を博している。この記事では、PyTorchの問題点と、それに対するJAXの優位性について、コードサンプルを交えて詳しく分析されている。
この記事では、PyTorchが科学計算の分野でどれほどの生産性を低下させ、技術的負債を増大させたかを強調している。PyTorchは本来、迅速なプロトタイピングを目的として開発されたが、その設計は大規模な分散システムでの使用に適していないという。
設計思想の違い
PyTorchの設計思想は、TensorFlowの静的でパフォーマンス重視のアプローチに対して、動的でデバッグが容易、そして「Pythonらしい(Pythonic)」フレームワークを目指すというものである。しかしこの設計思想は、特に大規模なプロジェクトにおいて、スケーラビリティとパフォーマンスの面で妥協を強いられる。
一方で、JAXはその設計段階からスケーラビリティと高性能を追求しており、特に分散コンピューティングにおいて優れた性能を発揮する。JAXでは、@jax.jit
デコレータを使用して関数をJITコンパイルすることで、計算の最適化と並列化が自動的に行われる。これにより、開発者は手動での最適化や複雑な設定作業を省くことができる。
以下は、JAXでの簡単なサンプルコードである。
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
return x * x + 2
x = jnp.array([1.0, 2.0, 3.0])
print(f(x))
このコードでは、関数f
がJITコンパイルされ、XLA
コンパイラを通じて最適化されている。PyTorch
で同様の最適化を行うには、手動で多くの設定が必要である。
JAXの「コンパイラ駆動」開発がもたらす利点
JAXの最大の強みは、そのコンパイラ駆動の開発アプローチにある。XLA
(Accelerated Linear Algebra)は、Googleによって開発された高性能なコンパイラで、深層学習や科学計算の負荷を劇的に軽減することを目的としている。XLA
は、特に複雑な計算グラフの最適化、演算の並列化、自動シャーディング(データの分割と分散処理)などの機能を持っている。
XLA
の魅力は、その柔軟性と汎用性にある。たとえば、JAXでは以下のようなコードで分散計算が簡単に行える。
# デバイス間での値の分散
sharding = jax.experimental.maps.Mesh((8,))
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
y = jax.device_put(x, sharding)
XLA
は、計算の分割と並列化を自動的に行い、ユーザーが細かな設定に煩わされることなく、最適なパフォーマンスを発揮できる環境を提供する。これは、PyTorchの手動設定に頼らざるを得ないアプローチとは対照的である。
さらに、XLA
は異なるハードウェアバックエンド(TPU、GPU、CPUなど)での計算をシームレスにサポートしており、コードの移植性を高める。これにより、同じコードが異なるデバイスで簡単に実行できるため、開発者はハードウェアの制約に煩わされることなく、研究に集中できる。
PyTorchは、マルチバックエンドのサポートが複雑化を助長している
PyTorchが直面する最大の課題の一つは、そのマルチバックエンドサポートの複雑さである。PyTorchは、多様なハードウェアでの動作をサポートするために、複数のバックエンドをサポートする設計を採用している。これには、GPUやCPUだけでなく、GoogleのTPU、そして将来的には他のハードウェアも含まれる。しかし、このアプローチは理論的には魅力的に見えるものの、実際にはいくつかの問題を引き起こしている。
第一に、各バックエンド間でのAPIの一貫性が保たれておらず、ユーザーは異なるバックエンド間での切り替え時に多くの手動調整やデバッグが必要となる。例えば、同じPyTorchコードが異なるハードウェアで動作する際に、性能が大きく変わる可能性があり、これがユーザーにとって大きな負担となっている。
第二に、各バックエンドのサポートはしばしば不完全であり、新しいハードウェアに対応するためには、PyTorchの内部での多くの改修が必要となる。この結果、ユーザーはしばしば、公式ドキュメントやサポートを頼りにしなければならず、これが開発の遅延やフラストレーションを引き起こしている。
関数型プログラミングの採用によるJAXのアドバンテージ
JAXは、関数型プログラミングのアプローチを採用しており、これがコードの再利用性とメンテナンス性を大幅に向上させている。
関数型プログラミングでは、関数が「純粋」であることが求められる。これは、関数が副作用を持たず、同じ入力に対して常に同じ出力を返すことを意味する。この設計により、コードは予測可能でテストが容易になり、デバッグが簡単になる。JAXでは、この純粋性が高度に尊重されており、これにより、複雑な操作もシンプルかつ直感的に記述できる。
例えばJAXにはvmap
という強力な機能があり、これを使用することで、バッチ処理を効率的に実行できる。vmap
は、ベクトル化されたマッピングを提供し、同じ操作を複数の入力に対して一括で適用できる。以下は、vmap
を使用したJAXのコード例である。
import jax
import jax.numpy as jnp
def apply_function(x):
return jnp.sin(x)
# バッチ処理の実行
batched_apply_function = jax.vmap(apply_function)
x = jnp.array([0.1, 0.2, 0.3])
result = batched_apply_function(x)
print(result)
この例では、apply_function
をvmap
でベクトル化し、バッチ処理を一括で行っている。vmap
の内部では、JAXが自動的に並列化を行うため、バッチ処理は非常に高速に実行される。
一方、PyTorchでは、同様のバッチ処理や並列化を行うために、多くの手動設定やコードの調整が必要となる。例えば、分散処理を行うためには、torch.distributed
モジュールを使用し、手動でプロセスの同期やデバイス間の通信を管理する必要がある。これにより、コードが複雑化し、デバッグやメンテナンスが難しくなる。
さらに、PyTorchのコードはしばしばオブジェクト指向プログラミング(OOP)に依存しており、これがコードの断片化を招く一因となっている。OOPでは、状態を持つオブジェクトやクラスが中心となるため、コードのモジュール化が進む一方で、各モジュール間の依存関係が複雑化しやすい。この結果、コードの再利用性や保守性が低下し、大規模なプロジェクトでは特に管理が難しくなる。
再現性
再現性は科学研究において非常に重要であり、JAXはこの点でも優れている。JAXでは、ランダム性の制御が厳密に管理されており、コードが確実に再現可能な形で実行される。これに対し、PyTorchでは、シードの設定ミスが再現性に影響を与えることがある。
まとめ
JAXの課題は、オープンソースプロジェクトとしてのガバナンスがまだ確立されていないため、開発元であるGoogleの影響を強く受ける可能性があること、一部のAPIがまだ発展途上である点などが挙げられる。しかし、これらの問題にもかかわらず、JAXのコンパイラ駆動アプローチは、PyTorchに比べて非常に強力である。
この記事の筆者は、PyTorchが現代の科学計算や大規模なAI研究にはもはや適していないと強調している。その代わりに、JAXのようなコンパイラ駆動のフレームワークが、今後の研究の標準となるべきだと結論付けている。PyTorchは柔軟性を追求する一方で、その結果として複雑化し、エコシステム全体が断片化している。一方、JAXはその設計哲学を一貫して守り、シンプルで強力なツールとして機能している。
詳細は[PyTorch is dead. Long live JAX.]を参照していただきたい。