Python单元测试中的assertAlmostEqual用于浮点数的集合


81

Python的单元测试框架中assertAlmostEqual(x,y)方法在假设和为浮点数的情况下测试和近似相等。xy

问题assertAlmostEqual()在于它仅适用于浮点数。我正在寻找一种类似的方法assertAlmostEqual(),可用于浮点数列表,浮点数集,浮点数字典,浮点数元组,浮点数元组列表,浮点数列表集等。

例如,我们x = 0.1234567890y = 0.1234567891xy几乎相等,因为他们对每一个除了最后一个数字一致。因此self.assertAlmostEqual(x, y)True因为assertAlmostEqual()适用于彩车。

我正在寻找一种更通用的方法assertAlmostEquals(),该方法还可以评估对的以下调用True

  • self.assertAlmostEqual_generic([x, x, x], [y, y, y])
  • self.assertAlmostEqual_generic({1: x, 2: x, 3: x}, {1: y, 2: y, 3: y})
  • self.assertAlmostEqual_generic([(x,x)], [(y,y)])

有这种方法还是我必须自己实现?

说明:

  • assertAlmostEquals()有一个名为的可选参数,places并且通过计算四舍五入到十进制数的差来比较数字places。默认情况下places=7,因此self.assertAlmostEqual(0.5, 0.4)为False,而self.assertAlmostEqual(0.12345678, 0.12345679)为True。我的投机assertAlmostEqual_generic()应该具有相同的功能。

  • 如果两个列表具有完全相同的顺序几乎相等的数字,则认为它们几乎相等。正式地,for i in range(n): self.assertAlmostEqual(list1[i], list2[i])

  • 同样,如果两个集合可以转换为几乎相等的列表(通过为每个集合分配顺序),则认为它们几乎相等。

  • 类似地,如果每个字典的键集几乎等于另一个字典的键集,则两个字典被认为几乎相等,并且对于每个这样的几乎相等的键对,都有一个相应的几乎相等的值。

  • 总的来说:我认为两个集合如果相等就几乎相等,除了一些对应的float彼此几乎相等。换句话说,我真的想比较对象,但是在沿途比较浮点时精度(自定义)较低。


float在字典中使用键有什么意义?由于不能确保获得完全相同的浮点数,因此永远都不会使用查找来找到项目。而且,如果您不使用查找,为什么不使用元组列表而不是字典呢?相同的论点适用于集合。
最大

只是链接到源assertAlmostEqual
djvg

Answers:


71

如果您不介意使用NumPy(Python(x,y)随附),则可能需要查看np.testing定义以下内容的模块:assert_almost_equal函数。

签名是 np.testing.assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True)

>>> x = 1.000001
>>> y = 1.000002
>>> np.testing.assert_almost_equal(x, y)
AssertionError: 
Arrays are not almost equal to 7 decimals
ACTUAL: 1.000001
DESIRED: 1.000002
>>> np.testing.assert_almost_equal(x, y, 5)
>>> np.testing.assert_almost_equal([x, x, x], [y, y, y], 5)
>>> np.testing.assert_almost_equal((x, x, x), (y, y, y), 5)

4
这很接近,但是numpy.testing几乎相等的方法仅适用于数字,数组,元组和列表。它们不适用于字典,集合集和集合。
snakile 2012年

确实,但这只是一个开始。此外,您可以访问源代码,可以对其进行修改以允许比较字典,集合等。np.testing.assert_equal确实会将字典识别为参数(即使比较是由a进行的==,但对您不起作用)。
皮埃尔·通用

当然,如@BrenBarn所述,在比较集合时您仍然会遇到麻烦。
Pierre GM

需要注意的是当前文档assert_array_almost_equal建议使用assert_allcloseassert_array_almost_equal_nulpassert_array_max_ulp代替。
phunehehe


9

这是我实现通用is_almost_equal(first, second)函数的方式

首先,复制需要比较的对象(firstsecond),但不要进行精确的复制:切掉对象内部遇到的任何浮点数的无关紧要的十进制数字。

现在,你有副本first,并second为这微不足道的小数位数都消失了,只是比较firstsecond使用的==运营商。

假设我们有一个cut_insignificant_digits_recursively(obj, places)函数可以复制,obj但只将places每个float的最高有效十进制数字保留在original中obj。这是的有效实现is_almost_equals(first, second, places)

from insignificant_digit_cutter import cut_insignificant_digits_recursively

def is_almost_equal(first, second, places):
    '''returns True if first and second equal. 
    returns true if first and second aren't equal but have exactly the same
    structure and values except for a bunch of floats which are just almost
    equal (floats are almost equal if they're equal when we consider only the
    [places] most significant digits of each).'''
    if first == second: return True
    cut_first = cut_insignificant_digits_recursively(first, places)
    cut_second = cut_insignificant_digits_recursively(second, places)
    return cut_first == cut_second

这是一个有效的实现cut_insignificant_digits_recursively(obj, places)

def cut_insignificant_digits(number, places):
    '''cut the least significant decimal digits of a number, 
    leave only [places] decimal digits'''
    if  type(number) != float: return number
    number_as_str = str(number)
    end_of_number = number_as_str.find('.')+places+1
    if end_of_number > len(number_as_str): return number
    return float(number_as_str[:end_of_number])

def cut_insignificant_digits_lazy(iterable, places):
    for obj in iterable:
        yield cut_insignificant_digits_recursively(obj, places)

def cut_insignificant_digits_recursively(obj, places):
    '''return a copy of obj except that every float loses its least significant 
    decimal digits remaining only [places] decimal digits'''
    t = type(obj)
    if t == float: return cut_insignificant_digits(obj, places)
    if t in (list, tuple, set):
        return t(cut_insignificant_digits_lazy(obj, places))
    if t == dict:
        return {cut_insignificant_digits_recursively(key, places):
                cut_insignificant_digits_recursively(val, places)
                for key,val in obj.items()}
    return obj

该代码及其单元测试可在此处获得:https : //github.com/snakile/approximate_comparator。我欢迎任何改进和错误修复。


不是比较浮点数,而是比较字符串?好的...但是,设置通用格式会更容易吗?喜欢fmt="{{0:{0}f}}".format(decimals),并使用这种fmt格式对您的浮动内容进行“字符串化”吗?
皮埃尔·通用

1
这看起来不错,但有一点要注意:places给出小数位数,而不是有效数字的位数。例如,比较1024.1231023.999与3个有效位应返回相等,但与3个小数位则不相等。
罗德尼·理查森

1
@pir,许可证确实未定义。请参阅snalile在此问题上的回答,他说他没有时间选择/添加许可证,但授予使用/修改权限。谢谢分享,顺便说一句。
杰罗姆

1
@RodneyRichardson,是的,这是小数位,例如assertAlmostEqual中的内容:“请注意,这些方法将值四舍五入到给定的小数位数(例如,round()函数),而不是有效数字。”
杰罗姆

2
@Jérôme,谢谢您的评论。我刚刚添加了MIT许可证。
snakile

5

如果您不介意使用该numpy软件包numpy.testing,请assert_array_almost_equal方法。

这适用于 array_like对象,因此浮点数的数组,列表和元组,但不适用于集合和字典。

文档在这里


4

没有这种方法,您必须自己做。

对于列表和元组,定义是显而易见的,但是请注意,您提到的其他情况并不明显,因此也就难怪没有提供这样的功能。例如,是{1.00001: 1.00002}几乎等于{1.00002: 1.00001}?处理此类情况需要选择是否接近取决于键或值还是两者都取决于。对于集合,您不太可能找到有意义的定义,因为集合是无序的,因此没有“对应”元素的概念。


布伦·巴恩(BrenBarn):我已经对问题进行了澄清。回答你的问题是,{1.00001: 1.00002}几乎等于{1.00002: 1.00001}当且仅当1.00001几乎等于1.00002。默认情况下,它们几乎不相等(因为默认精度为7位小数),但对于一个足够小的值,places它们几乎相等。
snakile 2012年

1
@BrenBarn:IMO,float出于明显的原因,不建议使用dict中的type键(甚至禁止使用)。dict的近似相等应仅基于值;测试框架不必担心floatfor键的错误使用。对于集合,可以在比较之前对它们进行排序,然后可以对排序后的列表进行比较。
最大

2

您可能必须自己实现它,虽然列表和集合可以用相同的方式迭代,字典是不同的故事,您迭代了它们的键而不是值,但第三个示例对我来说似乎有点模棱两可,您的意思是比较集合中的每个值或每个集合中的每个值。

这是一个简单的代码段。

def almost_equal(value_1, value_2, accuracy = 10**-8):
    return abs(value_1 - value_2) < accuracy

x = [1,2,3,4]
y = [1,2,4,5]
assert all(almost_equal(*values) for values in zip(x, y))

谢谢,该解决方案适用于列表和元组,但不适用于其他类型的集合(或嵌套集合)。请参阅我添加到问题中的说明。我希望我的意图是明确的。如果在数字不是很精确地测量的世界中,如果将两组视为相等,则两组几乎相等。
snakile 2012年

0

这些答案都不适合我。以下代码应适用于python集合,类,数据类和namedtuple。我可能已经忘记了一些东西,但是到目前为止,这对我有用。

import unittest
from collections import namedtuple, OrderedDict
from dataclasses import dataclass
from typing import Any


def are_almost_equal(o1: Any, o2: Any, max_abs_ratio_diff: float, max_abs_diff: float) -> bool:
    """
    Compares two objects by recursively walking them trough. Equality is as usual except for floats.
    Floats are compared according to the two measures defined below.

    :param o1: The first object.
    :param o2: The second object.
    :param max_abs_ratio_diff: The maximum allowed absolute value of the difference.
    `abs(1 - (o1 / o2)` and vice-versa if o2 == 0.0. Ignored if < 0.
    :param max_abs_diff: The maximum allowed absolute difference `abs(o1 - o2)`. Ignored if < 0.
    :return: Whether the two objects are almost equal.
    """
    if type(o1) != type(o2):
        return False

    composite_type_passed = False

    if hasattr(o1, '__slots__'):
        if len(o1.__slots__) != len(o2.__slots__):
            return False
        if any(not are_almost_equal(getattr(o1, s1), getattr(o2, s2),
                                    max_abs_ratio_diff, max_abs_diff)
            for s1, s2 in zip(sorted(o1.__slots__), sorted(o2.__slots__))):
            return False
        else:
            composite_type_passed = True

    if hasattr(o1, '__dict__'):
        if len(o1.__dict__) != len(o2.__dict__):
            return False
        if any(not are_almost_equal(k1, k2, max_abs_ratio_diff, max_abs_diff)
            or not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
            for ((k1, v1), (k2, v2))
            in zip(sorted(o1.__dict__.items()), sorted(o2.__dict__.items()))
            if not k1.startswith('__')):  # avoid infinite loops
            return False
        else:
            composite_type_passed = True

    if isinstance(o1, dict):
        if len(o1) != len(o2):
            return False
        if any(not are_almost_equal(k1, k2, max_abs_ratio_diff, max_abs_diff)
            or not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
            for ((k1, v1), (k2, v2)) in zip(sorted(o1.items()), sorted(o2.items()))):
            return False

    elif any(issubclass(o1.__class__, c) for c in (list, tuple, set)):
        if len(o1) != len(o2):
            return False
        if any(not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
            for v1, v2 in zip(o1, o2)):
            return False

    elif isinstance(o1, float):
        if o1 == o2:
            return True
        else:
            if max_abs_ratio_diff > 0:  # if max_abs_ratio_diff < 0, max_abs_ratio_diff is ignored
                if o2 != 0:
                    if abs(1.0 - (o1 / o2)) > max_abs_ratio_diff:
                        return False
                else:  # if both == 0, we already returned True
                    if abs(1.0 - (o2 / o1)) > max_abs_ratio_diff:
                        return False
            if 0 < max_abs_diff < abs(o1 - o2):  # if max_abs_diff < 0, max_abs_diff is ignored
                return False
            return True

    else:
        if not composite_type_passed:
            return o1 == o2

    return True


class EqualityTest(unittest.TestCase):

    def test_floats(self) -> None:
        o1 = ('hi', 3, 3.4)
        o2 = ('hi', 3, 3.400001)
        self.assertTrue(are_almost_equal(o1, o2, 0.0001, 0.0001))
        self.assertFalse(are_almost_equal(o1, o2, 0.00000001, 0.00000001))

    def test_ratio_only(self):
        o1 = ['hey', 10000, 123.12]
        o2 = ['hey', 10000, 123.80]
        self.assertTrue(are_almost_equal(o1, o2, 0.01, -1))
        self.assertFalse(are_almost_equal(o1, o2, 0.001, -1))

    def test_diff_only(self):
        o1 = ['hey', 10000, 1234567890.12]
        o2 = ['hey', 10000, 1234567890.80]
        self.assertTrue(are_almost_equal(o1, o2, -1, 1))
        self.assertFalse(are_almost_equal(o1, o2, -1, 0.1))

    def test_both_ignored(self):
        o1 = ['hey', 10000, 1234567890.12]
        o2 = ['hey', 10000, 0.80]
        o3 = ['hi', 10000, 0.80]
        self.assertTrue(are_almost_equal(o1, o2, -1, -1))
        self.assertFalse(are_almost_equal(o1, o3, -1, -1))

    def test_different_lengths(self):
        o1 = ['hey', 1234567890.12, 10000]
        o2 = ['hey', 1234567890.80]
        self.assertFalse(are_almost_equal(o1, o2, 1, 1))

    def test_classes(self):
        class A:
            d = 12.3

            def __init__(self, a, b, c):
                self.a = a
                self.b = b
                self.c = c

        o1 = A(2.34, 'str', {1: 'hey', 345.23: [123, 'hi', 890.12]})
        o2 = A(2.34, 'str', {1: 'hey', 345.231: [123, 'hi', 890.121]})
        self.assertTrue(are_almost_equal(o1, o2, 0.1, 0.1))
        self.assertFalse(are_almost_equal(o1, o2, 0.0001, 0.0001))

        o2.hello = 'hello'
        self.assertFalse(are_almost_equal(o1, o2, -1, -1))

    def test_namedtuples(self):
        B = namedtuple('B', ['x', 'y'])
        o1 = B(3.3, 4.4)
        o2 = B(3.4, 4.5)
        self.assertTrue(are_almost_equal(o1, o2, 0.2, 0.2))
        self.assertFalse(are_almost_equal(o1, o2, 0.001, 0.001))

    def test_classes_with_slots(self):
        class C(object):
            __slots__ = ['a', 'b']

            def __init__(self, a, b):
                self.a = a
                self.b = b

        o1 = C(3.3, 4.4)
        o2 = C(3.4, 4.5)
        self.assertTrue(are_almost_equal(o1, o2, 0.3, 0.3))
        self.assertFalse(are_almost_equal(o1, o2, -1, 0.01))

    def test_dataclasses(self):
        @dataclass
        class D:
            s: str
            i: int
            f: float

        @dataclass
        class E:
            f2: float
            f4: str
            d: D

        o1 = E(12.3, 'hi', D('hello', 34, 20.01))
        o2 = E(12.1, 'hi', D('hello', 34, 20.0))
        self.assertTrue(are_almost_equal(o1, o2, -1, 0.4))
        self.assertFalse(are_almost_equal(o1, o2, -1, 0.001))

        o3 = E(12.1, 'hi', D('ciao', 34, 20.0))
        self.assertFalse(are_almost_equal(o2, o3, -1, -1))

    def test_ordereddict(self):
        o1 = OrderedDict({1: 'hey', 345.23: [123, 'hi', 890.12]})
        o2 = OrderedDict({1: 'hey', 345.23: [123, 'hi', 890.0]})
        self.assertTrue(are_almost_equal(o1, o2, 0.01, -1))
        self.assertFalse(are_almost_equal(o1, o2, 0.0001, -1))

0

我仍然会使用self.assertEqual()它,当狗屎撞到风扇时,它仍能提供最丰富的信息。您可以通过四舍五入来做到这一点,例如。

self.assertEqual(round_tuple((13.949999999999999, 1.121212), 2), (13.95, 1.12))

这里round_tuple

def round_tuple(t: tuple, ndigits: int) -> tuple:
    return tuple(round(e, ndigits=ndigits) for e in t)

def round_list(l: list, ndigits: int) -> list:
    return [round(e, ndigits=ndigits) for e in l]

根据python文档(请参阅https://stackoverflow.com/a/41407651/1031191),您可以绕开13.94999999之类的舍入问题,因为 13.94999999 == 13.95is True


-1

一种替代方法是将数据转换为可比较的形式,例如,将每个浮点数转换为具有固定精度的字符串。

def comparable(data):
    """Converts `data` to a comparable structure by converting any floats to a string with fixed precision."""
    if isinstance(data, (int, str)):
        return data
    if isinstance(data, float):
        return '{:.4f}'.format(data)
    if isinstance(data, list):
        return [comparable(el) for el in data]
    if isinstance(data, tuple):
        return tuple([comparable(el) for el in data])
    if isinstance(data, dict):
        return {k: comparable(v) for k, v in data.items()}

那么你就可以:

self.assertEquals(comparable(value1), comparable(value2))
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.