Tail recursion in Scala
ভূমিকা
আপনারা অনেকেই হয়ত recursion ব্যবহার করেছেন। কিছু কিছু সমস্যা সমাধানে recursion এর ব্যবহার খুবই স্বাভাবিক, যেমন gcd, fibonacci, factorial ইত্যাদি।
def factorial (n:Int): Int =
if (n == 0) 1
else (n * factorial(n-1))
Scala তে recursive ফাংশান লিখার সময় একটি জিনিস আমাদের জানা খুব জরুরি, তা হল tail recursion। আমাদের লিখা recursive ফাংশান যদি tail recursive না হয়, তাহলে প্রোগ্রাম চলার সময় StackOverflowError হবার সম্ভাবনা থেকে যায়। চলুন দেখি বিস্তারিত।
Tail recursion কি?
আমরা জানি recursive ফাংশান হল সেই ফাংশান, যে তার বডি এর মধ্যে নিজেকেই আবার কল করে। কোন recursive ফাংশান এর একদম শেষের এক্সপ্রেশনটি যদি হয় তার নিজেকে কল করা, তাহলে সেই ফাংশান কে আমরা বলি tail recursive ফাংশান। যেমন নিচের gcd ফাংশানটি একটি tail recursive ফাংশান, কারণ ফাংশানটি তার শেষ এক্সপ্রেশেনে নিজেকে কল করেছে ।
def gcd(a: Int, b: Int): Int = {
if (b == 0) a
else gcd(b, a % b)
}
Tail recursion এর গুরুত্ব
জাভা ভার্চুয়াল মেশিন এ প্রতিটি ফাংশান কল এর জন্য একটি করে stack ফ্রেম ব্যবহার হয়। recursive ফাংশান এর ক্ষেত্রে ব্যাপারটা অনেকটা নিচের মত।
- যখন একটি
recursiveফাংশান নিজেকে কল করে, তখন ঐ ফাংশানটির ইনফর্মেশন এর একটি কপিstackএ পুশ করা হয়। - প্রতিবার ফাংশানটি যখন নিজেকে কল করে, নতুন করে ফাংশানটির ইনফর্মেশন
stackএ পুশ করা হয়। এ কারণেrecursionএর প্রত্যেকটি লেভেল এর জন্য একটি করেstackফ্রেম দরকার হয়। একটিrecursiveফাংশান যদি ১ লাখ বার নিজেকে কল করে তাহলে ১ লাখটিstackফ্রেম তৈরি হবে। - সমস্যা হল জাভা ভার্চুয়াল মেশিন প্রতিটি থ্রেড এর জন্য
stackএর আকার সীমাবদ্ধ করে দিয়েছে। যখন সেটার অধিক মেমরি আমাদের ফাংশান নিতে চাইবে, ভার্চুয়াল মেশিন তখনStackOverflowErrorথ্রো করবে।
Tail recursive ফাংশান এর সুবিধা হল, এর প্রত্যেকটি লেভেল এর জন্য নতুন করে কোন stack ফ্রেম এর প্রয়োজন হয় না, কাজেই StackOverflowError হবার কোন সম্ভাবনা নেই। চলুন একটি উদাহরণ এর মধ্যমে ব্যাপারটি দেখে নেওয়া যাক।
def sum(ls: List[Int]): Long = ls match {
case Nil => 0 // if the list is empty, return 0
case x :: xs => x + sum(xs) // otherwise with the current element, add the rest of the element's sum to get the result
}
উপরের ফাংশানটি একটি recursive ফাংশান, যা একটি লিস্ট এর যোগফল বের করছে। বলুন দেখি উপরের ফাংশানটি tail recursive কিনা?
আপাতদৃষ্টিতে মনে হচ্ছে ফাংশানটি যেহেতু তার শেষ লাইন এ নিজেকে কল করেছে, সেহেতু ফাংশানটি হয়ত tail recursive, কিন্তু আসলে তা নয়। ফাংশানটির শেষ লাইন করে ভেঙ্গে যদি আমরা নিচের মত করে লিখি তাহলেই ব্যাপারটি ধরতে পারব।
val current = x
val restSum = sum(xs)
current + restSum
যেহেতু ফাংশানটি tail recursive নয়, সেহেতু আমার মেশিন এ (-Xss ছাড়া, java 8, scala 2.12) আমি নিচের মত করে ফাংশানটিকে ব্যবহার করতে গেলে StackOverflowError পাই।
val myList = (1 to 8000).toList
val result = sum(myList)
...
Exception in thread "main" java.lang.StackOverflowError
এবার চলুন আমরা এই ফাংশানটির একটি tail recursive ভার্সন লিখি, এবং সেটা চালিয়ে দেখি।
def sumT(ls: List[Int]): Long = {
@tailrec
def sumInternal(ls: List[Int], acc: Long) : Long = ls match {
case Nil => acc
case x :: xs => sumInternal(xs, x + acc)
}
sumInternal(ls, 0)
}
val myList = (1 to 8000).toList
val result = sumT(myList)
// no error this time, result is computed fine.
এখানে Scala compiler যখন recursive ফাংশানটিকে tail recursive হিসাবে চিহ্নিত করতে পারে, তখন যে বাইটকোড তৈরি করে সেখানে মাত্র একটি stack frame ব্যবহার করেই পুরো ফাংশান এর কাজটি সম্পন্ন করে ফেলে। এক্ষেত্রে compiler recursive ফাংশানটিকে একটি loop এ রূপান্তরিত করে (প্রতিটি নতুন ফাংশান কল, হয়ে যায় একটি goto ইন্সট্রাকশন, পরিশিষ্ট দেখুন)। যে কারণে প্রতিটি recursion লেভেল এর জন্য আলাদা আলাদা stack frame এর প্রয়োজন হয় না, এবং StackOverfloError ও এড়ানো যায়। এই পুরো ব্যাপারটিকে বলা হয় tail call optimization।
কাজেই যখন আমরা নিশ্চিত হতে পারব না যে আমাদের recursive ফাংশানটিকে কি পরিমাণ data নিয়ে কাজ করতে হবে, তখন আমরা অবশ্যই চেষ্টা করব যাতে ফাংশানটিকে tail recursive করে লিখা যায়। নইলে রানটাইমে গিয়ে StackOverfloError হবার সম্ভাবনা প্রচুর।
@tailrec annotation
যদিও scala compiler নিজে থেকেই tail recursive ফাংশান চিহ্নিত করতে পারে, তারপরেও আমরা আমাদের tail recursive ফাংশান এ @tailrec annotation টি ব্যবহার করতে পারি। এই annotation টি ব্যবহারের সুবিধা হল, যদি আমরা কোন ফাংশান যেটি tail recursive নয়, সেটাতে এই annotation টি ব্যবহার করি, তাইলে কম্পাইলেশন ফেইল হবে। যেমন আমরা যদি প্রথম sum ফাংশানটিতে এই annotation ব্যবহার করি তাহলে নিচের ঘটনা ঘটবে।
Error: could not optimize @tailrec annotated method sum: it contains a recursive call not in tail position
কাজেই এই annotation ব্যবহার করার পরে যদি আমাদের recursive ফাংশান এর কম্পাইলেশন ঠিকঠাক মত হয়, তাহলে আমরা বুঝতে পারব যে আমাদের ফাংশানটি একটি tail recursive ফাংশান হয়েছে।
কখন প্রযোজ্য নয়
নিম্নলিখিত ক্ষেত্রে Scala compiler, tail call optimize করবে না।
- যদি মেথডটিকে
overrideকরা যায় - যদি
indirect recursionঅথবাmutual recursionহয় - শেষ কলটি একটা
function valueতে যায়
শুধুমাত্র private অথবা final অথবা অন্য কোনও মেথড এর ভিতরের মেথড tail call optimization এর জন্য বিবেচিত হবে।
উপসংহার
প্রয়োজন না হলে recursive function না লিখাই ভাল। আগে দেখতে হবে যেসব library ফাংশান দেয়া আছে, যেমন fold, map, reduce ইত্যাদি দিয়ে কাজ হচ্ছে কিনা। আর যদি recursive ফাংশান না লিখে কোনও ভাবেই হাতে থাকা সমস্যাটি সমাধান করা না যায়, তাহলে চেষ্টা করতে হবে যাতে recursive ফাংশানটিকে tail recursive বানানো যায়।
এ বিষয়ে আরও বিস্তারিত জানতে নিচের লিঙ্কগুলো দেখতে পারেন।
- Martin Odersky এর বইয়ের এই অংশটুকু
- Stack frame নিয়ে বিস্তারিত
- Scala তে কিছু recursive function এর উদাহরণ
পরিশিষ্ট
tail recursion ছাড়া ফাংশান
$ cat ListSum.scala
object ListSum {
def sum(ls: List[Int]): Long = ls match {
case Nil => 0
case x :: xs => x + sum(xs)
}
}
$ scalac ListSum.scala
$ javap -p -c ListSum\$.class
Compiled from "ListSum.scala"
public final class ListSum$ {
public static final ListSum$ MODULE$;
...
public long sum(scala.collection.immutable.List<java.lang.Object>);
Code:
0: aload_1
...
6: invokevirtual #23 // Method java/lang/Object.equals:(Ljava/lang/Object;)Z
...
32: invokevirtual #29 // Method scala/collection/immutable/$colon$colon.head:()Ljava/lang/Object;
...
53: invokevirtual #41 // Method sum:(Lscala/collection/immutable/List;)J
...
68: athrow
...
}
// many lines are removed due to brevity
এখানে নিচের লাইন টি দেখে বুঝা যাচ্ছে যে sum ফাংশানটি আবার নিজেকেই কল করছে।
53: invokevirtual #41 // Method sum:(Lscala/collection/immutable/List;)J
tail recursion সহ ফাংশান
$ cat ListSumTailRec.scala
import scala.annotation.tailrec
object ListSumTailRec {
def sumT(ls: List[Int]): Long = {
@tailrec
def sumInternal(ls: List[Int], acc: Long): Long = ls match {
case Nil => acc
case x :: xs => sumInternal(xs, x + acc)
}
sumInternal(ls, 0)
}
}
$ scalac ListSumTailRec.scala
$ javap -p -c ListSumTailRec\$.class
Compiled from "ListSumTailRec.scala"
public final class ListSumTailRec$ {
public static final ListSumTailRec$ MODULE$;
...
public long sumT(scala.collection.immutable.List<java.lang.Object>);
Code:
0: aload_0
...
private final long sumInternal$1(scala.collection.immutable.List, long);
Code:
0: aload_1
...
8: invokevirtual #30 // Method java/lang/Object.equals:(Ljava/lang/Object;)Z
...
35: aload 8
37: invokevirtual #36 // Method scala/collection/immutable/$colon$colon.head:()Ljava/lang/Object;
40: invokestatic #42 // Method scala/runtime/BoxesRunTime.unboxToInt:(Ljava/lang/Object;)I
43: istore 9
45: aload 8
47: invokevirtual #46 // Method scala/collection/immutable/$colon$colon.tl$1:()Lscala/collection/immutable/List;
50: astore 10
...
59: lstore_2
60: astore_1
61: goto 0
...
73: athrow
...
}
// many lines are removed due to brevity
এবং এক্ষেত্রে নিচের লাইনটি দেখে বুঝা যাচ্ছে যে আমাদের tail recursive ফাংশান এর recursive কলটি goto দিয়ে প্রতিস্থাপিত হয়েছে।
61: goto 0