blob: 0cdaa7a0a2d84757bf35f8bdb611e2d3a9163263 [file] [log] [blame]
// Copyright (C) 2024 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! # safemath library
//!
//! This library provides an API to safely work with unsigned integers. At a high level, all math
//! operations are checked by default rather than having to remember to call specific `checked_*`
//! functions, so that the burden is on the programmer if they want to perform unchecked math
//! rather than the other way around:
//!
//! ```
//! use safemath::SafeNum;
//!
//! let safe = SafeNum::from(0);
//! let result = safe - 1;
//! assert!(u32::try_from(result).is_err());
//!
//! let safe_chain = (SafeNum::from(BIG_NUMBER) * HUGE_NUMBER) / MAYBE_ZERO;
//! // If any operation would have caused an overflow or division by zero,
//! // the number is flagged and the lexical location is specified for logging.
//! if safe_chain.has_error() {
//! eprintln!("safe_chain error = {:#?}", safe_chain);
//! }
//! ```
//!
//! In addition to checked-by-default arithmetic, the API exposed here support
//! more natural usage than the `checked_*` functions by allowing chaining
//! of operations without having to check the result at each step.
//! This is similar to how floating-point `NaN` works - you can continue to use the
//! value, but continued operations will just propagate `NaN`.
//!
//! ## Supported Operations
//!
//! ### Arithmetic
//! The basic arithmetic operations are supported:
//! addition, subtraction, multiplication, division, and remainder.
//! The right hand side may be another SafeNum or any integer,
//! and the result is always another SafeNum.
//! If the operation would result in an overflow or division by zero,
//! or if converting the right hand element to a `u64` would cause an error,
//! the result is an error-tagged SafeNum that tracks the lexical origin of the error.
//!
//! ### Conversion from and to SafeNum
//! SafeNums support conversion to and from all integer types.
//! Conversion to SafeNum from signed integers and from usize and u128
//! can fail, generating an error value that is then propagated.
//! Conversion from SafeNum to all integers is only exposed via `try_from`
//! in order to force the user to handle potential resultant errors.
//!
//! E.g.
//! ```
//! fn call_func(_: u32, _: u32) {
//! }
//!
//! fn do_a_thing(a: SafeNum) -> Result<(), safemath::Error> {
//! call_func(16, a.try_into()?);
//! Ok(())
//! }
//! ```
//!
//! ### Comparison
//! SafeNums can be checked for equality against each other.
//! Valid numbers are equal to other numbers of the same magnitude.
//! Errored SafeNums are only equal to themselves.
//! Note that because errors propagate from their first introduction in an
//! arithmetic chain this can lead to surprising results.
//!
//! E.g.
//! ```
//! let overflow = SafeNum::MAX + 1;
//! let otherflow = SafeNum::MAX + 1;
//!
//! assert_ne!(overflow, otherflow);
//! assert_eq!(overflow + otherflow, overflow);
//! assert_eq!(otherflow + overflow, otherflow);
//! ```
//!
//! Inequality comparison operators are deliberately not provided.
//! By necessity they would have similar caveats to floating point comparisons,
//! which are easy to use incorrectly and unintuitive to use correctly.
//!
//! The required alternative is to convert to a real integer type before comparing,
//! forcing any errors upwards.
//!
//! E.g.
//! ```
//! impl From<safemath::Error> for &'static str {
//! fn from(_: safemath::Error) -> Self {
//! "checked arithmetic error"
//! }
//! }
//!
//! fn my_op(a: SafeNum, b: SafeNum, c: SafeNum, d: SafeNum) -> Result<bool, &'static str> {
//! Ok(safemath::Primitive::try_from(a)? < b.try_into()?
//! && safemath::Primitive::try_from(c)? >= d.try_into()?)
//! }
//! ```
//!
//! ### Miscellaneous
//! SafeNums also provide helper methods to round up or down
//! to the nearest multiple of another number
//! and helper predicate methods that indicate whether the SafeNum
//! is valid or is tracking an error.
//!
//! Also provided are constants `SafeNum::MAX`, `SafeNum::MIN`, and `SafeNum::ZERO`.
//!
//! Warning: SafeNums can help prevent, isolate, and detect arithmetic overflow
//! but they are not a panacea. In particular, chains of different operations
//! are not guaranteed to be associative or commutative.
//!
//! E.g.
//! ```
//! let a = SafeNum::MAX - 1 + 1;
//! let b = SafeNum::MAX + 1 - 1;
//! assert_ne!(a, b);
//! assert!(a.is_valid());
//! assert!(b.has_error());
//!
//! let c = (SafeNum::MAX + 31) / 31;
//! let d = SafeNum::MAX / 31 + 31 / 31;
//! assert_ne!(c, d);
//! assert!(c.has_error());
//! assert!(d.is_valid());
//! ```
//!
//! Note: SafeNum arithmetic is much slower than arithmetic on integer primitives.
//! If you are concerned about performance, be sure to run benchmarks.
#![cfg_attr(not(test), no_std)]
use core::convert::TryFrom;
use core::fmt;
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Sub, SubAssign};
use core::panic::Location;
pub type Primitive = u64;
pub type Error = &'static Location<'static>;
#[derive(Copy, Clone, PartialEq, Eq)]
pub struct SafeNum(Result<Primitive, Error>);
impl fmt::Debug for SafeNum {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
Ok(val) => write!(f, "{}", val),
Err(location) => write!(f, "error at {}", location),
}
}
}
impl SafeNum {
pub const MAX: SafeNum = SafeNum(Ok(u64::MAX));
pub const MIN: SafeNum = SafeNum(Ok(u64::MIN));
pub const ZERO: SafeNum = SafeNum(Ok(0));
/// Round `self` down to the nearest multiple of `rhs`.
#[track_caller]
pub fn round_down<T>(self, rhs: T) -> Self
where
Self: Rem<T, Output = Self>,
{
self - (self % rhs)
}
/// Round `self` up to the nearest multiple of `rhs`.
#[track_caller]
pub fn round_up<T>(self, rhs: T) -> Self
where
Self: Add<T, Output = Self>,
T: Copy + Into<Self>,
{
((self + rhs) - 1).round_down(rhs)
}
/// Returns whether self is the result of an operation that has errored.
pub const fn has_error(&self) -> bool {
self.0.is_err()
}
/// Returns whether self represents a valid, non-overflowed integer.
pub const fn is_valid(&self) -> bool {
self.0.is_ok()
}
}
macro_rules! try_conversion_func {
($other_type:tt) => {
impl TryFrom<SafeNum> for $other_type {
type Error = Error;
#[track_caller]
fn try_from(val: SafeNum) -> Result<Self, Self::Error> {
Self::try_from(val.0?).map_err(|_| Location::caller())
}
}
};
}
macro_rules! conversion_func {
($from_type:tt) => {
impl From<$from_type> for SafeNum {
fn from(val: $from_type) -> SafeNum {
Self(Ok(val.into()))
}
}
try_conversion_func!($from_type);
};
}
macro_rules! conversion_func_maybe_error {
($from_type:tt) => {
impl From<$from_type> for SafeNum {
#[track_caller]
fn from(val: $from_type) -> Self {
Self(Primitive::try_from(val).map_err(|_| Location::caller()))
}
}
try_conversion_func!($from_type);
};
}
macro_rules! arithmetic_impl {
($trait_name:ident, $op:ident, $assign_trait_name:ident, $assign_op:ident, $func:ident) => {
impl<T: Into<SafeNum>> $trait_name<T> for SafeNum {
type Output = Self;
#[track_caller]
fn $op(self, rhs: T) -> Self {
let rhs: Self = rhs.into();
match (self.0, rhs.0) {
(Err(_), _) => self,
(_, Err(_)) => rhs,
(Ok(lhs), Ok(rhs)) => Self(lhs.$func(rhs).ok_or_else(Location::caller)),
}
}
}
impl<T> $assign_trait_name<T> for SafeNum
where
Self: $trait_name<T, Output = Self>,
{
#[track_caller]
fn $assign_op(&mut self, rhs: T) {
*self = self.$op(rhs)
}
}
};
}
conversion_func!(u8);
conversion_func!(u16);
conversion_func!(u32);
conversion_func!(u64);
conversion_func_maybe_error!(usize);
conversion_func_maybe_error!(u128);
conversion_func_maybe_error!(i8);
conversion_func_maybe_error!(i16);
conversion_func_maybe_error!(i32);
conversion_func_maybe_error!(i64);
conversion_func_maybe_error!(i128);
conversion_func_maybe_error!(isize);
arithmetic_impl!(Add, add, AddAssign, add_assign, checked_add);
arithmetic_impl!(Sub, sub, SubAssign, sub_assign, checked_sub);
arithmetic_impl!(Mul, mul, MulAssign, mul_assign, checked_mul);
arithmetic_impl!(Div, div, DivAssign, div_assign, checked_div);
arithmetic_impl!(Rem, rem, RemAssign, rem_assign, checked_rem);
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_addition() {
let a: SafeNum = 2100.into();
let b: SafeNum = 12.into();
assert_eq!(a + b, 2112.into());
}
#[test]
fn test_subtraction() {
let a: SafeNum = 667.into();
let b: SafeNum = 1.into();
assert_eq!(a - b, 666.into());
}
#[test]
fn test_multiplication() {
let a: SafeNum = 17.into();
let b: SafeNum = 3.into();
assert_eq!(a * b, 51.into());
}
#[test]
fn test_division() {
let a: SafeNum = 1066.into();
let b: SafeNum = 41.into();
assert_eq!(a / b, 26.into());
}
#[test]
fn test_remainder() {
let a: SafeNum = 613.into();
let b: SafeNum = 10.into();
assert_eq!(a % b, 3.into());
}
#[test]
fn test_addition_poison() {
let base: SafeNum = 2.into();
let poison = base + SafeNum::MAX;
assert!(u64::try_from(poison).is_err());
let a = poison - 1;
let b = poison - 2;
assert_eq!(a, poison);
assert_eq!(b, poison);
}
#[test]
fn test_subtraction_poison() {
let base: SafeNum = 2.into();
let poison = base - SafeNum::MAX;
assert!(u64::try_from(poison).is_err());
let a = poison + 1;
let b = poison + 2;
assert_eq!(a, poison);
assert_eq!(b, poison);
}
#[test]
fn test_multiplication_poison() {
let base: SafeNum = 2.into();
let poison = base * SafeNum::MAX;
assert!(u64::try_from(poison).is_err());
let a = poison / 2;
let b = poison / 4;
assert_eq!(a, poison);
assert_eq!(b, poison);
}
#[test]
fn test_division_poison() {
let base: SafeNum = 2.into();
let poison = base / 0;
assert!(u64::try_from(poison).is_err());
let a = poison * 2;
let b = poison * 4;
assert_eq!(a, poison);
assert_eq!(b, poison);
}
#[test]
fn test_remainder_poison() {
let base: SafeNum = 2.into();
let poison = base % 0;
assert!(u64::try_from(poison).is_err());
let a = poison * 2;
let b = poison * 4;
assert_eq!(a, poison);
assert_eq!(b, poison);
}
macro_rules! conversion_test {
($name:ident) => {
mod $name {
use super::*;
use core::convert::TryInto;
#[test]
fn test_between_safenum() {
let var: $name = 16;
let sn: SafeNum = var.into();
let res: $name = sn.try_into().unwrap();
assert_eq!(var, res);
}
#[test]
fn test_arithmetic_safenum() {
let primitive: $name = ((((0 + 11) * 11) / 3) % 32) - 3;
let safe = ((((SafeNum::ZERO + $name::try_from(11u8).unwrap())
* $name::try_from(11u8).unwrap())
/ $name::try_from(3u8).unwrap())
% $name::try_from(32u8).unwrap())
- $name::try_from(3u8).unwrap();
assert_eq!($name::try_from(safe).unwrap(), primitive);
}
}
};
}
conversion_test!(u8);
conversion_test!(u16);
conversion_test!(u32);
conversion_test!(u64);
conversion_test!(u128);
conversion_test!(usize);
conversion_test!(i8);
conversion_test!(i16);
conversion_test!(i32);
conversion_test!(i64);
conversion_test!(i128);
conversion_test!(isize);
macro_rules! correctness_tests {
($name:ident, $operation:ident, $assign_operation:ident) => {
mod $operation {
use super::*;
use core::ops::$name;
#[test]
fn test_correctness() {
let normal = 300u64;
let safe: SafeNum = normal.into();
let rhs = 7u64;
assert_eq!(
u64::try_from(safe.$operation(rhs)).unwrap(),
normal.$operation(rhs)
);
}
#[test]
fn test_assign() {
let mut var: SafeNum = 2112.into();
let rhs = 666u64;
let expect = var.$operation(rhs);
var.$assign_operation(rhs);
assert_eq!(var, expect);
}
#[test]
fn test_assign_poison() {
let mut var = SafeNum::MIN - 1;
let expected = var - 1;
var.$assign_operation(2);
// Poison saturates and doesn't perform additional changes
assert_eq!(var, expected);
}
}
};
}
correctness_tests!(Add, add, add_assign);
correctness_tests!(Sub, sub, sub_assign);
correctness_tests!(Mul, mul, mul_assign);
correctness_tests!(Div, div, div_assign);
correctness_tests!(Rem, rem, rem_assign);
#[test]
fn test_round_down() {
let x: SafeNum = 255.into();
assert_eq!(x.round_down(32), 224.into());
assert_eq!((x + 1).round_down(64), 256.into());
assert_eq!(x.round_down(256), SafeNum::ZERO);
assert!(x.round_down(SafeNum::MIN).has_error());
}
#[test]
fn test_round_up() {
let x: SafeNum = 255.into();
assert_eq!(x.round_up(32), 256.into());
assert_eq!(x.round_up(51), x);
assert_eq!(SafeNum::ZERO.round_up(x), SafeNum::ZERO);
assert!(SafeNum::MAX.round_up(32).has_error());
}
}