The Ultimate Guide to Writing Functions
1.视频 https://www.youtube.com/watch?v=yatgY4NpZXE
2.代码 https://github.com/ArjanCodes/2022-funcguide
Python高质量函数编写指南
1. 一次做好一件事
from dataclasses import dataclass
from datetime import datetime
@dataclass
class Customer:
name: str
phone: str
cc_number: str
cc_exp_month: int
cc_exp_year: int
cc_valid: bool = False
# validate_card函数做了太多事情
def validate_card(customer: Customer) -> bool:
def digits_of(number: str) -> list[int]:
return [int(d) for d in number]
digits = digits_of(customer.cc_number)
odd_digits = digits[-1::-2]
even_digits = digits[-2::-2]
checksum = 0
checksum += sum(odd_digits)
for digit in even_digits:
checksum += sum(digits_of(str(digit * 2)))
customer.cc_valid = (
checksum % 10 == 0
and datetime(customer.cc_exp_year, customer.cc_exp_month, 1) > datetime.now()
)
return customer.cc_valid
def main() -> None:
alice = Customer(
name="Alice",
phone="2341",
cc_number="1249190007575069",
cc_exp_month=1,
cc_exp_year=2024,
)
is_valid = validate_card(alice)
print(f"Is Alice's card valid? {is_valid}")
print(alice)
if __name__ == "__main__":
main()
我们发现validate_card
函数做了两件事:验证数字和有效、验证时间有效。
我们把验证数字和拆分出来一个函数luhn_checksum
, 并在validate_card
中调用。
修改后:
from dataclasses import dataclass
from datetime import datetime
# 验证和 函数
def luhn_checksum(card_number: str) -> bool:
def digits_of(number: str) -> list[int]:
return [int(d) for d in number]
digits = digits_of(card_number)
odd_digits = digits[-1::-2]
even_digits = digits[-2::-2]
checksum = 0
checksum += sum(odd_digits)
for digit in even_digits:
checksum += sum(digits_of(str(digit * 2)))
return checksum % 10 == 0
@dataclass
class Customer:
name: str
phone: str
...
def validate_card(customer: Customer) -> bool:
customer.cc_valid = (
luhn_checksum(customer.cc_number)
and datetime(customer.cc_exp_year, customer.cc_exp_month, 1) > datetime.now()
)
return customer.cc_valid
2. 分离命令和查询(command and query)
validate_card
中同时进行了查询和赋值两个操作,这样不好。
我们将查询和赋值拆分成两个步骤。
validate_card
只返回卡是否有效,而赋值操作alice.cc_valid = validate_card(alice)
移动到了主函数中。
from dataclasses import dataclass
from datetime import datetime
def luhn_checksum(card_number: str) -> bool:
def digits_of(number: str) -> list[int]:
return [int(d) for d in number]
digits = digits_of(card_number)
odd_digits = digits[-1::-2]
even_digits = digits[-2::-2]
checksum = 0
checksum += sum(odd_digits)
for digit in even_digits:
checksum += sum(digits_of(str(digit * 2)))
return checksum % 10 == 0
@dataclass
class Customer:
name: str
phone: str
cc_number: str
cc_exp_month: int
cc_exp_year: int
cc_valid: bool = False
# 查询
def validate_card(customer: Customer) -> bool:
return (
luhn_checksum(customer.cc_number)
and datetime(customer.cc_exp_year, customer.cc_exp_month, 1) > datetime.now()
)
def main() -> None:
alice = Customer(
name="Alice",
phone="2341",
cc_number="1249190007575069",
cc_exp_month=1,
cc_exp_year=2024,
)
# 赋值
alice.cc_valid = validate_card(alice)
print(f"Is Alice's card valid? {alice.cc_valid}")
print(alice)
if __name__ == "__main__":
main()
3. 只请求你需要的
函数validate_card
实际上只需要3个参数(而不需要整个Customer对象)。
因此只请求3个参数:def validate_card(*, number: str, exp_month: int, exp_year: int) -> bool:
``
from dataclasses import dataclass
from datetime import datetime
def luhn_checksum(card_number: str) -> bool:
def digits_of(number: str) -> list[int]:
return [int(d) for d in number]
digits = digits_of(card_number)
odd_digits = digits[-1::-2]
even_digits = digits[-2::-2]
checksum = 0
checksum += sum(odd_digits)
for digit in even_digits:
checksum += sum(digits_of(str(digit * 2)))
return checksum % 10 == 0
@dataclass
class Customer:
name: str
phone: str
cc_number: str
cc_exp_month: int
cc_exp_year: int
cc_valid: bool = False
# 只请求你需要的参数
def validate_card(*, number: str, exp_month: int, exp_year: int) -> bool:
return luhn_checksum(number) and datetime(exp_year, exp_month, 1) > datetime.now()
def main() -> None:
alice = Customer(
name="Alice",
phone="2341",
cc_number="1249190007575069",
cc_exp_month=1,
cc_exp_year=2024,
)
alice.cc_valid = validate_card(
number=alice.cc_number,
exp_month=alice.cc_exp_month,
exp_year=alice.cc_exp_year,
)
print(f"Is Alice's card valid? {alice.cc_valid}")
print(alice)
if __name__ == "__main__":
main()
4. 保持最小参数量
参数量很多时,调用时传参会比较麻烦。另一方面,函数需要很多参数,则暗示该函数可能做了很多事情。
下面我们抽象出Card
类, 减少了Customer
和validae_card
的参数量。
from dataclasses import dataclass
from datetime import datetime
from typing import Protocol
def luhn_checksum(card_number: str) -> bool:
def digits_of(number: str) -> list[int]:
return [int(d) for d in number]
digits = digits_of(card_number)
odd_digits = digits[-1::-2]
even_digits = digits[-2::-2]
checksum = 0
checksum += sum(odd_digits)
for digit in even_digits:
checksum += sum(digits_of(str(digit * 2)))
return checksum % 10 == 0
@dataclass
class Card:
number: str
exp_month: int
exp_year: int
valid: bool = False
@dataclass
class Customer:
name: str
phone: str
card: Card
card_valid: bool = False
class CardInfo(Protocol):
@property
def number(self) -> str:
...
@property
def exp_month(self) -> int:
...
@property
def exp_year(self) -> int:
...
def validate_card(card: CardInfo) -> bool:
return (
luhn_checksum(card.number)
and datetime(card.exp_year, card.exp_month, 1) > datetime.now()
)
def main() -> None:
card = Card(number="1249190007575069", exp_month=1, exp_year=2024)
alice = Customer(name="Alice", phone="2341", card=card) # 现在传入card,而不是3个参数
card.valid = validate_card(card) # 传入card
print(f"Is Alice's card valid? {card.valid}")
print(alice)
if __name__ == "__main__":
main()
5. 不要在同一个地方创建并使用对象
不要再函数内创建对象并使用,更好的方式是在外面创建对象并作为参数传递给函数。
import logging
class StripePaymentHandler:
def handle_payment(self, amount: int) -> None:
logging.info(f"Charging ${amount/100:.2f} using Stripe")
PRICES = {
"burger": 10_00,
"fries": 5_00,
"drink": 2_00,
"salad": 15_00,
}
# !!
def order_food(items: list[str]) -> None:
total = sum(PRICES[item] for item in items)
logging.info(f"Order total is ${total/100:.2f}.")
payment_handler = StripePaymentHandler() # ... 创建对象
payment_handler.handle_payment(total) # 使用对象
logging.info("Order completed.")
def main() -> None:
logging.basicConfig(level=logging.INFO)
order_food(["burger", "fries", "drink"])
if __name__ == "__main__":
main()
修改后:
import logging
from typing import Protocol
class StripePaymentHandler:
def handle_payment(self, amount: int) -> None:
logging.info(f"Charging ${amount/100:.2f} using Stripe")
PRICES = {
"burger": 10_00,
"fries": 5_00,
"drink": 2_00,
"salad": 15_00,
}
class PaymentHandler(Protocol):
def handle_payment(self, amount: int) -> None:
...
# !! 现在通过参数传入对象
def order_food(items: list[str], payment_handler: PaymentHandler) -> None:
total = sum(PRICES[item] for item in items)
logging.info(f"Order total is ${total/100:.2f}.")
payment_handler.handle_payment(total) #
logging.info("Order completed.")
def main() -> None:
logging.basicConfig(level=logging.INFO)
order_food(["burger", "salad", "drink"], StripePaymentHandler())
if __name__ == "__main__":
main()
6. 不要用flag参数
flag参数意味着函数处理两种情况,函数会变得复杂。建议将两者情况拆分成单独的函数。
from dataclasses import dataclass
from enum import StrEnum, auto
FIXED_VACATION_DAYS_PAYOUT = 5
class Role(StrEnum):
PRESIDENT = auto()
VICEPRESIDENT = auto()
MANAGER = auto()
LEAD = auto()
ENGINEER = auto()
INTERN = auto()
@dataclass
class Employee:
name: str
role: Role
vacation_days: int = 25
def take_a_holiday(self, payout: bool, nr_days: int = 1) -> None:
if payout:
if self.vacation_days < FIXED_VACATION_DAYS_PAYOUT:
raise ValueError(
f"You don't have enough holidays left over for a payout.\
Remaining holidays: {self.vacation_days}."
)
self.vacation_days -= FIXED_VACATION_DAYS_PAYOUT
print(f"Paying out a holiday. Holidays left: {self.vacation_days}")
else:
if self.vacation_days < nr_days:
raise ValueError(
"You don't have any holidays left. Now back to work, you!"
)
self.vacation_days -= nr_days
print("Have fun on your holiday. Don't forget to check your emails!")
def main() -> None:
employee = Employee(name="John Doe", role=Role.ENGINEER)
employee.take_a_holiday(True)
if __name__ == "__main__":
main()
修改后:
from dataclasses import dataclass
from enum import StrEnum, auto
FIXED_VACATION_DAYS_PAYOUT = 5
class Role(StrEnum):
PRESIDENT = auto()
VICEPRESIDENT = auto()
MANAGER = auto()
LEAD = auto()
ENGINEER = auto()
INTERN = auto()
@dataclass
class Employee:
name: str
role: Role
vacation_days: int = 25
def payout_holiday(self) -> None:
if self.vacation_days < FIXED_VACATION_DAYS_PAYOUT:
raise ValueError(
f"You don't have enough holidays left over for a payout.\
Remaining holidays: {self.vacation_days}."
)
self.vacation_days -= FIXED_VACATION_DAYS_PAYOUT
print(f"Paying out a holiday. Holidays left: {self.vacation_days}")
def take_holiday(self, nr_days: int = 1) -> None:
if self.vacation_days < nr_days:
raise ValueError("You don't have any holidays left. Now back to work, you!")
self.vacation_days -= nr_days
print("Have fun on your holiday. Don't forget to check your emails!")
def main() -> None:
employee = Employee(name="John Doe", role=Role.ENGINEER)
employee.payout_holiday()
if __name__ == "__main__":
main()
7. 函数也是对象
函数也是对象,因此可以作为参数传递,作为函数返回值。
import logging
from functools import partial
from typing import Callable
def handle_payment_stripe(amount: int) -> None:
logging.info(f"Charging ${amount/100:.2f} using Stripe")
PRICES = {
"burger": 10_00,
"fries": 5_00,
"drink": 2_00,
"salad": 15_00,
}
HandlePaymentFn = Callable[[int], None]
# 函数作为参数
def order_food(items: list[str], payment_handler: HandlePaymentFn) -> None:
total = sum(PRICES[item] for item in items)
logging.info(f"Order total is ${total/100:.2f}.")
payment_handler(total)
logging.info("Order completed.")
order_food_stripe = partial(order_food, payment_handler=handle_payment_stripe)
def main() -> None:
logging.basicConfig(level=logging.INFO)
# order_food(["burger", "salad", "drink"], handle_payment_stripe)
order_food_stripe(["burger", "salad", "drink"])
if __name__ == "__main__":
main()