سر و صدا در مورد Google JAX چیست؟ بیاموزید که چگونه JAX Autograd و XLA را برای محاسبات عددی و تحقیقات یادگیری ماشینی فوقالعاده در مورد CPU، GPU و TPU ترکیب میکند.
- Autograd چیست؟
- XLA چیست؟
- شروع به کار با Google JAX
- نحوه نصب JAX
- API JAX
- محدودیتهای JAX
- استفاده از JAX برای شبکههای عصبی تسریع شده
- درباره JAX بیشتر بیاموزید
از جمله نوآوریهایی که پلتفرم منبع باز محبوب TensorFlow آموزش ماشین را تقویت میکند، تمایز خودکار است (Autograd) و کامپایلر بهینهسازی XLA (جبر خطی شتابدار) برای یادگیری عمیق.
Google JAX پروژه دیگری است که این دو فناوری را گرد هم می آورد و مزایای قابل توجهی برای سرعت و عملکرد ارائه می دهد. وقتی JAX روی GPU یا TPU اجرا میشود، میتواند جایگزین برنامههایی شود که NumPy را فراخوانی میکنند، اما برنامههای آن بسیار سریعتر اجرا میشوند. علاوه بر این، استفاده از JAX برای شبکههای عصبی میتواند افزودن قابلیتهای جدید را بسیار آسانتر از گسترش یک چارچوب بزرگتر مانند TensorFlow کند.
این مقاله Google JAX را معرفی میکند، از جمله مروری بر مزایا و محدودیتهای آن، دستورالعملهای نصب، و اولین نگاهی به شروع سریع Google JAX در Colab.
Autograd چیست؟
Autograd یک موتور تمایز خودکار است که به عنوان یک پروژه تحقیقاتی در گروه سیستم های احتمالی هوشمند هاروارد رایان آدامز آغاز شد. از زمان نوشتن این مقاله، موتور در حال نگهداری است اما دیگر به طور فعال توسعه نمی یابد. در عوض، توسعه دهندگان آن روی Google JAX کار می کنند که Autograd را با ویژگی های اضافی مانند کامپایل XLA JIT ترکیب می کند. موتور Autograd میتواند به طور خودکار کدهای Python و NumPy را متمایز کند. کاربرد اصلی مورد نظر آن بهینه سازی مبتنی بر گرادیان است.
API tf.GradientTape
TensorFlow بر اساس ایدههای مشابه Autograd است، اما پیادهسازی آن یکسان نیست. Autograd به طور کامل در Python نوشته شده است و گرادیان را مستقیماً از تابع محاسبه می کند، در حالی که عملکرد نوار گرادیان TensorFlow به زبان C++ با یک پوشش نازک پایتون نوشته شده است. TensorFlow از پس انتشار برای محاسبه تفاوتهای ضرر، تخمین گرادیان افت و پیشبینی بهترین مرحله بعدی استفاده میکند.
XLA چیست؟
XLA یک کامپایلر مخصوص دامنه برای جبر خطی است که توسط TensorFlow توسعه یافته است. با توجه به مستندات TensorFlow، XLA میتواند مدلهای TensorFlow را بدون تغییر کد منبع تسریع کند و سرعت و استفاده از حافظه را بهبود بخشد. یکی از نمونهها یک ارسال معیار BERT MLPerf در سال ۲۰۲۰ است، جایی که ۸ پردازنده گرافیکی Volta V100 با استفاده از XLA به بهبود عملکرد ~ ۷ برابری و بهبود اندازه دسته ای ~ ۵ برابری دست یافتند.
XLA یک نمودار TensorFlow را در دنباله ای از هسته های محاسباتی که به طور خاص برای مدل داده شده تولید شده اند، کامپایل می کند. از آنجایی که این هسته ها منحصر به مدل هستند، می توانند از اطلاعات خاص مدل برای بهینه سازی بهره برداری کنند. در TensorFlow، XLA همچنین کامپایلر JIT (فقط در زمان) نامیده می شود. می توانید آن را با یک پرچم در @tf.function
تزئین کننده پایتون فعال کنید، مانند:
@tf.function(jit_compile=True)
همچنین میتوانید XLA را در TensorFlow با تنظیم متغیر محیطی TF_XLA_FLAGS
یا با اجرای ابزار مستقل tfcompile
فعال کنید.
بهجز TensorFlow، برنامههای XLA میتوانند توسط:
تولید شوند
شروع به کار با Google JAX
من JAX Quickstart را در Colab مرور کردم که از یک GPU به صورت پیش فرض در صورت تمایل می توانید استفاده از TPU را انتخاب کنید، اما استفاده از TPU رایگان ماهانه محدود است. همچنین باید یک راه اندازی اولیه ویژه< /a> برای استفاده از Colab TPU برای Google JAX.
برای رسیدن به شروع سریع، دکمه Open in Colab را در بالای ارزیابی موازی در صفحه مستندات JAX. این شما را به محیط نوت بوک زنده سوئیچ می کند. سپس، دکمه اتصال را در نوت بوک برای اتصال به زمان اجرا میزبانی شده رها کنید.
اجرای سریع شروع با یک GPU مشخص کرد که JAX چقدر می تواند عملیات ماتریس و جبر خطی را تسریع کند. بعداً در نوت بوک، زمانهای شتابدادهشده با JIT را دیدم که در میکروثانیه اندازهگیری میشد. هنگامی که کد را می خوانید، ممکن است بیشتر آن حافظه شما را به عنوان بیان کننده عملکردهای رایج مورد استفاده در یادگیری عمیق بیان کند.
شکل ۱. یک مثال ریاضی ماتریسی در راه اندازی سریع Google JAX.
نحوه نصب JAX
نصب JAX باید با سیستم عامل شما و انتخاب نسخه CPU، GPU یا TPU مطابقت داشته باشد. برای CPU ها ساده است. برای مثال، اگر میخواهید JAX را روی لپتاپ خود اجرا کنید، وارد کنید:
pip install --upgrade pip
pip install --upgrade "jax[cpu]"
برای پردازندههای گرافیکی، باید CUDA و CuDNN به همراه یک درایور سازگار NVIDIA نصب شده است. شما به نسخههای نسبتاً جدید هر دو نیاز دارید. در لینوکس با نسخه های اخیر CUDA و CuDNN، می توانید چرخ های از پیش ساخته شده سازگار با CUDA را نصب کنید. در غیر این صورت، باید از منبع بسازید. p>
JAX همچنین چرخهای از پیش ساخته شده را برای Google Cloud TPUs. Cloud TPU جدیدتر از Colab TPU هستند و سازگار با عقب نیستند، اما محیط های Colab از قبل شامل JAX و پشتیبانی صحیح TPU هستند.
API JAX
در JAX API سه لایه وجود دارد. در بالاترین سطح، JAX آینه ای از NumPy API، jax.numpy
را پیاده سازی می کند. تقریباً هر کاری که با numpy
قابل انجام باشد را می توان با jax.numpy
انجام داد. محدودیت jax.numpy
این است که برخلاف آرایههای NumPy، آرایههای JAX غیرقابل تغییر هستند، به این معنی که پس از ایجاد محتوای آنها نمیتوان تغییر داد.
لایه میانی JAX API jax.lax
است که سختتر و اغلب قویتر از لایه NumPy است. تمام عملیات در jax.numpy
در نهایت بر حسب توابع تعریف شده در jax.lax
بیان میشوند. در حالی که jax.numpy
به طور ضمنی آرگومان هایی را برای اجازه دادن به عملیات بین انواع داده های مختلط ترویج می کند، jax.lax
این کار را نخواهد کرد. در عوض، توابع تبلیغاتی صریح را ارائه می کند.
پایین ترین لایه API XLA است. همه عملیات jax.lax
بستهبندی پایتون برای عملیات در XLA هستند. هر عملیات JAX در نهایت بر حسب این عملیات اساسی XLA بیان میشود، که کامپایل JIT را قادر میسازد.
محدودیت های JAX
تبدیلها و کامپایلهای JAX برای کار کردن فقط روی توابع پایتون طراحی شدهاند. که از نظر عملکردی خالص هستند. اگر یک تابع یک عارضه جانبی داشته باشد، حتی چیزی به سادگی عبارت print()
، اجراهای متعدد از طریق کد عوارض جانبی متفاوتی خواهند داشت. یک print()
در اجراهای بعدی چیزهای مختلف یا اصلاً هیچ چاپ نمیکند.
از دیگر محدودیتهای JAX میتوان به عدم اجازه جهش در محل اشاره کرد (زیرا آرایهها تغییرناپذیر هستند). این محدودیت با اجازه بهروزرسانی آرایههای بیجا کاهش مییابد:
updated_array = jax_array.at[1, :].set(1.0)
علاوه بر این، JAX بهصورت پیشفرض روی اعداد دقیق تکی (float32
) پیشفرض میشود، در حالی که NumPy پیشفرض دقت دو برابری (float64
) را دارد. اگر واقعاً به دقت مضاعف نیاز دارید، میتوانید JAX را روی حالت jax_enable_x64
تنظیم کنید. به طور کلی، محاسبات تک دقیق سریعتر اجرا می شوند و به حافظه GPU کمتری نیاز دارند.
استفاده از JAX برای شبکه های عصبی تسریع شده
در این مرحله، باید واضح باشد که میتوانید شبکههای عصبی شتابدار را در JAX پیادهسازی کنید. از سوی دیگر، چرا چرخ را دوباره اختراع کنیم؟ گروههای تحقیقاتی Google و DeepMind چندین کتابخانه شبکه عصبی مبتنی بر JAX را منبع باز کردهاند: Flax یک کتابخانه کاملاً ویژه برای آموزش شبکههای عصبی با مثالها و راهنماهایی است. هایکو برای ماژولهای شبکه عصبی است، Optax برای پردازش گرادیان و بهینهسازی است، RLax است برای الگوریتمهای RL (یادگیری تقویتی) و chex برای کد و آزمایش قابل اعتماد است. .
درباره JAX بیشتر بدانید
علاوه بر JAX Quickstart، JAX دارای مجموعه آموزشهایی که میتوانید (و باید) در Colab اجرا کنید. اولین آموزش به شما نشان می دهد که چگونه از توابع jax.numpy
، توابع grad
و value_and_grad
و @jit استفاده کنید. کد> دکوراتور. آموزش بعدی به عمق بیشتری در مورد کامپایل JIT می پردازد. با آخرین آموزش، شما در حال یادگیری نحوه کامپایل و پارتیشن بندی خودکار توابع در هر دو محیط تک و چند میزبان هستید.
شما می توانید (و باید) اسناد مرجع JAX را نیز بخوانید (با سؤالات متداول شروع می شود ) و آموزش های پیشرفته را اجرا کنید (شروع با کتاب آشپزی خودکار) در Colab. در نهایت، باید مستندات API را بخوانید و با بسته اصلی JAX شروع کنید.< /p>
پست های مرتبط
Google JAX چیست؟ NumPy در شتاب دهنده ها
Google JAX چیست؟ NumPy در شتاب دهنده ها
Google JAX چیست؟ NumPy در شتاب دهنده ها