۲۹ شهریور ۱۴۰۳

Techboy

اخبار و اطلاعات روز تکنولوژی

Google JAX چیست؟ NumPy در شتاب دهنده ها

سر و صدا در مورد Google JAX چیست؟ بیاموزید که چگونه JAX Autograd و XLA را برای محاسبات عددی و تحقیقات یادگیری ماشینی فوق‌العاده در مورد CPU، GPU و TPU ترکیب می‌کند.

سر و صدا در مورد Google JAX چیست؟ بیاموزید که چگونه JAX Autograd و XLA را برای محاسبات عددی و تحقیقات یادگیری ماشینی فوق‌العاده در مورد CPU، GPU و TPU ترکیب می‌کند.

از جمله نوآوری‌هایی که پلتفرم منبع باز محبوب TensorFlow آموزش ماشین را تقویت می‌کند، تمایز خودکار است (Autograd) و کامپایلر بهینه‌سازی XLA (جبر خطی شتاب‌دار) برای یادگیری عمیق.

Google JAX پروژه دیگری است که این دو فناوری را گرد هم می آورد و مزایای قابل توجهی برای سرعت و عملکرد ارائه می دهد. وقتی JAX روی GPU یا TPU اجرا می‌شود، می‌تواند جایگزین برنامه‌هایی شود که NumPy را فراخوانی می‌کنند، اما برنامه‌های آن بسیار سریع‌تر اجرا می‌شوند. علاوه بر این، استفاده از JAX برای شبکه‌های عصبی می‌تواند افزودن قابلیت‌های جدید را بسیار آسان‌تر از گسترش یک چارچوب بزرگ‌تر مانند TensorFlow کند.

این مقاله Google JAX را معرفی می‌کند، از جمله مروری بر مزایا و محدودیت‌های آن، دستورالعمل‌های نصب، و اولین نگاهی به شروع سریع Google JAX در Colab.

Autograd چیست؟

Autograd یک موتور تمایز خودکار است که به عنوان یک پروژه تحقیقاتی در گروه سیستم های احتمالی هوشمند هاروارد رایان آدامز آغاز شد. از زمان نوشتن این مقاله، موتور در حال نگهداری است اما دیگر به طور فعال توسعه نمی یابد. در عوض، توسعه دهندگان آن روی Google JAX کار می کنند که Autograd را با ویژگی های اضافی مانند کامپایل XLA JIT ترکیب می کند. موتور Autograd می‌تواند به طور خودکار کدهای Python و NumPy را متمایز کند. کاربرد اصلی مورد نظر آن بهینه سازی مبتنی بر گرادیان است.

API tf.GradientTape TensorFlow بر اساس ایده‌های مشابه Autograd است، اما پیاده‌سازی آن یکسان نیست. Autograd به طور کامل در Python نوشته شده است و گرادیان را مستقیماً از تابع محاسبه می کند، در حالی که عملکرد نوار گرادیان TensorFlow به زبان C++ با یک پوشش نازک پایتون نوشته شده است. TensorFlow از پس انتشار برای محاسبه تفاوت‌های ضرر، تخمین گرادیان افت و پیش‌بینی بهترین مرحله بعدی استفاده می‌کند.

نظرسنجی Stack Overflow می گوید که Rust بیشترین زبان را تحسین می کند

XLA چیست؟

XLA یک کامپایلر مخصوص دامنه برای جبر خطی است که توسط TensorFlow توسعه یافته است. با توجه به مستندات TensorFlow، XLA می‌تواند مدل‌های TensorFlow را بدون تغییر کد منبع تسریع کند و سرعت و استفاده از حافظه را بهبود بخشد. یکی از نمونه‌ها یک ارسال معیار BERT MLPerf در سال ۲۰۲۰ است، جایی که ۸ پردازنده گرافیکی Volta V100 با استفاده از XLA به بهبود عملکرد ~ ۷ برابری و بهبود اندازه دسته ای ~ ۵ برابری دست یافتند.

XLA یک نمودار TensorFlow را در دنباله ای از هسته های محاسباتی که به طور خاص برای مدل داده شده تولید شده اند، کامپایل می کند. از آنجایی که این هسته ها منحصر به مدل هستند، می توانند از اطلاعات خاص مدل برای بهینه سازی بهره برداری کنند. در TensorFlow، XLA همچنین کامپایلر JIT (فقط در زمان) نامیده می شود. می توانید آن را با یک پرچم در @tf.function تزئین کننده پایتون فعال کنید، مانند:


@tf.function(jit_compile=True)

همچنین می‌توانید XLA را در TensorFlow با تنظیم متغیر محیطی TF_XLA_FLAGS یا با اجرای ابزار مستقل tfcompile فعال کنید.

به‌جز TensorFlow، برنامه‌های XLA می‌توانند توسط:

تولید شوند

شروع به کار با Google JAX

من JAX Quickstart را در Colab مرور کردم که از یک GPU به صورت پیش فرض در صورت تمایل می توانید استفاده از TPU را انتخاب کنید، اما استفاده از TPU رایگان ماهانه محدود است. همچنین باید یک راه اندازی اولیه ویژه< /a> برای استفاده از Colab TPU برای Google JAX.

برای رسیدن به شروع سریع، دکمه Open in Colab را در بالای ارزیابی موازی در صفحه مستندات JAX. این شما را به محیط نوت بوک زنده سوئیچ می کند. سپس، دکمه اتصال را در نوت بوک برای اتصال به زمان اجرا میزبانی شده رها کنید.

اجرای سریع شروع با یک GPU مشخص کرد که JAX چقدر می تواند عملیات ماتریس و جبر خطی را تسریع کند. بعداً در نوت بوک، زمان‌های شتاب‌داده‌شده با JIT را دیدم که در میکروثانیه اندازه‌گیری می‌شد. هنگامی که کد را می خوانید، ممکن است بیشتر آن حافظه شما را به عنوان بیان کننده عملکردهای رایج مورد استفاده در یادگیری عمیق بیان کند.

عرضه هوش مصنوعی بسیار جلوتر از تقاضای هوش مصنوعی است

Google JAX

شکل ۱. یک مثال ریاضی ماتریسی در راه اندازی سریع Google JAX.

نحوه نصب JAX

نصب JAX باید با سیستم عامل شما و انتخاب نسخه CPU، GPU یا TPU مطابقت داشته باشد. برای CPU ها ساده است. برای مثال، اگر می‌خواهید JAX را روی لپ‌تاپ خود اجرا کنید، وارد کنید:


pip install --upgrade pip

pip install --upgrade "jax[cpu]"

برای پردازنده‌های گرافیکی، باید CUDA و CuDNN به همراه یک درایور سازگار NVIDIA نصب شده است. شما به نسخه‌های نسبتاً جدید هر دو نیاز دارید. در لینوکس با نسخه های اخیر CUDA و CuDNN، می توانید چرخ های از پیش ساخته شده سازگار با CUDA را نصب کنید. در غیر این صورت، باید از منبع بسازید. p>

JAX همچنین چرخ‌های از پیش ساخته شده را برای Google Cloud TPUs. Cloud TPU جدیدتر از Colab TPU هستند و سازگار با عقب نیستند، اما محیط های Colab از قبل شامل JAX و پشتیبانی صحیح TPU هستند.

API JAX

در JAX API سه لایه وجود دارد. در بالاترین سطح، JAX آینه ای از NumPy API، jax.numpy را پیاده سازی می کند. تقریباً هر کاری که با numpy قابل انجام باشد را می توان با jax.numpy انجام داد. محدودیت jax.numpy این است که برخلاف آرایه‌های NumPy، آرایه‌های JAX غیرقابل تغییر هستند، به این معنی که پس از ایجاد محتوای آنها نمی‌توان تغییر داد.

لایه میانی JAX API jax.lax است که سخت‌تر و اغلب قوی‌تر از لایه NumPy است. تمام عملیات در jax.numpy در نهایت بر حسب توابع تعریف شده در jax.lax بیان می‌شوند. در حالی که jax.numpy به طور ضمنی آرگومان هایی را برای اجازه دادن به عملیات بین انواع داده های مختلط ترویج می کند، jax.lax این کار را نخواهد کرد. در عوض، توابع تبلیغاتی صریح را ارائه می کند.

پایین ترین لایه API XLA است. همه عملیات jax.lax بسته‌بندی پایتون برای عملیات در XLA هستند. هر عملیات JAX در نهایت بر حسب این عملیات اساسی XLA بیان می‌شود، که کامپایل JIT را قادر می‌سازد.

محدودیت های JAX

تبدیل‌ها و کامپایل‌های JAX برای کار کردن فقط روی توابع پایتون طراحی شده‌اند. که از نظر عملکردی خالص هستند. اگر یک تابع یک عارضه جانبی داشته باشد، حتی چیزی به سادگی عبارت print()، اجراهای متعدد از طریق کد عوارض جانبی متفاوتی خواهند داشت. یک print() در اجراهای بعدی چیزهای مختلف یا اصلاً هیچ چاپ نمی‌کند.

بررسی: Databricks Lakehouse Platform

از دیگر محدودیت‌های JAX می‌توان به عدم اجازه جهش در محل اشاره کرد (زیرا آرایه‌ها تغییرناپذیر هستند). این محدودیت با اجازه به‌روزرسانی آرایه‌های بی‌جا کاهش می‌یابد:


updated_array = jax_array.at[1, :].set(1.0)

علاوه بر این، JAX به‌صورت پیش‌فرض روی اعداد دقیق تکی (float32) پیش‌فرض می‌شود، در حالی که NumPy پیش‌فرض دقت دو برابری (float64) را دارد. اگر واقعاً به دقت مضاعف نیاز دارید، می‌توانید JAX را روی حالت jax_enable_x64 تنظیم کنید. به طور کلی، محاسبات تک دقیق سریعتر اجرا می شوند و به حافظه GPU کمتری نیاز دارند.

استفاده از JAX برای شبکه های عصبی تسریع شده

در این مرحله، باید واضح باشد که می‌توانید شبکه‌های عصبی شتاب‌دار را در JAX پیاده‌سازی کنید. از سوی دیگر، چرا چرخ را دوباره اختراع کنیم؟ گروه‌های تحقیقاتی Google و DeepMind چندین کتابخانه شبکه عصبی مبتنی بر JAX را منبع باز کرده‌اند: Flax یک کتابخانه کاملاً ویژه برای آموزش شبکه‌های عصبی با مثال‌ها و راهنماهایی است. هایکو برای ماژول‌های شبکه عصبی است، Optax برای پردازش گرادیان و بهینه‌سازی است، RLax است برای الگوریتم‌های RL (یادگیری تقویتی) و chex برای کد و آزمایش قابل اعتماد است. .

درباره JAX بیشتر بدانید

علاوه بر JAX Quickstart، JAX دارای مجموعه آموزش‌هایی که می‌توانید (و باید) در Colab اجرا کنید. اولین آموزش به شما نشان می دهد که چگونه از توابع jax.numpy، توابع grad و value_and_grad و @jit استفاده کنید. کد> دکوراتور. آموزش بعدی به عمق بیشتری در مورد کامپایل JIT می پردازد. با آخرین آموزش، شما در حال یادگیری نحوه کامپایل و پارتیشن بندی خودکار توابع در هر دو محیط تک و چند میزبان هستید.

شما می توانید (و باید) اسناد مرجع JAX را نیز بخوانید (با سؤالات متداول شروع می شود ) و آموزش های پیشرفته را اجرا کنید (شروع با کتاب آشپزی خودکار) در Colab. در نهایت، باید مستندات API را بخوانید و با بسته اصلی JAX شروع کنید.< /p>