Mocking async repository calls

Mocking is a great technique for isolating code that you want to test. By mocking the interfaces a code section depends on you can create an isolation level where only the behavior of this specific code section is being tested.

In my example I have a ProfileService (BLL) that takes an IProfileRepository interface (DAL) as parameter in the constructor. In the real application this interface would get a repository injected by the dependency resolver. But in my unit testing example I want to mock the repository so it’s completely disconnected and the ProfileService is isolated for the testing purpose. Usually that would be accomplished with a code like this:

[TestMethod]
public async Task ValuesAreCorrectAndReflectedInReturnedProfile()
{
  // Arrange
  string userId = Guid.NewGuid().ToString();
  string userEmail = "john.doe@outlook.com";
  string userName = "John Doe";
             
  // Act
  List<Profile> profiles = new List<Profile>(); 

  var profileRepositoryMock = new Mock<IProfileRepository>();
  profileRepositoryMock
         .Setup(x => x.GetAll())
         .Returns(profiles.AsQueryable());
         
  IProfileService service = new ProfileService(profileRepositoryMock.Object);
  Profile profile = await service
         .CreateNewProfileAsync(userId, userEmail, userName);

  // Assert
  Assert.IsNotNull(profile);
  Assert.AreEqual(userId, profile.UserId);
  Assert.AreEqual(userEmail, profile.UserEmail);
  Assert.AreEqual(userName, profile.UserName);
}

However, the above code will fail because the repository returns an async enabled IQueryable to the service layer to specify a query against.

public IQueryable<Profile> GetAll()
{
  return dbContext.Set<Profile>();
}

All of this to make use of the asynchronous features available in EF. This is not working with the way I’ve mocked the repository above, by just using an .AsQueryable() at the end of a list. An example of where it fails can bee seen here:

public async Task<Profile> FindProfileAsync(string userId)
{
  return await ProfileRepository
                      .GetAll()
                      .Where(p => p.UserId == userId)
                      .FirstOrDefaultAsync();
}

The code above illustrates a section from the ProfileService. The reason why it fails is that GetAll() is mocked as an IQueryable that doesn’t support async calls. When we try to use it in an async way an error is thrown.

System.InvalidOperationException: The provider for the source IQueryable doesn’t implement IDbAsyncQueryProvider. Only providers that implement IDbAsyncQueryProvider can be used for Entity Framework asynchronous operations.

As seen from the exception, this kind of mocking isn’t enough when using async calls to the repository.

One way of solving this is to skip the mocking and instead create a new local test database. You can then intercept the EF startup and reconfigure to perform a complete database migrations for each test run, towards the test database. This makes the tests a bit slow but you get a complete testable database connection for integration tests. I won’t go into that setup now. Maybe in a later post.

What I’m after in this example is instead to keep the mocking possibility but still being able to use a local collection and async calls.

One solution is to implement a fake DbSet that can handle async but stores the values in a local collection. The following implementation by Tim Schmidt is one that I’ve used a couple of times.

public class FakeDbSet<T> : IDbSet<T>, IDbAsyncEnumerable<T> where T : class
{
    readonly ObservableCollection<T> _data;
    readonly IQueryable _queryable;

    public FakeDbSet()
    {
        _data = new ObservableCollection<T>();
        _queryable = _data.AsQueryable();
    }

    public virtual T Find(params object[] keyValues)
    {
        throw new NotImplementedException("Derive from FakeDbSet<T> and override Find");
    }

    public Task<T> FindAsync(CancellationToken cancellationToken, params object[] keyValues)
    {
        throw new NotImplementedException();
    }

    public T Add(T item)
    {
        _data.Add(item);
        return item;
    }

    public T Remove(T item)
    {
        _data.Remove(item);
        return item;
    }

    public T Attach(T item)
    {
        _data.Add(item);
        return item;
    }

    public T Detach(T item)
    {
        _data.Remove(item);
        return item;
    }

    public T Create()
    {
        return Activator.CreateInstance<T>();
    }

    public TDerivedEntity Create<TDerivedEntity>() where TDerivedEntity : class, T
    {
        return Activator.CreateInstance<TDerivedEntity>();
    }

    public ObservableCollection<T> Local
    {
        get { return _data; }
    }

    Type IQueryable.ElementType
    {
        get { return _queryable.ElementType; }
    }

    System.Linq.Expressions.Expression IQueryable.Expression
    {
        get { return _queryable.Expression; }
    }

    IQueryProvider IQueryable.Provider
    {
        get { return new AsyncQueryProviderWrapper<T>(_queryable.Provider); }
    }

    System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
    {
        return _data.GetEnumerator();
    }

    IEnumerator<T> IEnumerable<T>.GetEnumerator()
    {
        return _data.GetEnumerator();
    }

    public int Count
    {
        get { return _data.Count; }
    }

    public IDbAsyncEnumerator<T> GetAsyncEnumerator()
    {
        return new AsyncEnumeratorWrapper<T>(_data.GetEnumerator());
    }

    IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
    {
        return GetAsyncEnumerator();
    }
}

internal class AsyncQueryProviderWrapper<T> : IDbAsyncQueryProvider
{
    private readonly IQueryProvider _inner;

    internal AsyncQueryProviderWrapper(IQueryProvider inner)
    {
        _inner = inner;
    }

    public IQueryable CreateQuery(Expression expression)
    {
        return new AsyncEnumerableQuery<T>(expression);
    }

    public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
    {
        return new AsyncEnumerableQuery<TElement>(expression);
    }

    public object Execute(Expression expression)
    {
        return _inner.Execute(expression);
    }

    public TResult Execute<TResult>(Expression expression)
    {
        return _inner.Execute<TResult>(expression);
    }

    public Task<object> ExecuteAsync(Expression expression, CancellationToken cancellationToken)
    {
        return Task.FromResult(Execute(expression));
    }

    public Task<TResult> ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
    {
        return Task.FromResult(Execute<TResult>(expression));
    }
}

public class AsyncEnumerableQuery<T> : EnumerableQuery<T>, IDbAsyncEnumerable<T>, IQueryable
{
    public AsyncEnumerableQuery(IEnumerable<T> enumerable) : base(enumerable)
    {
    }

    public AsyncEnumerableQuery(Expression expression) : base(expression)
    {
    }

    public IDbAsyncEnumerator<T> GetAsyncEnumerator()
    {
        return new AsyncEnumeratorWrapper<T>(this.AsEnumerable().GetEnumerator());
    }

    IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
    {
        return GetAsyncEnumerator();
    }

    IQueryProvider IQueryable.Provider
    {
        get { return new AsyncQueryProviderWrapper<T>(this); }
    }
}

public class AsyncEnumeratorWrapper<T> : IDbAsyncEnumerator<T>
{
    private readonly IEnumerator<T> _inner;

    public AsyncEnumeratorWrapper(IEnumerator<T> inner)
    {
        _inner = inner;
    }

    public void Dispose()
    {
        _inner.Dispose();
    }

    public Task<bool> MoveNextAsync(CancellationToken cancellationToken)
    {
        return Task.FromResult(_inner.MoveNext());
    }

    public T Current
    {
        get { return _inner.Current; }
    }

    object IDbAsyncEnumerator.Current
    {
        get { return Current; }
    }
}

To use this FakeDbSet in your unit testing just do like this. The changed code is marked.

[TestMethod]
public async Task ValuesAreCorrectAndReflectedInReturnedProfile()
{
  // Arrange
  string userId = Guid.NewGuid().ToString();
  string userEmail = "john.doe@outlook.com";
  string userName = "John Doe";

  // Act
  var dbSet = new FakeDbSet<Profile>();
  //var existingProfile = Profile
  //     .Create(userId, userEmail, userName);
  //dbSet.Add(existingProfile);

  var profileRepositoryMock = new Mock<IProfileRepository>();
  profileRepositoryMock
         .Setup(x => x.GetAll())
         .Returns(dbSet);

  IProfileService service = new ProfileService(profileRepositoryMock.Object);
  Profile profile = await service
         .CreateNewProfileAsync(userId, userEmail, userName);

  // Assert
  Assert.IsNotNull(profile);
  Assert.AreEqual(userId, profile.UserId);
  Assert.AreEqual(userEmail, profile.UserEmail);
  Assert.AreEqual(userName, profile.UserName);
}

Now the async code testing works fine.

Don’t forget to change the test method to async Task from the default void. If you don’t change you’ll get errors when running the tests.

Creds to Tim Schmidt for creating the FakeDbSet that I’ve used in the solution above.

Happy asynchronous testing!